├── .coveragerc ├── .github ├── FUNDING.yml └── workflows │ ├── publish.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── doc ├── attention.png ├── benchmark_attention.png ├── benchmark_t5.png └── benchmark_t5_original.png ├── grouped_query_attention_pytorch ├── __init__.py ├── attention.py ├── t5.py ├── transformer.py └── utils │ ├── __init__.py │ └── benchmark.py ├── pyproject.toml ├── scripts ├── README.md ├── benchmark_attention.py └── benchmark_t5.py └── tests ├── __init__.py ├── conftest.py ├── test_attention.py ├── test_t5.py └── test_transformer.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = grouped_query_attention_pytorch/ 3 | omit = grouped_query_attention_pytorch/utils/benchmark.py 4 | 5 | [report] 6 | exclude_lines = 7 | # Have to re-enable the standard pragma 8 | pragma: no cover 9 | 10 | # Don't complain about missing debug-only code: 11 | def __repr__ 12 | if self\.debug 13 | 14 | # Don't complain if tests don't hit defensive assertion code: 15 | raise AssertionError 16 | raise NotImplementedError 17 | 18 | # Don't complain if non-importable code isn't run: 19 | if __name__ == .__main__.: -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [fkodom] 4 | custom: [fkodom.substack.com] 5 | # patreon: # Replace with a single Patreon username 6 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_dispatch: {} 5 | release: 6 | types: 7 | - created 8 | 9 | env: 10 | PYTHON_VERSION: 3.9 11 | 12 | jobs: 13 | publish: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Setup Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ env.PYTHON_VERSION }} 22 | 23 | - name: Install Dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install build 27 | 28 | - name: Build Package 29 | env: 30 | GROUPED_QUERY_ATTENTION_PYTORCH_VERSION: ${{ github.event.release.tag_name }} 31 | run: python -m build 32 | 33 | - name: Publish to PyPI 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | user: __token__ 37 | password: ${{ secrets.PYPI_API_TOKEN }} 38 | verify-metadata: false 39 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_dispatch: {} 5 | push: {} 6 | 7 | jobs: 8 | test: 9 | name: Test 10 | runs-on: ubuntu-latest 11 | continue-on-error: true 12 | 13 | strategy: 14 | matrix: 15 | python: ["3.8", "3.9", "3.10"] 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | 21 | - name: Setup Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python }} 25 | 26 | - name: Templatize 27 | # Templatize the repo before attempting to run tests. (Tests will fail 28 | # otherwise, due to syntax errors.) 29 | # NOTE: Check if the 'templatize' script exists, so that this doesn't 30 | # immediately fail for repos that have already run the script. 31 | run: | 32 | if test -f "templatize"; then 33 | ./templatize 34 | fi 35 | 36 | - name: Install Package 37 | run: | 38 | pip install -e .[test,t5] 39 | 40 | - name: Test 41 | run: | 42 | ruff . 43 | pytest --cov --cov-report term-missing --cov-fail-under 80 tests/ 44 | mypy 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: black 5 | name: black 6 | stages: [commit] 7 | language: system 8 | entry: black 9 | types: [python] 10 | 11 | - id: ruff 12 | name: ruff 13 | stages: [commit] 14 | language: system 15 | entry: ruff 16 | types: [python] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Frank Odom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # grouped-query-attention-pytorch 2 | 3 | (Unofficial) PyTorch implementation of grouped-query attention (GQA) from [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/pdf/2305.13245.pdf) 4 | 5 | compare-attention-mechanisms 6 | 7 | ### Includes: 8 | - [x] scaled dot-product attention with GQA support. (See: [scaled_dot_product_gqa usage](#scaled_dot_product_gqa)) 9 | - [x] GQA multi-head attention layer. (See: [MultiheadGQA usage](#multiheadgqa)) 10 | - [x] Code to convert pretrained T5 model to use GQA. (See: [T5 usage](#t5) ) 11 | - [x] Prototype (untrained) GQA encoder-decoder models: `GQATransformer`, `GQATransformerLM` (See: [GQATransformer )usage](#gqatransformer)) 12 | - [x] Reproduce runtime benchmarks from [GQA paper](https://arxiv.org/pdf/2305.13245.pdf), figure 6 (See: [scripts/)README.md](scripts/README.md)) 13 | 14 | ### To do: 15 | - [ ] Fine-tuning code for T5 GQA models 16 | - [ ] Reproduce fine-tuning results from [GQA paper](https://arxiv.org/pdf/2305.13245.pdf), figures 3,5 17 | 18 | ## Install 19 | 20 | PyPI: (NOT YET AVAILABLE) 21 | ```bash 22 | pip install grouped-query-attention-pytorch 23 | ``` 24 | 25 | From source: 26 | ```bash 27 | pip install "grouped-query-attention-pytorch @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git" 28 | ``` 29 | 30 | For contributors: 31 | ```bash 32 | # Install all dev dependencies (tests, T5 support, etc.) 33 | pip install "grouped-query-attention-pytorch[test,t5] @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git" 34 | # Setup pre-commit hooks 35 | pre-commit install 36 | ``` 37 | 38 | 39 | ## Benchmark 40 | 41 | I attempt to reproduce the runtime benchmarks from the [GQA paper](https://arxiv.org/pdf/2305.13245.pdf) (Figure 6). Unfortunately, I don't have access to the same hardware, so the comparison isn't perfect. (They use multiple high-end GPUs, and I use a single 2080 Ti.) Even with different hardware, though, it is clear that runtime scales similarly with the number of GQA groups. 42 | 43 | For more details, see [scripts/README.md](scripts/README.md#benchmark_t5) 44 | 45 | > Left: This repo
Right: Original paper 46 |

47 | drawing 48 | drawing 49 |

50 | 51 | 52 | ## Usage 53 | 54 | ### `scaled_dot_product_gqa` 55 | 56 | See: [attention.py](grouped_query_attention_pytorch/attention.py) 57 | 58 | Intended to be a drop-in replacement for `F.scaled_dot_product_attention` with support for GQA. 59 | 60 | > **NOTE**: The built-in `F.scaled_dot_product_attention` will be *much* faster when you're **not** using grouped queries -- especially for `torch>=2.0`, which uses [flash attention](https://github.com/Dao-AILab/flash-attention) under the hood. However, [this benchmark](./scripts/README.md#benchmark_attention) shows that naie `scaled_dot_product_gqa` is faster than flash attention when the number of GQA groups is small. 🔥 61 | 62 | ```python 63 | import torch 64 | 65 | from grouped_query_attention_pytorch.attention import scaled_dot_product_gqa 66 | 67 | # shapes: (batch_size, seq_len, num_heads, head_dim) 68 | query = torch.randn(1, 256, 8, 64, device="cuda", dtype=torch.float16) 69 | key = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16) 70 | value = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16) 71 | 72 | out, attn_weights = scaled_dot_product_gqa( 73 | query, 74 | key, 75 | value, 76 | is_causal=True, # default: False 77 | need_weights=True, # default: False, which returns 'attn_weights=None' 78 | ) 79 | print(out.shape) # (batch_size, q_seq_len, kv_heads, embed_dim) 80 | # torch.Size([1, 256, 2, 64]) 81 | print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads) 82 | # torch.Size([1, 256, 128, 2]) 83 | ``` 84 | 85 | 86 | ### `MultiheadGQA` 87 | 88 | See: [attention.py](grouped_query_attention_pytorch/attention.py) 89 | 90 | Intended to be a drop-in replacement for `nn.MultiheadAttention` with support for GQA. 91 | 92 | > **NOTE**: The same performance advice from [scaled_dot_product_gqa](#scaled_dot_product_gqa) (above) applies here as well. 93 | 94 | ```python 95 | from grouped_query_attention_pytorch.attention import MultiheadGQA 96 | 97 | mha = MultiheadGQA( 98 | embed_dim=512, query_heads=8, kv_heads=2, device="cuda", dtype=torch.float16 99 | ) 100 | 101 | # shapes: (batch_size, seq_len, embed_dim) 102 | query = torch.randn(1, 256, 512, device="cuda", dtype=torch.float16) 103 | key = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16) 104 | value = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16) 105 | 106 | out, attn_weights = mha( 107 | query, 108 | key, 109 | value, 110 | is_causal=True, # default: False 111 | need_weights=True, # default: False, which returns 'attn_weights=None' 112 | ) 113 | print(out.shape) # (batch_size, q_seq_len, embed_dim) 114 | # torch.Size([1, 256, 512]) 115 | print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads) 116 | # torch.Size([1, 256, 128, 2]) 117 | ``` 118 | 119 | 120 | ### T5 121 | 122 | See: [t5.py](grouped_query_attention_pytorch/t5.py) 123 | 124 | Convert a pretrained T5 model from [huggingface/transformers](https://github.com/huggingface/transformers) to use GQA. The resulting model can be used and trained with the Huggingface Transformers library, just like an ordinary T5 model. 125 | 126 | ```python 127 | from transformers import T5ForConditionalGeneration, T5Tokenizer 128 | 129 | from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa 130 | 131 | # Initialize a pre-trained T5 model 132 | t5 = T5ForConditionalGeneration.from_pretrained("t5-small") 133 | tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False) 134 | # Convert attention layers to GQA 135 | t5_gqa = convert_t5_to_gqa(t5, kv_heads=2, inplace=False) # default: inplace=False 136 | 137 | # Generate some text with the converted model 138 | input_ids = tokenizer( 139 | "translate English to German: The house is wonderful.", return_tensors="pt" 140 | ).input_ids 141 | outputs = t5_gqa.generate(input_ids, max_new_tokens=25) 142 | text = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) 143 | print(text) 144 | # The correct answer is: ['', 'Das', 'Haus', 'ist', 'wunderbar', '.', ''] 145 | # NOTE: The original T5 model produces this answer, and so does GQA when we use the 146 | # maximum number of KV heads (kv_heads=8 in this example), which effectively makes 147 | # GQA equivalent to the original T5 model with MHA. The text quickly degrades as 148 | # we reduce the number of heads. 149 | ``` 150 | 151 | ### GQATransformer 152 | 153 | I also provide a prototype implementation of an (untrained) encoder-decoder Transformer model, which uses GQA instead of MHA. This is mostly for reference/educational purposes, but in principle it could be used as a drop-in replacement for `nn.Transformer`. 154 | 155 | See: [transformer.py](grouped_query_attention_pytorch/transformer.py) 156 | 157 | ```python 158 | from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM 159 | 160 | device = torch.device("cuda") 161 | dtype = torch.float16 162 | 163 | net = GQATransformer( 164 | d_model=512, # required 165 | nhead=8, # required 166 | kv_heads=2, # required 167 | num_encoder_layers=6, 168 | num_decoder_layers=6, 169 | dim_feedforward=2048, 170 | dropout=0.1, 171 | activation="relu", 172 | layer_norm_eps=1e-5, 173 | device=device, 174 | dtype=dtype, 175 | ) 176 | # shape: (batch_size, seq_len, d_model) 177 | x = torch.randn(1, 256, 512, device=device, dtype=dtype) 178 | with torch.no_grad(): 179 | y = net.forward(x, is_causal=True) # default: is_causal=True 180 | print(y.shape) 181 | # torch.Size([1, 256, 512]) 182 | 183 | num_tokens = 10000 # usually obtained from the tokenizer 184 | lm = GQATransformerLM( 185 | num_tokens=num_tokens, # required 186 | d_model=512, # required 187 | nhead=8, # required 188 | kv_heads=2, # required 189 | num_encoder_layers=6, 190 | num_decoder_layers=6, 191 | dim_feedforward=2048, 192 | dropout=0.1, 193 | activation="relu", 194 | layer_norm_eps=1e-5, 195 | device=device, 196 | dtype=dtype, 197 | ) 198 | # shape: (batch_size, seq_len) 199 | x = torch.randint(0, num_tokens, (1, 256), device=device, dtype=torch.long) 200 | with torch.no_grad(): 201 | y = lm.forward(x, is_causal=True) # default: is_causal=True 202 | print(y.shape) 203 | # torch.Size([1, 256, num_tokens]) 204 | ``` 205 | -------------------------------------------------------------------------------- /doc/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/attention.png -------------------------------------------------------------------------------- /doc/benchmark_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_attention.png -------------------------------------------------------------------------------- /doc/benchmark_t5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_t5.png -------------------------------------------------------------------------------- /doc/benchmark_t5_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/doc/benchmark_t5_original.png -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from subprocess import getoutput 3 | 4 | 5 | def get_version_tag() -> str: 6 | try: 7 | env_key = "GROUPED_QUERY_ATTENTION_PYTORCH_VERSION".upper() 8 | version = os.environ[env_key] 9 | except KeyError: 10 | version = getoutput("git describe --tags --abbrev=0") 11 | 12 | if version.lower().startswith("fatal"): 13 | version = "0.0.0" 14 | 15 | return version 16 | 17 | 18 | VERSION = get_version_tag() 19 | -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import einsum, rearrange 6 | from torch import Tensor, nn 7 | 8 | 9 | def scaled_dot_product_gqa( 10 | query: Tensor, 11 | key: Tensor, 12 | value: Tensor, 13 | dropout: float = 0.0, 14 | scale: Optional[float] = None, 15 | mask: Optional[Tensor] = None, 16 | is_causal: Optional[bool] = None, 17 | need_weights: bool = False, 18 | average_attn_weights: bool = False, 19 | force_grouped: bool = False, 20 | ): 21 | """Scaled dot product attention with support for grouped queries. 22 | 23 | Einstein notation: 24 | - b: batch size 25 | - n / s: sequence length 26 | - h: number of heads 27 | - g: number of groups 28 | - d: dimension of query/key/value 29 | 30 | Args: 31 | query: Query tensor of shape (b, n, h, d) 32 | key: Key tensor of shape (b, s, h, d) 33 | value: Value tensor of shape (b, s, h, d) 34 | dropout: Dropout probability (default: 0.0) 35 | scale: Scale factor for query (default: d_query ** 0.5) 36 | mask: Mask tensor of shape (b, n, s) or (b, s). If 'ndim == 2', the mask is 37 | applied to all 'n' rows of the attention matrix. (default: None) 38 | force_grouped: If True, apply grouped-query attention even if the number of 39 | heads is equal for query, key, and value. (default: False) 40 | 41 | Returns: 42 | 2-tuple of: 43 | - Attention output with shape (b, n, h, d) 44 | - (Optional) Attention weights with shape (b, h, n, s). Only returned if 45 | 'need_weights' is True. 46 | """ 47 | if (mask is not None) and (is_causal is not None): 48 | raise ValueError( 49 | "Only one of 'mask' and 'is_causal' should be provided, but got both." 50 | ) 51 | elif not query.ndim == key.ndim == value.ndim == 4: 52 | raise ValueError( 53 | f"Expected query, key, and value to be 4-dimensional, but got shapes " 54 | f"{query.shape}, {key.shape}, and {value.shape}." 55 | ) 56 | 57 | # Move sequence length dimension to axis 2. 58 | # This makes the attention operations below *much* faster. 59 | query = rearrange(query, "b n h d -> b h n d") 60 | key = rearrange(key, "b s h d -> b h s d") 61 | value = rearrange(value, "b s h d -> b h s d") 62 | 63 | bq, hq, nq, dq = query.shape 64 | bk, hk, nk, dk = key.shape 65 | bv, hv, nv, dv = value.shape 66 | if not (bq == bk == bv and dq == dk == dv): 67 | raise ValueError( 68 | "Expected query, key, and value to have the same batch size (dim=0) and " 69 | f"embedding dimension (dim=3), but got query: {query.shape}, " 70 | f"key: {key.shape}, and value: {value.shape}." 71 | ) 72 | elif (hk != hv) or (nk != nv): 73 | raise ValueError( 74 | "Expected key and value to have the same size in dimensions 1 and 2, but " 75 | f"got key: {key.shape} and value: {value.shape}." 76 | ) 77 | elif hq % hk != 0: 78 | raise ValueError( 79 | "Expected query heads to be a multiple of key/value heads, but got " 80 | f"query: {query.shape} and key/value: {key.shape}." 81 | ) 82 | 83 | if scale is None: 84 | scale = query.size(-1) ** 0.5 85 | query = query / scale 86 | 87 | num_head_groups = hq // hk 88 | query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) 89 | similarity = einsum(query, key, "b g h n d, b h s d -> b g h n s") 90 | 91 | if is_causal: 92 | # Mask out the upper triangular portion of the attention matrix. This prevents 93 | # the model from attending to tokens in the future. 94 | mask = torch.ones((bq, nq, nk), device=query.device, dtype=torch.bool).tril_() 95 | 96 | if mask is not None: 97 | # Expand mask to match the shape of the attention matrix. 98 | # If mask is 2D, assume that it is applied to the key/value sequence dimension. 99 | # Else if mask is 3D, assume that it is applied to the query/key/value sequence 100 | # dimension for all attention heads. 101 | # 102 | # Users could also provide a 4D mask, which is applied to the query/key/value 103 | # sequence dimension for each attention head (though I don't have a particular 104 | # use case in mind for that). 105 | if mask.ndim == 2: 106 | mask = rearrange(mask, "b s -> b () () () s") 107 | elif mask.ndim == 3: 108 | mask = rearrange(mask, "b n s -> b () () n s") 109 | # Mask similarity values by setting them to negative infinity. This guarantees 110 | # that they will not contribute to the softmax computation below. 111 | similarity.masked_fill_(~mask, torch.finfo(similarity.dtype).min) 112 | 113 | attention = F.softmax(similarity, dim=-1) 114 | if dropout > 0.0: 115 | attention = F.dropout(attention, p=dropout) 116 | 117 | # Apply attention matrix to the value Tensor. 118 | out = einsum(attention, value, "b g h n s, b h s d -> b g h n d") 119 | # Move head dimension back to axis 2 120 | out = rearrange(out, "b g h n d -> b n (h g) d") 121 | 122 | attn_weights: Optional[Tensor] = None 123 | if need_weights: 124 | # Move the sequence dimensions back to positions 1, 2. Move the head dimension 125 | # to position 3. This more closely matches the return shape of the attention 126 | # output: (b, n, h, d). 127 | attn_weights = rearrange(attention, "b g h n s -> b n s (h g)") 128 | if average_attn_weights: 129 | attn_weights = attn_weights.mean(dim=1) 130 | 131 | return out, attn_weights 132 | 133 | 134 | class MultiheadGQA(nn.Module): 135 | """Multi-head grouped query attention (GQA) layer. 136 | 137 | Reference: 138 | "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" 139 | https://arxiv.org/pdf/2305.13245v1.pdf 140 | 141 | GQA is a variant of multihead attention (MHA) that uses fewer write heads 142 | (key / value) than query heads. GQA can be viewed as a generalization of 143 | multi-query attention (MQA), which uses a single write head. GQA and MQA give 144 | significant speedups over standard MHA in decoder layers, with minimal loss in 145 | accuracy. In the paper, GQA is shown to be more accurate than MQA, while still 146 | having a significant speedup over MHA. 147 | 148 | NOTE: The original authors only benchmark GQA by adapting the T5 (XL or XXL) model 149 | from MHA to GQA. As a result, they do not mention parameter initialization or 150 | layer normalization strategies. I follow the best practices laid out in the 151 | MAGNETO paper, which improves Transformer performance through better parameter 152 | initialization and layer norm placement. See: 153 | https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 154 | """ 155 | 156 | def __init__( 157 | self, 158 | embed_dim: int, 159 | query_heads: int, 160 | kv_heads: int, 161 | dropout: float = 0.0, 162 | bias: bool = True, 163 | layer_norm: bool = True, 164 | layer_norm_eps: float = 1e-5, 165 | gamma_init: float = 1.0, 166 | device: Optional[Union[torch.device, str]] = None, 167 | dtype: Optional[torch.dtype] = None, 168 | ): 169 | super().__init__() 170 | self.query_heads = query_heads 171 | self.kv_heads = kv_heads 172 | self.dropout = dropout 173 | self.layer_norm = layer_norm 174 | self.gamma_init = gamma_init 175 | 176 | if self.query_heads % self.kv_heads != 0: 177 | raise ValueError( 178 | f"query_heads ({query_heads}) must be divisible by " 179 | f"kv_heads ({kv_heads})" 180 | ) 181 | elif (embed_dim % self.query_heads != 0) or (embed_dim % self.kv_heads != 0): 182 | raise ValueError( 183 | f"embed_dim ({embed_dim}) must be divisible by " 184 | f"query_heads ({query_heads}) and kv_heads ({kv_heads})" 185 | ) 186 | 187 | head_dim = embed_dim // query_heads 188 | if not head_dim % 8 == 0: 189 | raise ValueError( 190 | f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8" 191 | ) 192 | if not head_dim <= 128: 193 | raise ValueError( 194 | f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128" 195 | ) 196 | 197 | # Query projection layer is the same as in vanilla MHA. 198 | self.q_proj = nn.Linear( 199 | embed_dim, embed_dim, bias=bias, device=device, dtype=dtype 200 | ) 201 | # Key/value projection layers have a smaller output dimension, so that 202 | # the we have fewer key/value attention heads after reshaping. 203 | kv_embed_dim = embed_dim // query_heads * kv_heads 204 | self.k_proj = nn.Linear( 205 | embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype 206 | ) 207 | self.v_proj = nn.Linear( 208 | embed_dim, kv_embed_dim, bias=bias, device=device, dtype=dtype 209 | ) 210 | self.norm: Optional[nn.LayerNorm] = None 211 | if layer_norm: 212 | self.norm = nn.LayerNorm( 213 | embed_dim, eps=layer_norm_eps, device=device, dtype=dtype 214 | ) 215 | # Grouped attention output will have the same embedding dimension as the 216 | # key/value Tensors. So the output projection layer needs to accept the 217 | # same dimension (kv_embed_dim). 218 | self.out_proj = nn.Linear( 219 | embed_dim, embed_dim, bias=bias, device=device, dtype=dtype 220 | ) 221 | 222 | self._reset_parameters() 223 | 224 | def _reset_parameters(self): 225 | nn.init.xavier_normal_(self.q_proj.weight) 226 | if self.q_proj.bias is not None: 227 | nn.init.constant_(self.q_proj.bias, 0) 228 | nn.init.xavier_normal_(self.k_proj.weight) 229 | if self.k_proj.bias is not None: 230 | nn.init.constant_(self.k_proj.bias, 0) 231 | 232 | # NOTE: We follow the initialization strategy from MAGNETO. See: 233 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 234 | # Gain (self.gamma_init) should be provided as a keyword argument when 235 | # initializing the larger Transformer model, since it requires knowledge 236 | # of the number of encoder/decoder layers in the model. 237 | 238 | nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init) 239 | if self.v_proj.bias is not None: 240 | nn.init.constant_(self.v_proj.bias, 0) 241 | nn.init.xavier_normal_(self.out_proj.weight, gain=self.gamma_init) 242 | if self.out_proj.bias is not None: 243 | nn.init.constant_(self.out_proj.bias, 0) 244 | 245 | def forward( 246 | self, 247 | query: Tensor, 248 | key: Tensor, 249 | value: Tensor, 250 | need_weights: bool = False, 251 | # TODO 252 | # attn_mask: Optional[Tensor] = None, 253 | is_causal: bool = False, 254 | average_attn_weights: bool = False, 255 | ) -> Tuple[Tensor, Optional[Tensor]]: 256 | # Notation: 257 | # b - batch size 258 | # n - sequence length 259 | # h - number of heads 260 | # d - embedding dimension 261 | # 262 | # Input shape: (b, n, d) 263 | q: Tensor = self.q_proj(query) 264 | k: Tensor = self.k_proj(key) 265 | v: Tensor = self.v_proj(value) 266 | 267 | # Unfold 'd' dimension into 'h' separate attention heads. 268 | q = rearrange(q, "b n (h d) -> b n h d", h=self.query_heads) 269 | k = rearrange(k, "b n (h d) -> b n h d", h=self.kv_heads) 270 | v = rearrange(v, "b n (h d) -> b n h d", h=self.kv_heads) 271 | # Apply attention, then fold 'h' attention heads back into 'd'. 272 | x, attn = scaled_dot_product_gqa( 273 | query=q, 274 | key=k, 275 | value=v, 276 | # TODO 277 | # mask=attn_mask, 278 | is_causal=is_causal, 279 | need_weights=need_weights, 280 | average_attn_weights=average_attn_weights, 281 | force_grouped=False, 282 | ) 283 | x = rearrange(x, "b n h d -> b n (h d)") 284 | 285 | # NOTE: This is different from 'nn.MultiheadAttention'! We follow the MAGNETO 286 | # architecture (https://arxiv.org/pdf/2210.06423.pdf), which applies an extra 287 | # layer norm before the linear output projection. The cross-attention layer in 288 | # the MAGNETO decoder does not include this layer norm, so users have the 289 | # option to disable it (layer_norm=False). 290 | if self.layer_norm: 291 | assert self.norm is not None 292 | x = self.norm(x) 293 | # Linear projection on attention outputs. 294 | x = self.out_proj(x) 295 | 296 | return x, attn 297 | -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/t5.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from typing import TypeVar, overload 5 | 6 | import torch 7 | from einops import einsum, rearrange, repeat 8 | from torch import nn 9 | from transformers.models.t5.modeling_t5 import T5Attention 10 | 11 | 12 | class T5GQA(nn.Module): 13 | def __init__( 14 | self, 15 | is_decoder: bool, 16 | d_model: int, 17 | key_value_proj_dim: int, 18 | n_heads: int, 19 | kv_heads: int, 20 | dropout: float, 21 | has_relative_attention_bias: bool, 22 | relative_attention_num_buckets: int, 23 | relative_attention_max_distance: int, 24 | ): 25 | super().__init__() 26 | if n_heads % kv_heads != 0: 27 | raise ValueError( 28 | f"n_heads ({n_heads}) must be divisible by kv_heads ({kv_heads})" 29 | ) 30 | 31 | self.is_decoder = is_decoder 32 | self.d_model = d_model 33 | self.key_value_proj_dim = key_value_proj_dim 34 | self.n_heads = n_heads 35 | # TODO: Check if we need to store 'kv_heads' and 'inner_dim' as a properties 36 | self.kv_heads = kv_heads 37 | self.dropout = dropout 38 | # NOTE: Relative attention bias typically only used in the first layer 39 | # of a `T5Stack` module. 40 | self.has_relative_attention_bias = has_relative_attention_bias 41 | self.relative_attention_num_buckets = relative_attention_num_buckets 42 | self.relative_attention_max_distance = relative_attention_max_distance 43 | 44 | self.inner_dim = self.n_heads * self.key_value_proj_dim 45 | self.kv_dim = self.kv_heads * self.key_value_proj_dim 46 | 47 | # Mesh TensorFlow initialization to avoid scaling before softmax 48 | # self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 49 | # self.k = nn.Linear(self.d_model, self.kv_dim, bias=False) 50 | # self.v = nn.Linear(self.d_model, self.kv_dim, bias=False) 51 | # self.o = nn.Linear(self.kv_dim, self.d_model, bias=False) 52 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 53 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 54 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 55 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 56 | 57 | if self.has_relative_attention_bias: 58 | self.relative_attention_bias = nn.Embedding( 59 | self.relative_attention_num_buckets, self.n_heads 60 | ) 61 | self.pruned_heads = set() # type: ignore 62 | self.gradient_checkpointing = False 63 | 64 | self._relative_position_bucket = T5Attention._relative_position_bucket 65 | 66 | @classmethod 67 | def from_t5_attention(cls, t5: T5Attention, kv_heads: int) -> T5GQA: 68 | t5_gqa = T5GQA( 69 | is_decoder=t5.is_decoder, 70 | d_model=t5.d_model, 71 | key_value_proj_dim=t5.key_value_proj_dim, 72 | n_heads=t5.n_heads, 73 | kv_heads=kv_heads, 74 | dropout=t5.dropout, 75 | has_relative_attention_bias=t5.has_relative_attention_bias, 76 | relative_attention_num_buckets=t5.relative_attention_num_buckets, 77 | relative_attention_max_distance=t5.relative_attention_max_distance, 78 | ) 79 | 80 | # Copy all of the weights verbatim from the original T5Attention module. 81 | # NOTE: In the T5 GQA implementation, all of the attention head aggregations 82 | # happen in the 'forward' method. The weights themselves are not modified. 83 | t5_gqa.q.weight.data = t5.q.weight.data 84 | t5_gqa.k.weight.data = t5.k.weight.data 85 | t5_gqa.v.weight.data = t5.v.weight.data 86 | t5_gqa.o.weight.data = t5.o.weight.data 87 | if t5.has_relative_attention_bias: 88 | t5_gqa.relative_attention_bias.weight.data = ( 89 | t5.relative_attention_bias.weight.data 90 | ) 91 | 92 | return t5_gqa 93 | 94 | def forward( # noqa: C901 95 | self, 96 | hidden_states, 97 | mask=None, 98 | key_value_states=None, 99 | position_bias=None, 100 | past_key_value=None, 101 | layer_head_mask=None, 102 | query_length=None, 103 | use_cache=False, 104 | output_attentions=False, 105 | ): 106 | """ 107 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 108 | """ 109 | # Input is (batch_size, seq_length, dim) 110 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 111 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 112 | batch_size, seq_length = hidden_states.shape[:2] 113 | 114 | real_seq_length = seq_length 115 | 116 | if past_key_value is not None: 117 | if len(past_key_value) != 2: 118 | raise ValueError( 119 | f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 120 | ) 121 | real_seq_length += ( 122 | past_key_value[0].shape[2] if query_length is None else query_length 123 | ) 124 | 125 | key_length = ( 126 | real_seq_length if key_value_states is None else key_value_states.shape[1] 127 | ) 128 | 129 | def shape(states): 130 | """projection""" 131 | # NOTE: Changed from the original definition in T5Attention. 132 | sequence_length = states.shape[1] 133 | return states.view( 134 | batch_size, sequence_length, -1, self.key_value_proj_dim 135 | ).transpose(1, 2) 136 | 137 | def unshape(states): 138 | """reshape""" 139 | # NOTE: Changed from the original definition in T5Attention. 140 | sequence_length = states.shape[2] 141 | return ( 142 | states.transpose(1, 2) 143 | .contiguous() 144 | .view(batch_size, sequence_length, -1) 145 | ) 146 | 147 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 148 | """projects hidden states correctly to key/query states""" 149 | if key_value_states is None: 150 | # self-attn 151 | # (batch_size, n_heads, seq_length, dim_per_head) 152 | hidden_states = shape(proj_layer(hidden_states)) 153 | elif past_key_value is None: 154 | # cross-attn 155 | # (batch_size, n_heads, seq_length, dim_per_head) 156 | hidden_states = shape(proj_layer(key_value_states)) 157 | 158 | if past_key_value is not None: 159 | if key_value_states is None: 160 | # self-attn 161 | # (batch_size, n_heads, key_length, dim_per_head) 162 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 163 | elif past_key_value.shape[2] != key_value_states.shape[1]: 164 | # checking that the `sequence_length` of the `past_key_value` is the same as 165 | # the provided `key_value_states` to support prefix tuning 166 | # cross-attn 167 | # (batch_size, n_heads, seq_length, dim_per_head) 168 | hidden_states = shape(proj_layer(key_value_states)) 169 | else: 170 | # cross-attn 171 | hidden_states = past_key_value 172 | return hidden_states 173 | 174 | # get query states: (batch_size, n_heads, seq_length, dim_per_head) 175 | grouped_queries = shape(self.q(hidden_states)) 176 | # get key/value states 177 | key_states = project( 178 | hidden_states, 179 | self.k, 180 | key_value_states, 181 | past_key_value[0] if past_key_value is not None else None, 182 | ) 183 | value_states = project( 184 | hidden_states, 185 | self.v, 186 | key_value_states, 187 | past_key_value[1] if past_key_value is not None else None, 188 | ) 189 | 190 | # # compute scores 191 | # scores = torch.matmul( 192 | # query_states, key_states.transpose(3, 2) 193 | # ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 194 | grouped_queries = rearrange( 195 | grouped_queries, "b (g h) n d -> b g h n d", h=self.kv_heads 196 | ) 197 | grouped_keys = rearrange( 198 | key_states, "b (g h) s d -> b g h s d", h=self.kv_heads 199 | ).mean(dim=1) 200 | scores = einsum(grouped_queries, grouped_keys, "b g h n d, b h s d -> b h n s") 201 | 202 | if position_bias is None: 203 | if not self.has_relative_attention_bias: 204 | position_bias = torch.zeros( 205 | # NOTE: This is different from the original in T5Attention! 206 | # (1, self.n_heads, real_seq_length, key_length), 207 | (1, self.kv_heads, real_seq_length, key_length), 208 | device=scores.device, 209 | dtype=scores.dtype, 210 | ) 211 | if self.gradient_checkpointing and self.training: 212 | position_bias.requires_grad = True 213 | else: 214 | position_bias = T5Attention.compute_bias( 215 | self, real_seq_length, key_length, device=scores.device 216 | ) 217 | 218 | # if key and values are already calculated 219 | # we want only the last query position bias 220 | if past_key_value is not None: 221 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :] 222 | 223 | if mask is not None: 224 | # (batch_size, n_heads, seq_length, key_length) 225 | position_bias = position_bias + mask 226 | 227 | if self.pruned_heads: 228 | mask = torch.ones(position_bias.shape[1]) 229 | mask[list(self.pruned_heads)] = 0 230 | position_bias_masked = position_bias[:, mask.bool()] 231 | else: 232 | position_bias_masked = position_bias 233 | 234 | # NOTE: This is different from the original in T5Attention! 235 | grouped_position_bias = rearrange( 236 | position_bias_masked, "b (g h) n s -> b g h n s", h=self.kv_heads 237 | ).mean(dim=1) 238 | 239 | scores += grouped_position_bias 240 | # attn_weights: (batch_size, kv_heads, seq_length, key_length) 241 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) 242 | attn_weights = nn.functional.dropout( 243 | attn_weights, p=self.dropout, training=self.training 244 | ) 245 | 246 | # Mask heads if we want to 247 | if layer_head_mask is not None: 248 | attn_weights = attn_weights * layer_head_mask 249 | 250 | # NOTE: This is different from the original in T5Attention! 251 | # attn_output = unshape(torch.matmul(attn_weights, value_states)) 252 | grouped_values = rearrange( 253 | value_states, "b (g h) s d -> b g h s d", h=self.kv_heads 254 | ).mean(dim=1) 255 | attn_output = unshape(torch.matmul(attn_weights, grouped_values)) 256 | attn_output = repeat( 257 | attn_output, "b s d -> b s (g d)", g=(self.n_heads // self.kv_heads) 258 | ) 259 | attn_output = self.o(attn_output) 260 | 261 | present_key_value_state = ( 262 | (key_states, value_states) if (self.is_decoder and use_cache) else None 263 | ) 264 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 265 | 266 | if output_attentions: 267 | outputs = outputs + (attn_weights,) # type: ignore 268 | return outputs 269 | 270 | 271 | ModuleType = TypeVar("ModuleType", bound=nn.Module) 272 | 273 | 274 | @overload 275 | def convert_t5_to_gqa( 276 | module: ModuleType, kv_heads: int, inplace: bool = False 277 | ) -> ModuleType: 278 | ... 279 | 280 | 281 | @overload 282 | def convert_t5_to_gqa( 283 | module: T5Attention, kv_heads: int, inplace: bool = False 284 | ) -> T5GQA: 285 | ... 286 | 287 | 288 | def convert_t5_to_gqa(module, kv_heads: int, inplace: bool = False): 289 | if isinstance(module, T5Attention): 290 | return T5GQA.from_t5_attention(module, kv_heads=kv_heads) 291 | 292 | out = module if inplace else deepcopy(module) 293 | for name, child in out.named_children(): 294 | out._modules[name] = convert_t5_to_gqa(child, kv_heads=kv_heads, inplace=True) 295 | return out 296 | 297 | 298 | if __name__ == "__main__": 299 | from transformers import T5ForConditionalGeneration, T5Tokenizer 300 | 301 | # NOTE: The original paper uses T5 v1.1 XL and XXL models. When I load those 302 | # models through 'transformers' without applying GQA, I get nonsense outputs. 303 | # TODO: Figure out why this is happening. 304 | # tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-large", legacy=False) 305 | # model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-large") 306 | # 307 | # In the meantime, we can use the non-Google T5 models, which seem to work fine. 308 | # NOTE: Since the the original number of heads (n_heads) must be divisible by 309 | # 'kv_heads', there are only certain values of 'kv_heads' that we can use. 310 | # To the best of my knowledge, the following values of 'kv_heads' are valid: 311 | # - t5-small: 1, 2, 4, 8 312 | # - t5-base: 1, 2, 3, 4, 6, 12 313 | # - t5-large: 1, 2, 4, 8, 16 314 | # - t5-3b: 1, 2, 4, 8, 16, 32 315 | # - t5-11b: 1, 2, 4, 8, 16, 32, 64 TODO: Check 11b values specifically 316 | 317 | tokenizer = T5Tokenizer.from_pretrained( 318 | "t5-base", legacy=False, model_max_length=512 319 | ) 320 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( 321 | "t5-base" 322 | ) 323 | gqa = convert_t5_to_gqa(t5, kv_heads=6) 324 | 325 | input_ids = tokenizer( 326 | "translate English to German: The house is wonderful.", return_tensors="pt" 327 | ).input_ids 328 | y2 = gqa.generate(input_ids, max_new_tokens=25) 329 | text = tokenizer.batch_decode(y2[0], skip_special_tokens=True) 330 | print(text) 331 | # The correct answer is: ['', 'Das', 'Haus', 'ist', 'wunderbar', '.', ''] 332 | # NOTE: The original T5 model produces this answer, and so does GQA when we use 333 | # the maximum number of heads -- effectively equivalent to the original T5 model 334 | # with MHA. The text quickly degrades as we reduce the number of heads. 335 | 336 | labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids 337 | loss = gqa(input_ids=input_ids, labels=labels).loss 338 | print(f"Loss: {loss}") 339 | # NOTE: As above, the loss quickly degrades (increases) as we reduce the number 340 | # of GQA heads. 341 | -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/transformer.py: -------------------------------------------------------------------------------- 1 | from math import log 2 | from typing import Callable, Optional, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor, nn 7 | from torch.nn.modules.transformer import _get_activation_fn 8 | from torchscale.component.xpos_relative_position import XPOS 9 | 10 | from grouped_query_attention_pytorch.attention import MultiheadGQA 11 | 12 | 13 | class GQATransformerEncoderLayer(nn.Module): 14 | # NOTE: Mostly pulled from 'nn.TransformerEncoderLayer', but with changes: 15 | # - use sub-LayerNorm like in MAGNETO. See: https://arxiv.org/abs/2210.06423 16 | # - use MultiheadGQA instead of MultiheadAttention 17 | 18 | def __init__( 19 | self, 20 | d_model: int, 21 | nhead: int, 22 | kv_heads: int, 23 | dim_feedforward: int = 2048, 24 | dropout: float = 0.1, 25 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 26 | layer_norm_eps: float = 1e-5, 27 | gamma_init: float = 1.0, 28 | device: Optional[Union[torch.device, str]] = None, 29 | dtype: Optional[torch.dtype] = None, 30 | ) -> None: 31 | super().__init__() 32 | # Legacy string support for activation function. 33 | if isinstance(activation, str): 34 | activation = _get_activation_fn(activation) 35 | 36 | self.activation = activation 37 | self.gamma_init = gamma_init 38 | 39 | self.dropout = nn.Dropout(dropout) 40 | # Self-attention block 41 | self.norm1 = nn.LayerNorm( 42 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 43 | ) 44 | self.self_attn = MultiheadGQA( # type: ignore 45 | embed_dim=d_model, 46 | query_heads=nhead, 47 | kv_heads=kv_heads, 48 | dropout=dropout, 49 | layer_norm=True, 50 | layer_norm_eps=layer_norm_eps, 51 | gamma_init=gamma_init, 52 | device=device, 53 | dtype=dtype, 54 | ) 55 | # Feedforward block 56 | self.norm2 = nn.LayerNorm( 57 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 58 | ) 59 | self.linear1 = nn.Linear(d_model, dim_feedforward, device=device, dtype=dtype) 60 | self.norm3 = nn.LayerNorm( 61 | dim_feedforward, eps=layer_norm_eps, device=device, dtype=dtype 62 | ) 63 | self.linear2 = nn.Linear(dim_feedforward, d_model, device=device, dtype=dtype) 64 | 65 | self._reset_parameters() 66 | 67 | def _reset_parameters(self): 68 | # NOTE: We follow the initialization strategy from MAGNETO. See: 69 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 70 | # The 'MultiheadGQA' module uses ths same initialization, 71 | # so we just need to worry about the 'Linear' modules here. 72 | nn.init.xavier_normal_(self.linear1.weight, gain=self.gamma_init) 73 | nn.init.constant_(self.linear1.bias, 0) 74 | nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init) 75 | nn.init.constant_(self.linear2.bias, 0) 76 | 77 | def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor: 78 | x = self.norm1(x) 79 | x, _ = self.self_attn(x, x, x, is_causal=is_causal) 80 | x = self.dropout(x) 81 | return x 82 | 83 | def _feedforward_block(self, x: Tensor) -> Tensor: 84 | x = self.norm2(x) 85 | x = self.activation(self.linear1(x)) 86 | x = self.dropout(x) 87 | x = self.norm3(x) 88 | x = self.linear2(x) 89 | x = self.dropout(x) 90 | return x 91 | 92 | def forward(self, src: Tensor, is_causal: bool = False) -> Tensor: 93 | x = src 94 | x = x + self._self_attention_block(x, is_causal=is_causal) 95 | x = x + self._feedforward_block(x) 96 | return x 97 | 98 | 99 | class GQATransformerDecoderLayer(nn.Module): 100 | # NOTE: Mostly pulled from 'nn.TransformerDecoderLayer', but with changes: 101 | # - use sub-LayerNorm like in MAGNETO. See: https://arxiv.org/abs/2210.06423 102 | # - use MultiheadGQA instead of MultiheadAttention 103 | 104 | def __init__( 105 | self, 106 | d_model: int, 107 | nhead: int, 108 | kv_heads: int, 109 | dim_feedforward: int = 2048, 110 | dropout: float = 0.1, 111 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 112 | layer_norm_eps: float = 1e-5, 113 | gamma_init: float = 1.0, 114 | device: Optional[Union[torch.device, str]] = None, 115 | dtype: Optional[torch.dtype] = None, 116 | ) -> None: 117 | super().__init__() 118 | # Legacy string support for activation function. 119 | if isinstance(activation, str): 120 | activation = _get_activation_fn(activation) 121 | 122 | self.activation = activation 123 | self.gamma_init = gamma_init 124 | 125 | self.dropout = nn.Dropout(dropout) 126 | # Self-attention block 127 | self.norm1 = nn.LayerNorm( 128 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 129 | ) 130 | self.self_attn = MultiheadGQA( # type: ignore 131 | embed_dim=d_model, 132 | query_heads=nhead, 133 | kv_heads=kv_heads, 134 | dropout=dropout, 135 | layer_norm=False, 136 | gamma_init=gamma_init, 137 | device=device, 138 | dtype=dtype, 139 | ) 140 | # Multi-head attention block 141 | self.norm2 = nn.LayerNorm( 142 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 143 | ) 144 | self.multihead_attn = MultiheadGQA( # type: ignore 145 | embed_dim=d_model, 146 | query_heads=nhead, 147 | kv_heads=kv_heads, 148 | dropout=dropout, 149 | layer_norm_eps=layer_norm_eps, 150 | gamma_init=gamma_init, 151 | device=device, 152 | dtype=dtype, 153 | ) 154 | # Feedforward block 155 | self.norm3 = nn.LayerNorm( 156 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 157 | ) 158 | self.linear1 = nn.Linear(d_model, dim_feedforward, device=device, dtype=dtype) 159 | self.norm4 = nn.LayerNorm( 160 | dim_feedforward, eps=layer_norm_eps, device=device, dtype=dtype 161 | ) 162 | self.linear2 = nn.Linear(dim_feedforward, d_model, device=device, dtype=dtype) 163 | 164 | self._reset_parameters() 165 | 166 | def _reset_parameters(self): 167 | # NOTE: We follow the initialization strategy from MAGNETO. See: 168 | # https://arxiv.org/pdf/2210.06423.pdf, Fig. 2 169 | # The 'MultiheadGQA' module uses ths same initialization, 170 | # so we just need to worry about the 'Linear' modules here. 171 | nn.init.xavier_normal_(self.linear1.weight, gain=self.gamma_init) 172 | nn.init.constant_(self.linear1.bias, 0) 173 | nn.init.xavier_normal_(self.linear2.weight, gain=self.gamma_init) 174 | nn.init.constant_(self.linear2.bias, 0) 175 | 176 | def _self_attention_block(self, x: Tensor, is_causal: bool = False) -> Tensor: 177 | x = self.norm1(x) 178 | x, _ = self.self_attn(x, x, x, is_causal=is_causal) 179 | x = self.dropout(x) 180 | return x 181 | 182 | def _multihead_attention_block( 183 | self, x: Tensor, memory: Tensor, is_causal: bool = False 184 | ) -> Tensor: 185 | x = self.norm2(x) 186 | x, _ = self.multihead_attn(x, memory, memory, is_causal=is_causal) 187 | x = self.dropout(x) 188 | return x 189 | 190 | def _feedforward_block(self, x: Tensor) -> Tensor: 191 | x = self.norm3(x) 192 | x = self.activation(self.linear1(x)) 193 | x = self.dropout(x) 194 | x = self.norm4(x) 195 | x = self.linear2(x) 196 | x = self.dropout(x) 197 | return x 198 | 199 | def forward( 200 | self, 201 | tgt: Tensor, 202 | memory: Tensor, 203 | tgt_is_causal: bool = False, 204 | memory_is_causal: bool = False, 205 | ) -> Tensor: 206 | x = tgt 207 | x = x + self._self_attention_block(x, is_causal=tgt_is_causal) 208 | x = x + self._multihead_attention_block(x, memory, is_causal=memory_is_causal) 209 | x = x + self._feedforward_block(x) 210 | return x 211 | 212 | 213 | class GQATransformer(nn.Module): 214 | def __init__( 215 | self, 216 | d_model: int = 512, 217 | nhead: int = 8, 218 | kv_heads: int = 4, 219 | num_encoder_layers: int = 6, 220 | num_decoder_layers: int = 6, 221 | dim_feedforward: int = 2048, 222 | dropout: float = 0.1, 223 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 224 | layer_norm_eps: float = 1e-5, 225 | device: Optional[torch.device] = None, 226 | dtype: Optional[torch.dtype] = None, 227 | ): 228 | super().__init__() 229 | # The 'gamma_init' parameters are different for the encoder and decoder, 230 | # and depend on the number of encoder/decoder layers. See MAGNETO paper: 231 | # https://arxiv.org/pdf/2210.06423.pdf, Figure 2 232 | encoder_gamma_init = ( 233 | log(3 * num_decoder_layers) * log(2 * num_encoder_layers) / 3 234 | ) ** 0.5 235 | decoder_gamma_init = log(3 * num_decoder_layers) ** 0.5 236 | 237 | self.encoder = nn.TransformerEncoder( 238 | encoder_layer=GQATransformerEncoderLayer( 239 | d_model=d_model, 240 | nhead=nhead, 241 | kv_heads=kv_heads, 242 | dim_feedforward=dim_feedforward, 243 | dropout=dropout, 244 | activation=activation, 245 | layer_norm_eps=layer_norm_eps, 246 | gamma_init=encoder_gamma_init, 247 | device=device, 248 | dtype=dtype, 249 | ), 250 | num_layers=num_encoder_layers, 251 | mask_check=False, 252 | enable_nested_tensor=False, 253 | ) 254 | self.decoder = nn.TransformerDecoder( 255 | decoder_layer=GQATransformerDecoderLayer( 256 | d_model=d_model, 257 | nhead=nhead, 258 | kv_heads=kv_heads, 259 | dim_feedforward=dim_feedforward, 260 | dropout=dropout, 261 | activation=activation, 262 | layer_norm_eps=layer_norm_eps, 263 | gamma_init=decoder_gamma_init, 264 | device=device, 265 | dtype=dtype, 266 | ), 267 | num_layers=num_decoder_layers, 268 | ) 269 | 270 | def forward(self, x: Tensor, is_causal: bool = True) -> Tensor: 271 | """ 272 | Input shape: (batch_size, seq_len, d_model) 273 | Output shape: (batch_size, seq_len, d_model) 274 | 275 | NOTE: Assume that 'is_causal' applies to both the encoder and decoder. 276 | This is the case for language modeling, but maybe not for other tasks. 277 | """ 278 | tgt = x 279 | for layer in self.encoder.layers: 280 | x = layer(x, is_causal=is_causal) 281 | if self.encoder.norm is not None: 282 | x = self.encoder.norm(x) 283 | 284 | mem = x 285 | for layer in self.decoder.layers: 286 | tgt = layer(tgt, mem, memory_is_causal=is_causal, tgt_is_causal=is_causal) 287 | if self.decoder.norm is not None: 288 | tgt = self.decoder.norm(tgt) 289 | 290 | return tgt 291 | 292 | 293 | class GQATransformerLM(nn.Module): 294 | def __init__( 295 | self, 296 | num_tokens: int, # (required) usually obtained from the tokenizer 297 | d_model: int = 512, 298 | nhead: int = 8, 299 | kv_heads: int = 4, 300 | num_encoder_layers: int = 6, 301 | num_decoder_layers: int = 6, 302 | dim_feedforward: int = 2048, 303 | dropout: float = 0.1, 304 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 305 | layer_norm_eps: float = 1e-5, 306 | device: Optional[torch.device] = None, 307 | dtype: Optional[torch.dtype] = None, 308 | ): 309 | super().__init__() 310 | self.token_embedding = nn.Embedding( 311 | num_tokens, d_model, device=device, dtype=dtype 312 | ) 313 | # TODO: Add support for other positional encodings? I use XPOS, which is the 314 | # "latest and greatest" at the time of writing. In principle, we could swap 315 | # it out for any other encoding, and remove the 'torchscale' dependency for this 316 | # repo, which is only used for XPOS. 317 | self.pos_embedding = XPOS(d_model).to(device=device, dtype=dtype) 318 | self.transformer = GQATransformer( 319 | d_model=d_model, 320 | nhead=nhead, 321 | kv_heads=kv_heads, 322 | num_encoder_layers=num_encoder_layers, 323 | num_decoder_layers=num_decoder_layers, 324 | dim_feedforward=dim_feedforward, 325 | dropout=dropout, 326 | activation=activation, 327 | layer_norm_eps=layer_norm_eps, 328 | device=device, 329 | dtype=dtype, 330 | ) 331 | self.norm = nn.LayerNorm( 332 | d_model, eps=layer_norm_eps, device=device, dtype=dtype 333 | ) 334 | self.out = nn.Linear(d_model, num_tokens, device=device, dtype=dtype) 335 | 336 | def _reset_parameters(self): 337 | nn.init.kaiming_normal_(self.out.weight) 338 | nn.init.constant_(self.out.bias, 0) 339 | 340 | def forward(self, x: Tensor, is_causal: bool = True) -> Tensor: 341 | x = self.token_embedding(x) 342 | x = x + self.pos_embedding(x) 343 | x = self.transformer(x, is_causal=is_causal) 344 | x = self.norm(x) 345 | return self.out(x) 346 | 347 | 348 | if __name__ == "__main__": 349 | num_tokens = 2048 350 | device = torch.device("cuda") 351 | dtype = torch.float16 352 | 353 | x = torch.randint(0, num_tokens - 1, size=(2, 512), device=device) 354 | model = GQATransformerLM(num_tokens=num_tokens, device=device, dtype=dtype) 355 | 356 | with torch.no_grad(): 357 | out = model(x) 358 | print(out.shape) 359 | -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/grouped_query_attention_pytorch/utils/__init__.py -------------------------------------------------------------------------------- /grouped_query_attention_pytorch/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from timeit import Timer 3 | from typing import Callable, List, NamedTuple 4 | 5 | import torch 6 | 7 | 8 | class BenchmarkResult(NamedTuple): 9 | mean: float 10 | std: float 11 | 12 | def __repr__(self): 13 | return f"BenchmarkResult(mean: {self.mean:.3e}, std: {self.std:.3e})" 14 | 15 | def __str__(self): 16 | return f"({self.mean:.3e} \u00B1 {self.std:.3e}) s" 17 | 18 | 19 | @torch.no_grad() 20 | def benchmark( 21 | fn: Callable, 22 | *args, 23 | min_total_seconds: float = 1.0, 24 | min_iterations: int = 10, 25 | **kwargs, 26 | ) -> BenchmarkResult: 27 | # Benchmark the runtime of a function and dynamically determine the number of 28 | # iterations to run. Continue running the function until *total* runtime 29 | # exceeds 'min_total_seconds' and 'min_iterations'. 30 | if min_iterations < 2: 31 | raise ValueError("min_iterations must be >= 2") 32 | 33 | timer = Timer( 34 | "fn(*args, **kwargs); synchronize()", 35 | globals={ 36 | "fn": fn, 37 | "args": args, 38 | "kwargs": kwargs, 39 | "synchronize": torch.cuda.synchronize, 40 | }, 41 | ) 42 | # Run the function 5 times to warm up 43 | _ = timer.repeat(number=1, repeat=5) 44 | 45 | times: List[float] = [] 46 | total_time = 0.0 47 | num_iterations = min_iterations 48 | 49 | while total_time < min_total_seconds: 50 | _times = timer.repeat(number=1, repeat=num_iterations) 51 | times.extend(_times) 52 | 53 | times_tensor = torch.as_tensor(times) 54 | total_time = times_tensor.sum().item() 55 | avg_time = times_tensor.mean().item() 56 | num_iterations = ceil((min_total_seconds - total_time) / avg_time) 57 | 58 | times_tensor = torch.as_tensor(times) 59 | return BenchmarkResult( 60 | mean=times_tensor.mean().item(), 61 | std=times_tensor.std().item(), 62 | ) 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | 4 | [project] 5 | name = "grouped_query_attention_pytorch" 6 | authors = [ 7 | {name = "Frank Odom", email = "frank.odom.iii@gmail.com"}, 8 | ] 9 | description = "grouped-query-attention-pytorch" 10 | license = {text = "MIT"} 11 | dynamic = [ 12 | "version", 13 | "readme", 14 | ] # NOTE: Must be in sync with [tool.setuptools.dynamic] below 15 | dependencies = [ 16 | # TODO: Check the full range of supported versions 17 | "einops~=0.6.0", 18 | "torch>=1.8.0", 19 | "torchscale~=0.2.0", 20 | ] 21 | requires-python = ">=3.8" 22 | classifiers = ["Programming Language :: Python :: 3"] 23 | 24 | [tool.setuptools.dynamic] 25 | # NOTE: Must be in sync with 'project.dynamic' above 26 | version = {attr = "grouped_query_attention_pytorch.VERSION"} 27 | readme = {file = ["README.md"], content-type = "text/markdown"} 28 | 29 | [tool.setuptools.packages.find] 30 | exclude = ["tests"] 31 | 32 | # --- extra packages --- 33 | [project.optional-dependencies] 34 | test = [ 35 | "black", 36 | "kaleido", 37 | "mypy", 38 | "plotly", 39 | "pre-commit", 40 | "pytest", 41 | "pytest-cov", 42 | "ruff", 43 | "xformers==0.0.20", 44 | ] 45 | t5 = [ 46 | "sentencepiece", 47 | "transformers>=4.5.0,<4.32", 48 | ] 49 | 50 | [project.scripts] 51 | # Entrypoint scripts 52 | 53 | 54 | # ----- Linting, Formatting, and Typing ----- 55 | 56 | [tool.black] 57 | line-length = 88 58 | 59 | [tool.mypy] 60 | files = "grouped_query_attention_pytorch/" 61 | check_untyped_defs = "true" 62 | ignore_missing_imports = "true" 63 | 64 | [tool.pytest.ini_options] 65 | testpaths = ["tests"] 66 | addopts = "--cov --cov-report term-missing --cov-fail-under 80" 67 | filterwarnings = "ignore:.*.:DeprecationWarning" 68 | 69 | [tool.ruff] 70 | line-length = 88 71 | ignore = ["B905", "E501"] 72 | select = ["B", "C", "E", "F", "I", "W"] 73 | # Exclude a variety of commonly ignored directories. 74 | exclude = [ 75 | ".bzr", 76 | ".direnv", 77 | ".eggs", 78 | ".git", 79 | ".hg", 80 | ".mypy_cache", 81 | ".nox", 82 | ".pants.d", 83 | ".ruff_cache", 84 | ".svn", 85 | ".tox", 86 | ".venv", 87 | "__pypackages__", 88 | "_build", 89 | "buck-out", 90 | "build", 91 | "dist", 92 | "node_modules", 93 | "venv", 94 | ] 95 | 96 | [tool.ruff.mccabe] 97 | max-complexity = 18 -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # scripts 2 | 3 | Scripts should be launched from the root of this repository. Ex: 4 | 5 | ```bash 6 | python -m scripts.benchmark_attention 7 | ``` 8 | 9 | 10 | ## benchmark_attention 11 | 12 | Gather runtime benchmarks for scaled dot product attention with grouped queries. In particular, we want to see how the runtime scales with the number of queries. We compare vanilla (naive) attention compares to grouped attention. We also compare against `xformers.ops.memory_efficient_attention` as a strong baseline. 13 | 14 | ### Reference 15 | The original paper benchmarks end-to-end runtime of the T5 model (https://arxiv.org/pdf/2305.13245v1.pdf, Figure 6). We do that in a separate benchmark (see `scripts/benchmark_t5.py`). Here, we focus on the attention layer itself. 16 | 17 | ### Results 18 | 19 | Clearly, runtime scales similarly with the number of GQA groups. 20 | 21 | Even through `xformers` is much faster than the naive implementation, GQA is still faster when the number of groups is small. Hopefully, someone will write an efficient CUDA implementation for GQA, so we can get the best of both worlds. Unfortunately, I likely don't have the CUDA experience to do it myself. :( 22 | 23 | #### This repo (attention layer only) 24 | 25 | drawing 26 | 27 | #### Original paper (end-to-end T5) 28 | 29 | drawing 30 | 31 | 32 | ## benchmark_t5 33 | 34 | Similar to [benchmark_attention](#benchmark_attention) above, but we benchmark the end-to-end T5 model. We compare the original T5 implementation with MHA to T5 with converted GQA. 35 | 36 | The same hardware differences apply as above. The original paper benchmarked T5-XXL (11B params), which does not fit in my GPU memory. Instead, I benchmark T5-3B, which is the largest T5 variant that will fit in memory. T5-3B only has 32 attention heads, so my benchmarks only go up to 32 GQA groups. (The original benchmarks with T5-XXL go up to 64 groups.) I use an input sequence length of 512 (similar to training) and a batch size of 8 (original uses 32). 37 | 38 | ### Reference 39 | https://arxiv.org/pdf/2305.13245v1.pdf, Figure 6 40 | 41 | ### Results 42 | 43 | Again, it's clear that runtime scales similarly with the number of GQA groups. 44 | 45 | #### This repo 46 | 47 | drawing 48 | 49 | #### Original paper 50 | 51 | drawing 52 | -------------------------------------------------------------------------------- /scripts/benchmark_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Sequence 3 | 4 | import plotly.graph_objects as go 5 | import torch 6 | import xformers.ops as xops 7 | 8 | from grouped_query_attention_pytorch.attention import ( 9 | scaled_dot_product_gqa, 10 | ) 11 | from grouped_query_attention_pytorch.utils.benchmark import BenchmarkResult, benchmark 12 | 13 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 15 | 16 | # Benchmarking parameters 17 | NUM_GROUPS = [1, 4, 8, 16, 32, 64] 18 | TOTAL_TOKENS = 8192 19 | NUM_HEADS = 64 20 | EMBED_DIM = 8 21 | SEQ_LENGTH = 2048 22 | SAVE_PATH = os.path.join("doc", "benchmark_attention.png") 23 | 24 | BENCHMARK_SETTINGS_TEMPLATE = """ 25 | Benchmark settings: 26 | device: {device} 27 | dtype: {dtype} 28 | total_tokens: {total_tokens} 29 | seq_length: {seq_length} 30 | num_heads: {num_heads} 31 | embed_dim: {embed_dim} 32 | """ 33 | 34 | 35 | def main( 36 | num_groups: Sequence[int] = NUM_GROUPS, 37 | total_tokens: int = TOTAL_TOKENS, 38 | seq_length: int = SEQ_LENGTH, 39 | num_heads: int = NUM_HEADS, 40 | embed_dim: int = EMBED_DIM, 41 | device: torch.device = DEVICE, 42 | dtype: torch.dtype = DTYPE, 43 | save_path: str = SAVE_PATH, 44 | ): 45 | batch_size = total_tokens // seq_length 46 | q = torch.randn( 47 | batch_size, seq_length, num_heads, embed_dim, device=device, dtype=dtype 48 | ) 49 | kv = torch.randn( 50 | batch_size, seq_length, num_heads, embed_dim, device=device, dtype=dtype 51 | ) 52 | 53 | _ = scaled_dot_product_gqa(q, kv, kv) 54 | vanilla_result = benchmark(scaled_dot_product_gqa, q, kv, kv) 55 | print(f"Vanilla: {vanilla_result}") 56 | xformers_result = benchmark(xops.memory_efficient_attention, q, kv, kv) 57 | print(f"Flash Attn: {xformers_result}") 58 | 59 | grouped_times: List[BenchmarkResult] = [] 60 | for g in num_groups: 61 | kv = torch.randn( 62 | batch_size, seq_length, g, embed_dim, dtype=dtype, device=device 63 | ) 64 | grouped_result = benchmark( 65 | scaled_dot_product_gqa, q, kv, kv, force_grouped=True 66 | ) 67 | grouped_times.append(grouped_result) 68 | print(f"Grouped (g={g}): {grouped_result}") 69 | 70 | fig = go.Figure() 71 | fig.add_trace( 72 | go.Scatter( 73 | x=num_groups, 74 | y=[vanilla_result.mean * 1000] * len(num_groups), 75 | mode="lines", 76 | line={"dash": "dash"}, 77 | name="Vanilla MHA", 78 | ) 79 | ) 80 | fig.add_trace( 81 | go.Scatter( 82 | x=num_groups, 83 | y=[r.mean * 1000 for r in grouped_times], 84 | mode="lines", 85 | name="GQA", 86 | ) 87 | ) 88 | fig.add_trace( 89 | go.Scatter( 90 | x=num_groups, 91 | y=[grouped_times[0].mean * 1000] * len(num_groups), 92 | mode="lines", 93 | line={"dash": "dash"}, 94 | name="MQA", 95 | ) 96 | ) 97 | fig.add_trace( 98 | go.Scatter( 99 | x=num_groups, 100 | y=[xformers_result.mean * 1000] * len(num_groups), 101 | mode="lines", 102 | line={"dash": "dash"}, 103 | name="Flash Attn (v1 - xformers)", 104 | ) 105 | ) 106 | fig.update_layout( 107 | title="Attention Benchmarks", 108 | xaxis_title="GQA Groups", 109 | yaxis_title="Runtime (ms)", 110 | # use log-scale for x-axis 111 | xaxis={"tickmode": "array", "tickvals": num_groups, "type": "log"}, 112 | # place legend at center-left 113 | legend={"x": 0.1, "y": 0.5}, 114 | ) 115 | fig.write_image(save_path) 116 | 117 | 118 | if __name__ == "__main__": 119 | import argparse 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "--num-groups", 124 | type=int, 125 | nargs="+", 126 | default=NUM_GROUPS, 127 | help="Sequence of GQA group sizes for benchmarking (default: [1, 4, 8, 16, 32, 64])", 128 | ) 129 | parser.add_argument( 130 | "--total-tokens", 131 | type=int, 132 | default=TOTAL_TOKENS, 133 | help="Total number of tokens in the batch (default: 8192)", 134 | ) 135 | parser.add_argument( 136 | "--seq-length", 137 | type=int, 138 | default=SEQ_LENGTH, 139 | help="Sequence length of the input (default: 2048)", 140 | ) 141 | parser.add_argument( 142 | "--num-heads", 143 | type=int, 144 | default=NUM_HEADS, 145 | help="Number of attention heads (default: 64)", 146 | ) 147 | parser.add_argument( 148 | "--embed-dim", 149 | type=int, 150 | default=EMBED_DIM, 151 | help="Embedding dimension of the input (default: 8)", 152 | ) 153 | parser.add_argument( 154 | "--device", 155 | type=str, 156 | default=DEVICE, 157 | help="Device to run the benchmark on (default: 'cuda' if available else 'cpu')", 158 | ) 159 | parser.add_argument( 160 | "--dtype", 161 | type=str, 162 | default=DTYPE, 163 | help="Data type to run the benchmark on (default: 'float16' if cuda is available else 'float32')", 164 | ) 165 | parser.add_argument( 166 | "--save-path", 167 | type=str, 168 | default=SAVE_PATH, 169 | help="Path to save the benchmark plot (default: 'doc/benchmark_attention.png')", 170 | ) 171 | args = parser.parse_args() 172 | 173 | print( 174 | BENCHMARK_SETTINGS_TEMPLATE.format( 175 | device=args.device, 176 | dtype=args.dtype, 177 | total_tokens=args.total_tokens, 178 | seq_length=args.seq_length, 179 | num_heads=args.num_heads, 180 | embed_dim=args.embed_dim, 181 | ) 182 | ) 183 | 184 | main(**vars(args)) 185 | -------------------------------------------------------------------------------- /scripts/benchmark_t5.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Sequence 3 | 4 | import plotly.graph_objects as go 5 | import torch 6 | from torch import Tensor 7 | from transformers import T5ForConditionalGeneration, T5Tokenizer 8 | 9 | from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa 10 | from grouped_query_attention_pytorch.utils.benchmark import BenchmarkResult, benchmark 11 | 12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 14 | 15 | # Benchmarking parameters 16 | MODEL_NAME = "t5-3b" 17 | NUM_GROUPS = [1, 2, 4, 8, 16, 32] 18 | TOTAL_TOKENS = 4096 19 | MODEL_MAX_LENGTH = 512 20 | SAVE_PATH = os.path.join("doc", "benchmark_t5.png") 21 | 22 | BENCHMARK_SETTINGS_TEMPLATE = """ 23 | Benchmark settings: 24 | device: {device} 25 | dtype: {dtype} 26 | model_name: {model_name} 27 | total_tokens: {total_tokens} 28 | model_max_length: {model_max_length} 29 | """ 30 | # NOTE: Text sample taken from the CNN/Daily Mail training set 31 | INPUT_TEXT = """LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains 32 | access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but 33 | he insists the money won't cast a spell on him. Daniel Radcliffe as Harry Potter in 34 | "Harry Potter and the Order of the Phoenix" To the disappointment of gossip 35 | columnists around the world, the young actor says he has no plans to fritter his 36 | cash away on fast cars, drink and celebrity parties. "I don't plan to be one of 37 | those people who, as soon as they turn 18, suddenly buy themselves a massive sports 38 | car collection or something similar," he told an Australian interviewer earlier this 39 | month. "I don't think I'll be particularly extravagant. "The things I like buying 40 | are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe 41 | will be able to gamble in a casino, buy a drink in a pub or see the horror film 42 | "Hostel: Part II," currently six places below his number one movie on the UK box 43 | office chart. Details of how he'll mark his landmark birthday are under wraps. 44 | His agent and publicist had no comment on his plans. "I'll definitely have some 45 | sort of party," he said in an interview. "Hopefully none of you will be reading 46 | about it." Radcliffe's earnings from the first five Potter films have been held 47 | in a trust fund which he has not been able to touch. Despite his growing fame and 48 | riches, the actor says he is keeping his feet firmly on the ground. "People are 49 | always looking to say 'kid star goes off the rails,'" he told reporters last month. 50 | "But I try very hard not to go that way because it would be too easy for them." His 51 | latest outing as the boy wizard in "Harry Potter and the Order of the Phoenix" is 52 | breaking records on both sides of the Atlantic and he will reprise the role in the 53 | last two films. Watch I-Reporter give her review of Potter's latest » . There is life 54 | beyond Potter, however. The Londoner has filmed a TV movie called "My Boy Jack," about 55 | author Rudyard Kipling and his son, due for release later this year. He will also 56 | appear in "December Boys," an Australian film about four boys who escape an orphanage. 57 | Earlier this year, he made his stage debut playing a tortured teenager in Peter 58 | Shaffer's "Equus." Meanwhile, he is braced for even closer media scrutiny now that 59 | he's legally an adult: "I just think I'm going to be more sort of fair game," he told 60 | Reuters. E-mail to a friend . Copyright 2007 Reuters. All rights reserved.This material 61 | may not be published, broadcast, rewritten, or redistributed.""" 62 | TARGET_TEXT = """Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 63 | 18 Monday. Young actor says he has no plans to fritter his cash away. Radcliffe's 64 | earnings from first five Potter films have been held in trust fund.""" 65 | 66 | 67 | @torch.no_grad() 68 | def forward_fn(model: T5ForConditionalGeneration, input_ids: Tensor, labels: Tensor): 69 | return model(input_ids=input_ids, labels=labels).loss 70 | 71 | 72 | def main( 73 | model_name: str = MODEL_NAME, 74 | num_groups: Sequence[int] = NUM_GROUPS, 75 | total_tokens: int = TOTAL_TOKENS, 76 | model_max_length: int = MODEL_MAX_LENGTH, 77 | device: torch.device = DEVICE, 78 | dtype: torch.dtype = DTYPE, 79 | save_path: str = SAVE_PATH, 80 | ): 81 | print(f"Loading model and tokenizer for {model_name}...") 82 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( 83 | model_name 84 | ).to(device=device, dtype=dtype) 85 | tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False) 86 | 87 | batch_size = total_tokens // model_max_length 88 | input_ids = tokenizer( 89 | [INPUT_TEXT] * batch_size, 90 | return_tensors="pt", 91 | max_length=model_max_length, 92 | truncation=True, 93 | ).input_ids.to(device=device) 94 | labels = tokenizer( 95 | [TARGET_TEXT] * batch_size, 96 | return_tensors="pt", 97 | max_length=model_max_length, 98 | truncation=True, 99 | ).input_ids.to(device=device) 100 | 101 | mha_result = benchmark(forward_fn, t5, input_ids, labels) 102 | print(f"MHA: {mha_result}") 103 | del t5 104 | torch.cuda.empty_cache() 105 | 106 | grouped_times: List[BenchmarkResult] = [] 107 | for g in num_groups: 108 | print("Reloading model and converting to GQA...") 109 | t5 = T5ForConditionalGeneration.from_pretrained(model_name).to( 110 | device=device, dtype=dtype 111 | ) 112 | gqa = convert_t5_to_gqa(t5, kv_heads=g, inplace=True) 113 | grouped_result = benchmark(forward_fn, gqa, input_ids, labels) 114 | grouped_times.append(grouped_result) 115 | print(f"Grouped (g={g}): {grouped_result}") 116 | 117 | del t5, gqa 118 | torch.cuda.empty_cache() 119 | 120 | fig = go.Figure() 121 | fig.add_trace( 122 | go.Scatter( 123 | x=num_groups, 124 | y=[mha_result.mean] * len(num_groups), 125 | mode="lines", 126 | line={"dash": "dash"}, 127 | name="MHA", 128 | ) 129 | ) 130 | fig.add_trace( 131 | go.Scatter( 132 | x=num_groups, 133 | y=[r.mean for r in grouped_times], 134 | mode="lines", 135 | name="GQA", 136 | ) 137 | ) 138 | fig.add_trace( 139 | go.Scatter( 140 | x=num_groups, 141 | y=[grouped_times[0].mean] * len(num_groups), 142 | mode="lines", 143 | line={"dash": "dash"}, 144 | name="MQA", 145 | ) 146 | ) 147 | fig.update_layout( 148 | title="T5 Benchmarks", 149 | xaxis_title="GQA Groups", 150 | yaxis_title="Time per sample (s)", 151 | # use log-scale for x-axis 152 | xaxis={"tickmode": "array", "tickvals": num_groups, "type": "log"}, 153 | # place legend at center-left 154 | legend={"x": 0.1, "y": 0.5}, 155 | ) 156 | fig.write_image(save_path) 157 | 158 | 159 | if __name__ == "__main__": 160 | import argparse 161 | 162 | # NOTE: The original paper uses T5 v1.1 XL and XXL models. When I load those 163 | # models through 'transformers' without applying GQA, I get nonsense outputs. 164 | # TODO: Figure out why this is happening. 165 | # tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-large", legacy=False) 166 | # model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-large") 167 | # 168 | # In the meantime, we can use the non-Google T5 models, which seem to work fine. 169 | # NOTE: Since the the original number of heads (n_heads) must be divisible by 170 | # 'kv_heads', there are only certain values of 'kv_heads' that we can use. 171 | # The following values of 'kv_heads' should be valid: 172 | # - t5-small: 1, 2, 4, 8 173 | # - t5-base: 1, 2, 3, 4, 6, 12 174 | # - t5-large: 1, 2, 4, 8, 16 175 | # - t5-3b: 1, 2, 4, 8, 16, 32 (DEFAULT) 176 | # - t5-11b: 1, 2, 4, 8, 16, 32, 64 TODO: Check 11b values specifically 177 | 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument( 180 | "--model-name", 181 | type=str, 182 | default=MODEL_NAME, 183 | help=f"Name of the T5 model to benchmark (default: '{MODEL_NAME}')", 184 | ) 185 | parser.add_argument( 186 | "--num-groups", 187 | type=int, 188 | nargs="+", 189 | default=NUM_GROUPS, 190 | help=f"Sequence of GQA group sizes for benchmarking (default: {NUM_GROUPS})", 191 | ) 192 | parser.add_argument( 193 | "--total-tokens", 194 | type=int, 195 | default=TOTAL_TOKENS, 196 | help=f"Total number of tokens in the batch (default: {TOTAL_TOKENS})", 197 | ) 198 | parser.add_argument( 199 | "--model-max-length", 200 | type=int, 201 | default=MODEL_MAX_LENGTH, 202 | help=f"Sequence length of the input (default: {MODEL_MAX_LENGTH})", 203 | ) 204 | parser.add_argument( 205 | "--device", 206 | type=str, 207 | default=DEVICE, 208 | help=f"Device to run the benchmark on (default: {DEVICE})", 209 | ) 210 | parser.add_argument( 211 | "--dtype", 212 | type=str, 213 | default=DTYPE, 214 | help=f"Data type to run the benchmark on (default: {DTYPE})", 215 | ) 216 | parser.add_argument( 217 | "--save-path", 218 | type=str, 219 | default=SAVE_PATH, 220 | help=f"Path to save the benchmark plot (default: '{SAVE_PATH}')", 221 | ) 222 | args = parser.parse_args() 223 | 224 | print( 225 | BENCHMARK_SETTINGS_TEMPLATE.format( 226 | device=args.device, 227 | dtype=args.dtype, 228 | model_name=args.model_name, 229 | total_tokens=args.total_tokens, 230 | model_max_length=args.model_max_length, 231 | ) 232 | ) 233 | 234 | main(**vars(args)) 235 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fkodom/grouped-query-attention-pytorch/5267b1a6f39742e19a0e08ae6ab5854762cc0382/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def pytest_addoption(parser): 5 | parser.addoption("--slow", action="store_true") 6 | 7 | 8 | def pytest_configure(config): 9 | config.addinivalue_line("markers", "slow: slow to run") 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | run_slow = config.getoption("--slow") 14 | skip_fast = pytest.mark.skip(reason="remove --slow option to run") 15 | skip_slow = pytest.mark.skip(reason="need --slow option to run") 16 | 17 | for item in items: 18 | if ("slow" in item.keywords) and (not run_slow): 19 | item.add_marker(skip_slow) 20 | if ("slow" not in item.keywords) and (run_slow): 21 | item.add_marker(skip_fast) 22 | -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from grouped_query_attention_pytorch.attention import ( 6 | MultiheadGQA, 7 | scaled_dot_product_gqa, 8 | ) 9 | 10 | torch.backends.cudnn.deterministic = True 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | DTYPE = torch.float64 13 | SEQ_LEN = 16 14 | 15 | 16 | @pytest.mark.parametrize("embed_dim", [64]) 17 | @pytest.mark.parametrize("num_heads", [4, 8]) 18 | @pytest.mark.parametrize("kv_heads", [4, 8]) 19 | @pytest.mark.parametrize("is_causal", [True, False]) 20 | def test_grouped_scaled_dot_product_attention( 21 | embed_dim: int, 22 | num_heads: int, 23 | kv_heads: int, 24 | is_causal: bool, 25 | ): 26 | x = torch.randn(1, SEQ_LEN, num_heads, embed_dim, device=DEVICE, dtype=DTYPE) 27 | kv = torch.randn(1, SEQ_LEN, kv_heads, embed_dim, device=DEVICE, dtype=DTYPE) 28 | 29 | if kv_heads > num_heads: 30 | with pytest.raises(ValueError): 31 | scaled_dot_product_gqa(x, kv, kv, is_causal=is_causal) 32 | return 33 | 34 | out, attn_weights = scaled_dot_product_gqa( 35 | x, kv, kv, is_causal=is_causal, need_weights=True 36 | ) 37 | assert out.size(0) == 1 38 | assert out.size(1) == SEQ_LEN 39 | assert out.size(2) == num_heads 40 | assert out.size(3) == embed_dim 41 | assert attn_weights.size(0) == 1 42 | assert attn_weights.size(1) == SEQ_LEN 43 | assert attn_weights.size(2) == SEQ_LEN 44 | assert attn_weights.size(3) == num_heads 45 | 46 | # Test that grouped SDPA is equivalent to SDPA if we duplicate the KV heads. 47 | kv = kv.repeat_interleave(num_heads // kv_heads, dim=2) 48 | kv = kv.permute(0, 2, 1, 3) 49 | x = x.permute(0, 2, 1, 3) 50 | out_vanilla = F.scaled_dot_product_attention(x, kv, kv, is_causal=is_causal) 51 | out_vanilla = out_vanilla.permute(0, 2, 1, 3) 52 | torch.testing.assert_close(out, out_vanilla) 53 | 54 | 55 | @torch.no_grad() 56 | @pytest.mark.parametrize("embed_dim", [64, 128]) 57 | @pytest.mark.parametrize("num_heads", [4, 8]) 58 | @pytest.mark.parametrize("kv_heads", [4, 8]) 59 | @pytest.mark.parametrize("is_causal", [True, False]) 60 | def test_multihead_gqa( 61 | embed_dim: int, 62 | num_heads: int, 63 | kv_heads: int, 64 | is_causal: bool, 65 | ): 66 | if kv_heads > num_heads: 67 | with pytest.raises(ValueError): 68 | MultiheadGQA(embed_dim, num_heads, kv_heads) 69 | return 70 | 71 | mhda = MultiheadGQA(embed_dim, num_heads, kv_heads, device=DEVICE, dtype=DTYPE) 72 | x = torch.randn(1, SEQ_LEN, embed_dim, device=DEVICE, dtype=DTYPE) 73 | 74 | out, _ = mhda(x, x, x, is_causal=is_causal) # default: causal=False 75 | assert out.size(0) == 1 76 | assert out.size(1) == SEQ_LEN 77 | assert out.size(2) == embed_dim 78 | -------------------------------------------------------------------------------- /tests/test_t5.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import T5ForConditionalGeneration, T5Tokenizer 4 | from transformers.models.t5.modeling_t5 import T5Attention, T5LayerFF 5 | 6 | from grouped_query_attention_pytorch.t5 import T5GQA, convert_t5_to_gqa 7 | 8 | MODEL_NAME = "t5-small" 9 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 11 | SEQ_LEN = 16 12 | 13 | 14 | # Test all of the valid kv_heads values for 't5-small 15 | @pytest.mark.parametrize("kv_heads", [1, 2, 4, 8]) 16 | def test_convert_t5_to_gqa(kv_heads: int): 17 | t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( 18 | MODEL_NAME 19 | ).to(device=DEVICE, dtype=DTYPE) 20 | 21 | # Just to establish that the T5 model is as expected. Check that all of the 22 | # known attention layers are of type 'T5Attention'. 23 | for block in t5.encoder.block: 24 | for layer in block.layer: 25 | if hasattr(layer, "SelfAttention"): 26 | assert isinstance(layer.SelfAttention, T5Attention) 27 | else: 28 | assert isinstance(layer, T5LayerFF) 29 | for block in t5.decoder.block: 30 | for layer in block.layer: 31 | if hasattr(layer, "SelfAttention"): 32 | assert isinstance(layer.SelfAttention, T5Attention) 33 | elif hasattr(layer, "EncDecAttention"): 34 | assert isinstance(layer.EncDecAttention, T5Attention) 35 | else: 36 | assert isinstance(layer, T5LayerFF) 37 | 38 | gqa = convert_t5_to_gqa(t5, kv_heads=kv_heads, inplace=True) 39 | # After conversion, check that all of the attention layers above have been 40 | # replaced with 'T5GQA' layers. 41 | for block in gqa.encoder.block: 42 | for layer in block.layer: 43 | if hasattr(layer, "SelfAttention"): 44 | assert isinstance(layer.SelfAttention, T5GQA) 45 | else: 46 | assert isinstance(layer, T5LayerFF) 47 | for block in t5.decoder.block: 48 | for layer in block.layer: 49 | if hasattr(layer, "SelfAttention"): 50 | assert isinstance(layer.SelfAttention, T5GQA) 51 | elif hasattr(layer, "EncDecAttention"): 52 | assert isinstance(layer.EncDecAttention, T5GQA) 53 | else: 54 | assert isinstance(layer, T5LayerFF) 55 | 56 | # Check that we can pass inputs/targets through the modified model, and that 57 | # the is returns a loss that is a scalar Tensor. 58 | tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, legacy=False) 59 | inputs = tokenizer( 60 | "translate English to German: The house is wonderful.", return_tensors="pt" 61 | ) 62 | targets = tokenizer("Das Haus ist wunderbar.", return_tensors="pt") 63 | with torch.no_grad(): 64 | loss = gqa( 65 | inputs.input_ids.to(DEVICE), 66 | labels=targets.input_ids.to(DEVICE), 67 | ).loss 68 | assert isinstance(loss, torch.Tensor) 69 | assert loss.size() == () 70 | 71 | # Check that we can generate outputs, using the usual 'generate' method. 72 | out = gqa.generate(inputs.input_ids.to(DEVICE), max_new_tokens=10) 73 | assert isinstance(out, torch.Tensor) 74 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | import pytest 4 | import torch 5 | 6 | from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM 7 | 8 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 10 | SEQ_LEN = 16 11 | 12 | 13 | @pytest.mark.parametrize("d_model", [128]) 14 | @pytest.mark.parametrize("nhead", [4, 8]) 15 | @pytest.mark.parametrize("kv_heads", [2, 4]) 16 | @pytest.mark.parametrize("num_layers", [1, 2]) 17 | @pytest.mark.parametrize("dim_feedforward", [64]) 18 | @pytest.mark.parametrize("dropout", [0.0]) 19 | @pytest.mark.parametrize("activation", ["relu", "gelu"]) 20 | @pytest.mark.parametrize("is_causal", [True, False]) 21 | def test_gqa_transformer( 22 | d_model: int, 23 | nhead: int, 24 | kv_heads: int, 25 | num_layers: int, 26 | dim_feedforward: int, 27 | dropout: float, 28 | activation: Union[str, Callable], 29 | is_causal: bool, 30 | ): 31 | net = GQATransformer( 32 | d_model=d_model, 33 | nhead=nhead, 34 | kv_heads=kv_heads, 35 | num_encoder_layers=num_layers, 36 | num_decoder_layers=num_layers, 37 | dim_feedforward=dim_feedforward, 38 | dropout=dropout, 39 | activation=activation, 40 | device=DEVICE, 41 | dtype=DTYPE, 42 | ) 43 | x = torch.randn(1, SEQ_LEN, d_model, device=DEVICE, dtype=DTYPE) 44 | with torch.no_grad(): 45 | y = net.forward(x, is_causal=is_causal) 46 | assert y.size(0) == 1 47 | assert y.size(1) == SEQ_LEN 48 | assert y.size(2) == d_model 49 | 50 | 51 | @pytest.mark.parametrize("num_tokens", [100]) 52 | @pytest.mark.parametrize("d_model", [128]) 53 | @pytest.mark.parametrize("nhead", [4, 8]) 54 | @pytest.mark.parametrize("kv_heads", [2, 4]) 55 | @pytest.mark.parametrize("num_layers", [1, 2]) 56 | @pytest.mark.parametrize("dim_feedforward", [64]) 57 | @pytest.mark.parametrize("dropout", [0.0]) 58 | @pytest.mark.parametrize("activation", ["relu", "gelu"]) 59 | @pytest.mark.parametrize("is_causal", [True, False]) 60 | def test_gqa_transformer_lm( 61 | num_tokens: int, 62 | d_model: int, 63 | nhead: int, 64 | kv_heads: int, 65 | num_layers: int, 66 | dim_feedforward: int, 67 | dropout: float, 68 | activation: Union[str, Callable], 69 | is_causal: bool, 70 | ): 71 | net = GQATransformerLM( 72 | num_tokens=num_tokens, 73 | d_model=d_model, 74 | nhead=nhead, 75 | kv_heads=kv_heads, 76 | num_encoder_layers=num_layers, 77 | num_decoder_layers=num_layers, 78 | dim_feedforward=dim_feedforward, 79 | dropout=dropout, 80 | activation=activation, 81 | device=DEVICE, 82 | dtype=DTYPE, 83 | ) 84 | x = torch.randint(0, num_tokens, (1, SEQ_LEN), device=DEVICE, dtype=torch.long) 85 | with torch.no_grad(): 86 | y = net.forward(x, is_causal=is_causal) 87 | assert y.size(0) == 1 88 | assert y.size(1) == SEQ_LEN 89 | assert y.size(2) == num_tokens 90 | --------------------------------------------------------------------------------