├── .gitignore ├── LICENSE ├── README.md ├── img ├── mixer_figure.png └── mixer_result.png ├── models ├── configs.py └── modeling.py ├── train.py └── utils ├── data_utils.py ├── dist_utils.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | !/checkpoint/ 143 | !/data/ 144 | !/logs/ 145 | !/output/ 146 | .idea 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Eunkwang Jeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MLP-Mixer 2 | 3 | Pytorch reimplementation of [Google's repository for the MLP-Mixer](https://github.com/google-research/vision_transformer/tree/linen) (Not yet updated on the master branch) that was released with the paper [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601) by Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy. 4 | 5 | In this paper, the authors show a performance close to SotA in an image classification benchmark using MLP(Multi-layer perceptron) without using CNN and Transformer. 6 | 7 | ![mixer_fig](./img/mixer_figure.png) 8 | 9 | MLP-Mixer (Mixer for short) consists of per-patch linear embeddings, Mixer layers, and a classifier head. Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and linear classifier head. 10 | 11 | ![mixer_result](./img/mixer_result.png) 12 | 13 | 14 | ## Usage 15 | ### 1. Download Pre-trained model (Google's Official Checkpoint) 16 | * [Available models](https://console.cloud.google.com/storage/browser/mixer_models): Mixer-B_16, Mixer-L_16 17 | * imagenet pre-train models 18 | * Mixer-B_16, Mixer-L_16 19 | * imagenet-21k pre-train models 20 | * Mixer-B_16, Mixer-L_16 21 | ``` 22 | # imagenet pre-train 23 | wget https://storage.googleapis.com/mixer_models/imagenet1k/{MODEL_NAME}.npz 24 | 25 | # imagenet-21k pre-train 26 | wget https://storage.googleapis.com/mixer_models/imagenet21k/{MODEL_NAME}.npz 27 | ``` 28 | 29 | ### 2. Fine-tuning 30 | ``` 31 | python3 train.py --name cifar10-100_500 --model_type Mixer-B_16 --pretrained_dir checkpoint/Mixer-B_16.npz 32 | ``` 33 | 34 | 35 | 36 | 37 | 38 | ## Reproducing Mixer results 39 | | upstream | model | dataset | acc(official) | acc(this repo) | 40 | |:------------:|:----------:|:-------:|:-------------:|:--------------:| 41 | | ImageNet | Mixer-B/16 | cifar10 | 96.72 | | 42 | | ImageNet | Mixer-L/16 | cifar10 | 96.59 | | 43 | | ImageNet-21k | Mixer-B/16 | cifar10 | 96.82 | | 44 | | ImageNet-21k | Mixer-L/16 | cifar10 | 96.34 | | 45 | 46 | 47 | 48 | ## Reference 49 | * [Google's Vision Transformer and MLP-Mixer](https://github.com/google-research/vision_transformer) 50 | 51 | 52 | ## Citations 53 | ```bibtex 54 | @article{tolstikhin2021, 55 | title={MLP-Mixer: An all-MLP Architecture for Vision}, 56 | author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey}, 57 | journal={arXiv preprint arXiv:2105.01601}, 58 | year={2021} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /img/mixer_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeonsworld/MLP-Mixer-Pytorch/793ccdeb73fd482e3fd5dacbb1cfc075af54fcd7/img/mixer_figure.png -------------------------------------------------------------------------------- /img/mixer_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeonsworld/MLP-Mixer-Pytorch/793ccdeb73fd482e3fd5dacbb1cfc075af54fcd7/img/mixer_result.png -------------------------------------------------------------------------------- /models/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | 18 | def get_mixer_b16_config(): 19 | """Returns Mixer-B/16 configuration.""" 20 | config = ml_collections.ConfigDict() 21 | config.name = 'Mixer-B_16' 22 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 23 | config.hidden_dim = 768 24 | config.num_blocks = 12 25 | config.tokens_mlp_dim = 384 26 | config.channels_mlp_dim = 3072 27 | return config 28 | 29 | 30 | def get_mixer_l16_config(): 31 | """Returns Mixer-L/16 configuration.""" 32 | config = ml_collections.ConfigDict() 33 | config.name = 'Mixer-L_16' 34 | config.patches = ml_collections.ConfigDict({'size': (16, 16)}) 35 | config.hidden_dim = 1024 36 | config.num_blocks = 24 37 | config.tokens_mlp_dim = 512 38 | config.channels_mlp_dim = 4096 39 | return config 40 | -------------------------------------------------------------------------------- /models/modeling.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | import models.configs as configs 6 | 7 | from os.path import join as pjoin 8 | 9 | from torch import nn 10 | from torch.nn.modules.utils import _pair 11 | 12 | TOK_FC_0 = "token_mixing/Dense_0" 13 | TOK_FC_1 = "token_mixing/Dense_1" 14 | CHA_FC_0 = "channel_mixing/Dense_0" 15 | CHA_FC_1 = "channel_mixing/Dense_1" 16 | PRE_NORM = "LayerNorm_0" 17 | POST_NORM = "LayerNorm_1" 18 | 19 | 20 | def np2th(weights, conv=False): 21 | """Possibly convert HWIO to OIHW.""" 22 | if conv: 23 | weights = weights.transpose([3, 2, 0, 1]) 24 | return torch.from_numpy(weights) 25 | 26 | 27 | class MlpBlock(nn.Module): 28 | def __init__(self, hidden_dim, ff_dim): 29 | super(MlpBlock, self).__init__() 30 | self.fc0 = nn.Linear(hidden_dim, ff_dim, bias=True) 31 | self.fc1 = nn.Linear(ff_dim, hidden_dim, bias=True) 32 | self.act_fn = nn.GELU() 33 | 34 | def forward(self, x): 35 | x = self.fc0(x) 36 | x = self.act_fn(x) 37 | x = self.fc1(x) 38 | return x 39 | 40 | 41 | class MixerBlock(nn.Module): 42 | def __init__(self, config): 43 | super(MixerBlock, self).__init__() 44 | self.token_mlp_block = MlpBlock(config.n_patches, config.tokens_mlp_dim) 45 | self.channel_mlp_block = MlpBlock(config.hidden_dim, config.channels_mlp_dim) 46 | self.pre_norm = nn.LayerNorm(config.hidden_dim, eps=1e-6) 47 | self.post_norm = nn.LayerNorm(config.hidden_dim, eps=1e-6) 48 | 49 | def forward(self, x): 50 | h = x 51 | x = self.pre_norm(x) 52 | x = x.transpose(-1, -2) 53 | x = self.token_mlp_block(x) 54 | x = x.transpose(-1, -2) 55 | x = x + h 56 | 57 | h = x 58 | x = self.post_norm(x) 59 | x = self.channel_mlp_block(x) 60 | x = x + h 61 | return x 62 | 63 | def load_from(self, weights, n_block): 64 | ROOT = f"MixerBlock_{n_block}" 65 | with torch.no_grad(): 66 | self.token_mlp_block.fc0.weight.copy_( 67 | np2th(weights[pjoin(ROOT, TOK_FC_0, "kernel")]).t()) 68 | self.token_mlp_block.fc1.weight.copy_( 69 | np2th(weights[pjoin(ROOT, TOK_FC_1, "kernel")]).t()) 70 | self.token_mlp_block.fc0.bias.copy_( 71 | np2th(weights[pjoin(ROOT, TOK_FC_0, "bias")]).t()) 72 | self.token_mlp_block.fc1.bias.copy_( 73 | np2th(weights[pjoin(ROOT, TOK_FC_1, "bias")]).t()) 74 | 75 | self.channel_mlp_block.fc0.weight.copy_( 76 | np2th(weights[pjoin(ROOT, CHA_FC_0, "kernel")]).t()) 77 | self.channel_mlp_block.fc1.weight.copy_( 78 | np2th(weights[pjoin(ROOT, CHA_FC_1, "kernel")]).t()) 79 | self.channel_mlp_block.fc0.bias.copy_( 80 | np2th(weights[pjoin(ROOT, CHA_FC_0, "bias")]).t()) 81 | self.channel_mlp_block.fc1.bias.copy_( 82 | np2th(weights[pjoin(ROOT, CHA_FC_1, "bias")]).t()) 83 | 84 | self.pre_norm.weight.copy_(np2th(weights[pjoin(ROOT, PRE_NORM, "scale")])) 85 | self.pre_norm.bias.copy_(np2th(weights[pjoin(ROOT, PRE_NORM, "bias")])) 86 | self.post_norm.weight.copy_(np2th(weights[pjoin(ROOT, POST_NORM, "scale")])) 87 | self.post_norm.bias.copy_(np2th(weights[pjoin(ROOT, POST_NORM, "bias")])) 88 | 89 | 90 | class MlpMixer(nn.Module): 91 | def __init__(self, config, img_size=224, num_classes=1000, patch_size=16, zero_head=False): 92 | super(MlpMixer, self).__init__() 93 | self.zero_head = zero_head 94 | self.num_classes = num_classes 95 | patch_size = _pair(patch_size) 96 | n_patches = (img_size // patch_size[0]) * (img_size // patch_size[1]) 97 | config.n_patches = n_patches 98 | 99 | self.stem = nn.Conv2d(in_channels=3, 100 | out_channels=config.hidden_dim, 101 | kernel_size=patch_size, 102 | stride=patch_size) 103 | self.head = nn.Linear(config.hidden_dim, num_classes, bias=True) 104 | self.pre_head_ln = nn.LayerNorm(config.hidden_dim, eps=1e-6) 105 | 106 | 107 | self.layer = nn.ModuleList() 108 | for _ in range(config.num_blocks): 109 | layer = MixerBlock(config) 110 | self.layer.append(copy.deepcopy(layer)) 111 | 112 | def forward(self, x, labels=None): 113 | x = self.stem(x) 114 | x = x.flatten(2) 115 | x = x.transpose(-1, -2) 116 | 117 | for block in self.layer: 118 | x = block(x) 119 | x = self.pre_head_ln(x) 120 | x = torch.mean(x, dim=1) 121 | logits = self.head(x) 122 | 123 | if labels is not None: 124 | loss_fct = nn.CrossEntropyLoss() 125 | loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) 126 | return loss 127 | else: 128 | return logits 129 | 130 | def load_from(self, weights): 131 | with torch.no_grad(): 132 | if self.zero_head: 133 | nn.init.zeros_(self.head.weight) 134 | nn.init.zeros_(self.head.bias) 135 | else: 136 | self.head.weight.copy_(np2th(weights["head/kernel"]).t()) 137 | self.head.bias.copy_(np2th(weights["head/bias"]).t()) 138 | self.stem.weight.copy_(np2th(weights["stem/kernel"], conv=True)) 139 | self.stem.bias.copy_(np2th(weights["stem/bias"])) 140 | self.pre_head_ln.weight.copy_(np2th(weights["pre_head_layer_norm/scale"])) 141 | self.pre_head_ln.bias.copy_(np2th(weights["pre_head_layer_norm/bias"])) 142 | 143 | for bname, block in self.layer.named_children(): 144 | block.load_from(weights, n_block=bname) 145 | 146 | 147 | CONFIGS = { 148 | 'Mixer-B_16': configs.get_mixer_b16_config(), 149 | 'Mixer-L_16': configs.get_mixer_l16_config(), 150 | 'Mixer-B_16-21k': configs.get_mixer_b16_config(), 151 | 'Mixer-L_16-21k': configs.get_mixer_l16_config() 152 | } 153 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import logging 5 | import argparse 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | from datetime import timedelta 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | from tqdm import tqdm 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from models.modeling import MlpMixer, CONFIGS 19 | from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule 20 | from utils.data_utils import get_loader 21 | from utils.dist_utils import get_world_size 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class AverageMeter(object): 28 | """Computes and stores the average and current value""" 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | 44 | 45 | def simple_accuracy(preds, labels): 46 | return (preds == labels).mean() 47 | 48 | 49 | def save_model(args, model): 50 | model_to_save = model.module if hasattr(model, 'module') else model 51 | model_checkpoint = os.path.join(args.output_dir, "%s_checkpoint.bin" % args.name) 52 | torch.save(model_to_save.state_dict(), model_checkpoint) 53 | logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir) 54 | 55 | 56 | def setup(args): 57 | # Prepare model 58 | config = CONFIGS[args.model_type] 59 | 60 | num_classes = 10 61 | 62 | model = MlpMixer(config, args.img_size, num_classes=num_classes, patch_size=16, zero_head=True) 63 | model.load_from(np.load(args.pretrained_dir)) 64 | model.to(args.device) 65 | num_params = count_parameters(model) 66 | 67 | logger.info("{}".format(config)) 68 | logger.info("Training parameters %s", args) 69 | logger.info("Total Parameter: \t%2.1fM" % num_params) 70 | return args, model 71 | 72 | 73 | def count_parameters(model): 74 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 75 | return params/1000000 76 | 77 | 78 | def set_seed(args): 79 | random.seed(args.seed) 80 | np.random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | if args.n_gpu > 0: 83 | torch.cuda.manual_seed_all(args.seed) 84 | 85 | 86 | def valid(args, model, writer, test_loader, global_step): 87 | # Validation! 88 | eval_losses = AverageMeter() 89 | 90 | logger.info("***** Running Validation *****") 91 | logger.info(" Num steps = %d", len(test_loader)) 92 | logger.info(" Batch size = %d", args.eval_batch_size) 93 | 94 | model.eval() 95 | all_preds, all_label = [], [] 96 | epoch_iterator = tqdm(test_loader, 97 | desc="Validating... (loss=X.X)", 98 | bar_format="{l_bar}{r_bar}", 99 | dynamic_ncols=True, 100 | disable=args.local_rank not in [-1, 0]) 101 | loss_fct = torch.nn.CrossEntropyLoss() 102 | for step, batch in enumerate(epoch_iterator): 103 | batch = tuple(t.to(args.device) for t in batch) 104 | x, y = batch 105 | with torch.no_grad(): 106 | logits = model(x)[0] 107 | 108 | eval_loss = loss_fct(logits, y) 109 | eval_losses.update(eval_loss.item()) 110 | 111 | preds = torch.argmax(logits, dim=-1) 112 | 113 | if len(all_preds) == 0: 114 | all_preds.append(preds.detach().cpu().numpy()) 115 | all_label.append(y.detach().cpu().numpy()) 116 | else: 117 | all_preds[0] = np.append( 118 | all_preds[0], preds.detach().cpu().numpy(), axis=0 119 | ) 120 | all_label[0] = np.append( 121 | all_label[0], y.detach().cpu().numpy(), axis=0 122 | ) 123 | epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) 124 | 125 | all_preds, all_label = all_preds[0], all_label[0] 126 | accuracy = simple_accuracy(all_preds, all_label) 127 | 128 | logger.info("\n") 129 | logger.info("Validation Results") 130 | logger.info("Global Steps: %d" % global_step) 131 | logger.info("Valid Loss: %2.5f" % eval_losses.avg) 132 | logger.info("Valid Accuracy: %2.5f" % accuracy) 133 | 134 | writer.add_scalar("test/accuracy", scalar_value=accuracy, global_step=global_step) 135 | return accuracy 136 | 137 | 138 | def train(args, model): 139 | """ Train the model """ 140 | if args.local_rank in [-1, 0]: 141 | os.makedirs(args.output_dir, exist_ok=True) 142 | writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) 143 | 144 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 145 | 146 | # Prepare dataset 147 | train_loader, test_loader = get_loader(args) 148 | 149 | # Prepare optimizer and scheduler 150 | optimizer = torch.optim.SGD(model.parameters(), 151 | lr=args.learning_rate, 152 | momentum=0.9, 153 | weight_decay=args.weight_decay) 154 | t_total = args.num_steps 155 | if args.decay_type == "cosine": 156 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 157 | else: 158 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 159 | 160 | if args.fp16: 161 | model, optimizer = amp.initialize(models=model, 162 | optimizers=optimizer, 163 | opt_level=args.fp16_opt_level) 164 | amp._amp_state.loss_scalers[0]._loss_scale = 2**20 165 | 166 | # Distributed training 167 | if args.local_rank != -1: 168 | model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) 169 | 170 | # Train! 171 | logger.info("***** Running training *****") 172 | logger.info(" Total optimization steps = %d", args.num_steps) 173 | logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size) 174 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 175 | args.train_batch_size * args.gradient_accumulation_steps * ( 176 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 177 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 178 | 179 | model.zero_grad() 180 | set_seed(args) # Added here for reproducibility (even between python 2 and 3) 181 | losses = AverageMeter() 182 | global_step, best_acc = 0, 0 183 | while True: 184 | model.train() 185 | epoch_iterator = tqdm(train_loader, 186 | desc="Training (X / X Steps) (loss=X.X)", 187 | bar_format="{l_bar}{r_bar}", 188 | dynamic_ncols=True, 189 | disable=args.local_rank not in [-1, 0]) 190 | for step, batch in enumerate(epoch_iterator): 191 | batch = tuple(t.to(args.device) for t in batch) 192 | x, y = batch 193 | loss = model(x, y) 194 | 195 | if args.gradient_accumulation_steps > 1: 196 | loss = loss / args.gradient_accumulation_steps 197 | if args.fp16: 198 | with amp.scale_loss(loss, optimizer) as scaled_loss: 199 | scaled_loss.backward() 200 | else: 201 | loss.backward() 202 | 203 | if (step + 1) % args.gradient_accumulation_steps == 0: 204 | losses.update(loss.item()*args.gradient_accumulation_steps) 205 | if args.fp16: 206 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 207 | else: 208 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 209 | scheduler.step() 210 | optimizer.step() 211 | optimizer.zero_grad() 212 | global_step += 1 213 | 214 | epoch_iterator.set_description( 215 | "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) 216 | ) 217 | if args.local_rank in [-1, 0]: 218 | writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) 219 | writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) 220 | if global_step % args.eval_every == 0 and args.local_rank in [-1, 0]: 221 | accuracy = valid(args, model, writer, test_loader, global_step) 222 | if best_acc < accuracy: 223 | save_model(args, model) 224 | best_acc = accuracy 225 | model.train() 226 | 227 | if global_step % t_total == 0: 228 | break 229 | losses.reset() 230 | if global_step % t_total == 0: 231 | break 232 | 233 | if args.local_rank in [-1, 0]: 234 | writer.close() 235 | logger.info("Best Accuracy: \t%f" % best_acc) 236 | logger.info("End Training!") 237 | 238 | 239 | def main(): 240 | parser = argparse.ArgumentParser() 241 | # Required parameters 242 | parser.add_argument("--name", required=True, 243 | help="Name of this run. Used for monitoring.") 244 | parser.add_argument("--model_type", choices=["Mixer-B_16", "Mixer-L_16", 245 | "Mixer-B_16-21k", "Mixer-L_16-21k"], 246 | default="Mixer-B_16", 247 | help="Which model to use.") 248 | parser.add_argument("--pretrained_dir", type=str, default="checkpoint/Mixer-B_16.npz", 249 | help="Where to search for pretrained ViT models.") 250 | parser.add_argument("--output_dir", default="output", type=str, 251 | help="The output directory where checkpoints will be written.") 252 | 253 | parser.add_argument("--train_batch_size", default=512, type=int, 254 | help="Total batch size for training.") 255 | parser.add_argument("--eval_batch_size", default=512, type=int, 256 | help="Total batch size for eval.") 257 | parser.add_argument("--eval_every", default=100, type=int, 258 | help="Run prediction on validation set every so many steps." 259 | "Will always run one evaluation at the end of training.") 260 | 261 | parser.add_argument("--learning_rate", default=3e-2, type=float, 262 | help="The initial learning rate for SGD.") 263 | parser.add_argument("--weight_decay", default=0, type=float, 264 | help="Weight deay if we apply some.") 265 | parser.add_argument("--num_steps", default=10000, type=int, 266 | help="Total number of training epochs to perform.") 267 | parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine", 268 | help="How to decay the learning rate.") 269 | parser.add_argument("--warmup_steps", default=500, type=int, 270 | help="Step of training to perform learning rate warmup for.") 271 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 272 | help="Max gradient norm.") 273 | 274 | parser.add_argument("--local_rank", type=int, default=-1, 275 | help="local_rank for distributed training on gpus") 276 | parser.add_argument('--seed', type=int, default=42, 277 | help="random seed for initialization") 278 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 279 | help="Number of updates steps to accumulate before performing a backward/update pass.") 280 | parser.add_argument('--fp16', action='store_true', 281 | help="Whether to use 16-bit float precision instead of 32-bit") 282 | parser.add_argument('--fp16_opt_level', type=str, default='O2', 283 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 284 | "See details at https://nvidia.github.io/apex/amp.html") 285 | parser.add_argument('--loss_scale', type=float, default=0, 286 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 287 | "0 (default value): dynamic loss scaling.\n" 288 | "Positive power of 2: static loss scaling value.\n") 289 | args = parser.parse_args() 290 | 291 | # Setup CUDA, GPU & distributed training 292 | if args.local_rank == -1: 293 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 294 | args.n_gpu = torch.cuda.device_count() 295 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 296 | torch.cuda.set_device(args.local_rank) 297 | device = torch.device("cuda", args.local_rank) 298 | torch.distributed.init_process_group(backend='nccl', 299 | timeout=timedelta(minutes=60)) 300 | args.n_gpu = 1 301 | args.device = device 302 | args.img_size = 224 303 | if args.fp16: 304 | from apex import amp 305 | if args.local_rank != -1: 306 | from apex.parallel import DistributedDataParallel as DDP 307 | 308 | # Setup logging 309 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 310 | datefmt='%m/%d/%Y %H:%M:%S', 311 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 312 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s" % 313 | (args.local_rank, args.device, args.n_gpu, bool(args.local_rank != -1), args.fp16)) 314 | 315 | # Set seed 316 | set_seed(args) 317 | 318 | # Model & Tokenizer Setup 319 | args, model = setup(args) 320 | 321 | # Training 322 | train(args, model) 323 | 324 | 325 | if __name__ == "__main__": 326 | main() 327 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from torchvision import transforms, datasets 6 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_loader(args): 13 | if args.local_rank not in [-1, 0]: 14 | torch.distributed.barrier() 15 | 16 | transform_train = transforms.Compose([ 17 | transforms.RandomResizedCrop((args.img_size, args.img_size), scale=(0.05, 1.0)), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 20 | ]) 21 | transform_test = transforms.Compose([ 22 | transforms.Resize((args.img_size, args.img_size)), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 25 | ]) 26 | 27 | trainset = datasets.CIFAR10(root="./data", 28 | train=True, 29 | download=True, 30 | transform=transform_train) 31 | testset = datasets.CIFAR10(root="./data", 32 | train=False, 33 | download=True, 34 | transform=transform_test) if args.local_rank in [-1, 0] else None 35 | if args.local_rank == 0: 36 | torch.distributed.barrier() 37 | 38 | train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset) 39 | test_sampler = SequentialSampler(testset) 40 | train_loader = DataLoader(trainset, 41 | sampler=train_sampler, 42 | batch_size=args.train_batch_size, 43 | num_workers=4, 44 | pin_memory=True) 45 | test_loader = DataLoader(testset, 46 | sampler=test_sampler, 47 | batch_size=args.eval_batch_size, 48 | num_workers=4, 49 | pin_memory=True) if testset is not None else None 50 | 51 | return train_loader, test_loader 52 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | def get_rank(): 4 | if not dist.is_available(): 5 | return 0 6 | if not dist.is_initialized(): 7 | return 0 8 | return dist.get_rank() 9 | 10 | def get_world_size(): 11 | if not dist.is_available(): 12 | return 1 13 | if not dist.is_initialized(): 14 | return 1 15 | return dist.get_world_size() 16 | 17 | def is_main_process(): 18 | return get_rank() == 0 19 | 20 | def format_step(step): 21 | if isinstance(step, str): 22 | return step 23 | s = "" 24 | if len(step) > 0: 25 | s += "Training Epoch: {} ".format(step[0]) 26 | if len(step) > 1: 27 | s += "Training Iteration: {} ".format(step[1]) 28 | if len(step) > 2: 29 | s += "Validation Iteration: {} ".format(step[2]) 30 | return s 31 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | class ConstantLRSchedule(LambdaLR): 9 | """ Constant learning rate schedule. 10 | """ 11 | def __init__(self, optimizer, last_epoch=-1): 12 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 13 | 14 | 15 | class WarmupConstantSchedule(LambdaLR): 16 | """ Linear warmup and then constant. 17 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 18 | Keeps learning rate schedule equal to 1. after warmup_steps. 19 | """ 20 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 21 | self.warmup_steps = warmup_steps 22 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 23 | 24 | def lr_lambda(self, step): 25 | if step < self.warmup_steps: 26 | return float(step) / float(max(1.0, self.warmup_steps)) 27 | return 1. 28 | 29 | 30 | class WarmupLinearSchedule(LambdaLR): 31 | """ Linear warmup and then linear decay. 32 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 33 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 34 | """ 35 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 36 | self.warmup_steps = warmup_steps 37 | self.t_total = t_total 38 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 39 | 40 | def lr_lambda(self, step): 41 | if step < self.warmup_steps: 42 | return float(step) / float(max(1, self.warmup_steps)) 43 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 44 | 45 | 46 | class WarmupCosineSchedule(LambdaLR): 47 | """ Linear warmup and then cosine decay. 48 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 49 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 50 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 51 | """ 52 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 53 | self.warmup_steps = warmup_steps 54 | self.t_total = t_total 55 | self.cycles = cycles 56 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1.0, self.warmup_steps)) 61 | # progress after warmup 62 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 63 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 64 | --------------------------------------------------------------------------------