├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── fast-transformer.png
├── fast_transformer_pytorch
├── __init__.py
└── fast_transformer_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Fast Transformer - Pytorch
4 |
5 | Implementation of Fast Transformer in Pytorch. This only work as an encoder.
6 |
7 | Yannic video
8 |
9 | AI Epiphany
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install fast-transformer-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from fast_transformer_pytorch import FastTransformer
22 |
23 | model = FastTransformer(
24 | num_tokens = 20000,
25 | dim = 512,
26 | depth = 2,
27 | max_seq_len = 4096,
28 | absolute_pos_emb = True # default uses relative positional encoding, but if that isn't working, then turn on absolute positional embedding by setting this to True
29 | )
30 |
31 | x = torch.randint(0, 20000, (1, 4096))
32 | mask = torch.ones(1, 4096).bool()
33 |
34 | logits = model(x, mask = mask) # (1, 4096, 20000)
35 | ```
36 |
37 | ## Citations
38 |
39 | ```bibtex
40 | @misc{wu2021fastformer,
41 | title = {Fastformer: Additive Attention is All You Need},
42 | author = {Chuhan Wu and Fangzhao Wu and Tao Qi and Yongfeng Huang},
43 | year = {2021},
44 | eprint = {2108.09084},
45 | archivePrefix = {arXiv},
46 | primaryClass = {cs.CL}
47 | }
48 | ```
49 |
--------------------------------------------------------------------------------
/fast-transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/fast-transformer-pytorch/8fb775fb939731fabd59e312758a57f57fa8db1e/fast-transformer.png
--------------------------------------------------------------------------------
/fast_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from fast_transformer_pytorch.fast_transformer_pytorch import FastTransformer
2 |
--------------------------------------------------------------------------------
/fast_transformer_pytorch/fast_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn, einsum
4 |
5 | from einops import rearrange, reduce
6 | from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding
7 |
8 | # helper functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def default(val, d):
14 | return val if exists(val) else d
15 |
16 | # helper classes
17 |
18 | class PreNorm(nn.Module):
19 | def __init__(self, dim, fn):
20 | super().__init__()
21 | self.norm = nn.LayerNorm(dim)
22 | self.fn = fn
23 |
24 | def forward(self, x, **kwargs):
25 | x = self.norm(x)
26 | return self.fn(x, **kwargs)
27 |
28 | # blocks
29 |
30 | def FeedForward(dim, mult = 4):
31 | return nn.Sequential(
32 | nn.Linear(dim, dim * mult),
33 | nn.GELU(),
34 | nn.Linear(dim * mult, dim)
35 | )
36 |
37 | class FastAttention(nn.Module):
38 | def __init__(
39 | self,
40 | dim,
41 | *,
42 | heads = 8,
43 | dim_head = 64,
44 | max_seq_len = None,
45 | pos_emb = None
46 | ):
47 | super().__init__()
48 | inner_dim = heads * dim_head
49 | self.heads = heads
50 | self.scale = dim_head ** -0.5
51 |
52 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
53 |
54 | # rotary positional embedding
55 |
56 | assert not (exists(pos_emb) and not exists(max_seq_len)), 'max_seq_len must be passed in if to use rotary positional embeddings'
57 |
58 | self.pos_emb = pos_emb
59 | self.max_seq_len = max_seq_len
60 |
61 | # if using relative positional encoding, make sure to reduce pairs of consecutive feature dimension before doing projection to attention logits
62 |
63 | kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2
64 |
65 | self.to_q_attn_logits = nn.Linear(dim_head, 1, bias = False) # for projecting queries to query attention logits
66 | self.to_k_attn_logits = nn.Linear(dim_head // kv_attn_proj_divisor, 1, bias = False) # for projecting keys to key attention logits
67 |
68 | # final transformation of values to "r" as in the paper
69 |
70 | self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head)
71 |
72 | self.to_out = nn.Linear(inner_dim, dim)
73 |
74 | def forward(self, x, mask = None):
75 | n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb)
76 |
77 | qkv = self.to_qkv(x).chunk(3, dim = -1)
78 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
79 |
80 | mask_value = -torch.finfo(x.dtype).max
81 | mask = rearrange(mask, 'b n -> b () n')
82 |
83 | # if relative positional encoding is needed
84 |
85 | if use_rotary_emb:
86 | freqs = self.pos_emb(torch.arange(self.max_seq_len, device = device), cache_key = self.max_seq_len)
87 | freqs = rearrange(freqs[:n], 'n d -> () () n d')
88 | q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v))
89 | else:
90 | q_aggr, k_aggr, v_aggr = q, k, v
91 |
92 | # calculate query attention logits
93 |
94 | q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale
95 | q_attn_logits = q_attn_logits.masked_fill(~mask, mask_value)
96 | q_attn = q_attn_logits.softmax(dim = -1)
97 |
98 | # calculate global query token
99 |
100 | global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr)
101 | global_q = rearrange(global_q, 'b h d -> b h () d')
102 |
103 | # bias keys with global query token
104 |
105 | k = k * global_q
106 |
107 | # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
108 |
109 | if use_rotary_emb:
110 | k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r = 2)
111 |
112 | # now calculate key attention logits
113 |
114 | k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale
115 | k_attn_logits = k_attn_logits.masked_fill(~mask, mask_value)
116 | k_attn = k_attn_logits.softmax(dim = -1)
117 |
118 | # calculate global key token
119 |
120 | global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr)
121 | global_k = rearrange(global_k, 'b h d -> b h () d')
122 |
123 | # bias the values
124 |
125 | u = v_aggr * global_k
126 |
127 | # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
128 |
129 | if use_rotary_emb:
130 | u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r = 2)
131 |
132 | # transformation step
133 |
134 | r = self.to_r(u)
135 |
136 | # paper then says to add the queries as a residual
137 |
138 | r = r + q
139 |
140 | # combine heads
141 |
142 | r = rearrange(r, 'b h n d -> b n (h d)')
143 | return self.to_out(r)
144 |
145 | # main class
146 |
147 | class FastTransformer(nn.Module):
148 | def __init__(
149 | self,
150 | *,
151 | num_tokens,
152 | dim,
153 | depth,
154 | max_seq_len,
155 | heads = 8,
156 | dim_head = 64,
157 | ff_mult = 4,
158 | absolute_pos_emb = False
159 | ):
160 | super().__init__()
161 | self.token_emb = nn.Embedding(num_tokens, dim)
162 |
163 | # positional embeddings
164 |
165 | self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None
166 |
167 | layer_pos_emb = None
168 | if not absolute_pos_emb:
169 | assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings'
170 | layer_pos_emb = RotaryEmbedding(dim_head // 2)
171 |
172 | # layers
173 |
174 | self.layers = nn.ModuleList([])
175 |
176 | for _ in range(depth):
177 | attn = FastAttention(dim, dim_head = dim_head, heads = heads, pos_emb = layer_pos_emb, max_seq_len = max_seq_len)
178 | ff = FeedForward(dim, mult = ff_mult)
179 |
180 | self.layers.append(nn.ModuleList([
181 | PreNorm(dim, attn),
182 | PreNorm(dim, ff)
183 | ]))
184 |
185 | # weight tie projections across all layers
186 |
187 | first_block, _ = self.layers[0]
188 | for block, _ in self.layers[1:]:
189 | block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits
190 | block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits
191 |
192 | # to logits
193 |
194 | self.to_logits = nn.Sequential(
195 | nn.LayerNorm(dim),
196 | nn.Linear(dim, num_tokens)
197 | )
198 |
199 | def forward(
200 | self,
201 | x,
202 | mask = None
203 | ):
204 | n, device = x.shape[1], x.device
205 | x = self.token_emb(x)
206 |
207 | if exists(self.abs_pos_emb):
208 | pos_emb = self.abs_pos_emb(torch.arange(n, device = device))
209 | x = x + rearrange(pos_emb, 'n d -> () n d')
210 |
211 | for attn, ff in self.layers:
212 | x = attn(x, mask = mask) + x
213 | x = ff(x) + x
214 |
215 | return self.to_logits(x)
216 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'fast-transformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'Fast Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/fast-transformer-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning',
15 | 'transformers'
16 | ],
17 | install_requires=[
18 | 'einops>=0.3',
19 | 'rotary-embedding-torch',
20 | 'torch>=1.6'
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------