├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── recurrent_memory_transformer_pytorch ├── __init__.py ├── attend.py └── recurrent_memory_transformer.py ├── rmt.png ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Recurrent Memory Transformer - Pytorch 4 | 5 | Implementation of Recurrent Memory Transformer (openreview) in Pytorch. They had a short follow up paper recently that demonstrated it was able to copy information across 1 million tokens at the very least. 6 | 7 | There is no doubt in my mind that RMT would make a stronger RL agent than AdA, which is just a Transformer-XL - Update: Recurrent Action Transformer with Memory (RATE) 8 | 9 | Yannic Kilcher paper review 10 | 11 | ## Appreciation 12 | 13 | - Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research 14 | 15 | ## Install 16 | 17 | ```bash 18 | $ pip install recurrent-memory-transformer-pytorch 19 | ``` 20 | 21 | ## Usage 22 | 23 | ```python 24 | import torch 25 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer 26 | 27 | model = RecurrentMemoryTransformer( 28 | num_tokens = 20000, # number of tokens 29 | num_memory_tokens = 128, # number of memory tokens, this will determine the bottleneck for information being passed to the future 30 | dim = 512, # model dimensions 31 | depth = 6, # transformer depth 32 | causal = True, # autoregressive or not 33 | dim_head = 64, # dimension per head 34 | heads = 8, # heads 35 | seq_len = 1024, # sequence length of a segment 36 | use_flash_attn = True # whether to use flash attention 37 | ) 38 | 39 | x = torch.randint(0, 256, (1, 1024)) 40 | 41 | logits1, mem1, _ = model(x) # (1, 1024, 20000), (1, 128, 512), None 42 | logits2, mem2, _ = model(x, mem1) # (1, 1024, 20000), (1, 128, 512), None 43 | logits3, mem3, _ = model(x, mem2) # (1, 1024, 20000), (1, 128, 512), None 44 | 45 | # and so on ... 46 | 47 | ``` 48 | 49 | With XL memories 50 | 51 | ```python 52 | import torch 53 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer 54 | 55 | model = RecurrentMemoryTransformer( 56 | num_tokens = 20000, 57 | num_memory_tokens = 128, 58 | dim = 512, 59 | depth = 6, 60 | causal = True, 61 | dim_head = 64, 62 | heads = 8, 63 | seq_len = 1024, 64 | use_flash_attn = True, 65 | use_xl_memories = True, # set this to True 66 | xl_mem_len = 512 # can be shorter than the seq len - i think just having a bit of the past will prevent much of the RMT memories memorizing the immediate preceding text 67 | ) 68 | 69 | x = torch.randint(0, 256, (1, 1024)) 70 | 71 | logits1, mem1, xl_mem1 = model(x) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)] 72 | logits2, mem2, xl_mem2 = model(x, mem1, xl_memories = xl_mem1) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)] 73 | logits3, mem3, xl_mem3 = model(x, mem2, xl_memories = xl_mem2) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)] 74 | 75 | # and so on ... 76 | ``` 77 | 78 | Train on an absurdly long sequence 79 | 80 | ```python 81 | import torch 82 | from recurrent_memory_transformer_pytorch import ( 83 | RecurrentMemoryTransformer, 84 | RecurrentMemoryTransformerWrapper 85 | ) 86 | 87 | model = RecurrentMemoryTransformer( 88 | num_tokens = 256, 89 | num_memory_tokens = 128, 90 | dim = 512, 91 | depth = 6, 92 | seq_len = 1024, 93 | use_flash_attn = True, 94 | causal = True 95 | ) 96 | 97 | model = RecurrentMemoryTransformerWrapper(model).cuda() 98 | 99 | seq = torch.randint(0, 256, (4, 65536)).cuda() # absurdly long sequence, in reality, they curriculum learned this starting with 1 segment to about 7-8 segments 100 | 101 | loss = model(seq, memory_replay_backprop = True) # memory efficient training from memformer paper 102 | 103 | ``` 104 | 105 | ## Todo 106 | 107 | - [ ] move the memory replay backprop into a torch.function, test out bidirectional, then test on a real problem 108 | 109 | - [x] get rotary embeddings working properly with xl memories 110 | - [x] add xl memories, detached 111 | - [x] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift 112 | - [x] make memories being causally masked an option 113 | - [x] add the memory replay backprop technique from memformer paper 114 | - [x] relative positional encoding 115 | 116 | ## Alternatives 117 | 118 | - Block Recurrent Transformer 119 | 120 | - Memformer 121 | 122 | ## Citations 123 | 124 | ```bibtex 125 | @inproceedings{bulatov2022recurrent, 126 | title = {Recurrent Memory Transformer}, 127 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail Burtsev}, 128 | booktitle = {Advances in Neural Information Processing Systems}, 129 | editor = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 130 | year = {2022}, 131 | url = {https://openreview.net/forum?id=Uynr3iPhksa} 132 | } 133 | ``` 134 | 135 | ```bibtex 136 | @misc{bulatov2023scaling, 137 | title = {Scaling Transformer to 1M tokens and beyond with RMT}, 138 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev}, 139 | year = {2023}, 140 | eprint = {2304.11062}, 141 | archivePrefix = {arXiv}, 142 | primaryClass = {cs.CL} 143 | } 144 | ``` 145 | 146 | ```bibtex 147 | @inproceedings{dao2022flashattention, 148 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 149 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 150 | booktitle = {Advances in Neural Information Processing Systems}, 151 | year = {2022} 152 | } 153 | ``` 154 | 155 | ```bibtex 156 | @misc{shazeer2020glu, 157 | title = {GLU Variants Improve Transformer}, 158 | author = {Noam Shazeer}, 159 | year = {2020}, 160 | url = {https://arxiv.org/abs/2002.05202} 161 | } 162 | ``` 163 | 164 | ```bibtex 165 | @misc{su2021roformer, 166 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 167 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 168 | year = {2021}, 169 | eprint = {2104.09864}, 170 | archivePrefix = {arXiv}, 171 | primaryClass = {cs.CL} 172 | } 173 | ``` 174 | 175 | ```bibtex 176 | @inproceedings{Wu2020MemformerAM, 177 | title = {Memformer: A Memory-Augmented Transformer for Sequence Modeling}, 178 | author = {Qingyang Wu and Zhenzhong Lan and Kun Qian and Jing Gu and Alborz Geramifard and Zhou Yu}, 179 | booktitle = {AACL/IJCNLP}, 180 | year = {2020} 181 | } 182 | ``` 183 | 184 | ```bibtex 185 | @software{peng_bo_2021_5196578, 186 | author = {PENG Bo}, 187 | title = {BlinkDL/RWKV-LM: 0.01}, 188 | month = {aug}, 189 | year = {2021}, 190 | publisher = {Zenodo}, 191 | version = {0.01}, 192 | doi = {10.5281/zenodo.5196578}, 193 | url = {https://doi.org/10.5281/zenodo.5196578} 194 | } 195 | ``` 196 | 197 | ```bibtex 198 | @misc{ding2021cogview, 199 | title = {CogView: Mastering Text-to-Image Generation via Transformers}, 200 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang}, 201 | year = {2021}, 202 | eprint = {2105.13290}, 203 | archivePrefix = {arXiv}, 204 | primaryClass = {cs.CV} 205 | } 206 | ``` 207 | 208 | ```bibtex 209 | @software{Dayma_DALLE_Mini_2021, 210 | author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata}, 211 | doi = {10.5281/zenodo.5146400}, 212 | license = {Apache-2.0}, 213 | month = {jul}, 214 | title = {{DALL·E Mini}}, 215 | url = {https://github.com/borisdayma/dalle-mini}, 216 | version = {v0.1-alpha}, 217 | year = {2021}} 218 | ``` 219 | 220 | ```bibtex 221 | @inproceedings{anonymous2022normformer, 222 | title = {NormFormer: Improved Transformer Pretraining with Extra Normalization}, 223 | author = {Anonymous}, 224 | booktitle = {Submitted to The Tenth International Conference on Learning Representations }, 225 | year = {2022}, 226 | url = {https://openreview.net/forum?id=GMYWzWztDx5}, 227 | note = {under review} 228 | } 229 | ``` 230 | 231 | ```bibtex 232 | @misc{ding2021erniedoc, 233 | title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, 234 | author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang}, 235 | year = {2021}, 236 | eprint = {2012.15688}, 237 | archivePrefix = {arXiv}, 238 | primaryClass = {cs.CL} 239 | } 240 | ``` 241 | 242 | ```bibtex 243 | @article{Zhu2024HyperConnections, 244 | title = {Hyper-Connections}, 245 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 246 | journal = {ArXiv}, 247 | year = {2024}, 248 | volume = {abs/2409.19606}, 249 | url = {https://api.semanticscholar.org/CorpusID:272987528} 250 | } 251 | ``` 252 | 253 | ```bibtex 254 | @inproceedings{Zhou2024ValueRL, 255 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 256 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 257 | year = {2024}, 258 | url = {https://api.semanticscholar.org/CorpusID:273532030} 259 | } 260 | ``` 261 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/recurrent-memory-transformer-pytorch/520a3574c5a00e452d2af3fb1c26f15a3779c8bb/data/enwik8.gz -------------------------------------------------------------------------------- /recurrent_memory_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from recurrent_memory_transformer_pytorch.recurrent_memory_transformer import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper 2 | -------------------------------------------------------------------------------- /recurrent_memory_transformer_pytorch/attend.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import wraps 3 | from packaging import version 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange 10 | 11 | # constants 12 | 13 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def once(fn): 21 | called = False 22 | @wraps(fn) 23 | def inner(x): 24 | nonlocal called 25 | if called: 26 | return 27 | called = True 28 | return fn(x) 29 | return inner 30 | 31 | print_once = once(print) 32 | 33 | # main class 34 | 35 | class Attend(nn.Module): 36 | def __init__( 37 | self, 38 | dropout = 0., 39 | causal = False, 40 | use_flash = False 41 | ): 42 | super().__init__() 43 | self.dropout = dropout 44 | self.attn_dropout = nn.Dropout(dropout) 45 | 46 | self.causal = causal 47 | self.register_buffer("mask", None, persistent=False) 48 | 49 | self.use_flash = use_flash 50 | assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 51 | 52 | # determine efficient attention configs for cuda and cpu 53 | 54 | self.cpu_config = Config(True, True, True) 55 | self.cuda_config = None 56 | 57 | if not torch.cuda.is_available() or not use_flash: 58 | return 59 | 60 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 61 | 62 | if device_properties.major == 8 and device_properties.minor == 0: 63 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 64 | self.cuda_config = Config(True, False, False) 65 | else: 66 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 67 | self.cuda_config = Config(False, True, True) 68 | 69 | def get_mask(self, n, device): 70 | if exists(self.mask) and self.mask.shape[-1] >= n: 71 | return self.mask[:n, :n] 72 | 73 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 74 | self.register_buffer("mask", mask, persistent=False) 75 | return mask 76 | 77 | def flash_attn(self, q, k, v, mask = None): 78 | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda 79 | 80 | # Check if mask exists and expand to compatible shape 81 | # The mask is B L, so it would have to be expanded to B H N L 82 | 83 | if exists(mask): 84 | if mask.ndim != 4: 85 | mask = rearrange(mask, 'b j -> b 1 1 j') 86 | 87 | mask = mask.expand(-1, heads, q_len, -1) 88 | 89 | # Check if there is a compatible device for flash attention 90 | 91 | config = self.cuda_config if is_cuda else self.cpu_config 92 | 93 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 94 | 95 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 96 | out = F.scaled_dot_product_attention( 97 | q, k, v, 98 | attn_mask = mask, 99 | dropout_p = self.dropout if self.training else 0., 100 | is_causal = self.causal 101 | ) 102 | 103 | return out 104 | 105 | def forward(self, q, k, v, mask = None): 106 | """ 107 | einstein notation 108 | b - batch 109 | h - heads 110 | n, i, j - sequence length (base sequence length, source, target) 111 | d - feature dimension 112 | """ 113 | 114 | n, device = q.shape[-2], q.device 115 | 116 | scale = q.shape[-1] ** -0.5 117 | 118 | if self.use_flash: 119 | return self.flash_attn(q, k, v, mask = mask) 120 | 121 | # similarity 122 | 123 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale 124 | 125 | # key padding mask 126 | 127 | if exists(mask): 128 | if mask.ndim != 4: 129 | mask = rearrange(mask, 'b j -> b 1 1 j') 130 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 131 | 132 | # causal mask 133 | 134 | if self.causal: 135 | causal_mask = self.get_mask(n, device) 136 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 137 | 138 | # attention 139 | 140 | attn = sim.softmax(dim=-1) 141 | attn = self.attn_dropout(attn) 142 | 143 | # aggregate values 144 | 145 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 146 | 147 | return out 148 | -------------------------------------------------------------------------------- /recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from functools import partial 5 | from itertools import zip_longest 6 | from contextlib import nullcontext 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn import Module, ModuleList 11 | from torch import nn, einsum, Tensor 12 | 13 | from einops import rearrange, repeat, pack, unpack 14 | from einops.layers.torch import Rearrange 15 | 16 | from recurrent_memory_transformer_pytorch.attend import Attend 17 | 18 | from hyper_connections import get_init_and_expand_reduce_stream_functions 19 | 20 | # constants 21 | 22 | Linear = partial(nn.Linear, bias = False) 23 | 24 | # helpers 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | def identity(t, *args, **kwargs): 30 | return t 31 | 32 | def default(*vals): 33 | for val in vals: 34 | if exists(val): 35 | return val 36 | return None 37 | 38 | def eval_decorator(fn): 39 | def inner(self, *args, **kwargs): 40 | was_training = self.training 41 | self.eval() 42 | out = fn(self, *args, **kwargs) 43 | self.train(was_training) 44 | return out 45 | return inner 46 | 47 | def divisible_by(numer, denom): 48 | return (numer % denom) == 0 49 | 50 | # sampling helpers 51 | 52 | def log(t, eps = 1e-20): 53 | return torch.log(t.clamp(min = eps)) 54 | 55 | def gumbel_noise(t): 56 | noise = torch.zeros_like(t).uniform_(0, 1) 57 | return -log(-log(noise)) 58 | 59 | def gumbel_sample(t, temperature = 1., dim = -1): 60 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 61 | 62 | def top_k(logits, thres = 0.9): 63 | k = math.ceil((1 - thres) * logits.shape[-1]) 64 | val, ind = torch.topk(logits, k) 65 | probs = torch.full_like(logits, float('-inf')) 66 | probs.scatter_(1, ind, val) 67 | return probs 68 | 69 | def frac_gradient(t, frac = 1.): 70 | if frac == 1.: 71 | return t 72 | 73 | return t * frac + t.detach() * (1. - frac) 74 | 75 | # rotary embedding 76 | 77 | class RotaryEmbedding(Module): 78 | def __init__(self, dim, theta = 32768): 79 | super().__init__() 80 | inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) 81 | self.register_buffer('inv_freq', inv_freq) 82 | 83 | def forward(self, positions): 84 | freqs = torch.einsum('i , j -> i j', positions, self.inv_freq) 85 | freqs = torch.cat((freqs, freqs), dim = -1) 86 | return freqs 87 | 88 | def rotate_half(x): 89 | x1, x2 = x.chunk(2, dim=-1) 90 | return torch.cat((-x2, x1), dim=-1) 91 | 92 | def apply_rotary_pos_emb(pos, t): 93 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 94 | 95 | # feedforward 96 | 97 | class GEGLU(Module): 98 | def forward(self, x): 99 | x, gate = x.chunk(2, dim = -1) 100 | return x * F.gelu(gate) 101 | 102 | def FeedForward(dim, mult = 4, dropout = 0.): 103 | dim_inner = int(dim * mult * 2 / 3) 104 | 105 | return nn.Sequential( 106 | nn.RMSNorm(dim), 107 | Linear(dim, dim_inner * 2), 108 | GEGLU(), 109 | nn.Dropout(dropout), 110 | Linear(dim_inner, dim) 111 | ) 112 | 113 | # attention 114 | 115 | class Attention(Module): 116 | def __init__( 117 | self, 118 | *, 119 | dim, 120 | causal = False, 121 | dim_head = 64, 122 | heads = 8, 123 | dropout = 0., 124 | accept_value_residual = False, 125 | use_flash_attn = False, 126 | use_custom_causal_attn_mask = False 127 | ): 128 | super().__init__() 129 | self.norm = nn.RMSNorm(dim) 130 | 131 | dim_inner = dim_head * heads 132 | self.heads = heads 133 | 134 | self.attend = Attend( 135 | causal = causal and not use_custom_causal_attn_mask, 136 | dropout = dropout, 137 | use_flash = use_flash_attn 138 | ) 139 | 140 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head)) 141 | 142 | self.to_q = Linear(dim, dim_inner) 143 | self.to_kv = Linear(dim, dim_inner * 2) 144 | self.to_out = Linear(dim_inner, dim) 145 | 146 | # learned value residual mixing 147 | 148 | self.learned_value_residual_mix = None 149 | 150 | if accept_value_residual: 151 | self.learned_value_residual_mix = nn.Sequential( 152 | Linear(dim, heads), 153 | Rearrange('b n h -> b h n 1'), 154 | nn.Sigmoid() 155 | ) 156 | 157 | def forward( 158 | self, 159 | x, 160 | rotary_emb: tuple[Tensor, Tensor] | None = None, 161 | mask = None, 162 | xl_memories = None, 163 | value_residual = None 164 | ): 165 | assert not (exists(value_residual) ^ exists(self.learned_value_residual_mix)) 166 | 167 | h = self.heads 168 | x = self.norm(x) 169 | 170 | q = self.to_q(x) 171 | k, v = self.to_kv(x).chunk(2, dim = -1) 172 | 173 | # split heads 174 | 175 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 176 | 177 | # handle value residual 178 | 179 | orig_v = v 180 | 181 | if exists(self.learned_value_residual_mix): 182 | mix = self.learned_value_residual_mix(x) 183 | v = v.lerp(value_residual, mix) 184 | 185 | # add a null key / value 186 | # to protect against an entirely masked out sequence 187 | # as well as giving attention ability to attend to nothing 188 | 189 | nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = x.shape[0]), self.null_kv) 190 | 191 | k = torch.cat((nk, k), dim = -2) 192 | v = torch.cat((nv, v), dim = -2) 193 | 194 | if exists(mask): 195 | mask = F.pad(mask, (1, 0), value = True) 196 | 197 | # manage memories 198 | 199 | next_xl_memories = torch.stack((k, v)) 200 | 201 | if exists(xl_memories): 202 | kx, vx = xl_memories 203 | k = torch.cat((kx, k), dim = -2) 204 | v = torch.cat((vx, v), dim = -2) 205 | 206 | if exists(mask): 207 | mask = F.pad(mask, (xl_memories.shape[-2], 0), value = True) 208 | 209 | if exists(rotary_emb): 210 | q_rotary_emb, k_rotary_emb = rotary_emb 211 | 212 | q = apply_rotary_pos_emb(q_rotary_emb, q) 213 | k = apply_rotary_pos_emb(k_rotary_emb, k) 214 | 215 | out = self.attend(q, k, v, mask = mask) 216 | 217 | out = rearrange(out, 'b h n d -> b n (h d)') 218 | 219 | return self.to_out(out), next_xl_memories, orig_v 220 | 221 | # transformer 222 | 223 | class RecurrentMemoryTransformer(Module): 224 | def __init__( 225 | self, 226 | dim, 227 | *, 228 | num_tokens, 229 | depth, 230 | num_memory_tokens, 231 | seq_len, 232 | causal = True, 233 | dim_head = 64, 234 | heads = 8, 235 | ff_mult = 4, 236 | attn_dropout = 0., 237 | ff_dropout = 0., 238 | use_flash_attn = False, 239 | ignore_index = -1, 240 | abs_pos_emb = True, 241 | rotary_pos_emb = False, 242 | use_xl_memories = True, 243 | xl_mem_len = None, 244 | enhanced_xl_recurrence = False, # add simple method for enhancing receptive field of xl memories, from ernie-doc paper 245 | emb_gradient_frac = 0.1, # trick from cogview paper that leads to a bit more stability 246 | memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on 247 | add_write_to_next_write_mem = False, # add the write memories of previous step to the next write step - thanks to @IcarusWizard for pointing out this discrepancy 248 | next_write_mem_stop_grad = True, # whether to stop gradient of previous read memory -> next write memory 249 | always_have_read_memories = True, # whether to always have read memories, even on the first step, so to make the model onnx-able 250 | num_residual_streams = 4 # number of residual streams for hyper connections 251 | ): 252 | super().__init__() 253 | self.causal = causal 254 | self.seq_len = seq_len 255 | 256 | self.emb_gradient_frac = emb_gradient_frac 257 | 258 | assert num_memory_tokens > 0 259 | 260 | self.token_emb = nn.Embedding(num_tokens, dim) 261 | 262 | # positions 263 | 264 | assert any([abs_pos_emb, rotary_pos_emb]) 265 | 266 | self.pos_emb = nn.Embedding(seq_len, dim) if abs_pos_emb else None 267 | 268 | self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else None 269 | 270 | # memory related 271 | 272 | self.num_memory_tokens = num_memory_tokens 273 | 274 | self.read_memory_emb = nn.Parameter(torch.zeros(num_memory_tokens, dim)) 275 | nn.init.normal_(self.read_memory_emb, std = 0.02) 276 | 277 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 278 | nn.init.normal_(self.memory_tokens, std = 0.02) 279 | 280 | # xl memories 281 | 282 | xl_mem_len = default(xl_mem_len, seq_len) 283 | assert xl_mem_len <= seq_len 284 | self.xl_mem_len = xl_mem_len 285 | 286 | self.use_xl_memories = use_xl_memories 287 | self.enhanced_xl_recurrence = enhanced_xl_recurrence 288 | 289 | # hyper connections 290 | 291 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 292 | 293 | # layers 294 | 295 | self.layers = ModuleList([]) 296 | 297 | for layer_index in range(depth): 298 | is_first = layer_index == 0 299 | 300 | self.layers.append(ModuleList([ 301 | init_hyper_conn(dim = dim, branch = Attention( 302 | dim = dim, 303 | dim_head = dim_head, 304 | causal = causal, 305 | heads = heads, 306 | use_flash_attn = use_flash_attn, 307 | accept_value_residual = not is_first, 308 | use_custom_causal_attn_mask = memory_not_causal, 309 | dropout = attn_dropout 310 | )), 311 | init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)), 312 | ])) 313 | 314 | self.norm = nn.RMSNorm(dim) 315 | self.to_logits = nn.Linear(dim, num_tokens) 316 | 317 | self.ignore_index = ignore_index 318 | 319 | # whether to use custom attention mask if causal and memory should not be causal 320 | 321 | self.use_custom_causal_attn_mask = causal and memory_not_causal 322 | 323 | # in the paper, they actually also use the previous write memories for the next write memories 324 | 325 | self.add_write_to_next_write_mem = add_write_to_next_write_mem 326 | self.next_write_mem_stop_grad = next_write_mem_stop_grad 327 | 328 | # allow for attending to raw read memory positional embeddings on first step 329 | # hack to make it onnx-able and should not hurt 330 | 331 | self.always_have_read_memories = always_have_read_memories 332 | 333 | def init_memory(self, batch): 334 | return repeat(self.memory_tokens, 'm d -> b m d', b = batch) 335 | 336 | def forward( 337 | self, 338 | x, 339 | read_memories = None, 340 | *, 341 | mask = None, 342 | labels = None, 343 | xl_memories: list[Tensor] | None = None, 344 | mask_out_read_memories = False # in the case one is passing in 0s for read memories, for onnx-able model 345 | ): 346 | has_xl_memories = exists(xl_memories) and len(xl_memories) > 0 347 | 348 | b, n, device, mem_length, return_loss = *x.shape, x.device, self.num_memory_tokens, exists(labels) 349 | 350 | assert n <= self.seq_len 351 | 352 | pos = torch.arange(n, device = device) 353 | 354 | x = self.token_emb(x) 355 | 356 | # maybe absolute positional embedding 357 | 358 | if exists(self.pos_emb): 359 | x = x + self.pos_emb(pos) 360 | 361 | # trick from cogview paper 362 | 363 | x = frac_gradient(x, self.emb_gradient_frac) 364 | 365 | # prepare write memories, as in paper 366 | 367 | write_memories = self.init_memory(b) 368 | 369 | if exists(read_memories) and self.add_write_to_next_write_mem: 370 | maybe_detach = torch.detach if self.next_write_mem_stop_grad else identity 371 | write_memories = write_memories + maybe_detach(read_memories) 372 | 373 | # prepare read memories 374 | 375 | if exists(read_memories): 376 | if read_memories.ndim == 2: 377 | read_memories = repeat(read_memories, 'n d -> b n d', b = b) 378 | 379 | read_mem_length = mem_length 380 | read_memories = read_memories + self.read_memory_emb 381 | elif self.always_have_read_memories: 382 | read_mem_length = mem_length 383 | read_memories = repeat(self.read_memory_emb, 'n d -> b n d', b = b) 384 | else: 385 | read_mem_length = 0 386 | read_memories = x[:, 0:0] 387 | 388 | # concat to main sequence using einop's pack 389 | 390 | x, ps = pack([read_memories, x, write_memories], 'b * d') 391 | 392 | # take care of mask 393 | 394 | if exists(mask): 395 | mask = F.pad(mask, (read_mem_length, mem_length), value = True) 396 | 397 | # custom causal mask, if needed 398 | 399 | if self.use_custom_causal_attn_mask: 400 | causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).tril() 401 | 402 | causal_mask = F.pad(causal_mask, (0, mem_length, read_mem_length, 0), value = False) 403 | causal_mask = F.pad(causal_mask, (read_mem_length, 0, 0, mem_length), value = True) 404 | 405 | causal_mask = rearrange(causal_mask, 'i j -> 1 1 i j') 406 | 407 | if exists(mask): 408 | mask = rearrange(mask, 'b j -> b 1 1 j') 409 | mask = mask & causal_mask 410 | else: 411 | mask = causal_mask 412 | 413 | # masking out read memories, either for passing in 0s for read memories on first step, or if you are doing some regularization game on the memories 414 | 415 | if read_mem_length > 0 and mask_out_read_memories: 416 | read_mem_mask = torch.arange(x.shape[-2], device = device) < read_mem_length 417 | 418 | if exists(mask): 419 | mask = mask & ~read_mem_mask 420 | else: 421 | mask = read_mem_mask 422 | 423 | # rotary embedding - offset main positions by 10000, and keep all memories at position 0 424 | 425 | rotary_emb = None 426 | 427 | if exists(self.rotary_pos_emb): 428 | mem_rel_dist = 10000 429 | 430 | q_pos = pos + mem_rel_dist 431 | 432 | if has_xl_memories: 433 | xl_mem_length = xl_memories[0].shape[-2] 434 | q_pos += xl_mem_length 435 | 436 | q_pos = F.pad(q_pos, (read_mem_length, mem_length), value = 0) 437 | q_rotary_emb = self.rotary_pos_emb(q_pos) 438 | 439 | # kind of confusing at the moment 440 | # but the order of the keys are - [xl memories] [read memories] [main sequence] [ write memories] 441 | # so the positions are (say xl memory length of 256) - [10001, 10002, 10003 ...] [0, 0, ...] [10256, 10257, ...] [0, 0, ...] 442 | 443 | if has_xl_memories: 444 | k_pos = torch.arange(xl_mem_length, device = device) + mem_rel_dist 445 | k_pos = torch.cat((k_pos, q_pos), dim = -1) 446 | else: 447 | k_pos = q_pos 448 | 449 | # account for null key / value 450 | 451 | k_pos = F.pad(k_pos, (1, 0), value = mem_rel_dist - 1) # give a null memory token, to allow for attending to nothing 452 | 453 | k_rotary_emb = self.rotary_pos_emb(k_pos) 454 | 455 | rotary_emb = (q_rotary_emb, k_rotary_emb) 456 | 457 | # prepare xl memories 458 | 459 | xl_memories = default(xl_memories, []) 460 | xl_memories_iter = iter(xl_memories) 461 | new_xl_memories = [] 462 | 463 | if has_xl_memories and self.enhanced_xl_recurrence and len(xl_memories) > 1: # simply shift all the xl memories down by one, so lower layer gets access to representations from layer above 464 | xl_memories = [*xl_memories[1:], xl_memories[0]] 465 | 466 | # value residual 467 | 468 | value_residual = None 469 | 470 | # expand streams for hyper connections 471 | 472 | x = self.expand_streams(x) 473 | 474 | # attention and feedforward 475 | 476 | for attn, ff in self.layers: 477 | x, xl_memories, attn_values = attn(x, mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb, value_residual = value_residual) 478 | 479 | value_residual = default(value_residual, attn_values) 480 | new_xl_memories.append(xl_memories) 481 | 482 | x = ff(x) 483 | 484 | # reduce streams for hyper connections 485 | 486 | x = self.reduce_streams(x) 487 | 488 | # final norm 489 | 490 | x = self.norm(x) 491 | 492 | # whether to return xl memories 493 | 494 | next_xl_memories = None 495 | 496 | if self.use_xl_memories: 497 | next_xl_memories = list(map(lambda t: torch.detach(t[..., -self.xl_mem_len:, :]), new_xl_memories)) 498 | 499 | # split out memories using unpack 500 | 501 | read_memories, x, write_memories = unpack(x, ps, 'b * d') 502 | 503 | # to logits 504 | 505 | logits = self.to_logits(x) 506 | 507 | if not return_loss: 508 | return logits, write_memories, next_xl_memories 509 | 510 | loss = F.cross_entropy( 511 | rearrange(logits, 'b n c -> b c n'), 512 | labels, 513 | ignore_index = self.ignore_index 514 | ) 515 | 516 | return loss, write_memories, next_xl_memories 517 | 518 | # wrapper to manage many segments 519 | 520 | class RecurrentMemoryTransformerWrapper(Module): 521 | def __init__( 522 | self, 523 | transformer: RecurrentMemoryTransformer, 524 | truncate_at_step = None # number of steps before detaching memories (truncated bptt). with memory replay checkpointing, there should be no memory issues, but in case of instability, as reported in initial paper 525 | ): 526 | super().__init__() 527 | self.transformer = transformer 528 | self.seq_len = transformer.seq_len 529 | self.truncate_at_step = truncate_at_step 530 | 531 | @torch.no_grad() 532 | @eval_decorator 533 | def generate( 534 | self, 535 | prime, 536 | *, 537 | length, 538 | memories = None, 539 | xl_memories: list[Tensor] | None = None, 540 | temperature = 1., 541 | filter_thres = 0.9 542 | ): 543 | assert self.transformer.causal, 'only autoregressive transformers can generate' 544 | 545 | start_len, seq_len = prime.shape[-1], self.seq_len 546 | 547 | assert length >= start_len 548 | 549 | *past_segments, curr_segment = prime.split(seq_len, dim = -1) 550 | 551 | # catch memories up to the current segment 552 | 553 | for past_segment in past_segments: 554 | _, memories, xl_memories = self.transformer(past_segment, memories, xl_memories = xl_memories) 555 | 556 | # sample for the remaining length 557 | 558 | for ind in range(length - start_len): 559 | logits, next_memories, next_xl_memories = self.transformer(curr_segment, memories, xl_memories = xl_memories) 560 | 561 | logits = logits[:, -1] 562 | 563 | filtered_logits = top_k(logits, thres = filter_thres) 564 | sampled = gumbel_sample(filtered_logits, temperature = temperature) 565 | sampled = rearrange(sampled, 'b -> b 1') 566 | 567 | curr_segment = torch.cat((curr_segment, sampled), dim = -1) 568 | 569 | if divisible_by(curr_segment.shape[-1] - 1, seq_len): 570 | memories = next_memories 571 | xl_memories = next_xl_memories 572 | 573 | past_segment, curr_segment = curr_segment[..., :seq_len], curr_segment[..., -1:] 574 | past_segments.append(past_segment) 575 | 576 | # add current segment to all segments 577 | 578 | past_segments.append(curr_segment) 579 | 580 | # reconcat all segments 581 | 582 | output = torch.cat(past_segments, dim = -1) 583 | 584 | output = output[:, start_len:] 585 | return output 586 | 587 | def forward( 588 | self, 589 | x, 590 | memories = None, 591 | *, 592 | mask = None, 593 | xl_memories: list[Tensor] | None = None, 594 | return_loss = False, 595 | labels = None, 596 | truncate_at_step = None, # if set, this would override the truncate_at_step at init 597 | memory_replay_backprop = False, # whether to have the class do the backwards pass memory efficiently 598 | mrbp_loss_weight = 1. # if using memory replay backprop with gradient accumulation, scale loss by this factor ex. (1. / ) 599 | ): 600 | seq_len, truncate_at_step = self.seq_len, default(truncate_at_step, self.truncate_at_step) 601 | 602 | labels = None 603 | if (return_loss or memory_replay_backprop) and not exists(labels): 604 | x, labels = x[:, :-1], x[:, 1:] 605 | 606 | # segment input 607 | 608 | segments = x.split(seq_len, dim = -1) 609 | total_length = x.shape[-1] 610 | num_segments = len(segments) 611 | segment_length_frac = tuple(map(lambda t: t.shape[-1] / total_length, segments)) 612 | 613 | # default values 614 | 615 | label_segments = mask_segments = (None,) 616 | 617 | # take care of labels 618 | 619 | if exists(labels): 620 | label_segments = labels.split(seq_len, dim = -1) 621 | 622 | # take care of the mask 623 | 624 | if exists(mask): 625 | mask_segments = mask.split(seq_len, dim = -1) 626 | 627 | # keep replay buffer 628 | 629 | replay_buffer = [memories] 630 | 631 | # replay buffer for xl memories 632 | 633 | xl_segments = [xl_memories] 634 | 635 | # decide context of forward depending on whether doing memory-replay-backprop 636 | 637 | forward_context = nullcontext if not memory_replay_backprop else torch.no_grad 638 | 639 | # forward and get all outputs (can be either loss or logits) 640 | 641 | logits = [] 642 | losses = [] 643 | 644 | for step, (segment, mask_segment, label_segment, loss_weight) in enumerate(zip_longest(segments, mask_segments, label_segments, segment_length_frac)): 645 | 646 | with forward_context(): 647 | output, memories, xl_memories = self.transformer(segment, memories, mask = mask_segment, labels = label_segment) 648 | 649 | if exists(truncate_at_step) and divisible_by(step + 1, truncate_at_step): 650 | memories = memories.detach() 651 | 652 | replay_buffer.append(memories) 653 | 654 | xl_segments.append(xl_memories) 655 | 656 | if return_loss: 657 | losses.append(output * loss_weight) 658 | else: 659 | logits.append(output) 660 | 661 | # whether to do memory replay backpropagation 662 | 663 | # https://arxiv.org/abs/2010.06891 664 | # algorithm 1 665 | 666 | if memory_replay_backprop: 667 | memories_grad = torch.zeros_like(replay_buffer[-1]) 668 | 669 | reversed_inputs = zip_longest(*map(reversed, [ 670 | range(num_segments), 671 | segments, 672 | replay_buffer[:-1], 673 | xl_segments[:-1], 674 | mask_segments, 675 | label_segments, 676 | segment_length_frac, 677 | ])) 678 | 679 | total_loss = 0. 680 | 681 | for step, segment, segment_memories, segment_xl_memories, mask_segment, label_segment, loss_weight in reversed_inputs: 682 | is_first = step == 0 683 | 684 | if exists(segment_memories): 685 | segment_memories.requires_grad_() 686 | 687 | loss, next_segment_memories, _ = self.transformer(segment, segment_memories, mask = mask_segment, xl_memories = segment_xl_memories, labels = label_segment) 688 | 689 | weighted_loss = loss * loss_weight * mrbp_loss_weight 690 | 691 | weighted_loss.backward(retain_graph = True) 692 | 693 | next_segment_memories.backward(memories_grad) 694 | 695 | total_loss += weighted_loss 696 | 697 | if is_first: 698 | continue 699 | 700 | if exists(truncate_at_step) and divisible_by(step, truncate_at_step): 701 | memories_grad.zero_() 702 | else: 703 | memories_grad.copy_(segment_memories.grad.data) 704 | 705 | return total_loss 706 | 707 | # return logits if needed 708 | 709 | if not return_loss: 710 | logits = torch.cat(logits, dim = -2) 711 | return logits, memories 712 | 713 | # otherwise return losses 714 | 715 | return sum(losses), memories 716 | -------------------------------------------------------------------------------- /rmt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/recurrent-memory-transformer-pytorch/520a3574c5a00e452d2af3fb1c26f15a3779c8bb/rmt.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'recurrent-memory-transformer-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.7.0', 7 | license='MIT', 8 | description = 'Recurrent Memory Transformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/recurrent-memory-transformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'recurrence', 19 | 'memory', 20 | 'long-context' 21 | ], 22 | install_requires=[ 23 | 'einops>=0.8.0', 24 | 'hyper-connections>=0.1.7', 25 | 'torch>=2.3', 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Developers', 30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3.6', 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | import tqdm 4 | import numpy as np 5 | 6 | import torch 7 | from torch.optim import Adam 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 4 17 | GRADIENT_ACCUMULATE_EVERY = 4 18 | LEARNING_RATE = 1e-4 19 | VALIDATE_EVERY = 100 20 | PRIME_LENGTH = 128 21 | GENERATE_EVERY = 250 22 | GENERATE_LENGTH = 2048 23 | SEQ_LEN = 2048 24 | 25 | # helpers 26 | 27 | def cycle(loader): 28 | while True: 29 | for data in loader: 30 | yield data 31 | 32 | def decode_token(token): 33 | return str(chr(max(32, token))) 34 | 35 | def decode_tokens(tokens): 36 | return "".join(list(map(decode_token, tokens))) 37 | 38 | 39 | # instantiate palm 40 | 41 | model = RecurrentMemoryTransformer( 42 | num_tokens = 256, 43 | dim = 512, 44 | depth = 6, 45 | dim_head = 64, 46 | heads = 8, 47 | seq_len = 512, 48 | use_flash_attn = True, 49 | num_memory_tokens = 128, 50 | use_xl_memories = True, 51 | xl_mem_len = 256 52 | ) 53 | 54 | model = RecurrentMemoryTransformerWrapper(model) 55 | 56 | model.cuda() 57 | 58 | # prepare enwik8 data 59 | 60 | with gzip.open("./data/enwik8.gz") as file: 61 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 62 | np_train, np_valid = np.split(data, [int(90e6)]) 63 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid) 64 | 65 | class TextSamplerDataset(Dataset): 66 | def __init__(self, data, seq_len): 67 | super().__init__() 68 | self.data = data 69 | self.seq_len = seq_len 70 | 71 | def __getitem__(self, index): 72 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 73 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() 74 | return full_seq.cuda() 75 | 76 | def __len__(self): 77 | return self.data.size(0) // self.seq_len 78 | 79 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 80 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 81 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 82 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 83 | 84 | # optimizer 85 | 86 | optim = Adam(model.parameters(), lr = LEARNING_RATE) 87 | 88 | # training 89 | 90 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): 91 | model.train() 92 | 93 | total_loss = 0. 94 | for _ in range(GRADIENT_ACCUMULATE_EVERY): 95 | loss = model( 96 | next(train_loader), 97 | memory_replay_backprop = True, 98 | mrbp_loss_weight = 1. / GRADIENT_ACCUMULATE_EVERY 99 | ) 100 | 101 | total_loss += loss 102 | 103 | print(f"training loss: {total_loss.item()}") 104 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 105 | 106 | optim.step() 107 | optim.zero_grad() 108 | 109 | if i % VALIDATE_EVERY == 0: 110 | model.eval() 111 | with torch.no_grad(): 112 | loss, _ = model(next(val_loader), return_loss = True) 113 | print(f"validation loss: {loss.item()}") 114 | 115 | if i % GENERATE_EVERY == 0: 116 | model.eval() 117 | inp = random.choice(val_dataset)[:PRIME_LENGTH] 118 | prime = decode_tokens(inp) 119 | print(f"%s \n\n %s", (prime, "*" * 100)) 120 | 121 | sample = model.generate(inp[None, :], length = GENERATE_LENGTH) 122 | output_str = decode_tokens(sample[0]) 123 | print(output_str, "\n") 124 | --------------------------------------------------------------------------------