├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── figures
└── teaser.png
├── finetune
├── dataset_folder.py
├── datasets.py
├── engine_for_finetuning.py
├── modeling_finetune.py
├── optim_factory.py
├── run_class_finetuning.py
└── utils.py
├── lib
├── __init__.py
├── augment.py
├── builder.py
├── dataload_optim.py
├── logger.py
└── misc.py
├── main.py
├── main_lincls.py
└── vits.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Aa][Rr][Mm]/
27 | [Aa][Rr][Mm]64/
28 | bld/
29 | [Bb]in/
30 | [Oo]bj/
31 | [Ll]og/
32 | [Ll]ogs/
33 |
34 | # Visual Studio 2015/2017 cache/options directory
35 | .vs/
36 | # Uncomment if you have tasks that create the project's static files in wwwroot
37 | #wwwroot/
38 |
39 | # Visual Studio 2017 auto generated files
40 | Generated\ Files/
41 |
42 | # MSTest test Results
43 | [Tt]est[Rr]esult*/
44 | [Bb]uild[Ll]og.*
45 |
46 | # NUnit
47 | *.VisualState.xml
48 | TestResult.xml
49 | nunit-*.xml
50 |
51 | # Build Results of an ATL Project
52 | [Dd]ebugPS/
53 | [Rr]eleasePS/
54 | dlldata.c
55 |
56 | # Benchmark Results
57 | BenchmarkDotNet.Artifacts/
58 |
59 | # .NET Core
60 | project.lock.json
61 | project.fragment.lock.json
62 | artifacts/
63 |
64 | # StyleCop
65 | StyleCopReport.xml
66 |
67 | # Files built by Visual Studio
68 | *_i.c
69 | *_p.c
70 | *_h.h
71 | *.ilk
72 | *.meta
73 | *.obj
74 | *.iobj
75 | *.pch
76 | *.pdb
77 | *.ipdb
78 | *.pgc
79 | *.pgd
80 | *.rsp
81 | *.sbr
82 | *.tlb
83 | *.tli
84 | *.tlh
85 | *.tmp
86 | *.tmp_proj
87 | *_wpftmp.csproj
88 | *.log
89 | *.vspscc
90 | *.vssscc
91 | .builds
92 | *.pidb
93 | *.svclog
94 | *.scc
95 |
96 | # Chutzpah Test files
97 | _Chutzpah*
98 |
99 | # Visual C++ cache files
100 | ipch/
101 | *.aps
102 | *.ncb
103 | *.opendb
104 | *.opensdf
105 | *.sdf
106 | *.cachefile
107 | *.VC.db
108 | *.VC.VC.opendb
109 |
110 | # Visual Studio profiler
111 | *.psess
112 | *.vsp
113 | *.vspx
114 | *.sap
115 |
116 | # Visual Studio Trace Files
117 | *.e2e
118 |
119 | # TFS 2012 Local Workspace
120 | $tf/
121 |
122 | # Guidance Automation Toolkit
123 | *.gpState
124 |
125 | # ReSharper is a .NET coding add-in
126 | _ReSharper*/
127 | *.[Rr]e[Ss]harper
128 | *.DotSettings.user
129 |
130 | # TeamCity is a build add-in
131 | _TeamCity*
132 |
133 | # DotCover is a Code Coverage Tool
134 | *.dotCover
135 |
136 | # AxoCover is a Code Coverage Tool
137 | .axoCover/*
138 | !.axoCover/settings.json
139 |
140 | # Visual Studio code coverage results
141 | *.coverage
142 | *.coveragexml
143 |
144 | # NCrunch
145 | _NCrunch_*
146 | .*crunch*.local.xml
147 | nCrunchTemp_*
148 |
149 | # MightyMoose
150 | *.mm.*
151 | AutoTest.Net/
152 |
153 | # Web workbench (sass)
154 | .sass-cache/
155 |
156 | # Installshield output folder
157 | [Ee]xpress/
158 |
159 | # DocProject is a documentation generator add-in
160 | DocProject/buildhelp/
161 | DocProject/Help/*.HxT
162 | DocProject/Help/*.HxC
163 | DocProject/Help/*.hhc
164 | DocProject/Help/*.hhk
165 | DocProject/Help/*.hhp
166 | DocProject/Help/Html2
167 | DocProject/Help/html
168 |
169 | # Click-Once directory
170 | publish/
171 |
172 | # Publish Web Output
173 | *.[Pp]ublish.xml
174 | *.azurePubxml
175 | # Note: Comment the next line if you want to checkin your web deploy settings,
176 | # but database connection strings (with potential passwords) will be unencrypted
177 | *.pubxml
178 | *.publishproj
179 |
180 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
181 | # checkin your Azure Web App publish settings, but sensitive information contained
182 | # in these scripts will be unencrypted
183 | PublishScripts/
184 |
185 | # NuGet Packages
186 | *.nupkg
187 | # NuGet Symbol Packages
188 | *.snupkg
189 | # The packages folder can be ignored because of Package Restore
190 | **/[Pp]ackages/*
191 | # except build/, which is used as an MSBuild target.
192 | !**/[Pp]ackages/build/
193 | # Uncomment if necessary however generally it will be regenerated when needed
194 | #!**/[Pp]ackages/repositories.config
195 | # NuGet v3's project.json files produces more ignorable files
196 | *.nuget.props
197 | *.nuget.targets
198 |
199 | # Microsoft Azure Build Output
200 | csx/
201 | *.build.csdef
202 |
203 | # Microsoft Azure Emulator
204 | ecf/
205 | rcf/
206 |
207 | # Windows Store app package directories and files
208 | AppPackages/
209 | BundleArtifacts/
210 | Package.StoreAssociation.xml
211 | _pkginfo.txt
212 | *.appx
213 | *.appxbundle
214 | *.appxupload
215 |
216 | # Visual Studio cache files
217 | # files ending in .cache can be ignored
218 | *.[Cc]ache
219 | # but keep track of directories ending in .cache
220 | !?*.[Cc]ache/
221 |
222 | # Others
223 | ClientBin/
224 | ~$*
225 | *~
226 | *.dbmdl
227 | *.dbproj.schemaview
228 | *.jfm
229 | *.pfx
230 | *.publishsettings
231 | orleans.codegen.cs
232 |
233 | # Including strong name files can present a security risk
234 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
235 | #*.snk
236 |
237 | # Since there are multiple workflows, uncomment next line to ignore bower_components
238 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
239 | #bower_components/
240 |
241 | # RIA/Silverlight projects
242 | Generated_Code/
243 |
244 | # Backup & report files from converting an old project file
245 | # to a newer Visual Studio version. Backup files are not needed,
246 | # because we have git ;-)
247 | _UpgradeReport_Files/
248 | Backup*/
249 | UpgradeLog*.XML
250 | UpgradeLog*.htm
251 | ServiceFabricBackup/
252 | *.rptproj.bak
253 |
254 | # SQL Server files
255 | *.mdf
256 | *.ldf
257 | *.ndf
258 |
259 | # Business Intelligence projects
260 | *.rdl.data
261 | *.bim.layout
262 | *.bim_*.settings
263 | *.rptproj.rsuser
264 | *- [Bb]ackup.rdl
265 | *- [Bb]ackup ([0-9]).rdl
266 | *- [Bb]ackup ([0-9][0-9]).rdl
267 |
268 | # Microsoft Fakes
269 | FakesAssemblies/
270 |
271 | # GhostDoc plugin setting file
272 | *.GhostDoc.xml
273 |
274 | # Node.js Tools for Visual Studio
275 | .ntvs_analysis.dat
276 | node_modules/
277 |
278 | # Visual Studio 6 build log
279 | *.plg
280 |
281 | # Visual Studio 6 workspace options file
282 | *.opt
283 |
284 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
285 | *.vbw
286 |
287 | # Visual Studio LightSwitch build output
288 | **/*.HTMLClient/GeneratedArtifacts
289 | **/*.DesktopClient/GeneratedArtifacts
290 | **/*.DesktopClient/ModelManifest.xml
291 | **/*.Server/GeneratedArtifacts
292 | **/*.Server/ModelManifest.xml
293 | _Pvt_Extensions
294 |
295 | # Paket dependency manager
296 | .paket/paket.exe
297 | paket-files/
298 |
299 | # FAKE - F# Make
300 | .fake/
301 |
302 | # CodeRush personal settings
303 | .cr/personal
304 |
305 | # Python Tools for Visual Studio (PTVS)
306 | __pycache__/
307 | *.pyc
308 |
309 | # Cake - Uncomment if you are using it
310 | # tools/**
311 | # !tools/packages.config
312 |
313 | # Tabs Studio
314 | *.tss
315 |
316 | # Telerik's JustMock configuration file
317 | *.jmconfig
318 |
319 | # BizTalk build output
320 | *.btp.cs
321 | *.btm.cs
322 | *.odx.cs
323 | *.xsd.cs
324 |
325 | # OpenCover UI analysis results
326 | OpenCover/
327 |
328 | # Azure Stream Analytics local run output
329 | ASALocalRun/
330 |
331 | # MSBuild Binary and Structured Log
332 | *.binlog
333 |
334 | # NVidia Nsight GPU debugger configuration file
335 | *.nvuser
336 |
337 | # MFractors (Xamarin productivity tool) working folder
338 | .mfractor/
339 |
340 | # Local History for Visual Studio
341 | .localhistory/
342 |
343 | # BeatPulse healthcheck temp database
344 | healthchecksdb
345 |
346 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
347 | MigrationBackup/
348 |
349 | # Ionide (cross platform F# VS Code tools) working folder
350 | .ionide/
351 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Extreme Masking for Learning Instance and Distributed Visual Representations
2 |
3 | This repo constains the official pytorch implementation for the ExtreMA paper [(arxiv)](https://arxiv.org/abs/2206.04667). ExtreMA explores to treat spatial token masking as a data augmentation for siamese representation learning. It follows the plain BYOL model with supervision created from the masking operation. ExtreMA not only learns a strong instance representation which captures the holistic image, but also meaningful distributed representations for each individual tokens. Multi-masking which processes a paralleled number of masks is developed to greatly accelerate training.
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Released Pretrained Model
11 |
12 | We release the following 4 representative models at the moment. The wall time is measured by a single node of 8xV100 GPUs with Pytorch environment 1.13. ExtreMA is signficantly more efficient and faster than competing masked modeling and siamese representation learning approaches.
13 |
14 | | name | pretrain dataset | epochs | masking | color-aug | wall time | linear | finetune | link |
15 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
16 | | ViT-Base | ImageNet-1K | 300 | 80%x5 | No | 50 hrs | 67.1 | 82.9| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_1k_300ep.pth) |
17 | | ViT-Base | ImageNet-1K | 300 | 80%x5 | Yes | 50 hrs | 73.3 | 83.7| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_color_1k_300ep.pth) |
18 | | ViT-Base | ImageNet-1K | 300 | 90%x8 | Yes | 46 hrs | 68.4 | 83.5| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask09x8_color_1k_300ep.pth) |
19 | | ViT-Base | ImageNet-22K | 30 | 80%x5 | Yes | 56 hrs | 74.5 | 83.9| [full ckpt](https://frontiers.blob.core.windows.net/pretraining/checkpoints/extrema/extrema_mask08x5_color_22k_30ep.pth) |
20 |
21 | ## Pre-training
22 |
23 | To train ExtreMA, follow the command:
24 |
25 | ```
26 | python -m torch.distributed.launch --nproc_per_node=8 main.py \
27 | -a vit_base -b 2048 \
28 | --lr=1.5e-4 --weight-decay=.1 --weight-decay-end=.1 \
29 | --opt=adamw \
30 | --aug-spatialconsistent-color \
31 | --loss byol \
32 | --epochs=300 --warmup-epochs=40 --save-freq 5 \
33 | --opt-betas 0.9 0.95 \
34 | --drop_path_rate 0.1 --attn_drop_rate 0. \
35 | --layer_scale_init_value 0.1 --class_attention_layers 2 \
36 | --mask-ratio 0.8 --num-masks 5 \
37 | --ema-momentum 0.996 \
38 | --proj-dim 256 \
39 | --dist-url 'tcp://localhost:10001' \
40 | --multiprocessing-distributed \
41 | --seed 0 \
42 | --log_dir $LOG_DIR \
43 | --output_dir $SAVE_DIR \
44 | $DATA_DIR \
45 | ```
46 |
47 | ## Linear and Finetuning Evaluations on ImageNet-1k
48 |
49 | To linear probe a pretrained model,
50 | ```
51 | python -m torch.distributed.launch --nproc_per_node=8 main_lincls.py \
52 | -a vit_base --lr 0.1 \
53 | -b 4096 --optimizer sgd --warmup-epochs 10 \
54 | --log_dir ./ --eval_momentum \
55 | --dist-url 'tcp://localhost:10001' \
56 | --multiprocessing-distributed \
57 | --pretrained $MODEL \
58 | $DATA
59 | ```
60 |
61 | To finetune the model end-to-end,
62 | ```
63 | python -m torch.distributed.launch --nproc_per_node=8 finetune/run_class_finetuning.py \
64 | --model vit_base_patch16_224 \
65 | --data_path $DATA_DIR \
66 | --use_mean_pooling \
67 | --color_jitter 0.4 --reprob 0.25 \
68 | --finetune $MODEL --output_dir $LOG_DIR \
69 | --layer_decay 0.65 \
70 | --lr 5e-4 \
71 | --batch_size 128 --update_freq 1 --opt adamw --opt_betas 0.9 0.999 \
72 | --weight_decay 0.05 --warmup_epochs 5 --drop_path 0.2 --epochs 100 \
73 | --dist_eval \
74 | ```
75 | The finetuning code is based on BEiT, with important modifications of removing the [cls] token at the ViT input.
76 |
77 | ## Other Downstream Evaluations
78 |
79 | For semantic segmentation and instance detection, we follow [the CAE codebase](https://github.com/lxtGH/CAE). Care must be taken to remove the [cls] token at the input for ExtreMA.
80 |
81 | ## Acknowledgement
82 |
83 | The ExtreMA code sigificantly borrows content from MoCo-v3, MAE, BEiT and the timm library.
84 |
85 | ## Citation
86 |
87 | ```
88 | @article{wu2022extreme,
89 | title={Extreme Masking for Learning Instance and Distributed Visual Representations},
90 | author={Wu, Zhirong and Lai, Zihang and Sun, Xiao and Lin, Stephen},
91 | journal={arXiv preprint arXiv:2206.04667},
92 | year={2022}
93 | }
94 | ```
95 |
96 | ## Contributing
97 |
98 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
99 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
100 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
101 |
102 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
103 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
104 | provided by the bot. You will only need to do this once across all repos using our CLA.
105 |
106 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
107 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
108 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
109 |
110 | ## Trademarks
111 |
112 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
113 | trademarks or logos is subject to and must follow
114 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
115 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
116 | Any use of third-party trademarks or logos are subject to those third-party's policies.
117 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/ExtreMA/bf30b64ec0046524d0c373b761df323b9328f8dc/figures/teaser.png
--------------------------------------------------------------------------------
/finetune/dataset_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 | from PIL import Image
11 |
12 | import os
13 | import os.path
14 | import random
15 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple
16 |
17 |
18 | def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
19 | """Checks if a file is an allowed extension.
20 |
21 | Args:
22 | filename (string): path to a file
23 | extensions (tuple of strings): extensions to consider (lowercase)
24 |
25 | Returns:
26 | bool: True if the filename ends with one of given extensions
27 | """
28 | return filename.lower().endswith(extensions)
29 |
30 |
31 | def is_image_file(filename: str) -> bool:
32 | """Checks if a file is an allowed image extension.
33 |
34 | Args:
35 | filename (string): path to a file
36 |
37 | Returns:
38 | bool: True if the filename ends with a known image extension
39 | """
40 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
41 |
42 |
43 | def make_dataset(
44 | directory: str,
45 | class_to_idx: Dict[str, int],
46 | extensions: Optional[Tuple[str, ...]] = None,
47 | is_valid_file: Optional[Callable[[str], bool]] = None,
48 | ) -> List[Tuple[str, int]]:
49 | instances = []
50 | directory = os.path.expanduser(directory)
51 | both_none = extensions is None and is_valid_file is None
52 | both_something = extensions is not None and is_valid_file is not None
53 | if both_none or both_something:
54 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
55 | if extensions is not None:
56 | def is_valid_file(x: str) -> bool:
57 | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
58 | is_valid_file = cast(Callable[[str], bool], is_valid_file)
59 | for target_class in sorted(class_to_idx.keys()):
60 | class_index = class_to_idx[target_class]
61 | target_dir = os.path.join(directory, target_class)
62 | if not os.path.isdir(target_dir):
63 | continue
64 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
65 | for fname in sorted(fnames):
66 | path = os.path.join(root, fname)
67 | if is_valid_file(path):
68 | item = path, class_index
69 | instances.append(item)
70 | return instances
71 |
72 |
73 | class DatasetFolder(VisionDataset):
74 | """A generic data loader where the samples are arranged in this way: ::
75 |
76 | root/class_x/xxx.ext
77 | root/class_x/xxy.ext
78 | root/class_x/xxz.ext
79 |
80 | root/class_y/123.ext
81 | root/class_y/nsdf3.ext
82 | root/class_y/asd932_.ext
83 |
84 | Args:
85 | root (string): Root directory path.
86 | loader (callable): A function to load a sample given its path.
87 | extensions (tuple[string]): A list of allowed extensions.
88 | both extensions and is_valid_file should not be passed.
89 | transform (callable, optional): A function/transform that takes in
90 | a sample and returns a transformed version.
91 | E.g, ``transforms.RandomCrop`` for images.
92 | target_transform (callable, optional): A function/transform that takes
93 | in the target and transforms it.
94 | is_valid_file (callable, optional): A function that takes path of a file
95 | and check if the file is a valid file (used to check of corrupt files)
96 | both extensions and is_valid_file should not be passed.
97 |
98 | Attributes:
99 | classes (list): List of the class names sorted alphabetically.
100 | class_to_idx (dict): Dict with items (class_name, class_index).
101 | samples (list): List of (sample path, class_index) tuples
102 | targets (list): The class_index value for each image in the dataset
103 | """
104 |
105 | def __init__(
106 | self,
107 | root: str,
108 | loader: Callable[[str], Any],
109 | extensions: Optional[Tuple[str, ...]] = None,
110 | transform: Optional[Callable] = None,
111 | target_transform: Optional[Callable] = None,
112 | is_valid_file: Optional[Callable[[str], bool]] = None,
113 | ) -> None:
114 | super(DatasetFolder, self).__init__(root, transform=transform,
115 | target_transform=target_transform)
116 | classes, class_to_idx = self._find_classes(self.root)
117 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
118 | if len(samples) == 0:
119 | msg = "Found 0 files in subfolders of: {}\n".format(self.root)
120 | if extensions is not None:
121 | msg += "Supported extensions are: {}".format(",".join(extensions))
122 | raise RuntimeError(msg)
123 |
124 | self.loader = loader
125 | self.extensions = extensions
126 |
127 | self.classes = classes
128 | self.class_to_idx = class_to_idx
129 | self.samples = samples
130 | self.targets = [s[1] for s in samples]
131 |
132 | def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
133 | """
134 | Finds the class folders in a dataset.
135 |
136 | Args:
137 | dir (string): Root directory path.
138 |
139 | Returns:
140 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
141 |
142 | Ensures:
143 | No class is a subdirectory of another.
144 | """
145 | classes = [d.name for d in os.scandir(dir) if d.is_dir()]
146 | classes.sort()
147 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
148 | return classes, class_to_idx
149 |
150 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
151 | """
152 | Args:
153 | index (int): Index
154 |
155 | Returns:
156 | tuple: (sample, target) where target is class_index of the target class.
157 | """
158 | while True:
159 | try:
160 | path, target = self.samples[index]
161 | sample = self.loader(path)
162 | break
163 | except Exception as e:
164 | print(e)
165 | index = random.randint(0, len(self.samples) - 1)
166 |
167 | if self.transform is not None:
168 | sample = self.transform(sample)
169 | if self.target_transform is not None:
170 | target = self.target_transform(target)
171 |
172 | return sample, target
173 |
174 | def __len__(self) -> int:
175 | return len(self.samples)
176 |
177 |
178 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
179 |
180 |
181 | def pil_loader(path: str) -> Image.Image:
182 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
183 | with open(path, 'rb') as f:
184 | img = Image.open(f)
185 | return img.convert('RGB')
186 |
187 |
188 | # TODO: specify the return type
189 | def accimage_loader(path: str) -> Any:
190 | import accimage
191 | try:
192 | return accimage.Image(path)
193 | except IOError:
194 | # Potentially a decoding problem, fall back to PIL.Image
195 | return pil_loader(path)
196 |
197 |
198 | def default_loader(path: str) -> Any:
199 | from torchvision import get_image_backend
200 | if get_image_backend() == 'accimage':
201 | return accimage_loader(path)
202 | else:
203 | return pil_loader(path)
204 |
205 |
206 | class ImageFolder(DatasetFolder):
207 | """A generic data loader where the images are arranged in this way: ::
208 |
209 | root/dog/xxx.png
210 | root/dog/xxy.png
211 | root/dog/xxz.png
212 |
213 | root/cat/123.png
214 | root/cat/nsdf3.png
215 | root/cat/asd932_.png
216 |
217 | Args:
218 | root (string): Root directory path.
219 | transform (callable, optional): A function/transform that takes in an PIL image
220 | and returns a transformed version. E.g, ``transforms.RandomCrop``
221 | target_transform (callable, optional): A function/transform that takes in the
222 | target and transforms it.
223 | loader (callable, optional): A function to load an image given its path.
224 | is_valid_file (callable, optional): A function that takes path of an Image file
225 | and check if the file is a valid file (used to check of corrupt files)
226 |
227 | Attributes:
228 | classes (list): List of the class names sorted alphabetically.
229 | class_to_idx (dict): Dict with items (class_name, class_index).
230 | imgs (list): List of (image path, class_index) tuples
231 | """
232 |
233 | def __init__(
234 | self,
235 | root: str,
236 | transform: Optional[Callable] = None,
237 | target_transform: Optional[Callable] = None,
238 | loader: Callable[[str], Any] = default_loader,
239 | is_valid_file: Optional[Callable[[str], bool]] = None,
240 | ):
241 | super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
242 | transform=transform,
243 | target_transform=target_transform,
244 | is_valid_file=is_valid_file)
245 | self.imgs = self.samples
246 |
--------------------------------------------------------------------------------
/finetune/datasets.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import os
9 | import torch
10 |
11 | from torchvision import datasets, transforms
12 |
13 | from timm.data.constants import \
14 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
15 |
16 | from timm.data import create_transform
17 | #
18 | #from finetune.dataset_folder import ImageFolder
19 |
20 | def build_dataset(is_train, args):
21 | transform = build_transform(is_train, args)
22 |
23 | print("Transform = ")
24 | if isinstance(transform, tuple):
25 | for trans in transform:
26 | print(" - - - - - - - - - - ")
27 | for t in trans.transforms:
28 | print(t)
29 | else:
30 | for t in transform.transforms:
31 | print(t)
32 | print("---------------------------")
33 |
34 | if args.data_set == 'CIFAR':
35 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
36 | nb_classes = 100
37 | elif args.data_set == 'IMNET':
38 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
39 | dataset = datasets.ImageFolder(root, transform=transform)
40 | nb_classes = 1000
41 | else:
42 | raise NotImplementedError()
43 | assert nb_classes == args.nb_classes
44 | print("Number of the class = %d" % args.nb_classes)
45 |
46 | return dataset, nb_classes
47 |
48 |
49 | def build_transform(is_train, args):
50 | resize_im = args.input_size > 32
51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
54 |
55 | if is_train:
56 | # this should always dispatch to transforms_imagenet_train
57 | transform = create_transform(
58 | input_size=args.input_size,
59 | is_training=True,
60 | color_jitter=args.color_jitter,
61 | auto_augment=args.aa,
62 | interpolation=args.train_interpolation,
63 | re_prob=args.reprob,
64 | re_mode=args.remode,
65 | re_count=args.recount,
66 | mean=mean,
67 | std=std,
68 | )
69 | if not resize_im:
70 | # replace RandomResizedCropAndInterpolation with
71 | # RandomCrop
72 | transform.transforms[0] = transforms.RandomCrop(
73 | args.input_size, padding=4)
74 | return transform
75 |
76 | t = []
77 | if resize_im:
78 | if args.crop_pct is None:
79 | if args.input_size < 384:
80 | args.crop_pct = 224 / 256
81 | else:
82 | args.crop_pct = 1.0
83 | size = int(args.input_size / args.crop_pct)
84 | t.append(
85 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
86 | )
87 | t.append(transforms.CenterCrop(args.input_size))
88 |
89 | t.append(transforms.ToTensor())
90 | t.append(transforms.Normalize(mean, std))
91 | return transforms.Compose(t)
92 |
--------------------------------------------------------------------------------
/finetune/engine_for_finetuning.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import math
9 | import sys
10 | from typing import Iterable, Optional
11 |
12 | import torch
13 |
14 | from timm.data import Mixup
15 | from timm.utils import accuracy, ModelEma
16 |
17 | import utils
18 |
19 |
20 | def train_class_batch(model, samples, target, criterion):
21 | outputs = model(samples)
22 | loss = criterion(outputs, target)
23 | return loss, outputs
24 |
25 |
26 | def get_loss_scale_for_deepspeed(model):
27 | optimizer = model.optimizer
28 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
29 |
30 |
31 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
32 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
33 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
34 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
35 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
36 | num_training_steps_per_epoch=None, update_freq=None):
37 | model.train(True)
38 | metric_logger = utils.MetricLogger(delimiter=" ")
39 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
40 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
41 | header = 'Epoch: [{}]'.format(epoch)
42 | print_freq = 10
43 |
44 | if loss_scaler is None:
45 | model.zero_grad()
46 | model.micro_steps = 0
47 | else:
48 | optimizer.zero_grad()
49 |
50 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
51 | step = data_iter_step // update_freq
52 | if step >= num_training_steps_per_epoch:
53 | continue
54 | it = start_steps + step # global training iteration
55 | # Update LR & WD for the first acc
56 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
57 | for i, param_group in enumerate(optimizer.param_groups):
58 | if lr_schedule_values is not None:
59 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
60 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
61 | param_group["weight_decay"] = wd_schedule_values[it]
62 |
63 | samples = samples.to(device, non_blocking=True)
64 | targets = targets.to(device, non_blocking=True)
65 |
66 | if mixup_fn is not None:
67 | samples, targets = mixup_fn(samples, targets)
68 |
69 | if loss_scaler is None:
70 | samples = samples.half()
71 | loss, output = train_class_batch(
72 | model, samples, targets, criterion)
73 | else:
74 | with torch.cuda.amp.autocast():
75 | loss, output = train_class_batch(
76 | model, samples, targets, criterion)
77 |
78 | loss_value = loss.item()
79 |
80 | if not math.isfinite(loss_value):
81 | print("Loss is {}, stopping training".format(loss_value))
82 | sys.exit(1)
83 |
84 | if loss_scaler is None:
85 | loss /= update_freq
86 | model.backward(loss)
87 | model.step()
88 |
89 | if (data_iter_step + 1) % update_freq == 0:
90 | # model.zero_grad()
91 | # Deepspeed will call step() & model.zero_grad() automatic
92 | if model_ema is not None:
93 | model_ema.update(model)
94 | grad_norm = None
95 | loss_scale_value = get_loss_scale_for_deepspeed(model)
96 | else:
97 | # this attribute is added by timm on one optimizer (adahessian)
98 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
99 | loss /= update_freq
100 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
101 | parameters=model.parameters(), create_graph=is_second_order,
102 | update_grad=(data_iter_step + 1) % update_freq == 0)
103 | if (data_iter_step + 1) % update_freq == 0:
104 | optimizer.zero_grad()
105 | if model_ema is not None:
106 | model_ema.update(model)
107 | loss_scale_value = loss_scaler.state_dict()["scale"]
108 |
109 | torch.cuda.synchronize()
110 |
111 | if mixup_fn is None:
112 | class_acc = (output.max(-1)[-1] == targets).float().mean()
113 | else:
114 | class_acc = None
115 | metric_logger.update(loss=loss_value)
116 | metric_logger.update(class_acc=class_acc)
117 | metric_logger.update(loss_scale=loss_scale_value)
118 | min_lr = 10.
119 | max_lr = 0.
120 | for group in optimizer.param_groups:
121 | min_lr = min(min_lr, group["lr"])
122 | max_lr = max(max_lr, group["lr"])
123 |
124 | metric_logger.update(lr=max_lr)
125 | metric_logger.update(min_lr=min_lr)
126 | weight_decay_value = None
127 | for group in optimizer.param_groups:
128 | if group["weight_decay"] > 0:
129 | weight_decay_value = group["weight_decay"]
130 | metric_logger.update(weight_decay=weight_decay_value)
131 | metric_logger.update(grad_norm=grad_norm)
132 |
133 | if log_writer is not None:
134 | log_writer.update(loss=loss_value, head="loss")
135 | log_writer.update(class_acc=class_acc, head="loss")
136 | log_writer.update(loss_scale=loss_scale_value, head="opt")
137 | log_writer.update(lr=max_lr, head="opt")
138 | log_writer.update(min_lr=min_lr, head="opt")
139 | log_writer.update(weight_decay=weight_decay_value, head="opt")
140 | log_writer.update(grad_norm=grad_norm, head="opt")
141 |
142 | log_writer.set_step()
143 |
144 | # gather the stats from all processes
145 | metric_logger.synchronize_between_processes()
146 | print("Averaged stats:", metric_logger)
147 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
148 |
149 |
150 | @torch.no_grad()
151 | def evaluate(data_loader, model, device):
152 | criterion = torch.nn.CrossEntropyLoss()
153 |
154 | metric_logger = utils.MetricLogger(delimiter=" ")
155 | header = 'Test:'
156 |
157 | # switch to evaluation mode
158 | model.eval()
159 |
160 | for batch in metric_logger.log_every(data_loader, 10, header):
161 | images = batch[0]
162 | target = batch[-1]
163 | images = images.to(device, non_blocking=True)
164 | target = target.to(device, non_blocking=True)
165 |
166 | # compute output
167 | with torch.cuda.amp.autocast():
168 | output = model(images)
169 | loss = criterion(output, target)
170 |
171 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
172 |
173 | batch_size = images.shape[0]
174 | metric_logger.update(loss=loss.item())
175 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
176 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
177 | # gather the stats from all processes
178 | metric_logger.synchronize_between_processes()
179 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
180 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
181 |
182 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
183 |
--------------------------------------------------------------------------------
/finetune/modeling_finetune.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import math
9 | from functools import partial
10 | import numpy as np
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16 | from timm.models.registry import register_model
17 |
18 |
19 | def _cfg(url='', **kwargs):
20 | return {
21 | 'url': url,
22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
23 | 'crop_pct': .9, 'interpolation': 'bicubic',
24 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
25 | **kwargs
26 | }
27 |
28 |
29 | class DropPath(nn.Module):
30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31 | """
32 | def __init__(self, drop_prob=None):
33 | super(DropPath, self).__init__()
34 | self.drop_prob = drop_prob
35 |
36 | def forward(self, x):
37 | return drop_path(x, self.drop_prob, self.training)
38 |
39 | def extra_repr(self) -> str:
40 | return 'p={}'.format(self.drop_prob)
41 |
42 |
43 | class Mlp(nn.Module):
44 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
45 | super().__init__()
46 | out_features = out_features or in_features
47 | hidden_features = hidden_features or in_features
48 | self.fc1 = nn.Linear(in_features, hidden_features)
49 | self.act = act_layer()
50 | self.fc2 = nn.Linear(hidden_features, out_features)
51 | self.drop = nn.Dropout(drop)
52 |
53 | def forward(self, x):
54 | x = self.fc1(x)
55 | x = self.act(x)
56 | # x = self.drop(x)
57 | # commit this for the orignal BERT implement
58 | x = self.fc2(x)
59 | x = self.drop(x)
60 | return x
61 |
62 |
63 | class Attention(nn.Module):
64 | def __init__(
65 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
66 | proj_drop=0., attn_head_dim=None):
67 | super().__init__()
68 | self.num_heads = num_heads
69 | head_dim = dim // num_heads
70 | if attn_head_dim is not None:
71 | head_dim = attn_head_dim
72 | all_head_dim = head_dim * self.num_heads
73 | self.scale = qk_scale or head_dim ** -0.5
74 |
75 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
76 |
77 | self.attn_drop = nn.Dropout(attn_drop)
78 | self.proj = nn.Linear(all_head_dim, dim)
79 | self.proj_drop = nn.Dropout(proj_drop)
80 |
81 | def forward(self, x):
82 | B, N, C = x.shape
83 |
84 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
86 |
87 | attn = (q @ k.transpose(-2, -1)) * self.scale
88 | attn = attn.softmax(dim=-1)
89 | attn = self.attn_drop(attn)
90 |
91 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
92 | x = self.proj(x)
93 | x = self.proj_drop(x)
94 |
95 | return x
96 |
97 |
98 | class Block(nn.Module):
99 |
100 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
101 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
102 | attn_head_dim=None):
103 | super().__init__()
104 | self.norm1 = norm_layer(dim)
105 | self.attn = Attention(
106 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
107 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
108 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
109 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
110 | self.norm2 = norm_layer(dim)
111 | mlp_hidden_dim = int(dim * mlp_ratio)
112 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
113 |
114 | if init_values > 0:
115 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
116 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
117 | else:
118 | self.gamma_1, self.gamma_2 = None, None
119 |
120 | def forward(self, x):
121 | if self.gamma_1 is None:
122 | x = x + self.drop_path(self.attn(self.norm1(x)))
123 | x = x + self.drop_path(self.mlp(self.norm2(x)))
124 | else:
125 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
126 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
127 | return x
128 |
129 |
130 | class PatchEmbed(nn.Module):
131 | """ Image to Patch Embedding
132 | """
133 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
134 | super().__init__()
135 | img_size = to_2tuple(img_size)
136 | patch_size = to_2tuple(patch_size)
137 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
138 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
139 | self.img_size = img_size
140 | self.patch_size = patch_size
141 | self.num_patches = num_patches
142 |
143 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
144 |
145 | def forward(self, x, **kwargs):
146 | B, C, H, W = x.shape
147 | # FIXME look at relaxing size constraints
148 | assert H == self.img_size[0] and W == self.img_size[1], \
149 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
150 | x = self.proj(x).flatten(2).transpose(1, 2)
151 | return x
152 |
153 | # sin-cos position encoding
154 | # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
155 | def get_sinusoid_encoding_table(n_position, d_hid):
156 | ''' Sinusoid position encoding table '''
157 | # TODO: make it with torch instead of numpy
158 | def get_position_angle_vec(position):
159 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
160 |
161 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
162 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
163 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
164 |
165 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
166 |
167 |
168 | class VisionTransformer(nn.Module):
169 | """ Vision Transformer with support for patch or hybrid CNN input stage
170 | """
171 | def __init__(self,
172 | img_size=224,
173 | patch_size=16,
174 | in_chans=3,
175 | num_classes=1000,
176 | embed_dim=768,
177 | depth=12,
178 | num_heads=12,
179 | mlp_ratio=4.,
180 | qkv_bias=False,
181 | qk_scale=None,
182 | drop_rate=0.,
183 | attn_drop_rate=0.,
184 | drop_path_rate=0.,
185 | norm_layer=nn.LayerNorm,
186 | init_values=0.,
187 | use_learnable_pos_emb=False,
188 | init_scale=0.,
189 | use_mean_pooling=True):
190 | super().__init__()
191 | self.num_classes = num_classes
192 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
193 |
194 | self.patch_embed = PatchEmbed(
195 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
196 | num_patches = self.patch_embed.num_patches
197 |
198 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
199 | if use_learnable_pos_emb:
200 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
201 | #self.pos_embed.requires_grad = False
202 | else:
203 | # sine-cosine positional embeddings is on the way
204 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
205 |
206 | self.pos_drop = nn.Dropout(p=drop_rate)
207 |
208 |
209 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
210 | self.blocks = nn.ModuleList([
211 | Block(
212 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
213 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
214 | init_values=init_values)
215 | for i in range(depth)])
216 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
217 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
218 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
219 |
220 | if use_learnable_pos_emb:
221 | trunc_normal_(self.pos_embed, std=.02)
222 |
223 | trunc_normal_(self.cls_token, std=.02)
224 | trunc_normal_(self.head.weight, std=.02)
225 | self.apply(self._init_weights)
226 |
227 | self.head.weight.data.mul_(init_scale)
228 | self.head.bias.data.mul_(init_scale)
229 |
230 | def _init_weights(self, m):
231 | if isinstance(m, nn.Linear):
232 | trunc_normal_(m.weight, std=.02)
233 | if isinstance(m, nn.Linear) and m.bias is not None:
234 | nn.init.constant_(m.bias, 0)
235 | elif isinstance(m, nn.LayerNorm):
236 | nn.init.constant_(m.bias, 0)
237 | nn.init.constant_(m.weight, 1.0)
238 |
239 | def get_num_layers(self):
240 | return len(self.blocks)
241 |
242 | @torch.jit.ignore
243 | def no_weight_decay(self):
244 | return {'pos_embed', 'cls_token'}
245 |
246 | def get_classifier(self):
247 | return self.head
248 |
249 | def reset_classifier(self, num_classes, global_pool=''):
250 | self.num_classes = num_classes
251 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
252 |
253 | def forward_features(self, x):
254 | x = self.patch_embed(x)
255 | B, _, _ = x.size()
256 |
257 | #cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
258 | #x = torch.cat((cls_tokens, x), dim=1)
259 | if self.pos_embed is not None:
260 | #x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
261 | x = x + self.pos_embed
262 | x = self.pos_drop(x)
263 |
264 | for blk in self.blocks:
265 | x = blk(x)
266 |
267 | if self.fc_norm is not None:
268 | return self.fc_norm(x.mean(1))
269 | else:
270 | x = self.norm(x)
271 | return x[:, 0]
272 |
273 | def forward(self, x):
274 | x = self.forward_features(x)
275 | x = self.head(x)
276 | return x
277 |
278 | @register_model
279 | def vit_small_patch16_224(pretrained=False, **kwargs):
280 | model = VisionTransformer(
281 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
282 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
283 | model.default_cfg = _cfg()
284 | return model
285 |
286 | @register_model
287 | def vit_base_patch16_224(pretrained=False, **kwargs):
288 | model = VisionTransformer(
289 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
290 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
291 | model.default_cfg = _cfg()
292 | return model
293 |
294 |
295 | @register_model
296 | def vit_base_patch16_384(pretrained=False, **kwargs):
297 | model = VisionTransformer(
298 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
299 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
300 | model.default_cfg = _cfg()
301 | return model
302 |
303 |
304 | @register_model
305 | def vit_large_patch16_224(pretrained=False, **kwargs):
306 | model = VisionTransformer(
307 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
308 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
309 | model.default_cfg = _cfg()
310 | return model
311 |
312 |
313 | @register_model
314 | def vit_large_patch16_384(pretrained=False, **kwargs):
315 | model = VisionTransformer(
316 | img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
317 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
318 | model.default_cfg = _cfg()
319 | return model
320 |
321 |
322 | @register_model
323 | def vit_large_patch16_512(pretrained=False, **kwargs):
324 | model = VisionTransformer(
325 | img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
326 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
327 | model.default_cfg = _cfg()
328 | return model
329 |
--------------------------------------------------------------------------------
/finetune/optim_factory.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import torch
9 | from torch import optim as optim
10 |
11 | from timm.optim.adafactor import Adafactor
12 | from timm.optim.adahessian import Adahessian
13 | from timm.optim.adamp import AdamP
14 | from timm.optim.lookahead import Lookahead
15 | from timm.optim.nadam import Nadam
16 | from timm.optim.novograd import NovoGrad
17 | from timm.optim.nvnovograd import NvNovoGrad
18 | from timm.optim.radam import RAdam
19 | from timm.optim.rmsprop_tf import RMSpropTF
20 | from timm.optim.sgdp import SGDP
21 |
22 | import json
23 |
24 | try:
25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
26 | has_apex = True
27 | except ImportError:
28 | has_apex = False
29 |
30 |
31 | def get_num_layer_for_vit(var_name, num_max_layer):
32 | if var_name in ("cls_token", "mask_token", "pos_embed"):
33 | return 0
34 | elif var_name.startswith("patch_embed"):
35 | return 0
36 | elif var_name.startswith("rel_pos_bias"):
37 | return num_max_layer - 1
38 | elif var_name.startswith("blocks"):
39 | layer_id = int(var_name.split('.')[1])
40 | return layer_id + 1
41 | else:
42 | return num_max_layer - 1
43 |
44 |
45 | class LayerDecayValueAssigner(object):
46 | def __init__(self, values):
47 | self.values = values
48 |
49 | def get_scale(self, layer_id):
50 | return self.values[layer_id]
51 |
52 | def get_layer_id(self, var_name):
53 | return get_num_layer_for_vit(var_name, len(self.values))
54 |
55 |
56 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
57 | parameter_group_names = {}
58 | parameter_group_vars = {}
59 |
60 | for name, param in model.named_parameters():
61 | if not param.requires_grad:
62 | continue # frozen weights
63 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
64 | group_name = "no_decay"
65 | this_weight_decay = 0.
66 | else:
67 | group_name = "decay"
68 | this_weight_decay = weight_decay
69 | if get_num_layer is not None:
70 | layer_id = get_num_layer(name)
71 | group_name = "layer_%d_%s" % (layer_id, group_name)
72 | else:
73 | layer_id = None
74 |
75 | if group_name not in parameter_group_names:
76 | if get_layer_scale is not None:
77 | scale = get_layer_scale(layer_id)
78 | else:
79 | scale = 1.
80 |
81 | parameter_group_names[group_name] = {
82 | "weight_decay": this_weight_decay,
83 | "params": [],
84 | "lr_scale": scale
85 | }
86 | parameter_group_vars[group_name] = {
87 | "weight_decay": this_weight_decay,
88 | "params": [],
89 | "lr_scale": scale
90 | }
91 |
92 | parameter_group_vars[group_name]["params"].append(param)
93 | parameter_group_names[group_name]["params"].append(name)
94 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
95 | return list(parameter_group_vars.values())
96 |
97 |
98 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
99 | opt_lower = args.opt.lower()
100 | weight_decay = args.weight_decay
101 | if weight_decay and filter_bias_and_bn:
102 | skip = {}
103 | if skip_list is not None:
104 | skip = skip_list
105 | elif hasattr(model, 'no_weight_decay'):
106 | skip = model.no_weight_decay()
107 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
108 | weight_decay = 0.
109 | else:
110 | parameters = model.parameters()
111 |
112 | if 'fused' in opt_lower:
113 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
114 |
115 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
116 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
117 | opt_args['eps'] = args.opt_eps
118 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
119 | opt_args['betas'] = args.opt_betas
120 |
121 | print("optimizer settings:", opt_args)
122 |
123 | opt_split = opt_lower.split('_')
124 | opt_lower = opt_split[-1]
125 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
126 | opt_args.pop('eps', None)
127 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
128 | elif opt_lower == 'momentum':
129 | opt_args.pop('eps', None)
130 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
131 | elif opt_lower == 'adam':
132 | optimizer = optim.Adam(parameters, **opt_args)
133 | elif opt_lower == 'adamw':
134 | optimizer = optim.AdamW(parameters, **opt_args)
135 | elif opt_lower == 'nadam':
136 | optimizer = Nadam(parameters, **opt_args)
137 | elif opt_lower == 'radam':
138 | optimizer = RAdam(parameters, **opt_args)
139 | elif opt_lower == 'adamp':
140 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
141 | elif opt_lower == 'sgdp':
142 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
143 | elif opt_lower == 'adadelta':
144 | optimizer = optim.Adadelta(parameters, **opt_args)
145 | elif opt_lower == 'adafactor':
146 | if not args.lr:
147 | opt_args['lr'] = None
148 | optimizer = Adafactor(parameters, **opt_args)
149 | elif opt_lower == 'adahessian':
150 | optimizer = Adahessian(parameters, **opt_args)
151 | elif opt_lower == 'rmsprop':
152 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
153 | elif opt_lower == 'rmsproptf':
154 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
155 | elif opt_lower == 'novograd':
156 | optimizer = NovoGrad(parameters, **opt_args)
157 | elif opt_lower == 'nvnovograd':
158 | optimizer = NvNovoGrad(parameters, **opt_args)
159 | elif opt_lower == 'fusedsgd':
160 | opt_args.pop('eps', None)
161 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
162 | elif opt_lower == 'fusedmomentum':
163 | opt_args.pop('eps', None)
164 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
165 | elif opt_lower == 'fusedadam':
166 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
167 | elif opt_lower == 'fusedadamw':
168 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
169 | elif opt_lower == 'fusedlamb':
170 | optimizer = FusedLAMB(parameters, **opt_args)
171 | elif opt_lower == 'fusednovograd':
172 | opt_args.setdefault('betas', (0.95, 0.98))
173 | optimizer = FusedNovoGrad(parameters, **opt_args)
174 | else:
175 | assert False and "Invalid optimizer"
176 | raise ValueError
177 |
178 | if len(opt_split) > 1:
179 | if opt_split[0] == 'lookahead':
180 | optimizer = Lookahead(optimizer)
181 |
182 | return optimizer
183 |
--------------------------------------------------------------------------------
/finetune/run_class_finetuning.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Based on BEiT, timm, DINO and DeiT code bases
3 | # https://github.com/microsoft/unilm/tree/master/beit
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 |
9 | import argparse
10 | import datetime
11 | import numpy as np
12 | import time
13 | import torch
14 | import torch.backends.cudnn as cudnn
15 | import json
16 | import os
17 |
18 | from pathlib import Path
19 | from collections import OrderedDict
20 |
21 | from timm.data.mixup import Mixup
22 | from timm.models import create_model
23 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
24 | from timm.utils import ModelEma
25 | from optim_factory import create_optimizer, get_parameter_groups, LayerDecayValueAssigner
26 |
27 | from datasets import build_dataset
28 | from engine_for_finetuning import train_one_epoch, evaluate
29 | from utils import NativeScalerWithGradNormCount as NativeScaler
30 | import utils
31 | #from scipy import interpolate
32 | import modeling_finetune
33 |
34 |
35 | def get_args():
36 | parser = argparse.ArgumentParser('MAE fine-tuning and evaluation script for image classification', add_help=False)
37 | parser.add_argument('--batch_size', default=64, type=int)
38 | parser.add_argument('--epochs', default=30, type=int)
39 | parser.add_argument('--update_freq', default=1, type=int)
40 | parser.add_argument('--save_ckpt_freq', default=20, type=int)
41 |
42 | # Model parameters
43 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
44 | help='Name of model to train')
45 |
46 | parser.add_argument('--input_size', default=224, type=int,
47 | help='images input size')
48 |
49 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
50 | help='Dropout rate (default: 0.)')
51 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, metavar='PCT',
52 | help='Attention dropout rate (default: 0.)')
53 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
54 | help='Drop path rate (default: 0.1)')
55 |
56 | parser.add_argument('--disable_eval_during_finetuning', action='store_true', default=False)
57 |
58 | parser.add_argument('--model_ema', action='store_true', default=False)
59 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
60 | parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
61 |
62 | # Optimizer parameters
63 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
64 | help='Optimizer (default: "adamw"')
65 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
66 | help='Optimizer Epsilon (default: 1e-8)')
67 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
68 | help='Optimizer Betas (default: None, use opt default)')
69 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
70 | help='Clip gradient norm (default: None, no clipping)')
71 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
72 | help='SGD momentum (default: 0.9)')
73 | parser.add_argument('--weight_decay', type=float, default=0.05,
74 | help='weight decay (default: 0.05)')
75 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
76 | weight decay. We use a cosine schedule for WD and using a larger decay by
77 | the end of training improves performance for ViTs.""")
78 |
79 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
80 | help='learning rate (default: 1e-3)')
81 | parser.add_argument('--layer_decay', type=float, default=0.75)
82 |
83 | parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
84 | help='warmup learning rate (default: 1e-6)')
85 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
86 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
87 |
88 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
89 | help='epochs to warmup LR, if scheduler supports')
90 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
91 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
92 |
93 | # Augmentation parameters
94 | # TODO: color jitter differs in the MAE paper.
95 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
96 | help='Color jitter factor (default: 0.4)')
97 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
98 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
99 | parser.add_argument('--smoothing', type=float, default=0.1,
100 | help='Label smoothing (default: 0.1)')
101 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
102 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
103 |
104 | # Evaluation parameters
105 | parser.add_argument('--crop_pct', type=float, default=None)
106 |
107 | # * Random Erase params
108 | # TODO: This is not used in MAE.
109 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
110 | help='Random erase prob (default: 0.25)')
111 | parser.add_argument('--remode', type=str, default='pixel',
112 | help='Random erase mode (default: "pixel")')
113 | parser.add_argument('--recount', type=int, default=1,
114 | help='Random erase count (default: 1)')
115 | parser.add_argument('--resplit', action='store_true', default=False,
116 | help='Do not random erase first (clean) augmentation split')
117 |
118 | # * Mixup params
119 | parser.add_argument('--mixup', type=float, default=0.8,
120 | help='mixup alpha, mixup enabled if > 0.')
121 | parser.add_argument('--cutmix', type=float, default=1.0,
122 | help='cutmix alpha, cutmix enabled if > 0.')
123 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
124 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
125 | parser.add_argument('--mixup_prob', type=float, default=1.0,
126 | help='Probability of performing mixup or cutmix when either/both is enabled')
127 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
128 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
129 | parser.add_argument('--mixup_mode', type=str, default='batch',
130 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
131 |
132 | # * Finetuning params
133 | parser.add_argument('--finetune', default='', help='finetune from checkpoint')
134 | parser.add_argument('--model_key', default='model|module|state_dict|teacher', type=str)
135 | parser.add_argument('--eval_momentum', action='store_true',
136 | help='evaluation momentum encoder')
137 | parser.add_argument('--model_prefix', default='', type=str)
138 | parser.add_argument('--init_scale', default=0.001, type=float)
139 | parser.add_argument('--use_mean_pooling', action='store_true')
140 | parser.set_defaults(use_mean_pooling=True)
141 | parser.add_argument('--use_cls', action='store_false', dest='use_mean_pooling')
142 | # TODO: we need to try both. We also need to consider adding batch norm.
143 | # TODO: What does init scale mean?
144 |
145 | # Dataset parameters
146 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
147 | help='dataset path')
148 | parser.add_argument('--eval_data_path', default=None, type=str,
149 | help='dataset path for evaluation')
150 | parser.add_argument('--nb_classes', default=1000, type=int,
151 | help='number of the classification types')
152 | parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
153 |
154 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'],
155 | type=str, help='ImageNet dataset path')
156 | parser.add_argument('--output_dir', default='',
157 | help='path where to save, empty for no saving')
158 | parser.add_argument('--log_dir', default=None,
159 | help='path where to tensorboard log')
160 | parser.add_argument('--device', default='cuda',
161 | help='device to use for training / testing')
162 | parser.add_argument('--seed', default=0, type=int)
163 | parser.add_argument('--resume', default='',
164 | help='resume from checkpoint')
165 | parser.add_argument('--auto_resume', action='store_true')
166 | parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
167 | parser.set_defaults(auto_resume=True)
168 |
169 | parser.add_argument('--save_ckpt', action='store_true')
170 | parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
171 | parser.set_defaults(save_ckpt=True)
172 |
173 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
174 | help='start epoch')
175 | parser.add_argument('--eval', action='store_true',
176 | help='Perform evaluation only')
177 | parser.add_argument('--dist_eval', action='store_true', default=False,
178 | help='Enabling distributed evaluation')
179 | parser.add_argument('--num_workers', default=4, type=int)
180 | parser.add_argument('--pin_mem', action='store_true',
181 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
182 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
183 | parser.set_defaults(pin_mem=True)
184 |
185 | # distributed training parameters
186 | parser.add_argument('--world_size', default=1, type=int,
187 | help='number of distributed processes')
188 | parser.add_argument('--local_rank', default=-1, type=int)
189 | parser.add_argument('--dist_on_itp', action='store_true')
190 | parser.add_argument('--dist_url', default='env://',
191 | help='url used to set up distributed training')
192 |
193 | parser.add_argument('--enable_deepspeed', action='store_true', default=False)
194 |
195 | known_args, _ = parser.parse_known_args()
196 |
197 | if known_args.enable_deepspeed:
198 | try:
199 | import deepspeed
200 | from deepspeed import DeepSpeedConfig
201 | parser = deepspeed.add_config_arguments(parser)
202 | ds_init = deepspeed.initialize
203 | except:
204 | print("Please 'pip install deepspeed==0.4.0'")
205 | exit(0)
206 | else:
207 | ds_init = None
208 |
209 | return parser.parse_args(), ds_init
210 |
211 |
212 | def main(args, ds_init):
213 | utils.init_distributed_mode(args)
214 |
215 | if ds_init is not None:
216 | utils.create_ds_config(args)
217 |
218 | print(args)
219 |
220 | device = torch.device(args.device)
221 |
222 | # fix the seed for reproducibility
223 | seed = args.seed + utils.get_rank()
224 | torch.manual_seed(seed)
225 | np.random.seed(seed)
226 | # random.seed(seed)
227 |
228 | cudnn.benchmark = True
229 |
230 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
231 | if args.disable_eval_during_finetuning:
232 | dataset_val = None
233 | else:
234 | dataset_val, _ = build_dataset(is_train=False, args=args)
235 |
236 | if True: # args.distributed:
237 | num_tasks = utils.get_world_size()
238 | global_rank = utils.get_rank()
239 | sampler_train = torch.utils.data.DistributedSampler(
240 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
241 | )
242 | print("Sampler_train = %s" % str(sampler_train))
243 | if args.dist_eval:
244 | if len(dataset_val) % num_tasks != 0:
245 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
246 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
247 | 'equal num of samples per-process.')
248 | sampler_val = torch.utils.data.DistributedSampler(
249 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
250 | else:
251 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
252 | else:
253 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
254 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
255 |
256 | if global_rank == 0 and args.log_dir is not None:
257 | os.makedirs(args.log_dir, exist_ok=True)
258 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
259 | else:
260 | log_writer = None
261 |
262 | data_loader_train = torch.utils.data.DataLoader(
263 | dataset_train, sampler=sampler_train,
264 | batch_size=args.batch_size,
265 | num_workers=args.num_workers,
266 | pin_memory=args.pin_mem,
267 | drop_last=True,
268 | )
269 |
270 | if dataset_val is not None:
271 | data_loader_val = torch.utils.data.DataLoader(
272 | dataset_val, sampler=sampler_val,
273 | batch_size=int(1.5 * args.batch_size),
274 | num_workers=args.num_workers,
275 | pin_memory=args.pin_mem,
276 | drop_last=False
277 | )
278 | else:
279 | data_loader_val = None
280 |
281 | mixup_fn = None
282 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
283 | if mixup_active:
284 | print("Mixup is activated!")
285 | mixup_fn = Mixup(
286 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
287 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
288 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
289 |
290 | model = create_model(
291 | args.model,
292 | pretrained=False,
293 | num_classes=args.nb_classes,
294 | drop_rate=args.drop,
295 | drop_path_rate=args.drop_path,
296 | attn_drop_rate=args.attn_drop_rate,
297 | drop_block_rate=None,
298 | use_mean_pooling=args.use_mean_pooling,
299 | init_scale=args.init_scale,
300 | use_learnable_pos_emb=True,
301 | )
302 |
303 | patch_size = model.patch_embed.patch_size
304 | print("Patch size = %s" % str(patch_size))
305 | args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
306 | args.patch_size = patch_size
307 |
308 | if args.finetune:
309 | if args.finetune.startswith('https'):
310 | checkpoint = torch.hub.load_state_dict_from_url(
311 | args.finetune, map_location='cpu', check_hash=True)
312 | else:
313 | checkpoint = torch.load(args.finetune, map_location='cpu')
314 |
315 | print("Load ckpt from %s" % args.finetune)
316 | checkpoint_model = None
317 | for model_key in args.model_key.split('|'):
318 | if model_key in checkpoint:
319 | checkpoint_model = checkpoint[model_key]
320 | print("Load state_dict by model_key = %s" % model_key)
321 | break
322 | if checkpoint_model is None:
323 | checkpoint_model = checkpoint
324 | state_dict = model.state_dict()
325 | for k in ['head.weight', 'head.bias']:
326 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
327 | print(f"Removing key {k} from pretrained checkpoint")
328 | del checkpoint_model[k]
329 |
330 | all_keys = list(checkpoint_model.keys())
331 | new_dict = OrderedDict()
332 | for key in all_keys:
333 | if args.eval_momentum:
334 | if key.startswith('module.momentum_encoder.'):
335 | new_dict[key[24:]] = checkpoint_model[key]
336 | else:
337 | if key.startswith('backbone.'):
338 | new_dict[key[9:]] = checkpoint_model[key]
339 | elif key.startswith('encoder.'):
340 | new_dict[key[8:]] = checkpoint_model[key]
341 | elif key.startswith('module.base_encoder.'):
342 | new_dict[key[20:]] = checkpoint_model[key]
343 | elif key.startswith('module.visual.'):
344 | new_dict[key[14:]] = checkpoint_model[key]
345 | else:
346 | new_dict[key] = checkpoint_model[key]
347 | checkpoint_model = new_dict
348 |
349 | # interpolate position embedding
350 | if 'pos_embed' in checkpoint_model:
351 | pos_embed_checkpoint = checkpoint_model['pos_embed']
352 | embedding_size = pos_embed_checkpoint.shape[-1]
353 | num_patches = model.patch_embed.num_patches
354 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
355 | # height (== width) for the checkpoint position embedding
356 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
357 | # height (== width) for the new position embedding
358 | new_size = int(num_patches ** 0.5)
359 | # class_token and dist_token are kept unchanged
360 | if orig_size != new_size:
361 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
362 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
363 | # only the position tokens are interpolated
364 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
365 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
366 | pos_tokens = torch.nn.functional.interpolate(
367 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
368 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
369 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
370 | checkpoint_model['pos_embed'] = new_pos_embed
371 |
372 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
373 | # model.load_state_dict(checkpoint_model, strict=False)
374 |
375 | model.to(device)
376 |
377 | model_ema = None
378 | if args.model_ema:
379 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
380 | model_ema = ModelEma(
381 | model,
382 | decay=args.model_ema_decay,
383 | device='cpu' if args.model_ema_force_cpu else '',
384 | resume='')
385 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
386 |
387 | model_without_ddp = model
388 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
389 |
390 | print("Model = %s" % str(model_without_ddp))
391 | print('number of params:', n_parameters)
392 |
393 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
394 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
395 | args.lr = args.lr * total_batch_size / 256
396 | print("LR = %.8f" % args.lr)
397 | print("Batch size = %d" % total_batch_size)
398 | print("Update frequent = %d" % args.update_freq)
399 | print("Number of training examples = %d" % len(dataset_train))
400 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
401 |
402 | num_layers = model_without_ddp.get_num_layers()
403 | if args.layer_decay < 1.0:
404 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
405 | else:
406 | assigner = None
407 |
408 | if assigner is not None:
409 | print("Assigned values = %s" % str(assigner.values))
410 |
411 | skip_weight_decay_list = model.no_weight_decay()
412 | print("Skip weight decay list: ", skip_weight_decay_list)
413 |
414 | if args.enable_deepspeed:
415 | loss_scaler = None
416 | optimizer_params = get_parameter_groups(
417 | model, args.weight_decay, skip_weight_decay_list,
418 | assigner.get_layer_id if assigner is not None else None,
419 | assigner.get_scale if assigner is not None else None)
420 | model, optimizer, _, _ = ds_init(
421 | args=args, model=model, model_parameters=optimizer_params, dist_init_required=not args.distributed,
422 | )
423 |
424 | print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
425 | assert model.gradient_accumulation_steps() == args.update_freq
426 | else:
427 | if args.distributed:
428 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
429 | model_without_ddp = model.module
430 |
431 | optimizer = create_optimizer(
432 | args, model_without_ddp, skip_list=skip_weight_decay_list,
433 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
434 | get_layer_scale=assigner.get_scale if assigner is not None else None)
435 | loss_scaler = NativeScaler()
436 |
437 | print("Use step level LR scheduler!")
438 | lr_schedule_values = utils.cosine_scheduler(
439 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
440 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
441 | )
442 | if args.weight_decay_end is None:
443 | args.weight_decay_end = args.weight_decay
444 | wd_schedule_values = utils.cosine_scheduler(
445 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
446 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
447 |
448 | if mixup_fn is not None:
449 | # smoothing is handled with mixup label transform
450 | criterion = SoftTargetCrossEntropy()
451 | elif args.smoothing > 0.:
452 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
453 | else:
454 | criterion = torch.nn.CrossEntropyLoss()
455 |
456 | print("criterion = %s" % str(criterion))
457 |
458 | utils.auto_load_model(
459 | args=args, model=model, model_without_ddp=model_without_ddp,
460 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
461 |
462 | if args.eval:
463 | test_stats = evaluate(data_loader_val, model, device)
464 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
465 | exit(0)
466 |
467 | print(f"Start training for {args.epochs} epochs")
468 | start_time = time.time()
469 | max_accuracy = 0.0
470 | for epoch in range(args.start_epoch, args.epochs):
471 | if args.distributed:
472 | data_loader_train.sampler.set_epoch(epoch)
473 | if log_writer is not None:
474 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
475 | train_stats = train_one_epoch(
476 | model, criterion, data_loader_train, optimizer,
477 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
478 | log_writer=log_writer, start_steps=epoch * num_training_steps_per_epoch,
479 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
480 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
481 | )
482 | if args.output_dir and args.save_ckpt:
483 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
484 | utils.save_model(
485 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
486 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
487 | if data_loader_val is not None:
488 | test_stats = evaluate(data_loader_val, model, device)
489 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
490 | if max_accuracy < test_stats["acc1"]:
491 | max_accuracy = test_stats["acc1"]
492 | if args.output_dir and args.save_ckpt:
493 | utils.save_model(
494 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
495 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
496 |
497 | print(f'Max accuracy: {max_accuracy:.2f}%')
498 | if log_writer is not None:
499 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
500 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
501 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
502 |
503 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
504 | **{f'test_{k}': v for k, v in test_stats.items()},
505 | 'epoch': epoch,
506 | 'n_parameters': n_parameters}
507 | else:
508 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
509 | # **{f'test_{k}': v for k, v in test_stats.items()},
510 | 'epoch': epoch,
511 | 'n_parameters': n_parameters}
512 |
513 | if args.log_dir and utils.is_main_process():
514 | if log_writer is not None:
515 | log_writer.flush()
516 | with open(os.path.join(args.log_dir, "log.txt"), mode="a", encoding="utf-8") as f:
517 | f.write(json.dumps(log_stats) + "\n")
518 |
519 | total_time = time.time() - start_time
520 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
521 | print('Training time {}'.format(total_time_str))
522 |
523 |
524 | if __name__ == '__main__':
525 | opts, ds_init = get_args()
526 | if opts.output_dir:
527 | Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
528 | main(opts, ds_init)
529 |
--------------------------------------------------------------------------------
/finetune/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # --------------------------------------------------------
3 | # Based on BEiT, timm, DINO and DeiT code bases
4 | # https://github.com/microsoft/unilm/tree/master/beit
5 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
6 | # https://github.com/facebookresearch/deit
7 | # https://github.com/facebookresearch/dino
8 | # --------------------------------------------------------'
9 | import io
10 | import os
11 | import math
12 | import time
13 | import json
14 | from collections import defaultdict, deque
15 | import datetime
16 | import numpy as np
17 | from timm.utils import get_state_dict
18 |
19 | from pathlib import Path
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch._six import inf
24 |
25 | import random
26 |
27 | from tensorboardX import SummaryWriter
28 |
29 |
30 | class SmoothedValue(object):
31 | """Track a series of values and provide access to smoothed values over a
32 | window or the global series average.
33 | """
34 |
35 | def __init__(self, window_size=20, fmt=None):
36 | if fmt is None:
37 | fmt = "{median:.4f} ({global_avg:.4f})"
38 | self.deque = deque(maxlen=window_size)
39 | self.total = 0.0
40 | self.count = 0
41 | self.fmt = fmt
42 |
43 | def update(self, value, n=1):
44 | self.deque.append(value)
45 | self.count += n
46 | self.total += value * n
47 |
48 | def synchronize_between_processes(self):
49 | """
50 | Warning: does not synchronize the deque!
51 | """
52 | if not is_dist_avail_and_initialized():
53 | return
54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55 | dist.barrier()
56 | dist.all_reduce(t)
57 | t = t.tolist()
58 | self.count = int(t[0])
59 | self.total = t[1]
60 |
61 | @property
62 | def median(self):
63 | d = torch.tensor(list(self.deque))
64 | return d.median().item()
65 |
66 | @property
67 | def avg(self):
68 | d = torch.tensor(list(self.deque), dtype=torch.float32)
69 | return d.mean().item()
70 |
71 | @property
72 | def global_avg(self):
73 | return self.total / self.count
74 |
75 | @property
76 | def max(self):
77 | return max(self.deque)
78 |
79 | @property
80 | def value(self):
81 | return self.deque[-1]
82 |
83 | def __str__(self):
84 | return self.fmt.format(
85 | median=self.median,
86 | avg=self.avg,
87 | global_avg=self.global_avg,
88 | max=self.max,
89 | value=self.value)
90 |
91 |
92 | class MetricLogger(object):
93 | def __init__(self, delimiter="\t"):
94 | self.meters = defaultdict(SmoothedValue)
95 | self.delimiter = delimiter
96 |
97 | def update(self, **kwargs):
98 | for k, v in kwargs.items():
99 | if v is None:
100 | continue
101 | if isinstance(v, torch.Tensor):
102 | v = v.item()
103 | assert isinstance(v, (float, int))
104 | self.meters[k].update(v)
105 |
106 | def __getattr__(self, attr):
107 | if attr in self.meters:
108 | return self.meters[attr]
109 | if attr in self.__dict__:
110 | return self.__dict__[attr]
111 | raise AttributeError("'{}' object has no attribute '{}'".format(
112 | type(self).__name__, attr))
113 |
114 | def __str__(self):
115 | loss_str = []
116 | for name, meter in self.meters.items():
117 | loss_str.append(
118 | "{}: {}".format(name, str(meter))
119 | )
120 | return self.delimiter.join(loss_str)
121 |
122 | def synchronize_between_processes(self):
123 | for meter in self.meters.values():
124 | meter.synchronize_between_processes()
125 |
126 | def add_meter(self, name, meter):
127 | self.meters[name] = meter
128 |
129 | def log_every(self, iterable, print_freq, header=None):
130 | i = 0
131 | if not header:
132 | header = ''
133 | start_time = time.time()
134 | end = time.time()
135 | iter_time = SmoothedValue(fmt='{avg:.4f}')
136 | data_time = SmoothedValue(fmt='{avg:.4f}')
137 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
138 | log_msg = [
139 | header,
140 | '[{0' + space_fmt + '}/{1}]',
141 | 'eta: {eta}',
142 | '{meters}',
143 | 'time: {time}',
144 | 'data: {data}'
145 | ]
146 | if torch.cuda.is_available():
147 | log_msg.append('max mem: {memory:.0f}')
148 | log_msg = self.delimiter.join(log_msg)
149 | MB = 1024.0 * 1024.0
150 | for obj in iterable:
151 | data_time.update(time.time() - end)
152 | yield obj
153 | iter_time.update(time.time() - end)
154 | if i % print_freq == 0 or i == len(iterable) - 1:
155 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
156 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
157 | if torch.cuda.is_available():
158 | print(log_msg.format(
159 | i, len(iterable), eta=eta_string,
160 | meters=str(self),
161 | time=str(iter_time), data=str(data_time),
162 | memory=torch.cuda.max_memory_allocated() / MB))
163 | else:
164 | print(log_msg.format(
165 | i, len(iterable), eta=eta_string,
166 | meters=str(self),
167 | time=str(iter_time), data=str(data_time)))
168 | i += 1
169 | end = time.time()
170 | total_time = time.time() - start_time
171 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
172 | print('{} Total time: {} ({:.4f} s / it)'.format(
173 | header, total_time_str, total_time / len(iterable)))
174 |
175 |
176 | class TensorboardLogger(object):
177 | def __init__(self, log_dir):
178 | self.writer = SummaryWriter(logdir=log_dir)
179 | self.step = 0
180 |
181 | def set_step(self, step=None):
182 | if step is not None:
183 | self.step = step
184 | else:
185 | self.step += 1
186 |
187 | def update(self, head='scalar', step=None, **kwargs):
188 | for k, v in kwargs.items():
189 | if v is None:
190 | continue
191 | if isinstance(v, torch.Tensor):
192 | v = v.item()
193 | assert isinstance(v, (float, int))
194 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
195 |
196 | def flush(self):
197 | self.writer.flush()
198 |
199 | def seed_worker(worker_id):
200 | worker_seed = torch.initial_seed() % 2**32
201 | np.random.seed(worker_seed)
202 | random.seed(worker_seed)
203 |
204 | def _load_checkpoint_for_ema(model_ema, checkpoint):
205 | """
206 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
207 | """
208 | mem_file = io.BytesIO()
209 | torch.save(checkpoint, mem_file)
210 | mem_file.seek(0)
211 | model_ema._load_checkpoint(mem_file)
212 |
213 |
214 | def setup_for_distributed(is_master):
215 | """
216 | This function disables printing when not in master process
217 | """
218 | import builtins as __builtin__
219 | builtin_print = __builtin__.print
220 |
221 | def print(*args, **kwargs):
222 | force = kwargs.pop('force', False)
223 | if is_master or force:
224 | builtin_print(*args, **kwargs)
225 |
226 | __builtin__.print = print
227 |
228 |
229 | def is_dist_avail_and_initialized():
230 | if not dist.is_available():
231 | return False
232 | if not dist.is_initialized():
233 | return False
234 | return True
235 |
236 |
237 | def get_world_size():
238 | if not is_dist_avail_and_initialized():
239 | return 1
240 | return dist.get_world_size()
241 |
242 |
243 | def get_rank():
244 | if not is_dist_avail_and_initialized():
245 | return 0
246 | return dist.get_rank()
247 |
248 |
249 | def is_main_process():
250 | return get_rank() == 0
251 |
252 |
253 | def save_on_master(*args, **kwargs):
254 | if is_main_process():
255 | torch.save(*args, **kwargs)
256 |
257 |
258 | def init_distributed_mode(args):
259 | if args.dist_on_itp:
260 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
261 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
262 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
263 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
264 | os.environ['LOCAL_RANK'] = str(args.gpu)
265 | os.environ['RANK'] = str(args.rank)
266 | os.environ['WORLD_SIZE'] = str(args.world_size)
267 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
268 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
269 | args.rank = int(os.environ["RANK"])
270 | args.world_size = int(os.environ['WORLD_SIZE'])
271 | args.gpu = int(os.environ['LOCAL_RANK'])
272 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
273 | elif 'SLURM_PROCID' in os.environ:
274 | args.rank = int(os.environ['SLURM_PROCID'])
275 | args.gpu = args.rank % torch.cuda.device_count()
276 | else:
277 | print('Not using distributed mode')
278 | args.distributed = False
279 | return
280 |
281 | args.distributed = True
282 |
283 | torch.cuda.set_device(args.gpu)
284 | args.dist_backend = 'nccl'
285 | print('| distributed init (rank {}): {}, gpu {}'.format(
286 | args.rank, args.dist_url, args.gpu), flush=True)
287 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
288 | world_size=args.world_size, rank=args.rank)
289 | torch.distributed.barrier()
290 | setup_for_distributed(args.rank == 0)
291 |
292 |
293 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
294 | missing_keys = []
295 | unexpected_keys = []
296 | error_msgs = []
297 | # copy state_dict so _load_from_state_dict can modify it
298 | metadata = getattr(state_dict, '_metadata', None)
299 | state_dict = state_dict.copy()
300 | if metadata is not None:
301 | state_dict._metadata = metadata
302 |
303 | def load(module, prefix=''):
304 | local_metadata = {} if metadata is None else metadata.get(
305 | prefix[:-1], {})
306 | module._load_from_state_dict(
307 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
308 | for name, child in module._modules.items():
309 | if child is not None:
310 | load(child, prefix + name + '.')
311 |
312 | load(model, prefix=prefix)
313 |
314 | warn_missing_keys = []
315 | ignore_missing_keys = []
316 | for key in missing_keys:
317 | keep_flag = True
318 | for ignore_key in ignore_missing.split('|'):
319 | if ignore_key in key:
320 | keep_flag = False
321 | break
322 | if keep_flag:
323 | warn_missing_keys.append(key)
324 | else:
325 | ignore_missing_keys.append(key)
326 |
327 | missing_keys = warn_missing_keys
328 |
329 | if len(missing_keys) > 0:
330 | print("Weights of {} not initialized from pretrained model: {}".format(
331 | model.__class__.__name__, missing_keys))
332 | if len(unexpected_keys) > 0:
333 | print("Weights from pretrained model not used in {}: {}".format(
334 | model.__class__.__name__, unexpected_keys))
335 | if len(ignore_missing_keys) > 0:
336 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
337 | model.__class__.__name__, ignore_missing_keys))
338 | if len(error_msgs) > 0:
339 | print('\n'.join(error_msgs))
340 |
341 |
342 | class NativeScalerWithGradNormCount:
343 | state_dict_key = "amp_scaler"
344 |
345 | def __init__(self):
346 | self._scaler = torch.cuda.amp.GradScaler()
347 |
348 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
349 | self._scaler.scale(loss).backward(create_graph=create_graph)
350 | if update_grad:
351 | if clip_grad is not None:
352 | assert parameters is not None
353 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
354 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
355 | else:
356 | self._scaler.unscale_(optimizer)
357 | norm = get_grad_norm_(parameters)
358 | self._scaler.step(optimizer)
359 | self._scaler.update()
360 | else:
361 | norm = None
362 | return norm
363 |
364 | def state_dict(self):
365 | return self._scaler.state_dict()
366 |
367 | def load_state_dict(self, state_dict):
368 | self._scaler.load_state_dict(state_dict)
369 |
370 |
371 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
372 | if isinstance(parameters, torch.Tensor):
373 | parameters = [parameters]
374 | parameters = [p for p in parameters if p.grad is not None]
375 | norm_type = float(norm_type)
376 | if len(parameters) == 0:
377 | return torch.tensor(0.)
378 | device = parameters[0].grad.device
379 | if norm_type == inf:
380 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
381 | else:
382 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
383 | return total_norm
384 |
385 |
386 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
387 | start_warmup_value=0, warmup_steps=-1):
388 | warmup_schedule = np.array([])
389 | warmup_iters = warmup_epochs * niter_per_ep
390 | if warmup_steps > 0:
391 | warmup_iters = warmup_steps
392 | print("Set warmup steps = %d" % warmup_iters)
393 | if warmup_epochs > 0:
394 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
395 |
396 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
397 | schedule = np.array(
398 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
399 |
400 | schedule = np.concatenate((warmup_schedule, schedule))
401 |
402 | assert len(schedule) == epochs * niter_per_ep
403 | return schedule
404 |
405 |
406 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
407 | output_dir = Path(args.output_dir)
408 | epoch_name = str(epoch)
409 | if loss_scaler is not None:
410 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
411 | for checkpoint_path in checkpoint_paths:
412 | to_save = {
413 | 'model': model_without_ddp.state_dict(),
414 | 'optimizer': optimizer.state_dict(),
415 | 'epoch': epoch,
416 | 'scaler': loss_scaler.state_dict(),
417 | 'args': args,
418 | }
419 |
420 | if model_ema is not None:
421 | to_save['model_ema'] = get_state_dict(model_ema)
422 |
423 | save_on_master(to_save, checkpoint_path)
424 | else:
425 | client_state = {'epoch': epoch}
426 | if model_ema is not None:
427 | client_state['model_ema'] = get_state_dict(model_ema)
428 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
429 |
430 |
431 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
432 | output_dir = Path(args.output_dir)
433 | if loss_scaler is not None:
434 | # torch.amp
435 | if args.auto_resume and len(args.resume) == 0:
436 | import glob
437 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
438 | latest_ckpt = -1
439 | for ckpt in all_checkpoints:
440 | t = ckpt.split('-')[-1].split('.')[0]
441 | if t.isdigit():
442 | latest_ckpt = max(int(t), latest_ckpt)
443 | if latest_ckpt >= 0:
444 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
445 | print("Auto resume checkpoint: %s" % args.resume)
446 |
447 | if args.resume:
448 | if args.resume.startswith('https'):
449 | checkpoint = torch.hub.load_state_dict_from_url(
450 | args.resume, map_location='cpu', check_hash=True)
451 | else:
452 | checkpoint = torch.load(args.resume, map_location='cpu')
453 | model_without_ddp.load_state_dict(checkpoint['model'])
454 | print("Resume checkpoint %s" % args.resume)
455 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
456 | optimizer.load_state_dict(checkpoint['optimizer'])
457 | args.start_epoch = checkpoint['epoch'] + 1
458 | if hasattr(args, 'model_ema') and args.model_ema:
459 | _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
460 | if 'scaler' in checkpoint:
461 | loss_scaler.load_state_dict(checkpoint['scaler'])
462 | print("With optim & sched!")
463 | else:
464 | # deepspeed, only support '--auto_resume'.
465 | if args.auto_resume:
466 | import glob
467 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
468 | latest_ckpt = -1
469 | for ckpt in all_checkpoints:
470 | t = ckpt.split('-')[-1].split('.')[0]
471 | if t.isdigit():
472 | latest_ckpt = max(int(t), latest_ckpt)
473 | if latest_ckpt >= 0:
474 | args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
475 | print("Auto resume checkpoint: %d" % latest_ckpt)
476 | _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
477 | args.start_epoch = client_states['epoch'] + 1
478 | if model_ema is not None:
479 | if args.model_ema:
480 | _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
481 |
482 |
483 | def create_ds_config(args):
484 | args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
485 | with open(args.deepspeed_config, mode="w") as writer:
486 | ds_config = {
487 | "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
488 | "train_micro_batch_size_per_gpu": args.batch_size,
489 | "steps_per_print": 1000,
490 | "optimizer": {
491 | "type": "Adam",
492 | "adam_w_mode": True,
493 | "params": {
494 | "lr": args.lr,
495 | "weight_decay": args.weight_decay,
496 | "bias_correction": True,
497 | "betas": [
498 | 0.9,
499 | 0.999
500 | ],
501 | "eps": 1e-8
502 | }
503 | },
504 | "fp16": {
505 | "enabled": True,
506 | "loss_scale": 0,
507 | "initial_scale_power": 7,
508 | "loss_scale_window": 128
509 | }
510 | }
511 |
512 | writer.write(json.dumps(ds_config, indent=2))
513 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/ExtreMA/bf30b64ec0046524d0c373b761df323b9328f8dc/lib/__init__.py
--------------------------------------------------------------------------------
/lib/augment.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | import numpy as np
3 |
4 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
5 | std=[0.229, 0.224, 0.225])
6 |
7 | class SpatialConsistentColorAug:
8 | """Take two random crops of one image"""
9 |
10 | def __init__(self, crop_min):
11 | self.base_transform = transforms.Compose([
12 | transforms.RandomResizedCrop(224, scale=(crop_min, 1.)),
13 | transforms.RandomHorizontalFlip()])
14 | self.color_aug = transforms.Compose([transforms.RandomApply([
15 | transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
16 | ], p=0.8),
17 | transforms.RandomGrayscale(p=0.2),
18 | transforms.ToTensor(),
19 | normalize])
20 | self.to_tensor = transforms.Compose([transforms.ToTensor(),
21 | normalize])
22 |
23 | def __call__(self, x):
24 | im = self.base_transform(x)
25 | im1 = self.color_aug(im)
26 | im2 = self.color_aug(im)
27 | return [im1, im2]
28 |
29 | def create_augmentation(args):
30 |
31 | if args.aug_centercrop:
32 | augmentation = [
33 | transforms.Resize(256),
34 | transforms.CenterCrop(224),
35 | transforms.ToTensor(),
36 | normalize
37 | ]
38 |
39 | if args.aug_spatialconsistent_color:
40 | return SpatialConsistentColorAug(args.crop_min)
41 |
42 | if args.aug_spatial:
43 | augmentation = [
44 | transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)),
45 | transforms.RandomHorizontalFlip(),
46 | transforms.ToTensor(),
47 | normalize
48 | ]
49 |
50 | return transforms.Compose(augmentation)
51 |
--------------------------------------------------------------------------------
/lib/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import random
5 | import torch.nn.functional as F
6 | import torch.distributed as dist
7 |
8 | class ExtreMA(nn.Module):
9 | """
10 | Build the model
11 | """
12 | def __init__(self, base_encoder, ema_encoder, proj_dim=256, mlp_dim=4096, T=1., mask_ratio=0.8, num_masks=1, disjoint=True):
13 | """
14 | dim: feature dimension (default: 256)
15 | mlp_dim: hidden dimension in MLPs (default: 4096)
16 | T: softmax temperature (default: 1.0)
17 | """
18 | super(ExtreMA, self).__init__()
19 |
20 | self.T = T
21 | self.mask_ratio = mask_ratio
22 | self.num_masks = num_masks
23 | self.disjoint_sampling = disjoint
24 |
25 | # build encoders
26 | self.base_encoder = base_encoder(num_classes=mlp_dim)
27 | self.momentum_encoder = ema_encoder(num_classes=mlp_dim)
28 | self.base_encoder.student=True
29 | self.momentum_encoder.student=False
30 |
31 | hidden_dim = self.base_encoder.norm.weight.data.shape[0]
32 | del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer
33 | # projectors
34 | self.base_encoder.head = self._build_proj(3, hidden_dim, mlp_dim, proj_dim)
35 | self.momentum_encoder.head = self._build_proj(3, hidden_dim, mlp_dim, proj_dim)
36 | # predictor
37 | self.predictor = self._build_pred(2, proj_dim, mlp_dim, proj_dim)
38 |
39 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
40 | param_m.data.copy_(param_b.data) # initialize
41 | param_m.requires_grad = False # not update by gradient
42 |
43 | def _build_pred(self, num_layers, input_dim, mlp_dim, output_dim):
44 | layers = []
45 | for l in range(num_layers):
46 | dim1 = input_dim if l == 0 else mlp_dim
47 | dim2 = output_dim if l == num_layers - 1 else mlp_dim
48 |
49 | layers.append(nn.Linear(dim1, dim2, bias=True))
50 |
51 | if l < num_layers - 1:
52 | layers.append(nn.LayerNorm(dim2))
53 | layers.append(nn.GELU())
54 |
55 | mlp = nn.Sequential(*layers)
56 | mlp.apply(self._init_weights)
57 | return mlp
58 |
59 | def _build_proj(self, num_layers, input_dim, mlp_dim, output_dim):
60 | layers = []
61 | for l in range(num_layers):
62 | dim1 = input_dim if l == 0 else mlp_dim
63 | dim2 = output_dim if l == num_layers - 1 else mlp_dim
64 |
65 | layers.append(nn.Linear(dim1, dim2, bias=True))
66 |
67 | if l < num_layers - 1:
68 | layers.append(nn.LayerNorm(dim2))
69 | layers.append(nn.GELU())
70 |
71 | mlp = nn.Sequential(*layers)
72 | mlp.apply(self._init_weights)
73 | return mlp
74 |
75 | @torch.no_grad()
76 | def _update_momentum_encoder(self, m):
77 | """Momentum update of the momentum encoder"""
78 | for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
79 | param_m.data = param_m.data * m + param_b.data * (1. - m)
80 |
81 | def contrastive_loss(self, q, k):
82 | # normalize
83 | q = nn.functional.normalize(q, dim=1)
84 | k = nn.functional.normalize(k, dim=1)
85 | # gather all targets
86 | k = concat_all_gather(k)
87 | # Einstein sum is more intuitive
88 | logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
89 | N = logits.shape[0] # batch size per GPU
90 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
91 | return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)
92 |
93 | def byol_loss(self, q, k):
94 | # normalize
95 | q = nn.functional.normalize(q, dim=1)
96 | k = nn.functional.normalize(k, dim=1)
97 | loss = ((q - k) ** 2).sum(dim=-1)
98 | return loss.mean()
99 |
100 | def _init_weights(self, m):
101 | if isinstance(m, nn.Linear):
102 | # we use xavier_uniform following official JAX ViT:
103 | torch.nn.init.xavier_uniform_(m.weight)
104 | if isinstance(m, nn.Linear) and m.bias is not None:
105 | nn.init.constant_(m.bias, 0)
106 | elif isinstance(m, nn.LayerNorm):
107 | if m.weight is not None:
108 | nn.init.constant_(m.bias, 0)
109 | nn.init.constant_(m.weight, 1.0)
110 |
111 | @torch.jit.ignore
112 | def no_weight_decay(self):
113 | return {'base_encoder.' + k for k in self.base_encoder.no_weight_decay()}
114 |
115 | @torch.no_grad()
116 | def generate_mask(self, x, num_masks, mask_ratio=0.):
117 | if mask_ratio > 0:
118 | if self.disjoint_sampling:
119 | view_size = int(196 * (1 - mask_ratio))
120 | B = x.size(0)
121 | device = x.get_device()
122 | noise = torch.rand(B, 196, device=device)
123 | mask_index = torch.argsort(noise, dim=1) # consider cls token
124 | masks = []
125 | for i in range(num_masks):
126 | # 196 patches are hard-coded
127 | mask = mask_index[:, view_size*i:view_size*(i+1)]
128 | mask = mask.long()
129 | masks.append(mask)
130 | else:
131 | masks = []
132 | for i in range(num_masks):
133 | # 196 patches are hard-coded
134 | B = x.size(0)
135 | device = x.get_device()
136 | noise = torch.rand(B, 196, device=device)
137 | mask_index = torch.argsort(noise, dim=1)
138 | mask = mask_index[:, :int(196*(1-mask_ratio))] # consider the cls token
139 | mask = mask.long()
140 | masks.append(mask)
141 | else:
142 | masks = None
143 |
144 | return masks
145 |
146 | def forward(self, x, m, loss='byol'):
147 | """
148 | Input:
149 | x1: first views of images
150 | x2: second views of images
151 | m: ema momentum
152 | Output:
153 | loss
154 | """
155 | if isinstance(x, list):
156 | x1 = x[0]
157 | x2 = x[1]
158 | else:
159 | x1 = x
160 | x2 = x
161 |
162 | # compute features
163 | B,_,_,_ = x1.size()
164 | device = x1.get_device()
165 |
166 | mask_s = self.generate_mask(x1, self.num_masks, self.mask_ratio)
167 | mask_t = self.generate_mask(x1, 1, 0.)
168 |
169 | q1 = self.predictor(self.base_encoder(x1, mask_s, self.mask_ratio))
170 | with torch.no_grad(): # no gradient
171 | self._update_momentum_encoder(m) # update the momentum encoder
172 | k2 = self.momentum_encoder(x2, mask_t)
173 |
174 | if loss == "byol":
175 | q1 = torch.chunk(q1, self.num_masks, dim=0)
176 | loss = 0.
177 | for q1i in q1:
178 | loss += self.byol_loss(q1i, k2)
179 | return loss / self.num_masks
180 |
181 | elif loss == 'infonce':
182 | q1 = torch.chunk(q1, self.local_crops, dim=0)
183 | loss = 0.
184 | for q1i in q1:
185 | loss += self.contrastive_loss(q1i, k2)
186 | return loss / self.num_masks
187 |
188 | # utils
189 | @torch.no_grad()
190 | def concat_all_gather(tensor):
191 | """
192 | Performs all_gather operation on the provided tensors.
193 | *** Warning ***: torch.distributed.all_gather has no gradient.
194 | """
195 | tensors_gather = [torch.ones_like(tensor)
196 | for _ in range(torch.distributed.get_world_size())]
197 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
198 |
199 | output = torch.cat(tensors_gather, dim=0)
200 | return output
201 |
202 | # utils
203 | @torch.no_grad()
204 | def mean_all_gather(tensor):
205 | """
206 | Performs all_gather operation on the provided tensors.
207 | *** Warning ***: torch.distributed.all_gather has no gradient.
208 | """
209 | # print(tensor.size())
210 | tensors_gather = [torch.ones_like(tensor)
211 | for _ in range(torch.distributed.get_world_size())]
212 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
213 |
214 | output = torch.mean(torch.cat(tensors_gather,dim=0))
215 | return output
216 |
--------------------------------------------------------------------------------
/lib/dataload_optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class _RepeatSampler(object):
4 |
5 | def __init__(self, sampler):
6 | self.sampler = sampler
7 |
8 | def __iter__(self):
9 | while True:
10 | yield from iter(self.sampler)
11 |
12 | class PersistentDataLoader(torch.utils.data.dataloader.DataLoader):
13 |
14 | def __init__(self, *args, **kwargs):
15 | print('persistent dataloader')
16 | super().__init__(*args, **kwargs)
17 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
18 | self.iterator = super().__iter__()
19 |
20 | def __len__(self):
21 | return len(self.batch_sampler.sampler)
22 |
23 | def __iter__(self):
24 | for i in range(len(self)):
25 | yield next(self.iterator)
26 |
27 | class SoftwarePipeline(object):
28 |
29 | def __init__(self, dataloader):
30 | self.dataloader = dataloader
31 | self.stream = None
32 |
33 | def __len__(self):
34 | return len(self.dataloader)
35 |
36 | def __iter__(self):
37 | if self.stream is None:
38 | self.stream = torch.cuda.Stream()
39 | first = True
40 | for next_input, next_target in self.dataloader:
41 | with torch.cuda.stream(self.stream):
42 | next_target = next_target.cuda(non_blocking=True)
43 | if isinstance(next_input, list):
44 | for i in range(len(next_input)):
45 | next_input[i] = next_input[i].cuda(non_blocking=True)
46 | else:
47 | next_input = next_input.cuda(non_blocking=True)
48 | if not first:
49 | yield input, target
50 | else:
51 | first = False
52 | torch.cuda.current_stream().wait_stream(self.stream)
53 | input = next_input
54 | target = next_target
55 | yield input, target
56 |
--------------------------------------------------------------------------------
/lib/logger.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import os
4 | import sys
5 | from termcolor import colored
6 |
7 |
8 | class _ColorfulFormatter(logging.Formatter):
9 | def __init__(self, *args, **kwargs):
10 | self._root_name = kwargs.pop("root_name") + "."
11 | self._abbrev_name = kwargs.pop("abbrev_name", "")
12 | if len(self._abbrev_name):
13 | self._abbrev_name = self._abbrev_name + "."
14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
15 |
16 | def formatMessage(self, record):
17 | record.name = record.name.replace(self._root_name, self._abbrev_name)
18 | log = super(_ColorfulFormatter, self).formatMessage(record)
19 | if record.levelno == logging.WARNING:
20 | prefix = colored("WARNING", "red", attrs=["blink"])
21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
23 | else:
24 | return log
25 | return prefix + " " + log
26 |
27 | # so that calling setup_logger multiple times won't add many handlers
28 | @functools.lru_cache()
29 | def setup_logger(
30 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None
31 | ):
32 | """
33 | Initialize the detectron2 logger and set its verbosity level to "INFO".
34 |
35 | Args:
36 | output (str): a file name or a directory to save log. If None, will not save log file.
37 | If ends with ".txt" or ".log", assumed to be a file name.
38 | Otherwise, logs will be saved to `output/log.txt`.
39 | name (str): the root module name of this logger
40 |
41 | Returns:
42 | logging.Logger: a logger
43 | """
44 | logger = logging.getLogger(name)
45 | logger.setLevel(logging.DEBUG)
46 | logger.propagate = False
47 |
48 | if abbrev_name is None:
49 | abbrev_name = name
50 |
51 | plain_formatter = logging.Formatter(
52 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
53 | )
54 | # stdout logging: master only
55 | if distributed_rank == 0:
56 | ch = logging.StreamHandler(stream=sys.stdout)
57 | ch.setLevel(logging.DEBUG)
58 | if color:
59 | formatter = _ColorfulFormatter(
60 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
61 | datefmt="%m/%d %H:%M:%S",
62 | root_name=name,
63 | abbrev_name=str(abbrev_name),
64 | )
65 | else:
66 | formatter = plain_formatter
67 | ch.setFormatter(formatter)
68 | logger.addHandler(ch)
69 |
70 | # file logging: all workers
71 | if output is not None:
72 | if output.endswith(".txt") or output.endswith(".log"):
73 | filename = output
74 | else:
75 | filename = os.path.join(output, "log.txt")
76 | if distributed_rank > 0:
77 | filename = filename + f".rank{distributed_rank}"
78 | os.makedirs(os.path.dirname(filename), exist_ok=True)
79 |
80 | fh = logging.StreamHandler(_cached_log_stream(filename))
81 | fh.setLevel(logging.DEBUG)
82 | fh.setFormatter(plain_formatter)
83 | logger.addHandler(fh)
84 |
85 | return logger
86 |
87 |
88 | # cache the opened file object, so that different calls to `setup_logger`
89 | # with the same file name can safely write to the same file.
90 | @functools.lru_cache(maxsize=None)
91 | def _cached_log_stream(filename):
92 | return open(filename, "a")
--------------------------------------------------------------------------------
/lib/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import builtins
4 | import datetime
5 |
6 | import torch
7 | from torch._six import inf
8 | import torch.distributed as dist
9 |
10 | # --------------------------------------------------------
11 | # 2D sine-cosine position embedding
12 | # References:
13 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
14 | # MoCo v3: https://github.com/facebookresearch/moco-v3
15 | # --------------------------------------------------------
16 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
17 | """
18 | grid_size: int of the grid height and width
19 | return:
20 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
21 | """
22 | grid_h = np.arange(grid_size, dtype=np.float32)
23 | grid_w = np.arange(grid_size, dtype=np.float32)
24 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
25 | grid = np.stack(grid, axis=0)
26 |
27 | grid = grid.reshape([2, 1, grid_size, grid_size])
28 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
29 | if cls_token:
30 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
31 | return pos_embed
32 |
33 |
34 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
35 | assert embed_dim % 2 == 0
36 |
37 | # use half of dimensions to encode grid_h
38 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
39 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
40 |
41 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
42 | return emb
43 |
44 |
45 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
46 | """
47 | embed_dim: output dimension for each position
48 | pos: a list of positions to be encoded: size (M,)
49 | out: (M, D)
50 | """
51 | assert embed_dim % 2 == 0
52 | omega = np.arange(embed_dim // 2, dtype=np.float)
53 | omega /= embed_dim / 2.
54 | omega = 1. / 10000**omega # (D/2,)
55 |
56 | pos = pos.reshape(-1) # (M,)
57 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
58 |
59 | emb_sin = np.sin(out) # (M, D/2)
60 | emb_cos = np.cos(out) # (M, D/2)
61 |
62 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
63 | return emb
64 |
65 |
66 | # --------------------------------------------------------
67 | # Interpolate position embeddings for high-resolution
68 | # References:
69 | # DeiT: https://github.com/facebookresearch/deit
70 | # --------------------------------------------------------
71 | def interpolate_pos_embed(model, checkpoint_model):
72 | if 'pos_embed' in checkpoint_model:
73 | pos_embed_checkpoint = checkpoint_model['pos_embed']
74 | embedding_size = pos_embed_checkpoint.shape[-1]
75 | num_patches = model.patch_embed.num_patches
76 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
77 | # height (== width) for the checkpoint position embedding
78 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
79 | # height (== width) for the new position embedding
80 | new_size = int(num_patches ** 0.5)
81 | # class_token and dist_token are kept unchanged
82 | if orig_size != new_size:
83 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
84 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
85 | # only the position tokens are interpolated
86 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
87 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
88 | pos_tokens = torch.nn.functional.interpolate(
89 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
90 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
91 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
92 | checkpoint_model['pos_embed'] = new_pos_embed
93 |
94 | def clip_gradients(parameters, clip):
95 | norms = []
96 | for p in parameters:
97 | if p.grad is not None:
98 | param_norm = p.grad.data.norm(2)
99 | norms.append(param_norm)
100 | clip_coef = clip / (param_norm + 1e-6)
101 | if clip_coef < 1:
102 | p.grad.data.mul_(clip_coef)
103 | return norms
104 |
105 | class NativeScalerWithGradNormCount:
106 | state_dict_key = "amp_scaler"
107 |
108 | def __init__(self):
109 | self._scaler = torch.cuda.amp.GradScaler()
110 |
111 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
112 | self._scaler.scale(loss).backward(create_graph=create_graph)
113 | if update_grad:
114 | if clip_grad is not None:
115 | assert parameters is not None
116 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
117 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
118 | #norms = clip_gradients(parameters, clip_grad)
119 | #norm = torch.norm(torch.stack(norms), 2)
120 | else:
121 | self._scaler.unscale_(optimizer)
122 | norm = get_grad_norm_(parameters)
123 | self._scaler.step(optimizer)
124 | self._scaler.update()
125 | else:
126 | norm = None
127 | return norm
128 |
129 | def state_dict(self):
130 | return self._scaler.state_dict()
131 |
132 | def load_state_dict(self, state_dict):
133 | self._scaler.load_state_dict(state_dict)
134 |
135 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
136 | if isinstance(parameters, torch.Tensor):
137 | parameters = [parameters]
138 | parameters = [p for p in parameters if p.grad is not None]
139 | norm_type = float(norm_type)
140 | if len(parameters) == 0:
141 | return torch.tensor(0.)
142 | device = parameters[0].grad.device
143 | if norm_type == inf:
144 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
145 | else:
146 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
147 | return total_norm
148 |
149 | def init_distributed_mode(args):
150 | if args.dist_on_itp:
151 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
152 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
153 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
154 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
155 | os.environ['LOCAL_RANK'] = str(args.gpu)
156 | os.environ['RANK'] = str(args.rank)
157 | os.environ['WORLD_SIZE'] = str(args.world_size)
158 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
159 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
160 | args.rank = int(os.environ["RANK"])
161 | args.world_size = int(os.environ['WORLD_SIZE'])
162 | args.gpu = int(os.environ['LOCAL_RANK'])
163 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
164 | elif 'SLURM_PROCID' in os.environ:
165 | args.rank = int(os.environ['SLURM_PROCID'])
166 | args.gpu = args.rank % torch.cuda.device_count()
167 | else:
168 | print('Not using distributed mode')
169 | setup_for_distributed(is_master=True) # hack
170 | args.distributed = False
171 | return
172 |
173 | args.distributed = True
174 |
175 | torch.cuda.set_device(args.gpu)
176 | args.dist_backend = 'nccl'
177 | print('| distributed init (rank {}): {}, gpu {}'.format(
178 | args.rank, args.dist_url, args.gpu), flush=True)
179 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
180 | world_size=args.world_size, rank=args.rank)
181 | torch.distributed.barrier()
182 | setup_for_distributed(args.rank == 0)
183 |
184 | def setup_for_distributed(is_master):
185 | """
186 | This function disables printing when not in master process
187 | """
188 | builtin_print = builtins.print
189 |
190 | def print(*args, **kwargs):
191 | force = kwargs.pop('force', False)
192 | force = force or (get_world_size() > 8)
193 | if is_master or force:
194 | now = datetime.datetime.now().time()
195 | builtin_print('[{}] '.format(now), end='') # print with time stamp
196 | builtin_print(*args, **kwargs)
197 |
198 | builtins.print = print
199 |
200 |
201 | def is_dist_avail_and_initialized():
202 | if not dist.is_available():
203 | return False
204 | if not dist.is_initialized():
205 | return False
206 | return True
207 |
208 |
209 | def get_world_size():
210 | if not is_dist_avail_and_initialized():
211 | return 1
212 | return dist.get_world_size()
213 |
214 |
215 | def get_rank():
216 | if not is_dist_avail_and_initialized():
217 | return 0
218 | return dist.get_rank()
219 |
220 |
221 | def is_main_process():
222 | return get_rank() == 0
223 |
224 |
225 | def save_on_master(*args, **kwargs):
226 | if is_main_process():
227 | torch.save(*args, **kwargs)
228 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import builtins
5 | import math
6 | import os
7 | import random
8 | import shutil
9 | import time
10 | import warnings
11 | import json
12 | import logging
13 | from functools import partial
14 | import numpy as np
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.parallel
19 | import torch.backends.cudnn as cudnn
20 | import torch.distributed as dist
21 | import torch.optim
22 | import torch.multiprocessing as mp
23 | import torch.utils.data
24 | import torch.utils.data.distributed
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 | import torchvision.models as torchvision_models
28 | from torch.utils.tensorboard import SummaryWriter
29 |
30 | import lib.builder
31 | from lib.logger import setup_logger
32 | from lib.augment import create_augmentation
33 | from lib.misc import NativeScalerWithGradNormCount as NativeScaler
34 | from lib.dataload_optim import PersistentDataLoader, SoftwarePipeline
35 | import lib.misc as misc
36 |
37 | from timm.optim import optim_factory, create_optimizer
38 | import vits
39 |
40 | model_names = ['vit_small', 'vit_base', 'vit_large']
41 |
42 | parser = argparse.ArgumentParser(description='ExtreMA Arguments')
43 | parser.add_argument('data', metavar='DIR',
44 | help='path to dataset')
45 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base',
46 | choices=model_names,
47 | help='model architecture: ' +
48 | ' | '.join(model_names) +
49 | ' (default: vit_base)')
50 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
51 | help='number of data loading workers per gpu (default: 6)')
52 | parser.add_argument('--epochs', default=300, type=int, metavar='N',
53 | help='number of total epochs to run')
54 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
55 | help='manual epoch number (useful on restarts)')
56 | parser.add_argument('-b', '--batch-size', default=2048, type=int,
57 | metavar='N',
58 | help='mini-batch size (default: 2048), this is the total '
59 | 'batch size of all GPUs on the current node when '
60 | 'using Data Parallel or Distributed Data Parallel')
61 | parser.add_argument('--lr', '--learning-rate', default=1.5e-4, type=float,
62 | metavar='LR', help='initial (base) learning rate', dest='lr')
63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
64 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
66 | help='momentum')
67 | parser.add_argument('--wd', '--weight-decay', default=0.1, type=float,
68 | metavar='W', help='weight decay (default: 1e-6)',
69 | dest='weight_decay')
70 | parser.add_argument('--weight-decay-end', default=None, type=float,
71 | metavar='W', help='weight decay end (default: 1e-6)',
72 | dest='weight_decay_end')
73 | parser.add_argument('-p', '--print-freq', default=20, type=int,
74 | metavar='N', help='print frequency (default: 10)')
75 | parser.add_argument('--save-freq', default=5, type=int,
76 | metavar='N', help='save frequency (default: 5)')
77 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
78 | help='path to latest checkpoint (default: none)')
79 | parser.add_argument('--world-size', default=-1, type=int,
80 | help='number of nodes for distributed training')
81 | parser.add_argument('--local_rank', default=-1, type=int,
82 | help='node rank for distributed training')
83 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
84 | help='url used to set up distributed training')
85 | parser.add_argument('--dist-backend', default='nccl', type=str,
86 | help='distributed backend')
87 | parser.add_argument('--seed', default=None, type=int,
88 | help='seed for initializing training. ')
89 | parser.add_argument('--gpu', default=None, type=int,
90 | help='GPU id to use.')
91 | parser.add_argument('--device', default='cuda',
92 | help='device to use for training / testing')
93 | parser.add_argument('--dist_on_itp', action='store_true')
94 | parser.add_argument('--multiprocessing-distributed', action='store_true',
95 | help='Use multi-processing distributed training to launch '
96 | 'N processes per node, which has N GPUs. This is the '
97 | 'fastest way to use PyTorch for either single node or '
98 | 'multi node data parallel training')
99 | parser.add_argument('--log_dir', default="tf_logs", type=str,
100 | help='dir of logs')
101 | parser.add_argument('--output_dir', default="results", type=str,
102 | help='dir of checkpoints')
103 |
104 | # siamese specific configs:
105 | parser.add_argument('--proj-dim', default=256, type=int,
106 | help='feature dimension (default: 256)')
107 | parser.add_argument('--mlp-dim', default=4096, type=int,
108 | help='hidden dimension in MLPs (default: 4096)')
109 | parser.add_argument('--ema-momentum', default=0.996, type=float,
110 | help='momentum of updating momentum encoder (default: 0.996)')
111 | parser.add_argument('--contrast-temp', default=1.0, type=float,
112 | help='contrastive softmax temperature (default: 1.0)')
113 |
114 | # vit specific configs:
115 | parser.add_argument('--drop_path_rate', type=float, default=0.0, help="stochastic depth rate")
116 | parser.add_argument('--attn_drop_rate', type=float, default=0.0, help="attention dropout rate")
117 | parser.add_argument('--layer_scale_init_value', default=0.1, type=float,
118 | help="0.1 for base, 1e-5 for large. set 0 to disable layer scale")
119 | parser.add_argument('--class_attention_layers', default=2, type=int)
120 |
121 | # other hyper-params
122 | parser.add_argument('--opt', default='adamw', type=str,
123 | choices=['lars', 'adamw'],
124 | help='optimizer used (default: adamw)')
125 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
126 | help='Clip gradient norm (default: None, no clipping)')
127 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
128 | help='Optimizer Epsilon (default: 1e-8)')
129 | parser.add_argument('--opt-betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA',
130 | help='Optimizer Betas (default: None, use opt default)')
131 | parser.add_argument('--adjust-weight-decay', action='store_true',
132 | help='cosine weight decay')
133 | parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
134 | help='number of warmup epochs')
135 | parser.add_argument('--crop-min', default=0.2, type=float,
136 | help='minimum scale for random cropping student (default: 0.2)')
137 |
138 | # augmentation options
139 | parser.add_argument('--aug-spatial', action='store_true',
140 | help='use spatial data augmentation')
141 | parser.add_argument('--aug-centercrop', action='store_true',
142 | help='use centercrop data augmentation')
143 | parser.add_argument('--aug-spatialconsistent-color', action='store_true',
144 | help='use spatial consistent with colorjitter data augmentation')
145 | parser.add_argument('--loss', default='byol', type=str,
146 | choices=['infonce', 'byol'],
147 | help='loss function to use')
148 |
149 | # add mask options
150 | parser.add_argument('--mask-ratio', default=0.8, type=float,
151 | help='mask ratio for student augmentation')
152 | parser.add_argument('--num-masks', default=1, type=int)
153 | parser.add_argument('--disjoint', action='store_true',
154 | help='use disjoint sampling of patches')
155 | parser.set_defaults(disjoint=True)
156 |
157 |
158 | def main_worker(args):
159 | misc.init_distributed_mode(args)
160 | global_rank = misc.get_rank()
161 |
162 | os.makedirs(args.log_dir, exist_ok=True)
163 | os.makedirs(args.output_dir, exist_ok=True)
164 | print("args.output_dir", args.output_dir)
165 | print("args.log_dir", args.log_dir)
166 |
167 | if global_rank == 0 and args.log_dir is not None:
168 | with open(args.log_dir + '/config.json', "w") as config_file:
169 | json.dump(vars(args), config_file)
170 | os.makedirs(args.log_dir, exist_ok=True)
171 | summary_writer = SummaryWriter(log_dir=args.log_dir)
172 | else:
173 | summary_writer = None
174 | logger = setup_logger(output=args.log_dir, distributed_rank=global_rank, name="byol")
175 |
176 | device = torch.device(args.device)
177 |
178 | if args.seed is not None:
179 | seed = args.seed + misc.get_rank()
180 | random.seed(seed)
181 | torch.manual_seed(seed)
182 | torch.cuda.manual_seed_all(seed)
183 | np.random.seed(seed)
184 |
185 | cudnn.benchmark = True
186 |
187 | local_batch_size = int(args.batch_size / misc.get_world_size())
188 | augmentation = create_augmentation(args)
189 | logger.info(augmentation)
190 |
191 | # Data loading
192 | traindir = os.path.join(args.data, 'train')
193 | train_dataset = datasets.ImageFolder(
194 | traindir,
195 | transform=augmentation,
196 | )
197 |
198 | if True: #args.distributed:
199 | num_tasks = misc.get_world_size()
200 | train_sampler = torch.utils.data.distributed.DistributedSampler(
201 | train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
202 | )
203 |
204 | train_loader = SoftwarePipeline(PersistentDataLoader(
205 | train_dataset, batch_size=local_batch_size, shuffle=(train_sampler is None),
206 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True))
207 |
208 | # create model
209 | logger.info("=> creating model '{}'".format(args.arch))
210 | base_encoder = partial(vits.__dict__[args.arch], drop_path_rate=args.drop_path_rate, attn_drop_rate=args.attn_drop_rate, init_values=args.layer_scale_init_value, class_attention_layers=args.class_attention_layers)
211 | ema_encoder = partial(vits.__dict__[args.arch], init_values=args.layer_scale_init_value, class_attention_layers=args.class_attention_layers)
212 | model = lib.builder.ExtreMA(
213 | base_encoder, ema_encoder,
214 | args.proj_dim, args.mlp_dim, args.contrast_temp, args.mask_ratio, args.num_masks, args.disjoint)
215 | model.to(device)
216 |
217 | # infer learning rate before changing batch sizex
218 | args.lr = args.lr * args.batch_size / 256
219 |
220 | if True:
221 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
222 |
223 | logger.info(model)
224 |
225 | param_groups = optim_factory.add_weight_decay(model.module, args.weight_decay, model.module.no_weight_decay())
226 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=args.opt_betas)
227 |
228 | logger.info(optimizer)
229 | scaler = NativeScaler()
230 |
231 | # auto resume from a checkpoint
232 | args.resume = os.path.join(args.output_dir, 'current.pth.tar')
233 | if args.resume:
234 | if os.path.isfile(args.resume):
235 | logger.info("=> loading checkpoint '{}'".format(args.resume))
236 | if args.gpu is None:
237 | checkpoint = torch.load(args.resume)
238 | else:
239 | # Map model to be loaded to specified single gpu.
240 | loc = 'cuda:{}'.format(args.gpu)
241 | checkpoint = torch.load(args.resume, map_location=loc)
242 | args.start_epoch = checkpoint['epoch']
243 | model.load_state_dict(checkpoint['state_dict'])
244 | optimizer.load_state_dict(checkpoint['optimizer'])
245 | scaler.load_state_dict(checkpoint['scaler'])
246 | logger.info("=> loaded checkpoint '{}' (epoch {})"
247 | .format(args.resume, checkpoint['epoch']))
248 | del checkpoint
249 | torch.cuda.empty_cache()
250 | else:
251 | logger.info("=> no checkpoint found at '{}'".format(args.resume))
252 |
253 |
254 | for epoch in range(args.start_epoch, args.epochs):
255 | if args.distributed:
256 | train_sampler.set_epoch(epoch)
257 |
258 | # train for one epoch
259 | train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
260 |
261 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
262 | and misc.get_rank() == 0 and (epoch+1) % args.save_freq == 0): # only the first GPU saves checkpoint
263 | save_checkpoint({
264 | 'epoch': epoch + 1,
265 | 'arch': args.arch,
266 | 'state_dict': model.state_dict(),
267 | 'optimizer' : optimizer.state_dict(),
268 | 'scaler': scaler.state_dict(),
269 | }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.output_dir, epoch))
270 | shutil.copyfile('{}/checkpoint_{:04d}.pth.tar'.format(args.output_dir, epoch), '{}/current.pth.tar'.format(args.output_dir))
271 |
272 | if misc.get_rank() == 0:
273 | summary_writer.close()
274 |
275 | def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
276 | batch_time = AverageMeter('Time', ':6.3f')
277 | data_time = AverageMeter('Data', ':6.3f')
278 | learning_rates = AverageMeter('LR', ':.4e')
279 | loss_scales = AverageMeter('LossScale', ':.4e')
280 | weight_decays = AverageMeter('WeightDecay', ':.4e')
281 | grad_norms = AverageMeter('GradNorm', ':.4e')
282 | losses = AverageMeter('Loss', ':.4e')
283 |
284 | progress = ProgressMeter(
285 | len(train_loader),
286 | [batch_time, data_time, losses, grad_norms, loss_scales, weight_decays, learning_rates],
287 | prefix="Epoch: [{}]".format(epoch))
288 | logger = logging.getLogger('byol')
289 |
290 | # switch to train mode
291 | model.train()
292 |
293 | end = time.time()
294 | iters_per_epoch = len(train_loader)
295 | ema_momentum = args.ema_momentum
296 | for i, (images, labels) in enumerate(train_loader):
297 | # measure data loading time
298 | data_time.update(time.time() - end)
299 |
300 | # adjust learning rate and momentum coefficient per iteration
301 | lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args)
302 | ema_momentum = adjust_ema_momentum(epoch + i / iters_per_epoch, args)
303 | if args.adjust_weight_decay:
304 | wd = adjust_decay_rate(optimizer, epoch + i / iters_per_epoch, args)
305 |
306 | if args.gpu is not None:
307 | if isinstance(images, list):
308 | images[0] = images[0].cuda(args.gpu, non_blocking=True)
309 | images[1] = images[1].cuda(args.gpu, non_blocking=True)
310 | bsz = images[0].size(0)
311 | else:
312 | images = images.cuda(args.gpu, non_blocking=True)
313 | bsz = images.size(0)
314 |
315 | # compute output
316 | with torch.cuda.amp.autocast(True):
317 | loss = model(images, ema_momentum, args.loss)
318 | losses.update(loss.item(), bsz)
319 |
320 | # compute gradient and do SGD step
321 | optimizer.zero_grad()
322 |
323 | norm = scaler(loss, optimizer, parameters=model.parameters(), clip_grad=args.clip_grad)
324 |
325 | # measure elapsed time
326 | batch_time.update(time.time() - end)
327 | end = time.time()
328 |
329 | if i % args.print_freq == 0:
330 | grad_norms.update(norm)
331 | loss_scales.update(scaler.state_dict()["scale"])
332 | learning_rates.update(optimizer.param_groups[1]['lr'])
333 | weight_decays.update(optimizer.param_groups[1]['weight_decay'])
334 | progress.display(logger,i)
335 |
336 | if misc.get_rank() == 0:
337 | summary_writer.add_scalar("losses", losses.avg, epoch )
338 | summary_writer.add_scalar("opt/grad_norm", grad_norms.avg, epoch )
339 | summary_writer.add_scalar("opt/loss_scale", loss_scales.avg, epoch )
340 | summary_writer.add_scalar("opt/lr", learning_rates.avg, epoch )
341 | summary_writer.add_scalar("opt/wd", weight_decays.avg, epoch )
342 |
343 |
344 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
345 | torch.save(state, filename)
346 | if is_best:
347 | shutil.copyfile(filename, 'model_best.pth.tar')
348 |
349 | class AverageMeter(object):
350 | """Computes and stores the average and current value"""
351 | def __init__(self, name, fmt=':f'):
352 | self.name = name
353 | self.fmt = fmt
354 | self.reset()
355 |
356 | def reset(self):
357 | self.val = 0
358 | self.avg = 0
359 | self.sum = 0
360 | self.count = 0
361 |
362 | def update(self, val, n=1):
363 | self.val = val
364 | self.sum += val * n
365 | self.count += n
366 | self.avg = self.sum / self.count
367 |
368 | def __str__(self):
369 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
370 | return fmtstr.format(**self.__dict__)
371 |
372 | class ProgressMeter(object):
373 | def __init__(self, num_batches, meters, prefix=""):
374 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
375 | self.meters = meters
376 | self.prefix = prefix
377 |
378 | def display(self, logger, batch):
379 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
380 | entries += [str(meter) for meter in self.meters]
381 | logger.info('\t'.join(entries))
382 |
383 | def _get_batch_fmtstr(self, num_batches):
384 | num_digits = len(str(num_batches // 1))
385 | fmt = '{:' + str(num_digits) + 'd}'
386 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
387 |
388 | def adjust_learning_rate(optimizer, epoch, args):
389 | """Decays the learning rate with half-cycle cosine after warmup"""
390 | if epoch < args.warmup_epochs:
391 | lr = args.lr * epoch / args.warmup_epochs
392 | else:
393 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
394 | for param_group in optimizer.param_groups:
395 | param_group['lr'] = lr
396 | return lr
397 |
398 | def adjust_decay_rate(optimizer, epoch, args):
399 | """Decays the learning rate with half-cycle cosine"""
400 | if args.weight_decay_end is None:
401 | args.weight_decay_end = args.weight_decay
402 | wd = args.weight_decay_end + 0.5 * (args.weight_decay - args.weight_decay_end) * (1. + math.cos(math.pi * (epoch / args.epochs)))
403 | for param_group in optimizer.param_groups:
404 | if param_group['weight_decay'] > 0:
405 | param_group['weight_decay'] = wd
406 | return wd
407 |
408 | def adjust_ema_momentum(epoch, args):
409 | """Decays the momentum paramter with half-cycle cosine"""
410 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.ema_momentum)
411 | return m
412 |
413 | def accuracy(output, target, topk=(1,)):
414 | """Computes the accuracy over the k top predictions for the specified values of k"""
415 | with torch.no_grad():
416 | maxk = max(topk)
417 | batch_size = target.size(0)
418 |
419 | _, pred = output.topk(maxk, 1, True, True)
420 | pred = pred.t()
421 | correct = pred.eq(target.view(1, -1).expand_as(pred))
422 |
423 | res = []
424 | for k in topk:
425 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
426 | res.append(correct_k.mul_(100.0 / batch_size))
427 | return res
428 |
429 | if __name__ == '__main__':
430 | args = parser.parse_args()
431 | main_worker(args)
432 |
--------------------------------------------------------------------------------
/main_lincls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import builtins
5 | import math
6 | import os
7 | import random
8 | import shutil
9 | import time
10 | import warnings
11 | import json
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.parallel
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 | import torch.optim
19 | import torch.multiprocessing as mp
20 | import torch.utils.data
21 | import torch.utils.data.distributed
22 | import torchvision.transforms as transforms
23 | import torchvision.datasets as datasets
24 | import torchvision.models as torchvision_models
25 |
26 | import vits
27 | import lib.misc as misc
28 | import lib.builder
29 | from lib.dataload_optim import PersistentDataLoader
30 |
31 | model_names = ['vit_small', 'vit_base', 'vit_large']
32 |
33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
34 | parser.add_argument('data', metavar='DIR',
35 | help='path to dataset')
36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='vit_base',
37 | choices=model_names,
38 | help='model architecture: ' +
39 | ' | '.join(model_names) +
40 | ' (default: vit_base)')
41 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
42 | help='number of data loading workers (default: 32)')
43 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
44 | help='number of total epochs to run')
45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
46 | help='manual epoch number (useful on restarts)')
47 | parser.add_argument('-b', '--batch-size', default=1024, type=int,
48 | metavar='N',
49 | help='mini-batch size (default: 1024), this is the total '
50 | 'batch size of all GPUs on the current node when '
51 | 'using Data Parallel or Distributed Data Parallel')
52 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
53 | metavar='LR', help='initial (base) learning rate', dest='lr')
54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
55 | help='momentum')
56 | parser.add_argument('--wd', '--weight-decay', default=0., type=float,
57 | metavar='W', help='weight decay (default: 0.)',
58 | dest='weight_decay')
59 | parser.add_argument('--optimizer', default='sgd', type=str,
60 | choices=['lars', 'sgd'],
61 | help='optimizer used (default: sgd)')
62 | parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
63 | help='number of warmup epochs')
64 | parser.add_argument('-p', '--print-freq', default=10, type=int,
65 | metavar='N', help='print frequency (default: 10)')
66 | parser.add_argument('--save-freq', default=5, type=int,
67 | metavar='N', help='save frequency (default: 5)')
68 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
69 | help='path to latest checkpoint (default: none)')
70 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
71 | help='evaluate model on validation set')
72 | parser.add_argument('--world-size', default=-1, type=int,
73 | help='number of nodes for distributed training')
74 | parser.add_argument('--rank', default=-1, type=int,
75 | help='node rank for distributed training')
76 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
77 | help='url used to set up distributed training')
78 | parser.add_argument('--dist-backend', default='nccl', type=str,
79 | help='distributed backend')
80 | parser.add_argument('--seed', default=None, type=int,
81 | help='seed for initializing training. ')
82 | parser.add_argument('--gpu', default=None, type=int,
83 | help='GPU id to use.')
84 | parser.add_argument('--multiprocessing-distributed', action='store_true',
85 | help='Use multi-processing distributed training to launch '
86 | 'N processes per node, which has N GPUs. This is the '
87 | 'fastest way to use PyTorch for either single node or '
88 | 'multi node data parallel training')
89 | parser.add_argument('--dist_on_itp', action='store_true')
90 | parser.add_argument('--local_rank', default=-1, type=int,
91 | help='node rank for distributed training')
92 |
93 | # additional configs:
94 | parser.add_argument('--pretrained', default='', type=str,
95 | help='path to moco pretrained checkpoint')
96 | parser.add_argument('--log_dir', default=None,
97 | help='path where to tensorboard log')
98 | parser.add_argument('--eval_momentum', action='store_true',
99 | help='evaluation momentum encoder')
100 |
101 | best_acc1 = 0
102 |
103 |
104 | def main():
105 | args = parser.parse_args()
106 |
107 | if args.seed is not None:
108 | random.seed(args.seed)
109 | torch.manual_seed(args.seed)
110 | cudnn.deterministic = True
111 | warnings.warn('You have chosen to seed training. '
112 | 'This will turn on the CUDNN deterministic setting, '
113 | 'which can slow down your training considerably! '
114 | 'You may see unexpected behavior when restarting '
115 | 'from checkpoints.')
116 |
117 | if args.gpu is not None:
118 | warnings.warn('You have chosen a specific GPU. This will completely '
119 | 'disable data parallelism.')
120 |
121 | if args.dist_url == "env://" and args.world_size == -1:
122 | args.world_size = int(os.environ["WORLD_SIZE"])
123 |
124 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
125 |
126 | ngpus_per_node = torch.cuda.device_count()
127 | if args.multiprocessing_distributed:
128 | # Since we have ngpus_per_node processes per node, the total world_size
129 | # needs to be adjusted accordingly
130 | args.world_size = ngpus_per_node * args.world_size
131 | # Use torch.multiprocessing.spawn to launch distributed processes: the
132 | # main_worker process function
133 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
134 | else:
135 | # Simply call main_worker function
136 | main_worker(args.gpu, ngpus_per_node, args)
137 |
138 |
139 | def main_worker(args):
140 | if 'LOCAL_RANK' not in os.environ:
141 | os.environ['LOCAL_RANK'] = args.local_rank
142 |
143 | misc.init_distributed_mode(args)
144 | global_rank = misc.get_rank()
145 |
146 | if args.seed is not None:
147 | seed = args.seed + misc.get_rank()
148 | random.seed(seed)
149 | torch.manual_seed(seed)
150 | torch.cuda.manual_seed_all(seed)
151 | np.random.seed(seed)
152 |
153 | cudnn.benchmark = True
154 |
155 | global best_acc1
156 | #args.gpu = gpu
157 |
158 | # suppress printing if not master
159 | if args.multiprocessing_distributed and global_rank != 0:
160 | def print_pass(*args):
161 | pass
162 | builtins.print = print_pass
163 |
164 | if args.gpu is not None:
165 | print("Use GPU: {} for training".format(args.gpu))
166 |
167 | print("=> creating model '{}'".format(args.arch))
168 | if args.arch.startswith('vit'):
169 | model = vits.__dict__[args.arch](init_values=0.1)
170 | linear_keyword = 'head'
171 | else:
172 | model = torchvision_models.__dict__[args.arch]()
173 | linear_keyword = 'fc'
174 |
175 | # freeze all layers but the last fc
176 | for name, param in model.named_parameters():
177 | if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
178 | param.requires_grad = False
179 | model.head = nn.Sequential(nn.SyncBatchNorm(model.head.in_features, affine=False), model.head)
180 | # init the fc layer
181 | model.head[1].weight.data.normal_(mean=0.0, std=0.01)
182 | model.head[1].bias.data.zero_()
183 |
184 | # load from pre-trained, before DistributedDataParallel constructor
185 | if args.pretrained:
186 | if os.path.isfile(args.pretrained):
187 | print("=> loading checkpoint '{}'".format(args.pretrained))
188 | checkpoint = torch.load(args.pretrained, map_location="cpu")
189 |
190 | # rename moco pre-trained keys
191 | state_dict = checkpoint['state_dict']
192 | if args.eval_momentum:
193 | prefix_key = 'module.momentum_encoder'
194 | else:
195 | prefix_key = 'module.base_encoder'
196 | for k in list(state_dict.keys()):
197 | # retain only base_encoder up to before the embedding layer
198 | if k.startswith(prefix_key) and not k.startswith('{}.{}'.format(prefix_key,linear_keyword)):
199 | # remove prefix
200 | state_dict[k[len(prefix_key)+1:]] = state_dict[k]
201 | # delete renamed or unused k
202 | del state_dict[k]
203 |
204 | args.start_epoch = 0
205 | msg = model.load_state_dict(state_dict, strict=False)
206 | #assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}
207 | print(msg)
208 |
209 | print("=> loaded pre-trained model '{}'".format(args.pretrained))
210 | else:
211 | print("=> no checkpoint found at '{}'".format(args.pretrained))
212 |
213 | model.norm = nn.Identity()
214 |
215 | # infer learning rate before changing batch size
216 | args.lr = args.lr * args.batch_size / 256
217 |
218 | if True:
219 | # For multiprocessing distributed, DistributedDataParallel constructor
220 | # should always set the single device scope, otherwise,
221 | # DistributedDataParallel will use all available devices.
222 | torch.cuda.set_device(args.gpu)
223 | model.cuda(args.gpu)
224 | # When using a single GPU per process and per
225 | # DistributedDataParallel, we need to divide the batch size
226 | # ourselves based on the total number of GPUs we have
227 | args.batch_size = int(args.batch_size / misc.get_world_size())
228 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
229 |
230 | # define loss function (criterion) and optimizer
231 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
232 |
233 | # optimize only the linear classifier
234 | parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
235 | assert len(parameters) == 2 # weight, bias
236 |
237 | if global_rank == 0 and args.log_dir is not None:
238 | os.makedirs(args.log_dir, exist_ok=True)
239 |
240 | if args.optimizer == 'lars':
241 | optimizer = torch.optim.LARS(parameters, args.lr,
242 | weight_decay=0,
243 | momentum=args.momentum)
244 | else:
245 | optimizer = torch.optim.SGD(parameters, args.lr,
246 | momentum=args.momentum,
247 | weight_decay=0)
248 | print(optimizer)
249 |
250 | # optionally resume from a checkpoint
251 | if args.resume:
252 | if os.path.isfile(args.resume):
253 | print("=> loading checkpoint '{}'".format(args.resume))
254 | if args.gpu is None:
255 | checkpoint = torch.load(args.resume)
256 | else:
257 | # Map model to be loaded to specified single gpu.
258 | loc = 'cuda:{}'.format(args.gpu)
259 | checkpoint = torch.load(args.resume, map_location=loc)
260 | args.start_epoch = checkpoint['epoch']
261 | best_acc1 = checkpoint['best_acc1']
262 | if args.gpu is not None:
263 | # best_acc1 may be from a checkpoint from a different GPU
264 | best_acc1 = best_acc1.to(args.gpu)
265 | model.load_state_dict(checkpoint['state_dict'])
266 | optimizer.load_state_dict(checkpoint['optimizer'])
267 | print("=> loaded checkpoint '{}' (epoch {})"
268 | .format(args.resume, checkpoint['epoch']))
269 | else:
270 | print("=> no checkpoint found at '{}'".format(args.resume))
271 |
272 | print(model)
273 | cudnn.benchmark = True
274 |
275 | # Data loading code
276 | traindir = os.path.join(args.data, 'train')
277 | valdir = os.path.join(args.data, 'val')
278 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
279 | std=[0.229, 0.224, 0.225])
280 |
281 | train_dataset = datasets.ImageFolder(
282 | traindir,
283 | transforms.Compose([
284 | transforms.RandomResizedCrop(224),
285 | transforms.RandomHorizontalFlip(),
286 | transforms.ToTensor(),
287 | normalize,
288 | ]))
289 | val_dataset = datasets.ImageFolder(
290 | valdir,
291 | transforms.Compose([
292 | transforms.Resize(256),
293 | transforms.CenterCrop(224),
294 | transforms.ToTensor(),
295 | normalize
296 | ]))
297 |
298 | if args.distributed:
299 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
300 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
301 | else:
302 | train_sampler = None
303 |
304 | train_loader = PersistentDataLoader(
305 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
306 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
307 |
308 | val_loader = PersistentDataLoader(
309 | val_dataset,
310 | batch_size=args.batch_size, shuffle=False,
311 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
312 |
313 | if args.evaluate:
314 | validate(val_loader, model, criterion, args)
315 | return
316 |
317 | for epoch in range(args.start_epoch, args.epochs):
318 | if args.distributed:
319 | train_sampler.set_epoch(epoch)
320 | # train for one epoch
321 | train(train_loader, model, criterion, optimizer, epoch, args)
322 |
323 | # evaluate on validation set
324 | acc1 = validate(val_loader, model, criterion, args)
325 | # remember best acc@1 and save checkpoint
326 | is_best = acc1 > best_acc1
327 | best_acc1 = max(acc1, best_acc1)
328 |
329 | log_stats = {'epoch': epoch, 'acc': acc1.item()}
330 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
331 | and args.rank == 0):
332 | with open(os.path.join(args.log_dir, "log.txt"), mode="a", encoding="utf-8") as f:
333 | f.write(json.dumps(log_stats) + "\n")
334 |
335 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
336 | and args.rank == 0 and (epoch+1) % args.save_freq == 0): # only the first GPU saves checkpoint
337 | save_checkpoint({
338 | 'epoch': epoch + 1,
339 | 'arch': args.arch,
340 | 'state_dict': model.state_dict(),
341 | 'best_acc1': best_acc1,
342 | 'optimizer' : optimizer.state_dict(),
343 | }, is_best)
344 | if epoch == args.start_epoch:
345 | sanity_check(model.state_dict(), args.pretrained, linear_keyword)
346 |
347 |
348 | def train(train_loader, model, criterion, optimizer, epoch, args):
349 | batch_time = AverageMeter('Time', ':6.3f')
350 | data_time = AverageMeter('Data', ':6.3f')
351 | losses = AverageMeter('Loss', ':.4e')
352 | top1 = AverageMeter('Acc@1', ':6.2f')
353 | top5 = AverageMeter('Acc@5', ':6.2f')
354 | progress = ProgressMeter(
355 | len(train_loader),
356 | [batch_time, data_time, losses, top1, top5],
357 | prefix="Epoch: [{}]".format(epoch))
358 |
359 | """
360 | Switch to eval mode:
361 | Under the protocol of linear classification on frozen features/models,
362 | it is not legitimate to change any part of the pre-trained model.
363 | BatchNorm in train mode may revise running mean/std (even if it receives
364 | no gradient), which are part of the model parameters too.
365 | """
366 | model.eval()
367 | model.module.head.train()
368 |
369 | end = time.time()
370 | for i, (images, target) in enumerate(train_loader):
371 | # measure data loading time
372 | data_time.update(time.time() - end)
373 | lr = adjust_learning_rate(optimizer, epoch + i / len(train_loader), args)
374 | if args.gpu is not None:
375 | images = images.cuda(args.gpu, non_blocking=True)
376 | if torch.cuda.is_available():
377 | target = target.cuda(args.gpu, non_blocking=True)
378 |
379 | # compute output
380 | with torch.cuda.amp.autocast(True):
381 | output = model(images)
382 | loss = criterion(output, target)
383 |
384 | # measure accuracy and record loss
385 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
386 | losses.update(loss.item(), images.size(0))
387 | top1.update(acc1[0], images.size(0))
388 | top5.update(acc5[0], images.size(0))
389 |
390 | # compute gradient and do SGD step
391 | optimizer.zero_grad()
392 | loss.backward()
393 | optimizer.step()
394 |
395 | # measure elapsed time
396 | batch_time.update(time.time() - end)
397 | end = time.time()
398 |
399 | if i % args.print_freq == 0:
400 | progress.display(i)
401 |
402 |
403 | def validate(val_loader, model, criterion, args):
404 | batch_time = AverageMeter('Time', ':6.3f')
405 | losses = AverageMeter('Loss', ':.4e')
406 | top1 = AverageMeter('Acc@1', ':6.2f')
407 | top5 = AverageMeter('Acc@5', ':6.2f')
408 | progress = ProgressMeter(
409 | len(val_loader),
410 | [batch_time, losses, top1, top5],
411 | prefix='Test: ')
412 |
413 | # switch to evaluate mode
414 | model.eval()
415 |
416 | with torch.no_grad():
417 | end = time.time()
418 | for i, (images, target) in enumerate(val_loader):
419 | if args.gpu is not None:
420 | images = images.cuda(args.gpu, non_blocking=True)
421 | if torch.cuda.is_available():
422 | target = target.cuda(args.gpu, non_blocking=True)
423 |
424 | # compute output
425 | with torch.cuda.amp.autocast(True):
426 | output = model(images)
427 | loss = criterion(output, target)
428 |
429 | # measure accuracy and record loss
430 | output = moco.builder.concat_all_gather(output.to("cuda"))
431 | target = moco.builder.concat_all_gather(target.to("cuda"))
432 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
433 |
434 | losses.update(loss.item(), images.size(0))
435 | top1.update(acc1[0], output.size(0))
436 | top5.update(acc5[0], target.size(0))
437 |
438 | # measure elapsed time
439 | batch_time.update(time.time() - end)
440 | end = time.time()
441 |
442 | if i % args.print_freq == 0:
443 | progress.display(i)
444 |
445 | # TODO: this should also be done with the ProgressMeter
446 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
447 | .format(top1=top1, top5=top5))
448 |
449 | return top1.avg
450 |
451 |
452 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
453 | torch.save(state, filename)
454 | if is_best:
455 | shutil.copyfile(filename, 'model_best.pth.tar')
456 |
457 |
458 | def sanity_check(state_dict, pretrained_weights, linear_keyword):
459 | """
460 | Linear classifier should not change any weights other than the linear layer.
461 | This sanity check asserts nothing wrong happens (e.g., BN stats updated).
462 | """
463 | print("=> loading '{}' for sanity check".format(pretrained_weights))
464 | checkpoint = torch.load(pretrained_weights, map_location="cpu")
465 | state_dict_pre = checkpoint['state_dict']
466 |
467 | for k in list(state_dict.keys()):
468 | # only ignore linear layer
469 | if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k:
470 | continue
471 |
472 | # name in pretrained model
473 | k_pre = 'module.base_encoder.' + k[len('module.'):] \
474 | if k.startswith('module.') else 'module.base_encoder.' + k
475 |
476 | assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
477 | '{} is changed in linear classifier training.'.format(k)
478 |
479 | print("=> sanity check passed.")
480 |
481 |
482 | class AverageMeter(object):
483 | """Computes and stores the average and current value"""
484 | def __init__(self, name, fmt=':f'):
485 | self.name = name
486 | self.fmt = fmt
487 | self.reset()
488 |
489 | def reset(self):
490 | self.val = 0
491 | self.avg = 0
492 | self.sum = 0
493 | self.count = 0
494 |
495 | def update(self, val, n=1):
496 | self.val = val
497 | self.sum += val * n
498 | self.count += n
499 | self.avg = self.sum / self.count
500 |
501 | def __str__(self):
502 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
503 | return fmtstr.format(**self.__dict__)
504 |
505 |
506 | class ProgressMeter(object):
507 | def __init__(self, num_batches, meters, prefix=""):
508 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
509 | self.meters = meters
510 | self.prefix = prefix
511 |
512 | def display(self, batch):
513 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
514 | entries += [str(meter) for meter in self.meters]
515 | print('\t'.join(entries))
516 |
517 | def _get_batch_fmtstr(self, num_batches):
518 | num_digits = len(str(num_batches // 1))
519 | fmt = '{:' + str(num_digits) + 'd}'
520 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
521 |
522 | def adjust_learning_rate(optimizer, epoch, args):
523 | """Decays the learning rate with half-cycle cosine after warmup"""
524 | if epoch < args.warmup_epochs:
525 | lr = args.lr * epoch / args.warmup_epochs
526 | else:
527 | lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
528 | for param_group in optimizer.param_groups:
529 | param_group['lr'] = lr
530 | return lr
531 |
532 | def accuracy(output, target, topk=(1,)):
533 | """Computes the accuracy over the k top predictions for the specified values of k"""
534 | with torch.no_grad():
535 | maxk = max(topk)
536 | batch_size = target.size(0)
537 |
538 | _, pred = output.topk(maxk, 1, True, True)
539 | pred = pred.t()
540 | correct = pred.eq(target.view(1, -1).expand_as(pred))
541 |
542 | res = []
543 | for k in topk:
544 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
545 | res.append(correct_k.mul_(100.0 / batch_size))
546 | return res
547 |
548 |
549 | if __name__ == '__main__':
550 | args = parser.parse_args()
551 | main_worker(args)
552 | #main()
553 |
--------------------------------------------------------------------------------
/vits.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from timm.models.vision_transformer import _cfg
8 | from lib.misc import get_2d_sincos_pos_embed
9 |
10 | class PatchEmbed(nn.Module):
11 | """ Image to Patch Embedding
12 | """
13 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
14 | super().__init__()
15 | num_patches = (img_size // patch_size) * (img_size // patch_size)
16 | self.img_size = img_size
17 | self.patch_size = patch_size
18 | self.num_patches = num_patches
19 |
20 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
21 |
22 | def forward(self, x):
23 | B, C, H, W = x.shape
24 | x = self.proj(x).flatten(2).transpose(1, 2)
25 | return x
26 |
27 | def drop_path(x, drop_prob: float = 0., training: bool = False):
28 | if drop_prob == 0. or not training:
29 | return x
30 | keep_prob = 1 - drop_prob
31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33 | random_tensor.floor_() # binarize
34 | output = x.div(keep_prob) * random_tensor
35 | return output
36 |
37 | class DropPath(nn.Module):
38 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
39 | """
40 | def __init__(self, drop_prob=None):
41 | super(DropPath, self).__init__()
42 | self.drop_prob = drop_prob
43 |
44 | def forward(self, x):
45 | return drop_path(x, self.drop_prob, self.training)
46 |
47 |
48 | class Mlp(nn.Module):
49 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50 | super().__init__()
51 | out_features = out_features or in_features
52 | hidden_features = hidden_features or in_features
53 | self.fc1 = nn.Linear(in_features, hidden_features)
54 | self.act = act_layer()
55 | self.fc2 = nn.Linear(hidden_features, out_features)
56 | self.drop = nn.Dropout(drop)
57 |
58 | def forward(self, x):
59 | x = self.fc1(x)
60 | x = self.act(x)
61 | x = self.drop(x)
62 | x = self.fc2(x)
63 | x = self.drop(x)
64 | return x
65 |
66 | class Attention(nn.Module):
67 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
68 | super().__init__()
69 | self.num_heads = num_heads
70 | head_dim = dim // num_heads
71 | self.scale = qk_scale or head_dim ** -0.5
72 |
73 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
74 | self.attn_drop = nn.Dropout(attn_drop)
75 | self.proj = nn.Linear(dim, dim)
76 | self.proj_drop = nn.Dropout(proj_drop)
77 |
78 | def forward(self, x):
79 | B, N, C = x.shape
80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
81 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
82 |
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 |
87 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
88 | x = self.proj(x)
89 | x = self.proj_drop(x)
90 | return x
91 |
92 | class Class_Attention(nn.Module):
93 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
94 | # with slight modifications to do CA
95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
96 | super().__init__()
97 | self.num_heads = num_heads
98 | head_dim = dim // num_heads
99 | self.scale = qk_scale or head_dim ** -0.5
100 |
101 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
102 | self.k = nn.Linear(dim, dim, bias=qkv_bias)
103 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
104 | self.attn_drop = nn.Dropout(attn_drop)
105 | self.proj = nn.Linear(dim, dim)
106 | self.proj_drop = nn.Dropout(proj_drop)
107 |
108 |
109 | def forward(self, x ):
110 |
111 | B, N, C = x.shape
112 | q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
113 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
114 |
115 | q = q * self.scale
116 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
117 |
118 | attn = (q @ k.transpose(-2, -1))
119 | attn = attn.softmax(dim=-1)
120 | attn = self.attn_drop(attn)
121 |
122 | x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
123 | x_cls = self.proj(x_cls)
124 | x_cls = self.proj_drop(x_cls)
125 |
126 | return x_cls
127 |
128 | class Block_CA(nn.Module):
129 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
130 | # with slight modifications to add CA and LayerScale
131 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
132 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block = Class_Attention,
133 | Mlp_block=Mlp,init_values=1e-4):
134 | super().__init__()
135 | self.norm1 = norm_layer(dim)
136 | self.attn = Attention_block(
137 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
138 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
139 | self.norm2 = norm_layer(dim)
140 | mlp_hidden_dim = int(dim * mlp_ratio)
141 | self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
142 |
143 | if init_values is not None:
144 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
145 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
146 | else:
147 | self.gamma_1, self.gamma_2 = None, None
148 |
149 |
150 | def forward(self, x, x_cls):
151 |
152 | u = torch.cat((x_cls,x),dim=1)
153 |
154 | if self.gamma_1 is None:
155 | x_cls = x_cls + self.drop_path(self.attn(self.norm1(u)))
156 | x_cls = x_cls + self.drop_path(self.mlp(self.norm2(x_cls)))
157 | else:
158 | x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
159 | x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
160 |
161 | return x_cls
162 |
163 | class Block(nn.Module):
164 |
165 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
166 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
167 | window_size=None, attn_head_dim=None):
168 | super().__init__()
169 | self.norm1 = norm_layer(dim)
170 | self.attn = Attention(
171 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
172 | attn_drop=attn_drop, proj_drop=drop)
173 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
174 | self.norm2 = norm_layer(dim)
175 | mlp_hidden_dim = int(dim * mlp_ratio)
176 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
177 |
178 | if init_values is not None:
179 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
180 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
181 | else:
182 | self.gamma_1, self.gamma_2 = None, None
183 |
184 | def forward(self, x):
185 | if self.gamma_1 is None:
186 | x = x + self.drop_path(self.attn(self.norm1(x)))
187 | x = x + self.drop_path(self.mlp(self.norm2(x)))
188 | else:
189 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
190 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
191 | return x
192 |
193 | class VisionTransformer(nn.Module):
194 | """ Vision Transformer
195 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
196 | - https://arxiv.org/abs/2010.11929
197 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
198 | - https://arxiv.org/abs/2012.12877
199 | """
200 |
201 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
202 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
203 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, init_values=None,
204 | act_layer=None, weight_init='', class_attention_layers=2):
205 | """
206 | Args:
207 | img_size (int, tuple): input image size
208 | patch_size (int, tuple): patch size
209 | in_chans (int): number of input channels
210 | num_classes (int): number of classes for classification head
211 | embed_dim (int): embedding dimension
212 | depth (int): depth of transformer
213 | num_heads (int): number of attention heads
214 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
215 | qkv_bias (bool): enable bias for qkv if True
216 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set
217 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
218 | drop_rate (float): dropout rate
219 | attn_drop_rate (float): attention dropout rate
220 | drop_path_rate (float): stochastic depth rate
221 | embed_layer (nn.Module): patch embedding layer
222 | norm_layer: (nn.Module): normalization layer
223 | weight_init: (str): weight init scheme
224 | """
225 | super().__init__()
226 | self.num_classes = num_classes
227 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
228 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
229 | act_layer = act_layer or nn.GELU
230 |
231 | self.patch_embed = embed_layer(
232 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
233 | num_patches = self.patch_embed.num_patches
234 |
235 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
236 |
237 | self.rel_pos_bias = None
238 | self.pos_embed = nn.Parameter(
239 | torch.zeros(1, self.patch_embed.num_patches, self.embed_dim), requires_grad=False)
240 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=False)
241 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
242 |
243 | self.pos_drop = nn.Dropout(p=drop_rate)
244 |
245 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
246 | self.blocks = nn.Sequential(*[
247 | Block(
248 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
249 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], init_values=None, norm_layer=norm_layer, act_layer=act_layer)
250 | for i in range(depth)])
251 |
252 | self.blocks_token = nn.ModuleList([
253 | Block_CA(
254 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
255 | drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
256 | act_layer=act_layer,init_values=init_values)
257 | for i in range(class_attention_layers)])
258 |
259 | self.norm = norm_layer(embed_dim)
260 |
261 | self.pre_logits = nn.Identity()
262 |
263 | # Classifier head(s)
264 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
265 |
266 | # weight init
267 | nn.init.normal_(self.cls_token, std=0.02)
268 | #nn.init.normal_(self.pos_embed, std=0.02)
269 | w = self.patch_embed.proj.weight.data
270 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
271 | self.apply(self._init_weights)
272 |
273 | def _init_weights(self, m):
274 | if isinstance(m, nn.Linear):
275 | # we use xavier_uniform following official JAX ViT:
276 | torch.nn.init.xavier_uniform_(m.weight)
277 | if isinstance(m, nn.Linear) and m.bias is not None:
278 | nn.init.constant_(m.bias, 0)
279 | elif isinstance(m, nn.LayerNorm):
280 | nn.init.constant_(m.bias, 0)
281 | nn.init.constant_(m.weight, 1.0)
282 |
283 | @torch.jit.ignore
284 | def no_weight_decay(self):
285 | return {'pos_embed', 'cls_token'}
286 |
287 | def forward_features(self, x, masks=None, mask_ratio=0.):
288 | x = self.patch_embed(x)
289 | # ExtreMA follows the CaiT-style architecture
290 | #x = torch.cat((cls_token, x), dim=1)
291 | x = self.pos_drop(x + self.pos_embed)
292 | # for student
293 | if masks is not None and self.student:
294 | cls_token = self.cls_token.expand(x.shape[0] * len(masks), -1, -1) # stole cls_tokens impl from Phil Wang, thanks
295 | x_list = []
296 | for mask in masks:
297 | mask = mask.view(x.shape[0], -1, 1).repeat(1,1,x.shape[2])
298 | x_list.append(torch.gather(x, 1, mask))
299 | x_multi = torch.cat(x_list, dim=0)
300 | else:
301 | x_multi = x
302 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
303 |
304 | x_multi = self.blocks(x_multi)
305 |
306 | for i , blk in enumerate(self.blocks_token):
307 | cls_token = blk(x_multi,cls_token)
308 |
309 | x_multi = torch.cat((cls_token, x_multi), dim=1)
310 |
311 | x = self.norm(x_multi)
312 | return self.pre_logits(x[:, 0])
313 |
314 | def forward(self, x, masks=None, mask_ratio=0.):
315 | x = self.forward_features(x, masks, mask_ratio)
316 | out = self.head(x)
317 | return out
318 |
319 | def vit_small(**kwargs):
320 | model = VisionTransformer(
321 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
322 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
323 | model.default_cfg = _cfg()
324 | return model
325 |
326 | def vit_base(**kwargs):
327 | model = VisionTransformer(
328 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
329 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
330 | model.default_cfg = _cfg()
331 | return model
332 |
333 | def vit_large(**kwargs):
334 | model = VisionTransformer(
335 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
336 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
337 | model.default_cfg = _cfg()
338 | return model
--------------------------------------------------------------------------------