├── .gitignore ├── LICENSE ├── README.md ├── assets ├── model1.PNG ├── model2.PNG └── model3.PNG ├── ceit.py ├── configs └── default.yaml ├── data_utils.py ├── module.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CeiT : Convolutional enhanced image Transformer 2 | This is an unofficial PyTorch implementation of [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/abs/2103.11816) . 3 | ![](assets/model1.PNG) 4 | ![](assets/model2.PNG) 5 | 6 | 7 | 8 | ## Training : 9 | ``` 10 | python train.py -c configs/default.yaml --name "name_of_exp" 11 | ``` 12 | ## Usage : 13 | ```python 14 | import torch 15 | from ceit import CeiT 16 | 17 | img = torch.ones([1, 3, 224, 224]) 18 | 19 | model = CeiT(image_size = 224, patch_size = 4, num_classes = 100) 20 | out = model(img) 21 | 22 | print("Shape of out :", out.shape) # [B, num_classes] 23 | 24 | model = CeiT(image_size = 224, patch_size = 4, num_classes = 100, with_lca = True) 25 | out = model(img) 26 | 27 | print("Shape of out :", out.shape) # [B, num_classes] 28 | 29 | ``` 30 | 31 | 32 | ## Note : 33 | * LCA might not be properly implemented. 34 | 35 | ## Citation : 36 | ``` 37 | @misc{yuan2021incorporating, 38 | title={Incorporating Convolution Designs into Visual Transformers}, 39 | author={Kun Yuan and Shaopeng Guo and Ziwei Liu and Aojun Zhou and Fengwei Yu and Wei Wu}, 40 | year={2021}, 41 | eprint={2103.11816}, 42 | archivePrefix={arXiv}, 43 | primaryClass={cs.CV} 44 | } 45 | ``` 46 | 47 | ## Acknowledgement : 48 | * Base ViT code is borrowed from [@lucidrains](https://github.com/lucidrains) repo : https://github.com/lucidrains/vit-pytorch 49 | * Training and dataloader code is borrowed from [@jeonsworld](https://github.com/jeonsworld) repo : https://github.com/jeonsworld/ViT-pytorch 50 | -------------------------------------------------------------------------------- /assets/model1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/CeiT-pytorch/d41cb23743fbb34c4354564a958f1d78de4e2770/assets/model1.PNG -------------------------------------------------------------------------------- /assets/model2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/CeiT-pytorch/d41cb23743fbb34c4354564a958f1d78de4e2770/assets/model2.PNG -------------------------------------------------------------------------------- /assets/model3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/CeiT-pytorch/d41cb23743fbb34c4354564a958f1d78de4e2770/assets/model3.PNG -------------------------------------------------------------------------------- /ceit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from module import Residual, Attention, PreNorm, LeFF, FeedForward, LCAttention 7 | import numpy as np 8 | 9 | class TransformerLeFF(nn.Module): 10 | def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.): 11 | super().__init__() 12 | self.layers = nn.ModuleList([]) 13 | for _ in range(depth): 14 | self.layers.append(nn.ModuleList([ 15 | Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), 16 | Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel))) 17 | ])) 18 | def forward(self, x): 19 | c = list() 20 | for attn, leff in self.layers: 21 | x = attn(x) 22 | cls_tokens = x[:, 0] 23 | c.append(cls_tokens) 24 | x = leff(x[:, 1:]) 25 | x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) 26 | return x, torch.stack(c).transpose(0, 1) 27 | 28 | 29 | 30 | class LCA(nn.Module): 31 | # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, 32 | # so I removed it, althougth with Residual connection also this code will work. 33 | def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.): 34 | super().__init__() 35 | self.layers = nn.ModuleList([]) 36 | self.layers.append(nn.ModuleList([ 37 | PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 38 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 39 | ])) 40 | def forward(self, x): 41 | for attn, ff in self.layers: 42 | x = attn(x) + x[:, -1].unsqueeze(1) 43 | 44 | x = x[:, -1].unsqueeze(1) + ff(x) 45 | return x 46 | 47 | 48 | 49 | 50 | class CeiT(nn.Module): 51 | def __init__(self, image_size, patch_size, num_classes, dim = 192, depth = 12, heads = 3, pool = 'cls', in_channels = 3, out_channels = 32, dim_head = 64, dropout = 0., 52 | emb_dropout = 0., conv_kernel = 7, stride = 2, depth_kernel = 3, pool_kernel = 3, scale_dim = 4, with_lca = False, lca_heads = 4, lca_dim_head = 48, lca_mlp_dim = 384): 53 | super().__init__() 54 | 55 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 56 | 57 | # IoT 58 | self.conv = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4), 60 | nn.BatchNorm2d(out_channels), 61 | nn.MaxPool2d(pool_kernel, stride) 62 | ) 63 | 64 | feature_size = image_size // 4 65 | 66 | assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 67 | num_patches = (feature_size // patch_size) ** 2 68 | patch_dim = out_channels * patch_size ** 2 69 | self.to_patch_embedding = nn.Sequential( 70 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), 71 | nn.Linear(patch_dim, dim), 72 | ) 73 | 74 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 75 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 76 | self.dropout = nn.Dropout(emb_dropout) 77 | 78 | self.transformer = TransformerLeFF(dim, depth, heads, dim_head, scale_dim, depth_kernel, dropout) 79 | 80 | self.with_lca = with_lca 81 | if with_lca: 82 | self.LCA = LCA(dim, lca_heads, lca_dim_head, lca_mlp_dim) 83 | 84 | self.pool = pool 85 | self.to_latent = nn.Identity() 86 | 87 | self.mlp_head = nn.Sequential( 88 | nn.LayerNorm(dim), 89 | nn.Linear(dim, num_classes) 90 | ) 91 | 92 | def forward(self, img): 93 | x = self.conv(img) 94 | x = self.to_patch_embedding(x) 95 | b, n, _ = x.shape 96 | 97 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 98 | x = torch.cat((cls_tokens, x), dim=1) 99 | x += self.pos_embedding[:, :(n + 1)] 100 | x = self.dropout(x) 101 | 102 | x, c = self.transformer(x) 103 | 104 | if self.with_lca: 105 | x = self.LCA(c)[:, 0] 106 | else: 107 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 108 | 109 | x = self.to_latent(x) 110 | return self.mlp_head(x) 111 | 112 | 113 | 114 | 115 | if __name__ == "__main__": 116 | 117 | img = torch.ones([1, 3, 224, 224]) 118 | 119 | model = CeiT(224, 4, 100) 120 | 121 | out = model(img) 122 | 123 | print("Shape of out :", out.shape) # [B, num_classes] 124 | 125 | model = CeiT(224, 4, 1000, with_lca = True) 126 | out = model(img) 127 | 128 | print("Shape of out :", out.shape) # [B, num_classes] 129 | 130 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 131 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 132 | print('Trainable Parameters: %.3fM' % parameters) -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset: "cifar100" # ["cifar10", "cifar100"] 3 | path: "./data" 4 | image_size: 224 5 | outdir: "./output" 6 | 7 | model: 8 | patch_size: 4 9 | num_classes: 100 10 | in_channels: 3 11 | out_channels: 32 12 | dim: 192 13 | dim_head: 64 14 | depth: 12 15 | heads: 3 16 | conv_kernel: 7 17 | stride: 2 18 | depth_kernel: 3 19 | pool_kernel: 2 20 | pool_stride: 2 21 | scale_dim: 4 22 | with_lca: True 23 | lca_dim: 192 24 | lca_heads: 4 25 | lca_dim_head: 48 26 | lca_mlp_dim: 384 27 | dropout: 0. 28 | pool: 'cls' # ['cls', 'mean'] 29 | emb_dropout: 0. 30 | 31 | train: 32 | batch: 64 33 | valid_batch: 32 34 | valid_step: 500 35 | num_steps: 100000 36 | accum_grad: 1 37 | lr: 0.001 38 | weight_decay: 0.05 39 | decay_type: "cosine" # ["cosine", "linear"] 40 | grad_clip: 1.0 41 | warmup_steps: 3500 42 | seed: 1992 43 | ngpu: 1 44 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from torchvision import transforms, datasets 6 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_loader(local_rank, hp): 13 | # if local_rank not in [-1, 0]: 14 | # torch.distributed.barrier() 15 | 16 | transform_train = transforms.Compose([ 17 | transforms.RandomResizedCrop((hp.data.image_size, hp.data.image_size), scale=(0.05, 1.0)), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 20 | ]) 21 | transform_test = transforms.Compose([ 22 | transforms.Resize((hp.data.image_size, hp.data.image_size)), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 25 | ]) 26 | 27 | if hp.data.dataset == "cifar10": 28 | trainset = datasets.CIFAR10(root=hp.data.path, 29 | train=True, 30 | download=True, 31 | transform=transform_train) 32 | testset = datasets.CIFAR10(root=hp.data.path, 33 | train=False, 34 | download=True, 35 | transform=transform_test) if local_rank in [-1, 0] else None 36 | 37 | else: 38 | trainset = datasets.CIFAR100(root=hp.data.path, 39 | train=True, 40 | download=True, 41 | transform=transform_train) 42 | testset = datasets.CIFAR100(root=hp.data.path, 43 | train=False, 44 | download=True, 45 | transform=transform_test) if local_rank in [-1, 0] else None 46 | # if local_rank == 0: 47 | # torch.distributed.barrier() 48 | 49 | train_sampler = RandomSampler(trainset) if local_rank == 0 else DistributedSampler(trainset) 50 | test_sampler = SequentialSampler(testset) 51 | train_loader = DataLoader(trainset, 52 | sampler=train_sampler, 53 | batch_size=hp.train.batch, 54 | num_workers=4, 55 | pin_memory=True) 56 | test_loader = DataLoader(testset, 57 | sampler=test_sampler, 58 | batch_size=hp.train.valid_batch, 59 | num_workers=4, 60 | pin_memory=True) if testset is not None else None 61 | 62 | return train_loader, test_loader -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | 67 | class ReAttention(nn.Module): 68 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 69 | super().__init__() 70 | inner_dim = dim_head * heads 71 | self.heads = heads 72 | self.scale = dim_head ** -0.5 73 | 74 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 75 | 76 | self.reattn_weights = nn.Parameter(torch.randn(heads, heads)) 77 | 78 | self.reattn_norm = nn.Sequential( 79 | Rearrange('b h i j -> b i j h'), 80 | nn.LayerNorm(heads), 81 | Rearrange('b i j h -> b h i j') 82 | ) 83 | 84 | self.to_out = nn.Sequential( 85 | nn.Linear(inner_dim, dim), 86 | nn.Dropout(dropout) 87 | ) 88 | 89 | def forward(self, x): 90 | b, n, _, h = *x.shape, self.heads 91 | qkv = self.to_qkv(x).chunk(3, dim = -1) 92 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 93 | 94 | # attention 95 | 96 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 97 | attn = dots.softmax(dim=-1) 98 | 99 | # re-attention 100 | 101 | attn = einsum('b h i j, h g -> b g i j', attn, self.reattn_weights) 102 | attn = self.reattn_norm(attn) 103 | 104 | # aggregate and out 105 | 106 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 107 | out = rearrange(out, 'b h n d -> b n (h d)') 108 | out = self.to_out(out) 109 | return out 110 | 111 | class LeFF(nn.Module): 112 | 113 | def __init__(self, dim = 192, scale = 4, depth_kernel = 3): 114 | super().__init__() 115 | 116 | scale_dim = dim*scale 117 | self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim), 118 | Rearrange('b n c -> b c n'), 119 | nn.BatchNorm1d(scale_dim), 120 | nn.GELU(), 121 | Rearrange('b c (h w) -> b c h w', h=14, w=14) 122 | ) 123 | 124 | self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False), 125 | nn.BatchNorm2d(scale_dim), 126 | nn.GELU(), 127 | Rearrange('b c h w -> b (h w) c', h=14, w=14) 128 | ) 129 | 130 | self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim), 131 | Rearrange('b n c -> b c n'), 132 | nn.BatchNorm1d(dim), 133 | nn.GELU(), 134 | Rearrange('b c n -> b n c') 135 | ) 136 | 137 | def forward(self, x): 138 | x = self.up_proj(x) 139 | x = self.depth_conv(x) 140 | x = self.down_proj(x) 141 | return x 142 | 143 | 144 | class LCAttention(nn.Module): 145 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 146 | super().__init__() 147 | inner_dim = dim_head * heads 148 | project_out = not (heads == 1 and dim_head == dim) 149 | 150 | self.heads = heads 151 | self.scale = dim_head ** -0.5 152 | 153 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 154 | 155 | self.to_out = nn.Sequential( 156 | nn.Linear(inner_dim, dim), 157 | nn.Dropout(dropout) 158 | ) if project_out else nn.Identity() 159 | 160 | def forward(self, x): 161 | b, n, _, h = *x.shape, self.heads 162 | qkv = self.to_qkv(x).chunk(3, dim = -1) 163 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 164 | q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query 165 | 166 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 167 | 168 | attn = dots.softmax(dim=-1) 169 | 170 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 171 | out = rearrange(out, 'b h n d -> b n (h d)') 172 | out = self.to_out(out) 173 | return out 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | from __future__ import absolute_import, division, print_function 5 | import warnings 6 | warnings.simplefilter(action='ignore', category=FutureWarning) 7 | import logging 8 | import argparse 9 | import os 10 | import random 11 | import numpy as np 12 | 13 | from datetime import timedelta 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | from tqdm import tqdm 19 | from torch.utils.tensorboard import SummaryWriter 20 | import torch.multiprocessing as mp 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | from utils import HParam 23 | from ceit import CeiT 24 | from utils import WarmupLinearSchedule, WarmupCosineSchedule 25 | from data_utils import get_loader 26 | from utils import get_world_size, get_rank 27 | 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class AverageMeter(object): 33 | """Computes and stores the average and current value""" 34 | def __init__(self): 35 | self.reset() 36 | 37 | def reset(self): 38 | self.val = 0 39 | self.avg = 0 40 | self.sum = 0 41 | self.count = 0 42 | 43 | def update(self, val, n=1): 44 | self.val = val 45 | self.sum += val * n 46 | self.count += n 47 | self.avg = self.sum / self.count 48 | 49 | 50 | def simple_accuracy(preds, labels): 51 | return (preds == labels).mean() 52 | 53 | 54 | def save_model(name, outdir, model): 55 | model_to_save = model.module if hasattr(model, 'module') else model 56 | model_checkpoint = os.path.join(outdir, "%s_checkpoint.pyt" % name) 57 | torch.save(model_to_save.state_dict(), model_checkpoint) 58 | logger.info("Saved model checkpoint to [DIR: %s]", outdir) 59 | 60 | 61 | 62 | 63 | 64 | def count_parameters(model): 65 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 66 | return params/1000000 67 | 68 | 69 | def set_seed(hp): 70 | random.seed(hp.train.seed) 71 | np.random.seed(hp.train.seed) 72 | torch.manual_seed(hp.train.seed) 73 | if hp.train.ngpu > 0: 74 | torch.cuda.manual_seed_all(hp.train.seed) 75 | 76 | 77 | def valid(device, local_rank, hp, model, writer, test_loader, global_step): 78 | # Validation! 79 | eval_losses = AverageMeter() 80 | 81 | logger.info("***** Running Validation *****") 82 | logger.info(" Num steps = %d", len(test_loader)) 83 | logger.info(" Batch size = %d", hp.train.valid_batch) 84 | 85 | model.eval() 86 | all_preds, all_label = [], [] 87 | epoch_iterator = tqdm(test_loader, 88 | desc="Validating... (loss=X.X)", 89 | bar_format="{l_bar}{r_bar}", 90 | dynamic_ncols=True, 91 | disable=local_rank not in [-1, 0]) 92 | loss_fct = torch.nn.CrossEntropyLoss() 93 | for step, batch in enumerate(epoch_iterator): 94 | batch = tuple(t.to(device) for t in batch) 95 | x, y = batch 96 | with torch.no_grad(): 97 | logits = model(x) 98 | 99 | eval_loss = loss_fct(logits, y) 100 | eval_losses.update(eval_loss.item()) 101 | 102 | preds = torch.argmax(logits, dim=-1) 103 | 104 | if len(all_preds) == 0: 105 | all_preds.append(preds.detach().cpu().numpy()) 106 | all_label.append(y.detach().cpu().numpy()) 107 | else: 108 | all_preds[0] = np.append( 109 | all_preds[0], preds.detach().cpu().numpy(), axis=0 110 | ) 111 | all_label[0] = np.append( 112 | all_label[0], y.detach().cpu().numpy(), axis=0 113 | ) 114 | epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val) 115 | 116 | all_preds, all_label = all_preds[0], all_label[0] 117 | accuracy = simple_accuracy(all_preds, all_label) 118 | 119 | logger.info("\n") 120 | logger.info("Validation Results") 121 | logger.info("Global Steps: %d" % global_step) 122 | logger.info("Valid Loss: %2.5f" % eval_losses.avg) 123 | logger.info("Valid Accuracy: %2.5f" % accuracy) 124 | 125 | writer.add_scalar("test/accuracy", scalar_value=accuracy, global_step=global_step) 126 | return accuracy 127 | 128 | 129 | def train(local_rank, args, hp, model): 130 | 131 | if hp.train.ngpu > 1: 132 | dist.init_process_group(backend="nccl", init_method="tcp://localhost:54321", 133 | world_size=hp.train.ngpu, rank=local_rank) 134 | 135 | torch.cuda.manual_seed(hp.train.seed) 136 | device = torch.device('cuda:{:d}'.format(local_rank)) 137 | model = model.to(device) 138 | 139 | 140 | """ Train the model """ 141 | if local_rank in [-1, 0]: 142 | os.makedirs(hp.data.outdir, exist_ok=True) 143 | writer = SummaryWriter(log_dir=os.path.join("logs", args.name)) 144 | print("Loading dataset :") 145 | 146 | hp.train.batch = hp.train.batch // hp.train.accum_grad 147 | 148 | # Prepare dataset 149 | train_loader, test_loader = get_loader(local_rank, hp) 150 | 151 | # Prepare optimizer and scheduler 152 | 153 | optimizer = torch.optim.AdamW(model.parameters(), hp.train.lr, betas=[0.8, 0.99], weight_decay=0.05) 154 | t_total = hp.train.num_steps 155 | if hp.train.decay_type == "cosine": 156 | scheduler = WarmupCosineSchedule(optimizer, warmup_steps=hp.train.warmup_steps, t_total=t_total) 157 | else: 158 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=hp.train.warmup_steps, t_total=t_total) 159 | 160 | 161 | 162 | # Distributed training 163 | if hp.train.ngpu > 1: 164 | model = DDP(model, device_ids=[local_rank]) 165 | 166 | # Train! 167 | logger.info("***** Running training *****") 168 | logger.info(" Total optimization steps = %d", hp.train.num_steps) 169 | logger.info(" Instantaneous batch size per GPU = %d", hp.train.batch) 170 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 171 | hp.train.batch * hp.train.accum_grad * ( 172 | hp.train.ngpu if local_rank != -1 else 1)) 173 | logger.info(" Gradient Accumulation steps = %d", hp.train.accum_grad) 174 | 175 | model.zero_grad() 176 | set_seed(hp) # Added here for reproducibility (even between python 2 and 3) 177 | losses = AverageMeter() 178 | global_step, best_acc = 0, 0 179 | 180 | loss_fct = torch.nn.CrossEntropyLoss() 181 | while True: 182 | model.train() 183 | epoch_iterator = tqdm(train_loader, 184 | desc="Training (X / X Steps) (loss=X.X)", 185 | bar_format="{l_bar}{r_bar}", 186 | dynamic_ncols=True, 187 | disable=local_rank not in [-1, 0]) 188 | for step, batch in enumerate(epoch_iterator): 189 | batch = tuple(t.to(device) for t in batch) 190 | x, y = batch 191 | logits = model(x) 192 | 193 | loss = loss_fct(logits.view(-1, hp.model.num_classes), y.view(-1)) 194 | 195 | if hp.train.accum_grad > 1: 196 | loss = loss / hp.train.accum_grad 197 | loss.backward() 198 | 199 | if (step + 1) % hp.train.accum_grad == 0: 200 | losses.update(loss.item()*hp.train.accum_grad) 201 | torch.nn.utils.clip_grad_norm_(model.parameters(), hp.train.grad_clip) 202 | 203 | optimizer.step() 204 | scheduler.step() 205 | optimizer.zero_grad() 206 | global_step += 1 207 | 208 | epoch_iterator.set_description( 209 | "Training (%d / %d Steps) (loss=%2.5f)" % (global_step, t_total, losses.val) 210 | ) 211 | if local_rank in [-1, 0]: 212 | writer.add_scalar("train/loss", scalar_value=losses.val, global_step=global_step) 213 | writer.add_scalar("train/lr", scalar_value=scheduler.get_lr()[0], global_step=global_step) 214 | if global_step % hp.train.valid_step == 0 and local_rank in [-1, 0]: 215 | accuracy = valid(device, local_rank, hp, model, writer, test_loader, global_step) 216 | if best_acc < accuracy: 217 | save_model(args.name, hp.data.outdir, model) 218 | best_acc = accuracy 219 | model.train() 220 | 221 | if global_step % t_total == 0: 222 | break 223 | losses.reset() 224 | if global_step % t_total == 0: 225 | break 226 | 227 | if local_rank in [-1, 0]: 228 | writer.close() 229 | logger.info("Best Accuracy: \t%f" % best_acc) 230 | logger.info("End Training!") 231 | 232 | 233 | def main(): 234 | parser = argparse.ArgumentParser() 235 | # Required parameters 236 | parser.add_argument('-c', '--config', type=str, required=True, 237 | help="yaml file for configuration") 238 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 239 | help="path of checkpoint pt file to resume training") 240 | parser.add_argument("--name", required=True, 241 | help="Name of this run. Used for monitoring.") 242 | args = parser.parse_args() 243 | 244 | hp = HParam(args.config) 245 | with open(args.config, 'r') as f: 246 | hp_str = ''.join(f.readlines()) 247 | 248 | # Setup CUDA, GPU & distributed training 249 | if hp.train.ngpu > 1: 250 | torch.cuda.manual_seed(hp.train.seed) 251 | hp.train.ngpu = torch.cuda.device_count() 252 | hp.train.batch = int(hp.train.batch / hp.train.ngpu) 253 | print('Batch size per GPU :', hp.train.batch) 254 | 255 | 256 | # Set seed 257 | set_seed(hp) 258 | 259 | # Model & Tokenizer Setup 260 | model = CeiT(image_size = hp.data.image_size, patch_size = hp.model.patch_size, num_classes = hp.model.num_classes, 261 | dim = hp.model.dim, depth = hp.model.depth, heads = hp.model.heads, pool = hp.model.pool, 262 | in_channels = hp.model.in_channels, out_channels = hp.model.out_channels, with_lca=hp.model.with_lca) 263 | 264 | num_params = count_parameters(model) 265 | 266 | logger.info("Training parameters %s", args) 267 | logger.info("Total Parameter: \t%2.1fM" % num_params) 268 | print(num_params) 269 | 270 | # Training 271 | #train(args, hp, model) 272 | if hp.train.ngpu > 1: 273 | mp.spawn(train, nprocs=hp.train.ngpu, args=(args, hp, model,)) 274 | else: 275 | train(0, args, hp, model) 276 | 277 | 278 | if __name__ == "__main__": 279 | main() 280 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification 3 | 4 | import os 5 | import yaml 6 | import torch.distributed as dist 7 | import logging 8 | import math 9 | 10 | from torch.optim.lr_scheduler import LambdaLR 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class ConstantLRSchedule(LambdaLR): 15 | """ Constant learning rate schedule. 16 | """ 17 | def __init__(self, optimizer, last_epoch=-1): 18 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 19 | 20 | 21 | class WarmupConstantSchedule(LambdaLR): 22 | """ Linear warmup and then constant. 23 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 24 | Keeps learning rate schedule equal to 1. after warmup_steps. 25 | """ 26 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 27 | self.warmup_steps = warmup_steps 28 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 29 | 30 | def lr_lambda(self, step): 31 | if step < self.warmup_steps: 32 | return float(step) / float(max(1.0, self.warmup_steps)) 33 | return 1. 34 | 35 | 36 | class WarmupLinearSchedule(LambdaLR): 37 | """ Linear warmup and then linear decay. 38 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 39 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 40 | """ 41 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 42 | self.warmup_steps = warmup_steps 43 | self.t_total = t_total 44 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 45 | 46 | def lr_lambda(self, step): 47 | if step < self.warmup_steps: 48 | return float(step) / float(max(1, self.warmup_steps)) 49 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 50 | 51 | 52 | class WarmupCosineSchedule(LambdaLR): 53 | """ Linear warmup and then cosine decay. 54 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 55 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 56 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 57 | """ 58 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 59 | self.warmup_steps = warmup_steps 60 | self.t_total = t_total 61 | self.cycles = cycles 62 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 63 | 64 | def lr_lambda(self, step): 65 | if step < self.warmup_steps: 66 | return float(step) / float(max(1.0, self.warmup_steps)) 67 | # progress after warmup 68 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 69 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 70 | 71 | def get_rank(): 72 | if not dist.is_available(): 73 | return 0 74 | if not dist.is_initialized(): 75 | return 0 76 | return dist.get_rank() 77 | 78 | def get_world_size(): 79 | if not dist.is_available(): 80 | return 1 81 | if not dist.is_initialized(): 82 | return 1 83 | return dist.get_world_size() 84 | 85 | def is_main_process(): 86 | return get_rank() == 0 87 | 88 | def format_step(step): 89 | if isinstance(step, str): 90 | return step 91 | s = "" 92 | if len(step) > 0: 93 | s += "Training Epoch: {} ".format(step[0]) 94 | if len(step) > 1: 95 | s += "Training Iteration: {} ".format(step[1]) 96 | if len(step) > 2: 97 | s += "Validation Iteration: {} ".format(step[2]) 98 | return s 99 | 100 | def load_hparam_str(hp_str): 101 | path = 'temp-restore.yaml' 102 | with open(path, 'w') as f: 103 | f.write(hp_str) 104 | ret = HParam(path) 105 | os.remove(path) 106 | return ret 107 | 108 | 109 | def load_hparam(filename): 110 | stream = open(filename, 'r') 111 | docs = yaml.load_all(stream, Loader=yaml.Loader) 112 | hparam_dict = dict() 113 | for doc in docs: 114 | for k, v in doc.items(): 115 | hparam_dict[k] = v 116 | return hparam_dict 117 | 118 | 119 | def merge_dict(user, default): 120 | if isinstance(user, dict) and isinstance(default, dict): 121 | for k, v in default.items(): 122 | if k not in user: 123 | user[k] = v 124 | else: 125 | user[k] = merge_dict(user[k], v) 126 | return user 127 | 128 | 129 | class Dotdict(dict): 130 | """ 131 | a dictionary that supports dot notation 132 | as well as dictionary access notation 133 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 134 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 135 | get attributes: d.val2 or d['val2'] 136 | """ 137 | __getattr__ = dict.__getitem__ 138 | __setattr__ = dict.__setitem__ 139 | __delattr__ = dict.__delitem__ 140 | 141 | def __init__(self, dct=None): 142 | dct = dict() if not dct else dct 143 | for key, value in dct.items(): 144 | if hasattr(value, 'keys'): 145 | value = Dotdict(value) 146 | self[key] = value 147 | 148 | 149 | class HParam(Dotdict): 150 | 151 | def __init__(self, file): 152 | super(Dotdict, self).__init__() 153 | hp_dict = load_hparam(file) 154 | hp_dotdict = Dotdict(hp_dict) 155 | for k, v in hp_dotdict.items(): 156 | setattr(self, k, v) 157 | 158 | __getattr__ = Dotdict.__getitem__ 159 | __setattr__ = Dotdict.__setitem__ 160 | __delattr__ = Dotdict.__delitem__ --------------------------------------------------------------------------------