├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── diagram.png
├── memorizing_transformers_pytorch
├── __init__.py
├── knn_memory.py
└── memorizing_transformers_pytorch.py
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Memorizing Transformers - Pytorch
4 |
5 | Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch
6 |
7 | This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install memorizing-transformers-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | from memorizing_transformers_pytorch import MemorizingTransformer
20 |
21 | model = MemorizingTransformer(
22 | num_tokens = 20000, # number of tokens
23 | dim = 512, # dimension
24 | dim_head = 64, # dimension per attention head
25 | depth = 8, # number of layers
26 | memorizing_layers = (4, 5), # which layers to have ANN memories
27 | max_knn_memories = 64000, # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries)
28 | num_retrieved_memories = 32, # number of ANN memories to retrieve
29 | clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration
30 | )
31 |
32 | data = torch.randint(0, 20000, (2, 1024)) # mock data
33 |
34 | knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example)
35 |
36 | logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000)
37 | ```
38 |
39 | You can make the KNN memories read-only by setting `add_knn_memory` on forward to `False`
40 |
41 | ex.
42 |
43 | ```python
44 | logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated
45 | ```
46 |
47 | With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory)
48 |
49 | ```python
50 | import torch
51 | from memorizing_transformers_pytorch import MemorizingTransformer
52 |
53 | model = MemorizingTransformer(
54 | num_tokens = 20000,
55 | dim = 512,
56 | depth = 8,
57 | memorizing_layers = (4, 5),
58 | max_knn_memories = 64000,
59 | num_retrieved_memories = 32,
60 | clear_memories_on_sos_token_id = 1,
61 | xl_memory_layers = (2, 3, 4, 5), # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory
62 | xl_max_memories = 512, # number of xl memories to keep
63 | shift_knn_memories_down = 1, # let a layer look at the KNN memories this number of layers above
64 | shift_xl_memories_down = 1, # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper
65 | )
66 |
67 | data = torch.randint(0, 20000, (2, 1024)) # mock data
68 |
69 | xl_memories = None
70 |
71 | with model.knn_memories_context(batch_size = 2) as knn_memories:
72 | logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
73 | logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
74 | logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories)
75 |
76 | # ... and so on
77 | ```
78 |
79 | ## KNN Memory
80 |
81 | This repository contains a wrapper around Faiss that can automatically store and retrieve key / values
82 |
83 | ```python
84 | import torch
85 | from memorizing_transformers_pytorch import KNNMemory
86 |
87 | memory = KNNMemory(
88 | dim = 64, # dimension of key / values
89 | max_memories = 64000, # maximum number of memories to keep (will throw out the oldest memories for now if it overfills)
90 | num_indices = 2 # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document
91 | )
92 |
93 | memory.add(torch.randn(2, 512, 2, 64)) # (batch, seq, key | value, feature dim)
94 | memory.add(torch.randn(2, 512, 2, 64))
95 |
96 | memory.clear([0]) # clear batch 0, if it saw an
97 |
98 | memory.add(torch.randn(2, 512, 2, 64))
99 | memory.add(torch.randn(2, 512, 2, 64))
100 |
101 | key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32)
102 | ```
103 |
104 | ## Training
105 |
106 | Enwik8 training
107 |
108 | ```bash
109 | $ python train.py
110 | ```
111 |
112 | ## Todo
113 |
114 | - [x] switch to ivfhnsw and just remember all memories
115 | - [x] enwik8 demo
116 | - [x] validation for enwik8
117 | - [x] solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array
118 | - [ ] setup text generation with memories
119 | - [ ] figure out how to deal with memories efficiently once capacity has been hit
120 | - [ ] try to speed up reading and writing to knn memories collection with multiprocessing
121 |
122 | ## Citations
123 |
124 | ```bibtex
125 | @article{wu2022memorizing,
126 | title = {Memorizing transformers},
127 | author = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian},
128 | journal = {arXiv preprint arXiv:2203.08913},
129 | year = {2022}
130 | }
131 | ```
132 |
133 | ```bibtex
134 | @article{Shazeer2019FastTD,
135 | title = {Fast Transformer Decoding: One Write-Head is All You Need},
136 | author = {Noam M. Shazeer},
137 | journal = {ArXiv},
138 | year = {2019},
139 | volume = {abs/1911.02150}
140 | }
141 | ```
142 |
143 | ```bibtex
144 | @Article{AlphaFold2021,
145 | author = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
146 | journal = {Nature},
147 | title = {Highly accurate protein structure prediction with {AlphaFold}},
148 | year = {2021},
149 | doi = {10.1038/s41586-021-03819-2},
150 | note = {(Accelerated article preview)},
151 | }
152 | ```
153 |
154 | ```bibtex
155 | @inproceedings{Rae2020DoTN,
156 | title = {Do Transformers Need Deep Long-Range Memory?},
157 | author = {Jack W. Rae and Ali Razavi},
158 | booktitle = {ACL},
159 | year = {2020}
160 | }
161 | ```
162 |
163 | ```bibtex
164 | @misc{ding2021erniedoc,
165 | title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
166 | author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
167 | year = {2021},
168 | eprint = {2012.15688},
169 | archivePrefix = {arXiv},
170 | primaryClass = {cs.CL}
171 | }
172 | ```
173 |
174 | ```bibtex
175 | @misc{henry2020querykey,
176 | title = {Query-Key Normalization for Transformers},
177 | author = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
178 | year = {2020},
179 | eprint = {2010.04245},
180 | archivePrefix = {arXiv},
181 | primaryClass = {cs.CL}
182 | }
183 | ```
184 |
185 | *Memory is Attention through Time* - Alex Graves
186 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/memorizing-transformers-pytorch/272e39bafd2a507d21ac896bd7cf4b593ee9acb7/data/enwik8.gz
--------------------------------------------------------------------------------
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/memorizing-transformers-pytorch/272e39bafd2a507d21ac896bd7cf4b593ee9acb7/diagram.png
--------------------------------------------------------------------------------
/memorizing_transformers_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from memorizing_transformers_pytorch.memorizing_transformers_pytorch import MemorizingTransformer, KNNAttention
2 | from memorizing_transformers_pytorch.knn_memory import KNNMemory
3 |
--------------------------------------------------------------------------------
/memorizing_transformers_pytorch/knn_memory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import faiss
5 | import numpy as np
6 | from pathlib import Path
7 | from functools import wraps
8 |
9 | from contextlib import ExitStack, contextmanager
10 |
11 | from einops import rearrange, pack, unpack
12 |
13 | # multiprocessing
14 |
15 | from joblib import Parallel, delayed, cpu_count
16 |
17 | # constants
18 |
19 | FAISS_INDEX_GPU_ID = int(os.getenv('FAISS_INDEX_GPU_ID', 0))
20 |
21 | DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY = './.tmp/knn.memories'
22 |
23 | # helper functions
24 |
25 | def exists(val):
26 | return val is not None
27 |
28 | def default(val, d):
29 | return val if exists(val) else d
30 |
31 | def cast_list(val):
32 | return val if isinstance(val, list) else [val]
33 |
34 | def all_el_unique(arr):
35 | return len(set(arr)) == len(arr)
36 |
37 | @contextmanager
38 | def multi_context(*cms):
39 | with ExitStack() as stack:
40 | yield [stack.enter_context(cls) for cls in cms]
41 |
42 | def count_intersect(x, y):
43 | # returns an array that shows how many times an element in x is contained in tensor y
44 | return np.sum(rearrange(x, 'i -> i 1') == rearrange(y, 'j -> 1 j'), axis = -1)
45 |
46 | def check_shape(tensor, pattern, **kwargs):
47 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs)
48 |
49 | # a wrapper around faiss IndexIVFFlat
50 | # taking care of expiring old keys automagically
51 |
52 | class KNN():
53 | def __init__(
54 | self,
55 | dim,
56 | max_num_entries,
57 | cap_num_entries = False,
58 | M = 15,
59 | keep_stats = False
60 | ):
61 | index = faiss.IndexHNSWFlat(dim, M, faiss.METRIC_INNER_PRODUCT)
62 | self.index = index
63 | self.max_num_entries = max_num_entries
64 | self.cap_num_entries = cap_num_entries
65 | self.is_trained = False
66 | self.keep_stats = keep_stats
67 |
68 | self.reset()
69 |
70 | def __del__(self):
71 | if hasattr(self, 'index'):
72 | del self.index
73 |
74 | def reset(self):
75 | self.ids = np.empty((0,), dtype = np.int32)
76 |
77 | if self.keep_stats:
78 | self.hits = np.empty((0,), dtype = np.int32)
79 | self.age_num_iterations = np.empty((0,), dtype = np.int32)
80 | self.ages_since_last_hit = np.empty((0,), dtype = np.int32)
81 |
82 | self.index.reset()
83 | self.is_trained = False
84 |
85 | def train(self, x):
86 | self.index.train(x)
87 | self.is_trained = True
88 |
89 | def add(self, x, ids):
90 | if not self.is_trained:
91 | self.train(x)
92 |
93 | self.ids = np.concatenate((ids, self.ids))
94 |
95 | if self.keep_stats:
96 | self.hits = np.concatenate((np.zeros_like(ids), self.hits))
97 | self.age_num_iterations = np.concatenate((np.zeros_like(ids), self.age_num_iterations))
98 | self.ages_since_last_hit = np.concatenate((np.zeros_like(ids), self.ages_since_last_hit))
99 |
100 | if self.cap_num_entries and len(self.ids) > self.max_num_entries:
101 | self.reset()
102 |
103 | return self.index.add(x)
104 |
105 | def search(
106 | self,
107 | x,
108 | topk,
109 | nprobe = 8,
110 | return_distances = False,
111 | increment_hits = False,
112 | increment_age = True
113 | ):
114 | if not self.is_trained:
115 | return np.full((x.shape[0], topk), -1)
116 |
117 | distances, indices = self.index.search(x, k = topk)
118 |
119 | if increment_hits and self.keep_stats:
120 | hits = count_intersect(self.ids, rearrange(indices, '... -> (...)'))
121 | self.hits += hits
122 |
123 | self.ages_since_last_hit += 1
124 | self.ages_since_last_hit *= (hits == 0)
125 |
126 | if increment_age and self.keep_stats:
127 | self.age_num_iterations += 1
128 |
129 | if return_distances:
130 | return indices, distances
131 |
132 | return indices
133 |
134 | # KNN memory layer, where one can store key / value memories
135 | # can automatically take care of a collection of faiss indices (across batch dimension)
136 |
137 | class KNNMemory():
138 | def __init__(
139 | self,
140 | dim,
141 | max_memories = 16000,
142 | num_indices = 1,
143 | memmap_filename = './knn.memory.memmap',
144 | multiprocessing = True
145 | ):
146 | self.dim = dim
147 | self.num_indices = num_indices
148 | self.scoped_indices = list(range(num_indices))
149 |
150 | self.max_memories = max_memories
151 | self.shape = (num_indices, max_memories, 2, dim)
152 | self.db_offsets = np.zeros(num_indices, dtype = np.int32)
153 |
154 | self.db = np.memmap(memmap_filename, mode = 'w+', dtype = np.float32, shape = self.shape)
155 | self.knns = [KNN(dim = dim, max_num_entries = max_memories, cap_num_entries = True) for _ in range(num_indices)]
156 |
157 | self.n_jobs = cpu_count() if multiprocessing else 1
158 |
159 | def set_scoped_indices(self, indices):
160 | indices = list(indices)
161 | assert all_el_unique(indices), f'all scoped batch indices must be unique, received: {indices}'
162 | assert all([0 <= i < self.num_indices for i in indices]), f'each batch index must be between 0 and less than {self.num_indices}: received {indices}'
163 | self.scoped_indices = indices
164 |
165 | @contextmanager
166 | def at_batch_indices(self, indices):
167 | prev_indices = self.scoped_indices
168 | self.set_scoped_indices(indices)
169 | yield self
170 | self.set_scoped_indices(prev_indices)
171 |
172 | def clear(self, batch_indices = None):
173 | if not exists(batch_indices):
174 | batch_indices = list(range(self.num_indices))
175 |
176 | batch_indices = cast_list(batch_indices)
177 |
178 | for index in batch_indices:
179 | knn = self.knns[index]
180 | knn.reset()
181 |
182 | self.db_offsets[batch_indices] = 0
183 |
184 | def add(self, memories):
185 | check_shape(memories, 'b n kv d', d = self.dim, kv = 2, b = len(self.scoped_indices))
186 |
187 | memories = memories.detach().cpu().numpy()
188 | memories = memories[:, -self.max_memories:]
189 | num_memories = memories.shape[1]
190 |
191 | knn_insert_ids = np.arange(num_memories)
192 |
193 | keys = np.ascontiguousarray(memories[..., 0, :])
194 | knns = [self.knns[i] for i in self.scoped_indices]
195 | db_offsets = [self.db_offsets[i] for i in self.scoped_indices]
196 |
197 | # use joblib to insert new key / value memories into faiss index
198 |
199 | @delayed
200 | def knn_add(knn, key, db_offset):
201 | knn.add(key, ids = knn_insert_ids + db_offset)
202 | return knn
203 |
204 | updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets))
205 | for knn_idx, scoped_idx in enumerate(self.scoped_indices):
206 | self.knns[scoped_idx] = updated_knns[knn_idx]
207 |
208 | # add the new memories to the memmap "database"
209 |
210 | add_indices = (rearrange(np.arange(num_memories), 'j -> 1 j') + rearrange(self.db_offsets[list(self.scoped_indices)], 'i -> i 1')) % self.max_memories
211 | self.db[rearrange(np.array(self.scoped_indices), 'i -> i 1'), add_indices] = memories
212 | self.db.flush()
213 |
214 | self.db_offsets += num_memories
215 |
216 | def search(
217 | self,
218 | queries,
219 | topk,
220 | nprobe = 8,
221 | increment_hits = True,
222 | increment_age = True
223 | ):
224 | check_shape(queries, 'b ... d', d = self.dim, b = len(self.scoped_indices))
225 | queries, ps = pack([queries], 'b * d')
226 |
227 | device = queries.device
228 | queries = queries.detach().cpu().numpy()
229 |
230 | all_masks = []
231 | all_key_values = []
232 |
233 | knns = [self.knns[i] for i in self.scoped_indices]
234 |
235 | # parallelize faiss search
236 |
237 | @delayed
238 | def knn_search(knn, query):
239 | return knn.search(query, topk, nprobe, increment_hits = increment_hits, increment_age = increment_age)
240 |
241 | fetched_indices = Parallel(n_jobs = self.n_jobs)(knn_search(*args) for args in zip(knns, queries))
242 |
243 | # get all the memory key / values from memmap 'database'
244 | # todo - remove for loop below
245 |
246 | for batch_index, indices in zip(self.scoped_indices, fetched_indices):
247 | mask = indices != -1
248 | db_indices = np.where(mask, indices, 0)
249 |
250 | all_masks.append(torch.from_numpy(mask))
251 |
252 | key_values = self.db[batch_index, db_indices % self.max_memories]
253 | all_key_values.append(torch.from_numpy(key_values))
254 |
255 | all_masks = torch.stack(all_masks)
256 | all_key_values = torch.stack(all_key_values)
257 | all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.)
258 |
259 | all_key_values, = unpack(all_key_values, ps, 'b * n kv d')
260 | all_masks, = unpack(all_masks, ps, 'b * n')
261 |
262 | return all_key_values.to(device), all_masks.to(device)
263 |
264 | def __del__(self):
265 | if hasattr(self, 'knns'):
266 | for knn in self.knns:
267 | del knn
268 | del self.db
269 |
270 | # extends list with some extra methods for collections of KNN memories
271 |
272 | class KNNMemoryList(list):
273 | def cleanup(self):
274 | for memory in self:
275 | del memory
276 |
277 | @classmethod
278 | def create_memories(
279 | self,
280 | *,
281 | batch_size,
282 | num_memory_layers,
283 | memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
284 | ):
285 | memories_path = Path(memories_directory)
286 | memories_path.mkdir(exist_ok = True, parents = True)
287 |
288 | def inner(*args, **kwargs):
289 | return self([KNNMemory(*args, num_indices = batch_size, memmap_filename = str(memories_path / f'knn.memory.layer.{ind + 1}.memmap'), **kwargs) for ind in range(num_memory_layers)])
290 | return inner
291 |
292 | @contextmanager
293 | def at_batch_indices(
294 | self,
295 | indices
296 | ):
297 | knn_batch_indices_contexts = [memory.at_batch_indices(indices) for memory in self]
298 | with multi_context(*knn_batch_indices_contexts):
299 | yield
300 |
301 | def clear_memory(
302 | self,
303 | batch_indices = None,
304 | memory_indices = None
305 | ):
306 | memory_indices = default(memory_indices, tuple(range(len(self))))
307 |
308 | for memory_index in memory_indices:
309 | memory = self[memory_index]
310 | memory.clear(batch_indices)
311 |
--------------------------------------------------------------------------------
/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 | from contextlib import contextmanager
4 | from pathlib import Path
5 | from filelock import FileLock
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn, einsum
10 |
11 | from einops import rearrange, repeat
12 | from einops_exts import repeat_many
13 | from einops.layers.torch import Rearrange
14 |
15 | from memorizing_transformers_pytorch.knn_memory import KNNMemoryList, DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY
16 |
17 | # helper functions
18 |
19 | def identity(t):
20 | return t
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 | def unique(arr):
26 | return list({el: True for el in arr}.keys())
27 |
28 | def default(val, d):
29 | return val if exists(val) else d
30 |
31 | def cast_tuple(val, length = 1):
32 | return val if isinstance(val, tuple) else ((val,) * length)
33 |
34 | def l2norm(t):
35 | return F.normalize(t, dim = -1)
36 |
37 | # helper classes
38 |
39 | class PreNormResidual(nn.Module):
40 | def __init__(self, dim, fn):
41 | super().__init__()
42 | self.fn = fn
43 | self.norm = nn.LayerNorm(dim)
44 |
45 | def forward(self, x, **kwargs):
46 | out = self.fn(self.norm(x), **kwargs)
47 |
48 | if not isinstance(out, tuple):
49 | return out + x
50 |
51 | head, *tail = out
52 | return (head + x, *tail)
53 |
54 | # t5 relative positional bias
55 |
56 | class T5RelativePositionBias(nn.Module):
57 | def __init__(
58 | self,
59 | scale,
60 | num_buckets = 32,
61 | max_distance = 128,
62 | heads = 8
63 | ):
64 | super().__init__()
65 | self.scale = scale
66 | self.num_buckets = num_buckets
67 | self.max_distance = max_distance
68 | self.relative_attention_bias = nn.Embedding(num_buckets, heads)
69 |
70 | @staticmethod
71 | def _relative_position_bucket(
72 | relative_position,
73 | num_buckets = 32,
74 | max_distance = 128
75 | ):
76 | n = -relative_position
77 | n = torch.max(n, torch.zeros_like(n))
78 |
79 | max_exact = num_buckets // 2
80 | is_small = n < max_exact
81 |
82 | val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
83 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
84 | return torch.where(is_small, n, val_if_large)
85 |
86 | def forward(self, i, j, *, device):
87 | q_pos = torch.arange(i, dtype = torch.long, device = device)
88 | k_pos = torch.arange(j, dtype = torch.long, device = device)
89 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
90 | rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
91 | values = self.relative_attention_bias(rp_bucket)
92 | bias = rearrange(values, 'i j h -> () h i j')
93 | return bias * self.scale
94 |
95 | # feedforward
96 |
97 | class FeedForward(nn.Module):
98 | def __init__(self, dim, mult = 4, dropout = 0.):
99 | super().__init__()
100 | self.net = nn.Sequential(
101 | nn.Linear(dim, dim * mult),
102 | nn.GELU(),
103 | nn.Dropout(dropout),
104 | nn.Linear(dim * mult, dim)
105 | )
106 |
107 | def forward(self, x):
108 | return self.net(x)
109 |
110 | # attention
111 |
112 | class Attention(nn.Module):
113 | def __init__(
114 | self,
115 | *,
116 | dim,
117 | heads = 8,
118 | dim_head = 64,
119 | dropout = 0.,
120 | xl_max_memories = 0.,
121 | ):
122 | super().__init__()
123 | self.heads = heads
124 | self.scale = dim_head ** -0.5
125 | inner_dim = heads * dim_head
126 | self.xl_max_memories = xl_max_memories
127 |
128 | self.dropout = nn.Dropout(dropout)
129 |
130 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
131 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
132 | self.to_out = nn.Linear(inner_dim, dim)
133 |
134 | def forward(self, x, *, xl_memory = None, rel_pos_bias = None):
135 | h, device = self.heads, x.device
136 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
137 |
138 | q = rearrange(q, 'b n (h d) -> b h n d', h = h)
139 |
140 | q = q * self.scale
141 |
142 | if exists(xl_memory):
143 | k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2)
144 | k = torch.cat((k_xl_mem, k), dim = -2)
145 | v = torch.cat((v_xl_mem, v), dim = -2)
146 |
147 | sim = einsum('b h i d, b j d -> b h i j', q, k)
148 | i, j = sim.shape[-2:]
149 |
150 | if exists(rel_pos_bias):
151 | sim = rel_pos_bias[..., -i:, -j:] + sim
152 |
153 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
154 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
155 |
156 | attn = sim.softmax(dim = -1)
157 | attn = self.dropout(attn)
158 |
159 | out = einsum('b h i j, b j d -> b h i d', attn, v)
160 | out = rearrange(out, 'b h n d -> b n (h d)')
161 |
162 | # new xl memories
163 |
164 | new_kv_memories = torch.stack((k, v), dim = -2).detach()
165 |
166 | if self.xl_max_memories > 0:
167 | new_xl_kv_memories = new_kv_memories[:, -self.xl_max_memories:]
168 | else:
169 | new_xl_kv_memories = None
170 |
171 | return self.to_out(out), new_xl_kv_memories
172 |
173 | # approximate nearest neighbor attention
174 |
175 | class KNNAttention(nn.Module):
176 | def __init__(
177 | self,
178 | *,
179 | dim,
180 | heads = 8,
181 | dim_head = 64,
182 | dropout = 0.,
183 | num_retrieved_memories = 32,
184 | xl_max_memories = 0.,
185 | attn_scale_init = 20,
186 | gate_output = False
187 | ):
188 | super().__init__()
189 | self.heads = heads
190 | self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))
191 |
192 | inner_dim = heads * dim_head
193 | self.xl_max_memories = xl_max_memories
194 |
195 | self.num_retrieved_memories = num_retrieved_memories
196 |
197 | self.dropout = nn.Dropout(dropout)
198 | self.knn_mem_dropout = nn.Dropout(dropout)
199 |
200 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
201 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
202 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
203 |
204 | self.output_gate = nn.Parameter(torch.zeros(1)) if gate_output else None
205 |
206 | def forward(
207 | self,
208 | x,
209 | *,
210 | knn_memory,
211 | xl_memory = None,
212 | add_knn_memory = True,
213 | rel_pos_bias = None
214 | ):
215 | b, n, h, device = *x.shape[:2], self.heads, x.device
216 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
217 |
218 | q = rearrange(q, 'b n (h d) -> b h n d', h = h)
219 |
220 | # in paper, they showed normalizing of keys led to more stable training
221 | # we'll just go with full cosine sim attention https://arxiv.org/abs/2010.04245
222 |
223 | q, k = map(l2norm, (q, k))
224 |
225 | # handle xl memory
226 |
227 | if exists(xl_memory):
228 | k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2)
229 | k = torch.cat((k_xl_mem, k), dim = -2)
230 | v = torch.cat((v_xl_mem, v), dim = -2)
231 |
232 | # calculate local attention
233 |
234 | scale = self.scale.exp()
235 |
236 | sim = einsum('b h i d, b j d -> b h i j', q, k) * scale
237 | i, j = sim.shape[-2:]
238 |
239 | if exists(rel_pos_bias):
240 | sim = rel_pos_bias[..., -i:, -j:] + sim
241 |
242 | mask_value = -torch.finfo(sim.dtype).max
243 |
244 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
245 | sim = sim.masked_fill(causal_mask, mask_value)
246 |
247 | # calculate knn attention over memory, if index is passed in
248 |
249 | mem_kv, mem_mask = knn_memory.search(q, self.num_retrieved_memories)
250 | mem_k, mem_v = mem_kv.unbind(dim = -2)
251 |
252 | sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale
253 | sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)
254 |
255 | # calculate new XL memories, as well as memories to be discarded
256 |
257 | new_kv_memories = torch.stack((k, v), dim = -2).detach()
258 |
259 | if self.xl_max_memories > 0:
260 | new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories[:, :-self.xl_max_memories], new_kv_memories[:, -self.xl_max_memories:]
261 | else:
262 | new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories, None
263 |
264 | # add memories to be discarded into KNN memory
265 |
266 | if add_knn_memory and new_kv_memories_discarded.numel() > 0:
267 | knn_memory.add(new_kv_memories_discarded)
268 |
269 | # attention (combining local and distant)
270 |
271 | sim = torch.cat((sim_mem, sim), dim = -1)
272 | attn = sim.softmax(dim = -1)
273 | attn = self.dropout(attn)
274 |
275 | local_attn, mem_attn = attn[..., self.num_retrieved_memories:], attn[..., :self.num_retrieved_memories]
276 | local_out = einsum('b h i j, b j d -> b h i d', local_attn, v)
277 | mem_out = einsum('b h i j, b h i j d -> b h i d', mem_attn, mem_v)
278 |
279 | out = local_out + mem_out
280 |
281 | # combine heads and project out
282 |
283 | out = rearrange(out, 'b h n d -> b n (h d)')
284 | out = self.to_out(out)
285 |
286 | # use flamingo styled gating of output, so that memorizing transformers can be gated into an existing LLM
287 | # preparation to add this to block-recurrent-transformer-pytorch, for the pinnacle of long context attention network
288 |
289 | if exists(self.output_gate):
290 | out = out * self.output_gate.tanh()
291 |
292 | return out, new_xl_kv_memories
293 |
294 | # main class
295 |
296 | class MemorizingTransformer(nn.Module):
297 | def __init__(
298 | self,
299 | *,
300 | num_tokens,
301 | dim,
302 | depth,
303 | dim_head = 64,
304 | heads = 8,
305 | knn_attn_heads = None,
306 | attn_dropout = 0.,
307 | ff_mult = 4,
308 | ff_dropout = 0.,
309 | memorizing_layers = None,
310 | max_knn_memories = 250000,
311 | num_retrieved_memories = 32,
312 | clear_memories_on_sos_token_id = None,
313 | clear_memories_on_eos_token_id = None,
314 | knn_memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY,
315 | shift_knn_memories_down = 0.,
316 | pad_id = 0,
317 | xl_max_memories = 0,
318 | xl_memory_layers = None,
319 | shift_xl_memories_down = 0.,
320 | knn_memory_multiprocessing = False
321 | ):
322 | super().__init__()
323 | self.token_emb = nn.Embedding(num_tokens, dim)
324 | self.pad_id = pad_id
325 |
326 | block_wrapper = partial(PreNormResidual, dim)
327 | valid_layers = set(range(1, depth + 1))
328 |
329 | memorizing_layers = default(memorizing_layers, (depth // 2,)) # default KNN attention layer to midpoint of transformer
330 | memorizing_layers = cast_tuple(memorizing_layers)
331 | memorizing_layers = tuple(filter(lambda i: i in valid_layers, memorizing_layers))
332 |
333 | self.dim_head = dim_head
334 |
335 | knn_attn_heads = default(knn_attn_heads, heads)
336 |
337 | # xl memory hyperparameter
338 |
339 | if xl_max_memories > 0:
340 | xl_memory_layers = default(xl_memory_layers, tuple(range(1, depth + 1)))
341 | xl_memory_layers = unique(xl_memory_layers)
342 | self.xl_memory_layers = tuple(filter(lambda i: i in valid_layers, xl_memory_layers))
343 | self.num_xl_memory_layers = len(self.xl_memory_layers)
344 | else:
345 | self.xl_memory_layers = tuple()
346 | self.num_xl_memory_layers = 0
347 |
348 | # knn memory hyperparameters
349 |
350 | self.max_knn_memories = max_knn_memories
351 | self.knn_memories_directory = knn_memories_directory
352 | self.memorizing_layers = unique(memorizing_layers)
353 | self.num_memory_layers = len(memorizing_layers)
354 |
355 | self.clear_memories_on_sos_token_id = clear_memories_on_sos_token_id
356 | self.clear_memories_on_eos_token_id = clear_memories_on_eos_token_id
357 |
358 | # relative positional bias
359 |
360 | self.rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)
361 | self.knn_rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)
362 |
363 | # layers
364 |
365 | self.layers = nn.ModuleList([])
366 | for idx in range(depth):
367 | layer_num = idx + 1
368 |
369 | use_xl_memories = layer_num in self.xl_memory_layers
370 | use_knn_attention = layer_num in memorizing_layers
371 | xl_max_memories_layer = 0 if not use_xl_memories else xl_max_memories
372 |
373 | if use_knn_attention:
374 | attn = KNNAttention(dim = dim, dim_head = dim_head, heads = knn_attn_heads, dropout = attn_dropout, num_retrieved_memories = num_retrieved_memories, xl_max_memories = xl_max_memories_layer)
375 | else:
376 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, xl_max_memories = xl_max_memories_layer)
377 |
378 | self.layers.append(nn.ModuleList([
379 | block_wrapper(attn),
380 | block_wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)),
381 | ]))
382 |
383 | # memory layer shifting
384 | # from a little known paper https://arxiv.org/abs/2012.15688
385 |
386 | self.shift_knn_memories_down = shift_knn_memories_down
387 | self.shift_xl_memories_down = shift_xl_memories_down
388 |
389 | # to logits
390 |
391 | self.to_logits = nn.Sequential(
392 | nn.LayerNorm(dim),
393 | nn.Linear(dim, num_tokens)
394 | )
395 |
396 | # knn memories init
397 |
398 | self.knn_mem_kwargs = dict(
399 | dim = self.dim_head,
400 | max_memories = self.max_knn_memories,
401 | multiprocessing = knn_memory_multiprocessing
402 | )
403 |
404 | def create_knn_memories(
405 | self,
406 | *,
407 | batch_size
408 | ):
409 | return KNNMemoryList.create_memories(
410 | batch_size = batch_size,
411 | num_memory_layers = self.num_memory_layers,
412 | memories_directory = self.knn_memories_directory,
413 | )(**self.knn_mem_kwargs)
414 |
415 | @contextmanager
416 | def knn_memories_context(
417 | self,
418 | **kwargs
419 | ):
420 | knn_dir = Path(self.knn_memories_directory)
421 | knn_dir.mkdir(exist_ok = True, parents = True)
422 | lock = FileLock(str(knn_dir / 'mutex'))
423 |
424 | with lock:
425 | knn_memories = self.create_knn_memories(**kwargs)
426 | yield knn_memories
427 | knn_memories.cleanup()
428 |
429 | def clear_memory(self, x, token_id):
430 | """ clears the KNN memories based on if the batch row contains the specified token id """
431 | """ for auto-clearing KNN memories based on start and end of strings """
432 |
433 | clear_memory = (x == token_id).any(dim = -1)
434 | batch_indices, _ = clear_memory.nonzero(as_tuple = True)
435 | batch_indices_to_clear = batch_indices.tolist()
436 |
437 | if len(batch_indices_to_clear) == 0:
438 | return
439 |
440 | knn_memories.clear_memory(batch_indices_to_clear)
441 |
442 | def forward(
443 | self,
444 | x,
445 | knn_memories,
446 | xl_memories = None,
447 | labels = None,
448 | add_knn_memory = True
449 | ):
450 | batch_size, seq_len, *_, device = *x.shape, x.device
451 | x = self.token_emb(x)
452 |
453 | # validate KNN memories to have enough indices for batch size
454 |
455 | assert all([memory.num_indices == batch_size for memory in knn_memories]), f'you passed in an input with batch size {batch_size} but your memories were not instantiated with that number of KNN indices'
456 |
457 | # if KNN memories are passed in, and researcher wants memories auto-cleared on token detection
458 | # do the appropriate logic
459 |
460 | if exists(self.clear_memories_on_sos_token_id):
461 | self.clear_memory(x, self.clear_memories_on_sos_token_id)
462 |
463 | # handle XL memories
464 |
465 | xl_memories = default(xl_memories, (None,) * self.num_xl_memory_layers)
466 | assert len(xl_memories) == self.num_xl_memory_layers
467 | has_xl_memories = len(xl_memories) > 0
468 |
469 | # shifting memories a number of layers down, little known technique shown to enhance memories from Ernie-Doc paper
470 |
471 | if len(knn_memories) > 0 and self.shift_knn_memories_down > 0:
472 | knn_memories = [*knn_memories[self.shift_knn_memories_down:], *knn_memories[:self.shift_knn_memories_down]]
473 |
474 | if len(xl_memories) > 0 and self.shift_xl_memories_down > 0:
475 | xl_memories = [*xl_memories[self.shift_xl_memories_down:], *xl_memories[:self.shift_xl_memories_down]]
476 |
477 | # iterate through the memories in order of the ascending layers that contain KNNAttention
478 |
479 | xl_memories_iter = iter(xl_memories)
480 | knn_memories_iter = iter(knn_memories)
481 |
482 | # positional bias
483 |
484 | max_context_len = max([seq_len, *map(lambda t: (t.shape[-3] if exists(t) else 0) + seq_len, xl_memories)])
485 |
486 | rel_pos_bias = self.rel_pos_bias(seq_len, max_context_len, device = device)
487 | knn_rel_pos_bias = self.knn_rel_pos_bias(seq_len, max_context_len, device = device)
488 |
489 | # keep track of new xl memories
490 |
491 | new_xl_memories = [] if has_xl_memories else None
492 |
493 | # go through all layers
494 |
495 | for ind, (attn, ff) in enumerate(self.layers):
496 | layer_num = ind + 1
497 |
498 | is_memorizing_layer = layer_num in self.memorizing_layers
499 | is_xl_memory_layer = layer_num in self.xl_memory_layers
500 |
501 | attn_kwargs = dict(rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias)
502 |
503 | if is_memorizing_layer:
504 | attn_kwargs = {**attn_kwargs, 'knn_memory': next(knn_memories_iter), 'add_knn_memory': add_knn_memory}
505 |
506 | if is_xl_memory_layer:
507 | attn_kwargs = {**attn_kwargs, 'xl_memory': next(xl_memories_iter)}
508 |
509 | # attention
510 |
511 | x, xl_mem = attn(x, **attn_kwargs)
512 |
513 | # add new XL memories if needed
514 |
515 | if exists(xl_mem):
516 | new_xl_memories.append(xl_mem)
517 |
518 | # feedforward
519 |
520 | x = ff(x)
521 |
522 | # to logits
523 |
524 | logits = self.to_logits(x)
525 |
526 | # auto-clear KNN memories on end of string token
527 |
528 | if exists(self.clear_memories_on_eos_token_id):
529 | self.clear_memory(x, self.clear_memories_on_eos_token_id)
530 |
531 | # for training
532 |
533 | if not exists(labels):
534 | if exists(new_xl_memories):
535 | return logits, new_xl_memories
536 |
537 | return logits
538 |
539 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = self.pad_id)
540 |
541 | if exists(new_xl_memories):
542 | return loss, new_xl_memories
543 |
544 | return loss
545 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'memorizing-transformers-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.4.1',
7 | license='MIT',
8 | description = 'Memorizing Transformer - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/memorizing-transformers-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'memory',
18 | 'retrieval'
19 | ],
20 | install_requires=[
21 | 'einops>=0.6',
22 | 'filelock',
23 | 'joblib',
24 | 'faiss-gpu',
25 | 'numpy',
26 | 'torch>=1.6',
27 | ],
28 | classifiers=[
29 | 'Development Status :: 4 - Beta',
30 | 'Intended Audience :: Developers',
31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
32 | 'License :: OSI Approved :: MIT License',
33 | 'Programming Language :: Python :: 3.6',
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from memorizing_transformers_pytorch import MemorizingTransformer
2 |
3 | import random
4 | import tqdm
5 | import gzip
6 | import numpy as np
7 | import torch
8 | import torch.optim as optim
9 | from torch.nn import functional as F
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 | # constants
13 |
14 | NUM_BATCHES = int(1e5)
15 | BATCH_SIZE = 16
16 | SEQ_LEN = 512
17 | SEGMENTS = 5
18 |
19 | LEARNING_RATE = 2e-4
20 | MAX_GRAD_CLIP_NORM = 0.5
21 |
22 | VALIDATE_EVERY = 100
23 | GENERATE_EVERY = 500
24 | GENERATE_LENGTH = 512
25 |
26 | # helpers
27 |
28 | def cycle(loader):
29 | while True:
30 | for data in loader:
31 | yield data
32 |
33 | def decode_token(token):
34 | return str(chr(max(32, token)))
35 |
36 | def decode_tokens(tokens):
37 | return ''.join(list(map(decode_token, tokens)))
38 |
39 | # instantiate GPT-like decoder model
40 |
41 | model = MemorizingTransformer(
42 | num_tokens = 256,
43 | dim = 512,
44 | depth = 8,
45 | memorizing_layers = 4,
46 | max_knn_memories = 512 * 15,
47 | num_retrieved_memories = 32,
48 | xl_memory_layers = (7, 8),
49 | xl_max_memories = 512,
50 | ).cuda()
51 |
52 | # prepare enwik8 data
53 |
54 | with gzip.open('./data/enwik8.gz') as file:
55 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
56 | trX, vaX = np.split(X, [int(90e6)])
57 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
58 |
59 | class TextSamplerDataset(Dataset):
60 | def __init__(self, data, seq_len):
61 | super().__init__()
62 | self.data = data
63 | self.seq_len = seq_len
64 |
65 | def __getitem__(self, index):
66 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
67 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
68 | return full_seq.cuda()
69 |
70 | def __len__(self):
71 | return self.data.size(0) // self.seq_len
72 |
73 | # dataset and dataloader
74 |
75 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS)
76 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
77 | valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS)
78 | valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True))
79 |
80 | # optimizer
81 |
82 | optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
83 |
84 | # training
85 |
86 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
87 | model.train()
88 |
89 | data = next(train_loader)
90 |
91 | train_loss = 0.
92 | with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
93 | xl_memories = None
94 | seq, labels = data[:, :-1], data[:, 1:]
95 |
96 | for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
97 | loss, xl_memories = model(
98 | seq_segment,
99 | labels = labels_segment,
100 | knn_memories = knn_memories,
101 | xl_memories = xl_memories
102 | )
103 |
104 | train_loss += loss.item() / SEGMENTS
105 | (loss / SEGMENTS).backward()
106 |
107 | print(f'training loss: {train_loss}')
108 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
109 | optim.step()
110 | optim.zero_grad()
111 |
112 | if not (i % VALIDATE_EVERY):
113 | model.eval()
114 |
115 | valid_data = next(valid_loader)
116 | valid_loss = 0.
117 |
118 | with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories:
119 | xl_memories = None
120 | seq, labels = data[:, :-1], data[:, 1:]
121 |
122 | for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):
123 | loss, xl_memories = model(
124 | seq_segment,
125 | labels = labels_segment,
126 | knn_memories = knn_memories,
127 | xl_memories = xl_memories
128 | )
129 |
130 | valid_loss += loss.item() / SEGMENTS
131 |
132 | print(f'valid loss: {valid_loss}')
133 |
--------------------------------------------------------------------------------