├── heinsen_routing ├── __init__.py └── heinsen_routing.py ├── assets └── fig_sample_credit_assignments_vision_from_paper.png ├── setup.py ├── LICENSE ├── .gitignore └── README.md /heinsen_routing/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from .heinsen_routing import EfficientVectorRouting, DefinableVectorRouting, GenerativeMatrixRouting 3 | -------------------------------------------------------------------------------- /assets/fig_sample_credit_assignments_vision_from_paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/glassroom/heinsen_routing/HEAD/assets/fig_sample_credit_assignments_vision_from_paper.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from setuptools import setup 3 | 4 | setup(name='heinsen_routing', 5 | version='1.0.4', 6 | description='Implementation of the routing algorithm proposed by Franz A. Heinsen, 2019 and 2022 variants.', 7 | url='https://github.com/glassroom/heinsen_routing', 8 | author='Franz A. Heinsen', 9 | author_email='franz@glassroom.com', 10 | license='MIT', 11 | packages=['heinsen_routing'], 12 | install_requires='torch', 13 | zip_safe=False) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019-Present GlassRoom Software LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | python 3 | 4 | # Dataset directory 5 | .data 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /heinsen_routing/heinsen_routing.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from __future__ import annotations 4 | from typing import Union, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import einsum 9 | 10 | 11 | class EfficientVectorRouting(nn.Module): 12 | """ 13 | Routes input vectors to the output vectors that maximize "bang per bit" 14 | by best predicting them, with optimizations that reduce parameter count, 15 | memory use, and computation by orders of magnitude. Each vector is a 16 | capsule representing an entity in a context (e.g., a word in a paragraph, 17 | an object in an image). See "An Algorithm for Routing Vectors in 18 | Sequences" (Heinsen, 2022), https://arxiv.org/abs/2211.11754. 19 | 20 | Args: 21 | n_inp: int, number of input vectors. If -1, the number is variable. 22 | n_out: int, number of output vectors. 23 | d_inp: int, size of input vectors. 24 | d_out: int, size of output vectors. 25 | n_iters: (optional) int, number of iterations. Default: 2. 26 | normalize: (optional) bool, if True and d_out > 1, normalize each 27 | output vector's elements to mean 0 and variance 1. Default: True. 28 | memory_efficient: (optional) bool, if True, compute votes lazily to 29 | reduce memory use by O(n_inp * n_out * d_out), while increasing 30 | computation by only O(n_iters). Default: True. 31 | return_dict: (optional) bool, if True, return a dictionary with the 32 | final state of all internal and output tensors. Default: False. 33 | 34 | Input: 35 | x_inp: tensor of input vectors [..., n_inp, d_inp]. 36 | mask: (optional) bool tensor of shape [n_inp, n_out], True if an input 37 | vector will be ignored by an output vector, False everywhere else. 38 | 39 | Output: 40 | x_out: tensor of output vectors [..., n_out, d_out] by default, 41 | or a dict with output vectors as 'x_out' if return_dict is True. 42 | 43 | Sample usage: 44 | >>> # Route 100 vectors of size 1024 to 10 vectors of size 4096. 45 | >>> m = EfficientVectorRouting(n_inp=100, n_out=10, d_inp=1024, d_out=4096) 46 | >>> x_inp = torch.randn(100, 1024) # 100 vectors of size 1024 47 | >>> x_out = m(x_inp) # 10 vectors of size 4096 48 | """ 49 | def __init__(self, n_inp: int, n_out: int, d_inp: int, d_out: int, n_iters: int = 2, 50 | normalize: bool = True, memory_efficient: bool = True, return_dict: bool = False) -> None: 51 | super().__init__() 52 | assert n_inp > 0 or n_inp == -1, "Number of input vectors must be > 0 or -1 (variable)." 53 | assert n_out >= 2, "Number of output vectors must be at least 2." 54 | one_or_n_inp = max(1, n_inp) 55 | self.n_inp, self.n_out, self.d_inp, self.d_out, self.n_iters = (n_inp, n_out, d_inp, d_out, n_iters) 56 | self.normalize, self.memory_efficient, self.return_dict = (normalize, memory_efficient, return_dict) 57 | self.register_buffer('CONST_ones_over_n_out', torch.ones(n_out) / n_out) 58 | self.W_A = nn.Parameter(torch.empty(one_or_n_inp, d_inp).normal_(std=2.0 * d_inp**-0.5)) 59 | self.B_A = nn.Parameter(torch.zeros(one_or_n_inp)) 60 | self.W_F1 = nn.Parameter(torch.empty(n_out, d_inp).normal_()) 61 | self.W_F2 = nn.Parameter(torch.empty(d_inp, d_out).normal_(std=2.0 * d_inp**-0.5)) 62 | self.B_F2 = nn.Parameter(torch.zeros(n_out, d_out)) 63 | self.W_G1 = nn.Parameter(torch.empty(d_out, d_inp).normal_(std=d_out**-0.5)) 64 | self.W_G2 = nn.Parameter(torch.empty(n_out, d_inp).normal_()) 65 | self.B_G2 = nn.Parameter(torch.zeros(n_out, d_inp)) 66 | self.W_S = nn.Parameter(torch.empty(one_or_n_inp, n_out).normal_(std=d_inp**-0.5)) 67 | self.B_S = nn.Parameter(torch.zeros(one_or_n_inp, n_out)) 68 | if n_inp > 0: 69 | self.beta_use = nn.Parameter(torch.empty(n_inp, n_out).normal_()) 70 | self.beta_ign = nn.Parameter(torch.empty(n_inp, n_out).normal_()) 71 | else: 72 | self.compute_beta_use = nn.Linear(d_inp, n_out) 73 | self.compute_beta_ign = nn.Linear(d_inp, n_out) 74 | self.N = nn.LayerNorm(d_out, elementwise_affine=False) if d_out > 1 else nn.Identity() 75 | self.f, self.log_f, self.softmax = (nn.Sigmoid(), nn.LogSigmoid(), nn.Softmax(dim=-1)) 76 | 77 | def __repr__(self) -> str: 78 | cfg_str = ', '.join(f'{s}={getattr(self, s)}' for s in 'n_inp n_out d_inp d_out n_iters normalize memory_efficient return_dict'.split()) 79 | return '{}({})'.format(self._get_name(), cfg_str) 80 | 81 | def forward(self, x_inp: torch.Tensor, mask: Union[torch.BoolTensor, None] = None) -> Union[torch.Tensor, dict]: 82 | beta_use, beta_ign = (self.beta_use, self.beta_ign) if hasattr(self, 'beta_use') else (self.compute_beta_use(x_inp), self.compute_beta_ign(x_inp)) 83 | scaled_x_inp = x_inp * x_inp.shape[-2]**-0.5 # [...id] 84 | a_inp = (scaled_x_inp * self.W_A).sum(dim=-1) + self.B_A # [...i] 85 | V = None if self.memory_efficient else einsum('...id,jd,dh->...ijh', scaled_x_inp, self.W_F1, self.W_F2) + self.B_F2 86 | f_a_inp = self.f(a_inp).unsqueeze(-1) if mask is None else self.f(a_inp.unsqueeze(-1).masked_fill(mask, float('-inf'))) # [...i1] or [...ij] 87 | for iter_num in range(self.n_iters): 88 | 89 | # E-step. 90 | if iter_num == 0: 91 | R = self.CONST_ones_over_n_out if mask is None else self.softmax(self.CONST_ones_over_n_out.log().masked_fill(mask, float('-inf'))) 92 | else: 93 | pred_x_inp = einsum('...jh,hd,jd->...jd', self.N(x_out), self.W_G1, self.W_G2) + self.B_G2 94 | S = self.log_f(einsum('...id,...jd->...ij', x_inp, pred_x_inp) * self.W_S + self.B_S) 95 | R = self.softmax(S) if mask is None else self.softmax(S.masked_fill(mask, float('-inf'))) # [...ij] 96 | 97 | # D-step. 98 | D_use = f_a_inp * R # [...ij] 99 | D_ign = f_a_inp - D_use # [...ij] 100 | 101 | # M-step. 102 | phi = beta_use * D_use - beta_ign * D_ign # [...ij] "bang per bit" coefficients 103 | if V is None: 104 | _einsum_phi_scaled_x_inp = einsum('...ij,...id->...jd', phi, scaled_x_inp) 105 | x_out = einsum('...jd,jd,dh->...jh', _einsum_phi_scaled_x_inp, self.W_F1, self.W_F2) + einsum('...ij,jh->...jh', phi, self.B_F2) 106 | else: 107 | x_out = einsum('...ij,...ijh->...jh', phi, V) 108 | 109 | if self.normalize: 110 | x_out = self.N(x_out) 111 | 112 | if self.return_dict: 113 | return { 'a_inp': a_inp, 'V': V, 'pred_x_inp': pred_x_inp, 'S': S, 'R': R, 'D_use': D_use, 'D_ign': D_ign, 'phi': phi, 'x_out': x_out } 114 | else: 115 | return x_out 116 | 117 | 118 | class DefinableVectorRouting(nn.Module): 119 | """ 120 | Routes input vectors to the output vectors that maximize "bang per bit" 121 | by best predicting them, as specified by four neural networks provided 122 | as PyTorch module instances. Each vector is a capsule representing an 123 | entity in a context (e.g., a word in a sentence, an object in an image). 124 | See "An Algorithm for Routing Vectors in Sequences" (Heinsen, 2022). 125 | https://arxiv.org/abs/2211.11754. 126 | 127 | Args: 128 | A: nn.Module instance that accepts input vectors and computes input 129 | vector activation scores: [..., n_inp, d_inp] -> [..., n_inp, 1]. 130 | F: nn.Module instance that accepts input vectors and proposes output 131 | sequences: [..., n_inp, d_inp] -> [..., n_inp, n_out, d_out]. 132 | G: nn.Module instance that accepts output vectors and predicts input 133 | vectors: [..., n_out, d_out] -> [..., n_out, d_inp]. G can be a 134 | generative model that samples from a parametrized distribution. 135 | S: nn.Module instance that accepts actual and predicted input vectors 136 | (the latter computed by G) and computes their similary scores: 137 | [..., n_inp, d_inp], [...,n_out, d_inp] -> [..., n_inp , n_out]. 138 | If G is generative, S should ideally compute the log-probability 139 | density of each actual input vector given each predicted one. 140 | n_inp: int, number of input vectors. If -1, the number is variable (in 141 | which case A, F, and S must be able to handle a variable number). 142 | n_out: int, number of output vectors. 143 | n_iters: (optional) int, number of iterations. Default: 2. 144 | return_dict: (optional) bool, if True, return a dictionary with the 145 | final state of all internal and output tensors. Default: False. 146 | 147 | Input: 148 | x_inp: tensor of input vectors [..., n_inp, d_inp]. 149 | mask: (optional) bool tensor of shape [n_inp, n_out], True if an input 150 | vector will be ignored by an output vector, False everywhere else. 151 | 152 | Output: 153 | x_out: tensor of output vectors [..., n_out, d_out] by default, 154 | or a dict with output vectors as 'x_out' if return_dict is True. 155 | 156 | Sample usage: 157 | >>> class LearnedMemories(nn.Module): 158 | >>> def __init__(self, n_inp, n_out, d_out): 159 | >>> super().__init__() 160 | >>> self.W_mem = nn.Parameter(torch.randn(n_inp, n_out, d_out)) 161 | >>> def forward(self, x_inp): 162 | >>> return self.W_mem.expand(*x_inp.shape[:-2], -1, -1, -1) 163 | >>> 164 | >>> class DotProductSimilarities(nn.Module): 165 | >>> def __init__(self): 166 | >>> super().__init__() 167 | >>> def forward(self, true_x_inp, pred_x_inp): 168 | >>> scaling = true_x_inp.shape[-1]**-0.5 169 | >>> return true_x_inp @ pred_x_inp.transpose(-2, -1) * scaling 170 | >>> 171 | >>> # Route 100 vectors of size 1024 to 10 vectors of size 4096. 172 | >>> m = DefinableVectorRouting( 173 | >>> A=nn.Linear(1024, 1), 174 | >>> F=LearnedMemories(100, 10, 4096), 175 | >>> G=nn.Linear(4096, 1024), 176 | >>> S=DotProductSimilarities(), 177 | >>> n_inp=100, n_out=10) 178 | >>> 179 | >>> x_inp = torch.randn(100, 1024) # 100 vectors of size 1024 180 | >>> x_out = m(x_inp) # 10 vectors of size 4096 181 | """ 182 | def __init__(self, A: nn.Module, F: nn.Module, G: nn.Module, S: nn.Module, n_inp: int, n_out: int, n_iters: int = 2, return_dict: bool = False) -> None: 183 | super().__init__() 184 | assert n_inp > 0 or n_inp == -1, "Number of input vectors must be > 0 or -1 (variable)." 185 | assert n_out >= 2, "Number of output vectors must be at least 2." 186 | self.n_inp, self.n_out = (n_inp, n_out) 187 | self.A, self.F, self.G, self.S, self.n_iters, self.return_dict = (A, F, G, S, n_iters, return_dict) 188 | self.register_buffer('CONST_ones_over_n_out', torch.ones(n_out) / n_out) 189 | if n_inp > 0: 190 | self.beta_use = nn.Parameter(torch.empty(n_inp, n_out).normal_()) 191 | self.beta_ign = nn.Parameter(torch.empty(n_inp, n_out).normal_()) 192 | else: 193 | self.compute_beta_use = nn.Linear(d_inp, n_out) 194 | self.compute_beta_ign = nn.Linear(d_inp, n_out) 195 | self.f, self.softmax = (nn.Sigmoid(), nn.Softmax(dim=-1)) 196 | 197 | def __repr__(self) -> str: 198 | cfg_str = ',\n '.join(f'{s}={getattr(self, s)}' for s in 'A F G S n_inp n_out n_iters return_dict'.split()) 199 | return '{}({})'.format(self._get_name(), cfg_str) 200 | 201 | def forward(self, x_inp: torch.Tensor, mask: Union[torch.BoolTensor, None] = None) -> Union[torch.Tensor, dict]: 202 | beta_use, beta_ign = (self.beta_use, self.beta_ign) if hasattr(self, 'beta_use') else (self.compute_beta_use(x_inp), self.compute_beta_ign(x_inp)) 203 | a_inp = self.A(x_inp).view(*x_inp.shape[:-1]) # [...i] 204 | V = self.F(x_inp).view(*a_inp.shape, self.n_out, -1) # [...ijh] 205 | f_a_inp = self.f(a_inp).unsqueeze(-1) if mask is None else self.f(a_inp.unsqueeze(-1).masked_fill(mask, float('-inf'))) # [...i1] or [...ij] 206 | for iter_num in range(self.n_iters): 207 | 208 | # E-step. 209 | if iter_num == 0: 210 | R = self.CONST_ones_over_n_out if mask is None else self.softmax(self.CONST_ones_over_n_out.log().masked_fill(mask, float('-inf'))) 211 | else: 212 | pred_x_inp = self.G(x_out) # [...jd] 213 | S = self.S(x_inp, pred_x_inp) # [...ij] 214 | R = self.softmax(S) if mask is None else self.softmax(S.masked_fill(mask, float('-inf'))) # [...ij] 215 | 216 | # D-step. 217 | D_use = f_a_inp * R # [...ij] 218 | D_ign = f_a_inp - D_use # [...ij] 219 | 220 | # M-step. 221 | phi = beta_use * D_use - beta_ign * D_ign # [...ij] "bang per bit" coefficients 222 | x_out = einsum('...ij,...ijh->...jh', phi, V) 223 | 224 | if self.return_dict: 225 | return { 'a_inp': a_inp, 'V': V, 'pred_x_inp': pred_x_inp, 'S': S, 'R': R, 'D_use': D_use, 'D_ign': D_ign, 'phi': phi, 'x_out': x_out } 226 | else: 227 | return x_out 228 | 229 | 230 | class GenerativeMatrixRouting(nn.Module): 231 | """ 232 | Routes input matrices to output matrices generated by probabilistic models 233 | that best explain certain transformations of the given input matrices. Each 234 | matrix is a capsule representing an entity in a context (e.g., a word in a 235 | sentence, an object in an image). Both input and output matrices are paired 236 | with an activation score quantifying detection of the corresponding entity. 237 | See "An Algorithm for Routing Capsules in All Domains" (Heinsen, 2019), 238 | https://arxiv.org/abs/1911.00792. 239 | 240 | Args: 241 | n_inp: int, number of input matrices. If -1, the number is variable. 242 | n_out: int, number of output matrices. 243 | d_cov: int, dimension 1 of input and output matrices. 244 | d_inp: int, dimension 2 of input matrices. 245 | d_out: int, dimension 2 of output matrices. 246 | n_iters: (optional) int, number of routing iterations. Default is 3. 247 | single_beta: (optional) bool, if True, beta_use (net benefits per unit 248 | of data) and beta_ign (net costs) are the same. Default: False. 249 | p_model: (optional) str, specifies how to compute probability of input 250 | votes at each output matrix. Choices are 'gaussian' for Gaussian 251 | mixtures and 'skm' for soft k-means. Default: 'gaussian'. 252 | eps: (optional) small positive float << 1.0 for numerical stability. 253 | return_dict: (optional) bool, if True, return a dictionary with the 254 | final state of all internal and output tensors. Default: False. 255 | Input: 256 | a_inp: [..., n_inp] input scores. 257 | mu_inp: [..., n_inp, d_cov, d_inp] matrices of shape d_cov x d_inp. 258 | Output: 259 | If return_dict is False (default): 260 | a_out: [..., n_out] output scores. 261 | mu_out: [..., n_out, d_cov, d_out] matrices, each d_cov x d_out. 262 | sig2_out: [..., n_out, d_cov, d_out] matrices, each d_cov x d_out. 263 | Otherwise: 264 | Python dict with multiple tensors, including the default output 265 | tensors as keys 'a_out', 'mu_out', and 'sig2_out'. 266 | Sample usage: 267 | >>> a_inp = torch.randn(100) # 100 input scores 268 | >>> mu_inp = torch.randn(100, 4, 4) # 100 capsules of shape 4 x 4 269 | >>> m = GenerativeMatrixRouting( 270 | >>> d_cov=4, d_inp=4, d_out=4, n_inp=100, n_out=10) 271 | >>> a_out, mu_out, sig2_out = m(a_inp, mu_inp) 272 | >>> print(a_out) # 10 activation scores 273 | >>> print(mu_out) # 10 matrices of shape 4 x 4 (means) 274 | >>> print(sig2_out) # 10 matrices of shape 4 x 4 (variances) 275 | """ 276 | def __init__(self, n_inp: int, n_out: int, d_cov: int, d_inp: int, d_out: int, n_iters: int = 3, 277 | single_beta: bool = False, p_model: str ='gaussian', eps: float = 1e-5, return_dict: bool = False) -> None: 278 | super().__init__() 279 | assert n_inp > 0 or n_inp == -1, "Number of input matrices must be > 0 or -1 (variable)." 280 | assert n_out >= 2, "Number of output matrices must be at least 2." 281 | assert p_model in ['gaussian', 'skm'], 'Unrecognized value for p_model.' 282 | one_or_n_inp = max(1, n_inp) 283 | self.n_inp, self.n_out, self.d_cov, self.d_inp, self.d_out, self.n_iters = (n_inp, n_out, d_cov, d_inp, d_out, n_iters) 284 | self.single_beta, self.p_model, self.eps, self.return_dict = (single_beta, p_model, eps, return_dict) 285 | self.n_inp_is_fixed = n_inp > 0 286 | self.register_buffer('CONST_one', torch.tensor(1.0)) 287 | self.W = nn.Parameter(torch.empty(one_or_n_inp, n_out, d_inp, d_out).normal_() / d_inp) 288 | self.B = nn.Parameter(torch.zeros(one_or_n_inp, n_out, d_cov, d_out)) 289 | self.beta_use = nn.Parameter(torch.zeros(one_or_n_inp, n_out)) 290 | self.beta_ign = nn.Parameter(torch.zeros(one_or_n_inp, n_out)) if not single_beta else self.beta_use 291 | self.f, self.log_f, self.softmax, self.log_softmax = (nn.Sigmoid(), nn.LogSigmoid(), nn.Softmax(dim=-1), nn.LogSoftmax(dim=-1)) 292 | 293 | def __repr__(self) -> str: 294 | cfg_str = ', '.join(f'{s}={getattr(self, s)}' for s in 'n_inp n_out d_cov d_inp d_out n_iters single_beta p_model eps return_dict'.split()) 295 | return '{}({})'.format(self._get_name(), cfg_str) 296 | 297 | def forward(self, a_inp: torch.Tensor, mu_inp: torch.Tensor) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], dict]: 298 | W = self.W if self.n_inp_is_fixed else self.W.expand(a_inp.shape[-1], -1, -1, -1) 299 | V = einsum('...icd,ijdh->...ijch', mu_inp, W) + self.B 300 | f_a_inp = self.f(a_inp).unsqueeze(-1) # [...i1] 301 | for iter_num in range(self.n_iters): 302 | 303 | # E-step. 304 | if iter_num == 0: 305 | R = (self.CONST_one / self.n_out).expand(V.shape[:-2]) # [...ij] 306 | else: 307 | if self.p_model == 'gaussian': 308 | log_p = - einsum('...ijch,...jch->...ij', V_less_mu_out_2, 1.0 / (2.0 * sig2_out)) - sig2_out.sqrt().log().sum((-2, -1)).unsqueeze(-2) 309 | else: 310 | log_p = self.log_softmax(-V_less_mu_out_2.sum((-2, -1))) # soft k-means 311 | R = self.softmax(self.log_f(a_out).unsqueeze(-2) + log_p) # [...ij] 312 | 313 | # D-step. 314 | D_use = f_a_inp * R # [...ij] 315 | D_ign = f_a_inp - D_use # [...ij] 316 | 317 | # M-step. 318 | a_out = (self.beta_use * D_use).sum(dim=-2) - (self.beta_ign * D_ign).sum(dim=-2) # [...j] "bang per bit" activation scores 319 | normalized_D_use = D_use / (D_use.sum(dim=-2, keepdims=True) + self.eps) # [...ij] 320 | mu_out = einsum('...ij,...ijch->...jch', normalized_D_use, V) 321 | V_less_mu_out_2 = (V - mu_out.unsqueeze(-4)) ** 2 # [...ijch] 322 | sig2_out = einsum('...ij,...ijch->...jch', normalized_D_use, V_less_mu_out_2) + self.eps 323 | 324 | if self.return_dict: 325 | return { 'V': V, 'log_p': log_p, 'R': R, 'D_use': D_use, 'D_ign': D_ign, 'a_out': a_out, 'mu_out': mu_out, 'sig2_out': sig2_out } 326 | else: 327 | return a_out, mu_out, sig2_out 328 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # heinsen_routing 2 | 3 | Reference implementation of "[An Algorithm for Routing Vectors in Sequences](https://arxiv.org/abs/2211.11754)" (Heinsen, 2022), and an earlier variant, "[An Algorithm for Routing Capsules in All Domains](https://arxiv.org/abs/1911.00792)" (Heinsen, 2019), for computing a sequence of capsules (e.g., vectors, matrices) that best explain (e.g., predict, generate) a given input sequence. 4 | 5 | A toy example is helpful for conveying quickly what the algorithm does: 6 | 7 | ```python 8 | import torch 9 | from heinsen_routing import EfficientVectorRouting as Routing 10 | 11 | model = torch.nn.Sequential( 12 | Routing(n_inp=10000, n_out=1000, d_inp=1024, d_out=2048), 13 | Routing(n_inp= 1000, n_out= 100, d_inp=2048, d_out=3072), 14 | Routing(n_inp= 100, n_out= 10, d_inp=3072, d_out=4096), 15 | ) 16 | 17 | x_inp = torch.randn(10_000, 1024) # 10,000 vectors of size 1024 (e.g., embedded tokens) 18 | x_out = model(x_inp) # 10 vectors of size 4096 that best explain x_inp 19 | ``` 20 | 21 | For instructions to route very long sequences (e.g., 1,000,000+ token embeddings in 18GB of VRAM), see [here](#routing-very-long-sequences). For an example of end-to-end credit assignment, see [here](#assigning-credit-end-to-end). For replicating published results, see [here](#replicating-published-results). 22 | 23 | 24 | ## Table of Contents 25 | 26 | * [Installing](#installing) 27 | 28 | * [How Does it Work?](#how-does-it-work) 29 | 30 | * [Variants of the Algorithm in this Repository](#variants-of-the-algorithm-in-this-repository) 31 | * [EfficientVectorRouting](#efficientvectorrouting) 32 | * [DefinableVectorRouting](#definablevectorrouting) 33 | * [GenerativeMatrixRouting](#generativematrixrouting) 34 | 35 | * [Sample Usage of EfficientVectorRouting](#sample-usage-of-efficientvectorrouting) 36 | * [Sequence to Sequence](#sequence-to-sequence) 37 | * [Sequence to Vector](#sequence-to-vector) 38 | * [Routing Sequences of Varying Length](#routing-sequences-of-varying-length) 39 | * [Routing Very Long Sequences](#routing-very-long-sequences) 40 | * [Recurrent Routings](#recurrent-routings) 41 | * [Masking Input from Output Vectors](#masking-input-from-output-vectors) 42 | * [Assigning Credit to Input Vectors](#assigning-credit-to-input-vectors) 43 | * [Assigning Credit End-to-End](#assigning-credit-end-to-end) ([Example](#example-of-end-to-end-credit-assignment)) 44 | 45 | * [Frequently Asked Questions](#frequently-Asked-Questions) 46 | 47 | * [Replicating Published Results](#replicating-published-results) 48 | 49 | * [Notes](#notes) 50 | 51 | * [Citing](#citing) 52 | 53 | 54 | ## Installing 55 | 56 | ``` 57 | pip install git+https://github.com/glassroom/heinsen_routing 58 | ``` 59 | 60 | Alternatively, you can download a single file to your project directory: [heinsen_routing.py](heinsen_routing/heinsen_routing.py). 61 | 62 | The only dependency is PyTorch. 63 | 64 | 65 | ## How Does it Work? 66 | 67 | Our routing algorithm takes a sequence of `n_inp` input capsules and computes a new sequence of `n_out` output capsules. A capsule is a group of artificial neurons, such as a vector or a matrix, representing the properties of an entity in a context (e.g., a word in a paragraph, an object in an image, a topic in a conversation). Each input and output capsule represents (i.e., is a *symbol* for) a different entity. 68 | 69 | The algorithm is iterative. In each iteration, we update the state of all output capsules in parallel. Each output capsule maximizes "bang per bit," the difference between a net benefit to use and net cost to ignore data, by better explaining (e.g., predicting, generating) the input capsules. The output sequence's final state maximizes "bang per bit" by best explaining the input sequence. 70 | 71 | The algorithm is differentiable. When you train it with stochastic gradient descent, it learns to compute the output sequence that best explains the input sequence, in order to minimize the training loss you specify. 72 | 73 | 74 | ## Variants of the Algorithm in this Repository 75 | 76 | This repository contains three variants of our routing algorithm, implemented as PyTorch modules: 77 | 78 | ### EfficientVectorRouting 79 | 80 | `EfficientVectorRouting` is the efficient implementation proposed in "[An Algorithm for Routing Vectors in Sequences](https://arxiv.org/abs/2211.11754)" (Heinsen, 2022). It incorporates optimizations that reduce parameter count, memory use, and computation by orders of magnitude compared to the other two variants, making it the best choice for most use cases. This README focuses primarily on this PyTorch module. *If you're not sure which module you should use, we recommend this one.* [See the next section for sample usage](#sample-usage-of-efficientvectorrouting). 81 | 82 | ### DefinableVectorRouting 83 | 84 | `DefinableVectorRouting` implements the general form of "[An Algorithm for Routing Vectors in Sequences](https://arxiv.org/abs/2211.11754)" (Heinsen, 2022). It requires you to define, instantiate, and pass at initialization four PyTorch modules named A, F, G, and S (corresponding to neural networks A, F, G, and S in the paper), which specify routing behavior. In principle, you could define A, F, G, and S to replicate `EfficientVectorRouting`'s behavior -- but not its optimizations. See the module's docstring for sample usage. 85 | 86 | ### GenerativeMatrixRouting 87 | 88 | `GenerativeMatrixRouting` implements the original variant, "[An Algorithm for Routing Capsules in All Domains](https://arxiv.org/abs/1911.00792)" (Heinsen, 2019). It routes matrices as the capsules, and uses Gaussian mixture models to generate the output matrices, weighted by separate activations that maximize "bang per bit." This module is the least scalable of the three, so we now recommend using it mainly for small-scale tasks. See the module's docstring for sample usage. 89 | 90 | 91 | ## Sample Usage of EfficientVectorRouting 92 | 93 | ### Sequence to Sequence 94 | 95 | `EfficientVectorRouting` takes a sequence of input vectors `[..., n_inp, d_inp]` and computes a sequence of output vectors `[..., n_out, d_out]`, where "`...`" denotes zero or more preserved dimensions. Each vector is a capsule representing a different entity. The output sequence maximizes "bang per bit" by best predicting (i.e., explaining) the given input sequence: 96 | 97 | ```python 98 | import torch 99 | from heinsen_routing import EfficientVectorRouting as Routing 100 | 101 | batch_sz = 4 102 | n_inp, d_inp = (2000, 1024) # input seqs will have 2000 vectors of size 1024 103 | n_out, d_out = (1000, 2048) # we will route them to 1000 vectors of size 2048 104 | 105 | model = Routing(n_inp=n_inp, n_out=n_out, d_inp=d_inp, d_out=d_out) 106 | 107 | x_inp = torch.randn(batch_sz, n_inp, d_inp) # shape is [batch_sz, n_inp, d_inp] 108 | x_out = model(x_inp) # shape is [batch_sz, n_out, d_out] 109 | ``` 110 | 111 | 112 | ### Sequence to Vector 113 | 114 | If you set `d_out` equal to 1, `EfficientVectorRouting` routes each sequence of input vectors to a sequence of one-dimensional vectors, or scalars, which you can concatenate into a single vector. Each scalar value is a capsule representing a different entity: 115 | 116 | ```python 117 | import torch 118 | from heinsen_routing import EfficientVectorRouting as Routing 119 | 120 | batch_sz = 4 121 | n_inp, d_inp = (1000, 1024) # input seqs will have 1000 vectors of size 1024 122 | d_vec = 2048 # we will route each seq to a vector of size 2048 123 | 124 | model = Routing(n_inp=n_inp, n_out=d_vec, d_inp=d_inp, d_out=1) 125 | 126 | x_inp = torch.randn(batch_sz, n_inp, d_inp) # shape is [batch_sz, n_inp, d_inp] 127 | x_out = model(x_inp).squeeze(-1) # shape is [batch_sz, d_vec] 128 | ``` 129 | 130 | 131 | ### Routing Sequences of Varying Length 132 | 133 | If you set `n_inp` equal to -1, `EfficientVectorRouting` routes input sequences of *any* length, limited only by available memory, to output sequences of fixed length. *If the order of input vectors matters, embed position information in them beforehand.* Train the module with input sequences of varying lengths and it will learn to predict (explain) them. Here, we route sequences of random length: 134 | 135 | 136 | ```python 137 | import torch 138 | from heinsen_routing import EfficientVectorRouting as Routing 139 | 140 | batch_sz = 4 141 | n_inp, d_inp = ( -1, 1024) # length of input seq will vary 142 | n_out, d_out = (1000, 1024) # length of output seq is fixed 143 | 144 | model = Routing(n_inp=n_inp, n_out=n_out, d_inp=d_inp, d_out=d_out) 145 | 146 | random_len = torch.randint(1, 10_000, []).item() 147 | x_inp = torch.randn(batch_sz, random_len, d_inp) # seqs of random length; order unimportant 148 | x_out = model(x_inp) # seqs of fixed length 149 | ``` 150 | 151 | Note: When `n_inp` is set to -1, `EfficientVectorRouting` treats every input vector as being in the same shared feature space (i.e., each element represents the same feature in all input vectors). In contrast, when `n_inp` is fixed, the module treats each input vector differently, so each input vector *may* be in a different feature space (i.e., the same element may represent a different feature in each input vector). 152 | 153 | 154 | ### Routing Very Long Sequences 155 | 156 | `EfficientVectorRouting`'s memory footprint is linear in each of `n_inp`, `n_out`, `d_inp`, and `d_out`, giving you fine-grained control over memory consumption. To route input sequences of greater length, you can reduce the length of the output sequence, and vice versa. For example, here we route an input sequence with 1,000,000 vectors at full (32-bit) precision, keeping track of gradients for backpropagation, consuming under ~18GB of memory (on a recent Nvidia GPU, excluding ~1GB of PyTorch and CUDA overhead): 157 | 158 | ```python 159 | import torch 160 | from heinsen_routing import EfficientVectorRouting as Routing 161 | 162 | n_inp, d_inp = (1_000_000, 1024) # very long input seq 163 | n_out, d_out = ( 100, 1024) # short output seq 164 | 165 | model = Routing(n_inp=n_inp, n_out=n_out, d_inp=d_inp, d_out=d_out) # uses ~6GB of memory 166 | 167 | x_inp = torch.randn(n_inp, d_inp) # uses ~4GB of memory 168 | x_out = model(x_inp) # uses ~8GB of memory 169 | ``` 170 | 171 | A handy technique for routing very long sequences is to route them *twice*: First to a short hidden sequence of larger vectors representing "summaries," and then to an output sequence with the desired shape. The motivation is to reduce the computation incurred and memory allocated by the execution of `n_inp` Softmax functions, each over `n_out` proposed output vectors, in each iteration of the first routing. Here, we apply this technique to route 250,000 to 1,000 vectors of size 1024 at 32-bit precision, keeping track of gradients, requiring only ~5.6GB of memory (on a recent Nvidia GPU, excluding ~1GB of PyTorch and CUDA overhead): 172 | 173 | ```python 174 | import torch 175 | import torch.nn as nn 176 | from heinsen_routing import EfficientVectorRouting as Routing 177 | 178 | n_inp, d_inp = (250_000, 1024) # very long input seq 179 | n_hid, d_hid = ( 100, 10240) # short hidden seq of higher-dimensional "summaries" 180 | n_out, d_out = ( 1000, 1024) # output seq with final desired shape 181 | 182 | model = nn.Sequential( 183 | Routing(n_inp=n_inp, n_out=n_hid, d_inp=d_inp, d_out=d_hid), # "summarize" 184 | Routing(n_inp=n_hid, n_out=n_out, d_inp=d_hid, d_out=d_out), # "rewrite" 185 | ) 186 | 187 | x_inp = torch.randn(n_inp, d_inp) 188 | x_out = model(x_inp) 189 | ``` 190 | 191 | Note: If the long input sequences have varying lengths, or if your application will work well with all input vectors in each sequence being in the same feature space, you can set `n_inp` to -1 to reduce parameter count and memory footprint. For example, in the code snippet above, you can replace `model` with: 192 | 193 | ```python 194 | model = nn.Sequential( 195 | Routing(n_inp=-1, n_out=n_hid, d_inp=d_inp, d_out=d_hid), # "summarize" 196 | Routing(n_inp=-1, n_out=n_out, d_inp=d_hid, d_out=d_out), # "rewrite" 197 | ) 198 | ``` 199 | 200 | and the memory required to route 250,000 to 1,000 vectors of size 1024 at 32-bit precision, keeping track of gradients, shrinks to ~3.9GB (on a recent Nvidia GPU, excluding ~1GB of PyTorch and CUDA overhead). 201 | 202 | 203 | ### Recurrent Routings 204 | 205 | You can apply `EfficientVectorRouting` recurrently, inducing it each time to compute a new sequence that best predicts (explains) the previously computed sequence. You can also apply recurrent routings as *residuals*, inducing the module to compute new residual sequences that best predict (explain) the recurrent accumulation of all previous sequences. For example: 206 | 207 | ```python 208 | import torch 209 | import torch.nn as nn 210 | from heinsen_routing import EfficientVectorRouting as Routing 211 | 212 | class RecurrentResidualRoutings(nn.Module): 213 | 214 | def __init__(self, n_emb, d_emb, n_iters, seq_compression_factor=8): 215 | super().__init__() 216 | n_hid = max(2, n_emb // seq_compression_factor) 217 | self.n_iters = n_iters 218 | self.normalize = nn.LayerNorm(d_emb) 219 | self.residualize = nn.Sequential( 220 | Routing(n_inp=n_emb, n_out=n_hid, d_inp=d_emb, d_out=d_emb), 221 | Routing(n_inp=n_hid, n_out=n_emb, d_inp=d_emb, d_out=d_emb), 222 | ) 223 | 224 | def forward(self, x): 225 | for _ in range(self.n_iters): 226 | x = self.normalize(x) 227 | x = x + self.residualize(x) 228 | return x 229 | 230 | batch_sz = 4 231 | n_emb, d_emb = (1000, 1024) 232 | 233 | model = RecurrentResidualRoutings(n_emb, d_emb, n_iters=5) 234 | 235 | x_inp = torch.randn(batch_sz, n_emb, d_emb) 236 | x_out = model(x_inp) 237 | ``` 238 | 239 | 240 | ### Masking Input from Output Vectors 241 | 242 | You can mask input vectors differently for each output vector by passing a boolean mask of shape `[n_inp, n_out]` as an input to the forward pass. True values mask input vector data and False values don't. For example, here we use masking to restrict `EfficientVectorRouting` to model only causal relationships over an ordered sequence of vectors, each representing a token. For each output vector in a given position, the module can route data only from input vectors located in up to that position: 243 | 244 | ```python 245 | import torch 246 | from heinsen_routing import EfficientVectorRouting as Routing 247 | 248 | batch_sz = 4 249 | n_tok, d_tok = (100, 1024) 250 | causal_mask = torch.tril(torch.ones(n_tok, n_tok), diagonal=-1).bool() 251 | 252 | model = Routing(n_inp=n_tok, n_out=n_tok, d_inp=d_tok, d_out=d_tok) 253 | 254 | tok_embs = torch.randn(batch_sz, n_tok, d_tok) # normally provided by a model 255 | pos_embs = torch.randn(n_tok, d_tok) # normally provided by a model 256 | 257 | x_inp = torch.nn.functional.layer_norm(tok_embs + pos_embs, [d_tok]) 258 | x_out = model(x_inp, mask=causal_mask) 259 | ``` 260 | 261 | Note: The mask does *not* have to be square. For example, you can specify a causal mask of shape `[n_ltc + n_tok, n_tok]`, where `n_ltc` is a specified number of vectors providing long-term context to all tokens, `n_ltc + n_tok` is the number of input vectors, and `n_tok` is the number of output vectors. 262 | 263 | 264 | ### Assigning Credit to Input Vectors 265 | 266 | Each instance of `EfficientVectorRouting` internally computes a credit assignment matrix of shape `[..., n_inp, n_out]`, consisting of the credit assigned to each input vector by each output vector. To obtain the credit assignments, instantiate the module with `return_dict=True`, and it will return a dictionary with output vectors as key `'x_out'` and credit assignments as key `'phi'`. For example: 267 | 268 | ```python 269 | import torch 270 | from heinsen_routing import EfficientVectorRouting as Routing 271 | 272 | batch_sz = 4 273 | n_inp, d_inp = (100, 1024) 274 | n_out, d_out = ( 10, 1024) 275 | 276 | model = Routing(n_inp=n_inp, n_out=n_out, d_inp=d_inp, d_out=d_out, return_dict=True) 277 | 278 | x_inp = torch.randn(batch_sz, n_inp, d_inp) 279 | outputs = model(x_inp) 280 | 281 | x_out = outputs['x_out'] # [batch_sz, n_out, d_out] output vectors 282 | phi = outputs['phi'] # [batch_sz, n_inp, n_out] credit assigned to input by output vecs 283 | ``` 284 | 285 | ### Assigning Credit End-to-End 286 | 287 | The credit assignments are additive, like Shapley values, and composable on their own, independently of data transformations, making it possible to compute end-to-end credit assignments over a network of routings, as explained in Subsection 3.2 of [the paper](https://arxiv.org/abs/2211.11754). For "how-to" recipes to compute end-to-end credit assignments over common compositions, including residual layers, see Appendix A of the same paper. For example: 288 | 289 | ```python 290 | import torch 291 | import torch.nn as nn 292 | from heinsen_routing import EfficientVectorRouting as Routing 293 | 294 | class SequentialRoutingsWithCreditAssignments(nn.Module): 295 | """ 296 | Apply routings sequentially and compute end-to-end credit assignments 297 | by following the recipe for sequential routings in Appendix A of "An 298 | Algorithm for Routing Vectors in Sequences" (Heinsen, 2022). 299 | """ 300 | def __init__(self, kwds_by_routing, eps=1e-5): 301 | super().__init__() 302 | self.eps = eps 303 | self.routings = nn.ModuleList( 304 | [Routing(**kwds, return_dict=True) for kwds in kwds_by_routing] 305 | ) 306 | 307 | def forward(self, x): 308 | prod = None 309 | for routing in self.routings: 310 | outputs = routing(x) 311 | x = outputs['x_out'] 312 | phi = outputs['phi'] 313 | prod = phi if prod is None else prod @ phi # chain of matrix products 314 | prod = prod / (prod.std() + self.eps) # scale the matrix products 315 | return x, prod 316 | 317 | kwds_by_routing = [ 318 | { 'n_inp': 500, 'n_out': 400, 'd_inp': 1024, 'd_out': 1024, }, 319 | { 'n_inp': 400, 'n_out': 300, 'd_inp': 1024, 'd_out': 1024, }, 320 | { 'n_inp': 300, 'n_out': 200, 'd_inp': 1024, 'd_out': 1024, }, 321 | { 'n_inp': 200, 'n_out': 100, 'd_inp': 1024, 'd_out': 1024, }, 322 | ] 323 | model = SequentialRoutingsWithCreditAssignments(kwds_by_routing) 324 | 325 | x_inp = torch.randn(kwds_by_routing[0]['n_inp'], kwds_by_routing[0]['d_inp']) 326 | x_out, credit_assignments = model(x_inp) 327 | ``` 328 | 329 | If you run the code above, `x_out` will have shape `[100, 1024]`, computed by the last routing, and `credit_assignments` will have shape `[500, 100]`, consisting of the end-to-end credit assigned to the first routing's 500 input vectors by the final routing's 100 output vectors. 330 | 331 | 332 | #### Example of End-to-End Credit Assignment 333 | 334 | Here is a typical example of end-to-end credit assignment, in this case obtained from three sequential routings trained to route a sequence of hidden states at all depths in a BEiT Transformer to a sequence of predicted scores for ImageNet-1k classification. We show the end-to-end credit asigned to every pixel patch at every level of Transformer depth by the top predicted score, corresponding to the label "Cardigan Welsh Corgi." The credit assignments are additive, so we sum them over all depths to obtain the credit assigned to each pixel patch. *As you can see, our algorithm assigns credit end-to-end to the dog's body in shallower layers, and to its nose, mouth, ears, and paws in deeper layers, explaining the top predicted score ("Cardigan Welsh Corgi"):* 335 | 336 | > ![Sample end-to-end credit assignment in vision](assets/fig_sample_credit_assignments_vision_from_paper.png) 337 | 338 | The code and a pretrained model necessary for recreating the visualization above are online (see [here](#replicating-published-results)). For a higher-resolution version and a more in-depth explanation of this visualization, see [the paper](https://arxiv.org/abs/2211.11754). 339 | 340 | 341 | ## Frequently Asked Questions 342 | 343 | *Q: "Is it true that `EfficientVectorRouting` can route sequences with 1,000,000+ vectors in a single GPU/TPU?"* 344 | 345 | A: Yes. See [here](#routing-very-long-sequences). 346 | 347 | 348 | *Q: "Can I use `EfficientVectorRouting` in a Transformer? CNN? RNN? Autoencoder? Generative model? Any model?"* 349 | 350 | A: Yes. `EfficientVectorRouting` is a general-purpose PyTorch module. 351 | 352 | 353 | *Q: "Can I use `EfficientVectorRouting` instead of self-attention as a component of models?"* 354 | 355 | A: Yes. There is in fact a connection between the query-key-value self-attention mechanism used in Transformers and the algorithm implemented by `EfficientVectorRouting`: Transformer self-attention is a special case of modern Hopfield networks with bipartite structure, a class of dense associative memories which are in turn a special case of the routing algorithm we propose in "[An Algorithm for Routing Vectors in Sequences](https://arxiv.org/abs/2211.11754)." `EfficientVectorRouting` is one possible implementation of our algorithm. 356 | 357 | 358 | *Q: "Can I use `EfficientVectorRouting` to classify a sequence of token embeddings?"* 359 | 360 | A: Yes. Route the sequence to a vector (see [here](#sequence-to-vector)) and use the vector's elements as the predicted classification scores. In training, the module will learn to compute scores that minimize classification error and simultaneously best predict (explain) the sequence being classified. 361 | 362 | Note: If the number of classes is large (say, more than a few thousand classes), it may be more efficient to route the sequence to a hidden vector, and then apply a linear transformation to that hidden vector to predict classification scores, as is conventional. 363 | 364 | 365 | *Q: "Can I use `EfficientVectorRouting` to build "deep autoencoders for sequences"?* 366 | 367 | A: Yes. You can build deep autoencoders that apply multiple `EfficientVectorRouting` layers to encode an input sequence to progressively shorter sequences and then progressively decode the shortest sequence back to the original length, in a typical "bowtie" arrangement. The autoencoders can of course be variational, using the reparametrization trick to sample the inner shortest sequence from a specified distribution. 368 | 369 | 370 | *Q: "Is it true that each output vector computed by `EfficientVectorRouting` can have its own feature space?"* 371 | 372 | A: Yes. Each output vector is computed in a different basis, possibly representing different features (i.e., the same element may represent different features in different vectors). The number of representable features can be as large as `n_out` × `d_out`. This increase in representational capacity may enable you to work with shorter sequences and/or smaller vector sizes than otherwise necessary. See Subsection 3.1 of [the paper](https://arxiv.org/abs/2211.11754). 373 | 374 | Note: If you treat every output vector as being in the same shared feature space (e.g., if you always apply the same transformations to all output vectors, instead of different transformations to each one), you *can* induce all vector bases to represent the same features. If that's what you want, great -- but if not, please exercise a bit of care to avoid doing it unintentionally! 375 | 376 | 377 | *Q: "Can I set `n_inp` to -1 in `EfficientVectorRouting` even if all input sequences have the same length?"* 378 | 379 | A: Yes, but note that when you set `n_inp` to -1, `EfficientVectorRouting` treats all input vectors as being in the same shared feature space (i.e., each element represents the same feature in all input vectors). If all input vectors are indeed in the same shared feature space, then, yes, it makes sense to set `n_inp` to -1 even if input sequence length is fixed. Note: Setting `n_inp` to -1 may cause the module to incur more computation, depending on the shape of input and output sequences. 380 | 381 | 382 | *Q: "Is it true that I can get end-to-end credit assignments over a network of `EfficientVectorRouting` layers?"* 383 | 384 | A: Yes. To compute end-to-end credit assignments, follow the "how-to" recipes in Appendix A of [the paper](https://arxiv.org/abs/2211.11754) -- or make sure you thoroughly understand how the credit assignments work before straying away from the proven recipes. For a discussion of credit assignments, see Subsection 3.2 of the same paper. For a concrete example of end-to-end credit assignment, see [here](#example-of-end-to-end-credit-assignment). 385 | 386 | 387 | *Q: "Is it true that `EfficientVectorRouting` implements a model of associative memory?"* 388 | 389 | A: Yes. In Subsection 3.3 of [the paper](https://arxiv.org/abs/2211.11754), we describe input vectors as *keys* to content-addressable memory values and biases, and output vectors as *queries* whose states are iteratively updated until they stabilize in a local maximum of a "bang per bit" landscape (or, equivalently, a local minimum of an energy landscape). We also show that with significant simplifications, the routing algorithm implemented by `EfficientVectorRouting` reduces to modern Hopfield networks with bipartite structure, a class of dense associative memories of which Transformer self-attention is a notable special case. 390 | 391 | 392 | *Q: "Is it true that `EfficientVectorRouting` implements a "block" in a model of a Society of Mind (Minsky, 1986)?"* 393 | 394 | A: Yes. In Subsection 3.4 of of [the paper](https://arxiv.org/abs/2211.11754), we describe output vectors as multidimensional agents competing against each other to use or ignore scarce resources in a block via knowledge lines, or K-lines, in a model of a Society of Mind. Agents iteratively improve the shares of each scarce resource they use or ignore by better predicting (i.e., explaining) it. Note that when Minsky was working on ``The Society of Mind," published in 1986, he was certainly aware of early models of associative memory, including Hopfield networks (formulated by John Hopfield in 1982 and 1984, building on prior work going back to the early 1970's) and restricted Boltzmann machines (first proposed as the "Harmonium" by Paul Smolensky in 1986). 395 | 396 | 397 | ## Replicating Published Results 398 | 399 | To replicate the results published in "[An Algorithm for Routing Vectors in Sequences](https://arxiv.org/abs/2211.11754)" (Heinsen, 2022), follow the intructions here: . 400 | 401 | To replicate the results published in "[An Algorithm for Routing Capsules in All Domains](https://arxiv.org/abs/1911.00792)" (Heinsen, 2019), follow the intructions here: . 402 | 403 | 404 | ## Notes 405 | 406 | We have tested the code in this repository only on Ubuntu Linux 20.04 with Python 3.8+. 407 | 408 | 409 | ## Citing 410 | 411 | If our work is helpful to your research, please cite it: 412 | 413 | ``` 414 | @misc{heinsen2022algorithm, 415 | title={An Algorithm for Routing Vectors in Sequences}, 416 | author={Franz A. Heinsen}, 417 | year={2022}, 418 | eprint={2211.11754}, 419 | archivePrefix={arXiv}, 420 | primaryClass={cs.LG} 421 | } 422 | 423 | @misc{heinsen2019algorithm, 424 | title={An Algorithm for Routing Capsules in All Domains}, 425 | author={Franz A. Heinsen}, 426 | year={2019}, 427 | eprint={1911.00792}, 428 | archivePrefix={arXiv}, 429 | primaryClass={cs.LG} 430 | } 431 | ``` 432 | 433 | ## How is this used at GlassRoom? 434 | 435 | We conceived and implemented this routing algorithm to be a component (i.e., a layer) of larger models that are in turn part of our AI software, nicknamed Graham. Most of the original work we do at GlassRoom tends to be either proprietary in nature or tightly coupled to internal code, so we cannot share it with outsiders. In this case, however, we were able to isolate our code and release it as stand-alone open-source software without having to disclose any key intellectual property. We hope others find our work and our code useful. 436 | 437 | 438 | --------------------------------------------------------------------------------