├── .github
└── workflows
│ ├── python-package.yml
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
├── README.md
└── enwik8.gz
├── memory_efficient_attention_pytorch
├── __init__.py
├── autoregressive_wrapper.py
├── cosine_sim_flash_attention.py
├── flash_attention.py
├── memory_efficient_attention.py
├── memory_efficient_cosine_sim_attention.py
├── reversible.py
└── transformer.py
├── setup.cfg
├── setup.py
├── tests
└── test.py
└── train.py
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: [3.8, 3.9]
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install pytest
30 | python -m pip install pytest torch==1.10.0
31 | - name: Test with pytest
32 | run: |
33 | python setup.py test
34 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Memory Efficient Attention Pytorch (obsolete)
2 |
3 | Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.
4 |
5 | This repository also contains a naive non-CUDA implementation of the improvements made by Tri Dao with his Flash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.
6 |
7 | Update: from now on, you should just be using the `F.scaled_dot_product_attention` function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at the official repository
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install memory-efficient-attention-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | For autoregressive language model
18 |
19 | ```python
20 | import torch
21 | from memory_efficient_attention_pytorch import Attention
22 |
23 | attn = Attention(
24 | dim = 512,
25 | dim_head = 64, # dimension per head
26 | heads = 8, # number of attention heads
27 | causal = True, # autoregressive or not
28 | memory_efficient = True, # whether to use memory efficient attention (can be turned off to test against normal attention)
29 | q_bucket_size = 1024, # bucket size along queries dimension
30 | k_bucket_size = 2048 # bucket size along key / values dimension
31 | ).cuda()
32 |
33 | x = torch.randn(1, 65536, 512).cuda()
34 | out = attn(x) # (1, 65536, 512)
35 | ```
36 |
37 | Cross attention
38 |
39 | ```python
40 | import torch
41 | from memory_efficient_attention_pytorch import Attention
42 |
43 | cross_attn = Attention(
44 | dim = 512,
45 | dim_head = 64,
46 | heads = 8,
47 | memory_efficient = True,
48 | q_bucket_size = 1024,
49 | k_bucket_size = 2048
50 | ).cuda()
51 |
52 | x = torch.randn(1, 65536, 512).cuda()
53 | context = torch.randn(1, 65536, 512).cuda()
54 | mask = torch.ones(1, 65536).bool().cuda()
55 |
56 | out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
57 | ```
58 |
59 | ## Citations
60 |
61 | ```bibtex
62 | @misc{rabe2021selfattention,
63 | title = {Self-attention Does Not Need $O(n^2)$ Memory},
64 | author = {Markus N. Rabe and Charles Staats},
65 | year = {2021},
66 | eprint = {2112.05682},
67 | archivePrefix = {arXiv},
68 | primaryClass = {cs.LG}
69 | }
70 | ```
71 |
72 | ```bibtex
73 | @misc{liu2021swin,
74 | title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
75 | author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
76 | year = {2021},
77 | eprint = {2111.09883},
78 | archivePrefix = {arXiv},
79 | primaryClass = {cs.CV}
80 | }
81 | ```
82 |
83 | ```bibtex
84 | @article{Dao2022FlashAttentionFA,
85 | title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
86 | author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
87 | journal = {ArXiv},
88 | year = {2022},
89 | volume = {abs/2205.14135}
90 | }
91 | ```
92 |
93 | ```bibtex
94 | @article{dao2023flashattention2,
95 | title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
96 | author = {Dao, Tri},
97 | year = {2023}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/memory-efficient-attention-pytorch/d54f391370ecbf843a871f0e260425d076995550/data/enwik8.gz
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
2 | from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
3 | from memory_efficient_attention_pytorch.flash_attention import FlashAttention
4 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/autoregressive_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | # helper function
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 | def eval_decorator(fn):
11 | def inner(model, *args, **kwargs):
12 | was_training = model.training
13 | model.eval()
14 | out = fn(model, *args, **kwargs)
15 | model.train(was_training)
16 | return out
17 | return inner
18 |
19 | # top k filtering
20 |
21 | def top_k(logits, thres = 0.9):
22 | k = int((1 - thres) * logits.shape[-1])
23 | val, ind = torch.topk(logits, k)
24 | probs = torch.full_like(logits, float('-inf'))
25 | probs.scatter_(1, ind, val)
26 | return probs
27 |
28 | class AutoregressiveWrapper(nn.Module):
29 | def __init__(self, net, pad_value = 0):
30 | super().__init__()
31 | self.pad_value = pad_value
32 | self.net = net
33 | self.max_seq_len = net.max_seq_len
34 |
35 | @torch.no_grad()
36 | @eval_decorator
37 | def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_thres = 0.9, **kwargs):
38 | b, t, device = *start_tokens.shape, start_tokens.device
39 |
40 | out = start_tokens
41 |
42 | for _ in range(seq_len):
43 | x = out[:, -self.max_seq_len:]
44 |
45 | logits = self.net(x, **kwargs)[:, -1, :]
46 |
47 | filtered_logits = top_k(logits, thres = filter_thres)
48 | probs = F.softmax(filtered_logits / temperature, dim=-1)
49 |
50 | sample = torch.multinomial(probs, 1)
51 |
52 | out = torch.cat((out, sample), dim=-1)
53 |
54 | if exists(eos_token):
55 | is_eos_token = (out == eos_token)
56 |
57 | if is_eos_token.any(dim = -1).all():
58 | # mask out everything after the eos tokens
59 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
60 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
61 | out = out.masked_fill(mask, self.pad_value)
62 | break
63 |
64 | out = out[:, t:]
65 | return out
66 |
67 | def forward(self, x, **kwargs):
68 | x_inp, x_labels = x[:, :-1], x[:, 1:]
69 | return self.net(x_inp, labels = x_labels, **kwargs)
70 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/cosine_sim_flash_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from functools import partial
4 | from torch import nn, einsum
5 | import torch.nn.functional as F
6 | from torch.autograd.function import Function
7 |
8 | from einops import rearrange
9 |
10 | # constants
11 |
12 | EPSILON = 1e-6
13 |
14 | # helper functions
15 |
16 | def exists(val):
17 | return val is not None
18 |
19 | def default(val, d):
20 | return val if exists(val) else d
21 |
22 | def l2norm(t):
23 | return F.normalize(t, dim = -1)
24 |
25 | # flash attention forwards and backwards
26 |
27 | class FlashAttentionFunction(Function):
28 | @staticmethod
29 | @torch.no_grad()
30 | def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
31 | device = q.device
32 | max_neg_value = -torch.finfo(q.dtype).max
33 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
34 |
35 | k_len = k.shape[-2] # in cosine sim attention, row sums are bounded by key / values sequence length
36 |
37 | o = torch.zeros_like(q)
38 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
39 |
40 | if not exists(mask):
41 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
42 | else:
43 | mask = rearrange(mask, 'b n -> b 1 1 n')
44 | mask = mask.split(q_bucket_size, dim = -1)
45 |
46 | row_splits = zip(
47 | q.split(q_bucket_size, dim = -2),
48 | o.split(q_bucket_size, dim = -2),
49 | mask,
50 | all_row_sums.split(q_bucket_size, dim = -2),
51 | )
52 |
53 | for ind, (qc, oc, row_mask, row_sums) in enumerate(row_splits):
54 | q_start_index = ind * q_bucket_size - qk_len_diff
55 |
56 | col_splits = zip(
57 | k.split(k_bucket_size, dim = -2),
58 | v.split(k_bucket_size, dim = -2),
59 | )
60 |
61 | for k_ind, (kc, vc) in enumerate(col_splits):
62 | k_start_index = k_ind * k_bucket_size
63 |
64 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
65 |
66 | if exists(row_mask):
67 | attn_weights.masked_fill_(~row_mask, max_neg_value)
68 |
69 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
70 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
71 | attn_weights.masked_fill_(causal_mask, max_neg_value)
72 |
73 | attn_weights -= scale
74 | exp_weights = torch.exp(attn_weights)
75 |
76 | if exists(row_mask):
77 | exp_weights.masked_fill_(~row_mask, 0.)
78 |
79 | block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
80 |
81 | exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
82 |
83 | oc.add_(exp_values / k_len)
84 | row_sums.add_(block_row_sums)
85 |
86 | ctx.args = (scale, causal, mask, q_bucket_size, k_bucket_size)
87 | ctx.save_for_backward(q, k, v, o, all_row_sums)
88 |
89 | o.mul_(k_len / all_row_sums)
90 |
91 | return o
92 |
93 | @staticmethod
94 | @torch.no_grad()
95 | def backward(ctx, do):
96 | scale, causal, mask, q_bucket_size, k_bucket_size = ctx.args
97 | q, k, v, o, l = ctx.saved_tensors
98 |
99 | device = q.device
100 |
101 | max_neg_value = -torch.finfo(q.dtype).max
102 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
103 |
104 | dq = torch.zeros_like(q)
105 | dk = torch.zeros_like(k)
106 | dv = torch.zeros_like(v)
107 |
108 | row_splits = zip(
109 | q.split(q_bucket_size, dim = -2),
110 | o.split(q_bucket_size, dim = -2),
111 | do.split(q_bucket_size, dim = -2),
112 | mask,
113 | l.split(q_bucket_size, dim = -2),
114 | dq.split(q_bucket_size, dim = -2)
115 | )
116 |
117 | for ind, (qc, oc, doc, row_mask, lc, dqc) in enumerate(row_splits):
118 | q_start_index = ind * q_bucket_size - qk_len_diff
119 |
120 | col_splits = zip(
121 | k.split(k_bucket_size, dim = -2),
122 | v.split(k_bucket_size, dim = -2),
123 | dk.split(k_bucket_size, dim = -2),
124 | dv.split(k_bucket_size, dim = -2),
125 | )
126 |
127 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
128 | k_start_index = k_ind * k_bucket_size
129 |
130 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
131 |
132 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
133 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
134 | attn_weights.masked_fill_(causal_mask, max_neg_value)
135 |
136 | exp_attn_weights = torch.exp(attn_weights - scale)
137 |
138 | if exists(row_mask):
139 | exp_attn_weights.masked_fill_(~row_mask, 0.)
140 |
141 | p = exp_attn_weights / lc
142 |
143 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
144 | dp = einsum('... i d, ... j d -> ... i j', doc, vc)
145 |
146 | D = (doc * oc).sum(dim = -1, keepdims = True)
147 | ds = p * scale * (dp - D)
148 |
149 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
150 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
151 |
152 | dqc.add_(dq_chunk)
153 | dkc.add_(dk_chunk)
154 | dvc.add_(dv_chunk)
155 |
156 | return dq, dk, dv, None, None, None, None, None
157 |
158 | # main class
159 |
160 | # flash attention for cosine sim attention
161 | # a bit less complicated, as no more need to worry about softmax numerical stability, and row sums are bounded
162 |
163 | class FlashAttention(nn.Module):
164 | def __init__(
165 | self,
166 | *,
167 | dim,
168 | scale = 16,
169 | heads = 8,
170 | dim_head = 64,
171 | causal = False,
172 | q_bucket_size = 512,
173 | k_bucket_size = 1024
174 | ):
175 | super().__init__()
176 | self.heads = heads
177 |
178 | self.scale = scale
179 | self.causal = causal
180 |
181 | inner_dim = heads * dim_head
182 |
183 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
184 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
185 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
186 |
187 | # memory efficient attention related parameters
188 | # can be overriden on forward
189 | self.q_bucket_size = q_bucket_size
190 | self.k_bucket_size = k_bucket_size
191 |
192 | def forward(
193 | self,
194 | x,
195 | context = None,
196 | mask = None,
197 | q_bucket_size = None,
198 | k_bucket_size = None,
199 | ):
200 | q_bucket_size = default(q_bucket_size, self.q_bucket_size)
201 | k_bucket_size = default(k_bucket_size, self.k_bucket_size)
202 |
203 | h = self.heads
204 | context = default(context, x)
205 |
206 | q = self.to_q(x)
207 | k, v = self.to_kv(context).chunk(2, dim = -1)
208 |
209 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
210 |
211 | q, k = map(l2norm, (q, k))
212 |
213 | out = FlashAttentionFunction.apply(q, k, v, mask, self.scale, self.causal, q_bucket_size, k_bucket_size)
214 |
215 | out = rearrange(out, 'b h n d -> b n (h d)')
216 | return self.to_out(out)
217 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/flash_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from functools import partial
4 | from torch import nn, einsum
5 | from torch.autograd.function import Function
6 |
7 | from einops import rearrange
8 |
9 | # constants
10 |
11 | EPSILON = 1e-10
12 |
13 | # helper functions
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def default(val, d):
19 | return val if exists(val) else d
20 |
21 | # flash attention forwards and backwards
22 |
23 | # flash attention v1 - https://arxiv.org/abs/2205.14135
24 | # flash attention v2 - https://tridao.me/publications/flash2/flash2.pdf
25 |
26 | class FlashAttentionFunction(Function):
27 | @staticmethod
28 | @torch.no_grad()
29 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
30 | """ Algorithm 1 in the v2 paper """
31 |
32 | device = q.device
33 | max_neg_value = -torch.finfo(q.dtype).max
34 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
35 |
36 | o = torch.zeros_like(q)
37 | all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
38 | all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)
39 |
40 | scale = (q.shape[-1] ** -0.5)
41 |
42 | num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
43 | num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)
44 |
45 | if exists(mask) and mask.ndim == 2:
46 | mask = rearrange(mask, 'b n -> b 1 1 n')
47 |
48 | if not exists(mask):
49 | col_masks = (None,) * num_col_tiles
50 | mask = (col_masks,) * num_row_tiles
51 | else:
52 | mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
53 | mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)
54 |
55 | row_splits = zip(
56 | q.split(q_bucket_size, dim = -2),
57 | o.split(q_bucket_size, dim = -2),
58 | mask,
59 | all_row_sums.split(q_bucket_size, dim = -2),
60 | all_row_maxes.split(q_bucket_size, dim = -2),
61 | )
62 |
63 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
64 | q_start_index = ind * q_bucket_size - qk_len_diff
65 |
66 | col_splits = zip(
67 | k.split(k_bucket_size, dim = -2),
68 | v.split(k_bucket_size, dim = -2),
69 | row_mask
70 | )
71 |
72 | for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
73 | k_start_index = k_ind * k_bucket_size
74 |
75 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
76 |
77 | if exists(col_mask):
78 | attn_weights.masked_fill_(~col_mask, max_neg_value)
79 |
80 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
81 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
82 | attn_weights.masked_fill_(causal_mask, max_neg_value)
83 |
84 | block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
85 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
86 |
87 | exp_weights = torch.exp(attn_weights - new_row_maxes)
88 |
89 | if exists(col_mask):
90 | exp_weights.masked_fill_(~col_mask, 0.)
91 |
92 | block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)
93 |
94 | exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
95 |
96 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
97 |
98 | new_row_sums = exp_row_max_diff * row_sums + block_row_sums
99 |
100 | oc.mul_(exp_row_max_diff).add_(exp_values)
101 |
102 | row_maxes.copy_(new_row_maxes)
103 | row_sums.copy_(new_row_sums)
104 |
105 | oc.div_(row_sums)
106 |
107 | lse = all_row_sums.log() + all_row_maxes
108 |
109 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
110 | ctx.save_for_backward(q, k, v, o, lse)
111 |
112 | return o
113 |
114 | @staticmethod
115 | @torch.no_grad()
116 | def backward(ctx, do):
117 | """ Algorithm 2 in the v2 paper """
118 |
119 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
120 | q, k, v, o, lse = ctx.saved_tensors
121 |
122 | device = q.device
123 |
124 | max_neg_value = -torch.finfo(q.dtype).max
125 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
126 |
127 | dq = torch.zeros_like(q)
128 | dk = torch.zeros_like(k)
129 | dv = torch.zeros_like(v)
130 |
131 | row_splits = zip(
132 | q.split(q_bucket_size, dim = -2),
133 | o.split(q_bucket_size, dim = -2),
134 | do.split(q_bucket_size, dim = -2),
135 | mask,
136 | lse.split(q_bucket_size, dim = -2),
137 | dq.split(q_bucket_size, dim = -2)
138 | )
139 |
140 | for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
141 | q_start_index = ind * q_bucket_size - qk_len_diff
142 |
143 | col_splits = zip(
144 | k.split(k_bucket_size, dim = -2),
145 | v.split(k_bucket_size, dim = -2),
146 | dk.split(k_bucket_size, dim = -2),
147 | dv.split(k_bucket_size, dim = -2),
148 | row_mask
149 | )
150 |
151 | for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
152 | k_start_index = k_ind * k_bucket_size
153 |
154 | attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
155 |
156 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
157 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
158 | attn_weights.masked_fill_(causal_mask, max_neg_value)
159 |
160 | p = torch.exp(attn_weights - lsec)
161 |
162 | if exists(col_mask):
163 | p.masked_fill_(~col_mask, 0.)
164 |
165 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
166 | dp = einsum('... i d, ... j d -> ... i j', doc, vc)
167 |
168 | D = (doc * oc).sum(dim = -1, keepdims = True)
169 | ds = p * scale * (dp - D)
170 |
171 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
172 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
173 |
174 | dqc.add_(dq_chunk)
175 | dkc.add_(dk_chunk)
176 | dvc.add_(dv_chunk)
177 |
178 | return dq, dk, dv, None, None, None, None
179 |
180 | # main class
181 |
182 | # just flash attention in plain pytorch
183 | # it will be way slower than implementing it in CUDA
184 | # for tinkering and educational purposes
185 |
186 | class FlashAttention(nn.Module):
187 | def __init__(
188 | self,
189 | *,
190 | dim,
191 | heads = 8,
192 | dim_head = 64,
193 | causal = False,
194 | q_bucket_size = 512,
195 | k_bucket_size = 1024
196 | ):
197 | super().__init__()
198 | self.heads = heads
199 | self.causal = causal
200 |
201 | inner_dim = heads * dim_head
202 |
203 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
204 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
205 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
206 |
207 | # memory efficient attention related parameters
208 | # can be overriden on forward
209 | self.q_bucket_size = q_bucket_size
210 | self.k_bucket_size = k_bucket_size
211 |
212 | def forward(
213 | self,
214 | x,
215 | context = None,
216 | mask = None,
217 | q_bucket_size = None,
218 | k_bucket_size = None,
219 | ):
220 | q_bucket_size = default(q_bucket_size, self.q_bucket_size)
221 | k_bucket_size = default(k_bucket_size, self.k_bucket_size)
222 |
223 | h = self.heads
224 | context = default(context, x)
225 |
226 | q = self.to_q(x)
227 | k, v = self.to_kv(context).chunk(2, dim = -1)
228 |
229 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
230 |
231 | out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
232 |
233 | out = rearrange(out, 'b h n d -> b n (h d)')
234 | return self.to_out(out)
235 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/memory_efficient_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 | from torch import nn, einsum
4 | from torch.utils.checkpoint import checkpoint
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 | # helper functions
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 | def default(val, d):
15 | return val if exists(val) else d
16 |
17 | # regular attention
18 |
19 | def attention(
20 | q, k, v,
21 | mask = None,
22 | causal = False,
23 | attn_bias = None,
24 | **kwargs
25 | ):
26 | scale = q.shape[-1] ** -0.5
27 | q = q * scale
28 |
29 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
30 |
31 | if exists(attn_bias):
32 | sim = sim + attn_bias
33 |
34 | mask_value = -torch.finfo(sim.dtype).max
35 |
36 | if exists(mask):
37 | if mask.ndim == 2:
38 | mask = rearrange(mask, 'b j -> b 1 1 j')
39 | sim = sim.masked_fill(~mask, mask_value)
40 |
41 | if causal:
42 | i, j = sim.shape[-2:]
43 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
44 | sim = sim.masked_fill(mask, mask_value)
45 |
46 | sim = sim - sim.amax(dim = -1, keepdim = True).detach()
47 | attn = sim.softmax(dim = -1)
48 |
49 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
50 | return out
51 |
52 | # memory efficient attention
53 |
54 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
55 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
56 |
57 | weight = einsum('b h i d, b h j d -> b h i j', q, k)
58 |
59 | if exists(attn_bias_chunk):
60 | weight = weight + attn_bias_chunk
61 |
62 | mask_value = -torch.finfo(weight.dtype).max
63 |
64 | if exists(mask):
65 | mask = rearrange(mask, 'b j -> b 1 1 j')
66 | weight = weight.masked_fill(~mask, mask_value)
67 |
68 | if causal and q_start_index < (k_start_index + k_chunk_size - 1):
69 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
70 | weight = weight.masked_fill(causal_mask, mask_value)
71 |
72 | weight_max = weight.amax(dim = -1, keepdim = True).detach()
73 | weight = weight - weight_max
74 |
75 | exp_weight = weight.exp()
76 |
77 | exp_weight = F.dropout(exp_weight, p = dropout)
78 |
79 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
80 |
81 | return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
82 |
83 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
84 |
85 | def memory_efficient_attention(
86 | q, k, v,
87 | mask = None,
88 | causal = False,
89 | attn_bias = None,
90 | q_bucket_size = 512,
91 | k_bucket_size = 1024,
92 | eps = 1e-8,
93 | dropout = 0.,
94 | training = False
95 | ):
96 | scale = q.shape[-1] ** -0.5
97 | q = q * scale
98 |
99 | # function
100 |
101 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
102 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
103 |
104 | # chunk all the inputs
105 |
106 | q_chunks = q.split(q_bucket_size, dim = -2)
107 | k_chunks = k.split(k_bucket_size, dim = -2)
108 | v_chunks = v.split(k_bucket_size, dim = -2)
109 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
110 |
111 | if exists(attn_bias):
112 | i, j = attn_bias.shape[-2:]
113 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
114 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
115 |
116 | # loop through all chunks and accumulate
117 |
118 | out = []
119 | for q_index, q_chunk in enumerate(q_chunks):
120 | exp_weights = []
121 | weighted_values = []
122 | weight_maxes = []
123 |
124 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
125 | q_start_index = q_index * q_bucket_size
126 | k_start_index = k_index * k_bucket_size
127 |
128 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
129 | # if chunk is to be all masked out causally, skip
130 | continue
131 |
132 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
133 |
134 | exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
135 | q_chunk,
136 | k_chunk,
137 | v_chunk,
138 | mask_chunk,
139 | attn_bias_chunk,
140 | causal,
141 | (q_start_index, k_start_index),
142 | dropout if training else 0.
143 | )
144 |
145 | exp_weights.append(exp_weight_chunk)
146 | weighted_values.append(weighted_value_chunk)
147 | weight_maxes.append(weight_max_chunk)
148 |
149 | weight_maxes = torch.stack(weight_maxes, dim = -1)
150 |
151 | weighted_values = torch.stack(weighted_values, dim = -1)
152 | exp_weights = torch.stack(exp_weights, dim = -1)
153 |
154 | global_max = weight_maxes.amax(dim = -1, keepdim = True)
155 | renorm_factor = (weight_maxes - global_max).exp().detach()
156 |
157 | exp_weights = exp_weights * renorm_factor
158 | weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
159 |
160 | all_values = weighted_values.sum(dim = -1)
161 | all_weights = exp_weights.sum(dim = -1)
162 |
163 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
164 | out.append(normalized_values)
165 |
166 | return torch.cat(out, dim = -2)
167 |
168 | # main class
169 |
170 | class Attention(nn.Module):
171 | def __init__(
172 | self,
173 | *,
174 | dim,
175 | heads = 8,
176 | dim_head = 64,
177 | dropout = 0.,
178 | causal = False,
179 | memory_efficient = False,
180 | q_bucket_size = 512,
181 | k_bucket_size = 1024
182 | ):
183 | super().__init__()
184 | self.heads = heads
185 | self.causal = causal
186 | self.dropout = dropout
187 | inner_dim = heads * dim_head
188 |
189 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
190 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
191 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
192 |
193 | # memory efficient attention related parameters
194 | # can be overriden on forward
195 | self.memory_efficient = memory_efficient
196 | self.q_bucket_size = q_bucket_size
197 | self.k_bucket_size = k_bucket_size
198 |
199 | def forward(
200 | self,
201 | x,
202 | context = None,
203 | mask = None,
204 | attn_bias = None,
205 | memory_efficient = None,
206 | q_bucket_size = None,
207 | k_bucket_size = None,
208 | ):
209 | memory_efficient = default(memory_efficient, self.memory_efficient)
210 | q_bucket_size = default(q_bucket_size, self.q_bucket_size)
211 | k_bucket_size = default(k_bucket_size, self.k_bucket_size)
212 |
213 | h = self.heads
214 | context = default(context, x)
215 |
216 | q = self.to_q(x)
217 | k, v = self.to_kv(context).chunk(2, dim = -1)
218 |
219 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
220 |
221 | attn_fn = attention if not memory_efficient else memory_efficient_attention
222 |
223 | out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size,
224 | k_bucket_size = k_bucket_size, dropout = self.dropout, training = self.training)
225 |
226 | out = rearrange(out, 'b h n d -> b n (h d)')
227 | return self.to_out(out)
228 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/memory_efficient_cosine_sim_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from functools import partial
5 | from torch import nn, einsum
6 | from torch.utils.checkpoint import checkpoint
7 |
8 | from einops import rearrange
9 |
10 | # helper functions
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 | def l2norm(t):
19 | return F.normalize(t, dim = -1)
20 |
21 | # regular attention
22 |
23 | def attention(
24 | q, k, v,
25 | mask = None,
26 | causal = False,
27 | attn_bias = None,
28 | **kwargs
29 | ):
30 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
31 |
32 | if exists(attn_bias):
33 | sim = sim + attn_bias
34 |
35 | mask_value = -torch.finfo(sim.dtype).max
36 |
37 | if exists(mask):
38 | mask = rearrange(mask, 'b j -> b 1 1 j')
39 | sim = sim.masked_fill(~mask, mask_value)
40 |
41 | if causal:
42 | i, j = sim.shape[-2:]
43 | mask = torch.ones(i, j, device = q.device, dtype = torch.bool).triu(j - i + 1)
44 | sim = sim.masked_fill(mask, mask_value)
45 |
46 | attn = sim.softmax(dim = -1)
47 |
48 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
49 | return out
50 |
51 | # memory efficient attention
52 |
53 | def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices):
54 | q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[-2], q.device
55 |
56 | weight = einsum('b h i d, b h j d -> b h i j', q, k)
57 |
58 | if exists(attn_bias_chunk):
59 | weight = weight + attn_bias_chunk
60 |
61 | mask_value = -torch.finfo(weight.dtype).max
62 |
63 | if exists(mask):
64 | mask = rearrange(mask, 'b j -> b 1 1 j')
65 | weight = weight.masked_fill(~mask, mask_value)
66 |
67 | if causal and q_start_index < (k_start_index + k_chunk_size - 1):
68 | causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
69 | weight = weight.masked_fill(causal_mask, mask_value)
70 |
71 | exp_weight = weight.exp()
72 | weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
73 |
74 | return exp_weight.sum(dim = -1), weighted_value
75 |
76 | checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
77 |
78 | def numerically_unstable_memory_efficient_attention(
79 | q, k, v,
80 | mask = None,
81 | causal = False,
82 | attn_bias = None,
83 | q_bucket_size = 512,
84 | k_bucket_size = 1024,
85 | eps = 1e-8
86 | ):
87 | needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
88 | summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
89 |
90 | # chunk all the inputs
91 |
92 | q_chunks = q.split(q_bucket_size, dim = -2)
93 | k_chunks = k.split(k_bucket_size, dim = -2)
94 | v_chunks = v.split(k_bucket_size, dim = -2)
95 | mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
96 |
97 | if exists(attn_bias):
98 | i, j = attn_bias.shape[-2:]
99 | attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
100 | attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
101 |
102 | # loop through all chunks and accumulate
103 |
104 | out = []
105 | for q_index, q_chunk in enumerate(q_chunks):
106 | q_start_index = q_index * q_bucket_size
107 | exp_weights = []
108 | weighted_values = []
109 |
110 | for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
111 | k_start_index = k_index * k_bucket_size
112 |
113 | if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
114 | # if chunk is to be all masked out causally, skip
115 | continue
116 |
117 | attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
118 |
119 | exp_weight_chunk, weighted_value_chunk = summarize_qkv_fn(
120 | q_chunk,
121 | k_chunk,
122 | v_chunk,
123 | mask_chunk,
124 | attn_bias_chunk,
125 | causal,
126 | (q_start_index, k_start_index)
127 | )
128 |
129 | exp_weights.append(exp_weight_chunk)
130 | weighted_values.append(weighted_value_chunk)
131 |
132 | all_values = sum(weighted_values)
133 | all_weights = sum(exp_weights)
134 |
135 | normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
136 | out.append(normalized_values)
137 |
138 | return torch.cat(out, dim = -2)
139 |
140 | # main class
141 |
142 | class CosineSimAttention(nn.Module):
143 | def __init__(
144 | self,
145 | *,
146 | dim,
147 | seq_len,
148 | heads = 8,
149 | dim_head = 64,
150 | dropout = 0.,
151 | causal = False,
152 | memory_efficient = False,
153 | q_bucket_size = 512,
154 | k_bucket_size = 1024
155 | ):
156 | super().__init__()
157 | self.heads = heads
158 | self.causal = causal
159 |
160 | inner_dim = heads * dim_head
161 |
162 | scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len))
163 | self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value))
164 |
165 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
166 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
167 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
168 |
169 | # memory efficient attention related parameters
170 | # can be overriden on forward
171 | self.memory_efficient = memory_efficient
172 | self.q_bucket_size = q_bucket_size
173 | self.k_bucket_size = k_bucket_size
174 |
175 | def forward(
176 | self,
177 | x,
178 | context = None,
179 | mask = None,
180 | attn_bias = None,
181 | memory_efficient = None,
182 | q_bucket_size = None,
183 | k_bucket_size = None,
184 | ):
185 | memory_efficient = default(memory_efficient, self.memory_efficient)
186 | q_bucket_size = default(q_bucket_size, self.q_bucket_size)
187 | k_bucket_size = default(k_bucket_size, self.k_bucket_size)
188 |
189 | h = self.heads
190 | context = default(context, x)
191 |
192 | q = self.to_q(x)
193 | k, v = self.to_kv(context).chunk(2, dim = -1)
194 |
195 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
196 |
197 | q, k = map(l2norm, (q, k))
198 |
199 | q = q * self.scale.exp()
200 |
201 | attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention
202 |
203 | out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
204 |
205 | out = rearrange(out, 'b h n d -> b n (h d)')
206 | return self.to_out(out)
207 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/reversible.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from operator import itemgetter
4 | from torch.autograd.function import Function
5 | from torch.utils.checkpoint import get_device_states, set_device_states
6 |
7 | # for routing arguments into the functions of the reversible layer
8 | def route_args(router, args, depth):
9 | routed_args = [(dict(), dict()) for _ in range(depth)]
10 | matched_keys = [key for key in args.keys() if key in router]
11 |
12 | for key in matched_keys:
13 | val = args[key]
14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17 | return routed_args
18 |
19 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
20 | class Deterministic(nn.Module):
21 | def __init__(self, net):
22 | super().__init__()
23 | self.net = net
24 | self.cpu_state = None
25 | self.cuda_in_fwd = None
26 | self.gpu_devices = None
27 | self.gpu_states = None
28 |
29 | def record_rng(self, *args):
30 | self.cpu_state = torch.get_rng_state()
31 | if torch.cuda._initialized:
32 | self.cuda_in_fwd = True
33 | self.gpu_devices, self.gpu_states = get_device_states(*args)
34 |
35 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
36 | if record_rng:
37 | self.record_rng(*args)
38 |
39 | if not set_rng:
40 | return self.net(*args, **kwargs)
41 |
42 | rng_devices = []
43 | if self.cuda_in_fwd:
44 | rng_devices = self.gpu_devices
45 |
46 | with torch.random.fork_rng(devices=rng_devices, enabled=True):
47 | torch.set_rng_state(self.cpu_state)
48 | if self.cuda_in_fwd:
49 | set_device_states(self.gpu_devices, self.gpu_states)
50 | return self.net(*args, **kwargs)
51 |
52 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
53 | # once multi-GPU is confirmed working, refactor and send PR back to source
54 | class ReversibleBlock(nn.Module):
55 | def __init__(self, f, g):
56 | super().__init__()
57 | self.f = Deterministic(f)
58 | self.g = Deterministic(g)
59 |
60 | def forward(self, x, f_args = {}, g_args = {}):
61 | x1, x2 = torch.chunk(x, 2, dim=2)
62 | y1, y2 = None, None
63 |
64 | with torch.no_grad():
65 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
66 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
67 |
68 | return torch.cat([y1, y2], dim=2)
69 |
70 | def backward_pass(self, y, dy, f_args = {}, g_args = {}):
71 | y1, y2 = torch.chunk(y, 2, dim=2)
72 | del y
73 |
74 | dy1, dy2 = torch.chunk(dy, 2, dim=2)
75 | del dy
76 |
77 | with torch.enable_grad():
78 | y1.requires_grad = True
79 | gy1 = self.g(y1, set_rng=True, **g_args)
80 | torch.autograd.backward(gy1, dy2)
81 |
82 | with torch.no_grad():
83 | x2 = y2 - gy1
84 | del y2, gy1
85 |
86 | dx1 = dy1 + y1.grad
87 | del dy1
88 | y1.grad = None
89 |
90 | with torch.enable_grad():
91 | x2.requires_grad = True
92 | fx2 = self.f(x2, set_rng=True, **f_args)
93 | torch.autograd.backward(fx2, dx1, retain_graph=True)
94 |
95 | with torch.no_grad():
96 | x1 = y1 - fx2
97 | del y1, fx2
98 |
99 | dx2 = dy2 + x2.grad
100 | del dy2
101 | x2.grad = None
102 |
103 | x = torch.cat([x1, x2.detach()], dim=2)
104 | dx = torch.cat([dx1, dx2], dim=2)
105 |
106 | return x, dx
107 |
108 | class _ReversibleFunction(Function):
109 | @staticmethod
110 | def forward(ctx, x, blocks, args):
111 | ctx.args = args
112 | for block, kwarg in zip(blocks, args):
113 | x = block(x, **kwarg)
114 | ctx.y = x.detach()
115 | ctx.blocks = blocks
116 | return x
117 |
118 | @staticmethod
119 | def backward(ctx, dy):
120 | y = ctx.y
121 | args = ctx.args
122 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
123 | y, dy = block.backward_pass(y, dy, **kwargs)
124 | return dy, None, None
125 |
126 | class ReversibleSequence(nn.Module):
127 | def __init__(self, blocks, args_route = {}):
128 | super().__init__()
129 | self.args_route = args_route
130 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
131 |
132 | def forward(self, x, **kwargs):
133 | x = torch.cat([x, x], dim=-1)
134 |
135 | blocks = self.blocks
136 | args = route_args(self.args_route, kwargs, len(blocks))
137 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
138 |
139 | layers_and_args = list(zip(blocks, args))
140 |
141 | out = _ReversibleFunction.apply(x, blocks, args)
142 | return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
143 |
--------------------------------------------------------------------------------
/memory_efficient_attention_pytorch/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from einops import rearrange
7 | from memory_efficient_attention_pytorch import FlashAttention, Attention
8 | from memory_efficient_attention_pytorch.reversible import ReversibleSequence
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | class PreNorm(nn.Module):
14 | def __init__(self, dim, fn):
15 | super().__init__()
16 | self.fn = fn
17 | self.norm = nn.LayerNorm(dim)
18 |
19 | def forward(self, x, **kwargs):
20 | x = self.norm(x)
21 | return self.fn(x, **kwargs)
22 |
23 | class FeedForward(nn.Module):
24 | def __init__(self, dim, mult = 4, chunks = 1):
25 | super().__init__()
26 | self.chunks = chunks
27 |
28 | self.net = nn.Sequential(
29 | nn.Linear(dim, dim * mult),
30 | nn.GELU(),
31 | nn.Linear(dim * mult, dim)
32 | )
33 |
34 | def forward(self, x):
35 | if self.chunks <= 1:
36 | return self.net(x)
37 |
38 | chunks = x.chunk(self.chunks, dim = 1)
39 | out = [self.net(chunk) for chunk in chunks]
40 | return torch.cat(out, dim = 1)
41 |
42 | class Transformer(nn.Module):
43 | def __init__(
44 | self,
45 | *,
46 | num_tokens,
47 | max_seq_len,
48 | dim,
49 | depth,
50 | causal = False,
51 | dim_head = 64,
52 | heads = 8,
53 | ff_mult = 4,
54 | ff_chunks = 1,
55 | use_flash_attn = True,
56 | **kwargs
57 | ):
58 | super().__init__()
59 | self.max_seq_len = max_seq_len
60 |
61 | self.token_emb = nn.Embedding(num_tokens, dim)
62 | self.pos_emb = nn.Embedding(max_seq_len, dim)
63 |
64 | attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True)
65 |
66 | self.layers = nn.ModuleList([])
67 | for _ in range(depth):
68 | self.layers.append(nn.ModuleList([
69 | PreNorm(dim, attn_klass(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)),
70 | PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)),
71 | ]))
72 |
73 | self.net = ReversibleSequence(self.layers)
74 |
75 | self.to_logits = nn.Sequential(
76 | nn.LayerNorm(dim),
77 | nn.Linear(dim, num_tokens)
78 | )
79 |
80 | def forward(self, x, labels = None):
81 | device = x.device
82 | x = self.token_emb(x)
83 |
84 | pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
85 | x = x + pos_emb
86 |
87 | x = self.net(x)
88 |
89 | logits = self.to_logits(x)
90 |
91 | if not exists(labels):
92 | return logits
93 |
94 | return F.cross_entropy(rearrange(logits, 'b n d -> b d n'), labels)
95 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --verbose
6 | python_files = tests/*.py
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'memory-efficient-attention-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.1.6',
7 | license='MIT',
8 | description = 'Memory Efficient Attention - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/memory-efficient-attention-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'attention-mechanism'
17 | ],
18 | install_requires=[
19 | 'einops>=0.4.1',
20 | 'torch>=1.6'
21 | ],
22 | setup_requires=[
23 | 'pytest-runner',
24 | ],
25 | tests_require=[
26 | 'pytest'
27 | ],
28 | classifiers=[
29 | 'Development Status :: 4 - Beta',
30 | 'Intended Audience :: Developers',
31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
32 | 'License :: OSI Approved :: MIT License',
33 | 'Programming Language :: Python :: 3.8',
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from memory_efficient_attention_pytorch import Attention
3 |
4 | from memory_efficient_attention_pytorch.memory_efficient_attention import attention
5 | from memory_efficient_attention_pytorch.flash_attention import FlashAttention, FlashAttentionFunction
6 |
7 | # constants
8 |
9 | def isclose(a, b, atol = 1e-6):
10 | diff = (a - b).abs().amax()
11 | return diff < atol
12 |
13 | # test outputs are equal
14 |
15 | def test_output_equal():
16 | attn = Attention(
17 | dim = 512,
18 | dim_head = 64,
19 | heads = 8,
20 | q_bucket_size = 64,
21 | k_bucket_size = 64,
22 | causal = True
23 | )
24 |
25 | x = torch.randn(2, 2048, 512)
26 | mask = torch.ones(2, 2048).bool()
27 |
28 | out = attn(x, mask = mask)
29 | mem_efficient_out = attn(x, mask = mask, memory_efficient = True)
30 |
31 | assert isclose(mem_efficient_out, out, atol = 1e-6)
32 |
33 | # test gradients equal
34 |
35 | def test_gradients_equal():
36 | attn = Attention(
37 | dim = 512,
38 | dim_head = 64,
39 | heads = 8,
40 | q_bucket_size = 64,
41 | k_bucket_size = 64,
42 | causal = True
43 | )
44 |
45 | def loss_fn(inp, **kwargs):
46 | return attn(inp, **kwargs).sum()
47 |
48 | x = torch.randn(2, 2048, 512).requires_grad_()
49 | mask = torch.ones(2, 2048).bool()
50 |
51 | loss_fn(x, mask = mask).backward()
52 | out_grad = x.grad.clone()
53 |
54 | x.grad.zero_()
55 | loss_fn(x, mask = mask, memory_efficient = True).backward()
56 | mem_efficient_out_grad = x.grad.clone()
57 |
58 | assert isclose(out_grad, mem_efficient_out_grad, atol = 1e-5)
59 |
60 | # test flash attention
61 |
62 | def test_flash_attn_output_equal():
63 | attn_kwargs = dict(
64 | dim = 512,
65 | dim_head = 64,
66 | heads = 8,
67 | q_bucket_size = 64,
68 | k_bucket_size = 64,
69 | causal = True
70 | )
71 |
72 | attn = Attention(**attn_kwargs)
73 | flash_attn = FlashAttention(**attn_kwargs)
74 |
75 | flash_attn.to_q = attn.to_q
76 | flash_attn.to_kv = attn.to_kv
77 | flash_attn.to_out = attn.to_out
78 |
79 | x = torch.randn(2, 2048, 512)
80 | mask = torch.ones(2, 2048).bool()
81 |
82 | out = attn(x, mask = mask)
83 | mem_efficient_out = flash_attn(x, mask = mask)
84 |
85 | assert isclose(mem_efficient_out, out, atol = 1e-6)
86 |
87 | # test gradients equal
88 |
89 | def test_flash_attn_gradients_equal():
90 | q = torch.randn(1, 8, 1024, 512).requires_grad_()
91 | k = torch.randn(1, 8, 1024, 512).requires_grad_()
92 | v = torch.randn(1, 8, 1024, 512).requires_grad_()
93 |
94 | mask = torch.ones(1, 1024).bool()
95 |
96 | o = attention(q, k, v, mask = mask, causal = True)
97 | o.sum().backward()
98 |
99 | dq_grad = q.grad.clone()
100 | dk_grad = k.grad.clone()
101 | dv_grad = v.grad.clone()
102 |
103 | q.grad.zero_()
104 | k.grad.zero_()
105 | v.grad.zero_()
106 |
107 | flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
108 | flash_o.sum().backward()
109 |
110 | flash_dq_grad = q.grad.clone()
111 | flash_dk_grad = k.grad.clone()
112 | flash_dv_grad = v.grad.clone()
113 |
114 | assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
115 | assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
116 | assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)
117 |
118 | # test flash attention - full attention mask
119 |
120 | def test_flash_attn_full_attn_mask_output_equal():
121 | attn_kwargs = dict(
122 | dim = 512,
123 | dim_head = 64,
124 | heads = 8,
125 | q_bucket_size = 64,
126 | k_bucket_size = 64,
127 | causal = True
128 | )
129 |
130 | attn = Attention(**attn_kwargs)
131 | flash_attn = FlashAttention(**attn_kwargs)
132 |
133 | flash_attn.to_q = attn.to_q
134 | flash_attn.to_kv = attn.to_kv
135 | flash_attn.to_out = attn.to_out
136 |
137 | x = torch.randn(2, 2048, 512)
138 | mask = torch.ones(2, 1, 2048, 2048).bool()
139 |
140 | out = attn(x, mask = mask)
141 | mem_efficient_out = flash_attn(x, mask = mask)
142 |
143 | assert isclose(mem_efficient_out, out, atol = 1e-6)
144 |
145 | # test gradients equal - full attention mask
146 |
147 | def test_flash_attn_full_attn_mask_gradients_equal():
148 | q = torch.randn(1, 8, 1024, 512).requires_grad_()
149 | k = torch.randn(1, 8, 1024, 512).requires_grad_()
150 | v = torch.randn(1, 8, 1024, 512).requires_grad_()
151 |
152 | mask = torch.ones(1, 1, 1024, 1024).bool()
153 |
154 | o = attention(q, k, v, mask = mask, causal = True)
155 | o.sum().backward()
156 |
157 | dq_grad = q.grad.clone()
158 | dk_grad = k.grad.clone()
159 | dv_grad = v.grad.clone()
160 |
161 | q.grad.zero_()
162 | k.grad.zero_()
163 | v.grad.zero_()
164 |
165 | flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
166 | flash_o.sum().backward()
167 |
168 | flash_dq_grad = q.grad.clone()
169 | flash_dk_grad = k.grad.clone()
170 | flash_dv_grad = v.grad.clone()
171 |
172 | assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
173 | assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
174 | assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)
175 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from memory_efficient_attention_pytorch.transformer import Transformer
2 | from memory_efficient_attention_pytorch.autoregressive_wrapper import AutoregressiveWrapper
3 |
4 | import random
5 | import tqdm
6 | import gzip
7 | import numpy as np
8 | import torch
9 | import torch.optim as optim
10 | from torch.nn import functional as F
11 | from torch.utils.data import DataLoader, Dataset
12 |
13 | # constants
14 |
15 | NUM_BATCHES = int(1e5)
16 | BATCH_SIZE = 4
17 | GRADIENT_ACCUMULATE_EVERY = 4
18 | LEARNING_RATE = 2e-4
19 | VALIDATE_EVERY = 100
20 | GENERATE_EVERY = 500
21 | GENERATE_LENGTH = 1024
22 | SEQ_LEN = 4096
23 |
24 | # helpers
25 |
26 | def cycle(loader):
27 | while True:
28 | for data in loader:
29 | yield data
30 |
31 | def decode_token(token):
32 | return str(chr(max(32, token)))
33 |
34 | def decode_tokens(tokens):
35 | return ''.join(list(map(decode_token, tokens)))
36 |
37 | # instantiate GPT-like decoder model
38 |
39 | model = Transformer(
40 | num_tokens = 256,
41 | dim = 512,
42 | max_seq_len = SEQ_LEN,
43 | depth = 6,
44 | heads = 8,
45 | causal = True,
46 | q_bucket_size = 256,
47 | k_bucket_size = 256,
48 | ff_chunks = 5,
49 | use_flash_attn = True
50 | )
51 |
52 | model = AutoregressiveWrapper(model)
53 | model.cuda()
54 |
55 | # prepare enwik8 data
56 |
57 | with gzip.open('./data/enwik8.gz') as file:
58 | X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
59 | trX, vaX = np.split(X, [int(90e6)])
60 | data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
61 |
62 | class TextSamplerDataset(Dataset):
63 | def __init__(self, data, seq_len):
64 | super().__init__()
65 | self.data = data
66 | self.seq_len = seq_len
67 |
68 | def __getitem__(self, index):
69 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
70 | full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
71 | return full_seq.cuda()
72 |
73 | def __len__(self):
74 | return self.data.size(0) // self.seq_len
75 |
76 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
77 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
78 | train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
79 | val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
80 |
81 | # optimizer
82 |
83 | optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
84 |
85 | # training
86 |
87 | for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
88 | model.train()
89 |
90 | for __ in range(GRADIENT_ACCUMULATE_EVERY):
91 | loss = model(next(train_loader))
92 | loss.backward()
93 |
94 | print(f'training loss: {loss.item()}')
95 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
96 | optim.step()
97 | optim.zero_grad()
98 |
99 | if i % VALIDATE_EVERY == 0:
100 | model.eval()
101 | with torch.no_grad():
102 | loss = model(next(val_loader))
103 | print(f'validation loss: {loss.item()}')
104 |
105 | if i != 0 and i % GENERATE_EVERY == 0:
106 | model.eval()
107 | inp = random.choice(val_dataset)[:-1]
108 | prime = decode_tokens(inp)
109 | print(f'%s \n\n %s', (prime, '*' * 100))
110 |
111 | sample = model.generate(inp[None, ...], GENERATE_LENGTH)
112 | output_str = decode_tokens(sample[0])
113 | print(output_str)
114 |
--------------------------------------------------------------------------------