├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── diagram.png ├── memorizing_transformers_pytorch ├── __init__.py ├── knn_memory.py └── memorizing_transformers_pytorch.py ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Memorizing Transformers - Pytorch 4 | 5 | Implementation of Memorizing Transformers (ICLR 2022), attention net augmented with indexing and retrieval of memories using approximate nearest neighbors, in Pytorch 6 | 7 | This repository deviates from the paper slightly, using a hybrid attention across attention logits local and distant (rather than the sigmoid gate setup). It also uses cosine similarity attention (with learned temperature) for the KNN attention layer. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install memorizing-transformers-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | from memorizing_transformers_pytorch import MemorizingTransformer 20 | 21 | model = MemorizingTransformer( 22 | num_tokens = 20000, # number of tokens 23 | dim = 512, # dimension 24 | dim_head = 64, # dimension per attention head 25 | depth = 8, # number of layers 26 | memorizing_layers = (4, 5), # which layers to have ANN memories 27 | max_knn_memories = 64000, # maximum ANN memories to keep (once it hits this capacity, it will be reset for now, due to limitations in faiss' ability to remove entries) 28 | num_retrieved_memories = 32, # number of ANN memories to retrieve 29 | clear_memories_on_sos_token_id = 1, # clear passed in ANN memories automatically for batch indices which contain this specified SOS token id - otherwise, you can also manually iterate through the ANN memories and clear the indices before the next iteration 30 | ) 31 | 32 | data = torch.randint(0, 20000, (2, 1024)) # mock data 33 | 34 | knn_memories = model.create_knn_memories(batch_size = 2) # create collection of KNN memories with the correct batch size (2 in example) 35 | 36 | logits = model(data, knn_memories = knn_memories) # (1, 1024, 20000) 37 | ``` 38 | 39 | You can make the KNN memories read-only by setting `add_knn_memory` on forward to `False` 40 | 41 | ex. 42 | 43 | ```python 44 | logits = model(data, knn_memories = knn_memories, add_knn_memory = False) # knn memories will not be updated 45 | ``` 46 | 47 | With Transformer-XL memories (only the memories that will be discarded will be added to the KNN memory) 48 | 49 | ```python 50 | import torch 51 | from memorizing_transformers_pytorch import MemorizingTransformer 52 | 53 | model = MemorizingTransformer( 54 | num_tokens = 20000, 55 | dim = 512, 56 | depth = 8, 57 | memorizing_layers = (4, 5), 58 | max_knn_memories = 64000, 59 | num_retrieved_memories = 32, 60 | clear_memories_on_sos_token_id = 1, 61 | xl_memory_layers = (2, 3, 4, 5), # xl memory layers - (https://arxiv.org/abs/2007.03356 shows you do not need XL memory on all layers, just the latter ones) - if a KNNAttention layer ends up using XL memories, only the XL memories that will be discarded will be added to long term memory 62 | xl_max_memories = 512, # number of xl memories to keep 63 | shift_knn_memories_down = 1, # let a layer look at the KNN memories this number of layers above 64 | shift_xl_memories_down = 1, # let a layer look at the XL memories this number of layers above, shown to enhance receptive field in ernie-doc paper 65 | ) 66 | 67 | data = torch.randint(0, 20000, (2, 1024)) # mock data 68 | 69 | xl_memories = None 70 | 71 | with model.knn_memories_context(batch_size = 2) as knn_memories: 72 | logits1, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories) 73 | logits2, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories) 74 | logits3, xl_memories = model(data, knn_memories = knn_memories, xl_memories = xl_memories) 75 | 76 | # ... and so on 77 | ``` 78 | 79 | ## KNN Memory 80 | 81 | This repository contains a wrapper around Faiss that can automatically store and retrieve key / values 82 | 83 | ```python 84 | import torch 85 | from memorizing_transformers_pytorch import KNNMemory 86 | 87 | memory = KNNMemory( 88 | dim = 64, # dimension of key / values 89 | max_memories = 64000, # maximum number of memories to keep (will throw out the oldest memories for now if it overfills) 90 | num_indices = 2 # this should be equivalent to batch dimension, as each batch keeps track of its own memories, expiring when it sees a new document 91 | ) 92 | 93 | memory.add(torch.randn(2, 512, 2, 64)) # (batch, seq, key | value, feature dim) 94 | memory.add(torch.randn(2, 512, 2, 64)) 95 | 96 | memory.clear([0]) # clear batch 0, if it saw an 97 | 98 | memory.add(torch.randn(2, 512, 2, 64)) 99 | memory.add(torch.randn(2, 512, 2, 64)) 100 | 101 | key_values, mask = memory.search(torch.randn(2, 512, 64), topk = 32) 102 | ``` 103 | 104 | ## Training 105 | 106 | Enwik8 training 107 | 108 | ```bash 109 | $ python train.py 110 | ``` 111 | 112 | ## Todo 113 | 114 | - [x] switch to ivfhnsw and just remember all memories 115 | - [x] enwik8 demo 116 | - [x] validation for enwik8 117 | - [x] solve gradient accumulation problem by offering some way to scope reads and writes to knn memories with another indices array 118 | - [ ] setup text generation with memories 119 | - [ ] figure out how to deal with memories efficiently once capacity has been hit 120 | - [ ] try to speed up reading and writing to knn memories collection with multiprocessing 121 | 122 | ## Citations 123 | 124 | ```bibtex 125 | @article{wu2022memorizing, 126 | title = {Memorizing transformers}, 127 | author = {Wu, Yuhuai and Rabe, Markus N and Hutchins, DeLesley and Szegedy, Christian}, 128 | journal = {arXiv preprint arXiv:2203.08913}, 129 | year = {2022} 130 | } 131 | ``` 132 | 133 | ```bibtex 134 | @article{Shazeer2019FastTD, 135 | title = {Fast Transformer Decoding: One Write-Head is All You Need}, 136 | author = {Noam M. Shazeer}, 137 | journal = {ArXiv}, 138 | year = {2019}, 139 | volume = {abs/1911.02150} 140 | } 141 | ``` 142 | 143 | ```bibtex 144 | @Article{AlphaFold2021, 145 | author = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis}, 146 | journal = {Nature}, 147 | title = {Highly accurate protein structure prediction with {AlphaFold}}, 148 | year = {2021}, 149 | doi = {10.1038/s41586-021-03819-2}, 150 | note = {(Accelerated article preview)}, 151 | } 152 | ``` 153 | 154 | ```bibtex 155 | @inproceedings{Rae2020DoTN, 156 | title = {Do Transformers Need Deep Long-Range Memory?}, 157 | author = {Jack W. Rae and Ali Razavi}, 158 | booktitle = {ACL}, 159 | year = {2020} 160 | } 161 | ``` 162 | 163 | ```bibtex 164 | @misc{ding2021erniedoc, 165 | title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, 166 | author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang}, 167 | year = {2021}, 168 | eprint = {2012.15688}, 169 | archivePrefix = {arXiv}, 170 | primaryClass = {cs.CL} 171 | } 172 | ``` 173 | 174 | ```bibtex 175 | @misc{henry2020querykey, 176 | title = {Query-Key Normalization for Transformers}, 177 | author = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen}, 178 | year = {2020}, 179 | eprint = {2010.04245}, 180 | archivePrefix = {arXiv}, 181 | primaryClass = {cs.CL} 182 | } 183 | ``` 184 | 185 | *Memory is Attention through Time* - Alex Graves 186 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/memorizing-transformers-pytorch/272e39bafd2a507d21ac896bd7cf4b593ee9acb7/data/enwik8.gz -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/memorizing-transformers-pytorch/272e39bafd2a507d21ac896bd7cf4b593ee9acb7/diagram.png -------------------------------------------------------------------------------- /memorizing_transformers_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from memorizing_transformers_pytorch.memorizing_transformers_pytorch import MemorizingTransformer, KNNAttention 2 | from memorizing_transformers_pytorch.knn_memory import KNNMemory 3 | -------------------------------------------------------------------------------- /memorizing_transformers_pytorch/knn_memory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import faiss 5 | import numpy as np 6 | from pathlib import Path 7 | from functools import wraps 8 | 9 | from contextlib import ExitStack, contextmanager 10 | 11 | from einops import rearrange, pack, unpack 12 | 13 | # multiprocessing 14 | 15 | from joblib import Parallel, delayed, cpu_count 16 | 17 | # constants 18 | 19 | FAISS_INDEX_GPU_ID = int(os.getenv('FAISS_INDEX_GPU_ID', 0)) 20 | 21 | DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY = './.tmp/knn.memories' 22 | 23 | # helper functions 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(val, d): 29 | return val if exists(val) else d 30 | 31 | def cast_list(val): 32 | return val if isinstance(val, list) else [val] 33 | 34 | def all_el_unique(arr): 35 | return len(set(arr)) == len(arr) 36 | 37 | @contextmanager 38 | def multi_context(*cms): 39 | with ExitStack() as stack: 40 | yield [stack.enter_context(cls) for cls in cms] 41 | 42 | def count_intersect(x, y): 43 | # returns an array that shows how many times an element in x is contained in tensor y 44 | return np.sum(rearrange(x, 'i -> i 1') == rearrange(y, 'j -> 1 j'), axis = -1) 45 | 46 | def check_shape(tensor, pattern, **kwargs): 47 | return rearrange(tensor, f"{pattern} -> {pattern}", **kwargs) 48 | 49 | # a wrapper around faiss IndexIVFFlat 50 | # taking care of expiring old keys automagically 51 | 52 | class KNN(): 53 | def __init__( 54 | self, 55 | dim, 56 | max_num_entries, 57 | cap_num_entries = False, 58 | M = 15, 59 | keep_stats = False 60 | ): 61 | index = faiss.IndexHNSWFlat(dim, M, faiss.METRIC_INNER_PRODUCT) 62 | self.index = index 63 | self.max_num_entries = max_num_entries 64 | self.cap_num_entries = cap_num_entries 65 | self.is_trained = False 66 | self.keep_stats = keep_stats 67 | 68 | self.reset() 69 | 70 | def __del__(self): 71 | if hasattr(self, 'index'): 72 | del self.index 73 | 74 | def reset(self): 75 | self.ids = np.empty((0,), dtype = np.int32) 76 | 77 | if self.keep_stats: 78 | self.hits = np.empty((0,), dtype = np.int32) 79 | self.age_num_iterations = np.empty((0,), dtype = np.int32) 80 | self.ages_since_last_hit = np.empty((0,), dtype = np.int32) 81 | 82 | self.index.reset() 83 | self.is_trained = False 84 | 85 | def train(self, x): 86 | self.index.train(x) 87 | self.is_trained = True 88 | 89 | def add(self, x, ids): 90 | if not self.is_trained: 91 | self.train(x) 92 | 93 | self.ids = np.concatenate((ids, self.ids)) 94 | 95 | if self.keep_stats: 96 | self.hits = np.concatenate((np.zeros_like(ids), self.hits)) 97 | self.age_num_iterations = np.concatenate((np.zeros_like(ids), self.age_num_iterations)) 98 | self.ages_since_last_hit = np.concatenate((np.zeros_like(ids), self.ages_since_last_hit)) 99 | 100 | if self.cap_num_entries and len(self.ids) > self.max_num_entries: 101 | self.reset() 102 | 103 | return self.index.add(x) 104 | 105 | def search( 106 | self, 107 | x, 108 | topk, 109 | nprobe = 8, 110 | return_distances = False, 111 | increment_hits = False, 112 | increment_age = True 113 | ): 114 | if not self.is_trained: 115 | return np.full((x.shape[0], topk), -1) 116 | 117 | distances, indices = self.index.search(x, k = topk) 118 | 119 | if increment_hits and self.keep_stats: 120 | hits = count_intersect(self.ids, rearrange(indices, '... -> (...)')) 121 | self.hits += hits 122 | 123 | self.ages_since_last_hit += 1 124 | self.ages_since_last_hit *= (hits == 0) 125 | 126 | if increment_age and self.keep_stats: 127 | self.age_num_iterations += 1 128 | 129 | if return_distances: 130 | return indices, distances 131 | 132 | return indices 133 | 134 | # KNN memory layer, where one can store key / value memories 135 | # can automatically take care of a collection of faiss indices (across batch dimension) 136 | 137 | class KNNMemory(): 138 | def __init__( 139 | self, 140 | dim, 141 | max_memories = 16000, 142 | num_indices = 1, 143 | memmap_filename = './knn.memory.memmap', 144 | multiprocessing = True 145 | ): 146 | self.dim = dim 147 | self.num_indices = num_indices 148 | self.scoped_indices = list(range(num_indices)) 149 | 150 | self.max_memories = max_memories 151 | self.shape = (num_indices, max_memories, 2, dim) 152 | self.db_offsets = np.zeros(num_indices, dtype = np.int32) 153 | 154 | self.db = np.memmap(memmap_filename, mode = 'w+', dtype = np.float32, shape = self.shape) 155 | self.knns = [KNN(dim = dim, max_num_entries = max_memories, cap_num_entries = True) for _ in range(num_indices)] 156 | 157 | self.n_jobs = cpu_count() if multiprocessing else 1 158 | 159 | def set_scoped_indices(self, indices): 160 | indices = list(indices) 161 | assert all_el_unique(indices), f'all scoped batch indices must be unique, received: {indices}' 162 | assert all([0 <= i < self.num_indices for i in indices]), f'each batch index must be between 0 and less than {self.num_indices}: received {indices}' 163 | self.scoped_indices = indices 164 | 165 | @contextmanager 166 | def at_batch_indices(self, indices): 167 | prev_indices = self.scoped_indices 168 | self.set_scoped_indices(indices) 169 | yield self 170 | self.set_scoped_indices(prev_indices) 171 | 172 | def clear(self, batch_indices = None): 173 | if not exists(batch_indices): 174 | batch_indices = list(range(self.num_indices)) 175 | 176 | batch_indices = cast_list(batch_indices) 177 | 178 | for index in batch_indices: 179 | knn = self.knns[index] 180 | knn.reset() 181 | 182 | self.db_offsets[batch_indices] = 0 183 | 184 | def add(self, memories): 185 | check_shape(memories, 'b n kv d', d = self.dim, kv = 2, b = len(self.scoped_indices)) 186 | 187 | memories = memories.detach().cpu().numpy() 188 | memories = memories[:, -self.max_memories:] 189 | num_memories = memories.shape[1] 190 | 191 | knn_insert_ids = np.arange(num_memories) 192 | 193 | keys = np.ascontiguousarray(memories[..., 0, :]) 194 | knns = [self.knns[i] for i in self.scoped_indices] 195 | db_offsets = [self.db_offsets[i] for i in self.scoped_indices] 196 | 197 | # use joblib to insert new key / value memories into faiss index 198 | 199 | @delayed 200 | def knn_add(knn, key, db_offset): 201 | knn.add(key, ids = knn_insert_ids + db_offset) 202 | return knn 203 | 204 | updated_knns = Parallel(n_jobs = self.n_jobs)(knn_add(*args) for args in zip(knns, keys, db_offsets)) 205 | for knn_idx, scoped_idx in enumerate(self.scoped_indices): 206 | self.knns[scoped_idx] = updated_knns[knn_idx] 207 | 208 | # add the new memories to the memmap "database" 209 | 210 | add_indices = (rearrange(np.arange(num_memories), 'j -> 1 j') + rearrange(self.db_offsets[list(self.scoped_indices)], 'i -> i 1')) % self.max_memories 211 | self.db[rearrange(np.array(self.scoped_indices), 'i -> i 1'), add_indices] = memories 212 | self.db.flush() 213 | 214 | self.db_offsets += num_memories 215 | 216 | def search( 217 | self, 218 | queries, 219 | topk, 220 | nprobe = 8, 221 | increment_hits = True, 222 | increment_age = True 223 | ): 224 | check_shape(queries, 'b ... d', d = self.dim, b = len(self.scoped_indices)) 225 | queries, ps = pack([queries], 'b * d') 226 | 227 | device = queries.device 228 | queries = queries.detach().cpu().numpy() 229 | 230 | all_masks = [] 231 | all_key_values = [] 232 | 233 | knns = [self.knns[i] for i in self.scoped_indices] 234 | 235 | # parallelize faiss search 236 | 237 | @delayed 238 | def knn_search(knn, query): 239 | return knn.search(query, topk, nprobe, increment_hits = increment_hits, increment_age = increment_age) 240 | 241 | fetched_indices = Parallel(n_jobs = self.n_jobs)(knn_search(*args) for args in zip(knns, queries)) 242 | 243 | # get all the memory key / values from memmap 'database' 244 | # todo - remove for loop below 245 | 246 | for batch_index, indices in zip(self.scoped_indices, fetched_indices): 247 | mask = indices != -1 248 | db_indices = np.where(mask, indices, 0) 249 | 250 | all_masks.append(torch.from_numpy(mask)) 251 | 252 | key_values = self.db[batch_index, db_indices % self.max_memories] 253 | all_key_values.append(torch.from_numpy(key_values)) 254 | 255 | all_masks = torch.stack(all_masks) 256 | all_key_values = torch.stack(all_key_values) 257 | all_key_values = all_key_values.masked_fill(~rearrange(all_masks, '... -> ... 1 1'), 0.) 258 | 259 | all_key_values, = unpack(all_key_values, ps, 'b * n kv d') 260 | all_masks, = unpack(all_masks, ps, 'b * n') 261 | 262 | return all_key_values.to(device), all_masks.to(device) 263 | 264 | def __del__(self): 265 | if hasattr(self, 'knns'): 266 | for knn in self.knns: 267 | del knn 268 | del self.db 269 | 270 | # extends list with some extra methods for collections of KNN memories 271 | 272 | class KNNMemoryList(list): 273 | def cleanup(self): 274 | for memory in self: 275 | del memory 276 | 277 | @classmethod 278 | def create_memories( 279 | self, 280 | *, 281 | batch_size, 282 | num_memory_layers, 283 | memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY 284 | ): 285 | memories_path = Path(memories_directory) 286 | memories_path.mkdir(exist_ok = True, parents = True) 287 | 288 | def inner(*args, **kwargs): 289 | return self([KNNMemory(*args, num_indices = batch_size, memmap_filename = str(memories_path / f'knn.memory.layer.{ind + 1}.memmap'), **kwargs) for ind in range(num_memory_layers)]) 290 | return inner 291 | 292 | @contextmanager 293 | def at_batch_indices( 294 | self, 295 | indices 296 | ): 297 | knn_batch_indices_contexts = [memory.at_batch_indices(indices) for memory in self] 298 | with multi_context(*knn_batch_indices_contexts): 299 | yield 300 | 301 | def clear_memory( 302 | self, 303 | batch_indices = None, 304 | memory_indices = None 305 | ): 306 | memory_indices = default(memory_indices, tuple(range(len(self)))) 307 | 308 | for memory_index in memory_indices: 309 | memory = self[memory_index] 310 | memory.clear(batch_indices) 311 | -------------------------------------------------------------------------------- /memorizing_transformers_pytorch/memorizing_transformers_pytorch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from contextlib import contextmanager 4 | from pathlib import Path 5 | from filelock import FileLock 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn, einsum 10 | 11 | from einops import rearrange, repeat 12 | from einops_exts import repeat_many 13 | from einops.layers.torch import Rearrange 14 | 15 | from memorizing_transformers_pytorch.knn_memory import KNNMemoryList, DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY 16 | 17 | # helper functions 18 | 19 | def identity(t): 20 | return t 21 | 22 | def exists(val): 23 | return val is not None 24 | 25 | def unique(arr): 26 | return list({el: True for el in arr}.keys()) 27 | 28 | def default(val, d): 29 | return val if exists(val) else d 30 | 31 | def cast_tuple(val, length = 1): 32 | return val if isinstance(val, tuple) else ((val,) * length) 33 | 34 | def l2norm(t): 35 | return F.normalize(t, dim = -1) 36 | 37 | # helper classes 38 | 39 | class PreNormResidual(nn.Module): 40 | def __init__(self, dim, fn): 41 | super().__init__() 42 | self.fn = fn 43 | self.norm = nn.LayerNorm(dim) 44 | 45 | def forward(self, x, **kwargs): 46 | out = self.fn(self.norm(x), **kwargs) 47 | 48 | if not isinstance(out, tuple): 49 | return out + x 50 | 51 | head, *tail = out 52 | return (head + x, *tail) 53 | 54 | # t5 relative positional bias 55 | 56 | class T5RelativePositionBias(nn.Module): 57 | def __init__( 58 | self, 59 | scale, 60 | num_buckets = 32, 61 | max_distance = 128, 62 | heads = 8 63 | ): 64 | super().__init__() 65 | self.scale = scale 66 | self.num_buckets = num_buckets 67 | self.max_distance = max_distance 68 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 69 | 70 | @staticmethod 71 | def _relative_position_bucket( 72 | relative_position, 73 | num_buckets = 32, 74 | max_distance = 128 75 | ): 76 | n = -relative_position 77 | n = torch.max(n, torch.zeros_like(n)) 78 | 79 | max_exact = num_buckets // 2 80 | is_small = n < max_exact 81 | 82 | val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() 83 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 84 | return torch.where(is_small, n, val_if_large) 85 | 86 | def forward(self, i, j, *, device): 87 | q_pos = torch.arange(i, dtype = torch.long, device = device) 88 | k_pos = torch.arange(j, dtype = torch.long, device = device) 89 | rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') 90 | rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) 91 | values = self.relative_attention_bias(rp_bucket) 92 | bias = rearrange(values, 'i j h -> () h i j') 93 | return bias * self.scale 94 | 95 | # feedforward 96 | 97 | class FeedForward(nn.Module): 98 | def __init__(self, dim, mult = 4, dropout = 0.): 99 | super().__init__() 100 | self.net = nn.Sequential( 101 | nn.Linear(dim, dim * mult), 102 | nn.GELU(), 103 | nn.Dropout(dropout), 104 | nn.Linear(dim * mult, dim) 105 | ) 106 | 107 | def forward(self, x): 108 | return self.net(x) 109 | 110 | # attention 111 | 112 | class Attention(nn.Module): 113 | def __init__( 114 | self, 115 | *, 116 | dim, 117 | heads = 8, 118 | dim_head = 64, 119 | dropout = 0., 120 | xl_max_memories = 0., 121 | ): 122 | super().__init__() 123 | self.heads = heads 124 | self.scale = dim_head ** -0.5 125 | inner_dim = heads * dim_head 126 | self.xl_max_memories = xl_max_memories 127 | 128 | self.dropout = nn.Dropout(dropout) 129 | 130 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 131 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 132 | self.to_out = nn.Linear(inner_dim, dim) 133 | 134 | def forward(self, x, *, xl_memory = None, rel_pos_bias = None): 135 | h, device = self.heads, x.device 136 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 137 | 138 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 139 | 140 | q = q * self.scale 141 | 142 | if exists(xl_memory): 143 | k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2) 144 | k = torch.cat((k_xl_mem, k), dim = -2) 145 | v = torch.cat((v_xl_mem, v), dim = -2) 146 | 147 | sim = einsum('b h i d, b j d -> b h i j', q, k) 148 | i, j = sim.shape[-2:] 149 | 150 | if exists(rel_pos_bias): 151 | sim = rel_pos_bias[..., -i:, -j:] + sim 152 | 153 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 154 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 155 | 156 | attn = sim.softmax(dim = -1) 157 | attn = self.dropout(attn) 158 | 159 | out = einsum('b h i j, b j d -> b h i d', attn, v) 160 | out = rearrange(out, 'b h n d -> b n (h d)') 161 | 162 | # new xl memories 163 | 164 | new_kv_memories = torch.stack((k, v), dim = -2).detach() 165 | 166 | if self.xl_max_memories > 0: 167 | new_xl_kv_memories = new_kv_memories[:, -self.xl_max_memories:] 168 | else: 169 | new_xl_kv_memories = None 170 | 171 | return self.to_out(out), new_xl_kv_memories 172 | 173 | # approximate nearest neighbor attention 174 | 175 | class KNNAttention(nn.Module): 176 | def __init__( 177 | self, 178 | *, 179 | dim, 180 | heads = 8, 181 | dim_head = 64, 182 | dropout = 0., 183 | num_retrieved_memories = 32, 184 | xl_max_memories = 0., 185 | attn_scale_init = 20, 186 | gate_output = False 187 | ): 188 | super().__init__() 189 | self.heads = heads 190 | self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init)) 191 | 192 | inner_dim = heads * dim_head 193 | self.xl_max_memories = xl_max_memories 194 | 195 | self.num_retrieved_memories = num_retrieved_memories 196 | 197 | self.dropout = nn.Dropout(dropout) 198 | self.knn_mem_dropout = nn.Dropout(dropout) 199 | 200 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 201 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 202 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 203 | 204 | self.output_gate = nn.Parameter(torch.zeros(1)) if gate_output else None 205 | 206 | def forward( 207 | self, 208 | x, 209 | *, 210 | knn_memory, 211 | xl_memory = None, 212 | add_knn_memory = True, 213 | rel_pos_bias = None 214 | ): 215 | b, n, h, device = *x.shape[:2], self.heads, x.device 216 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 217 | 218 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 219 | 220 | # in paper, they showed normalizing of keys led to more stable training 221 | # we'll just go with full cosine sim attention https://arxiv.org/abs/2010.04245 222 | 223 | q, k = map(l2norm, (q, k)) 224 | 225 | # handle xl memory 226 | 227 | if exists(xl_memory): 228 | k_xl_mem, v_xl_mem = xl_memory.unbind(dim = -2) 229 | k = torch.cat((k_xl_mem, k), dim = -2) 230 | v = torch.cat((v_xl_mem, v), dim = -2) 231 | 232 | # calculate local attention 233 | 234 | scale = self.scale.exp() 235 | 236 | sim = einsum('b h i d, b j d -> b h i j', q, k) * scale 237 | i, j = sim.shape[-2:] 238 | 239 | if exists(rel_pos_bias): 240 | sim = rel_pos_bias[..., -i:, -j:] + sim 241 | 242 | mask_value = -torch.finfo(sim.dtype).max 243 | 244 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) 245 | sim = sim.masked_fill(causal_mask, mask_value) 246 | 247 | # calculate knn attention over memory, if index is passed in 248 | 249 | mem_kv, mem_mask = knn_memory.search(q, self.num_retrieved_memories) 250 | mem_k, mem_v = mem_kv.unbind(dim = -2) 251 | 252 | sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale 253 | sim_mem = sim_mem.masked_fill(~mem_mask, mask_value) 254 | 255 | # calculate new XL memories, as well as memories to be discarded 256 | 257 | new_kv_memories = torch.stack((k, v), dim = -2).detach() 258 | 259 | if self.xl_max_memories > 0: 260 | new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories[:, :-self.xl_max_memories], new_kv_memories[:, -self.xl_max_memories:] 261 | else: 262 | new_kv_memories_discarded, new_xl_kv_memories = new_kv_memories, None 263 | 264 | # add memories to be discarded into KNN memory 265 | 266 | if add_knn_memory and new_kv_memories_discarded.numel() > 0: 267 | knn_memory.add(new_kv_memories_discarded) 268 | 269 | # attention (combining local and distant) 270 | 271 | sim = torch.cat((sim_mem, sim), dim = -1) 272 | attn = sim.softmax(dim = -1) 273 | attn = self.dropout(attn) 274 | 275 | local_attn, mem_attn = attn[..., self.num_retrieved_memories:], attn[..., :self.num_retrieved_memories] 276 | local_out = einsum('b h i j, b j d -> b h i d', local_attn, v) 277 | mem_out = einsum('b h i j, b h i j d -> b h i d', mem_attn, mem_v) 278 | 279 | out = local_out + mem_out 280 | 281 | # combine heads and project out 282 | 283 | out = rearrange(out, 'b h n d -> b n (h d)') 284 | out = self.to_out(out) 285 | 286 | # use flamingo styled gating of output, so that memorizing transformers can be gated into an existing LLM 287 | # preparation to add this to block-recurrent-transformer-pytorch, for the pinnacle of long context attention network 288 | 289 | if exists(self.output_gate): 290 | out = out * self.output_gate.tanh() 291 | 292 | return out, new_xl_kv_memories 293 | 294 | # main class 295 | 296 | class MemorizingTransformer(nn.Module): 297 | def __init__( 298 | self, 299 | *, 300 | num_tokens, 301 | dim, 302 | depth, 303 | dim_head = 64, 304 | heads = 8, 305 | knn_attn_heads = None, 306 | attn_dropout = 0., 307 | ff_mult = 4, 308 | ff_dropout = 0., 309 | memorizing_layers = None, 310 | max_knn_memories = 250000, 311 | num_retrieved_memories = 32, 312 | clear_memories_on_sos_token_id = None, 313 | clear_memories_on_eos_token_id = None, 314 | knn_memories_directory = DEFAULT_KNN_MEMORY_MEMMAP_DIRECTORY, 315 | shift_knn_memories_down = 0., 316 | pad_id = 0, 317 | xl_max_memories = 0, 318 | xl_memory_layers = None, 319 | shift_xl_memories_down = 0., 320 | knn_memory_multiprocessing = False 321 | ): 322 | super().__init__() 323 | self.token_emb = nn.Embedding(num_tokens, dim) 324 | self.pad_id = pad_id 325 | 326 | block_wrapper = partial(PreNormResidual, dim) 327 | valid_layers = set(range(1, depth + 1)) 328 | 329 | memorizing_layers = default(memorizing_layers, (depth // 2,)) # default KNN attention layer to midpoint of transformer 330 | memorizing_layers = cast_tuple(memorizing_layers) 331 | memorizing_layers = tuple(filter(lambda i: i in valid_layers, memorizing_layers)) 332 | 333 | self.dim_head = dim_head 334 | 335 | knn_attn_heads = default(knn_attn_heads, heads) 336 | 337 | # xl memory hyperparameter 338 | 339 | if xl_max_memories > 0: 340 | xl_memory_layers = default(xl_memory_layers, tuple(range(1, depth + 1))) 341 | xl_memory_layers = unique(xl_memory_layers) 342 | self.xl_memory_layers = tuple(filter(lambda i: i in valid_layers, xl_memory_layers)) 343 | self.num_xl_memory_layers = len(self.xl_memory_layers) 344 | else: 345 | self.xl_memory_layers = tuple() 346 | self.num_xl_memory_layers = 0 347 | 348 | # knn memory hyperparameters 349 | 350 | self.max_knn_memories = max_knn_memories 351 | self.knn_memories_directory = knn_memories_directory 352 | self.memorizing_layers = unique(memorizing_layers) 353 | self.num_memory_layers = len(memorizing_layers) 354 | 355 | self.clear_memories_on_sos_token_id = clear_memories_on_sos_token_id 356 | self.clear_memories_on_eos_token_id = clear_memories_on_eos_token_id 357 | 358 | # relative positional bias 359 | 360 | self.rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads) 361 | self.knn_rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads) 362 | 363 | # layers 364 | 365 | self.layers = nn.ModuleList([]) 366 | for idx in range(depth): 367 | layer_num = idx + 1 368 | 369 | use_xl_memories = layer_num in self.xl_memory_layers 370 | use_knn_attention = layer_num in memorizing_layers 371 | xl_max_memories_layer = 0 if not use_xl_memories else xl_max_memories 372 | 373 | if use_knn_attention: 374 | attn = KNNAttention(dim = dim, dim_head = dim_head, heads = knn_attn_heads, dropout = attn_dropout, num_retrieved_memories = num_retrieved_memories, xl_max_memories = xl_max_memories_layer) 375 | else: 376 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, xl_max_memories = xl_max_memories_layer) 377 | 378 | self.layers.append(nn.ModuleList([ 379 | block_wrapper(attn), 380 | block_wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)), 381 | ])) 382 | 383 | # memory layer shifting 384 | # from a little known paper https://arxiv.org/abs/2012.15688 385 | 386 | self.shift_knn_memories_down = shift_knn_memories_down 387 | self.shift_xl_memories_down = shift_xl_memories_down 388 | 389 | # to logits 390 | 391 | self.to_logits = nn.Sequential( 392 | nn.LayerNorm(dim), 393 | nn.Linear(dim, num_tokens) 394 | ) 395 | 396 | # knn memories init 397 | 398 | self.knn_mem_kwargs = dict( 399 | dim = self.dim_head, 400 | max_memories = self.max_knn_memories, 401 | multiprocessing = knn_memory_multiprocessing 402 | ) 403 | 404 | def create_knn_memories( 405 | self, 406 | *, 407 | batch_size 408 | ): 409 | return KNNMemoryList.create_memories( 410 | batch_size = batch_size, 411 | num_memory_layers = self.num_memory_layers, 412 | memories_directory = self.knn_memories_directory, 413 | )(**self.knn_mem_kwargs) 414 | 415 | @contextmanager 416 | def knn_memories_context( 417 | self, 418 | **kwargs 419 | ): 420 | knn_dir = Path(self.knn_memories_directory) 421 | knn_dir.mkdir(exist_ok = True, parents = True) 422 | lock = FileLock(str(knn_dir / 'mutex')) 423 | 424 | with lock: 425 | knn_memories = self.create_knn_memories(**kwargs) 426 | yield knn_memories 427 | knn_memories.cleanup() 428 | 429 | def clear_memory(self, x, token_id): 430 | """ clears the KNN memories based on if the batch row contains the specified token id """ 431 | """ for auto-clearing KNN memories based on start and end of strings """ 432 | 433 | clear_memory = (x == token_id).any(dim = -1) 434 | batch_indices, _ = clear_memory.nonzero(as_tuple = True) 435 | batch_indices_to_clear = batch_indices.tolist() 436 | 437 | if len(batch_indices_to_clear) == 0: 438 | return 439 | 440 | knn_memories.clear_memory(batch_indices_to_clear) 441 | 442 | def forward( 443 | self, 444 | x, 445 | knn_memories, 446 | xl_memories = None, 447 | labels = None, 448 | add_knn_memory = True 449 | ): 450 | batch_size, seq_len, *_, device = *x.shape, x.device 451 | x = self.token_emb(x) 452 | 453 | # validate KNN memories to have enough indices for batch size 454 | 455 | assert all([memory.num_indices == batch_size for memory in knn_memories]), f'you passed in an input with batch size {batch_size} but your memories were not instantiated with that number of KNN indices' 456 | 457 | # if KNN memories are passed in, and researcher wants memories auto-cleared on token detection 458 | # do the appropriate logic 459 | 460 | if exists(self.clear_memories_on_sos_token_id): 461 | self.clear_memory(x, self.clear_memories_on_sos_token_id) 462 | 463 | # handle XL memories 464 | 465 | xl_memories = default(xl_memories, (None,) * self.num_xl_memory_layers) 466 | assert len(xl_memories) == self.num_xl_memory_layers 467 | has_xl_memories = len(xl_memories) > 0 468 | 469 | # shifting memories a number of layers down, little known technique shown to enhance memories from Ernie-Doc paper 470 | 471 | if len(knn_memories) > 0 and self.shift_knn_memories_down > 0: 472 | knn_memories = [*knn_memories[self.shift_knn_memories_down:], *knn_memories[:self.shift_knn_memories_down]] 473 | 474 | if len(xl_memories) > 0 and self.shift_xl_memories_down > 0: 475 | xl_memories = [*xl_memories[self.shift_xl_memories_down:], *xl_memories[:self.shift_xl_memories_down]] 476 | 477 | # iterate through the memories in order of the ascending layers that contain KNNAttention 478 | 479 | xl_memories_iter = iter(xl_memories) 480 | knn_memories_iter = iter(knn_memories) 481 | 482 | # positional bias 483 | 484 | max_context_len = max([seq_len, *map(lambda t: (t.shape[-3] if exists(t) else 0) + seq_len, xl_memories)]) 485 | 486 | rel_pos_bias = self.rel_pos_bias(seq_len, max_context_len, device = device) 487 | knn_rel_pos_bias = self.knn_rel_pos_bias(seq_len, max_context_len, device = device) 488 | 489 | # keep track of new xl memories 490 | 491 | new_xl_memories = [] if has_xl_memories else None 492 | 493 | # go through all layers 494 | 495 | for ind, (attn, ff) in enumerate(self.layers): 496 | layer_num = ind + 1 497 | 498 | is_memorizing_layer = layer_num in self.memorizing_layers 499 | is_xl_memory_layer = layer_num in self.xl_memory_layers 500 | 501 | attn_kwargs = dict(rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias) 502 | 503 | if is_memorizing_layer: 504 | attn_kwargs = {**attn_kwargs, 'knn_memory': next(knn_memories_iter), 'add_knn_memory': add_knn_memory} 505 | 506 | if is_xl_memory_layer: 507 | attn_kwargs = {**attn_kwargs, 'xl_memory': next(xl_memories_iter)} 508 | 509 | # attention 510 | 511 | x, xl_mem = attn(x, **attn_kwargs) 512 | 513 | # add new XL memories if needed 514 | 515 | if exists(xl_mem): 516 | new_xl_memories.append(xl_mem) 517 | 518 | # feedforward 519 | 520 | x = ff(x) 521 | 522 | # to logits 523 | 524 | logits = self.to_logits(x) 525 | 526 | # auto-clear KNN memories on end of string token 527 | 528 | if exists(self.clear_memories_on_eos_token_id): 529 | self.clear_memory(x, self.clear_memories_on_eos_token_id) 530 | 531 | # for training 532 | 533 | if not exists(labels): 534 | if exists(new_xl_memories): 535 | return logits, new_xl_memories 536 | 537 | return logits 538 | 539 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = self.pad_id) 540 | 541 | if exists(new_xl_memories): 542 | return loss, new_xl_memories 543 | 544 | return loss 545 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'memorizing-transformers-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.4.1', 7 | license='MIT', 8 | description = 'Memorizing Transformer - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/memorizing-transformers-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'memory', 18 | 'retrieval' 19 | ], 20 | install_requires=[ 21 | 'einops>=0.6', 22 | 'filelock', 23 | 'joblib', 24 | 'faiss-gpu', 25 | 'numpy', 26 | 'torch>=1.6', 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.6', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from memorizing_transformers_pytorch import MemorizingTransformer 2 | 3 | import random 4 | import tqdm 5 | import gzip 6 | import numpy as np 7 | import torch 8 | import torch.optim as optim 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | # constants 13 | 14 | NUM_BATCHES = int(1e5) 15 | BATCH_SIZE = 16 16 | SEQ_LEN = 512 17 | SEGMENTS = 5 18 | 19 | LEARNING_RATE = 2e-4 20 | MAX_GRAD_CLIP_NORM = 0.5 21 | 22 | VALIDATE_EVERY = 100 23 | GENERATE_EVERY = 500 24 | GENERATE_LENGTH = 512 25 | 26 | # helpers 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | def decode_token(token): 34 | return str(chr(max(32, token))) 35 | 36 | def decode_tokens(tokens): 37 | return ''.join(list(map(decode_token, tokens))) 38 | 39 | # instantiate GPT-like decoder model 40 | 41 | model = MemorizingTransformer( 42 | num_tokens = 256, 43 | dim = 512, 44 | depth = 8, 45 | memorizing_layers = 4, 46 | max_knn_memories = 512 * 15, 47 | num_retrieved_memories = 32, 48 | xl_memory_layers = (7, 8), 49 | xl_max_memories = 512, 50 | ).cuda() 51 | 52 | # prepare enwik8 data 53 | 54 | with gzip.open('./data/enwik8.gz') as file: 55 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) 56 | trX, vaX = np.split(X, [int(90e6)]) 57 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) 58 | 59 | class TextSamplerDataset(Dataset): 60 | def __init__(self, data, seq_len): 61 | super().__init__() 62 | self.data = data 63 | self.seq_len = seq_len 64 | 65 | def __getitem__(self, index): 66 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 67 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 68 | return full_seq.cuda() 69 | 70 | def __len__(self): 71 | return self.data.size(0) // self.seq_len 72 | 73 | # dataset and dataloader 74 | 75 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN * SEGMENTS) 76 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True)) 77 | valid_dataset = TextSamplerDataset(data_val, SEQ_LEN * SEGMENTS) 78 | valid_loader = cycle(DataLoader(valid_dataset, batch_size = BATCH_SIZE, drop_last = True)) 79 | 80 | # optimizer 81 | 82 | optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE) 83 | 84 | # training 85 | 86 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'): 87 | model.train() 88 | 89 | data = next(train_loader) 90 | 91 | train_loss = 0. 92 | with model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories: 93 | xl_memories = None 94 | seq, labels = data[:, :-1], data[:, 1:] 95 | 96 | for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)): 97 | loss, xl_memories = model( 98 | seq_segment, 99 | labels = labels_segment, 100 | knn_memories = knn_memories, 101 | xl_memories = xl_memories 102 | ) 103 | 104 | train_loss += loss.item() / SEGMENTS 105 | (loss / SEGMENTS).backward() 106 | 107 | print(f'training loss: {train_loss}') 108 | torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM) 109 | optim.step() 110 | optim.zero_grad() 111 | 112 | if not (i % VALIDATE_EVERY): 113 | model.eval() 114 | 115 | valid_data = next(valid_loader) 116 | valid_loss = 0. 117 | 118 | with torch.no_grad(), model.knn_memories_context(batch_size = BATCH_SIZE) as knn_memories: 119 | xl_memories = None 120 | seq, labels = data[:, :-1], data[:, 1:] 121 | 122 | for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)): 123 | loss, xl_memories = model( 124 | seq_segment, 125 | labels = labels_segment, 126 | knn_memories = knn_memories, 127 | xl_memories = xl_memories 128 | ) 129 | 130 | valid_loss += loss.item() / SEGMENTS 131 | 132 | print(f'valid loss: {valid_loss}') 133 | --------------------------------------------------------------------------------