├── .github
└── workflows
│ ├── python-publish.yml
│ └── test.yaml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── fig1.png
├── fig2.png
├── pyproject.toml
├── tests
└── test_titans.py
├── titans_pytorch
├── __init__.py
├── mac_transformer.py
├── memory_models.py
└── neural_memory.py
└── train_mac.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Tests the examples in README
2 | on: [push, pull_request]
3 |
4 | env:
5 | TYPECHECK: True
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Install Python
13 | uses: actions/setup-python@v5
14 | with:
15 | python-version: "3.11"
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install uv
19 | python -m uv pip install --upgrade pip
20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21 | python -m uv pip install -e .[test]
22 | - name: Test with pytest
23 | run: |
24 | python -m pytest tests/
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | train_local.py
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # UV
101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | #uv.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119 | .pdm.toml
120 | .pdm-python
121 | .pdm-build/
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | # PyCharm
167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169 | # and can be added to the global gitignore or merged into this file. For a more nuclear
170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171 | #.idea/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Titans - Pytorch
6 |
7 | Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8 |
9 | ## Appreciation
10 |
11 | - [Eryk](https://github.com/sentialx) for sharing his early experimental results with me, positive for 2 layer MLP
12 |
13 | ## Install
14 |
15 | ```bash
16 | $ pip install titans-pytorch
17 | ```
18 |
19 | ## Usage
20 |
21 | ```python
22 | import torch
23 | from titans_pytorch import NeuralMemory
24 |
25 | mem = NeuralMemory(
26 | dim = 384,
27 | chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
28 | ).cuda()
29 |
30 | seq = torch.randn(2, 1024, 384).cuda()
31 | retrieved, mem_state = mem(seq)
32 |
33 | assert seq.shape == retrieved.shape
34 | ```
35 |
36 | A transformer with the `MAC` configuration can be used as
37 |
38 | ```python
39 | import torch
40 | from titans_pytorch import MemoryAsContextTransformer
41 |
42 | transformer = MemoryAsContextTransformer(
43 | num_tokens = 256,
44 | dim = 256,
45 | depth = 2,
46 | segment_len = 128, # local attention window size
47 | num_persist_mem_tokens = 4,
48 | num_longterm_mem_tokens = 16,
49 | )
50 |
51 | token_ids = torch.randint(0, 256, (1, 1023))
52 |
53 | loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
54 | loss.backward()
55 |
56 | # after much training
57 |
58 | sampled = transformer.sample(token_ids[:, :4], 512)
59 | ```
60 |
61 | ## Experiments
62 |
63 | ```bash
64 | $ pip install .[examples]
65 | ```
66 |
67 | Then modify `train_mac.py` and run it to query nature
68 |
69 | ```bash
70 | $ python train_mac.py
71 | ```
72 |
73 | ## Citations
74 |
75 | ```bibtex
76 | @inproceedings{Behrouz2024TitansLT,
77 | title = {Titans: Learning to Memorize at Test Time},
78 | author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
79 | year = {2024},
80 | url = {https://api.semanticscholar.org/CorpusID:275212078}
81 | }
82 | ```
83 |
84 | ```bibtex
85 | @article{Sun2024LearningT,
86 | title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States},
87 | author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin},
88 | journal = {ArXiv},
89 | year = {2024},
90 | volume = {abs/2407.04620},
91 | url = {https://api.semanticscholar.org/CorpusID:271039606}
92 | }
93 | ```
94 |
95 | ```bibtex
96 | @inproceedings{Yang2024GatedDN,
97 | title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
98 | author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
99 | year = {2024},
100 | url = {https://api.semanticscholar.org/CorpusID:274598177}
101 | }
102 | ```
103 |
104 | ```bibtex
105 | @inproceedings{Nguyen2024TurningUT,
106 | title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs},
107 | author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv},
108 | year = {2024},
109 | url = {https://api.semanticscholar.org/CorpusID:270870613}
110 | }
111 | ```
112 |
113 | ```bibtex
114 | @article{Zhu2024HyperConnections,
115 | title = {Hyper-Connections},
116 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
117 | journal = {ArXiv},
118 | year = {2024},
119 | volume = {abs/2409.19606},
120 | url = {https://api.semanticscholar.org/CorpusID:272987528}
121 | }
122 | ```
123 |
124 | ```bibtex
125 | @article{Zhou2024ValueRL,
126 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
127 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
128 | journal = {ArXiv},
129 | year = {2024},
130 | volume = {abs/2410.17897},
131 | url = {https://api.semanticscholar.org/CorpusID:273532030}
132 | }
133 | ```
134 |
135 | ```bibtex
136 | @software{Kyrylov_Accelerated_Scan_2024,
137 | author = {Kyrylov, Volodymyr},
138 | doi = {10.5281/zenodo.10600962},
139 | title = {Accelerated Scan},
140 | version = {0.1.2},
141 | year = {2024}
142 | }
143 | ```
144 |
145 | ```bibtex
146 | @misc{wang2025testtimeregressionunifyingframework,
147 | title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
148 | author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
149 | year = {2025},
150 | eprint = {2501.12352},
151 | archivePrefix = {arXiv},
152 | primaryClass = {cs.LG},
153 | url = {https://arxiv.org/abs/2501.12352},
154 | }
155 | ```
156 |
157 | ```bibtex
158 | @misc{jordan2024muon,
159 | author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and
160 | Franz Cesista and Laker Newhouse and Jeremy Bernstein},
161 | title = {Muon: An optimizer for hidden layers in neural networks},
162 | year = {2024},
163 | url = {https://kellerjordan.github.io/posts/muon/}
164 | }
165 | ```
166 |
167 | ```bibtex
168 | @inproceedings{Zhang2025TestTimeTD,
169 | title = {Test-Time Training Done Right},
170 | author = {Tianyuan Zhang and Sai Bi and Yicong Hong and Kai Zhang and Fujun Luan and Songlin Yang and Kalyan Sunkavalli and William T. Freeman and Hao Tan},
171 | year = {2025},
172 | url = {https://api.semanticscholar.org/CorpusID:279071244}
173 | }
174 | ```
175 |
176 | ```bibtex
177 | @inproceedings{Behrouz2025ATLASLT,
178 | title = {ATLAS: Learning to Optimally Memorize the Context at Test Time},
179 | author = {Ali Behrouz and Ze-Minghui Li and Praneeth Kacham and Majid Daliri and Yuan Deng and Peilin Zhong and Meisam Razaviyayn and Vahab S. Mirrokni},
180 | year = {2025},
181 | url = {https://api.semanticscholar.org/CorpusID:278996373}
182 | }
183 | ```
184 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/data/enwik8.gz
--------------------------------------------------------------------------------
/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/fig1.png
--------------------------------------------------------------------------------
/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/fig2.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "titans-pytorch"
3 | version = "0.4.10"
4 | description = "Titans"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'test time training',
15 | 'linear attention',
16 | 'memory',
17 | ]
18 |
19 | classifiers=[
20 | 'Development Status :: 4 - Beta',
21 | 'Intended Audience :: Developers',
22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Programming Language :: Python :: 3.9',
25 | ]
26 |
27 | dependencies = [
28 | "assoc-scan",
29 | "axial_positional_embedding>=0.3.10",
30 | "einops>=0.8.0",
31 | "einx>=0.3.0",
32 | "hyper-connections>=0.1.11",
33 | "Ninja",
34 | "rotary-embedding-torch",
35 | "tensordict",
36 | "torch>=2.2",
37 | "tqdm",
38 | "x-transformers"
39 | ]
40 |
41 | [project.urls]
42 | Homepage = "https://pypi.org/project/titans-pytorch/"
43 | Repository = "https://github.com/lucidrains/titans-pytorch"
44 |
45 | [project.optional-dependencies]
46 |
47 | examples = [
48 | "adam-atan2-pytorch>=0.1.18",
49 | "wandb"
50 | ]
51 |
52 | test = [
53 | "pytest"
54 | ]
55 |
56 | [tool.pytest.ini_options]
57 | pythonpath = [
58 | "."
59 | ]
60 |
61 | [build-system]
62 | requires = ["hatchling"]
63 | build-backend = "hatchling.build"
64 |
65 | [tool.rye]
66 | managed = true
67 | dev-dependencies = []
68 |
69 | [tool.hatch.metadata]
70 | allow-direct-references = true
71 |
72 | [tool.hatch.build.targets.wheel]
73 | packages = ["titans_pytorch"]
74 |
--------------------------------------------------------------------------------
/tests/test_titans.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | import torch
4 | from torch import nn
5 |
6 | import pytest
7 | from titans_pytorch import NeuralMemory
8 | from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
9 |
10 | # functions
11 |
12 | def exists(v):
13 | return v is not None
14 |
15 | def diff(x, y):
16 | return (x - y).abs().amax()
17 |
18 | @contextmanager
19 | def torch_default_dtype(dtype):
20 | prev_dtype = torch.get_default_dtype()
21 | torch.set_default_dtype(dtype)
22 | yield
23 | torch.set_default_dtype(prev_dtype)
24 |
25 | # main test
26 |
27 | @pytest.mark.parametrize('seq_len', (32, 512, 77))
28 | @pytest.mark.parametrize('silu', (False, True))
29 | @pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False)))
30 | @pytest.mark.parametrize('momentum', (False, True))
31 | @pytest.mark.parametrize('qk_rmsnorm', (False, True))
32 | @pytest.mark.parametrize('heads', (1, 4))
33 | @pytest.mark.parametrize('max_grad_norm', (None, 2.))
34 | @pytest.mark.parametrize('num_kv_per_token', (1, 2))
35 | @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
36 | @pytest.mark.parametrize('per_head_learned_parameters', (False, True))
37 | @pytest.mark.parametrize('test_store_mask', (False, True))
38 | def test_titans(
39 | seq_len,
40 | silu,
41 | attn_pool_chunks,
42 | chunk_size,
43 | momentum,
44 | qk_rmsnorm,
45 | heads,
46 | max_grad_norm,
47 | num_kv_per_token,
48 | per_parameter_lr_modulation,
49 | per_head_learned_parameters,
50 | test_store_mask
51 | ):
52 | mem = NeuralMemory(
53 | dim = 16,
54 | chunk_size = chunk_size,
55 | activation = nn.SiLU() if silu else None,
56 | attn_pool_chunks = attn_pool_chunks,
57 | max_grad_norm = max_grad_norm,
58 | num_kv_per_token = num_kv_per_token,
59 | momentum = momentum,
60 | qk_rmsnorm = qk_rmsnorm,
61 | heads = heads,
62 | per_parameter_lr_modulation = per_parameter_lr_modulation,
63 | per_head_learned_parameters = per_head_learned_parameters
64 | )
65 |
66 | seq = torch.randn(2, seq_len, 16)
67 |
68 | store_mask = None
69 |
70 | if test_store_mask:
71 | store_mask = torch.randint(0, 2, (2, seq_len)).bool()
72 |
73 | retrieved, _ = mem(seq, store_mask = store_mask)
74 |
75 | assert seq.shape == retrieved.shape
76 |
77 | def test_return_surprises():
78 |
79 | mem = NeuralMemory(
80 | dim = 384,
81 | chunk_size = 2,
82 | dim_head = 64,
83 | heads = 4,
84 | )
85 |
86 | seq = torch.randn(4, 64, 384)
87 |
88 | _, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
89 |
90 | assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
91 |
92 | @pytest.mark.parametrize('learned_momentum_combine', (False, True))
93 | @pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
94 | def test_titans_second_order_momentum(
95 | learned_momentum_combine,
96 | learned_combine_include_zeroth
97 | ):
98 |
99 | mem = NeuralMemory(
100 | dim = 384,
101 | dim_head = 64,
102 | heads = 2,
103 | chunk_size = 1,
104 | batch_size = 2,
105 | momentum_order = 2,
106 | learned_momentum_combine = learned_momentum_combine,
107 | learned_combine_include_zeroth = learned_combine_include_zeroth
108 | )
109 |
110 | seq = torch.randn(2, 5, 384)
111 |
112 | parallel_retrieved, state = mem(seq)
113 | assert seq.shape == parallel_retrieved.shape
114 |
115 | def test_titans_attn_memory():
116 | from titans_pytorch.memory_models import MemoryAttention
117 |
118 | mem = NeuralMemory(
119 | dim = 16,
120 | chunk_size = 64,
121 | model = MemoryAttention(
122 | dim = 16
123 | )
124 | )
125 |
126 | seq = torch.randn(2, 1024, 16)
127 | retrieved, _ = mem(seq)
128 |
129 | assert seq.shape == retrieved.shape
130 |
131 | def test_swiglu_ff_memory():
132 | from titans_pytorch.memory_models import MemorySwiGluMLP
133 |
134 | mem = NeuralMemory(
135 | dim = 16,
136 | chunk_size = 2,
137 | mem_model_norm_add_residual = False,
138 | model = MemorySwiGluMLP(
139 | dim = 16,
140 | depth = 2
141 | )
142 | )
143 |
144 | seq = torch.randn(2, 64, 16)
145 | retrieved, _ = mem(seq)
146 |
147 | assert seq.shape == retrieved.shape
148 |
149 | @pytest.mark.parametrize('gated_transition', (True, False))
150 | def test_neural_mem_chaining_chunks(
151 | gated_transition
152 | ):
153 | mem = NeuralMemory(
154 | dim = 16,
155 | dim_head = 16,
156 | heads = 2,
157 | chunk_size = 16,
158 | gated_transition = gated_transition
159 | )
160 |
161 | seq = torch.randn(2, 48, 16)
162 |
163 | parallel_retrieved, state = mem(seq)
164 |
165 | seq_first, seq_second, seq_third = seq.split(16, dim = 1)
166 |
167 | first_retrieved, state = mem(seq_first)
168 | second_retrieved, state = mem(seq_second, state = state)
169 | third_retrieved, state = mem(seq_third, state = state)
170 |
171 | assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5)
172 |
173 | def test_neural_mem_chaining_with_weight_residual():
174 | mem = NeuralMemory(
175 | dim = 16,
176 | dim_head = 16,
177 | heads = 2,
178 | chunk_size = 64
179 | )
180 |
181 | mem2 = NeuralMemory(
182 | dim = 16,
183 | dim_head = 16,
184 | heads = 2,
185 | chunk_size = 64,
186 | accept_weight_residual = True
187 | )
188 |
189 | seq = torch.randn(2, 256, 16)
190 |
191 | seq, state = mem(seq)
192 |
193 | parallel_retrieved, _ = mem2(seq, prev_weights = state.updates)
194 |
195 | seq_first, seq_second = seq[:, :128], seq[:, 128:]
196 |
197 | first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
198 | second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
199 |
200 | assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
201 |
202 | def test_neural_mem_chaining_with_batch_size():
203 | mem = NeuralMemory(
204 | dim = 16,
205 | dim_head = 16,
206 | heads = 2,
207 | chunk_size = 16,
208 | batch_size = 64
209 | )
210 |
211 | seq = torch.randn(2, 112, 16)
212 |
213 | parallel_retrieved, state = mem(seq)
214 |
215 | seq_first, seq_second, seq_third = seq[:, :16], seq[:, 16:64], seq[:, 64:]
216 |
217 | first_retrieved, state = mem(seq_first)
218 | second_retrieved, state = mem(seq_second, state = state)
219 | third_retrieved, state = mem(seq_third, state = state)
220 |
221 | parallel_part_retrieved = torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1)
222 |
223 | assert torch.allclose(parallel_retrieved, parallel_part_retrieved, atol = 1e-5)
224 |
225 | @pytest.mark.parametrize('seq_len', (1023, 17))
226 | @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
227 | @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
228 | @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
229 | @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
230 | @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
231 | @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
232 | @pytest.mark.parametrize('neural_mem_qkv_receives_diff_views', (False, True))
233 | @pytest.mark.parametrize('neural_mem_momentum', (False, True))
234 | def test_mac(
235 | seq_len,
236 | num_persist_mem_tokens,
237 | num_longterm_mem_tokens,
238 | neural_mem_gate_attn_output,
239 | neural_mem_segment_len,
240 | neural_mem_weight_residual,
241 | neural_mem_batch_size,
242 | neural_mem_qkv_receives_diff_views,
243 | neural_mem_momentum
244 | ):
245 | transformer = MemoryAsContextTransformer(
246 | num_tokens = 256,
247 | dim = 16,
248 | depth = 2,
249 | num_persist_mem_tokens = num_persist_mem_tokens,
250 | num_longterm_mem_tokens = num_longterm_mem_tokens,
251 | segment_len = 128,
252 | neural_mem_gate_attn_output = neural_mem_gate_attn_output,
253 | neural_memory_segment_len = neural_mem_segment_len,
254 | neural_memory_batch_size = neural_mem_batch_size,
255 | neural_memory_qkv_receives_diff_views = neural_mem_qkv_receives_diff_views,
256 | neural_mem_weight_residual = neural_mem_weight_residual,
257 | neural_memory_kwargs = dict(
258 | momentum = neural_mem_momentum
259 | )
260 | )
261 |
262 | x = torch.randint(0, 256, (1, seq_len))
263 |
264 | logits = transformer(x)
265 | assert logits.shape == (1, seq_len, 256)
266 |
267 | @pytest.mark.parametrize('sliding', (False, True))
268 | @pytest.mark.parametrize('mem_layers', ((), None))
269 | @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
270 | @pytest.mark.parametrize('prompt_len', (4, 16))
271 | @torch_default_dtype(torch.float64)
272 | def test_mac_sampling(
273 | sliding,
274 | mem_layers,
275 | longterm_mems,
276 | prompt_len
277 | ):
278 | transformer = MemoryAsContextTransformer(
279 | num_tokens = 256,
280 | dim = 16,
281 | depth = 4,
282 | segment_len = 32,
283 | num_persist_mem_tokens = 4,
284 | num_longterm_mem_tokens = longterm_mems,
285 | sliding_window_attn = sliding,
286 | neural_memory_layers = mem_layers,
287 | neural_mem_gate_attn_output = False
288 | )
289 |
290 | ids = torch.randint(0, 256, (1, 1023))
291 |
292 | # after much training
293 |
294 | prompt = ids[:, :prompt_len]
295 |
296 | sampled = transformer.sample(prompt, 53, use_cache = False, temperature = 0.)
297 | sampled_with_cache = transformer.sample(prompt, 53, use_cache = True, temperature = 0.)
298 |
299 | assert torch.allclose(sampled, sampled_with_cache)
300 |
301 | @pytest.mark.parametrize('seq_len', (2, 64, 256))
302 | @pytest.mark.parametrize('prompt_len', (0, 65))
303 | @pytest.mark.parametrize('mem_chunk_size', (2, 32, 64))
304 | @pytest.mark.parametrize('gated_transition', (False, True))
305 | @torch_default_dtype(torch.float64)
306 | def test_neural_mem_inference(
307 | seq_len,
308 | prompt_len,
309 | mem_chunk_size,
310 | gated_transition
311 | ):
312 |
313 | mem = NeuralMemory(
314 | dim = 16,
315 | chunk_size = mem_chunk_size,
316 | gated_transition = gated_transition
317 | )
318 |
319 | seq = torch.randn(2, seq_len, 16)
320 | parallel_retrieved, _ = mem(seq)
321 |
322 | assert seq.shape == parallel_retrieved.shape
323 |
324 | state = None
325 | sequential_retrieved = []
326 |
327 | # test initial parallel prompt
328 |
329 | test_parallel_prompt = prompt_len > 0 and prompt_len < seq_len
330 |
331 | if test_parallel_prompt:
332 | prompt, seq = seq[:, :prompt_len], seq[:, prompt_len:]
333 | retrieved_prompt, state = mem(prompt)
334 | sequential_retrieved.append(retrieved_prompt)
335 |
336 | # sequential inference
337 |
338 | for token in seq.unbind(dim = 1):
339 |
340 | one_retrieved, state = mem.forward(
341 | token,
342 | state = state,
343 | )
344 |
345 | sequential_retrieved.append(one_retrieved)
346 |
347 | sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
348 |
349 | assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-6)
350 |
351 | @pytest.mark.parametrize('seq_len', (1023, 17))
352 | @pytest.mark.parametrize('sliding', (True, False))
353 | def test_flex(
354 | seq_len,
355 | sliding
356 | ):
357 | if not (torch.cuda.is_available() and exists(flex_attention)):
358 | pytest.skip()
359 |
360 | attn = SegmentedAttention(
361 | dim = 16,
362 | segment_len = 32,
363 | num_persist_mem_tokens = 1,
364 | num_longterm_mem_tokens = 1,
365 | use_flex_attn = True,
366 | sliding = sliding
367 | ).cuda()
368 |
369 | seq = torch.randn(1, seq_len, 16).cuda()
370 |
371 | out_flex, _ = attn(seq)
372 | out_non_flex, _ = attn(seq, disable_flex_attn = True)
373 |
374 | assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
375 |
376 | @pytest.mark.parametrize('use_accelerated', (True, False))
377 | def test_assoc_scan(
378 | use_accelerated
379 | ):
380 | from titans_pytorch.neural_memory import AssocScan
381 |
382 | if use_accelerated and not torch.cuda.is_available():
383 | pytest.skip()
384 |
385 | scan = AssocScan(use_accelerated = use_accelerated)
386 |
387 | seq_len = 128
388 | mid_point = seq_len // 2
389 |
390 | gates = torch.randn(2, seq_len, 16).sigmoid()
391 | inputs = torch.randn(2, seq_len, 16)
392 |
393 | if use_accelerated:
394 | gates = gates.cuda()
395 | inputs = inputs.cuda()
396 |
397 | output = scan(gates, inputs)
398 |
399 | gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:]
400 | inputs1, inputs2 = inputs[:, :mid_point], inputs[:, mid_point:]
401 |
402 | first_half = scan(gates1, inputs1)
403 |
404 | second_half = scan(gates2, inputs2, prev = first_half[:, -1])
405 | assert second_half.shape == inputs2.shape
406 |
407 | assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
408 |
409 | def test_mem_state_detach():
410 | from titans_pytorch.neural_memory import mem_state_detach
411 |
412 | mem = NeuralMemory(
413 | dim = 384,
414 | chunk_size = 2,
415 | qk_rmsnorm = True,
416 | dim_head = 64,
417 | heads = 4,
418 | )
419 |
420 | seq = torch.randn(4, 64, 384)
421 |
422 | state = None
423 |
424 | for _ in range(2):
425 | parallel_retrieved, state = mem(seq, state = state)
426 | state = mem_state_detach(state)
427 | parallel_retrieved.sum().backward()
428 |
--------------------------------------------------------------------------------
/titans_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from titans_pytorch.neural_memory import (
2 | NeuralMemory,
3 | NeuralMemState,
4 | mem_state_detach
5 | )
6 |
7 | from titans_pytorch.memory_models import (
8 | MemoryMLP,
9 | MemoryAttention,
10 | FactorizedMemoryMLP,
11 | MemorySwiGluMLP,
12 | GatedResidualMemoryMLP
13 | )
14 |
15 | from titans_pytorch.mac_transformer import (
16 | MemoryAsContextTransformer
17 | )
18 |
--------------------------------------------------------------------------------
/titans_pytorch/mac_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | from math import ceil
5 | from copy import deepcopy
6 | from functools import partial
7 | from collections import namedtuple
8 |
9 | import tqdm
10 |
11 | import torch
12 | from torch import nn, stack, cat
13 | import torch.nn.functional as F
14 | from torch.nn import Module, ModuleList, Linear
15 |
16 | # flex attention
17 | # https://pytorch.org/blog/flexattention/
18 |
19 | flex_attention = None
20 |
21 | try:
22 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask
23 | if torch.cuda.is_available():
24 | flex_attention = torch.compile(flex_attention)
25 | except ImportError:
26 | pass
27 |
28 | def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False):
29 |
30 | def create_mac_mask(_, __, q_idx, kv_idx):
31 | is_persist_mem = kv_idx < persist_mem_len
32 | kv_without_mem = kv_idx - persist_mem_len
33 | causal_mask = q_idx >= kv_without_mem
34 |
35 | if not sliding:
36 | block_diagonal = (q_idx // window_size) == (kv_without_mem // window_size)
37 | causal_mask = causal_mask & block_diagonal
38 | else:
39 | sliding_mask = (q_idx - kv_without_mem) <= window_size
40 | causal_mask = causal_mask & sliding_mask
41 |
42 | return is_persist_mem | (~is_persist_mem & causal_mask)
43 |
44 | block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
45 | return block_mask
46 |
47 | # einstein notation related
48 |
49 | from einops import repeat, rearrange, pack, unpack, einsum
50 | from einops.layers.torch import Rearrange
51 |
52 | # b - batch
53 | # n - sequence
54 | # h - heads
55 | # d - feature dimension
56 |
57 | # absolute and relative positions
58 |
59 | from axial_positional_embedding import ContinuousAxialPositionalEmbedding
60 | from rotary_embedding_torch import RotaryEmbedding
61 |
62 | # hyper connections / attend from x-transformers, which handles different queries and key lengths better
63 |
64 | from x_transformers.attend import Attend
65 |
66 | from hyper_connections import get_init_and_expand_reduce_stream_functions
67 |
68 | # proposed neural memory
69 |
70 | from titans_pytorch.neural_memory import NeuralMemory
71 |
72 | # constants
73 |
74 | LinearNoBias = partial(Linear, bias = False)
75 |
76 | AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
77 |
78 | # helpers
79 |
80 | def exists(v):
81 | return v is not None
82 |
83 | def default(v, d):
84 | return v if exists(v) else d
85 |
86 | def identity(t):
87 | return t
88 |
89 | def divisible_by(num, den):
90 | return (num % den) == 0
91 |
92 | def round_up_multiple(seq, mult):
93 | return ceil(seq / mult) * mult
94 |
95 | def round_down_multiple(seq, mult):
96 | return seq // mult * mult
97 |
98 | def pack_with_inverse(t, pattern):
99 | packed, packed_shape = pack(t, pattern)
100 |
101 | def inverse(out, inv_pattern = None):
102 | return unpack(out, packed_shape, default(inv_pattern, pattern))
103 |
104 | return packed, inverse
105 |
106 | def pad_at_dim(t, pad, dim = -1, value = 0.):
107 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
108 | zeros = ((0, 0) * dims_from_right)
109 | return F.pad(t, (*zeros, *pad), value = value)
110 |
111 | def pad_and_segment_with_inverse(
112 | seq,
113 | segment_len,
114 | fold_into_batch = True,
115 | inverse_remove_pad = True
116 | ):
117 | batch, seq_len = seq.shape[:2]
118 | next_seq_len_mult = round_up_multiple(seq_len, segment_len)
119 |
120 | padding = next_seq_len_mult - seq_len
121 | needs_pad = padding > 0
122 |
123 | if needs_pad:
124 | seq = F.pad(seq, (0, 0, 0, padding))
125 |
126 | if fold_into_batch:
127 | seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
128 |
129 | def inverse(out):
130 |
131 | if fold_into_batch:
132 | out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
133 |
134 | if needs_pad and inverse_remove_pad:
135 | out = out[..., :-padding, :]
136 |
137 | return out
138 |
139 | return seq, inverse
140 |
141 | # sampling related
142 |
143 | def log(t, eps = 1e-20):
144 | return torch.log(t.clamp(min = eps))
145 |
146 | def gumbel_noise(t):
147 | noise = torch.rand_like(t)
148 | return -log(-log(noise))
149 |
150 | def gumbel_sample(t, temperature = 1.):
151 | if temperature > 0.:
152 | t = t / temperature + gumbel_noise(t)
153 | return t.argmax(dim = -1, keepdim = True)
154 |
155 | # min_p
156 | # https://arxiv.org/abs/2407.01082
157 |
158 | def min_p_filter(logits, min_p = 0.1):
159 | probs = logits.softmax(dim = -1)
160 | max_probs = probs.amax(dim = -1, keepdim = True)
161 | limit = min_p * max_probs
162 | return torch.where(probs < limit, float('-inf'), logits)
163 |
164 | # feedforward and attention
165 |
166 | class GEGLU(Module):
167 | def forward(self, x):
168 | x, gate = x.chunk(2, dim = -1)
169 | return F.silu(gate) * x
170 |
171 | def FeedForward(dim, mult = 4):
172 | dim_inner = int(dim * mult * 2 / 3)
173 |
174 | return nn.Sequential(
175 | nn.RMSNorm(dim),
176 | nn.Linear(dim, dim_inner * 2),
177 | GEGLU(),
178 | nn.Linear(dim_inner, dim)
179 | )
180 |
181 | class SegmentedAttention(Module):
182 | def __init__(
183 | self,
184 | dim,
185 | segment_len,
186 | num_persist_mem_tokens = 0,
187 | num_longterm_mem_tokens = 0,
188 | dim_head = 64,
189 | heads = 8,
190 | sliding = False,
191 | accept_value_residual = False,
192 | attend_kwargs: dict = dict(),
193 | use_flex_attn = False
194 | ):
195 | super().__init__()
196 | self.norm = nn.RMSNorm(dim)
197 |
198 | dim_inner = dim_head * heads
199 |
200 | self.rotary_emb = RotaryEmbedding(dim_head)
201 |
202 | self.attend = Attend(causal = True, **attend_kwargs)
203 |
204 | self.to_qkv = LinearNoBias(dim, dim_inner * 3)
205 | self.to_out = LinearNoBias(dim_inner, dim)
206 |
207 | self.to_learned_v_mix = nn.Sequential(
208 | nn.Linear(dim, heads),
209 | Rearrange('b n h -> b h n 1'),
210 | nn.Sigmoid()
211 | ) if accept_value_residual else None
212 |
213 | self.segment_len = segment_len
214 | self.num_longterm_mem_tokens = num_longterm_mem_tokens
215 |
216 | total_segment_len = segment_len + num_longterm_mem_tokens
217 | self.total_segment_len = total_segment_len
218 |
219 | self.sliding = sliding # sliding window attn - doubt their non-sliding results being the best. local attention with overlapping windows is very strong
220 |
221 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
222 | self.merge_heads = Rearrange('b h n d -> b n (h d)')
223 |
224 | self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
225 |
226 | # flex attn related
227 |
228 | assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
229 | self.use_flex_attn = use_flex_attn
230 |
231 | self.segment_len = segment_len
232 | self.num_persist_mem_tokens = num_persist_mem_tokens
233 |
234 | def forward_inference(
235 | self,
236 | token,
237 | cache,
238 | value_residual = None,
239 | output_gating = None,
240 | ):
241 | batch = token.shape[0]
242 |
243 | # attention
244 |
245 | token = self.norm(token)
246 |
247 | q, k, v = self.to_qkv(token).chunk(3, dim = -1)
248 | q, k, v = map(self.split_heads, (q, k, v))
249 |
250 | # value residual
251 |
252 | orig_v = v
253 |
254 | if exists(self.to_learned_v_mix):
255 | mix = self.to_learned_v_mix(token)
256 | v = v.lerp(value_residual, mix)
257 |
258 | # caching
259 |
260 | ck, cv = cache
261 | k = cat((ck, k), dim = -2)
262 | v = cat((cv, v), dim = -2)
263 |
264 | next_cache = (k, v)
265 |
266 | # relative positions
267 |
268 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
269 |
270 | # fold
271 |
272 | q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
273 |
274 | # take care of persistent memory key / values
275 |
276 | pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
277 |
278 | # persistent memory
279 |
280 | k = cat((pmk, k), dim = -2)
281 | v = cat((pmv, v), dim = -2)
282 |
283 | # attention
284 |
285 | out, _ = self.attend(q, k, v)
286 |
287 | out = self.merge_heads(out)
288 |
289 | out = self.to_out(out)
290 |
291 | if exists(output_gating):
292 | out = out * output_gating
293 |
294 | return out, AttnIntermediates(orig_v, next_cache)
295 |
296 | def forward_flex(
297 | self,
298 | seq,
299 | value_residual = None,
300 | flex_attn_fn: Callable | None = None,
301 | output_gating = None,
302 | cache = None
303 | ):
304 |
305 | assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
306 |
307 | batch, seq_len = seq.shape[:2]
308 |
309 | # attention
310 |
311 | seq = self.norm(seq)
312 |
313 | q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
314 | q, k, v = map(self.split_heads, (q, k, v))
315 |
316 | # value residual
317 |
318 | orig_v = v
319 |
320 | if exists(self.to_learned_v_mix):
321 | mix = self.to_learned_v_mix(seq)
322 | v = v.lerp(value_residual, mix)
323 |
324 | # caching
325 |
326 | next_cache = (k, v)
327 |
328 | # take care of persistent memory key / values
329 |
330 | pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
331 |
332 | # relative positions
333 |
334 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
335 |
336 | # persistent memory
337 |
338 | k = cat((pmk, k), dim = -2)
339 | v = cat((pmv, v), dim = -2)
340 |
341 | # prep flex attention
342 |
343 | if not exists(flex_attn_fn):
344 | block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens, self.sliding)
345 |
346 | flex_attn_fn = partial(flex_attention, block_mask = block_mask)
347 |
348 | # attention
349 |
350 | out = flex_attn_fn(q, k, v)
351 |
352 | out = self.merge_heads(out)
353 |
354 | out = self.to_out(out)
355 |
356 | if exists(output_gating):
357 | out = out * output_gating
358 |
359 | return out, AttnIntermediates(orig_v, next_cache)
360 |
361 | def forward(
362 | self,
363 | seq,
364 | value_residual = None,
365 | flex_attn_fn: Callable | None = None,
366 | disable_flex_attn = False,
367 | output_gating = None,
368 | cache = None
369 | ):
370 | is_inferencing = exists(cache)
371 |
372 | if is_inferencing:
373 | assert seq.shape[-2] == 1
374 | return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
375 |
376 | if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
377 | return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
378 |
379 | assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
380 |
381 | segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
382 | total_segment_len = segment_len + num_longterm_mem_tokens
383 |
384 | batch, seq_len = seq.shape[:2]
385 |
386 | # auto pad to multiple
387 |
388 | seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len, fold_into_batch = False)
389 |
390 | # attention
391 |
392 | seq = self.norm(seq)
393 |
394 | q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
395 | q, k, v = map(self.split_heads, (q, k, v))
396 |
397 | # value residual
398 |
399 | orig_v = v
400 |
401 | if exists(self.to_learned_v_mix):
402 | mix = self.to_learned_v_mix(seq)
403 | v = v.lerp(value_residual, mix)
404 |
405 | # caching
406 |
407 | next_cache = tuple(map(inverse_segment, (k, v)))
408 |
409 | # relative positions
410 |
411 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
412 |
413 | # fold
414 |
415 | q, k, v = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = total_segment_len) for t in (q, k, v))
416 |
417 | # maybe sliding for cpu
418 |
419 | attend_kwargs = dict()
420 |
421 | if self.sliding:
422 | k, v = tuple(rearrange(t, '(b w) ... -> b w ...', b = batch) for t in (k, v))
423 | k, v = tuple(pad_at_dim(t, (1, 0), value = 0., dim = 1) for t in (k, v))
424 | k = cat((k[:, :-1], k[:, 1:]), dim = -2)
425 | v = cat((v[:, :-1], v[:, 1:]), dim = -2)
426 | k, v = tuple(rearrange(t, 'b w ... -> (b w) ...') for t in (k, v))
427 |
428 | # take care of masking
429 |
430 | idx = torch.arange(seq.shape[-2], device = seq.device)
431 | q_idx = rearrange(idx, '(w n) -> w n', n = total_segment_len)
432 | k_idx = pad_at_dim(q_idx, (1, 0), dim = 0, value = -1e4)
433 | k_idx = cat((k_idx[:-1], k_idx[1:]), dim = -1)
434 |
435 | q_idx = rearrange(q_idx, 'w i -> w i 1')
436 | k_idx = rearrange(k_idx, 'w j -> w 1 j')
437 |
438 | sliding_mask = (q_idx - k_idx) <= total_segment_len
439 | sliding_mask = F.pad(sliding_mask, (self.num_persist_mem_tokens, 0), value = True)
440 |
441 | sliding_mask = repeat(sliding_mask, 'w i j -> (b w) 1 i j', b = batch)
442 | attend_kwargs.update(mask = sliding_mask)
443 |
444 | # take care of persistent memory key / values
445 |
446 | pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
447 |
448 | # persistent memory
449 |
450 | k = cat((pmk, k), dim = -2)
451 | v = cat((pmv, v), dim = -2)
452 |
453 | # attention
454 |
455 | out, _ = self.attend(q, k, v, **attend_kwargs)
456 |
457 | out = self.merge_heads(out)
458 |
459 | out = self.to_out(out)
460 |
461 | out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
462 |
463 | out = inverse_segment(out)
464 |
465 | if exists(output_gating):
466 | out = out * output_gating
467 |
468 | return out, AttnIntermediates(orig_v, next_cache)
469 |
470 | # MAC transformer
471 |
472 | class MemoryAsContextTransformer(Module):
473 | def __init__(
474 | self,
475 | *,
476 | num_tokens,
477 | dim,
478 | depth,
479 | segment_len,
480 | neural_memory_segment_len = None,
481 | neural_mem_gate_attn_output = False,
482 | neural_memory_add_value_residual = False,
483 | num_longterm_mem_tokens = 0,
484 | num_persist_mem_tokens = 0,
485 | neural_memory_batch_size = None,
486 | neural_memory_qkv_receives_diff_views = False,
487 | dim_head = 64,
488 | heads = 8,
489 | ff_mult = 4,
490 | num_residual_streams = 4,
491 | neural_memory_model: Module | None = None,
492 | neural_memory_kwargs: dict = dict(),
493 | neural_memory_layers: tuple[int, ...] | None = None,
494 | use_flex_attn = False,
495 | sliding_window_attn = False,
496 | neural_mem_weight_residual = False,
497 | token_emb: Module | None = None,
498 | ):
499 | super().__init__()
500 |
501 | if not exists(token_emb):
502 | token_emb = nn.Embedding(num_tokens, dim)
503 |
504 | self.token_emb = token_emb
505 |
506 | # absolute positions
507 |
508 | self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
509 |
510 | # long term mem tokens
511 |
512 | self.segment_len = segment_len
513 |
514 | self.num_longterm_mem_tokens = num_longterm_mem_tokens
515 | has_longterm_mems = num_longterm_mem_tokens > 0
516 |
517 | self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
518 |
519 | # maybe sliding window attn
520 |
521 | self.sliding_window_attn = sliding_window_attn
522 | self.attn_window_size = segment_len + num_longterm_mem_tokens
523 |
524 | # hyper connection
525 |
526 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
527 |
528 | self.layers = ModuleList([])
529 |
530 | self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len)
531 |
532 | layers = tuple(range(1, depth + 1))
533 |
534 | neural_memory_layers = default(neural_memory_layers, layers)
535 |
536 | # weight residual related
537 |
538 | self.neural_mem_weight_residual = neural_mem_weight_residual
539 | is_first_neural_mem = True
540 |
541 | # mem, attn, and feedforward layers
542 |
543 | for layer in layers:
544 | is_first = layer == 1
545 |
546 | # attention and feedforward
547 |
548 | attn = SegmentedAttention(
549 | dim = dim,
550 | dim_head = dim_head,
551 | heads = heads,
552 | segment_len = segment_len,
553 | use_flex_attn = use_flex_attn,
554 | accept_value_residual = not is_first,
555 | num_longterm_mem_tokens = num_longterm_mem_tokens,
556 | num_persist_mem_tokens = num_persist_mem_tokens,
557 | sliding = sliding_window_attn
558 | )
559 |
560 | mem = None
561 | mem_qkv_layer_selector = None
562 | mem_hyper_conn = None
563 |
564 | if layer in neural_memory_layers:
565 | mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
566 |
567 | if not is_first and neural_memory_qkv_receives_diff_views:
568 | num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
569 |
570 | mem_qkv_layer_selector = nn.Sequential(
571 | nn.RMSNorm(dim),
572 | nn.Linear(dim, 3 * num_layer_choices),
573 | Rearrange('... (views layers) -> views ... layers', views = 3),
574 | nn.Softmax(dim = -1)
575 | )
576 |
577 | mem = NeuralMemory(
578 | dim = dim,
579 | chunk_size = self.neural_memory_segment_len,
580 | batch_size = neural_memory_batch_size,
581 | model = deepcopy(neural_memory_model),
582 | qkv_receives_diff_views = True,
583 | accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
584 | **neural_memory_kwargs
585 | )
586 |
587 | is_first_neural_mem = False
588 |
589 | ff = FeedForward(dim = dim, mult = ff_mult)
590 |
591 | self.layers.append(ModuleList([
592 | mem_hyper_conn,
593 | init_hyper_conn(),
594 | init_hyper_conn(),
595 | mem_qkv_layer_selector,
596 | mem,
597 | attn,
598 | ff,
599 | ]))
600 |
601 | self.norm = nn.RMSNorm(dim)
602 |
603 | self.to_logits = LinearNoBias(dim, num_tokens)
604 |
605 | # whether to gate the attention output with the retrieved memories
606 |
607 | self.gate_attn_output = neural_mem_gate_attn_output
608 |
609 | # zero for maybe aux loss + device
610 |
611 | self.register_buffer('zero', torch.tensor(0.), persistent = False)
612 |
613 | # flex attn related
614 |
615 | assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
616 | self.use_flex_attn = use_flex_attn
617 |
618 | self.num_persist_mem_tokens = num_persist_mem_tokens
619 |
620 | def seq_index_is_longterm(
621 | self,
622 | seq_index
623 | ):
624 | total_segment_len, segment_len = self.attn_window_size, self.segment_len
625 | return ((seq_index % total_segment_len + 1) - segment_len) > 0
626 |
627 | def seq_len_with_longterm_mem(
628 | self,
629 | seq_len
630 | ):
631 | assert seq_len > 0
632 |
633 | segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
634 | return ((seq_len - 1) // segment_len) * num_mem + seq_len
635 |
636 | @torch.no_grad()
637 | def sample(
638 | self,
639 | prompt: Tensor,
640 | seq_len: int,
641 | temperature = 1.5,
642 | filter_fn: Callable = min_p_filter,
643 | filter_kwargs: dict = dict(
644 | min_p = 0.1,
645 | ),
646 | show_progress = True,
647 | use_cache = False
648 | ):
649 | was_training = self.training
650 | self.eval()
651 |
652 | prompt_seq_len, out = prompt.shape[-1], prompt.clone()
653 | sample_num_times = max(0, seq_len - prompt_seq_len)
654 |
655 | # cache for axial pos, attention, and neural memory
656 |
657 | cache = None
658 | factorized_pos_emb = None
659 |
660 | # precompute factorized pos emb
661 |
662 | if use_cache:
663 | seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
664 |
665 | axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
666 |
667 | factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
668 |
669 | # sample
670 |
671 | with tqdm.tqdm(total = sample_num_times, disable = not show_progress) as pbar:
672 |
673 | while out.shape[-1] < seq_len:
674 |
675 | logits, next_cache = self.forward(
676 | out,
677 | disable_flex_attn = True,
678 | cache = cache,
679 | return_cache = True,
680 | factorized_pos_emb = factorized_pos_emb
681 | )
682 |
683 | if use_cache:
684 | cache = next_cache
685 |
686 | if not exists(logits):
687 | continue
688 |
689 | logits = logits[:, -1]
690 |
691 | logits = filter_fn(logits, **filter_kwargs)
692 | sample = gumbel_sample(logits, temperature = temperature)
693 |
694 | out = torch.cat((out, sample), dim = -1)
695 | pbar.update(1)
696 |
697 | self.train(was_training)
698 |
699 | return out[..., prompt_seq_len:]
700 |
701 | def forward(
702 | self,
703 | x,
704 | return_loss = False,
705 | return_loss_breakdown = False,
706 | disable_flex_attn = False,
707 | cache = None,
708 | return_cache = False,
709 | factorized_pos_emb = None
710 | ):
711 |
712 | if return_loss:
713 | x, labels = x[:, :-1], x[:, 1:]
714 |
715 | # math
716 |
717 | batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
718 |
719 | seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
720 |
721 | # token embedding
722 |
723 | x = self.token_emb(x)
724 |
725 | # intersperse longterm memory
726 |
727 | x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False)
728 |
729 | mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
730 | x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
731 |
732 | x = inverse_segment(x)
733 |
734 | # splice out unneeded tokens from padding for longterm mems
735 |
736 | x = x[:, :seq_len_with_mem]
737 |
738 | # apply axial positional embedding
739 | # so intra and inter segment can be more easily discerned by the network
740 |
741 | pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
742 |
743 | x = x + pos_emb
744 |
745 | # prep flex attention
746 |
747 | use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn
748 |
749 | flex_attn_fn = None
750 |
751 | if use_flex_attn:
752 | block_mask = create_mac_block_mask(seq_len_with_mem, self.attn_window_size, self.num_persist_mem_tokens, self.sliding_window_attn)
753 | flex_attn_fn = partial(flex_attention, block_mask = block_mask)
754 |
755 | # kv caching
756 |
757 | is_inferencing = exists(cache)
758 |
759 | if not exists(cache):
760 | cache = (seq_len_with_mem - 1, None, None)
761 |
762 | inference_seq_index, kv_caches, neural_mem_caches = cache
763 |
764 | kv_caches = iter(default(kv_caches, []))
765 | neural_mem_caches = iter(default(neural_mem_caches, []))
766 |
767 | next_kv_caches = []
768 | next_neural_mem_caches = []
769 |
770 | # value residual
771 |
772 | value_residual = None
773 |
774 | # neural mem weight residual
775 |
776 | mem_weight_residual = None
777 |
778 | # layers for the neural mem to select the qkv inputs from
779 |
780 | mem_input_layers = []
781 |
782 | # when inferencing, only do one token at a time
783 |
784 | if is_inferencing:
785 | ind = inference_seq_index
786 | x = x[:, ind:(ind + 1)]
787 |
788 | # expand and reduce streams for hyper connections
789 |
790 | x = self.expand_streams(x)
791 |
792 | for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
793 |
794 | retrieved = None
795 | attn_out_gates = None
796 | next_neural_mem_cache = None
797 |
798 | # maybe neural memory
799 |
800 | if exists(mem):
801 |
802 | mem_input, add_residual = mem_hyper_conn(x)
803 |
804 | if not exists(mem_qkv_layer_selector):
805 | qkv_mem_input = stack((mem_input, mem_input, mem_input))
806 | else:
807 | layers_to_choose_from = stack((mem_input, *mem_input_layers))
808 |
809 | # let the current `mem_input` select the 3 layers for qkv
810 |
811 | selected = mem_qkv_layer_selector(mem_input)
812 |
813 | qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
814 |
815 | retrieved, next_neural_mem_cache = mem.forward(
816 | qkv_mem_input,
817 | state = next(neural_mem_caches, None),
818 | prev_weights = mem_weight_residual
819 | )
820 |
821 | if self.neural_mem_weight_residual:
822 | mem_weight_residual = next_neural_mem_cache.updates
823 |
824 | if self.gate_attn_output:
825 | attn_out_gates = retrieved.sigmoid()
826 | else:
827 | x = add_residual(retrieved)
828 |
829 | # attention
830 |
831 | attn_in, add_residual = attn_hyper_conn(x)
832 |
833 | mem_input_layers.append(attn_in)
834 |
835 | attn_out, (values, next_kv_cache) = attn(
836 | attn_in,
837 | value_residual = value_residual,
838 | disable_flex_attn = disable_flex_attn,
839 | flex_attn_fn = flex_attn_fn,
840 | output_gating = attn_out_gates,
841 | cache = next(kv_caches, None)
842 | )
843 |
844 | mem_input_layers.append(attn_out)
845 |
846 | value_residual = default(value_residual, values)
847 |
848 | x = add_residual(attn_out)
849 |
850 | # caches
851 |
852 | next_kv_caches.append(next_kv_cache)
853 | next_neural_mem_caches.append(next_neural_mem_cache)
854 |
855 | # feedforward
856 |
857 | ff_in, add_ff_residual = ff_hyper_conn(x)
858 |
859 | mem_input_layers.append(ff_in)
860 |
861 | ff_out = ff(ff_in)
862 |
863 | mem_input_layers.append(ff_out)
864 |
865 | x = add_ff_residual(ff_out)
866 |
867 | # taking care of cache first
868 | # for early return when processing long term mem tokens during inference
869 |
870 | if return_cache:
871 | next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
872 |
873 | # handle kv cache length depending on local attention type
874 |
875 | next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
876 |
877 | kv_cache_length = next_kv_caches.shape[-2]
878 |
879 | if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size):
880 | next_kv_caches = next_kv_caches[..., 0:0, :]
881 |
882 | next_cache = (
883 | inference_seq_index + 1,
884 | next_kv_caches,
885 | next_neural_mem_caches
886 | )
887 |
888 | is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
889 |
890 | if is_inferencing and is_longterm_mem:
891 | return None, next_cache
892 |
893 | # hyper connection reducing of streams
894 |
895 | x = self.reduce_streams(x)
896 |
897 | # excise out the memories
898 |
899 | if not is_inferencing:
900 |
901 | x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False)
902 |
903 | x, _ = inverse_pack_mems(x)
904 |
905 | x = inverse_segment(x)
906 |
907 | x = x[:, :seq_len]
908 |
909 | # to logits
910 |
911 | x = self.norm(x)
912 |
913 | logits = self.to_logits(x)
914 |
915 | if not return_loss:
916 | if not return_cache:
917 | return logits
918 |
919 | return logits, next_cache
920 |
921 | return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
922 |
--------------------------------------------------------------------------------
/titans_pytorch/memory_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, cat
3 | import torch.nn.functional as F
4 | from torch.nn import Module, ModuleList, Parameter, ParameterList
5 |
6 | from einops import rearrange
7 |
8 | # functions
9 |
10 | def l2norm(t):
11 | return F.normalize(t, dim = -1)
12 |
13 | # norms
14 |
15 | class LayerNorm(Module):
16 | def __init__(
17 | self,
18 | dim
19 | ):
20 | super().__init__()
21 |
22 | self.ln = nn.LayerNorm(dim, elementwise_affine = False)
23 | self.gamma = Parameter(torch.zeros(dim))
24 |
25 | def forward(self, x):
26 | gamma = self.gamma
27 |
28 | if gamma.ndim == 2:
29 | gamma = rearrange(gamma, 'b d -> b 1 d')
30 |
31 | return self.ln(x) * (gamma + 1.)
32 |
33 | # norm + residual wrapper, as used in original TTT paper
34 | # but could be removed
35 |
36 | class ResidualNorm(Module):
37 | def __init__(
38 | self,
39 | dim,
40 | model: Module
41 | ):
42 | super().__init__()
43 | self.norm = LayerNorm(dim)
44 | self.model = model
45 |
46 | def forward(self, x):
47 |
48 | out = self.model(x)
49 |
50 | return self.norm(out) + x
51 |
52 | # memory mlp proposed in TTT
53 |
54 | class MemoryMLP(Module):
55 | def __init__(
56 | self,
57 | dim,
58 | depth,
59 | expansion_factor = 2.
60 | ):
61 | super().__init__()
62 | dim_hidden = int(dim * expansion_factor)
63 | dims = (dim, *((dim_hidden,) * (depth - 1)), dim)
64 |
65 | self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
66 |
67 | for weight in self.weights:
68 | nn.init.xavier_uniform_(weight)
69 |
70 | def forward(
71 | self,
72 | x
73 | ):
74 | for ind, weight in enumerate(self.weights):
75 | is_first = ind == 0
76 |
77 | if not is_first:
78 | x = F.gelu(x)
79 |
80 | x = x @ weight
81 |
82 | return x
83 |
84 | # memory mlp, but with gated residual + final projection
85 |
86 | class GatedResidualMemoryMLP(Module):
87 | def __init__(
88 | self,
89 | dim,
90 | depth,
91 | expansion_factor = 4.
92 | ):
93 | super().__init__()
94 | dim_hidden = int(dim * expansion_factor)
95 |
96 | self.weights = ParameterList([
97 | ParameterList([
98 | Parameter(torch.randn(dim, dim_hidden)),
99 | Parameter(torch.randn(dim_hidden, dim)),
100 | Parameter(torch.randn(dim * 2, dim)),
101 | ]) for _ in range(depth)
102 | ])
103 |
104 | self.final_proj = Parameter(torch.randn(dim, dim))
105 |
106 | for param in self.parameters():
107 | nn.init.xavier_uniform_(param)
108 |
109 | def forward(
110 | self,
111 | x
112 | ):
113 |
114 | for weight1, weight2, to_gates in self.weights:
115 | res = x
116 |
117 | hidden = x @ weight1
118 | hidden = F.gelu(hidden)
119 | branch_out = hidden @ weight2
120 |
121 | # gated residual
122 |
123 | gates = cat((branch_out, res), dim = -1) @ to_gates
124 | x = res.lerp(branch_out, gates.sigmoid())
125 |
126 | return x @ self.final_proj
127 |
128 | # memory mlp with factorized weights
129 | # so can tradeoff capacity for smaller chunk sizes
130 |
131 | class FactorizedMemoryMLP(Module):
132 | def __init__(
133 | self,
134 | dim,
135 | depth,
136 | k = 32
137 | ):
138 | super().__init__()
139 | self.weights = ParameterList([
140 | ParameterList([
141 | Parameter(torch.randn(dim, k)),
142 | Parameter(torch.randn(k, dim)),
143 | ]) for _ in range(depth)
144 | ])
145 |
146 | for weight1, weight2 in self.weights:
147 | nn.init.xavier_uniform_(weight1)
148 | nn.init.xavier_uniform_(weight2)
149 |
150 | def forward(
151 | self,
152 | x
153 | ):
154 |
155 | for ind, (weight1, weight2) in enumerate(self.weights):
156 | is_first = ind == 0
157 |
158 | if not is_first:
159 | x = F.gelu(x)
160 |
161 | x = x @ weight1 @ weight2
162 |
163 | return x
164 |
165 | # an MLP modelled after the popular swiglu ff in modern transformers
166 |
167 | class MemorySwiGluMLP(Module):
168 | def __init__(
169 | self,
170 | dim,
171 | depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual
172 | expansion_factor = 4.
173 | ):
174 | super().__init__()
175 |
176 | dim_inner = int(dim * expansion_factor * 2 / 3)
177 |
178 | weights = []
179 |
180 | for _ in range(depth):
181 | weights.append(ParameterList([
182 | Parameter(torch.randn(dim, dim_inner * 2)),
183 | Parameter(torch.randn(dim_inner, dim)),
184 | ]))
185 |
186 | self.weights = ParameterList(weights)
187 | self.norm = LayerNorm(dim)
188 |
189 | def forward(self, x):
190 |
191 | for w1, w2 in self.weights:
192 | residual = x
193 |
194 | x, gates = (x @ w1).chunk(2, dim = -1)
195 |
196 | x = x * F.gelu(gates)
197 |
198 | x = x @ w2
199 |
200 | x = x + residual
201 |
202 | return self.norm(x)
203 |
204 | # improvised attention as memory module
205 |
206 | class MemoryAttention(Module):
207 | def __init__(
208 | self,
209 | dim,
210 | scale = 8.,
211 | expansion_factor = 2.
212 | ):
213 | super().__init__()
214 | self.scale = scale
215 | dim_ff_hidden = int(dim * expansion_factor)
216 |
217 | self.weights = ParameterList([
218 | Parameter(torch.randn(dim, dim)), # queries
219 | Parameter(torch.randn(dim, dim)), # keys
220 | Parameter(torch.randn(dim, dim)), # values
221 | Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
222 | Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
223 | ])
224 |
225 | for weight in self.weights:
226 | nn.init.xavier_uniform_(weight)
227 |
228 | def forward(self, x):
229 |
230 | wq, wk, wv, ffw1, ffw2 = self.weights
231 |
232 | q = l2norm(x @ wq)
233 | k = l2norm(x @ wk)
234 | v = x @ wv
235 |
236 | attn_out = F.scaled_dot_product_attention(
237 | q, k, v,
238 | scale = self.scale,
239 | is_causal = True
240 | )
241 |
242 | # parallel attention + feedforward block
243 | # as in PaLM + Gpt-J
244 |
245 | h = F.gelu(x @ ffw1)
246 | ff_out = h @ ffw2
247 |
248 | return attn_out + ff_out
249 |
--------------------------------------------------------------------------------
/titans_pytorch/neural_memory.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Callable
3 |
4 | import math
5 | from functools import partial
6 | from itertools import zip_longest
7 | from collections import namedtuple
8 |
9 | import torch
10 | from torch import nn, stack, cat, is_tensor, tensor, Tensor
11 | import torch.nn.functional as F
12 | from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
13 | from torch.func import functional_call, vmap, grad
14 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
15 |
16 | from tensordict import TensorDict
17 |
18 | from assoc_scan import AssocScan
19 |
20 | from titans_pytorch.memory_models import(
21 | MemoryMLP,
22 | ResidualNorm
23 | )
24 |
25 | import einx
26 | from einops import einsum, rearrange, repeat, reduce, pack, unpack
27 | from einops.layers.torch import Rearrange, Reduce
28 |
29 | """
30 | ein notation:
31 | b - batch
32 | h - heads
33 | bh - batch and heads
34 | n - sequence
35 | d - feature dimension
36 | c - intra-chunk
37 | w - num memory network weight parameters
38 | o - momentum orders
39 | u - key / value updates - allowing a token to emit multiple key / values
40 | """
41 |
42 | LinearNoBias = partial(Linear, bias = False)
43 |
44 | # neural mem state related
45 |
46 | NeuralMemState = namedtuple('NeuralMemState', [
47 | 'seq_index',
48 | 'weights',
49 | 'cache_store_segment',
50 | 'states',
51 | 'updates',
52 | ])
53 |
54 | def mem_state_detach(
55 | state: NeuralMemState
56 | ):
57 | assert isinstance(state, NeuralMemState)
58 | state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
59 | return NeuralMemState(*state)
60 |
61 | # functions
62 |
63 | def exists(v):
64 | return v is not None
65 |
66 | def default(*args):
67 | for arg in args:
68 | if exists(arg):
69 | return arg
70 | return None
71 |
72 | def identity(t):
73 | return t
74 |
75 | def xnor(x, y):
76 | return not (x ^ y)
77 |
78 | def divisible_by(num, den):
79 | return (num % den) == 0
80 |
81 | def safe_cat(inputs, dim = -2):
82 | inputs = tuple(filter(exists, inputs))
83 |
84 | if len(inputs) == 0:
85 | return None
86 | elif len(inputs) == 1:
87 | return inputs[0]
88 |
89 | return cat(inputs, dim = dim)
90 |
91 | def is_empty_tensor(t):
92 | return t.numel() == 0
93 |
94 | def dict_get_value_shapes(td):
95 | return [v.shape for k, v in td.items()]
96 |
97 | def rearrange_dict_values(td, pattern, **kwargs):
98 | return td.apply(lambda t: rearrange(t, pattern, **kwargs))
99 |
100 | def repeat_dict_values(td, pattern, **kwargs):
101 | return td.apply(lambda t: repeat(t, pattern, **kwargs))
102 |
103 | def pair(v):
104 | return (v, v) if not isinstance(v, tuple) else v
105 |
106 | def round_down_multiple(seq, mult):
107 | return seq // mult * mult
108 |
109 | def round_up_multiple(seq, mult):
110 | return math.ceil(seq / mult) * mult
111 |
112 | def pad_at_dim(t, pad, dim = -1, value = 0.):
113 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
114 | zeros = ((0, 0) * dims_from_right)
115 | return F.pad(t, (*zeros, *pad), value = value)
116 |
117 | def pack_one_with_inverse(t, pattern):
118 | packed, packed_shape = pack([t], pattern)
119 |
120 | def inverse(out, inv_pattern = None):
121 | inv_pattern = default(inv_pattern, pattern)
122 | return unpack(out, packed_shape, inv_pattern)[0]
123 |
124 | return packed, inverse
125 |
126 | def Sequential(*modules):
127 | modules = [*filter(exists, modules)]
128 |
129 | if len(modules) == 0:
130 | return nn.Identity()
131 |
132 | if len(modules) == 1:
133 | return modules[0]
134 |
135 | return nn.Sequential(*modules)
136 |
137 | # softclamping gradients
138 |
139 | def softclamp_max(t, max_value):
140 | half_max_value = max_value / 2
141 | return ((t / half_max_value).tanh() * half_max_value) + half_max_value
142 |
143 | def softclamp_grad_norm(t, max_value):
144 | if is_empty_tensor(t):
145 | return t
146 |
147 | t, inverse = pack_one_with_inverse(t, 'bn *')
148 |
149 | norm = t.norm(dim = -1, keepdim = True)
150 | clamped_norm = softclamp_max(norm, max_value)
151 |
152 | t = t * (clamped_norm / norm)
153 | return inverse(t)
154 |
155 | # spectral norming the surprise update w/ newton schulz matrix iter
156 | # Keller Jordan et al. from OSS w/ nanogpt, now being used for two works, Atlas and 'TTT done right'
157 |
158 | def newtonschulz5(
159 | t,
160 | steps = 5,
161 | eps = 1e-7,
162 | coefs = (3.4445, -4.7750, 2.0315)
163 | ):
164 | if t.ndim <= 3:
165 | return t
166 |
167 | shape = t.shape
168 | should_transpose = shape[-2] > shape[-1]
169 |
170 | if should_transpose:
171 | t = t.transpose(-1, -2)
172 |
173 | t, inv_pack = pack_one_with_inverse(t, '* i j')
174 | t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
175 |
176 | a, b, c = coefs
177 |
178 | for _ in range(steps):
179 | A = t @ t.transpose(-1, -2)
180 | B = b * A + c * A @ A
181 | t = a * t + B @ t
182 |
183 | if should_transpose:
184 | t = t.transpose(-1, -2)
185 |
186 | return inv_pack(t)
187 |
188 | # multi head rmsnorm
189 |
190 | class MultiheadRMSNorm(Module):
191 | def __init__(self, dim, heads):
192 | super().__init__()
193 | self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
194 | self.gamma = Parameter(torch.zeros(heads, 1, dim))
195 |
196 | def forward(self, x):
197 | return self.rmsnorm(x) * (self.gamma + 1.)
198 |
199 | # chunk pooling
200 |
201 | class AveragePool(Module):
202 | def __init__(
203 | self,
204 | chunk_size
205 | ):
206 | super().__init__()
207 | self.chunk_size = chunk_size
208 |
209 | def forward(
210 | self,
211 | x,
212 | chunk_size = None
213 | ):
214 | chunk_size = default(chunk_size, self.chunk_size)
215 | return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size)
216 |
217 | class AttentionPool(Module):
218 | def __init__(
219 | self,
220 | dim,
221 | chunk_size
222 | ):
223 | """
224 | taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
225 | """
226 | super().__init__()
227 | self.chunk_size = chunk_size
228 | self.to_attn_logits = nn.Linear(dim, dim)
229 |
230 | # default to average pool
231 |
232 | nn.init.zeros_(self.to_attn_logits.weight)
233 | nn.init.zeros_(self.to_attn_logits.bias)
234 |
235 | def forward(
236 | self,
237 | x,
238 | chunk_size = None
239 | ):
240 | chunk_size = default(chunk_size, self.chunk_size)
241 |
242 | x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size)
243 |
244 | attn_logits = self.to_attn_logits(x)
245 |
246 | attn = attn_logits.softmax(dim = -2)
247 |
248 | return reduce(x * attn, 'b n c d -> b n d', 'sum')
249 |
250 | # main neural memory
251 |
252 | def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
253 | return adaptive_step.sigmoid() * max_lr
254 |
255 | def default_loss_fn(pred, target):
256 | return (pred - target).pow(2).mean(dim = -1)
257 |
258 | class NeuralMemory(Module):
259 | def __init__(
260 | self,
261 | dim,
262 | chunk_size: int | tuple[int, int] = 1,
263 | batch_size = None,
264 | dim_head = None,
265 | heads = 1,
266 | model: Module | None = None,
267 | store_memory_loss_fn: Callable = default_loss_fn,
268 | adaptive_step_transform: Callable | None = None,
269 | default_step_transform_max_lr = 1.,
270 | per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
271 | max_mem_layer_modulation = 1., # max of 10.
272 | per_head_learned_parameters = True,
273 | attn_pool_chunks = False,
274 | momentum = True,
275 | momentum_order = 1,
276 | learned_momentum_combine = False,
277 | learned_combine_include_zeroth = False,
278 | num_kv_per_token = 1, # whether a single token can do multiple updates to the memory model
279 | qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
280 | pre_rmsnorm = True,
281 | post_rmsnorm = False,
282 | qk_rmsnorm = False,
283 | max_grad_norm: float | None = None,
284 | use_accelerated_scan = False,
285 | activation: Module | None = None,
286 | init_adaptive_step_bias = None,
287 | init_momentum_bias = None,
288 | init_decay_bias = None,
289 | accept_weight_residual = False,
290 | spectral_norm_surprises = False,
291 | gated_transition = False,
292 | mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
293 | default_model_kwargs: dict = dict(
294 | depth = 2,
295 | expansion_factor = 4.
296 | )
297 | ):
298 | super().__init__()
299 | dim_head = default(dim_head, dim)
300 | assert not (heads == 1 and dim_head != dim)
301 |
302 | self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
303 |
304 | # batch size
305 |
306 | if exists(batch_size):
307 | assert divisible_by(batch_size, self.store_chunk_size)
308 |
309 | self.batch_size = batch_size
310 |
311 | # associative scan
312 |
313 | self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
314 |
315 | # key values receiving different views
316 |
317 | self.qkv_receives_diff_views = qkv_receives_diff_views
318 |
319 | # norms
320 |
321 | self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
322 | self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
323 |
324 | self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
325 |
326 | self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
327 | self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
328 |
329 | # maybe multi-headed
330 |
331 | dim_inner = dim_head * heads
332 |
333 | self.heads = heads
334 |
335 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
336 | self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token)
337 |
338 | self.merge_heads = Rearrange('b h n d -> b n (h d)')
339 | self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
340 |
341 | self.retrieve_gate = Sequential(
342 | LinearNoBias(dim, heads),
343 | Rearrange('b n h -> b h n 1'),
344 | nn.Sigmoid()
345 | ) if heads > 1 else None
346 |
347 | # memory model
348 |
349 | if not exists(model):
350 | model = MemoryMLP(dim_head, **default_model_kwargs)
351 |
352 | # validate memory model
353 |
354 | assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
355 |
356 | test_shape = (3, 2, dim_head)
357 |
358 | with torch.no_grad():
359 | try:
360 | test_input = torch.randn(test_shape)
361 | mem_model_output = model(test_input)
362 | except:
363 | raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}')
364 |
365 | assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input'
366 |
367 | # the memory is the weights of the model
368 |
369 | if mem_model_norm_add_residual:
370 | model = ResidualNorm(dim = dim_head, model = model)
371 |
372 | self.memory_model = model
373 |
374 | mem_model_params = dict(model.named_parameters())
375 |
376 | self.num_memory_parameter_tensors = len(mem_model_params)
377 |
378 | self.memory_model_parameter_names = [*mem_model_params.keys()]
379 |
380 | memory_model_parameters = [*mem_model_params.values()]
381 |
382 | if per_head_learned_parameters:
383 | memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters]
384 |
385 | self.init_weight_shape = [p.shape for p in memory_model_parameters]
386 |
387 | self.memory_model_parameters = ParameterList(memory_model_parameters)
388 | self.per_head_learned_parameters = per_head_learned_parameters
389 |
390 | # the chunk size within the paper where adaptive step, momentum, weight decay are shared
391 |
392 | self.chunk_size = chunk_size
393 |
394 | # prepare function for per sample gradients from model above, using torch.func
395 |
396 | def forward_and_loss(params, inputs, loss_weights, target):
397 | pred = functional_call(self.memory_model, params, inputs)
398 | loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
399 | weighted_loss = loss * loss_weights
400 | return weighted_loss.sum(), loss
401 |
402 | # two functions
403 |
404 | grad_fn = grad(forward_and_loss, has_aux = True)
405 |
406 | self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
407 |
408 | # queries for retrieving from the model
409 |
410 | self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
411 |
412 | # keys and values for storing to the model
413 |
414 | assert num_kv_per_token > 0
415 |
416 | self.to_keys = Sequential(
417 | LinearNoBias(dim, dim_inner * num_kv_per_token),
418 | activation,
419 | )
420 |
421 | self.to_values = Sequential(
422 | LinearNoBias(dim, dim_inner * num_kv_per_token),
423 | activation,
424 | )
425 |
426 | self.store_memory_loss_fn = store_memory_loss_fn
427 |
428 | self.num_kv_per_token = num_kv_per_token
429 |
430 | # `chunk_size` refers to chunk size used for storing to memory model weights
431 |
432 | chunk_size = self.store_chunk_size
433 |
434 | # whether to use averaging of chunks, or attention pooling
435 |
436 | assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1'
437 |
438 | if not attn_pool_chunks:
439 | self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size)
440 | else:
441 | self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
442 |
443 | # learned adaptive learning rate
444 |
445 | self.to_adaptive_step = Sequential(
446 | nn.Linear(dim, heads * num_kv_per_token),
447 | Rearrange('b n (h u) -> (b h) (n u)', u = num_kv_per_token)
448 | )
449 |
450 | if not exists(adaptive_step_transform):
451 | adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr)
452 |
453 | self.adaptive_step_transform = adaptive_step_transform
454 |
455 | # momentum related
456 |
457 | self.to_momentum = Sequential(
458 | nn.Linear(dim, heads * momentum_order),
459 | Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
460 | ) if momentum else None
461 |
462 | self.momentum_order = momentum_order
463 | self.to_learned_momentum_combine = None
464 |
465 | if learned_momentum_combine:
466 | assert momentum
467 | assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
468 |
469 | if learned_combine_include_zeroth:
470 | momentum_order += 1
471 |
472 | self.to_learned_momentum_combine = Sequential(
473 | nn.Linear(dim, heads * momentum_order),
474 | Rearrange('b n (h o) -> o (b h) n', h = heads),
475 | nn.Softmax(dim = 0),
476 | )
477 |
478 | self.learned_combine_include_zeroth = learned_combine_include_zeroth
479 |
480 | # per layer learning rate modulation
481 |
482 | self.to_layer_modulation = Sequential(
483 | nn.Linear(dim, heads * self.num_memory_parameter_tensors),
484 | Rearrange('b n (h w) -> w (b h) n', h = heads),
485 | nn.Sigmoid()
486 | ) if per_parameter_lr_modulation else None
487 |
488 | self.max_mem_layer_modulation = max_mem_layer_modulation
489 |
490 | # learned weight residual
491 |
492 | self.to_learned_weight_residual_mix = Sequential(
493 | nn.Linear(dim, heads),
494 | Rearrange('b n h -> b h n'),
495 | nn.Sigmoid()
496 | ) if accept_weight_residual else None
497 |
498 | # allow for softclamp the gradient norms for storing memories
499 |
500 | self.max_grad_norm = max_grad_norm
501 |
502 | # spectral norming the surprises before update, a la Muon from Jordan et al.
503 |
504 | self.spectral_norm_surprises = spectral_norm_surprises
505 |
506 | # weight decay factor
507 |
508 | self.to_decay_factor = Sequential(
509 | nn.Linear(dim, heads),
510 | Rearrange('b n h -> (b h) n 1')
511 | )
512 |
513 | # learned transition, as seeing instability when decreasing neural mem batch size
514 | # perhaps it can slowly learn to adjust from early residual to fully transitioning to new weights every batch size
515 |
516 | self.transition_gate = nn.Parameter(tensor(-5.)) if gated_transition else None
517 |
518 | # inits
519 |
520 | if exists(init_adaptive_step_bias):
521 | linear = self.to_adaptive_step[0]
522 | nn.init.zeros_(linear.weight)
523 | nn.init.constant_(linear.bias, init_adaptive_step_bias)
524 |
525 | if exists(init_momentum_bias):
526 | linear = self.to_momentum[0]
527 | nn.init.zeros_(linear.weight)
528 | nn.init.constant_(linear.bias, init_momentum_bias)
529 |
530 | if exists(init_decay_bias):
531 | linear = self.to_decay_factor[0]
532 | nn.init.zeros_(linear.weight)
533 | nn.init.constant_(linear.bias, init_decay_bias)
534 |
535 | # maybe use accelerated scan
536 |
537 | self.use_accelerated_scan = use_accelerated_scan
538 |
539 | self.register_buffer('zero', torch.tensor(0.), persistent = False)
540 |
541 | @property
542 | def memory_model_parameter_dict(self):
543 | return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters)))
544 |
545 | def init_weights(
546 | self,
547 | batch,
548 | ):
549 | if self.per_head_learned_parameters:
550 | weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch)
551 | else:
552 | weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads)
553 |
554 | return weights
555 |
556 | def init_momentum(
557 | self,
558 | batch,
559 | ):
560 | zeros = self.memory_model_parameter_dict.clone().zero_()
561 |
562 | if self.per_head_learned_parameters:
563 | zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
564 | else:
565 | zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
566 |
567 | return zeros
568 |
569 | def store_memories(
570 | self,
571 | seq,
572 | weights: dict[str, Tensor] | None = None,
573 | past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
574 | seq_index = 0,
575 | prev_weights = None,
576 | mask: Tensor | None = None,
577 | return_surprises = True
578 | ):
579 | if self.qkv_receives_diff_views:
580 | _, batch, seq_len = seq.shape[:3]
581 | else:
582 | batch, seq_len = seq.shape[:2]
583 |
584 | # shapes and variables
585 |
586 | heads, chunk_size, num_updates = self.heads, self.store_chunk_size, self.num_kv_per_token
587 |
588 | # curtail sequence by multiple of the chunk size
589 | # only a complete chunk of the sequence provides the memory for the next chunk
590 |
591 | round_down_seq_len = round_down_multiple(seq_len, chunk_size)
592 | num_chunks = round_down_seq_len // chunk_size
593 |
594 | seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :]
595 |
596 | next_seq_len_index = seq_index + round_down_seq_len
597 |
598 | # init weights if needed
599 | # weights of the memory network
600 |
601 | if not exists(weights):
602 | weights = self.init_weights(batch)
603 |
604 | weights = TensorDict(weights)
605 |
606 | # allow for neural memory of a previous layer to influence surprise of current layer
607 |
608 | weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
609 |
610 | # initial norm
611 |
612 | seq = self.store_norm(seq)
613 |
614 | # handle keys and values coming from different sequences from hyper connection
615 |
616 | values_seq = seq
617 |
618 | if self.qkv_receives_diff_views:
619 | seq, values_seq = seq
620 |
621 | # derive learned hparams for optimization of memory network
622 |
623 | adaptive_lr = self.to_adaptive_step(seq)
624 | adaptive_lr = self.adaptive_step_transform(adaptive_lr)
625 |
626 | chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size)
627 |
628 | decay_factor = self.to_decay_factor(chunked_seq).sigmoid()
629 |
630 | need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0
631 | has_momentum = exists(self.to_momentum)
632 |
633 | if has_momentum:
634 | adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
635 |
636 | learned_combine = exists(self.to_learned_momentum_combine)
637 |
638 | if learned_combine:
639 | combine_momentums = self.to_learned_momentum_combine(chunked_seq)
640 |
641 | if need_layer_lr_mod:
642 | layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
643 |
644 | # keys and values
645 |
646 | keys = self.to_keys(seq)
647 | values = self.to_values(values_seq)
648 |
649 | # maybe multi head
650 |
651 | keys, values = map(self.split_kv_heads, (keys, values))
652 |
653 | # maybe keys rmsnorm
654 |
655 | keys = self.k_norm(keys)
656 |
657 | # take care of chunking
658 |
659 | keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values))
660 |
661 | # adaptive lr
662 |
663 | adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
664 |
665 | # optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
666 |
667 | if exists(mask):
668 | mask = mask[..., :round_down_seq_len]
669 | mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
670 |
671 | adaptive_lr = torch.where(mask, adaptive_lr, 0.)
672 |
673 | # maybe add previous layer weight
674 |
675 | assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
676 |
677 | if exists(prev_weights):
678 |
679 | start_index = math.ceil(seq_index / chunk_size)
680 | end_index = start_index + num_chunks
681 |
682 | prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index])
683 |
684 | if exists(self.to_learned_weight_residual_mix) and num_chunks > 0:
685 | mix = self.to_learned_weight_residual_mix(chunked_seq)
686 | mix = rearrange(mix, 'b h n -> (b h) n')
687 | prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t))
688 |
689 | weights_for_surprise = weights_for_surprise + prev_weights
690 |
691 | # flatten batch and time if surprise depends on previous layer memory model
692 |
693 | weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
694 |
695 | # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
696 |
697 | grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
698 |
699 | grads = TensorDict(grads)
700 |
701 | # surprises
702 |
703 | adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
704 | unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
705 |
706 | # maybe softclamp grad norm
707 |
708 | if exists(self.max_grad_norm):
709 | grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
710 |
711 | # restore batch and sequence dimension
712 |
713 | grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads)
714 |
715 | # maybe per layer modulation
716 |
717 | if need_layer_lr_mod:
718 | grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
719 |
720 | # negative gradients, adaptive lr already applied as loss weight
721 |
722 | surprises = grads.mul(-1)
723 |
724 | # past states
725 |
726 | if not exists(past_state):
727 | # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
728 |
729 | minibatch_init_weight = weights
730 | init_momentum = self.init_momentum(batch)
731 |
732 | past_state = (minibatch_init_weight, init_momentum)
733 |
734 | past_last_update, past_last_momentum = past_state
735 |
736 | # early return if sequence length less than chunk size
737 |
738 | if num_chunks == 0:
739 | updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
740 | next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
741 |
742 | output = (updates, next_store_state)
743 |
744 | if not return_surprises:
745 | return output
746 |
747 | return (*output, (unweighted_mem_model_loss, adaptive_lr))
748 |
749 | # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
750 |
751 | updates = TensorDict()
752 |
753 | next_last_update = TensorDict()
754 | next_last_momentum = TensorDict()
755 |
756 | for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()):
757 |
758 | update = surprise
759 |
760 | # derive momentum with associative scan - eq (10)
761 |
762 | if has_momentum:
763 | momentum = surprise
764 |
765 | momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
766 |
767 | last_momentum = past_last_momentum[param_name]
768 |
769 | # go from first order momentum all the way to the Nth
770 |
771 | for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
772 | momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
773 |
774 | momentums.append(momentum)
775 |
776 | momentums = stack(momentums)
777 |
778 | next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
779 |
780 | if learned_combine and self.learned_combine_include_zeroth:
781 | # add the original surprise if learned combination of momentums
782 | momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0)
783 |
784 | if not learned_combine:
785 | update = momentums[-1]
786 | else:
787 | update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
788 |
789 | # maybe spectral norm surprises
790 |
791 | if self.spectral_norm_surprises:
792 | update = newtonschulz5(update)
793 |
794 | # use associative scan again for learned forgetting (weight decay) - eq (13)
795 |
796 | update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
797 |
798 | updates[param_name] = update
799 | next_last_update[param_name] = update[:, -1]
800 |
801 | # determine next state for the storing of memories
802 |
803 | next_state = (next_last_update, next_last_momentum)
804 |
805 | next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
806 |
807 | # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
808 |
809 | if not return_surprises:
810 | return updates, next_store_state
811 |
812 | return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
813 |
814 | def retrieve_memories(
815 | self,
816 | seq,
817 | weights: dict[str, Tensor],
818 | ):
819 | chunk_size = self.retrieve_chunk_size
820 |
821 | weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape
822 |
823 | batch, seq_len = seq.shape[:2]
824 |
825 | # auto infer single token decoding, if there are only 1 set of weights and 1 token
826 |
827 | is_one_token = seq_len == 1
828 | is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
829 |
830 | is_single_token_decode = is_one_token and is_one_weight
831 |
832 | if is_single_token_decode:
833 | chunk_size = 1
834 |
835 | # padding related, for chunked processing
836 |
837 | need_pad = chunk_size > 1 or not is_one_weight
838 |
839 | if need_pad:
840 | seq = pad_at_dim(seq, (1, 0), dim = 1)
841 |
842 | seq_len_plus_one = seq.shape[-2]
843 |
844 | next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
845 |
846 | padding = next_seq_len - seq_len_plus_one
847 | seq = pad_at_dim(seq, (0, padding), dim = 1)
848 |
849 | # the parameters of the memory model stores the memories of the key / values
850 | # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
851 |
852 | weights = TensorDict(weights)
853 |
854 | # pre norm
855 |
856 | seq = self.retrieve_norm(seq)
857 |
858 | # sequence Float['b n d'] to queries
859 |
860 | queries = self.to_queries(seq)
861 |
862 | # maybe multihead
863 |
864 | queries = self.split_heads(queries)
865 |
866 | # maybe qk rmsnorm
867 |
868 | queries = self.q_norm(queries)
869 |
870 | # fetch values from memory model
871 |
872 | if weights_have_expanded_shape:
873 | weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
874 |
875 | queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
876 |
877 | # forward functional call
878 |
879 | values = functional_call(self.memory_model, dict(weights), queries)
880 |
881 | # reconstitute batch dimension
882 |
883 | values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
884 |
885 | values = self.multihead_rmsnorm(values)
886 |
887 | # maybe gate
888 |
889 | if exists(self.retrieve_gate):
890 | values = values * self.retrieve_gate(seq)
891 |
892 | # maybe merge heads and combine
893 |
894 | values = self.merge_heads(values)
895 |
896 | values = self.combine_heads(values)
897 |
898 | # restore, pad with empty memory embed
899 |
900 | if need_pad:
901 | values = values[:, 1:]
902 |
903 | return values[:, :seq_len]
904 |
905 | def forward(
906 | self,
907 | seq,
908 | store_seq = None,
909 | state: NeuralMemState | None = None,
910 | detach_mem_state = False,
911 | prev_weights = None,
912 | store_mask: Tensor | None = None,
913 | return_surprises = False,
914 | ttt_batch_size: int | None = None
915 | ):
916 | is_multi_input = self.qkv_receives_diff_views
917 |
918 | # handle single token
919 |
920 | if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
921 | seq = rearrange(seq, '... b d -> ... b 1 d')
922 |
923 | is_single_token = seq.shape[-2] == 1
924 |
925 | # if different views for qkv, then
926 |
927 | if is_multi_input:
928 | retrieve_seq, seq = seq[0], seq[1:]
929 | else:
930 | retrieve_seq = seq
931 |
932 | # handle previous state init
933 |
934 | if not exists(state):
935 | state = (0, None, None, None, None)
936 |
937 | seq_index, weights, cache_store_seq, past_state, updates = state
938 |
939 | # store
940 |
941 | store_seq = default(store_seq, seq)
942 |
943 | # take care of cache
944 |
945 | if exists(cache_store_seq):
946 | store_seq = safe_cat((cache_store_seq, store_seq))
947 |
948 | # compute split sizes of sequence
949 | # for now manually update weights to last update at the correct boundaries
950 |
951 | store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size)
952 |
953 | need_update_weights = exists(batch_size)
954 |
955 | # determine split sizes and when to update
956 |
957 | if need_update_weights:
958 | update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size)
959 |
960 | seq_range = torch.arange(store_seq_len) + seq_index + 1
961 | batch_boundary = divisible_by(seq_range, batch_size)
962 |
963 | indices = seq_range[batch_boundary] - seq_index
964 |
965 | indices = F.pad(indices, (1, 0), value = 0)
966 |
967 | if indices[-1] != store_seq_len:
968 | indices = F.pad(indices, (0, 1), value = store_seq_len)
969 |
970 | split_sizes = (indices[1:] - indices[:-1]).tolist()
971 |
972 | assert sum(split_sizes) == store_seq_len
973 | else:
974 | split_sizes = (store_seq_len,)
975 | update_after_final_store = False
976 |
977 | # accumulate updates
978 |
979 | updates = None
980 |
981 | def accum_updates(past_updates, future_updates):
982 | if not exists(past_updates):
983 | return future_updates
984 |
985 | return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())})
986 |
987 | # loop through chunks of store sequences
988 |
989 | store_seqs = store_seq.split(split_sizes, dim = -2)
990 |
991 | if exists(store_mask):
992 | store_masks = store_mask.split(split_sizes, dim = -1)
993 | else:
994 | store_masks = (None,) * len(split_sizes)
995 |
996 | # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
997 |
998 | surprises = (None, None)
999 | gate = None
1000 |
1001 | if exists(self.transition_gate):
1002 | gate = self.transition_gate.sigmoid()
1003 |
1004 | for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
1005 | is_last = ind == (len(store_seqs) - 1)
1006 |
1007 | # store
1008 |
1009 | next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
1010 | store_seq_chunk,
1011 | weights,
1012 | seq_index = seq_index,
1013 | past_state = past_state,
1014 | prev_weights = prev_weights,
1015 | mask = maybe_store_mask,
1016 | return_surprises = True
1017 | )
1018 |
1019 | weights = next_neural_mem_state.weights
1020 | seq_index = next_neural_mem_state.seq_index
1021 | past_state = next_neural_mem_state.states
1022 |
1023 | updates = accum_updates(updates, next_updates)
1024 |
1025 | surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
1026 |
1027 | if is_last and not update_after_final_store:
1028 | continue
1029 |
1030 | # update weights once batch size is fulfilled
1031 |
1032 | last_update, last_momentum = past_state
1033 |
1034 | if exists(gate):
1035 | last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())})
1036 |
1037 | past_state = (last_update, last_momentum)
1038 |
1039 | # set weights to the last updated weights for the last minibatch
1040 |
1041 | weights = last_update
1042 |
1043 | next_neural_mem_state = next_neural_mem_state._replace(
1044 | weights = weights,
1045 | states = past_state,
1046 | )
1047 |
1048 | next_neural_mem_state = next_neural_mem_state._replace(updates = updates)
1049 |
1050 | # retrieve
1051 |
1052 | if is_single_token:
1053 | last_update, _ = next_neural_mem_state.states
1054 | updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
1055 |
1056 | retrieved = self.retrieve_memories(
1057 | retrieve_seq,
1058 | updates
1059 | )
1060 |
1061 | # maybe detach
1062 |
1063 | if detach_mem_state:
1064 | next_neural_mem_state = mem_state_detach(next_neural_mem_state)
1065 |
1066 | # returning
1067 |
1068 | if not return_surprises:
1069 | return retrieved, next_neural_mem_state
1070 |
1071 | return retrieved, next_neural_mem_state, surprises
1072 |
--------------------------------------------------------------------------------
/train_mac.py:
--------------------------------------------------------------------------------
1 | import random
2 | import tqdm
3 | import gzip
4 | import numpy as np
5 |
6 | import torch
7 | from torch import nn, Tensor
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from adam_atan2_pytorch import AdoptAtan2
12 |
13 | from titans_pytorch import (
14 | MemoryAsContextTransformer,
15 | MemoryMLP,
16 | MemoryAttention
17 | )
18 |
19 | # constants
20 |
21 | NUM_BATCHES = int(1e5)
22 | BATCH_SIZE = 4
23 | GRADIENT_ACCUMULATE_EVERY = 4
24 | LEARNING_RATE = 2e-4
25 | VALIDATE_EVERY = 100
26 | GENERATE_EVERY = 500
27 | PRIME_LENGTH = 100
28 | GENERATE_LENGTH = 512
29 | SHOULD_GENERATE = True
30 | SEQ_LEN = 512
31 |
32 | # neural memory related
33 |
34 | NEURAL_MEMORY_DEPTH = 2
35 | NUM_PERSIST_MEM = 4
36 | NUM_LONGTERM_MEM = 4
37 | NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
38 | NEURAL_MEM_GATE_ATTN_OUTPUT = False
39 | NEURAL_MEM_MOMENTUM = True
40 | NEURAL_MEM_MOMENTUM_ORDER = 1
41 | NEURAL_MEM_QK_NORM = True
42 | NEURAL_MEM_MAX_LR = 1e-1
43 | USE_MEM_ATTENTION_MODEL = False
44 | WINDOW_SIZE = 32
45 | NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc
46 | NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence
47 | SLIDING_WINDOWS = True
48 | STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
49 | MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
50 | NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper
51 | NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS
52 | NEURAL_MEM_SPEC_NORM_SURPRISES = True # applying lessons from Muon optimizer to surprise updates, by spectral norming the surprises
53 |
54 | # experiment related
55 |
56 | PROJECT_NAME = 'titans-mac-transformer'
57 | RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
58 | WANDB_ONLINE = False # turn this on to pipe experiment to cloud
59 |
60 | # perf related
61 |
62 | USE_ACCELERATED_SCAN = True
63 | USE_FLEX_ATTN = True
64 | USE_FAST_INFERENCE = False
65 |
66 | # wandb experiment tracker
67 |
68 | import wandb
69 | wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
70 | wandb.run.name = RUN_NAME
71 | wandb.run.save()
72 |
73 | # helpers
74 |
75 | def cycle(loader):
76 | while True:
77 | for data in loader:
78 | yield data
79 |
80 | def decode_token(token):
81 | return str(chr(max(32, token)))
82 |
83 | def decode_tokens(tokens):
84 | return ''.join(list(map(decode_token, tokens)))
85 |
86 | # memory model
87 |
88 | if USE_MEM_ATTENTION_MODEL:
89 | neural_memory_model = MemoryAttention(
90 | dim = 64
91 | )
92 | else:
93 | neural_memory_model = MemoryMLP(
94 | dim = 64,
95 | depth = NEURAL_MEMORY_DEPTH
96 | )
97 |
98 | # instantiate memory-as-context transformer
99 |
100 | model = MemoryAsContextTransformer(
101 | num_tokens = 256,
102 | dim = 384,
103 | depth = 8,
104 | segment_len = WINDOW_SIZE,
105 | num_persist_mem_tokens = NUM_PERSIST_MEM,
106 | num_longterm_mem_tokens = NUM_LONGTERM_MEM,
107 | neural_memory_layers = NEURAL_MEM_LAYERS,
108 | neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN,
109 | neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
110 | neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
111 | neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
112 | neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW,
113 | use_flex_attn = USE_FLEX_ATTN,
114 | sliding_window_attn = SLIDING_WINDOWS,
115 | neural_memory_model = neural_memory_model,
116 | neural_memory_kwargs = dict(
117 | dim_head = 64,
118 | heads = 4,
119 | attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
120 | qk_rmsnorm = NEURAL_MEM_QK_NORM,
121 | momentum = NEURAL_MEM_MOMENTUM,
122 | momentum_order = NEURAL_MEM_MOMENTUM_ORDER,
123 | default_step_transform_max_lr = NEURAL_MEM_MAX_LR,
124 | use_accelerated_scan = USE_ACCELERATED_SCAN,
125 | per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR,
126 | spectral_norm_surprises = NEURAL_MEM_SPEC_NORM_SURPRISES
127 | )
128 | ).cuda()
129 |
130 | # prepare enwik8 data
131 |
132 | with gzip.open('./data/enwik8.gz') as file:
133 | data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
134 | data_train, data_val = np.split(data, [int(90e6)])
135 | data_train, data_val = map(torch.from_numpy, (data_train, data_val))
136 |
137 | class TextSamplerDataset(Dataset):
138 | def __init__(self, data, seq_len):
139 | super().__init__()
140 | self.data = data
141 | self.seq_len = seq_len
142 |
143 | def __getitem__(self, index):
144 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
145 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
146 | return full_seq.cuda()
147 |
148 | def __len__(self):
149 | return self.data.size(0) // self.seq_len
150 |
151 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
152 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
153 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
154 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
155 |
156 | # optimizer
157 |
158 | optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE)
159 |
160 | # training
161 |
162 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
163 | model.train()
164 |
165 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
166 | loss = model(next(train_loader), return_loss = True)
167 | loss.backward()
168 |
169 | print(f'training loss: {loss.item()}')
170 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
171 | optim.step()
172 | optim.zero_grad()
173 | wandb.log(dict(loss = loss.item()))
174 |
175 | if i % VALIDATE_EVERY == 0:
176 | model.eval()
177 | with torch.no_grad():
178 | loss = model(next(val_loader), return_loss = True)
179 | print(f'validation loss: {loss.item()}')
180 |
181 | if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
182 | model.eval()
183 | inp = random.choice(val_dataset)[:PRIME_LENGTH]
184 | prime = decode_tokens(inp)
185 | print(f'%s \n\n %s', (prime, '*' * 100))
186 |
187 | sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE)
188 | output_str = decode_tokens(sample[0])
189 | print(output_str)
190 |
--------------------------------------------------------------------------------