├── heinsen_attention ├── __init__.py └── heinsen_attention.py ├── setup.py ├── LICENSE ├── .gitignore ├── generative_language_model.py └── README.md /heinsen_attention/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from .heinsen_attention import LogAttention 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from setuptools import setup 3 | 4 | setup(name='heinsen_attention', 5 | version='1.0.0', 6 | description='Reference implementation of "Softmax Attention with Constant Cost per Token" (Heinsen, 2024).', 7 | url='https://github.com/glassroom/heinsen_attention', 8 | author='Franz A. Heinsen', 9 | author_email='franz@glassroom.com', 10 | license='MIT', 11 | packages=['heinsen_attention'], 12 | install_requires='torch', 13 | zip_safe=False) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 GlassRoom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /heinsen_attention/heinsen_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LogAttention(nn.Module): 5 | """ 6 | As proposed by Franz A. Heinsen, March 2024. 7 | 8 | Args: 9 | is_causal: (optional) bool, if True, compute causal log-attention. 10 | 11 | Input shapes: 12 | Q: [..., n_queries, d_key] queries. 13 | K: [..., n_context, d_key] keys. 14 | log_V: [..., n_context, d_val] log-values. 15 | 16 | Output shapes: 17 | log_attention: [..., n_queries, d_val] log of Softmax mixtures of values. 18 | """ 19 | 20 | def __init__(self, is_causal=True): 21 | super().__init__() 22 | self.is_causal = is_causal 23 | 24 | def forward(self, Q, K, log_V, using_prev_context=False): 25 | Q = Q.unsqueeze(-1) # [..., n_queries, d_key, 1] 26 | K = K.unsqueeze(-1) # [..., n_context, d_key, 1] 27 | log_V = log_V.unsqueeze(-2) # [..., n_context, 1, d_val] 28 | 29 | if self.is_causal: 30 | K = K.to(torch.float32) if self.training else K # work-around for PyTorch 2.2 cuda issue 31 | H_S = torch.logcumsumexp(K + log_V, dim=-3).to(Q.dtype) # [..., n_context, d_key, d_val] 32 | H_Z = torch.logcumsumexp(K , dim=-3).to(Q.dtype) # [..., n_context, d_key, 1] 33 | else: 34 | H_S = torch.logsumexp(K + log_V, dim=-3, keepdim=True) # [..., 1, d_key, d_val] 35 | H_Z = torch.logsumexp(K , dim=-3, keepdim=True) # [..., 1, d_key, 1] 36 | 37 | if using_prev_context: 38 | H_S = self.prev_H_S.logaddexp(H_S) # [..., :, d_key, d_val] 39 | H_Z = self.prev_H_Z.logaddexp(H_Z) # [..., :, d_key, 1] 40 | 41 | self.prev_H_S = H_S[..., -1:, :, :].detach() # [..., 1, d_key, d_val] 42 | self.prev_H_Z = H_Z[..., -1:, :, :].detach() # [..., 1, d_key, d_val] 43 | 44 | log_S = torch.logsumexp(Q + H_S, dim=-2) # [..., n_queries, d_val] 45 | log_Z = torch.logsumexp(Q + H_Z, dim=-2) # [..., n_queries, 1] 46 | 47 | return log_S - log_Z 48 | 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | python 3 | 4 | # Dataset directory 5 | .data 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /generative_language_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from tqdm import tqdm 6 | from dataclasses import dataclass 7 | from heinsen_attention import LogAttention 8 | 9 | 10 | @dataclass 11 | class ModelConfig: 12 | vocab_sz: int = 50304 # vocab size of 50257, padded up for efficiency 13 | d_emb: int = 768 # number of embedding features 14 | n_layers: int = 24 # number of residual layers 15 | n_heads: int = 24 # number of heads per token 16 | d_key: int = 32 # number of key features per head 17 | d_val: int = 32 # number of value features per head 18 | 19 | 20 | class ResidualLayer(nn.Module): 21 | """ 22 | A simple causal (autoregressive) residual layer. 23 | 24 | Input shapes: 25 | tokens: [..., n_tok, d_emb]. 26 | 27 | Output shapes: 28 | tokens: [..., n_tok, d_emb]. 29 | """ 30 | def __init__(self, d_emb, n_heads, d_key, d_val): 31 | super().__init__() 32 | self.d_emb, self.n_heads, self.d_key, self.d_val = (d_emb, n_heads, d_key, d_val) 33 | self.feedforward1 = nn.Sequential( 34 | nn.LayerNorm(d_emb), 35 | nn.Linear(d_emb, n_heads * (d_key + d_key + d_val)), 36 | ) 37 | self.log_attention = LogAttention(is_causal=True) 38 | self.feedforward2 = nn.Sequential( 39 | nn.Linear(n_heads * d_val, d_emb * 2), 40 | nn.GLU(dim=-1), 41 | nn.Linear(d_emb, d_emb, bias=False), 42 | ) 43 | 44 | def extra_repr(self): 45 | return ', '.join('{}={}'.format(s, getattr(self, s)) for s in 'd_emb n_heads d_key d_val'.split(' ')) 46 | 47 | def forward(self, inp, using_prev_context=False): 48 | x = self.feedforward1(inp) # [..., n_toks, n_heads * (d_key + d_key + d_val)] 49 | x = x.view(*x.shape[:-1], self.n_heads, -1) # [..., n_toks, n_heads, d_key + d_key + d_val] 50 | x = x.transpose(-3, -2) # [..., n_heads, n_toks, d_key + d_key + d_val] 51 | x = x.split([self.d_key, self.d_key, self.d_val], dim=-1) # tuple of three tensors 52 | x = self.log_attention(*x, using_prev_context) # [..., n_heads, n_toks, d_val] 53 | x = x.transpose(-3, -2).flatten(-2) # [..., n_toks, n_heads * d_val] 54 | x = self.feedforward2(x) # [..., n_toks, d_emb] 55 | return inp + x 56 | 57 | 58 | class EmbedPosition(nn.Module): 59 | """ 60 | As proposed by Franz A. Heinsen, March 2024. 61 | 62 | Input shapes: 63 | tokens: [..., n_tok, d_emb]. 64 | 65 | Output shapes: 66 | tokens: [..., n_tok, d_emb]. 67 | """ 68 | def __init__(self, d_emb): 69 | super().__init__() 70 | self.d_emb = d_emb 71 | self.dense = nn.Linear(d_emb, d_emb * 2) 72 | 73 | def extra_repr(self): 74 | return 'd_emb={}'.format(d_emb) 75 | 76 | def _log_linear_recurrence(self, log_coeffs, prepended_logits): 77 | "Applies method proposed in https://arxiv.org/abs/2311.06281." 78 | a_star = F.pad(log_coeffs.cumsum(dim=-2), (0,0, 1,0), value=0) # [..., 1 + n_tok, d_emb] 79 | logit0_plus_b_star = torch.logcumsumexp(prepended_logits - a_star, dim=-2) # [..., 1 + n_tok, d_emb] 80 | log_linear_recurrence = a_star + logit0_plus_b_star # [..., 1 + n_tok, d_emb] 81 | return log_linear_recurrence[..., 1:, :] # [..., n_tok, d_emb] 82 | 83 | def forward(self, tokens, using_prev_context): 84 | tup = self.dense(tokens).split(self.d_emb, dim=-1) # [..., n_tok, d_emb] x 2 85 | log_coeffs, logits = (F.logsigmoid(tup[0]), tup[1]) # [..., n_tok, d_emb] x 2 86 | if using_prev_context: 87 | prepended_logits = torch.cat([self.prev_context, logits], dim=-2) # [..., 1 + n_tok, d_emb] 88 | else: 89 | prepended_logits = F.pad(logits, (0,0, 1,0), value=0) # [..., 1 + n_tok, d] 90 | pos_embs = self._log_linear_recurrence(log_coeffs, prepended_logits) # [..., n_tok, d_emb] 91 | self.prev_context = pos_embs[..., -1:, :].detach() # [..., 1, d_emb] 92 | return tokens + pos_embs # [..., n_tok, d_emb] 93 | 94 | 95 | class GenerativeLanguageModel(nn.Module): 96 | """ 97 | Given a sequence of token ids, predict each next token id. 98 | 99 | Input shape: 100 | token_ids: [..., n_toks], sequence of token ids. 101 | 102 | Output shape: 103 | predicted logits [..., n_toks, vocab_sz]. 104 | """ 105 | def __init__(self, config: ModelConfig) -> None: 106 | super().__init__() 107 | _initial_embs = torch.empty(config.vocab_sz, config.d_emb).uniform_(-1, 1) / sqrt(config.d_emb) 108 | self.embed = nn.Embedding(*_initial_embs.shape, _weight=_initial_embs) 109 | self.embed_pos = EmbedPosition(config.d_emb) 110 | self.layers = nn.Sequential(*[ 111 | ResidualLayer(config.d_emb, config.n_heads, config.d_key, config.d_val) 112 | for _ in range(config.n_layers) 113 | ]) 114 | self.lnorm = nn.LayerNorm(config.d_emb) 115 | self.config = config 116 | 117 | def extra_repr(self): 118 | return 'config={}'.format(self.config) 119 | 120 | def body(self, token_ids, using_prev_context=False): 121 | x = self.embed(token_ids) 122 | x = self.embed_pos(x, using_prev_context) 123 | for layer in self.layers: 124 | x = layer(x, using_prev_context) 125 | x = self.lnorm(x) 126 | return x 127 | 128 | def head(self, x): 129 | return x @ self.embed.weight.T 130 | 131 | def forward(self, token_ids, using_prev_context=False): 132 | x = self.body(token_ids, using_prev_context) 133 | x = self.head(x) 134 | return x 135 | 136 | # Convenience methods: 137 | 138 | def get_param_groups(self, weight_decay): 139 | decay_attrs = { nn.Embedding: ['weight'], nn.Linear: ['weight'], } 140 | decay_modules = set(m for m in self.modules() if type(m) in decay_attrs.keys()) 141 | decay_ids = set(id(getattr(m, attr)) for m in decay_modules for attr in decay_attrs[type(m)]) 142 | return [ 143 | { 'params': [p for p in self.parameters() if id(p) in decay_ids], 'weight_decay': weight_decay, }, 144 | { 'params': [p for p in self.parameters() if id(p) not in decay_ids], 'weight_decay': 0.0, }, 145 | ] 146 | 147 | @torch.no_grad() 148 | def generate(self, token_ids, n_new, temp=1.0, topk=None, using_prev_context=False, show_progress=False): 149 | assert self.training is False, "Model should be in eval mode." 150 | generated_ids = [] 151 | upc_states = [using_prev_context] + [True] * (n_new - 1) 152 | for upc_state in (tqdm(upc_states) if show_progress else upc_states): 153 | hidden_states = self.body(token_ids, using_prev_context=upc_state) 154 | logits = self.head(hidden_states[..., -1, :]) / temp 155 | if topk is not None: 156 | min_of_topk = logits.topk(topk, dim=-1).values.min(dim=-1, keepdim=True).values 157 | logits[logits < min_of_topk] = float('-inf') 158 | token_ids = torch.multinomial(logits.softmax(dim=-1), num_samples=1) 159 | generated_ids.append(token_ids) 160 | return torch.cat(generated_ids, dim=-1) 161 | 162 | 163 | def build_model(**model_config_kwds): 164 | return GenerativeLanguageModel(ModelConfig(**model_config_kwds)) 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # heinsen_attention 2 | 3 | Reference implementation of "[Softmax Attention with Constant Cost per Token](http://arxiv.org/abs/2404.05843)" (Heinsen, 2024). 4 | 5 | We propose a simple modification to the conventional Softmax attention mechanism applied by Transformers: Instead of quantifying pairwise query-key similarity with scaled dot-products, we quantify it with the logarithms of scaled dot-products of exponentials: 6 | 7 | $$\overset{\text{modified}}{\text{Attention}}(Q, K, V) := \displaystyle \text{Softmax}\left( \log \frac{\exp(Q) \exp(K)^T}{\exp(c)} \right) V,$$ 8 | 9 | where $c$ is a scaling constant. This simple modification [linearizes attention](https://arxiv.org/abs/2006.16236) with exponential kernel feature maps and makes it expressible as a composition of log-sums of exponentials, with a latent space of constant size, enabling application with constant time and space complexity per token. 10 | 11 | Note that the feature function corresponding to an exponential kernel is infinite dimensional. 12 | 13 | 14 | ## Table of Contents 15 | 16 | * [How Does it Work?](#how-does-it-work) 17 | 18 | * [Frequently Asked Questions](#frequently-asked-questions) 19 | 20 | * [Installation and Usage](#installation-and-usage) 21 | 22 | * [Important Limitations](#important-limitations) <-- make sure to read them! 23 | 24 | * [Replicating Published Results](#replicating-published-results) 25 | 26 | * [Notes](#notes) 27 | 28 | * [Citing](#citing) 29 | 30 | 31 | ## How Does it Work? 32 | 33 | It's best to _see it in action_ with a toy example. First, we will show how to compute causal (autoregressive) Softmax attention with our modification using the familiar quadratic-cost formulation. Then, we will show how we linearize computation as a composition of log-sums of exponentials, obtaining the same results. Finally, we will split the sequence in chunks and compute attention sequentially, chunk by chunk, incurring constant cost per token, again obtaining the same results. 34 | 35 | 36 | ### Our Toy Example 37 | 38 | Start by importing all dependencies we will need: 39 | 40 | ```python 41 | import torch 42 | import torch.nn as nn 43 | import torch.nn.functional as F 44 | ``` 45 | 46 | Now, let's create toy queries `Q`, keys `K`, and values `V`. Our method requires computing the logarithm of `V`. If there are any negative values in `V`, their logarithms will be complex numbers, which are not uniformly well-supported in PyTorch. To avoid having to deal with them in our toy example, we will limit `V`'s elements to positive numbers. Also, we will keep the number of tokens `n_tok`, key features `d_key`, and value features `d_val` tiny so that when we print results, they can fit on a single screen: 47 | 48 | ```python 49 | # Setup for our toy example: 50 | n_tok = 10 51 | d_key = 4 52 | d_val = 4 53 | 54 | Q = torch.randn(n_tok, d_key) 55 | K = torch.randn(n_tok, d_key) 56 | 57 | log_V = torch.randn(n_tok, d_val) # real 58 | V = torch.exp(log_V) # positive only 59 | ``` 60 | 61 | 62 | ### First, Causal Softmax Attention with Quadratic Cost 63 | 64 | Here is a PyTorch module that computes our attention mechanism with its quadratic-cost formulation, 65 | 66 | $$\text{Softmax} \left( \log \frac{\exp(Q) \exp(K)^T}{\exp(c)} \right) V,$$ 67 | 68 | using $c = c_1 + c_2$ as the scaling constant, with $c_1 = \max(Q)$ and $c_2 = \max(K)$: 69 | 70 | ```python 71 | class QuadraticCostCausalAttention(nn.Module): 72 | 73 | def __init__(self): 74 | super().__init__() 75 | 76 | def forward(self, Q, K, V): 77 | c1, c2 = (Q.detach().max(), K.detach().max()) # scaling constants 78 | sims = torch.log((Q - c1).exp() @ (K - c2).exp().transpose(-2, -1)) # [n_tok, n_tok] 79 | mask = sims.new_ones(sims.shape[-2:], dtype=torch.bool).tril() # [n_tok, n_tok] 80 | sims = sims.masked_fill(mask.logical_not(), float('-inf')) # [n_tok, n_tok] 81 | Y = F.softmax(sims, dim=-1) @ V # [n_tok, d_val] 82 | return Y 83 | ``` 84 | 85 | Try it: 86 | 87 | ```python 88 | quadratic_attn = QuadraticCostCausalAttention() 89 | Y1 = quadratic_attn(Q, K, V) 90 | print(Y1) 91 | ``` 92 | 93 | 94 | ### Second, Linearized Causal Softmax Attention 95 | 96 | Here is a PyTorch module that computes the same output, using a linearized formulation that consists _entirely of log-sums of exponentials_. Note that the module accepts `log_V` instead of `V` as an input: 97 | 98 | ```python 99 | class LinearizedCausalAttention(nn.Module): 100 | 101 | def __init__(self): 102 | super().__init__() 103 | 104 | def forward(self, Q, K, log_V): 105 | Q, K, log_V = (Q.unsqueeze(-1), K.unsqueeze(-1), log_V.unsqueeze(-2)) 106 | 107 | H_S = torch.logcumsumexp(K + log_V, dim=-3) # [n_tok, d_key, d_val] eq. (6) in paper 108 | H_Z = torch.logcumsumexp(K , dim=-3) # [n_tok, d_key, 1] eq. (6) 109 | 110 | log_S = torch.logsumexp(Q + H_S, dim=-2) # [n_tok, d_val] eq. (5) 111 | log_Z = torch.logsumexp(Q + H_Z, dim=-2) # [n_tok, d_val] eq. (5) 112 | 113 | Y = torch.exp(log_S - log_Z) # [n_tok, d_val] eq. (2) 114 | return Y 115 | ``` 116 | 117 | Try it: 118 | 119 | ```python 120 | linearized_attn = LinearizedCausalAttention() 121 | Y2 = linearized_attn(Q, K, log_V) 122 | print(Y2) 123 | ``` 124 | 125 | You can confirm the results are the same as with the quadratic formulation: 126 | 127 | ```python 128 | print('Do Y1 and Y2 match?', torch.allclose(Y1, Y2)) 129 | ``` 130 | 131 | 132 | ### Finally, Sequential Causal Softmax Attention with Constant Cost per Token 133 | 134 | We now sequentialize the computation by caching our attention mechanism's latent state, which has a constant size, enabling us to apply attention over a stream of tokens that arrive in chunks, with constant time and space complexity per token: 135 | 136 | ```python 137 | class SequentialCausalAttention(nn.Module): 138 | 139 | def __init__(self): 140 | super().__init__() 141 | 142 | def forward(self, Q, K, log_V, using_prev_context=False): 143 | Q, K, log_V = (Q.unsqueeze(-1), K.unsqueeze(-1), log_V.unsqueeze(-2)) 144 | 145 | H_S = torch.logcumsumexp(K + log_V, dim=-3) # [n_tok, d_key, d_val] eq. (6) in paper 146 | H_Z = torch.logcumsumexp(K , dim=-3) # [n_tok, d_key, 1] eq. (6) 147 | 148 | if using_prev_context: 149 | H_S = self.prev_H_S.logaddexp(H_S) # [n_tok, d_key, d_val] use cache 150 | H_Z = self.prev_H_Z.logaddexp(H_Z) # [n_tok, d_key, 1] use cache 151 | 152 | self.prev_H_S = H_S[..., -1:, :, :].detach() # [1, d_key, d_val] cache end-state 153 | self.prev_H_Z = H_Z[..., -1:, :, :].detach() # [1, d_key, 1] cache end-state 154 | 155 | log_S = torch.logsumexp(Q + H_S, dim=-2) # [n_tok, d_val] eq. (5) 156 | log_Z = torch.logsumexp(Q + H_Z, dim=-2) # [n_tok, 1] eq. (5) 157 | 158 | Y = torch.exp(log_S - log_Z) # [n_tok, d_val] eq. (2) 159 | return Y 160 | ``` 161 | 162 | Try it: 163 | 164 | ```python 165 | # Split sequence into a stream of chunks: 166 | chunk_len = 3 167 | chunks = zip( 168 | Q.split(chunk_len, dim=-2), 169 | K.split(chunk_len, dim=-2), 170 | log_V.split(chunk_len, dim=-2), 171 | ) 172 | 173 | # Instantiate the module: 174 | sequential_attn = SequentialCausalAttention() 175 | 176 | # Compute attention over the first chunk: 177 | chunk = next(chunks) 178 | print('Processing a chunk with {} tokens.'.format(chunk[0].size(-2))) 179 | Y3 = [sequential_attn(*chunk)] # saves latent state 180 | 181 | # Compute attention over remaining chunks, using prev context for each one: 182 | for chunk in chunks: 183 | print('Processing a chunk with {} tokens.'.format(chunk[0].size(-2))) 184 | Y3.append(sequential_attn(*chunk, using_prev_context=True)) 185 | 186 | print('---\nConcatenated:') 187 | Y3 = torch.cat(Y3, dim=-2) 188 | print(Y3) 189 | ``` 190 | 191 | You can confirm the results are the same as before: 192 | 193 | ```python 194 | print('Do Y1 and Y3 match?', torch.allclose(Y1, Y3)) 195 | ``` 196 | 197 | At each step, the above module is computing attention over all tokens in the input context! Remarkably, the stream of chunks could be _never-ending_! That's right: We can compute Softmax attention over input contexts of unlimited length! 198 | 199 | 200 | ### The Key Insight 201 | 202 | Take a single query vector $\mathbf{q}$ and a single key vector $\mathbf{k}$ in $\mathbb{R}^{d}$. 203 | 204 | $$\mathbf{q} = \begin{bmatrix} q_1 \\\ q_2 \\\ \vdots \\\ q_d \end{bmatrix}, \quad \mathbf{k} = \begin{bmatrix} k_1 \\\ k_2 \\\ \vdots \\\ k_d \end{bmatrix}.$$ 205 | 206 | The logarithm of the dot-product $\langle \cdot, \cdot \rangle$ of their exponentials is: 207 | 208 | $$\begin{aligned} 209 | \log \langle \exp(\mathbf{q}), \exp(\mathbf{k}) \rangle 210 | & = \log ( e^{q_1} e^{k_1} + e^{q_2} e^{k_2} + \dots + e^{q_d} e^{k_d} ) \\ 211 | & = \log \sum \left( \begin{bmatrix} e^{q_1} \\\ e^{q_2} \\\ \vdots \\\ e_{q^d} \end{bmatrix} \odot \begin{bmatrix} e^{k_1} \\\ e_{k^2} \\\ \vdots \\\ e_{k^d} \end{bmatrix} \right) \\ 212 | & = \log \sum \left( \begin{bmatrix} e^{q_1} e^{k_1} \\\ e^{q_2} e^{k_2} \\\ \vdots \\\ e^{q_d} e^{k_d} \end{bmatrix} \right) \\ 213 | & = \log \sum \left( \begin{bmatrix} e^{q_1 + k_1} \\\ e^{q_2 + k_2} \\\ \vdots \\\ e^{q_d + k_d} \end{bmatrix} \right) \\ 214 | & = \log \sum \exp \left( \begin{bmatrix} q_1 \\\ q_2 \\\ \vdots \\\ q_d \end{bmatrix} + \begin{bmatrix} k_1 \\\ k_2 \\\ \vdots \\\ k_d \end{bmatrix} \right) \\ 215 | & = \log\sum\exp ( \mathbf{q} + \mathbf{k} ) \\ 216 | & = \text{LSE} ( \mathbf{q} + \mathbf{k} ), \\ 217 | \end{aligned}$$ 218 | 219 | where $\text{LSE}$ is shorthand for "Logarithm of a Sum of Exponentials." 220 | 221 | Armed with this insight, we prove that our Softmax attention mechanism is expressible as a composition of log-sums of exponentials that is linearizable, with a latent space of constant size, enabling sequential application with constant time and space complexity per token. For details, please see our paper. 222 | 223 | 224 | ## Frequently Asked Questions 225 | 226 | *Q: "Is this method a special case of ``linear attention'' as proposed by [Katharopoulos et al (2020)](https://arxiv.org/abs/2006.16236)?"* 227 | 228 | A: Yes. The quadratic-cost formulation is expressible as a special case of linear attention. It's the special case that applies exponential kernel feature maps, whose corresponding feature function is infinite dimensional: 229 | 230 | $$\text{Softmax}\left( \log \frac{\exp(Q) \exp(K)^T}{\exp(c)} \right) V = \begin{bmatrix} \displaystyle \frac{\exp(Q) \exp(K)^T}{\sum_{[n_K]} \exp(Q) \exp(K)^T} \end{bmatrix} V,$$ 231 | 232 | where $\sum_{[n_K]}$ sums over the dimension indexed by the number of keys. The gram matrix is symmetric and positive semi-definite, giving us a kernel (Mercer's theorem). Expressed in code: 233 | 234 | ```python 235 | class NumericallyUnstableCausalAttention(nn.Module): 236 | 237 | def __init__(self): 238 | super().__init__() 239 | 240 | def forward(self, Q, K, V): 241 | exp_sims = Q.exp() @ K.exp().transpose(-2, -1) # [n_tok, n_tok] 242 | mask = exp_sims.new_ones(exp_sims.shape[-2:], dtype=torch.bool).tril() # [n_tok, n_tok] 243 | exp_sims = exp_sims.masked_fill(mask.logical_not(), 0.0) # [n_tok, n_tok] 244 | Y = (exp_sims / exp_sims.sum(dim=-1, keepdim=True)) @ V # [n_tok, d_val] 245 | return Y 246 | ``` 247 | 248 | It turns out this special case is expressible _entirely as a composition of log-sums of exponentials_. 249 | 250 | Initially, we didn't realize our modification was a special case of linear attention. In hindsight, we're a bit embarrassed that we didn't see it right away. Maybe our gray matter was temporarily stuck on subpar local optima? Please see shaochenze's comment [here](https://github.com/glassroom/heinsen_attention/issues/1). 251 | 252 | 253 | *Q: "Can this be generalized to functions other than _exp()_ and _log()_?"* 254 | 255 | A: Yes. If we define $\phi = \exp$, we have: 256 | 257 | $$\overset{\text{modified}}{\text{Attention}}(Q, K, V) := \displaystyle \text{Softmax}\left( \phi^{-1} \left( \frac{\phi(Q) \phi(K)^T}{\phi(c)} \right) \right) V.$$ 258 | 259 | The question is whether there are other functions $\phi$ that are not $\exp$ (and do not exponentiate) which (a) are invertible, and (b) enable linearization of the Softmax function as a composition of (log-) sums. We suspect the answer is no. It might be possible to replace $\exp$ and $\log$ with two functions that are not each other's inverses and together enable linearization of the Softmax function as a composition of sums, but the result might not work as well or be... as elegant. 260 | 261 | 262 | *Q: "How can I help?"* 263 | 264 | A: Glad you asked! The most helpful thing anyone could do is write code that addresses the two [self-imposed limitations](#important-limitations) of our implementation with efficiency and numerical stability in PyTorch. Another thing that would be helpful is implementing our method in other software frameworks (e.g., JAX, TensorFlow) and languages (e.g., Julia, Mojo) that maybe could make it easier to address both limitations. Finally, our method has yet to be tested on a diverse set of tasks and benchmarks with larger models. 265 | 266 | 267 | ## Installation and Usage 268 | 269 | ``` 270 | pip install git+https://github.com/glassroom/heinsen_attention 271 | ``` 272 | 273 | Alternatively, you can download a single file to your project directory: [heinsen_attention.py](https://github.com/glassroom/heinsen_attention/blob/main/heinsen_attention/heinsen_attention.py). 274 | 275 | The only dependency is a recent version of [PyTorch](https://pytorch.org/). 276 | 277 | 278 | ### Usage 279 | 280 | Our implementation returns _the logarithm of Softmax attention_, which is a float tensor like `log_V`. In practice, we have found that computing `log_V` as a float tensor directly from token states and using the logarithm of our attention mechanism as input to subsequent model components works well! 281 | 282 | ```python 283 | # Load PyTorch module: 284 | from heinsen_attention import LogAttention 285 | 286 | # Instantiate PyTorch module: 287 | log_attn = LogAttention(is_causal=True) 288 | 289 | # Compute log(Attention(...)): 290 | log_Y = log_attn(Q, K, log_V) 291 | ``` 292 | 293 | To compute attention over additional tokens in the same sequence, pass `using_prev_context=True` to the module's forward pass: 294 | 295 | ```python 296 | log_Y = log_attn(Q, K, log_V, using_prev_context=True) 297 | ``` 298 | 299 | For a concrete example of how we do this, see the residual layer of the generative language model we use in our experiments, defined in the file `generative_language_model.py`. 300 | 301 | 302 | ## Important Limitations 303 | 304 | For simplicity and expediency, we limit our implementation in two significant ways: 305 | 306 | 1. We restrict the values $V$ to positive numbers to avoid dealing with complex floating-point numbers, which incur greater overhead and presently are more cumbersome to manipulate than real floating-point numbers. In practice, we have found this isn't an issue: We work with the logarithm of attention, which is in the same space as $\log V$. For a concrete example of how we do this, see the residual layer of the generative language model we use in our experiments, defined in the file `generative_language_model.py`. 307 | 308 | 2. When computing autoregressive attention in parallel over all tokens in a sequence, we first compute all latent states with two parallel scans (`logcumsumexp`'s), keeping all latent states simultaneously in memory as intermediate values, and then reduce them, which is memory-inefficient but easier to write than a memory-efficient implementation. In practice, this impacts the amount of memory required for training. 309 | 310 | Neither limitation is intrinsic to our attention mechanism. Both can be addressed with code. 311 | 312 | 313 | ## Replicating Published Results 314 | 315 | The generative language model we use in our experiments is defined in the file `generative_language_model.py`. The only additional requirement is [tqdm](https://tqdm.github.io/), for displaying a progress bar when generating tokens. 316 | 317 | Build the model with: 318 | 319 | ```python 320 | from generative_language_model import build_model 321 | model = build_model() 322 | ``` 323 | 324 | To replicate our results, train the model on 300B tokens from The Pile ([Gao et al, 2020](https://arxiv.org/abs/2101.00027)) using a conventional setup: AdamW optimizer with weight decay 1e-1 and betas (0.90, 0.95), and one-cycle lr schedule with short warm-up, max lr 6e-4, min lr 6e-5 (e.g., you could use [this training script](https://github.com/karpathy/nanoGPT/blob/master/train.py) by Andrej Karpathy with minor modifications). For convenience, the model splits its parameters into groups with/without weight decay: 325 | 326 | ```python 327 | param_groups = model.get_param_groups(self, weight_decay=1e-1) 328 | optimizer = torch.optim.AdamW(param_groups) 329 | ``` 330 | 331 | For tokenization, we use [tiktoken](https://github.com/openai/tiktoken) with the 'gpt2' vocabulary. 332 | 333 | For training hardware, we would recommend at least an 8XA100 40GB. 334 | 335 | 336 | ## Notes 337 | 338 | We have tested the code in this repository only on Ubuntu Linux 22.04 with Python 3.10+. 339 | 340 | 341 | ## Citing 342 | 343 | ``` 344 | @misc{heinsen2024softmax, 345 | title={Softmax Attention with Constant Cost per Token}, 346 | author={Franz A. Heinsen}, 347 | year={2024}, 348 | eprint={2404.05843}, 349 | archivePrefix={arXiv}, 350 | primaryClass={cs.LG} 351 | } 352 | ``` 353 | 354 | 355 | ## How is this used at GlassRoom? 356 | 357 | We conceived and implemented our attention mechanism for proprietary use. Most of the original work we do at GlassRoom tends to be tightly coupled to internal code, so we cannot share it with outsiders. In this case, however, we were able to isolate our code and release it as stand-alone open-source software without having to disclose any key intellectual property. We hope others find our work and our code useful. 358 | --------------------------------------------------------------------------------