├── feedback_transformer_pytorch
├── __init__.py
└── feedback_transformer_pytorch.py
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── LICENSE
├── .gitignore
└── README.md
/feedback_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from feedback_transformer_pytorch.feedback_transformer_pytorch import FeedbackTransformer
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'feedback-transformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.11',
7 | license='MIT',
8 | description = 'Implementation of Feedback Transformer in Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/feedback-transformer-pytorch',
12 | keywords = [
13 | 'attention',
14 | 'artificial intelligence',
15 | 'transformer',
16 | 'deep learning',
17 | 'memory'
18 | ],
19 | install_requires=[
20 | 'torch>=1.6',
21 | 'einops'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.6',
29 | ],
30 | )
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Feedback Transformer - Pytorch
2 |
3 | Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.
4 |
5 | The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.
6 |
7 | Yannic Kilcher video
8 |
9 | I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install feedback-transformer-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from feedback_transformer_pytorch import FeedbackTransformer
22 |
23 | model = FeedbackTransformer(
24 | num_tokens = 20000, # number of tokens
25 | dim = 512, # dimension
26 | depth = 6, # depth
27 | seq_len = 2, # the sequence length of each segment or window
28 | mem_len = 256, # length of the memory buffer
29 | dim_head = 64, # dimension of each head
30 | heads = 8, # number of heads
31 | attn_dropout = 0.1, # attention dropout
32 | ff_dropout = 0.1 # feedforward dropout
33 | ).cuda()
34 |
35 | x = torch.randint(0, 20000, (2, 64)).cuda()
36 | model(x) # (2, 64, 20000)
37 | ```
38 |
39 | If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on `.forward`
40 |
41 | ```python
42 | import torch
43 | from feedback_transformer_pytorch import FeedbackTransformer
44 |
45 | model = FeedbackTransformer(
46 | num_tokens = 20000,
47 | dim = 512,
48 | depth = 6,
49 | seq_len = 32,
50 | mem_len = 256
51 | ).cuda()
52 |
53 | x1 = torch.randint(0, 20000, (2, 32)).cuda()
54 | x2 = torch.randint(0, 20000, (2, 32)).cuda()
55 | x3 = torch.randint(0, 20000, (2, 32)).cuda()
56 |
57 | out1, mem1 = model(x1, return_memory = True)
58 | out2, mem2 = model(x2, memory = mem1, return_memory = True)
59 | out3, mem3 = model(x3, memory = mem2, return_memory = True) # (2, 32, 20000)
60 | ```
61 |
62 | ## Citations
63 |
64 | ```bibtex
65 | @misc{fan2021addressing,
66 | title = {Addressing Some Limitations of Transformers with Feedback Memory},
67 | author = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
68 | year = {2021},
69 | eprint = {2002.09402},
70 | archivePrefix = {arXiv},
71 | primaryClass = {cs.LG}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/feedback_transformer_pytorch/feedback_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import namedtuple
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 | # constants
10 |
11 | Memory = namedtuple('Memory', ['keys', 'values'])
12 |
13 | # helpers
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | def safe_cat(arr, el, dim = 1):
22 | if not exists(arr):
23 | return el
24 | return torch.cat((arr, el), dim = dim)
25 |
26 | # positional embedding
27 |
28 | class RelativePositionBias(nn.Module):
29 | def __init__(
30 | self,
31 | causal = False,
32 | num_buckets = 32,
33 | max_distance = 128,
34 | heads = 8
35 | ):
36 | super().__init__()
37 | self.causal = causal
38 | self.num_buckets = num_buckets
39 | self.max_distance = max_distance
40 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
41 |
42 | @staticmethod
43 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
44 | ret = 0
45 | n = -relative_position
46 | if not causal:
47 | num_buckets //= 2
48 | ret += (n < 0).long() * num_buckets
49 | n = torch.abs(n)
50 | else:
51 | n = torch.max(n, torch.zeros_like(n))
52 |
53 | max_exact = num_buckets // 2
54 | is_small = n < max_exact
55 |
56 | val_if_large = max_exact + (
57 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
58 | ).long()
59 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
60 |
61 | ret += torch.where(is_small, n, val_if_large)
62 | return ret
63 |
64 | def forward(self, qk_dots):
65 | i, j, device = *qk_dots.shape[-2:], qk_dots.device
66 | q_pos = torch.arange(i, dtype = torch.long, device = device)
67 | k_pos = torch.arange(j, dtype = torch.long, device = device)
68 | rel_pos = k_pos[None, :] - q_pos[:, None]
69 | rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
70 | values = self.relative_attention_bias(rp_bucket)
71 | bias = rearrange(values, 'i j h -> () h i j')
72 | return bias
73 |
74 | # helper classes
75 |
76 | class Residual(nn.Module):
77 | def __init__(self, fn):
78 | super().__init__()
79 | self.fn = fn
80 |
81 | def forward(self, x, **kwargs):
82 | return self.fn(x, **kwargs) + x
83 |
84 | class PreNorm(nn.Module):
85 | def __init__(self, dim, fn):
86 | super().__init__()
87 | self.fn = fn
88 | self.norm = nn.LayerNorm(dim)
89 |
90 | def forward(self, x, **kwargs):
91 | x = self.norm(x)
92 | return self.fn(x, **kwargs)
93 |
94 | class SkipIf(nn.Module):
95 | def __init__(self, cond, fn):
96 | super().__init__()
97 | self.cond = cond
98 | self.fn = fn
99 |
100 | def forward(self, x, *args, **kwargs):
101 | if self.cond(x, *args, **kwargs):
102 | return x
103 | return self.fn(x, *args, **kwargs)
104 |
105 | # feedforward
106 |
107 | class GEGLU(nn.Module):
108 | def forward(self, x):
109 | x, gate = x.chunk(2, dim = -1)
110 | return F.gelu(gate) * x
111 |
112 | class FeedForward(nn.Module):
113 | def __init__(
114 | self,
115 | *,
116 | dim,
117 | mult = 4,
118 | dropout = 0.
119 | ):
120 | super().__init__()
121 | self.net = nn.Sequential(
122 | nn.Linear(dim, dim * mult * 2),
123 | GEGLU(),
124 | nn.Dropout(dropout),
125 | nn.Linear(dim * mult, dim)
126 | )
127 |
128 | def forward(self, x):
129 | return self.net(x)
130 |
131 | # attention
132 |
133 | class Attention(nn.Module):
134 | def __init__(
135 | self,
136 | *,
137 | dim,
138 | heads = 8,
139 | dim_head = 64,
140 | dropout = 0.
141 | ):
142 | super().__init__()
143 | self.heads = heads
144 | self.scale = dim_head ** -0.5
145 |
146 | inner_dim = dim_head * heads
147 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
148 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
149 | self.to_out = nn.Linear(inner_dim, dim)
150 |
151 | self.dropout = nn.Dropout(dropout)
152 |
153 | def forward(self, x, memory, pos_emb = None):
154 | h, n, device = self.heads, x.shape[1], x.device
155 |
156 | self_attend = n > 1 # only self attend if going at greater than 1 token at a time
157 |
158 | q = self.to_q(x) * self.scale
159 |
160 | k, v = memory if exists(memory) else (None, None)
161 |
162 | if self_attend:
163 | self_k, self_v = self.to_kv(x).chunk(2, dim = -1)
164 | k = safe_cat(k, self_k, dim = 1)
165 | v = safe_cat(v, self_v, dim = 1)
166 |
167 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
168 |
169 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
170 | i, j = sim.shape[-2:]
171 |
172 | if exists(pos_emb):
173 | sim = sim + pos_emb(sim)
174 |
175 | if self_attend:
176 | causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
177 | causal_mask = rearrange(causal_mask, 'i j -> () () i j')
178 | mask_value = -torch.finfo(q.dtype).max
179 | sim.masked_fill_(causal_mask, mask_value)
180 |
181 | attn = sim.softmax(dim = -1)
182 | attn = self.dropout(attn)
183 |
184 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
185 | out = rearrange(out, 'b h n d -> b n (h d)')
186 | return self.to_out(out)
187 |
188 | # main class
189 |
190 | class FeedbackTransformer(nn.Module):
191 | def __init__(
192 | self,
193 | *,
194 | num_tokens,
195 | dim,
196 | depth,
197 | mem_len,
198 | seq_len = 2,
199 | heads = 8,
200 | dim_head = 64,
201 | attn_dropout = 0.,
202 | ff_dropout = 0.,
203 | keep_last_hidden = False
204 | ):
205 | super().__init__()
206 | self.seq_len = seq_len
207 | self.mem_len = mem_len
208 |
209 | self.token_emb = nn.Embedding(num_tokens, dim)
210 | self.pos_emb = RelativePositionBias(causal = True, heads = heads)
211 |
212 | # main layers
213 |
214 | self.layers = nn.ModuleList([])
215 | shared_kv_proj = None
216 |
217 | for _ in range(depth):
218 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)
219 | ff = FeedForward(dim = dim, dropout = ff_dropout)
220 |
221 | shared_kv_proj = default(shared_kv_proj, attn.to_kv)
222 | attn.to_kv = shared_kv_proj
223 |
224 | attn, ff = map(lambda fn: Residual(PreNorm(dim, fn)), (attn, ff))
225 |
226 | if seq_len == 1:
227 | memory_is_empty = lambda *args, **kwargs: not exists(kwargs['memory'])
228 | attn = SkipIf(memory_is_empty, attn)
229 |
230 | self.layers.append(nn.ModuleList([
231 | attn,
232 | ff
233 | ]))
234 |
235 | # memory parameters
236 |
237 | self.layer_weight = nn.Parameter(torch.ones(depth + 1))
238 | self.shared_kv_proj = shared_kv_proj
239 | self.keep_last_hidden = keep_last_hidden
240 |
241 | # final projection to logits
242 |
243 | self.to_logits = nn.Sequential(
244 | nn.LayerNorm(dim),
245 | nn.Linear(dim, num_tokens)
246 | )
247 |
248 | def forward(self, x, memory = None, return_memory = False):
249 | b, n, device = *x.shape, x.device
250 |
251 | x = self.token_emb(x)
252 |
253 | memory_keys = None
254 | memory_values = None
255 |
256 | if exists(memory):
257 | memory_keys, memory_values = memory
258 |
259 | outputs = []
260 |
261 | # calculate weighting of layers for storing to memory
262 |
263 | layer_weight = self.layer_weight.softmax(dim = -1)
264 | layer_weight = rearrange(layer_weight, 'd -> d () () ()')
265 |
266 | for x in x.split(self.seq_len, dim = 1):
267 | hiddens = [x]
268 |
269 | # prepare memory for attention, if it exists
270 |
271 | memory = None
272 | if exists(memory_keys):
273 | memory = (memory_keys, memory_values)
274 |
275 | for attn, ff in self.layers:
276 |
277 | x = attn(x, memory = memory, pos_emb = self.pos_emb)
278 | x = ff(x)
279 |
280 | hiddens.append(x)
281 |
282 | outputs.append(x)
283 |
284 | # calculate new memory key / values and store to FIFO queue
285 |
286 | if self.keep_last_hidden: # secret option for only keeping last hidden layer, as in paper
287 | agg_hiddens = hiddens[-1]
288 | else:
289 | hiddens = torch.stack(hiddens)
290 | agg_hiddens = (hiddens * layer_weight).sum(dim = 0)
291 |
292 | # pre-calculate memory key / values and store to buffer
293 |
294 | mem_k, mem_v = self.shared_kv_proj(agg_hiddens).chunk(2, dim = -1)
295 | memory_keys = safe_cat(memory_keys, mem_k, dim = 1)
296 | memory_values = safe_cat(memory_values, mem_v, dim = 1)
297 |
298 | # enforce max length on memory buffer
299 |
300 | memory_keys = memory_keys[:, -self.mem_len:]
301 | memory_values = memory_values[:, -self.mem_len:]
302 |
303 | x = torch.cat((outputs), dim = 1)
304 | out = self.to_logits(x)
305 |
306 | if not return_memory:
307 | return out
308 |
309 | return out, Memory(memory_keys, memory_values)
310 |
--------------------------------------------------------------------------------