├── .github └── workflows │ ├── python-publish.yml │ └── test.yaml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── fig1.png ├── fig2.png ├── pyproject.toml ├── tests └── test_titans.py ├── titans_pytorch ├── __init__.py ├── mac_transformer.py ├── memory_models.py └── neural_memory.py └── train_mac.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests the examples in README 2 | on: [push, pull_request] 3 | 4 | env: 5 | TYPECHECK: True 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Install Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.11" 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install uv 19 | python -m uv pip install --upgrade pip 20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu 21 | python -m uv pip install -e .[test] 22 | - name: Test with pytest 23 | run: | 24 | python -m pytest tests/ 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | train_local.py 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ## Titans - Pytorch 6 | 7 | Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree. 8 | 9 | ## Appreciation 10 | 11 | - [Eryk](https://github.com/sentialx) for sharing his early experimental results with me, positive for 2 layer MLP 12 | 13 | ## Install 14 | 15 | ```bash 16 | $ pip install titans-pytorch 17 | ``` 18 | 19 | ## Usage 20 | 21 | ```python 22 | import torch 23 | from titans_pytorch import NeuralMemory 24 | 25 | mem = NeuralMemory( 26 | dim = 384, 27 | chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage) 28 | ).cuda() 29 | 30 | seq = torch.randn(2, 1024, 384).cuda() 31 | retrieved, mem_state = mem(seq) 32 | 33 | assert seq.shape == retrieved.shape 34 | ``` 35 | 36 | A transformer with the `MAC` configuration can be used as 37 | 38 | ```python 39 | import torch 40 | from titans_pytorch import MemoryAsContextTransformer 41 | 42 | transformer = MemoryAsContextTransformer( 43 | num_tokens = 256, 44 | dim = 256, 45 | depth = 2, 46 | segment_len = 128, # local attention window size 47 | num_persist_mem_tokens = 4, 48 | num_longterm_mem_tokens = 16, 49 | ) 50 | 51 | token_ids = torch.randint(0, 256, (1, 1023)) 52 | 53 | loss = transformer(token_ids, return_loss = True) # (1, 1023, 256) 54 | loss.backward() 55 | 56 | # after much training 57 | 58 | sampled = transformer.sample(token_ids[:, :4], 512) 59 | ``` 60 | 61 | ## Experiments 62 | 63 | ```bash 64 | $ pip install .[examples] 65 | ``` 66 | 67 | Then modify `train_mac.py` and run it to query nature 68 | 69 | ```bash 70 | $ python train_mac.py 71 | ``` 72 | 73 | ## Citations 74 | 75 | ```bibtex 76 | @inproceedings{Behrouz2024TitansLT, 77 | title = {Titans: Learning to Memorize at Test Time}, 78 | author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni}, 79 | year = {2024}, 80 | url = {https://api.semanticscholar.org/CorpusID:275212078} 81 | } 82 | ``` 83 | 84 | ```bibtex 85 | @article{Sun2024LearningT, 86 | title = {Learning to (Learn at Test Time): RNNs with Expressive Hidden States}, 87 | author = {Yu Sun and Xinhao Li and Karan Dalal and Jiarui Xu and Arjun Vikram and Genghan Zhang and Yann Dubois and Xinlei Chen and Xiaolong Wang and Oluwasanmi Koyejo and Tatsunori Hashimoto and Carlos Guestrin}, 88 | journal = {ArXiv}, 89 | year = {2024}, 90 | volume = {abs/2407.04620}, 91 | url = {https://api.semanticscholar.org/CorpusID:271039606} 92 | } 93 | ``` 94 | 95 | ```bibtex 96 | @inproceedings{Yang2024GatedDN, 97 | title = {Gated Delta Networks: Improving Mamba2 with Delta Rule}, 98 | author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh}, 99 | year = {2024}, 100 | url = {https://api.semanticscholar.org/CorpusID:274598177} 101 | } 102 | ``` 103 | 104 | ```bibtex 105 | @inproceedings{Nguyen2024TurningUT, 106 | title = {Turning Up the Heat: Min-p Sampling for Creative and Coherent LLM Outputs}, 107 | author = {Minh Nguyen and Andrew Baker and Clement Neo and Allen Roush and Andreas Kirsch and Ravid Shwartz-Ziv}, 108 | year = {2024}, 109 | url = {https://api.semanticscholar.org/CorpusID:270870613} 110 | } 111 | ``` 112 | 113 | ```bibtex 114 | @article{Zhu2024HyperConnections, 115 | title = {Hyper-Connections}, 116 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 117 | journal = {ArXiv}, 118 | year = {2024}, 119 | volume = {abs/2409.19606}, 120 | url = {https://api.semanticscholar.org/CorpusID:272987528} 121 | } 122 | ``` 123 | 124 | ```bibtex 125 | @article{Zhou2024ValueRL, 126 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 127 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 128 | journal = {ArXiv}, 129 | year = {2024}, 130 | volume = {abs/2410.17897}, 131 | url = {https://api.semanticscholar.org/CorpusID:273532030} 132 | } 133 | ``` 134 | 135 | ```bibtex 136 | @software{Kyrylov_Accelerated_Scan_2024, 137 | author = {Kyrylov, Volodymyr}, 138 | doi = {10.5281/zenodo.10600962}, 139 | title = {Accelerated Scan}, 140 | version = {0.1.2}, 141 | year = {2024} 142 | } 143 | ``` 144 | 145 | ```bibtex 146 | @misc{wang2025testtimeregressionunifyingframework, 147 | title = {Test-time regression: a unifying framework for designing sequence models with associative memory}, 148 | author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox}, 149 | year = {2025}, 150 | eprint = {2501.12352}, 151 | archivePrefix = {arXiv}, 152 | primaryClass = {cs.LG}, 153 | url = {https://arxiv.org/abs/2501.12352}, 154 | } 155 | ``` 156 | 157 | ```bibtex 158 | @misc{jordan2024muon, 159 | author = {Keller Jordan and Yuchen Jin and Vlado Boza and Jiacheng You and 160 | Franz Cesista and Laker Newhouse and Jeremy Bernstein}, 161 | title = {Muon: An optimizer for hidden layers in neural networks}, 162 | year = {2024}, 163 | url = {https://kellerjordan.github.io/posts/muon/} 164 | } 165 | ``` 166 | 167 | ```bibtex 168 | @inproceedings{Zhang2025TestTimeTD, 169 | title = {Test-Time Training Done Right}, 170 | author = {Tianyuan Zhang and Sai Bi and Yicong Hong and Kai Zhang and Fujun Luan and Songlin Yang and Kalyan Sunkavalli and William T. Freeman and Hao Tan}, 171 | year = {2025}, 172 | url = {https://api.semanticscholar.org/CorpusID:279071244} 173 | } 174 | ``` 175 | 176 | ```bibtex 177 | @inproceedings{Behrouz2025ATLASLT, 178 | title = {ATLAS: Learning to Optimally Memorize the Context at Test Time}, 179 | author = {Ali Behrouz and Ze-Minghui Li and Praneeth Kacham and Majid Daliri and Yuan Deng and Peilin Zhong and Meisam Razaviyayn and Vahab S. Mirrokni}, 180 | year = {2025}, 181 | url = {https://api.semanticscholar.org/CorpusID:278996373} 182 | } 183 | ``` 184 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/data/enwik8.gz -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/fig1.png -------------------------------------------------------------------------------- /fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/titans-pytorch/1d9ad1417ee7d8ac5e7288d5e86765fa6453651d/fig2.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "titans-pytorch" 3 | version = "0.4.10" 4 | description = "Titans" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'test time training', 15 | 'linear attention', 16 | 'memory', 17 | ] 18 | 19 | classifiers=[ 20 | 'Development Status :: 4 - Beta', 21 | 'Intended Audience :: Developers', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Programming Language :: Python :: 3.9', 25 | ] 26 | 27 | dependencies = [ 28 | "assoc-scan", 29 | "axial_positional_embedding>=0.3.10", 30 | "einops>=0.8.0", 31 | "einx>=0.3.0", 32 | "hyper-connections>=0.1.11", 33 | "Ninja", 34 | "rotary-embedding-torch", 35 | "tensordict", 36 | "torch>=2.2", 37 | "tqdm", 38 | "x-transformers" 39 | ] 40 | 41 | [project.urls] 42 | Homepage = "https://pypi.org/project/titans-pytorch/" 43 | Repository = "https://github.com/lucidrains/titans-pytorch" 44 | 45 | [project.optional-dependencies] 46 | 47 | examples = [ 48 | "adam-atan2-pytorch>=0.1.18", 49 | "wandb" 50 | ] 51 | 52 | test = [ 53 | "pytest" 54 | ] 55 | 56 | [tool.pytest.ini_options] 57 | pythonpath = [ 58 | "." 59 | ] 60 | 61 | [build-system] 62 | requires = ["hatchling"] 63 | build-backend = "hatchling.build" 64 | 65 | [tool.rye] 66 | managed = true 67 | dev-dependencies = [] 68 | 69 | [tool.hatch.metadata] 70 | allow-direct-references = true 71 | 72 | [tool.hatch.build.targets.wheel] 73 | packages = ["titans_pytorch"] 74 | -------------------------------------------------------------------------------- /tests/test_titans.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import pytest 7 | from titans_pytorch import NeuralMemory 8 | from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer 9 | 10 | # functions 11 | 12 | def exists(v): 13 | return v is not None 14 | 15 | def diff(x, y): 16 | return (x - y).abs().amax() 17 | 18 | @contextmanager 19 | def torch_default_dtype(dtype): 20 | prev_dtype = torch.get_default_dtype() 21 | torch.set_default_dtype(dtype) 22 | yield 23 | torch.set_default_dtype(prev_dtype) 24 | 25 | # main test 26 | 27 | @pytest.mark.parametrize('seq_len', (32, 512, 77)) 28 | @pytest.mark.parametrize('silu', (False, True)) 29 | @pytest.mark.parametrize('chunk_size, attn_pool_chunks', ((64, True), (64, False), (1, False))) 30 | @pytest.mark.parametrize('momentum', (False, True)) 31 | @pytest.mark.parametrize('qk_rmsnorm', (False, True)) 32 | @pytest.mark.parametrize('heads', (1, 4)) 33 | @pytest.mark.parametrize('max_grad_norm', (None, 2.)) 34 | @pytest.mark.parametrize('num_kv_per_token', (1, 2)) 35 | @pytest.mark.parametrize('per_parameter_lr_modulation', (False, True)) 36 | @pytest.mark.parametrize('per_head_learned_parameters', (False, True)) 37 | @pytest.mark.parametrize('test_store_mask', (False, True)) 38 | def test_titans( 39 | seq_len, 40 | silu, 41 | attn_pool_chunks, 42 | chunk_size, 43 | momentum, 44 | qk_rmsnorm, 45 | heads, 46 | max_grad_norm, 47 | num_kv_per_token, 48 | per_parameter_lr_modulation, 49 | per_head_learned_parameters, 50 | test_store_mask 51 | ): 52 | mem = NeuralMemory( 53 | dim = 16, 54 | chunk_size = chunk_size, 55 | activation = nn.SiLU() if silu else None, 56 | attn_pool_chunks = attn_pool_chunks, 57 | max_grad_norm = max_grad_norm, 58 | num_kv_per_token = num_kv_per_token, 59 | momentum = momentum, 60 | qk_rmsnorm = qk_rmsnorm, 61 | heads = heads, 62 | per_parameter_lr_modulation = per_parameter_lr_modulation, 63 | per_head_learned_parameters = per_head_learned_parameters 64 | ) 65 | 66 | seq = torch.randn(2, seq_len, 16) 67 | 68 | store_mask = None 69 | 70 | if test_store_mask: 71 | store_mask = torch.randint(0, 2, (2, seq_len)).bool() 72 | 73 | retrieved, _ = mem(seq, store_mask = store_mask) 74 | 75 | assert seq.shape == retrieved.shape 76 | 77 | def test_return_surprises(): 78 | 79 | mem = NeuralMemory( 80 | dim = 384, 81 | chunk_size = 2, 82 | dim_head = 64, 83 | heads = 4, 84 | ) 85 | 86 | seq = torch.randn(4, 64, 384) 87 | 88 | _, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True) 89 | 90 | assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)]) 91 | 92 | @pytest.mark.parametrize('learned_momentum_combine', (False, True)) 93 | @pytest.mark.parametrize('learned_combine_include_zeroth', (False, True)) 94 | def test_titans_second_order_momentum( 95 | learned_momentum_combine, 96 | learned_combine_include_zeroth 97 | ): 98 | 99 | mem = NeuralMemory( 100 | dim = 384, 101 | dim_head = 64, 102 | heads = 2, 103 | chunk_size = 1, 104 | batch_size = 2, 105 | momentum_order = 2, 106 | learned_momentum_combine = learned_momentum_combine, 107 | learned_combine_include_zeroth = learned_combine_include_zeroth 108 | ) 109 | 110 | seq = torch.randn(2, 5, 384) 111 | 112 | parallel_retrieved, state = mem(seq) 113 | assert seq.shape == parallel_retrieved.shape 114 | 115 | def test_titans_attn_memory(): 116 | from titans_pytorch.memory_models import MemoryAttention 117 | 118 | mem = NeuralMemory( 119 | dim = 16, 120 | chunk_size = 64, 121 | model = MemoryAttention( 122 | dim = 16 123 | ) 124 | ) 125 | 126 | seq = torch.randn(2, 1024, 16) 127 | retrieved, _ = mem(seq) 128 | 129 | assert seq.shape == retrieved.shape 130 | 131 | def test_swiglu_ff_memory(): 132 | from titans_pytorch.memory_models import MemorySwiGluMLP 133 | 134 | mem = NeuralMemory( 135 | dim = 16, 136 | chunk_size = 2, 137 | mem_model_norm_add_residual = False, 138 | model = MemorySwiGluMLP( 139 | dim = 16, 140 | depth = 2 141 | ) 142 | ) 143 | 144 | seq = torch.randn(2, 64, 16) 145 | retrieved, _ = mem(seq) 146 | 147 | assert seq.shape == retrieved.shape 148 | 149 | @pytest.mark.parametrize('gated_transition', (True, False)) 150 | def test_neural_mem_chaining_chunks( 151 | gated_transition 152 | ): 153 | mem = NeuralMemory( 154 | dim = 16, 155 | dim_head = 16, 156 | heads = 2, 157 | chunk_size = 16, 158 | gated_transition = gated_transition 159 | ) 160 | 161 | seq = torch.randn(2, 48, 16) 162 | 163 | parallel_retrieved, state = mem(seq) 164 | 165 | seq_first, seq_second, seq_third = seq.split(16, dim = 1) 166 | 167 | first_retrieved, state = mem(seq_first) 168 | second_retrieved, state = mem(seq_second, state = state) 169 | third_retrieved, state = mem(seq_third, state = state) 170 | 171 | assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1), atol = 1e-5) 172 | 173 | def test_neural_mem_chaining_with_weight_residual(): 174 | mem = NeuralMemory( 175 | dim = 16, 176 | dim_head = 16, 177 | heads = 2, 178 | chunk_size = 64 179 | ) 180 | 181 | mem2 = NeuralMemory( 182 | dim = 16, 183 | dim_head = 16, 184 | heads = 2, 185 | chunk_size = 64, 186 | accept_weight_residual = True 187 | ) 188 | 189 | seq = torch.randn(2, 256, 16) 190 | 191 | seq, state = mem(seq) 192 | 193 | parallel_retrieved, _ = mem2(seq, prev_weights = state.updates) 194 | 195 | seq_first, seq_second = seq[:, :128], seq[:, 128:] 196 | 197 | first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates) 198 | second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates) 199 | 200 | assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5) 201 | 202 | def test_neural_mem_chaining_with_batch_size(): 203 | mem = NeuralMemory( 204 | dim = 16, 205 | dim_head = 16, 206 | heads = 2, 207 | chunk_size = 16, 208 | batch_size = 64 209 | ) 210 | 211 | seq = torch.randn(2, 112, 16) 212 | 213 | parallel_retrieved, state = mem(seq) 214 | 215 | seq_first, seq_second, seq_third = seq[:, :16], seq[:, 16:64], seq[:, 64:] 216 | 217 | first_retrieved, state = mem(seq_first) 218 | second_retrieved, state = mem(seq_second, state = state) 219 | third_retrieved, state = mem(seq_third, state = state) 220 | 221 | parallel_part_retrieved = torch.cat((first_retrieved, second_retrieved, third_retrieved), dim = 1) 222 | 223 | assert torch.allclose(parallel_retrieved, parallel_part_retrieved, atol = 1e-5) 224 | 225 | @pytest.mark.parametrize('seq_len', (1023, 17)) 226 | @pytest.mark.parametrize('num_persist_mem_tokens', (0, 16)) 227 | @pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16)) 228 | @pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True)) 229 | @pytest.mark.parametrize('neural_mem_segment_len', (8, 16)) 230 | @pytest.mark.parametrize('neural_mem_weight_residual', (False, True)) 231 | @pytest.mark.parametrize('neural_mem_batch_size', (None, 64)) 232 | @pytest.mark.parametrize('neural_mem_qkv_receives_diff_views', (False, True)) 233 | @pytest.mark.parametrize('neural_mem_momentum', (False, True)) 234 | def test_mac( 235 | seq_len, 236 | num_persist_mem_tokens, 237 | num_longterm_mem_tokens, 238 | neural_mem_gate_attn_output, 239 | neural_mem_segment_len, 240 | neural_mem_weight_residual, 241 | neural_mem_batch_size, 242 | neural_mem_qkv_receives_diff_views, 243 | neural_mem_momentum 244 | ): 245 | transformer = MemoryAsContextTransformer( 246 | num_tokens = 256, 247 | dim = 16, 248 | depth = 2, 249 | num_persist_mem_tokens = num_persist_mem_tokens, 250 | num_longterm_mem_tokens = num_longterm_mem_tokens, 251 | segment_len = 128, 252 | neural_mem_gate_attn_output = neural_mem_gate_attn_output, 253 | neural_memory_segment_len = neural_mem_segment_len, 254 | neural_memory_batch_size = neural_mem_batch_size, 255 | neural_memory_qkv_receives_diff_views = neural_mem_qkv_receives_diff_views, 256 | neural_mem_weight_residual = neural_mem_weight_residual, 257 | neural_memory_kwargs = dict( 258 | momentum = neural_mem_momentum 259 | ) 260 | ) 261 | 262 | x = torch.randint(0, 256, (1, seq_len)) 263 | 264 | logits = transformer(x) 265 | assert logits.shape == (1, seq_len, 256) 266 | 267 | @pytest.mark.parametrize('sliding', (False, True)) 268 | @pytest.mark.parametrize('mem_layers', ((), None)) 269 | @pytest.mark.parametrize('longterm_mems', (0, 4, 16)) 270 | @pytest.mark.parametrize('prompt_len', (4, 16)) 271 | @torch_default_dtype(torch.float64) 272 | def test_mac_sampling( 273 | sliding, 274 | mem_layers, 275 | longterm_mems, 276 | prompt_len 277 | ): 278 | transformer = MemoryAsContextTransformer( 279 | num_tokens = 256, 280 | dim = 16, 281 | depth = 4, 282 | segment_len = 32, 283 | num_persist_mem_tokens = 4, 284 | num_longterm_mem_tokens = longterm_mems, 285 | sliding_window_attn = sliding, 286 | neural_memory_layers = mem_layers, 287 | neural_mem_gate_attn_output = False 288 | ) 289 | 290 | ids = torch.randint(0, 256, (1, 1023)) 291 | 292 | # after much training 293 | 294 | prompt = ids[:, :prompt_len] 295 | 296 | sampled = transformer.sample(prompt, 53, use_cache = False, temperature = 0.) 297 | sampled_with_cache = transformer.sample(prompt, 53, use_cache = True, temperature = 0.) 298 | 299 | assert torch.allclose(sampled, sampled_with_cache) 300 | 301 | @pytest.mark.parametrize('seq_len', (2, 64, 256)) 302 | @pytest.mark.parametrize('prompt_len', (0, 65)) 303 | @pytest.mark.parametrize('mem_chunk_size', (2, 32, 64)) 304 | @pytest.mark.parametrize('gated_transition', (False, True)) 305 | @torch_default_dtype(torch.float64) 306 | def test_neural_mem_inference( 307 | seq_len, 308 | prompt_len, 309 | mem_chunk_size, 310 | gated_transition 311 | ): 312 | 313 | mem = NeuralMemory( 314 | dim = 16, 315 | chunk_size = mem_chunk_size, 316 | gated_transition = gated_transition 317 | ) 318 | 319 | seq = torch.randn(2, seq_len, 16) 320 | parallel_retrieved, _ = mem(seq) 321 | 322 | assert seq.shape == parallel_retrieved.shape 323 | 324 | state = None 325 | sequential_retrieved = [] 326 | 327 | # test initial parallel prompt 328 | 329 | test_parallel_prompt = prompt_len > 0 and prompt_len < seq_len 330 | 331 | if test_parallel_prompt: 332 | prompt, seq = seq[:, :prompt_len], seq[:, prompt_len:] 333 | retrieved_prompt, state = mem(prompt) 334 | sequential_retrieved.append(retrieved_prompt) 335 | 336 | # sequential inference 337 | 338 | for token in seq.unbind(dim = 1): 339 | 340 | one_retrieved, state = mem.forward( 341 | token, 342 | state = state, 343 | ) 344 | 345 | sequential_retrieved.append(one_retrieved) 346 | 347 | sequential_retrieved = torch.cat(sequential_retrieved, dim = -2) 348 | 349 | assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-6) 350 | 351 | @pytest.mark.parametrize('seq_len', (1023, 17)) 352 | @pytest.mark.parametrize('sliding', (True, False)) 353 | def test_flex( 354 | seq_len, 355 | sliding 356 | ): 357 | if not (torch.cuda.is_available() and exists(flex_attention)): 358 | pytest.skip() 359 | 360 | attn = SegmentedAttention( 361 | dim = 16, 362 | segment_len = 32, 363 | num_persist_mem_tokens = 1, 364 | num_longterm_mem_tokens = 1, 365 | use_flex_attn = True, 366 | sliding = sliding 367 | ).cuda() 368 | 369 | seq = torch.randn(1, seq_len, 16).cuda() 370 | 371 | out_flex, _ = attn(seq) 372 | out_non_flex, _ = attn(seq, disable_flex_attn = True) 373 | 374 | assert torch.allclose(out_flex, out_non_flex, atol = 1e-5) 375 | 376 | @pytest.mark.parametrize('use_accelerated', (True, False)) 377 | def test_assoc_scan( 378 | use_accelerated 379 | ): 380 | from titans_pytorch.neural_memory import AssocScan 381 | 382 | if use_accelerated and not torch.cuda.is_available(): 383 | pytest.skip() 384 | 385 | scan = AssocScan(use_accelerated = use_accelerated) 386 | 387 | seq_len = 128 388 | mid_point = seq_len // 2 389 | 390 | gates = torch.randn(2, seq_len, 16).sigmoid() 391 | inputs = torch.randn(2, seq_len, 16) 392 | 393 | if use_accelerated: 394 | gates = gates.cuda() 395 | inputs = inputs.cuda() 396 | 397 | output = scan(gates, inputs) 398 | 399 | gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:] 400 | inputs1, inputs2 = inputs[:, :mid_point], inputs[:, mid_point:] 401 | 402 | first_half = scan(gates1, inputs1) 403 | 404 | second_half = scan(gates2, inputs2, prev = first_half[:, -1]) 405 | assert second_half.shape == inputs2.shape 406 | 407 | assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5) 408 | 409 | def test_mem_state_detach(): 410 | from titans_pytorch.neural_memory import mem_state_detach 411 | 412 | mem = NeuralMemory( 413 | dim = 384, 414 | chunk_size = 2, 415 | qk_rmsnorm = True, 416 | dim_head = 64, 417 | heads = 4, 418 | ) 419 | 420 | seq = torch.randn(4, 64, 384) 421 | 422 | state = None 423 | 424 | for _ in range(2): 425 | parallel_retrieved, state = mem(seq, state = state) 426 | state = mem_state_detach(state) 427 | parallel_retrieved.sum().backward() 428 | -------------------------------------------------------------------------------- /titans_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from titans_pytorch.neural_memory import ( 2 | NeuralMemory, 3 | NeuralMemState, 4 | mem_state_detach 5 | ) 6 | 7 | from titans_pytorch.memory_models import ( 8 | MemoryMLP, 9 | MemoryAttention, 10 | FactorizedMemoryMLP, 11 | MemorySwiGluMLP, 12 | GatedResidualMemoryMLP 13 | ) 14 | 15 | from titans_pytorch.mac_transformer import ( 16 | MemoryAsContextTransformer 17 | ) 18 | -------------------------------------------------------------------------------- /titans_pytorch/mac_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | from math import ceil 5 | from copy import deepcopy 6 | from functools import partial 7 | from collections import namedtuple 8 | 9 | import tqdm 10 | 11 | import torch 12 | from torch import nn, stack, cat 13 | import torch.nn.functional as F 14 | from torch.nn import Module, ModuleList, Linear 15 | 16 | # flex attention 17 | # https://pytorch.org/blog/flexattention/ 18 | 19 | flex_attention = None 20 | 21 | try: 22 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask 23 | if torch.cuda.is_available(): 24 | flex_attention = torch.compile(flex_attention) 25 | except ImportError: 26 | pass 27 | 28 | def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False): 29 | 30 | def create_mac_mask(_, __, q_idx, kv_idx): 31 | is_persist_mem = kv_idx < persist_mem_len 32 | kv_without_mem = kv_idx - persist_mem_len 33 | causal_mask = q_idx >= kv_without_mem 34 | 35 | if not sliding: 36 | block_diagonal = (q_idx // window_size) == (kv_without_mem // window_size) 37 | causal_mask = causal_mask & block_diagonal 38 | else: 39 | sliding_mask = (q_idx - kv_without_mem) <= window_size 40 | causal_mask = causal_mask & sliding_mask 41 | 42 | return is_persist_mem | (~is_persist_mem & causal_mask) 43 | 44 | block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True) 45 | return block_mask 46 | 47 | # einstein notation related 48 | 49 | from einops import repeat, rearrange, pack, unpack, einsum 50 | from einops.layers.torch import Rearrange 51 | 52 | # b - batch 53 | # n - sequence 54 | # h - heads 55 | # d - feature dimension 56 | 57 | # absolute and relative positions 58 | 59 | from axial_positional_embedding import ContinuousAxialPositionalEmbedding 60 | from rotary_embedding_torch import RotaryEmbedding 61 | 62 | # hyper connections / attend from x-transformers, which handles different queries and key lengths better 63 | 64 | from x_transformers.attend import Attend 65 | 66 | from hyper_connections import get_init_and_expand_reduce_stream_functions 67 | 68 | # proposed neural memory 69 | 70 | from titans_pytorch.neural_memory import NeuralMemory 71 | 72 | # constants 73 | 74 | LinearNoBias = partial(Linear, bias = False) 75 | 76 | AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values')) 77 | 78 | # helpers 79 | 80 | def exists(v): 81 | return v is not None 82 | 83 | def default(v, d): 84 | return v if exists(v) else d 85 | 86 | def identity(t): 87 | return t 88 | 89 | def divisible_by(num, den): 90 | return (num % den) == 0 91 | 92 | def round_up_multiple(seq, mult): 93 | return ceil(seq / mult) * mult 94 | 95 | def round_down_multiple(seq, mult): 96 | return seq // mult * mult 97 | 98 | def pack_with_inverse(t, pattern): 99 | packed, packed_shape = pack(t, pattern) 100 | 101 | def inverse(out, inv_pattern = None): 102 | return unpack(out, packed_shape, default(inv_pattern, pattern)) 103 | 104 | return packed, inverse 105 | 106 | def pad_at_dim(t, pad, dim = -1, value = 0.): 107 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 108 | zeros = ((0, 0) * dims_from_right) 109 | return F.pad(t, (*zeros, *pad), value = value) 110 | 111 | def pad_and_segment_with_inverse( 112 | seq, 113 | segment_len, 114 | fold_into_batch = True, 115 | inverse_remove_pad = True 116 | ): 117 | batch, seq_len = seq.shape[:2] 118 | next_seq_len_mult = round_up_multiple(seq_len, segment_len) 119 | 120 | padding = next_seq_len_mult - seq_len 121 | needs_pad = padding > 0 122 | 123 | if needs_pad: 124 | seq = F.pad(seq, (0, 0, 0, padding)) 125 | 126 | if fold_into_batch: 127 | seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len) 128 | 129 | def inverse(out): 130 | 131 | if fold_into_batch: 132 | out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch) 133 | 134 | if needs_pad and inverse_remove_pad: 135 | out = out[..., :-padding, :] 136 | 137 | return out 138 | 139 | return seq, inverse 140 | 141 | # sampling related 142 | 143 | def log(t, eps = 1e-20): 144 | return torch.log(t.clamp(min = eps)) 145 | 146 | def gumbel_noise(t): 147 | noise = torch.rand_like(t) 148 | return -log(-log(noise)) 149 | 150 | def gumbel_sample(t, temperature = 1.): 151 | if temperature > 0.: 152 | t = t / temperature + gumbel_noise(t) 153 | return t.argmax(dim = -1, keepdim = True) 154 | 155 | # min_p 156 | # https://arxiv.org/abs/2407.01082 157 | 158 | def min_p_filter(logits, min_p = 0.1): 159 | probs = logits.softmax(dim = -1) 160 | max_probs = probs.amax(dim = -1, keepdim = True) 161 | limit = min_p * max_probs 162 | return torch.where(probs < limit, float('-inf'), logits) 163 | 164 | # feedforward and attention 165 | 166 | class GEGLU(Module): 167 | def forward(self, x): 168 | x, gate = x.chunk(2, dim = -1) 169 | return F.silu(gate) * x 170 | 171 | def FeedForward(dim, mult = 4): 172 | dim_inner = int(dim * mult * 2 / 3) 173 | 174 | return nn.Sequential( 175 | nn.RMSNorm(dim), 176 | nn.Linear(dim, dim_inner * 2), 177 | GEGLU(), 178 | nn.Linear(dim_inner, dim) 179 | ) 180 | 181 | class SegmentedAttention(Module): 182 | def __init__( 183 | self, 184 | dim, 185 | segment_len, 186 | num_persist_mem_tokens = 0, 187 | num_longterm_mem_tokens = 0, 188 | dim_head = 64, 189 | heads = 8, 190 | sliding = False, 191 | accept_value_residual = False, 192 | attend_kwargs: dict = dict(), 193 | use_flex_attn = False 194 | ): 195 | super().__init__() 196 | self.norm = nn.RMSNorm(dim) 197 | 198 | dim_inner = dim_head * heads 199 | 200 | self.rotary_emb = RotaryEmbedding(dim_head) 201 | 202 | self.attend = Attend(causal = True, **attend_kwargs) 203 | 204 | self.to_qkv = LinearNoBias(dim, dim_inner * 3) 205 | self.to_out = LinearNoBias(dim_inner, dim) 206 | 207 | self.to_learned_v_mix = nn.Sequential( 208 | nn.Linear(dim, heads), 209 | Rearrange('b n h -> b h n 1'), 210 | nn.Sigmoid() 211 | ) if accept_value_residual else None 212 | 213 | self.segment_len = segment_len 214 | self.num_longterm_mem_tokens = num_longterm_mem_tokens 215 | 216 | total_segment_len = segment_len + num_longterm_mem_tokens 217 | self.total_segment_len = total_segment_len 218 | 219 | self.sliding = sliding # sliding window attn - doubt their non-sliding results being the best. local attention with overlapping windows is very strong 220 | 221 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) 222 | self.merge_heads = Rearrange('b h n d -> b n (h d)') 223 | 224 | self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head)) 225 | 226 | # flex attn related 227 | 228 | assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available' 229 | self.use_flex_attn = use_flex_attn 230 | 231 | self.segment_len = segment_len 232 | self.num_persist_mem_tokens = num_persist_mem_tokens 233 | 234 | def forward_inference( 235 | self, 236 | token, 237 | cache, 238 | value_residual = None, 239 | output_gating = None, 240 | ): 241 | batch = token.shape[0] 242 | 243 | # attention 244 | 245 | token = self.norm(token) 246 | 247 | q, k, v = self.to_qkv(token).chunk(3, dim = -1) 248 | q, k, v = map(self.split_heads, (q, k, v)) 249 | 250 | # value residual 251 | 252 | orig_v = v 253 | 254 | if exists(self.to_learned_v_mix): 255 | mix = self.to_learned_v_mix(token) 256 | v = v.lerp(value_residual, mix) 257 | 258 | # caching 259 | 260 | ck, cv = cache 261 | k = cat((ck, k), dim = -2) 262 | v = cat((cv, v), dim = -2) 263 | 264 | next_cache = (k, v) 265 | 266 | # relative positions 267 | 268 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) 269 | 270 | # fold 271 | 272 | q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v)) 273 | 274 | # take care of persistent memory key / values 275 | 276 | pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0]) 277 | 278 | # persistent memory 279 | 280 | k = cat((pmk, k), dim = -2) 281 | v = cat((pmv, v), dim = -2) 282 | 283 | # attention 284 | 285 | out, _ = self.attend(q, k, v) 286 | 287 | out = self.merge_heads(out) 288 | 289 | out = self.to_out(out) 290 | 291 | if exists(output_gating): 292 | out = out * output_gating 293 | 294 | return out, AttnIntermediates(orig_v, next_cache) 295 | 296 | def forward_flex( 297 | self, 298 | seq, 299 | value_residual = None, 300 | flex_attn_fn: Callable | None = None, 301 | output_gating = None, 302 | cache = None 303 | ): 304 | 305 | assert not (exists(value_residual) ^ exists(self.to_learned_v_mix)) 306 | 307 | batch, seq_len = seq.shape[:2] 308 | 309 | # attention 310 | 311 | seq = self.norm(seq) 312 | 313 | q, k, v = self.to_qkv(seq).chunk(3, dim = -1) 314 | q, k, v = map(self.split_heads, (q, k, v)) 315 | 316 | # value residual 317 | 318 | orig_v = v 319 | 320 | if exists(self.to_learned_v_mix): 321 | mix = self.to_learned_v_mix(seq) 322 | v = v.lerp(value_residual, mix) 323 | 324 | # caching 325 | 326 | next_cache = (k, v) 327 | 328 | # take care of persistent memory key / values 329 | 330 | pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch) 331 | 332 | # relative positions 333 | 334 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) 335 | 336 | # persistent memory 337 | 338 | k = cat((pmk, k), dim = -2) 339 | v = cat((pmv, v), dim = -2) 340 | 341 | # prep flex attention 342 | 343 | if not exists(flex_attn_fn): 344 | block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens, self.sliding) 345 | 346 | flex_attn_fn = partial(flex_attention, block_mask = block_mask) 347 | 348 | # attention 349 | 350 | out = flex_attn_fn(q, k, v) 351 | 352 | out = self.merge_heads(out) 353 | 354 | out = self.to_out(out) 355 | 356 | if exists(output_gating): 357 | out = out * output_gating 358 | 359 | return out, AttnIntermediates(orig_v, next_cache) 360 | 361 | def forward( 362 | self, 363 | seq, 364 | value_residual = None, 365 | flex_attn_fn: Callable | None = None, 366 | disable_flex_attn = False, 367 | output_gating = None, 368 | cache = None 369 | ): 370 | is_inferencing = exists(cache) 371 | 372 | if is_inferencing: 373 | assert seq.shape[-2] == 1 374 | return self.forward_inference(seq, cache, value_residual, output_gating = output_gating) 375 | 376 | if seq.is_cuda and self.use_flex_attn and not disable_flex_attn: 377 | return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache) 378 | 379 | assert not (exists(value_residual) ^ exists(self.to_learned_v_mix)) 380 | 381 | segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens 382 | total_segment_len = segment_len + num_longterm_mem_tokens 383 | 384 | batch, seq_len = seq.shape[:2] 385 | 386 | # auto pad to multiple 387 | 388 | seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len, fold_into_batch = False) 389 | 390 | # attention 391 | 392 | seq = self.norm(seq) 393 | 394 | q, k, v = self.to_qkv(seq).chunk(3, dim = -1) 395 | q, k, v = map(self.split_heads, (q, k, v)) 396 | 397 | # value residual 398 | 399 | orig_v = v 400 | 401 | if exists(self.to_learned_v_mix): 402 | mix = self.to_learned_v_mix(seq) 403 | v = v.lerp(value_residual, mix) 404 | 405 | # caching 406 | 407 | next_cache = tuple(map(inverse_segment, (k, v))) 408 | 409 | # relative positions 410 | 411 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) 412 | 413 | # fold 414 | 415 | q, k, v = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = total_segment_len) for t in (q, k, v)) 416 | 417 | # maybe sliding for cpu 418 | 419 | attend_kwargs = dict() 420 | 421 | if self.sliding: 422 | k, v = tuple(rearrange(t, '(b w) ... -> b w ...', b = batch) for t in (k, v)) 423 | k, v = tuple(pad_at_dim(t, (1, 0), value = 0., dim = 1) for t in (k, v)) 424 | k = cat((k[:, :-1], k[:, 1:]), dim = -2) 425 | v = cat((v[:, :-1], v[:, 1:]), dim = -2) 426 | k, v = tuple(rearrange(t, 'b w ... -> (b w) ...') for t in (k, v)) 427 | 428 | # take care of masking 429 | 430 | idx = torch.arange(seq.shape[-2], device = seq.device) 431 | q_idx = rearrange(idx, '(w n) -> w n', n = total_segment_len) 432 | k_idx = pad_at_dim(q_idx, (1, 0), dim = 0, value = -1e4) 433 | k_idx = cat((k_idx[:-1], k_idx[1:]), dim = -1) 434 | 435 | q_idx = rearrange(q_idx, 'w i -> w i 1') 436 | k_idx = rearrange(k_idx, 'w j -> w 1 j') 437 | 438 | sliding_mask = (q_idx - k_idx) <= total_segment_len 439 | sliding_mask = F.pad(sliding_mask, (self.num_persist_mem_tokens, 0), value = True) 440 | 441 | sliding_mask = repeat(sliding_mask, 'w i j -> (b w) 1 i j', b = batch) 442 | attend_kwargs.update(mask = sliding_mask) 443 | 444 | # take care of persistent memory key / values 445 | 446 | pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0]) 447 | 448 | # persistent memory 449 | 450 | k = cat((pmk, k), dim = -2) 451 | v = cat((pmv, v), dim = -2) 452 | 453 | # attention 454 | 455 | out, _ = self.attend(q, k, v, **attend_kwargs) 456 | 457 | out = self.merge_heads(out) 458 | 459 | out = self.to_out(out) 460 | 461 | out = rearrange(out, '(b w) n d -> b (w n) d', b = batch) 462 | 463 | out = inverse_segment(out) 464 | 465 | if exists(output_gating): 466 | out = out * output_gating 467 | 468 | return out, AttnIntermediates(orig_v, next_cache) 469 | 470 | # MAC transformer 471 | 472 | class MemoryAsContextTransformer(Module): 473 | def __init__( 474 | self, 475 | *, 476 | num_tokens, 477 | dim, 478 | depth, 479 | segment_len, 480 | neural_memory_segment_len = None, 481 | neural_mem_gate_attn_output = False, 482 | neural_memory_add_value_residual = False, 483 | num_longterm_mem_tokens = 0, 484 | num_persist_mem_tokens = 0, 485 | neural_memory_batch_size = None, 486 | neural_memory_qkv_receives_diff_views = False, 487 | dim_head = 64, 488 | heads = 8, 489 | ff_mult = 4, 490 | num_residual_streams = 4, 491 | neural_memory_model: Module | None = None, 492 | neural_memory_kwargs: dict = dict(), 493 | neural_memory_layers: tuple[int, ...] | None = None, 494 | use_flex_attn = False, 495 | sliding_window_attn = False, 496 | neural_mem_weight_residual = False, 497 | token_emb: Module | None = None, 498 | ): 499 | super().__init__() 500 | 501 | if not exists(token_emb): 502 | token_emb = nn.Embedding(num_tokens, dim) 503 | 504 | self.token_emb = token_emb 505 | 506 | # absolute positions 507 | 508 | self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2) 509 | 510 | # long term mem tokens 511 | 512 | self.segment_len = segment_len 513 | 514 | self.num_longterm_mem_tokens = num_longterm_mem_tokens 515 | has_longterm_mems = num_longterm_mem_tokens > 0 516 | 517 | self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02) 518 | 519 | # maybe sliding window attn 520 | 521 | self.sliding_window_attn = sliding_window_attn 522 | self.attn_window_size = segment_len + num_longterm_mem_tokens 523 | 524 | # hyper connection 525 | 526 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1) 527 | 528 | self.layers = ModuleList([]) 529 | 530 | self.neural_memory_segment_len = default(neural_memory_segment_len, num_longterm_mem_tokens + segment_len) 531 | 532 | layers = tuple(range(1, depth + 1)) 533 | 534 | neural_memory_layers = default(neural_memory_layers, layers) 535 | 536 | # weight residual related 537 | 538 | self.neural_mem_weight_residual = neural_mem_weight_residual 539 | is_first_neural_mem = True 540 | 541 | # mem, attn, and feedforward layers 542 | 543 | for layer in layers: 544 | is_first = layer == 1 545 | 546 | # attention and feedforward 547 | 548 | attn = SegmentedAttention( 549 | dim = dim, 550 | dim_head = dim_head, 551 | heads = heads, 552 | segment_len = segment_len, 553 | use_flex_attn = use_flex_attn, 554 | accept_value_residual = not is_first, 555 | num_longterm_mem_tokens = num_longterm_mem_tokens, 556 | num_persist_mem_tokens = num_persist_mem_tokens, 557 | sliding = sliding_window_attn 558 | ) 559 | 560 | mem = None 561 | mem_qkv_layer_selector = None 562 | mem_hyper_conn = None 563 | 564 | if layer in neural_memory_layers: 565 | mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output) 566 | 567 | if not is_first and neural_memory_qkv_receives_diff_views: 568 | num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input) 569 | 570 | mem_qkv_layer_selector = nn.Sequential( 571 | nn.RMSNorm(dim), 572 | nn.Linear(dim, 3 * num_layer_choices), 573 | Rearrange('... (views layers) -> views ... layers', views = 3), 574 | nn.Softmax(dim = -1) 575 | ) 576 | 577 | mem = NeuralMemory( 578 | dim = dim, 579 | chunk_size = self.neural_memory_segment_len, 580 | batch_size = neural_memory_batch_size, 581 | model = deepcopy(neural_memory_model), 582 | qkv_receives_diff_views = True, 583 | accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem, 584 | **neural_memory_kwargs 585 | ) 586 | 587 | is_first_neural_mem = False 588 | 589 | ff = FeedForward(dim = dim, mult = ff_mult) 590 | 591 | self.layers.append(ModuleList([ 592 | mem_hyper_conn, 593 | init_hyper_conn(), 594 | init_hyper_conn(), 595 | mem_qkv_layer_selector, 596 | mem, 597 | attn, 598 | ff, 599 | ])) 600 | 601 | self.norm = nn.RMSNorm(dim) 602 | 603 | self.to_logits = LinearNoBias(dim, num_tokens) 604 | 605 | # whether to gate the attention output with the retrieved memories 606 | 607 | self.gate_attn_output = neural_mem_gate_attn_output 608 | 609 | # zero for maybe aux loss + device 610 | 611 | self.register_buffer('zero', torch.tensor(0.), persistent = False) 612 | 613 | # flex attn related 614 | 615 | assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available' 616 | self.use_flex_attn = use_flex_attn 617 | 618 | self.num_persist_mem_tokens = num_persist_mem_tokens 619 | 620 | def seq_index_is_longterm( 621 | self, 622 | seq_index 623 | ): 624 | total_segment_len, segment_len = self.attn_window_size, self.segment_len 625 | return ((seq_index % total_segment_len + 1) - segment_len) > 0 626 | 627 | def seq_len_with_longterm_mem( 628 | self, 629 | seq_len 630 | ): 631 | assert seq_len > 0 632 | 633 | segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens 634 | return ((seq_len - 1) // segment_len) * num_mem + seq_len 635 | 636 | @torch.no_grad() 637 | def sample( 638 | self, 639 | prompt: Tensor, 640 | seq_len: int, 641 | temperature = 1.5, 642 | filter_fn: Callable = min_p_filter, 643 | filter_kwargs: dict = dict( 644 | min_p = 0.1, 645 | ), 646 | show_progress = True, 647 | use_cache = False 648 | ): 649 | was_training = self.training 650 | self.eval() 651 | 652 | prompt_seq_len, out = prompt.shape[-1], prompt.clone() 653 | sample_num_times = max(0, seq_len - prompt_seq_len) 654 | 655 | # cache for axial pos, attention, and neural memory 656 | 657 | cache = None 658 | factorized_pos_emb = None 659 | 660 | # precompute factorized pos emb 661 | 662 | if use_cache: 663 | seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len) 664 | 665 | axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,)) 666 | 667 | factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True) 668 | 669 | # sample 670 | 671 | with tqdm.tqdm(total = sample_num_times, disable = not show_progress) as pbar: 672 | 673 | while out.shape[-1] < seq_len: 674 | 675 | logits, next_cache = self.forward( 676 | out, 677 | disable_flex_attn = True, 678 | cache = cache, 679 | return_cache = True, 680 | factorized_pos_emb = factorized_pos_emb 681 | ) 682 | 683 | if use_cache: 684 | cache = next_cache 685 | 686 | if not exists(logits): 687 | continue 688 | 689 | logits = logits[:, -1] 690 | 691 | logits = filter_fn(logits, **filter_kwargs) 692 | sample = gumbel_sample(logits, temperature = temperature) 693 | 694 | out = torch.cat((out, sample), dim = -1) 695 | pbar.update(1) 696 | 697 | self.train(was_training) 698 | 699 | return out[..., prompt_seq_len:] 700 | 701 | def forward( 702 | self, 703 | x, 704 | return_loss = False, 705 | return_loss_breakdown = False, 706 | disable_flex_attn = False, 707 | cache = None, 708 | return_cache = False, 709 | factorized_pos_emb = None 710 | ): 711 | 712 | if return_loss: 713 | x, labels = x[:, :-1], x[:, 1:] 714 | 715 | # math 716 | 717 | batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size 718 | 719 | seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len) 720 | 721 | # token embedding 722 | 723 | x = self.token_emb(x) 724 | 725 | # intersperse longterm memory 726 | 727 | x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False) 728 | 729 | mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0]) 730 | x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d') 731 | 732 | x = inverse_segment(x) 733 | 734 | # splice out unneeded tokens from padding for longterm mems 735 | 736 | x = x[:, :seq_len_with_mem] 737 | 738 | # apply axial positional embedding 739 | # so intra and inter segment can be more easily discerned by the network 740 | 741 | pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb) 742 | 743 | x = x + pos_emb 744 | 745 | # prep flex attention 746 | 747 | use_flex_attn = x.is_cuda and self.use_flex_attn and not disable_flex_attn 748 | 749 | flex_attn_fn = None 750 | 751 | if use_flex_attn: 752 | block_mask = create_mac_block_mask(seq_len_with_mem, self.attn_window_size, self.num_persist_mem_tokens, self.sliding_window_attn) 753 | flex_attn_fn = partial(flex_attention, block_mask = block_mask) 754 | 755 | # kv caching 756 | 757 | is_inferencing = exists(cache) 758 | 759 | if not exists(cache): 760 | cache = (seq_len_with_mem - 1, None, None) 761 | 762 | inference_seq_index, kv_caches, neural_mem_caches = cache 763 | 764 | kv_caches = iter(default(kv_caches, [])) 765 | neural_mem_caches = iter(default(neural_mem_caches, [])) 766 | 767 | next_kv_caches = [] 768 | next_neural_mem_caches = [] 769 | 770 | # value residual 771 | 772 | value_residual = None 773 | 774 | # neural mem weight residual 775 | 776 | mem_weight_residual = None 777 | 778 | # layers for the neural mem to select the qkv inputs from 779 | 780 | mem_input_layers = [] 781 | 782 | # when inferencing, only do one token at a time 783 | 784 | if is_inferencing: 785 | ind = inference_seq_index 786 | x = x[:, ind:(ind + 1)] 787 | 788 | # expand and reduce streams for hyper connections 789 | 790 | x = self.expand_streams(x) 791 | 792 | for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers: 793 | 794 | retrieved = None 795 | attn_out_gates = None 796 | next_neural_mem_cache = None 797 | 798 | # maybe neural memory 799 | 800 | if exists(mem): 801 | 802 | mem_input, add_residual = mem_hyper_conn(x) 803 | 804 | if not exists(mem_qkv_layer_selector): 805 | qkv_mem_input = stack((mem_input, mem_input, mem_input)) 806 | else: 807 | layers_to_choose_from = stack((mem_input, *mem_input_layers)) 808 | 809 | # let the current `mem_input` select the 3 layers for qkv 810 | 811 | selected = mem_qkv_layer_selector(mem_input) 812 | 813 | qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d') 814 | 815 | retrieved, next_neural_mem_cache = mem.forward( 816 | qkv_mem_input, 817 | state = next(neural_mem_caches, None), 818 | prev_weights = mem_weight_residual 819 | ) 820 | 821 | if self.neural_mem_weight_residual: 822 | mem_weight_residual = next_neural_mem_cache.updates 823 | 824 | if self.gate_attn_output: 825 | attn_out_gates = retrieved.sigmoid() 826 | else: 827 | x = add_residual(retrieved) 828 | 829 | # attention 830 | 831 | attn_in, add_residual = attn_hyper_conn(x) 832 | 833 | mem_input_layers.append(attn_in) 834 | 835 | attn_out, (values, next_kv_cache) = attn( 836 | attn_in, 837 | value_residual = value_residual, 838 | disable_flex_attn = disable_flex_attn, 839 | flex_attn_fn = flex_attn_fn, 840 | output_gating = attn_out_gates, 841 | cache = next(kv_caches, None) 842 | ) 843 | 844 | mem_input_layers.append(attn_out) 845 | 846 | value_residual = default(value_residual, values) 847 | 848 | x = add_residual(attn_out) 849 | 850 | # caches 851 | 852 | next_kv_caches.append(next_kv_cache) 853 | next_neural_mem_caches.append(next_neural_mem_cache) 854 | 855 | # feedforward 856 | 857 | ff_in, add_ff_residual = ff_hyper_conn(x) 858 | 859 | mem_input_layers.append(ff_in) 860 | 861 | ff_out = ff(ff_in) 862 | 863 | mem_input_layers.append(ff_out) 864 | 865 | x = add_ff_residual(ff_out) 866 | 867 | # taking care of cache first 868 | # for early return when processing long term mem tokens during inference 869 | 870 | if return_cache: 871 | next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches]) 872 | 873 | # handle kv cache length depending on local attention type 874 | 875 | next_kv_caches = next_kv_caches[..., -attn_window_size:, :] 876 | 877 | kv_cache_length = next_kv_caches.shape[-2] 878 | 879 | if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size): 880 | next_kv_caches = next_kv_caches[..., 0:0, :] 881 | 882 | next_cache = ( 883 | inference_seq_index + 1, 884 | next_kv_caches, 885 | next_neural_mem_caches 886 | ) 887 | 888 | is_longterm_mem = self.seq_index_is_longterm(inference_seq_index) 889 | 890 | if is_inferencing and is_longterm_mem: 891 | return None, next_cache 892 | 893 | # hyper connection reducing of streams 894 | 895 | x = self.reduce_streams(x) 896 | 897 | # excise out the memories 898 | 899 | if not is_inferencing: 900 | 901 | x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False) 902 | 903 | x, _ = inverse_pack_mems(x) 904 | 905 | x = inverse_segment(x) 906 | 907 | x = x[:, :seq_len] 908 | 909 | # to logits 910 | 911 | x = self.norm(x) 912 | 913 | logits = self.to_logits(x) 914 | 915 | if not return_loss: 916 | if not return_cache: 917 | return logits 918 | 919 | return logits, next_cache 920 | 921 | return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels) 922 | -------------------------------------------------------------------------------- /titans_pytorch/memory_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, cat 3 | import torch.nn.functional as F 4 | from torch.nn import Module, ModuleList, Parameter, ParameterList 5 | 6 | from einops import rearrange 7 | 8 | # functions 9 | 10 | def l2norm(t): 11 | return F.normalize(t, dim = -1) 12 | 13 | # norms 14 | 15 | class LayerNorm(Module): 16 | def __init__( 17 | self, 18 | dim 19 | ): 20 | super().__init__() 21 | 22 | self.ln = nn.LayerNorm(dim, elementwise_affine = False) 23 | self.gamma = Parameter(torch.zeros(dim)) 24 | 25 | def forward(self, x): 26 | gamma = self.gamma 27 | 28 | if gamma.ndim == 2: 29 | gamma = rearrange(gamma, 'b d -> b 1 d') 30 | 31 | return self.ln(x) * (gamma + 1.) 32 | 33 | # norm + residual wrapper, as used in original TTT paper 34 | # but could be removed 35 | 36 | class ResidualNorm(Module): 37 | def __init__( 38 | self, 39 | dim, 40 | model: Module 41 | ): 42 | super().__init__() 43 | self.norm = LayerNorm(dim) 44 | self.model = model 45 | 46 | def forward(self, x): 47 | 48 | out = self.model(x) 49 | 50 | return self.norm(out) + x 51 | 52 | # memory mlp proposed in TTT 53 | 54 | class MemoryMLP(Module): 55 | def __init__( 56 | self, 57 | dim, 58 | depth, 59 | expansion_factor = 2. 60 | ): 61 | super().__init__() 62 | dim_hidden = int(dim * expansion_factor) 63 | dims = (dim, *((dim_hidden,) * (depth - 1)), dim) 64 | 65 | self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 66 | 67 | for weight in self.weights: 68 | nn.init.xavier_uniform_(weight) 69 | 70 | def forward( 71 | self, 72 | x 73 | ): 74 | for ind, weight in enumerate(self.weights): 75 | is_first = ind == 0 76 | 77 | if not is_first: 78 | x = F.gelu(x) 79 | 80 | x = x @ weight 81 | 82 | return x 83 | 84 | # memory mlp, but with gated residual + final projection 85 | 86 | class GatedResidualMemoryMLP(Module): 87 | def __init__( 88 | self, 89 | dim, 90 | depth, 91 | expansion_factor = 4. 92 | ): 93 | super().__init__() 94 | dim_hidden = int(dim * expansion_factor) 95 | 96 | self.weights = ParameterList([ 97 | ParameterList([ 98 | Parameter(torch.randn(dim, dim_hidden)), 99 | Parameter(torch.randn(dim_hidden, dim)), 100 | Parameter(torch.randn(dim * 2, dim)), 101 | ]) for _ in range(depth) 102 | ]) 103 | 104 | self.final_proj = Parameter(torch.randn(dim, dim)) 105 | 106 | for param in self.parameters(): 107 | nn.init.xavier_uniform_(param) 108 | 109 | def forward( 110 | self, 111 | x 112 | ): 113 | 114 | for weight1, weight2, to_gates in self.weights: 115 | res = x 116 | 117 | hidden = x @ weight1 118 | hidden = F.gelu(hidden) 119 | branch_out = hidden @ weight2 120 | 121 | # gated residual 122 | 123 | gates = cat((branch_out, res), dim = -1) @ to_gates 124 | x = res.lerp(branch_out, gates.sigmoid()) 125 | 126 | return x @ self.final_proj 127 | 128 | # memory mlp with factorized weights 129 | # so can tradeoff capacity for smaller chunk sizes 130 | 131 | class FactorizedMemoryMLP(Module): 132 | def __init__( 133 | self, 134 | dim, 135 | depth, 136 | k = 32 137 | ): 138 | super().__init__() 139 | self.weights = ParameterList([ 140 | ParameterList([ 141 | Parameter(torch.randn(dim, k)), 142 | Parameter(torch.randn(k, dim)), 143 | ]) for _ in range(depth) 144 | ]) 145 | 146 | for weight1, weight2 in self.weights: 147 | nn.init.xavier_uniform_(weight1) 148 | nn.init.xavier_uniform_(weight2) 149 | 150 | def forward( 151 | self, 152 | x 153 | ): 154 | 155 | for ind, (weight1, weight2) in enumerate(self.weights): 156 | is_first = ind == 0 157 | 158 | if not is_first: 159 | x = F.gelu(x) 160 | 161 | x = x @ weight1 @ weight2 162 | 163 | return x 164 | 165 | # an MLP modelled after the popular swiglu ff in modern transformers 166 | 167 | class MemorySwiGluMLP(Module): 168 | def __init__( 169 | self, 170 | dim, 171 | depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual 172 | expansion_factor = 4. 173 | ): 174 | super().__init__() 175 | 176 | dim_inner = int(dim * expansion_factor * 2 / 3) 177 | 178 | weights = [] 179 | 180 | for _ in range(depth): 181 | weights.append(ParameterList([ 182 | Parameter(torch.randn(dim, dim_inner * 2)), 183 | Parameter(torch.randn(dim_inner, dim)), 184 | ])) 185 | 186 | self.weights = ParameterList(weights) 187 | self.norm = LayerNorm(dim) 188 | 189 | def forward(self, x): 190 | 191 | for w1, w2 in self.weights: 192 | residual = x 193 | 194 | x, gates = (x @ w1).chunk(2, dim = -1) 195 | 196 | x = x * F.gelu(gates) 197 | 198 | x = x @ w2 199 | 200 | x = x + residual 201 | 202 | return self.norm(x) 203 | 204 | # improvised attention as memory module 205 | 206 | class MemoryAttention(Module): 207 | def __init__( 208 | self, 209 | dim, 210 | scale = 8., 211 | expansion_factor = 2. 212 | ): 213 | super().__init__() 214 | self.scale = scale 215 | dim_ff_hidden = int(dim * expansion_factor) 216 | 217 | self.weights = ParameterList([ 218 | Parameter(torch.randn(dim, dim)), # queries 219 | Parameter(torch.randn(dim, dim)), # keys 220 | Parameter(torch.randn(dim, dim)), # values 221 | Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1 222 | Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2 223 | ]) 224 | 225 | for weight in self.weights: 226 | nn.init.xavier_uniform_(weight) 227 | 228 | def forward(self, x): 229 | 230 | wq, wk, wv, ffw1, ffw2 = self.weights 231 | 232 | q = l2norm(x @ wq) 233 | k = l2norm(x @ wk) 234 | v = x @ wv 235 | 236 | attn_out = F.scaled_dot_product_attention( 237 | q, k, v, 238 | scale = self.scale, 239 | is_causal = True 240 | ) 241 | 242 | # parallel attention + feedforward block 243 | # as in PaLM + Gpt-J 244 | 245 | h = F.gelu(x @ ffw1) 246 | ff_out = h @ ffw2 247 | 248 | return attn_out + ff_out 249 | -------------------------------------------------------------------------------- /titans_pytorch/neural_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | 4 | import math 5 | from functools import partial 6 | from itertools import zip_longest 7 | from collections import namedtuple 8 | 9 | import torch 10 | from torch import nn, stack, cat, is_tensor, tensor, Tensor 11 | import torch.nn.functional as F 12 | from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict 13 | from torch.func import functional_call, vmap, grad 14 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 15 | 16 | from tensordict import TensorDict 17 | 18 | from assoc_scan import AssocScan 19 | 20 | from titans_pytorch.memory_models import( 21 | MemoryMLP, 22 | ResidualNorm 23 | ) 24 | 25 | import einx 26 | from einops import einsum, rearrange, repeat, reduce, pack, unpack 27 | from einops.layers.torch import Rearrange, Reduce 28 | 29 | """ 30 | ein notation: 31 | b - batch 32 | h - heads 33 | bh - batch and heads 34 | n - sequence 35 | d - feature dimension 36 | c - intra-chunk 37 | w - num memory network weight parameters 38 | o - momentum orders 39 | u - key / value updates - allowing a token to emit multiple key / values 40 | """ 41 | 42 | LinearNoBias = partial(Linear, bias = False) 43 | 44 | # neural mem state related 45 | 46 | NeuralMemState = namedtuple('NeuralMemState', [ 47 | 'seq_index', 48 | 'weights', 49 | 'cache_store_segment', 50 | 'states', 51 | 'updates', 52 | ]) 53 | 54 | def mem_state_detach( 55 | state: NeuralMemState 56 | ): 57 | assert isinstance(state, NeuralMemState) 58 | state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state)) 59 | return NeuralMemState(*state) 60 | 61 | # functions 62 | 63 | def exists(v): 64 | return v is not None 65 | 66 | def default(*args): 67 | for arg in args: 68 | if exists(arg): 69 | return arg 70 | return None 71 | 72 | def identity(t): 73 | return t 74 | 75 | def xnor(x, y): 76 | return not (x ^ y) 77 | 78 | def divisible_by(num, den): 79 | return (num % den) == 0 80 | 81 | def safe_cat(inputs, dim = -2): 82 | inputs = tuple(filter(exists, inputs)) 83 | 84 | if len(inputs) == 0: 85 | return None 86 | elif len(inputs) == 1: 87 | return inputs[0] 88 | 89 | return cat(inputs, dim = dim) 90 | 91 | def is_empty_tensor(t): 92 | return t.numel() == 0 93 | 94 | def dict_get_value_shapes(td): 95 | return [v.shape for k, v in td.items()] 96 | 97 | def rearrange_dict_values(td, pattern, **kwargs): 98 | return td.apply(lambda t: rearrange(t, pattern, **kwargs)) 99 | 100 | def repeat_dict_values(td, pattern, **kwargs): 101 | return td.apply(lambda t: repeat(t, pattern, **kwargs)) 102 | 103 | def pair(v): 104 | return (v, v) if not isinstance(v, tuple) else v 105 | 106 | def round_down_multiple(seq, mult): 107 | return seq // mult * mult 108 | 109 | def round_up_multiple(seq, mult): 110 | return math.ceil(seq / mult) * mult 111 | 112 | def pad_at_dim(t, pad, dim = -1, value = 0.): 113 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 114 | zeros = ((0, 0) * dims_from_right) 115 | return F.pad(t, (*zeros, *pad), value = value) 116 | 117 | def pack_one_with_inverse(t, pattern): 118 | packed, packed_shape = pack([t], pattern) 119 | 120 | def inverse(out, inv_pattern = None): 121 | inv_pattern = default(inv_pattern, pattern) 122 | return unpack(out, packed_shape, inv_pattern)[0] 123 | 124 | return packed, inverse 125 | 126 | def Sequential(*modules): 127 | modules = [*filter(exists, modules)] 128 | 129 | if len(modules) == 0: 130 | return nn.Identity() 131 | 132 | if len(modules) == 1: 133 | return modules[0] 134 | 135 | return nn.Sequential(*modules) 136 | 137 | # softclamping gradients 138 | 139 | def softclamp_max(t, max_value): 140 | half_max_value = max_value / 2 141 | return ((t / half_max_value).tanh() * half_max_value) + half_max_value 142 | 143 | def softclamp_grad_norm(t, max_value): 144 | if is_empty_tensor(t): 145 | return t 146 | 147 | t, inverse = pack_one_with_inverse(t, 'bn *') 148 | 149 | norm = t.norm(dim = -1, keepdim = True) 150 | clamped_norm = softclamp_max(norm, max_value) 151 | 152 | t = t * (clamped_norm / norm) 153 | return inverse(t) 154 | 155 | # spectral norming the surprise update w/ newton schulz matrix iter 156 | # Keller Jordan et al. from OSS w/ nanogpt, now being used for two works, Atlas and 'TTT done right' 157 | 158 | def newtonschulz5( 159 | t, 160 | steps = 5, 161 | eps = 1e-7, 162 | coefs = (3.4445, -4.7750, 2.0315) 163 | ): 164 | if t.ndim <= 3: 165 | return t 166 | 167 | shape = t.shape 168 | should_transpose = shape[-2] > shape[-1] 169 | 170 | if should_transpose: 171 | t = t.transpose(-1, -2) 172 | 173 | t, inv_pack = pack_one_with_inverse(t, '* i j') 174 | t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps) 175 | 176 | a, b, c = coefs 177 | 178 | for _ in range(steps): 179 | A = t @ t.transpose(-1, -2) 180 | B = b * A + c * A @ A 181 | t = a * t + B @ t 182 | 183 | if should_transpose: 184 | t = t.transpose(-1, -2) 185 | 186 | return inv_pack(t) 187 | 188 | # multi head rmsnorm 189 | 190 | class MultiheadRMSNorm(Module): 191 | def __init__(self, dim, heads): 192 | super().__init__() 193 | self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False) 194 | self.gamma = Parameter(torch.zeros(heads, 1, dim)) 195 | 196 | def forward(self, x): 197 | return self.rmsnorm(x) * (self.gamma + 1.) 198 | 199 | # chunk pooling 200 | 201 | class AveragePool(Module): 202 | def __init__( 203 | self, 204 | chunk_size 205 | ): 206 | super().__init__() 207 | self.chunk_size = chunk_size 208 | 209 | def forward( 210 | self, 211 | x, 212 | chunk_size = None 213 | ): 214 | chunk_size = default(chunk_size, self.chunk_size) 215 | return reduce(x, 'b (n c) d -> b n d', 'mean', c = chunk_size) 216 | 217 | class AttentionPool(Module): 218 | def __init__( 219 | self, 220 | dim, 221 | chunk_size 222 | ): 223 | """ 224 | taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else 225 | """ 226 | super().__init__() 227 | self.chunk_size = chunk_size 228 | self.to_attn_logits = nn.Linear(dim, dim) 229 | 230 | # default to average pool 231 | 232 | nn.init.zeros_(self.to_attn_logits.weight) 233 | nn.init.zeros_(self.to_attn_logits.bias) 234 | 235 | def forward( 236 | self, 237 | x, 238 | chunk_size = None 239 | ): 240 | chunk_size = default(chunk_size, self.chunk_size) 241 | 242 | x = rearrange(x, 'b (n c) d -> b n c d', c = chunk_size) 243 | 244 | attn_logits = self.to_attn_logits(x) 245 | 246 | attn = attn_logits.softmax(dim = -2) 247 | 248 | return reduce(x * attn, 'b n c d -> b n d', 'sum') 249 | 250 | # main neural memory 251 | 252 | def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2): 253 | return adaptive_step.sigmoid() * max_lr 254 | 255 | def default_loss_fn(pred, target): 256 | return (pred - target).pow(2).mean(dim = -1) 257 | 258 | class NeuralMemory(Module): 259 | def __init__( 260 | self, 261 | dim, 262 | chunk_size: int | tuple[int, int] = 1, 263 | batch_size = None, 264 | dim_head = None, 265 | heads = 1, 266 | model: Module | None = None, 267 | store_memory_loss_fn: Callable = default_loss_fn, 268 | adaptive_step_transform: Callable | None = None, 269 | default_step_transform_max_lr = 1., 270 | per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network 271 | max_mem_layer_modulation = 1., # max of 10. 272 | per_head_learned_parameters = True, 273 | attn_pool_chunks = False, 274 | momentum = True, 275 | momentum_order = 1, 276 | learned_momentum_combine = False, 277 | learned_combine_include_zeroth = False, 278 | num_kv_per_token = 1, # whether a single token can do multiple updates to the memory model 279 | qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything 280 | pre_rmsnorm = True, 281 | post_rmsnorm = False, 282 | qk_rmsnorm = False, 283 | max_grad_norm: float | None = None, 284 | use_accelerated_scan = False, 285 | activation: Module | None = None, 286 | init_adaptive_step_bias = None, 287 | init_momentum_bias = None, 288 | init_decay_bias = None, 289 | accept_weight_residual = False, 290 | spectral_norm_surprises = False, 291 | gated_transition = False, 292 | mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed 293 | default_model_kwargs: dict = dict( 294 | depth = 2, 295 | expansion_factor = 4. 296 | ) 297 | ): 298 | super().__init__() 299 | dim_head = default(dim_head, dim) 300 | assert not (heads == 1 and dim_head != dim) 301 | 302 | self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size) 303 | 304 | # batch size 305 | 306 | if exists(batch_size): 307 | assert divisible_by(batch_size, self.store_chunk_size) 308 | 309 | self.batch_size = batch_size 310 | 311 | # associative scan 312 | 313 | self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan) 314 | 315 | # key values receiving different views 316 | 317 | self.qkv_receives_diff_views = qkv_receives_diff_views 318 | 319 | # norms 320 | 321 | self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity() 322 | self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity() 323 | 324 | self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity() 325 | 326 | self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity() 327 | self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity() 328 | 329 | # maybe multi-headed 330 | 331 | dim_inner = dim_head * heads 332 | 333 | self.heads = heads 334 | 335 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) 336 | self.split_kv_heads = Rearrange('b n (h u d) -> b h (n u) d', h = heads, u = num_kv_per_token) 337 | 338 | self.merge_heads = Rearrange('b h n d -> b n (h d)') 339 | self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity() 340 | 341 | self.retrieve_gate = Sequential( 342 | LinearNoBias(dim, heads), 343 | Rearrange('b n h -> b h n 1'), 344 | nn.Sigmoid() 345 | ) if heads > 1 else None 346 | 347 | # memory model 348 | 349 | if not exists(model): 350 | model = MemoryMLP(dim_head, **default_model_kwargs) 351 | 352 | # validate memory model 353 | 354 | assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now' 355 | 356 | test_shape = (3, 2, dim_head) 357 | 358 | with torch.no_grad(): 359 | try: 360 | test_input = torch.randn(test_shape) 361 | mem_model_output = model(test_input) 362 | except: 363 | raise RuntimeError(f'memory model unable to accept a tensor of shape {test_shape}') 364 | 365 | assert mem_model_output.shape == test_shape, 'output of memory model needs to be same shape as input' 366 | 367 | # the memory is the weights of the model 368 | 369 | if mem_model_norm_add_residual: 370 | model = ResidualNorm(dim = dim_head, model = model) 371 | 372 | self.memory_model = model 373 | 374 | mem_model_params = dict(model.named_parameters()) 375 | 376 | self.num_memory_parameter_tensors = len(mem_model_params) 377 | 378 | self.memory_model_parameter_names = [*mem_model_params.keys()] 379 | 380 | memory_model_parameters = [*mem_model_params.values()] 381 | 382 | if per_head_learned_parameters: 383 | memory_model_parameters = [repeat(p, '... -> h ...', h = heads) for p in memory_model_parameters] 384 | 385 | self.init_weight_shape = [p.shape for p in memory_model_parameters] 386 | 387 | self.memory_model_parameters = ParameterList(memory_model_parameters) 388 | self.per_head_learned_parameters = per_head_learned_parameters 389 | 390 | # the chunk size within the paper where adaptive step, momentum, weight decay are shared 391 | 392 | self.chunk_size = chunk_size 393 | 394 | # prepare function for per sample gradients from model above, using torch.func 395 | 396 | def forward_and_loss(params, inputs, loss_weights, target): 397 | pred = functional_call(self.memory_model, params, inputs) 398 | loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|² 399 | weighted_loss = loss * loss_weights 400 | return weighted_loss.sum(), loss 401 | 402 | # two functions 403 | 404 | grad_fn = grad(forward_and_loss, has_aux = True) 405 | 406 | self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0)) 407 | 408 | # queries for retrieving from the model 409 | 410 | self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation) 411 | 412 | # keys and values for storing to the model 413 | 414 | assert num_kv_per_token > 0 415 | 416 | self.to_keys = Sequential( 417 | LinearNoBias(dim, dim_inner * num_kv_per_token), 418 | activation, 419 | ) 420 | 421 | self.to_values = Sequential( 422 | LinearNoBias(dim, dim_inner * num_kv_per_token), 423 | activation, 424 | ) 425 | 426 | self.store_memory_loss_fn = store_memory_loss_fn 427 | 428 | self.num_kv_per_token = num_kv_per_token 429 | 430 | # `chunk_size` refers to chunk size used for storing to memory model weights 431 | 432 | chunk_size = self.store_chunk_size 433 | 434 | # whether to use averaging of chunks, or attention pooling 435 | 436 | assert not (attn_pool_chunks and chunk_size == 1), '`attn_pool_chunks` cannot be set to True if `chunk_size` is set to 1' 437 | 438 | if not attn_pool_chunks: 439 | self.reduce_to_chunk_rep = AveragePool(chunk_size = chunk_size) 440 | else: 441 | self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size) 442 | 443 | # learned adaptive learning rate 444 | 445 | self.to_adaptive_step = Sequential( 446 | nn.Linear(dim, heads * num_kv_per_token), 447 | Rearrange('b n (h u) -> (b h) (n u)', u = num_kv_per_token) 448 | ) 449 | 450 | if not exists(adaptive_step_transform): 451 | adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr) 452 | 453 | self.adaptive_step_transform = adaptive_step_transform 454 | 455 | # momentum related 456 | 457 | self.to_momentum = Sequential( 458 | nn.Linear(dim, heads * momentum_order), 459 | Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order) 460 | ) if momentum else None 461 | 462 | self.momentum_order = momentum_order 463 | self.to_learned_momentum_combine = None 464 | 465 | if learned_momentum_combine: 466 | assert momentum 467 | assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth' 468 | 469 | if learned_combine_include_zeroth: 470 | momentum_order += 1 471 | 472 | self.to_learned_momentum_combine = Sequential( 473 | nn.Linear(dim, heads * momentum_order), 474 | Rearrange('b n (h o) -> o (b h) n', h = heads), 475 | nn.Softmax(dim = 0), 476 | ) 477 | 478 | self.learned_combine_include_zeroth = learned_combine_include_zeroth 479 | 480 | # per layer learning rate modulation 481 | 482 | self.to_layer_modulation = Sequential( 483 | nn.Linear(dim, heads * self.num_memory_parameter_tensors), 484 | Rearrange('b n (h w) -> w (b h) n', h = heads), 485 | nn.Sigmoid() 486 | ) if per_parameter_lr_modulation else None 487 | 488 | self.max_mem_layer_modulation = max_mem_layer_modulation 489 | 490 | # learned weight residual 491 | 492 | self.to_learned_weight_residual_mix = Sequential( 493 | nn.Linear(dim, heads), 494 | Rearrange('b n h -> b h n'), 495 | nn.Sigmoid() 496 | ) if accept_weight_residual else None 497 | 498 | # allow for softclamp the gradient norms for storing memories 499 | 500 | self.max_grad_norm = max_grad_norm 501 | 502 | # spectral norming the surprises before update, a la Muon from Jordan et al. 503 | 504 | self.spectral_norm_surprises = spectral_norm_surprises 505 | 506 | # weight decay factor 507 | 508 | self.to_decay_factor = Sequential( 509 | nn.Linear(dim, heads), 510 | Rearrange('b n h -> (b h) n 1') 511 | ) 512 | 513 | # learned transition, as seeing instability when decreasing neural mem batch size 514 | # perhaps it can slowly learn to adjust from early residual to fully transitioning to new weights every batch size 515 | 516 | self.transition_gate = nn.Parameter(tensor(-5.)) if gated_transition else None 517 | 518 | # inits 519 | 520 | if exists(init_adaptive_step_bias): 521 | linear = self.to_adaptive_step[0] 522 | nn.init.zeros_(linear.weight) 523 | nn.init.constant_(linear.bias, init_adaptive_step_bias) 524 | 525 | if exists(init_momentum_bias): 526 | linear = self.to_momentum[0] 527 | nn.init.zeros_(linear.weight) 528 | nn.init.constant_(linear.bias, init_momentum_bias) 529 | 530 | if exists(init_decay_bias): 531 | linear = self.to_decay_factor[0] 532 | nn.init.zeros_(linear.weight) 533 | nn.init.constant_(linear.bias, init_decay_bias) 534 | 535 | # maybe use accelerated scan 536 | 537 | self.use_accelerated_scan = use_accelerated_scan 538 | 539 | self.register_buffer('zero', torch.tensor(0.), persistent = False) 540 | 541 | @property 542 | def memory_model_parameter_dict(self): 543 | return TensorDict(dict(zip(self.memory_model_parameter_names, self.memory_model_parameters))) 544 | 545 | def init_weights( 546 | self, 547 | batch, 548 | ): 549 | if self.per_head_learned_parameters: 550 | weights = repeat_dict_values(self.memory_model_parameter_dict, 'h ... -> (b h) ...', b = batch) 551 | else: 552 | weights = repeat_dict_values(self.memory_model_parameter_dict, '... -> bh ...', bh = batch * self.heads) 553 | 554 | return weights 555 | 556 | def init_momentum( 557 | self, 558 | batch, 559 | ): 560 | zeros = self.memory_model_parameter_dict.clone().zero_() 561 | 562 | if self.per_head_learned_parameters: 563 | zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order) 564 | else: 565 | zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order) 566 | 567 | return zeros 568 | 569 | def store_memories( 570 | self, 571 | seq, 572 | weights: dict[str, Tensor] | None = None, 573 | past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None, 574 | seq_index = 0, 575 | prev_weights = None, 576 | mask: Tensor | None = None, 577 | return_surprises = True 578 | ): 579 | if self.qkv_receives_diff_views: 580 | _, batch, seq_len = seq.shape[:3] 581 | else: 582 | batch, seq_len = seq.shape[:2] 583 | 584 | # shapes and variables 585 | 586 | heads, chunk_size, num_updates = self.heads, self.store_chunk_size, self.num_kv_per_token 587 | 588 | # curtail sequence by multiple of the chunk size 589 | # only a complete chunk of the sequence provides the memory for the next chunk 590 | 591 | round_down_seq_len = round_down_multiple(seq_len, chunk_size) 592 | num_chunks = round_down_seq_len // chunk_size 593 | 594 | seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :] 595 | 596 | next_seq_len_index = seq_index + round_down_seq_len 597 | 598 | # init weights if needed 599 | # weights of the memory network 600 | 601 | if not exists(weights): 602 | weights = self.init_weights(batch) 603 | 604 | weights = TensorDict(weights) 605 | 606 | # allow for neural memory of a previous layer to influence surprise of current layer 607 | 608 | weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks) 609 | 610 | # initial norm 611 | 612 | seq = self.store_norm(seq) 613 | 614 | # handle keys and values coming from different sequences from hyper connection 615 | 616 | values_seq = seq 617 | 618 | if self.qkv_receives_diff_views: 619 | seq, values_seq = seq 620 | 621 | # derive learned hparams for optimization of memory network 622 | 623 | adaptive_lr = self.to_adaptive_step(seq) 624 | adaptive_lr = self.adaptive_step_transform(adaptive_lr) 625 | 626 | chunked_seq = self.reduce_to_chunk_rep(seq, chunk_size = chunk_size) 627 | 628 | decay_factor = self.to_decay_factor(chunked_seq).sigmoid() 629 | 630 | need_layer_lr_mod = exists(self.to_layer_modulation) and num_chunks > 0 631 | has_momentum = exists(self.to_momentum) 632 | 633 | if has_momentum: 634 | adaptive_momentum = self.to_momentum(chunked_seq).sigmoid() 635 | 636 | learned_combine = exists(self.to_learned_momentum_combine) 637 | 638 | if learned_combine: 639 | combine_momentums = self.to_learned_momentum_combine(chunked_seq) 640 | 641 | if need_layer_lr_mod: 642 | layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation 643 | 644 | # keys and values 645 | 646 | keys = self.to_keys(seq) 647 | values = self.to_values(values_seq) 648 | 649 | # maybe multi head 650 | 651 | keys, values = map(self.split_kv_heads, (keys, values)) 652 | 653 | # maybe keys rmsnorm 654 | 655 | keys = self.k_norm(keys) 656 | 657 | # take care of chunking 658 | 659 | keys, values = tuple(rearrange(t, 'b h (n c u) d -> (b h n) (c u) d', c = chunk_size, u = num_updates) for t in (keys, values)) 660 | 661 | # adaptive lr 662 | 663 | adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates) 664 | 665 | # optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions 666 | 667 | if exists(mask): 668 | mask = mask[..., :round_down_seq_len] 669 | mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size) 670 | 671 | adaptive_lr = torch.where(mask, adaptive_lr, 0.) 672 | 673 | # maybe add previous layer weight 674 | 675 | assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights)) 676 | 677 | if exists(prev_weights): 678 | 679 | start_index = math.ceil(seq_index / chunk_size) 680 | end_index = start_index + num_chunks 681 | 682 | prev_weights = prev_weights.apply(lambda t: t[:, start_index:end_index]) 683 | 684 | if exists(self.to_learned_weight_residual_mix) and num_chunks > 0: 685 | mix = self.to_learned_weight_residual_mix(chunked_seq) 686 | mix = rearrange(mix, 'b h n -> (b h) n') 687 | prev_weights = prev_weights.apply(lambda t: einx.multiply('bh n, bh n ... -> bh n ...', mix, t)) 688 | 689 | weights_for_surprise = weights_for_surprise + prev_weights 690 | 691 | # flatten batch and time if surprise depends on previous layer memory model 692 | 693 | weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...') 694 | 695 | # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module) 696 | 697 | grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values) 698 | 699 | grads = TensorDict(grads) 700 | 701 | # surprises 702 | 703 | adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads) 704 | unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads) 705 | 706 | # maybe softclamp grad norm 707 | 708 | if exists(self.max_grad_norm): 709 | grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm)) 710 | 711 | # restore batch and sequence dimension 712 | 713 | grads = rearrange_dict_values(grads, '(b n) ... -> b n ...', b = batch * heads) 714 | 715 | # maybe per layer modulation 716 | 717 | if need_layer_lr_mod: 718 | grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())}) 719 | 720 | # negative gradients, adaptive lr already applied as loss weight 721 | 722 | surprises = grads.mul(-1) 723 | 724 | # past states 725 | 726 | if not exists(past_state): 727 | # minibatch_init_weight corresponds to W0 in figure 7 of TTT paper 728 | 729 | minibatch_init_weight = weights 730 | init_momentum = self.init_momentum(batch) 731 | 732 | past_state = (minibatch_init_weight, init_momentum) 733 | 734 | past_last_update, past_last_momentum = past_state 735 | 736 | # early return if sequence length less than chunk size 737 | 738 | if num_chunks == 0: 739 | updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...') 740 | next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates) 741 | 742 | output = (updates, next_store_state) 743 | 744 | if not return_surprises: 745 | return output 746 | 747 | return (*output, (unweighted_mem_model_loss, adaptive_lr)) 748 | 749 | # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates 750 | 751 | updates = TensorDict() 752 | 753 | next_last_update = TensorDict() 754 | next_last_momentum = TensorDict() 755 | 756 | for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()): 757 | 758 | update = surprise 759 | 760 | # derive momentum with associative scan - eq (10) 761 | 762 | if has_momentum: 763 | momentum = surprise 764 | 765 | momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum 766 | 767 | last_momentum = past_last_momentum[param_name] 768 | 769 | # go from first order momentum all the way to the Nth 770 | 771 | for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum): 772 | momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper 773 | 774 | momentums.append(momentum) 775 | 776 | momentums = stack(momentums) 777 | 778 | next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1'] 779 | 780 | if learned_combine and self.learned_combine_include_zeroth: 781 | # add the original surprise if learned combination of momentums 782 | momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0) 783 | 784 | if not learned_combine: 785 | update = momentums[-1] 786 | else: 787 | update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...') 788 | 789 | # maybe spectral norm surprises 790 | 791 | if self.spectral_norm_surprises: 792 | update = newtonschulz5(update) 793 | 794 | # use associative scan again for learned forgetting (weight decay) - eq (13) 795 | 796 | update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False) 797 | 798 | updates[param_name] = update 799 | next_last_update[param_name] = update[:, -1] 800 | 801 | # determine next state for the storing of memories 802 | 803 | next_state = (next_last_update, next_last_momentum) 804 | 805 | next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates) 806 | 807 | # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back 808 | 809 | if not return_surprises: 810 | return updates, next_store_state 811 | 812 | return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr) 813 | 814 | def retrieve_memories( 815 | self, 816 | seq, 817 | weights: dict[str, Tensor], 818 | ): 819 | chunk_size = self.retrieve_chunk_size 820 | 821 | weights_have_expanded_shape = dict_get_value_shapes(weights) != self.init_weight_shape 822 | 823 | batch, seq_len = seq.shape[:2] 824 | 825 | # auto infer single token decoding, if there are only 1 set of weights and 1 token 826 | 827 | is_one_token = seq_len == 1 828 | is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1 829 | 830 | is_single_token_decode = is_one_token and is_one_weight 831 | 832 | if is_single_token_decode: 833 | chunk_size = 1 834 | 835 | # padding related, for chunked processing 836 | 837 | need_pad = chunk_size > 1 or not is_one_weight 838 | 839 | if need_pad: 840 | seq = pad_at_dim(seq, (1, 0), dim = 1) 841 | 842 | seq_len_plus_one = seq.shape[-2] 843 | 844 | next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size) 845 | 846 | padding = next_seq_len - seq_len_plus_one 847 | seq = pad_at_dim(seq, (0, padding), dim = 1) 848 | 849 | # the parameters of the memory model stores the memories of the key / values 850 | # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper 851 | 852 | weights = TensorDict(weights) 853 | 854 | # pre norm 855 | 856 | seq = self.retrieve_norm(seq) 857 | 858 | # sequence Float['b n d'] to queries 859 | 860 | queries = self.to_queries(seq) 861 | 862 | # maybe multihead 863 | 864 | queries = self.split_heads(queries) 865 | 866 | # maybe qk rmsnorm 867 | 868 | queries = self.q_norm(queries) 869 | 870 | # fetch values from memory model 871 | 872 | if weights_have_expanded_shape: 873 | weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...') 874 | 875 | queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size) 876 | 877 | # forward functional call 878 | 879 | values = functional_call(self.memory_model, dict(weights), queries) 880 | 881 | # reconstitute batch dimension 882 | 883 | values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads) 884 | 885 | values = self.multihead_rmsnorm(values) 886 | 887 | # maybe gate 888 | 889 | if exists(self.retrieve_gate): 890 | values = values * self.retrieve_gate(seq) 891 | 892 | # maybe merge heads and combine 893 | 894 | values = self.merge_heads(values) 895 | 896 | values = self.combine_heads(values) 897 | 898 | # restore, pad with empty memory embed 899 | 900 | if need_pad: 901 | values = values[:, 1:] 902 | 903 | return values[:, :seq_len] 904 | 905 | def forward( 906 | self, 907 | seq, 908 | store_seq = None, 909 | state: NeuralMemState | None = None, 910 | detach_mem_state = False, 911 | prev_weights = None, 912 | store_mask: Tensor | None = None, 913 | return_surprises = False, 914 | ttt_batch_size: int | None = None 915 | ): 916 | is_multi_input = self.qkv_receives_diff_views 917 | 918 | # handle single token 919 | 920 | if seq.ndim == 2 or (is_multi_input and seq.ndim == 3): 921 | seq = rearrange(seq, '... b d -> ... b 1 d') 922 | 923 | is_single_token = seq.shape[-2] == 1 924 | 925 | # if different views for qkv, then 926 | 927 | if is_multi_input: 928 | retrieve_seq, seq = seq[0], seq[1:] 929 | else: 930 | retrieve_seq = seq 931 | 932 | # handle previous state init 933 | 934 | if not exists(state): 935 | state = (0, None, None, None, None) 936 | 937 | seq_index, weights, cache_store_seq, past_state, updates = state 938 | 939 | # store 940 | 941 | store_seq = default(store_seq, seq) 942 | 943 | # take care of cache 944 | 945 | if exists(cache_store_seq): 946 | store_seq = safe_cat((cache_store_seq, store_seq)) 947 | 948 | # compute split sizes of sequence 949 | # for now manually update weights to last update at the correct boundaries 950 | 951 | store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size) 952 | 953 | need_update_weights = exists(batch_size) 954 | 955 | # determine split sizes and when to update 956 | 957 | if need_update_weights: 958 | update_after_final_store = divisible_by(seq_index + store_seq_len, batch_size) 959 | 960 | seq_range = torch.arange(store_seq_len) + seq_index + 1 961 | batch_boundary = divisible_by(seq_range, batch_size) 962 | 963 | indices = seq_range[batch_boundary] - seq_index 964 | 965 | indices = F.pad(indices, (1, 0), value = 0) 966 | 967 | if indices[-1] != store_seq_len: 968 | indices = F.pad(indices, (0, 1), value = store_seq_len) 969 | 970 | split_sizes = (indices[1:] - indices[:-1]).tolist() 971 | 972 | assert sum(split_sizes) == store_seq_len 973 | else: 974 | split_sizes = (store_seq_len,) 975 | update_after_final_store = False 976 | 977 | # accumulate updates 978 | 979 | updates = None 980 | 981 | def accum_updates(past_updates, future_updates): 982 | if not exists(past_updates): 983 | return future_updates 984 | 985 | return TensorDict({param_name: cat((past_update[:, :-1], future_update), dim = 1) for (param_name, past_update), (_, future_update) in zip(past_updates.items(), future_updates.items())}) 986 | 987 | # loop through chunks of store sequences 988 | 989 | store_seqs = store_seq.split(split_sizes, dim = -2) 990 | 991 | if exists(store_mask): 992 | store_masks = store_mask.split(split_sizes, dim = -1) 993 | else: 994 | store_masks = (None,) * len(split_sizes) 995 | 996 | # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch 997 | 998 | surprises = (None, None) 999 | gate = None 1000 | 1001 | if exists(self.transition_gate): 1002 | gate = self.transition_gate.sigmoid() 1003 | 1004 | for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)): 1005 | is_last = ind == (len(store_seqs) - 1) 1006 | 1007 | # store 1008 | 1009 | next_updates, next_neural_mem_state, chunk_surprises = self.store_memories( 1010 | store_seq_chunk, 1011 | weights, 1012 | seq_index = seq_index, 1013 | past_state = past_state, 1014 | prev_weights = prev_weights, 1015 | mask = maybe_store_mask, 1016 | return_surprises = True 1017 | ) 1018 | 1019 | weights = next_neural_mem_state.weights 1020 | seq_index = next_neural_mem_state.seq_index 1021 | past_state = next_neural_mem_state.states 1022 | 1023 | updates = accum_updates(updates, next_updates) 1024 | 1025 | surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises)) 1026 | 1027 | if is_last and not update_after_final_store: 1028 | continue 1029 | 1030 | # update weights once batch size is fulfilled 1031 | 1032 | last_update, last_momentum = past_state 1033 | 1034 | if exists(gate): 1035 | last_update = TensorDict({param_name: one_weight.lerp(one_last_update, gate) for (param_name, one_weight), (_, one_last_update) in zip(weights.items(), last_update.items())}) 1036 | 1037 | past_state = (last_update, last_momentum) 1038 | 1039 | # set weights to the last updated weights for the last minibatch 1040 | 1041 | weights = last_update 1042 | 1043 | next_neural_mem_state = next_neural_mem_state._replace( 1044 | weights = weights, 1045 | states = past_state, 1046 | ) 1047 | 1048 | next_neural_mem_state = next_neural_mem_state._replace(updates = updates) 1049 | 1050 | # retrieve 1051 | 1052 | if is_single_token: 1053 | last_update, _ = next_neural_mem_state.states 1054 | updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...') 1055 | 1056 | retrieved = self.retrieve_memories( 1057 | retrieve_seq, 1058 | updates 1059 | ) 1060 | 1061 | # maybe detach 1062 | 1063 | if detach_mem_state: 1064 | next_neural_mem_state = mem_state_detach(next_neural_mem_state) 1065 | 1066 | # returning 1067 | 1068 | if not return_surprises: 1069 | return retrieved, next_neural_mem_state 1070 | 1071 | return retrieved, next_neural_mem_state, surprises 1072 | -------------------------------------------------------------------------------- /train_mac.py: -------------------------------------------------------------------------------- 1 | import random 2 | import tqdm 3 | import gzip 4 | import numpy as np 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from adam_atan2_pytorch import AdoptAtan2 12 | 13 | from titans_pytorch import ( 14 | MemoryAsContextTransformer, 15 | MemoryMLP, 16 | MemoryAttention 17 | ) 18 | 19 | # constants 20 | 21 | NUM_BATCHES = int(1e5) 22 | BATCH_SIZE = 4 23 | GRADIENT_ACCUMULATE_EVERY = 4 24 | LEARNING_RATE = 2e-4 25 | VALIDATE_EVERY = 100 26 | GENERATE_EVERY = 500 27 | PRIME_LENGTH = 100 28 | GENERATE_LENGTH = 512 29 | SHOULD_GENERATE = True 30 | SEQ_LEN = 512 31 | 32 | # neural memory related 33 | 34 | NEURAL_MEMORY_DEPTH = 2 35 | NUM_PERSIST_MEM = 4 36 | NUM_LONGTERM_MEM = 4 37 | NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more 38 | NEURAL_MEM_GATE_ATTN_OUTPUT = False 39 | NEURAL_MEM_MOMENTUM = True 40 | NEURAL_MEM_MOMENTUM_ORDER = 1 41 | NEURAL_MEM_QK_NORM = True 42 | NEURAL_MEM_MAX_LR = 1e-1 43 | USE_MEM_ATTENTION_MODEL = False 44 | WINDOW_SIZE = 32 45 | NEURAL_MEM_SEGMENT_LEN = 4 # set smaller for more granularity for learning rate / momentum etc 46 | NEURAL_MEM_BATCH_SIZE = 128 # set smaller to update the neural memory weights more often as it traverses the sequence 47 | SLIDING_WINDOWS = True 48 | STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay 49 | MEMORY_MODEL_PER_LAYER_LEARNED_LR = True 50 | NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper 51 | NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS 52 | NEURAL_MEM_SPEC_NORM_SURPRISES = True # applying lessons from Muon optimizer to surprise updates, by spectral norming the surprises 53 | 54 | # experiment related 55 | 56 | PROJECT_NAME = 'titans-mac-transformer' 57 | RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}' 58 | WANDB_ONLINE = False # turn this on to pipe experiment to cloud 59 | 60 | # perf related 61 | 62 | USE_ACCELERATED_SCAN = True 63 | USE_FLEX_ATTN = True 64 | USE_FAST_INFERENCE = False 65 | 66 | # wandb experiment tracker 67 | 68 | import wandb 69 | wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online') 70 | wandb.run.name = RUN_NAME 71 | wandb.run.save() 72 | 73 | # helpers 74 | 75 | def cycle(loader): 76 | while True: 77 | for data in loader: 78 | yield data 79 | 80 | def decode_token(token): 81 | return str(chr(max(32, token))) 82 | 83 | def decode_tokens(tokens): 84 | return ''.join(list(map(decode_token, tokens))) 85 | 86 | # memory model 87 | 88 | if USE_MEM_ATTENTION_MODEL: 89 | neural_memory_model = MemoryAttention( 90 | dim = 64 91 | ) 92 | else: 93 | neural_memory_model = MemoryMLP( 94 | dim = 64, 95 | depth = NEURAL_MEMORY_DEPTH 96 | ) 97 | 98 | # instantiate memory-as-context transformer 99 | 100 | model = MemoryAsContextTransformer( 101 | num_tokens = 256, 102 | dim = 384, 103 | depth = 8, 104 | segment_len = WINDOW_SIZE, 105 | num_persist_mem_tokens = NUM_PERSIST_MEM, 106 | num_longterm_mem_tokens = NUM_LONGTERM_MEM, 107 | neural_memory_layers = NEURAL_MEM_LAYERS, 108 | neural_memory_segment_len = NEURAL_MEM_SEGMENT_LEN, 109 | neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE, 110 | neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT, 111 | neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL, 112 | neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW, 113 | use_flex_attn = USE_FLEX_ATTN, 114 | sliding_window_attn = SLIDING_WINDOWS, 115 | neural_memory_model = neural_memory_model, 116 | neural_memory_kwargs = dict( 117 | dim_head = 64, 118 | heads = 4, 119 | attn_pool_chunks = STORE_ATTN_POOL_CHUNKS, 120 | qk_rmsnorm = NEURAL_MEM_QK_NORM, 121 | momentum = NEURAL_MEM_MOMENTUM, 122 | momentum_order = NEURAL_MEM_MOMENTUM_ORDER, 123 | default_step_transform_max_lr = NEURAL_MEM_MAX_LR, 124 | use_accelerated_scan = USE_ACCELERATED_SCAN, 125 | per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR, 126 | spectral_norm_surprises = NEURAL_MEM_SPEC_NORM_SURPRISES 127 | ) 128 | ).cuda() 129 | 130 | # prepare enwik8 data 131 | 132 | with gzip.open('./data/enwik8.gz') as file: 133 | data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy() 134 | data_train, data_val = np.split(data, [int(90e6)]) 135 | data_train, data_val = map(torch.from_numpy, (data_train, data_val)) 136 | 137 | class TextSamplerDataset(Dataset): 138 | def __init__(self, data, seq_len): 139 | super().__init__() 140 | self.data = data 141 | self.seq_len = seq_len 142 | 143 | def __getitem__(self, index): 144 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 145 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 146 | return full_seq.cuda() 147 | 148 | def __len__(self): 149 | return self.data.size(0) // self.seq_len 150 | 151 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 152 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 153 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 154 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 155 | 156 | # optimizer 157 | 158 | optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE) 159 | 160 | # training 161 | 162 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'): 163 | model.train() 164 | 165 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 166 | loss = model(next(train_loader), return_loss = True) 167 | loss.backward() 168 | 169 | print(f'training loss: {loss.item()}') 170 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 171 | optim.step() 172 | optim.zero_grad() 173 | wandb.log(dict(loss = loss.item())) 174 | 175 | if i % VALIDATE_EVERY == 0: 176 | model.eval() 177 | with torch.no_grad(): 178 | loss = model(next(val_loader), return_loss = True) 179 | print(f'validation loss: {loss.item()}') 180 | 181 | if SHOULD_GENERATE and i % GENERATE_EVERY == 0: 182 | model.eval() 183 | inp = random.choice(val_dataset)[:PRIME_LENGTH] 184 | prime = decode_tokens(inp) 185 | print(f'%s \n\n %s', (prime, '*' * 100)) 186 | 187 | sample = model.sample(inp[None, ...], GENERATE_LENGTH, use_cache = USE_FAST_INFERENCE) 188 | output_str = decode_tokens(sample[0]) 189 | print(output_str) 190 | --------------------------------------------------------------------------------