├── requirements.txt ├── .gitattributes ├── config.txt ├── module ├── __pycache__ │ └── Layers.cpython-38.pyc └── Layers.py ├── .gitignore ├── LICENSE ├── README.md ├── model.py └── train.py /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /config.txt: -------------------------------------------------------------------------------- 1 | 32 2 | 500 3 | 0.0001 4 | 0.0001 5 | 224 6 | 16 7 | 100 8 | 768 9 | 12 10 | 12 11 | 3072 12 | 3 13 | 0. 14 | cls 15 | -------------------------------------------------------------------------------- /module/__pycache__/Layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qnguyen3/ViT_PyTorch/HEAD/module/__pycache__/Layers.cpython-38.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cifar100/cifar-100-python.tar.gz 2 | cifar100/cifar-100-python/train 3 | cifar100/cifar-100-python/test 4 | cifar100/cifar-100-python/meta 5 | cifar100/cifar-100-python/file.txt~ 6 | __pycache__/model.cpython-38.pyc 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nguyen Hoang Quan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViT - Vision Transformer 2 | 3 | This is an implementation of ViT - Vision Transformer by Google Research Team through the paper [**"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"**](https://arxiv.org/abs/2010.11929) 4 | 5 | **Please install PyTorch with CUDA support following this [link](https://pytorch.org/get-started/locally/)** 6 | 7 | ## ViT Architecture 8 | ![Architecture of Vision Transformer](https://neurohive.io/wp-content/uploads/2020/10/rsz_cov.png) 9 | 10 | ## Configs 11 | You can config the network by yourself through the `config.txt` file 12 | 13 | ``` 14 | 128 #batch_size 15 | 500 #epoch 16 | 0.001 #learning_rate 17 | 0.0001 #gamma 18 | 224 #img_size 19 | 16 #patch_size 20 | 100 #num_class 21 | 768 #d_model 22 | 12 #n_head 23 | 12 #n_layers 24 | 3072 #d_mlp 25 | 3 #channels 26 | 0. #dropout 27 | cls #pool 28 | ``` 29 | 30 | ## Training 31 | Currently, you can only train this model on CIFAR-100 with the following commands: 32 | 33 | `> git clone https://github.com/quanmario0311/ViT_PyTorch.git`\ 34 | `> cd ViT_PyTorch`\ 35 | `> pip3 install -r requirements.txt`\ 36 | `> python3 train.py` 37 | 38 | ***Suppport for other dataset and custom datasets will be updated later*** 39 | -------------------------------------------------------------------------------- /module/Layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | from torch.nn.modules.linear import Linear 8 | 9 | 10 | #Positional Encoding 11 | class PositionalEncoding(nn.Module): 12 | def __init__(self, d_model: int = 768, num_patches: int = None, dropout: float = 0.): 13 | super().__init__() 14 | self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model)) 15 | self.dropout = nn.Dropout(dropout) 16 | 17 | def forward(self, x, n): 18 | x = x + self.pos_embed[:, :(n+1)] 19 | x = self.dropout(x) 20 | return x 21 | 22 | #Norm Layer 23 | class Norm(nn.Module): 24 | def __init__(self, d_model: int = 768, next_layer: nn.Module = None): 25 | super().__init__() 26 | self.norm = nn.LayerNorm(d_model) 27 | self.next_layer = next_layer 28 | def forward(self, x: torch.Tensor, **kwargs): 29 | x = self.norm(x) 30 | return self.next_layer(x, **kwargs) 31 | 32 | #Feed Forward MLP 33 | class FeedForward(nn.Module): 34 | def __init__(self, d_model: int = 768, d_mlp: int = 3072, dropout: float = 0.): 35 | super().__init__() 36 | self.layers = nn.Sequential( 37 | nn.Linear(d_model, d_mlp), 38 | nn.GELU(), 39 | nn.Dropout(dropout), 40 | nn.Linear(d_mlp, d_model), 41 | nn.Dropout(dropout) 42 | ) 43 | 44 | def forward(self, x: torch.Tensor): 45 | return self.layers(x) 46 | 47 | #Multi-Head Attention 48 | class MultiHeadAttention(nn.Module): 49 | def __init__(self, d_model: int = 768, n_head: int = 12, dropout: float = 0.): 50 | super().__init__() 51 | self.d_model = d_model 52 | self.n_head = n_head 53 | self.dropout = nn.Dropout(dropout) 54 | 55 | d_head = d_model // n_head 56 | project_out = not (n_head == 1 and d_head == d_model) 57 | 58 | self.scale = d_head ** -0.5 59 | self.softmax = nn.Softmax(dim = -1) 60 | self.w_qkv = nn.Linear(d_model, d_model * 3, bias = False) 61 | 62 | self.fc_out = nn.Sequential( 63 | nn.Linear(d_model, d_model), 64 | nn.Dropout(dropout) 65 | ) if project_out else nn.Identity() 66 | 67 | def forward(self, x): 68 | b, n, _, h = *x.shape, self.n_head 69 | qkv = self.w_qkv(x).chunk(3, dim = -1) 70 | queries, keys, values = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 71 | # Compute Attention score 72 | scores = torch.einsum('b h i d, b h j d -> b h i j', queries, keys) * self.scale 73 | attention = self.softmax(scores) 74 | 75 | x = torch.einsum('b h i j, b h j d -> b h i d', attention, values) 76 | x = rearrange(x, 'b h n d -> b n (h d)') 77 | 78 | return x 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ### Author: Quan Nguyen 2 | 3 | from module.Layers import * 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | from torch.nn.modules.linear import Linear 11 | 12 | #Vision Transformer 13 | class ViT(nn.Module): 14 | def __init__(self, img_size: int = 256, patch_size: int = 16, 15 | num_class: int = 1000, d_model: int = 768, n_head: int = 12, 16 | n_layers:int = 12, d_mlp: int = 3072, channels: int = 3, 17 | dropout: float = 0., pool: str = 'cls'): 18 | super().__init__() 19 | 20 | img_h, img_w = img_size, img_size 21 | patch_h, patch_w = patch_size, patch_size 22 | 23 | assert img_h % patch_h == 0, 'image dimension must be divisible by patch dimension' 24 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 25 | 26 | num_patches = (img_h // patch_h) * (img_w // patch_w) 27 | patch_dim = channels * patch_h * patch_w 28 | 29 | self.patches_embed = nn.Sequential( 30 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2 = patch_w), 31 | nn.Linear(patch_dim, d_model) 32 | ) 33 | 34 | self.pos_embed = PositionalEncoding(d_model, num_patches, dropout) 35 | self.class_token = nn.Parameter(torch.randn(1, 1, d_model)) 36 | self.pool = pool 37 | 38 | self.transformer = Transformer(d_model, n_head, n_layers, d_mlp, dropout) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | self.mlp_head = nn.Sequential( 42 | nn.LayerNorm(d_model), 43 | nn.Linear(d_model, num_class) 44 | ) 45 | 46 | def forward(self, img): 47 | x = self.patches_embed(img) 48 | b, n, _ = x.shape 49 | class_token = repeat(self.class_token, '() n d -> b n d', b = b) 50 | #Concat Class Token with image patches 51 | x = torch.cat((class_token,x), dim=1) 52 | #Add Positional Encoding 53 | x = self.pos_embed(x, n) 54 | x = self.transformer(x) 55 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 56 | #MLP Head 57 | x = self.mlp_head(x) 58 | return x 59 | 60 | # Transformer 61 | class Transformer(nn.Module): 62 | def __init__(self, d_model: int = 768, n_head: int = 12, n_layers:int = 12, 63 | d_mlp: int = 3072, dropout: float = 0.): 64 | super().__init__() 65 | 66 | self.block = nn.ModuleList([ 67 | Norm(d_model, MultiHeadAttention(d_model, n_head, dropout)), 68 | Norm(d_model, FeedForward(d_model, d_mlp, dropout)) 69 | ]) 70 | self.layers = nn.ModuleList([self.block for _ in range(n_layers)]) 71 | 72 | def forward(self, x): 73 | for attention, mlp in self.layers: 74 | x = attention(x) + x 75 | x = mlp(x) + x 76 | return x 77 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.datasets import CIFAR10 5 | from torch.utils.data.dataloader import DataLoader 6 | import torchvision.transforms as transforms 7 | import random 8 | import numpy as np 9 | import os 10 | from tqdm import tqdm 11 | import torch.optim as optim 12 | from torch.optim.lr_scheduler import StepLR 13 | from model import ViT 14 | 15 | seed = 3 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | def train(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs: int = 100): 19 | for epoch in range(epochs): 20 | epoch_loss = 0 21 | epoch_accuracy = 0 22 | for data, label in tqdm(train_loader): 23 | model.train() 24 | optimizer.zero_grad() 25 | #Load data into cuda 26 | data = data.to(device) 27 | label = label.to(device) 28 | #Pass data to model 29 | output = model(data) 30 | loss = criterion(output, label) 31 | #Optimizing 32 | loss.backward() 33 | optimizer.step() 34 | #Calculate Accuracy 35 | acc = (output.argmax(dim=1) == label).float().mean() 36 | epoch_accuracy += acc / len(train_loader) 37 | epoch_loss += loss / len(train_loader) 38 | if val_loader is not None: 39 | epoch_val_accuracy = 0 40 | epoch_val_loss = 0 41 | for data, label in valid_loader: 42 | model.eval() 43 | #Load val_data into cuda 44 | data = data.to(device) 45 | label = label.to(device) 46 | #Pass val_data to model 47 | val_output = model(data) 48 | val_loss = criterion(val_output, label) 49 | #Calculate Validation Accuracy 50 | acc = (val_output.argmax(dim=1) == label).float().mean() 51 | epoch_val_accuracy += acc / len(valid_loader) 52 | epoch_val_loss += val_loss / len(valid_loader) 53 | if val_loader is not None: 54 | print( 55 | f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n" 56 | ) 57 | else: 58 | print( 59 | f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\n" 60 | ) 61 | 62 | 63 | 64 | 65 | 66 | def read_config(config_path: str = "config.txt"): 67 | with open(config_path) as f: 68 | lines = f.readlines() 69 | lines = [word.strip('\n') for word in lines] 70 | return {'batch_size': int(lines[0]), 71 | 'epochs': int(lines[1]), 72 | 'learning_rate': float(lines[2]), 73 | 'gamma': float(lines[3]), 74 | 'img_size': int(lines[4]), 75 | 'patch_size': int(lines[5]), 76 | 'num_class': int(lines[6]), 77 | 'd_model': int(lines[7]), 78 | 'n_head': int(lines[8]), 79 | 'n_layers': int(lines[9]), 80 | 'd_mlp': int(lines[10]), 81 | 'channels': int(lines[11]), 82 | 'dropout': float(lines[12]), 83 | 'pool': lines[13]} 84 | 85 | 86 | 87 | 88 | def seed_everything(seed): 89 | random.seed(seed) 90 | os.environ['PYTHONHASHSEED'] = str(seed) 91 | np.random.seed(seed) 92 | torch.manual_seed(seed) 93 | torch.cuda.manual_seed(seed) 94 | torch.cuda.manual_seed_all(seed) 95 | torch.backends.cudnn.deterministic = True 96 | 97 | 98 | if __name__ == "__main__": 99 | seed_everything(seed) 100 | configs = read_config() 101 | img_size = configs['img_size'] 102 | 103 | train_transforms = transforms.Compose([ 104 | transforms.Resize((img_size, img_size)), 105 | transforms.RandomResizedCrop(img_size), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | ]) 109 | 110 | test_transforms = transforms.Compose([ 111 | transforms.Resize((img_size, img_size)), 112 | transforms.ToTensor(), 113 | ]) 114 | 115 | train_data = CIFAR10(download=True,root="./cifar10",transform=train_transforms) 116 | test_val_data = CIFAR10(root="./cifar10",train = False,transform=test_transforms) 117 | train_len = len(train_data) 118 | val_len = test_len = int(len(test_val_data)/2) 119 | test_data, val_data = torch.utils.data.random_split(test_val_data, [test_len, val_len]) 120 | num_class = len(np.unique(train_data.targets)) 121 | train_loader = DataLoader(dataset = train_data, batch_size = configs['batch_size'], shuffle = True) 122 | test_loader = DataLoader(dataset = test_data, batch_size=configs['batch_size'], shuffle = True) 123 | valid_loader = DataLoader(dataset = val_data, batch_size=configs['batch_size'], shuffle = True) 124 | 125 | vision_transformer = ViT(img_size = configs['img_size'], 126 | patch_size = configs['patch_size'], 127 | num_class = configs['num_class'], 128 | d_model = configs['d_model'], 129 | n_head = configs['n_head'], 130 | n_layers = configs['n_layers'], 131 | d_mlp = configs['d_mlp'], 132 | channels = configs['channels'], 133 | dropout = configs['dropout'], 134 | pool = configs['pool']).to(device) 135 | #epochs 136 | epochs = configs['epochs'] 137 | # loss function 138 | criterion = nn.CrossEntropyLoss() 139 | # optimizer 140 | optimizer = optim.Adam(vision_transformer.parameters(), lr=configs['learning_rate']) 141 | # scheduler 142 | scheduler = StepLR(optimizer, step_size=10, gamma=0.7) 143 | 144 | train(vision_transformer, train_loader, valid_loader, criterion, optimizer, scheduler, epochs) 145 | --------------------------------------------------------------------------------