├── .github └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── infini-attention.png ├── infini_transformer_pytorch ├── __init__.py ├── infini_transformer.py └── wrapper.py ├── pyproject.toml ├── tests └── test_readme.py └── train.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests the examples in README 2 | on: push 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - name: Install Python 10 | uses: actions/setup-python@v4 11 | - name: Install the latest version of rye 12 | uses: eifinger/setup-rye@v2 13 | - name: Use UV instead of pip 14 | run: rye config --set-bool behavior.use-uv=true 15 | - name: Install dependencies 16 | run: | 17 | rye sync 18 | - name: Run pytest 19 | run: rye run pytest tests/test_readme.py 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/astral-sh/ruff-pre-commit 4 | rev: v0.0.278 5 | hooks: 6 | - id: ruff 7 | args: [ --fix, --exit-non-zero-on-fix] 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Infini-Transformer - Pytorch 4 | 5 | Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks. 6 | 7 | Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal. 8 | 9 | Yannic Kilcher's explanation 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install infini-transformer-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | from infini_transformer_pytorch import InfiniTransformer 22 | 23 | transformer = InfiniTransformer( 24 | num_tokens = 256, 25 | dim = 512, 26 | depth = 8, 27 | dim_head = 128, # high head dimension may be part of the reason they got good results (kv has high capacity) 28 | heads = 8, 29 | use_mem_delta_rule = True 30 | ) 31 | 32 | x = torch.randint(0, 256, (1, 1024)) 33 | 34 | logits1, _, mem1 = transformer(x, return_new_memories = False) 35 | logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False) 36 | logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True) 37 | 38 | ``` 39 | 40 | Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with `InfiniTransformerWrapper` 41 | 42 | ```python 43 | import torch 44 | 45 | from infini_transformer_pytorch import ( 46 | InfiniTransformer, 47 | InfiniTransformerWrapper 48 | ) 49 | 50 | # model and wrapper 51 | 52 | model = InfiniTransformer( 53 | num_tokens = 256, 54 | dim = 512, 55 | depth = 8, 56 | dim_head = 128, 57 | heads = 8, 58 | use_mem_delta_rule = True 59 | ) 60 | 61 | wrapper = InfiniTransformerWrapper( 62 | model, 63 | segment_length = 512, 64 | detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories 65 | ).cuda() 66 | 67 | # mock input 68 | 69 | seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence 70 | 71 | # training 72 | 73 | loss = wrapper( 74 | seq, 75 | backward = True # will automatically segment and accumulate gradients when it detaches the memories 76 | ) 77 | 78 | # after much data... 79 | 80 | # calculating eval loss 81 | 82 | with torch.no_grad(): 83 | wrapper.eval() 84 | eval_loss = wrapper(seq) 85 | 86 | # generating is as easy as 87 | 88 | output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1]) 89 | 90 | output.shape # (2, 8192 - 1) 91 | ``` 92 | 93 | ## Testing 94 | 95 | Train an autoregressive enwik8 96 | 97 | ```bash 98 | $ python train.py 99 | ``` 100 | 101 | ## Todo 102 | 103 | - [ ] `detach_mems_every_num_segments` hyperparameter is too confusing, get rid of it 104 | - [ ] experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before 105 | - [x] working example with enwik8 106 | 107 | ## Citations 108 | 109 | ```bibtex 110 | @inproceedings{Munkhdalai2024LeaveNC, 111 | title = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention}, 112 | author = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal}, 113 | year = {2024}, 114 | url = {https://api.semanticscholar.org/CorpusID:269033427} 115 | } 116 | ``` 117 | 118 | ```bibtex 119 | @article{Yang2024ParallelizingLT, 120 | title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, 121 | author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim}, 122 | journal = {ArXiv}, 123 | year = {2024}, 124 | volume = {abs/2406.06484}, 125 | url = {https://api.semanticscholar.org/CorpusID:270371554} 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/infini-transformer-pytorch/d5a2d89ef8da9340bae0761d18630d6fc0e90460/data/enwik8.gz -------------------------------------------------------------------------------- /infini-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/infini-transformer-pytorch/d5a2d89ef8da9340bae0761d18630d6fc0e90460/infini-attention.png -------------------------------------------------------------------------------- /infini_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from infini_transformer_pytorch.infini_transformer import ( 2 | InfiniTransformer, 3 | FastweightMemory, 4 | detach_memories_, 5 | detach_cached_kv_ 6 | ) 7 | 8 | from infini_transformer_pytorch.wrapper import ( 9 | InfiniTransformerWrapper 10 | ) 11 | 12 | __all__ = [ 13 | InfiniTransformer, 14 | FastweightMemory, 15 | InfiniTransformerWrapper, 16 | detach_memories_, 17 | detach_cached_kv_ 18 | ] 19 | -------------------------------------------------------------------------------- /infini_transformer_pytorch/infini_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Tuple, List, NamedTuple 3 | 4 | import torch 5 | from torch import nn, Tensor 6 | import torch.nn.functional as F 7 | from torch.nn import Module, ModuleList 8 | 9 | from einops import einsum, rearrange, reduce 10 | from einops.layers.torch import Rearrange 11 | 12 | from rotary_embedding_torch import RotaryEmbedding 13 | 14 | # constants 15 | 16 | class Memories(NamedTuple): 17 | kv_mem: Tensor 18 | k_norm: Tensor 19 | 20 | class TransformerReturn(NamedTuple): 21 | logits: Tensor 22 | cached_kvs: List[Tensor] | None 23 | past_memories: List[Memories] | None 24 | 25 | # helpers 26 | 27 | def exists(v): 28 | return v is not None 29 | 30 | def default(v, d): 31 | return v if exists(v) else d 32 | 33 | def detach_memories_(memories: List[Memories]): 34 | for (mem_kv, mem_norm) in memories: 35 | mem_kv.detach_() 36 | mem_norm.detach_() 37 | 38 | def detach_cached_kv_(cached_kvs: List[Tensor]): 39 | for cached_kv in cached_kvs: 40 | cached_kv.detach_() 41 | 42 | # classes 43 | 44 | class RMSNorm(Module): 45 | def __init__(self, dim): 46 | super().__init__() 47 | self.scale = dim ** 0.5 48 | self.gamma = nn.Parameter(torch.ones(dim)) 49 | 50 | def forward(self, x): 51 | return F.normalize(x, dim = -1) * self.scale * self.gamma 52 | 53 | class FeedForward(Module): 54 | def __init__( 55 | self, 56 | dim, 57 | mult = 4, 58 | dropout = 0. 59 | ): 60 | super().__init__() 61 | dim_inner = int(mult * dim * 2 / 3) 62 | 63 | self.norm = RMSNorm(dim) 64 | self.proj_in = nn.Linear(dim, dim_inner * 2) 65 | self.proj_out = nn.Linear(dim_inner, dim) 66 | 67 | self.dropout = nn.Dropout(dropout) 68 | 69 | def forward(self, x): 70 | x = self.norm(x) 71 | x, gates = self.proj_in(x).chunk(2, dim = -1) 72 | x = F.gelu(gates) * x 73 | x = self.dropout(x) 74 | return self.proj_out(x) 75 | 76 | # fastweight memory 77 | 78 | def retrieve_from_kv_memories(t, past_memories: Memories, eps = 1e-10): 79 | past_memories_kv, past_memories_norm = past_memories 80 | 81 | numer = einsum(t, past_memories_kv, 'b h n dk, b h dk dv -> b h n dv') 82 | denom = einsum(t, past_memories_norm, 'b h n d, b h d -> b h n') 83 | 84 | denom = rearrange(denom, '... -> ... 1') 85 | return numer / denom.clamp(min = eps) # eq (3) 86 | 87 | class FastweightMemory(Module): 88 | def __init__( 89 | self, 90 | heads: int, 91 | head_gate_init_value = 10., 92 | use_mem_delta_rule = False, 93 | ): 94 | super().__init__() 95 | self.use_mem_delta_rule = use_mem_delta_rule 96 | self.head_gates = nn.Parameter(torch.ones(heads) * head_gate_init_value) 97 | 98 | def create_new_memories( 99 | self, 100 | keys: Tensor, 101 | values: Tensor, 102 | past_memories: Memories | None = None, 103 | weights: Tensor | None = None 104 | ) -> Memories: 105 | 106 | # Katharopoulos linear attention activation 107 | 108 | keys = F.elu(keys) + 1 109 | 110 | # create the next memories 111 | 112 | if exists(past_memories) and self.use_mem_delta_rule: 113 | delta_v = retrieve_from_kv_memories(keys, past_memories) 114 | 115 | # eq (5) - the delta rule 116 | 117 | if exists(weights): 118 | # `weights` correspond to beta from deltanet, controlling how much of a given value is stored to fastweight memories 119 | values = values.lerp(delta_v, weights) 120 | 121 | diff_values = values - delta_v 122 | else: 123 | 124 | if exists(weights): 125 | values = values * (1. - weights) 126 | 127 | diff_values = values 128 | 129 | new_memories_kv = einsum(keys, diff_values, '... n dk, ... n dv -> ... dk dv') 130 | new_memories_norm = reduce(keys, 'b h n d -> b h d', 'sum') 131 | 132 | if exists(past_memories): 133 | past_memories_kv, past_memories_norm = past_memories 134 | 135 | new_memories_kv = new_memories_kv + past_memories_kv # eq (4) 136 | new_memories_norm = new_memories_norm + past_memories_norm # eq (4) 137 | 138 | return Memories(new_memories_kv, new_memories_norm) 139 | 140 | def retrieve_and_add_to_output( 141 | self, 142 | out: Tensor, 143 | queries: Tensor, 144 | past_memories: Memories | None = None 145 | ) -> Tensor: 146 | 147 | if not exists(past_memories): 148 | return out 149 | 150 | # the main contribution of the paper 151 | # Katharopoulos linear attention to kv memory of shape (batch, heads, dim keys, dim values) 152 | # it makes sense the author would try this, as he is ex-shmidhuber lab (linear transformers are fast weights paper) 153 | 154 | queries = F.elu(queries) + 1 155 | 156 | # retrieve from past memories 157 | 158 | mem_out = retrieve_from_kv_memories(queries, past_memories) 159 | 160 | # combine the current timestep output of queries with the outputs querying the past 'compressed' key/value memories 161 | # in paper, they use a sigmoid gating scheme with learned gate per head 162 | 163 | gates = rearrange(self.head_gates, 'h -> h 1 1') 164 | gates = gates.sigmoid() 165 | 166 | out = out * gates + mem_out * (1. - gates) # eq (6) - figure 3 shows how heads emergently specialize to look either at the present, past, or a bit of both 167 | 168 | return out 169 | 170 | # attention 171 | 172 | class CausalAttention(Module): 173 | def __init__( 174 | self, 175 | dim, 176 | *, 177 | dim_head = 128, 178 | heads = 8, 179 | dropout = 0., 180 | head_gate_init_value = 10., 181 | use_mem_delta_rule = False, 182 | learned_delta_update = False 183 | ): 184 | super().__init__() 185 | dim_inner = dim_head * heads 186 | self.scale = dim_head ** -0.5 187 | self.norm = RMSNorm(dim) 188 | 189 | self.rotary_emb = RotaryEmbedding(dim_head) 190 | 191 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) 192 | self.to_out = nn.Linear(dim_inner, dim, bias = False) 193 | 194 | self.dropout = nn.Dropout(dropout) 195 | 196 | self.split_heads = Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) 197 | self.merge_heads = Rearrange('b h n d -> b n (h d)') 198 | 199 | # this corresponds to the learned beta in Deltanet from Yang et al. https://arxiv.org/abs/2406.06484 200 | 201 | self.to_learned_delta_update_weights = nn.Sequential( 202 | nn.Linear(dim, heads, bias = False), 203 | Rearrange('b n h -> b h n 1'), 204 | nn.Sigmoid() 205 | ) if learned_delta_update else None 206 | 207 | self.fastweight_mem = FastweightMemory( 208 | heads = heads, 209 | head_gate_init_value = head_gate_init_value, 210 | use_mem_delta_rule = use_mem_delta_rule, 211 | ) 212 | 213 | def forward( 214 | self, 215 | x, 216 | cached_kv: Tensor | None = None, 217 | past_memories: Memories | None = None, 218 | return_new_memories = False, 219 | eps = 1e-10 220 | ) -> Tuple[Tensor, Tensor, Memories]: 221 | """ 222 | ein notation: 223 | 224 | b - batch 225 | h - heads 226 | n - sequence 227 | i - source sequence (q) 228 | j - target sequence (kv) 229 | d - feature dimension 230 | dk - feature dimension keys (and queries) 231 | dv - feature dimension of values 232 | """ 233 | 234 | x = self.norm(x) 235 | 236 | qkv = self.to_qkv(x) 237 | q, k, v = self.split_heads(qkv) 238 | 239 | # handle cached key / values 240 | 241 | if exists(cached_kv): 242 | cached_k, cached_v = cached_kv 243 | k = torch.cat((cached_k, k), dim = -2) 244 | v = torch.cat((cached_v, v), dim = -2) 245 | 246 | # similarity 247 | 248 | q_scaled = q * self.scale 249 | q_rotated, k_rotated = self.rotary_emb.rotate_queries_with_cached_keys(q_scaled, k) 250 | 251 | sim = einsum(q_rotated, k_rotated, '... i d, ... j d -> ... i j') 252 | 253 | # causal mask 254 | 255 | i, j = sim.shape[-2:] 256 | causal_mask = torch.ones((i, j), device = sim.device, dtype = torch.bool).triu(j - i + 1) 257 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 258 | 259 | # attend 260 | 261 | attn = sim.softmax(dim = -1) 262 | 263 | # dropout 264 | 265 | attn = self.dropout(attn) 266 | 267 | # aggregate values 268 | 269 | out = einsum(attn, v, '... i j, ... j d -> ... i d') 270 | 271 | out = self.fastweight_mem.retrieve_and_add_to_output(out, q, past_memories) 272 | 273 | # merge heads and combine 274 | 275 | out = self.merge_heads(out) 276 | out = self.to_out(out) 277 | 278 | # if new memories are not needed, early return 279 | # at inference time, kv cache up to segment length and then compress memories into kv 280 | 281 | if not return_new_memories: 282 | cached_kv = torch.stack((k, v)) 283 | 284 | return out, cached_kv, past_memories 285 | 286 | # having the network learn the strength at which to apply the delta update rule 287 | # learnt per token / head 288 | 289 | delta_update_weights = None 290 | 291 | if exists(self.to_learned_delta_update_weights): 292 | delta_update_weights = self.to_learned_delta_update_weights(x) 293 | 294 | new_memories = self.fastweight_mem.create_new_memories(k, v, past_memories, delta_update_weights) 295 | 296 | return out, None, new_memories 297 | 298 | # main class 299 | 300 | class InfiniTransformer(Module): 301 | def __init__( 302 | self, 303 | *, 304 | num_tokens, 305 | dim, 306 | depth, 307 | dim_head = 128, 308 | heads = 8, 309 | attn_dropout = 0., 310 | ff_mult = 4, 311 | ff_dropout = 0., 312 | use_mem_delta_rule = False, # in the paper, the delta rule didn't seem to do that much, but will include for completeness 313 | learned_delta_update = False, # whether to use learned delta rule 314 | ): 315 | super().__init__() 316 | 317 | self.token_emb = nn.Embedding(num_tokens, dim) 318 | 319 | self.layers = ModuleList([]) 320 | 321 | for _ in range(depth): 322 | 323 | attn = CausalAttention( 324 | dim = dim, 325 | dim_head = dim_head, 326 | heads = heads, 327 | use_mem_delta_rule = use_mem_delta_rule, 328 | learned_delta_update = learned_delta_update, 329 | dropout = attn_dropout 330 | ) 331 | 332 | ff = FeedForward( 333 | dim = dim, 334 | mult = ff_mult, 335 | dropout = ff_dropout 336 | ) 337 | 338 | self.layers.append(ModuleList([attn, ff])) 339 | 340 | self.norm = RMSNorm(dim) 341 | self.to_logits = nn.Linear(dim, num_tokens) 342 | 343 | def forward( 344 | self, 345 | x, 346 | past_memories: List[Memories] | None = None, 347 | cached_kv: List[Tensor] | None = None, 348 | return_new_memories = False, 349 | detach_memories = False 350 | ) -> TransformerReturn: 351 | 352 | x = self.token_emb(x) 353 | 354 | # handle cached key values 355 | 356 | if exists(cached_kv): 357 | x = x[:, -1:] 358 | 359 | new_cached_kv = [] 360 | cached_kv_iter = iter(default(cached_kv, [])) 361 | 362 | # iterator for past compressed memories 363 | 364 | new_memories = [] 365 | past_memories_iter = iter(default(past_memories, [])) 366 | 367 | # going through layers of infini-transformer 368 | 369 | for attn, ff in self.layers: 370 | 371 | attn_out, layer_cached_kv, layer_new_memories = attn( 372 | x, 373 | cached_kv = next(cached_kv_iter, None), 374 | past_memories = next(past_memories_iter, None), 375 | return_new_memories = return_new_memories 376 | ) 377 | 378 | x = attn_out + x 379 | x = ff(x) + x 380 | 381 | new_cached_kv.append(layer_cached_kv) 382 | new_memories.append(layer_new_memories) 383 | 384 | # final norm 385 | 386 | embed = self.norm(x) 387 | 388 | # logits 389 | 390 | logits = self.to_logits(embed) 391 | 392 | if detach_memories: 393 | detach_cached_kv_(new_cached_kv) 394 | 395 | if not return_new_memories: 396 | return TransformerReturn(logits, new_cached_kv, past_memories) 397 | 398 | if detach_memories: 399 | detach_memories_(new_memories) 400 | 401 | return TransformerReturn(logits, None, new_memories) 402 | -------------------------------------------------------------------------------- /infini_transformer_pytorch/wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from math import ceil 3 | from typing import Callable 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import Module 8 | import torch.nn.functional as F 9 | 10 | from einops import pack, rearrange 11 | from tqdm import tqdm 12 | 13 | from infini_transformer_pytorch.infini_transformer import ( 14 | InfiniTransformer, 15 | detach_memories_ 16 | ) 17 | 18 | # helper functions 19 | 20 | def exists(v): 21 | return v is not None 22 | 23 | def default(v, d): 24 | return v if exists(v) else d 25 | 26 | def divisible_by(num, den): 27 | return (num % den) == 0 28 | 29 | def is_empty(t: Tensor): 30 | return t.numel() == 0 31 | 32 | def round_down_multiple(n, mult): 33 | return n // mult * mult 34 | 35 | # sampling helpers 36 | 37 | def log(t, eps = 1e-20): 38 | return torch.log(t.clamp(min = eps)) 39 | 40 | def gumbel_noise(t): 41 | noise = torch.zeros_like(t).uniform_(0, 1) 42 | return -log(-log(noise)) 43 | 44 | def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10): 45 | return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim) 46 | 47 | # nucleus 48 | 49 | def top_p(logits, thres = 0.9): 50 | sorted_logits, sorted_indices = torch.sort(logits, descending = True) 51 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1) 52 | 53 | sorted_indices_to_remove = cum_probs > thres 54 | sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False) 55 | 56 | sorted_logits[sorted_indices_to_remove] = float('-inf') 57 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 58 | 59 | # topk 60 | 61 | def top_k(logits, frac_num_tokens = 0.1, k: int | None = None): 62 | num_tokens = logits.shape[-1] 63 | 64 | k = default(k, ceil(frac_num_tokens * num_tokens)) 65 | k = min(k, num_tokens) 66 | 67 | val, ind = torch.topk(logits, k) 68 | probs = torch.full_like(logits, float('-inf')) 69 | probs.scatter_(1, ind, val) 70 | return probs 71 | 72 | # class 73 | 74 | class InfiniTransformerWrapper(Module): 75 | def __init__( 76 | self, 77 | model: InfiniTransformer, 78 | segment_length = 512, 79 | detach_mems_every_num_segments = 2, 80 | ignore_index = -1 81 | ): 82 | super().__init__() 83 | self.model = model 84 | 85 | self.segment_length = segment_length 86 | self.detach_mems_every_num_segments = detach_mems_every_num_segments 87 | 88 | # loss related 89 | 90 | self.ignore_index = ignore_index 91 | 92 | @property 93 | def device(self): 94 | return next(self.model.parameters()).device 95 | 96 | @torch.no_grad() 97 | def generate( 98 | self, 99 | *, 100 | seq_len, 101 | prompt = None, 102 | batch_size = 1, 103 | temperature = 1., 104 | filter_fn: Callable = top_p, 105 | filter_kwargs: dict = dict(thres = 0.9), 106 | exclude_prompt = True, 107 | segment_length = None 108 | ): 109 | segment_length = default(segment_length, self.segment_length) 110 | device, train_state = self.device, self.training 111 | self.eval() 112 | 113 | out = default(prompt, torch.empty((batch_size, 0), device = device, dtype = torch.long)) 114 | init_len = out.shape[-1] 115 | 116 | # sample from the model token by token 117 | # keeping track of kv cache and when to compress into new memories 118 | 119 | cached_kv = None 120 | past_memories = None 121 | 122 | for curr_len in tqdm(range(init_len, seq_len)): 123 | 124 | # what is fed into the model is always at the start of the very last segment 125 | 126 | start_ind = round_down_multiple(curr_len - 1, segment_length) 127 | model_input = out[:, start_ind:] 128 | 129 | # forward the model with cached key / values and past memories 130 | 131 | logits, cached_kv, past_memories = self.model( 132 | model_input, 133 | cached_kv = cached_kv, 134 | past_memories = past_memories, 135 | return_new_memories = divisible_by(curr_len, segment_length) 136 | ) 137 | 138 | # grab the last logit 139 | 140 | logits = logits[:, -1] 141 | 142 | # filter by either topk or nucleus 143 | # and sample 144 | 145 | filtered_logits = filter_fn(logits, **filter_kwargs) 146 | sampled = gumbel_sample(filtered_logits, temperature = temperature) 147 | 148 | # concat sampled token 149 | 150 | out, _ = pack((out, sampled), 'b *') 151 | 152 | # return output 153 | 154 | if exclude_prompt: 155 | out = out[:, init_len:] 156 | 157 | self.train(train_state) 158 | return out 159 | 160 | def forward( 161 | self, 162 | seq, 163 | segment_length = None, 164 | backward = False, 165 | grad_accum_scale = 1. 166 | ): 167 | segment_length = default(segment_length, self.segment_length) 168 | 169 | seq, label = seq[:, :-1], seq[:, 1:] 170 | 171 | # put into train mode if doing backwards within forward call 172 | 173 | if backward: 174 | self.model.train() 175 | 176 | total_tokens = (label != self.ignore_index).sum().item() 177 | 178 | # split the sequence by segment length 179 | 180 | split_seq = seq.split(segment_length, dim = -1) 181 | split_label = label.split(segment_length, dim = -1) 182 | 183 | num_segments = len(split_seq) 184 | 185 | # go over each segment length and calculate cross entropy loss 186 | 187 | total_loss = 0. 188 | past_memories = None 189 | 190 | running_loss = 0. 191 | 192 | for ind, (segment_seq, segment_label) in enumerate(zip(split_seq, split_label)): 193 | segment_num = ind + 1 194 | is_last = segment_num == num_segments 195 | 196 | should_detach_memories = divisible_by(segment_num, self.detach_mems_every_num_segments) 197 | should_backward = backward and (is_last or should_detach_memories) 198 | 199 | # model forwards for logits and past memories 200 | 201 | logits, _, past_memories = self.model( 202 | segment_seq, 203 | past_memories = past_memories, 204 | return_new_memories = True 205 | ) 206 | 207 | # calculate cross entropy loss for segment 208 | 209 | segment_loss = F.cross_entropy( 210 | rearrange(logits, 'b n c -> b c n'), 211 | segment_label, 212 | reduction = 'none' 213 | ) 214 | 215 | # make sure segment losses do not include ignored index 216 | # then also make sure the segment loss is scaled 217 | 218 | segment_mask = segment_label != self.ignore_index 219 | num_segment_tokens = segment_mask.sum() 220 | frac_tokens = num_segment_tokens / total_tokens 221 | 222 | segment_loss = segment_loss[segment_mask] 223 | segment_scaled_loss = segment_loss.mean() * frac_tokens 224 | 225 | total_loss = total_loss + segment_scaled_loss 226 | running_loss = running_loss + segment_scaled_loss 227 | 228 | # perform backwards every `(num_segment * detach_mems_every_num_segments)` 229 | 230 | if should_backward: 231 | (running_loss / grad_accum_scale).backward() 232 | running_loss = 0. 233 | 234 | # detach memories if need be 235 | 236 | if should_detach_memories and not is_last: 237 | detach_memories_(past_memories) 238 | 239 | return total_loss 240 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "infini-transformer-pytorch" 3 | version = "0.2.1" 4 | description = "Infini-Transformer in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.8" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'attention mechanism', 16 | 'long context', 17 | 'memory' 18 | ] 19 | classifiers=[ 20 | 'Development Status :: 4 - Beta', 21 | 'Intended Audience :: Developers', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Programming Language :: Python :: 3.8', 25 | ] 26 | 27 | dependencies = [ 28 | 'einops>=0.8.0', 29 | 'rotary-embedding-torch>=0.6.0', 30 | 'torch>=2.0', 31 | 'tqdm' 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://pypi.org/project/infini-transformer-pytorch/" 36 | Repository = "https://github.com/lucidrains/infini-transformer-pytorch" 37 | 38 | [project.optional-dependencies] 39 | examples = [ 40 | "tqdm", 41 | "numpy" 42 | ] 43 | 44 | [build-system] 45 | requires = ["hatchling"] 46 | build-backend = "hatchling.build" 47 | 48 | [tool.pytest.ini_options] 49 | pythonpath = ["."] 50 | 51 | [tool.rye] 52 | managed = true 53 | dev-dependencies = [ 54 | "ruff>=0.4.2", 55 | "pytest>=8.2.0", 56 | ] 57 | 58 | [tool.ruff] 59 | line-length = 1000 60 | ignore-init-module-imports = true 61 | 62 | [tool.hatch.metadata] 63 | allow-direct-references = true 64 | 65 | [tool.hatch.build.targets.wheel] 66 | packages = ["infini_transformer_pytorch"] 67 | -------------------------------------------------------------------------------- /tests/test_readme.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from infini_transformer_pytorch import ( 5 | InfiniTransformer, 6 | InfiniTransformerWrapper 7 | ) 8 | 9 | def test_readme(): 10 | transformer = InfiniTransformer( 11 | num_tokens = 256, 12 | dim = 64, 13 | depth = 1, 14 | dim_head = 128, 15 | heads = 8, 16 | use_mem_delta_rule = True 17 | ) 18 | 19 | x = torch.randint(0, 256, (1, 1024)) 20 | 21 | logits1, _, mem1 = transformer(x, return_new_memories = False) 22 | logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False) 23 | logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True) 24 | 25 | def test_generate(): 26 | # model and wrapper 27 | 28 | model = InfiniTransformer( 29 | num_tokens = 256, 30 | dim = 64, 31 | depth = 1, 32 | dim_head = 128, 33 | heads = 8, 34 | use_mem_delta_rule = True 35 | ) 36 | 37 | wrapper = InfiniTransformerWrapper( 38 | model, 39 | segment_length = 32, 40 | detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories 41 | ) 42 | 43 | # mock input 44 | 45 | seq = torch.randint(0, 256, (2, 128)) # can be arbitrarily long sequence 46 | 47 | # training 48 | 49 | wrapper( 50 | seq, 51 | backward = True # will automatically segment and accumulate gradients when it detaches the memories 52 | ) 53 | 54 | # after much data... 55 | 56 | # calculating eval loss 57 | 58 | with torch.no_grad(): 59 | wrapper.eval() 60 | wrapper(seq) 61 | 62 | # generating is as easy as 63 | 64 | wrapper.generate(seq_len = 128, prompt = seq[:, :1]) 65 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from infini_transformer_pytorch import ( 2 | InfiniTransformer, 3 | InfiniTransformerWrapper 4 | ) 5 | 6 | import tqdm 7 | import gzip 8 | import numpy as np 9 | import torch 10 | from torch.optim import Adam 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 4 17 | GRADIENT_ACCUMULATE_EVERY = 4 18 | LEARNING_RATE = 2e-4 19 | VALIDATE_EVERY = 100 20 | GENERATE_EVERY = 250 21 | PRIME_LEN = 100 22 | SEQ_LEN = 1024 23 | SEGMENT_LENGTH = 128 24 | 25 | # helpers 26 | 27 | def cycle(loader): 28 | while True: 29 | for data in loader: 30 | yield data 31 | 32 | def decode_token(token): 33 | return str(chr(max(32, token))) 34 | 35 | def decode_tokens(tokens): 36 | return ''.join(list(map(decode_token, tokens))) 37 | 38 | # instantiate GPT-like decoder model 39 | 40 | model = InfiniTransformer( 41 | num_tokens = 256, 42 | dim = 512, 43 | depth = 8, 44 | dim_head = 64, 45 | heads = 8, 46 | use_mem_delta_rule = True, 47 | learned_delta_update = True 48 | ) 49 | 50 | wrapper = InfiniTransformerWrapper( 51 | model, 52 | segment_length = SEGMENT_LENGTH, 53 | detach_mems_every_num_segments = 2 54 | ).cuda() 55 | 56 | # prepare enwik8 data 57 | 58 | with gzip.open('./data/enwik8.gz') as file: 59 | x = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 60 | train_x, valid_x = np.split(x, [int(90e6)]) 61 | data_train, data_val = map(torch.from_numpy, (train_x, valid_x)) 62 | 63 | class TextSamplerDataset(Dataset): 64 | def __init__(self, data, seq_len): 65 | super().__init__() 66 | self.data = data 67 | self.seq_len = seq_len 68 | 69 | def __getitem__(self, index): 70 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 71 | full_seq = self.data[rand_start: rand_start + self.seq_len].long() 72 | return full_seq.cuda() 73 | 74 | def __len__(self): 75 | return self.data.size(0) // self.seq_len 76 | 77 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 78 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 79 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 80 | val_loader = cycle(DataLoader(val_dataset, batch_size = 1)) 81 | 82 | # optimizer 83 | 84 | optim = Adam(model.parameters(), lr = LEARNING_RATE) 85 | 86 | # training 87 | 88 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.): 89 | 90 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 91 | loss = wrapper( 92 | next(train_loader), 93 | backward = True, 94 | grad_accum_scale = GRADIENT_ACCUMULATE_EVERY ** -1. 95 | ) 96 | 97 | print(f'training loss: {loss.item()}') 98 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 99 | optim.step() 100 | optim.zero_grad() 101 | 102 | if i % VALIDATE_EVERY == 0: 103 | with torch.no_grad(): 104 | wrapper.eval() 105 | loss = wrapper(next(val_loader)) 106 | print(f'validation loss: {loss.item()}') 107 | 108 | if i % GENERATE_EVERY == 0: 109 | ids = next(val_loader)[:, :PRIME_LEN] 110 | prime = decode_tokens(ids.flatten()) 111 | print('%s \n\n %s', (prime, '*' * 100)) 112 | 113 | sample = wrapper.generate( 114 | prompt = ids, 115 | seq_len = SEQ_LEN 116 | ) 117 | 118 | decoded_string = decode_tokens(sample.flatten()) 119 | print(decoded_string) 120 | print("\n") 121 | --------------------------------------------------------------------------------