├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── README.md └── enwik8.gz ├── memory_efficient_attention_pytorch ├── __init__.py ├── autoregressive_wrapper.py ├── cosine_sim_flash_attention.py ├── flash_attention.py ├── memory_efficient_attention.py ├── memory_efficient_cosine_sim_attention.py ├── reversible.py └── transformer.py ├── setup.cfg ├── setup.py ├── tests └── test.py └── train.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install pytest 30 | python -m pip install pytest torch==1.10.0 31 | - name: Test with pytest 32 | run: | 33 | python setup.py test 34 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Memory Efficient Attention Pytorch (obsolete) 2 | 3 | Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention. 4 | 5 | This repository also contains a naive non-CUDA implementation of the improvements made by Tri Dao with his Flash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers. 6 | 7 | Update: from now on, you should just be using the `F.scaled_dot_product_attention` function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at the official repository 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install memory-efficient-attention-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | For autoregressive language model 18 | 19 | ```python 20 | import torch 21 | from memory_efficient_attention_pytorch import Attention 22 | 23 | attn = Attention( 24 | dim = 512, 25 | dim_head = 64, # dimension per head 26 | heads = 8, # number of attention heads 27 | causal = True, # autoregressive or not 28 | memory_efficient = True, # whether to use memory efficient attention (can be turned off to test against normal attention) 29 | q_bucket_size = 1024, # bucket size along queries dimension 30 | k_bucket_size = 2048 # bucket size along key / values dimension 31 | ).cuda() 32 | 33 | x = torch.randn(1, 65536, 512).cuda() 34 | out = attn(x) # (1, 65536, 512) 35 | ``` 36 | 37 | Cross attention 38 | 39 | ```python 40 | import torch 41 | from memory_efficient_attention_pytorch import Attention 42 | 43 | cross_attn = Attention( 44 | dim = 512, 45 | dim_head = 64, 46 | heads = 8, 47 | memory_efficient = True, 48 | q_bucket_size = 1024, 49 | k_bucket_size = 2048 50 | ).cuda() 51 | 52 | x = torch.randn(1, 65536, 512).cuda() 53 | context = torch.randn(1, 65536, 512).cuda() 54 | mask = torch.ones(1, 65536).bool().cuda() 55 | 56 | out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) 57 | ``` 58 | 59 | ## Citations 60 | 61 | ```bibtex 62 | @misc{rabe2021selfattention, 63 | title = {Self-attention Does Not Need $O(n^2)$ Memory}, 64 | author = {Markus N. Rabe and Charles Staats}, 65 | year = {2021}, 66 | eprint = {2112.05682}, 67 | archivePrefix = {arXiv}, 68 | primaryClass = {cs.LG} 69 | } 70 | ``` 71 | 72 | ```bibtex 73 | @misc{liu2021swin, 74 | title = {Swin Transformer V2: Scaling Up Capacity and Resolution}, 75 | author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo}, 76 | year = {2021}, 77 | eprint = {2111.09883}, 78 | archivePrefix = {arXiv}, 79 | primaryClass = {cs.CV} 80 | } 81 | ``` 82 | 83 | ```bibtex 84 | @article{Dao2022FlashAttentionFA, 85 | title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, 86 | author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e}, 87 | journal = {ArXiv}, 88 | year = {2022}, 89 | volume = {abs/2205.14135} 90 | } 91 | ``` 92 | 93 | ```bibtex 94 | @article{dao2023flashattention2, 95 | title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, 96 | author = {Dao, Tri}, 97 | year = {2023} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/memory-efficient-attention-pytorch/d54f391370ecbf843a871f0e260425d076995550/data/enwik8.gz -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention 2 | from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention 3 | from memory_efficient_attention_pytorch.flash_attention import FlashAttention 4 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/autoregressive_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | # helper function 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def eval_decorator(fn): 11 | def inner(model, *args, **kwargs): 12 | was_training = model.training 13 | model.eval() 14 | out = fn(model, *args, **kwargs) 15 | model.train(was_training) 16 | return out 17 | return inner 18 | 19 | # top k filtering 20 | 21 | def top_k(logits, thres = 0.9): 22 | k = int((1 - thres) * logits.shape[-1]) 23 | val, ind = torch.topk(logits, k) 24 | probs = torch.full_like(logits, float('-inf')) 25 | probs.scatter_(1, ind, val) 26 | return probs 27 | 28 | class AutoregressiveWrapper(nn.Module): 29 | def __init__(self, net, pad_value = 0): 30 | super().__init__() 31 | self.pad_value = pad_value 32 | self.net = net 33 | self.max_seq_len = net.max_seq_len 34 | 35 | @torch.no_grad() 36 | @eval_decorator 37 | def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs): 38 | b, t, device = *start_tokens.shape, start_tokens.device 39 | 40 | out = start_tokens 41 | 42 | for _ in range(seq_len): 43 | x = out[:, -self.max_seq_len:] 44 | 45 | logits = self.net(x, **kwargs)[:, -1, :] 46 | 47 | filtered_logits = top_k(logits, thres = filter_thres) 48 | probs = F.softmax(filtered_logits / temperature, dim=-1) 49 | 50 | sample = torch.multinomial(probs, 1) 51 | 52 | out = torch.cat((out, sample), dim=-1) 53 | 54 | if exists(eos_token): 55 | is_eos_token = (out == eos_token) 56 | 57 | if is_eos_token.any(dim = -1).all(): 58 | # mask out everything after the eos tokens 59 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 60 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 61 | out = out.masked_fill(mask, self.pad_value) 62 | break 63 | 64 | out = out[:, t:] 65 | return out 66 | 67 | def forward(self, x, **kwargs): 68 | x_inp, x_labels = x[:, :-1], x[:, 1:] 69 | return self.net(x_inp, labels = x_labels, **kwargs) 70 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/cosine_sim_flash_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | from torch.autograd.function import Function 7 | 8 | from einops import rearrange 9 | 10 | # constants 11 | 12 | EPSILON = 1e-6 13 | 14 | # helper functions 15 | 16 | def exists(val): 17 | return val is not None 18 | 19 | def default(val, d): 20 | return val if exists(val) else d 21 | 22 | def l2norm(t): 23 | return F.normalize(t, dim = -1) 24 | 25 | # flash attention forwards and backwards 26 | 27 | class FlashAttentionFunction(Function): 28 | @staticmethod 29 | @torch.no_grad() 30 | def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size): 31 | device = q.device 32 | max_neg_value = -torch.finfo(q.dtype).max 33 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 34 | 35 | k_len = k.shape[-2] # in cosine sim attention, row sums are bounded by key / values sequence length 36 | 37 | o = torch.zeros_like(q) 38 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device) 39 | 40 | if not exists(mask): 41 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) 42 | else: 43 | mask = rearrange(mask, 'b n -> b 1 1 n') 44 | mask = mask.split(q_bucket_size, dim = -1) 45 | 46 | row_splits = zip( 47 | q.split(q_bucket_size, dim = -2), 48 | o.split(q_bucket_size, dim = -2), 49 | mask, 50 | all_row_sums.split(q_bucket_size, dim = -2), 51 | ) 52 | 53 | for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits): 54 | q_start_index = ind * q_bucket_size - qk_len_diff 55 | 56 | col_splits = zip( 57 | k.split(k_bucket_size, dim = -2), 58 | v.split(k_bucket_size, dim = -2), 59 | ) 60 | 61 | for k_ind, (kc, vc) in enumerate(col_splits): 62 | k_start_index = k_ind * k_bucket_size 63 | 64 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 65 | 66 | if exists(row_mask): 67 | attn_weights.masked_fill_(~row_mask, max_neg_value) 68 | 69 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 70 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 71 | attn_weights.masked_fill_(causal_mask, max_neg_value) 72 | 73 | attn_weights -= scale 74 | exp_weights = torch.exp(attn_weights) 75 | 76 | if exists(row_mask): 77 | exp_weights.masked_fill_(~row_mask, 0.) 78 | 79 | block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON) 80 | 81 | exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) 82 | 83 | oc.add_(exp_values / k_len) 84 | row_sums.add_(block_row_sums) 85 | 86 | ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size) 87 | ctx.save_for_backward(q, k, v, o, all_row_sums) 88 | 89 | o.mul_(k_len / all_row_sums) 90 | 91 | return o 92 | 93 | @staticmethod 94 | @torch.no_grad() 95 | def backward(ctx, do): 96 | scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args 97 | q, k, v, o, l = ctx.saved_tensors 98 | 99 | device = q.device 100 | 101 | max_neg_value = -torch.finfo(q.dtype).max 102 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 103 | 104 | dq = torch.zeros_like(q) 105 | dk = torch.zeros_like(k) 106 | dv = torch.zeros_like(v) 107 | 108 | row_splits = zip( 109 | q.split(q_bucket_size, dim = -2), 110 | o.split(q_bucket_size, dim = -2), 111 | do.split(q_bucket_size, dim = -2), 112 | mask, 113 | l.split(q_bucket_size, dim = -2), 114 | dq.split(q_bucket_size, dim = -2) 115 | ) 116 | 117 | for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits): 118 | q_start_index = ind * q_bucket_size - qk_len_diff 119 | 120 | col_splits = zip( 121 | k.split(k_bucket_size, dim = -2), 122 | v.split(k_bucket_size, dim = -2), 123 | dk.split(k_bucket_size, dim = -2), 124 | dv.split(k_bucket_size, dim = -2), 125 | ) 126 | 127 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): 128 | k_start_index = k_ind * k_bucket_size 129 | 130 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 131 | 132 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 133 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 134 | attn_weights.masked_fill_(causal_mask, max_neg_value) 135 | 136 | exp_attn_weights = torch.exp(attn_weights - scale) 137 | 138 | if exists(row_mask): 139 | exp_attn_weights.masked_fill_(~row_mask, 0.) 140 | 141 | p = exp_attn_weights / lc 142 | 143 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) 144 | dp = einsum('... i d, ... j d -> ... i j', doc, vc) 145 | 146 | D = (doc * oc).sum(dim = -1, keepdims = True) 147 | ds = p * scale * (dp - D) 148 | 149 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) 150 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) 151 | 152 | dqc.add_(dq_chunk) 153 | dkc.add_(dk_chunk) 154 | dvc.add_(dv_chunk) 155 | 156 | return dq, dk, dv, None, None, None, None, None 157 | 158 | # main class 159 | 160 | # flash attention for cosine sim attention 161 | # a bit less complicated, as no more need to worry about softmax numerical stability, and row sums are bounded 162 | 163 | class FlashAttention(nn.Module): 164 | def __init__( 165 | self, 166 | *, 167 | dim, 168 | scale = 16, 169 | heads = 8, 170 | dim_head = 64, 171 | causal = False, 172 | q_bucket_size = 512, 173 | k_bucket_size = 1024 174 | ): 175 | super().__init__() 176 | self.heads = heads 177 | 178 | self.scale = scale 179 | self.causal = causal 180 | 181 | inner_dim = heads * dim_head 182 | 183 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 184 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 185 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 186 | 187 | # memory efficient attention related parameters 188 | # can be overriden on forward 189 | self.q_bucket_size = q_bucket_size 190 | self.k_bucket_size = k_bucket_size 191 | 192 | def forward( 193 | self, 194 | x, 195 | context = None, 196 | mask = None, 197 | q_bucket_size = None, 198 | k_bucket_size = None, 199 | ): 200 | q_bucket_size = default(q_bucket_size, self.q_bucket_size) 201 | k_bucket_size = default(k_bucket_size, self.k_bucket_size) 202 | 203 | h = self.heads 204 | context = default(context, x) 205 | 206 | q = self.to_q(x) 207 | k, v = self.to_kv(context).chunk(2, dim = -1) 208 | 209 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 210 | 211 | q, k = map(l2norm, (q, k)) 212 | 213 | out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size) 214 | 215 | out = rearrange(out, 'b h n d -> b n (h d)') 216 | return self.to_out(out) 217 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/flash_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | from torch import nn, einsum 5 | from torch.autograd.function import Function 6 | 7 | from einops import rearrange 8 | 9 | # constants 10 | 11 | EPSILON = 1e-10 12 | 13 | # helper functions 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def default(val, d): 19 | return val if exists(val) else d 20 | 21 | # flash attention forwards and backwards 22 | 23 | # flash attention v1 - https://arxiv.org/abs/2205.14135 24 | # flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf 25 | 26 | class FlashAttentionFunction(Function): 27 | @staticmethod 28 | @torch.no_grad() 29 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): 30 | """ Algorithm 1 in the v2 paper """ 31 | 32 | device = q.device 33 | max_neg_value = -torch.finfo(q.dtype).max 34 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 35 | 36 | o = torch.zeros_like(q) 37 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device) 38 | all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device) 39 | 40 | scale = (q.shape[-1] ** -0.5) 41 | 42 | num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size) 43 | num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size) 44 | 45 | if exists(mask) and mask.ndim == 2: 46 | mask = rearrange(mask, 'b n -> b 1 1 n') 47 | 48 | if not exists(mask): 49 | col_masks = (None,) * num_col_tiles 50 | mask = (col_masks,) * num_row_tiles 51 | else: 52 | mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2) 53 | mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask) 54 | 55 | row_splits = zip( 56 | q.split(q_bucket_size, dim = -2), 57 | o.split(q_bucket_size, dim = -2), 58 | mask, 59 | all_row_sums.split(q_bucket_size, dim = -2), 60 | all_row_maxes.split(q_bucket_size, dim = -2), 61 | ) 62 | 63 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): 64 | q_start_index = ind * q_bucket_size - qk_len_diff 65 | 66 | col_splits = zip( 67 | k.split(k_bucket_size, dim = -2), 68 | v.split(k_bucket_size, dim = -2), 69 | row_mask 70 | ) 71 | 72 | for k_ind, (kc, vc, col_mask) in enumerate(col_splits): 73 | k_start_index = k_ind * k_bucket_size 74 | 75 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 76 | 77 | if exists(col_mask): 78 | attn_weights.masked_fill_(~col_mask, max_neg_value) 79 | 80 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 81 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 82 | attn_weights.masked_fill_(causal_mask, max_neg_value) 83 | 84 | block_row_maxes = attn_weights.amax(dim = -1, keepdims = True) 85 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes) 86 | 87 | exp_weights = torch.exp(attn_weights - new_row_maxes) 88 | 89 | if exists(col_mask): 90 | exp_weights.masked_fill_(~col_mask, 0.) 91 | 92 | block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON) 93 | 94 | exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) 95 | 96 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) 97 | 98 | new_row_sums = exp_row_max_diff * row_sums + block_row_sums 99 | 100 | oc.mul_(exp_row_max_diff).add_(exp_values) 101 | 102 | row_maxes.copy_(new_row_maxes) 103 | row_sums.copy_(new_row_sums) 104 | 105 | oc.div_(row_sums) 106 | 107 | lse = all_row_sums.log() + all_row_maxes 108 | 109 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) 110 | ctx.save_for_backward(q, k, v, o, lse) 111 | 112 | return o 113 | 114 | @staticmethod 115 | @torch.no_grad() 116 | def backward(ctx, do): 117 | """ Algorithm 2 in the v2 paper """ 118 | 119 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args 120 | q, k, v, o, lse = ctx.saved_tensors 121 | 122 | device = q.device 123 | 124 | max_neg_value = -torch.finfo(q.dtype).max 125 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) 126 | 127 | dq = torch.zeros_like(q) 128 | dk = torch.zeros_like(k) 129 | dv = torch.zeros_like(v) 130 | 131 | row_splits = zip( 132 | q.split(q_bucket_size, dim = -2), 133 | o.split(q_bucket_size, dim = -2), 134 | do.split(q_bucket_size, dim = -2), 135 | mask, 136 | lse.split(q_bucket_size, dim = -2), 137 | dq.split(q_bucket_size, dim = -2) 138 | ) 139 | 140 | for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits): 141 | q_start_index = ind * q_bucket_size - qk_len_diff 142 | 143 | col_splits = zip( 144 | k.split(k_bucket_size, dim = -2), 145 | v.split(k_bucket_size, dim = -2), 146 | dk.split(k_bucket_size, dim = -2), 147 | dv.split(k_bucket_size, dim = -2), 148 | row_mask 149 | ) 150 | 151 | for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits): 152 | k_start_index = k_ind * k_bucket_size 153 | 154 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale 155 | 156 | if causal and q_start_index < (k_start_index + k_bucket_size - 1): 157 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 158 | attn_weights.masked_fill_(causal_mask, max_neg_value) 159 | 160 | p = torch.exp(attn_weights - lsec) 161 | 162 | if exists(col_mask): 163 | p.masked_fill_(~col_mask, 0.) 164 | 165 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) 166 | dp = einsum('... i d, ... j d -> ... i j', doc, vc) 167 | 168 | D = (doc * oc).sum(dim = -1, keepdims = True) 169 | ds = p * scale * (dp - D) 170 | 171 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) 172 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) 173 | 174 | dqc.add_(dq_chunk) 175 | dkc.add_(dk_chunk) 176 | dvc.add_(dv_chunk) 177 | 178 | return dq, dk, dv, None, None, None, None 179 | 180 | # main class 181 | 182 | # just flash attention in plain pytorch 183 | # it will be way slower than implementing it in CUDA 184 | # for tinkering and educational purposes 185 | 186 | class FlashAttention(nn.Module): 187 | def __init__( 188 | self, 189 | *, 190 | dim, 191 | heads = 8, 192 | dim_head = 64, 193 | causal = False, 194 | q_bucket_size = 512, 195 | k_bucket_size = 1024 196 | ): 197 | super().__init__() 198 | self.heads = heads 199 | self.causal = causal 200 | 201 | inner_dim = heads * dim_head 202 | 203 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 204 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 205 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 206 | 207 | # memory efficient attention related parameters 208 | # can be overriden on forward 209 | self.q_bucket_size = q_bucket_size 210 | self.k_bucket_size = k_bucket_size 211 | 212 | def forward( 213 | self, 214 | x, 215 | context = None, 216 | mask = None, 217 | q_bucket_size = None, 218 | k_bucket_size = None, 219 | ): 220 | q_bucket_size = default(q_bucket_size, self.q_bucket_size) 221 | k_bucket_size = default(k_bucket_size, self.k_bucket_size) 222 | 223 | h = self.heads 224 | context = default(context, x) 225 | 226 | q = self.to_q(x) 227 | k, v = self.to_kv(context).chunk(2, dim = -1) 228 | 229 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 230 | 231 | out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size) 232 | 233 | out = rearrange(out, 'b h n d -> b n (h d)') 234 | return self.to_out(out) 235 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/memory_efficient_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from torch import nn, einsum 4 | from torch.utils.checkpoint import checkpoint 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | # helper functions 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def default(val, d): 15 | return val if exists(val) else d 16 | 17 | # regular attention 18 | 19 | def attention( 20 | q, k, v, 21 | mask = None, 22 | causal = False, 23 | attn_bias = None, 24 | **kwargs 25 | ): 26 | scale = q.shape[-1] ** -0.5 27 | q = q * scale 28 | 29 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 30 | 31 | if exists(attn_bias): 32 | sim = sim + attn_bias 33 | 34 | mask_value = -torch.finfo(sim.dtype).max 35 | 36 | if exists(mask): 37 | if mask.ndim == 2: 38 | mask = rearrange(mask, 'b j -> b 1 1 j') 39 | sim = sim.masked_fill(~mask, mask_value) 40 | 41 | if causal: 42 | i, j = sim.shape[-2:] 43 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1) 44 | sim = sim.masked_fill(mask, mask_value) 45 | 46 | sim = sim - sim.amax(dim = -1, keepdim = True).detach() 47 | attn = sim.softmax(dim = -1) 48 | 49 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 50 | return out 51 | 52 | # memory efficient attention 53 | 54 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout): 55 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device 56 | 57 | weight = einsum('b h i d, b h j d -> b h i j', q, k) 58 | 59 | if exists(attn_bias_chunk): 60 | weight = weight + attn_bias_chunk 61 | 62 | mask_value = -torch.finfo(weight.dtype).max 63 | 64 | if exists(mask): 65 | mask = rearrange(mask, 'b j -> b 1 1 j') 66 | weight = weight.masked_fill(~mask, mask_value) 67 | 68 | if causal and q_start_index < (k_start_index + k_chunk_size - 1): 69 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 70 | weight = weight.masked_fill(causal_mask, mask_value) 71 | 72 | weight_max = weight.amax(dim = -1, keepdim = True).detach() 73 | weight = weight - weight_max 74 | 75 | exp_weight = weight.exp() 76 | 77 | exp_weight = F.dropout(exp_weight, p = dropout) 78 | 79 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v) 80 | 81 | return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') 82 | 83 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) 84 | 85 | def memory_efficient_attention( 86 | q, k, v, 87 | mask = None, 88 | causal = False, 89 | attn_bias = None, 90 | q_bucket_size = 512, 91 | k_bucket_size = 1024, 92 | eps = 1e-8, 93 | dropout = 0., 94 | training = False 95 | ): 96 | scale = q.shape[-1] ** -0.5 97 | q = q * scale 98 | 99 | # function 100 | 101 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad 102 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk 103 | 104 | # chunk all the inputs 105 | 106 | q_chunks = q.split(q_bucket_size, dim = -2) 107 | k_chunks = k.split(k_bucket_size, dim = -2) 108 | v_chunks = v.split(k_bucket_size, dim = -2) 109 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) 110 | 111 | if exists(attn_bias): 112 | i, j = attn_bias.shape[-2:] 113 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2) 114 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks)) 115 | 116 | # loop through all chunks and accumulate 117 | 118 | out = [] 119 | for q_index, q_chunk in enumerate(q_chunks): 120 | exp_weights = [] 121 | weighted_values = [] 122 | weight_maxes = [] 123 | 124 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): 125 | q_start_index = q_index * q_bucket_size 126 | k_start_index = k_index * k_bucket_size 127 | 128 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1): 129 | # if chunk is to be all masked out causally, skip 130 | continue 131 | 132 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None 133 | 134 | exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( 135 | q_chunk, 136 | k_chunk, 137 | v_chunk, 138 | mask_chunk, 139 | attn_bias_chunk, 140 | causal, 141 | (q_start_index, k_start_index), 142 | dropout if training else 0. 143 | ) 144 | 145 | exp_weights.append(exp_weight_chunk) 146 | weighted_values.append(weighted_value_chunk) 147 | weight_maxes.append(weight_max_chunk) 148 | 149 | weight_maxes = torch.stack(weight_maxes, dim = -1) 150 | 151 | weighted_values = torch.stack(weighted_values, dim = -1) 152 | exp_weights = torch.stack(exp_weights, dim = -1) 153 | 154 | global_max = weight_maxes.amax(dim = -1, keepdim = True) 155 | renorm_factor = (weight_maxes - global_max).exp().detach() 156 | 157 | exp_weights = exp_weights * renorm_factor 158 | weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') 159 | 160 | all_values = weighted_values.sum(dim = -1) 161 | all_weights = exp_weights.sum(dim = -1) 162 | 163 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) 164 | out.append(normalized_values) 165 | 166 | return torch.cat(out, dim = -2) 167 | 168 | # main class 169 | 170 | class Attention(nn.Module): 171 | def __init__( 172 | self, 173 | *, 174 | dim, 175 | heads = 8, 176 | dim_head = 64, 177 | dropout = 0., 178 | causal = False, 179 | memory_efficient = False, 180 | q_bucket_size = 512, 181 | k_bucket_size = 1024 182 | ): 183 | super().__init__() 184 | self.heads = heads 185 | self.causal = causal 186 | self.dropout = dropout 187 | inner_dim = heads * dim_head 188 | 189 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 190 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 191 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 192 | 193 | # memory efficient attention related parameters 194 | # can be overriden on forward 195 | self.memory_efficient = memory_efficient 196 | self.q_bucket_size = q_bucket_size 197 | self.k_bucket_size = k_bucket_size 198 | 199 | def forward( 200 | self, 201 | x, 202 | context = None, 203 | mask = None, 204 | attn_bias = None, 205 | memory_efficient = None, 206 | q_bucket_size = None, 207 | k_bucket_size = None, 208 | ): 209 | memory_efficient = default(memory_efficient, self.memory_efficient) 210 | q_bucket_size = default(q_bucket_size, self.q_bucket_size) 211 | k_bucket_size = default(k_bucket_size, self.k_bucket_size) 212 | 213 | h = self.heads 214 | context = default(context, x) 215 | 216 | q = self.to_q(x) 217 | k, v = self.to_kv(context).chunk(2, dim = -1) 218 | 219 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 220 | 221 | attn_fn = attention if not memory_efficient else memory_efficient_attention 222 | 223 | out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, 224 | k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training) 225 | 226 | out = rearrange(out, 'b h n d -> b n (h d)') 227 | return self.to_out(out) 228 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from torch import nn, einsum 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | from einops import rearrange 9 | 10 | # helper functions 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | def l2norm(t): 19 | return F.normalize(t, dim = -1) 20 | 21 | # regular attention 22 | 23 | def attention( 24 | q, k, v, 25 | mask = None, 26 | causal = False, 27 | attn_bias = None, 28 | **kwargs 29 | ): 30 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 31 | 32 | if exists(attn_bias): 33 | sim = sim + attn_bias 34 | 35 | mask_value = -torch.finfo(sim.dtype).max 36 | 37 | if exists(mask): 38 | mask = rearrange(mask, 'b j -> b 1 1 j') 39 | sim = sim.masked_fill(~mask, mask_value) 40 | 41 | if causal: 42 | i, j = sim.shape[-2:] 43 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1) 44 | sim = sim.masked_fill(mask, mask_value) 45 | 46 | attn = sim.softmax(dim = -1) 47 | 48 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 49 | return out 50 | 51 | # memory efficient attention 52 | 53 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices): 54 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device 55 | 56 | weight = einsum('b h i d, b h j d -> b h i j', q, k) 57 | 58 | if exists(attn_bias_chunk): 59 | weight = weight + attn_bias_chunk 60 | 61 | mask_value = -torch.finfo(weight.dtype).max 62 | 63 | if exists(mask): 64 | mask = rearrange(mask, 'b j -> b 1 1 j') 65 | weight = weight.masked_fill(~mask, mask_value) 66 | 67 | if causal and q_start_index < (k_start_index + k_chunk_size - 1): 68 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1) 69 | weight = weight.masked_fill(causal_mask, mask_value) 70 | 71 | exp_weight = weight.exp() 72 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v) 73 | 74 | return exp_weight.sum(dim = -1), weighted_value 75 | 76 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) 77 | 78 | def numerically_unstable_memory_efficient_attention( 79 | q, k, v, 80 | mask = None, 81 | causal = False, 82 | attn_bias = None, 83 | q_bucket_size = 512, 84 | k_bucket_size = 1024, 85 | eps = 1e-8 86 | ): 87 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad 88 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk 89 | 90 | # chunk all the inputs 91 | 92 | q_chunks = q.split(q_bucket_size, dim = -2) 93 | k_chunks = k.split(k_bucket_size, dim = -2) 94 | v_chunks = v.split(k_bucket_size, dim = -2) 95 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) 96 | 97 | if exists(attn_bias): 98 | i, j = attn_bias.shape[-2:] 99 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2) 100 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks)) 101 | 102 | # loop through all chunks and accumulate 103 | 104 | out = [] 105 | for q_index, q_chunk in enumerate(q_chunks): 106 | q_start_index = q_index * q_bucket_size 107 | exp_weights = [] 108 | weighted_values = [] 109 | 110 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): 111 | k_start_index = k_index * k_bucket_size 112 | 113 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1): 114 | # if chunk is to be all masked out causally, skip 115 | continue 116 | 117 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None 118 | 119 | exp_weight_chunk, weighted_value_chunk = summarize_qkv_fn( 120 | q_chunk, 121 | k_chunk, 122 | v_chunk, 123 | mask_chunk, 124 | attn_bias_chunk, 125 | causal, 126 | (q_start_index, k_start_index) 127 | ) 128 | 129 | exp_weights.append(exp_weight_chunk) 130 | weighted_values.append(weighted_value_chunk) 131 | 132 | all_values = sum(weighted_values) 133 | all_weights = sum(exp_weights) 134 | 135 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) 136 | out.append(normalized_values) 137 | 138 | return torch.cat(out, dim = -2) 139 | 140 | # main class 141 | 142 | class CosineSimAttention(nn.Module): 143 | def __init__( 144 | self, 145 | *, 146 | dim, 147 | seq_len, 148 | heads = 8, 149 | dim_head = 64, 150 | dropout = 0., 151 | causal = False, 152 | memory_efficient = False, 153 | q_bucket_size = 512, 154 | k_bucket_size = 1024 155 | ): 156 | super().__init__() 157 | self.heads = heads 158 | self.causal = causal 159 | 160 | inner_dim = heads * dim_head 161 | 162 | scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len)) 163 | self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value)) 164 | 165 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 166 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 167 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 168 | 169 | # memory efficient attention related parameters 170 | # can be overriden on forward 171 | self.memory_efficient = memory_efficient 172 | self.q_bucket_size = q_bucket_size 173 | self.k_bucket_size = k_bucket_size 174 | 175 | def forward( 176 | self, 177 | x, 178 | context = None, 179 | mask = None, 180 | attn_bias = None, 181 | memory_efficient = None, 182 | q_bucket_size = None, 183 | k_bucket_size = None, 184 | ): 185 | memory_efficient = default(memory_efficient, self.memory_efficient) 186 | q_bucket_size = default(q_bucket_size, self.q_bucket_size) 187 | k_bucket_size = default(k_bucket_size, self.k_bucket_size) 188 | 189 | h = self.heads 190 | context = default(context, x) 191 | 192 | q = self.to_q(x) 193 | k, v = self.to_kv(context).chunk(2, dim = -1) 194 | 195 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 196 | 197 | q, k = map(l2norm, (q, k)) 198 | 199 | q = q * self.scale.exp() 200 | 201 | attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention 202 | 203 | out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) 204 | 205 | out = rearrange(out, 'b h n d -> b n (h d)') 206 | return self.to_out(out) 207 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 20 | class Deterministic(nn.Module): 21 | def __init__(self, net): 22 | super().__init__() 23 | self.net = net 24 | self.cpu_state = None 25 | self.cuda_in_fwd = None 26 | self.gpu_devices = None 27 | self.gpu_states = None 28 | 29 | def record_rng(self, *args): 30 | self.cpu_state = torch.get_rng_state() 31 | if torch.cuda._initialized: 32 | self.cuda_in_fwd = True 33 | self.gpu_devices, self.gpu_states = get_device_states(*args) 34 | 35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 36 | if record_rng: 37 | self.record_rng(*args) 38 | 39 | if not set_rng: 40 | return self.net(*args, **kwargs) 41 | 42 | rng_devices = [] 43 | if self.cuda_in_fwd: 44 | rng_devices = self.gpu_devices 45 | 46 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 47 | torch.set_rng_state(self.cpu_state) 48 | if self.cuda_in_fwd: 49 | set_device_states(self.gpu_devices, self.gpu_states) 50 | return self.net(*args, **kwargs) 51 | 52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 53 | # once multi-GPU is confirmed working, refactor and send PR back to source 54 | class ReversibleBlock(nn.Module): 55 | def __init__(self, f, g): 56 | super().__init__() 57 | self.f = Deterministic(f) 58 | self.g = Deterministic(g) 59 | 60 | def forward(self, x, f_args = {}, g_args = {}): 61 | x1, x2 = torch.chunk(x, 2, dim=2) 62 | y1, y2 = None, None 63 | 64 | with torch.no_grad(): 65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 67 | 68 | return torch.cat([y1, y2], dim=2) 69 | 70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 71 | y1, y2 = torch.chunk(y, 2, dim=2) 72 | del y 73 | 74 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 75 | del dy 76 | 77 | with torch.enable_grad(): 78 | y1.requires_grad = True 79 | gy1 = self.g(y1, set_rng=True, **g_args) 80 | torch.autograd.backward(gy1, dy2) 81 | 82 | with torch.no_grad(): 83 | x2 = y2 - gy1 84 | del y2, gy1 85 | 86 | dx1 = dy1 + y1.grad 87 | del dy1 88 | y1.grad = None 89 | 90 | with torch.enable_grad(): 91 | x2.requires_grad = True 92 | fx2 = self.f(x2, set_rng=True, **f_args) 93 | torch.autograd.backward(fx2, dx1, retain_graph=True) 94 | 95 | with torch.no_grad(): 96 | x1 = y1 - fx2 97 | del y1, fx2 98 | 99 | dx2 = dy2 + x2.grad 100 | del dy2 101 | x2.grad = None 102 | 103 | x = torch.cat([x1, x2.detach()], dim=2) 104 | dx = torch.cat([dx1, dx2], dim=2) 105 | 106 | return x, dx 107 | 108 | class _ReversibleFunction(Function): 109 | @staticmethod 110 | def forward(ctx, x, blocks, args): 111 | ctx.args = args 112 | for block, kwarg in zip(blocks, args): 113 | x = block(x, **kwarg) 114 | ctx.y = x.detach() 115 | ctx.blocks = blocks 116 | return x 117 | 118 | @staticmethod 119 | def backward(ctx, dy): 120 | y = ctx.y 121 | args = ctx.args 122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 123 | y, dy = block.backward_pass(y, dy, **kwargs) 124 | return dy, None, None 125 | 126 | class ReversibleSequence(nn.Module): 127 | def __init__(self, blocks, args_route = {}): 128 | super().__init__() 129 | self.args_route = args_route 130 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 131 | 132 | def forward(self, x, **kwargs): 133 | x = torch.cat([x, x], dim=-1) 134 | 135 | blocks = self.blocks 136 | args = route_args(self.args_route, kwargs, len(blocks)) 137 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 138 | 139 | layers_and_args = list(zip(blocks, args)) 140 | 141 | out = _ReversibleFunction.apply(x, blocks, args) 142 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0) 143 | -------------------------------------------------------------------------------- /memory_efficient_attention_pytorch/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from einops import rearrange 7 | from memory_efficient_attention_pytorch import FlashAttention, Attention 8 | from memory_efficient_attention_pytorch.reversible import ReversibleSequence 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | class PreNorm(nn.Module): 14 | def __init__(self, dim, fn): 15 | super().__init__() 16 | self.fn = fn 17 | self.norm = nn.LayerNorm(dim) 18 | 19 | def forward(self, x, **kwargs): 20 | x = self.norm(x) 21 | return self.fn(x, **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, mult = 4, chunks = 1): 25 | super().__init__() 26 | self.chunks = chunks 27 | 28 | self.net = nn.Sequential( 29 | nn.Linear(dim, dim * mult), 30 | nn.GELU(), 31 | nn.Linear(dim * mult, dim) 32 | ) 33 | 34 | def forward(self, x): 35 | if self.chunks <= 1: 36 | return self.net(x) 37 | 38 | chunks = x.chunk(self.chunks, dim = 1) 39 | out = [self.net(chunk) for chunk in chunks] 40 | return torch.cat(out, dim = 1) 41 | 42 | class Transformer(nn.Module): 43 | def __init__( 44 | self, 45 | *, 46 | num_tokens, 47 | max_seq_len, 48 | dim, 49 | depth, 50 | causal = False, 51 | dim_head = 64, 52 | heads = 8, 53 | ff_mult = 4, 54 | ff_chunks = 1, 55 | use_flash_attn = True, 56 | **kwargs 57 | ): 58 | super().__init__() 59 | self.max_seq_len = max_seq_len 60 | 61 | self.token_emb = nn.Embedding(num_tokens, dim) 62 | self.pos_emb = nn.Embedding(max_seq_len, dim) 63 | 64 | attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True) 65 | 66 | self.layers = nn.ModuleList([]) 67 | for _ in range(depth): 68 | self.layers.append(nn.ModuleList([ 69 | PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)), 70 | PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)), 71 | ])) 72 | 73 | self.net = ReversibleSequence(self.layers) 74 | 75 | self.to_logits = nn.Sequential( 76 | nn.LayerNorm(dim), 77 | nn.Linear(dim, num_tokens) 78 | ) 79 | 80 | def forward(self, x, labels = None): 81 | device = x.device 82 | x = self.token_emb(x) 83 | 84 | pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device)) 85 | x = x + pos_emb 86 | 87 | x = self.net(x) 88 | 89 | logits = self.to_logits(x) 90 | 91 | if not exists(labels): 92 | return logits 93 | 94 | return F.cross_entropy(rearrange(logits, 'b n d -> b d n'), labels) 95 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests/*.py 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'memory-efficient-attention-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.1.6', 7 | license='MIT', 8 | description = 'Memory Efficient Attention - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'attention-mechanism' 17 | ], 18 | install_requires=[ 19 | 'einops>=0.4.1', 20 | 'torch>=1.6' 21 | ], 22 | setup_requires=[ 23 | 'pytest-runner', 24 | ], 25 | tests_require=[ 26 | 'pytest' 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.8', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from memory_efficient_attention_pytorch import Attention 3 | 4 | from memory_efficient_attention_pytorch.memory_efficient_attention import attention 5 | from memory_efficient_attention_pytorch.flash_attention import FlashAttention, FlashAttentionFunction 6 | 7 | # constants 8 | 9 | def isclose(a, b, atol = 1e-6): 10 | diff = (a - b).abs().amax() 11 | return diff < atol 12 | 13 | # test outputs are equal 14 | 15 | def test_output_equal(): 16 | attn = Attention( 17 | dim = 512, 18 | dim_head = 64, 19 | heads = 8, 20 | q_bucket_size = 64, 21 | k_bucket_size = 64, 22 | causal = True 23 | ) 24 | 25 | x = torch.randn(2, 2048, 512) 26 | mask = torch.ones(2, 2048).bool() 27 | 28 | out = attn(x, mask = mask) 29 | mem_efficient_out = attn(x, mask = mask, memory_efficient = True) 30 | 31 | assert isclose(mem_efficient_out, out, atol = 1e-6) 32 | 33 | # test gradients equal 34 | 35 | def test_gradients_equal(): 36 | attn = Attention( 37 | dim = 512, 38 | dim_head = 64, 39 | heads = 8, 40 | q_bucket_size = 64, 41 | k_bucket_size = 64, 42 | causal = True 43 | ) 44 | 45 | def loss_fn(inp, **kwargs): 46 | return attn(inp, **kwargs).sum() 47 | 48 | x = torch.randn(2, 2048, 512).requires_grad_() 49 | mask = torch.ones(2, 2048).bool() 50 | 51 | loss_fn(x, mask = mask).backward() 52 | out_grad = x.grad.clone() 53 | 54 | x.grad.zero_() 55 | loss_fn(x, mask = mask, memory_efficient = True).backward() 56 | mem_efficient_out_grad = x.grad.clone() 57 | 58 | assert isclose(out_grad, mem_efficient_out_grad, atol = 1e-5) 59 | 60 | # test flash attention 61 | 62 | def test_flash_attn_output_equal(): 63 | attn_kwargs = dict( 64 | dim = 512, 65 | dim_head = 64, 66 | heads = 8, 67 | q_bucket_size = 64, 68 | k_bucket_size = 64, 69 | causal = True 70 | ) 71 | 72 | attn = Attention(**attn_kwargs) 73 | flash_attn = FlashAttention(**attn_kwargs) 74 | 75 | flash_attn.to_q = attn.to_q 76 | flash_attn.to_kv = attn.to_kv 77 | flash_attn.to_out = attn.to_out 78 | 79 | x = torch.randn(2, 2048, 512) 80 | mask = torch.ones(2, 2048).bool() 81 | 82 | out = attn(x, mask = mask) 83 | mem_efficient_out = flash_attn(x, mask = mask) 84 | 85 | assert isclose(mem_efficient_out, out, atol = 1e-6) 86 | 87 | # test gradients equal 88 | 89 | def test_flash_attn_gradients_equal(): 90 | q = torch.randn(1, 8, 1024, 512).requires_grad_() 91 | k = torch.randn(1, 8, 1024, 512).requires_grad_() 92 | v = torch.randn(1, 8, 1024, 512).requires_grad_() 93 | 94 | mask = torch.ones(1, 1024).bool() 95 | 96 | o = attention(q, k, v, mask = mask, causal = True) 97 | o.sum().backward() 98 | 99 | dq_grad = q.grad.clone() 100 | dk_grad = k.grad.clone() 101 | dv_grad = v.grad.clone() 102 | 103 | q.grad.zero_() 104 | k.grad.zero_() 105 | v.grad.zero_() 106 | 107 | flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64) 108 | flash_o.sum().backward() 109 | 110 | flash_dq_grad = q.grad.clone() 111 | flash_dk_grad = k.grad.clone() 112 | flash_dv_grad = v.grad.clone() 113 | 114 | assert isclose(flash_dq_grad, dq_grad, atol = 1e-5) 115 | assert isclose(flash_dk_grad, dk_grad, atol = 1e-5) 116 | assert isclose(flash_dv_grad, dv_grad, atol = 1e-5) 117 | 118 | # test flash attention - full attention mask 119 | 120 | def test_flash_attn_full_attn_mask_output_equal(): 121 | attn_kwargs = dict( 122 | dim = 512, 123 | dim_head = 64, 124 | heads = 8, 125 | q_bucket_size = 64, 126 | k_bucket_size = 64, 127 | causal = True 128 | ) 129 | 130 | attn = Attention(**attn_kwargs) 131 | flash_attn = FlashAttention(**attn_kwargs) 132 | 133 | flash_attn.to_q = attn.to_q 134 | flash_attn.to_kv = attn.to_kv 135 | flash_attn.to_out = attn.to_out 136 | 137 | x = torch.randn(2, 2048, 512) 138 | mask = torch.ones(2, 1, 2048, 2048).bool() 139 | 140 | out = attn(x, mask = mask) 141 | mem_efficient_out = flash_attn(x, mask = mask) 142 | 143 | assert isclose(mem_efficient_out, out, atol = 1e-6) 144 | 145 | # test gradients equal - full attention mask 146 | 147 | def test_flash_attn_full_attn_mask_gradients_equal(): 148 | q = torch.randn(1, 8, 1024, 512).requires_grad_() 149 | k = torch.randn(1, 8, 1024, 512).requires_grad_() 150 | v = torch.randn(1, 8, 1024, 512).requires_grad_() 151 | 152 | mask = torch.ones(1, 1, 1024, 1024).bool() 153 | 154 | o = attention(q, k, v, mask = mask, causal = True) 155 | o.sum().backward() 156 | 157 | dq_grad = q.grad.clone() 158 | dk_grad = k.grad.clone() 159 | dv_grad = v.grad.clone() 160 | 161 | q.grad.zero_() 162 | k.grad.zero_() 163 | v.grad.zero_() 164 | 165 | flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64) 166 | flash_o.sum().backward() 167 | 168 | flash_dq_grad = q.grad.clone() 169 | flash_dk_grad = k.grad.clone() 170 | flash_dv_grad = v.grad.clone() 171 | 172 | assert isclose(flash_dq_grad, dq_grad, atol = 1e-5) 173 | assert isclose(flash_dk_grad, dk_grad, atol = 1e-5) 174 | assert isclose(flash_dv_grad, dv_grad, atol = 1e-5) 175 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from memory_efficient_attention_pytorch.transformer import Transformer 2 | from memory_efficient_attention_pytorch.autoregressive_wrapper import AutoregressiveWrapper 3 | 4 | import random 5 | import tqdm 6 | import gzip 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | # constants 14 | 15 | NUM_BATCHES = int(1e5) 16 | BATCH_SIZE = 4 17 | GRADIENT_ACCUMULATE_EVERY = 4 18 | LEARNING_RATE = 2e-4 19 | VALIDATE_EVERY = 100 20 | GENERATE_EVERY = 500 21 | GENERATE_LENGTH = 1024 22 | SEQ_LEN = 4096 23 | 24 | # helpers 25 | 26 | def cycle(loader): 27 | while True: 28 | for data in loader: 29 | yield data 30 | 31 | def decode_token(token): 32 | return str(chr(max(32, token))) 33 | 34 | def decode_tokens(tokens): 35 | return ''.join(list(map(decode_token, tokens))) 36 | 37 | # instantiate GPT-like decoder model 38 | 39 | model = Transformer( 40 | num_tokens = 256, 41 | dim = 512, 42 | max_seq_len = SEQ_LEN, 43 | depth = 6, 44 | heads = 8, 45 | causal = True, 46 | q_bucket_size = 256, 47 | k_bucket_size = 256, 48 | ff_chunks = 5, 49 | use_flash_attn = True 50 | ) 51 | 52 | model = AutoregressiveWrapper(model) 53 | model.cuda() 54 | 55 | # prepare enwik8 data 56 | 57 | with gzip.open('./data/enwik8.gz') as file: 58 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) 59 | trX, vaX = np.split(X, [int(90e6)]) 60 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) 61 | 62 | class TextSamplerDataset(Dataset): 63 | def __init__(self, data, seq_len): 64 | super().__init__() 65 | self.data = data 66 | self.seq_len = seq_len 67 | 68 | def __getitem__(self, index): 69 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 70 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long() 71 | return full_seq.cuda() 72 | 73 | def __len__(self): 74 | return self.data.size(0) // self.seq_len 75 | 76 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 77 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 78 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE)) 79 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE)) 80 | 81 | # optimizer 82 | 83 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 84 | 85 | # training 86 | 87 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'): 88 | model.train() 89 | 90 | for __ in range(GRADIENT_ACCUMULATE_EVERY): 91 | loss = model(next(train_loader)) 92 | loss.backward() 93 | 94 | print(f'training loss: {loss.item()}') 95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 96 | optim.step() 97 | optim.zero_grad() 98 | 99 | if i % VALIDATE_EVERY == 0: 100 | model.eval() 101 | with torch.no_grad(): 102 | loss = model(next(val_loader)) 103 | print(f'validation loss: {loss.item()}') 104 | 105 | if i != 0 and i % GENERATE_EVERY == 0: 106 | model.eval() 107 | inp = random.choice(val_dataset)[:-1] 108 | prime = decode_tokens(inp) 109 | print(f'%s \n\n %s', (prime, '*' * 100)) 110 | 111 | sample = model.generate(inp[None, ...], GENERATE_LENGTH) 112 | output_str = decode_tokens(sample[0]) 113 | print(output_str) 114 | --------------------------------------------------------------------------------