├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── htm.png ├── htm_pytorch ├── __init__.py └── htm_pytorch.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Hierarchical Transformer Memory (HTM) - Pytorch 4 | 5 | Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a simple method to allow transformers to attend to memories of the past efficiently. Original Jax repository 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install htm-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from htm_pytorch import HTMAttention 18 | 19 | attn = HTMAttention( 20 | dim = 512, 21 | heads = 8, # number of heads for within-memory attention 22 | dim_head = 64, # dimension per head for within-memory attention 23 | topk_mems = 8, # how many memory chunks to select for 24 | mem_chunk_size = 32, # number of tokens in each memory chunk 25 | add_pos_enc = True # whether to add positional encoding to the memories 26 | ) 27 | 28 | queries = torch.randn(1, 128, 512) # queries 29 | memories = torch.randn(1, 20000, 512) # memories, of any size 30 | mask = torch.ones(1, 20000).bool() # memory mask 31 | 32 | attended = attn(queries, memories, mask = mask) # (1, 128, 512) 33 | ``` 34 | 35 | If you want the entire HTM Block (which contains the layernorm for the input followed by a skip connection), just import `HTMBlock` instead 36 | 37 | ```python 38 | import torch 39 | from htm_pytorch import HTMBlock 40 | 41 | block = HTMBlock( 42 | dim = 512, 43 | topk_mems = 8, 44 | mem_chunk_size = 32 45 | ) 46 | 47 | queries = torch.randn(1, 128, 512) 48 | memories = torch.randn(1, 20000, 512) 49 | mask = torch.ones(1, 20000).bool() 50 | 51 | out = block(queries, memories, mask = mask) # (1, 128, 512) 52 | ``` 53 | 54 | ## Citations 55 | 56 | ```bibtex 57 | @misc{lampinen2021mental, 58 | title = {Towards mental time travel: a hierarchical memory for reinforcement learning agents}, 59 | author = {Andrew Kyle Lampinen and Stephanie C. Y. Chan and Andrea Banino and Felix Hill}, 60 | year = {2021}, 61 | eprint = {2105.14039}, 62 | archivePrefix = {arXiv}, 63 | primaryClass = {cs.LG} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /htm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/HTM-pytorch/db4deb58cba750c8fc66a74b98e513dc66bd5296/htm.png -------------------------------------------------------------------------------- /htm_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from htm_pytorch.htm_pytorch import HTMAttention, HTMBlock 2 | -------------------------------------------------------------------------------- /htm_pytorch/htm_pytorch.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | 8 | # helpers 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | def default(val, d): 14 | return val if exists(val) else d 15 | 16 | def pad_to_multiple(t, multiple, dim = -2, value = 0.): 17 | seq_len = t.shape[dim] 18 | pad_to_len = ceil(seq_len / multiple) * multiple 19 | remainder = pad_to_len - seq_len 20 | 21 | if remainder == 0: 22 | return t 23 | 24 | zeroes = (0, 0) * (-dim - 1) 25 | padded_t = F.pad(t, (*zeroes, remainder, 0), value = value) 26 | return padded_t 27 | 28 | # positional encoding 29 | 30 | class SinusoidalPosition(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | min_timescale = 2., 35 | max_timescale = 1e4 36 | ): 37 | super().__init__() 38 | freqs = torch.arange(0, dim, min_timescale) 39 | inv_freqs = max_timescale ** (-freqs / dim) 40 | self.register_buffer('inv_freqs', inv_freqs) 41 | 42 | def forward(self, x): 43 | seq_len = x.shape[-2] 44 | seq = torch.arange(seq_len - 1, -1, -1.) 45 | sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d') 46 | pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1) 47 | return pos_emb 48 | 49 | # multi-head attention 50 | 51 | class Attention(nn.Module): 52 | def __init__( 53 | self, 54 | dim, 55 | dim_head = 64, 56 | heads = 8, 57 | ): 58 | super().__init__() 59 | self.scale = dim_head ** -0.5 60 | self.heads = heads 61 | inner_dim = dim_head * heads 62 | 63 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 64 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 65 | self.to_out = nn.Linear(inner_dim, dim) 66 | 67 | def forward( 68 | self, 69 | x, 70 | mems, 71 | mask = None 72 | ): 73 | h = self.heads 74 | q, k, v = self.to_q(x), *self.to_kv(mems).chunk(2, dim = -1) 75 | 76 | q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v)) 77 | q = q * self.scale 78 | 79 | sim = einsum('b m i d, b m i j d -> b m i j', q, k) 80 | 81 | if exists(mask): 82 | mask = repeat(mask, 'b ... -> (b h) ...', h = h) 83 | mask_value = -torch.finfo(sim.dtype).max 84 | sim = sim.masked_fill(~mask, mask_value) 85 | 86 | attn = sim.softmax(dim = -1) 87 | 88 | out = einsum('... i j, ... i j d -> ... i d', attn, v) 89 | out = rearrange(out, '(b h) ... d -> b ... (h d)', h = h) 90 | return self.to_out(out) 91 | 92 | # main class 93 | 94 | class HTMAttention(nn.Module): 95 | def __init__( 96 | self, 97 | dim, 98 | heads, 99 | topk_mems = 2, 100 | mem_chunk_size = 32, 101 | dim_head = 64, 102 | add_pos_enc = True, 103 | eps = 1e-5 104 | ): 105 | super().__init__() 106 | self.dim = dim 107 | self.eps = eps 108 | self.scale = dim ** -0.5 109 | 110 | self.to_summary_queries = nn.Linear(dim, dim) 111 | self.to_summary_keys = nn.Linear(dim, dim) 112 | 113 | self.attn = Attention(dim = dim, heads = heads, dim_head = dim_head) 114 | 115 | self.topk_mems = topk_mems 116 | self.mem_chunk_size = mem_chunk_size 117 | self.pos_emb = SinusoidalPosition(dim = dim) if add_pos_enc else None 118 | 119 | def forward( 120 | self, 121 | queries, 122 | memories, 123 | mask = None, 124 | chunk_attn_mask = None 125 | ): 126 | dim, query_len, mem_chunk_size, topk_mems, scale, eps = self.dim, queries.shape[1], self.mem_chunk_size, self.topk_mems, self.scale, self.eps 127 | 128 | # pad memories, and the memory mask, if needed 129 | # and then divide into chunks 130 | 131 | memories = pad_to_multiple(memories, mem_chunk_size, dim = -2, value = 0.) 132 | memories = rearrange(memories, 'b (n c) d -> b n c d', c = mem_chunk_size) 133 | 134 | if exists(mask): 135 | mask = pad_to_multiple(mask, mem_chunk_size, dim = -1, value = False) 136 | mask = rearrange(mask, 'b (n c) -> b n c', c = mem_chunk_size) 137 | 138 | # summarize memories through mean-pool, accounting for mask 139 | 140 | if exists(mask): 141 | mean_mask = rearrange(mask, '... -> ... ()') 142 | memories = memories.masked_fill(~mean_mask, 0.) 143 | numer = memories.sum(dim = 2) 144 | denom = mean_mask.sum(dim = 2) 145 | summarized_memories = numer / (denom + eps) 146 | else: 147 | summarized_memories = memories.mean(dim = 2) 148 | 149 | # derive queries and summarized memory keys 150 | 151 | summary_queries = self.to_summary_queries(queries) 152 | summary_keys = self.to_summary_keys(summarized_memories.detach()) 153 | 154 | # do a single head attention over summary keys 155 | 156 | sim = einsum('b i d, b j d -> b i j', summary_queries, summary_keys) * scale 157 | mask_value = -torch.finfo(sim.dtype).max 158 | 159 | if exists(mask): 160 | chunk_mask = mask.any(dim = 2) 161 | chunk_mask = rearrange(chunk_mask, 'b j -> b () j') 162 | sim = sim.masked_fill(~chunk_mask, mask_value) 163 | 164 | if exists(chunk_attn_mask): 165 | sim = sim.masked_fill(~chunk_attn_mask, mask_value) 166 | 167 | topk_logits, topk_indices = sim.topk(k = topk_mems, dim = -1) 168 | weights = topk_logits.softmax(dim = -1) 169 | 170 | # ready queries for in-memory attention 171 | 172 | queries = repeat(queries, 'b n d -> b k n d', k = topk_mems) 173 | 174 | # select the topk memories 175 | 176 | memories = repeat(memories, 'b m j d -> b m i j d', i = query_len) 177 | mem_topk_indices = repeat(topk_indices, 'b i m -> b m i j d', j = mem_chunk_size, d = dim) 178 | selected_memories = memories.gather(1, mem_topk_indices) 179 | 180 | # positional encoding 181 | 182 | if exists(self.pos_emb): 183 | pos_emb = self.pos_emb(memories) 184 | selected_memories = selected_memories + rearrange(pos_emb, 'n d -> () () () n d') 185 | 186 | # select the mask 187 | 188 | selected_mask = None 189 | if exists(mask): 190 | mask = repeat(mask, 'b m j -> b m i j', i = query_len) 191 | mask_topk_indices = repeat(topk_indices, 'b i m -> b m i j', j = mem_chunk_size) 192 | selected_mask = mask.gather(1, mask_topk_indices) 193 | 194 | # now do in-memory attention 195 | 196 | within_mem_output = self.attn( 197 | queries, 198 | selected_memories.detach(), 199 | mask = selected_mask 200 | ) 201 | 202 | # weight the in-memory attention outputs 203 | 204 | weighted_output = within_mem_output * rearrange(weights, 'b i m -> b m i ()') 205 | output = weighted_output.sum(dim = 1) 206 | return output 207 | 208 | # HTM Block 209 | 210 | class HTMBlock(nn.Module): 211 | def __init__(self, dim, **kwargs): 212 | super().__init__() 213 | self.norm = nn.LayerNorm(dim) 214 | self.attn = HTMAttention(dim = dim, **kwargs) 215 | def forward( 216 | self, 217 | queries, 218 | memories, 219 | **kwargs 220 | ): 221 | queries = self.norm(queries) 222 | out = self.attn(queries, memories, **kwargs) + queries 223 | return out 224 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'htm-pytorch', 5 | packages = find_packages(), 6 | version = '0.0.4', 7 | license='MIT', 8 | description = 'Hierarchical Transformer Memory - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/htm-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'deep learning', 15 | 'attention-mechanism', 16 | 'memory' 17 | ], 18 | install_requires=[ 19 | 'einops>=0.3', 20 | 'torch>=1.6' 21 | ], 22 | classifiers=[ 23 | 'Development Status :: 4 - Beta', 24 | 'Intended Audience :: Developers', 25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 26 | 'License :: OSI Approved :: MIT License', 27 | 'Programming Language :: Python :: 3.6', 28 | ], 29 | ) 30 | --------------------------------------------------------------------------------