├── img ├── README.md └── vit.png ├── presentation.pdf ├── README.md └── vit.py /img/README.md: -------------------------------------------------------------------------------- 1 | # Images 2 | This folder contains all the images. 3 | -------------------------------------------------------------------------------- /img/vit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Vision-Transformer-ViT-in-TensorFlow/HEAD/img/vit.png -------------------------------------------------------------------------------- /presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilroxtomar/Vision-Transformer-ViT-in-TensorFlow/HEAD/presentation.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Transformer (ViT) in TensorFlow. 2 | The repository contains the code for the implementation of the Vision Transformer in the TensorFlow framework.
3 | 4 | - Arxiv Paper: [AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929.pdf) 5 | - Blog Post: [Vision Transformer](https://idiotdeveloper.com/vision-transformer-an-image-is-worth-16x16-words-transformers-for-image-recognition-at-scale/) by Idiot Developer 6 | - YouTube Tutorial: [Vision Transformer Implementation In TensorFlow](https://youtu.be/Fb1xsTXT4P8) 7 | 8 | ## Architecture 9 | | ![The block diagram of the Vision Transformer](img/vit.png) | 10 | | :--: | 11 | | *The block diagram of the Vision Transformer along with the Transformer Encoder.* | 12 | 13 | 14 | ## Contact: 15 | For more follow me on: 16 | 17 | - YouTube 18 | - Facebook 19 | - Twitter 20 | - Instagram 21 | - Telegram 22 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization, MultiHeadAttention, Add 7 | from tensorflow.keras.layers import Input, Embedding, Concatenate 8 | from tensorflow.keras.models import Model 9 | 10 | class ClassToken(Layer): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def build(self, input_shape): 15 | w_init = tf.random_normal_initializer() 16 | self.w = tf.Variable( 17 | initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32), 18 | trainable = True 19 | ) 20 | 21 | def call(self, inputs): 22 | batch_size = tf.shape(inputs)[0] 23 | hidden_dim = self.w.shape[-1] 24 | 25 | cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim]) 26 | cls = tf.cast(cls, dtype=inputs.dtype) 27 | return cls 28 | 29 | def mlp(x, cf): 30 | x = Dense(cf["mlp_dim"], activation="gelu")(x) 31 | x = Dropout(cf["dropout_rate"])(x) 32 | x = Dense(cf["hidden_dim"])(x) 33 | x = Dropout(cf["dropout_rate"])(x) 34 | return x 35 | 36 | def transformer_encoder(x, cf): 37 | skip_1 = x 38 | x = LayerNormalization()(x) 39 | x = MultiHeadAttention( 40 | num_heads=cf["num_heads"], key_dim=cf["hidden_dim"] 41 | )(x, x) 42 | x = Add()([x, skip_1]) 43 | 44 | skip_2 = x 45 | x = LayerNormalization()(x) 46 | x = mlp(x, cf) 47 | x = Add()([x, skip_2]) 48 | 49 | return x 50 | 51 | def ViT(cf): 52 | """ Inputs """ 53 | input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"]) 54 | inputs = Input(input_shape) ## (None, 256, 3072) 55 | 56 | """ Patch + Position Embeddings """ 57 | patch_embed = Dense(cf["hidden_dim"])(inputs) 58 | 59 | positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (256,) 60 | pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768) 61 | embed = patch_embed + pos_embed ## (None, 256, 768) 62 | 63 | """ Adding Class Token """ 64 | token = ClassToken()(embed) 65 | x = Concatenate(axis=1)([token, embed]) ## (None, 257, 768) 66 | 67 | """ Transformer Encoder """ 68 | for _ in range(cf["num_layers"]): 69 | x = transformer_encoder(x, cf) 70 | 71 | """ Classification Head """ 72 | x = LayerNormalization()(x) ## (None, 257, 768) 73 | x = x[:, 0, :] ## (None, 768) 74 | x = Dropout(0.1)(x) 75 | x = Dense(10, activation="softmax")(x) 76 | 77 | model = Model(inputs, x) 78 | return model 79 | 80 | 81 | if __name__ == "__main__": 82 | config = {} 83 | config["num_layers"] = 12 84 | config["hidden_dim"] = 768 85 | config["mlp_dim"] = 3072 86 | config["num_heads"] = 12 87 | config["dropout_rate"] = 0.1 88 | config["num_patches"] = 256 89 | config["patch_size"] = 32 90 | config["num_channels"] = 3 91 | 92 | model = ViT(config) 93 | model.summary() 94 | --------------------------------------------------------------------------------