├── src ├── __init__.py ├── transformer.py ├── multi_head_attention.py ├── utils.py ├── encoder.py └── decoder.py ├── tests ├── __init__.py └── test_transformer.py ├── requirements.txt ├── .github ├── FUNDING.yml └── workflows │ └── test.yaml ├── setup.cfg ├── README.md ├── LICENSE └── .gitignore /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.8 -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [fkodom] 4 | custom: [fkodom.substack.com] 5 | # patreon: # Replace with a single Patreon username 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F401 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4 6 | 7 | [isort] 8 | multi_line_output=3 9 | include_trailing_comma=True 10 | force_grid_wrap=0 11 | use_parentheses=True 12 | line_length=88 13 | 14 | [mypy] 15 | files=dataset_tools,tests 16 | ignore_missing_imports=true 17 | 18 | [tool:pytest] 19 | testpaths=tests/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformer-from-scratch 2 | Code for my blog post: [Transformers from Scratch in PyTorch](https://fkodom.substack.com/p/transformers-from-scratch-in-pytorch) 3 | 4 | **Note:** This Transformer code does **not** include masked attention. That was intentional, because it led to a much cleaner implementation. This repository is intended for educational purposes only. I believe that everything here is correct, but make no guarantees if for some reason you decide to use it in your own project. 5 | 6 | ## Citations 7 | ``` 8 | @misc{vaswani2023attention, 9 | title={Attention Is All You Need}, 10 | author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin}, 11 | year={2023}, 12 | eprint={1706.03762}, 13 | archivePrefix={arXiv}, 14 | primaryClass={cs.CL} 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_dispatch: {} 5 | push: {} 6 | 7 | jobs: 8 | test: 9 | name: Test 10 | runs-on: ubuntu-latest 11 | 12 | strategy: 13 | matrix: 14 | python: ["3.7", "3.8", "3.9"] 15 | torch: ["1.8", "1.9", "1.10"] 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Setup Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python }} 25 | 26 | - name: Install Package 27 | run: | 28 | pip install torch==${{ matrix.torch }} 29 | pip install -r requirements.txt 30 | pip install black flake8 isort pytest pytest-cov 31 | - name: Test 32 | run: | 33 | pytest --cov --cov-fail-under 100 tests/ 34 | black --check . 35 | isort --diff . 36 | flake8 . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Frank Odom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/transformer.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | from src.decoder import TransformerDecoder 4 | from src.encoder import TransformerEncoder 5 | 6 | 7 | class Transformer(nn.Module): 8 | def __init__( 9 | self, 10 | num_encoder_layers: int = 6, 11 | num_decoder_layers: int = 6, 12 | dim_model: int = 512, 13 | num_heads: int = 6, 14 | dim_feedforward: int = 2048, 15 | dropout: float = 0.1, 16 | ): 17 | super().__init__() 18 | self.encoder = TransformerEncoder( 19 | num_layers=num_encoder_layers, 20 | dim_model=dim_model, 21 | num_heads=num_heads, 22 | dim_feedforward=dim_feedforward, 23 | dropout=dropout, 24 | ) 25 | self.decoder = TransformerDecoder( 26 | num_layers=num_decoder_layers, 27 | dim_model=dim_model, 28 | num_heads=num_heads, 29 | dim_feedforward=dim_feedforward, 30 | dropout=dropout, 31 | ) 32 | 33 | def forward(self, src: Tensor, tgt: Tensor) -> Tensor: 34 | return self.decoder(tgt, self.encoder(src)) 35 | -------------------------------------------------------------------------------- /src/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | from torch import Tensor, nn 4 | 5 | 6 | def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor: 7 | temp = query.bmm(key.transpose(1, 2)) 8 | scale = query.size(-1) ** 0.5 9 | softmax = f.softmax(temp / scale, dim=-1) 10 | return softmax.bmm(value) 11 | 12 | 13 | class AttentionHead(nn.Module): 14 | def __init__(self, dim_in: int, dim_q: int, dim_k: int): 15 | super().__init__() 16 | self.q = nn.Linear(dim_in, dim_q) 17 | self.k = nn.Linear(dim_in, dim_k) 18 | self.v = nn.Linear(dim_in, dim_k) 19 | 20 | def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: 21 | return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value)) 22 | 23 | 24 | class MultiHeadAttention(nn.Module): 25 | def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int): 26 | super().__init__() 27 | self.heads = nn.ModuleList( 28 | [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)] 29 | ) 30 | self.linear = nn.Linear(num_heads * dim_k, dim_in) 31 | 32 | def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: 33 | return self.linear( 34 | torch.cat([h(query, key, value) for h in self.heads], dim=-1) 35 | ) 36 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | 5 | def position_encoding( 6 | seq_len: int, 7 | dim_model: int, 8 | device: torch.device = torch.device("cpu"), 9 | ) -> Tensor: 10 | pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1) 11 | dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1, 1, -1) 12 | phase = pos / (1e4 ** torch.div(dim, dim_model, rounding_mode="floor")) 13 | 14 | return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase)) 15 | 16 | 17 | def feed_forward(dim_input: int = 512, dim_feedforward: int = 2048) -> nn.Module: 18 | return nn.Sequential( 19 | nn.Linear(dim_input, dim_feedforward), 20 | nn.ReLU(), 21 | nn.Linear(dim_feedforward, dim_input), 22 | ) 23 | 24 | 25 | class Residual(nn.Module): 26 | def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1): 27 | super().__init__() 28 | self.sublayer = sublayer 29 | self.norm = nn.LayerNorm(dimension) 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | def forward(self, *tensors: Tensor) -> Tensor: 33 | # Assume that the "query" tensor is given first, so we can compute the 34 | # residual. This matches the signature of 'MultiHeadAttention'. 35 | return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors))) 36 | -------------------------------------------------------------------------------- /src/encoder.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | from src.multi_head_attention import MultiHeadAttention 4 | from src.utils import Residual, feed_forward, position_encoding 5 | 6 | 7 | class TransformerEncoderLayer(nn.Module): 8 | def __init__( 9 | self, 10 | dim_model: int = 512, 11 | num_heads: int = 6, 12 | dim_feedforward: int = 2048, 13 | dropout: float = 0.1, 14 | ): 15 | super().__init__() 16 | dim_q = dim_k = max(dim_model // num_heads, 1) 17 | self.attention = Residual( 18 | MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), 19 | dimension=dim_model, 20 | dropout=dropout, 21 | ) 22 | self.feed_forward = Residual( 23 | feed_forward(dim_model, dim_feedforward), 24 | dimension=dim_model, 25 | dropout=dropout, 26 | ) 27 | 28 | def forward(self, src: Tensor) -> Tensor: 29 | src = self.attention(src, src, src) 30 | return self.feed_forward(src) 31 | 32 | 33 | class TransformerEncoder(nn.Module): 34 | def __init__( 35 | self, 36 | num_layers: int = 6, 37 | dim_model: int = 512, 38 | num_heads: int = 8, 39 | dim_feedforward: int = 2048, 40 | dropout: float = 0.1, 41 | ): 42 | super().__init__() 43 | self.layers = nn.ModuleList( 44 | [ 45 | TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout) 46 | for _ in range(num_layers) 47 | ] 48 | ) 49 | 50 | def forward(self, src: Tensor) -> Tensor: 51 | seq_len, dimension = src.size(1), src.size(2) 52 | src += position_encoding(seq_len, dimension) 53 | for layer in self.layers: 54 | src = layer(src) 55 | 56 | return src 57 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from src.transformer import Transformer 5 | 6 | 7 | @pytest.mark.parametrize("num_encoder_layers", [1, 6]) 8 | @pytest.mark.parametrize("num_decoder_layers", [1, 6]) 9 | @pytest.mark.parametrize("dim_model", [2, 8]) 10 | @pytest.mark.parametrize("num_heads", [1, 6]) 11 | @pytest.mark.parametrize("dim_feedforward", [2, 8]) 12 | def test_init( 13 | num_encoder_layers: int, 14 | num_decoder_layers: int, 15 | dim_model: int, 16 | num_heads: int, 17 | dim_feedforward: int, 18 | ): 19 | _ = Transformer( 20 | num_encoder_layers=num_encoder_layers, 21 | num_decoder_layers=num_decoder_layers, 22 | dim_model=dim_model, 23 | num_heads=num_heads, 24 | dim_feedforward=dim_feedforward, 25 | ) 26 | 27 | 28 | @pytest.mark.parametrize("batch_size", [1, 2]) 29 | @pytest.mark.parametrize("src_len", [2, 8]) 30 | @pytest.mark.parametrize("tgt_len", [2, 8]) 31 | @pytest.mark.parametrize("num_features", [2, 8]) 32 | @pytest.mark.parametrize("num_encoder_layers", [1, 6]) 33 | @pytest.mark.parametrize("num_decoder_layers", [1, 6]) 34 | @pytest.mark.parametrize("num_heads", [1, 6]) 35 | @pytest.mark.parametrize("dim_feedforward", [2, 8]) 36 | def test_forward( 37 | batch_size: int, 38 | src_len: int, 39 | tgt_len: int, 40 | num_features: int, 41 | num_encoder_layers: int, 42 | num_decoder_layers: int, 43 | num_heads: int, 44 | dim_feedforward: int, 45 | ): 46 | model = Transformer( 47 | num_encoder_layers=num_encoder_layers, 48 | num_decoder_layers=num_decoder_layers, 49 | dim_model=num_features, 50 | num_heads=num_heads, 51 | dim_feedforward=dim_feedforward, 52 | ) 53 | 54 | src = torch.randn(batch_size, src_len, num_features) 55 | tgt = torch.randn(batch_size, tgt_len, num_features) 56 | out = model(src, tgt) 57 | 58 | _batch_size, seq_len, _num_features = out.shape 59 | assert batch_size == _batch_size 60 | assert seq_len == tgt_len 61 | assert _num_features == num_features 62 | -------------------------------------------------------------------------------- /src/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from src.multi_head_attention import MultiHeadAttention 5 | from src.utils import Residual, feed_forward, position_encoding 6 | 7 | 8 | class TransformerDecoderLayer(nn.Module): 9 | def __init__( 10 | self, 11 | dim_model: int = 512, 12 | num_heads: int = 6, 13 | dim_feedforward: int = 2048, 14 | dropout: float = 0.1, 15 | ): 16 | super().__init__() 17 | dim_q = dim_k = max(dim_model // num_heads, 1) 18 | self.attention_1 = Residual( 19 | MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), 20 | dimension=dim_model, 21 | dropout=dropout, 22 | ) 23 | self.attention_2 = Residual( 24 | MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), 25 | dimension=dim_model, 26 | dropout=dropout, 27 | ) 28 | self.feed_forward = Residual( 29 | feed_forward(dim_model, dim_feedforward), 30 | dimension=dim_model, 31 | dropout=dropout, 32 | ) 33 | 34 | def forward(self, tgt: Tensor, memory: Tensor) -> Tensor: 35 | tgt = self.attention_1(tgt, tgt, tgt) 36 | tgt = self.attention_2(tgt, memory, memory) 37 | return self.feed_forward(tgt) 38 | 39 | 40 | class TransformerDecoder(nn.Module): 41 | def __init__( 42 | self, 43 | num_layers: int = 6, 44 | dim_model: int = 512, 45 | num_heads: int = 8, 46 | dim_feedforward: int = 2048, 47 | dropout: float = 0.1, 48 | ): 49 | super().__init__() 50 | self.layers = nn.ModuleList( 51 | [ 52 | TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout) 53 | for _ in range(num_layers) 54 | ] 55 | ) 56 | self.linear = nn.Linear(dim_model, dim_model) 57 | 58 | def forward(self, tgt: Tensor, memory: Tensor) -> Tensor: 59 | seq_len, dimension = tgt.size(1), tgt.size(2) 60 | tgt += position_encoding(seq_len, dimension) 61 | for layer in self.layers: 62 | tgt = layer(tgt, memory) 63 | 64 | return torch.softmax(self.linear(tgt), dim=-1) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | --------------------------------------------------------------------------------