├── .gitignore ├── LICENSE ├── README.md ├── layers └── attention.py └── models ├── bert.py └── gpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Benjamin Warner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Commented Transformers 2 | 3 | Highly commented implementations of Transformers in PyTorch for *Creating a Transformer From Scratch* series: 4 | 5 | 1. [The Attention Mechanism](https://benjaminwarner.dev/2023/07/01/attention-mechanism.html) 6 | 2. [The Rest of the Transformer](https://benjaminwarner.dev/2023/07/28/rest-of-the-transformer.html) 7 | 8 | 9 | The layers folder contains implementations for Bidirectional Attention, Causal Attention, and CausalCrossAttention. 10 | 11 | The models folder contains single file implementations for GPT-2 and BERT. Both models are compatible with `torch.compile(..., fullgraph=True)`. -------------------------------------------------------------------------------- /layers/attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor, BoolTensor 8 | from torch.nn import functional as F 9 | 10 | 11 | class BidirectionalAttention(nn.Module): 12 | def __init__(self, hidden_size:int, num_heads:int, attn_drop:float=0.1, 13 | out_drop:float=0.1, bias:bool=True): 14 | super().__init__() 15 | # input dimension must be divisible by num_heads 16 | assert hidden_size % num_heads == 0 17 | # number of Attention heads 18 | self.nh = num_heads 19 | 20 | # linear layer to project queries, keys, values 21 | self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias) 22 | 23 | # attention dropout layer to prevent overfitting 24 | self.attn_drop = nn.Dropout(attn_drop) 25 | 26 | # linear layer to project final output 27 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias) 28 | 29 | # final output dropout layer to prevent overfitting 30 | self.out_drop = nn.Dropout(out_drop) 31 | 32 | # boolean `mask` of shape (batch_size, sequence_length) 33 | # where True is masked and False is unmasked 34 | def forward(self, x: Tensor, mask: BoolTensor|None = None): 35 | # batch size, sequence length, input dimension 36 | B, S, C = x.shape 37 | 38 | # split into queries, keys, & values of shape 39 | # batch size (B), num_heads (NH), sequence length (S), head size (HS) 40 | x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh) 41 | q, k, v = x.transpose(3, 1).unbind(dim=2) 42 | 43 | # dot product queries and keys for each head 44 | # (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S) 45 | attn = q @ k.transpose(-2, -1) 46 | 47 | # scale by square root of output dimension 48 | attn = attn / math.sqrt(k.size(-1)) 49 | 50 | # reshape and mask attention scores 51 | if mask is not None: 52 | attn = attn.masked_fill(mask.view(B, 1, 1, S), float('-inf')) 53 | 54 | # apply softmax to get attention weights 55 | attn = attn.softmax(dim=-1) 56 | 57 | # apply dropout to attention weight 58 | attn = self.attn_drop(attn) 59 | 60 | # dot product attention weights with values of shape 61 | # (B, NH, S, HS) = (B, NH, S, S) @ (B, NH, HS, S) 62 | x = attn @ v 63 | 64 | # and transpose heads & sequence and reshape back to (B, S, C) 65 | x = x.transpose(1, 2).reshape(B, S, C) 66 | 67 | # apply final linear layer and dropout to get output (B, S, C) 68 | return self.out_drop(self.Wo(x)) 69 | 70 | 71 | class CausalAttention(nn.Module): 72 | def __init__(self, hidden_size:int, num_heads:int, context_size:int, 73 | attn_drop:float=0.1, out_drop:float=0.1, bias:bool=True): 74 | super().__init__() 75 | # input dimension must be divisible by num_heads 76 | assert hidden_size % num_heads == 0 77 | # number of Attention heads 78 | self.nh = num_heads 79 | 80 | # linear layer to project queries, keys, values 81 | self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias) 82 | 83 | # attention dropout layer to prevent overfitting 84 | self.attn_drop = nn.Dropout(attn_drop) 85 | 86 | # linear layer to project final output 87 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias) 88 | 89 | # final output dropout layer to prevent overfitting 90 | self.out_drop = nn.Dropout(out_drop) 91 | 92 | # causal mask to ensure that Attention is not applied to future tokens where 93 | # context_size is the maximum sequence length of the transformer 94 | self.register_buffer('causal_mask', 95 | torch.triu(torch.ones([context_size, context_size], dtype=torch.bool), diagonal=1) 96 | .view(1, 1, context_size, context_size), persistent=False 97 | ) 98 | 99 | # boolean `mask` of shape (batch_size, sequence_length) 100 | # where True is masked and False is unmasked 101 | def forward(self, x: Tensor, mask: BoolTensor|None = None): 102 | # batch size, sequence length, input dimension 103 | B, S, C = x.shape 104 | 105 | # split into queries, keys, & values of shape 106 | # batch size (B), num_heads (NH), sequence length (S), head size (HS) 107 | x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh) 108 | q, k, v = x.transpose(3, 1).unbind(dim=2) 109 | 110 | # dot product queries and keys for each head 111 | # (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S) 112 | attn = q @ k.transpose(-2, -1) 113 | 114 | # scale by square root of output dimension 115 | attn = attn / math.sqrt(k.size(-1)) 116 | 117 | # apply input and causal mask 118 | combined_mask = self.causal_mask[:, :, :S, :S] 119 | if mask is not None: 120 | combined_mask += mask.view(B, 1, 1, S) 121 | attn = attn.masked_fill(combined_mask, float('-inf')) 122 | 123 | # apply softmax to get attention weights 124 | attn = attn.softmax(dim=-1) 125 | 126 | # apply dropout to attention weight 127 | attn = self.attn_drop(attn) 128 | 129 | # dot product attention weights with values of shape 130 | # (B, NH, S, HS) = (B, NH, S, S) @ (B, NH, HS, S) 131 | x = attn @ v 132 | 133 | # and transpose heads & sequence and reshape back to (B, S, C) 134 | x = x.transpose(1, 2).reshape(B, S, C) 135 | 136 | # apply final linear layer and dropout to get output (B, S, C) 137 | return self.out_drop(self.Wo(x)) 138 | 139 | class CausalCrossAttention(nn.Module): 140 | def __init__(self, 141 | hidden_size: int, 142 | num_heads: int, 143 | context_size: int, 144 | attn_drop: float = 0.1, 145 | out_drop: float = 0.1, 146 | bias: bool = True, 147 | ): 148 | super().__init__() 149 | # input dimension must be divisible by num_heads 150 | assert hidden_size % num_heads == 0 151 | # number of Attention heads 152 | self.nh = num_heads 153 | 154 | # linear layer to project queries from decoder input 155 | self.Wq = nn.Linear(hidden_size, hidden_size, bias=bias) 156 | 157 | # linear layer to project keys and values from encoder output 158 | self.Wkv = nn.Linear(hidden_size, hidden_size * 2, bias=bias) 159 | 160 | # attention dropout layer to prevent overfitting 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | 163 | # linear layer to project final output 164 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias) 165 | 166 | # final output dropout layer to prevent overfitting 167 | self.out_drop = nn.Dropout(out_drop) 168 | 169 | # causal mask to ensure that Attention is not applied to future tokens where 170 | # context_size is the maximum sequence length of the transformer 171 | self.register_buffer('causal_mask', 172 | torch.triu(torch.ones([context_size, context_size], dtype=torch.bool), diagonal=1) 173 | .view(1, 1, context_size, context_size), persistent=False 174 | ) 175 | 176 | 177 | # boolean `mask` of shape (batch_size, sequence_length) 178 | # where True is masked and False is unmasked 179 | def forward(self, x: Tensor, y: Tensor, mask: BoolTensor|None = None): 180 | # batch size, sequence length, input dimension 181 | B, S, C = x.shape 182 | 183 | # split into queries of shape (B, NH, S, HS) from decoder input 184 | q = self.Wq(x).reshape(B, S, self.nh, C//self.nh).transpose(1, 2) 185 | 186 | # split into keys and values of shape (B, NH, S, HS) from encoder output 187 | y = self.Wkv(y).reshape(B, S, 2, self.nh, C//self.nh) 188 | k, v = y.transpose(3, 1).unbind(dim=2) 189 | 190 | # dot product queries and keys for each head 191 | # (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S) 192 | attn = q @ k.transpose(-2, -1) 193 | 194 | # scale by square root of output dimension 195 | attn = attn / math.sqrt(k.size(-1)) 196 | 197 | # apply input and causal mask 198 | combined_mask = self.causal_mask[:, :, :S, :S] 199 | if mask is not None: 200 | combined_mask += mask.view(B, 1, 1, S) 201 | attn = attn.masked_fill(combined_mask, float('-inf')) 202 | 203 | # apply softmax to get attention weights 204 | attn = attn.softmax(dim=-1) 205 | 206 | # apply dropout to attention weight 207 | attn = self.attn_drop(attn) 208 | 209 | # dot product attention weights with values of shape 210 | # (B,NH,S,S) @ (B,NH,S,HS) -> (B,NH,S,HS) 211 | x = attn @ v 212 | 213 | # and transpose heads & sequence and reshape back to (B,S,C) 214 | x = x.transpose(1, 2).reshape(B, S, C) 215 | 216 | # apply final linear layer and dropout to get output (B,S,C) 217 | return self.out_drop(self.Wo(x)) -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | bert.py is a highly commented implementation of a modern BERT encoder Transformer 3 | 4 | The codebase for bert.py is inspired by: 5 | nanoGPT https://github.com/karpathy/nanoGPT - Copyright (c) 2022 Andrej Karpathy - MIT License 6 | cramming https://github.com/JonasGeiping/cramming - Copyright (c) 2022 Jonas Geiping - MIT License 7 | """ 8 | from __future__ import annotations 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch import Tensor, BoolTensor 15 | from torch.nn import functional as F 16 | 17 | 18 | class PositionalEncoding(nn.Module): 19 | def __init__(self, context_size: int, hidden_size: int): 20 | super().__init__() 21 | # create the positional encoding tensor of shape 22 | # maximum sequence length (MS) by embedding dimension (C) 23 | pe = torch.zeros(context_size, hidden_size, dtype=torch.float) 24 | 25 | # pre-populate the position and the div_terms 26 | position = torch.arange(context_size).unsqueeze(1) 27 | div_term = torch.exp( 28 | torch.arange(0, hidden_size, 2) * (-math.log(10000) / hidden_size) 29 | ) 30 | 31 | # even positional encodings use sine, odd cosine 32 | pe[:, 0::2] = torch.sin(position * div_term) 33 | pe[:, 1::2] = torch.cos(position * div_term) 34 | 35 | # register as a buffer so autograd doesn't modify 36 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) 37 | 38 | def forward(self, x: Tensor): 39 | # return the pre-calculated positional encodings 40 | # up to sequence length (S). output shape (1, S, C) 41 | return self.pe[:, :x.shape[1], :] 42 | 43 | 44 | class FeedForward(nn.Module): 45 | def __init__(self, hidden_size:int, expand_size:int, act:nn.Module=nn.GELU, 46 | drop:float=0.1, bias:bool=True): 47 | super().__init__() 48 | # project input to expanded dimension 49 | self.fc1 = nn.Linear(hidden_size, expand_size, bias=bias) 50 | 51 | # activation function to introduce non-linearity 52 | self.act = act() 53 | 54 | # project back to the input dimension 55 | self.fc2 = nn.Linear(expand_size, hidden_size, bias=bias) 56 | 57 | # optional dropout layer to prevent overfitting 58 | self.drop = nn.Dropout(drop) 59 | 60 | def forward(self, x:Tensor): 61 | x = self.fc1(x) # apply first linear layer 62 | x = self.act(x) # apply activation function 63 | x = self.fc2(x) # apply second linear layer 64 | x = self.drop(x) # optionally apply dropout layer 65 | return x 66 | 67 | 68 | class BidirectionalAttention(nn.Module): 69 | def __init__(self, hidden_size:int, num_heads:int, attn_drop:float=0.1, 70 | out_drop:float=0.1, bias:bool=True): 71 | super().__init__() 72 | # input dimension must be divisible by num_heads 73 | assert hidden_size % num_heads == 0 74 | # number of Attention heads 75 | self.nh = num_heads 76 | 77 | # linear layer to project queries, keys, values 78 | self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias) 79 | 80 | # attention dropout layer to prevent overfitting 81 | self.attn_drop = nn.Dropout(attn_drop) 82 | 83 | # linear layer to project final output 84 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias) 85 | 86 | # final output dropout layer to prevent overfitting 87 | self.out_drop = nn.Dropout(out_drop) 88 | 89 | # boolean `mask` of shape (batch_size, sequence_length) 90 | # where True is masked and False is unmasked 91 | def forward(self, x: Tensor, mask: BoolTensor|None = None): 92 | # batch size, sequence length, input dimension 93 | B, S, C = x.shape 94 | 95 | # split into queries, keys, & values of shape 96 | # batch size (B), num_heads (NH), sequence length (S), head size (HS) 97 | x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh) 98 | q, k, v = x.transpose(3, 1).unbind(dim=2) 99 | 100 | # dot product queries and keys for each head 101 | # (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S) 102 | attn = q @ k.transpose(-2, -1) 103 | 104 | # scale by square root of output dimension 105 | attn = attn / math.sqrt(k.size(-1)) 106 | 107 | # reshape and mask attention scores 108 | if mask is not None: 109 | attn = attn.masked_fill(mask.view(B, 1, 1, S), float('-inf')) 110 | 111 | # apply softmax to get attention weights 112 | attn = attn.softmax(dim=-1) 113 | 114 | # apply dropout to attention weight 115 | attn = self.attn_drop(attn) 116 | 117 | # dot product attention weights with values of shape 118 | # (B, NH, S, HS) = (B, NH, S, S) @ (B, NH, HS, S) 119 | x = attn @ v 120 | 121 | # and transpose heads & sequence and reshape back to (B, S, C) 122 | x = x.transpose(1, 2).reshape(B, S, C) 123 | 124 | # apply final linear layer and dropout to get output (B, S, C) 125 | return self.out_drop(self.Wo(x)) 126 | 127 | 128 | class TransformerBlock(nn.Module): 129 | def __init__(self, hidden_size:int, num_heads:int, expand_size:int, 130 | attention:nn.Module=BidirectionalAttention, act:nn.Module=nn.GELU, 131 | attn_drop:float=0.1, out_drop:float=0.1, ffn_drop:float=0.1, 132 | bias:bool=True): 133 | super().__init__() 134 | # first pre-norm layer 135 | self.norm1 = nn.LayerNorm(hidden_size) 136 | # initialize the attention layer 137 | self.attn = attention( 138 | hidden_size=hidden_size, num_heads=num_heads, attn_drop=attn_drop, 139 | out_drop=out_drop, bias=bias 140 | ) 141 | 142 | # second pre-norm layer 143 | self.norm2 = nn.LayerNorm(hidden_size) 144 | # initialize the feed forward network (MLP) 145 | self.ffn = FeedForward( 146 | hidden_size=hidden_size, expand_size=expand_size, act=act, 147 | drop=ffn_drop, bias=bias, 148 | ) 149 | 150 | def forward(self, x: Tensor): 151 | # normalize input then add residual to attention output 152 | x = x + self.attn(self.norm1(x)) 153 | 154 | # normalize input then add residual to feedforward output 155 | return x + self.ffn(self.norm2(x)) 156 | 157 | 158 | class BERT(nn.Module): 159 | def __init__(self, num_layers:int, vocab_size:int, hidden_size:int, num_heads:int, 160 | context_size:int, expand_size:int, attention:nn.Module=BidirectionalAttention, 161 | act:nn.Module=nn.GELU, embed_drop:float=0.1, attn_drop:float=0.1, 162 | out_drop:float=0.1, ffn_drop:float=0.1, head_norm:bool=True, 163 | tie_weights:bool=True, head_bias:bool=True, bias:bool=True): 164 | super().__init__() 165 | # initialize vocab & positional encodings to convert numericalied tokens 166 | # & position indicies to token and position vectors, with optional dropout 167 | self.vocab_embed = nn.Embedding(vocab_size, hidden_size) 168 | self.pos_encode = PositionalEncoding(context_size, hidden_size) 169 | self.embed_drop = nn.Dropout(embed_drop) 170 | 171 | # initialize num_layers of transformer layers 172 | self.tfm_blocks = nn.ModuleList([TransformerBlock( 173 | hidden_size=hidden_size, num_heads=num_heads, expand_size=expand_size, 174 | attention=attention, act=act, bias=bias, attn_drop=attn_drop, 175 | out_drop=out_drop, ffn_drop=ffn_drop) 176 | for _ in range(num_layers)]) 177 | 178 | # optional pre-head normalization 179 | self.head_norm = nn.LayerNorm(hidden_size) if head_norm else nn.Identity() 180 | 181 | # predicts the next token in the sequence 182 | self.head = nn.Linear(hidden_size, vocab_size, bias=head_bias) 183 | 184 | # optionally set the vocab embedding and prediction head to share weights 185 | if tie_weights: 186 | self.head.weight = self.vocab_embed.weight 187 | 188 | self.apply(self._init_weights) 189 | 190 | def forward(self, x: Tensor, return_preds:bool=True): 191 | # convert numericalized tokens of shape (B, S) 192 | # into token embeddings of shape (B, S, C) 193 | tokens = self.vocab_embed(x) 194 | # positional encodings are shape (S, C) 195 | pos = self.pos_encode(x) 196 | 197 | # positional encodings are added to token embeddings 198 | x = self.embed_drop(tokens + pos) 199 | 200 | # pass token vectors through all transformer layers 201 | for block in self.tfm_blocks: 202 | x = block(x) 203 | 204 | # apply optional pre-head normalization 205 | x = self.head_norm(x) 206 | 207 | # if MLM pretraining, don't predict outputs here 208 | if return_preds: 209 | # converts input token vectors of shape (B, S, C) to probability 210 | # distribution of shape batch, sequence length, vocabulary size (B, S, VS) 211 | return self.head(x) 212 | else: 213 | return x 214 | 215 | def _init_weights(self, module): 216 | if isinstance(module, nn.Linear): 217 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 218 | if module.bias is not None: 219 | torch.nn.init.zeros_(module.bias) 220 | elif isinstance(module, nn.Embedding): 221 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 222 | 223 | 224 | class BERTForMaskedLM(BERT): 225 | def __init__(self, loss_fn:nn.Module=nn.CrossEntropyLoss(), 226 | mlm_prob:float|None=None, **kwargs): 227 | super().__init__(**kwargs) 228 | self.loss_fn = loss_fn 229 | self.mlm_prob = mlm_prob 230 | 231 | def forward(self, input_ids: Tensor, labels: Tensor, mlm_prob: float|None = None): 232 | x = super().forward(input_ids, False) 233 | 234 | # flatten both the labels and the intermediate outputs 235 | labels = labels.view(-1) 236 | x = x.view(labels.shape[0], -1) 237 | 238 | # only select the masked tokens for predictions 239 | mask_tokens = labels != self.loss_fn.ignore_index 240 | 241 | # torch.compile with fullgraph cannot have dynamic shapes 242 | # if `mlm_prob` is set, this will create workable indicies 243 | # if `mlm_prob` is None, then fullgraph=True cannot be used 244 | mlm_prob = self.mlm_prob if mlm_prob is None else mlm_prob 245 | if mlm_prob is not None: 246 | num_masks = math.floor(self.mlm_prob * mask_tokens.shape[0]) 247 | else: 248 | num_masks = mask_tokens.sum().int() 249 | indices = torch.argsort(mask_tokens.int())[-num_masks:] 250 | 251 | # selecting the masked tokens reshapes x to (B*S, VS) and labels to (B*S) 252 | x = x[indices] 253 | labels = labels[indices] 254 | 255 | # converts input token vectors of shape (B*S, C) 256 | # to probability distribution of shape (B*S, VS) 257 | logits = self.head(x) 258 | 259 | # return both the logits and the loss 260 | return {'logits': logits, 'loss': self.loss_fn(logits, labels)} -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | gpt.py is a highly commented implementation of the GPT-2 causal decoder Transformer 3 | 4 | The codebase for gpt.py is inspired by: 5 | nanoGPT https://github.com/karpathy/nanoGPT - Copyright (c) 2022 Andrej Karpathy - MIT License 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch import Tensor, BoolTensor 15 | from torch.nn import functional as F 16 | 17 | 18 | class FeedForward(nn.Module): 19 | def __init__(self, hidden_size:int, expand_size:int, act:nn.Module=nn.GELU, 20 | drop:float=0.1, bias:bool=True): 21 | super().__init__() 22 | # project input to expanded dimension 23 | self.fc1 = nn.Linear(hidden_size, expand_size, bias=bias) 24 | 25 | # activation function to introduce non-linearity 26 | self.act = act() 27 | 28 | # project back to the input dimension 29 | self.fc2 = nn.Linear(expand_size, hidden_size, bias=bias) 30 | 31 | # optional dropout layer to prevent overfitting 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x:Tensor): 35 | x = self.fc1(x) # apply first linear layer 36 | x = self.act(x) # apply activation function 37 | x = self.fc2(x) # apply second linear layer 38 | x = self.drop(x) # optionally apply dropout layer 39 | return x 40 | 41 | 42 | class CausalAttention(nn.Module): 43 | def __init__(self, hidden_size:int, num_heads:int, context_size:int, 44 | attn_drop:float=0.1, out_drop:float=0.1, bias:bool=True): 45 | super().__init__() 46 | # input dimension must be divisible by num_heads 47 | assert hidden_size % num_heads == 0 48 | # number of Attention heads 49 | self.nh = num_heads 50 | 51 | # linear layer to project queries, keys, values 52 | self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias) 53 | 54 | # attention dropout layer to prevent overfitting 55 | self.attn_drop = nn.Dropout(attn_drop) 56 | 57 | # linear layer to project final output 58 | self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias) 59 | 60 | # final output dropout layer to prevent overfitting 61 | self.out_drop = nn.Dropout(out_drop) 62 | 63 | # causal mask to ensure that Attention is not applied to future tokens where 64 | # context_size is the maximum sequence length of the transformer 65 | self.register_buffer('causal_mask', 66 | torch.triu(torch.ones([context_size, context_size], dtype=torch.bool), diagonal=1) 67 | .view(1, 1, context_size, context_size), persistent=False 68 | ) 69 | 70 | # boolean `mask` of shape (batch_size, sequence_length) 71 | # where True is masked and False is unmasked 72 | def forward(self, x: Tensor, mask: BoolTensor|None = None): 73 | # batch size, sequence length, input dimension 74 | B, S, C = x.shape 75 | 76 | # split into queries, keys, & values of shape 77 | # batch size (B), num_heads (NH), sequence length (S), head size (HS) 78 | x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh) 79 | q, k, v = x.transpose(3, 1).unbind(dim=2) 80 | 81 | # dot product queries and keys for each head 82 | # (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S) 83 | attn = q @ k.transpose(-2, -1) 84 | 85 | # scale by square root of output dimension 86 | attn = attn / math.sqrt(k.size(-1)) 87 | 88 | # apply input and causal mask 89 | combined_mask = self.causal_mask[:, :, :S, :S] 90 | if mask is not None: 91 | combined_mask += mask.view(B, 1, 1, S) 92 | attn = attn.masked_fill(combined_mask, float('-inf')) 93 | 94 | # apply softmax to get attention weights 95 | attn = attn.softmax(dim=-1) 96 | 97 | # apply dropout to attention weight 98 | attn = self.attn_drop(attn) 99 | 100 | # dot product attention weights with values of shape 101 | # (B, NH, S, HS) = (B, NH, S, S) @ (B, NH, HS, S) 102 | x = attn @ v 103 | 104 | # and transpose heads & sequence and reshape back to (B, S, C) 105 | x = x.transpose(1, 2).reshape(B, S, C) 106 | 107 | # apply final linear layer and dropout to get output (B, S, C) 108 | return self.out_drop(self.Wo(x)) 109 | 110 | 111 | class TransformerBlock(nn.Module): 112 | def __init__(self, hidden_size:int, num_heads:int, context_size:int, expand_size:int, 113 | attention:nn.Module=CausalAttention, act:nn.Module=nn.GELU, 114 | attn_drop:float=0.1, out_drop:float=0.1, ffn_drop:float=0.1, 115 | bias:bool=True): 116 | super().__init__() 117 | # first pre-norm layer 118 | self.norm1 = nn.LayerNorm(hidden_size) 119 | # initialize the attention layer 120 | self.attn = attention( 121 | hidden_size=hidden_size, num_heads=num_heads, context_size=context_size, 122 | attn_drop=attn_drop, out_drop=out_drop, bias=bias 123 | ) 124 | 125 | # second pre-norm layer 126 | self.norm2 = nn.LayerNorm(hidden_size) 127 | # initialize the feed forward network (MLP) 128 | self.ffn = FeedForward( 129 | hidden_size=hidden_size, expand_size=expand_size, act=act, 130 | drop=ffn_drop, bias=bias, 131 | ) 132 | 133 | def forward(self, x: Tensor): 134 | # normalize input then add residual to attention output 135 | x = x + self.attn(self.norm1(x)) 136 | 137 | # normalize input then add residual to feedforward output 138 | return x + self.ffn(self.norm2(x)) 139 | 140 | 141 | class GPT(nn.Module): 142 | def __init__(self, num_layers:int, vocab_size:int, hidden_size:int, num_heads:int, 143 | context_size:int, expand_size:int, attention:nn.Module=CausalAttention, 144 | act:nn.Module=nn.GELU, embed_drop:float=0.1, attn_drop:float=0.1, 145 | out_drop:float=0.1, ffn_drop:float=0.1, head_norm:bool=True, 146 | tie_weights:bool=True, head_bias:bool=True, bias:bool=True): 147 | super().__init__() 148 | # initialize vocab & positional embeddings to convert numericalied tokens 149 | # & position indicies to token and position vectors, with optional dropout 150 | self.vocab_embed = nn.Embedding(vocab_size, hidden_size) 151 | self.pos_embed = nn.Embedding(context_size, hidden_size) 152 | self.embed_drop = nn.Dropout(embed_drop) 153 | 154 | # initialize num_layers of transformer layers 155 | self.tfm_blocks = nn.ModuleList([TransformerBlock( 156 | hidden_size=hidden_size, num_heads=num_heads, context_size=context_size, 157 | expand_size=expand_size, attention=attention, act=act, bias=bias, 158 | attn_drop=attn_drop, out_drop=out_drop, ffn_drop=ffn_drop) 159 | for _ in range(num_layers)]) 160 | 161 | # optional pre-head normalization 162 | self.head_norm = nn.LayerNorm(hidden_size) if head_norm else nn.Identity() 163 | 164 | # predicts the next token in the sequence 165 | self.head = nn.Linear(hidden_size, vocab_size, bias=head_bias) 166 | 167 | # optionally set the vocab embedding and prediction head to share weights 168 | if tie_weights: 169 | self.head.weight = self.vocab_embed.weight 170 | 171 | # precreate positional indices for the positional embedding 172 | pos = torch.arange(0, context_size, dtype=torch.long) 173 | self.register_buffer('pos', pos, persistent=False) 174 | 175 | self.apply(self._init_weights) 176 | 177 | def forward(self, x: Tensor): 178 | # convert numericalized tokens of shape (B, S) 179 | # into token embeddings of shape (B, S, C) 180 | tokens = self.vocab_embed(x) 181 | # positional embeddings are shape (S, C) 182 | pos = self.pos_embed(self.pos[:x.shape[1]]) 183 | 184 | # positional embeddings are added to token embeddings 185 | x = self.embed_drop(tokens + pos) 186 | 187 | # pass token vectors through all transformer layers 188 | for block in self.tfm_blocks: 189 | x = block(x) 190 | 191 | # apply optional pre-head normalization 192 | x = self.head_norm(x) 193 | 194 | # converts input token vectors of shape (B, S, C) to probability 195 | # distribution of shape batch, sequence length, vocabulary size (B, S, VS) 196 | return self.head(x) 197 | 198 | def _init_weights(self, module): 199 | if isinstance(module, nn.Linear): 200 | if module._get_name() == 'fc2': 201 | # GPT-2 style FFN init 202 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02/math.sqrt(2 * self.num_layers)) 203 | else: 204 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 205 | if module.bias is not None: 206 | torch.nn.init.zeros_(module.bias) 207 | elif isinstance(module, nn.Embedding): 208 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 209 | 210 | 211 | class GPTForCausalLM(GPT): 212 | def __init__(self, loss_fn:nn.Module=nn.CrossEntropyLoss(), **kwargs): 213 | super().__init__(**kwargs) 214 | self.loss_fn = loss_fn 215 | 216 | def forward(self, x: Tensor): 217 | # the labels are the next token, so shift the labels over one 218 | # & resize inputs to same length as labels by dropping last token 219 | inputs = x[:, :-1] 220 | labels = x[:, 1:] 221 | 222 | # logits are of shape batch, sequence length, vocab size (B, S, VS), 223 | # labels are of shape batch, vocab size (B, S) 224 | logits = super().forward(inputs) 225 | 226 | # flatten logits into (B*S, VS) and labels into (B*S) & calculate loss 227 | loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) 228 | 229 | # return both the logits and the loss 230 | return {'logits': logits, 'loss': loss} --------------------------------------------------------------------------------