├── .gitignore ├── LICENSE ├── README.md ├── examples └── imagenet │ ├── README.md │ └── main.py ├── jax_to_pytorch ├── README.md └── convert_jax_to_pt │ └── load_jax_weights.py ├── setup.py ├── test.py ├── tests └── test_model.py └── vision_transformer_pytorch ├── __init__.py ├── model.py ├── resnet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | tmp 3 | *.pkl 4 | .vscode 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .DS_STORE 111 | 112 | # PyCharm 113 | .idea* 114 | *.xml 115 | 116 | # Custom 117 | examples/imagenet/data/ 118 | jax_to_pytorch/pretrained_jax 119 | jax_to_pytorch/pretrained_pytorch -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [ZHANG Zhi] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Transformer Pytorch 2 | This project is modified from [lukemelas](https://github.com/lukemelas)/[EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch) and [asyml](https://github.com/asyml)/[vision-transformer-pytorch](https://github.com/asyml/vision-transformer-pytorch) to provide out-of-box API for you to utilize VisionTransformer as easy as EfficientNet. 3 | 4 | ### Quickstart 5 | 6 | Install with `pip install vision_transformer_pytorch` and load a pretrained VisionTransformer with: 7 | 8 | ``` 9 | from vision_transformer_pytorch import VisionTransformer 10 | model = VisionTransformer.from_pretrained('ViT-B_16') 11 | ``` 12 | 13 | ### About Vision Transformer PyTorch 14 | 15 | Vision Transformer Pytorch is a PyTorch re-implementation of Vision Transformer based on one of the best practice of commonly utilized deep learning libraries, [EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch), and an elegant implement of VisionTransformer, [vision-transformer-pytorch](https://github.com/asyml/vision-transformer-pytorch). In this project, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible. 16 | 17 | If you have any feature requests or questions, feel free to leave them as GitHub issues! 18 | 19 | ### Installation 20 | 21 | Install via pip: 22 | 23 | ``` 24 | pip install vision_transformer_pytorch 25 | ``` 26 | 27 | Or install from source: 28 | 29 | ``` 30 | git clone https://github.com/tczhangzhi/VisionTransformer-Pytorch 31 | cd VisionTransformer-Pytorch 32 | pip install -e . 33 | ``` 34 | 35 | ### Usage 36 | 37 | #### Loading pretrained models 38 | 39 | Load a Vision Transformer: 40 | 41 | ``` 42 | from vision_transformer_pytorch import VisionTransformer 43 | model = VisionTransformer.from_name('ViT-B_16') # or 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16' 44 | ``` 45 | 46 | Load a pretrained Vision Transformer: 47 | 48 | ``` 49 | from vision_transformer_pytorch import VisionTransformer 50 | model = VisionTransformer.from_pretrained('ViT-B_16') # or 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16' 51 | # inputs = torch.randn(1, 3, *model.image_size) 52 | # model(inputs) 53 | # model.extract_features(inputs) 54 | ``` 55 | 56 | Default hyper parameters: 57 | 58 | | Param\Model | ViT-B_16 | ViT-B_32 | ViT-L_16 | ViT-L_32 | R50+ViT-B_16 | 59 | | ----------------- | -------- | -------- | -------- | -------- | ------------ | 60 | | image_size | 384 | 384 | 384 | 384 | 384 | 61 | | patch_size | 16 | 32 | 16 | 32 | 1 | 62 | | emb_dim | 768 | 768 | 1024 | 1024 | 768 | 63 | | mlp_dim | 3072 | 3072 | 4096 | 4096 | 3072 | 64 | | num_heads | 12 | 12 | 16 | 16 | 12 | 65 | | num_layers | 12 | 12 | 24 | 24 | 12 | 66 | | num_classes | 1000 | 1000 | 1000 | 1000 | 1000 | 67 | | attn_dropout_rate | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 68 | | dropout_rate | 0.1 | 0.1 | 0.1 | 0.1 | 0.1 | 69 | 70 | If you need to modify these hyper parameters, please use: 71 | 72 | ``` 73 | from vision_transformer_pytorch import VisionTransformer 74 | model = VisionTransformer.from_name('ViT-B_16', image_size=256, patch_size=64, ...) 75 | ``` 76 | 77 | #### ImageNet 78 | 79 | See `examples/imagenet` for details about evaluating on ImageNet. 80 | 81 | ### Contributing 82 | 83 | If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues. 84 | 85 | I look forward to seeing what the community does with these models! -------------------------------------------------------------------------------- /examples/imagenet/README.md: -------------------------------------------------------------------------------- 1 | ### Imagenet 2 | 3 | This is a preliminary directory for evaluating the model on ImageNet. It is adapted from the standard PyTorch Imagenet script. 4 | 5 | For now, only evaluation is supported, but I am currently building scripts to assist with training new models on Imagenet. 6 | 7 | To run on Imagenet, place your `train` and `val` directories in `data`. 8 | 9 | Example commands: 10 | 11 | ``` 12 | # Evaluate small VisionTransformer on CPU 13 | python main.py data -e -a 'ViT-B_16' --pretrained 14 | # Evaluate large VisionTransformer on GPU 15 | python main.py data -e -a 'ViT-L_32' --pretrained --gpu 0 --batch-size 128 16 | # Evaluate ResNet-50 for comparison 17 | python main.py data -e -a 'resnet50' --pretrained --gpu 0 18 | ``` -------------------------------------------------------------------------------- /examples/imagenet/main.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/examples/imagenet/main.py 3 | 4 | import argparse 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | import PIL 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim 18 | import torch.multiprocessing as mp 19 | import torch.utils.data 20 | import torch.utils.data.distributed 21 | import torchvision.transforms as transforms 22 | import torchvision.datasets as datasets 23 | import torchvision.models as models 24 | 25 | from vision_transformer_pytorch import VisionTransformer 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 28 | parser.add_argument('data', metavar='DIR', 29 | help='path to dataset') 30 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 31 | help='model architecture (default: resnet18)') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', 40 | help='mini-batch size (default: 256), this is the total ' 41 | 'batch size of all GPUs on the current node when ' 42 | 'using Data Parallel or Distributed Data Parallel') 43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | metavar='LR', help='initial learning rate', dest='lr') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)', 49 | dest='weight_decay') 50 | parser.add_argument('-p', '--print-freq', default=10, type=int, 51 | metavar='N', help='print frequency (default: 10)') 52 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 53 | help='path to latest checkpoint (default: none)') 54 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 55 | help='evaluate model on validation set') 56 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 57 | help='use pre-trained model') 58 | parser.add_argument('--world-size', default=-1, type=int, 59 | help='number of nodes for distributed training') 60 | parser.add_argument('--rank', default=-1, type=int, 61 | help='node rank for distributed training') 62 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 63 | help='url used to set up distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', type=str, 65 | help='distributed backend') 66 | parser.add_argument('--seed', default=None, type=int, 67 | help='seed for initializing training. ') 68 | parser.add_argument('--gpu', default=None, type=int, 69 | help='GPU id to use.') 70 | parser.add_argument('--image_size', default=224, type=int, 71 | help='image size') 72 | parser.add_argument('--advprop', default=False, action='store_true', 73 | help='use advprop or not') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | 80 | best_acc1 = 0 81 | 82 | 83 | def main(): 84 | args = parser.parse_args() 85 | 86 | if args.seed is not None: 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | cudnn.deterministic = True 90 | warnings.warn('You have chosen to seed training. ' 91 | 'This will turn on the CUDNN deterministic setting, ' 92 | 'which can slow down your training considerably! ' 93 | 'You may see unexpected behavior when restarting ' 94 | 'from checkpoints.') 95 | 96 | if args.gpu is not None: 97 | warnings.warn('You have chosen a specific GPU. This will completely ' 98 | 'disable data parallelism.') 99 | 100 | if args.dist_url == "env://" and args.world_size == -1: 101 | args.world_size = int(os.environ["WORLD_SIZE"]) 102 | 103 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 104 | 105 | ngpus_per_node = torch.cuda.device_count() 106 | if args.multiprocessing_distributed: 107 | # Since we have ngpus_per_node processes per node, the total world_size 108 | # needs to be adjusted accordingly 109 | args.world_size = ngpus_per_node * args.world_size 110 | # Use torch.multiprocessing.spawn to launch distributed processes: the 111 | # main_worker process function 112 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 113 | else: 114 | # Simply call main_worker function 115 | main_worker(args.gpu, ngpus_per_node, args) 116 | 117 | 118 | def main_worker(gpu, ngpus_per_node, args): 119 | global best_acc1 120 | args.gpu = gpu 121 | 122 | if args.gpu is not None: 123 | print("Use GPU: {} for training".format(args.gpu)) 124 | 125 | if args.distributed: 126 | if args.dist_url == "env://" and args.rank == -1: 127 | args.rank = int(os.environ["RANK"]) 128 | if args.multiprocessing_distributed: 129 | # For multiprocessing distributed training, rank needs to be the 130 | # global rank among all the processes 131 | args.rank = args.rank * ngpus_per_node + gpu 132 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 133 | world_size=args.world_size, rank=args.rank) 134 | # create model 135 | if 'ViT' in args.arch: # NEW 136 | if args.pretrained: 137 | model = VisionTransformer.from_pretrained(args.arch, advprop=args.advprop) 138 | print("=> using pre-trained model '{}'".format(args.arch)) 139 | else: 140 | print("=> creating model '{}'".format(args.arch)) 141 | model = VisionTransformer.from_name(args.arch) 142 | 143 | else: 144 | if args.pretrained: 145 | print("=> using pre-trained model '{}'".format(args.arch)) 146 | model = models.__dict__[args.arch](pretrained=True) 147 | else: 148 | print("=> creating model '{}'".format(args.arch)) 149 | model = models.__dict__[args.arch]() 150 | 151 | if args.distributed: 152 | # For multiprocessing distributed, DistributedDataParallel constructor 153 | # should always set the single device scope, otherwise, 154 | # DistributedDataParallel will use all available devices. 155 | if args.gpu is not None: 156 | torch.cuda.set_device(args.gpu) 157 | model.cuda(args.gpu) 158 | # When using a single GPU per process and per 159 | # DistributedDataParallel, we need to divide the batch size 160 | # ourselves based on the total number of GPUs we have 161 | args.batch_size = int(args.batch_size / ngpus_per_node) 162 | args.workers = int(args.workers / ngpus_per_node) 163 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 164 | else: 165 | model.cuda() 166 | # DistributedDataParallel will divide and allocate batch_size to all 167 | # available GPUs if device_ids are not set 168 | model = torch.nn.parallel.DistributedDataParallel(model) 169 | elif args.gpu is not None: 170 | torch.cuda.set_device(args.gpu) 171 | model = model.cuda(args.gpu) 172 | else: 173 | # DataParallel will divide and allocate batch_size to all available GPUs 174 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 175 | model.features = torch.nn.DataParallel(model.features) 176 | model.cuda() 177 | else: 178 | model = torch.nn.DataParallel(model).cuda() 179 | 180 | # define loss function (criterion) and optimizer 181 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 182 | 183 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 184 | momentum=args.momentum, 185 | weight_decay=args.weight_decay) 186 | 187 | # optionally resume from a checkpoint 188 | if args.resume: 189 | if os.path.isfile(args.resume): 190 | print("=> loading checkpoint '{}'".format(args.resume)) 191 | checkpoint = torch.load(args.resume) 192 | args.start_epoch = checkpoint['epoch'] 193 | best_acc1 = checkpoint['best_acc1'] 194 | if args.gpu is not None: 195 | # best_acc1 may be from a checkpoint from a different GPU 196 | best_acc1 = best_acc1.to(args.gpu) 197 | model.load_state_dict(checkpoint['state_dict']) 198 | optimizer.load_state_dict(checkpoint['optimizer']) 199 | print("=> loaded checkpoint '{}' (epoch {})" 200 | .format(args.resume, checkpoint['epoch'])) 201 | else: 202 | print("=> no checkpoint found at '{}'".format(args.resume)) 203 | 204 | cudnn.benchmark = True 205 | 206 | # Data loading code 207 | traindir = os.path.join(args.data, 'train') 208 | valdir = os.path.join(args.data, 'val') 209 | if args.advprop: 210 | normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0) 211 | else: 212 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 213 | std=[0.229, 0.224, 0.225]) 214 | 215 | if 'ViT' in args.arch: 216 | image_size = model.image_size() 217 | else: 218 | image_size = args.image_size 219 | 220 | train_dataset = datasets.ImageFolder( 221 | traindir, 222 | transforms.Compose([ 223 | transforms.RandomResizedCrop(image_size), 224 | transforms.RandomHorizontalFlip(), 225 | transforms.ToTensor(), 226 | normalize, 227 | ])) 228 | 229 | if args.distributed: 230 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 231 | else: 232 | train_sampler = None 233 | 234 | train_loader = torch.utils.data.DataLoader( 235 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 236 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 237 | 238 | val_transforms = transforms.Compose([ 239 | transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC), 240 | transforms.CenterCrop(image_size), 241 | transforms.ToTensor(), 242 | normalize, 243 | ]) 244 | print('Using image size', image_size) 245 | 246 | val_loader = torch.utils.data.DataLoader( 247 | datasets.ImageFolder(valdir, val_transforms), 248 | batch_size=args.batch_size, shuffle=False, 249 | num_workers=args.workers, pin_memory=True) 250 | 251 | if args.evaluate: 252 | res = validate(val_loader, model, criterion, args) 253 | with open('res.txt', 'w') as f: 254 | print(res, file=f) 255 | return 256 | 257 | for epoch in range(args.start_epoch, args.epochs): 258 | if args.distributed: 259 | train_sampler.set_epoch(epoch) 260 | adjust_learning_rate(optimizer, epoch, args) 261 | 262 | # train for one epoch 263 | train(train_loader, model, criterion, optimizer, epoch, args) 264 | 265 | # evaluate on validation set 266 | acc1 = validate(val_loader, model, criterion, args) 267 | 268 | # remember best acc@1 and save checkpoint 269 | is_best = acc1 > best_acc1 270 | best_acc1 = max(acc1, best_acc1) 271 | 272 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 273 | and args.rank % ngpus_per_node == 0): 274 | save_checkpoint({ 275 | 'epoch': epoch + 1, 276 | 'arch': args.arch, 277 | 'state_dict': model.state_dict(), 278 | 'best_acc1': best_acc1, 279 | 'optimizer' : optimizer.state_dict(), 280 | }, is_best) 281 | 282 | 283 | def train(train_loader, model, criterion, optimizer, epoch, args): 284 | batch_time = AverageMeter('Time', ':6.3f') 285 | data_time = AverageMeter('Data', ':6.3f') 286 | losses = AverageMeter('Loss', ':.4e') 287 | top1 = AverageMeter('Acc@1', ':6.2f') 288 | top5 = AverageMeter('Acc@5', ':6.2f') 289 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 290 | top5, prefix="Epoch: [{}]".format(epoch)) 291 | 292 | # switch to train mode 293 | model.train() 294 | 295 | end = time.time() 296 | for i, (images, target) in enumerate(train_loader): 297 | # measure data loading time 298 | data_time.update(time.time() - end) 299 | 300 | if args.gpu is not None: 301 | images = images.cuda(args.gpu, non_blocking=True) 302 | target = target.cuda(args.gpu, non_blocking=True) 303 | 304 | # compute output 305 | output = model(images) 306 | loss = criterion(output, target) 307 | 308 | # measure accuracy and record loss 309 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 310 | losses.update(loss.item(), images.size(0)) 311 | top1.update(acc1[0], images.size(0)) 312 | top5.update(acc5[0], images.size(0)) 313 | 314 | # compute gradient and do SGD step 315 | optimizer.zero_grad() 316 | loss.backward() 317 | optimizer.step() 318 | 319 | # measure elapsed time 320 | batch_time.update(time.time() - end) 321 | end = time.time() 322 | 323 | if i % args.print_freq == 0: 324 | progress.print(i) 325 | 326 | 327 | def validate(val_loader, model, criterion, args): 328 | batch_time = AverageMeter('Time', ':6.3f') 329 | losses = AverageMeter('Loss', ':.4e') 330 | top1 = AverageMeter('Acc@1', ':6.2f') 331 | top5 = AverageMeter('Acc@5', ':6.2f') 332 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 333 | prefix='Test: ') 334 | 335 | # switch to evaluate mode 336 | model.eval() 337 | 338 | with torch.no_grad(): 339 | end = time.time() 340 | for i, (images, target) in enumerate(val_loader): 341 | if args.gpu is not None: 342 | images = images.cuda(args.gpu, non_blocking=True) 343 | target = target.cuda(args.gpu, non_blocking=True) 344 | 345 | # compute output 346 | output = model(images) 347 | loss = criterion(output, target) 348 | 349 | # measure accuracy and record loss 350 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 351 | losses.update(loss.item(), images.size(0)) 352 | top1.update(acc1[0], images.size(0)) 353 | top5.update(acc5[0], images.size(0)) 354 | 355 | # measure elapsed time 356 | batch_time.update(time.time() - end) 357 | end = time.time() 358 | 359 | if i % args.print_freq == 0: 360 | progress.print(i) 361 | 362 | # TODO: this should also be done with the ProgressMeter 363 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 364 | .format(top1=top1, top5=top5)) 365 | 366 | return top1.avg 367 | 368 | 369 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 370 | torch.save(state, filename) 371 | if is_best: 372 | shutil.copyfile(filename, 'model_best.pth.tar') 373 | 374 | 375 | class AverageMeter(object): 376 | """Computes and stores the average and current value""" 377 | def __init__(self, name, fmt=':f'): 378 | self.name = name 379 | self.fmt = fmt 380 | self.reset() 381 | 382 | def reset(self): 383 | self.val = 0 384 | self.avg = 0 385 | self.sum = 0 386 | self.count = 0 387 | 388 | def update(self, val, n=1): 389 | self.val = val 390 | self.sum += val * n 391 | self.count += n 392 | self.avg = self.sum / self.count 393 | 394 | def __str__(self): 395 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 396 | return fmtstr.format(**self.__dict__) 397 | 398 | 399 | class ProgressMeter(object): 400 | def __init__(self, num_batches, *meters, prefix=""): 401 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 402 | self.meters = meters 403 | self.prefix = prefix 404 | 405 | def print(self, batch): 406 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 407 | entries += [str(meter) for meter in self.meters] 408 | print('\t'.join(entries)) 409 | 410 | def _get_batch_fmtstr(self, num_batches): 411 | num_digits = len(str(num_batches // 1)) 412 | fmt = '{:' + str(num_digits) + 'd}' 413 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 414 | 415 | 416 | def adjust_learning_rate(optimizer, epoch, args): 417 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 418 | lr = args.lr * (0.1 ** (epoch // 30)) 419 | for param_group in optimizer.param_groups: 420 | param_group['lr'] = lr 421 | 422 | 423 | def accuracy(output, target, topk=(1,)): 424 | """Computes the accuracy over the k top predictions for the specified values of k""" 425 | with torch.no_grad(): 426 | maxk = max(topk) 427 | batch_size = target.size(0) 428 | 429 | _, pred = output.topk(maxk, 1, True, True) 430 | pred = pred.t() 431 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 432 | 433 | res = [] 434 | for k in topk: 435 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 436 | res.append(correct_k.mul_(100.0 / batch_size)) 437 | return res 438 | 439 | 440 | if __name__ == '__main__': 441 | main() -------------------------------------------------------------------------------- /jax_to_pytorch/README.md: -------------------------------------------------------------------------------- 1 | ### JAX to PyTorch Conversion 2 | 3 | This directory is used to convert JAX weights to PyTorch. It was hacked together fairly quickly, so the code is not the most beautiful (just a warning!), but it does the job. I will be refactoring it soon. 4 | 5 | I should also emphasize that you do *not* need to run any of this code to load pretrained weights. Simply use `VisionTransformer.from_pretrained(...)`. 6 | 7 | That being said, the main script here is `convert_to_jax/load_jax_weights.py`. In order to use it, you should first download the pre-trained JAX weights following the description official repository. 8 | 9 | >You can find all these models in the following storage bucket: 10 | > 11 | >https://console.cloud.google.com/storage/vit_models/ 12 | > 13 | >For example, if you would like to download the ViT-B/16 pre-trained on imagenet21k run the following command: 14 | 15 | ``` 16 | mkdir pretrained_jax 17 | cd pretrained_jax 18 | wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz 19 | cd .. 20 | ``` 21 | 22 | Then 23 | 24 | ``` 25 | mkdir pretrained_pytorch 26 | cd convert_jax_to_pt 27 | python load_jax_weights.py \ 28 | --jax_checkpoint ../pretrained_jax/ViT-B_16.npz \ 29 | --output_file ../pretrained_pytorch/ViT-B_16.pth 30 | ``` -------------------------------------------------------------------------------- /jax_to_pytorch/convert_jax_to_pt/load_jax_weights.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/asyml/vision-transformer-pytorch/blob/92b8deb1ce99e83e0a182fefc866ab0485d76f1b/src/check_jax.py 3 | 4 | import torch 5 | import argparse 6 | import numpy as np 7 | 8 | from tensorflow.io import gfile 9 | 10 | def load_jax(path): 11 | """ Loads params from a npz checkpoint previously stored with `save()` in jax implemetation """ 12 | with gfile.GFile(path, 'rb') as f: 13 | ckpt_dict = np.load(f, allow_pickle=False) 14 | keys, values = zip(*list(ckpt_dict.items())) 15 | return keys, values 16 | 17 | def replace_names(names): 18 | """ Replace jax model names with pytorch model names """ 19 | new_names = [] 20 | for name in names: 21 | if name == 'Transformer': 22 | new_names.append('transformer') 23 | elif name == 'encoder_norm': 24 | new_names.append('norm') 25 | elif 'encoderblock' in name: 26 | num = name.split('_')[-1] 27 | new_names.append('encoder_layers') 28 | new_names.append(num) 29 | elif 'LayerNorm' in name: 30 | num = name.split('_')[-1] 31 | if num == '0': 32 | new_names.append('norm{}'.format(1)) 33 | elif num == '2': 34 | new_names.append('norm{}'.format(2)) 35 | elif 'MlpBlock' in name: 36 | new_names.append('mlp') 37 | elif 'Dense' in name: 38 | num = name.split('_')[-1] 39 | new_names.append('fc{}'.format(int(num) + 1)) 40 | elif 'MultiHeadDotProductAttention' in name: 41 | new_names.append('attn') 42 | elif name == 'kernel' or name == 'scale': 43 | new_names.append('weight') 44 | elif name == 'bias': 45 | new_names.append(name) 46 | elif name == 'posembed_input': 47 | new_names.append('pos_embedding') 48 | elif name == 'pos_embedding': 49 | new_names.append('pos_embedding') 50 | elif name == 'embedding': 51 | new_names.append('embedding') 52 | elif name == 'head': 53 | new_names.append('classifier') 54 | elif name == 'cls': 55 | new_names.append('cls_token') 56 | elif name == 'block1': 57 | new_names.append('resnet.body.block1') 58 | elif name == 'block2': 59 | new_names.append('resnet.body.block2') 60 | elif name == 'block3': 61 | new_names.append('resnet.body.block3') 62 | elif name == 'conv_root': 63 | new_names.append('resnet.root.conv') 64 | elif name == 'gn_root': 65 | new_names.append('resnet.root.gn') 66 | elif name == 'conv_proj': 67 | new_names.append('downsample') 68 | else: 69 | new_names.append(name) 70 | return new_names 71 | 72 | def convert_jax_pytorch(keys, values): 73 | """ Convert jax model parameters with pytorch model parameters """ 74 | state_dict = {} 75 | for key, value in zip(keys, values): 76 | 77 | # convert name to torch names 78 | names = key.split('/') 79 | torch_names = replace_names(names) 80 | torch_key = '.'.join(w for w in torch_names) 81 | 82 | # convert values to tensor and check shapes 83 | tensor_value = torch.tensor(value, dtype=torch.float) 84 | # check shape 85 | num_dim = len(tensor_value.shape) 86 | 87 | if num_dim == 1: 88 | tensor_value = tensor_value.squeeze() 89 | elif num_dim == 2 and torch_names[-1] == 'weight': 90 | # for normal weight, transpose it 91 | tensor_value = tensor_value.T 92 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] in ['query', 'key', 'value']: 93 | feat_dim, num_heads, head_dim = tensor_value.shape 94 | # for multi head attention q/k/v weight 95 | tensor_value = tensor_value 96 | elif torch_names[-1] == 'weight' and 'gn' in torch_names[-2]: 97 | # for multi head attention q/k/v weight 98 | tensor_value = tensor_value.reshape(tensor_value.shape[-1]) 99 | elif num_dim == 2 and torch_names[-1] == 'bias' and torch_names[-2] in ['query', 'key', 'value']: 100 | # for multi head attention q/k/v bias 101 | tensor_value = tensor_value 102 | elif torch_names[-1] == 'bias' and 'gn' in torch_names[-2]: 103 | # for multi head attention q/k/v weight 104 | tensor_value = tensor_value.reshape(tensor_value.shape[-1]) 105 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] == 'out': 106 | # for multi head attention out weight 107 | tensor_value = tensor_value 108 | elif num_dim == 4 and torch_names[-1] == 'weight': 109 | tensor_value = tensor_value.permute(3, 2, 0, 1) 110 | 111 | # print("{}: {}".format(torch_key, tensor_value.shape)) 112 | state_dict[torch_key] = tensor_value 113 | return state_dict 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser( 117 | description='Convert JAX model to PyTorch model and save for easier future loading') 118 | parser.add_argument('--jax_checkpoint', type=str, default='pretrained_jax/ViT-B_32.npz', 119 | help='jax checkpoint file path') 120 | parser.add_argument('--output_file', type=str, default='pretrained_pytorch/ViT-B_32.pth', 121 | help='output PyTorch model file name') 122 | args = parser.parse_args() 123 | 124 | keys, values = load_jax(args.jax_checkpoint) 125 | state_dict = convert_jax_pytorch(keys, values) 126 | torch.save(state_dict, args.output_file) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # MODIFIED FROM 5 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/setup.py 6 | # Note: To use the 'upload' functionality of this file, you must: 7 | # $ pipenv install twine --dev 8 | 9 | import io 10 | import os 11 | import sys 12 | from shutil import rmtree 13 | 14 | from setuptools import find_packages, setup, Command 15 | 16 | # Package meta-data. 17 | NAME = 'vision_transformer_pytorch' 18 | DESCRIPTION = 'VisionTransformer implemented in PyTorch.' 19 | URL = 'https://github.com/tczhangzhi/VisionTransformer-PyTorch' 20 | EMAIL = 'zhangzhi2018@email.szu.edu.cn' 21 | AUTHOR = 'ZHANG Zhi' 22 | REQUIRES_PYTHON = '>=3.5.0' 23 | VERSION = '1.0.3' 24 | 25 | # What packages are required for this module to be executed? 26 | REQUIRED = [ 27 | 'torch>=1.5.0', # require torch.nn.GELU() 28 | 'numpy' 29 | ] 30 | 31 | # What packages are optional? 32 | EXTRAS = { 33 | # 'fancy feature': ['django'], 34 | } 35 | 36 | # The rest you shouldn't have to touch too much :) 37 | # ------------------------------------------------ 38 | # Except, perhaps the License and Trove Classifiers! 39 | # If you do change the License, remember to change the Trove Classifier for that! 40 | 41 | here = os.path.abspath(os.path.dirname(__file__)) 42 | 43 | # Import the README and use it as the long-description. 44 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 45 | try: 46 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 47 | long_description = '\n' + f.read() 48 | except FileNotFoundError: 49 | long_description = DESCRIPTION 50 | 51 | # Load the package's __version__.py module as a dictionary. 52 | about = {} 53 | if not VERSION: 54 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 55 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 56 | exec(f.read(), about) 57 | else: 58 | about['__version__'] = VERSION 59 | 60 | 61 | class UploadCommand(Command): 62 | """Support setup.py upload.""" 63 | 64 | description = 'Build and publish the package.' 65 | user_options = [] 66 | 67 | @staticmethod 68 | def status(s): 69 | """Prints things in bold.""" 70 | print('\033[1m{0}\033[0m'.format(s)) 71 | 72 | def initialize_options(self): 73 | pass 74 | 75 | def finalize_options(self): 76 | pass 77 | 78 | def run(self): 79 | try: 80 | self.status('Removing previous builds…') 81 | rmtree(os.path.join(here, 'dist')) 82 | except OSError: 83 | pass 84 | 85 | self.status('Building Source and Wheel (universal) distribution…') 86 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 87 | 88 | self.status('Uploading the package to PyPI via Twine…') 89 | os.system('twine upload dist/*') 90 | 91 | self.status('Pushing git tags…') 92 | os.system('git tag v{0}'.format(about['__version__'])) 93 | os.system('git push --tags') 94 | 95 | sys.exit() 96 | 97 | 98 | # Where the magic happens: 99 | setup( 100 | name=NAME, 101 | version=about['__version__'], 102 | description=DESCRIPTION, 103 | long_description=long_description, 104 | long_description_content_type='text/markdown', 105 | author=AUTHOR, 106 | author_email=EMAIL, 107 | python_requires=REQUIRES_PYTHON, 108 | url=URL, 109 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 110 | # py_modules=['model'], # If your package is a single module, use this instead of 'packages' 111 | install_requires=REQUIRED, 112 | extras_require=EXTRAS, 113 | include_package_data=True, 114 | license='Apache', 115 | classifiers=[ 116 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 117 | 'License :: OSI Approved :: Apache Software License', 118 | 'Programming Language :: Python', 119 | 'Programming Language :: Python :: 3', 120 | 'Programming Language :: Python :: 3.6', 121 | ], 122 | # $ setup.py publish support. 123 | cmdclass={ 124 | 'upload': UploadCommand, 125 | }, 126 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vision_transformer_pytorch import VisionTransformer 5 | 6 | net = VisionTransformer.from_pretrained('R50+ViT-B_16') 7 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/tests/test_model.py 3 | 4 | from collections import OrderedDict 5 | 6 | import pytest 7 | import torch 8 | import torch.nn as nn 9 | 10 | from vision_transformer_pytorch import VisionTransformer 11 | 12 | # -- fixtures ------------------------------------------------------------------------------------- 13 | 14 | 15 | @pytest.fixture( 16 | scope='module', 17 | params=['ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16']) 18 | def model(request): 19 | return request.param 20 | 21 | 22 | @pytest.fixture(scope='module', params=[True, False]) 23 | def pretrained(request): 24 | return request.param 25 | 26 | 27 | @pytest.fixture(scope='function') 28 | def net(model, pretrained): 29 | return VisionTransformer.from_pretrained( 30 | model) if pretrained else VisionTransformer.from_name(model) 31 | 32 | 33 | # -- tests ---------------------------------------------------------------------------------------- 34 | 35 | 36 | def test_forward(net): 37 | """Test `.forward()` doesn't throw an error""" 38 | data = torch.zeros((1, 3, *net.image_size)) 39 | output = net(data) 40 | assert not torch.isnan(output).any() 41 | 42 | 43 | @pytest.mark.parametrize('img_size', [224, 256, 512]) 44 | def test_hyper_params(model, img_size): 45 | """Test `.forward()` doesn't throw an error with different input size""" 46 | data = torch.zeros((1, 3, img_size, img_size)) 47 | net = VisionTransformer.from_name(model, image_size=img_size) 48 | output = net(data) 49 | assert not torch.isnan(output).any() 50 | 51 | 52 | def test_modify_classifier(net): 53 | """Test ability to modify fc modules of network""" 54 | classifier = nn.Linear(net._params.emb_dim, net._params.num_classes) 55 | 56 | net.classifier = classifier 57 | 58 | data = torch.zeros((2, 3, *net.image_size)) 59 | output = net(data) 60 | assert not torch.isnan(output).any() -------------------------------------------------------------------------------- /vision_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.3" 2 | from .model import VisionTransformer, VALID_MODELS 3 | from .utils import ( 4 | Params, 5 | vision_transformer, 6 | get_model_params, 7 | ) -------------------------------------------------------------------------------- /vision_transformer_pytorch/model.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py 3 | # https://github.com/asyml/vision-transformer-pytorch/blob/main/src/model.py 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .resnet import StdConv2d 11 | from .utils import (get_width_and_height_from_size, load_pretrained_weights, 12 | get_model_params) 13 | 14 | VALID_MODELS = ('ViT-B_16', 'ViT-B_32', 'ViT-L_16', 'ViT-L_32', 'R50+ViT-B_16') 15 | 16 | 17 | class PositionEmbs(nn.Module): 18 | def __init__(self, num_patches, emb_dim, dropout_rate=0.1): 19 | super(PositionEmbs, self).__init__() 20 | self.pos_embedding = nn.Parameter( 21 | torch.randn(1, num_patches + 1, emb_dim)) 22 | if dropout_rate > 0: 23 | self.dropout = nn.Dropout(dropout_rate) 24 | else: 25 | self.dropout = None 26 | 27 | def forward(self, x): 28 | out = x + self.pos_embedding 29 | 30 | if self.dropout: 31 | out = self.dropout(out) 32 | 33 | return out 34 | 35 | 36 | class MlpBlock(nn.Module): 37 | """ Transformer Feed-Forward Block """ 38 | def __init__(self, in_dim, mlp_dim, out_dim, dropout_rate=0.1): 39 | super(MlpBlock, self).__init__() 40 | 41 | # init layers 42 | self.fc1 = nn.Linear(in_dim, mlp_dim) 43 | self.fc2 = nn.Linear(mlp_dim, out_dim) 44 | self.act = nn.GELU() 45 | if dropout_rate > 0.0: 46 | self.dropout1 = nn.Dropout(dropout_rate) 47 | self.dropout2 = nn.Dropout(dropout_rate) 48 | else: 49 | self.dropout1 = None 50 | self.dropout2 = None 51 | 52 | def forward(self, x): 53 | 54 | out = self.fc1(x) 55 | out = self.act(out) 56 | if self.dropout1: 57 | out = self.dropout1(out) 58 | 59 | out = self.fc2(out) 60 | out = self.dropout2(out) 61 | return out 62 | 63 | 64 | class LinearGeneral(nn.Module): 65 | def __init__(self, in_dim=(768, ), feat_dim=(12, 64)): 66 | super(LinearGeneral, self).__init__() 67 | 68 | self.weight = nn.Parameter(torch.randn(*in_dim, *feat_dim)) 69 | self.bias = nn.Parameter(torch.zeros(*feat_dim)) 70 | 71 | def forward(self, x, dims): 72 | a = torch.tensordot(x, self.weight, dims=dims) + self.bias 73 | return a 74 | 75 | 76 | class SelfAttention(nn.Module): 77 | def __init__(self, in_dim, heads=8, dropout_rate=0.1): 78 | super(SelfAttention, self).__init__() 79 | self.heads = heads 80 | self.head_dim = in_dim // heads 81 | self.scale = self.head_dim**0.5 82 | 83 | self.query = LinearGeneral((in_dim, ), (self.heads, self.head_dim)) 84 | self.key = LinearGeneral((in_dim, ), (self.heads, self.head_dim)) 85 | self.value = LinearGeneral((in_dim, ), (self.heads, self.head_dim)) 86 | self.out = LinearGeneral((self.heads, self.head_dim), (in_dim, )) 87 | 88 | if dropout_rate > 0: 89 | self.dropout = nn.Dropout(dropout_rate) 90 | else: 91 | self.dropout = None 92 | 93 | def forward(self, x): 94 | b, n, _ = x.shape 95 | 96 | q = self.query(x, dims=([2], [0])) 97 | k = self.key(x, dims=([2], [0])) 98 | v = self.value(x, dims=([2], [0])) 99 | 100 | q = q.permute(0, 2, 1, 3) 101 | k = k.permute(0, 2, 1, 3) 102 | v = v.permute(0, 2, 1, 3) 103 | 104 | attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale 105 | attn_weights = F.softmax(attn_weights, dim=-1) 106 | out = torch.matmul(attn_weights, v) 107 | out = out.permute(0, 2, 1, 3) 108 | 109 | out = self.out(out, dims=([2, 3], [0, 1])) 110 | 111 | return out 112 | 113 | 114 | class EncoderBlock(nn.Module): 115 | def __init__(self, 116 | in_dim, 117 | mlp_dim, 118 | num_heads, 119 | dropout_rate=0.1, 120 | attn_dropout_rate=0.1): 121 | super(EncoderBlock, self).__init__() 122 | 123 | self.norm1 = nn.LayerNorm(in_dim) 124 | self.attn = SelfAttention(in_dim, 125 | heads=num_heads, 126 | dropout_rate=attn_dropout_rate) 127 | if dropout_rate > 0: 128 | self.dropout = nn.Dropout(dropout_rate) 129 | else: 130 | self.dropout = None 131 | self.norm2 = nn.LayerNorm(in_dim) 132 | self.mlp = MlpBlock(in_dim, mlp_dim, in_dim, dropout_rate) 133 | 134 | def forward(self, x): 135 | residual = x 136 | out = self.norm1(x) 137 | out = self.attn(out) 138 | if self.dropout: 139 | out = self.dropout(out) 140 | out += residual 141 | residual = out 142 | 143 | out = self.norm2(out) 144 | out = self.mlp(out) 145 | out += residual 146 | return out 147 | 148 | 149 | class Encoder(nn.Module): 150 | def __init__(self, 151 | num_patches, 152 | emb_dim, 153 | mlp_dim, 154 | num_layers=12, 155 | num_heads=12, 156 | dropout_rate=0.1, 157 | attn_dropout_rate=0.0): 158 | super(Encoder, self).__init__() 159 | 160 | # positional embedding 161 | self.pos_embedding = PositionEmbs(num_patches, emb_dim, dropout_rate) 162 | 163 | # encoder blocks 164 | in_dim = emb_dim 165 | self.encoder_layers = nn.ModuleList() 166 | for i in range(num_layers): 167 | layer = EncoderBlock(in_dim, mlp_dim, num_heads, dropout_rate, 168 | attn_dropout_rate) 169 | self.encoder_layers.append(layer) 170 | self.norm = nn.LayerNorm(in_dim) 171 | 172 | def forward(self, x): 173 | 174 | out = self.pos_embedding(x) 175 | 176 | for layer in self.encoder_layers: 177 | out = layer(out) 178 | 179 | out = self.norm(out) 180 | return out 181 | 182 | 183 | class VisionTransformer(nn.Module): 184 | """ Vision Transformer. 185 | Most easily loaded with the .from_name or .from_pretrained methods. 186 | Args: 187 | params (namedtuple): A set of Params. 188 | References: 189 | [1] https://arxiv.org/abs/2010.11929 (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale) 190 | Example: 191 | 192 | 193 | import torch 194 | >>> from vision_transformer_pytorch import VisionTransformer 195 | >>> inputs = torch.rand(1, 3, 256, 256) 196 | >>> model = VisionTransformer.from_pretrained('ViT-B_16') 197 | >>> model.eval() 198 | >>> outputs = model(inputs) 199 | """ 200 | def __init__(self, params=None): 201 | super(VisionTransformer, self).__init__() 202 | self._params = params 203 | 204 | if self._params.resnet: 205 | self.resnet = self._params.resnet() 206 | self.embedding = nn.Conv2d(self.resnet.width * 16, 207 | self._params.emb_dim, 208 | kernel_size=1, 209 | stride=1) 210 | else: 211 | self.embedding = nn.Conv2d(3, 212 | self._params.emb_dim, 213 | kernel_size=self.patch_size, 214 | stride=self.patch_size) 215 | # class token 216 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self._params.emb_dim)) 217 | 218 | # transformer 219 | self.transformer = Encoder( 220 | num_patches=self.num_patches, 221 | emb_dim=self._params.emb_dim, 222 | mlp_dim=self._params.mlp_dim, 223 | num_layers=self._params.num_layers, 224 | num_heads=self._params.num_heads, 225 | dropout_rate=self._params.dropout_rate, 226 | attn_dropout_rate=self._params.attn_dropout_rate) 227 | 228 | # classfier 229 | self.classifier = nn.Linear(self._params.emb_dim, 230 | self._params.num_classes) 231 | 232 | @property 233 | def image_size(self): 234 | return get_width_and_height_from_size(self._params.image_size) 235 | 236 | @property 237 | def patch_size(self): 238 | return get_width_and_height_from_size(self._params.patch_size) 239 | 240 | @property 241 | def num_patches(self): 242 | h, w = self.image_size 243 | fh, fw = self.patch_size 244 | if hasattr(self, 'resnet'): 245 | gh, gw = h // fh // self.resnet.downsample, w // fw // self.resnet.downsample 246 | else: 247 | gh, gw = h // fh, w // fw 248 | return gh * gw 249 | 250 | def extract_features(self, x): 251 | if hasattr(self, 'resnet'): 252 | x = self.resnet(x) 253 | 254 | emb = self.embedding(x) # (n, c, gh, gw) 255 | emb = emb.permute(0, 2, 3, 1) # (n, gh, hw, c) 256 | b, h, w, c = emb.shape 257 | emb = emb.reshape(b, h * w, c) 258 | 259 | # prepend class token 260 | cls_token = self.cls_token.repeat(b, 1, 1) 261 | emb = torch.cat([cls_token, emb], dim=1) 262 | 263 | # transformer 264 | feat = self.transformer(emb) 265 | return feat 266 | 267 | def forward(self, x): 268 | feat = self.extract_features(x) 269 | 270 | # classifier 271 | logits = self.classifier(feat[:, 0]) 272 | return logits 273 | 274 | @classmethod 275 | def from_name(cls, model_name, in_channels=3, **override_params): 276 | """create an vision transformer model according to name. 277 | Args: 278 | model_name (str): Name for vision transformer. 279 | in_channels (int): Input data's channel number. 280 | override_params (other key word params): 281 | Params to override model's global_params. 282 | Optional key: 283 | 'image_size', 'patch_size', 284 | 'emb_dim', 'mlp_dim', 285 | 'num_heads', 'num_layers', 286 | 'num_classes', 'attn_dropout_rate', 287 | 'dropout_rate' 288 | Returns: 289 | An vision transformer model. 290 | """ 291 | cls._check_model_name_is_valid(model_name) 292 | params = get_model_params(model_name, override_params) 293 | model = cls(params) 294 | model._change_in_channels(in_channels) 295 | return model 296 | 297 | @classmethod 298 | def from_pretrained(cls, 299 | model_name, 300 | weights_path=None, 301 | in_channels=3, 302 | num_classes=1000, 303 | **override_params): 304 | """create an vision transformer model according to name. 305 | Args: 306 | model_name (str): Name for vision transformer. 307 | weights_path (None or str): 308 | str: path to pretrained weights file on the local disk. 309 | None: use pretrained weights downloaded from the Internet. 310 | in_channels (int): Input data's channel number. 311 | num_classes (int): 312 | Number of categories for classification. 313 | It controls the output size for final linear layer. 314 | override_params (other key word params): 315 | Params to override model's global_params. 316 | Optional key: 317 | 'image_size', 'patch_size', 318 | 'emb_dim', 'mlp_dim', 319 | 'num_heads', 'num_layers', 320 | 'num_classes', 'attn_dropout_rate', 321 | 'dropout_rate' 322 | Returns: 323 | A pretrained vision transformer model. 324 | """ 325 | model = cls.from_name(model_name, 326 | num_classes=num_classes, 327 | **override_params) 328 | load_pretrained_weights(model, 329 | model_name, 330 | weights_path=weights_path, 331 | load_fc=(num_classes == 1000)) 332 | model._change_in_channels(in_channels) 333 | return model 334 | 335 | @classmethod 336 | def _check_model_name_is_valid(cls, model_name): 337 | """Validates model name. 338 | Args: 339 | model_name (str): Name for vision transformer. 340 | Returns: 341 | bool: Is a valid name or not. 342 | """ 343 | if model_name not in VALID_MODELS: 344 | raise ValueError('model_name should be one of: ' + 345 | ', '.join(VALID_MODELS)) 346 | 347 | def _change_in_channels(self, in_channels): 348 | """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. 349 | Args: 350 | in_channels (int): Input data's channel number. 351 | """ 352 | if in_channels != 3: 353 | if hasattr(self, 'resnet'): 354 | self.resnet.root['conv'] = StdConv2d(in_channels, 355 | self.resnet.width, 356 | kernel_size=7, 357 | stride=2, 358 | bias=False, 359 | padding=3) 360 | else: 361 | self.embedding = nn.Conv2d(in_channels, 362 | self._params.emb_dim, 363 | kernel_size=self.patch_size, 364 | stride=self.patch_size) 365 | -------------------------------------------------------------------------------- /vision_transformer_pytorch/resnet.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/google-research/vision_transformer/blob/master/vit_jax/models_resnet.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from os.path import join as pjoin 9 | from collections import OrderedDict 10 | 11 | 12 | def weight_standardize(w, dim, eps): 13 | """Subtracts mean and divides by standard deviation.""" 14 | w = w - torch.mean(w, dim=dim) 15 | w = w / (torch.std(w, dim=dim) + eps) 16 | return w 17 | 18 | 19 | def np2th(weights, conv=False): 20 | """Possibly convert HWIO to OIHW.""" 21 | if conv: 22 | weights = weights.transpose([3, 2, 0, 1]) 23 | return torch.from_numpy(weights) 24 | 25 | 26 | class StdConv2d(nn.Conv2d): 27 | def forward(self, x): 28 | w = weight_standardize(self.weight, [0, 1, 2], 1e-5) 29 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 30 | self.dilation, self.groups) 31 | 32 | 33 | def conv3x3(in_channels, out_channels, stride=1, groups=1, bias=False): 34 | return StdConv2d(in_channels, 35 | out_channels, 36 | kernel_size=3, 37 | stride=stride, 38 | padding=1, 39 | bias=bias, 40 | groups=groups) 41 | 42 | 43 | def conv1x1(in_channels, out_channels, stride=1, bias=False): 44 | return StdConv2d(in_channels, 45 | out_channels, 46 | kernel_size=1, 47 | stride=stride, 48 | padding=0, 49 | bias=bias) 50 | 51 | 52 | class PreActBottleneck(nn.Module): 53 | """Pre-activation (v2) bottleneck block. 54 | """ 55 | def __init__(self, 56 | in_channels, 57 | out_channels=None, 58 | mid_channels=None, 59 | stride=1): 60 | super().__init__() 61 | out_channels = out_channels or in_channels 62 | mid_channels = mid_channels or out_channels // 4 63 | 64 | self.gn1 = nn.GroupNorm(32, mid_channels, eps=1e-6) 65 | self.conv1 = conv1x1(in_channels, mid_channels, bias=False) 66 | self.gn2 = nn.GroupNorm(32, mid_channels, eps=1e-6) 67 | self.conv2 = conv3x3(mid_channels, mid_channels, stride, 68 | bias=False) # Original code has it on conv1!! 69 | self.gn3 = nn.GroupNorm(32, out_channels, eps=1e-6) 70 | self.conv3 = conv1x1(mid_channels, out_channels, bias=False) 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | if (stride != 1 or in_channels != out_channels): 74 | # Projection also with pre-activation according to paper. 75 | self.downsample = conv1x1(in_channels, 76 | out_channels, 77 | stride, 78 | bias=False) 79 | self.gn_proj = nn.GroupNorm(out_channels, out_channels) 80 | 81 | def forward(self, x): 82 | 83 | # Residual branch 84 | residual = x 85 | if hasattr(self, 'downsample'): 86 | residual = self.downsample(x) 87 | residual = self.gn_proj(residual) 88 | 89 | # Unit's branch 90 | y = self.relu(self.gn1(self.conv1(x))) 91 | y = self.relu(self.gn2(self.conv2(y))) 92 | y = self.gn3(self.conv3(y)) 93 | 94 | y = self.relu(residual + y) 95 | return y 96 | 97 | 98 | class ResNetV2(nn.Module): 99 | """Implementation of Pre-activation (v2) ResNet mode.""" 100 | def __init__(self, block_units, width_factor): 101 | super().__init__() 102 | width = int(64 * width_factor) 103 | self.width = width 104 | self.downsample = 16 # four stride=2 conv2d layer 105 | 106 | # The following will be unreadable if we split lines. 107 | # pylint: disable=line-too-long 108 | self.root = nn.Sequential( 109 | OrderedDict([('conv', 110 | StdConv2d(3, 111 | width, 112 | kernel_size=7, 113 | stride=2, 114 | bias=False, 115 | padding=3)), 116 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 117 | ('relu', nn.ReLU(inplace=True)), 118 | ('pool', 119 | nn.MaxPool2d(kernel_size=3, stride=2, padding=0))])) 120 | 121 | self.body = nn.Sequential( 122 | OrderedDict([ 123 | ('block1', 124 | nn.Sequential( 125 | OrderedDict([('unit1', 126 | PreActBottleneck(in_channels=width, 127 | out_channels=width * 4, 128 | mid_channels=width))] + 129 | [(f'unit{i:d}', 130 | PreActBottleneck(in_channels=width * 4, 131 | out_channels=width * 4, 132 | mid_channels=width)) 133 | for i in range(2, block_units[0] + 1)], ))), 134 | ('block2', 135 | nn.Sequential( 136 | OrderedDict([('unit1', 137 | PreActBottleneck(in_channels=width * 4, 138 | out_channels=width * 8, 139 | mid_channels=width * 2, 140 | stride=2))] + 141 | [(f'unit{i:d}', 142 | PreActBottleneck(in_channels=width * 8, 143 | out_channels=width * 8, 144 | mid_channels=width * 2)) 145 | for i in range(2, block_units[1] + 1)], ))), 146 | ('block3', 147 | nn.Sequential( 148 | OrderedDict([('unit1', 149 | PreActBottleneck(in_channels=width * 8, 150 | out_channels=width * 16, 151 | mid_channels=width * 4, 152 | stride=2))] + 153 | [(f'unit{i:d}', 154 | PreActBottleneck(in_channels=width * 16, 155 | out_channels=width * 16, 156 | mid_channels=width * 4)) 157 | for i in range(2, block_units[2] + 1)], ))), 158 | ])) 159 | 160 | def forward(self, x): 161 | x = self.root(x) 162 | x = self.body(x) 163 | return x 164 | 165 | 166 | def resnet50(): 167 | return ResNetV2(block_units=(3, 4, 9), width_factor=1) 168 | -------------------------------------------------------------------------------- /vision_transformer_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | # MODIFIED FROM 2 | # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py 3 | 4 | import re 5 | import math 6 | import torch 7 | import collections 8 | 9 | from torch import nn 10 | from functools import partial 11 | from torch.utils import model_zoo 12 | from torch.nn import functional as F 13 | 14 | from .resnet import resnet50 15 | 16 | ################################################################################ 17 | ### Help functions for model architecture 18 | ################################################################################ 19 | 20 | # Params: namedtuple 21 | # get_width_and_height_from_size and calculate_output_image_size 22 | 23 | # Parameters for the entire model (stem, all blocks, and head) 24 | Params = collections.namedtuple('Params', [ 25 | 'image_size', 'patch_size', 'emb_dim', 'mlp_dim', 'num_heads', 'num_layers', 26 | 'num_classes', 'attn_dropout_rate', 'dropout_rate', 'resnet' 27 | ]) 28 | 29 | # Set Params and BlockArgs's defaults 30 | Params.__new__.__defaults__ = (None, ) * len(Params._fields) 31 | 32 | 33 | def get_width_and_height_from_size(x): 34 | """Obtain height and width from x. 35 | Args: 36 | x (int, tuple or list): Data size. 37 | Returns: 38 | size: A tuple or list (H,W). 39 | """ 40 | if isinstance(x, int): 41 | return x, x 42 | if isinstance(x, list) or isinstance(x, tuple): 43 | return x 44 | else: 45 | raise TypeError() 46 | 47 | 48 | ################################################################################ 49 | ### Helper functions for loading model params 50 | ################################################################################ 51 | 52 | # get_model_params and efficientnet: 53 | # Functions to get BlockArgs and GlobalParams for efficientnet 54 | # url_map and url_map_advprop: Dicts of url_map for pretrained weights 55 | # load_pretrained_weights: A function to load pretrained weights 56 | 57 | 58 | def vision_transformer(model_name): 59 | """Create Params for vision transformer model. 60 | Args: 61 | model_name (str): Model name to be queried. 62 | Returns: 63 | Params(params_dict[model_name]) 64 | """ 65 | 66 | params_dict = { 67 | 'ViT-B_16': (384, 16, 768, 3072, 12, 12, 1000, 0.0, 0.1, None), 68 | 'ViT-B_32': (384, 32, 768, 3072, 12, 12, 1000, 0.0, 0.1, None), 69 | 'ViT-L_16': (384, 16, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None), 70 | 'ViT-L_32': (384, 32, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None), 71 | 'R50+ViT-B_16': (384, 1, 768, 3072, 12, 12, 1000, 0.0, 0.1, resnet50), 72 | } 73 | image_size, patch_size, emb_dim, mlp_dim, num_heads, num_layers, num_classes, attn_dropout_rate, dropout_rate, resnet = params_dict[ 74 | model_name] 75 | params = Params(image_size=image_size, 76 | patch_size=patch_size, 77 | emb_dim=emb_dim, 78 | mlp_dim=mlp_dim, 79 | num_heads=num_heads, 80 | num_layers=num_layers, 81 | num_classes=num_classes, 82 | attn_dropout_rate=attn_dropout_rate, 83 | dropout_rate=dropout_rate, 84 | resnet=resnet) 85 | 86 | return params 87 | 88 | 89 | def get_model_params(model_name, override_params): 90 | """Get the block args and global params for a given model name. 91 | Args: 92 | model_name (str): Model's name. 93 | override_params (dict): A dict to modify params. 94 | Returns: 95 | params 96 | """ 97 | params = vision_transformer(model_name) 98 | 99 | if override_params: 100 | # ValueError will be raised here if override_params has fields not included in params. 101 | params = params._replace(**override_params) 102 | return params 103 | 104 | 105 | # train with Standard methods 106 | # check more details in paper(An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale) 107 | url_map = { 108 | 'ViT-B_16': 109 | 'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_16_imagenet21k_imagenet2012.pth', 110 | 'ViT-B_32': 111 | 'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_32_imagenet21k_imagenet2012.pth', 112 | 'ViT-L_16': 113 | 'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_16_imagenet21k_imagenet2012.pth', 114 | 'ViT-L_32': 115 | 'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_32_imagenet21k_imagenet2012.pth', 116 | 'R50+ViT-B_16': 117 | 'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/R50+ViT-B_16_imagenet21k_imagenet2012.pth', 118 | } 119 | 120 | 121 | def load_pretrained_weights(model, 122 | model_name, 123 | weights_path=None, 124 | load_fc=True, 125 | advprop=False): 126 | """Loads pretrained weights from weights path or download using url. 127 | Args: 128 | model (Module): The whole model of vision transformer. 129 | model_name (str): Model name of vision transformer. 130 | weights_path (None or str): 131 | str: path to pretrained weights file on the local disk. 132 | None: use pretrained weights downloaded from the Internet. 133 | load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. 134 | """ 135 | if isinstance(weights_path, str): 136 | state_dict = torch.load(weights_path) 137 | else: 138 | state_dict = model_zoo.load_url(url_map[model_name]) 139 | 140 | if load_fc: 141 | ret = model.load_state_dict(state_dict, strict=False) 142 | assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format( 143 | ret.missing_keys) 144 | else: 145 | state_dict.pop('classifier.weight') 146 | state_dict.pop('classifier.bias') 147 | ret = model.load_state_dict(state_dict, strict=False) 148 | assert set(ret.missing_keys) == set([ 149 | 'classifier.weight', 'classifier.bias' 150 | ]), 'Missing keys when loading pretrained weights: {}'.format( 151 | ret.missing_keys) 152 | assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( 153 | ret.unexpected_keys) 154 | 155 | print('Loaded pretrained weights for {}'.format(model_name)) --------------------------------------------------------------------------------