├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── htm.png
├── htm_pytorch
├── __init__.py
└── htm_pytorch.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Hierarchical Transformer Memory (HTM) - Pytorch
4 |
5 | Implementation of Hierarchical Transformer Memory (HTM) for Pytorch. This Deepmind paper proposes a simple method to allow transformers to attend to memories of the past efficiently. Original Jax repository
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install htm-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from htm_pytorch import HTMAttention
18 |
19 | attn = HTMAttention(
20 | dim = 512,
21 | heads = 8, # number of heads for within-memory attention
22 | dim_head = 64, # dimension per head for within-memory attention
23 | topk_mems = 8, # how many memory chunks to select for
24 | mem_chunk_size = 32, # number of tokens in each memory chunk
25 | add_pos_enc = True # whether to add positional encoding to the memories
26 | )
27 |
28 | queries = torch.randn(1, 128, 512) # queries
29 | memories = torch.randn(1, 20000, 512) # memories, of any size
30 | mask = torch.ones(1, 20000).bool() # memory mask
31 |
32 | attended = attn(queries, memories, mask = mask) # (1, 128, 512)
33 | ```
34 |
35 | If you want the entire HTM Block (which contains the layernorm for the input followed by a skip connection), just import `HTMBlock` instead
36 |
37 | ```python
38 | import torch
39 | from htm_pytorch import HTMBlock
40 |
41 | block = HTMBlock(
42 | dim = 512,
43 | topk_mems = 8,
44 | mem_chunk_size = 32
45 | )
46 |
47 | queries = torch.randn(1, 128, 512)
48 | memories = torch.randn(1, 20000, 512)
49 | mask = torch.ones(1, 20000).bool()
50 |
51 | out = block(queries, memories, mask = mask) # (1, 128, 512)
52 | ```
53 |
54 | ## Citations
55 |
56 | ```bibtex
57 | @misc{lampinen2021mental,
58 | title = {Towards mental time travel: a hierarchical memory for reinforcement learning agents},
59 | author = {Andrew Kyle Lampinen and Stephanie C. Y. Chan and Andrea Banino and Felix Hill},
60 | year = {2021},
61 | eprint = {2105.14039},
62 | archivePrefix = {arXiv},
63 | primaryClass = {cs.LG}
64 | }
65 | ```
66 |
--------------------------------------------------------------------------------
/htm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/HTM-pytorch/db4deb58cba750c8fc66a74b98e513dc66bd5296/htm.png
--------------------------------------------------------------------------------
/htm_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from htm_pytorch.htm_pytorch import HTMAttention, HTMBlock
2 |
--------------------------------------------------------------------------------
/htm_pytorch/htm_pytorch.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange, repeat
7 |
8 | # helpers
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | def default(val, d):
14 | return val if exists(val) else d
15 |
16 | def pad_to_multiple(t, multiple, dim = -2, value = 0.):
17 | seq_len = t.shape[dim]
18 | pad_to_len = ceil(seq_len / multiple) * multiple
19 | remainder = pad_to_len - seq_len
20 |
21 | if remainder == 0:
22 | return t
23 |
24 | zeroes = (0, 0) * (-dim - 1)
25 | padded_t = F.pad(t, (*zeroes, remainder, 0), value = value)
26 | return padded_t
27 |
28 | # positional encoding
29 |
30 | class SinusoidalPosition(nn.Module):
31 | def __init__(
32 | self,
33 | dim,
34 | min_timescale = 2.,
35 | max_timescale = 1e4
36 | ):
37 | super().__init__()
38 | freqs = torch.arange(0, dim, min_timescale)
39 | inv_freqs = max_timescale ** (-freqs / dim)
40 | self.register_buffer('inv_freqs', inv_freqs)
41 |
42 | def forward(self, x):
43 | seq_len = x.shape[-2]
44 | seq = torch.arange(seq_len - 1, -1, -1.)
45 | sinusoidal_inp = rearrange(seq, 'n -> n ()') * rearrange(self.inv_freqs, 'd -> () d')
46 | pos_emb = torch.cat((sinusoidal_inp.sin(), sinusoidal_inp.cos()), dim = -1)
47 | return pos_emb
48 |
49 | # multi-head attention
50 |
51 | class Attention(nn.Module):
52 | def __init__(
53 | self,
54 | dim,
55 | dim_head = 64,
56 | heads = 8,
57 | ):
58 | super().__init__()
59 | self.scale = dim_head ** -0.5
60 | self.heads = heads
61 | inner_dim = dim_head * heads
62 |
63 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
64 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
65 | self.to_out = nn.Linear(inner_dim, dim)
66 |
67 | def forward(
68 | self,
69 | x,
70 | mems,
71 | mask = None
72 | ):
73 | h = self.heads
74 | q, k, v = self.to_q(x), *self.to_kv(mems).chunk(2, dim = -1)
75 |
76 | q, k, v = map(lambda t: rearrange(t, 'b ... (h d) -> (b h) ... d', h = h), (q, k, v))
77 | q = q * self.scale
78 |
79 | sim = einsum('b m i d, b m i j d -> b m i j', q, k)
80 |
81 | if exists(mask):
82 | mask = repeat(mask, 'b ... -> (b h) ...', h = h)
83 | mask_value = -torch.finfo(sim.dtype).max
84 | sim = sim.masked_fill(~mask, mask_value)
85 |
86 | attn = sim.softmax(dim = -1)
87 |
88 | out = einsum('... i j, ... i j d -> ... i d', attn, v)
89 | out = rearrange(out, '(b h) ... d -> b ... (h d)', h = h)
90 | return self.to_out(out)
91 |
92 | # main class
93 |
94 | class HTMAttention(nn.Module):
95 | def __init__(
96 | self,
97 | dim,
98 | heads,
99 | topk_mems = 2,
100 | mem_chunk_size = 32,
101 | dim_head = 64,
102 | add_pos_enc = True,
103 | eps = 1e-5
104 | ):
105 | super().__init__()
106 | self.dim = dim
107 | self.eps = eps
108 | self.scale = dim ** -0.5
109 |
110 | self.to_summary_queries = nn.Linear(dim, dim)
111 | self.to_summary_keys = nn.Linear(dim, dim)
112 |
113 | self.attn = Attention(dim = dim, heads = heads, dim_head = dim_head)
114 |
115 | self.topk_mems = topk_mems
116 | self.mem_chunk_size = mem_chunk_size
117 | self.pos_emb = SinusoidalPosition(dim = dim) if add_pos_enc else None
118 |
119 | def forward(
120 | self,
121 | queries,
122 | memories,
123 | mask = None,
124 | chunk_attn_mask = None
125 | ):
126 | dim, query_len, mem_chunk_size, topk_mems, scale, eps = self.dim, queries.shape[1], self.mem_chunk_size, self.topk_mems, self.scale, self.eps
127 |
128 | # pad memories, and the memory mask, if needed
129 | # and then divide into chunks
130 |
131 | memories = pad_to_multiple(memories, mem_chunk_size, dim = -2, value = 0.)
132 | memories = rearrange(memories, 'b (n c) d -> b n c d', c = mem_chunk_size)
133 |
134 | if exists(mask):
135 | mask = pad_to_multiple(mask, mem_chunk_size, dim = -1, value = False)
136 | mask = rearrange(mask, 'b (n c) -> b n c', c = mem_chunk_size)
137 |
138 | # summarize memories through mean-pool, accounting for mask
139 |
140 | if exists(mask):
141 | mean_mask = rearrange(mask, '... -> ... ()')
142 | memories = memories.masked_fill(~mean_mask, 0.)
143 | numer = memories.sum(dim = 2)
144 | denom = mean_mask.sum(dim = 2)
145 | summarized_memories = numer / (denom + eps)
146 | else:
147 | summarized_memories = memories.mean(dim = 2)
148 |
149 | # derive queries and summarized memory keys
150 |
151 | summary_queries = self.to_summary_queries(queries)
152 | summary_keys = self.to_summary_keys(summarized_memories.detach())
153 |
154 | # do a single head attention over summary keys
155 |
156 | sim = einsum('b i d, b j d -> b i j', summary_queries, summary_keys) * scale
157 | mask_value = -torch.finfo(sim.dtype).max
158 |
159 | if exists(mask):
160 | chunk_mask = mask.any(dim = 2)
161 | chunk_mask = rearrange(chunk_mask, 'b j -> b () j')
162 | sim = sim.masked_fill(~chunk_mask, mask_value)
163 |
164 | if exists(chunk_attn_mask):
165 | sim = sim.masked_fill(~chunk_attn_mask, mask_value)
166 |
167 | topk_logits, topk_indices = sim.topk(k = topk_mems, dim = -1)
168 | weights = topk_logits.softmax(dim = -1)
169 |
170 | # ready queries for in-memory attention
171 |
172 | queries = repeat(queries, 'b n d -> b k n d', k = topk_mems)
173 |
174 | # select the topk memories
175 |
176 | memories = repeat(memories, 'b m j d -> b m i j d', i = query_len)
177 | mem_topk_indices = repeat(topk_indices, 'b i m -> b m i j d', j = mem_chunk_size, d = dim)
178 | selected_memories = memories.gather(1, mem_topk_indices)
179 |
180 | # positional encoding
181 |
182 | if exists(self.pos_emb):
183 | pos_emb = self.pos_emb(memories)
184 | selected_memories = selected_memories + rearrange(pos_emb, 'n d -> () () () n d')
185 |
186 | # select the mask
187 |
188 | selected_mask = None
189 | if exists(mask):
190 | mask = repeat(mask, 'b m j -> b m i j', i = query_len)
191 | mask_topk_indices = repeat(topk_indices, 'b i m -> b m i j', j = mem_chunk_size)
192 | selected_mask = mask.gather(1, mask_topk_indices)
193 |
194 | # now do in-memory attention
195 |
196 | within_mem_output = self.attn(
197 | queries,
198 | selected_memories.detach(),
199 | mask = selected_mask
200 | )
201 |
202 | # weight the in-memory attention outputs
203 |
204 | weighted_output = within_mem_output * rearrange(weights, 'b i m -> b m i ()')
205 | output = weighted_output.sum(dim = 1)
206 | return output
207 |
208 | # HTM Block
209 |
210 | class HTMBlock(nn.Module):
211 | def __init__(self, dim, **kwargs):
212 | super().__init__()
213 | self.norm = nn.LayerNorm(dim)
214 | self.attn = HTMAttention(dim = dim, **kwargs)
215 | def forward(
216 | self,
217 | queries,
218 | memories,
219 | **kwargs
220 | ):
221 | queries = self.norm(queries)
222 | out = self.attn(queries, memories, **kwargs) + queries
223 | return out
224 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'htm-pytorch',
5 | packages = find_packages(),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'Hierarchical Transformer Memory - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/htm-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning',
15 | 'attention-mechanism',
16 | 'memory'
17 | ],
18 | install_requires=[
19 | 'einops>=0.3',
20 | 'torch>=1.6'
21 | ],
22 | classifiers=[
23 | 'Development Status :: 4 - Beta',
24 | 'Intended Audience :: Developers',
25 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26 | 'License :: OSI Approved :: MIT License',
27 | 'Programming Language :: Python :: 3.6',
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------