├── RETRO.png ├── retro_pytorch ├── __init__.py ├── utils.py ├── optimizer.py ├── data.py ├── training.py ├── retrieval.py └── retro_pytorch.py ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE └── README.md /RETRO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/RETRO-pytorch/HEAD/RETRO.png -------------------------------------------------------------------------------- /retro_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from retro_pytorch.retro_pytorch import RETRO 2 | from retro_pytorch.data import RETRODataset 3 | from retro_pytorch.training import TrainingWrapper 4 | -------------------------------------------------------------------------------- /retro_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from contextlib import contextmanager 7 | 8 | def is_true_env_flag(env_flag): 9 | return os.getenv(env_flag, 'false').lower() in ('true', '1', 't') 10 | 11 | def reset_folder_(p): 12 | path = Path(p) 13 | rmtree(path, ignore_errors = True) 14 | path.mkdir(exist_ok = True, parents = True) 15 | 16 | @contextmanager 17 | def memmap(*args, **kwargs): 18 | pointer = np.memmap(*args, **kwargs) 19 | yield pointer 20 | del pointer 21 | -------------------------------------------------------------------------------- /retro_pytorch/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import AdamW 2 | 3 | def separate_weight_decayable_params(params): 4 | no_wd_params = set([param for param in params if param.ndim < 2]) 5 | wd_params = set(params) - no_wd_params 6 | return wd_params, no_wd_params 7 | 8 | def get_optimizer(params, lr = 3e-4, wd = 1e-1, filter_by_requires_grad = False): 9 | if filter_by_requires_grad: 10 | params = list(filter(lambda t: t.requires_grad, params)) 11 | 12 | params = set(params) 13 | wd_params, no_wd_params = separate_weight_decayable_params(params) 14 | 15 | param_groups = [ 16 | {'params': list(wd_params)}, 17 | {'params': list(no_wd_params), 'weight_decay': 0}, 18 | ] 19 | 20 | return AdamW(param_groups, lr = lr, weight_decay = wd) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'retro-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.3.9', 7 | license='MIT', 8 | description = 'RETRO - Retrieval Enhanced Transformer - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/RETRO-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention-mechanism', 18 | 'retrieval', 19 | ], 20 | install_requires=[ 21 | 'autofaiss', 22 | 'einops>=0.3', 23 | 'numpy', 24 | 'sentencepiece', 25 | 'torch>=1.6', 26 | 'tqdm' 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.6', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /retro_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from retro_pytorch.retrieval import EOS_ID 7 | from retro_pytorch.utils import memmap 8 | 9 | # knn to retrieved chunks 10 | 11 | def knn_to_retrieved_chunks( 12 | knns, 13 | chunks_memmap, 14 | *, 15 | add_continuations, 16 | num_chunks, 17 | pad_id = 0, 18 | eos_id = EOS_ID, 19 | ): 20 | 21 | # derive mask for no neighbors found (-1) 22 | 23 | no_neighbor_mask = knns == -1 24 | knns = np.maximum(knns, 0) 25 | 26 | # get neighbor and continuation chunks 27 | 28 | knn_chunks = chunks_memmap[knns] 29 | is_last_document_chunk = np.any(knn_chunks == eos_id, axis = -1, keepdims = True) 30 | 31 | # use presence of [EOS] in chunk as way to detect document boundaries 32 | # [EOS] in BERT tokenizer is 102 33 | 34 | retrieved = knn_chunks[..., :-1] 35 | 36 | if add_continuations: 37 | continuation_indices = np.clip(knns + 1, 0, num_chunks - 1) # chunks are stored contiguously 38 | continuation_chunks = chunks_memmap[continuation_indices][..., :-1] 39 | continuation_chunks *= ~is_last_document_chunk 40 | 41 | # combine neighbors with continuations 42 | 43 | retrieved = np.concatenate((retrieved, continuation_chunks), axis = -1) 44 | 45 | # mask out any nearest neighbor chunks that was -1 (not found at index time) to padding id 46 | 47 | retrieved = np.where(~no_neighbor_mask[..., None], retrieved, pad_id) 48 | return retrieved 49 | 50 | # dataset 51 | 52 | class RETRODataset(Dataset): 53 | def __init__( 54 | self, 55 | *, 56 | num_chunks, 57 | chunk_size, 58 | seq_len, 59 | num_sequences, 60 | num_neighbors, 61 | chunk_memmap_path, 62 | chunk_nn_memmap_path, 63 | seq_memmap_path, 64 | eos_id = EOS_ID, 65 | pad_id = 0., 66 | add_continuations = True 67 | ): 68 | super().__init__() 69 | self.num_chunks = num_chunks 70 | self.num_sequences = num_sequences 71 | self.seq_num_chunks = seq_len // chunk_size 72 | self.eos_id = eos_id 73 | self.pad_id = pad_id 74 | 75 | num_chunks_with_padding = num_chunks + self.seq_num_chunks 76 | 77 | chunks_shape = (num_chunks_with_padding, chunk_size + 1) 78 | knn_shape = (num_chunks_with_padding, num_neighbors) 79 | 80 | self.add_continuations = add_continuations 81 | self.get_chunks = partial(memmap, chunk_memmap_path, dtype = np.int32, shape = chunks_shape) 82 | self.get_knns = partial(memmap, chunk_nn_memmap_path, dtype = np.int32, shape = knn_shape) 83 | self.get_seqs = partial(memmap, seq_memmap_path, dtype = np.int32, shape = (num_sequences,)) 84 | 85 | def __len__(self): 86 | return self.num_sequences 87 | 88 | def __getitem__(self, ind): 89 | with self.get_chunks() as chunks_memmap, self.get_knns() as knns_memmap, self.get_seqs() as seqs_memmap: 90 | begin_chunk_index = seqs_memmap[ind] 91 | chunk_range = slice(begin_chunk_index, (begin_chunk_index + self.seq_num_chunks)) 92 | 93 | chunks = chunks_memmap[chunk_range] 94 | 95 | # excise the last token, except for last token of last chunk 96 | 97 | seq_tokens = np.concatenate((chunks[:, :-1].flatten(), chunks[-1, -1:])) 98 | 99 | # mask out (with padding tokens) any token following an | disallow having more than 1 document in a sequence, as it would break RETRO's CCA 100 | 101 | seq_mask = np.cumsum(seq_tokens == self.eos_id, axis = 0) 102 | seq_mask = np.pad(seq_mask, (1, 0))[:-1] == 0. 103 | seq_tokens = np.where(seq_mask, seq_tokens, 0.) 104 | 105 | # derive retrieved tokens 106 | 107 | knns = knns_memmap[chunk_range] 108 | 109 | retrieved = knn_to_retrieved_chunks( 110 | knns, 111 | chunks_memmap, 112 | add_continuations = self.add_continuations, 113 | eos_id = self.eos_id, 114 | num_chunks = self.num_chunks 115 | ) 116 | 117 | seq_tokens_torch = torch.from_numpy(seq_tokens).long() 118 | retrieved_torch = torch.from_numpy(retrieved).long() 119 | return seq_tokens_torch, retrieved_torch 120 | -------------------------------------------------------------------------------- /retro_pytorch/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | import json 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | 11 | from retro_pytorch import RETRO, RETRODataset 12 | from retro_pytorch.data import knn_to_retrieved_chunks 13 | from retro_pytorch.optimizer import get_optimizer 14 | from retro_pytorch.retrieval import text_folder_to_chunks_, chunks_to_precalculated_knn_, bert_embed, SOS_ID, EOS_ID 15 | from retro_pytorch.utils import memmap, is_true_env_flag 16 | 17 | from einops import rearrange 18 | 19 | # helpers 20 | 21 | def exists(val): 22 | return val is not None 23 | 24 | def eval_decorator(fn): 25 | def inner(model, *args, **kwargs): 26 | was_training = model.training 27 | model.eval() 28 | out = fn(model, *args, **kwargs) 29 | model.train(was_training) 30 | return out 31 | return inner 32 | 33 | def safe_cat(accum, t, dim = -1): 34 | if not exists(accum): 35 | return t 36 | return torch.cat((accum, t), dim = dim) 37 | 38 | # sampling helpers 39 | 40 | def log(t, eps = 1e-20): 41 | return torch.log(t.clamp(min = eps)) 42 | 43 | def gumbel_noise(t): 44 | noise = torch.zeros_like(t).uniform_(0, 1) 45 | return -log(-log(noise)) 46 | 47 | def gumbel_sample(t, temperature = 1., dim = -1): 48 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 49 | 50 | def top_k(logits, thres = 0.9): 51 | num_logits = logits.shape[-1] 52 | k = max(int((1 - thres) * num_logits), 1) 53 | val, ind = torch.topk(logits, k) 54 | probs = torch.full_like(logits, float('-inf')) 55 | probs.scatter_(1, ind, val) 56 | return probs 57 | 58 | def top_p(logits, thres = 0.9): 59 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 60 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 61 | 62 | sorted_indices_to_remove = cum_probs > (1 - thres) 63 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 64 | sorted_indices_to_remove[:, 0] = 0 65 | 66 | sorted_logits[sorted_indices_to_remove] = float('-inf') 67 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 68 | 69 | # function that returns knn chunks from seq chunks 70 | # 71 | # 1. adds sos and eos to seq chunks 72 | # 2. embeds the seq chunks with special tokens with frozen BERT 73 | # 3. fetches the knn indices with faiss 74 | # 4. gets the knn chunks as well as the continuation from a reference to the chunks data (memmap) 75 | # 76 | 77 | def knn_chunks_from_seq_chunks( 78 | seq_chunks, 79 | *, 80 | knn, 81 | faiss_index, 82 | num_chunks, 83 | chunk_size, 84 | chunks_memmap_path, 85 | ): 86 | b, device = seq_chunks.shape[0], seq_chunks.device 87 | 88 | # prepare last chunk with sos and eos tokens for BERT embed 89 | 90 | ones = torch.ones((b, 1), dtype = torch.bool, device = device) 91 | sos = ones * SOS_ID 92 | eos = ones * EOS_ID 93 | 94 | seq_chunks = torch.cat((sos, seq_chunks, eos), dim = 1) 95 | 96 | # embed with frozen BERT 97 | 98 | embeds = bert_embed(seq_chunks.cpu()) # fetch embeds on CPU for now 99 | 100 | # retrieval of knn with faiss 101 | 102 | _, knn_indices = faiss_index.search(embeds.cpu().numpy(), k = knn) 103 | 104 | # numpy to torch 105 | 106 | with memmap(chunks_memmap_path, dtype = np.int32, shape = (num_chunks + 1, chunk_size + 1)) as chunk_memmap: 107 | knn_chunks = knn_to_retrieved_chunks( 108 | knn_indices, 109 | chunk_memmap, 110 | add_continuations = True, 111 | num_chunks = num_chunks 112 | ) 113 | 114 | knn_chunks_torch = torch.from_numpy(knn_chunks).to(device) 115 | 116 | return knn_chunks_torch 117 | 118 | # training wrapper class 119 | 120 | class TrainingWrapper(nn.Module): 121 | def __init__( 122 | self, 123 | *, 124 | retro, 125 | chunk_size, 126 | documents_path, 127 | knn, 128 | glob = '**/*.txt', 129 | chunks_memmap_path = './train.chunks.dat', 130 | seqs_memmap_path = './train.seq.dat', 131 | doc_ids_memmap_path = './train.doc_ids.dat', 132 | max_chunks = 1_000_000, 133 | max_seqs = 100_000, 134 | knn_extra_neighbors = 100, 135 | processed_stats_json_path = './processed-stats.json', 136 | faiss_index_filename = 'knn.index', 137 | **index_kwargs 138 | ): 139 | super().__init__() 140 | assert isinstance(retro, RETRO), 'retro must be instance of RETRO' 141 | self.retro = retro 142 | 143 | force_reprocess = is_true_env_flag('REPROCESS') 144 | 145 | # store the processed training data statistics 146 | # number of chunks, number of sequences 147 | 148 | stats_path = Path(processed_stats_json_path) 149 | 150 | # if the statistics file does not exist, process folders of text 151 | # force reprocess by setting REPROCESS=1 when running training script 152 | 153 | if not stats_path.exists() or force_reprocess: 154 | self.stats = text_folder_to_chunks_( 155 | folder = documents_path, 156 | glob = glob, 157 | chunks_memmap_path = chunks_memmap_path, 158 | seqs_memmap_path = seqs_memmap_path, 159 | doc_ids_memmap_path = doc_ids_memmap_path, 160 | chunk_size = chunk_size, 161 | seq_len = retro.seq_len, 162 | max_chunks = max_chunks, 163 | max_seqs = max_seqs 164 | ) 165 | with open(processed_stats_json_path, 'w') as f: 166 | json.dump(self.stats, f) 167 | else: 168 | print(f'found to be previously processed at {str(stats_path)}') 169 | self.stats = json.loads(stats_path.read_text()) 170 | 171 | # get number of chunks and number of sequences 172 | 173 | num_chunks = self.stats['chunks'] 174 | num_seqs = self.stats['seqs'] 175 | 176 | # calculate knn memmap path and get the faiss index 177 | # todo - make sure if faiss_index_filename is found, do not reprocess unless flag is given 178 | 179 | knn_memmap_path, faiss_index = chunks_to_precalculated_knn_( 180 | num_chunks = num_chunks, 181 | chunk_size = chunk_size, 182 | chunk_memmap_path = chunks_memmap_path, 183 | doc_ids_memmap_path = doc_ids_memmap_path, 184 | num_nearest_neighbors = knn, 185 | num_extra_neighbors = knn_extra_neighbors, 186 | index_file = faiss_index_filename, 187 | force_reprocess = force_reprocess, 188 | **index_kwargs 189 | ) 190 | 191 | # retro dataset 192 | 193 | self.ds = RETRODataset( 194 | num_sequences = num_seqs, 195 | num_chunks = num_chunks, 196 | num_neighbors = knn, 197 | chunk_size = chunk_size, 198 | seq_len = retro.seq_len, 199 | chunk_memmap_path = chunks_memmap_path, 200 | chunk_nn_memmap_path = knn_memmap_path, 201 | seq_memmap_path = seqs_memmap_path 202 | ) 203 | 204 | # params needed for generation 205 | 206 | self.chunk_size = chunk_size 207 | self.max_seq_len = self.retro.seq_len 208 | 209 | self.fetch_knn_chunks_fn = partial( 210 | knn_chunks_from_seq_chunks, 211 | knn = knn, 212 | chunk_size = chunk_size, 213 | num_chunks = num_chunks, 214 | chunks_memmap_path = chunks_memmap_path, 215 | faiss_index = faiss_index 216 | ) 217 | 218 | @torch.no_grad() 219 | @eval_decorator 220 | def generate( 221 | self, 222 | start = None, 223 | retrieved = None, 224 | filter_fn = top_k, 225 | filter_thres = 0.9, 226 | temperature = 1.0, 227 | ): 228 | assert filter_fn in {top_k, top_p}, 'filter function must be either top-k or nucleus' 229 | 230 | device = next(self.retro.parameters()).device 231 | 232 | # if not prime tokens given, assume sampling from SOS token with batch size of 1 233 | 234 | if not exists(start): 235 | start = torch.full((1, 1), SOS_ID, device = device).long() 236 | 237 | b, start_seq_len = start.shape 238 | 239 | # move onto same device as RETRO 240 | 241 | start = start.to(device) 242 | 243 | # prepare retrieval related variables 244 | 245 | if start_seq_len >= self.chunk_size: 246 | seq_index = (start_seq_len // self.chunk_size) * self.chunk_size 247 | past_seq_chunks = rearrange(start[:, :seq_index], 'b (n c) -> (b n) c', c = self.chunk_size) 248 | 249 | retrieved = self.fetch_knn_chunks_fn(past_seq_chunks) 250 | retrieved = rearrange(retrieved, '(b n) k c -> b n k c', b = b) 251 | 252 | # get starting sequence index 253 | 254 | out = start 255 | 256 | # sampling loop 257 | 258 | for i in range(start_seq_len - 1, self.max_seq_len): 259 | 260 | logits = self.retro(out, retrieved = retrieved) 261 | logits = logits[:, i] 262 | 263 | logits = filter_fn(logits, thres = filter_thres) 264 | sampled = gumbel_sample(logits, temperature = temperature, dim = -1) 265 | sampled = rearrange(sampled, 'b -> b 1') 266 | 267 | out = torch.cat((out, sampled), dim = 1) 268 | 269 | # early terminate if all EOS 270 | 271 | is_eos_tokens = (out == EOS_ID) 272 | 273 | if is_eos_tokens.any(dim = -1).all(): 274 | 275 | # mask out everything after the eos tokens 276 | 277 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 278 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 279 | out = out.masked_fill(mask, self.retro.pad_id) 280 | break 281 | 282 | # when the sequence length is a multiple of the chunk size 283 | # retrieve the next set of knns 284 | 285 | curr_seq_len = out.shape[-1] 286 | 287 | if (curr_seq_len % self.chunk_size) == 0: 288 | last_chunk = rearrange(out, 'b (c n) -> b c n', n = self.chunk_size)[:, -1] 289 | 290 | knn_chunks = self.fetch_knn_chunks_fn(last_chunk) 291 | 292 | # concat retrieved knn chunks to all retrieved 293 | # to be sent to Retro for chunked cross attention at the next iteration 294 | 295 | knn_chunks = rearrange(knn_chunks, 'b k r -> b 1 k r') 296 | retrieved = safe_cat(retrieved, knn_chunks, dim = 1) 297 | 298 | print(f'retrieved at {curr_seq_len} / {self.max_seq_len}') 299 | 300 | return out 301 | 302 | def get_dataloader(self, **kwargs): 303 | return DataLoader(self.ds, **kwargs) 304 | 305 | def get_optimizer(self, **kwargs): 306 | return get_optimizer(self.retro.parameters(), **kwargs) 307 | 308 | def forward(self): 309 | raise NotImplemented 310 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /retro_pytorch/retrieval.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from math import ceil 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import logging 7 | import numpy as np 8 | from einops import rearrange 9 | 10 | import faiss 11 | from autofaiss import build_index 12 | 13 | from retro_pytorch.utils import memmap, reset_folder_ 14 | 15 | # constants 16 | 17 | SOS_ID = 101 18 | EOS_ID = 102 19 | BERT_MODEL_DIM = 768 20 | BERT_VOCAB_SIZE = 28996 21 | 22 | TMP_PATH = Path('./.tmp') 23 | INDEX_FOLDER_PATH = TMP_PATH / '.index' 24 | EMBEDDING_TMP_SUBFOLDER = 'embeddings' 25 | 26 | # helper functions 27 | 28 | def exists(val): 29 | return val is not None 30 | 31 | def range_chunked(max_value, *, batch_size): 32 | counter = 0 33 | while counter < max_value: 34 | curr = counter + batch_size 35 | curr = min(curr, max_value) 36 | yield slice(counter, curr) 37 | counter = curr 38 | 39 | # indexing helper functions 40 | 41 | def faiss_read_index(path): 42 | return faiss.read_index(str(path), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) 43 | 44 | # singleton globals 45 | 46 | MODEL = None 47 | TOKENIZER = None 48 | 49 | def get_tokenizer(): 50 | global TOKENIZER 51 | if not exists(TOKENIZER): 52 | TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') 53 | return TOKENIZER 54 | 55 | def get_bert(): 56 | global MODEL 57 | if not exists(MODEL): 58 | MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased') 59 | if torch.cuda.is_available(): 60 | MODEL = MODEL.cuda() 61 | 62 | return MODEL 63 | 64 | # tokenize 65 | 66 | def tokenize(texts, add_special_tokens = True): 67 | if not isinstance(texts, (list, tuple)): 68 | texts = [texts] 69 | 70 | tokenizer = get_tokenizer() 71 | 72 | encoding = tokenizer.batch_encode_plus( 73 | texts, 74 | add_special_tokens = add_special_tokens, 75 | padding = True, 76 | return_tensors = 'pt' 77 | ) 78 | 79 | token_ids = encoding.input_ids 80 | return token_ids 81 | 82 | # text to chunks 83 | 84 | def doc_text_to_chunks_and_seq_indices( 85 | *, 86 | doc_text, 87 | chunk_size = 64, 88 | seq_len = 2048, 89 | pad_id = 0 90 | ): 91 | assert (seq_len % chunk_size) == 0, 'sequence length must be divisible by chunk size' 92 | 93 | ids = tokenize(doc_text) 94 | ids = rearrange(ids, '1 ... -> ...') 95 | 96 | text_len = ids.shape[-1] 97 | 98 | # pad to multiple of chunk size with an extra token 99 | 100 | padding = chunk_size - ((text_len - 1) % chunk_size) 101 | ids = F.pad(ids, (0, padding)) 102 | 103 | # split out very last token 104 | 105 | ids, last_token = ids[:-1], ids[-1:] 106 | ids = rearrange(ids, '(n c) -> n c', c = chunk_size) 107 | 108 | # first tokens of chunk [2:] and on will become the last token of chunk [1:] 109 | 110 | last_token_per_chunk = ids[1:, 0] 111 | all_last_tokens = torch.cat((last_token_per_chunk, last_token), dim = 0) 112 | all_last_tokens = rearrange(all_last_tokens, 'n -> n 1') 113 | 114 | # append all last tokens to ids for (num_chunks, chunk_size + 1) 115 | 116 | chunks_with_extra_token = torch.cat((ids, all_last_tokens), dim = -1) 117 | 118 | # calculate chunk indices starting at 0, spaced number of chunks of seq len apart 119 | 120 | total_chunks = ids.shape[0] 121 | num_chunks_per_seq = seq_len // chunk_size 122 | seq = torch.arange(0, total_chunks, num_chunks_per_seq) 123 | 124 | return chunks_with_extra_token, seq 125 | 126 | def text_folder_to_chunks_( 127 | *, 128 | folder, 129 | chunks_memmap_path, 130 | seqs_memmap_path, 131 | doc_ids_memmap_path, 132 | chunk_size = 64, 133 | seq_len = 2048, 134 | glob = '**/*.txt', 135 | max_chunks = 1_000_000, 136 | max_seqs = 100_000 137 | ): 138 | paths = sorted([*Path(folder).glob(glob)]) 139 | 140 | total_chunks = 0 141 | total_docs = 0 142 | total_seqs = 0 143 | 144 | chunks_shape = (max_chunks, chunk_size + 1) 145 | seqs_shape = (max_seqs,) 146 | doc_ids_shape = (max_chunks,) 147 | 148 | with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32, mode = 'w+') as chunks_memmap\ 149 | , memmap(seqs_memmap_path, shape = seqs_shape, dtype = np.int32, mode = 'w+') as seqs_memmap\ 150 | , memmap(doc_ids_memmap_path, shape = doc_ids_shape, dtype = np.int32, mode = 'w+') as doc_ids_memmap: 151 | 152 | for path in paths: 153 | print(f'processing {path}') 154 | 155 | chunks, seq = doc_text_to_chunks_and_seq_indices( 156 | doc_text = path.read_text(), 157 | chunk_size = chunk_size, 158 | seq_len = seq_len 159 | ) 160 | 161 | doc_chunk_len = chunks.shape[0] 162 | doc_seq_len = seq.shape[0] 163 | 164 | chunks_memmap[total_chunks:(total_chunks + doc_chunk_len)] = chunks.numpy() 165 | seqs_memmap[total_seqs:(total_seqs + doc_seq_len)] = seq.numpy() + total_chunks 166 | doc_ids_memmap[total_chunks:(total_chunks + doc_chunk_len)] = np.full((doc_chunk_len,), total_docs) 167 | 168 | total_chunks += doc_chunk_len 169 | total_seqs += doc_seq_len 170 | total_docs += 1 171 | 172 | return dict( 173 | chunks = total_chunks, 174 | docs = total_docs, 175 | seqs = total_seqs 176 | ) 177 | 178 | # embedding function 179 | 180 | @torch.no_grad() 181 | def bert_embed( 182 | token_ids, 183 | return_cls_repr = False, 184 | eps = 1e-8, 185 | pad_id = 0. 186 | ): 187 | model = get_bert() 188 | mask = token_ids != pad_id 189 | 190 | if torch.cuda.is_available(): 191 | token_ids = token_ids.cuda() 192 | mask = mask.cuda() 193 | 194 | outputs = model( 195 | input_ids = token_ids, 196 | attention_mask = mask, 197 | output_hidden_states = True 198 | ) 199 | 200 | hidden_state = outputs.hidden_states[-1] 201 | 202 | if return_cls_repr: 203 | return hidden_state[:, 0] # return [cls] as representation 204 | 205 | if not exists(mask): 206 | return hidden_state.mean(dim = 1) 207 | 208 | mask = mask[:, 1:] # mean all tokens excluding [cls], accounting for length 209 | mask = rearrange(mask, 'b n -> b n 1') 210 | 211 | numer = (hidden_state[:, 1:] * mask).sum(dim = 1) 212 | denom = mask.sum(dim = 1) 213 | masked_mean = numer / (denom + eps) 214 | return masked_mean 215 | 216 | # chunks to knn 217 | 218 | def chunks_to_embeddings_( 219 | *, 220 | num_chunks, 221 | chunks_memmap_path, 222 | embeddings_memmap_path, 223 | chunk_size = 64, 224 | embed_dim = BERT_MODEL_DIM, 225 | batch_size = 16, 226 | use_cls_repr = False, 227 | pad_id = 0. 228 | ): 229 | chunks_shape = (num_chunks, chunk_size + 1) 230 | embed_shape = (num_chunks, embed_dim) 231 | 232 | with memmap(chunks_memmap_path, shape = chunks_shape, dtype = np.int32) as chunks\ 233 | , memmap(embeddings_memmap_path, shape = embed_shape, dtype = np.float32, mode = 'w+') as embeddings: 234 | 235 | for dim_slice in range_chunked(num_chunks, batch_size = batch_size): 236 | batch_chunk_npy = chunks[dim_slice] 237 | 238 | batch_chunk = torch.from_numpy(batch_chunk_npy) 239 | 240 | cls_tokens = torch.full((batch_chunk.shape[0], 1), SOS_ID) 241 | batch_chunk = torch.cat((cls_tokens, batch_chunk), dim = 1) 242 | 243 | batch_chunk = batch_chunk[:, :-1] # omit last token, the first token of the next chunk, used for autoregressive training 244 | 245 | batch_embed = bert_embed( 246 | batch_chunk, 247 | return_cls_repr = use_cls_repr 248 | ) 249 | 250 | embeddings[dim_slice] = batch_embed.detach().cpu().numpy() 251 | print(f'embedded {dim_slice.stop} / {num_chunks}') 252 | 253 | 254 | def memmap_file_to_chunks_( 255 | memmap_path, 256 | *, 257 | folder, 258 | shape, 259 | dtype, 260 | max_rows_per_file = 500 261 | ): 262 | rows, _ = shape 263 | 264 | with memmap(memmap_path, shape = shape, dtype = dtype, mode = 'r') as f: 265 | root_path = TMP_PATH / folder 266 | reset_folder_(root_path) 267 | 268 | for ind, dim_slice in enumerate(range_chunked(rows, batch_size = max_rows_per_file)): 269 | filename = root_path / f'{ind:05d}.npy' 270 | data_slice = f[dim_slice] 271 | 272 | np.save(str(filename), f[dim_slice]) 273 | print(f'saved {str(filename)}') 274 | 275 | def index_embeddings( 276 | embeddings_folder, 277 | *, 278 | index_file = 'knn.index', 279 | index_infos_file = 'index_infos.json', 280 | max_index_memory_usage = '100m', 281 | current_memory_available = '1G' 282 | ): 283 | embeddings_path = TMP_PATH / embeddings_folder 284 | index_path = INDEX_FOLDER_PATH / index_file 285 | 286 | reset_folder_(INDEX_FOLDER_PATH) 287 | 288 | build_index( 289 | embeddings = str(embeddings_path), 290 | index_path = str(index_path), 291 | index_infos_path = str(INDEX_FOLDER_PATH / index_infos_file), 292 | metric_type = "l2", 293 | max_index_memory_usage = max_index_memory_usage, 294 | current_memory_available = current_memory_available, 295 | make_direct_map = True, 296 | should_be_memory_mappable = False, 297 | use_gpu = torch.cuda.is_available(), 298 | ) 299 | 300 | index = faiss_read_index(index_path) 301 | return index 302 | 303 | def chunks_to_index_and_embed( 304 | *, 305 | num_chunks, 306 | chunk_size, 307 | chunk_memmap_path, 308 | use_cls_repr = False, 309 | max_rows_per_file = 500, 310 | chunks_to_embeddings_batch_size = 16, 311 | embed_dim = BERT_MODEL_DIM, 312 | index_file = 'knn.index', 313 | **index_kwargs 314 | ): 315 | embedding_path = f'{chunk_memmap_path}.embedded' 316 | embed_shape = (num_chunks, embed_dim) 317 | 318 | chunks_to_embeddings_( 319 | num_chunks = num_chunks, 320 | chunk_size = chunk_size, 321 | chunks_memmap_path = chunk_memmap_path, 322 | embeddings_memmap_path = embedding_path, 323 | use_cls_repr = use_cls_repr, 324 | batch_size = chunks_to_embeddings_batch_size, 325 | embed_dim = embed_dim 326 | ) 327 | 328 | memmap_file_to_chunks_( 329 | embedding_path, 330 | shape = embed_shape, 331 | dtype = np.float32, 332 | folder = EMBEDDING_TMP_SUBFOLDER, 333 | max_rows_per_file = max_rows_per_file 334 | ) 335 | 336 | index = index_embeddings( 337 | embeddings_folder = EMBEDDING_TMP_SUBFOLDER, 338 | index_file = index_file, 339 | **index_kwargs 340 | ) 341 | 342 | embeddings = np.memmap(embedding_path, shape = embed_shape, dtype = np.float32, mode = 'r') 343 | return index, embeddings 344 | 345 | def chunks_to_precalculated_knn_( 346 | *, 347 | num_nearest_neighbors, 348 | num_chunks, 349 | chunk_size, 350 | chunk_memmap_path, 351 | doc_ids_memmap_path, 352 | use_cls_repr = False, 353 | max_rows_per_file = 500, 354 | chunks_to_embeddings_batch_size = 16, 355 | embed_dim = BERT_MODEL_DIM, 356 | num_extra_neighbors = 10, 357 | force_reprocess = False, 358 | index_file = 'knn.index', 359 | **index_kwargs 360 | ): 361 | chunk_path = Path(chunk_memmap_path) 362 | knn_path = chunk_path.parents[0] / f'{chunk_path.stem}.knn{chunk_path.suffix}' 363 | index_path = INDEX_FOLDER_PATH / index_file 364 | 365 | # early return knn path and faiss index 366 | # unless if force_reprocess is True 367 | 368 | if index_path.exists() and knn_path.exists() and not force_reprocess: 369 | print(f'preprocessed knn found at {str(knn_path)}, faiss index reconstituted from {str(index_path)}') 370 | index = faiss_read_index(index_path) 371 | return knn_path, index 372 | 373 | # fetch the faiss index and calculated embeddings for the chunks 374 | 375 | index, embeddings = chunks_to_index_and_embed( 376 | num_chunks = num_chunks, 377 | chunk_size = chunk_size, 378 | chunk_memmap_path = chunk_memmap_path, 379 | index_file = index_file, 380 | **index_kwargs 381 | ) 382 | 383 | total_neighbors_to_fetch = num_extra_neighbors + num_nearest_neighbors + 1 384 | 385 | with memmap(knn_path, shape = (num_chunks, num_nearest_neighbors), dtype = np.int32, mode = 'w+') as knns\ 386 | , memmap(doc_ids_memmap_path, shape = (num_chunks,), dtype = np.int32, mode = 'r') as doc_ids: 387 | 388 | for dim_slice in range_chunked(num_chunks, batch_size = max_rows_per_file): 389 | query_vector = embeddings[dim_slice] 390 | 391 | distances, indices = index.search(query_vector, k = total_neighbors_to_fetch) 392 | 393 | # remove self from distances and indices 394 | 395 | distances = distances[:, 1:] 396 | indices = indices[:, 1:] 397 | 398 | # mask out any neighbors that belong to the same document to -1 399 | 400 | query_doc_ids = doc_ids[dim_slice] 401 | neighbor_doc_ids = doc_ids[indices] 402 | neighbor_from_same_doc = query_doc_ids[..., None] == neighbor_doc_ids 403 | 404 | indices = np.where(neighbor_from_same_doc, -1, indices) 405 | distances = np.where(neighbor_from_same_doc, 1e3, distances) 406 | 407 | # re-sort indices by updated distances 408 | 409 | indices = np.take_along_axis(indices, np.argsort(distances, axis = 1), axis = 1) 410 | 411 | # store nearest neighbors to knn memmap 412 | 413 | knns[dim_slice] = indices[:, :num_nearest_neighbors] 414 | 415 | print(f'knns calculated for {dim_slice.stop} / {num_chunks}') 416 | 417 | print(f'knn saved to {knn_path}') 418 | return knn_path, index 419 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## RETRO - Pytorch 4 | 5 | Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann. 6 | 7 | This library leverages autofaiss for building the index and calculating the k-nearest neighbors for all chunks. 8 | 9 | Jay Alammar explanatory blogpost 10 | 11 | The selling point of this retriever approach is reaching GPT-3 performance at 10x less parameters. More research is definitely deserved in this area. 12 | 13 | I have also included the features necessary to scale the retrieval transformer to 1000 layers, if the claims of DeepNet paper is to be believed. 14 | 15 | Update: Someone on Reddit has gifted me a Gold Award. Not sure what it is, but thank you! 🙏 16 | 17 | Update: Deepnorm has been validated at scale in a 130B model out of Tsinghua. It is now recommended that you train with `use_deepnet` set to `True` 18 | 19 | ## Install 20 | 21 | ```bash 22 | $ pip install retro-pytorch 23 | ```` 24 | 25 | ## Usage 26 | 27 | ```python 28 | import torch 29 | from retro_pytorch import RETRO 30 | 31 | retro = RETRO( 32 | chunk_size = 64, # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention) 33 | max_seq_len = 2048, # max sequence length 34 | enc_dim = 896, # encoder model dim 35 | enc_depth = 2, # encoder depth 36 | dec_dim = 796, # decoder model dim 37 | dec_depth = 12, # decoder depth 38 | dec_cross_attn_layers = (3, 6, 9, 12), # decoder cross attention layers (with causal chunk cross attention) 39 | heads = 8, # attention heads 40 | dim_head = 64, # dimension per head 41 | dec_attn_dropout = 0.25, # decoder attention dropout 42 | dec_ff_dropout = 0.25, # decoder feedforward dropout 43 | use_deepnet = True # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers 44 | ) 45 | 46 | seq = torch.randint(0, 20000, (2, 2048 + 1)) # plus one since it is split into input and labels for training 47 | retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation) 48 | 49 | loss = retro(seq, retrieved, return_loss = True) 50 | loss.backward() 51 | 52 | # do above for many steps 53 | ``` 54 | 55 | 56 | ## RETRO Training Wrapper 57 | 58 | The aim of the `TrainingWrapper` is to process a folder of text documents into the necessary memmapped numpy arrays to begin training `RETRO`. 59 | 60 | ```python 61 | import torch 62 | from retro_pytorch import RETRO, TrainingWrapper 63 | 64 | # instantiate RETRO, fit it into the TrainingWrapper with correct settings 65 | 66 | retro = RETRO( 67 | max_seq_len = 2048, # max sequence length 68 | enc_dim = 896, # encoder model dimension 69 | enc_depth = 3, # encoder depth 70 | dec_dim = 768, # decoder model dimensions 71 | dec_depth = 12, # decoder depth 72 | dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention) 73 | heads = 8, # attention heads 74 | dim_head = 64, # dimension per head 75 | dec_attn_dropout = 0.25, # decoder attention dropout 76 | dec_ff_dropout = 0.25 # decoder feedforward dropout 77 | ).cuda() 78 | 79 | wrapper = TrainingWrapper( 80 | retro = retro, # path to retro instance 81 | knn = 2, # knn (2 in paper was sufficient) 82 | chunk_size = 64, # chunk size (64 in paper) 83 | documents_path = './text_folder', # path to folder of text 84 | glob = '**/*.txt', # text glob 85 | chunks_memmap_path = './train.chunks.dat', # path to chunks 86 | seqs_memmap_path = './train.seq.dat', # path to sequence data 87 | doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids per chunk (used for filtering neighbors belonging to same document) 88 | max_chunks = 1_000_000, # maximum cap to chunks 89 | max_seqs = 100_000, # maximum seqs 90 | knn_extra_neighbors = 100, # num extra neighbors to fetch 91 | max_index_memory_usage = '100m', 92 | current_memory_available = '1G' 93 | ) 94 | 95 | # get the dataloader and optimizer (AdamW with all the correct settings) 96 | 97 | train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True)) 98 | optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01) 99 | 100 | # now do your training 101 | # ex. one gradient step 102 | 103 | seq, retrieved = map(lambda t: t.cuda(), next(train_dl)) 104 | 105 | # seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:] 106 | # retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens 107 | 108 | loss = retro( 109 | seq, 110 | retrieved, 111 | return_loss = True 112 | ) 113 | 114 | # one gradient step 115 | 116 | loss.backward() 117 | optim.step() 118 | optim.zero_grad() 119 | 120 | # do above for many steps, then ... 121 | 122 | # topk sampling with retrieval at chunk boundaries 123 | 124 | sampled = wrapper.generate(filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all 125 | 126 | # or you can generate with a prompt, knn retrieval for initial chunks all taken care of 127 | 128 | prompt = torch.randint(0, 1000, (1, 128)) # start with two chunks worth of sequence 129 | sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # (1, <2049) terminates early if all 130 | 131 | ``` 132 | 133 | If you wish to force a reprocess of the training data, simply run your script with a `REPROCESS=1` environment flag as so 134 | 135 | ```bash 136 | $ REPROCESS=1 python train.py 137 | ``` 138 | 139 | ## RETRO Datasets 140 | 141 | The `RETRODataset` class accepts paths to a number of memmapped numpy arrays containing the chunks, the index of the first chunk in the sequence to be trained on (in RETRO decoder), and the pre-calculated indices of the k-nearest neighbors per chunk. 142 | 143 | You can use this to easily assemble the data for `RETRO` training, if you do not wish to use the `TrainingWrapper` from above. 144 | 145 | Furthermore, all the functions needed to create the necessary memmapped data is in the sections to follow. 146 | 147 | 148 | ```python 149 | import torch 150 | from torch.utils.data import DataLoader 151 | from retro_pytorch import RETRO, RETRODataset 152 | 153 | # mock data constants 154 | 155 | import numpy as np 156 | 157 | NUM_CHUNKS = 1000 158 | CHUNK_SIZE = 64 159 | NUM_SEQS = 100 160 | NUM_NEIGHBORS = 2 161 | 162 | def save_memmap(path, tensor): 163 | f = np.memmap(path, dtype = tensor.dtype, mode = 'w+', shape = tensor.shape) 164 | f[:] = tensor 165 | del f 166 | 167 | # generate mock chunk data 168 | 169 | save_memmap( 170 | './train.chunks.dat', 171 | np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1))) 172 | ) 173 | 174 | # generate nearest neighbors for each chunk 175 | 176 | save_memmap( 177 | './train.chunks.knn.dat', 178 | np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS))) 179 | ) 180 | 181 | # generate seq data 182 | 183 | save_memmap( 184 | './train.seq.dat', 185 | np.int32(np.random.randint(0, 128, size = (NUM_SEQS,))) 186 | ) 187 | 188 | # instantiate dataset class 189 | # which constructs the sequence and neighbors from memmapped chunk and neighbor information 190 | 191 | train_ds = RETRODataset( 192 | num_sequences = NUM_SEQS, 193 | num_chunks = NUM_CHUNKS, 194 | num_neighbors = NUM_NEIGHBORS, 195 | chunk_size = CHUNK_SIZE, 196 | seq_len = 2048, 197 | chunk_memmap_path = './train.chunks.dat', 198 | chunk_nn_memmap_path = './train.chunks.knn.dat', 199 | seq_memmap_path = './train.seq.dat' 200 | ) 201 | 202 | train_dl = iter(DataLoader(train_ds, batch_size = 2)) 203 | 204 | # one forwards and backwards 205 | 206 | retro = RETRO( 207 | max_seq_len = 2048, # max sequence length 208 | enc_dim = 896, # encoder model dimension 209 | enc_depth = 3, # encoder depth 210 | dec_dim = 768, # decoder model dimensions 211 | dec_depth = 12, # decoder depth 212 | dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention) 213 | heads = 8, # attention heads 214 | dim_head = 64, # dimension per head 215 | dec_attn_dropout = 0.25, # decoder attention dropout 216 | dec_ff_dropout = 0.25 # decoder feedforward dropout 217 | ).cuda() 218 | 219 | seq, retrieved = map(lambda t: t.cuda(), next(train_dl)) 220 | 221 | # seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:] 222 | # retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens 223 | 224 | loss = retro( 225 | seq, 226 | retrieved, 227 | return_loss = True 228 | ) 229 | 230 | loss.backward() 231 | 232 | ``` 233 | 234 | ## Retrieval related tools 235 | 236 | This repository will use the default tokenizer (sentencepiece) for the cased version of BERT. Embeddings will be fetched from the vanilla BERT, and can either be masked mean pooled representation, or the CLS token. 237 | 238 | ex. masked mean pooled representation 239 | 240 | ```python 241 | from retro_pytorch.retrieval import bert_embed, tokenize 242 | 243 | ids = tokenize([ 244 | 'hello world', 245 | 'foo bar' 246 | ]) 247 | 248 | embeds = bert_embed(ids) # (2, 768) - 768 is hidden dimension of BERT 249 | ``` 250 | 251 | ex. CLS token representation 252 | 253 | 254 | ```python 255 | from retro_pytorch.retrieval import bert_embed, tokenize 256 | 257 | ids = tokenize([ 258 | 'hello world', 259 | 'foo bar' 260 | ]) 261 | 262 | embeds = bert_embed(ids, return_cls_repr = True) # (2, 768) 263 | ``` 264 | 265 | Create your chunks and chunk start indices (for calculating sequence ranges for autoregressive training) using `text_folder_to_chunks_` 266 | 267 | ```python 268 | from retro_pytorch.retrieval import text_folder_to_chunks_ 269 | 270 | stats = text_folder_to_chunks_( 271 | folder = './text_folder', 272 | glob = '**/*.txt', 273 | chunks_memmap_path = './train.chunks.dat', 274 | seqs_memmap_path = './train.seq.dat', 275 | doc_ids_memmap_path = './train.doc_ids.dat', # document ids are needed for filtering out neighbors belonging to same document appropriately during computation of nearest neighbors 276 | chunk_size = 64, 277 | seq_len = 2048, 278 | max_chunks = 1_000_000, 279 | max_seqs = 100_000 280 | ) 281 | 282 | # {'chunks': , 'docs': , 'seqs': } 283 | ``` 284 | 285 | ## Fetching Nearest Neighbors 286 | 287 | You can turn your memmapped chunks numpy array into embeddings and a faiss index with one command 288 | 289 | ```python 290 | from retro_pytorch.retrieval import chunks_to_index_and_embed 291 | 292 | index, embeddings = chunks_to_index_and_embed( 293 | num_chunks = 1000, 294 | chunk_size = 64, 295 | chunk_memmap_path = './train.chunks.dat' 296 | ) 297 | 298 | query_vector = embeddings[:1] # use first embedding as query 299 | _, indices = index.search(query_vector, k = 2) # fetch 2 neighbors, first indices should be self 300 | 301 | neighbor_embeddings = embeddings[indices] # (1, 2, 768) 302 | 303 | ``` 304 | 305 | You can also directly calculate the nearest neighbor file necessary for training, with `chunks_to_precalculated_knn_` command 306 | 307 | ```python 308 | from retro_pytorch.retrieval import chunks_to_precalculated_knn_ 309 | 310 | chunks_to_precalculated_knn_( 311 | num_chunks = 1000, 312 | chunk_size = 64, 313 | chunk_memmap_path = './train.chunks.dat', # path to main chunks dataset 314 | doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids created by text_folder_to_chunks_, used for filtering out neighbors that belong to the same document 315 | num_nearest_neighbors = 2, # number of nearest neighbors you'd like to use 316 | num_extra_neighbors = 10 # fetch 10 extra neighbors, in the case that fetched neighbors are frequently from same document (filtered out) 317 | ) 318 | 319 | # nearest neighbor info saved to ./train.chunks.knn.dat 320 | 321 | ``` 322 | 323 | ## Citations 324 | 325 | ```bibtex 326 | @misc{borgeaud2022improving, 327 | title = {Improving language models by retrieving from trillions of tokens}, 328 | author = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre}, 329 | year = {2022}, 330 | eprint = {2112.04426}, 331 | archivePrefix = {arXiv}, 332 | primaryClass = {cs.CL} 333 | } 334 | ``` 335 | 336 | ```bibtex 337 | @misc{su2021roformer, 338 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 339 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 340 | year = {2021}, 341 | eprint = {2104.09864}, 342 | archivePrefix = {arXiv}, 343 | primaryClass = {cs.CL} 344 | } 345 | ``` 346 | 347 | ```bibtex 348 | @article{Wang2022DeepNetST, 349 | title = {DeepNet: Scaling Transformers to 1, 000 Layers}, 350 | author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei}, 351 | journal = {ArXiv}, 352 | year = {2022}, 353 | volume = {abs/2203.00555} 354 | } 355 | ``` 356 | 357 | ```bibtex 358 | @misc{zhang2021sparse, 359 | title = {Sparse Attention with Linear Units}, 360 | author = {Biao Zhang and Ivan Titov and Rico Sennrich}, 361 | year = {2021}, 362 | eprint = {2104.07012}, 363 | archivePrefix = {arXiv}, 364 | primaryClass = {cs.CL} 365 | } 366 | ``` 367 | 368 | *I consider always the adult life to be the continuous retrieval of childhood.* - Umberto Eco 369 | -------------------------------------------------------------------------------- /retro_pytorch/retro_pytorch.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | 7 | from retro_pytorch.retrieval import BERT_VOCAB_SIZE 8 | from einops import rearrange, repeat 9 | 10 | # constants 11 | 12 | MIN_DIM_HEAD = 32 13 | 14 | # helper functions 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | def divisible_by(val, divisor): 23 | return (val / divisor).is_integer() 24 | 25 | def cast_tuple(val, num = 1): 26 | return val if isinstance(val, tuple) else ((val,) * num) 27 | 28 | # deepnet init 29 | 30 | def deepnorm_init(transformer, beta, module_name_match_list = ['.ff.', '.to_v', '.to_out']): 31 | for name, module in transformer.named_modules(): 32 | if type(module) != nn.Linear: 33 | continue 34 | 35 | needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list)) 36 | gain = beta if needs_beta_gain else 1 37 | nn.init.xavier_normal_(module.weight.data, gain = gain) 38 | 39 | if exists(module.bias): 40 | nn.init.constant_(module.bias.data, 0) 41 | 42 | # normalization 43 | 44 | class RMSNorm(nn.Module): 45 | def __init__( 46 | self, 47 | dim, 48 | *, 49 | eps = 1e-8, 50 | gated = False 51 | ): 52 | super().__init__() 53 | self.eps = eps 54 | self.scale = dim ** -0.5 55 | self.gamma = nn.Parameter(torch.ones(dim)) 56 | self.weight = nn.Parameter(torch.ones(dim)) if gated else None 57 | 58 | def forward(self, x): 59 | norm = x.norm(keepdim = True, dim = -1) * self.scale 60 | out = (x / norm.clamp(min = self.eps)) * self.gamma 61 | 62 | if not exists(self.weight): 63 | return out 64 | 65 | return out * (x * self.weight).sigmoid() 66 | 67 | # pre and post norm residual wrapper modules 68 | 69 | class PreNorm(nn.Module): 70 | def __init__(self, dim, fn, norm_klass = RMSNorm): 71 | super().__init__() 72 | self.fn = fn 73 | self.norm = norm_klass(dim) 74 | 75 | def forward(self, x, *args, **kwargs): 76 | return self.fn(self.norm(x), *args, **kwargs) + x 77 | 78 | class PostNorm(nn.Module): 79 | def __init__(self, dim, fn, scale_residual = 1, norm_klass = RMSNorm): 80 | super().__init__() 81 | self.fn = fn 82 | self.scale_residual = scale_residual 83 | self.norm = norm_klass(dim) 84 | 85 | def forward(self, x, *args, **kwargs): 86 | residual = x * self.scale_residual 87 | out = self.fn(x, *args, **kwargs) + residual 88 | return self.norm(out) 89 | 90 | # positional embedding 91 | 92 | class RotaryEmbedding(nn.Module): 93 | def __init__(self, dim): 94 | super().__init__() 95 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 96 | self.register_buffer('inv_freq', inv_freq) 97 | 98 | def forward(self, max_seq_len, *, device, offset = 0): 99 | seq = torch.arange(max_seq_len, device = device) + offset 100 | freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) 101 | emb = torch.cat((freqs, freqs), dim = -1) 102 | return rearrange(emb, 'n d -> 1 1 n d') 103 | 104 | def rotate_half(x): 105 | x = rearrange(x, '... (j d) -> ... j d', j = 2) 106 | x1, x2 = x.unbind(dim = -2) 107 | return torch.cat((-x2, x1), dim = -1) 108 | 109 | def apply_rotary_pos_emb(t, freqs): 110 | seq_len, rot_dim = t.shape[-2], freqs.shape[-1] 111 | t, t_pass = t[..., :rot_dim], t[..., rot_dim:] 112 | t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 113 | return torch.cat((t, t_pass), dim = -1) 114 | 115 | # feedforward 116 | 117 | class FeedForward(nn.Module): 118 | def __init__(self, dim, mult = 4, dropout = 0.): 119 | super().__init__() 120 | inner_dim = int(mult * dim) 121 | 122 | self.ff = nn.Sequential( 123 | nn.Linear(dim, inner_dim), 124 | nn.GELU(), 125 | nn.Dropout(dropout), 126 | nn.Linear(inner_dim, dim) 127 | ) 128 | 129 | def forward(self, x): 130 | return self.ff(x) 131 | 132 | # attention 133 | 134 | class Attention(nn.Module): 135 | def __init__( 136 | self, 137 | dim, 138 | *, 139 | context_dim = None, 140 | dim_head = 64, 141 | heads = 8, 142 | causal = False, 143 | dropout = 0., 144 | null_kv = False 145 | ): 146 | super().__init__() 147 | context_dim = default(context_dim, dim) 148 | 149 | self.heads = heads 150 | self.scale = dim_head ** -0.5 151 | self.causal = causal 152 | inner_dim = dim_head * heads 153 | 154 | self.dropout = nn.Dropout(dropout) 155 | 156 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 157 | self.to_k = nn.Linear(context_dim, inner_dim, bias = False) 158 | self.to_v = nn.Linear(context_dim, inner_dim, bias = False) 159 | self.to_out = nn.Linear(inner_dim, dim) 160 | 161 | # allowing for attending to nothing (null function) 162 | # and to save attention from breaking if all retrieved chunks are padded out 163 | self.null_k = nn.Parameter(torch.randn(inner_dim)) if null_kv else None 164 | self.null_v = nn.Parameter(torch.randn(inner_dim)) if null_kv else None 165 | 166 | def forward(self, x, mask = None, context = None, pos_emb = None): 167 | b, device, h, scale = x.shape[0], x.device, self.heads, self.scale 168 | 169 | kv_input = default(context, x) 170 | 171 | q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input) 172 | 173 | # split heads 174 | 175 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 176 | 177 | # scale 178 | 179 | q = q * scale 180 | 181 | # apply relative positional encoding (rotary embeddings) 182 | 183 | if exists(pos_emb): 184 | q_pos_emb, k_pos_emb = cast_tuple(pos_emb, num = 2) 185 | 186 | q = apply_rotary_pos_emb(q, q_pos_emb) 187 | k = apply_rotary_pos_emb(k, k_pos_emb) 188 | 189 | # add null key / values 190 | 191 | if exists(self.null_k): 192 | nk, nv = self.null_k, self.null_v 193 | nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h = h), (nk, nv)) 194 | k = torch.cat((nk, k), dim = -2) 195 | v = torch.cat((nv, v), dim = -2) 196 | 197 | # derive query key similarities 198 | 199 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 200 | 201 | # masking 202 | 203 | mask_value = -torch.finfo(sim.dtype).max 204 | 205 | if exists(mask): 206 | if exists(self.null_k): 207 | mask = F.pad(mask, (1, 0), value = True) 208 | 209 | mask = rearrange(mask, 'b j -> b 1 1 j') 210 | sim = sim.masked_fill(~mask, mask_value) 211 | 212 | if self.causal: 213 | i, j = sim.shape[-2:] 214 | causal_mask = torch.ones(i, j, device = device, dtype = torch.bool).triu(j - i + 1) 215 | sim = sim.masked_fill(causal_mask, mask_value) 216 | 217 | # attention 218 | 219 | attn = sim.softmax(dim = -1) 220 | 221 | attn = self.dropout(attn) 222 | 223 | # aggregate 224 | 225 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 226 | 227 | # merge heads 228 | 229 | out = rearrange(out, 'b h n d -> b n (h d)') 230 | 231 | # combine heads linear out 232 | 233 | return self.to_out(out) 234 | 235 | 236 | class ChunkedCrossAttention(nn.Module): 237 | def __init__( 238 | self, 239 | chunk_size, 240 | **kwargs 241 | ): 242 | super().__init__() 243 | self.chunk_size = chunk_size 244 | self.cross_attn = Attention(null_kv = True, **kwargs) 245 | 246 | def forward(self, x, *, context_mask = None, context, pos_emb = None): 247 | # derive variables 248 | chunk_size = self.chunk_size 249 | 250 | b, n, num_chunks, num_retrieved = x.shape[0], x.shape[-2], *context.shape[-4:-2] 251 | 252 | # if sequence length less than chunk size, do an early return 253 | 254 | if n < self.chunk_size: 255 | return torch.zeros_like(x) 256 | 257 | # causal padding 258 | 259 | causal_padding = chunk_size - 1 260 | 261 | x = F.pad(x, (0, 0, -causal_padding, causal_padding), value = 0.) 262 | 263 | # remove sequence which is ahead of the neighbors retrieved (during inference) 264 | 265 | seq_index = (n // chunk_size) * chunk_size 266 | x, x_remainder = x[:, :seq_index], x[:, seq_index:] 267 | 268 | seq_remain_len = x_remainder.shape[-2] 269 | 270 | # take care of rotary positional embedding 271 | # make sure queries positions are properly shifted to the future 272 | 273 | q_pos_emb, k_pos_emb = pos_emb 274 | q_pos_emb = F.pad(q_pos_emb, (0, 0, -causal_padding, causal_padding), value = 0.) 275 | 276 | k_pos_emb = repeat(k_pos_emb, 'b h n d -> b h (r n) d', r = num_retrieved) 277 | pos_emb = (q_pos_emb, k_pos_emb) 278 | 279 | # reshape so we have chunk to chunk attention, without breaking causality 280 | 281 | x = rearrange(x, 'b (k n) d -> (b k) n d', k = num_chunks) 282 | context = rearrange(context, 'b k r n d -> (b k) (r n) d') 283 | 284 | if exists(context_mask): 285 | context_mask = rearrange(context_mask, 'b k r n -> (b k) (r n)') 286 | 287 | # cross attention 288 | 289 | out = self.cross_attn(x, context = context, mask = context_mask, pos_emb = pos_emb) 290 | 291 | # reshape back to original sequence 292 | 293 | out = rearrange(out, '(b k) n d -> b (k n) d', b = b) 294 | 295 | # pad back to original, with 0s at the beginning (which will be added to the residual and be fine) 296 | 297 | out = F.pad(out, (0, 0, causal_padding, -causal_padding + seq_remain_len), value = 0.) 298 | return out 299 | 300 | # encoder and decoder classes 301 | 302 | class Encoder(nn.Module): 303 | def __init__( 304 | self, 305 | dim, 306 | *, 307 | depth, 308 | context_dim = None, 309 | causal = False, 310 | heads = 8, 311 | dim_head = 64, 312 | attn_dropout = 0., 313 | ff_mult = 4, 314 | ff_dropout = 0., 315 | final_norm = True, 316 | cross_attn_layers = None, 317 | post_norm = False, 318 | output_dim = None, 319 | norm_klass = RMSNorm, 320 | scale_residual = 1. 321 | ): 322 | super().__init__() 323 | self.layers = nn.ModuleList([]) 324 | 325 | # partial rotary embeddings, which is better than full rotary 326 | # Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/ 327 | 328 | rotary_emb_dim = min(dim_head, MIN_DIM_HEAD) 329 | self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) 330 | 331 | wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass) 332 | 333 | for layer_num in range(1, depth + 1): 334 | has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers 335 | 336 | self.layers.append(nn.ModuleList([ 337 | wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)), 338 | wrapper(Attention(dim = dim, context_dim = context_dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None, 339 | wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)), 340 | ])) 341 | 342 | self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity() 343 | self.project_out = nn.Linear(dim, output_dim) if exists(output_dim) else nn.Identity() 344 | 345 | def forward(self, x, *, mask = None, chunked_seq): 346 | device, chunk_size, seq_len = x.device, x.shape[-2], chunked_seq.shape[-2] 347 | 348 | q_pos_emb = self.rotary_pos_emb(chunk_size, device = device) 349 | k_pos_emb = self.rotary_pos_emb(seq_len, device = device) 350 | 351 | for attn, cross_attn, ff in self.layers: 352 | x = attn(x, mask = mask, pos_emb = q_pos_emb) 353 | 354 | if exists(cross_attn): 355 | x = cross_attn(x, context = chunked_seq, pos_emb = (q_pos_emb, k_pos_emb)) 356 | 357 | x = ff(x) 358 | 359 | x = self.norm_out(x) 360 | return self.project_out(x) 361 | 362 | class Decoder(nn.Module): 363 | def __init__( 364 | self, 365 | dim, 366 | *, 367 | depth, 368 | heads = 8, 369 | dim_head = 64, 370 | attn_dropout = 0., 371 | ff_mult = 4, 372 | ff_dropout = 0., 373 | final_norm = True, 374 | cross_attn_layers = None, 375 | chunk_size = 64, 376 | post_norm = False, 377 | norm_klass = RMSNorm, 378 | scale_residual = 1. 379 | ): 380 | super().__init__() 381 | self.layers = nn.ModuleList([]) 382 | 383 | # partial rotary embeddings, which is better than full rotary 384 | # Wang and Komatsuzaki et al https://github.com/kingoflolz/mesh-transformer-jax/ 385 | 386 | rotary_emb_dim = min(dim_head, MIN_DIM_HEAD) 387 | self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) 388 | 389 | wrapper = partial(PreNorm, dim, norm_klass = norm_klass) if not post_norm else partial(PostNorm, dim, scale_residual = scale_residual, norm_klass = norm_klass) 390 | 391 | self.chunk_size = chunk_size 392 | 393 | for layer_num in range(1, depth + 1): 394 | has_cross_attn = not exists(cross_attn_layers) or layer_num in cross_attn_layers 395 | 396 | self.layers.append(nn.ModuleList([ 397 | wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = True)), 398 | wrapper(ChunkedCrossAttention(chunk_size = chunk_size, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)) if has_cross_attn else None, 399 | wrapper(FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)), 400 | ])) 401 | 402 | self.norm_out = norm_klass(dim) if final_norm and not post_norm else nn.Identity() 403 | 404 | def forward(self, x, *, encoder = None, encoder_retrieved_mask = None, context_mask = None, retrieved = None): 405 | device, seq_len = x.device, x.shape[-2] 406 | self_attn_pos_emb = self.rotary_pos_emb(seq_len, device = device) 407 | 408 | # calculate seq index 409 | 410 | num_seq_chunks = seq_len // self.chunk_size 411 | seq_index = num_seq_chunks * self.chunk_size 412 | 413 | # rotary positions on the retrieved chunks 414 | 415 | if exists(retrieved): 416 | num_chunks, num_neighbors, chunk_size = retrieved.shape[-4:-1] 417 | 418 | cross_attn_q_pos_emb = self.rotary_pos_emb(self.chunk_size, device = device, offset = self.chunk_size - 1) # need to add extra chunk size, since it will be shifted 419 | cross_attn_k_pos_emb = self.rotary_pos_emb(chunk_size, device = device) 420 | 421 | cross_attn_pos_emb = (cross_attn_q_pos_emb, cross_attn_k_pos_emb) 422 | 423 | # keep track of whether retrieved tokens are encoded yet 424 | 425 | retrieved_encoded = False 426 | 427 | # go through the decoder layers 428 | 429 | for attn, cross_attn, ff in self.layers: 430 | x = attn(x, pos_emb = self_attn_pos_emb) 431 | 432 | if exists(cross_attn) and exists(retrieved): 433 | if not retrieved_encoded: 434 | retrieved = rearrange(retrieved, 'b k r n d -> (b k r) n d') 435 | seq_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors) 436 | 437 | retrieved = encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = seq_as_context) 438 | retrieved = rearrange(retrieved, '(b k r) n d -> b k r n d', k = num_chunks, r = num_neighbors) 439 | retrieved_encoded = True 440 | 441 | x = cross_attn( 442 | x, 443 | context = retrieved, 444 | context_mask = context_mask, 445 | pos_emb = cross_attn_pos_emb 446 | ) 447 | 448 | x = ff(x) 449 | 450 | return self.norm_out(x) 451 | 452 | # main class 453 | 454 | class RETRO(nn.Module): 455 | def __init__( 456 | self, 457 | *, 458 | num_tokens = BERT_VOCAB_SIZE, 459 | max_seq_len = 2048, 460 | enc_dim = 896, 461 | enc_depth = 2, 462 | enc_cross_attn_layers = None, 463 | dec_depth = 12, 464 | dec_cross_attn_layers = (1, 3, 6, 9), 465 | heads = 8, 466 | dec_dim = 768, 467 | dim_head = 64, 468 | enc_attn_dropout = 0., 469 | enc_ff_dropout = 0., 470 | dec_attn_dropout = 0., 471 | dec_ff_dropout = 0., 472 | chunk_size = 64, 473 | pad_id = 0, 474 | enc_scale_residual = None, 475 | dec_scale_residual = None, 476 | norm_klass = None, 477 | gated_rmsnorm = False, 478 | use_deepnet = False 479 | ): 480 | super().__init__() 481 | assert dim_head >= MIN_DIM_HEAD, f'dimension per head must be greater than {MIN_DIM_HEAD}' 482 | self.seq_len = max_seq_len 483 | self.pad_id = pad_id 484 | 485 | self.token_emb = nn.Embedding(num_tokens, enc_dim) 486 | self.pos_emb = nn.Embedding(max_seq_len, enc_dim) 487 | 488 | self.chunk_size = chunk_size 489 | 490 | self.to_decoder_model_dim = nn.Linear(enc_dim, dec_dim) if enc_dim != dec_dim else nn.Identity() 491 | 492 | # for deepnet, residual scales 493 | # follow equation in Figure 2. in https://arxiv.org/abs/2203.00555 494 | 495 | norm_klass = default(norm_klass, RMSNorm) 496 | 497 | if use_deepnet: 498 | enc_scale_residual = default(enc_scale_residual, 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625) 499 | dec_scale_residual = default(dec_scale_residual, (3 * dec_depth) ** 0.25) 500 | norm_klass = nn.LayerNorm 501 | 502 | # allow for gated rmsnorm 503 | 504 | if gated_rmsnorm: 505 | norm_klass = partial(RMSNorm, gated = True) 506 | 507 | # define encoder and decoders 508 | 509 | self.encoder = Encoder( 510 | dim = enc_dim, 511 | context_dim = dec_dim, 512 | dim_head = dim_head, 513 | depth = enc_depth, 514 | attn_dropout = enc_attn_dropout, 515 | ff_dropout = enc_ff_dropout, 516 | cross_attn_layers = enc_cross_attn_layers, 517 | post_norm = use_deepnet, 518 | norm_klass = norm_klass, 519 | scale_residual = enc_scale_residual, 520 | output_dim = dec_dim 521 | ) 522 | 523 | self.decoder = Decoder( 524 | dim = dec_dim, 525 | depth = dec_depth, 526 | dim_head = dim_head, 527 | attn_dropout = dec_attn_dropout, 528 | ff_dropout = dec_ff_dropout, 529 | cross_attn_layers = dec_cross_attn_layers, 530 | chunk_size = chunk_size, 531 | post_norm = use_deepnet, 532 | norm_klass = norm_klass, 533 | scale_residual = dec_scale_residual 534 | ) 535 | 536 | self.to_logits = nn.Linear(dec_dim, num_tokens) 537 | 538 | # deepnet has special init of weight matrices 539 | 540 | if use_deepnet: 541 | deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625) 542 | deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25) 543 | 544 | def forward_without_retrieval( 545 | self, 546 | seq 547 | ): 548 | # embed sequence 549 | 550 | embed = self.token_emb(seq) 551 | embed = embed[:, :self.seq_len] 552 | 553 | # get absolute positional embedding 554 | 555 | pos_emb = self.pos_emb(torch.arange(embed.shape[1], device = embed.device)) 556 | pos_emb = rearrange(pos_emb, 'n d -> 1 n d') 557 | embed = embed + pos_emb 558 | 559 | embed = self.to_decoder_model_dim(embed) 560 | embed = self.decoder(embed) 561 | 562 | # project to logits 563 | 564 | return self.to_logits(embed) 565 | 566 | def forward( 567 | self, 568 | seq, 569 | retrieved = None, 570 | return_loss = False 571 | ): 572 | """ 573 | b - batch 574 | n - sequence length / chunk length 575 | k - number of chunks 576 | d - feature dimension 577 | r - num retrieved neighbors 578 | """ 579 | 580 | if not exists(retrieved): 581 | return self.forward_without_retrieval(seq) 582 | 583 | assert not (return_loss and not self.training), 'must be training if returning loss' 584 | 585 | # assume padding token id (usually 0.) is to be masked out 586 | 587 | mask = retrieved != self.pad_id 588 | 589 | # handle some user inputs 590 | 591 | if retrieved.ndim == 3: 592 | retrieved = rearrange(retrieved, 'b k n -> b k 1 n') # 1 neighbor retrieved 593 | 594 | # if training, derive labels 595 | 596 | if return_loss: 597 | seq, labels = seq[:, :-1], seq[:, 1:] 598 | 599 | # variables 600 | 601 | n, num_chunks, num_neighbors, chunk_size, retrieved_shape, device = seq.shape[-1], *retrieved.shape[-3:], retrieved.shape, seq.device 602 | 603 | assert chunk_size >= self.chunk_size, 'chunk size of retrieval input must be greater or equal to the designated chunk_size on RETRO initialization' 604 | 605 | num_seq_chunks = n // self.chunk_size 606 | assert num_chunks == num_seq_chunks, f'sequence requires {num_seq_chunks} retrieved chunks, but only {num_chunks} passed in' 607 | 608 | # sequence index at which k-nearest neighbors have not been fetched yet after 609 | 610 | seq_index = num_seq_chunks * self.chunk_size 611 | 612 | # embed both sequence and retrieved chunks 613 | 614 | embed = self.token_emb(seq) 615 | retrieved = self.token_emb(retrieved) 616 | 617 | # get absolute positional embedding 618 | 619 | pos_emb = self.pos_emb(torch.arange(n, device = device)) 620 | pos_emb = rearrange(pos_emb, 'n d -> 1 n d') 621 | embed = embed + pos_emb 622 | 623 | # handle masks for encoder and decoder, if needed 624 | 625 | encoder_retrieved_mask = decoder_retrieved_mask = None 626 | 627 | if exists(mask): 628 | assert mask.shape == retrieved_shape, 'retrieval mask must be of the same shape as the retrieval tokens' 629 | encoder_retrieved_mask = rearrange(mask, 'b k r n -> (b k r) n') 630 | decoder_retrieved_mask = mask 631 | 632 | # project both sequence embedding and retrieved embedding to decoder dimension if necessary 633 | 634 | embed = self.to_decoder_model_dim(embed) 635 | 636 | # decode 637 | 638 | embed = self.decoder( 639 | embed, 640 | encoder = self.encoder, 641 | context_mask = decoder_retrieved_mask, 642 | encoder_retrieved_mask = encoder_retrieved_mask, 643 | retrieved = retrieved 644 | ) 645 | 646 | # project to logits 647 | 648 | logits = self.to_logits(embed) 649 | 650 | if not return_loss: 651 | return logits 652 | 653 | # cross entropy loss 654 | 655 | loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = self.pad_id) 656 | return loss 657 | --------------------------------------------------------------------------------