├── .gitignore ├── LICENSE ├── README.md ├── img ├── figure1.png ├── figure2.png └── figure3.png ├── models ├── configs.py ├── modeling.py └── modeling_resnet.py ├── requirements.txt ├── train.py ├── utils ├── data_utils.py ├── dist_util.py └── scheduler.py └── visualize_attention_map.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | !/checkpoint/ 143 | !/data/ 144 | !/logs/ 145 | !/output/ 146 | .idea 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jeonsworld 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Transformer 2 | Pytorch reimplementation of [Google's repository for the ViT model](https://github.com/google-research/vision_transformer) that was released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 3 | 4 | This paper show that Transformers applied directly to image patches and pre-trained on large datasets work really well on image recognition task. 5 | 6 | ![fig1](./img/figure1.png) 7 | 8 | Vision Transformer achieve State-of-the-Art in image recognition task with standard Transformer encoder and fixed-size patches. In order to perform classification, author use the standard approach of adding an extra learnable "classification token" to the sequence. 9 | 10 | ![fig2](./img/figure2.png) 11 | 12 | 13 | ## Usage 14 | ### 1. Download Pre-trained model (Google's Official Checkpoint) 15 | * [Available models](https://console.cloud.google.com/storage/vit_models/): ViT-B_16(**85.8M**), R50+ViT-B_16(**97.96M**), ViT-B_32(**87.5M**), ViT-L_16(**303.4M**), ViT-L_32(**305.5M**), ViT-H_14(**630.8M**) 16 | * imagenet21k pre-train models 17 | * ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14 18 | * imagenet21k pre-train + imagenet2012 fine-tuned models 19 | * ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32 20 | * Hybrid Model([Resnet50](https://github.com/google-research/big_transfer) + Transformer) 21 | * R50-ViT-B_16 22 | ``` 23 | # imagenet21k pre-train 24 | wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz 25 | 26 | # imagenet21k pre-train + imagenet2012 fine-tuning 27 | wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz 28 | 29 | ``` 30 | 31 | ### 2. Train Model 32 | ``` 33 | python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz 34 | ``` 35 | CIFAR-10 and CIFAR-100 are automatically download and train. In order to use a different dataset you need to customize [data_utils.py](./utils/data_utils.py). 36 | 37 | The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of `--gradient_accumulation_steps`. 38 | 39 | Also can use [Automatic Mixed Precision(Amp)](https://nvidia.github.io/apex/amp.html) to reduce memory usage and train faster 40 | ``` 41 | python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --fp16 --fp16_opt_level O2 42 | ``` 43 | 44 | 45 | 46 | ## Results 47 | To verify that the converted model weight is correct, we simply compare it with the author's experimental results. We trained using mixed precision, and `--fp16_opt_level` was set to O2. 48 | 49 | ### imagenet-21k 50 | * [**tensorboard**](https://tensorboard.dev/experiment/Oz9GmmQIQCOEr4xbdr8O3Q) 51 | 52 | | model | dataset | resolution | acc(official) | acc(this repo) | time | 53 | |:------------:|:---------:|:----------:|:-------------:|:--------------:|:-------:| 54 | | ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9908 | 3h 13m | 55 | | ViT-B_16 | CIFAR-10 | 384x384 | 0.9903 | 0.9906 | 12h 25m | 56 | | ViT_B_16 | CIFAR-100 | 224x224 | - | 0.923 | 3h 9m | 57 | | ViT_B_16 | CIFAR-100 | 384x384 | 0.9264 | 0.9228 | 12h 31m | 58 | | R50-ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9892 | 4h 23m | 59 | | R50-ViT-B_16 | CIFAR-10 | 384x384 | 0.99 | 0.9904 | 15h 40m | 60 | | R50-ViT-B_16 | CIFAR-100 | 224x224 | - | 0.9231 | 4h 18m | 61 | | R50-ViT-B_16 | CIFAR-100 | 384x384 | 0.9231 | 0.9197 | 15h 53m | 62 | | ViT_L_32 | CIFAR-10 | 224x224 | - | 0.9903 | 2h 11m | 63 | | ViT_L_32 | CIFAR-100 | 224x224 | - | 0.9276 | 2h 9m | 64 | | ViT_H_14 | CIFAR-100 | 224x224 | - | WIP | | 65 | 66 | 67 | ### imagenet-21k + imagenet2012 68 | * [**tensorboard**](https://tensorboard.dev/experiment/CXOzjFRqTM6aLCk0jNXgAw/#scalars) 69 | 70 | | model | dataset | resolution | acc | 71 | |:------------:|:---------:|:----------:|:------:| 72 | | ViT-B_16-224 | CIFAR-10 | 224x224 | 0.99 | 73 | | ViT_B_16-224 | CIFAR-100 | 224x224 | 0.9245 | 74 | | ViT-L_32 | CIFAR-10 | 224x224 | 0.9903 | 75 | | ViT-L_32 | CIFAR-100 | 224x224 | 0.9285 | 76 | 77 | 78 | ### shorter train 79 | * In the experiment below, we used a resolution size (224x224). 80 | * [**tensorboard**](https://tensorboard.dev/experiment/lpknnMpHRT2qpVrSZi10Ag/#scalars) 81 | 82 | | upstream | model | dataset | total_steps /warmup_steps | acc(official) | acc(this repo) | 83 | |:-----------:|:--------:|:---------:|:-------------------------:|:-------------:|:--------------:| 84 | | imagenet21k | ViT-B_16 | CIFAR-10 | 500/100 | 0.9859 | 0.9859 | 85 | | imagenet21k | ViT-B_16 | CIFAR-10 | 1000/100 | 0.9886 | 0.9878 | 86 | | imagenet21k | ViT-B_16 | CIFAR-100 | 500/100 | 0.8917 | 0.9072 | 87 | | imagenet21k | ViT-B_16 | CIFAR-100 | 1000/100 | 0.9115 | 0.9216 | 88 | 89 | 90 | ## Visualization 91 | The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module. 92 | The attention map for the input image can be visualized through the attention score of self-attention. 93 | 94 | Visualization code can be found at [visualize_attention_map](./visualize_attention_map.ipynb). 95 | 96 | ![fig3](./img/figure3.png) 97 | 98 | 99 | ## Reference 100 | * [Google ViT](https://github.com/google-research/vision_transformer) 101 | * [Pytorch Image Models(timm)](https://github.com/rwightman/pytorch-image-models) 102 | 103 | 104 | ## Citations 105 | 106 | ```bibtex 107 | @article{dosovitskiy2020, 108 | title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}, 109 | author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil}, 110 | journal={arXiv preprint arXiv:2010.11929}, 111 | year={2020} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /img/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeonsworld/ViT-pytorch/460a162767de1722a014ed2261463dbbc01196b6/img/figure1.png -------------------------------------------------------------------------------- /img/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeonsworld/ViT-pytorch/460a162767de1722a014ed2261463dbbc01196b6/img/figure2.png -------------------------------------------------------------------------------- /img/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeonsworld/ViT-pytorch/460a162767de1722a014ed2261463dbbc01196b6/img/figure3.png -------------------------------------------------------------------------------- /models/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | 18 | def get_testing(): 19 | """Returns a minimal configuration for testing.""" 20 | config = ml_collections.ConfigDict() 21 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 22 | config.hidden_size = 1 23 | config.transformer = ml_collections.ConfigDict() 24 | config.transformer.mlp_dim = 1 25 | config.transformer.num_heads = 1 26 | config.transformer.num_layers = 1 27 | config.transformer.attention_dropout_rate = 0.0 28 | config.transformer.dropout_rate = 0.1 29 | config.classifier = 'token' 30 | config.representation_size = None 31 | return config 32 | 33 | 34 | def get_b16_config(): 35 | """Returns the ViT-B/16 configuration.""" 36 | config = ml_collections.ConfigDict() 37 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 38 | config.hidden_size = 768 39 | config.transformer = ml_collections.ConfigDict() 40 | config.transformer.mlp_dim = 3072 41 | config.transformer.num_heads = 12 42 | config.transformer.num_layers = 12 43 | config.transformer.attention_dropout_rate = 0.0 44 | config.transformer.dropout_rate = 0.1 45 | config.classifier = 'token' 46 | config.representation_size = None 47 | return config 48 | 49 | 50 | def get_r50_b16_config(): 51 | """Returns the Resnet50 + ViT-B/16 configuration.""" 52 | config = get_b16_config() 53 | del config.patches.size 54 | config.patches.grid = (14, 14) 55 | config.resnet = ml_collections.ConfigDict() 56 | config.resnet.num_layers = (3, 4, 9) 57 | config.resnet.width_factor = 1 58 | return config 59 | 60 | 61 | def get_b32_config(): 62 | """Returns the ViT-B/32 configuration.""" 63 | config = get_b16_config() 64 | config.patches.size = (32, 32) 65 | return config 66 | 67 | 68 | def get_l16_config(): 69 | """Returns the ViT-L/16 configuration.""" 70 | config = ml_collections.ConfigDict() 71 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 72 | config.hidden_size = 1024 73 | config.transformer = ml_collections.ConfigDict() 74 | config.transformer.mlp_dim = 4096 75 | config.transformer.num_heads = 16 76 | config.transformer.num_layers = 24 77 | config.transformer.attention_dropout_rate = 0.0 78 | config.transformer.dropout_rate = 0.1 79 | config.classifier = 'token' 80 | config.representation_size = None 81 | return config 82 | 83 | 84 | def get_l32_config(): 85 | """Returns the ViT-L/32 configuration.""" 86 | config = get_l16_config() 87 | config.patches.size = (32, 32) 88 | return config 89 | 90 | 91 | def get_h14_config(): 92 | """Returns the ViT-L/16 configuration.""" 93 | config = ml_collections.ConfigDict() 94 | config.patches = ml_collections.ConfigDict({'size': (14, 14)}) 95 | config.hidden_size = 1280 96 | config.transformer = ml_collections.ConfigDict() 97 | config.transformer.mlp_dim = 5120 98 | config.transformer.num_heads = 16 99 | config.transformer.num_layers = 32 100 | config.transformer.attention_dropout_rate = 0.0 101 | config.transformer.dropout_rate = 0.1 102 | config.classifier = 'token' 103 | config.representation_size = None 104 | return config 105 | -------------------------------------------------------------------------------- /models/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import copy 7 | import logging 8 | import math 9 | 10 | from os.path import join as pjoin 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 17 | from torch.nn.modules.utils import _pair 18 | from scipy import ndimage 19 | 20 | import models.configs as configs 21 | 22 | from .modeling_resnet import ResNetV2 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 29 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 30 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 31 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 32 | FC_0 = "MlpBlock_3/Dense_0" 33 | FC_1 = "MlpBlock_3/Dense_1" 34 | ATTENTION_NORM = "LayerNorm_0" 35 | MLP_NORM = "LayerNorm_2" 36 | 37 | 38 | def np2th(weights, conv=False): 39 | """Possibly convert HWIO to OIHW.""" 40 | if conv: 41 | weights = weights.transpose([3, 2, 0, 1]) 42 | return torch.from_numpy(weights) 43 | 44 | 45 | def swish(x): 46 | return x * torch.sigmoid(x) 47 | 48 | 49 | ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, config, vis): 54 | super(Attention, self).__init__() 55 | self.vis = vis 56 | self.num_attention_heads = config.transformer["num_heads"] 57 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 58 | self.all_head_size = self.num_attention_heads * self.attention_head_size 59 | 60 | self.query = Linear(config.hidden_size, self.all_head_size) 61 | self.key = Linear(config.hidden_size, self.all_head_size) 62 | self.value = Linear(config.hidden_size, self.all_head_size) 63 | 64 | self.out = Linear(config.hidden_size, config.hidden_size) 65 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 66 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 67 | 68 | self.softmax = Softmax(dim=-1) 69 | 70 | def transpose_for_scores(self, x): 71 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 72 | x = x.view(*new_x_shape) 73 | return x.permute(0, 2, 1, 3) 74 | 75 | def forward(self, hidden_states): 76 | mixed_query_layer = self.query(hidden_states) 77 | mixed_key_layer = self.key(hidden_states) 78 | mixed_value_layer = self.value(hidden_states) 79 | 80 | query_layer = self.transpose_for_scores(mixed_query_layer) 81 | key_layer = self.transpose_for_scores(mixed_key_layer) 82 | value_layer = self.transpose_for_scores(mixed_value_layer) 83 | 84 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 85 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 86 | attention_probs = self.softmax(attention_scores) 87 | weights = attention_probs if self.vis else None 88 | attention_probs = self.attn_dropout(attention_probs) 89 | 90 | context_layer = torch.matmul(attention_probs, value_layer) 91 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 92 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 93 | context_layer = context_layer.view(*new_context_layer_shape) 94 | attention_output = self.out(context_layer) 95 | attention_output = self.proj_dropout(attention_output) 96 | return attention_output, weights 97 | 98 | 99 | class Mlp(nn.Module): 100 | def __init__(self, config): 101 | super(Mlp, self).__init__() 102 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 103 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 104 | self.act_fn = ACT2FN["gelu"] 105 | self.dropout = Dropout(config.transformer["dropout_rate"]) 106 | 107 | self._init_weights() 108 | 109 | def _init_weights(self): 110 | nn.init.xavier_uniform_(self.fc1.weight) 111 | nn.init.xavier_uniform_(self.fc2.weight) 112 | nn.init.normal_(self.fc1.bias, std=1e-6) 113 | nn.init.normal_(self.fc2.bias, std=1e-6) 114 | 115 | def forward(self, x): 116 | x = self.fc1(x) 117 | x = self.act_fn(x) 118 | x = self.dropout(x) 119 | x = self.fc2(x) 120 | x = self.dropout(x) 121 | return x 122 | 123 | 124 | class Embeddings(nn.Module): 125 | """Construct the embeddings from patch, position embeddings. 126 | """ 127 | def __init__(self, config, img_size, in_channels=3): 128 | super(Embeddings, self).__init__() 129 | self.hybrid = None 130 | img_size = _pair(img_size) 131 | 132 | if config.patches.get("grid") is not None: 133 | grid_size = config.patches["grid"] 134 | patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) 135 | n_patches = (img_size[0] // 16) * (img_size[1] // 16) 136 | self.hybrid = True 137 | else: 138 | patch_size = _pair(config.patches["size"]) 139 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 140 | self.hybrid = False 141 | 142 | if self.hybrid: 143 | self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, 144 | width_factor=config.resnet.width_factor) 145 | in_channels = self.hybrid_model.width * 16 146 | self.patch_embeddings = Conv2d(in_channels=in_channels, 147 | out_channels=config.hidden_size, 148 | kernel_size=patch_size, 149 | stride=patch_size) 150 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) 151 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 152 | 153 | self.dropout = Dropout(config.transformer["dropout_rate"]) 154 | 155 | def forward(self, x): 156 | B = x.shape[0] 157 | cls_tokens = self.cls_token.expand(B, -1, -1) 158 | 159 | if self.hybrid: 160 | x = self.hybrid_model(x) 161 | x = self.patch_embeddings(x) 162 | x = x.flatten(2) 163 | x = x.transpose(-1, -2) 164 | x = torch.cat((cls_tokens, x), dim=1) 165 | 166 | embeddings = x + self.position_embeddings 167 | embeddings = self.dropout(embeddings) 168 | return embeddings 169 | 170 | 171 | class Block(nn.Module): 172 | def __init__(self, config, vis): 173 | super(Block, self).__init__() 174 | self.hidden_size = config.hidden_size 175 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 176 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 177 | self.ffn = Mlp(config) 178 | self.attn = Attention(config, vis) 179 | 180 | def forward(self, x): 181 | h = x 182 | x = self.attention_norm(x) 183 | x, weights = self.attn(x) 184 | x = x + h 185 | 186 | h = x 187 | x = self.ffn_norm(x) 188 | x = self.ffn(x) 189 | x = x + h 190 | return x, weights 191 | 192 | def load_from(self, weights, n_block): 193 | ROOT = f"Transformer/encoderblock_{n_block}" 194 | with torch.no_grad(): 195 | query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() 196 | key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() 197 | value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() 198 | out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() 199 | 200 | query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) 201 | key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) 202 | value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) 203 | out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) 204 | 205 | self.attn.query.weight.copy_(query_weight) 206 | self.attn.key.weight.copy_(key_weight) 207 | self.attn.value.weight.copy_(value_weight) 208 | self.attn.out.weight.copy_(out_weight) 209 | self.attn.query.bias.copy_(query_bias) 210 | self.attn.key.bias.copy_(key_bias) 211 | self.attn.value.bias.copy_(value_bias) 212 | self.attn.out.bias.copy_(out_bias) 213 | 214 | mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() 215 | mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() 216 | mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() 217 | mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() 218 | 219 | self.ffn.fc1.weight.copy_(mlp_weight_0) 220 | self.ffn.fc2.weight.copy_(mlp_weight_1) 221 | self.ffn.fc1.bias.copy_(mlp_bias_0) 222 | self.ffn.fc2.bias.copy_(mlp_bias_1) 223 | 224 | self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) 225 | self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) 226 | self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) 227 | self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) 228 | 229 | 230 | class Encoder(nn.Module): 231 | def __init__(self, config, vis): 232 | super(Encoder, self).__init__() 233 | self.vis = vis 234 | self.layer = nn.ModuleList() 235 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 236 | for _ in range(config.transformer["num_layers"]): 237 | layer = Block(config, vis) 238 | self.layer.append(copy.deepcopy(layer)) 239 | 240 | def forward(self, hidden_states): 241 | attn_weights = [] 242 | for layer_block in self.layer: 243 | hidden_states, weights = layer_block(hidden_states) 244 | if self.vis: 245 | attn_weights.append(weights) 246 | encoded = self.encoder_norm(hidden_states) 247 | return encoded, attn_weights 248 | 249 | 250 | class Transformer(nn.Module): 251 | def __init__(self, config, img_size, vis): 252 | super(Transformer, self).__init__() 253 | self.embeddings = Embeddings(config, img_size=img_size) 254 | self.encoder = Encoder(config, vis) 255 | 256 | def forward(self, input_ids): 257 | embedding_output = self.embeddings(input_ids) 258 | encoded, attn_weights = self.encoder(embedding_output) 259 | return encoded, attn_weights 260 | 261 | 262 | class VisionTransformer(nn.Module): 263 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 264 | super(VisionTransformer, self).__init__() 265 | self.num_classes = num_classes 266 | self.zero_head = zero_head 267 | self.classifier = config.classifier 268 | 269 | self.transformer = Transformer(config, img_size, vis) 270 | self.head = Linear(config.hidden_size, num_classes) 271 | 272 | def forward(self, x, labels=None): 273 | x, attn_weights = self.transformer(x) 274 | logits = self.head(x[:, 0]) 275 | 276 | if labels is not None: 277 | loss_fct = CrossEntropyLoss() 278 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) 279 | return loss 280 | else: 281 | return logits, attn_weights 282 | 283 | def load_from(self, weights): 284 | with torch.no_grad(): 285 | if self.zero_head: 286 | nn.init.zeros_(self.head.weight) 287 | nn.init.zeros_(self.head.bias) 288 | else: 289 | self.head.weight.copy_(np2th(weights["head/kernel"]).t()) 290 | self.head.bias.copy_(np2th(weights["head/bias"]).t()) 291 | 292 | self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) 293 | self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) 294 | self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"])) 295 | self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) 296 | self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) 297 | 298 | posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) 299 | posemb_new = self.transformer.embeddings.position_embeddings 300 | if posemb.size() == posemb_new.size(): 301 | self.transformer.embeddings.position_embeddings.copy_(posemb) 302 | else: 303 | logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) 304 | ntok_new = posemb_new.size(1) 305 | 306 | if self.classifier == "token": 307 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 308 | ntok_new -= 1 309 | else: 310 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 311 | 312 | gs_old = int(np.sqrt(len(posemb_grid))) 313 | gs_new = int(np.sqrt(ntok_new)) 314 | print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) 315 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 316 | 317 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 318 | posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) 319 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 320 | posemb = np.concatenate([posemb_tok, posemb_grid], axis=1) 321 | self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) 322 | 323 | for bname, block in self.transformer.encoder.named_children(): 324 | for uname, unit in block.named_children(): 325 | unit.load_from(weights, n_block=uname) 326 | 327 | if self.transformer.embeddings.hybrid: 328 | self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True)) 329 | gn_weight = np2th(weights["gn_root/scale"]).view(-1) 330 | gn_bias = np2th(weights["gn_root/bias"]).view(-1) 331 | self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) 332 | self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) 333 | 334 | for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): 335 | for uname, unit in block.named_children(): 336 | unit.load_from(weights, n_block=bname, n_unit=uname) 337 | 338 | 339 | CONFIGS = { 340 | 'ViT-B_16': configs.get_b16_config(), 341 | 'ViT-B_32': configs.get_b32_config(), 342 | 'ViT-L_16': configs.get_l16_config(), 343 | 'ViT-L_32': configs.get_l32_config(), 344 | 'ViT-H_14': configs.get_h14_config(), 345 | 'R50-ViT-B_16': configs.get_r50_b16_config(), 346 | 'testing': configs.get_testing(), 347 | } 348 | -------------------------------------------------------------------------------- /models/modeling_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" 17 | import math 18 | 19 | from os.path import join as pjoin 20 | 21 | from collections import OrderedDict # pylint: disable=g-importing-member 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | 28 | def np2th(weights, conv=False): 29 | """Possibly convert HWIO to OIHW.""" 30 | if conv: 31 | weights = weights.transpose([3, 2, 0, 1]) 32 | return torch.from_numpy(weights) 33 | 34 | 35 | class StdConv2d(nn.Conv2d): 36 | 37 | def forward(self, x): 38 | w = self.weight 39 | v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) 40 | w = (w - m) / torch.sqrt(v + 1e-5) 41 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 42 | self.dilation, self.groups) 43 | 44 | 45 | def conv3x3(cin, cout, stride=1, groups=1, bias=False): 46 | return StdConv2d(cin, cout, kernel_size=3, stride=stride, 47 | padding=1, bias=bias, groups=groups) 48 | 49 | 50 | def conv1x1(cin, cout, stride=1, bias=False): 51 | return StdConv2d(cin, cout, kernel_size=1, stride=stride, 52 | padding=0, bias=bias) 53 | 54 | 55 | class PreActBottleneck(nn.Module): 56 | """Pre-activation (v2) bottleneck block. 57 | """ 58 | 59 | def __init__(self, cin, cout=None, cmid=None, stride=1): 60 | super().__init__() 61 | cout = cout or cin 62 | cmid = cmid or cout//4 63 | 64 | self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) 65 | self.conv1 = conv1x1(cin, cmid, bias=False) 66 | self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) 67 | self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! 68 | self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) 69 | self.conv3 = conv1x1(cmid, cout, bias=False) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | if (stride != 1 or cin != cout): 73 | # Projection also with pre-activation according to paper. 74 | self.downsample = conv1x1(cin, cout, stride, bias=False) 75 | self.gn_proj = nn.GroupNorm(cout, cout) 76 | 77 | def forward(self, x): 78 | 79 | # Residual branch 80 | residual = x 81 | if hasattr(self, 'downsample'): 82 | residual = self.downsample(x) 83 | residual = self.gn_proj(residual) 84 | 85 | # Unit's branch 86 | y = self.relu(self.gn1(self.conv1(x))) 87 | y = self.relu(self.gn2(self.conv2(y))) 88 | y = self.gn3(self.conv3(y)) 89 | 90 | y = self.relu(residual + y) 91 | return y 92 | 93 | def load_from(self, weights, n_block, n_unit): 94 | conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) 95 | conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) 96 | conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) 97 | 98 | gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) 99 | gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) 100 | 101 | gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) 102 | gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) 103 | 104 | gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) 105 | gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) 106 | 107 | self.conv1.weight.copy_(conv1_weight) 108 | self.conv2.weight.copy_(conv2_weight) 109 | self.conv3.weight.copy_(conv3_weight) 110 | 111 | self.gn1.weight.copy_(gn1_weight.view(-1)) 112 | self.gn1.bias.copy_(gn1_bias.view(-1)) 113 | 114 | self.gn2.weight.copy_(gn2_weight.view(-1)) 115 | self.gn2.bias.copy_(gn2_bias.view(-1)) 116 | 117 | self.gn3.weight.copy_(gn3_weight.view(-1)) 118 | self.gn3.bias.copy_(gn3_bias.view(-1)) 119 | 120 | if hasattr(self, 'downsample'): 121 | proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) 122 | proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) 123 | proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) 124 | 125 | self.downsample.weight.copy_(proj_conv_weight) 126 | self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) 127 | self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) 128 | 129 | class ResNetV2(nn.Module): 130 | """Implementation of Pre-activation (v2) ResNet mode.""" 131 | 132 | def __init__(self, block_units, width_factor): 133 | super().__init__() 134 | width = int(64 * width_factor) 135 | self.width = width 136 | 137 | # The following will be unreadable if we split lines. 138 | # pylint: disable=line-too-long 139 | self.root = nn.Sequential(OrderedDict([ 140 | ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), 141 | ('gn', nn.GroupNorm(32, width, eps=1e-6)), 142 | ('relu', nn.ReLU(inplace=True)), 143 | ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) 144 | ])) 145 | 146 | self.body = nn.Sequential(OrderedDict([ 147 | ('block1', nn.Sequential(OrderedDict( 148 | [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + 149 | [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], 150 | ))), 151 | ('block2', nn.Sequential(OrderedDict( 152 | [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + 153 | [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], 154 | ))), 155 | ('block3', nn.Sequential(OrderedDict( 156 | [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + 157 | [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], 158 | ))), 159 | ])) 160 | 161 | def forward(self, x): 162 | x = self.root(x) 163 | x = self.body(x) 164 | return x 165 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm 4 | tensorboard 5 | ml-collections 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import logging 5 | import argparse 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | from datetime import timedelta 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | from tqdm import tqdm 16 | from torch.utils.tensorboard import SummaryWriter 17 | from apex import amp 18 | from apex.parallel import DistributedDataParallel as DDP 19 | 20 | from models.modeling import VisionTransformer, CONFIGS 21 | from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule 22 | from utils.data_utils import get_loader 23 | from utils.dist_util import get_world_size 24 | 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class AverageMeter(object): 30 | """Computes and stores the average and current value""" 31 | def __init__(self): 32 | self.reset() 33 | 34 | def reset(self): 35 | self.val = 0 36 | self.avg = 0 37 | self.sum = 0 38 | self.count = 0 39 | 40 | def update(self, val, n=1): 41 | self.val = val 42 | self.sum += val * n 43 | self.count += n 44 | self.avg = self.sum / self.count 45 | 46 | 47 | def simple_accuracy(preds, labels): 48 | return (preds == labels).mean() 49 | 50 | 51 | def save_model(args, model): 52 | model_to_save = model.module if hasattr(model, 'module') else model 53 | model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) 54 | torch.save(model_to_save.state_dict(), model_checkpoint) 55 | logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) 56 | 57 | 58 | def setup(args): 59 | # Prepare model 60 | config = CONFIGS[args.model_type] 61 | 62 | num_classes = 10 if args.dataset == "cifar10" else 100 63 | 64 | model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes) 65 | model.load_from(np.load(args.pretrained_dir)) 66 | model.to(args.device) 67 | num_params = count_parameters(model) 68 | 69 | logger.info("{}".format(config)) 70 | logger.info("Training parameters %s", args) 71 | logger.info("Total Parameter: \t%2.1fM" % num_params) 72 | print(num_params) 73 | return args, model 74 | 75 | 76 | def count_parameters(model): 77 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 78 | return params/1000000 79 | 80 | 81 | def set_seed(args): 82 | random.seed(args.seed) 83 | np.random.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | if args.n_gpu > 0: 86 | torch.cuda.manual_seed_all(args.seed) 87 | 88 | 89 | def valid(args, model, writer, test_loader, global_step): 90 | # Validation! 91 | eval_losses = AverageMeter() 92 | 93 | logger.info("***** Running Validation *****") 94 | logger.info(" Num steps = %d", len(test_loader)) 95 | logger.info(" Batch size = %d", args.eval_batch_size) 96 | 97 | model.eval() 98 | all_preds, all_label = [], [] 99 | epoch_iterator = tqdm(test_loader, 100 | desc="Validating... (loss=X.X)", 101 | bar_format="{l_bar}{r_bar}", 102 | dynamic_ncols=True, 103 | disable=args.local_rank not in [-1, 0]) 104 | loss_fct = torch.nn.CrossEntropyLoss() 105 | for step, batch in enumerate(epoch_iterator): 106 | batch = tuple(t.to(args.device) for t in batch) 107 | x, y = batch 108 | with torch.no_grad(): 109 | logits = model(x)[0] 110 | 111 | eval_loss = loss_fct(logits, y) 112 | eval_losses.update(eval_loss.item()) 113 | 114 | preds = torch.argmax(logits, dim=-1) 115 | 116 | if len(all_preds) == 0: 117 | all_preds.append(preds.detach().cpu().numpy()) 118 | all_label.append(y.detach().cpu().numpy()) 119 | else: 120 | all_preds[0] = np.append( 121 | all_preds[0], preds.detach().cpu().numpy(), axis=0 122 | ) 123 | all_label[0] = np.append( 124 | all_label[0], y.detach().cpu().numpy(), axis=0 125 | ) 126 | epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) 127 | 128 | all_preds, all_label = all_preds[0], all_label[0] 129 | accuracy = simple_accuracy(all_preds, all_label) 130 | 131 | logger.info("\n") 132 | logger.info("Validation Results") 133 | logger.info("Global Steps: %d" % global_step) 134 | logger.info("Valid Loss: %2.5f" % eval_losses.avg) 135 | logger.info("Valid Accuracy: %2.5f" % accuracy) 136 | 137 | writer.add_scalar("test/accuracy", scalar_value=accuracy, global_step=global_step) 138 | return accuracy 139 | 140 | 141 | def train(args, model): 142 | """ Train the model """ 143 | if args.local_rank in [-1, 0]: 144 | os.makedirs(args.output_dir, exist_ok=True) 145 | writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) 146 | 147 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 148 | 149 | # Prepare dataset 150 | train_loader, test_loader = get_loader(args) 151 | 152 | # Prepare optimizer and scheduler 153 | optimizer = torch.optim.SGD(model.parameters(), 154 | lr=args.learning_rate, 155 | momentum=0.9, 156 | weight_decay=args.weight_decay) 157 | t_total = args.num_steps 158 | if args.decay_type == "cosine": 159 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 160 | else: 161 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 162 | 163 | if args.fp16: 164 | model, optimizer = amp.initialize(models=model, 165 | optimizers=optimizer, 166 | opt_level=args.fp16_opt_level) 167 | amp._amp_state.loss_scalers[0]._loss_scale = 2**20 168 | 169 | # Distributed training 170 | if args.local_rank != -1: 171 | model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) 172 | 173 | # Train! 174 | logger.info("***** Running training *****") 175 | logger.info(" Total optimization steps = %d", args.num_steps) 176 | logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) 177 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 178 | args.train_batch_size * args.gradient_accumulation_steps * ( 179 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 180 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 181 | 182 | model.zero_grad() 183 | set_seed(args) # Added here for reproducibility (even between python 2 and 3) 184 | losses = AverageMeter() 185 | global_step, best_acc = 0, 0 186 | while True: 187 | model.train() 188 | epoch_iterator = tqdm(train_loader, 189 | desc="Training (X / X Steps) (loss=X.X)", 190 | bar_format="{l_bar}{r_bar}", 191 | dynamic_ncols=True, 192 | disable=args.local_rank not in [-1, 0]) 193 | for step, batch in enumerate(epoch_iterator): 194 | batch = tuple(t.to(args.device) for t in batch) 195 | x, y = batch 196 | loss = model(x, y) 197 | 198 | if args.gradient_accumulation_steps > 1: 199 | loss = loss / args.gradient_accumulation_steps 200 | if args.fp16: 201 | with amp.scale_loss(loss, optimizer) as scaled_loss: 202 | scaled_loss.backward() 203 | else: 204 | loss.backward() 205 | 206 | if (step + 1) % args.gradient_accumulation_steps == 0: 207 | losses.update(loss.item()*args.gradient_accumulation_steps) 208 | if args.fp16: 209 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 210 | else: 211 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 212 | scheduler.step() 213 | optimizer.step() 214 | optimizer.zero_grad() 215 | global_step += 1 216 | 217 | epoch_iterator.set_description( 218 | "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) 219 | ) 220 | if args.local_rank in [-1, 0]: 221 | writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) 222 | writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) 223 | if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: 224 | accuracy = valid(args, model, writer, test_loader, global_step) 225 | if best_acc < accuracy: 226 | save_model(args, model) 227 | best_acc = accuracy 228 | model.train() 229 | 230 | if global_step % t_total == 0: 231 | break 232 | losses.reset() 233 | if global_step % t_total == 0: 234 | break 235 | 236 | if args.local_rank in [-1, 0]: 237 | writer.close() 238 | logger.info("Best Accuracy: \t%f" % best_acc) 239 | logger.info("End Training!") 240 | 241 | 242 | def main(): 243 | parser = argparse.ArgumentParser() 244 | # Required parameters 245 | parser.add_argument("--name", required=True, 246 | help="Name of this run. Used for monitoring.") 247 | parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10", 248 | help="Which downstream task.") 249 | parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16", 250 | "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"], 251 | default="ViT-B_16", 252 | help="Which variant to use.") 253 | parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz", 254 | help="Where to search for pretrained ViT models.") 255 | parser.add_argument("--output_dir", default="output", type=str, 256 | help="The output directory where checkpoints will be written.") 257 | 258 | parser.add_argument("--img_size", default=224, type=int, 259 | help="Resolution size") 260 | parser.add_argument("--train_batch_size", default=512, type=int, 261 | help="Total batch size for training.") 262 | parser.add_argument("--eval_batch_size", default=64, type=int, 263 | help="Total batch size for eval.") 264 | parser.add_argument("--eval_every", default=100, type=int, 265 | help="Run prediction on validation set every so many steps." 266 | "Will always run one evaluation at the end of training.") 267 | 268 | parser.add_argument("--learning_rate", default=3e-2, type=float, 269 | help="The initial learning rate for SGD.") 270 | parser.add_argument("--weight_decay", default=0, type=float, 271 | help="Weight deay if we apply some.") 272 | parser.add_argument("--num_steps", default=10000, type=int, 273 | help="Total number of training epochs to perform.") 274 | parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", 275 | help="How to decay the learning rate.") 276 | parser.add_argument("--warmup_steps", default=500, type=int, 277 | help="Step of training to perform learning rate warmup for.") 278 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 279 | help="Max gradient norm.") 280 | 281 | parser.add_argument("--local_rank", type=int, default=-1, 282 | help="local_rank for distributed training on gpus") 283 | parser.add_argument('--seed', type=int, default=42, 284 | help="random seed for initialization") 285 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 286 | help="Number of updates steps to accumulate before performing a backward/update pass.") 287 | parser.add_argument('--fp16', action='store_true', 288 | help="Whether to use 16-bit float precision instead of 32-bit") 289 | parser.add_argument('--fp16_opt_level', type=str, default='O2', 290 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 291 | "See details at https://nvidia.github.io/apex/amp.html") 292 | parser.add_argument('--loss_scale', type=float, default=0, 293 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 294 | "0 (default value): dynamic loss scaling.\n" 295 | "Positive power of 2: static loss scaling value.\n") 296 | args = parser.parse_args() 297 | 298 | # Setup CUDA, GPU & distributed training 299 | if args.local_rank == -1: 300 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 301 | args.n_gpu = torch.cuda.device_count() 302 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 303 | torch.cuda.set_device(args.local_rank) 304 | device = torch.device("cuda", args.local_rank) 305 | torch.distributed.init_process_group(backend='nccl', 306 | timeout=timedelta(minutes=60)) 307 | args.n_gpu = 1 308 | args.device = device 309 | 310 | # Setup logging 311 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 312 | datefmt='%m/%d/%Y %H:%M:%S', 313 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 314 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % 315 | (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) 316 | 317 | # Set seed 318 | set_seed(args) 319 | 320 | # Model & Tokenizer Setup 321 | args, model = setup(args) 322 | 323 | # Training 324 | train(args, model) 325 | 326 | 327 | if __name__ == "__main__": 328 | main() 329 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from torchvision import transforms, datasets 6 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_loader(args): 13 | if args.local_rank not in [-1, 0]: 14 | torch.distributed.barrier() 15 | 16 | transform_train = transforms.Compose([ 17 | transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 20 | ]) 21 | transform_test = transforms.Compose([ 22 | transforms.Resize((args.img_size, args.img_size)), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 25 | ]) 26 | 27 | if args.dataset == "cifar10": 28 | trainset = datasets.CIFAR10(root="./data", 29 | train=True, 30 | download=True, 31 | transform=transform_train) 32 | testset = datasets.CIFAR10(root="./data", 33 | train=False, 34 | download=True, 35 | transform=transform_test) if args.local_rank in [-1, 0] else None 36 | 37 | else: 38 | trainset = datasets.CIFAR100(root="./data", 39 | train=True, 40 | download=True, 41 | transform=transform_train) 42 | testset = datasets.CIFAR100(root="./data", 43 | train=False, 44 | download=True, 45 | transform=transform_test) if args.local_rank in [-1, 0] else None 46 | if args.local_rank == 0: 47 | torch.distributed.barrier() 48 | 49 | train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) 50 | test_sampler = SequentialSampler(testset) 51 | train_loader = DataLoader(trainset, 52 | sampler=train_sampler, 53 | batch_size=args.train_batch_size, 54 | num_workers=4, 55 | pin_memory=True) 56 | test_loader = DataLoader(testset, 57 | sampler=test_sampler, 58 | batch_size=args.eval_batch_size, 59 | num_workers=4, 60 | pin_memory=True) if testset is not None else None 61 | 62 | return train_loader, test_loader 63 | -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | def get_rank(): 4 | if not dist.is_available(): 5 | return 0 6 | if not dist.is_initialized(): 7 | return 0 8 | return dist.get_rank() 9 | 10 | def get_world_size(): 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | def is_main_process(): 18 | return get_rank() == 0 19 | 20 | def format_step(step): 21 | if isinstance(step, str): 22 | return step 23 | s = "" 24 | if len(step) > 0: 25 | s += "Training Epoch: {} ".format(step[0]) 26 | if len(step) > 1: 27 | s += "Training Iteration: {} ".format(step[1]) 28 | if len(step) > 2: 29 | s += "Validation Iteration: {} ".format(step[2]) 30 | return s 31 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class ConstantLRSchedule(LambdaLR): 9 | """ Constant learning rate schedule. 10 | """ 11 | def __init__(self, optimizer, last_epoch=-1): 12 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 13 | 14 | 15 | class WarmupConstantSchedule(LambdaLR): 16 | """ Linear warmup and then constant. 17 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 18 | Keeps learning rate schedule equal to 1. after warmup_steps. 19 | """ 20 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 21 | self.warmup_steps = warmup_steps 22 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 23 | 24 | def lr_lambda(self, step): 25 | if step < self.warmup_steps: 26 | return float(step) / float(max(1.0, self.warmup_steps)) 27 | return 1. 28 | 29 | 30 | class WarmupLinearSchedule(LambdaLR): 31 | """ Linear warmup and then linear decay. 32 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 33 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 34 | """ 35 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 36 | self.warmup_steps = warmup_steps 37 | self.t_total = t_total 38 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 39 | 40 | def lr_lambda(self, step): 41 | if step < self.warmup_steps: 42 | return float(step) / float(max(1, self.warmup_steps)) 43 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 44 | 45 | 46 | class WarmupCosineSchedule(LambdaLR): 47 | """ Linear warmup and then cosine decay. 48 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 49 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 50 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 51 | """ 52 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 53 | self.warmup_steps = warmup_steps 54 | self.t_total = t_total 55 | self.cycles = cycles 56 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1.0, self.warmup_steps)) 61 | # progress after warmup 62 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 63 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 64 | --------------------------------------------------------------------------------