├── gittransformer.png ├── LICENSE ├── README.md ├── GPT_1.py ├── VisionImageTransformer.py ├── GPT_2.py ├── BERT.py ├── TRANSFORMERS.py └── PERFORMER.py /gittransformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShivamRajSharma/Transformer-Architectures-From-Scratch/HEAD/gittransformer.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shivam Raj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Architecure From Scratch Using PyTorch 2 | 3 |

4 | 5 |

6 | 7 | 8 | 9 | ## 1) TRANSFORMER - 10 | A Self attention based Encoder-Decoder Architecture. It is mostly used for 11 | 1) Machine Translation 12 | 2) Document Summaraization 13 | 3) Text extraction 14 | 15 | Paper - https://arxiv.org/abs/1706.03762 16 | 17 | ## 2) BERT - 18 | A Self-attention based Encoder Architecture. It is mostly used for 19 | 1) Sentiment Classification 20 | 2) Named Entity Recognition 21 | 3) Question and Answering 22 | 4) Sentence Embedding Extraction 23 | 5) Document Matching 24 | 25 | Paper - https://arxiv.org/abs/1810.04805 26 | 27 | ## 3) GPT-1 - 28 | A Self-attention based Decoder based Autoregressive model. It is mostly used for 29 | 1) Sentence Completion 30 | 2) Generating Text 31 | 3) Sentiment Classification 32 | 33 | Paper - https://paperswithcode.com/method/gpt 34 | 35 | ## 4) GPT-2 - 36 | A Self-attention based Decoder based Autoregressive model with a slight change in architecture and trained on larger corpus of text than GPT-1. It is mostly used for 37 | 1) Sentence Completion 38 | 2) Generating Text 39 | 3) Sentiment Classification 40 | 41 | Paper - https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf 42 | 43 | ## 5) ViT - 44 | A State of the art Self-attention based Encoder Architecture for Computer Vision application. It is mostly used for 45 | 1) Image Classification 46 | 2) Image Encoding 47 | 3) Backbone for Object Detection 48 | 49 | Paper - https://arxiv.org/abs/2006.03677 50 | 51 | ## 6) PERFORMER - 52 | A Self-attention based Encoder-Decoder Architecture with a linear time complexity other than transformer which has quadratic time complexity. It is mostly used 53 | 1) Machine Translation 54 | 2) Document Summaraization 55 | 3) Text extraction 56 | 57 | Paper - https://arxiv.org/abs/2009.14794 58 | -------------------------------------------------------------------------------- /GPT_1.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__(self, input_dims, heads): 7 | super(SelfAttention, self).__init__() 8 | self.heads = heads 9 | self.head_dims = int(input_dims/heads) 10 | self.input_dims = input_dims 11 | 12 | self.query = nn.Linear(self.head_dims, self.head_dims) 13 | self.key = nn.Linear(self.head_dims, self.head_dims) 14 | self.value = nn.Linear(self.head_dims, self.head_dims) 15 | self.fc = nn.Linear(self.head_dims*heads, self.input_dims) 16 | 17 | def forward(self, query, key, value, mask): 18 | Batch, Seq_len, embed = query.shape 19 | query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1] 20 | 21 | query = query.reshape(Batch, query_len, self.heads, self.head_dims) 22 | key = key.reshape(Batch, key_len, self.heads, self.head_dims) 23 | value = value.reshape(Batch, value_len, self.heads, self.head_dims) 24 | 25 | query = self.query(query) 26 | key = self.key(key) 27 | value = self.value(value) 28 | 29 | score = torch.einsum('bqhd,bkhd->bhqk', [query, key]) 30 | if mask is not None: 31 | score = score.masked_fill(mask == 0, float('-1e20')) 32 | 33 | attention_score = nn.Softmax(dim=-1)(score/((self.head_dims)**(1/2))) 34 | out = torch.einsum('bhqv,bvhd->bqhd', [attention_score, value]).reshape(Batch, query_len, self.head_dims*self.heads) 35 | out = self.fc(out) 36 | 37 | return out 38 | 39 | 40 | class GPTBlock(nn.Module): 41 | def __init__( 42 | self, 43 | heads, 44 | embedding_dims, 45 | dropout, 46 | forward_expansion, 47 | layer_norm_eps 48 | ): 49 | super(GPTBlock, self).__init__() 50 | self.embedding_dims = embedding_dims 51 | self.attention = SelfAttention(embedding_dims, heads) 52 | self.layer_norm1 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 53 | self.layer_norm2 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 54 | self.feed_forward = nn.Sequential( 55 | *[ 56 | nn.Linear(embedding_dims, embedding_dims*forward_expansion), 57 | nn.GELU(), 58 | nn.Linear(embedding_dims*forward_expansion, embedding_dims) 59 | ] 60 | ) 61 | self.dropout = nn.Dropout(dropout) 62 | 63 | def forward(self, x, mask): 64 | attention_block = self.attention(x, x, x, mask) 65 | add = self.dropout(self.layer_norm1(attention_block + x)) 66 | feed_forward = self.feed_forward(add) 67 | out = self.dropout(self.layer_norm2(feed_forward + add)) 68 | return out 69 | 70 | 71 | class GPT(nn.Module): 72 | def __init__( 73 | self, 74 | vocab_size, 75 | embedding_dims, 76 | dropout, 77 | heads, 78 | num_of_layers, 79 | forward_expansion, 80 | max_len, 81 | layer_norm_eps = 1e-5 82 | ): 83 | super(GPT, self).__init__() 84 | self.embedding_dims = embedding_dims 85 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dims) 86 | self.positional_embeddings = nn.Parameter(torch.zeros(1, max_len, embedding_dims)) 87 | self.dropout = nn.Dropout(dropout) 88 | self.gpt_blocks = nn.ModuleList( 89 | [ 90 | GPTBlock( 91 | heads, 92 | embedding_dims, 93 | dropout, 94 | forward_expansion, 95 | layer_norm_eps 96 | 97 | ) 98 | for _ in range(num_of_layers) 99 | ] 100 | ) 101 | 102 | self.layer_norm = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 103 | self.fc = nn.Linear(embedding_dims, vocab_size) 104 | 105 | self.apply(self._init_weights) 106 | 107 | #From @HuggingFace 108 | def _init_weights(self, module): 109 | if isinstance(module, (nn.Linear, nn.Embedding)): 110 | module.weight.data.normal_(mean=0.0, std=0.02) 111 | 112 | elif isinstance(module, nn.LayerNorm): 113 | module.bias.data.zero_() 114 | module.weight.data.fill_(1.0) 115 | 116 | if isinstance(module, nn.Linear) and module.bias is not None: 117 | module.bias.data.zero_() 118 | 119 | def casual_mask(self, x): 120 | mask = torch.tril(torch.ones((x.shape[0], x.shape[-1], x.shape[-1]))).unsqueeze(1) 121 | return mask 122 | 123 | def forward(self, x): 124 | casual_mask = self.casual_mask(x) 125 | seq_len = x.shape[-1] 126 | word_embeddings = self.word_embeddings(x) 127 | x = self.dropout(word_embeddings + self.positional_embeddings[:, :seq_len, :]) 128 | for block in self.gpt_blocks: 129 | x = block(x, casual_mask) 130 | x = self.layer_norm(x) 131 | out = self.fc(x) 132 | return x 133 | 134 | 135 | if __name__ == '__main__': 136 | #DEFAULT GPT PARAMETERS :- 137 | vocab_size = 40478 138 | embedding_dims = 768 139 | dropout = 0.1 140 | heads = 12 141 | num_of_layers = 12 142 | forward_expansion = 4 143 | max_len = 512 144 | 145 | 146 | a = torch.randint(1, 100, (1, 300)) 147 | model = GPT( 148 | vocab_size, 149 | embedding_dims, 150 | dropout, 151 | heads, 152 | num_of_layers, 153 | forward_expansion, 154 | max_len, 155 | ) 156 | 157 | start = time() 158 | y = model(a) 159 | print(f'INFERENCE TIME = {time() - start}sec') 160 | x = sum(p.numel() for p in model.parameters() if p.requires_grad) 161 | print(f'NUMBER OF PARAMETERS ARE = {x}') 162 | -------------------------------------------------------------------------------- /VisionImageTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_dims, 9 | heads, 10 | dropout 11 | ): 12 | super(SelfAttention, self).__init__() 13 | self.heads = heads 14 | self.embedding_dims = embedding_dims 15 | self.head_dims = int(embedding_dims/heads) 16 | 17 | self.key = nn.Linear(self.head_dims, self.head_dims) 18 | self.query = nn.Linear(self.head_dims, self.head_dims) 19 | self.value = nn.Linear(self.head_dims, self.head_dims) 20 | 21 | self.fc = nn.Linear(self.head_dims*self.heads, self.embedding_dims) 22 | 23 | self.dropout = nn.Dropout(dropout) 24 | 25 | def forward(self, query, key, value, mask): 26 | Batch = query.shape[0] 27 | 28 | query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1] 29 | 30 | query = query.reshape(Batch, query_len, self.heads, self.head_dims) 31 | key = key.reshape(Batch, key_len, self.heads, self.head_dims) 32 | value = value.reshape(Batch, value_len, self.heads, self.head_dims) 33 | 34 | query = self.query(query) 35 | key = self.key(key) 36 | value = self.value(value) 37 | 38 | attention_score = torch.einsum('bqhd,bkhd->bhqk', [query, key]) 39 | 40 | if mask is not None: 41 | attention_score = attention_score.masked_fill(mask==0, float('-1e20')) 42 | 43 | attention_score = attention_score/((self.head_dims)**(1/2)) 44 | attention_score = torch.softmax(attention_score, dim=-1) 45 | 46 | out = torch.einsum('bhqv,bvhd->bqhd', [attention_score, value]).reshape( 47 | Batch, query_len, self.heads*self.head_dims 48 | ) 49 | 50 | out = self.dropout(self.fc(out)) 51 | 52 | return out 53 | 54 | 55 | 56 | class TransformerBlock(nn.Module): 57 | def __init__( 58 | self, 59 | embedding_dims, 60 | heads, 61 | dropout, 62 | forward_expansion, 63 | layer_norm_eps 64 | ): 65 | super(TransformerBlock, self).__init__() 66 | self.layer_norm1 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 67 | self.layer_norm2 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 68 | self.attention = SelfAttention(embedding_dims, heads, dropout) 69 | self.feed_forward = nn.Sequential( 70 | nn.Linear(embedding_dims, embedding_dims*forward_expansion), 71 | nn.GELU(), 72 | nn.Dropout(dropout), 73 | nn.Linear(embedding_dims*forward_expansion, embedding_dims), 74 | nn.Dropout(dropout) 75 | ) 76 | self.dropout = nn.Dropout(dropout) 77 | 78 | def forward(self, x, mask): 79 | norm = self.layer_norm1(x) 80 | attention_block = self.attention(norm, norm, norm, mask) 81 | add = x + attention_block 82 | norm = self.layer_norm2(add) 83 | feed_forward = self.feed_forward(norm) 84 | out = feed_forward + add 85 | return out 86 | 87 | 88 | class ViT(nn.Module): 89 | def __init__( 90 | self, 91 | patch_height, 92 | patch_width, 93 | max_len, 94 | embedding_dims, 95 | heads, 96 | forward_expansion, 97 | num_layers, 98 | dropout, 99 | layer_norm_eps, 100 | num_classes 101 | ): 102 | super(ViT, self).__init__() 103 | 104 | self.vit_blocks = nn.Sequential( 105 | *[ 106 | TransformerBlock( 107 | embedding_dims, 108 | heads, 109 | dropout, 110 | forward_expansion, 111 | layer_norm_eps 112 | ) 113 | for _ in range(num_layers) 114 | ] 115 | 116 | ) 117 | self.patch_height = patch_height 118 | self.patch_width = patch_width 119 | self.cls_embedding = nn.Parameter(torch.zeros(1, 1, embedding_dims)) 120 | self.patch_embeddings = nn.Linear(embedding_dims, embedding_dims) 121 | self.postional_embedding = nn.Parameter(torch.zeros(1, max_len+1, embedding_dims)) 122 | self.to_cls_token = nn.Identity() 123 | self.classifier = nn.Sequential( 124 | nn.LayerNorm(embedding_dims), 125 | nn.Linear(embedding_dims, num_classes*4), 126 | nn.GELU(), 127 | nn.Dropout(dropout), 128 | nn.Linear(num_classes*4, num_classes) 129 | ) 130 | self.dropout = nn.Dropout(dropout) 131 | 132 | 133 | def forward(self, images): 134 | patches = images.unfold(2, self.patch_height, self.patch_width).unfold(3, self.patch_height, self.patch_width) 135 | patches = patches.permute(0, 2, 3, 1, 4, 5) 136 | patches = patches.reshape( 137 | patches.shape[0], 138 | patches.shape[1], 139 | patches.shape[2], 140 | patches.shape[3]*patches.shape[4]*patches.shape[5] 141 | ) 142 | patches = patches.view(patches.shape[0], -1, patches.shape[-1]) 143 | 144 | x = self.cls_embedding.expand(patches.shape[0], -1, -1) 145 | patch_embeddings = self.patch_embeddings(patches) 146 | x = torch.cat((x, patch_embeddings), dim=1) + self.postional_embedding 147 | x = self.dropout(x) 148 | mask = None 149 | for block in self.vit_blocks: 150 | x = block(x, mask) 151 | out = self.to_cls_token(x[:, 0]) 152 | out = self.classifier(out) 153 | return out 154 | 155 | 156 | 157 | if __name__ == "__main__": 158 | 159 | model = ViT( 160 | patch_height = 16, 161 | patch_width = 16, 162 | embedding_dims = 768, 163 | dropout = 0.1, 164 | heads = 4, 165 | num_layers = 4, 166 | forward_expansion = 4, 167 | max_len = int((32*32)/(16*16)), 168 | layer_norm_eps = 1e-5, 169 | num_classes = 10, 170 | ) 171 | 172 | a = torch.randn(32, 3, 32, 32) 173 | output = model(a) 174 | print(output.shape) -------------------------------------------------------------------------------- /GPT_2.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__(self, input_dims, heads): 7 | super(SelfAttention, self).__init__() 8 | self.heads = heads 9 | self.head_dims = int(input_dims/heads) 10 | self.input_dims = input_dims 11 | 12 | self.expand = nn.Linear(self.input_dims, self.input_dims*3) 13 | self.fc = nn.Linear(self.head_dims*heads, self.input_dims) 14 | 15 | def split_(self, x): 16 | query, key, value = x.split(self.input_dims, dim=-1) 17 | return query, key, value 18 | 19 | def forward(self, x, mask, past): 20 | Batch, seq_len, embed = x.shape 21 | expand = self.expand(x) 22 | query, key, value = self.split_(expand) 23 | 24 | query = query.reshape(Batch, seq_len, self.heads, self.head_dims) 25 | key = key.reshape(Batch, seq_len, self.heads, self.head_dims) 26 | value = value.reshape(Batch, seq_len, self.heads, self.head_dims) 27 | 28 | present = torch.cat((key.unsqueeze(0), value.unsqueeze(0)), dim=0) 29 | 30 | if past is not None: 31 | past_key, past_value = past 32 | key = torch.cat((past_key, key), dim=1) 33 | value = torch.cat((past_value, value), dim=1) 34 | 35 | score = torch.einsum('bqhd,bkhd->bhqk', [query, key]) 36 | if mask is not None: 37 | score = score.masked_fill(mask == 0, float('-1e20')) 38 | 39 | attention_score = nn.Softmax(dim=-1)(score/((self.head_dims)**(1/2))) 40 | out = torch.einsum('bhqv,bvhd->bqhd', [attention_score, value]).reshape(Batch, seq_len, self.head_dims*self.heads) 41 | out = self.fc(out) 42 | 43 | return out, present 44 | 45 | 46 | class GPTBlock(nn.Module): 47 | def __init__( 48 | self, 49 | heads, 50 | embedding_dims, 51 | dropout, 52 | forward_expansion, 53 | layer_norm_eps 54 | ): 55 | super(GPTBlock, self).__init__() 56 | self.embedding_dims = embedding_dims 57 | self.attention = SelfAttention(embedding_dims, heads) 58 | self.layer_norm1 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 59 | self.layer_norm2 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 60 | self.feed_forward = nn.Sequential( 61 | *[ 62 | nn.Linear(embedding_dims, embedding_dims*forward_expansion), 63 | nn.GELU(), 64 | nn.Linear(embedding_dims*forward_expansion, embedding_dims) 65 | ] 66 | ) 67 | self.dropout = nn.Dropout(dropout) 68 | 69 | def forward(self, x, mask, past): 70 | attention_block, present = self.attention(self.layer_norm1(x), mask, past) 71 | add = self.dropout(self.layer_norm2(attention_block + x)) 72 | feed_forward = self.feed_forward(add) 73 | out = self.dropout(feed_forward + add) 74 | return out, present 75 | 76 | 77 | class GPT2(nn.Module): 78 | def __init__( 79 | self, 80 | vocab_size, 81 | embedding_dims, 82 | dropout, 83 | heads, 84 | num_of_layers, 85 | forward_expansion, 86 | max_len, 87 | layer_norm_eps = 1e-5 88 | ): 89 | super(GPT2, self).__init__() 90 | self.embedding_dims = embedding_dims 91 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dims) 92 | self.positional_embeddings = nn.Parameter(torch.zeros(1, max_len, embedding_dims)) 93 | self.dropout = nn.Dropout(dropout) 94 | self.gpt_blocks = nn.ModuleList( 95 | [ 96 | GPTBlock( 97 | heads, 98 | embedding_dims, 99 | dropout, 100 | forward_expansion, 101 | layer_norm_eps 102 | 103 | ) 104 | for _ in range(num_of_layers) 105 | ] 106 | ) 107 | 108 | self.fc = nn.Linear(embedding_dims, vocab_size) 109 | 110 | self.apply(self._init_weights) 111 | 112 | #From @HuggingFace 113 | def _init_weights(self, module): 114 | if isinstance(module, (nn.Linear, nn.Embedding)): 115 | module.weight.data.normal_(mean=0.0, std=0.02) 116 | 117 | elif isinstance(module, nn.LayerNorm): 118 | module.bias.data.zero_() 119 | module.weight.data.fill_(1.0) 120 | 121 | if isinstance(module, nn.Linear) and module.bias is not None: 122 | module.bias.data.zero_() 123 | 124 | 125 | def casual_mask(self, x, past): 126 | ones_matix = torch.ones((x.shape[-1], x.shape[-1])) 127 | mask = torch.tril(ones_matix) 128 | if past is not None: 129 | mask = torch.cat((ones_matix, mask), dim=1) 130 | mask = mask.unsqueeze(0).unsqueeze(1) 131 | return mask 132 | 133 | 134 | def forward(self, x, past=None): 135 | casual_mask = self.casual_mask(x, past) 136 | seq_len = x.shape[-1] 137 | word_embeddings = self.word_embeddings(x) 138 | x = self.dropout(word_embeddings + self.positional_embeddings[:, :seq_len, :]) 139 | presents = [] 140 | past_layer = None 141 | for num, block in enumerate(self.gpt_blocks): 142 | if past is not None: 143 | past_layer = past[num] 144 | x, present = block(x, casual_mask, past_layer) 145 | presents.append(present) 146 | return x, presents 147 | 148 | 149 | if __name__ == '__main__': 150 | #DEFAULT GPT-2 PARAMETERS :- 151 | vocab_size = 50257 152 | embedding_dims = 768 153 | dropout = 0.1 154 | heads = 12 155 | num_of_layers = 12 156 | forward_expansion = 4 157 | max_len = 1024 158 | 159 | 160 | a = torch.randint(1, 100, (3, 300)) 161 | model = GPT2( 162 | vocab_size, 163 | embedding_dims, 164 | dropout, 165 | heads, 166 | num_of_layers, 167 | forward_expansion, 168 | max_len, 169 | ) 170 | 171 | start = time() 172 | past_key_value = None 173 | for i in range(2): 174 | y, past_key_value = model(a, past_key_value) 175 | print(f'INFERENCE TIME = {time() - start}sec') 176 | x = sum(p.numel() for p in model.parameters() if p.requires_grad) 177 | print(f'NUMBER OF PARAMETERS ARE = {x}') 178 | -------------------------------------------------------------------------------- /BERT.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SelfAttention(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dims, 10 | heads 11 | ): 12 | super(SelfAttention, self).__init__() 13 | self.heads = heads 14 | self.embedding_dims = embedding_dims 15 | self.head_dims = int(embedding_dims/heads) 16 | 17 | self.key = nn.Linear(self.head_dims, self.head_dims) 18 | self.query = nn.Linear(self.head_dims, self.head_dims) 19 | self.value = nn.Linear(self.head_dims, self.head_dims) 20 | 21 | self.fc = nn.Linear(self.head_dims*self.heads, self.embedding_dims) 22 | 23 | def forward(self, query, key, value, mask): 24 | Batch = query.shape[0] 25 | 26 | query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1] 27 | 28 | query = query.reshape(Batch, query_len, self.heads, self.head_dims) 29 | key = key.reshape(Batch, key_len, self.heads, self.head_dims) 30 | value = value.reshape(Batch, value_len, self.heads, self.head_dims) 31 | 32 | query = self.query(query) 33 | key = self.key(key) 34 | value = self.value(value) 35 | 36 | attention_score = torch.einsum('bqhd,bkhd->bhqk', [query, key]) 37 | 38 | if mask is not None: 39 | attention_score = attention_score.masked_fill(mask==0, float('-1e20')) 40 | 41 | attention_score = attention_score/((self.head_dims)**(1/2)) 42 | attention_score = torch.softmax(attention_score, dim=-1) 43 | 44 | out = torch.einsum('bhqv,bvhd->bqhd', [attention_score, value]).reshape( 45 | Batch, query_len, self.heads*self.head_dims 46 | ) 47 | 48 | out = self.fc(out) 49 | 50 | return out 51 | 52 | 53 | 54 | class BertBlock(nn.Module): 55 | def __init__( 56 | self, 57 | embedding_dims, 58 | heads, 59 | dropout, 60 | forward_expansion, 61 | layer_norm_eps 62 | ): 63 | super(BertBlock, self).__init__() 64 | self.layer_norm1 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 65 | self.layer_norm2 = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 66 | self.attention = SelfAttention(embedding_dims, heads) 67 | self.feed_forward = nn.Sequential( 68 | nn.Linear(embedding_dims, embedding_dims*forward_expansion), 69 | nn.GELU(), 70 | nn.Linear(embedding_dims*forward_expansion, embedding_dims) 71 | ) 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | def forward(self, x, mask): 75 | attention_block = self.attention(x, x, x, mask) 76 | add = self.dropout(self.layer_norm1(x + attention_block)) 77 | feed_forward = self.feed_forward(add) 78 | out = self.dropout(self.layer_norm2(feed_forward + add)) 79 | return out 80 | 81 | 82 | 83 | class Embeddings(nn.Module): 84 | def __init__( 85 | self, 86 | vocab_size, 87 | max_len, 88 | embedding_dims 89 | ): 90 | super(Embeddings, self).__init__() 91 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dims) 92 | self.positional_embeddings = nn.Parameter( 93 | torch.zeros(1, max_len, embedding_dims) 94 | ) 95 | self.segment_embeddings = nn.Embedding(3, embedding_dims) 96 | 97 | def forward(self, x, segment_x): 98 | sentence_len = x.shape[1] 99 | word_embeddings = self.word_embeddings(x) 100 | positional_embeddings = self.positional_embeddings[:, :sentence_len, :] 101 | segment_embeddings = self.segment_embeddings(segment_x) 102 | return word_embeddings + positional_embeddings + segment_embeddings 103 | 104 | 105 | 106 | class BERT(nn.Module): 107 | def __init__( 108 | self, 109 | vocab_size, 110 | max_len, 111 | mask_idx, 112 | embedding_dims, 113 | heads, 114 | forward_expansion, 115 | num_layers, 116 | dropout, 117 | layer_norm_eps 118 | ): 119 | super(BERT, self).__init__() 120 | self.embedding = Embeddings( 121 | vocab_size, 122 | max_len, 123 | embedding_dims 124 | ) 125 | 126 | self.bert_blocks = nn.Sequential( 127 | *[ 128 | BertBlock( 129 | embedding_dims, 130 | heads, 131 | dropout, 132 | forward_expansion, 133 | layer_norm_eps 134 | ) 135 | for _ in range(num_layers) 136 | ] 137 | 138 | ) 139 | 140 | self.layer_norm = nn.LayerNorm(embedding_dims, eps=layer_norm_eps) 141 | self.fc = nn.Linear(embedding_dims, vocab_size) 142 | self.dropout = nn.Dropout(dropout) 143 | self.mask_idx = mask_idx 144 | 145 | self.apply(self._init_weight) 146 | 147 | # @hugging_face 148 | def _init_weight(self, module): 149 | if isinstance(module, (nn.Linear, nn.Embedding)): 150 | module.weight.data.normal_(mean=0.0, std=0.02) 151 | 152 | elif isinstance(module, nn.LayerNorm): 153 | module.weight.data.fill_(1.0) 154 | 155 | if isinstance(module, nn.Linear) and module.bias is not None: 156 | module.bias.data.zero_() 157 | 158 | def create_mask(self, mask): 159 | mask = (mask != self.mask_idx).unsqueeze(1).unsqueeze(2) 160 | return mask 161 | 162 | def forward(self, x, segment_x, mask): 163 | mask = self.create_mask(mask) 164 | x = self.dropout(self.embedding(x, segment_x)) 165 | for block in self.bert_blocks: 166 | x = block(x, mask) 167 | return x 168 | 169 | 170 | if __name__ == '__main__': 171 | #DEFAULT BERT PARAMETER :- 172 | vocab_size = 30522 173 | embedding_dims = 768 174 | dropout = 0.1 175 | heads = 12 176 | num_layers = 12 177 | forward_expansion = 4 178 | max_len = 512 179 | layer_norm_eps = 1e-12 180 | mask_idx = 0 181 | 182 | x = torch.randint(1, 100, (32, 100)) 183 | x_segment = torch.randint(0, 2, (32, 100)) 184 | 185 | model = BERT( 186 | vocab_size, 187 | max_len, 188 | mask_idx, 189 | embedding_dims, 190 | heads, 191 | forward_expansion, 192 | num_layers, 193 | dropout, 194 | layer_norm_eps 195 | ) 196 | 197 | mask = torch.randint(0, 2, (32, 100)) 198 | start = time() 199 | y = model(x, x_segment, mask) 200 | print(f'INFERENCE TIME = {time() - start}sec') 201 | x = sum(p.numel() for p in model.parameters() if p.requires_grad) 202 | print(f'NUMBER OF PARAMETERS ARE = {x}') 203 | -------------------------------------------------------------------------------- /TRANSFORMERS.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Attention(nn.Module): 6 | def __init__(self, input_shape, head): 7 | super(Attention, self).__init__() 8 | self.head = head 9 | self.input_shape = input_shape 10 | self.head_dims = int(input_shape // head) 11 | 12 | self.query = nn.Linear(self.head_dims, self.head_dims) 13 | self.key = nn.Linear(self.head_dims, self.head_dims) 14 | self.value = nn.Linear(self.head_dims, self.head_dims) 15 | self.fc = nn.Linear(self.head_dims*head, input_shape) 16 | 17 | def forward(self, query, key, value, mask=None): 18 | batch = query.shape[0] 19 | query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1] 20 | 21 | query = query.reshape(batch, query_len, self.head, self.head_dims) 22 | key = key.reshape(batch, key_len, self.head, self.head_dims) 23 | value = value.reshape(batch, value_len, self.head, self.head_dims) 24 | 25 | query = self.query(query) 26 | key = self.key(key) 27 | value = self.value(value) 28 | 29 | score = torch.einsum("bqhd,bkhd->bhqk", [query, key]) 30 | 31 | if mask is not None: 32 | score.masked_fill(mask == 0, float("-1e20")) 33 | score = torch.softmax(score/((self.head_dims)**(1/2)), dim=-1) 34 | 35 | out = torch.einsum("bhqv,bvhd->bqhd", [score, value]) 36 | out = out.reshape(batch, query_len, self.head*self.head_dims) 37 | out = self.fc(out) 38 | 39 | return out 40 | 41 | 42 | 43 | class TransformerBlock(nn.Module): 44 | def __init__(self, input_shape, head, dropout, forward_expansion): 45 | super(TransformerBlock, self).__init__() 46 | self.attention = Attention(input_shape, head) 47 | self.feed_forward = nn.Sequential( 48 | nn.Linear(input_shape, input_shape*forward_expansion), 49 | nn.GELU(), 50 | nn.Linear(input_shape*forward_expansion, input_shape) 51 | ) 52 | self.layernorm1 = nn.LayerNorm(input_shape) 53 | self.layernorm2 = nn.LayerNorm(input_shape) 54 | 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | def forward(self, query, key, value, mask): 58 | attention = self.attention(query, key, value, mask) 59 | add = attention + query 60 | regulazation = self.dropout(self.layernorm1(add)) 61 | forward = self.feed_forward(regulazation) 62 | out = self.dropout(self.layernorm2(forward + regulazation)) 63 | return out 64 | 65 | 66 | 67 | class Encoder(nn.Module): 68 | def __init__( 69 | self, 70 | vocab_size, 71 | embedding_out, 72 | num_layers, 73 | heads, 74 | forward_expansion, 75 | dropout, 76 | max_len 77 | ): 78 | super(Encoder, self).__init__() 79 | self.word_embedding = nn.Embedding(vocab_size, embedding_out) 80 | self.postional_embedding = nn.Parameter(torch.zeros(1, max_len, embedding_out)) 81 | self.dropout = nn.Dropout(dropout) 82 | self.layers = nn.Sequential( 83 | *[ 84 | TransformerBlock( 85 | embedding_out, 86 | heads, 87 | dropout, 88 | forward_expansion 89 | ) 90 | for _ in range(num_layers) 91 | ] 92 | ) 93 | 94 | def forward(self, x, mask): 95 | word_embedding = self.word_embedding(x) 96 | postional_embedding = self.postional_embedding[:, :x.shape[1], :] 97 | out = self.dropout(word_embedding + postional_embedding) 98 | for layer in self.layers: 99 | out = layer(out, out, out, mask) 100 | return out 101 | 102 | 103 | 104 | class DecoderBlock(nn.Module): 105 | def __init__( 106 | self, 107 | embedding_out, 108 | head, 109 | forward_expansion, 110 | dropout 111 | ): 112 | super(DecoderBlock, self).__init__() 113 | self.attention = Attention(embedding_out, head) 114 | self.transformer_block = TransformerBlock( 115 | embedding_out, 116 | head, 117 | dropout, 118 | forward_expansion 119 | ) 120 | self.dropout = nn.Dropout(dropout) 121 | self.norm = nn.LayerNorm(embedding_out) 122 | 123 | def forward(self, query, key, value, src_mask, causal_mask): 124 | attention = self.attention(query, query, query, causal_mask) 125 | query = self.dropout(self.norm(attention + query)) 126 | out = self.transformer_block(query, key, value, src_mask) 127 | return out 128 | 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__( 133 | self, 134 | vocab_size, 135 | embedding_out, 136 | num_layers, 137 | head, 138 | forward_expansion, 139 | dropout, 140 | max_len 141 | ): 142 | super(Decoder, self).__init__() 143 | self.word_embedding = nn.Embedding(vocab_size, embedding_out) 144 | self.positional_embedding = nn.Parameter(torch.zeros(1, max_len, embedding_out)) 145 | self.layers = nn.Sequential( 146 | *[ 147 | DecoderBlock( 148 | embedding_out, 149 | head, 150 | forward_expansion, 151 | dropout 152 | ) 153 | for _ in range(num_layers) 154 | ] 155 | ) 156 | self.fc = nn.Linear(embedding_out, vocab_size) 157 | self.dropout = nn.Dropout(dropout) 158 | 159 | def forward(self, x, encoder_output, src_mask, casual_mask): 160 | x = self.dropout(self.word_embedding(x) + self.positional_embedding[:, :x.shape[1], :]) 161 | for layer in self.layers: 162 | x = layer( 163 | x, 164 | encoder_output, 165 | encoder_output, 166 | src_mask, 167 | casual_mask 168 | ) 169 | out = self.fc(x) 170 | return out 171 | 172 | 173 | 174 | class Transformers(nn.Module): 175 | def __init__( 176 | self, 177 | input_vocab_size, 178 | output_vocab_size, 179 | pad_idx, 180 | embedding_out, 181 | num_layers, 182 | forward_expansion, 183 | head, 184 | dropout, 185 | max_len 186 | ): 187 | super(Transformers, self).__init__() 188 | self.encoder = Encoder( 189 | input_vocab_size, 190 | embedding_out, 191 | num_layers, 192 | head, 193 | forward_expansion, 194 | dropout, 195 | max_len 196 | ) 197 | 198 | self.decoder = Decoder( 199 | output_vocab_size, 200 | embedding_out, 201 | num_layers, 202 | head, 203 | forward_expansion, 204 | dropout, 205 | max_len 206 | ) 207 | 208 | self.pad_idx = pad_idx 209 | self.apply(self._init_weights) 210 | 211 | #From @HuggingFace 212 | def _init_weights(self, module): 213 | if isinstance(module, (nn.Linear, nn.Embedding)): 214 | module.weight.data.normal_(mean=0.0, std=0.02) 215 | 216 | elif isinstance(module, nn.LayerNorm): 217 | module.weight.data.fill_(1.0) 218 | 219 | if isinstance(module, nn.Linear) and module.bias is not None: 220 | module.bias.data.zero_() 221 | 222 | def pad_mask(self, inputs): 223 | pad_mask = (inputs != self.pad_idx).unsqueeze(1).unsqueeze(2) 224 | return pad_mask 225 | 226 | def causal_mask(self, target): 227 | N, target_len = target.shape 228 | target_mask = torch.tril(torch.ones((N, target_len, target_len))).unsqueeze(1) 229 | return target_mask 230 | 231 | def forward(self, inputs, target): 232 | pad_mask = self.pad_mask(inputs) 233 | causal_mask = self.causal_mask(target) 234 | encoder_output = self.encoder(inputs, pad_mask) 235 | decoder_out = self.decoder(target, encoder_output, pad_mask, causal_mask) 236 | return decoder_out 237 | 238 | 239 | 240 | if __name__ == "__main__": 241 | #Depends on the Tokenizer 242 | input_vocab_size = 100 243 | output_vocab_size = 200 244 | 245 | #DEFAULT TRANSFORMERS PARAMETERS:- 246 | pad_idx = 0 247 | embedding_out = 512 248 | num_layers = 6 249 | forward_expansion = 4 250 | head = 8 251 | dropout = 0.1 252 | max_len = 512 253 | 254 | inputs = torch.randint(0, 100, (32, 200)) 255 | targets = torch.randint(0, 100, (32,100)) 256 | 257 | model = Transformers( 258 | input_vocab_size, 259 | output_vocab_size, 260 | pad_idx, 261 | embedding_out, 262 | num_layers, 263 | forward_expansion, 264 | head, 265 | dropout, 266 | max_len 267 | ) 268 | 269 | start = time() 270 | y = model(inputs, targets) 271 | print(f'INFERENCE TIME = {time() - start}sec') 272 | x = sum(p.numel() for p in model.parameters() if p.requires_grad) 273 | print(f'NUMBER OF PARAMETERS ARE = {x}') 274 | -------------------------------------------------------------------------------- /PERFORMER.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FastAttention(nn.Module): 7 | def __init__(self, input_shape, head, n_features): 8 | super(FastAttention, self).__init__() 9 | self.head = head 10 | self.input_shape = input_shape 11 | self.depth = int(input_shape // head) 12 | self.n_features = n_features 13 | self.key_ORF = self.OrthogonalRandomFeature() 14 | self.query_ORF = self.OrthogonalRandomFeature() 15 | 16 | self.query = nn.Linear(self.depth, self.depth) 17 | self.key = nn.Linear(self.depth, self.depth) 18 | self.value = nn.Linear(self.depth, self.depth) 19 | self.fc = nn.Linear(self.depth*head, input_shape) 20 | 21 | def kernel_function(self, x, flag): 22 | ORF = self.query_ORF if flag == 'query' else self.key_ORF 23 | normalization_factor = 1/ORF.shape[-1]**0.25 24 | x *= normalization_factor 25 | out = torch.einsum('nhsd, fd -> nhsf', x, ORF) 26 | kernel_fn = nn.ReLU()(out) + 1e-3 27 | return kernel_fn 28 | 29 | def OrthogonalRandomFeature(self): 30 | n = self.n_features//self.depth 31 | remainder = self.n_features%self.depth 32 | orthogonal_features = [] 33 | for _ in range(n): 34 | normal_feature = torch.rand(self.depth, self.depth) 35 | orthogonal_feature, _ = torch.qr(normal_feature) 36 | orthogonal_features.append(orthogonal_feature) 37 | 38 | if remainder > 0 : 39 | normal_feature = torch.rand(self.depth, self.depth) 40 | orthogonal_feature, _ = torch.qr(normal_feature) 41 | orthogonal_features.append(orthogonal_feature[0: remainder]) 42 | 43 | orthogonal_features = torch.cat(orthogonal_features) 44 | mutilplier = torch.randn(self.n_features, self.depth).norm(dim=1) 45 | final_features = torch.matmul(torch.diag(mutilplier), orthogonal_features) 46 | 47 | return final_features 48 | 49 | def causal_attention(self, q, k, v): 50 | denominator = 1/torch.einsum('nhqf, nhkf -> nhqf', q, k.cumsum(dim=-2)) 51 | x = torch.einsum('nhkf, nhkd -> nhkfd', k, v) 52 | x = x.cumsum(dim=-3) 53 | out = torch.einsum('nhqfd, nhqf, nhqf -> nhqd', x, q, denominator) 54 | return out 55 | 56 | 57 | def bidirectional_attention(self, q, k, v): 58 | kt_i = torch.einsum('nhkf -> nhf', k) 59 | normalization_factor = 1/(torch.einsum('nhqf, nhf -> nhq', q, kt_i)) 60 | k_v = torch.einsum('nhkf, nhkd -> nhfd', k, v) 61 | attention = torch.einsum('nhfd, nhqf, nhq-> nhqd', k_v, q, normalization_factor) 62 | return attention 63 | 64 | 65 | def forward(self, query, key, value, mask=None, casual_mask=False): 66 | batch = query.shape[0] 67 | query_len, key_len, value_len = query.shape[1], key.shape[1], value.shape[1] 68 | 69 | 70 | query = query.reshape(batch, query_len, self.head, self.depth) 71 | key = key.reshape(batch, key_len, self.head, self.depth) 72 | value = value.reshape(batch, value_len, self.head, self.depth) 73 | 74 | query = query.permute(0, 2, 1, 3) 75 | key = key.permute(0, 2, 1, 3) 76 | value = value.permute(0, 2, 1, 3) 77 | 78 | query = self.query(query) 79 | key = self.key(key) 80 | value = self.value(value) 81 | 82 | if mask is not None: 83 | key.masked_fill(mask == 0, float("-1e20")) 84 | 85 | query = self.kernel_function(query, 'query') 86 | key = self.kernel_function(key, 'key') 87 | 88 | if casual_mask: 89 | out = self.causal_attention(query, key, value) 90 | else: 91 | out = self.bidirectional_attention(query, key, value) 92 | 93 | out = out.permute(0, 2, 1, 3) 94 | out = out.reshape(batch, query_len, self.head*self.depth) 95 | out = self.fc(out) 96 | 97 | return out 98 | 99 | 100 | 101 | class PerformerBlock(nn.Module): 102 | def __init__(self, input_shape, head, n_features, dropout, forward_expansion): 103 | super(PerformerBlock, self).__init__() 104 | self.attention = FastAttention(input_shape, head, n_features) 105 | self.feed_forward = nn.Sequential( 106 | nn.Linear(input_shape, input_shape*forward_expansion), 107 | nn.GELU(), 108 | nn.Linear(input_shape*forward_expansion, input_shape) 109 | ) 110 | self.layernorm1 = nn.LayerNorm(input_shape) 111 | self.layernorm2 = nn.LayerNorm(input_shape) 112 | 113 | self.dropout = nn.Dropout(dropout) 114 | 115 | def forward(self, query, key, value, mask): 116 | attention = self.attention(query, key, value, mask) 117 | add = attention + query 118 | regulazation = self.dropout(self.layernorm1(add)) 119 | forward = self.feed_forward(regulazation) 120 | out = self.dropout(self.layernorm2(forward + regulazation)) 121 | return out 122 | 123 | 124 | 125 | class Encoder(nn.Module): 126 | def __init__( 127 | self, 128 | vocab_size, 129 | embedding_out, 130 | num_layers, 131 | heads, 132 | n_features, 133 | forward_expansion, 134 | dropout, 135 | max_len 136 | ): 137 | super(Encoder, self).__init__() 138 | self.word_embedding = nn.Embedding(vocab_size, embedding_out) 139 | self.postional_embedding = nn.Parameter(torch.zeros(1, max_len, embedding_out)) 140 | self.dropout = nn.Dropout(dropout) 141 | self.layers = nn.Sequential( 142 | *[ 143 | PerformerBlock( 144 | embedding_out, 145 | heads, 146 | n_features, 147 | dropout, 148 | forward_expansion 149 | ) 150 | for _ in range(num_layers) 151 | ] 152 | ) 153 | 154 | def forward(self, x, mask): 155 | word_embedding = self.word_embedding(x) 156 | postional_embedding = self.postional_embedding[:, :x.shape[1], :] 157 | out = self.dropout(word_embedding + postional_embedding) 158 | for layer in self.layers: 159 | out = layer(out, out, out, mask) 160 | return out 161 | 162 | 163 | 164 | class DecoderBlock(nn.Module): 165 | def __init__( 166 | self, 167 | embedding_out, 168 | head, 169 | n_features, 170 | forward_expansion, 171 | dropout 172 | ): 173 | super(DecoderBlock, self).__init__() 174 | self.attention = FastAttention(embedding_out, head, n_features) 175 | self.Performer_block = PerformerBlock( 176 | embedding_out, 177 | head, 178 | n_features, 179 | dropout, 180 | forward_expansion 181 | ) 182 | self.dropout = nn.Dropout(dropout) 183 | self.norm = nn.LayerNorm(embedding_out) 184 | 185 | def forward(self, query, key, value, src_mask): 186 | attention = self.attention(query, query, query, src_mask, True) 187 | query = self.dropout(self.norm(attention + query)) 188 | out = self.Performer_block(query, key, value, src_mask) 189 | return out 190 | 191 | 192 | 193 | class Decoder(nn.Module): 194 | def __init__( 195 | self, 196 | vocab_size, 197 | embedding_out, 198 | num_layers, 199 | head, 200 | n_features, 201 | forward_expansion, 202 | dropout, 203 | max_len 204 | ): 205 | super(Decoder, self).__init__() 206 | self.word_embedding = nn.Embedding(vocab_size, embedding_out) 207 | self.positional_embedding = nn.Parameter(torch.zeros(1, max_len, embedding_out)) 208 | self.layers = nn.Sequential( 209 | *[ 210 | DecoderBlock( 211 | embedding_out, 212 | head, 213 | n_features, 214 | forward_expansion, 215 | dropout 216 | ) 217 | for _ in range(num_layers) 218 | ] 219 | ) 220 | self.fc = nn.Linear(embedding_out, vocab_size) 221 | self.dropout = nn.Dropout(dropout) 222 | 223 | def forward(self, x, encoder_output, src_mask): 224 | x = self.dropout(self.word_embedding(x) + self.positional_embedding[:, :x.shape[1], :]) 225 | for layer in self.layers: 226 | x = layer( 227 | x, 228 | encoder_output, 229 | encoder_output, 230 | src_mask 231 | ) 232 | out = self.fc(x) 233 | return out 234 | 235 | 236 | 237 | class Performers(nn.Module): 238 | def __init__( 239 | self, 240 | input_vocab_size, 241 | output_vocab_size, 242 | pad_idx, 243 | embedding_out, 244 | num_layers, 245 | forward_expansion, 246 | head, 247 | n_features, 248 | dropout, 249 | max_len 250 | ): 251 | super(Performers, self).__init__() 252 | self.encoder = Encoder( 253 | input_vocab_size, 254 | embedding_out, 255 | num_layers, 256 | head, 257 | n_features, 258 | forward_expansion, 259 | dropout, 260 | max_len 261 | ) 262 | 263 | self.decoder = Decoder( 264 | output_vocab_size, 265 | embedding_out, 266 | num_layers, 267 | head, 268 | n_features, 269 | forward_expansion, 270 | dropout, 271 | max_len 272 | ) 273 | 274 | self.pad_idx = pad_idx 275 | self.apply(self._init_weights) 276 | 277 | #From @HuggingFace 278 | def _init_weights(self, module): 279 | if isinstance(module, (nn.Linear, nn.Embedding)): 280 | module.weight.data.normal_(mean=0.0, std=0.02) 281 | 282 | elif isinstance(module, nn.LayerNorm): 283 | module.weight.data.fill_(1.0) 284 | 285 | if isinstance(module, nn.Linear) and module.bias is not None: 286 | module.bias.data.zero_() 287 | 288 | def input_pad_mask(self, inputs): 289 | pad_mask = (inputs != self.pad_idx).unsqueeze(1).unsqueeze(3) 290 | return pad_mask 291 | 292 | def output_pad_mask(self, targets): 293 | pad_mask = (targets != self.pad_idx).unsqueeze(1).unsqueeze(3) 294 | 295 | def forward(self, inputs, target): 296 | input_pad_mask = self.input_pad_mask(inputs) 297 | output_pad_mask = self.output_pad_mask(targets) 298 | encoder_output = self.encoder(inputs, input_pad_mask) 299 | decoder_out = self.decoder(target, encoder_output, output_pad_mask) 300 | return decoder_out 301 | 302 | 303 | 304 | if __name__ == "__main__": 305 | #Depends on the Tokenizer 306 | input_vocab_size = 100 307 | output_vocab_size = 200 308 | 309 | #DEFAULT PerFORMERS PARAMETERS:- 310 | pad_idx = 0 311 | embedding_out = 512 312 | num_layers = 6 313 | forward_expansion = 4 314 | head = 8 315 | n_features = 256 316 | dropout = 0.1 317 | max_len = 512 318 | 319 | inputs = torch.randint(0, 100, (32, 200)) 320 | targets = torch.randint(0, 100, (32,100)) 321 | 322 | model = Performers( 323 | input_vocab_size, 324 | output_vocab_size, 325 | pad_idx, 326 | embedding_out, 327 | num_layers, 328 | forward_expansion, 329 | head, 330 | n_features, 331 | dropout, 332 | max_len 333 | ) 334 | 335 | start = time() 336 | y = model(inputs, targets) 337 | print(f'INFERENCE TIME = {time() - start}sec') 338 | x = sum(p.numel() for p in model.parameters() if p.requires_grad) 339 | print(f'NUMBER OF PARAMETERS ARE = {x}') --------------------------------------------------------------------------------