├── ViT.png ├── README.md └── Google_ViT.py /ViT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tahmid0007/VisionTransformer/HEAD/ViT.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VisionTransformer (Pytorch) 2 | A complete easy to follow implementation of Google's Vision Transformer proposed in "AN IMAGE IS WORTH 16X16 WORDS". This pytorch implementation has comments for better understanding. 3 | 4 | **An image is worth 16x16 words: transformers for image recognition at scale:** 5 | Find the original paper [here](https://arxiv.org/pdf/2010.11929.pdf). 6 |

7 | 8 |

9 | 10 | - This Pytorch Implementation is based on [This repo](https://github.com/lucidrains/vit-pytorch) but follows the original paper more closely in terms of the first patch embedding and initializations. The default dataset used here is CIFAR10 which can be easily changed to ImageNet or anything else. 11 | - You might need to install einops. 12 | - According to the paper, if you are training from scratch, accuracy might not match state of the art CNNs like ResNet. Pretrain on a larger dataset to exploit the full potential of Vision transformer. 13 | - Easy to understand commnets are available in the code for better understanding, specially for beginners in attention and transformer models. 14 | - The standalone script "Google_ViT" is sufficient to run this code. 15 | -------------------------------------------------------------------------------- /Google_ViT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Oct 16 11:37:52 2020 4 | 5 | @author: mthossain 6 | """ 7 | import PIL 8 | import time 9 | import torch 10 | import torchvision 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from torch import nn 14 | 15 | class Residual(nn.Module): 16 | def __init__(self, fn): 17 | super().__init__() 18 | self.fn = fn 19 | def forward(self, x, **kwargs): 20 | return self.fn(x, **kwargs) + x 21 | 22 | class LayerNormalize(nn.Module): 23 | def __init__(self, dim, fn): 24 | super().__init__() 25 | self.norm = nn.LayerNorm(dim) 26 | self.fn = fn 27 | def forward(self, x, **kwargs): 28 | return self.fn(self.norm(x), **kwargs) 29 | 30 | class MLP_Block(nn.Module): 31 | def __init__(self, dim, hidden_dim, dropout = 0.1): 32 | super().__init__() 33 | self.nn1 = nn.Linear(dim, hidden_dim) 34 | torch.nn.init.xavier_uniform_(self.nn1.weight) 35 | torch.nn.init.normal_(self.nn1.bias, std = 1e-6) 36 | self.af1 = nn.GELU() 37 | self.do1 = nn.Dropout(dropout) 38 | self.nn2 = nn.Linear(hidden_dim, dim) 39 | torch.nn.init.xavier_uniform_(self.nn2.weight) 40 | torch.nn.init.normal_(self.nn2.bias, std = 1e-6) 41 | self.do2 = nn.Dropout(dropout) 42 | 43 | def forward(self, x): 44 | x = self.nn1(x) 45 | x = self.af1(x) 46 | x = self.do1(x) 47 | x = self.nn2(x) 48 | x = self.do2(x) 49 | 50 | return x 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, heads = 8, dropout = 0.1): 54 | super().__init__() 55 | self.heads = heads 56 | self.scale = dim ** -0.5 # 1/sqrt(dim) 57 | 58 | self.to_qkv = nn.Linear(dim, dim * 3, bias = True) # Wq,Wk,Wv for each vector, thats why *3 59 | torch.nn.init.xavier_uniform_(self.to_qkv.weight) 60 | torch.nn.init.zeros_(self.to_qkv.bias) 61 | 62 | self.nn1 = nn.Linear(dim, dim) 63 | torch.nn.init.xavier_uniform_(self.nn1.weight) 64 | torch.nn.init.zeros_(self.nn1.bias) 65 | self.do1 = nn.Dropout(dropout) 66 | 67 | 68 | def forward(self, x, mask = None): 69 | b, n, _, h = *x.shape, self.heads 70 | qkv = self.to_qkv(x) #gets q = Q = Wq matmul x1, k = Wk mm x2, v = Wv mm x3 71 | q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h) # split into multi head attentions 72 | 73 | dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 74 | 75 | if mask is not None: 76 | mask = F.pad(mask.flatten(1), (1, 0), value = True) 77 | assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' 78 | mask = mask[:, None, :] * mask[:, :, None] 79 | dots.masked_fill_(~mask, float('-inf')) 80 | del mask 81 | 82 | attn = dots.softmax(dim=-1) #follow the softmax,q,d,v equation in the paper 83 | 84 | out = torch.einsum('bhij,bhjd->bhid', attn, v) #product of v times whatever inside softmax 85 | out = rearrange(out, 'b h n d -> b n (h d)') #concat heads into one matrix, ready for next encoder block 86 | out = self.nn1(out) 87 | out = self.do1(out) 88 | return out 89 | 90 | class Transformer(nn.Module): 91 | def __init__(self, dim, depth, heads, mlp_dim, dropout): 92 | super().__init__() 93 | self.layers = nn.ModuleList([]) 94 | for _ in range(depth): 95 | self.layers.append(nn.ModuleList([ 96 | Residual(LayerNormalize(dim, Attention(dim, heads = heads, dropout = dropout))), 97 | Residual(LayerNormalize(dim, MLP_Block(dim, mlp_dim, dropout = dropout))) 98 | ])) 99 | def forward(self, x, mask = None): 100 | for attention, mlp in self.layers: 101 | x = attention(x, mask = mask) # go to attention 102 | x = mlp(x) #go to MLP_Block 103 | return x 104 | 105 | class ImageTransformer(nn.Module): 106 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0.1, emb_dropout = 0.1): 107 | super().__init__() 108 | assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 109 | num_patches = (image_size // patch_size) ** 2 # e.g. (32/4)**2= 64 110 | patch_dim = channels * patch_size ** 2 # e.g. 3*8**2 = 64*3 111 | 112 | self.patch_size = patch_size 113 | self.pos_embedding = nn.Parameter(torch.empty(1, (num_patches + 1), dim)) 114 | torch.nn.init.normal_(self.pos_embedding, std = .02) # initialized based on the paper 115 | self.patch_conv= nn.Conv2d(3,dim, patch_size, stride = patch_size) #eqivalent to x matmul E, E= embedd matrix, this is the linear patch projection 116 | 117 | #self.E = nn.Parameter(nn.init.normal_(torch.empty(BATCH_SIZE_TRAIN,patch_dim,dim)),requires_grad = True) 118 | 119 | self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) #initialized based on the paper 120 | self.dropout = nn.Dropout(emb_dropout) 121 | 122 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout) 123 | 124 | self.to_cls_token = nn.Identity() 125 | 126 | self.nn1 = nn.Linear(dim, num_classes) # if finetuning, just use a linear layer without further hidden layers (paper) 127 | torch.nn.init.xavier_uniform_(self.nn1.weight) 128 | torch.nn.init.normal_(self.nn1.bias, std = 1e-6) 129 | # self.af1 = nn.GELU() # use additinal hidden layers only when training on large datasets 130 | # self.do1 = nn.Dropout(dropout) 131 | # self.nn2 = nn.Linear(mlp_dim, num_classes) 132 | # torch.nn.init.xavier_uniform_(self.nn2.weight) 133 | # torch.nn.init.normal_(self.nn2.bias) 134 | # self.do2 = nn.Dropout(dropout) 135 | 136 | def forward(self, img, mask = None): 137 | p = self.patch_size 138 | 139 | x = self.patch_conv(img) # each of 64 vecotrs is linearly transformed with a FFN equiv to E matmul 140 | #x = torch.matmul(x, self.E) 141 | x = rearrange(x, 'b c h w -> b (h w) c') # 64 vectors in rows representing 64 patches, each 64*3 long 142 | 143 | cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) 144 | x = torch.cat((cls_tokens, x), dim=1) 145 | x += self.pos_embedding 146 | x = self.dropout(x) 147 | 148 | x = self.transformer(x, mask) #main game 149 | 150 | x = self.to_cls_token(x[:, 0]) 151 | 152 | x = self.nn1(x) 153 | # x = self.af1(x) 154 | # x = self.do1(x) 155 | # x = self.nn2(x) 156 | # x = self.do2(x) 157 | 158 | return x 159 | 160 | 161 | BATCH_SIZE_TRAIN = 100 162 | BATCH_SIZE_TEST = 100 163 | 164 | DL_PATH = "C:\Pytorch\Spyder\CIFAR10_data" # Use your own path 165 | # CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class 166 | transform = torchvision.transforms.Compose( 167 | [torchvision.transforms.RandomHorizontalFlip(), 168 | torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR), 169 | torchvision.transforms.RandomAffine(8, translate=(.15,.15)), 170 | torchvision.transforms.ToTensor(), 171 | torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) 172 | 173 | 174 | train_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=True, 175 | download=True, transform=transform) 176 | 177 | test_dataset = torchvision.datasets.CIFAR10(DL_PATH, train=False, 178 | download=True, transform=transform) 179 | 180 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, 181 | shuffle=True) 182 | 183 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, 184 | shuffle=False) 185 | 186 | def train(model, optimizer, data_loader, loss_history): 187 | total_samples = len(data_loader.dataset) 188 | model.train() 189 | 190 | for i, (data, target) in enumerate(data_loader): 191 | optimizer.zero_grad() 192 | output = F.log_softmax(model(data), dim=1) 193 | loss = F.nll_loss(output, target) 194 | loss.backward() 195 | optimizer.step() 196 | 197 | if i % 100 == 0: 198 | print('[' + '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(total_samples) + 199 | ' (' + '{:3.0f}'.format(100 * i / len(data_loader)) + '%)] Loss: ' + 200 | '{:6.4f}'.format(loss.item())) 201 | loss_history.append(loss.item()) 202 | 203 | def evaluate(model, data_loader, loss_history): 204 | model.eval() 205 | 206 | total_samples = len(data_loader.dataset) 207 | correct_samples = 0 208 | total_loss = 0 209 | 210 | with torch.no_grad(): 211 | for data, target in data_loader: 212 | output = F.log_softmax(model(data), dim=1) 213 | loss = F.nll_loss(output, target, reduction='sum') 214 | _, pred = torch.max(output, dim=1) 215 | 216 | total_loss += loss.item() 217 | correct_samples += pred.eq(target).sum() 218 | 219 | avg_loss = total_loss / total_samples 220 | loss_history.append(avg_loss) 221 | print('\nAverage test loss: ' + '{:.4f}'.format(avg_loss) + 222 | ' Accuracy:' + '{:5}'.format(correct_samples) + '/' + 223 | '{:5}'.format(total_samples) + ' (' + 224 | '{:4.2f}'.format(100.0 * correct_samples / total_samples) + '%)\n') 225 | 226 | N_EPOCHS = 150 227 | 228 | model = ImageTransformer(image_size=32, patch_size=4, num_classes=10, channels=3, 229 | dim=64, depth=6, heads=8, mlp_dim=128) 230 | optimizer = torch.optim.Adam(model.parameters(), lr=0.003) 231 | 232 | 233 | train_loss_history, test_loss_history = [], [] 234 | for epoch in range(1, N_EPOCHS + 1): 235 | print('Epoch:', epoch) 236 | start_time = time.time() 237 | train(model, optimizer, train_loader, train_loss_history) 238 | print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds') 239 | evaluate(model, test_loader, test_loss_history) 240 | 241 | print('Execution time') 242 | 243 | PATH = ".\ViTnet_Cifar10_4x4_aug_1.pt" # Use your own path 244 | torch.save(model.state_dict(), PATH) 245 | 246 | 247 | # ============================================================================= 248 | # model = ViT() 249 | # model.load_state_dict(torch.load(PATH)) 250 | # model.eval() 251 | # ============================================================================= 252 | --------------------------------------------------------------------------------