├── .github
└── workflows
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── infini-attention.png
├── infini_transformer_pytorch
├── __init__.py
├── infini_transformer.py
└── wrapper.py
├── pyproject.toml
├── tests
└── test_readme.py
└── train.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests the examples in README
2 | on: push
3 |
4 | jobs:
5 | test:
6 | runs-on: ubuntu-latest
7 | steps:
8 | - uses: actions/checkout@v4
9 | - name: Install Python
10 | uses: actions/setup-python@v4
11 | - name: Install the latest version of rye
12 | uses: eifinger/setup-rye@v2
13 | - name: Use UV instead of pip
14 | run: rye config --set-bool behavior.use-uv=true
15 | - name: Install dependencies
16 | run: |
17 | rye sync
18 | - name: Run pytest
19 | run: rye run pytest tests/test_readme.py
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 | - repo: https://github.com/astral-sh/ruff-pre-commit
4 | rev: v0.0.278
5 | hooks:
6 | - id: ruff
7 | args: [ --fix, --exit-non-zero-on-fix]
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Infini-Transformer - Pytorch
4 |
5 | Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.
6 |
7 | Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.
8 |
9 | Yannic Kilcher's explanation
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install infini-transformer-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | ```python
20 | import torch
21 | from infini_transformer_pytorch import InfiniTransformer
22 |
23 | transformer = InfiniTransformer(
24 | num_tokens = 256,
25 | dim = 512,
26 | depth = 8,
27 | dim_head = 128, # high head dimension may be part of the reason they got good results (kv has high capacity)
28 | heads = 8,
29 | use_mem_delta_rule = True
30 | )
31 |
32 | x = torch.randint(0, 256, (1, 1024))
33 |
34 | logits1, _, mem1 = transformer(x, return_new_memories = False)
35 | logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
36 | logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)
37 |
38 | ```
39 |
40 | Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with `InfiniTransformerWrapper`
41 |
42 | ```python
43 | import torch
44 |
45 | from infini_transformer_pytorch import (
46 | InfiniTransformer,
47 | InfiniTransformerWrapper
48 | )
49 |
50 | # model and wrapper
51 |
52 | model = InfiniTransformer(
53 | num_tokens = 256,
54 | dim = 512,
55 | depth = 8,
56 | dim_head = 128,
57 | heads = 8,
58 | use_mem_delta_rule = True
59 | )
60 |
61 | wrapper = InfiniTransformerWrapper(
62 | model,
63 | segment_length = 512,
64 | detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
65 | ).cuda()
66 |
67 | # mock input
68 |
69 | seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence
70 |
71 | # training
72 |
73 | loss = wrapper(
74 | seq,
75 | backward = True # will automatically segment and accumulate gradients when it detaches the memories
76 | )
77 |
78 | # after much data...
79 |
80 | # calculating eval loss
81 |
82 | with torch.no_grad():
83 | wrapper.eval()
84 | eval_loss = wrapper(seq)
85 |
86 | # generating is as easy as
87 |
88 | output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])
89 |
90 | output.shape # (2, 8192 - 1)
91 | ```
92 |
93 | ## Testing
94 |
95 | Train an autoregressive enwik8
96 |
97 | ```bash
98 | $ python train.py
99 | ```
100 |
101 | ## Todo
102 |
103 | - [ ] `detach_mems_every_num_segments` hyperparameter is too confusing, get rid of it
104 | - [ ] experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
105 | - [x] working example with enwik8
106 |
107 | ## Citations
108 |
109 | ```bibtex
110 | @inproceedings{Munkhdalai2024LeaveNC,
111 | title = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
112 | author = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
113 | year = {2024},
114 | url = {https://api.semanticscholar.org/CorpusID:269033427}
115 | }
116 | ```
117 |
118 | ```bibtex
119 | @article{Yang2024ParallelizingLT,
120 | title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length},
121 | author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
122 | journal = {ArXiv},
123 | year = {2024},
124 | volume = {abs/2406.06484},
125 | url = {https://api.semanticscholar.org/CorpusID:270371554}
126 | }
127 | ```
128 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/infini-transformer-pytorch/d5a2d89ef8da9340bae0761d18630d6fc0e90460/data/enwik8.gz
--------------------------------------------------------------------------------
/infini-attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/infini-transformer-pytorch/d5a2d89ef8da9340bae0761d18630d6fc0e90460/infini-attention.png
--------------------------------------------------------------------------------
/infini_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from infini_transformer_pytorch.infini_transformer import (
2 | InfiniTransformer,
3 | FastweightMemory,
4 | detach_memories_,
5 | detach_cached_kv_
6 | )
7 |
8 | from infini_transformer_pytorch.wrapper import (
9 | InfiniTransformerWrapper
10 | )
11 |
12 | __all__ = [
13 | InfiniTransformer,
14 | FastweightMemory,
15 | InfiniTransformerWrapper,
16 | detach_memories_,
17 | detach_cached_kv_
18 | ]
19 |
--------------------------------------------------------------------------------
/infini_transformer_pytorch/infini_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Tuple, List, NamedTuple
3 |
4 | import torch
5 | from torch import nn, Tensor
6 | import torch.nn.functional as F
7 | from torch.nn import Module, ModuleList
8 |
9 | from einops import einsum, rearrange, reduce
10 | from einops.layers.torch import Rearrange
11 |
12 | from rotary_embedding_torch import RotaryEmbedding
13 |
14 | # constants
15 |
16 | class Memories(NamedTuple):
17 | kv_mem: Tensor
18 | k_norm: Tensor
19 |
20 | class TransformerReturn(NamedTuple):
21 | logits: Tensor
22 | cached_kvs: List[Tensor] | None
23 | past_memories: List[Memories] | None
24 |
25 | # helpers
26 |
27 | def exists(v):
28 | return v is not None
29 |
30 | def default(v, d):
31 | return v if exists(v) else d
32 |
33 | def detach_memories_(memories: List[Memories]):
34 | for (mem_kv, mem_norm) in memories:
35 | mem_kv.detach_()
36 | mem_norm.detach_()
37 |
38 | def detach_cached_kv_(cached_kvs: List[Tensor]):
39 | for cached_kv in cached_kvs:
40 | cached_kv.detach_()
41 |
42 | # classes
43 |
44 | class RMSNorm(Module):
45 | def __init__(self, dim):
46 | super().__init__()
47 | self.scale = dim ** 0.5
48 | self.gamma = nn.Parameter(torch.ones(dim))
49 |
50 | def forward(self, x):
51 | return F.normalize(x, dim = -1) * self.scale * self.gamma
52 |
53 | class FeedForward(Module):
54 | def __init__(
55 | self,
56 | dim,
57 | mult = 4,
58 | dropout = 0.
59 | ):
60 | super().__init__()
61 | dim_inner = int(mult * dim * 2 / 3)
62 |
63 | self.norm = RMSNorm(dim)
64 | self.proj_in = nn.Linear(dim, dim_inner * 2)
65 | self.proj_out = nn.Linear(dim_inner, dim)
66 |
67 | self.dropout = nn.Dropout(dropout)
68 |
69 | def forward(self, x):
70 | x = self.norm(x)
71 | x, gates = self.proj_in(x).chunk(2, dim = -1)
72 | x = F.gelu(gates) * x
73 | x = self.dropout(x)
74 | return self.proj_out(x)
75 |
76 | # fastweight memory
77 |
78 | def retrieve_from_kv_memories(t, past_memories: Memories, eps = 1e-10):
79 | past_memories_kv, past_memories_norm = past_memories
80 |
81 | numer = einsum(t, past_memories_kv, 'b h n dk, b h dk dv -> b h n dv')
82 | denom = einsum(t, past_memories_norm, 'b h n d, b h d -> b h n')
83 |
84 | denom = rearrange(denom, '... -> ... 1')
85 | return numer / denom.clamp(min = eps) # eq (3)
86 |
87 | class FastweightMemory(Module):
88 | def __init__(
89 | self,
90 | heads: int,
91 | head_gate_init_value = 10.,
92 | use_mem_delta_rule = False,
93 | ):
94 | super().__init__()
95 | self.use_mem_delta_rule = use_mem_delta_rule
96 | self.head_gates = nn.Parameter(torch.ones(heads) * head_gate_init_value)
97 |
98 | def create_new_memories(
99 | self,
100 | keys: Tensor,
101 | values: Tensor,
102 | past_memories: Memories | None = None,
103 | weights: Tensor | None = None
104 | ) -> Memories:
105 |
106 | # Katharopoulos linear attention activation
107 |
108 | keys = F.elu(keys) + 1
109 |
110 | # create the next memories
111 |
112 | if exists(past_memories) and self.use_mem_delta_rule:
113 | delta_v = retrieve_from_kv_memories(keys, past_memories)
114 |
115 | # eq (5) - the delta rule
116 |
117 | if exists(weights):
118 | # `weights` correspond to beta from deltanet, controlling how much of a given value is stored to fastweight memories
119 | values = values.lerp(delta_v, weights)
120 |
121 | diff_values = values - delta_v
122 | else:
123 |
124 | if exists(weights):
125 | values = values * (1. - weights)
126 |
127 | diff_values = values
128 |
129 | new_memories_kv = einsum(keys, diff_values, '... n dk, ... n dv -> ... dk dv')
130 | new_memories_norm = reduce(keys, 'b h n d -> b h d', 'sum')
131 |
132 | if exists(past_memories):
133 | past_memories_kv, past_memories_norm = past_memories
134 |
135 | new_memories_kv = new_memories_kv + past_memories_kv # eq (4)
136 | new_memories_norm = new_memories_norm + past_memories_norm # eq (4)
137 |
138 | return Memories(new_memories_kv, new_memories_norm)
139 |
140 | def retrieve_and_add_to_output(
141 | self,
142 | out: Tensor,
143 | queries: Tensor,
144 | past_memories: Memories | None = None
145 | ) -> Tensor:
146 |
147 | if not exists(past_memories):
148 | return out
149 |
150 | # the main contribution of the paper
151 | # Katharopoulos linear attention to kv memory of shape (batch, heads, dim keys, dim values)
152 | # it makes sense the author would try this, as he is ex-shmidhuber lab (linear transformers are fast weights paper)
153 |
154 | queries = F.elu(queries) + 1
155 |
156 | # retrieve from past memories
157 |
158 | mem_out = retrieve_from_kv_memories(queries, past_memories)
159 |
160 | # combine the current timestep output of queries with the outputs querying the past 'compressed' key/value memories
161 | # in paper, they use a sigmoid gating scheme with learned gate per head
162 |
163 | gates = rearrange(self.head_gates, 'h -> h 1 1')
164 | gates = gates.sigmoid()
165 |
166 | out = out * gates + mem_out * (1. - gates) # eq (6) - figure 3 shows how heads emergently specialize to look either at the present, past, or a bit of both
167 |
168 | return out
169 |
170 | # attention
171 |
172 | class CausalAttention(Module):
173 | def __init__(
174 | self,
175 | dim,
176 | *,
177 | dim_head = 128,
178 | heads = 8,
179 | dropout = 0.,
180 | head_gate_init_value = 10.,
181 | use_mem_delta_rule = False,
182 | learned_delta_update = False
183 | ):
184 | super().__init__()
185 | dim_inner = dim_head * heads
186 | self.scale = dim_head ** -0.5
187 | self.norm = RMSNorm(dim)
188 |
189 | self.rotary_emb = RotaryEmbedding(dim_head)
190 |
191 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
192 | self.to_out = nn.Linear(dim_inner, dim, bias = False)
193 |
194 | self.dropout = nn.Dropout(dropout)
195 |
196 | self.split_heads = Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
197 | self.merge_heads = Rearrange('b h n d -> b n (h d)')
198 |
199 | # this corresponds to the learned beta in Deltanet from Yang et al. https://arxiv.org/abs/2406.06484
200 |
201 | self.to_learned_delta_update_weights = nn.Sequential(
202 | nn.Linear(dim, heads, bias = False),
203 | Rearrange('b n h -> b h n 1'),
204 | nn.Sigmoid()
205 | ) if learned_delta_update else None
206 |
207 | self.fastweight_mem = FastweightMemory(
208 | heads = heads,
209 | head_gate_init_value = head_gate_init_value,
210 | use_mem_delta_rule = use_mem_delta_rule,
211 | )
212 |
213 | def forward(
214 | self,
215 | x,
216 | cached_kv: Tensor | None = None,
217 | past_memories: Memories | None = None,
218 | return_new_memories = False,
219 | eps = 1e-10
220 | ) -> Tuple[Tensor, Tensor, Memories]:
221 | """
222 | ein notation:
223 |
224 | b - batch
225 | h - heads
226 | n - sequence
227 | i - source sequence (q)
228 | j - target sequence (kv)
229 | d - feature dimension
230 | dk - feature dimension keys (and queries)
231 | dv - feature dimension of values
232 | """
233 |
234 | x = self.norm(x)
235 |
236 | qkv = self.to_qkv(x)
237 | q, k, v = self.split_heads(qkv)
238 |
239 | # handle cached key / values
240 |
241 | if exists(cached_kv):
242 | cached_k, cached_v = cached_kv
243 | k = torch.cat((cached_k, k), dim = -2)
244 | v = torch.cat((cached_v, v), dim = -2)
245 |
246 | # similarity
247 |
248 | q_scaled = q * self.scale
249 | q_rotated, k_rotated = self.rotary_emb.rotate_queries_with_cached_keys(q_scaled, k)
250 |
251 | sim = einsum(q_rotated, k_rotated, '... i d, ... j d -> ... i j')
252 |
253 | # causal mask
254 |
255 | i, j = sim.shape[-2:]
256 | causal_mask = torch.ones((i, j), device = sim.device, dtype = torch.bool).triu(j - i + 1)
257 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
258 |
259 | # attend
260 |
261 | attn = sim.softmax(dim = -1)
262 |
263 | # dropout
264 |
265 | attn = self.dropout(attn)
266 |
267 | # aggregate values
268 |
269 | out = einsum(attn, v, '... i j, ... j d -> ... i d')
270 |
271 | out = self.fastweight_mem.retrieve_and_add_to_output(out, q, past_memories)
272 |
273 | # merge heads and combine
274 |
275 | out = self.merge_heads(out)
276 | out = self.to_out(out)
277 |
278 | # if new memories are not needed, early return
279 | # at inference time, kv cache up to segment length and then compress memories into kv
280 |
281 | if not return_new_memories:
282 | cached_kv = torch.stack((k, v))
283 |
284 | return out, cached_kv, past_memories
285 |
286 | # having the network learn the strength at which to apply the delta update rule
287 | # learnt per token / head
288 |
289 | delta_update_weights = None
290 |
291 | if exists(self.to_learned_delta_update_weights):
292 | delta_update_weights = self.to_learned_delta_update_weights(x)
293 |
294 | new_memories = self.fastweight_mem.create_new_memories(k, v, past_memories, delta_update_weights)
295 |
296 | return out, None, new_memories
297 |
298 | # main class
299 |
300 | class InfiniTransformer(Module):
301 | def __init__(
302 | self,
303 | *,
304 | num_tokens,
305 | dim,
306 | depth,
307 | dim_head = 128,
308 | heads = 8,
309 | attn_dropout = 0.,
310 | ff_mult = 4,
311 | ff_dropout = 0.,
312 | use_mem_delta_rule = False, # in the paper, the delta rule didn't seem to do that much, but will include for completeness
313 | learned_delta_update = False, # whether to use learned delta rule
314 | ):
315 | super().__init__()
316 |
317 | self.token_emb = nn.Embedding(num_tokens, dim)
318 |
319 | self.layers = ModuleList([])
320 |
321 | for _ in range(depth):
322 |
323 | attn = CausalAttention(
324 | dim = dim,
325 | dim_head = dim_head,
326 | heads = heads,
327 | use_mem_delta_rule = use_mem_delta_rule,
328 | learned_delta_update = learned_delta_update,
329 | dropout = attn_dropout
330 | )
331 |
332 | ff = FeedForward(
333 | dim = dim,
334 | mult = ff_mult,
335 | dropout = ff_dropout
336 | )
337 |
338 | self.layers.append(ModuleList([attn, ff]))
339 |
340 | self.norm = RMSNorm(dim)
341 | self.to_logits = nn.Linear(dim, num_tokens)
342 |
343 | def forward(
344 | self,
345 | x,
346 | past_memories: List[Memories] | None = None,
347 | cached_kv: List[Tensor] | None = None,
348 | return_new_memories = False,
349 | detach_memories = False
350 | ) -> TransformerReturn:
351 |
352 | x = self.token_emb(x)
353 |
354 | # handle cached key values
355 |
356 | if exists(cached_kv):
357 | x = x[:, -1:]
358 |
359 | new_cached_kv = []
360 | cached_kv_iter = iter(default(cached_kv, []))
361 |
362 | # iterator for past compressed memories
363 |
364 | new_memories = []
365 | past_memories_iter = iter(default(past_memories, []))
366 |
367 | # going through layers of infini-transformer
368 |
369 | for attn, ff in self.layers:
370 |
371 | attn_out, layer_cached_kv, layer_new_memories = attn(
372 | x,
373 | cached_kv = next(cached_kv_iter, None),
374 | past_memories = next(past_memories_iter, None),
375 | return_new_memories = return_new_memories
376 | )
377 |
378 | x = attn_out + x
379 | x = ff(x) + x
380 |
381 | new_cached_kv.append(layer_cached_kv)
382 | new_memories.append(layer_new_memories)
383 |
384 | # final norm
385 |
386 | embed = self.norm(x)
387 |
388 | # logits
389 |
390 | logits = self.to_logits(embed)
391 |
392 | if detach_memories:
393 | detach_cached_kv_(new_cached_kv)
394 |
395 | if not return_new_memories:
396 | return TransformerReturn(logits, new_cached_kv, past_memories)
397 |
398 | if detach_memories:
399 | detach_memories_(new_memories)
400 |
401 | return TransformerReturn(logits, None, new_memories)
402 |
--------------------------------------------------------------------------------
/infini_transformer_pytorch/wrapper.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from math import ceil
3 | from typing import Callable
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import Module
8 | import torch.nn.functional as F
9 |
10 | from einops import pack, rearrange
11 | from tqdm import tqdm
12 |
13 | from infini_transformer_pytorch.infini_transformer import (
14 | InfiniTransformer,
15 | detach_memories_
16 | )
17 |
18 | # helper functions
19 |
20 | def exists(v):
21 | return v is not None
22 |
23 | def default(v, d):
24 | return v if exists(v) else d
25 |
26 | def divisible_by(num, den):
27 | return (num % den) == 0
28 |
29 | def is_empty(t: Tensor):
30 | return t.numel() == 0
31 |
32 | def round_down_multiple(n, mult):
33 | return n // mult * mult
34 |
35 | # sampling helpers
36 |
37 | def log(t, eps = 1e-20):
38 | return torch.log(t.clamp(min = eps))
39 |
40 | def gumbel_noise(t):
41 | noise = torch.zeros_like(t).uniform_(0, 1)
42 | return -log(-log(noise))
43 |
44 | def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10):
45 | return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
46 |
47 | # nucleus
48 |
49 | def top_p(logits, thres = 0.9):
50 | sorted_logits, sorted_indices = torch.sort(logits, descending = True)
51 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
52 |
53 | sorted_indices_to_remove = cum_probs > thres
54 | sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
55 |
56 | sorted_logits[sorted_indices_to_remove] = float('-inf')
57 | return sorted_logits.scatter(1, sorted_indices, sorted_logits)
58 |
59 | # topk
60 |
61 | def top_k(logits, frac_num_tokens = 0.1, k: int | None = None):
62 | num_tokens = logits.shape[-1]
63 |
64 | k = default(k, ceil(frac_num_tokens * num_tokens))
65 | k = min(k, num_tokens)
66 |
67 | val, ind = torch.topk(logits, k)
68 | probs = torch.full_like(logits, float('-inf'))
69 | probs.scatter_(1, ind, val)
70 | return probs
71 |
72 | # class
73 |
74 | class InfiniTransformerWrapper(Module):
75 | def __init__(
76 | self,
77 | model: InfiniTransformer,
78 | segment_length = 512,
79 | detach_mems_every_num_segments = 2,
80 | ignore_index = -1
81 | ):
82 | super().__init__()
83 | self.model = model
84 |
85 | self.segment_length = segment_length
86 | self.detach_mems_every_num_segments = detach_mems_every_num_segments
87 |
88 | # loss related
89 |
90 | self.ignore_index = ignore_index
91 |
92 | @property
93 | def device(self):
94 | return next(self.model.parameters()).device
95 |
96 | @torch.no_grad()
97 | def generate(
98 | self,
99 | *,
100 | seq_len,
101 | prompt = None,
102 | batch_size = 1,
103 | temperature = 1.,
104 | filter_fn: Callable = top_p,
105 | filter_kwargs: dict = dict(thres = 0.9),
106 | exclude_prompt = True,
107 | segment_length = None
108 | ):
109 | segment_length = default(segment_length, self.segment_length)
110 | device, train_state = self.device, self.training
111 | self.eval()
112 |
113 | out = default(prompt, torch.empty((batch_size, 0), device = device, dtype = torch.long))
114 | init_len = out.shape[-1]
115 |
116 | # sample from the model token by token
117 | # keeping track of kv cache and when to compress into new memories
118 |
119 | cached_kv = None
120 | past_memories = None
121 |
122 | for curr_len in tqdm(range(init_len, seq_len)):
123 |
124 | # what is fed into the model is always at the start of the very last segment
125 |
126 | start_ind = round_down_multiple(curr_len - 1, segment_length)
127 | model_input = out[:, start_ind:]
128 |
129 | # forward the model with cached key / values and past memories
130 |
131 | logits, cached_kv, past_memories = self.model(
132 | model_input,
133 | cached_kv = cached_kv,
134 | past_memories = past_memories,
135 | return_new_memories = divisible_by(curr_len, segment_length)
136 | )
137 |
138 | # grab the last logit
139 |
140 | logits = logits[:, -1]
141 |
142 | # filter by either topk or nucleus
143 | # and sample
144 |
145 | filtered_logits = filter_fn(logits, **filter_kwargs)
146 | sampled = gumbel_sample(filtered_logits, temperature = temperature)
147 |
148 | # concat sampled token
149 |
150 | out, _ = pack((out, sampled), 'b *')
151 |
152 | # return output
153 |
154 | if exclude_prompt:
155 | out = out[:, init_len:]
156 |
157 | self.train(train_state)
158 | return out
159 |
160 | def forward(
161 | self,
162 | seq,
163 | segment_length = None,
164 | backward = False,
165 | grad_accum_scale = 1.
166 | ):
167 | segment_length = default(segment_length, self.segment_length)
168 |
169 | seq, label = seq[:, :-1], seq[:, 1:]
170 |
171 | # put into train mode if doing backwards within forward call
172 |
173 | if backward:
174 | self.model.train()
175 |
176 | total_tokens = (label != self.ignore_index).sum().item()
177 |
178 | # split the sequence by segment length
179 |
180 | split_seq = seq.split(segment_length, dim = -1)
181 | split_label = label.split(segment_length, dim = -1)
182 |
183 | num_segments = len(split_seq)
184 |
185 | # go over each segment length and calculate cross entropy loss
186 |
187 | total_loss = 0.
188 | past_memories = None
189 |
190 | running_loss = 0.
191 |
192 | for ind, (segment_seq, segment_label) in enumerate(zip(split_seq, split_label)):
193 | segment_num = ind + 1
194 | is_last = segment_num == num_segments
195 |
196 | should_detach_memories = divisible_by(segment_num, self.detach_mems_every_num_segments)
197 | should_backward = backward and (is_last or should_detach_memories)
198 |
199 | # model forwards for logits and past memories
200 |
201 | logits, _, past_memories = self.model(
202 | segment_seq,
203 | past_memories = past_memories,
204 | return_new_memories = True
205 | )
206 |
207 | # calculate cross entropy loss for segment
208 |
209 | segment_loss = F.cross_entropy(
210 | rearrange(logits, 'b n c -> b c n'),
211 | segment_label,
212 | reduction = 'none'
213 | )
214 |
215 | # make sure segment losses do not include ignored index
216 | # then also make sure the segment loss is scaled
217 |
218 | segment_mask = segment_label != self.ignore_index
219 | num_segment_tokens = segment_mask.sum()
220 | frac_tokens = num_segment_tokens / total_tokens
221 |
222 | segment_loss = segment_loss[segment_mask]
223 | segment_scaled_loss = segment_loss.mean() * frac_tokens
224 |
225 | total_loss = total_loss + segment_scaled_loss
226 | running_loss = running_loss + segment_scaled_loss
227 |
228 | # perform backwards every `(num_segment * detach_mems_every_num_segments)`
229 |
230 | if should_backward:
231 | (running_loss / grad_accum_scale).backward()
232 | running_loss = 0.
233 |
234 | # detach memories if need be
235 |
236 | if should_detach_memories and not is_last:
237 | detach_memories_(past_memories)
238 |
239 | return total_loss
240 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "infini-transformer-pytorch"
3 | version = "0.2.1"
4 | description = "Infini-Transformer in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.8"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'attention mechanism',
16 | 'long context',
17 | 'memory'
18 | ]
19 | classifiers=[
20 | 'Development Status :: 4 - Beta',
21 | 'Intended Audience :: Developers',
22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Programming Language :: Python :: 3.8',
25 | ]
26 |
27 | dependencies = [
28 | 'einops>=0.8.0',
29 | 'rotary-embedding-torch>=0.6.0',
30 | 'torch>=2.0',
31 | 'tqdm'
32 | ]
33 |
34 | [project.urls]
35 | Homepage = "https://pypi.org/project/infini-transformer-pytorch/"
36 | Repository = "https://github.com/lucidrains/infini-transformer-pytorch"
37 |
38 | [project.optional-dependencies]
39 | examples = [
40 | "tqdm",
41 | "numpy"
42 | ]
43 |
44 | [build-system]
45 | requires = ["hatchling"]
46 | build-backend = "hatchling.build"
47 |
48 | [tool.pytest.ini_options]
49 | pythonpath = ["."]
50 |
51 | [tool.rye]
52 | managed = true
53 | dev-dependencies = [
54 | "ruff>=0.4.2",
55 | "pytest>=8.2.0",
56 | ]
57 |
58 | [tool.ruff]
59 | line-length = 1000
60 | ignore-init-module-imports = true
61 |
62 | [tool.hatch.metadata]
63 | allow-direct-references = true
64 |
65 | [tool.hatch.build.targets.wheel]
66 | packages = ["infini_transformer_pytorch"]
67 |
--------------------------------------------------------------------------------
/tests/test_readme.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from infini_transformer_pytorch import (
5 | InfiniTransformer,
6 | InfiniTransformerWrapper
7 | )
8 |
9 | def test_readme():
10 | transformer = InfiniTransformer(
11 | num_tokens = 256,
12 | dim = 64,
13 | depth = 1,
14 | dim_head = 128,
15 | heads = 8,
16 | use_mem_delta_rule = True
17 | )
18 |
19 | x = torch.randint(0, 256, (1, 1024))
20 |
21 | logits1, _, mem1 = transformer(x, return_new_memories = False)
22 | logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
23 | logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)
24 |
25 | def test_generate():
26 | # model and wrapper
27 |
28 | model = InfiniTransformer(
29 | num_tokens = 256,
30 | dim = 64,
31 | depth = 1,
32 | dim_head = 128,
33 | heads = 8,
34 | use_mem_delta_rule = True
35 | )
36 |
37 | wrapper = InfiniTransformerWrapper(
38 | model,
39 | segment_length = 32,
40 | detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
41 | )
42 |
43 | # mock input
44 |
45 | seq = torch.randint(0, 256, (2, 128)) # can be arbitrarily long sequence
46 |
47 | # training
48 |
49 | wrapper(
50 | seq,
51 | backward = True # will automatically segment and accumulate gradients when it detaches the memories
52 | )
53 |
54 | # after much data...
55 |
56 | # calculating eval loss
57 |
58 | with torch.no_grad():
59 | wrapper.eval()
60 | wrapper(seq)
61 |
62 | # generating is as easy as
63 |
64 | wrapper.generate(seq_len = 128, prompt = seq[:, :1])
65 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from infini_transformer_pytorch import (
2 | InfiniTransformer,
3 | InfiniTransformerWrapper
4 | )
5 |
6 | import tqdm
7 | import gzip
8 | import numpy as np
9 | import torch
10 | from torch.optim import Adam
11 | from torch.utils.data import DataLoader, Dataset
12 |
13 | # constants
14 |
15 | NUM_BATCHES = int(1e5)
16 | BATCH_SIZE = 4
17 | GRADIENT_ACCUMULATE_EVERY = 4
18 | LEARNING_RATE = 2e-4
19 | VALIDATE_EVERY = 100
20 | GENERATE_EVERY = 250
21 | PRIME_LEN = 100
22 | SEQ_LEN = 1024
23 | SEGMENT_LENGTH = 128
24 |
25 | # helpers
26 |
27 | def cycle(loader):
28 | while True:
29 | for data in loader:
30 | yield data
31 |
32 | def decode_token(token):
33 | return str(chr(max(32, token)))
34 |
35 | def decode_tokens(tokens):
36 | return ''.join(list(map(decode_token, tokens)))
37 |
38 | # instantiate GPT-like decoder model
39 |
40 | model = InfiniTransformer(
41 | num_tokens = 256,
42 | dim = 512,
43 | depth = 8,
44 | dim_head = 64,
45 | heads = 8,
46 | use_mem_delta_rule = True,
47 | learned_delta_update = True
48 | )
49 |
50 | wrapper = InfiniTransformerWrapper(
51 | model,
52 | segment_length = SEGMENT_LENGTH,
53 | detach_mems_every_num_segments = 2
54 | ).cuda()
55 |
56 | # prepare enwik8 data
57 |
58 | with gzip.open('./data/enwik8.gz') as file:
59 | x = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
60 | train_x, valid_x = np.split(x, [int(90e6)])
61 | data_train, data_val = map(torch.from_numpy, (train_x, valid_x))
62 |
63 | class TextSamplerDataset(Dataset):
64 | def __init__(self, data, seq_len):
65 | super().__init__()
66 | self.data = data
67 | self.seq_len = seq_len
68 |
69 | def __getitem__(self, index):
70 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
71 | full_seq = self.data[rand_start: rand_start + self.seq_len].long()
72 | return full_seq.cuda()
73 |
74 | def __len__(self):
75 | return self.data.size(0) // self.seq_len
76 |
77 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
78 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
79 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
80 | val_loader = cycle(DataLoader(val_dataset, batch_size = 1))
81 |
82 | # optimizer
83 |
84 | optim = Adam(model.parameters(), lr = LEARNING_RATE)
85 |
86 | # training
87 |
88 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.):
89 |
90 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
91 | loss = wrapper(
92 | next(train_loader),
93 | backward = True,
94 | grad_accum_scale = GRADIENT_ACCUMULATE_EVERY ** -1.
95 | )
96 |
97 | print(f'training loss: {loss.item()}')
98 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
99 | optim.step()
100 | optim.zero_grad()
101 |
102 | if i % VALIDATE_EVERY == 0:
103 | with torch.no_grad():
104 | wrapper.eval()
105 | loss = wrapper(next(val_loader))
106 | print(f'validation loss: {loss.item()}')
107 |
108 | if i % GENERATE_EVERY == 0:
109 | ids = next(val_loader)[:, :PRIME_LEN]
110 | prime = decode_tokens(ids.flatten())
111 | print('%s \n\n %s', (prime, '*' * 100))
112 |
113 | sample = wrapper.generate(
114 | prompt = ids,
115 | seq_len = SEQ_LEN
116 | )
117 |
118 | decoded_string = decode_tokens(sample.flatten())
119 | print(decoded_string)
120 | print("\n")
121 |
--------------------------------------------------------------------------------