├── img
├── README.md
└── vit.png
├── presentation.pdf
├── README.md
└── vit.py
/img/README.md:
--------------------------------------------------------------------------------
1 | # Images
2 | This folder contains all the images.
3 |
--------------------------------------------------------------------------------
/img/vit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Vision-Transformer-ViT-in-TensorFlow/HEAD/img/vit.png
--------------------------------------------------------------------------------
/presentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nikhilroxtomar/Vision-Transformer-ViT-in-TensorFlow/HEAD/presentation.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vision Transformer (ViT) in TensorFlow.
2 | The repository contains the code for the implementation of the Vision Transformer in the TensorFlow framework.
3 |
4 | - Arxiv Paper: [AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE](https://arxiv.org/pdf/2010.11929.pdf)
5 | - Blog Post: [Vision Transformer](https://idiotdeveloper.com/vision-transformer-an-image-is-worth-16x16-words-transformers-for-image-recognition-at-scale/) by Idiot Developer
6 | - YouTube Tutorial: [Vision Transformer Implementation In TensorFlow](https://youtu.be/Fb1xsTXT4P8)
7 |
8 | ## Architecture
9 | |  |
10 | | :--: |
11 | | *The block diagram of the Vision Transformer along with the Transformer Encoder.* |
12 |
13 |
14 | ## Contact:
15 | For more follow me on:
16 |
17 | - YouTube
18 | - Facebook
19 | - Twitter
20 | - Instagram
21 | - Telegram
22 |
--------------------------------------------------------------------------------
/vit.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
4 |
5 | import tensorflow as tf
6 | from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization, MultiHeadAttention, Add
7 | from tensorflow.keras.layers import Input, Embedding, Concatenate
8 | from tensorflow.keras.models import Model
9 |
10 | class ClassToken(Layer):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def build(self, input_shape):
15 | w_init = tf.random_normal_initializer()
16 | self.w = tf.Variable(
17 | initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
18 | trainable = True
19 | )
20 |
21 | def call(self, inputs):
22 | batch_size = tf.shape(inputs)[0]
23 | hidden_dim = self.w.shape[-1]
24 |
25 | cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
26 | cls = tf.cast(cls, dtype=inputs.dtype)
27 | return cls
28 |
29 | def mlp(x, cf):
30 | x = Dense(cf["mlp_dim"], activation="gelu")(x)
31 | x = Dropout(cf["dropout_rate"])(x)
32 | x = Dense(cf["hidden_dim"])(x)
33 | x = Dropout(cf["dropout_rate"])(x)
34 | return x
35 |
36 | def transformer_encoder(x, cf):
37 | skip_1 = x
38 | x = LayerNormalization()(x)
39 | x = MultiHeadAttention(
40 | num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
41 | )(x, x)
42 | x = Add()([x, skip_1])
43 |
44 | skip_2 = x
45 | x = LayerNormalization()(x)
46 | x = mlp(x, cf)
47 | x = Add()([x, skip_2])
48 |
49 | return x
50 |
51 | def ViT(cf):
52 | """ Inputs """
53 | input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"])
54 | inputs = Input(input_shape) ## (None, 256, 3072)
55 |
56 | """ Patch + Position Embeddings """
57 | patch_embed = Dense(cf["hidden_dim"])(inputs)
58 |
59 | positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (256,)
60 | pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768)
61 | embed = patch_embed + pos_embed ## (None, 256, 768)
62 |
63 | """ Adding Class Token """
64 | token = ClassToken()(embed)
65 | x = Concatenate(axis=1)([token, embed]) ## (None, 257, 768)
66 |
67 | """ Transformer Encoder """
68 | for _ in range(cf["num_layers"]):
69 | x = transformer_encoder(x, cf)
70 |
71 | """ Classification Head """
72 | x = LayerNormalization()(x) ## (None, 257, 768)
73 | x = x[:, 0, :] ## (None, 768)
74 | x = Dropout(0.1)(x)
75 | x = Dense(10, activation="softmax")(x)
76 |
77 | model = Model(inputs, x)
78 | return model
79 |
80 |
81 | if __name__ == "__main__":
82 | config = {}
83 | config["num_layers"] = 12
84 | config["hidden_dim"] = 768
85 | config["mlp_dim"] = 3072
86 | config["num_heads"] = 12
87 | config["dropout_rate"] = 0.1
88 | config["num_patches"] = 256
89 | config["patch_size"] = 32
90 | config["num_channels"] = 3
91 |
92 | model = ViT(config)
93 | model.summary()
94 |
--------------------------------------------------------------------------------