├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── fast-transformer.png ├── fast_transformer_pytorch ├── __init__.py └── fast_transformer_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Fast Transformer - Pytorch 4 | 5 | Implementation of Fast Transformer in Pytorch. This only work as an encoder. 6 | 7 | Yannic video 8 | 9 | AI Epiphany 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install fast-transformer-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from fast_transformer_pytorch import FastTransformer 22 | 23 | model = FastTransformer( 24 | num_tokens = 20000, 25 | dim = 512, 26 | depth = 2, 27 | max_seq_len = 4096, 28 | absolute_pos_emb = True # default uses relative positional encoding, but if that isn't working, then turn on absolute positional embedding by setting this to True 29 | ) 30 | 31 | x = torch.randint(0, 20000, (1, 4096)) 32 | mask = torch.ones(1, 4096).bool() 33 | 34 | logits = model(x, mask = mask) # (1, 4096, 20000) 35 | ``` 36 | 37 | ## Citations 38 | 39 | ```bibtex 40 | @misc{wu2021fastformer, 41 | title = {Fastformer: Additive Attention is All You Need}, 42 | author = {Chuhan Wu and Fangzhao Wu and Tao Qi and Yongfeng Huang}, 43 | year = {2021}, 44 | eprint = {2108.09084}, 45 | archivePrefix = {arXiv}, 46 | primaryClass = {cs.CL} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /fast-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/fast-transformer-pytorch/8fb775fb939731fabd59e312758a57f57fa8db1e/fast-transformer.png -------------------------------------------------------------------------------- /fast_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from fast_transformer_pytorch.fast_transformer_pytorch import FastTransformer 2 | -------------------------------------------------------------------------------- /fast_transformer_pytorch/fast_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, einsum 4 | 5 | from einops import rearrange, reduce 6 | from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding 7 | 8 | # helper functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val, d): 14 | return val if exists(val) else d 15 | 16 | # helper classes 17 | 18 | class PreNorm(nn.Module): 19 | def __init__(self, dim, fn): 20 | super().__init__() 21 | self.norm = nn.LayerNorm(dim) 22 | self.fn = fn 23 | 24 | def forward(self, x, **kwargs): 25 | x = self.norm(x) 26 | return self.fn(x, **kwargs) 27 | 28 | # blocks 29 | 30 | def FeedForward(dim, mult = 4): 31 | return nn.Sequential( 32 | nn.Linear(dim, dim * mult), 33 | nn.GELU(), 34 | nn.Linear(dim * mult, dim) 35 | ) 36 | 37 | class FastAttention(nn.Module): 38 | def __init__( 39 | self, 40 | dim, 41 | *, 42 | heads = 8, 43 | dim_head = 64, 44 | max_seq_len = None, 45 | pos_emb = None 46 | ): 47 | super().__init__() 48 | inner_dim = heads * dim_head 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | 52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 53 | 54 | # rotary positional embedding 55 | 56 | assert not (exists(pos_emb) and not exists(max_seq_len)), 'max_seq_len must be passed in if to use rotary positional embeddings' 57 | 58 | self.pos_emb = pos_emb 59 | self.max_seq_len = max_seq_len 60 | 61 | # if using relative positional encoding, make sure to reduce pairs of consecutive feature dimension before doing projection to attention logits 62 | 63 | kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2 64 | 65 | self.to_q_attn_logits = nn.Linear(dim_head, 1, bias = False) # for projecting queries to query attention logits 66 | self.to_k_attn_logits = nn.Linear(dim_head // kv_attn_proj_divisor, 1, bias = False) # for projecting keys to key attention logits 67 | 68 | # final transformation of values to "r" as in the paper 69 | 70 | self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head) 71 | 72 | self.to_out = nn.Linear(inner_dim, dim) 73 | 74 | def forward(self, x, mask = None): 75 | n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb) 76 | 77 | qkv = self.to_qkv(x).chunk(3, dim = -1) 78 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 79 | 80 | mask_value = -torch.finfo(x.dtype).max 81 | mask = rearrange(mask, 'b n -> b () n') 82 | 83 | # if relative positional encoding is needed 84 | 85 | if use_rotary_emb: 86 | freqs = self.pos_emb(torch.arange(self.max_seq_len, device = device), cache_key = self.max_seq_len) 87 | freqs = rearrange(freqs[:n], 'n d -> () () n d') 88 | q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v)) 89 | else: 90 | q_aggr, k_aggr, v_aggr = q, k, v 91 | 92 | # calculate query attention logits 93 | 94 | q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale 95 | q_attn_logits = q_attn_logits.masked_fill(~mask, mask_value) 96 | q_attn = q_attn_logits.softmax(dim = -1) 97 | 98 | # calculate global query token 99 | 100 | global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr) 101 | global_q = rearrange(global_q, 'b h d -> b h () d') 102 | 103 | # bias keys with global query token 104 | 105 | k = k * global_q 106 | 107 | # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension 108 | 109 | if use_rotary_emb: 110 | k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r = 2) 111 | 112 | # now calculate key attention logits 113 | 114 | k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale 115 | k_attn_logits = k_attn_logits.masked_fill(~mask, mask_value) 116 | k_attn = k_attn_logits.softmax(dim = -1) 117 | 118 | # calculate global key token 119 | 120 | global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr) 121 | global_k = rearrange(global_k, 'b h d -> b h () d') 122 | 123 | # bias the values 124 | 125 | u = v_aggr * global_k 126 | 127 | # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension 128 | 129 | if use_rotary_emb: 130 | u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r = 2) 131 | 132 | # transformation step 133 | 134 | r = self.to_r(u) 135 | 136 | # paper then says to add the queries as a residual 137 | 138 | r = r + q 139 | 140 | # combine heads 141 | 142 | r = rearrange(r, 'b h n d -> b n (h d)') 143 | return self.to_out(r) 144 | 145 | # main class 146 | 147 | class FastTransformer(nn.Module): 148 | def __init__( 149 | self, 150 | *, 151 | num_tokens, 152 | dim, 153 | depth, 154 | max_seq_len, 155 | heads = 8, 156 | dim_head = 64, 157 | ff_mult = 4, 158 | absolute_pos_emb = False 159 | ): 160 | super().__init__() 161 | self.token_emb = nn.Embedding(num_tokens, dim) 162 | 163 | # positional embeddings 164 | 165 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None 166 | 167 | layer_pos_emb = None 168 | if not absolute_pos_emb: 169 | assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings' 170 | layer_pos_emb = RotaryEmbedding(dim_head // 2) 171 | 172 | # layers 173 | 174 | self.layers = nn.ModuleList([]) 175 | 176 | for _ in range(depth): 177 | attn = FastAttention(dim, dim_head = dim_head, heads = heads, pos_emb = layer_pos_emb, max_seq_len = max_seq_len) 178 | ff = FeedForward(dim, mult = ff_mult) 179 | 180 | self.layers.append(nn.ModuleList([ 181 | PreNorm(dim, attn), 182 | PreNorm(dim, ff) 183 | ])) 184 | 185 | # weight tie projections across all layers 186 | 187 | first_block, _ = self.layers[0] 188 | for block, _ in self.layers[1:]: 189 | block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits 190 | block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits 191 | 192 | # to logits 193 | 194 | self.to_logits = nn.Sequential( 195 | nn.LayerNorm(dim), 196 | nn.Linear(dim, num_tokens) 197 | ) 198 | 199 | def forward( 200 | self, 201 | x, 202 | mask = None 203 | ): 204 | n, device = x.shape[1], x.device 205 | x = self.token_emb(x) 206 | 207 | if exists(self.abs_pos_emb): 208 | pos_emb = self.abs_pos_emb(torch.arange(n, device = device)) 209 | x = x + rearrange(pos_emb, 'n d -> () n d') 210 | 211 | for attn, ff in self.layers: 212 | x = attn(x, mask = mask) + x 213 | x = ff(x) + x 214 | 215 | return self.to_logits(x) 216 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'fast-transformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.4', 7 | license='MIT', 8 | description = 'Fast Transformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/fast-transformer-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'deep learning', 15 | 'transformers' 16 | ], 17 | install_requires=[ 18 | 'einops>=0.3', 19 | 'rotary-embedding-torch', 20 | 'torch>=1.6' 21 | ], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3.6', 28 | ], 29 | ) 30 | --------------------------------------------------------------------------------