├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── recurrent_memory_transformer_pytorch
├── __init__.py
├── attend.py
└── recurrent_memory_transformer.py
├── rmt.png
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Recurrent Memory Transformer - Pytorch
4 |
5 | Implementation of Recurrent Memory Transformer (openreview) in Pytorch. They had a short follow up paper recently that demonstrated it was able to copy information across 1 million tokens at the very least.
6 |
7 | There is no doubt in my mind that RMT would make a stronger RL agent than AdA, which is just a Transformer-XL - Update: Recurrent Action Transformer with Memory (RATE)
8 |
9 | Yannic Kilcher paper review
10 |
11 | ## Appreciation
12 |
13 | - Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
14 |
15 | ## Install
16 |
17 | ```bash
18 | $ pip install recurrent-memory-transformer-pytorch
19 | ```
20 |
21 | ## Usage
22 |
23 | ```python
24 | import torch
25 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer
26 |
27 | model = RecurrentMemoryTransformer(
28 | num_tokens = 20000, # number of tokens
29 | num_memory_tokens = 128, # number of memory tokens, this will determine the bottleneck for information being passed to the future
30 | dim = 512, # model dimensions
31 | depth = 6, # transformer depth
32 | causal = True, # autoregressive or not
33 | dim_head = 64, # dimension per head
34 | heads = 8, # heads
35 | seq_len = 1024, # sequence length of a segment
36 | use_flash_attn = True # whether to use flash attention
37 | )
38 |
39 | x = torch.randint(0, 256, (1, 1024))
40 |
41 | logits1, mem1, _ = model(x) # (1, 1024, 20000), (1, 128, 512), None
42 | logits2, mem2, _ = model(x, mem1) # (1, 1024, 20000), (1, 128, 512), None
43 | logits3, mem3, _ = model(x, mem2) # (1, 1024, 20000), (1, 128, 512), None
44 |
45 | # and so on ...
46 |
47 | ```
48 |
49 | With XL memories
50 |
51 | ```python
52 | import torch
53 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer
54 |
55 | model = RecurrentMemoryTransformer(
56 | num_tokens = 20000,
57 | num_memory_tokens = 128,
58 | dim = 512,
59 | depth = 6,
60 | causal = True,
61 | dim_head = 64,
62 | heads = 8,
63 | seq_len = 1024,
64 | use_flash_attn = True,
65 | use_xl_memories = True, # set this to True
66 | xl_mem_len = 512 # can be shorter than the seq len - i think just having a bit of the past will prevent much of the RMT memories memorizing the immediate preceding text
67 | )
68 |
69 | x = torch.randint(0, 256, (1, 1024))
70 |
71 | logits1, mem1, xl_mem1 = model(x) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
72 | logits2, mem2, xl_mem2 = model(x, mem1, xl_memories = xl_mem1) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
73 | logits3, mem3, xl_mem3 = model(x, mem2, xl_memories = xl_mem2) # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
74 |
75 | # and so on ...
76 | ```
77 |
78 | Train on an absurdly long sequence
79 |
80 | ```python
81 | import torch
82 | from recurrent_memory_transformer_pytorch import (
83 | RecurrentMemoryTransformer,
84 | RecurrentMemoryTransformerWrapper
85 | )
86 |
87 | model = RecurrentMemoryTransformer(
88 | num_tokens = 256,
89 | num_memory_tokens = 128,
90 | dim = 512,
91 | depth = 6,
92 | seq_len = 1024,
93 | use_flash_attn = True,
94 | causal = True
95 | )
96 |
97 | model = RecurrentMemoryTransformerWrapper(model).cuda()
98 |
99 | seq = torch.randint(0, 256, (4, 65536)).cuda() # absurdly long sequence, in reality, they curriculum learned this starting with 1 segment to about 7-8 segments
100 |
101 | loss = model(seq, memory_replay_backprop = True) # memory efficient training from memformer paper
102 |
103 | ```
104 |
105 | ## Todo
106 |
107 | - [ ] move the memory replay backprop into a torch.function, test out bidirectional, then test on a real problem
108 |
109 | - [x] get rotary embeddings working properly with xl memories
110 | - [x] add xl memories, detached
111 | - [x] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift
112 | - [x] make memories being causally masked an option
113 | - [x] add the memory replay backprop technique from memformer paper
114 | - [x] relative positional encoding
115 |
116 | ## Alternatives
117 |
118 | - Block Recurrent Transformer
119 |
120 | - Memformer
121 |
122 | ## Citations
123 |
124 | ```bibtex
125 | @inproceedings{bulatov2022recurrent,
126 | title = {Recurrent Memory Transformer},
127 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail Burtsev},
128 | booktitle = {Advances in Neural Information Processing Systems},
129 | editor = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
130 | year = {2022},
131 | url = {https://openreview.net/forum?id=Uynr3iPhksa}
132 | }
133 | ```
134 |
135 | ```bibtex
136 | @misc{bulatov2023scaling,
137 | title = {Scaling Transformer to 1M tokens and beyond with RMT},
138 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
139 | year = {2023},
140 | eprint = {2304.11062},
141 | archivePrefix = {arXiv},
142 | primaryClass = {cs.CL}
143 | }
144 | ```
145 |
146 | ```bibtex
147 | @inproceedings{dao2022flashattention,
148 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
149 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
150 | booktitle = {Advances in Neural Information Processing Systems},
151 | year = {2022}
152 | }
153 | ```
154 |
155 | ```bibtex
156 | @misc{shazeer2020glu,
157 | title = {GLU Variants Improve Transformer},
158 | author = {Noam Shazeer},
159 | year = {2020},
160 | url = {https://arxiv.org/abs/2002.05202}
161 | }
162 | ```
163 |
164 | ```bibtex
165 | @misc{su2021roformer,
166 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
167 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
168 | year = {2021},
169 | eprint = {2104.09864},
170 | archivePrefix = {arXiv},
171 | primaryClass = {cs.CL}
172 | }
173 | ```
174 |
175 | ```bibtex
176 | @inproceedings{Wu2020MemformerAM,
177 | title = {Memformer: A Memory-Augmented Transformer for Sequence Modeling},
178 | author = {Qingyang Wu and Zhenzhong Lan and Kun Qian and Jing Gu and Alborz Geramifard and Zhou Yu},
179 | booktitle = {AACL/IJCNLP},
180 | year = {2020}
181 | }
182 | ```
183 |
184 | ```bibtex
185 | @software{peng_bo_2021_5196578,
186 | author = {PENG Bo},
187 | title = {BlinkDL/RWKV-LM: 0.01},
188 | month = {aug},
189 | year = {2021},
190 | publisher = {Zenodo},
191 | version = {0.01},
192 | doi = {10.5281/zenodo.5196578},
193 | url = {https://doi.org/10.5281/zenodo.5196578}
194 | }
195 | ```
196 |
197 | ```bibtex
198 | @misc{ding2021cogview,
199 | title = {CogView: Mastering Text-to-Image Generation via Transformers},
200 | author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
201 | year = {2021},
202 | eprint = {2105.13290},
203 | archivePrefix = {arXiv},
204 | primaryClass = {cs.CV}
205 | }
206 | ```
207 |
208 | ```bibtex
209 | @software{Dayma_DALLE_Mini_2021,
210 | author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
211 | doi = {10.5281/zenodo.5146400},
212 | license = {Apache-2.0},
213 | month = {jul},
214 | title = {{DALL·E Mini}},
215 | url = {https://github.com/borisdayma/dalle-mini},
216 | version = {v0.1-alpha},
217 | year = {2021}}
218 | ```
219 |
220 | ```bibtex
221 | @inproceedings{anonymous2022normformer,
222 | title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
223 | author = {Anonymous},
224 | booktitle = {Submitted to The Tenth International Conference on Learning Representations },
225 | year = {2022},
226 | url = {https://openreview.net/forum?id=GMYWzWztDx5},
227 | note = {under review}
228 | }
229 | ```
230 |
231 | ```bibtex
232 | @misc{ding2021erniedoc,
233 | title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
234 | author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
235 | year = {2021},
236 | eprint = {2012.15688},
237 | archivePrefix = {arXiv},
238 | primaryClass = {cs.CL}
239 | }
240 | ```
241 |
242 | ```bibtex
243 | @article{Zhu2024HyperConnections,
244 | title = {Hyper-Connections},
245 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
246 | journal = {ArXiv},
247 | year = {2024},
248 | volume = {abs/2409.19606},
249 | url = {https://api.semanticscholar.org/CorpusID:272987528}
250 | }
251 | ```
252 |
253 | ```bibtex
254 | @inproceedings{Zhou2024ValueRL,
255 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
256 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
257 | year = {2024},
258 | url = {https://api.semanticscholar.org/CorpusID:273532030}
259 | }
260 | ```
261 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/recurrent-memory-transformer-pytorch/520a3574c5a00e452d2af3fb1c26f15a3779c8bb/data/enwik8.gz
--------------------------------------------------------------------------------
/recurrent_memory_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from recurrent_memory_transformer_pytorch.recurrent_memory_transformer import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper
2 |
--------------------------------------------------------------------------------
/recurrent_memory_transformer_pytorch/attend.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from functools import wraps
3 | from packaging import version
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from einops import rearrange
10 |
11 | # constants
12 |
13 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def once(fn):
21 | called = False
22 | @wraps(fn)
23 | def inner(x):
24 | nonlocal called
25 | if called:
26 | return
27 | called = True
28 | return fn(x)
29 | return inner
30 |
31 | print_once = once(print)
32 |
33 | # main class
34 |
35 | class Attend(nn.Module):
36 | def __init__(
37 | self,
38 | dropout = 0.,
39 | causal = False,
40 | use_flash = False
41 | ):
42 | super().__init__()
43 | self.dropout = dropout
44 | self.attn_dropout = nn.Dropout(dropout)
45 |
46 | self.causal = causal
47 | self.register_buffer("mask", None, persistent=False)
48 |
49 | self.use_flash = use_flash
50 | assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
51 |
52 | # determine efficient attention configs for cuda and cpu
53 |
54 | self.cpu_config = Config(True, True, True)
55 | self.cuda_config = None
56 |
57 | if not torch.cuda.is_available() or not use_flash:
58 | return
59 |
60 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
61 |
62 | if device_properties.major == 8 and device_properties.minor == 0:
63 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
64 | self.cuda_config = Config(True, False, False)
65 | else:
66 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
67 | self.cuda_config = Config(False, True, True)
68 |
69 | def get_mask(self, n, device):
70 | if exists(self.mask) and self.mask.shape[-1] >= n:
71 | return self.mask[:n, :n]
72 |
73 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
74 | self.register_buffer("mask", mask, persistent=False)
75 | return mask
76 |
77 | def flash_attn(self, q, k, v, mask = None):
78 | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
79 |
80 | # Check if mask exists and expand to compatible shape
81 | # The mask is B L, so it would have to be expanded to B H N L
82 |
83 | if exists(mask):
84 | if mask.ndim != 4:
85 | mask = rearrange(mask, 'b j -> b 1 1 j')
86 |
87 | mask = mask.expand(-1, heads, q_len, -1)
88 |
89 | # Check if there is a compatible device for flash attention
90 |
91 | config = self.cuda_config if is_cuda else self.cpu_config
92 |
93 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
94 |
95 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
96 | out = F.scaled_dot_product_attention(
97 | q, k, v,
98 | attn_mask = mask,
99 | dropout_p = self.dropout if self.training else 0.,
100 | is_causal = self.causal
101 | )
102 |
103 | return out
104 |
105 | def forward(self, q, k, v, mask = None):
106 | """
107 | einstein notation
108 | b - batch
109 | h - heads
110 | n, i, j - sequence length (base sequence length, source, target)
111 | d - feature dimension
112 | """
113 |
114 | n, device = q.shape[-2], q.device
115 |
116 | scale = q.shape[-1] ** -0.5
117 |
118 | if self.use_flash:
119 | return self.flash_attn(q, k, v, mask = mask)
120 |
121 | # similarity
122 |
123 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
124 |
125 | # key padding mask
126 |
127 | if exists(mask):
128 | if mask.ndim != 4:
129 | mask = rearrange(mask, 'b j -> b 1 1 j')
130 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
131 |
132 | # causal mask
133 |
134 | if self.causal:
135 | causal_mask = self.get_mask(n, device)
136 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
137 |
138 | # attention
139 |
140 | attn = sim.softmax(dim=-1)
141 | attn = self.attn_dropout(attn)
142 |
143 | # aggregate values
144 |
145 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
146 |
147 | return out
148 |
--------------------------------------------------------------------------------
/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from functools import partial
5 | from itertools import zip_longest
6 | from contextlib import nullcontext
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.nn import Module, ModuleList
11 | from torch import nn, einsum, Tensor
12 |
13 | from einops import rearrange, repeat, pack, unpack
14 | from einops.layers.torch import Rearrange
15 |
16 | from recurrent_memory_transformer_pytorch.attend import Attend
17 |
18 | from hyper_connections import get_init_and_expand_reduce_stream_functions
19 |
20 | # constants
21 |
22 | Linear = partial(nn.Linear, bias = False)
23 |
24 | # helpers
25 |
26 | def exists(val):
27 | return val is not None
28 |
29 | def identity(t, *args, **kwargs):
30 | return t
31 |
32 | def default(*vals):
33 | for val in vals:
34 | if exists(val):
35 | return val
36 | return None
37 |
38 | def eval_decorator(fn):
39 | def inner(self, *args, **kwargs):
40 | was_training = self.training
41 | self.eval()
42 | out = fn(self, *args, **kwargs)
43 | self.train(was_training)
44 | return out
45 | return inner
46 |
47 | def divisible_by(numer, denom):
48 | return (numer % denom) == 0
49 |
50 | # sampling helpers
51 |
52 | def log(t, eps = 1e-20):
53 | return torch.log(t.clamp(min = eps))
54 |
55 | def gumbel_noise(t):
56 | noise = torch.zeros_like(t).uniform_(0, 1)
57 | return -log(-log(noise))
58 |
59 | def gumbel_sample(t, temperature = 1., dim = -1):
60 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
61 |
62 | def top_k(logits, thres = 0.9):
63 | k = math.ceil((1 - thres) * logits.shape[-1])
64 | val, ind = torch.topk(logits, k)
65 | probs = torch.full_like(logits, float('-inf'))
66 | probs.scatter_(1, ind, val)
67 | return probs
68 |
69 | def frac_gradient(t, frac = 1.):
70 | if frac == 1.:
71 | return t
72 |
73 | return t * frac + t.detach() * (1. - frac)
74 |
75 | # rotary embedding
76 |
77 | class RotaryEmbedding(Module):
78 | def __init__(self, dim, theta = 32768):
79 | super().__init__()
80 | inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
81 | self.register_buffer('inv_freq', inv_freq)
82 |
83 | def forward(self, positions):
84 | freqs = torch.einsum('i , j -> i j', positions, self.inv_freq)
85 | freqs = torch.cat((freqs, freqs), dim = -1)
86 | return freqs
87 |
88 | def rotate_half(x):
89 | x1, x2 = x.chunk(2, dim=-1)
90 | return torch.cat((-x2, x1), dim=-1)
91 |
92 | def apply_rotary_pos_emb(pos, t):
93 | return (t * pos.cos()) + (rotate_half(t) * pos.sin())
94 |
95 | # feedforward
96 |
97 | class GEGLU(Module):
98 | def forward(self, x):
99 | x, gate = x.chunk(2, dim = -1)
100 | return x * F.gelu(gate)
101 |
102 | def FeedForward(dim, mult = 4, dropout = 0.):
103 | dim_inner = int(dim * mult * 2 / 3)
104 |
105 | return nn.Sequential(
106 | nn.RMSNorm(dim),
107 | Linear(dim, dim_inner * 2),
108 | GEGLU(),
109 | nn.Dropout(dropout),
110 | Linear(dim_inner, dim)
111 | )
112 |
113 | # attention
114 |
115 | class Attention(Module):
116 | def __init__(
117 | self,
118 | *,
119 | dim,
120 | causal = False,
121 | dim_head = 64,
122 | heads = 8,
123 | dropout = 0.,
124 | accept_value_residual = False,
125 | use_flash_attn = False,
126 | use_custom_causal_attn_mask = False
127 | ):
128 | super().__init__()
129 | self.norm = nn.RMSNorm(dim)
130 |
131 | dim_inner = dim_head * heads
132 | self.heads = heads
133 |
134 | self.attend = Attend(
135 | causal = causal and not use_custom_causal_attn_mask,
136 | dropout = dropout,
137 | use_flash = use_flash_attn
138 | )
139 |
140 | self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
141 |
142 | self.to_q = Linear(dim, dim_inner)
143 | self.to_kv = Linear(dim, dim_inner * 2)
144 | self.to_out = Linear(dim_inner, dim)
145 |
146 | # learned value residual mixing
147 |
148 | self.learned_value_residual_mix = None
149 |
150 | if accept_value_residual:
151 | self.learned_value_residual_mix = nn.Sequential(
152 | Linear(dim, heads),
153 | Rearrange('b n h -> b h n 1'),
154 | nn.Sigmoid()
155 | )
156 |
157 | def forward(
158 | self,
159 | x,
160 | rotary_emb: tuple[Tensor, Tensor] | None = None,
161 | mask = None,
162 | xl_memories = None,
163 | value_residual = None
164 | ):
165 | assert not (exists(value_residual) ^ exists(self.learned_value_residual_mix))
166 |
167 | h = self.heads
168 | x = self.norm(x)
169 |
170 | q = self.to_q(x)
171 | k, v = self.to_kv(x).chunk(2, dim = -1)
172 |
173 | # split heads
174 |
175 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
176 |
177 | # handle value residual
178 |
179 | orig_v = v
180 |
181 | if exists(self.learned_value_residual_mix):
182 | mix = self.learned_value_residual_mix(x)
183 | v = v.lerp(value_residual, mix)
184 |
185 | # add a null key / value
186 | # to protect against an entirely masked out sequence
187 | # as well as giving attention ability to attend to nothing
188 |
189 | nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = x.shape[0]), self.null_kv)
190 |
191 | k = torch.cat((nk, k), dim = -2)
192 | v = torch.cat((nv, v), dim = -2)
193 |
194 | if exists(mask):
195 | mask = F.pad(mask, (1, 0), value = True)
196 |
197 | # manage memories
198 |
199 | next_xl_memories = torch.stack((k, v))
200 |
201 | if exists(xl_memories):
202 | kx, vx = xl_memories
203 | k = torch.cat((kx, k), dim = -2)
204 | v = torch.cat((vx, v), dim = -2)
205 |
206 | if exists(mask):
207 | mask = F.pad(mask, (xl_memories.shape[-2], 0), value = True)
208 |
209 | if exists(rotary_emb):
210 | q_rotary_emb, k_rotary_emb = rotary_emb
211 |
212 | q = apply_rotary_pos_emb(q_rotary_emb, q)
213 | k = apply_rotary_pos_emb(k_rotary_emb, k)
214 |
215 | out = self.attend(q, k, v, mask = mask)
216 |
217 | out = rearrange(out, 'b h n d -> b n (h d)')
218 |
219 | return self.to_out(out), next_xl_memories, orig_v
220 |
221 | # transformer
222 |
223 | class RecurrentMemoryTransformer(Module):
224 | def __init__(
225 | self,
226 | dim,
227 | *,
228 | num_tokens,
229 | depth,
230 | num_memory_tokens,
231 | seq_len,
232 | causal = True,
233 | dim_head = 64,
234 | heads = 8,
235 | ff_mult = 4,
236 | attn_dropout = 0.,
237 | ff_dropout = 0.,
238 | use_flash_attn = False,
239 | ignore_index = -1,
240 | abs_pos_emb = True,
241 | rotary_pos_emb = False,
242 | use_xl_memories = True,
243 | xl_mem_len = None,
244 | enhanced_xl_recurrence = False, # add simple method for enhancing receptive field of xl memories, from ernie-doc paper
245 | emb_gradient_frac = 0.1, # trick from cogview paper that leads to a bit more stability
246 | memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
247 | add_write_to_next_write_mem = False, # add the write memories of previous step to the next write step - thanks to @IcarusWizard for pointing out this discrepancy
248 | next_write_mem_stop_grad = True, # whether to stop gradient of previous read memory -> next write memory
249 | always_have_read_memories = True, # whether to always have read memories, even on the first step, so to make the model onnx-able
250 | num_residual_streams = 4 # number of residual streams for hyper connections
251 | ):
252 | super().__init__()
253 | self.causal = causal
254 | self.seq_len = seq_len
255 |
256 | self.emb_gradient_frac = emb_gradient_frac
257 |
258 | assert num_memory_tokens > 0
259 |
260 | self.token_emb = nn.Embedding(num_tokens, dim)
261 |
262 | # positions
263 |
264 | assert any([abs_pos_emb, rotary_pos_emb])
265 |
266 | self.pos_emb = nn.Embedding(seq_len, dim) if abs_pos_emb else None
267 |
268 | self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else None
269 |
270 | # memory related
271 |
272 | self.num_memory_tokens = num_memory_tokens
273 |
274 | self.read_memory_emb = nn.Parameter(torch.zeros(num_memory_tokens, dim))
275 | nn.init.normal_(self.read_memory_emb, std = 0.02)
276 |
277 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
278 | nn.init.normal_(self.memory_tokens, std = 0.02)
279 |
280 | # xl memories
281 |
282 | xl_mem_len = default(xl_mem_len, seq_len)
283 | assert xl_mem_len <= seq_len
284 | self.xl_mem_len = xl_mem_len
285 |
286 | self.use_xl_memories = use_xl_memories
287 | self.enhanced_xl_recurrence = enhanced_xl_recurrence
288 |
289 | # hyper connections
290 |
291 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
292 |
293 | # layers
294 |
295 | self.layers = ModuleList([])
296 |
297 | for layer_index in range(depth):
298 | is_first = layer_index == 0
299 |
300 | self.layers.append(ModuleList([
301 | init_hyper_conn(dim = dim, branch = Attention(
302 | dim = dim,
303 | dim_head = dim_head,
304 | causal = causal,
305 | heads = heads,
306 | use_flash_attn = use_flash_attn,
307 | accept_value_residual = not is_first,
308 | use_custom_causal_attn_mask = memory_not_causal,
309 | dropout = attn_dropout
310 | )),
311 | init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
312 | ]))
313 |
314 | self.norm = nn.RMSNorm(dim)
315 | self.to_logits = nn.Linear(dim, num_tokens)
316 |
317 | self.ignore_index = ignore_index
318 |
319 | # whether to use custom attention mask if causal and memory should not be causal
320 |
321 | self.use_custom_causal_attn_mask = causal and memory_not_causal
322 |
323 | # in the paper, they actually also use the previous write memories for the next write memories
324 |
325 | self.add_write_to_next_write_mem = add_write_to_next_write_mem
326 | self.next_write_mem_stop_grad = next_write_mem_stop_grad
327 |
328 | # allow for attending to raw read memory positional embeddings on first step
329 | # hack to make it onnx-able and should not hurt
330 |
331 | self.always_have_read_memories = always_have_read_memories
332 |
333 | def init_memory(self, batch):
334 | return repeat(self.memory_tokens, 'm d -> b m d', b = batch)
335 |
336 | def forward(
337 | self,
338 | x,
339 | read_memories = None,
340 | *,
341 | mask = None,
342 | labels = None,
343 | xl_memories: list[Tensor] | None = None,
344 | mask_out_read_memories = False # in the case one is passing in 0s for read memories, for onnx-able model
345 | ):
346 | has_xl_memories = exists(xl_memories) and len(xl_memories) > 0
347 |
348 | b, n, device, mem_length, return_loss = *x.shape, x.device, self.num_memory_tokens, exists(labels)
349 |
350 | assert n <= self.seq_len
351 |
352 | pos = torch.arange(n, device = device)
353 |
354 | x = self.token_emb(x)
355 |
356 | # maybe absolute positional embedding
357 |
358 | if exists(self.pos_emb):
359 | x = x + self.pos_emb(pos)
360 |
361 | # trick from cogview paper
362 |
363 | x = frac_gradient(x, self.emb_gradient_frac)
364 |
365 | # prepare write memories, as in paper
366 |
367 | write_memories = self.init_memory(b)
368 |
369 | if exists(read_memories) and self.add_write_to_next_write_mem:
370 | maybe_detach = torch.detach if self.next_write_mem_stop_grad else identity
371 | write_memories = write_memories + maybe_detach(read_memories)
372 |
373 | # prepare read memories
374 |
375 | if exists(read_memories):
376 | if read_memories.ndim == 2:
377 | read_memories = repeat(read_memories, 'n d -> b n d', b = b)
378 |
379 | read_mem_length = mem_length
380 | read_memories = read_memories + self.read_memory_emb
381 | elif self.always_have_read_memories:
382 | read_mem_length = mem_length
383 | read_memories = repeat(self.read_memory_emb, 'n d -> b n d', b = b)
384 | else:
385 | read_mem_length = 0
386 | read_memories = x[:, 0:0]
387 |
388 | # concat to main sequence using einop's pack
389 |
390 | x, ps = pack([read_memories, x, write_memories], 'b * d')
391 |
392 | # take care of mask
393 |
394 | if exists(mask):
395 | mask = F.pad(mask, (read_mem_length, mem_length), value = True)
396 |
397 | # custom causal mask, if needed
398 |
399 | if self.use_custom_causal_attn_mask:
400 | causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).tril()
401 |
402 | causal_mask = F.pad(causal_mask, (0, mem_length, read_mem_length, 0), value = False)
403 | causal_mask = F.pad(causal_mask, (read_mem_length, 0, 0, mem_length), value = True)
404 |
405 | causal_mask = rearrange(causal_mask, 'i j -> 1 1 i j')
406 |
407 | if exists(mask):
408 | mask = rearrange(mask, 'b j -> b 1 1 j')
409 | mask = mask & causal_mask
410 | else:
411 | mask = causal_mask
412 |
413 | # masking out read memories, either for passing in 0s for read memories on first step, or if you are doing some regularization game on the memories
414 |
415 | if read_mem_length > 0 and mask_out_read_memories:
416 | read_mem_mask = torch.arange(x.shape[-2], device = device) < read_mem_length
417 |
418 | if exists(mask):
419 | mask = mask & ~read_mem_mask
420 | else:
421 | mask = read_mem_mask
422 |
423 | # rotary embedding - offset main positions by 10000, and keep all memories at position 0
424 |
425 | rotary_emb = None
426 |
427 | if exists(self.rotary_pos_emb):
428 | mem_rel_dist = 10000
429 |
430 | q_pos = pos + mem_rel_dist
431 |
432 | if has_xl_memories:
433 | xl_mem_length = xl_memories[0].shape[-2]
434 | q_pos += xl_mem_length
435 |
436 | q_pos = F.pad(q_pos, (read_mem_length, mem_length), value = 0)
437 | q_rotary_emb = self.rotary_pos_emb(q_pos)
438 |
439 | # kind of confusing at the moment
440 | # but the order of the keys are - [xl memories] [read memories] [main sequence] [ write memories]
441 | # so the positions are (say xl memory length of 256) - [10001, 10002, 10003 ...] [0, 0, ...] [10256, 10257, ...] [0, 0, ...]
442 |
443 | if has_xl_memories:
444 | k_pos = torch.arange(xl_mem_length, device = device) + mem_rel_dist
445 | k_pos = torch.cat((k_pos, q_pos), dim = -1)
446 | else:
447 | k_pos = q_pos
448 |
449 | # account for null key / value
450 |
451 | k_pos = F.pad(k_pos, (1, 0), value = mem_rel_dist - 1) # give a null memory token, to allow for attending to nothing
452 |
453 | k_rotary_emb = self.rotary_pos_emb(k_pos)
454 |
455 | rotary_emb = (q_rotary_emb, k_rotary_emb)
456 |
457 | # prepare xl memories
458 |
459 | xl_memories = default(xl_memories, [])
460 | xl_memories_iter = iter(xl_memories)
461 | new_xl_memories = []
462 |
463 | if has_xl_memories and self.enhanced_xl_recurrence and len(xl_memories) > 1: # simply shift all the xl memories down by one, so lower layer gets access to representations from layer above
464 | xl_memories = [*xl_memories[1:], xl_memories[0]]
465 |
466 | # value residual
467 |
468 | value_residual = None
469 |
470 | # expand streams for hyper connections
471 |
472 | x = self.expand_streams(x)
473 |
474 | # attention and feedforward
475 |
476 | for attn, ff in self.layers:
477 | x, xl_memories, attn_values = attn(x, mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb, value_residual = value_residual)
478 |
479 | value_residual = default(value_residual, attn_values)
480 | new_xl_memories.append(xl_memories)
481 |
482 | x = ff(x)
483 |
484 | # reduce streams for hyper connections
485 |
486 | x = self.reduce_streams(x)
487 |
488 | # final norm
489 |
490 | x = self.norm(x)
491 |
492 | # whether to return xl memories
493 |
494 | next_xl_memories = None
495 |
496 | if self.use_xl_memories:
497 | next_xl_memories = list(map(lambda t: torch.detach(t[..., -self.xl_mem_len:, :]), new_xl_memories))
498 |
499 | # split out memories using unpack
500 |
501 | read_memories, x, write_memories = unpack(x, ps, 'b * d')
502 |
503 | # to logits
504 |
505 | logits = self.to_logits(x)
506 |
507 | if not return_loss:
508 | return logits, write_memories, next_xl_memories
509 |
510 | loss = F.cross_entropy(
511 | rearrange(logits, 'b n c -> b c n'),
512 | labels,
513 | ignore_index = self.ignore_index
514 | )
515 |
516 | return loss, write_memories, next_xl_memories
517 |
518 | # wrapper to manage many segments
519 |
520 | class RecurrentMemoryTransformerWrapper(Module):
521 | def __init__(
522 | self,
523 | transformer: RecurrentMemoryTransformer,
524 | truncate_at_step = None # number of steps before detaching memories (truncated bptt). with memory replay checkpointing, there should be no memory issues, but in case of instability, as reported in initial paper
525 | ):
526 | super().__init__()
527 | self.transformer = transformer
528 | self.seq_len = transformer.seq_len
529 | self.truncate_at_step = truncate_at_step
530 |
531 | @torch.no_grad()
532 | @eval_decorator
533 | def generate(
534 | self,
535 | prime,
536 | *,
537 | length,
538 | memories = None,
539 | xl_memories: list[Tensor] | None = None,
540 | temperature = 1.,
541 | filter_thres = 0.9
542 | ):
543 | assert self.transformer.causal, 'only autoregressive transformers can generate'
544 |
545 | start_len, seq_len = prime.shape[-1], self.seq_len
546 |
547 | assert length >= start_len
548 |
549 | *past_segments, curr_segment = prime.split(seq_len, dim = -1)
550 |
551 | # catch memories up to the current segment
552 |
553 | for past_segment in past_segments:
554 | _, memories, xl_memories = self.transformer(past_segment, memories, xl_memories = xl_memories)
555 |
556 | # sample for the remaining length
557 |
558 | for ind in range(length - start_len):
559 | logits, next_memories, next_xl_memories = self.transformer(curr_segment, memories, xl_memories = xl_memories)
560 |
561 | logits = logits[:, -1]
562 |
563 | filtered_logits = top_k(logits, thres = filter_thres)
564 | sampled = gumbel_sample(filtered_logits, temperature = temperature)
565 | sampled = rearrange(sampled, 'b -> b 1')
566 |
567 | curr_segment = torch.cat((curr_segment, sampled), dim = -1)
568 |
569 | if divisible_by(curr_segment.shape[-1] - 1, seq_len):
570 | memories = next_memories
571 | xl_memories = next_xl_memories
572 |
573 | past_segment, curr_segment = curr_segment[..., :seq_len], curr_segment[..., -1:]
574 | past_segments.append(past_segment)
575 |
576 | # add current segment to all segments
577 |
578 | past_segments.append(curr_segment)
579 |
580 | # reconcat all segments
581 |
582 | output = torch.cat(past_segments, dim = -1)
583 |
584 | output = output[:, start_len:]
585 | return output
586 |
587 | def forward(
588 | self,
589 | x,
590 | memories = None,
591 | *,
592 | mask = None,
593 | xl_memories: list[Tensor] | None = None,
594 | return_loss = False,
595 | labels = None,
596 | truncate_at_step = None, # if set, this would override the truncate_at_step at init
597 | memory_replay_backprop = False, # whether to have the class do the backwards pass memory efficiently
598 | mrbp_loss_weight = 1. # if using memory replay backprop with gradient accumulation, scale loss by this factor ex. (1. / )
599 | ):
600 | seq_len, truncate_at_step = self.seq_len, default(truncate_at_step, self.truncate_at_step)
601 |
602 | labels = None
603 | if (return_loss or memory_replay_backprop) and not exists(labels):
604 | x, labels = x[:, :-1], x[:, 1:]
605 |
606 | # segment input
607 |
608 | segments = x.split(seq_len, dim = -1)
609 | total_length = x.shape[-1]
610 | num_segments = len(segments)
611 | segment_length_frac = tuple(map(lambda t: t.shape[-1] / total_length, segments))
612 |
613 | # default values
614 |
615 | label_segments = mask_segments = (None,)
616 |
617 | # take care of labels
618 |
619 | if exists(labels):
620 | label_segments = labels.split(seq_len, dim = -1)
621 |
622 | # take care of the mask
623 |
624 | if exists(mask):
625 | mask_segments = mask.split(seq_len, dim = -1)
626 |
627 | # keep replay buffer
628 |
629 | replay_buffer = [memories]
630 |
631 | # replay buffer for xl memories
632 |
633 | xl_segments = [xl_memories]
634 |
635 | # decide context of forward depending on whether doing memory-replay-backprop
636 |
637 | forward_context = nullcontext if not memory_replay_backprop else torch.no_grad
638 |
639 | # forward and get all outputs (can be either loss or logits)
640 |
641 | logits = []
642 | losses = []
643 |
644 | for step, (segment, mask_segment, label_segment, loss_weight) in enumerate(zip_longest(segments, mask_segments, label_segments, segment_length_frac)):
645 |
646 | with forward_context():
647 | output, memories, xl_memories = self.transformer(segment, memories, mask = mask_segment, labels = label_segment)
648 |
649 | if exists(truncate_at_step) and divisible_by(step + 1, truncate_at_step):
650 | memories = memories.detach()
651 |
652 | replay_buffer.append(memories)
653 |
654 | xl_segments.append(xl_memories)
655 |
656 | if return_loss:
657 | losses.append(output * loss_weight)
658 | else:
659 | logits.append(output)
660 |
661 | # whether to do memory replay backpropagation
662 |
663 | # https://arxiv.org/abs/2010.06891
664 | # algorithm 1
665 |
666 | if memory_replay_backprop:
667 | memories_grad = torch.zeros_like(replay_buffer[-1])
668 |
669 | reversed_inputs = zip_longest(*map(reversed, [
670 | range(num_segments),
671 | segments,
672 | replay_buffer[:-1],
673 | xl_segments[:-1],
674 | mask_segments,
675 | label_segments,
676 | segment_length_frac,
677 | ]))
678 |
679 | total_loss = 0.
680 |
681 | for step, segment, segment_memories, segment_xl_memories, mask_segment, label_segment, loss_weight in reversed_inputs:
682 | is_first = step == 0
683 |
684 | if exists(segment_memories):
685 | segment_memories.requires_grad_()
686 |
687 | loss, next_segment_memories, _ = self.transformer(segment, segment_memories, mask = mask_segment, xl_memories = segment_xl_memories, labels = label_segment)
688 |
689 | weighted_loss = loss * loss_weight * mrbp_loss_weight
690 |
691 | weighted_loss.backward(retain_graph = True)
692 |
693 | next_segment_memories.backward(memories_grad)
694 |
695 | total_loss += weighted_loss
696 |
697 | if is_first:
698 | continue
699 |
700 | if exists(truncate_at_step) and divisible_by(step, truncate_at_step):
701 | memories_grad.zero_()
702 | else:
703 | memories_grad.copy_(segment_memories.grad.data)
704 |
705 | return total_loss
706 |
707 | # return logits if needed
708 |
709 | if not return_loss:
710 | logits = torch.cat(logits, dim = -2)
711 | return logits, memories
712 |
713 | # otherwise return losses
714 |
715 | return sum(losses), memories
716 |
--------------------------------------------------------------------------------
/rmt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/recurrent-memory-transformer-pytorch/520a3574c5a00e452d2af3fb1c26f15a3779c8bb/rmt.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'recurrent-memory-transformer-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.7.0',
7 | license='MIT',
8 | description = 'Recurrent Memory Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/recurrent-memory-transformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'recurrence',
19 | 'memory',
20 | 'long-context'
21 | ],
22 | install_requires=[
23 | 'einops>=0.8.0',
24 | 'hyper-connections>=0.1.7',
25 | 'torch>=2.3',
26 | ],
27 | classifiers=[
28 | 'Development Status :: 4 - Beta',
29 | 'Intended Audience :: Developers',
30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
31 | 'License :: OSI Approved :: MIT License',
32 | 'Programming Language :: Python :: 3.6',
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import random
3 | import tqdm
4 | import numpy as np
5 |
6 | import torch
7 | from torch.optim import Adam
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer, RecurrentMemoryTransformerWrapper
12 |
13 | # constants
14 |
15 | NUM_BATCHES = int(1e5)
16 | BATCH_SIZE = 4
17 | GRADIENT_ACCUMULATE_EVERY = 4
18 | LEARNING_RATE = 1e-4
19 | VALIDATE_EVERY = 100
20 | PRIME_LENGTH = 128
21 | GENERATE_EVERY = 250
22 | GENERATE_LENGTH = 2048
23 | SEQ_LEN = 2048
24 |
25 | # helpers
26 |
27 | def cycle(loader):
28 | while True:
29 | for data in loader:
30 | yield data
31 |
32 | def decode_token(token):
33 | return str(chr(max(32, token)))
34 |
35 | def decode_tokens(tokens):
36 | return "".join(list(map(decode_token, tokens)))
37 |
38 |
39 | # instantiate palm
40 |
41 | model = RecurrentMemoryTransformer(
42 | num_tokens = 256,
43 | dim = 512,
44 | depth = 6,
45 | dim_head = 64,
46 | heads = 8,
47 | seq_len = 512,
48 | use_flash_attn = True,
49 | num_memory_tokens = 128,
50 | use_xl_memories = True,
51 | xl_mem_len = 256
52 | )
53 |
54 | model = RecurrentMemoryTransformerWrapper(model)
55 |
56 | model.cuda()
57 |
58 | # prepare enwik8 data
59 |
60 | with gzip.open("./data/enwik8.gz") as file:
61 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
62 | np_train, np_valid = np.split(data, [int(90e6)])
63 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
64 |
65 | class TextSamplerDataset(Dataset):
66 | def __init__(self, data, seq_len):
67 | super().__init__()
68 | self.data = data
69 | self.seq_len = seq_len
70 |
71 | def __getitem__(self, index):
72 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
73 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
74 | return full_seq.cuda()
75 |
76 | def __len__(self):
77 | return self.data.size(0) // self.seq_len
78 |
79 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
80 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
81 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
82 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
83 |
84 | # optimizer
85 |
86 | optim = Adam(model.parameters(), lr = LEARNING_RATE)
87 |
88 | # training
89 |
90 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
91 | model.train()
92 |
93 | total_loss = 0.
94 | for _ in range(GRADIENT_ACCUMULATE_EVERY):
95 | loss = model(
96 | next(train_loader),
97 | memory_replay_backprop = True,
98 | mrbp_loss_weight = 1. / GRADIENT_ACCUMULATE_EVERY
99 | )
100 |
101 | total_loss += loss
102 |
103 | print(f"training loss: {total_loss.item()}")
104 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
105 |
106 | optim.step()
107 | optim.zero_grad()
108 |
109 | if i % VALIDATE_EVERY == 0:
110 | model.eval()
111 | with torch.no_grad():
112 | loss, _ = model(next(val_loader), return_loss = True)
113 | print(f"validation loss: {loss.item()}")
114 |
115 | if i % GENERATE_EVERY == 0:
116 | model.eval()
117 | inp = random.choice(val_dataset)[:PRIME_LENGTH]
118 | prime = decode_tokens(inp)
119 | print(f"%s \n\n %s", (prime, "*" * 100))
120 |
121 | sample = model.generate(inp[None, :], length = GENERATE_LENGTH)
122 | output_str = decode_tokens(sample[0])
123 | print(output_str, "\n")
124 |
--------------------------------------------------------------------------------