├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── examples ├── __init__.py ├── __pycache__ │ ├── create_mask.cpython-39.pyc │ ├── masked_self_attention.cpython-39.pyc │ └── mha.cpython-39.pyc ├── create_mask.py ├── mha.py ├── self_attention.py └── toy_examples.py ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── decoder.cpython-37.pyc │ ├── decoder.cpython-38.pyc │ ├── decoder.cpython-39.pyc │ ├── decoder_OLD.cpython-39.pyc │ ├── decoder_layer.cpython-37.pyc │ ├── decoder_layer.cpython-38.pyc │ ├── decoder_layer.cpython-39.pyc │ ├── decoder_layer_OLD.cpython-39.pyc │ ├── efficient_mha.cpython-39.pyc │ ├── embed.cpython-37.pyc │ ├── embed.cpython-38.pyc │ ├── embed.cpython-39.pyc │ ├── encoder.cpython-37.pyc │ ├── encoder.cpython-38.pyc │ ├── encoder.cpython-39.pyc │ ├── encoder_layer.cpython-37.pyc │ ├── encoder_layer.cpython-38.pyc │ ├── encoder_layer.cpython-39.pyc │ ├── mha.cpython-37.pyc │ ├── mha.cpython-38.pyc │ ├── mha.cpython-39.pyc │ ├── positional_encoding.cpython-37.pyc │ ├── positional_encoding.cpython-38.pyc │ ├── positional_encoding.cpython-39.pyc │ ├── pwffn.cpython-37.pyc │ ├── pwffn.cpython-38.pyc │ ├── pwffn.cpython-39.pyc │ ├── residual_layer_norm.cpython-37.pyc │ ├── residual_layer_norm.cpython-38.pyc │ ├── residual_layer_norm.cpython-39.pyc │ ├── transformers.cpython-37.pyc │ ├── transformers.cpython-38.pyc │ └── transformers.cpython-39.pyc ├── decoder.py ├── decoder_layer.py ├── efficient_mha.py ├── embed.py ├── encoder.py ├── encoder_layer.py ├── mha.py ├── positional_encoding.py ├── pwffn.py ├── residual_layer_norm.py └── transformers.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | .data 2 | .vscode 3 | lightning_logs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformers-tutorial 2 | The code for the video tutorial series on building a Transformer from scratch: https://www.youtube.com/watch?v=XR4VDnJzB8o 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import layers -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/examples/__init__.py -------------------------------------------------------------------------------- /examples/__pycache__/create_mask.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/examples/__pycache__/create_mask.cpython-39.pyc -------------------------------------------------------------------------------- /examples/__pycache__/masked_self_attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/examples/__pycache__/masked_self_attention.cpython-39.pyc -------------------------------------------------------------------------------- /examples/__pycache__/mha.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/examples/__pycache__/mha.cpython-39.pyc -------------------------------------------------------------------------------- /examples/create_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_mask(size): 5 | # since this mask is the same for a batch being fed into the model, 6 | # we will the mask Tensor with the batch size = 1. 7 | # Broadcasting will allow us to replicate this mask across all the other elements in the batch 8 | mask = torch.ones((1, size, size)).triu(1) 9 | mask = mask == 0 10 | return(mask) -------------------------------------------------------------------------------- /examples/mha.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torch.nn as nn 4 | import math as m 5 | import torch.nn.functional as F 6 | 7 | # %% 8 | class MultiHeadAttention(nn.Module): 9 | def __init__(self, d_model=4, num_heads=2, dropout=0.3): 10 | super().__init__() 11 | 12 | # d_q, d_k, d_v 13 | self.d = d_model//num_heads 14 | 15 | self.d_model = d_model 16 | self.num_heads = num_heads 17 | 18 | self.dropout = nn.Dropout(dropout) 19 | 20 | ##create a list of layers for K, and a list of layers for V 21 | 22 | self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d) 23 | for _ in range(num_heads)]) 24 | self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d) 25 | for _ in range(num_heads)]) 26 | self.linear_Vs = nn.ModuleList([nn.Linear(d_model, self.d) 27 | for _ in range(num_heads)]) 28 | 29 | self.mha_linear = nn.Linear(d_model, d_model) 30 | 31 | def scaled_dot_product_attention(self, Q, K, V): 32 | # shape(Q) = [B x seq_len x D/num_heads] = [B x T x d_k] 33 | # shape(K, V) = [B x T x d_k] 34 | 35 | Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1)) 36 | scores = Q_K_matmul/m.sqrt(self.d) 37 | # shape(scores) = [B x seq_len x seq_len] 38 | 39 | attention_weights = F.softmax(scores, dim=-1) 40 | # shape(attention_weights) = [B x seq_len x seq_len] 41 | 42 | output = torch.matmul(attention_weights, V) 43 | # shape(output) = [B x seq_len x D/num_heads] 44 | 45 | return output, attention_weights 46 | 47 | def forward(self, x): 48 | # shape(x) = [B x seq_len x D] 49 | 50 | Q = [linear_Q(x) for linear_Q in self.linear_Qs] 51 | K = [linear_K(x) for linear_K in self.linear_Ks] 52 | V = [linear_V(x) for linear_V in self.linear_Vs] 53 | # shape(Q, K, V) = [B x seq_len x D/num_heads] * num_heads 54 | 55 | output_per_head = [] 56 | attn_weights_per_head = [] 57 | # shape(output_per_head) = [B x seq_len x D/num_heads] * num_heads 58 | # shape(attn_weights_per_head) = [B x seq_len x seq_len] * num_heads 59 | for Q_, K_, V_ in zip(Q, K, V): 60 | 61 | ##run scaled_dot_product_attention 62 | output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_) 63 | # shape(output) = [B x seq_len x D/num_heads] 64 | # shape(attn_weights_per_head) = [B x seq_len x seq_len] 65 | output_per_head.append(output) 66 | attn_weights_per_head.append(attn_weight) 67 | 68 | output = torch.cat(output_per_head, -1) 69 | attn_weights = torch.stack(attn_weights_per_head).permute(1, 0, 2, 3) 70 | # shape(output) = [B x seq_len x D] 71 | # shape(attn_weights) = [B x num_heads x seq_len x seq_len] 72 | 73 | projection = self.dropout(self.mha_linear(output)) 74 | 75 | return projection, attn_weights 76 | 77 | -------------------------------------------------------------------------------- /examples/self_attention.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | 7 | # %% 8 | def scaled_dot_product_attention(Q, K, V, dk=4): 9 | ##matmul Q and K 10 | QK = torch.matmul(Q, K.T) 11 | 12 | ##scale QK by the sqrt of dk 13 | matmul_scaled = QK / math.sqrt(dk) 14 | 15 | attention_weights = F.softmax(matmul_scaled, dim=-1) 16 | 17 | ## matmul attention_weights by V 18 | output = torch.matmul(attention_weights, V) 19 | 20 | return output, attention_weights 21 | 22 | # %% 23 | def print_attention(Q, K, V, n_digits = 3): 24 | temp_out, temp_attn = scaled_dot_product_attention(Q, K, V) 25 | temp_out, temp_attn = temp_out.numpy(), temp_attn.numpy() 26 | print ('Attention weights are:') 27 | print (np.round(temp_attn, n_digits)) 28 | print() 29 | print ('Output is:') 30 | print (np.around(temp_out, n_digits)) 31 | 32 | # %% 33 | temp_k = torch.Tensor([[10,0,0], 34 | [0,10,0], 35 | [0,0,10], 36 | [0,0,10]]) # (4, 3) 37 | 38 | temp_v = torch.Tensor([[ 1,0, 1], 39 | [ 10,0, 2], 40 | [ 100,5, 0], 41 | [1000,6, 0]]) # (4, 3) 42 | 43 | # %% 44 | # This `query` aligns with the second `key`, 45 | # so the second `value` is returned. 46 | temp_q = torch.Tensor([[0, 10, 0]]) # (1, 3) 47 | print_attention(temp_q, temp_k, temp_v) 48 | 49 | # %% 50 | # This query aligns with a repeated key (third and fourth), 51 | # so all associated values get averaged. 52 | temp_q = torch.Tensor([[0, 0, 10]]) # (1, 3) 53 | print_attention(temp_q, temp_k, temp_v) 54 | 55 | # %% 56 | # This query aligns equally with the first and second key, 57 | # so their values get averaged. 58 | temp_q = torch.Tensor([[10, 10, 0]]) # (1, 3) 59 | print_attention(temp_q, temp_k, temp_v) 60 | 61 | # %% 62 | temp_q = torch.Tensor([[0, 10, 0], [0, 0, 10], [10, 10, 0]]) # (3, 3) 63 | print_attention(temp_q, temp_k, temp_v) -------------------------------------------------------------------------------- /examples/toy_examples.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | from mha import MultiHeadAttention 4 | import os 5 | import sys 6 | import pathlib 7 | 8 | PACKAGE_PARENT = pathlib.Path.cwd().parent 9 | sys.path.append(str(PACKAGE_PARENT)) 10 | 11 | 12 | # %% 13 | toy_encodings = torch.Tensor([[[0.0, 0.1, 0.2, 0.3], [1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]]]) 14 | # shape(toy_encodings) = [B, T, D] = (1, 3, 4) 15 | print("Toy Encodings:\n", toy_encodings) 16 | 17 | toy_MHA_layer = MultiHeadAttention(d_model=4, num_heads=2) 18 | toy_MHA, _ = toy_MHA_layer(toy_encodings) 19 | print("Toy MHA: \n", toy_MHA) 20 | print("Toy MHA Shape: \n", toy_MHA.shape) 21 | 22 | # %% 23 | from layers.residual_layer_norm import ResidualLayerNorm 24 | toy_prev_x = torch.randn(1, 3, 4) 25 | toy_norm_layer = ResidualLayerNorm(d_model=4) 26 | toy_norm = toy_norm_layer(toy_encodings, toy_prev_x) 27 | print("Toy Norm: \n", toy_norm) 28 | print("Toy Norm shape: \n", toy_norm.shape) 29 | 30 | # %% 31 | from layers.pwffn import PWFFN 32 | toy_PWFFN_layer = PWFFN(d_model=4, d_ff=16) 33 | toy_PWFFN = toy_PWFFN_layer(toy_norm) 34 | print("Toy PWFFN: \n", toy_PWFFN) 35 | print("Toy PWFFN Shape: \n", toy_PWFFN.shape) 36 | 37 | # %% 38 | from layers.embed import Embeddings 39 | toy_vocab = torch.LongTensor([[1, 2, 3, 4, 0, 0]]) 40 | 41 | toy_embedding_layer = Embeddings(5, pad_idx=0, d_model=4) 42 | toy_embeddings = toy_embedding_layer(toy_vocab) 43 | 44 | print("Toy Embeddings: \n", toy_embeddings) 45 | print("Toy Embeddings Shape: \n", toy_embeddings.shape) 46 | 47 | # %% 48 | from layers.positional_encoding import PositionalEncoding 49 | toy_PE_layer = PositionalEncoding(d_model=4) 50 | toy_PEs = toy_PE_layer(toy_embeddings) 51 | 52 | print("Toy PE: \n", toy_PEs) 53 | print("Toy PE Shape: \n", toy_PEs.shape) 54 | 55 | print(toy_PE_layer.pe[0, 0]) 56 | 57 | # %% 58 | # from layers.encoder import Encoder 59 | # toy_encoder = Encoder(toy_embedding_layer, 4, 2, 2, 8) 60 | # toy_encoder_output, toy_encoder_attn = toy_encoder(toy_vocab) 61 | 62 | # print("Toy Encodings: \n", toy_encoder_output) 63 | # print("Toy Encoder Attn Weights: \n", toy_encoder_attn) 64 | # print("Toy Encodings Shape: \n", toy_encoder_output.shape) 65 | # print("Toy Encodings Attn Weights Shape: \n", toy_encoder_attn.shape) 66 | 67 | # %% 68 | from create_mask import create_mask 69 | toy_mask = create_mask(10) 70 | print("Toy Mask: \n", toy_mask) 71 | print("Toy Mask Shape: \n", toy_mask.shape) 72 | 73 | # %% 74 | toy_scores = torch.arange(100).reshape(1, 10, 10) 75 | print("Toy Scores: \n", toy_scores) 76 | print("Toy Scores Shape: \n", toy_scores.shape) 77 | 78 | toy_scores = toy_scores.masked_fill(toy_mask == 0, -1) 79 | print("Toy Scores Masked: \n", toy_scores) 80 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__init__.py -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder_OLD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder_OLD.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder_layer.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder_layer.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder_layer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder_layer.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/decoder_layer_OLD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/decoder_layer_OLD.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/efficient_mha.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/efficient_mha.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/embed.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/embed.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/embed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/embed.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder_layer.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder_layer.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/encoder_layer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/encoder_layer.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/mha.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/mha.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/mha.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/mha.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/mha.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/mha.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/positional_encoding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/positional_encoding.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/positional_encoding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/positional_encoding.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/positional_encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/positional_encoding.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/pwffn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/pwffn.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/pwffn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/pwffn.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/pwffn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/pwffn.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/residual_layer_norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/residual_layer_norm.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/residual_layer_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/residual_layer_norm.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/residual_layer_norm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/residual_layer_norm.cpython-39.pyc -------------------------------------------------------------------------------- /layers/__pycache__/transformers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/transformers.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/transformers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/transformers.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/transformers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feather-ai/transformers-tutorial/64bfae21851333855862552e870a7ada3db99e7b/layers/__pycache__/transformers.cpython-39.pyc -------------------------------------------------------------------------------- /layers/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .decoder_layer import DecoderLayer 4 | from .embed import Embeddings 5 | from .positional_encoding import PositionalEncoding 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, Embedding: Embeddings, d_model, 10 | num_heads, num_layers, 11 | d_ff, device="cpu", dropout=0.3, efficient_mha=False): 12 | super().__init__() 13 | 14 | self.embedding = Embedding 15 | 16 | self.PE = PositionalEncoding( 17 | d_model, device=device) 18 | 19 | self.dropout = nn.Dropout(dropout) 20 | 21 | self.decoders = nn.ModuleList([DecoderLayer( 22 | d_model, 23 | num_heads, 24 | d_ff, 25 | dropout, 26 | efficient_mha 27 | ) for layer in range(num_layers)]) 28 | 29 | def forward(self, x, encoder_output, trg_mask, src_mask): 30 | # shape(x) = [B x TRG_seq_len] 31 | 32 | embeddings = self.embedding(x) 33 | encoding = self.PE(embeddings) 34 | # shape(embeddings) = [B x TRG_seq_len x D] 35 | # shape(encoding) = [B x TRG_seq_len x D] 36 | 37 | for decoder in self.decoders: 38 | encoding, masked_mha_attn_weights, enc_dec_mha_attn_weights = decoder(encoding, encoder_output, trg_mask, src_mask) 39 | # shape(encoding) = [B x TRG_seq_len x D] 40 | # shape(masked_mha_attn_weights) = [B x num_heads x TRG_seq_len x TRG_seq_len] 41 | # shape(enc_dec_mha_attn_weights) = [B x num_heads x TRG_seq_len x SRC_seq_len] 42 | 43 | return encoding, masked_mha_attn_weights, enc_dec_mha_attn_weights 44 | -------------------------------------------------------------------------------- /layers/decoder_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .mha import MultiHeadAttention 3 | from .efficient_mha import MultiHeadAttention as EfficientMultiHeadAttention 4 | from .pwffn import PWFFN 5 | from .residual_layer_norm import ResidualLayerNorm 6 | 7 | 8 | class DecoderLayer(nn.Module): 9 | def __init__(self, d_model, num_heads, d_ff, dropout=0.3, efficient_mha=False): 10 | super().__init__() 11 | self.norm_1 = ResidualLayerNorm(d_model) 12 | self.norm_2 = ResidualLayerNorm(d_model) 13 | self.norm_3 = ResidualLayerNorm(d_model) 14 | 15 | if efficient_mha: 16 | self.masked_mha = EfficientMultiHeadAttention(d_model, num_heads, dropout) 17 | self.enc_dec_mha = EfficientMultiHeadAttention(d_model, num_heads, dropout) 18 | else: 19 | self.masked_mha = MultiHeadAttention(d_model, num_heads, dropout) 20 | self.enc_dec_mha = MultiHeadAttention(d_model, num_heads, dropout) 21 | 22 | self.ff = PWFFN(d_model, d_ff) 23 | 24 | def forward(self, x, encoder_outputs, trg_mask, src_mask): 25 | # shape(x) = [B x TRG_seq_len x D] 26 | # shape(encoder_outputs) = [B x SRC_seq_len x D] 27 | 28 | masked_mha, masked_mha_attn_weights = self.masked_mha(x, x, x, mask=trg_mask) 29 | # shape(masked_mha) = [B x TRG_seq_len x D] 30 | # shape(masked_mha_attn_weights) = [B x num_heads x TRG_seq_len x TRG_seq_len] 31 | 32 | norm1 = self.norm_1(masked_mha, x) 33 | # shape(norm1) = [B x TRG_seq_len x D] 34 | 35 | enc_dec_mha, enc_dec_mha_attn_weights = self.enc_dec_mha(norm1, encoder_outputs, encoder_outputs, mask=src_mask) 36 | # shape(enc_dec_mha) = [B x TRG_seq_len x D] 37 | # shape(enc_dec_mha_attn_weights) = [B x num_heads x TRG_seq_len x SRC_seq_len] 38 | 39 | norm2 = self.norm_2(enc_dec_mha, norm1) 40 | # shape(norm2) = [B x TRG_seq_len x D] 41 | 42 | ff = self.ff(norm2) 43 | norm3 = self.norm_3(ff, norm2) 44 | # shape(ff) = [B x TRG_seq_len x D] 45 | # shape(norm3) = [B x TRG_seq_len x D] 46 | 47 | return norm3, masked_mha_attn_weights, enc_dec_mha_attn_weights 48 | -------------------------------------------------------------------------------- /layers/efficient_mha.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torch.nn as nn 4 | import math as m 5 | import torch.nn.functional as F 6 | 7 | # %% 8 | 9 | 10 | class MultiHeadAttention(nn.Module): 11 | def __init__(self, d_model=4, num_heads=2, dropout=0.3): 12 | super().__init__() 13 | 14 | # d_q, d_k, d_v 15 | self.d = d_model//num_heads 16 | 17 | self.d_model = d_model 18 | self.num_heads = num_heads 19 | 20 | self.dropout = nn.Dropout(dropout) 21 | 22 | self.linear_Q = nn.Linear(d_model, d_model) 23 | self.linear_K = nn.Linear(d_model, d_model) 24 | self.linear_V = nn.Linear(d_model, d_model) 25 | 26 | self.mha_linear = nn.Linear(d_model, d_model) 27 | 28 | def scaled_dot_product_attention(self, Q, K, V, mask=None): 29 | # shape(Q) = [B x num_heads x Q_len x D/num_heads] 30 | # shape(K, V) = [B x num_heads x KV_len x D/num_heads] 31 | 32 | # reshaped(K) = [B x num_heads x D/num_heads x KV_len] 33 | Q_K_matmul = torch.matmul(Q, K.permute(0, 1, 3, 2)) 34 | scores = Q_K_matmul/m.sqrt(self.d) 35 | # shape(scores) = [B x num_heads x Q_len x KV_len] 36 | 37 | if mask is not None: 38 | scores = scores.masked_fill(mask == 0, -1e9) 39 | 40 | attention_weights = F.softmax(scores, dim=-1) 41 | # shape(attention_weights) = [B x num_heads x Q_len x KV_len] 42 | 43 | output = torch.matmul(attention_weights, V) 44 | # shape(output) = [B x num_heads x Q_len x D/num_heads] 45 | 46 | return output, attention_weights 47 | 48 | def forward(self, pre_q, pre_k, pre_v, mask=None): 49 | # shape(x) = [B x seq_len x D] 50 | 51 | Q = self.linear_Q(pre_q) 52 | K = self.linear_K(pre_k) 53 | V = self.linear_V(pre_v) 54 | # shape(Q) = [B x seq_len x D] (if in encoder, seq_len = SRC_seq_len; if in decoder, seq_len = TRG_seq_len) 55 | # shape(K, V) = [B x seq_len x D] (always SRC_seq_len unless in masked-multihead-attention) 56 | 57 | batch_size = pre_q.shape[0] 58 | 59 | Q = Q.reshape(batch_size, self.num_heads, -1, self.d) 60 | K = K.reshape(batch_size, self.num_heads, -1, self.d) 61 | V = V.reshape(batch_size, self.num_heads, -1, self.d) 62 | # shape(Q) = [B x num_heads x seq_len x D] 63 | # shape(K, V) = [B x num_heads x seq_len x D] 64 | 65 | # run scaled_dot_product_attention 66 | output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask) 67 | # shape(output) = [B x num_heads x Q_len x D/num_heads] 68 | # shape(attn_weights) = [B x num_heads x Q_len x KV_len] 69 | 70 | output = output.reshape(batch_size, -1, self.d_model) 71 | # shape(output) = [B x seq_len x D] 72 | 73 | projection = self.dropout(self.mha_linear(output)) 74 | 75 | return projection, attn_weights 76 | -------------------------------------------------------------------------------- /layers/embed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math as m 3 | 4 | 5 | class Embeddings(nn.Module): 6 | def __init__(self, vocab_size, padding_idx, d_model): 7 | super().__init__() 8 | self.d_model = d_model 9 | self.embed = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 10 | 11 | def forward(self, x): 12 | # shape(x) = [B x seq_len] 13 | 14 | embedding = self.embed(x) 15 | # shape(embedding) = [B x seq_len x D] 16 | 17 | return embedding * m.sqrt(self.d_model) 18 | -------------------------------------------------------------------------------- /layers/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .encoder_layer import EncoderLayer 4 | from .positional_encoding import PositionalEncoding 5 | from .embed import Embeddings 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, Embedding: Embeddings, d_model, 10 | num_heads, num_layers, 11 | d_ff, device="cpu", dropout=0.3, efficient_mha=False): 12 | super().__init__() 13 | 14 | self.embedding = Embedding 15 | 16 | self.PE = PositionalEncoding( 17 | d_model, device=device) 18 | 19 | self.encoders = nn.ModuleList([EncoderLayer( 20 | d_model, 21 | num_heads, 22 | d_ff, 23 | dropout, 24 | efficient_mha 25 | ) for layer in range(num_layers)]) 26 | 27 | def forward(self, x, mask=None): 28 | # shape(x) = [B x SRC_seq_len] 29 | 30 | embeddings = self.embedding(x) 31 | encoding = self.PE(embeddings) 32 | # shape(embeddings) = [B x SRC_seq_len x D] 33 | # shape(encoding) = [B x SRC_seq_len x D] 34 | 35 | for encoder in self.encoders: 36 | encoding, encoder_attention_weights = encoder(encoding, mask) 37 | # shape(encoding) = [B x SRC_seq_len x D] 38 | # shape(encoder_attention_weights) = [B x SRC_seq_len x SRC_seq_len] 39 | 40 | return encoding, encoder_attention_weights 41 | -------------------------------------------------------------------------------- /layers/encoder_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .residual_layer_norm import ResidualLayerNorm 3 | from .mha import MultiHeadAttention 4 | from .efficient_mha import MultiHeadAttention as EfficientMultiHeadAttention 5 | from .pwffn import PWFFN 6 | 7 | 8 | class EncoderLayer(nn.Module): 9 | def __init__(self, d_model, num_heads, d_ff, dropout=0.3, efficient_mha=False): 10 | super().__init__() 11 | 12 | # initalize these 13 | self.norm_1 = ResidualLayerNorm(d_model, dropout) 14 | self.norm_2 = ResidualLayerNorm(d_model, dropout) 15 | 16 | if efficient_mha: 17 | self.mha = EfficientMultiHeadAttention(d_model, num_heads, dropout) 18 | else: 19 | self.mha = MultiHeadAttention(d_model, num_heads, dropout) 20 | 21 | self.ff = PWFFN(d_model, d_ff, dropout) 22 | 23 | def forward(self, x, mask): 24 | # shape(x) = [B x seq_len x D] 25 | 26 | mha, encoder_attention_weights = self.mha(x, x, x, mask) 27 | # shape(mha) = [B x seq_len x D] 28 | # shape(encoder_attention_weights) = [B x num_heads x seq_len x seq_len] 29 | 30 | norm1 = self.norm_1(mha, x) 31 | # shape(norm1) = [B x seq_len x D] 32 | 33 | ff = self.ff(norm1) 34 | norm2 = self.norm_2(ff, norm1) 35 | # shape(ff) = [B x seq_len x D] 36 | # shape(norm2) = [B x seq_len x D] 37 | 38 | return norm2, encoder_attention_weights 39 | -------------------------------------------------------------------------------- /layers/mha.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | import torch.nn as nn 4 | import math as m 5 | import torch.nn.functional as F 6 | 7 | # %% 8 | class MultiHeadAttention(nn.Module): 9 | def __init__(self, d_model=4, num_heads=2, dropout=0.3): 10 | super().__init__() 11 | 12 | # d_q, d_k, d_v 13 | self.d = d_model//num_heads 14 | 15 | 16 | self.d_model = d_model 17 | self.num_heads = num_heads 18 | 19 | self.dropout = nn.Dropout(dropout) 20 | 21 | ##create a list of layers for K, and a list of layers for V 22 | 23 | self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d) 24 | for _ in range(num_heads)]) 25 | self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d) 26 | for _ in range(num_heads)]) 27 | self.linear_Vs = nn.ModuleList([nn.Linear(d_model, self.d) 28 | for _ in range(num_heads)]) 29 | 30 | self.mha_linear = nn.Linear(d_model, d_model) 31 | 32 | def scaled_dot_product_attention(self, Q, K, V, mask=None): 33 | # shape(Q) = [B x seq_len x D/num_heads] 34 | # shape(K, V) = [B x seq_len x D/num_heads] 35 | 36 | Q_K_matmul = torch.matmul(Q, K.permute(0, 2, 1)) 37 | scores = Q_K_matmul/m.sqrt(self.d) 38 | # shape(scores) = [B x seq_len x seq_len] 39 | 40 | if mask is not None: 41 | scores = scores.masked_fill(mask == 0, -1e9) 42 | 43 | attention_weights = F.softmax(scores, dim=-1) 44 | # shape(attention_weights) = [B x seq_len x seq_len] 45 | 46 | output = torch.matmul(attention_weights, V) 47 | # shape(output) = [B x seq_len x D/num_heads] 48 | 49 | return output, attention_weights 50 | 51 | def forward(self, pre_q, pre_k, pre_v, mask=None): 52 | # shape(x) = [B x seq_len x D] 53 | 54 | Q = [linear_Q(pre_q) for linear_Q in self.linear_Qs] 55 | K = [linear_K(pre_k) for linear_K in self.linear_Ks] 56 | V = [linear_V(pre_v) for linear_V in self.linear_Vs] 57 | # shape(Q, K, V) = [B x seq_len x D/num_heads] * num_heads 58 | 59 | output_per_head = [] 60 | attn_weights_per_head = [] 61 | # shape(output_per_head) = [B x seq_len x D/num_heads] * num_heads 62 | # shape(attn_weights_per_head) = [B x seq_len x seq_len] * num_heads 63 | 64 | for Q_, K_, V_ in zip(Q, K, V): 65 | 66 | ##run scaled_dot_product_attention 67 | output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_, mask) 68 | # shape(output) = [B x seq_len x D/num_heads] 69 | # shape(attn_weights_per_head) = [B x seq_len x seq_len] 70 | output_per_head.append(output) 71 | attn_weights_per_head.append(attn_weight) 72 | 73 | output = torch.cat(output_per_head, -1) 74 | attn_weights = torch.stack(attn_weights_per_head).permute(1, 0, 2, 3) 75 | # shape(output) = [B x seq_len x D] 76 | # shape(attn_weights) = [B x num_heads x seq_len x seq_len] 77 | 78 | projection = self.dropout(self.mha_linear(output)) 79 | 80 | return projection, attn_weights 81 | 82 | -------------------------------------------------------------------------------- /layers/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class PositionalEncoding(nn.Module): 6 | def __init__(self, d_model, dropout=0.3, max_seq_len=200, device="cpu"): 7 | super().__init__() 8 | self.d_model = d_model 9 | self.dropout = nn.Dropout(dropout) 10 | 11 | pe = torch.zeros(max_seq_len, d_model).to(device) 12 | pos = torch.arange(0, max_seq_len).unsqueeze(1).float() 13 | 14 | two_i = torch.arange(0, d_model, step=2).float() 15 | div_term = torch.pow(10000, (two_i/torch.Tensor([d_model]))).float() 16 | pe[:, 0::2] = torch.sin(pos/div_term) 17 | pe[:, 1::2] = torch.cos(pos/div_term) 18 | 19 | pe = pe.unsqueeze(0) 20 | 21 | # assigns the first argument to a class variable 22 | # i.e. self.pe 23 | self.register_buffer("pe", pe) 24 | 25 | def forward(self, x): 26 | # shape(x) = [B x seq_len x D] 27 | one_batch_pe: torch.Tensor = self.pe[:, :x.shape[1]].detach() 28 | repeated_pe = one_batch_pe.repeat([x.shape[0], 1, 1]).detach() 29 | x = x.add(repeated_pe) 30 | # shape(x) = [B x seq_len x D] 31 | return self.dropout(x) 32 | -------------------------------------------------------------------------------- /layers/pwffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class PWFFN(nn.Module): 4 | def __init__(self, d_model, d_ff, dropout=0.3): 5 | super().__init__() 6 | 7 | self.ff = nn.Sequential( 8 | nn.Linear(d_model, d_ff), 9 | nn.ReLU(), 10 | nn.Dropout(dropout), 11 | nn.Linear(d_ff, d_model) 12 | ) 13 | 14 | def forward(self, x): 15 | # shape(x) = [B x seq_len x D] 16 | 17 | ff = self.ff(x) 18 | # shape(ff) = [B x seq_len x D] 19 | 20 | return ff -------------------------------------------------------------------------------- /layers/residual_layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResidualLayerNorm(nn.Module): 5 | def __init__(self, d_model, dropout=0.3): 6 | super().__init__() 7 | self.layer_norm = nn.LayerNorm(d_model) 8 | self.dropout = nn.Dropout(dropout) 9 | 10 | def forward(self, x, residual): 11 | # In the video this was: 12 | # ln = self.layer_norm(x + residual) 13 | # return self.dropout(ln) 14 | # The above does not lead to convergence. We must dropout x for convergence. 15 | # Why doesn't this work? Because we send the output of the layernorm to an attention block. 16 | # So some values would be zeroed out if dropout is enabled. Obviously MHA doesn't know what to attend to then. 17 | # We can dropout(x) though because 1) we're adding a residual to it, so dropped out values won't be zero 18 | # and 2), the layernorm has an additive beta parameter which provides a non-zero value to a tensor 19 | ln = self.layer_norm(self.dropout(x) + residual) 20 | return ln 21 | -------------------------------------------------------------------------------- /layers/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .embed import Embeddings 4 | from .encoder import Encoder 5 | from .decoder import Decoder 6 | 7 | 8 | class Transformer(nn.Module): 9 | def __init__(self, src_vocab_len, trg_vocab_len, d_model, d_ff, 10 | num_layers, num_heads, src_pad_idx, trg_pad_idx, dropout=0.3, device="cpu", efficient_mha=False): 11 | super().__init__() 12 | 13 | self.num_heads = num_heads 14 | self.device = device 15 | self.efficient_mha = efficient_mha 16 | 17 | encoder_Embedding = Embeddings( 18 | src_vocab_len, src_pad_idx, d_model) 19 | decoder_Embedding = Embeddings( 20 | trg_vocab_len, trg_pad_idx, d_model) 21 | 22 | self.src_pad_idx = src_pad_idx 23 | self.trg_pad_idx = trg_pad_idx 24 | 25 | self.encoder = Encoder(encoder_Embedding, d_model, 26 | num_heads, num_layers, d_ff, device, dropout, efficient_mha) 27 | self.decoder = Decoder(decoder_Embedding, d_model, 28 | num_heads, num_layers, d_ff, device, dropout, efficient_mha) 29 | 30 | self.linear_layer = nn.Linear(d_model, trg_vocab_len) 31 | 32 | for p in self.parameters(): 33 | if p.dim() > 1: 34 | nn.init.xavier_uniform_(p) 35 | 36 | def create_src_mask(self, src): 37 | src_mask = (src != self.src_pad_idx).unsqueeze(1) 38 | if self.efficient_mha: 39 | src_mask = src_mask.unsqueeze(2) 40 | return src_mask 41 | 42 | def create_trg_mask(self, trg): 43 | if self.efficient_mha: 44 | trg_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) 45 | mask = torch.ones((1, self.num_heads, trg.shape[1], trg.shape[1])).triu(1).to(self.device) 46 | else: 47 | trg_mask = (trg != self.trg_pad_idx).unsqueeze(1) 48 | mask = torch.ones((1, trg.shape[1], trg.shape[1])).triu(1).to(self.device) 49 | mask = mask == 0 50 | trg_mask = trg_mask & mask 51 | return trg_mask 52 | 53 | def forward(self, src, trg): 54 | # shape(src) = [B x SRC_seq_len] 55 | # shape(trg) = [B x TRG_seq_len] 56 | 57 | src_mask = self.create_src_mask(src) 58 | trg_mask = self.create_trg_mask(trg) 59 | # shape(src_mask) = [B x 1 x SRC_seq_len] 60 | # shape(trg_mask) = [B x 1 x TRG_seq_len] 61 | 62 | encoder_outputs, encoder_mha_attn_weights = self.encoder(src, src_mask) 63 | # shape(encoder_outputs) = [B x SRC_seq_len x D] 64 | # shape(encoder_mha_attn_weights) = [B x num_heads x SRC_seq_len x SRC_seq_len] 65 | 66 | decoder_outputs, _, enc_dec_mha_attn_weights = self.decoder( 67 | trg, encoder_outputs, trg_mask, src_mask) 68 | # shape(decoder_outputs) = [B x SRC_seq_len x D] 69 | # shape(enc_dec_mha_attn_weights) = [B x num_heads x TRG_seq_len x SRC_seq_len] 70 | 71 | logits = self.linear_layer(decoder_outputs) 72 | # shape(logits) = [B x TRG_seq_len x TRG_vocab_size] 73 | 74 | return logits 75 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | from layers.transformers import Transformer 5 | from torchtext.datasets import Multi30k, IWSLT2016 6 | from torchtext.data.utils import get_tokenizer 7 | from torchtext.vocab import build_vocab_from_iterator, Vocab 8 | from torch.utils.data import DataLoader 9 | import math as m 10 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 11 | 12 | 13 | class TransformerTrainer(pl.LightningModule): 14 | def __init__(self, src_vocab: Vocab, trg_vocab: Vocab, warmup_steps=4000, d_model=512, d_ff=2048, num_layers=6, num_heads=8, device="cpu", dropout=0.3): 15 | super().__init__() 16 | 17 | self.model = Transformer( 18 | src_vocab_len=len(src_vocab), 19 | trg_vocab_len=len(trg_vocab), 20 | d_model=d_model, 21 | d_ff=d_ff, 22 | num_layers=num_layers, 23 | num_heads=num_heads, 24 | src_pad_idx=src_vocab.__getitem__(""), 25 | trg_pad_idx=trg_vocab.__getitem__(""), 26 | dropout=dropout, 27 | device=device, 28 | efficient_mha=True 29 | ) 30 | self.src_vocab = src_vocab 31 | self.trg_vocab = trg_vocab 32 | self.device_ = device 33 | self.d_model = d_model 34 | self.warmup_steps = warmup_steps 35 | 36 | self.criterion = nn.CrossEntropyLoss(ignore_index=trg_vocab.__getitem__("")) 37 | 38 | def training_step(self, batch, batch_idx): 39 | src = batch[0].to(self.device_) 40 | trg = batch[1].to(self.device_) 41 | 42 | trg_input = trg[:, :-1] 43 | ys = trg[:, 1:].reshape(-1) 44 | 45 | logits = self.model(src, trg_input) 46 | 47 | loss = self.criterion(logits.reshape(-1, len(self.trg_vocab)), ys) 48 | 49 | self.change_lr_in_optimizer() 50 | self.log("train loss", loss) 51 | 52 | if batch_idx == 0: 53 | for idx in range(len(src)): 54 | print("(train) SRC:\t", self.clean_and_print_tokens(src[idx], "src")) 55 | print("(train) TRG:\t", self.clean_and_print_tokens(trg[idx], "trg")) 56 | print("(train) PRED:\t", self.clean_and_print_tokens(torch.argmax(logits[idx], dim=-1), "trg")) 57 | print("") 58 | 59 | return loss 60 | 61 | def validation_step(self, batch, batch_idx): 62 | src = batch[0].to(self.device_) 63 | trg = batch[1].to(self.device_) 64 | trg_input = trg[:, :-1] 65 | 66 | logits = self.model(src, trg_input) 67 | # shape(logits) = (batch_size, trg_len, vocab_size) 68 | 69 | ys = trg[:, 1:].reshape(-1) 70 | val_loss = self.criterion(logits.reshape(-1, len(self.trg_vocab)), ys) 71 | 72 | self.log("val loss", val_loss) 73 | 74 | for idx in range(len(src)): 75 | print(" SRC:\t", self.clean_and_print_tokens(src[idx], "src")) 76 | print(" TRG:\t", self.clean_and_print_tokens(trg[idx], "trg")) 77 | print("PRED:\t", self.clean_and_print_tokens(torch.argmax(logits[idx], dim=-1), "trg")) 78 | print("") 79 | print("Val Loss:", val_loss) 80 | 81 | def configure_optimizers(self): 82 | return torch.optim.Adam(self.parameters(), lr=0.001) 83 | 84 | def change_lr_in_optimizer(self): 85 | min_arg1 = m.sqrt(1/(self.global_step+1)) 86 | min_arg2 = self.global_step * (self.warmup_steps**-1.5) 87 | lr = m.sqrt(1/self.d_model) * min(min_arg1, min_arg2) 88 | self.trainer.lightning_optimizers[0].param_groups[0]['lr'] = lr 89 | 90 | def clean_and_print_tokens(self, tokens, src_or_trg): 91 | if src_or_trg == "src": 92 | vocab = self.src_vocab 93 | elif src_or_trg == "trg": 94 | vocab = self.trg_vocab 95 | 96 | return " ".join(vocab.lookup_tokens(tokens.tolist())) 97 | 98 | 99 | if __name__ == "__main__": 100 | device = ("cuda:0" if torch.cuda.is_available else "cpu") 101 | 102 | # train_iter, val_iter, test_iter = Multi30k() 103 | train_iter, val_iter, test_iter = IWSLT2016(language_pair=('de', 'en')) 104 | src_tokenizer = get_tokenizer("basic_english") 105 | trg_tokenizer = get_tokenizer("basic_english") 106 | 107 | def yield_tokens(data_iter, src_or_trg): 108 | for batch in data_iter: 109 | if src_or_trg == "src": 110 | yield src_tokenizer(batch[0]) 111 | elif src_or_trg == "trg": 112 | yield trg_tokenizer(batch[1]) 113 | 114 | src_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "src"), specials=["", "", "", ""]) 115 | src_vocab.set_default_index(src_vocab[""]) 116 | 117 | # train_iter, val_iter, test_iter = Multi30k() 118 | train_iter, val_iter, test_iter = IWSLT2016(language_pair=('de', 'en')) 119 | 120 | trg_vocab = build_vocab_from_iterator(yield_tokens(train_iter, "trg"), specials=["", "", "", ""]) 121 | trg_vocab.set_default_index(trg_vocab[""]) 122 | 123 | # train_iter, val_iter, test_iter = Multi30k() 124 | train_iter, val_iter, test_iter = IWSLT2016(language_pair=('de', 'en')) 125 | 126 | MAX_SEQ_LEN = 30 127 | 128 | def pad_to_max(tokens): 129 | return tokens[:MAX_SEQ_LEN] + [""] * max(0, MAX_SEQ_LEN - len(tokens)) 130 | 131 | def collate_fn(batch): 132 | # batch = [(, ), (, ), ...] 133 | srcs = [] 134 | trgs = [] 135 | for pair in batch: 136 | src = pair[0] 137 | trg = pair[1] 138 | 139 | tokenized_src = src_vocab(pad_to_max(src_tokenizer(" " + src + " "))) 140 | tokenized_trg = trg_vocab(pad_to_max(trg_tokenizer(" " + trg + " "))) 141 | 142 | srcs.append(tokenized_src) 143 | trgs.append(tokenized_trg) 144 | 145 | srcs = torch.tensor(srcs, dtype=torch.long) 146 | trgs = torch.tensor(trgs, dtype=torch.long) 147 | return srcs, trgs 148 | 149 | dataloader = DataLoader(list(train_iter), batch_size=64, shuffle=False, collate_fn=collate_fn) 150 | val_dataloader = DataLoader(list(val_iter), batch_size=64, shuffle=False, collate_fn=collate_fn) 151 | test_dataloader = DataLoader(list(test_iter), batch_size=64, shuffle=False, collate_fn=collate_fn) 152 | 153 | transformer = TransformerTrainer(src_vocab, trg_vocab, device=device) 154 | trainer = pl.Trainer(gpus=1, min_epochs=20, callbacks=[EarlyStopping(monitor="val loss", patience=5, mode="min")]) 155 | trainer.fit(transformer, dataloader, val_dataloader) 156 | --------------------------------------------------------------------------------