├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── memory_efficient_attention ├── __init__.py ├── attention_jax.py ├── attention_torch.py └── utils.py ├── setup.py └── test ├── __init__.py └── test_computation.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v2 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: '__token__' 23 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | # Project 3 | .models/ 4 | 5 | # IntelliJ project files 6 | .idea 7 | *.iml 8 | out 9 | gen 10 | ### Python template 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Amin Rezaei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Memory Efficient Attention 2 | [![arXiv](https://img.shields.io/badge/arXiv-2112.05682v2-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2112.05682v2) 3 | [![PyPI version](https://img.shields.io/badge/memory--efficient--attention-0.1.3-informational?style=flat-square&color=C51BA3)](https://pypi.org/project/memory-efficient-attention/) 4 | 5 | This is **unofficial** implementation of [Self-attention Does Not Need O(n^2) Memory](https://arxiv.org/abs/2112.05682v2) for Jax and PyTorch. 6 | 7 | Implementation is almost same as the one proposed in the paper, with additional **masking and adding bias compatibility**, **batch dimensions support** and **PyTorch implementation**. 8 | For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation. 9 | 10 | **Important Note:** This implementation is a trade-off between memory requirements and runtime, so you should adjust `key_chunk_size` and `query_chunk_size` parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors: 11 | 12 | >While a constant chunk size for the queries and a chunk size of sqrt(n) 13 | >for the keys and values is optimal for memory 14 | >consumption, the runtime is also affected by the choice of chunk size 15 | >in practice, which is heavily affected by the 16 | >choice of hardware. Ultimately, we have to leave this trade-off to the 17 | >programmer, and expose the chunk sizes as 18 | >arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that 19 | lead to minimal runtime impact (on TPUv2), while still providing significant memory savings. 20 | 21 | 22 | ## Quick Start 23 | 1. Install the library 24 | 25 | ```bash 26 | # for Jax 27 | pip install memory-efficient-attention[jax] 28 | # for PyTorch 29 | pip install memory-efficient-attention[torch] 30 | # for Running Tests 31 | pip install memory-efficient-attention[testing] 32 | ``` 33 | 34 | 2. Compute attention with the proper function 35 | 36 | ```python 37 | import numpy as np 38 | # for PyTorch 39 | from memory_efficient_attention import efficient_dot_product_attention_pt 40 | # or for Jax 41 | from memory_efficient_attention import efficient_dot_product_attention_jax 42 | 43 | # Random Data (batch dimensions are not necessary) 44 | b = 8 45 | query = np.random.rand(1, b, 128, 16, 8).astype("float32") 46 | key = np.random.rand(1, b, 128, 16, 8).astype("float32") 47 | value = np.random.rand(1, b, 128, 16, 8).astype("float32") 48 | # optional, for casual tasks, ... 49 | mask = np.random.rand(1, b, 16, 128, 128) > 0.5 50 | bias = np.random.rand(1, b, 16, 128, 128).astype("float32") / 100 51 | 52 | # Adjust chunk sizes 53 | efficient_dot_product_attention_jax(query, key, value, mask, bias, key_chunk_size=..., query_chunk_size=...) 54 | ``` 55 | 56 | ## Citation 57 | Please cite if this implementation helps your research. You can use the following BibTeX entry: 58 | 59 | ```bibtex 60 | @misc{memory_efficient_attention, 61 | title = {Memory Efficient Attention}, 62 | author = {Rezaei, Amin}, 63 | howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}}, 64 | year = {2021} 65 | } 66 | ``` 67 | Also, for the paper: 68 | ```bibtex 69 | @misc{rabe2021selfattention, 70 | title={Self-attention Does Not Need $O(n^2)$ Memory}, 71 | author={Markus N. Rabe and Charles Staats}, 72 | year={2021}, 73 | eprint={2112.05682}, 74 | archivePrefix={arXiv}, 75 | primaryClass={cs.LG} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /memory_efficient_attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention_jax import efficient_dot_product_attention as efficient_dot_product_attention_jax 2 | from .attention_torch import efficient_dot_product_attention as efficient_dot_product_attention_pt 3 | -------------------------------------------------------------------------------- /memory_efficient_attention/attention_jax.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import jax 3 | import math 4 | 5 | from jax import numpy as jnp 6 | 7 | 8 | def _query_chunk_attention(query_idx, query, key, value, 9 | mask, bias, precision, 10 | key_chunk_size=4096, 11 | mask_calc_fn=None, 12 | bias_calc_fn=None, 13 | weights_calc_fn=None, 14 | calc_fn_data=None): 15 | num_kv, num_heads, k_features = key.shape[-3:] 16 | v_features = value.shape[-1] 17 | num_q = query.shape[-3] 18 | key_chunk_size = min(key_chunk_size, num_kv) 19 | query = query / jnp.sqrt(k_features) 20 | 21 | @functools.partial(jax.checkpoint, prevent_cse=False) 22 | def summarize_chunk(chunk_idx, query, key, value, mask, bias): 23 | attn_weights = jnp.einsum('...qhd,...khd->...qhk', query, key, precision=precision) 24 | if bias_calc_fn is not None: 25 | bias = bias_calc_fn(query_idx, chunk_idx, bias, attn_weights, calc_fn_data) 26 | if bias is not None: 27 | bias = jnp.einsum('...hqk->...qhk', bias) 28 | attn_weights = attn_weights + bias 29 | if mask_calc_fn is not None: 30 | mask = mask_calc_fn(query_idx, chunk_idx, mask, attn_weights, calc_fn_data) 31 | if mask is not None: 32 | big_neg = jnp.finfo(attn_weights.dtype).min 33 | mask = jnp.einsum('...hqk->...qhk', mask) 34 | attn_weights = jnp.where(mask, attn_weights, big_neg) 35 | if weights_calc_fn is not None: 36 | attn_weights = weights_calc_fn(query_idx, chunk_idx, attn_weights, calc_fn_data) 37 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 38 | max_score = jax.lax.stop_gradient(max_score) 39 | exp_weights = jnp.exp(attn_weights - max_score) 40 | exp_values = jnp.einsum('...vhf,...qhv->...qhf', value, exp_weights, precision=precision) 41 | max_score = jnp.einsum('...qhk->...qh', max_score) 42 | return exp_values, exp_weights.sum(axis=-1), max_score 43 | 44 | def chunk_scanner(chunk_idx): 45 | key_chunk = jax.lax.dynamic_slice( 46 | key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), 47 | slice_sizes=tuple(key.shape[:-3]) + (key_chunk_size, num_heads, k_features)) 48 | value_chunk = jax.lax.dynamic_slice( 49 | value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), 50 | slice_sizes=tuple(value.shape[:-3]) + (key_chunk_size, num_heads, v_features)) 51 | 52 | if bias is None: 53 | bias_chunk = None 54 | elif bias.shape[-1] == 1: 55 | bias_chunk = bias 56 | elif bias.shape[-1] == num_kv: 57 | bias_chunk = jax.lax.dynamic_slice( 58 | bias, tuple([0] * (bias.ndim - 3)) + (0, 0, chunk_idx), 59 | slice_sizes=tuple(bias.shape[:-3]) + (bias.shape[-3], bias.shape[-2], key_chunk_size)) 60 | else: 61 | raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') 62 | 63 | if mask is None: 64 | mask_chunk = None 65 | elif bias.shape[-1] == 1: 66 | mask_chunk = mask 67 | elif mask.shape[-1] == num_kv: 68 | mask_chunk = jax.lax.dynamic_slice( 69 | mask, tuple([0] * (mask.ndim - 3)) + (0, 0, chunk_idx), 70 | slice_sizes=tuple(mask.shape[:-3]) + (mask.shape[-3], mask.shape[-2], key_chunk_size)) 71 | else: 72 | raise TypeError(f'mask.shape[-1] == {mask.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') 73 | 74 | return summarize_chunk(chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) 75 | 76 | chunk_values, chunk_weights, chunk_max = jax.lax.map( 77 | chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) 78 | 79 | global_max = jnp.max(chunk_max, axis=0, keepdims=True) 80 | max_diffs = jnp.exp(chunk_max - global_max) 81 | chunk_values *= jnp.expand_dims(max_diffs, axis=-1) 82 | chunk_weights *= max_diffs 83 | 84 | all_values = chunk_values.sum(axis=0) 85 | all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) 86 | return all_values / all_weights 87 | 88 | 89 | def efficient_dot_product_attention(query, key, value, 90 | mask=None, bias=None, 91 | precision=jax.lax.Precision.HIGHEST, 92 | query_chunk_size=1024, 93 | key_chunk_size=4096, 94 | bias_calc_fn=None, 95 | mask_calc_fn=None, 96 | weights_calc_fn=None, 97 | calc_fn_data=None): 98 | """Computes efficient dot-product attention given query, key, and value. 99 | This is efficient version of attention presented in 100 | https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. 101 | Note: query, key, value needn't have any batch dimensions. 102 | Args: 103 | query: queries for calculating attention with shape of 104 | `[batch..., q_length, num_heads, qk_depth_per_head]`. 105 | key: keys for calculating attention with shape of 106 | `[batch..., kv_length, num_heads, qk_depth_per_head]`. 107 | value: values to be used in attention with shape of 108 | `[batch..., kv_length, num_heads, v_depth_per_head]`. 109 | bias: bias for the attention weights. This should be broadcastable to the 110 | shape `[batch..., num_heads, q_length, kv_length]`. 111 | This can be used for incorporating padding masks, proximity bias, etc. 112 | mask: mask for the attention weights. This should be broadcastable to the 113 | shape `[batch..., num_heads, q_length, kv_length]`. 114 | Attention weights are masked out if their corresponding mask value 115 | is `False`. 116 | query_chunk_size: int: query chunks size 117 | key_chunk_size: int: key chunks size 118 | bias_calc_fn: a bias calculation callback for each chunk, of form 119 | `(q_offset, k_offset, bias_chunk, attn_weights, calc_fn_data) -> bias`. 120 | This can be used for incorporating causal masks, padding masks, 121 | proximity bias, etc. 122 | mask_calc_fn: a mask calculation callback for each chunk, of form 123 | `(q_offset, k_offset, mask_chunk, attn_weights, calc_fn_data) -> mask`. 124 | This can be used for incorporating causal or other large masks. 125 | Attention weights are masked out if their corresponding mask value 126 | is `False`. 127 | weights_calc_fn: a general attn_weights callback for each chunk, of form 128 | `(q_offset, k_offset, attn_weights, calc_fn_data) -> attn_weights`. 129 | attn_weights has shape of 130 | `[batch..., q_chunk_size, num_heads, k_chunk_size]`. 131 | This can be used to implement complex weights processing in a memory 132 | efficient way. 133 | calc_fn_data: optional pure data to pass to each per-chunk call of 134 | bias_calc_fn, mask_calc_fn, and weights_calc_fn. 135 | precision: numerical precision of the computation see `jax.lax.Precision` 136 | for details. 137 | Returns: 138 | Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. 139 | """ 140 | num_q, num_heads, q_features = query.shape[-3:] 141 | num_kv = key.shape[-3] 142 | 143 | def chunk_scanner(chunk_idx, _): 144 | query_chunk = jax.lax.dynamic_slice( 145 | query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), 146 | slice_sizes=tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) 147 | 148 | if mask is None: 149 | mask_chunk = None 150 | elif mask.shape[-2] == 1: 151 | mask_chunk = mask 152 | elif mask.shape[-2] == num_q: 153 | mask_chunk = jax.lax.dynamic_slice( 154 | mask, tuple([0] * (mask.ndim - 3)) + (0, chunk_idx, 0), 155 | slice_sizes=tuple(mask.shape[:-3]) + (mask.shape[-3], min(query_chunk_size, num_q), mask.shape[-1])) 156 | else: 157 | raise TypeError(f'mask.shape[-2] == {mask.shape[-2]} must broadcast with query.shape[-3] == {num_q}') 158 | 159 | if bias is None: 160 | bias_chunk = None 161 | elif mask.shape[-2] == 1: 162 | bias_chunk = bias 163 | elif bias.shape[-2] == num_q: 164 | bias_chunk = jax.lax.dynamic_slice( 165 | bias, tuple([0] * (bias.ndim - 3)) + (0, chunk_idx, 0), 166 | slice_sizes=tuple(bias.shape[:-3]) + (bias.shape[-3], min(query_chunk_size, num_q), bias.shape[-1])) 167 | else: 168 | raise TypeError(f'bias.shape[-2] == {bias.shape[-2]} must broadcast with query.shape[-3] == {num_q}') 169 | 170 | return (chunk_idx + query_chunk_size, 171 | _query_chunk_attention(chunk_idx, query_chunk, key, value, mask_chunk, bias_chunk, 172 | precision=precision, key_chunk_size=key_chunk_size, 173 | bias_calc_fn=bias_calc_fn, mask_calc_fn=mask_calc_fn, 174 | weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data)) 175 | 176 | _, res = jax.lax.scan( 177 | chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 178 | return jnp.concatenate(res, axis=-3) 179 | -------------------------------------------------------------------------------- /memory_efficient_attention/attention_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import checkpoint 3 | from .utils import dynamic_slice, map_pt, scan 4 | import math 5 | 6 | 7 | def _query_chunk_attention(query_idx, query, key, value, 8 | mask, bias, key_chunk_size=4096, 9 | mask_calc_fn=None, 10 | bias_calc_fn=None, 11 | weights_calc_fn=None, 12 | calc_fn_data=None): 13 | num_kv, num_heads, k_features = key.shape[-3:] 14 | v_features = value.shape[-1] 15 | num_q = query.shape[-3] 16 | key_chunk_size = min(key_chunk_size, num_kv) 17 | query = query / math.sqrt(k_features) 18 | 19 | def summarize_chunk(key_idx, query, key, value, mask, bias): 20 | attn_weights = torch.einsum('...qhd,...khd->...qhk', query, key) 21 | if bias_calc_fn is not None: 22 | bias = bias_calc_fn(query_idx, key_idx, bias, attn_weights, calc_fn_data) 23 | if bias is not None: 24 | bias = torch.einsum('...hqk->...qhk', bias) 25 | attn_weights = attn_weights + bias 26 | if mask_calc_fn is not None: 27 | mask = mask_calc_fn(query_idx, key_idx, mask, attn_weights, calc_fn_data) 28 | if mask is not None: 29 | big_neg = torch.finfo(attn_weights.dtype).min 30 | big_neg = torch.tensor(big_neg, device=mask.device, dtype=torch.float32) 31 | mask = torch.einsum('...hqk->...qhk', mask) 32 | attn_weights = torch.where(mask, attn_weights, big_neg) 33 | if weights_calc_fn is not None: 34 | attn_weights = weights_calc_fn(query_idx, key_idx, attn_weights, calc_fn_data) 35 | max_score, _ = torch.max(attn_weights, -1, keepdim=True) 36 | max_score = max_score.detach() 37 | exp_weights = torch.exp(attn_weights - max_score) 38 | exp_values = torch.einsum('...vhf,...qhv->...qhf', value, exp_weights) 39 | max_score = torch.einsum('...qhk->...qh', max_score) 40 | return exp_values, exp_weights.sum(dim=-1), max_score 41 | 42 | def chunk_scanner(chunk_idx): 43 | key_chunk = dynamic_slice(key, tuple([0] * (key.ndim - 3)) + (chunk_idx, 0, 0), 44 | tuple(key.shape[:-3]) + (key_chunk_size, num_heads, k_features)) 45 | value_chunk = dynamic_slice(value, tuple([0] * (value.ndim - 3)) + (chunk_idx, 0, 0), 46 | tuple(value.shape[:-3]) + (key_chunk_size, num_heads, v_features)) 47 | 48 | if bias is None: 49 | bias_chunk = None 50 | elif bias.shape[-1] == 1: 51 | bias_chunk = bias 52 | elif bias.shape[-1] == num_kv: 53 | bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, 0, chunk_idx), 54 | tuple(bias.shape[:-3]) + (bias.shape[-3], bias.shape[-2], key_chunk_size)) 55 | else: 56 | raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') 57 | 58 | if mask is None: 59 | mask_chunk = None 60 | elif mask.shape[-1] == 1: 61 | mask_chunk = mask 62 | elif mask.shape[-1] == num_kv: 63 | mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, 0, chunk_idx), 64 | tuple(mask.shape[:-3]) + (mask.shape[-3], mask.shape[-2], key_chunk_size)) 65 | else: 66 | raise TypeError(f'bias.shape[-1] == {bias.shape[-1]} must broadcast with key.shape[-3] == {num_kv}') 67 | 68 | return checkpoint(summarize_chunk, chunk_idx, query, key_chunk, value_chunk, mask_chunk, bias_chunk) 69 | 70 | chunk_values, chunk_weights, chunk_max = map_pt( 71 | chunk_scanner, xs=torch.arange(0, num_kv, key_chunk_size)) 72 | 73 | global_max, _ = torch.max(chunk_max, 0, keepdim=True) 74 | max_diffs = torch.exp(chunk_max - global_max) 75 | chunk_values *= torch.unsqueeze(max_diffs, -1) 76 | chunk_weights *= max_diffs 77 | 78 | all_values = chunk_values.sum(dim=0) 79 | all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) 80 | return all_values / all_weights 81 | 82 | 83 | def efficient_dot_product_attention(query, key, value, 84 | mask=None, bias=None, 85 | query_chunk_size=1024, 86 | key_chunk_size=4096, 87 | bias_calc_fn=None, 88 | mask_calc_fn=None, 89 | weights_calc_fn=None, 90 | calc_fn_data=None): 91 | """Computes efficient dot-product attention given query, key, and value. 92 | This is efficient version of attention presented in 93 | https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. 94 | Note: query, key, value needn't have any batch dimensions. 95 | Args: 96 | query: queries for calculating attention with shape of 97 | `[batch..., q_length, num_heads, qk_depth_per_head]`. 98 | key: keys for calculating attention with shape of 99 | `[batch..., kv_length, num_heads, qk_depth_per_head]`. 100 | value: values to be used in attention with shape of 101 | `[batch..., kv_length, num_heads, v_depth_per_head]`. 102 | bias: bias for the attention weights. This should be broadcastable to the 103 | shape `[batch..., num_heads, q_length, kv_length]`. 104 | This can be used for incorporating padding masks, proximity bias, etc. 105 | mask: mask for the attention weights. This should be broadcastable to the 106 | shape `[batch..., num_heads, q_length, kv_length]`. 107 | Attention weights are masked out if their corresponding mask value 108 | is `False`. 109 | query_chunk_size: int: query chunks size 110 | key_chunk_size: int: key chunks size 111 | bias_calc_fn: a bias calculation callback for each chunk, of form 112 | `(q_offset, k_offset, bias_chunk, attn_weights, calc_fn_data) -> bias`. 113 | This can be used for incorporating causal masks, padding masks, 114 | proximity bias, etc. 115 | mask_calc_fn: a mask calculation callback for each chunk, of form 116 | `(q_offset, k_offset, mask_chunk, attn_weights, calc_fn_data) -> mask`. 117 | This can be used for incorporating causal or other large masks. 118 | Attention weights are masked out if their corresponding mask value 119 | is `False`. 120 | weights_calc_fn: a general attn_weights callback for each chunk, of form 121 | `(q_offset, k_offset, attn_weights, calc_fn_data) -> attn_weights`. 122 | attn_weights has shape of 123 | `[batch..., q_chunk_size, num_heads, k_chunk_size]`. 124 | This can be used to implement complex weights processing in a memory 125 | efficient way. 126 | calc_fn_data: optional pure data to pass to each per-chunk call of 127 | bias_calc_fn, mask_calc_fn, and weights_calc_fn. 128 | weights_calc_data: pure_data to pass with each call to weights_calc_fn 129 | Returns: 130 | Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. 131 | """ 132 | num_q, num_heads, q_features = query.shape[-3:] 133 | num_kv = key.shape[-3] 134 | 135 | def chunk_scanner(chunk_idx, _): 136 | query_chunk = dynamic_slice(query, tuple([0] * (query.ndim - 3)) + (chunk_idx, 0, 0), 137 | tuple(query.shape[:-3]) + (min(query_chunk_size, num_q), num_heads, q_features)) 138 | 139 | if mask is None: 140 | mask_chunk = None 141 | elif mask.shape[-2] == 1: 142 | mask_chunk = mask 143 | elif mask.shape[-2] == num_q: 144 | mask_chunk = dynamic_slice(mask, tuple([0] * (mask.ndim - 3)) + (0, chunk_idx, 0), 145 | tuple(mask.shape[:-3]) + (mask.shape[-3], min(query_chunk_size, num_q), mask.shape[-1])) 146 | else: 147 | raise TypeError(f'mask.shape[-2] == {mask.shape[-2]} must broadcast with query.shape[-3] == {num_q}') 148 | 149 | if bias is None: 150 | bias_chunk = None 151 | elif bias.shape[-2] == 1: 152 | bias_chunk = bias 153 | elif bias.shape[-2] == num_q: 154 | bias_chunk = dynamic_slice(bias, tuple([0] * (bias.ndim - 3)) + (0, chunk_idx, 0), 155 | tuple(bias.shape[:-3]) + (bias.shape[-3], min(query_chunk_size, num_q), bias.shape[-1])) 156 | else: 157 | raise TypeError(f'bias.shape[-2] == {bias.shape[-2]} must broadcast with query.shape[-3] == {num_q}') 158 | return (chunk_idx + query_chunk_size, 159 | _query_chunk_attention(chunk_idx, query_chunk, key, value, mask_chunk, bias_chunk, key_chunk_size=key_chunk_size, 160 | bias_calc_fn=bias_calc_fn, mask_calc_fn=mask_calc_fn, 161 | weights_calc_fn=weights_calc_fn, calc_fn_data=calc_fn_data)) 162 | 163 | _, res = scan(chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) 164 | rl = [res[i] for i in range(res.shape[0])] 165 | return torch.cat(rl, dim=-3) 166 | -------------------------------------------------------------------------------- /memory_efficient_attention/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def dynamic_slice(x, starts, sizes): 6 | # start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) 7 | starts = [np.clip(starts[i], 0, x.shape[i] - sizes[i]) for i in range(len(starts))] 8 | for i, (start, size) in enumerate(zip(starts, sizes)): 9 | x = torch.index_select(x, i, torch.tensor(range(start, start + size), device=x.device)) 10 | return x 11 | 12 | 13 | def map_pt(f, xs): 14 | t = [f(x) for x in xs] 15 | return tuple(map(torch.stack, zip(*t))) 16 | 17 | 18 | def scan(f, init, xs, length=None): 19 | if xs is None: 20 | xs = [None] * length 21 | carry = init 22 | ys = [] 23 | for x in xs: 24 | carry, y = f(carry, x) 25 | ys.append(y) 26 | return carry, torch.stack(ys) 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='memory-efficient-attention', 5 | version='0.1.3', 6 | description='Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch', 7 | license='MIT', 8 | packages=find_packages(), 9 | author='Amin Rezaei', 10 | author_email='AminRezaei0x443@gmail.com', 11 | keywords=['attention', 'pytorch', 'jax'], 12 | url='https://github.com/AminRezaei0x443/memory-efficient-attention', 13 | install_requires=['numpy'], 14 | extras_require={ 15 | 'jax': ['jax'], 16 | 'torch': ['torch'], 17 | 'testing': ['jax', 'torch', 'flax'] 18 | } 19 | ) 20 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AminRezaei0x443/memory-efficient-attention/882c5e70d6b06b3215355b35435211c5b8caa75d/test/__init__.py -------------------------------------------------------------------------------- /test/test_computation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import jax, jax.numpy as jnp 4 | from memory_efficient_attention import efficient_dot_product_attention_pt, efficient_dot_product_attention_jax 5 | from flax.linen.attention import dot_product_attention 6 | from memory_efficient_attention.utils import dynamic_slice 7 | 8 | efficient_dot_product_attention_jax = jax.jit(efficient_dot_product_attention_jax, static_argnames=('mask_calc_fn', 'bias_calc_fn', 'weights_calc_fn')) 9 | 10 | class ComputationTest(unittest.TestCase): 11 | @staticmethod 12 | def data(): 13 | b = 8 14 | key = jax.random.PRNGKey(0) 15 | Qb = jax.random.uniform(key, (1, b, 128, 16, 8), dtype=jnp.float32) 16 | Kb = jax.random.uniform(key, (1, b, 128, 16, 8), dtype=jnp.float32) 17 | Vb = jax.random.uniform(key, (1, b, 128, 16, 8), dtype=jnp.float32) 18 | Mb = jax.random.uniform(key, (1, b, 16, 128, 128)) > 0.5 19 | Bb = jax.random.uniform(key, (1, b, 16, 128, 128), dtype=jnp.float32) / 100 20 | 21 | # calc_fn bias & mask 22 | def mask_calc_fn_jax(query_offset, key_offset, mask_chunk, attn_weights, MbBb): 23 | if MbBb is not None: 24 | Mb, Bb = MbBb 25 | return jax.lax.dynamic_slice(Mb, tuple([0] * (Mb.ndim - 2)) + (query_offset, key_offset), 26 | slice_sizes=tuple(Mb.shape[:-2]) + (attn_weights.shape[-3], attn_weights.shape[-1])) 27 | else: 28 | return mask_chunk 29 | def bias_calc_fn_jax(query_offset, key_offset, bias_chunk, attn_weights, MbBb): 30 | if MbBb is not None: 31 | Mb, Bb = MbBb 32 | return jax.lax.dynamic_slice(Bb, tuple([0] * (Bb.ndim - 2)) + (query_offset, key_offset), 33 | slice_sizes=tuple(Bb.shape[:-2]) + (attn_weights.shape[-3], attn_weights.shape[-1])) 34 | else: 35 | return bias_chunk 36 | def weights_calc_fn_jax(query_offset, key_offset, attn_weights, MbBb): 37 | if MbBb is not None: 38 | bias = bias_calc_fn_jax(query_offset, key_offset, None, attn_weights, MbBb) 39 | bias = jnp.einsum('...hqk->...qhk', bias) 40 | attn_weights = attn_weights + bias 41 | 42 | mask = mask_calc_fn_jax(query_offset, key_offset, None, attn_weights, MbBb) 43 | big_neg = jnp.finfo(attn_weights.dtype).min 44 | mask = jnp.einsum('...hqk->...qhk', mask) 45 | attn_weights = jnp.where(mask, attn_weights, big_neg) 46 | 47 | return attn_weights 48 | def mask_calc_fn_pt(query_offset, key_offset, mask_chunk, attn_weights, MbBb): 49 | if MbBb is not None: 50 | Mb, Bb = MbBb 51 | return dynamic_slice(torch.tensor(Mb.to_py()), tuple([0] * (Mb.ndim - 2)) + (query_offset, key_offset), 52 | tuple(Mb.shape[:-2]) + (attn_weights.shape[-3], attn_weights.shape[-1])) 53 | else: 54 | return mask_chunk 55 | def bias_calc_fn_pt(query_offset, key_offset, bias_chunk, attn_weights, MbBb): 56 | if MbBb is not None: 57 | Mb, Bb = MbBb 58 | return dynamic_slice(torch.tensor(Bb.to_py()), tuple([0] * (Bb.ndim - 2)) + (query_offset, key_offset), 59 | tuple(Bb.shape[:-2]) + (attn_weights.shape[-3], attn_weights.shape[-1])) 60 | else: 61 | return bias_chunk 62 | def weights_calc_fn_pt(query_offset, key_offset, attn_weights, MbBb): 63 | if MbBb is not None: 64 | bias = bias_calc_fn_pt(query_offset, key_offset, None, attn_weights, MbBb) 65 | bias = torch.einsum('...hqk->...qhk', bias) 66 | attn_weights = attn_weights + bias 67 | 68 | mask = mask_calc_fn_pt(query_offset, key_offset, None, attn_weights, MbBb) 69 | big_neg = torch.finfo(attn_weights.dtype).min 70 | big_neg = torch.tensor(big_neg, dtype=torch.float32) 71 | mask = torch.einsum('...hqk->...qhk', mask) 72 | attn_weights = torch.where(mask, attn_weights, big_neg) 73 | return attn_weights 74 | 75 | return [ 76 | ('broadcasting simple mask and bias', 77 | (Qb, Kb, Vb, Mb[:,:,:1,:1,:1], Bb[:,:,:1,:1,:1], dict(), None)), 78 | ('full mask and bias, negates memory savings', 79 | (Qb, Kb, Vb, Mb, Bb, dict(), None)), 80 | ('broadcasting mask and bias passed through callbacks', 81 | (Qb, Kb, Vb, Mb[:,:,:1,:1,:1], Bb[:,:,:1,:1,:1], dict( 82 | pt_mask=mask_calc_fn_pt, 83 | pt_bias=bias_calc_fn_pt, 84 | pt_weights=weights_calc_fn_pt, 85 | jax_mask=mask_calc_fn_jax, 86 | jax_bias=bias_calc_fn_jax, 87 | jax_weights=weights_calc_fn_jax, 88 | ), None)), 89 | ('mask and bias generated per-chunk by callbacks', 90 | (Qb, Kb, Vb, Mb, Bb, dict( 91 | pt_mask=mask_calc_fn_pt, 92 | pt_bias=bias_calc_fn_pt, 93 | jax_mask=mask_calc_fn_jax, 94 | jax_bias=bias_calc_fn_jax, 95 | ), (Mb, Bb))), 96 | ('mask and bias manually applied by custom callback', 97 | (Qb, Kb, Vb, Mb, Bb, dict( 98 | pt_weights=weights_calc_fn_pt, 99 | jax_weights=weights_calc_fn_jax, 100 | ), (Mb, Bb))), 101 | ('no mask nor bias', 102 | (Qb, Kb, Vb, None, None, dict(), None)), 103 | ] 104 | 105 | @staticmethod 106 | def calc_pt(data): 107 | Qb, Kb, Vb, Mb, Bb, Cf, Cb = data 108 | Qbt = torch.tensor(Qb.to_py(), requires_grad=True) 109 | Kbt = torch.tensor(Kb.to_py(), requires_grad=True) 110 | Vbt = torch.tensor(Vb.to_py(), requires_grad=True) 111 | if Mb is not None and Cb is None: 112 | Mbt = torch.tensor(Mb.to_py(), requires_grad=False) 113 | else: 114 | Mbt = None 115 | if Bb is not None and Cb is None: 116 | Bbt = torch.tensor(Bb.to_py(), requires_grad=False) 117 | else: 118 | Bbt = None 119 | Mf, Bf, Wf = Cf.get('pt_mask'), Cf.get('pt_bias'), Cf.get('pt_weights') 120 | return efficient_dot_product_attention_pt(Qbt, Kbt, Vbt, Mbt, Bbt, 121 | mask_calc_fn=Mf, bias_calc_fn=Bf, 122 | weights_calc_fn=Wf, calc_fn_data=Cb).detach().numpy() 123 | 124 | @staticmethod 125 | def calc_jax(data): 126 | Qb, Kb, Vb, Mb, Bb, Cf, Cb = data 127 | Mf, Bf, Wf = Cf.get('jax_mask'), Cf.get('jax_bias'), Cf.get('jax_weights') 128 | if Cb is not None: 129 | Mb, Bb = None, None 130 | return jnp.asarray(efficient_dot_product_attention_jax(Qb, Kb, Vb, Mb, Bb, 131 | mask_calc_fn=Mf, bias_calc_fn=Bf, 132 | weights_calc_fn=Wf, calc_fn_data=Cb)) 133 | 134 | @staticmethod 135 | def calc_flax(data): 136 | Qb, Kb, Vb, Mb, Bb, Cf, Pd = data 137 | return jnp.asarray(dot_product_attention(Qb, Kb, Vb, Bb, Mb)) 138 | 139 | def test_pt(self): 140 | for msg, data in ComputationTest.data(): 141 | with self.subTest(msg=msg): 142 | res_pt = ComputationTest.calc_pt(data) 143 | res_flax = ComputationTest.calc_flax(data) 144 | self.assertTrue(jnp.allclose(res_pt, res_flax)) 145 | 146 | def test_jax(self): 147 | for msg, data in ComputationTest.data(): 148 | with self.subTest(msg=msg): 149 | res_jax = ComputationTest.calc_jax(data) 150 | res_flax = ComputationTest.calc_flax(data) 151 | self.assertTrue(jnp.allclose(res_jax, res_flax)) 152 | 153 | def test_jax_and_pt(self): 154 | for msg, data in ComputationTest.data(): 155 | with self.subTest(msg=msg): 156 | res_pt = ComputationTest.calc_pt(data) 157 | res_jax = ComputationTest.calc_jax(data) 158 | res_flax = ComputationTest.calc_flax(data) 159 | self.assertTrue(jnp.allclose(res_pt, res_jax)) 160 | self.assertTrue(jnp.allclose(res_pt, res_flax)) 161 | self.assertTrue(jnp.allclose(res_jax, res_flax)) 162 | 163 | 164 | if __name__ == '__main__': 165 | unittest.main() 166 | --------------------------------------------------------------------------------