├── .coveragerc
├── .github
├── FUNDING.yml
└── workflows
│ ├── publish.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── doc
├── attention.png
├── benchmark_attention.png
├── benchmark_t5.png
└── benchmark_t5_original.png
├── grouped_query_attention_pytorch
├── __init__.py
├── attention.py
├── t5.py
├── transformer.py
└── utils
│ ├── __init__.py
│ └── benchmark.py
├── pyproject.toml
├── scripts
├── README.md
├── benchmark_attention.py
└── benchmark_t5.py
└── tests
├── __init__.py
├── conftest.py
├── test_attention.py
├── test_t5.py
└── test_transformer.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = grouped_query_attention_pytorch/
3 | omit = grouped_query_attention_pytorch/utils/benchmark.py
4 |
5 | [report]
6 | exclude_lines =
7 | # Have to re-enable the standard pragma
8 | pragma: no cover
9 |
10 | # Don't complain about missing debug-only code:
11 | def __repr__
12 | if self\.debug
13 |
14 | # Don't complain if tests don't hit defensive assertion code:
15 | raise AssertionError
16 | raise NotImplementedError
17 |
18 | # Don't complain if non-importable code isn't run:
19 | if __name__ == .__main__.:
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [fkodom]
4 | custom: [fkodom.substack.com]
5 | # patreon: # Replace with a single Patreon username
6 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | workflow_dispatch: {}
5 | release:
6 | types:
7 | - created
8 |
9 | env:
10 | PYTHON_VERSION: 3.9
11 |
12 | jobs:
13 | publish:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v2
17 |
18 | - name: Setup Python
19 | uses: actions/setup-python@v2
20 | with:
21 | python-version: ${{ env.PYTHON_VERSION }}
22 |
23 | - name: Install Dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install build
27 |
28 | - name: Build Package
29 | env:
30 | GROUPED_QUERY_ATTENTION_PYTORCH_VERSION: ${{ github.event.release.tag_name }}
31 | run: python -m build
32 |
33 | - name: Publish to PyPI
34 | uses: pypa/gh-action-pypi-publish@release/v1
35 | with:
36 | user: __token__
37 | password: ${{ secrets.PYPI_API_TOKEN }}
38 | verify-metadata: false
39 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | workflow_dispatch: {}
5 | push: {}
6 |
7 | jobs:
8 | test:
9 | name: Test
10 | runs-on: ubuntu-latest
11 | continue-on-error: true
12 |
13 | strategy:
14 | matrix:
15 | python: ["3.8", "3.9", "3.10"]
16 |
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v2
20 |
21 | - name: Setup Python
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python }}
25 |
26 | - name: Templatize
27 | # Templatize the repo before attempting to run tests. (Tests will fail
28 | # otherwise, due to syntax errors.)
29 | # NOTE: Check if the 'templatize' script exists, so that this doesn't
30 | # immediately fail for repos that have already run the script.
31 | run: |
32 | if test -f "templatize"; then
33 | ./templatize
34 | fi
35 |
36 | - name: Install Package
37 | run: |
38 | pip install -e .[test,t5]
39 |
40 | - name: Test
41 | run: |
42 | ruff .
43 | pytest --cov --cov-report term-missing --cov-fail-under 80 tests/
44 | mypy
45 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: black
5 | name: black
6 | stages: [commit]
7 | language: system
8 | entry: black
9 | types: [python]
10 |
11 | - id: ruff
12 | name: ruff
13 | stages: [commit]
14 | language: system
15 | entry: ruff
16 | types: [python]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Frank Odom
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # grouped-query-attention-pytorch
2 |
3 | (Unofficial) PyTorch implementation of grouped-query attention (GQA) from [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/pdf/2305.13245.pdf)
4 |
5 |
6 |
7 | ### Includes:
8 | - [x] scaled dot-product attention with GQA support. (See: [scaled_dot_product_gqa usage](#scaled_dot_product_gqa))
9 | - [x] GQA multi-head attention layer. (See: [MultiheadGQA usage](#multiheadgqa))
10 | - [x] Code to convert pretrained T5 model to use GQA. (See: [T5 usage](#t5) )
11 | - [x] Prototype (untrained) GQA encoder-decoder models: `GQATransformer`, `GQATransformerLM` (See: [GQATransformer )usage](#gqatransformer))
12 | - [x] Reproduce runtime benchmarks from [GQA paper](https://arxiv.org/pdf/2305.13245.pdf), figure 6 (See: [scripts/)README.md](scripts/README.md))
13 |
14 | ### To do:
15 | - [ ] Fine-tuning code for T5 GQA models
16 | - [ ] Reproduce fine-tuning results from [GQA paper](https://arxiv.org/pdf/2305.13245.pdf), figures 3,5
17 |
18 | ## Install
19 |
20 | PyPI: (NOT YET AVAILABLE)
21 | ```bash
22 | pip install grouped-query-attention-pytorch
23 | ```
24 |
25 | From source:
26 | ```bash
27 | pip install "grouped-query-attention-pytorch @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"
28 | ```
29 |
30 | For contributors:
31 | ```bash
32 | # Install all dev dependencies (tests, T5 support, etc.)
33 | pip install "grouped-query-attention-pytorch[test,t5] @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"
34 | # Setup pre-commit hooks
35 | pre-commit install
36 | ```
37 |
38 |
39 | ## Benchmark
40 |
41 | I attempt to reproduce the runtime benchmarks from the [GQA paper](https://arxiv.org/pdf/2305.13245.pdf) (Figure 6). Unfortunately, I don't have access to the same hardware, so the comparison isn't perfect. (They use multiple high-end GPUs, and I use a single 2080 Ti.) Even with different hardware, though, it is clear that runtime scales similarly with the number of GQA groups.
42 |
43 | For more details, see [scripts/README.md](scripts/README.md#benchmark_t5)
44 |
45 | > Left: This repo
Right: Original paper
46 |
47 |
48 |
49 |
50 |
51 |
52 | ## Usage
53 |
54 | ### `scaled_dot_product_gqa`
55 |
56 | See: [attention.py](grouped_query_attention_pytorch/attention.py)
57 |
58 | Intended to be a drop-in replacement for `F.scaled_dot_product_attention` with support for GQA.
59 |
60 | > **NOTE**: The built-in `F.scaled_dot_product_attention` will be *much* faster when you're **not** using grouped queries -- especially for `torch>=2.0`, which uses [flash attention](https://github.com/Dao-AILab/flash-attention) under the hood. However, [this benchmark](./scripts/README.md#benchmark_attention) shows that naie `scaled_dot_product_gqa` is faster than flash attention when the number of GQA groups is small. 🔥
61 |
62 | ```python
63 | import torch
64 |
65 | from grouped_query_attention_pytorch.attention import scaled_dot_product_gqa
66 |
67 | # shapes: (batch_size, seq_len, num_heads, head_dim)
68 | query = torch.randn(1, 256, 8, 64, device="cuda", dtype=torch.float16)
69 | key = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
70 | value = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
71 |
72 | out, attn_weights = scaled_dot_product_gqa(
73 | query,
74 | key,
75 | value,
76 | is_causal=True, # default: False
77 | need_weights=True, # default: False, which returns 'attn_weights=None'
78 | )
79 | print(out.shape) # (batch_size, q_seq_len, kv_heads, embed_dim)
80 | # torch.Size([1, 256, 2, 64])
81 | print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads)
82 | # torch.Size([1, 256, 128, 2])
83 | ```
84 |
85 |
86 | ### `MultiheadGQA`
87 |
88 | See: [attention.py](grouped_query_attention_pytorch/attention.py)
89 |
90 | Intended to be a drop-in replacement for `nn.MultiheadAttention` with support for GQA.
91 |
92 | > **NOTE**: The same performance advice from [scaled_dot_product_gqa](#scaled_dot_product_gqa) (above) applies here as well.
93 |
94 | ```python
95 | from grouped_query_attention_pytorch.attention import MultiheadGQA
96 |
97 | mha = MultiheadGQA(
98 | embed_dim=512, query_heads=8, kv_heads=2, device="cuda", dtype=torch.float16
99 | )
100 |
101 | # shapes: (batch_size, seq_len, embed_dim)
102 | query = torch.randn(1, 256, 512, device="cuda", dtype=torch.float16)
103 | key = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
104 | value = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
105 |
106 | out, attn_weights = mha(
107 | query,
108 | key,
109 | value,
110 | is_causal=True, # default: False
111 | need_weights=True, # default: False, which returns 'attn_weights=None'
112 | )
113 | print(out.shape) # (batch_size, q_seq_len, embed_dim)
114 | # torch.Size([1, 256, 512])
115 | print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads)
116 | # torch.Size([1, 256, 128, 2])
117 | ```
118 |
119 |
120 | ### T5
121 |
122 | See: [t5.py](grouped_query_attention_pytorch/t5.py)
123 |
124 | Convert a pretrained T5 model from [huggingface/transformers](https://github.com/huggingface/transformers) to use GQA. The resulting model can be used and trained with the Huggingface Transformers library, just like an ordinary T5 model.
125 |
126 | ```python
127 | from transformers import T5ForConditionalGeneration, T5Tokenizer
128 |
129 | from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa
130 |
131 | # Initialize a pre-trained T5 model
132 | t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
133 | tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
134 | # Convert attention layers to GQA
135 | t5_gqa = convert_t5_to_gqa(t5, kv_heads=2, inplace=False) # default: inplace=False
136 |
137 | # Generate some text with the converted model
138 | input_ids = tokenizer(
139 | "translate English to German: The house is wonderful.", return_tensors="pt"
140 | ).input_ids
141 | outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
142 | text = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
143 | print(text)
144 | # The correct answer is: ['', 'Das', 'Haus', 'ist', 'wunderbar', '.', '']
145 | # NOTE: The original T5 model produces this answer, and so does GQA when we use the
146 | # maximum number of KV heads (kv_heads=8 in this example), which effectively makes
147 | # GQA equivalent to the original T5 model with MHA. The text quickly degrades as
148 | # we reduce the number of heads.
149 | ```
150 |
151 | ### GQATransformer
152 |
153 | I also provide a prototype implementation of an (untrained) encoder-decoder Transformer model, which uses GQA instead of MHA. This is mostly for reference/educational purposes, but in principle it could be used as a drop-in replacement for `nn.Transformer`.
154 |
155 | See: [transformer.py](grouped_query_attention_pytorch/transformer.py)
156 |
157 | ```python
158 | from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM
159 |
160 | device = torch.device("cuda")
161 | dtype = torch.float16
162 |
163 | net = GQATransformer(
164 | d_model=512, # required
165 | nhead=8, # required
166 | kv_heads=2, # required
167 | num_encoder_layers=6,
168 | num_decoder_layers=6,
169 | dim_feedforward=2048,
170 | dropout=0.1,
171 | activation="relu",
172 | layer_norm_eps=1e-5,
173 | device=device,
174 | dtype=dtype,
175 | )
176 | # shape: (batch_size, seq_len, d_model)
177 | x = torch.randn(1, 256, 512, device=device, dtype=dtype)
178 | with torch.no_grad():
179 | y = net.forward(x, is_causal=True) # default: is_causal=True
180 | print(y.shape)
181 | # torch.Size([1, 256, 512])
182 |
183 | num_tokens = 10000 # usually obtained from the tokenizer
184 | lm = GQATransformerLM(
185 | num_tokens=num_tokens, # required
186 | d_model=512, # required
187 | nhead=8, # required
188 | kv_heads=2, # required
189 | num_encoder_layers=6,
190 | num_decoder_layers=6,
191 | dim_feedforward=2048,
192 | dropout=0.1,
193 | activation="relu",
194 | layer_norm_eps=1e-5,
195 | device=device,
196 | dtype=dtype,
197 | )
198 | # shape: (batch_size, seq_len)
199 | x = torch.randint(0, num_tokens, (1, 256), device=device, dtype=torch.long)
200 | with torch.no_grad():
201 | y = lm.forward(x, is_causal=True) # default: is_causal=True
202 | print(y.shape)
203 | # torch.Size([1, 256, num_tokens])
204 | ```
205 |
--------------------------------------------------------------------------------
/doc/attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/attention.png
--------------------------------------------------------------------------------
/doc/benchmark_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_attention.png
--------------------------------------------------------------------------------
/doc/benchmark_t5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_t5.png
--------------------------------------------------------------------------------
/doc/benchmark_t5_original.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_t5_original.png
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from subprocess import getoutput
3 |
4 |
5 | def get_version_tag() -> str:
6 | try:
7 | env_key = "GROUPED_QUERY_ATTENTION_PYTORCH_VERSION".upper()
8 | version = os.environ[env_key]
9 | except KeyError:
10 | version = getoutput("git describe --tags --abbrev=0")
11 |
12 | if version.lower().startswith("fatal"):
13 | version = "0.0.0"
14 |
15 | return version
16 |
17 |
18 | VERSION = get_version_tag()
19 |
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/attention.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import einsum, rearrange
6 | from torch import Tensor, nn
7 |
8 |
9 | def scaled_dot_product_gqa(
10 | query: Tensor,
11 | key: Tensor,
12 | value: Tensor,
13 | dropout: float = 0.0,
14 | scale: Optional[float] = None,
15 | mask: Optional[Tensor] = None,
16 | is_causal: Optional[bool] = None,
17 | need_weights: bool = False,
18 | average_attn_weights: bool = False,
19 | force_grouped: bool = False,
20 | ):
21 | """Scaled dot product attention with support for grouped queries.
22 |
23 | Einstein notation:
24 | - b: batch size
25 | - n / s: sequence length
26 | - h: number of heads
27 | - g: number of groups
28 | - d: dimension of query/key/value
29 |
30 | Args:
31 | query: Query tensor of shape (b, n, h, d)
32 | key: Key tensor of shape (b, s, h, d)
33 | value: Value tensor of shape (b, s, h, d)
34 | dropout: Dropout probability (default: 0.0)
35 | scale: Scale factor for query (default: d_query ** 0.5)
36 | mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is
37 | applied to all 'n' rows of the attention matrix. (default: None)
38 | force_grouped: If True, apply grouped-query attention even if the number of
39 | heads is equal for query, key, and value. (default: False)
40 |
41 | Returns:
42 | 2-tuple of:
43 | - Attention output with shape (b, n, h, d)
44 | - (Optional) Attention weights with shape (b, h, n, s). Only returned if
45 | 'need_weights' is True.
46 | """
47 | if (mask is not None) and (is_causal is not None):
48 | raise ValueError(
49 | "Only one of 'mask' and 'is_causal' should be provided, but got both."
50 | )
51 | elif not query.ndim == key.ndim == value.ndim == 4:
52 | raise ValueError(
53 | f"Expected query, key, and value to be 4-dimensional, but got shapes "
54 | f"{query.shape}, {key.shape}, and {value.shape}."
55 | )
56 |
57 | # Move sequence length dimension to axis 2.
58 | # This makes the attention operations below *much* faster.
59 | query = rearrange(query, "b n h d -> b h n d")
60 | key = rearrange(key, "b s h d -> b h s d")
61 | value = rearrange(value, "b s h d -> b h s d")
62 |
63 | bq, hq, nq, dq = query.shape
64 | bk, hk, nk, dk = key.shape
65 | bv, hv, nv, dv = value.shape
66 | if not (bq == bk == bv and dq == dk == dv):
67 | raise ValueError(
68 | "Expected query, key, and value to have the same batch size (dim=0) and "
69 | f"embedding dimension (dim=3), but got query: {query.shape}, "
70 | f"key: {key.shape}, and value: {value.shape}."
71 | )
72 | elif (hk != hv) or (nk != nv):
73 | raise ValueError(
74 | "Expected key and value to have the same size in dimensions 1 and 2, but "
75 | f"got key: {key.shape} and value: {value.shape}."
76 | )
77 | elif hq % hk != 0:
78 | raise ValueError(
79 | "Expected query heads to be a multiple of key/value heads, but got "
80 | f"query: {query.shape} and key/value: {key.shape}."
81 | )
82 |
83 | if scale is None:
84 | scale = query.size(-1) ** 0.5
85 | query = query / scale
86 |
87 | num_head_groups = hq // hk
88 | query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
89 | similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s")
90 |
91 | if is_causal:
92 | # Mask out the upper triangular portion of the attention matrix. This prevents
93 | # the model from attending to tokens in the future.
94 | mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_()
95 |
96 | if mask is not None:
97 | # Expand mask to match the shape of the attention matrix.
98 | # If mask is 2D, assume that it is applied to the key/value sequence dimension.
99 | # Else if mask is 3D, assume that it is applied to the query/key/value sequence
100 | # dimension for all attention heads.
101 | #
102 | # Users could also provide a 4D mask, which is applied to the query/key/value
103 | # sequence dimension for each attention head (though I don't have a particular
104 | # use case in mind for that).
105 | if mask.ndim == 2:
106 | mask = rearrange(mask, "b s -> b () () () s")
107 | elif mask.ndim == 3:
108 | mask = rearrange(mask, "b n s -> b () () n s")
109 | # Mask similarity values by setting them to negative infinity. This guarantees
110 | # that they will not contribute to the softmax computation below.
111 | similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min)
112 |
113 | attention = F.softmax(similarity, dim=-1)
114 | if dropout > 0.0:
115 | attention = F.dropout(attention, p=dropout)
116 |
117 | # Apply attention matrix to the value Tensor.
118 | out = einsum(attention, value, "b g h n s, b h s d -> b g h n d")
119 | # Move head dimension back to axis 2
120 | out = rearrange(out, "b g h n d -> b n (h g) d")
121 |
122 | attn_weights: Optional[Tensor] = None
123 | if need_weights:
124 | # Move the sequence dimensions back to positions 1, 2. Move the head dimension
125 | # to position 3. This more closely matches the return shape of the attention
126 | # output: (b, n, h, d).
127 | attn_weights = rearrange(attention, "b g h n s -> b n s (h g)")
128 | if average_attn_weights:
129 | attn_weights = attn_weights.mean(dim=1)
130 |
131 | return out, attn_weights
132 |
133 |
134 | class MultiheadGQA(nn.Module):
135 | """Multi-head grouped query attention (GQA) layer.
136 |
137 | Reference:
138 | "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"
139 | https://arxiv.org/pdf/2305.13245v1.pdf
140 |
141 | GQA is a variant of multihead attention (MHA) that uses fewer write heads
142 | (key / value) than query heads. GQA can be viewed as a generalization of
143 | multi-query attention (MQA), which uses a single write head. GQA and MQA give
144 | significant speedups over standard MHA in decoder layers, with minimal loss in
145 | accuracy. In the paper, GQA is shown to be more accurate than MQA, while still
146 | having a significant speedup over MHA.
147 |
148 | NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model
149 | from MHA to GQA. As a result, they do not mention parameter initialization or
150 | layer normalization strategies. I follow the best practices laid out in the
151 | MAGNETO paper, which improves Transformer performance through better parameter
152 | initialization and layer norm placement. See:
153 | https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
154 | """
155 |
156 | def __init__(
157 | self,
158 | embed_dim: int,
159 | query_heads: int,
160 | kv_heads: int,
161 | dropout: float = 0.0,
162 | bias: bool = True,
163 | layer_norm: bool = True,
164 | layer_norm_eps: float = 1e-5,
165 | gamma_init: float = 1.0,
166 | device: Optional[Union[torch.device, str]] = None,
167 | dtype: Optional[torch.dtype] = None,
168 | ):
169 | super().__init__()
170 | self.query_heads = query_heads
171 | self.kv_heads = kv_heads
172 | self.dropout = dropout
173 | self.layer_norm = layer_norm
174 | self.gamma_init = gamma_init
175 |
176 | if self.query_heads % self.kv_heads != 0:
177 | raise ValueError(
178 | f"query_heads ({query_heads}) must be divisible by "
179 | f"kv_heads ({kv_heads})"
180 | )
181 | elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0):
182 | raise ValueError(
183 | f"embed_dim ({embed_dim}) must be divisible by "
184 | f"query_heads ({query_heads}) and kv_heads ({kv_heads})"
185 | )
186 |
187 | head_dim = embed_dim // query_heads
188 | if not head_dim % 8 == 0:
189 | raise ValueError(
190 | f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
191 | )
192 | if not head_dim <= 128:
193 | raise ValueError(
194 | f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
195 | )
196 |
197 | # Query projection layer is the same as in vanilla MHA.
198 | self.q_proj = nn.Linear(
199 | embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
200 | )
201 | # Key/value projection layers have a smaller output dimension, so that
202 | # the we have fewer key/value attention heads after reshaping.
203 | kv_embed_dim = embed_dim // query_heads * kv_heads
204 | self.k_proj = nn.Linear(
205 | embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
206 | )
207 | self.v_proj = nn.Linear(
208 | embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype
209 | )
210 | self.norm: Optional[nn.LayerNorm] = None
211 | if layer_norm:
212 | self.norm = nn.LayerNorm(
213 | embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
214 | )
215 | # Grouped attention output will have the same embedding dimension as the
216 | # key/value Tensors. So the output projection layer needs to accept the
217 | # same dimension (kv_embed_dim).
218 | self.out_proj = nn.Linear(
219 | embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
220 | )
221 |
222 | self._reset_parameters()
223 |
224 | def _reset_parameters(self):
225 | nn.init.xavier_normal_(self.q_proj.weight)
226 | if self.q_proj.bias is not None:
227 | nn.init.constant_(self.q_proj.bias, 0)
228 | nn.init.xavier_normal_(self.k_proj.weight)
229 | if self.k_proj.bias is not None:
230 | nn.init.constant_(self.k_proj.bias, 0)
231 |
232 | # NOTE: We follow the initialization strategy from MAGNETO. See:
233 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
234 | # Gain (self.gamma_init) should be provided as a keyword argument when
235 | # initializing the larger Transformer model, since it requires knowledge
236 | # of the number of encoder/decoder layers in the model.
237 |
238 | nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
239 | if self.v_proj.bias is not None:
240 | nn.init.constant_(self.v_proj.bias, 0)
241 | nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init)
242 | if self.out_proj.bias is not None:
243 | nn.init.constant_(self.out_proj.bias, 0)
244 |
245 | def forward(
246 | self,
247 | query: Tensor,
248 | key: Tensor,
249 | value: Tensor,
250 | need_weights: bool = False,
251 | # TODO
252 | # attn_mask: Optional[Tensor] = None,
253 | is_causal: bool = False,
254 | average_attn_weights: bool = False,
255 | ) -> Tuple[Tensor, Optional[Tensor]]:
256 | # Notation:
257 | # b - batch size
258 | # n - sequence length
259 | # h - number of heads
260 | # d - embedding dimension
261 | #
262 | # Input shape: (b, n, d)
263 | q: Tensor = self.q_proj(query)
264 | k: Tensor = self.k_proj(key)
265 | v: Tensor = self.v_proj(value)
266 |
267 | # Unfold 'd' dimension into 'h' separate attention heads.
268 | q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads)
269 | k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads)
270 | v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads)
271 | # Apply attention, then fold 'h' attention heads back into 'd'.
272 | x, attn = scaled_dot_product_gqa(
273 | query=q,
274 | key=k,
275 | value=v,
276 | # TODO
277 | # mask=attn_mask,
278 | is_causal=is_causal,
279 | need_weights=need_weights,
280 | average_attn_weights=average_attn_weights,
281 | force_grouped=False,
282 | )
283 | x = rearrange(x, "b n h d -> b n (h d)")
284 |
285 | # NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO
286 | # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra
287 | # layer norm before the linear output projection. The cross-attention layer in
288 | # the MAGNETO decoder does not include this layer norm, so users have the
289 | # option to disable it (layer_norm=False).
290 | if self.layer_norm:
291 | assert self.norm is not None
292 | x = self.norm(x)
293 | # Linear projection on attention outputs.
294 | x = self.out_proj(x)
295 |
296 | return x, attn
297 |
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/t5.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from copy import deepcopy
4 | from typing import TypeVar, overload
5 |
6 | import torch
7 | from einops import einsum, rearrange, repeat
8 | from torch import nn
9 | from transformers.models.t5.modeling_t5 import T5Attention
10 |
11 |
12 | class T5GQA(nn.Module):
13 | def __init__(
14 | self,
15 | is_decoder: bool,
16 | d_model: int,
17 | key_value_proj_dim: int,
18 | n_heads: int,
19 | kv_heads: int,
20 | dropout: float,
21 | has_relative_attention_bias: bool,
22 | relative_attention_num_buckets: int,
23 | relative_attention_max_distance: int,
24 | ):
25 | super().__init__()
26 | if n_heads % kv_heads != 0:
27 | raise ValueError(
28 | f"n_heads ({n_heads}) must be divisible by kv_heads ({kv_heads})"
29 | )
30 |
31 | self.is_decoder = is_decoder
32 | self.d_model = d_model
33 | self.key_value_proj_dim = key_value_proj_dim
34 | self.n_heads = n_heads
35 | # TODO: Check if we need to store 'kv_heads' and 'inner_dim' as a properties
36 | self.kv_heads = kv_heads
37 | self.dropout = dropout
38 | # NOTE: Relative attention bias typically only used in the first layer
39 | # of a `T5Stack` module.
40 | self.has_relative_attention_bias = has_relative_attention_bias
41 | self.relative_attention_num_buckets = relative_attention_num_buckets
42 | self.relative_attention_max_distance = relative_attention_max_distance
43 |
44 | self.inner_dim = self.n_heads * self.key_value_proj_dim
45 | self.kv_dim = self.kv_heads * self.key_value_proj_dim
46 |
47 | # Mesh TensorFlow initialization to avoid scaling before softmax
48 | # self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
49 | # self.k = nn.Linear(self.d_model, self.kv_dim, bias=False)
50 | # self.v = nn.Linear(self.d_model, self.kv_dim, bias=False)
51 | # self.o = nn.Linear(self.kv_dim, self.d_model, bias=False)
52 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
53 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
54 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
55 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
56 |
57 | if self.has_relative_attention_bias:
58 | self.relative_attention_bias = nn.Embedding(
59 | self.relative_attention_num_buckets, self.n_heads
60 | )
61 | self.pruned_heads = set() # type: ignore
62 | self.gradient_checkpointing = False
63 |
64 | self._relative_position_bucket = T5Attention._relative_position_bucket
65 |
66 | @classmethod
67 | def from_t5_attention(cls, t5: T5Attention, kv_heads: int) -> T5GQA:
68 | t5_gqa = T5GQA(
69 | is_decoder=t5.is_decoder,
70 | d_model=t5.d_model,
71 | key_value_proj_dim=t5.key_value_proj_dim,
72 | n_heads=t5.n_heads,
73 | kv_heads=kv_heads,
74 | dropout=t5.dropout,
75 | has_relative_attention_bias=t5.has_relative_attention_bias,
76 | relative_attention_num_buckets=t5.relative_attention_num_buckets,
77 | relative_attention_max_distance=t5.relative_attention_max_distance,
78 | )
79 |
80 | # Copy all of the weights verbatim from the original T5Attention module.
81 | # NOTE: In the T5 GQA implementation, all of the attention head aggregations
82 | # happen in the 'forward' method. The weights themselves are not modified.
83 | t5_gqa.q.weight.data = t5.q.weight.data
84 | t5_gqa.k.weight.data = t5.k.weight.data
85 | t5_gqa.v.weight.data = t5.v.weight.data
86 | t5_gqa.o.weight.data = t5.o.weight.data
87 | if t5.has_relative_attention_bias:
88 | t5_gqa.relative_attention_bias.weight.data = (
89 | t5.relative_attention_bias.weight.data
90 | )
91 |
92 | return t5_gqa
93 |
94 | def forward( # noqa: C901
95 | self,
96 | hidden_states,
97 | mask=None,
98 | key_value_states=None,
99 | position_bias=None,
100 | past_key_value=None,
101 | layer_head_mask=None,
102 | query_length=None,
103 | use_cache=False,
104 | output_attentions=False,
105 | ):
106 | """
107 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
108 | """
109 | # Input is (batch_size, seq_length, dim)
110 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
111 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
112 | batch_size, seq_length = hidden_states.shape[:2]
113 |
114 | real_seq_length = seq_length
115 |
116 | if past_key_value is not None:
117 | if len(past_key_value) != 2:
118 | raise ValueError(
119 | f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
120 | )
121 | real_seq_length += (
122 | past_key_value[0].shape[2] if query_length is None else query_length
123 | )
124 |
125 | key_length = (
126 | real_seq_length if key_value_states is None else key_value_states.shape[1]
127 | )
128 |
129 | def shape(states):
130 | """projection"""
131 | # NOTE: Changed from the original definition in T5Attention.
132 | sequence_length = states.shape[1]
133 | return states.view(
134 | batch_size, sequence_length, -1, self.key_value_proj_dim
135 | ).transpose(1, 2)
136 |
137 | def unshape(states):
138 | """reshape"""
139 | # NOTE: Changed from the original definition in T5Attention.
140 | sequence_length = states.shape[2]
141 | return (
142 | states.transpose(1, 2)
143 | .contiguous()
144 | .view(batch_size, sequence_length, -1)
145 | )
146 |
147 | def project(hidden_states, proj_layer, key_value_states, past_key_value):
148 | """projects hidden states correctly to key/query states"""
149 | if key_value_states is None:
150 | # self-attn
151 | # (batch_size, n_heads, seq_length, dim_per_head)
152 | hidden_states = shape(proj_layer(hidden_states))
153 | elif past_key_value is None:
154 | # cross-attn
155 | # (batch_size, n_heads, seq_length, dim_per_head)
156 | hidden_states = shape(proj_layer(key_value_states))
157 |
158 | if past_key_value is not None:
159 | if key_value_states is None:
160 | # self-attn
161 | # (batch_size, n_heads, key_length, dim_per_head)
162 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
163 | elif past_key_value.shape[2] != key_value_states.shape[1]:
164 | # checking that the `sequence_length` of the `past_key_value` is the same as
165 | # the provided `key_value_states` to support prefix tuning
166 | # cross-attn
167 | # (batch_size, n_heads, seq_length, dim_per_head)
168 | hidden_states = shape(proj_layer(key_value_states))
169 | else:
170 | # cross-attn
171 | hidden_states = past_key_value
172 | return hidden_states
173 |
174 | # get query states: (batch_size, n_heads, seq_length, dim_per_head)
175 | grouped_queries = shape(self.q(hidden_states))
176 | # get key/value states
177 | key_states = project(
178 | hidden_states,
179 | self.k,
180 | key_value_states,
181 | past_key_value[0] if past_key_value is not None else None,
182 | )
183 | value_states = project(
184 | hidden_states,
185 | self.v,
186 | key_value_states,
187 | past_key_value[1] if past_key_value is not None else None,
188 | )
189 |
190 | # # compute scores
191 | # scores = torch.matmul(
192 | # query_states, key_states.transpose(3, 2)
193 | # ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
194 | grouped_queries = rearrange(
195 | grouped_queries, "b (g h) n d -> b g h n d", h=self.kv_heads
196 | )
197 | grouped_keys = rearrange(
198 | key_states, "b (g h) s d -> b g h s d", h=self.kv_heads
199 | ).mean(dim=1)
200 | scores = einsum(grouped_queries, grouped_keys, "b g h n d, b h s d -> b h n s")
201 |
202 | if position_bias is None:
203 | if not self.has_relative_attention_bias:
204 | position_bias = torch.zeros(
205 | # NOTE: This is different from the original in T5Attention!
206 | # (1, self.n_heads, real_seq_length, key_length),
207 | (1, self.kv_heads, real_seq_length, key_length),
208 | device=scores.device,
209 | dtype=scores.dtype,
210 | )
211 | if self.gradient_checkpointing and self.training:
212 | position_bias.requires_grad = True
213 | else:
214 | position_bias = T5Attention.compute_bias(
215 | self, real_seq_length, key_length, device=scores.device
216 | )
217 |
218 | # if key and values are already calculated
219 | # we want only the last query position bias
220 | if past_key_value is not None:
221 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
222 |
223 | if mask is not None:
224 | # (batch_size, n_heads, seq_length, key_length)
225 | position_bias = position_bias + mask
226 |
227 | if self.pruned_heads:
228 | mask = torch.ones(position_bias.shape[1])
229 | mask[list(self.pruned_heads)] = 0
230 | position_bias_masked = position_bias[:, mask.bool()]
231 | else:
232 | position_bias_masked = position_bias
233 |
234 | # NOTE: This is different from the original in T5Attention!
235 | grouped_position_bias = rearrange(
236 | position_bias_masked, "b (g h) n s -> b g h n s", h=self.kv_heads
237 | ).mean(dim=1)
238 |
239 | scores += grouped_position_bias
240 | # attn_weights: (batch_size, kv_heads, seq_length, key_length)
241 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
242 | attn_weights = nn.functional.dropout(
243 | attn_weights, p=self.dropout, training=self.training
244 | )
245 |
246 | # Mask heads if we want to
247 | if layer_head_mask is not None:
248 | attn_weights = attn_weights * layer_head_mask
249 |
250 | # NOTE: This is different from the original in T5Attention!
251 | # attn_output = unshape(torch.matmul(attn_weights, value_states))
252 | grouped_values = rearrange(
253 | value_states, "b (g h) s d -> b g h s d", h=self.kv_heads
254 | ).mean(dim=1)
255 | attn_output = unshape(torch.matmul(attn_weights, grouped_values))
256 | attn_output = repeat(
257 | attn_output, "b s d -> b s (g d)", g=(self.n_heads // self.kv_heads)
258 | )
259 | attn_output = self.o(attn_output)
260 |
261 | present_key_value_state = (
262 | (key_states, value_states) if (self.is_decoder and use_cache) else None
263 | )
264 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
265 |
266 | if output_attentions:
267 | outputs = outputs + (attn_weights,) # type: ignore
268 | return outputs
269 |
270 |
271 | ModuleType = TypeVar("ModuleType", bound=nn.Module)
272 |
273 |
274 | @overload
275 | def convert_t5_to_gqa(
276 | module: ModuleType, kv_heads: int, inplace: bool = False
277 | ) -> ModuleType:
278 | ...
279 |
280 |
281 | @overload
282 | def convert_t5_to_gqa(
283 | module: T5Attention, kv_heads: int, inplace: bool = False
284 | ) -> T5GQA:
285 | ...
286 |
287 |
288 | def convert_t5_to_gqa(module, kv_heads: int, inplace: bool = False):
289 | if isinstance(module, T5Attention):
290 | return T5GQA.from_t5_attention(module, kv_heads=kv_heads)
291 |
292 | out = module if inplace else deepcopy(module)
293 | for name, child in out.named_children():
294 | out._modules[name] = convert_t5_to_gqa(child, kv_heads=kv_heads, inplace=True)
295 | return out
296 |
297 |
298 | if __name__ == "__main__":
299 | from transformers import T5ForConditionalGeneration, T5Tokenizer
300 |
301 | # NOTE: The original paper uses T5 v1.1 XL and XXL models. When I load those
302 | # models through 'transformers' without applying GQA, I get nonsense outputs.
303 | # TODO: Figure out why this is happening.
304 | # tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-large", legacy=False)
305 | # model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-large")
306 | #
307 | # In the meantime, we can use the non-Google T5 models, which seem to work fine.
308 | # NOTE: Since the the original number of heads (n_heads) must be divisible by
309 | # 'kv_heads', there are only certain values of 'kv_heads' that we can use.
310 | # To the best of my knowledge, the following values of 'kv_heads' are valid:
311 | # - t5-small: 1, 2, 4, 8
312 | # - t5-base: 1, 2, 3, 4, 6, 12
313 | # - t5-large: 1, 2, 4, 8, 16
314 | # - t5-3b: 1, 2, 4, 8, 16, 32
315 | # - t5-11b: 1, 2, 4, 8, 16, 32, 64 TODO: Check 11b values specifically
316 |
317 | tokenizer = T5Tokenizer.from_pretrained(
318 | "t5-base", legacy=False, model_max_length=512
319 | )
320 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
321 | "t5-base"
322 | )
323 | gqa = convert_t5_to_gqa(t5, kv_heads=6)
324 |
325 | input_ids = tokenizer(
326 | "translate English to German: The house is wonderful.", return_tensors="pt"
327 | ).input_ids
328 | y2 = gqa.generate(input_ids, max_new_tokens=25)
329 | text = tokenizer.batch_decode(y2[0], skip_special_tokens=True)
330 | print(text)
331 | # The correct answer is: ['', 'Das', 'Haus', 'ist', 'wunderbar', '.', '']
332 | # NOTE: The original T5 model produces this answer, and so does GQA when we use
333 | # the maximum number of heads -- effectively equivalent to the original T5 model
334 | # with MHA. The text quickly degrades as we reduce the number of heads.
335 |
336 | labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
337 | loss = gqa(input_ids=input_ids, labels=labels).loss
338 | print(f"Loss: {loss}")
339 | # NOTE: As above, the loss quickly degrades (increases) as we reduce the number
340 | # of GQA heads.
341 |
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/transformer.py:
--------------------------------------------------------------------------------
1 | from math import log
2 | from typing import Callable, Optional, Union
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import Tensor, nn
7 | from torch.nn.modules.transformer import _get_activation_fn
8 | from torchscale.component.xpos_relative_position import XPOS
9 |
10 | from grouped_query_attention_pytorch.attention import MultiheadGQA
11 |
12 |
13 | class GQATransformerEncoderLayer(nn.Module):
14 | # NOTE: Mostly pulled from 'nn.TransformerEncoderLayer', but with changes:
15 | # - use sub-LayerNorm like in MAGNETO. See: https://arxiv.org/abs/2210.06423
16 | # - use MultiheadGQA instead of MultiheadAttention
17 |
18 | def __init__(
19 | self,
20 | d_model: int,
21 | nhead: int,
22 | kv_heads: int,
23 | dim_feedforward: int = 2048,
24 | dropout: float = 0.1,
25 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
26 | layer_norm_eps: float = 1e-5,
27 | gamma_init: float = 1.0,
28 | device: Optional[Union[torch.device, str]] = None,
29 | dtype: Optional[torch.dtype] = None,
30 | ) -> None:
31 | super().__init__()
32 | # Legacy string support for activation function.
33 | if isinstance(activation, str):
34 | activation = _get_activation_fn(activation)
35 |
36 | self.activation = activation
37 | self.gamma_init = gamma_init
38 |
39 | self.dropout = nn.Dropout(dropout)
40 | # Self-attention block
41 | self.norm1 = nn.LayerNorm(
42 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
43 | )
44 | self.self_attn = MultiheadGQA( # type: ignore
45 | embed_dim=d_model,
46 | query_heads=nhead,
47 | kv_heads=kv_heads,
48 | dropout=dropout,
49 | layer_norm=True,
50 | layer_norm_eps=layer_norm_eps,
51 | gamma_init=gamma_init,
52 | device=device,
53 | dtype=dtype,
54 | )
55 | # Feedforward block
56 | self.norm2 = nn.LayerNorm(
57 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
58 | )
59 | self.linear1 = nn.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
60 | self.norm3 = nn.LayerNorm(
61 | dim_feedforward, eps=layer_norm_eps, device=device, dtype=dtype
62 | )
63 | self.linear2 = nn.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
64 |
65 | self._reset_parameters()
66 |
67 | def _reset_parameters(self):
68 | # NOTE: We follow the initialization strategy from MAGNETO. See:
69 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
70 | # The 'MultiheadGQA' module uses ths same initialization,
71 | # so we just need to worry about the 'Linear' modules here.
72 | nn.init.xavier_normal_(self.linear1.weight, gain=self.gamma_init)
73 | nn.init.constant_(self.linear1.bias, 0)
74 | nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init)
75 | nn.init.constant_(self.linear2.bias, 0)
76 |
77 | def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor:
78 | x = self.norm1(x)
79 | x, _ = self.self_attn(x, x, x, is_causal=is_causal)
80 | x = self.dropout(x)
81 | return x
82 |
83 | def _feedforward_block(self, x: Tensor) -> Tensor:
84 | x = self.norm2(x)
85 | x = self.activation(self.linear1(x))
86 | x = self.dropout(x)
87 | x = self.norm3(x)
88 | x = self.linear2(x)
89 | x = self.dropout(x)
90 | return x
91 |
92 | def forward(self, src: Tensor, is_causal: bool = False) -> Tensor:
93 | x = src
94 | x = x + self._self_attention_block(x, is_causal=is_causal)
95 | x = x + self._feedforward_block(x)
96 | return x
97 |
98 |
99 | class GQATransformerDecoderLayer(nn.Module):
100 | # NOTE: Mostly pulled from 'nn.TransformerDecoderLayer', but with changes:
101 | # - use sub-LayerNorm like in MAGNETO. See: https://arxiv.org/abs/2210.06423
102 | # - use MultiheadGQA instead of MultiheadAttention
103 |
104 | def __init__(
105 | self,
106 | d_model: int,
107 | nhead: int,
108 | kv_heads: int,
109 | dim_feedforward: int = 2048,
110 | dropout: float = 0.1,
111 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
112 | layer_norm_eps: float = 1e-5,
113 | gamma_init: float = 1.0,
114 | device: Optional[Union[torch.device, str]] = None,
115 | dtype: Optional[torch.dtype] = None,
116 | ) -> None:
117 | super().__init__()
118 | # Legacy string support for activation function.
119 | if isinstance(activation, str):
120 | activation = _get_activation_fn(activation)
121 |
122 | self.activation = activation
123 | self.gamma_init = gamma_init
124 |
125 | self.dropout = nn.Dropout(dropout)
126 | # Self-attention block
127 | self.norm1 = nn.LayerNorm(
128 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
129 | )
130 | self.self_attn = MultiheadGQA( # type: ignore
131 | embed_dim=d_model,
132 | query_heads=nhead,
133 | kv_heads=kv_heads,
134 | dropout=dropout,
135 | layer_norm=False,
136 | gamma_init=gamma_init,
137 | device=device,
138 | dtype=dtype,
139 | )
140 | # Multi-head attention block
141 | self.norm2 = nn.LayerNorm(
142 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
143 | )
144 | self.multihead_attn = MultiheadGQA( # type: ignore
145 | embed_dim=d_model,
146 | query_heads=nhead,
147 | kv_heads=kv_heads,
148 | dropout=dropout,
149 | layer_norm_eps=layer_norm_eps,
150 | gamma_init=gamma_init,
151 | device=device,
152 | dtype=dtype,
153 | )
154 | # Feedforward block
155 | self.norm3 = nn.LayerNorm(
156 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
157 | )
158 | self.linear1 = nn.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
159 | self.norm4 = nn.LayerNorm(
160 | dim_feedforward, eps=layer_norm_eps, device=device, dtype=dtype
161 | )
162 | self.linear2 = nn.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
163 |
164 | self._reset_parameters()
165 |
166 | def _reset_parameters(self):
167 | # NOTE: We follow the initialization strategy from MAGNETO. See:
168 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2
169 | # The 'MultiheadGQA' module uses ths same initialization,
170 | # so we just need to worry about the 'Linear' modules here.
171 | nn.init.xavier_normal_(self.linear1.weight, gain=self.gamma_init)
172 | nn.init.constant_(self.linear1.bias, 0)
173 | nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init)
174 | nn.init.constant_(self.linear2.bias, 0)
175 |
176 | def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor:
177 | x = self.norm1(x)
178 | x, _ = self.self_attn(x, x, x, is_causal=is_causal)
179 | x = self.dropout(x)
180 | return x
181 |
182 | def _multihead_attention_block(
183 | self, x: Tensor, memory: Tensor, is_causal: bool = False
184 | ) -> Tensor:
185 | x = self.norm2(x)
186 | x, _ = self.multihead_attn(x, memory, memory, is_causal=is_causal)
187 | x = self.dropout(x)
188 | return x
189 |
190 | def _feedforward_block(self, x: Tensor) -> Tensor:
191 | x = self.norm3(x)
192 | x = self.activation(self.linear1(x))
193 | x = self.dropout(x)
194 | x = self.norm4(x)
195 | x = self.linear2(x)
196 | x = self.dropout(x)
197 | return x
198 |
199 | def forward(
200 | self,
201 | tgt: Tensor,
202 | memory: Tensor,
203 | tgt_is_causal: bool = False,
204 | memory_is_causal: bool = False,
205 | ) -> Tensor:
206 | x = tgt
207 | x = x + self._self_attention_block(x, is_causal=tgt_is_causal)
208 | x = x + self._multihead_attention_block(x, memory, is_causal=memory_is_causal)
209 | x = x + self._feedforward_block(x)
210 | return x
211 |
212 |
213 | class GQATransformer(nn.Module):
214 | def __init__(
215 | self,
216 | d_model: int = 512,
217 | nhead: int = 8,
218 | kv_heads: int = 4,
219 | num_encoder_layers: int = 6,
220 | num_decoder_layers: int = 6,
221 | dim_feedforward: int = 2048,
222 | dropout: float = 0.1,
223 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
224 | layer_norm_eps: float = 1e-5,
225 | device: Optional[torch.device] = None,
226 | dtype: Optional[torch.dtype] = None,
227 | ):
228 | super().__init__()
229 | # The 'gamma_init' parameters are different for the encoder and decoder,
230 | # and depend on the number of encoder/decoder layers. See MAGNETO paper:
231 | # https://arxiv.org/pdf/2210.06423.pdf, Figure 2
232 | encoder_gamma_init = (
233 | log(3 * num_decoder_layers) * log(2 * num_encoder_layers) / 3
234 | ) ** 0.5
235 | decoder_gamma_init = log(3 * num_decoder_layers) ** 0.5
236 |
237 | self.encoder = nn.TransformerEncoder(
238 | encoder_layer=GQATransformerEncoderLayer(
239 | d_model=d_model,
240 | nhead=nhead,
241 | kv_heads=kv_heads,
242 | dim_feedforward=dim_feedforward,
243 | dropout=dropout,
244 | activation=activation,
245 | layer_norm_eps=layer_norm_eps,
246 | gamma_init=encoder_gamma_init,
247 | device=device,
248 | dtype=dtype,
249 | ),
250 | num_layers=num_encoder_layers,
251 | mask_check=False,
252 | enable_nested_tensor=False,
253 | )
254 | self.decoder = nn.TransformerDecoder(
255 | decoder_layer=GQATransformerDecoderLayer(
256 | d_model=d_model,
257 | nhead=nhead,
258 | kv_heads=kv_heads,
259 | dim_feedforward=dim_feedforward,
260 | dropout=dropout,
261 | activation=activation,
262 | layer_norm_eps=layer_norm_eps,
263 | gamma_init=decoder_gamma_init,
264 | device=device,
265 | dtype=dtype,
266 | ),
267 | num_layers=num_decoder_layers,
268 | )
269 |
270 | def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
271 | """
272 | Input shape: (batch_size, seq_len, d_model)
273 | Output shape: (batch_size, seq_len, d_model)
274 |
275 | NOTE: Assume that 'is_causal' applies to both the encoder and decoder.
276 | This is the case for language modeling, but maybe not for other tasks.
277 | """
278 | tgt = x
279 | for layer in self.encoder.layers:
280 | x = layer(x, is_causal=is_causal)
281 | if self.encoder.norm is not None:
282 | x = self.encoder.norm(x)
283 |
284 | mem = x
285 | for layer in self.decoder.layers:
286 | tgt = layer(tgt, mem, memory_is_causal=is_causal, tgt_is_causal=is_causal)
287 | if self.decoder.norm is not None:
288 | tgt = self.decoder.norm(tgt)
289 |
290 | return tgt
291 |
292 |
293 | class GQATransformerLM(nn.Module):
294 | def __init__(
295 | self,
296 | num_tokens: int, # (required) usually obtained from the tokenizer
297 | d_model: int = 512,
298 | nhead: int = 8,
299 | kv_heads: int = 4,
300 | num_encoder_layers: int = 6,
301 | num_decoder_layers: int = 6,
302 | dim_feedforward: int = 2048,
303 | dropout: float = 0.1,
304 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
305 | layer_norm_eps: float = 1e-5,
306 | device: Optional[torch.device] = None,
307 | dtype: Optional[torch.dtype] = None,
308 | ):
309 | super().__init__()
310 | self.token_embedding = nn.Embedding(
311 | num_tokens, d_model, device=device, dtype=dtype
312 | )
313 | # TODO: Add support for other positional encodings? I use XPOS, which is the
314 | # "latest and greatest" at the time of writing. In principle, we could swap
315 | # it out for any other encoding, and remove the 'torchscale' dependency for this
316 | # repo, which is only used for XPOS.
317 | self.pos_embedding = XPOS(d_model).to(device=device, dtype=dtype)
318 | self.transformer = GQATransformer(
319 | d_model=d_model,
320 | nhead=nhead,
321 | kv_heads=kv_heads,
322 | num_encoder_layers=num_encoder_layers,
323 | num_decoder_layers=num_decoder_layers,
324 | dim_feedforward=dim_feedforward,
325 | dropout=dropout,
326 | activation=activation,
327 | layer_norm_eps=layer_norm_eps,
328 | device=device,
329 | dtype=dtype,
330 | )
331 | self.norm = nn.LayerNorm(
332 | d_model, eps=layer_norm_eps, device=device, dtype=dtype
333 | )
334 | self.out = nn.Linear(d_model, num_tokens, device=device, dtype=dtype)
335 |
336 | def _reset_parameters(self):
337 | nn.init.kaiming_normal_(self.out.weight)
338 | nn.init.constant_(self.out.bias, 0)
339 |
340 | def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
341 | x = self.token_embedding(x)
342 | x = x + self.pos_embedding(x)
343 | x = self.transformer(x, is_causal=is_causal)
344 | x = self.norm(x)
345 | return self.out(x)
346 |
347 |
348 | if __name__ == "__main__":
349 | num_tokens = 2048
350 | device = torch.device("cuda")
351 | dtype = torch.float16
352 |
353 | x = torch.randint(0, num_tokens - 1, size=(2, 512), device=device)
354 | model = GQATransformerLM(num_tokens=num_tokens, device=device, dtype=dtype)
355 |
356 | with torch.no_grad():
357 | out = model(x)
358 | print(out.shape)
359 |
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/grouped_query_attention_pytorch/utils/__init__.py
--------------------------------------------------------------------------------
/grouped_query_attention_pytorch/utils/benchmark.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from timeit import Timer
3 | from typing import Callable, List, NamedTuple
4 |
5 | import torch
6 |
7 |
8 | class BenchmarkResult(NamedTuple):
9 | mean: float
10 | std: float
11 |
12 | def __repr__(self):
13 | return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})"
14 |
15 | def __str__(self):
16 | return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s"
17 |
18 |
19 | @torch.no_grad()
20 | def benchmark(
21 | fn: Callable,
22 | *args,
23 | min_total_seconds: float = 1.0,
24 | min_iterations: int = 10,
25 | **kwargs,
26 | ) -> BenchmarkResult:
27 | # Benchmark the runtime of a function and dynamically determine the number of
28 | # iterations to run. Continue running the function until *total* runtime
29 | # exceeds 'min_total_seconds' and 'min_iterations'.
30 | if min_iterations < 2:
31 | raise ValueError("min_iterations must be >= 2")
32 |
33 | timer = Timer(
34 | "fn(*args, **kwargs); synchronize()",
35 | globals={
36 | "fn": fn,
37 | "args": args,
38 | "kwargs": kwargs,
39 | "synchronize": torch.cuda.synchronize,
40 | },
41 | )
42 | # Run the function 5 times to warm up
43 | _ = timer.repeat(number=1, repeat=5)
44 |
45 | times: List[float] = []
46 | total_time = 0.0
47 | num_iterations = min_iterations
48 |
49 | while total_time < min_total_seconds:
50 | _times = timer.repeat(number=1, repeat=num_iterations)
51 | times.extend(_times)
52 |
53 | times_tensor = torch.as_tensor(times)
54 | total_time = times_tensor.sum().item()
55 | avg_time = times_tensor.mean().item()
56 | num_iterations = ceil((min_total_seconds - total_time) / avg_time)
57 |
58 | times_tensor = torch.as_tensor(times)
59 | return BenchmarkResult(
60 | mean=times_tensor.mean().item(),
61 | std=times_tensor.std().item(),
62 | )
63 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "setuptools-scm"]
3 |
4 | [project]
5 | name = "grouped_query_attention_pytorch"
6 | authors = [
7 | {name = "Frank Odom", email = "frank.odom.iii@gmail.com"},
8 | ]
9 | description = "grouped-query-attention-pytorch"
10 | license = {text = "MIT"}
11 | dynamic = [
12 | "version",
13 | "readme",
14 | ] # NOTE: Must be in sync with [tool.setuptools.dynamic] below
15 | dependencies = [
16 | # TODO: Check the full range of supported versions
17 | "einops~=0.6.0",
18 | "torch>=1.8.0",
19 | "torchscale~=0.2.0",
20 | ]
21 | requires-python = ">=3.8"
22 | classifiers = ["Programming Language :: Python :: 3"]
23 |
24 | [tool.setuptools.dynamic]
25 | # NOTE: Must be in sync with 'project.dynamic' above
26 | version = {attr = "grouped_query_attention_pytorch.VERSION"}
27 | readme = {file = ["README.md"], content-type = "text/markdown"}
28 |
29 | [tool.setuptools.packages.find]
30 | exclude = ["tests"]
31 |
32 | # --- extra packages ---
33 | [project.optional-dependencies]
34 | test = [
35 | "black",
36 | "kaleido",
37 | "mypy",
38 | "plotly",
39 | "pre-commit",
40 | "pytest",
41 | "pytest-cov",
42 | "ruff",
43 | "xformers==0.0.20",
44 | ]
45 | t5 = [
46 | "sentencepiece",
47 | "transformers>=4.5.0,<4.32",
48 | ]
49 |
50 | [project.scripts]
51 | # Entrypoint scripts
52 |
53 |
54 | # ----- Linting, Formatting, and Typing -----
55 |
56 | [tool.black]
57 | line-length = 88
58 |
59 | [tool.mypy]
60 | files = "grouped_query_attention_pytorch/"
61 | check_untyped_defs = "true"
62 | ignore_missing_imports = "true"
63 |
64 | [tool.pytest.ini_options]
65 | testpaths = ["tests"]
66 | addopts = "--cov --cov-report term-missing --cov-fail-under 80"
67 | filterwarnings = "ignore:.*.:DeprecationWarning"
68 |
69 | [tool.ruff]
70 | line-length = 88
71 | ignore = ["B905", "E501"]
72 | select = ["B", "C", "E", "F", "I", "W"]
73 | # Exclude a variety of commonly ignored directories.
74 | exclude = [
75 | ".bzr",
76 | ".direnv",
77 | ".eggs",
78 | ".git",
79 | ".hg",
80 | ".mypy_cache",
81 | ".nox",
82 | ".pants.d",
83 | ".ruff_cache",
84 | ".svn",
85 | ".tox",
86 | ".venv",
87 | "__pypackages__",
88 | "_build",
89 | "buck-out",
90 | "build",
91 | "dist",
92 | "node_modules",
93 | "venv",
94 | ]
95 |
96 | [tool.ruff.mccabe]
97 | max-complexity = 18
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # scripts
2 |
3 | Scripts should be launched from the root of this repository. Ex:
4 |
5 | ```bash
6 | python -m scripts.benchmark_attention
7 | ```
8 |
9 |
10 | ## benchmark_attention
11 |
12 | Gather runtime benchmarks for scaled dot product attention with grouped queries. In particular, we want to see how the runtime scales with the number of queries. We compare vanilla (naive) attention compares to grouped attention. We also compare against `xformers.ops.memory_efficient_attention` as a strong baseline.
13 |
14 | ### Reference
15 | The original paper benchmarks end-to-end runtime of the T5 model (https://arxiv.org/pdf/2305.13245v1.pdf, Figure 6). We do that in a separate benchmark (see `scripts/benchmark_t5.py`). Here, we focus on the attention layer itself.
16 |
17 | ### Results
18 |
19 | Clearly, runtime scales similarly with the number of GQA groups.
20 |
21 | Even through `xformers` is much faster than the naive implementation, GQA is still faster when the number of groups is small. Hopefully, someone will write an efficient CUDA implementation for GQA, so we can get the best of both worlds. Unfortunately, I likely don't have the CUDA experience to do it myself. :(
22 |
23 | #### This repo (attention layer only)
24 |
25 |
26 |
27 | #### Original paper (end-to-end T5)
28 |
29 |
30 |
31 |
32 | ## benchmark_t5
33 |
34 | Similar to [benchmark_attention](#benchmark_attention) above, but we benchmark the end-to-end T5 model. We compare the original T5 implementation with MHA to T5 with converted GQA.
35 |
36 | The same hardware differences apply as above. The original paper benchmarked T5-XXL (11B params), which does not fit in my GPU memory. Instead, I benchmark T5-3B, which is the largest T5 variant that will fit in memory. T5-3B only has 32 attention heads, so my benchmarks only go up to 32 GQA groups. (The original benchmarks with T5-XXL go up to 64 groups.) I use an input sequence length of 512 (similar to training) and a batch size of 8 (original uses 32).
37 |
38 | ### Reference
39 | https://arxiv.org/pdf/2305.13245v1.pdf, Figure 6
40 |
41 | ### Results
42 |
43 | Again, it's clear that runtime scales similarly with the number of GQA groups.
44 |
45 | #### This repo
46 |
47 |
48 |
49 | #### Original paper
50 |
51 |
52 |
--------------------------------------------------------------------------------
/scripts/benchmark_attention.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Sequence
3 |
4 | import plotly.graph_objects as go
5 | import torch
6 | import xformers.ops as xops
7 |
8 | from grouped_query_attention_pytorch.attention import (
9 | scaled_dot_product_gqa,
10 | )
11 | from grouped_query_attention_pytorch.utils.benchmark import BenchmarkResult, benchmark
12 |
13 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
15 |
16 | # Benchmarking parameters
17 | NUM_GROUPS = [1, 4, 8, 16, 32, 64]
18 | TOTAL_TOKENS = 8192
19 | NUM_HEADS = 64
20 | EMBED_DIM = 8
21 | SEQ_LENGTH = 2048
22 | SAVE_PATH = os.path.join("doc", "benchmark_attention.png")
23 |
24 | BENCHMARK_SETTINGS_TEMPLATE = """
25 | Benchmark settings:
26 | device: {device}
27 | dtype: {dtype}
28 | total_tokens: {total_tokens}
29 | seq_length: {seq_length}
30 | num_heads: {num_heads}
31 | embed_dim: {embed_dim}
32 | """
33 |
34 |
35 | def main(
36 | num_groups: Sequence[int] = NUM_GROUPS,
37 | total_tokens: int = TOTAL_TOKENS,
38 | seq_length: int = SEQ_LENGTH,
39 | num_heads: int = NUM_HEADS,
40 | embed_dim: int = EMBED_DIM,
41 | device: torch.device = DEVICE,
42 | dtype: torch.dtype = DTYPE,
43 | save_path: str = SAVE_PATH,
44 | ):
45 | batch_size = total_tokens // seq_length
46 | q = torch.randn(
47 | batch_size, seq_length, num_heads, embed_dim, device=device, dtype=dtype
48 | )
49 | kv = torch.randn(
50 | batch_size, seq_length, num_heads, embed_dim, device=device, dtype=dtype
51 | )
52 |
53 | _ = scaled_dot_product_gqa(q, kv, kv)
54 | vanilla_result = benchmark(scaled_dot_product_gqa, q, kv, kv)
55 | print(f"Vanilla: {vanilla_result}")
56 | xformers_result = benchmark(xops.memory_efficient_attention, q, kv, kv)
57 | print(f"Flash Attn: {xformers_result}")
58 |
59 | grouped_times: List[BenchmarkResult] = []
60 | for g in num_groups:
61 | kv = torch.randn(
62 | batch_size, seq_length, g, embed_dim, dtype=dtype, device=device
63 | )
64 | grouped_result = benchmark(
65 | scaled_dot_product_gqa, q, kv, kv, force_grouped=True
66 | )
67 | grouped_times.append(grouped_result)
68 | print(f"Grouped (g={g}): {grouped_result}")
69 |
70 | fig = go.Figure()
71 | fig.add_trace(
72 | go.Scatter(
73 | x=num_groups,
74 | y=[vanilla_result.mean * 1000] * len(num_groups),
75 | mode="lines",
76 | line={"dash": "dash"},
77 | name="Vanilla MHA",
78 | )
79 | )
80 | fig.add_trace(
81 | go.Scatter(
82 | x=num_groups,
83 | y=[r.mean * 1000 for r in grouped_times],
84 | mode="lines",
85 | name="GQA",
86 | )
87 | )
88 | fig.add_trace(
89 | go.Scatter(
90 | x=num_groups,
91 | y=[grouped_times[0].mean * 1000] * len(num_groups),
92 | mode="lines",
93 | line={"dash": "dash"},
94 | name="MQA",
95 | )
96 | )
97 | fig.add_trace(
98 | go.Scatter(
99 | x=num_groups,
100 | y=[xformers_result.mean * 1000] * len(num_groups),
101 | mode="lines",
102 | line={"dash": "dash"},
103 | name="Flash Attn (v1 - xformers)",
104 | )
105 | )
106 | fig.update_layout(
107 | title="Attention Benchmarks",
108 | xaxis_title="GQA Groups",
109 | yaxis_title="Runtime (ms)",
110 | # use log-scale for x-axis
111 | xaxis={"tickmode": "array", "tickvals": num_groups, "type": "log"},
112 | # place legend at center-left
113 | legend={"x": 0.1, "y": 0.5},
114 | )
115 | fig.write_image(save_path)
116 |
117 |
118 | if __name__ == "__main__":
119 | import argparse
120 |
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument(
123 | "--num-groups",
124 | type=int,
125 | nargs="+",
126 | default=NUM_GROUPS,
127 | help="Sequence of GQA group sizes for benchmarking (default: [1, 4, 8, 16, 32, 64])",
128 | )
129 | parser.add_argument(
130 | "--total-tokens",
131 | type=int,
132 | default=TOTAL_TOKENS,
133 | help="Total number of tokens in the batch (default: 8192)",
134 | )
135 | parser.add_argument(
136 | "--seq-length",
137 | type=int,
138 | default=SEQ_LENGTH,
139 | help="Sequence length of the input (default: 2048)",
140 | )
141 | parser.add_argument(
142 | "--num-heads",
143 | type=int,
144 | default=NUM_HEADS,
145 | help="Number of attention heads (default: 64)",
146 | )
147 | parser.add_argument(
148 | "--embed-dim",
149 | type=int,
150 | default=EMBED_DIM,
151 | help="Embedding dimension of the input (default: 8)",
152 | )
153 | parser.add_argument(
154 | "--device",
155 | type=str,
156 | default=DEVICE,
157 | help="Device to run the benchmark on (default: 'cuda' if available else 'cpu')",
158 | )
159 | parser.add_argument(
160 | "--dtype",
161 | type=str,
162 | default=DTYPE,
163 | help="Data type to run the benchmark on (default: 'float16' if cuda is available else 'float32')",
164 | )
165 | parser.add_argument(
166 | "--save-path",
167 | type=str,
168 | default=SAVE_PATH,
169 | help="Path to save the benchmark plot (default: 'doc/benchmark_attention.png')",
170 | )
171 | args = parser.parse_args()
172 |
173 | print(
174 | BENCHMARK_SETTINGS_TEMPLATE.format(
175 | device=args.device,
176 | dtype=args.dtype,
177 | total_tokens=args.total_tokens,
178 | seq_length=args.seq_length,
179 | num_heads=args.num_heads,
180 | embed_dim=args.embed_dim,
181 | )
182 | )
183 |
184 | main(**vars(args))
185 |
--------------------------------------------------------------------------------
/scripts/benchmark_t5.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Sequence
3 |
4 | import plotly.graph_objects as go
5 | import torch
6 | from torch import Tensor
7 | from transformers import T5ForConditionalGeneration, T5Tokenizer
8 |
9 | from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa
10 | from grouped_query_attention_pytorch.utils.benchmark import BenchmarkResult, benchmark
11 |
12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
14 |
15 | # Benchmarking parameters
16 | MODEL_NAME = "t5-3b"
17 | NUM_GROUPS = [1, 2, 4, 8, 16, 32]
18 | TOTAL_TOKENS = 4096
19 | MODEL_MAX_LENGTH = 512
20 | SAVE_PATH = os.path.join("doc", "benchmark_t5.png")
21 |
22 | BENCHMARK_SETTINGS_TEMPLATE = """
23 | Benchmark settings:
24 | device: {device}
25 | dtype: {dtype}
26 | model_name: {model_name}
27 | total_tokens: {total_tokens}
28 | model_max_length: {model_max_length}
29 | """
30 | # NOTE: Text sample taken from the CNN/Daily Mail training set
31 | INPUT_TEXT = """LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains
32 | access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but
33 | he insists the money won't cast a spell on him. Daniel Radcliffe as Harry Potter in
34 | "Harry Potter and the Order of the Phoenix" To the disappointment of gossip
35 | columnists around the world, the young actor says he has no plans to fritter his
36 | cash away on fast cars, drink and celebrity parties. "I don't plan to be one of
37 | those people who, as soon as they turn 18, suddenly buy themselves a massive sports
38 | car collection or something similar," he told an Australian interviewer earlier this
39 | month. "I don't think I'll be particularly extravagant. "The things I like buying
40 | are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe
41 | will be able to gamble in a casino, buy a drink in a pub or see the horror film
42 | "Hostel: Part II," currently six places below his number one movie on the UK box
43 | office chart. Details of how he'll mark his landmark birthday are under wraps.
44 | His agent and publicist had no comment on his plans. "I'll definitely have some
45 | sort of party," he said in an interview. "Hopefully none of you will be reading
46 | about it." Radcliffe's earnings from the first five Potter films have been held
47 | in a trust fund which he has not been able to touch. Despite his growing fame and
48 | riches, the actor says he is keeping his feet firmly on the ground. "People are
49 | always looking to say 'kid star goes off the rails,'" he told reporters last month.
50 | "But I try very hard not to go that way because it would be too easy for them." His
51 | latest outing as the boy wizard in "Harry Potter and the Order of the Phoenix" is
52 | breaking records on both sides of the Atlantic and he will reprise the role in the
53 | last two films. Watch I-Reporter give her review of Potter's latest » . There is life
54 | beyond Potter, however. The Londoner has filmed a TV movie called "My Boy Jack," about
55 | author Rudyard Kipling and his son, due for release later this year. He will also
56 | appear in "December Boys," an Australian film about four boys who escape an orphanage.
57 | Earlier this year, he made his stage debut playing a tortured teenager in Peter
58 | Shaffer's "Equus." Meanwhile, he is braced for even closer media scrutiny now that
59 | he's legally an adult: "I just think I'm going to be more sort of fair game," he told
60 | Reuters. E-mail to a friend . Copyright 2007 Reuters. All rights reserved.This material
61 | may not be published, broadcast, rewritten, or redistributed."""
62 | TARGET_TEXT = """Harry Potter star Daniel Radcliffe gets £20M fortune as he turns
63 | 18 Monday. Young actor says he has no plans to fritter his cash away. Radcliffe's
64 | earnings from first five Potter films have been held in trust fund."""
65 |
66 |
67 | @torch.no_grad()
68 | def forward_fn(model: T5ForConditionalGeneration, input_ids: Tensor, labels: Tensor):
69 | return model(input_ids=input_ids, labels=labels).loss
70 |
71 |
72 | def main(
73 | model_name: str = MODEL_NAME,
74 | num_groups: Sequence[int] = NUM_GROUPS,
75 | total_tokens: int = TOTAL_TOKENS,
76 | model_max_length: int = MODEL_MAX_LENGTH,
77 | device: torch.device = DEVICE,
78 | dtype: torch.dtype = DTYPE,
79 | save_path: str = SAVE_PATH,
80 | ):
81 | print(f"Loading model and tokenizer for {model_name}...")
82 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
83 | model_name
84 | ).to(device=device, dtype=dtype)
85 | tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)
86 |
87 | batch_size = total_tokens // model_max_length
88 | input_ids = tokenizer(
89 | [INPUT_TEXT] * batch_size,
90 | return_tensors="pt",
91 | max_length=model_max_length,
92 | truncation=True,
93 | ).input_ids.to(device=device)
94 | labels = tokenizer(
95 | [TARGET_TEXT] * batch_size,
96 | return_tensors="pt",
97 | max_length=model_max_length,
98 | truncation=True,
99 | ).input_ids.to(device=device)
100 |
101 | mha_result = benchmark(forward_fn, t5, input_ids, labels)
102 | print(f"MHA: {mha_result}")
103 | del t5
104 | torch.cuda.empty_cache()
105 |
106 | grouped_times: List[BenchmarkResult] = []
107 | for g in num_groups:
108 | print("Reloading model and converting to GQA...")
109 | t5 = T5ForConditionalGeneration.from_pretrained(model_name).to(
110 | device=device, dtype=dtype
111 | )
112 | gqa = convert_t5_to_gqa(t5, kv_heads=g, inplace=True)
113 | grouped_result = benchmark(forward_fn, gqa, input_ids, labels)
114 | grouped_times.append(grouped_result)
115 | print(f"Grouped (g={g}): {grouped_result}")
116 |
117 | del t5, gqa
118 | torch.cuda.empty_cache()
119 |
120 | fig = go.Figure()
121 | fig.add_trace(
122 | go.Scatter(
123 | x=num_groups,
124 | y=[mha_result.mean] * len(num_groups),
125 | mode="lines",
126 | line={"dash": "dash"},
127 | name="MHA",
128 | )
129 | )
130 | fig.add_trace(
131 | go.Scatter(
132 | x=num_groups,
133 | y=[r.mean for r in grouped_times],
134 | mode="lines",
135 | name="GQA",
136 | )
137 | )
138 | fig.add_trace(
139 | go.Scatter(
140 | x=num_groups,
141 | y=[grouped_times[0].mean] * len(num_groups),
142 | mode="lines",
143 | line={"dash": "dash"},
144 | name="MQA",
145 | )
146 | )
147 | fig.update_layout(
148 | title="T5 Benchmarks",
149 | xaxis_title="GQA Groups",
150 | yaxis_title="Time per sample (s)",
151 | # use log-scale for x-axis
152 | xaxis={"tickmode": "array", "tickvals": num_groups, "type": "log"},
153 | # place legend at center-left
154 | legend={"x": 0.1, "y": 0.5},
155 | )
156 | fig.write_image(save_path)
157 |
158 |
159 | if __name__ == "__main__":
160 | import argparse
161 |
162 | # NOTE: The original paper uses T5 v1.1 XL and XXL models. When I load those
163 | # models through 'transformers' without applying GQA, I get nonsense outputs.
164 | # TODO: Figure out why this is happening.
165 | # tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-large", legacy=False)
166 | # model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-large")
167 | #
168 | # In the meantime, we can use the non-Google T5 models, which seem to work fine.
169 | # NOTE: Since the the original number of heads (n_heads) must be divisible by
170 | # 'kv_heads', there are only certain values of 'kv_heads' that we can use.
171 | # The following values of 'kv_heads' should be valid:
172 | # - t5-small: 1, 2, 4, 8
173 | # - t5-base: 1, 2, 3, 4, 6, 12
174 | # - t5-large: 1, 2, 4, 8, 16
175 | # - t5-3b: 1, 2, 4, 8, 16, 32 (DEFAULT)
176 | # - t5-11b: 1, 2, 4, 8, 16, 32, 64 TODO: Check 11b values specifically
177 |
178 | parser = argparse.ArgumentParser()
179 | parser.add_argument(
180 | "--model-name",
181 | type=str,
182 | default=MODEL_NAME,
183 | help=f"Name of the T5 model to benchmark (default: '{MODEL_NAME}')",
184 | )
185 | parser.add_argument(
186 | "--num-groups",
187 | type=int,
188 | nargs="+",
189 | default=NUM_GROUPS,
190 | help=f"Sequence of GQA group sizes for benchmarking (default: {NUM_GROUPS})",
191 | )
192 | parser.add_argument(
193 | "--total-tokens",
194 | type=int,
195 | default=TOTAL_TOKENS,
196 | help=f"Total number of tokens in the batch (default: {TOTAL_TOKENS})",
197 | )
198 | parser.add_argument(
199 | "--model-max-length",
200 | type=int,
201 | default=MODEL_MAX_LENGTH,
202 | help=f"Sequence length of the input (default: {MODEL_MAX_LENGTH})",
203 | )
204 | parser.add_argument(
205 | "--device",
206 | type=str,
207 | default=DEVICE,
208 | help=f"Device to run the benchmark on (default: {DEVICE})",
209 | )
210 | parser.add_argument(
211 | "--dtype",
212 | type=str,
213 | default=DTYPE,
214 | help=f"Data type to run the benchmark on (default: {DTYPE})",
215 | )
216 | parser.add_argument(
217 | "--save-path",
218 | type=str,
219 | default=SAVE_PATH,
220 | help=f"Path to save the benchmark plot (default: '{SAVE_PATH}')",
221 | )
222 | args = parser.parse_args()
223 |
224 | print(
225 | BENCHMARK_SETTINGS_TEMPLATE.format(
226 | device=args.device,
227 | dtype=args.dtype,
228 | model_name=args.model_name,
229 | total_tokens=args.total_tokens,
230 | model_max_length=args.model_max_length,
231 | )
232 | )
233 |
234 | main(**vars(args))
235 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | def pytest_addoption(parser):
5 | parser.addoption("--slow", action="store_true")
6 |
7 |
8 | def pytest_configure(config):
9 | config.addinivalue_line("markers", "slow: slow to run")
10 |
11 |
12 | def pytest_collection_modifyitems(config, items):
13 | run_slow = config.getoption("--slow")
14 | skip_fast = pytest.mark.skip(reason="remove --slow option to run")
15 | skip_slow = pytest.mark.skip(reason="need --slow option to run")
16 |
17 | for item in items:
18 | if ("slow" in item.keywords) and (not run_slow):
19 | item.add_marker(skip_slow)
20 | if ("slow" not in item.keywords) and (run_slow):
21 | item.add_marker(skip_fast)
22 |
--------------------------------------------------------------------------------
/tests/test_attention.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from grouped_query_attention_pytorch.attention import (
6 | MultiheadGQA,
7 | scaled_dot_product_gqa,
8 | )
9 |
10 | torch.backends.cudnn.deterministic = True
11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 | DTYPE = torch.float64
13 | SEQ_LEN = 16
14 |
15 |
16 | @pytest.mark.parametrize("embed_dim", [64])
17 | @pytest.mark.parametrize("num_heads", [4, 8])
18 | @pytest.mark.parametrize("kv_heads", [4, 8])
19 | @pytest.mark.parametrize("is_causal", [True, False])
20 | def test_grouped_scaled_dot_product_attention(
21 | embed_dim: int,
22 | num_heads: int,
23 | kv_heads: int,
24 | is_causal: bool,
25 | ):
26 | x = torch.randn(1, SEQ_LEN, num_heads, embed_dim, device=DEVICE, dtype=DTYPE)
27 | kv = torch.randn(1, SEQ_LEN, kv_heads, embed_dim, device=DEVICE, dtype=DTYPE)
28 |
29 | if kv_heads > num_heads:
30 | with pytest.raises(ValueError):
31 | scaled_dot_product_gqa(x, kv, kv, is_causal=is_causal)
32 | return
33 |
34 | out, attn_weights = scaled_dot_product_gqa(
35 | x, kv, kv, is_causal=is_causal, need_weights=True
36 | )
37 | assert out.size(0) == 1
38 | assert out.size(1) == SEQ_LEN
39 | assert out.size(2) == num_heads
40 | assert out.size(3) == embed_dim
41 | assert attn_weights.size(0) == 1
42 | assert attn_weights.size(1) == SEQ_LEN
43 | assert attn_weights.size(2) == SEQ_LEN
44 | assert attn_weights.size(3) == num_heads
45 |
46 | # Test that grouped SDPA is equivalent to SDPA if we duplicate the KV heads.
47 | kv = kv.repeat_interleave(num_heads // kv_heads, dim=2)
48 | kv = kv.permute(0, 2, 1, 3)
49 | x = x.permute(0, 2, 1, 3)
50 | out_vanilla = F.scaled_dot_product_attention(x, kv, kv, is_causal=is_causal)
51 | out_vanilla = out_vanilla.permute(0, 2, 1, 3)
52 | torch.testing.assert_close(out, out_vanilla)
53 |
54 |
55 | @torch.no_grad()
56 | @pytest.mark.parametrize("embed_dim", [64, 128])
57 | @pytest.mark.parametrize("num_heads", [4, 8])
58 | @pytest.mark.parametrize("kv_heads", [4, 8])
59 | @pytest.mark.parametrize("is_causal", [True, False])
60 | def test_multihead_gqa(
61 | embed_dim: int,
62 | num_heads: int,
63 | kv_heads: int,
64 | is_causal: bool,
65 | ):
66 | if kv_heads > num_heads:
67 | with pytest.raises(ValueError):
68 | MultiheadGQA(embed_dim, num_heads, kv_heads)
69 | return
70 |
71 | mhda = MultiheadGQA(embed_dim, num_heads, kv_heads, device=DEVICE, dtype=DTYPE)
72 | x = torch.randn(1, SEQ_LEN, embed_dim, device=DEVICE, dtype=DTYPE)
73 |
74 | out, _ = mhda(x, x, x, is_causal=is_causal) # default: causal=False
75 | assert out.size(0) == 1
76 | assert out.size(1) == SEQ_LEN
77 | assert out.size(2) == embed_dim
78 |
--------------------------------------------------------------------------------
/tests/test_t5.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from transformers import T5ForConditionalGeneration, T5Tokenizer
4 | from transformers.models.t5.modeling_t5 import T5Attention, T5LayerFF
5 |
6 | from grouped_query_attention_pytorch.t5 import T5GQA, convert_t5_to_gqa
7 |
8 | MODEL_NAME = "t5-small"
9 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
11 | SEQ_LEN = 16
12 |
13 |
14 | # Test all of the valid kv_heads values for 't5-small
15 | @pytest.mark.parametrize("kv_heads", [1, 2, 4, 8])
16 | def test_convert_t5_to_gqa(kv_heads: int):
17 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
18 | MODEL_NAME
19 | ).to(device=DEVICE, dtype=DTYPE)
20 |
21 | # Just to establish that the T5 model is as expected. Check that all of the
22 | # known attention layers are of type 'T5Attention'.
23 | for block in t5.encoder.block:
24 | for layer in block.layer:
25 | if hasattr(layer, "SelfAttention"):
26 | assert isinstance(layer.SelfAttention, T5Attention)
27 | else:
28 | assert isinstance(layer, T5LayerFF)
29 | for block in t5.decoder.block:
30 | for layer in block.layer:
31 | if hasattr(layer, "SelfAttention"):
32 | assert isinstance(layer.SelfAttention, T5Attention)
33 | elif hasattr(layer, "EncDecAttention"):
34 | assert isinstance(layer.EncDecAttention, T5Attention)
35 | else:
36 | assert isinstance(layer, T5LayerFF)
37 |
38 | gqa = convert_t5_to_gqa(t5, kv_heads=kv_heads, inplace=True)
39 | # After conversion, check that all of the attention layers above have been
40 | # replaced with 'T5GQA' layers.
41 | for block in gqa.encoder.block:
42 | for layer in block.layer:
43 | if hasattr(layer, "SelfAttention"):
44 | assert isinstance(layer.SelfAttention, T5GQA)
45 | else:
46 | assert isinstance(layer, T5LayerFF)
47 | for block in t5.decoder.block:
48 | for layer in block.layer:
49 | if hasattr(layer, "SelfAttention"):
50 | assert isinstance(layer.SelfAttention, T5GQA)
51 | elif hasattr(layer, "EncDecAttention"):
52 | assert isinstance(layer.EncDecAttention, T5GQA)
53 | else:
54 | assert isinstance(layer, T5LayerFF)
55 |
56 | # Check that we can pass inputs/targets through the modified model, and that
57 | # the is returns a loss that is a scalar Tensor.
58 | tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, legacy=False)
59 | inputs = tokenizer(
60 | "translate English to German: The house is wonderful.", return_tensors="pt"
61 | )
62 | targets = tokenizer("Das Haus ist wunderbar.", return_tensors="pt")
63 | with torch.no_grad():
64 | loss = gqa(
65 | inputs.input_ids.to(DEVICE),
66 | labels=targets.input_ids.to(DEVICE),
67 | ).loss
68 | assert isinstance(loss, torch.Tensor)
69 | assert loss.size() == ()
70 |
71 | # Check that we can generate outputs, using the usual 'generate' method.
72 | out = gqa.generate(inputs.input_ids.to(DEVICE), max_new_tokens=10)
73 | assert isinstance(out, torch.Tensor)
74 |
--------------------------------------------------------------------------------
/tests/test_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union
2 |
3 | import pytest
4 | import torch
5 |
6 | from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM
7 |
8 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
10 | SEQ_LEN = 16
11 |
12 |
13 | @pytest.mark.parametrize("d_model", [128])
14 | @pytest.mark.parametrize("nhead", [4, 8])
15 | @pytest.mark.parametrize("kv_heads", [2, 4])
16 | @pytest.mark.parametrize("num_layers", [1, 2])
17 | @pytest.mark.parametrize("dim_feedforward", [64])
18 | @pytest.mark.parametrize("dropout", [0.0])
19 | @pytest.mark.parametrize("activation", ["relu", "gelu"])
20 | @pytest.mark.parametrize("is_causal", [True, False])
21 | def test_gqa_transformer(
22 | d_model: int,
23 | nhead: int,
24 | kv_heads: int,
25 | num_layers: int,
26 | dim_feedforward: int,
27 | dropout: float,
28 | activation: Union[str, Callable],
29 | is_causal: bool,
30 | ):
31 | net = GQATransformer(
32 | d_model=d_model,
33 | nhead=nhead,
34 | kv_heads=kv_heads,
35 | num_encoder_layers=num_layers,
36 | num_decoder_layers=num_layers,
37 | dim_feedforward=dim_feedforward,
38 | dropout=dropout,
39 | activation=activation,
40 | device=DEVICE,
41 | dtype=DTYPE,
42 | )
43 | x = torch.randn(1, SEQ_LEN, d_model, device=DEVICE, dtype=DTYPE)
44 | with torch.no_grad():
45 | y = net.forward(x, is_causal=is_causal)
46 | assert y.size(0) == 1
47 | assert y.size(1) == SEQ_LEN
48 | assert y.size(2) == d_model
49 |
50 |
51 | @pytest.mark.parametrize("num_tokens", [100])
52 | @pytest.mark.parametrize("d_model", [128])
53 | @pytest.mark.parametrize("nhead", [4, 8])
54 | @pytest.mark.parametrize("kv_heads", [2, 4])
55 | @pytest.mark.parametrize("num_layers", [1, 2])
56 | @pytest.mark.parametrize("dim_feedforward", [64])
57 | @pytest.mark.parametrize("dropout", [0.0])
58 | @pytest.mark.parametrize("activation", ["relu", "gelu"])
59 | @pytest.mark.parametrize("is_causal", [True, False])
60 | def test_gqa_transformer_lm(
61 | num_tokens: int,
62 | d_model: int,
63 | nhead: int,
64 | kv_heads: int,
65 | num_layers: int,
66 | dim_feedforward: int,
67 | dropout: float,
68 | activation: Union[str, Callable],
69 | is_causal: bool,
70 | ):
71 | net = GQATransformerLM(
72 | num_tokens=num_tokens,
73 | d_model=d_model,
74 | nhead=nhead,
75 | kv_heads=kv_heads,
76 | num_encoder_layers=num_layers,
77 | num_decoder_layers=num_layers,
78 | dim_feedforward=dim_feedforward,
79 | dropout=dropout,
80 | activation=activation,
81 | device=DEVICE,
82 | dtype=DTYPE,
83 | )
84 | x = torch.randint(0, num_tokens, (1, SEQ_LEN), device=DEVICE, dtype=torch.long)
85 | with torch.no_grad():
86 | y = net.forward(x, is_causal=is_causal)
87 | assert y.size(0) == 1
88 | assert y.size(1) == SEQ_LEN
89 | assert y.size(2) == num_tokens
90 |
--------------------------------------------------------------------------------