├── feedback_transformer_pytorch ├── __init__.py └── feedback_transformer_pytorch.py ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── LICENSE ├── .gitignore └── README.md /feedback_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from feedback_transformer_pytorch.feedback_transformer_pytorch import FeedbackTransformer 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'feedback-transformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.11', 7 | license='MIT', 8 | description = 'Implementation of Feedback Transformer in Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/feedback-transformer-pytorch', 12 | keywords = [ 13 | 'attention', 14 | 'artificial intelligence', 15 | 'transformer', 16 | 'deep learning', 17 | 'memory' 18 | ], 19 | install_requires=[ 20 | 'torch>=1.6', 21 | 'einops' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Feedback Transformer - Pytorch 2 | 3 | Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step. 4 | 5 | The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work. 6 | 7 | Yannic Kilcher video 8 | 9 | I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings. 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install feedback-transformer-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from feedback_transformer_pytorch import FeedbackTransformer 22 | 23 | model = FeedbackTransformer( 24 | num_tokens = 20000, # number of tokens 25 | dim = 512, # dimension 26 | depth = 6, # depth 27 | seq_len = 2, # the sequence length of each segment or window 28 | mem_len = 256, # length of the memory buffer 29 | dim_head = 64, # dimension of each head 30 | heads = 8, # number of heads 31 | attn_dropout = 0.1, # attention dropout 32 | ff_dropout = 0.1 # feedforward dropout 33 | ).cuda() 34 | 35 | x = torch.randint(0, 20000, (2, 64)).cuda() 36 | model(x) # (2, 64, 20000) 37 | ``` 38 | 39 | If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on `.forward` 40 | 41 | ```python 42 | import torch 43 | from feedback_transformer_pytorch import FeedbackTransformer 44 | 45 | model = FeedbackTransformer( 46 | num_tokens = 20000, 47 | dim = 512, 48 | depth = 6, 49 | seq_len = 32, 50 | mem_len = 256 51 | ).cuda() 52 | 53 | x1 = torch.randint(0, 20000, (2, 32)).cuda() 54 | x2 = torch.randint(0, 20000, (2, 32)).cuda() 55 | x3 = torch.randint(0, 20000, (2, 32)).cuda() 56 | 57 | out1, mem1 = model(x1, return_memory = True) 58 | out2, mem2 = model(x2, memory = mem1, return_memory = True) 59 | out3, mem3 = model(x3, memory = mem2, return_memory = True) # (2, 32, 20000) 60 | ``` 61 | 62 | ## Citations 63 | 64 | ```bibtex 65 | @misc{fan2021addressing, 66 | title = {Addressing Some Limitations of Transformers with Feedback Memory}, 67 | author = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar}, 68 | year = {2021}, 69 | eprint = {2002.09402}, 70 | archivePrefix = {arXiv}, 71 | primaryClass = {cs.LG} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /feedback_transformer_pytorch/feedback_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import namedtuple 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | # constants 10 | 11 | Memory = namedtuple('Memory', ['keys', 'values']) 12 | 13 | # helpers 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | def safe_cat(arr, el, dim = 1): 22 | if not exists(arr): 23 | return el 24 | return torch.cat((arr, el), dim = dim) 25 | 26 | # positional embedding 27 | 28 | class RelativePositionBias(nn.Module): 29 | def __init__( 30 | self, 31 | causal = False, 32 | num_buckets = 32, 33 | max_distance = 128, 34 | heads = 8 35 | ): 36 | super().__init__() 37 | self.causal = causal 38 | self.num_buckets = num_buckets 39 | self.max_distance = max_distance 40 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 41 | 42 | @staticmethod 43 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): 44 | ret = 0 45 | n = -relative_position 46 | if not causal: 47 | num_buckets //= 2 48 | ret += (n < 0).long() * num_buckets 49 | n = torch.abs(n) 50 | else: 51 | n = torch.max(n, torch.zeros_like(n)) 52 | 53 | max_exact = num_buckets // 2 54 | is_small = n < max_exact 55 | 56 | val_if_large = max_exact + ( 57 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 58 | ).long() 59 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 60 | 61 | ret += torch.where(is_small, n, val_if_large) 62 | return ret 63 | 64 | def forward(self, qk_dots): 65 | i, j, device = *qk_dots.shape[-2:], qk_dots.device 66 | q_pos = torch.arange(i, dtype = torch.long, device = device) 67 | k_pos = torch.arange(j, dtype = torch.long, device = device) 68 | rel_pos = k_pos[None, :] - q_pos[:, None] 69 | rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) 70 | values = self.relative_attention_bias(rp_bucket) 71 | bias = rearrange(values, 'i j h -> () h i j') 72 | return bias 73 | 74 | # helper classes 75 | 76 | class Residual(nn.Module): 77 | def __init__(self, fn): 78 | super().__init__() 79 | self.fn = fn 80 | 81 | def forward(self, x, **kwargs): 82 | return self.fn(x, **kwargs) + x 83 | 84 | class PreNorm(nn.Module): 85 | def __init__(self, dim, fn): 86 | super().__init__() 87 | self.fn = fn 88 | self.norm = nn.LayerNorm(dim) 89 | 90 | def forward(self, x, **kwargs): 91 | x = self.norm(x) 92 | return self.fn(x, **kwargs) 93 | 94 | class SkipIf(nn.Module): 95 | def __init__(self, cond, fn): 96 | super().__init__() 97 | self.cond = cond 98 | self.fn = fn 99 | 100 | def forward(self, x, *args, **kwargs): 101 | if self.cond(x, *args, **kwargs): 102 | return x 103 | return self.fn(x, *args, **kwargs) 104 | 105 | # feedforward 106 | 107 | class GEGLU(nn.Module): 108 | def forward(self, x): 109 | x, gate = x.chunk(2, dim = -1) 110 | return F.gelu(gate) * x 111 | 112 | class FeedForward(nn.Module): 113 | def __init__( 114 | self, 115 | *, 116 | dim, 117 | mult = 4, 118 | dropout = 0. 119 | ): 120 | super().__init__() 121 | self.net = nn.Sequential( 122 | nn.Linear(dim, dim * mult * 2), 123 | GEGLU(), 124 | nn.Dropout(dropout), 125 | nn.Linear(dim * mult, dim) 126 | ) 127 | 128 | def forward(self, x): 129 | return self.net(x) 130 | 131 | # attention 132 | 133 | class Attention(nn.Module): 134 | def __init__( 135 | self, 136 | *, 137 | dim, 138 | heads = 8, 139 | dim_head = 64, 140 | dropout = 0. 141 | ): 142 | super().__init__() 143 | self.heads = heads 144 | self.scale = dim_head ** -0.5 145 | 146 | inner_dim = dim_head * heads 147 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 148 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 149 | self.to_out = nn.Linear(inner_dim, dim) 150 | 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | def forward(self, x, memory, pos_emb = None): 154 | h, n, device = self.heads, x.shape[1], x.device 155 | 156 | self_attend = n > 1 # only self attend if going at greater than 1 token at a time 157 | 158 | q = self.to_q(x) * self.scale 159 | 160 | k, v = memory if exists(memory) else (None, None) 161 | 162 | if self_attend: 163 | self_k, self_v = self.to_kv(x).chunk(2, dim = -1) 164 | k = safe_cat(k, self_k, dim = 1) 165 | v = safe_cat(v, self_v, dim = 1) 166 | 167 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 168 | 169 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 170 | i, j = sim.shape[-2:] 171 | 172 | if exists(pos_emb): 173 | sim = sim + pos_emb(sim) 174 | 175 | if self_attend: 176 | causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() 177 | causal_mask = rearrange(causal_mask, 'i j -> () () i j') 178 | mask_value = -torch.finfo(q.dtype).max 179 | sim.masked_fill_(causal_mask, mask_value) 180 | 181 | attn = sim.softmax(dim = -1) 182 | attn = self.dropout(attn) 183 | 184 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 185 | out = rearrange(out, 'b h n d -> b n (h d)') 186 | return self.to_out(out) 187 | 188 | # main class 189 | 190 | class FeedbackTransformer(nn.Module): 191 | def __init__( 192 | self, 193 | *, 194 | num_tokens, 195 | dim, 196 | depth, 197 | mem_len, 198 | seq_len = 2, 199 | heads = 8, 200 | dim_head = 64, 201 | attn_dropout = 0., 202 | ff_dropout = 0., 203 | keep_last_hidden = False 204 | ): 205 | super().__init__() 206 | self.seq_len = seq_len 207 | self.mem_len = mem_len 208 | 209 | self.token_emb = nn.Embedding(num_tokens, dim) 210 | self.pos_emb = RelativePositionBias(causal = True, heads = heads) 211 | 212 | # main layers 213 | 214 | self.layers = nn.ModuleList([]) 215 | shared_kv_proj = None 216 | 217 | for _ in range(depth): 218 | attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = attn_dropout) 219 | ff = FeedForward(dim = dim, dropout = ff_dropout) 220 | 221 | shared_kv_proj = default(shared_kv_proj, attn.to_kv) 222 | attn.to_kv = shared_kv_proj 223 | 224 | attn, ff = map(lambda fn: Residual(PreNorm(dim, fn)), (attn, ff)) 225 | 226 | if seq_len == 1: 227 | memory_is_empty = lambda *args, **kwargs: not exists(kwargs['memory']) 228 | attn = SkipIf(memory_is_empty, attn) 229 | 230 | self.layers.append(nn.ModuleList([ 231 | attn, 232 | ff 233 | ])) 234 | 235 | # memory parameters 236 | 237 | self.layer_weight = nn.Parameter(torch.ones(depth + 1)) 238 | self.shared_kv_proj = shared_kv_proj 239 | self.keep_last_hidden = keep_last_hidden 240 | 241 | # final projection to logits 242 | 243 | self.to_logits = nn.Sequential( 244 | nn.LayerNorm(dim), 245 | nn.Linear(dim, num_tokens) 246 | ) 247 | 248 | def forward(self, x, memory = None, return_memory = False): 249 | b, n, device = *x.shape, x.device 250 | 251 | x = self.token_emb(x) 252 | 253 | memory_keys = None 254 | memory_values = None 255 | 256 | if exists(memory): 257 | memory_keys, memory_values = memory 258 | 259 | outputs = [] 260 | 261 | # calculate weighting of layers for storing to memory 262 | 263 | layer_weight = self.layer_weight.softmax(dim = -1) 264 | layer_weight = rearrange(layer_weight, 'd -> d () () ()') 265 | 266 | for x in x.split(self.seq_len, dim = 1): 267 | hiddens = [x] 268 | 269 | # prepare memory for attention, if it exists 270 | 271 | memory = None 272 | if exists(memory_keys): 273 | memory = (memory_keys, memory_values) 274 | 275 | for attn, ff in self.layers: 276 | 277 | x = attn(x, memory = memory, pos_emb = self.pos_emb) 278 | x = ff(x) 279 | 280 | hiddens.append(x) 281 | 282 | outputs.append(x) 283 | 284 | # calculate new memory key / values and store to FIFO queue 285 | 286 | if self.keep_last_hidden: # secret option for only keeping last hidden layer, as in paper 287 | agg_hiddens = hiddens[-1] 288 | else: 289 | hiddens = torch.stack(hiddens) 290 | agg_hiddens = (hiddens * layer_weight).sum(dim = 0) 291 | 292 | # pre-calculate memory key / values and store to buffer 293 | 294 | mem_k, mem_v = self.shared_kv_proj(agg_hiddens).chunk(2, dim = -1) 295 | memory_keys = safe_cat(memory_keys, mem_k, dim = 1) 296 | memory_values = safe_cat(memory_values, mem_v, dim = 1) 297 | 298 | # enforce max length on memory buffer 299 | 300 | memory_keys = memory_keys[:, -self.mem_len:] 301 | memory_values = memory_values[:, -self.mem_len:] 302 | 303 | x = torch.cat((outputs), dim = 1) 304 | out = self.to_logits(x) 305 | 306 | if not return_memory: 307 | return out 308 | 309 | return out, Memory(memory_keys, memory_values) 310 | --------------------------------------------------------------------------------