├── tests ├── __init__.py ├── attention.py └── backward.py ├── src └── flashdeberta │ ├── ops │ ├── __init__.py │ ├── flash_attention_varlen.py │ └── flash_attention.py │ ├── __init__.py │ ├── utils.py │ ├── padding.py │ └── model.py ├── images └── benchmarking.png ├── pyproject.toml ├── README.md ├── .gitignore └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/flashdeberta/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/benchmarking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Knowledgator/FlashDeBERTa/HEAD/images/benchmarking.png -------------------------------------------------------------------------------- /src/flashdeberta/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ( 2 | FlashDisentangledSelfAttention, 3 | FlashDebertaV2Model, 4 | FlashDebertaV2ForMaskedLM, 5 | FlashDebertaV2ForSequenceClassification, 6 | FlashDebertaV2ForTokenClassification, 7 | FlashDebertaV2ForQuestionAnswering, 8 | FlashDebertaV2ForMultipleChoice 9 | ) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "flashDeBERTa" 3 | version = "0.0.5" 4 | description = "Memory and compute efficient DeBERTa models." 5 | authors = ["knowledgator.com"] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | transformers = ">=4.30.0" 12 | torch = ">=2.1.0" 13 | triton = ">=3.0.0" 14 | einops =">=0.7.0" 15 | 16 | [build-system] 17 | requires = ["poetry-core"] 18 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /src/flashdeberta/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Mesh TensorFlow authors, Knowledgator, T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Tuple, Union, List 17 | import importlib.util 18 | import importlib.metadata 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | 23 | # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. 24 | def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: 25 | # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version 26 | package_exists = importlib.util.find_spec(pkg_name) is not None 27 | package_version = "N/A" 28 | if package_exists: 29 | try: 30 | package_version = importlib.metadata.version(pkg_name) 31 | package_exists = True 32 | except importlib.metadata.PackageNotFoundError: 33 | package_exists = False 34 | if return_version: 35 | return package_exists, package_version 36 | else: 37 | return package_exists 38 | 39 | 40 | _flash_attn_available = _is_package_available("flash_attn") 41 | 42 | def is_flash_attn_available(): 43 | # Let's add an extra check to see if cuda is available 44 | import torch 45 | 46 | return _flash_attn_available and torch.cuda.is_available() 47 | 48 | def _get_unpad_data(padding_mask): 49 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) 50 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() 51 | max_seqlen_in_batch = seqlens_in_batch.max().item() 52 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 53 | return ( 54 | indices, 55 | cu_seqlens, 56 | max_seqlen_in_batch, 57 | ) 58 | 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlashDeBERTa 🦾 – Boost inference speed by 3-5x ⚡ and run DeBERTa models on long sequences 📚. 2 | 3 | **FlashDeBERTa** is an optimized version of the DeBERTa model leveraging flash attention to implement a disentangled attention mechanism. It significantly reduces memory usage and latency, especially with long sequences. The project enables loading and running original DeBERTa models on tens of thousands of tokens without retraining, maintaining original accuracy. 4 | 5 | ### Use Cases 6 | 7 | DeBERTa remains one of the top-performing models for the following tasks: 8 | 9 | - **Named Entity Recognition:** It serves as the main backbone for models such as [GLiNER](https://github.com/urchade/GLiNER), an efficient architecture for zero-shot information extraction. 10 | - **Text Classification:** DeBERTa is highly effective for supervised and zero-shot classification tasks, such as [GLiClass](https://github.com/Knowledgator/GLiClass). 11 | - **Reranking:** The model offers competitive performance compared to other reranking models, making it a valuable component in many RAG systems. 12 | 13 | > [!warning] 14 | > This project is under active development and may contain bugs. Please create an issue if you encounter bugs or have suggestions for improvements. 15 | 16 | ### Installation 17 | 18 | First, install the package: 19 | 20 | ```bash 21 | pip install flashdeberta -U 22 | ``` 23 | 24 | Then import the appropriate model heads for your use case and initialize the model from pretrained checkpoints: 25 | 26 | ```python 27 | from flashdeberta import FlashDebertaV2Model # FlashDebertaV2ForSequenceClassification, FlashDebertaV2ForTokenClassification, etc. 28 | from transformers import AutoTokenizer 29 | import torch 30 | 31 | # Load tokenizer and model 32 | tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base") 33 | model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base").to('cuda') 34 | 35 | # Tokenize input text 36 | input_text = "Hello world!" 37 | input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to('cuda') 38 | 39 | # Model inference 40 | outputs = model(input_ids) 41 | ``` 42 | 43 | In order to switch to eager attention implementation, initialise a model in the following way: 44 | ```python 45 | model = FlashDebertaV2Model.from_pretrained("microsoft/deberta-v3-base", _attn_implementation='eager').to('cuda') 46 | ``` 47 | 48 | ### Benchmarks 49 | 50 | While context-to-position and position-to-context biases still require quadratic memory, our flash attention implementation reduces overall memory requirements to nearly linear. This efficiency is particularly impactful for longer sequences. Starting from 512 tokens, FlashDeBERTa achieves more than a 50% performance improvement, and at 4k tokens, it's over 5 times faster than naive implementations. 51 | 52 | ![benchmarking](images/benchmarking.png) 53 | 54 | ### Future Work 55 | 56 | - Implement backward kernels. 57 | - Train DeBERTa models on 8,192-token sequences using high-quality data. 58 | - Integrate FlashDeBERTa into GLiNER and GLiClass. 59 | - Train multi-modal DeBERTa models. 60 | 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | src.zip 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /tests/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from flashdeberta.model import FlashDisentangledSelfAttention 4 | from transformers.models.deberta_v2.modeling_deberta_v2 import DisentangledSelfAttention 5 | 6 | class DummyConfig: 7 | def __init__(self, hidden_size, num_attention_heads, position_buckets, max_relative_positions, pos_att_type=[], max_position_embeddings=512): 8 | self.hidden_size = hidden_size 9 | self.num_attention_heads = num_attention_heads 10 | self.attention_probs_dropout_prob = 0.1 11 | self.hidden_dropout_prob = 0.1 12 | # For testing, use both position attention types. 13 | self.pos_att_type = pos_att_type 14 | self.relative_attention = True 15 | self.position_buckets = position_buckets 16 | self.max_relative_positions = max_relative_positions 17 | self.max_position_embeddings = max_position_embeddings 18 | self.share_att_key = False 19 | 20 | def compare_flash_and_deberta(B, L, hidden_size, causal=False, sm_scale=None, 21 | position_buckets=32, max_relative_positions=64, pos_att_type=[]): 22 | """ 23 | Compares outputs between the flash implementation and the original Deberta module, 24 | enforcing the same weights. Returns the mean absolute difference between outputs. 25 | """ 26 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 27 | 28 | # Create dummy hidden states. 29 | hidden_states = torch.randn(B, L, hidden_size, device=device) 30 | 31 | attention_mask = torch.ones(B, L, device=device) 32 | 33 | # Extended attention mask for the original Deberta model. 34 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 35 | extended_attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) 36 | 37 | # Create random relative positional embeddings. 38 | rel_embeddings = torch.randn(max_relative_positions, hidden_size, device=device) 39 | 40 | # For testing, choose a fixed number of attention heads. 41 | num_attention_heads = 8 42 | assert hidden_size % num_attention_heads == 0, "hidden_size must be divisible by num_attention_heads" 43 | 44 | # Instantiate a dummy configuration. 45 | config = DummyConfig(hidden_size=hidden_size, 46 | num_attention_heads=num_attention_heads, 47 | position_buckets=position_buckets, 48 | max_relative_positions=max_relative_positions, 49 | pos_att_type=pos_att_type) 50 | 51 | # Instantiate the original and flash models. 52 | deberta_model = DisentangledSelfAttention(config).to(device) 53 | flash_model = FlashDisentangledSelfAttention(config).to(device) 54 | 55 | # Set both models to evaluation mode to disable dropout. 56 | deberta_model.eval() 57 | flash_model.eval() 58 | 59 | # Copy weights from the original model to the flash version. 60 | flash_model.load_state_dict(deberta_model.state_dict()) 61 | 62 | # Run a forward pass with the original Deberta model. 63 | t_start = time.time() 64 | output_deberta = deberta_model(hidden_states, extended_attention_mask, rel_embeddings=rel_embeddings) 65 | t_deberta = 1000 * (time.time() - t_start) 66 | print('Deberta forward pass time (ms):', t_deberta) 67 | print('Deberta output sample:', output_deberta[0]) # Slice printed for brevity 68 | 69 | # Run a forward pass with the Flash model. 70 | t_start = time.time() 71 | output_flash = flash_model(hidden_states, attention_mask, rel_embeddings=rel_embeddings) 72 | t_flash = 1000 * (time.time() - t_start) 73 | print('Flash forward pass time (ms):', t_flash) 74 | print('Flash output sample:', output_flash[0]) # Slice printed for brevity 75 | 76 | # Compute the mean absolute difference between outputs. 77 | diff = (output_deberta[0] - output_flash[0]).abs().mean().item() 78 | print("Mean absolute difference between outputs:", diff) 79 | return diff 80 | 81 | def test_flash_vs_deberta(): 82 | """ 83 | This test compares the outputs of the flash and original Deberta attention modules. 84 | The test passes if the mean absolute difference is below the tolerance threshold. 85 | """ 86 | # Test parameters. 87 | B = 1 # Batch size 88 | L = 2048 # Sequence length 89 | hidden_size = 1024 # hidden_size should be divisible by num_attention_heads (8 here) 90 | causal = False 91 | sm_scale = None 92 | position_buckets = 256 93 | max_relative_positions = 512 94 | pos_att_type = ['c2p', 'p2c'] # Example position attention types 95 | 96 | # Compute the difference between outputs. 97 | diff = compare_flash_and_deberta( 98 | B, L, hidden_size, causal, sm_scale, position_buckets, max_relative_positions, pos_att_type 99 | ) 100 | 101 | # Define a tolerance threshold. 102 | tolerance = 1e-4 103 | # The test will fail if the outputs differ by more than the tolerated threshold. 104 | assert diff < tolerance, f"Difference {diff} exceeds tolerance {tolerance}" 105 | -------------------------------------------------------------------------------- /src/flashdeberta/padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | 5 | class IndexPutFirstAxis(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, values, indices, first_axis_dim): 8 | ctx.save_for_backward(indices) 9 | assert indices.ndim == 1 10 | assert values.ndim >= 2 11 | output = torch.zeros( 12 | first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype 13 | ) 14 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. 15 | output[indices] = values 16 | # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | (indices,) = ctx.saved_tensors 22 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 23 | grad_values = grad_output[indices] 24 | # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) 25 | return grad_values, None, None 26 | 27 | 28 | index_put_first_axis = IndexPutFirstAxis.apply 29 | 30 | def pad_input(hidden_states, indices, batch, seqlen): 31 | """ 32 | Arguments: 33 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 34 | indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. 35 | batch: int, batch size for the padded sequence. 36 | seqlen: int, maximum sequence length for the padded sequence. 37 | Return: 38 | hidden_states: (batch, seqlen, ...) 39 | """ 40 | dim = hidden_states.shape[-1] 41 | # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) 42 | # output[indices] = hidden_states 43 | output = index_put_first_axis(hidden_states, indices, batch * seqlen) 44 | return rearrange(output, "(b s) ... -> b s ...", b=batch) 45 | 46 | class IndexFirstAxis(torch.autograd.Function): 47 | @staticmethod 48 | def forward(ctx, input, indices): 49 | ctx.save_for_backward(indices) 50 | assert input.ndim >= 2 51 | ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] 52 | second_dim = other_shape.numel() 53 | # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. 54 | # return input[indices] 55 | return torch.gather( 56 | rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) 57 | ).reshape(-1, *other_shape) 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | (indices,) = ctx.saved_tensors 62 | assert grad_output.ndim >= 2 63 | other_shape = grad_output.shape[1:] 64 | grad_output = rearrange(grad_output, "b ... -> b (...)") 65 | grad_input = torch.zeros( 66 | [ctx.first_axis_dim, grad_output.shape[1]], 67 | device=grad_output.device, 68 | dtype=grad_output.dtype, 69 | ) 70 | # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. 71 | # grad_input[indices] = grad_output 72 | grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) 73 | return grad_input.reshape(ctx.first_axis_dim, *other_shape), None 74 | 75 | 76 | index_first_axis = IndexFirstAxis.apply 77 | 78 | 79 | def _get_unpad_data(attention_mask): 80 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 81 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 82 | max_seqlen_in_batch = seqlens_in_batch.max().item() 83 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 84 | return ( 85 | indices, 86 | cu_seqlens, 87 | max_seqlen_in_batch, 88 | ) 89 | 90 | def unpad_input(hidden_states, attention_mask): 91 | """ 92 | Arguments: 93 | hidden_states: (batch, seqlen, ...) 94 | attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. 95 | Return: 96 | hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. 97 | indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. 98 | cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. 99 | max_seqlen_in_batch: int 100 | """ 101 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 102 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 103 | max_seqlen_in_batch = seqlens_in_batch.max().item() 104 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 105 | 106 | return ( 107 | index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), 108 | indices, 109 | cu_seqlens, 110 | max_seqlen_in_batch, 111 | ) 112 | 113 | def _upad_input(query_layer, key_layer, value_layer, pos_key, pos_query, attention_mask, query_length, NH): 114 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 115 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 116 | 117 | key_layer = index_first_axis( 118 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 119 | ) 120 | 121 | value_layer = index_first_axis( 122 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 123 | ) 124 | if pos_query is not None: 125 | max_distance = pos_query.shape[-1] 126 | pos_query = index_first_axis( 127 | pos_query.reshape(batch_size * kv_seq_len, num_key_value_heads, max_distance), indices_k 128 | ) 129 | 130 | if query_length == kv_seq_len: 131 | 132 | query_layer = index_first_axis( 133 | query_layer.reshape(batch_size * kv_seq_len, NH, head_dim), indices_k 134 | ) 135 | if pos_key is not None: 136 | max_distance = pos_key.shape[-1] 137 | pos_key = index_first_axis( 138 | pos_key.reshape(batch_size * kv_seq_len, num_key_value_heads, max_distance), indices_k 139 | ) 140 | cu_seqlens_q = cu_seqlens_k 141 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 142 | indices_q = indices_k 143 | elif query_length == 1: 144 | max_seqlen_in_batch_q = 1 145 | cu_seqlens_q = torch.arange( 146 | batch_size + 1, dtype=torch.int32, device=query_layer.device 147 | ) # There is a memcpy here, that is very bad. 148 | indices_q = cu_seqlens_q[:-1] 149 | query_layer = query_layer.squeeze(1) 150 | if pos_key is not None: 151 | pos_key = pos_key.squeeze(1) 152 | else: 153 | # The -q_len: slice assumes left padding. 154 | attention_mask = attention_mask[:, -query_length:] 155 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 156 | 157 | return ( 158 | query_layer, 159 | key_layer, 160 | value_layer, 161 | pos_key, 162 | pos_query, 163 | indices_q, 164 | (cu_seqlens_q, cu_seqlens_k), 165 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 166 | ) -------------------------------------------------------------------------------- /tests/backward.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from flashdeberta.model import FlashDisentangledSelfAttention 4 | from transformers.models.deberta_v2.modeling_deberta_v2 import DisentangledSelfAttention 5 | 6 | class DummyConfig: 7 | def __init__(self, hidden_size, num_attention_heads, position_buckets, max_relative_positions, pos_att_type=[], max_position_embeddings=512): 8 | self.hidden_size = hidden_size 9 | self.num_attention_heads = num_attention_heads 10 | self.attention_probs_dropout_prob = 0. 11 | self.hidden_dropout_prob = 0. 12 | self.pos_att_type = pos_att_type 13 | self.relative_attention = True 14 | self.position_buckets = position_buckets 15 | self.max_relative_positions = max_relative_positions 16 | self.max_position_embeddings = max_position_embeddings 17 | self.share_att_key = False 18 | 19 | def _make_varlen_mask(B, L, device, min_len=1): 20 | lengths = torch.randint(low=min_len, high=L + 1, size=(B,), device=device) 21 | mask = torch.zeros(B, L, device=device, dtype=torch.bool) 22 | for b, t in enumerate(lengths.tolist()): 23 | mask[b, :t] = True 24 | return mask 25 | 26 | @torch.no_grad() 27 | def _extended_mask(attention_mask): 28 | m = attention_mask.float() 29 | m1 = m.unsqueeze(1).unsqueeze(2) # (B, 1, 1, L) 30 | m2 = m1.squeeze(-2).unsqueeze(-1) # (B, 1, L, 1) 31 | return m1 * m2 # (B, 1, L, L) 32 | 33 | def _tensor_stats(x): 34 | if x is None: 35 | return float("nan"), float("nan"), float("nan") 36 | a = x.abs() 37 | max_abs = a.max().item() 38 | mean_abs = a.mean().item() 39 | l2 = x.pow(2).sum().sqrt().item() 40 | return max_abs, mean_abs, l2 41 | 42 | def _grad_diff_record(name, ga, gb, atol, rtol, eps=1e-12): 43 | # ga/gb: gradients for parameter `name` (same shape) 44 | d = (ga - gb).detach() 45 | max_abs, mean_abs, l2 = _tensor_stats(d) 46 | ga_l2 = ga.detach().pow(2).sum().sqrt().item() 47 | gb_l2 = gb.detach().pow(2).sum().sqrt().item() 48 | rel_l2 = l2 / (max(ga_l2, gb_l2, eps)) # normalized by larger ref magnitude 49 | 50 | # Per-param pass criterion using absolute/relative on max element 51 | ref_max = max(ga.detach().abs().max().item(), gb.detach().abs().max().item(), eps) 52 | passed = (max_abs <= atol) or (max_abs <= rtol * ref_max) 53 | 54 | return { 55 | "name": name, 56 | "shape": tuple(ga.shape), 57 | "max_abs": max_abs, 58 | "mean_abs": mean_abs, 59 | "l2": l2, 60 | "ga_l2": ga_l2, 61 | "gb_l2": gb_l2, 62 | "rel_l2": rel_l2, 63 | "passed": passed, 64 | } 65 | 66 | def compare_flash_and_deberta_backward( 67 | B, L, hidden_size, 68 | causal=False, sm_scale=None, 69 | position_buckets=32, 70 | max_relative_positions=64, 71 | pos_att_type=[], 72 | varlen=False, 73 | dtype=None, 74 | atol=None, 75 | rtol=None, 76 | seed=0, 77 | verbose=True, 78 | print_top_k=20, # <── show top-K params by max_abs diff 79 | ): 80 | """ 81 | Compare backward pass gradients parameter-by-parameter. 82 | Returns: 83 | { 84 | "passed": bool, 85 | "input_grad_max_abs_diff": float, 86 | "input_grad_mean_abs_diff": float, 87 | "rel_grad_max_abs_diff": float or nan, 88 | "rel_grad_mean_abs_diff": float or nan, 89 | "param_grad_max_abs_diff": float, 90 | "param_grad_mean_abs_diff": float, 91 | "per_param": {name: metrics_dict, ...}, # <── NEW 92 | "failed_params": [names...], # <── NEW 93 | } 94 | """ 95 | torch.manual_seed(seed) 96 | device = "cuda" if torch.cuda.is_available() else "cpu" 97 | if dtype is None: 98 | dtype = torch.float16 if device == "cuda" else torch.float32 99 | if atol is None: 100 | atol = 5e-3 if dtype == torch.float16 else 1e-5 101 | if rtol is None: 102 | rtol = 5e-3 if dtype == torch.float16 else 1e-5 103 | mode = "varlen" if varlen else "fixed-len" 104 | 105 | # Inputs 106 | hidden_states_a = torch.randn(B, L, hidden_size, device=device, dtype=dtype, requires_grad=True) 107 | rel_embeddings_a = torch.randn(max_relative_positions, hidden_size, device=device, dtype=dtype, requires_grad=True) 108 | 109 | attention_mask = _make_varlen_mask(B, L, device) if varlen else torch.ones(B, L, device=device, dtype=torch.bool) 110 | extended_attention_mask = _extended_mask(attention_mask) 111 | 112 | # Clone for flash path 113 | hidden_states_b = hidden_states_a.detach().clone().requires_grad_(True) 114 | rel_embeddings_b = rel_embeddings_a.detach().clone().requires_grad_(True) 115 | 116 | # Config + models 117 | num_attention_heads = 8 118 | assert hidden_size % num_attention_heads == 0 119 | config = DummyConfig( 120 | hidden_size=hidden_size, 121 | num_attention_heads=num_attention_heads, 122 | position_buckets=position_buckets, 123 | max_relative_positions=max_relative_positions, 124 | pos_att_type=pos_att_type, 125 | ) 126 | deberta_model = DisentangledSelfAttention(config).to(device).to(dtype) 127 | flash_model = FlashDisentangledSelfAttention(config).to(device).to(dtype) 128 | flash_model.load_state_dict(deberta_model.state_dict()) 129 | 130 | deberta_model.eval() 131 | flash_model.eval() 132 | 133 | # Forward 134 | out_a, _ = deberta_model(hidden_states_a, extended_attention_mask, rel_embeddings=rel_embeddings_a) 135 | out_b, _ = flash_model(hidden_states_b, attention_mask, rel_embeddings=rel_embeddings_b) 136 | 137 | diff = (out_a - out_b).abs().mean().item() 138 | if verbose: 139 | print(f"[Forward test — {mode}] dtype={dtype}, atol={atol}, rtol={rtol}") 140 | print("Mean absolute difference between outputs:", diff) 141 | passed = True if diff < rtol else False 142 | print(f"Forward passed: {passed}.") 143 | # Identical upstream grad 144 | grad_out = torch.randn_like(out_a) 145 | 146 | # Zero param grads 147 | for p in deberta_model.parameters(): 148 | if p.grad is not None: p.grad = None 149 | for p in flash_model.parameters(): 150 | if p.grad is not None: p.grad = None 151 | 152 | # Backward 153 | (out_a * grad_out).sum().backward() 154 | (out_b * grad_out.detach().clone()).sum().backward() 155 | 156 | # ---- Inputs ---- 157 | in_diff = (hidden_states_a.grad - hidden_states_b.grad).detach() 158 | in_max, in_mean, _ = _tensor_stats(in_diff) 159 | 160 | # ---- Relative embeddings (if used) ---- 161 | if any(t in pos_att_type for t in ("c2p", "p2c")): 162 | rel_diff = (rel_embeddings_a.grad - rel_embeddings_b.grad).detach() 163 | rel_max, rel_mean, _ = _tensor_stats(rel_diff) 164 | else: 165 | rel_max, rel_mean = float("nan"), float("nan") 166 | 167 | # ---- Per-param diffs ---- 168 | per_param = {} 169 | diffs_max = [] 170 | grads_a = {n: p.grad for n, p in deberta_model.named_parameters()} 171 | grads_b = {n: p.grad for n, p in flash_model.named_parameters()} 172 | failed_params = [] 173 | 174 | for n, ga in grads_a.items(): 175 | gb = grads_b.get(n, None) 176 | if ga is None or gb is None: 177 | continue 178 | rec = _grad_diff_record(n, ga, gb, atol, rtol) 179 | per_param[n] = rec 180 | diffs_max.append(rec["max_abs"]) 181 | if not rec["passed"]: 182 | failed_params.append(n) 183 | 184 | param_max = max(diffs_max) if diffs_max else float("nan") 185 | param_mean = (sum(diffs_max) / len(diffs_max)) if diffs_max else float("nan") 186 | 187 | # Overall pass (all params + inputs + rel-emb if present) 188 | passed = True 189 | if not math.isnan(param_max): 190 | # overall: require all individual params passed instead of only global max 191 | passed &= (len(failed_params) == 0) 192 | if not math.isnan(in_max): 193 | passed &= (in_max <= atol) or (in_max <= rtol * max(1e-8, in_max)) 194 | if not math.isnan(rel_max): 195 | passed &= (rel_max <= atol) or (rel_max <= rtol * max(1e-8, rel_max)) 196 | 197 | if verbose: 198 | print(f"[Backward test — {mode}] dtype={dtype}, atol={atol}, rtol={rtol}") 199 | print(f" Input grad diff: max={in_max:.3e}, mean={in_mean:.3e}") 200 | if not math.isnan(rel_max): 201 | print(f" Rel-emb grad diff: max={rel_max:.3e}, mean={rel_mean:.3e}") 202 | print(f" Param grad diff: max={param_max:.3e}, mean(max_abs)={param_mean:.3e}") 203 | # Show top-K by max_abs diff 204 | if per_param: 205 | top = sorted(per_param.values(), key=lambda r: r["max_abs"], reverse=True)[:print_top_k] 206 | print(f" Top {len(top)} params by max_abs diff:") 207 | for r in top: 208 | flag = "" if r["passed"] else " <-- FAIL" 209 | print(f" {r['name']:55s} {str(r['shape']):>18s} max={r['max_abs']:.3e} " 210 | f"mean={r['mean_abs']:.3e} relL2={r['rel_l2']:.3e}{flag}") 211 | if failed_params: 212 | print(f" FAILED PARAMS ({len(failed_params)}): {failed_params}") 213 | print(" PASSED" if passed and not failed_params else " FAILED") 214 | 215 | return { 216 | "passed": passed and not failed_params, 217 | "input_grad_max_abs_diff": in_max, 218 | "input_grad_mean_abs_diff": in_mean, 219 | "rel_grad_max_abs_diff": rel_max, 220 | "rel_grad_mean_abs_diff": rel_mean, 221 | "param_grad_max_abs_diff": param_max, 222 | "param_grad_mean_abs_diff": param_mean, 223 | "per_param": per_param, # name -> metrics dict 224 | "failed_params": failed_params, # list of names 225 | } 226 | 227 | # Example 228 | if __name__ == "__main__": 229 | _ = compare_flash_and_deberta_backward( 230 | B=2, L=128, hidden_size=256, 231 | pos_att_type=["p2c", "c2p"], 232 | varlen=False, verbose=True, print_top_k=30 233 | ) 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/flashdeberta/ops/flash_attention_varlen.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import warnings 5 | import triton.language as tl 6 | 7 | def calculate_shared_memory_usage_varlen(BLOCK_M, BLOCK_N, BLOCK_DMODEL, num_stages, dtype, 8 | has_c2p=False, has_p2c=False, ATT_SPAN=0): 9 | """ 10 | Calculate the shared memory requirements for Flash Attention with disentangled attention 11 | for variable-length sequences. 12 | 13 | Args: 14 | BLOCK_M: Block size for query sequence dimension 15 | BLOCK_N: Block size for key sequence dimension 16 | BLOCK_DMODEL: Head dimension size 17 | num_stages: Number of pipeline stages 18 | dtype: Data type (torch.float16, torch.float32, etc.) 19 | has_c2p: Whether content-to-position bias is used 20 | has_p2c: Whether position-to-content bias is used 21 | ATT_SPAN: Attention span for relative position 22 | 23 | Returns: 24 | The estimated shared memory usage in bytes 25 | """ 26 | # Determine byte size based on data type 27 | if dtype == torch.float16: 28 | dtype_size = 2 29 | elif dtype == torch.float32: 30 | dtype_size = 4 31 | else: 32 | dtype_size = 2 # Default to float16 size for other types 33 | 34 | # Core tensors that are always used 35 | q_size = BLOCK_M * BLOCK_DMODEL * dtype_size 36 | k_size = BLOCK_N * BLOCK_DMODEL * dtype_size 37 | v_size = BLOCK_N * BLOCK_DMODEL * dtype_size 38 | 39 | # Memory for attention scores and accumulator 40 | attn_matrix_size = BLOCK_M * BLOCK_N * dtype_size 41 | accumulator_size = BLOCK_M * BLOCK_DMODEL * dtype_size 42 | 43 | # Position embedding memory if needed 44 | pos_memory = 0 45 | if has_c2p: 46 | pos_memory += BLOCK_M * 2 * ATT_SPAN * dtype_size 47 | if has_p2c: 48 | pos_memory += BLOCK_N * 2 * ATT_SPAN * dtype_size 49 | 50 | # Additional buffers for intermediate calculations 51 | # This includes arrays for relative positions, bucket indices, etc. 52 | additional_buffers = BLOCK_M * BLOCK_N * 4 # For relative position indices and calculations 53 | 54 | # For variable length, we need additional bookkeeping arrays 55 | varlen_buffers = (BLOCK_M + BLOCK_N) * 4 # For sequence boundary tracking 56 | 57 | # Mid batch and mid start arrays (for batch mapping) 58 | mid_batch_memory = BLOCK_M * 4 # Int32 array 59 | 60 | # Total memory per stage including variable length overhead 61 | memory_per_stage = q_size + k_size + v_size + attn_matrix_size + pos_memory + additional_buffers + varlen_buffers 62 | 63 | # Total shared memory including all pipeline stages and bookkeeping 64 | total_shared_memory = num_stages * memory_per_stage + accumulator_size + mid_batch_memory 65 | 66 | return total_shared_memory // 2 67 | 68 | def cdiv(a, b): 69 | return (a + b - 1) // b 70 | 71 | def get_mid(cu_seqlens_q, B, BLOCK_M): 72 | mid_batch = [] 73 | mid_start = [] 74 | MN = 0 75 | for batch in range(B): 76 | q_start = cu_seqlens_q[batch] 77 | q_end = cu_seqlens_q[batch+1] 78 | n_batch_blocks = (q_end-q_start+BLOCK_M-1).item()//BLOCK_M 79 | MN+=n_batch_blocks 80 | for block in range(n_batch_blocks): 81 | mid_start.append(q_start+(block)*BLOCK_M) 82 | mid_batch.append(batch) 83 | return (mid_batch, mid_start, MN) 84 | 85 | @triton.jit 86 | def _fwd_kernel_deberta_disentangled_attention( 87 | Q, K, V, 88 | K_POS, Q_POS, 89 | L, O, 90 | sm_scale, 91 | cu_seqlens_q, cu_seqlens_k, 92 | mid_batch, mid_start, 93 | stride_qz, stride_qh, stride_qk, 94 | stride_kz, stride_kh, stride_kk, 95 | stride_vz, stride_vh, stride_vk, 96 | stride_oz, stride_oh, stride_ok, 97 | stride_pk0, stride_pk1, stride_pk2, 98 | stride_pq0, stride_pq1, stride_pq2, 99 | B, H, M, N, 100 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 101 | IS_CAUSAL: tl.constexpr, 102 | HAS_C2P: tl.constexpr, HAS_P2C: tl.constexpr, 103 | ATT_SPAN: tl.constexpr, 104 | NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr 105 | ): 106 | input_dtype = Q.dtype.element_ty 107 | 108 | start_z = tl.program_id(0) 109 | off_h = tl.program_id(1) 110 | off_b = tl.load(mid_batch + start_z) 111 | off_m = tl.load(mid_start + start_z) 112 | 113 | q_start = tl.load(cu_seqlens_q + off_b) 114 | q_end = tl.load(cu_seqlens_q + off_b + 1) 115 | k_start = tl.load(cu_seqlens_k + off_b) 116 | k_end = tl.load(cu_seqlens_k + off_b + 1) 117 | 118 | lM = q_end - q_start 119 | lN = k_end - k_start 120 | P_SEQ = lM - lN 121 | 122 | log2e: tl.constexpr = 1.4426950408889634 123 | 124 | L += off_m * H + off_h 125 | 126 | offs_m_base = tl.arange(0, BLOCK_M) 127 | offs_m = offs_m_base + off_m 128 | offs_m_relative = offs_m - q_start 129 | offs_n_base = tl.arange(0, BLOCK_N) 130 | offs_k = tl.arange(0, BLOCK_DMODEL) 131 | 132 | q_ptrs = Q + (offs_m[:, None] * stride_qz + off_h * stride_qh + offs_k[None, :] * stride_qk) 133 | o_ptrs = O + (offs_m[:, None] * stride_oz + off_h * stride_oh + offs_k[None, :] * stride_ok) 134 | l_ptrs = L + offs_m_base * H 135 | 136 | mask_m = offs_m < q_end 137 | q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") 138 | 139 | if BLOCK_DMODEL < 128: 140 | I = tl.where(offs_k[:, None] == offs_k, 141 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype), 142 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype)) 143 | q = tl.dot(q, I).to(input_dtype) 144 | 145 | if IS_CAUSAL: 146 | hi = tl.minimum(lN, P_SEQ + (off_m + 1) * BLOCK_M) 147 | if lM > lN: 148 | hi = tl.maximum(0, hi) 149 | else: 150 | hi = lN 151 | 152 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 153 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 154 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 155 | 156 | offs_n_init = k_start + offs_n_base 157 | k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vz + off_h * stride_kh) 158 | v_ptrs = V + (offs_n_init[:, None] * stride_kz + offs_k[None, :] * stride_kk + off_h * stride_vh) 159 | 160 | if HAS_C2P: 161 | k_pos_ptrs = K_POS + (offs_m[:, None] * stride_pk0 + off_h * stride_pk1) 162 | 163 | for start_n in range(0, hi, BLOCK_N): 164 | start_n = tl.multiple_of(start_n, BLOCK_N) 165 | offs_n = start_n + offs_n_base 166 | 167 | mask_n = offs_n < lN 168 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") 169 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") 170 | 171 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=input_dtype) 172 | s += tl.dot(q, k) * sm_scale 173 | 174 | relative_positions = offs_n[None, :] - offs_m_relative[:, None] 175 | sign = tl.where(relative_positions > 0.0, 1.0, tl.where(relative_positions < 0.0, -1.0, 0.0)) 176 | mid_val = NUM_BUCKETS // 2 177 | abs_relative = tl.abs(relative_positions) 178 | condition = (relative_positions < mid_val) & (relative_positions > -mid_val) 179 | abs_pos = tl.where(condition, mid_val - 1.0, abs_relative) 180 | log_numer = tl.log(abs_pos / mid_val) 181 | log_denom = tl.log((MAX_DISTANCE - 1) / mid_val) 182 | log_scaled = log_numer / log_denom * (mid_val - 1.0) 183 | log_pos = tl.ceil(log_scaled) + mid_val 184 | bucket_pos = tl.where(abs_pos <= mid_val, relative_positions, log_pos * sign) 185 | 186 | if HAS_C2P: 187 | c2p_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32) 188 | k_pos_ptrs_ = k_pos_ptrs + c2p_index * stride_pk2 189 | c2p_bias = tl.load(k_pos_ptrs_, mask=mask_m[:, None] & (c2p_index < 2 * ATT_SPAN), other=0.0) 190 | s += c2p_bias * sm_scale 191 | 192 | if HAS_P2C: 193 | current_q_pos_ptrs = Q_POS + (offs_n[None, :] * stride_pq0 + off_h * stride_pq1) 194 | p2c_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32).trans(1, 0) 195 | q_pos_ptrs_ = current_q_pos_ptrs + p2c_index * stride_pq2 196 | p2c_bias = tl.load(q_pos_ptrs_, mask=mask_n[:, None] & (p2c_index < 2 * ATT_SPAN), other=0.0).trans(1, 0) 197 | s += p2c_bias * sm_scale 198 | 199 | s = tl.where(mask_n[None, :], s, float("-inf")) 200 | 201 | if IS_CAUSAL: 202 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] 203 | s = tl.where(causal_mask, s, float("-inf")) 204 | 205 | m_i_new = tl.maximum(m_i, tl.max(s, 1)) 206 | alpha = tl.math.exp2((m_i - m_i_new) * log2e) 207 | p = tl.math.exp2((s - m_i_new[:, None]) * log2e) 208 | acc *= alpha[:, None] 209 | acc += tl.dot(p.to(q.dtype), v) 210 | l_i = l_i * alpha + tl.sum(p, 1) 211 | m_i = m_i_new 212 | 213 | k_ptrs += BLOCK_N * stride_kz 214 | v_ptrs += BLOCK_N * stride_vz 215 | 216 | if IS_CAUSAL and lM > lN: 217 | is_empty_line = (offs_m_relative + P_SEQ) < 0 218 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) 219 | l_val = tl.where(is_empty_line, float("-inf"), m_i + tl.log(l_i)) 220 | else: 221 | acc = acc * (1.0 / l_i[:, None]) 222 | l_val = m_i + tl.log(l_i) 223 | 224 | tl.store(l_ptrs, l_val, mask=mask_m, cache_modifier=".cg") 225 | tl.store(o_ptrs, acc.to(input_dtype), mask=mask_m[:, None], cache_modifier=".cg") 226 | 227 | 228 | def get_fwd_config(total_tokens, max_seqlen_q, max_seqlen_k, D, causal, disentangled=False, att_span=256): 229 | """ 230 | Determine optimal kernel configuration parameters for variable-length sequences. 231 | 232 | Args: 233 | total_tokens: Total number of tokens across all batches 234 | max_seqlen_q: Maximum query sequence length 235 | max_seqlen_k: Maximum key sequence length 236 | D: Per-head dimension 237 | causal: Whether causal masking is applied 238 | disentangled: Whether to use DeBERTa-style disentangled attention 239 | att_span: Size of the attention span for relative positions 240 | 241 | Returns: 242 | Tuple (BLOCK_M, BLOCK_N, num_stages, num_warps) 243 | """ 244 | # See more details on the mapping at: https://forums.developer.nvidia.com/t/dynamic-shared-memory-calculated-by-ncu-larger-than-max-shared-memory-per-block/265589 245 | 246 | capability_map = { 247 | (7,0): 96000, 248 | (7,2): 96000, 249 | (7,5): 64000, 250 | (8,0): 163000, 251 | (8,6): 99000, 252 | (8,7): 163000, 253 | (8,9): 99000, 254 | (9,0): 227000, 255 | } 256 | 257 | capability = torch.cuda.get_device_capability() 258 | device_property = torch.cuda.get_device_properties() 259 | if hasattr(device_property,"shared_memory_per_block_optin"): 260 | shared_mem_per_block = device_property.shared_memory_per_block_optin 261 | elif capability in list(capability_map.keys()): 262 | shared_mem_per_block = capability_map[capability] 263 | elif capability[0] >= 8: 264 | shared_mem_per_block = 99000 265 | else: 266 | shared_mem_per_block = 48000 267 | 268 | max_shared_memory = shared_mem_per_block - 2000 # remove 2kb for ops overhead 269 | 270 | # Start with an aggressive configuration 271 | if capability[0] >= 8: 272 | if not causal: 273 | if D <= 64: 274 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 275 | else: 276 | if max_seqlen_q <= 1024: 277 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 278 | else: 279 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 8 280 | else: # causal 281 | if D <= 64: 282 | if disentangled: 283 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 284 | else: 285 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 4, 4 286 | else: 287 | if max_seqlen_q <= 1024: 288 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 289 | else: 290 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 8 291 | elif capability[0] == 7: 292 | if not causal: 293 | if D <= 64: 294 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 295 | else: 296 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 297 | else: # causal 298 | if D <= 64: 299 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 300 | else: 301 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 302 | else: 303 | BLOCK_M, BLOCK_N, num_stages, num_warps = 16, 16, 1, 4 304 | 305 | # Additional adjustments for variable-length sequences 306 | 307 | # For very sparse batches (many short sequences), reduce block size 308 | avg_seq_len = total_tokens / max(1, torch.cuda.device_count()) 309 | if avg_seq_len < 256: 310 | BLOCK_M = min(BLOCK_M, 64) 311 | BLOCK_N = min(BLOCK_N, 32) 312 | num_stages = max(1, num_stages - 1) # Reduce stages to save memory 313 | 314 | # Calculate shared memory usage with current config 315 | has_pos = disentangled 316 | ATT_SPAN = att_span if has_pos else 0 317 | 318 | dtype = torch.float16 # Assuming float16 is used as in original code 319 | 320 | shared_mem_usage = calculate_shared_memory_usage_varlen( 321 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 322 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN 323 | ) 324 | 325 | # If shared memory usage exceeds available, adjust parameters 326 | # We prioritize reducing num_stages first, then block sizes 327 | while shared_mem_usage > max_shared_memory and (BLOCK_M > 16 or BLOCK_N > 16 or num_stages > 1): 328 | # First try reducing num_stages 329 | if num_stages > 1: 330 | num_stages -= 1 331 | # Then try reducing block sizes 332 | else: 333 | BLOCK_M //= 2 334 | BLOCK_N //= 2 335 | 336 | # Recalculate with new parameters 337 | shared_mem_usage = calculate_shared_memory_usage_varlen( 338 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 339 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN 340 | ) 341 | 342 | warnings.warn(f"INFO: Variable-length forward config is {BLOCK_M}, {BLOCK_N}, {num_stages}, {num_warps} for BLOCK_M, BLOCK_N stages and warps, respectively.\n" 343 | "INFO: If you want to change it, feel free to check ops/flash_attention_varlen") 344 | 345 | return (BLOCK_M, BLOCK_N, num_stages, num_warps) 346 | 347 | 348 | 349 | def flash_attn_v2_fwd_dise(q, k, v, pos_key, pos_query, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, sm_scale, BLOCK_M, BLOCK_N, 350 | position_buckets, max_relative_distance, num_warps, num_stages, ATT_SPAN): 351 | """ 352 | Performs the forward pass of FlashAttention with DeBERTa-style disentangled relative attention. 353 | 354 | This function computes the attention output `o` and log-normalizer `L` for the input query (q), 355 | key (k), and value (v) tensors. It supports disentangled relative attention using optional 356 | positional projection matrices for content-to-position (C2P) and position-to-content (P2C) biases. 357 | 358 | Args: 359 | q (Tensor): Query tensor of shape (B, H, M, D) where 360 | B = batch size, H = number of heads, M = query sequence length, D = head dimension. 361 | k (Tensor): Key tensor of shape (B, H, N, D) where 362 | N = key sequence length. 363 | v (Tensor): Value tensor of shape (B, H, N, D). 364 | pos_key (Tensor or None): Relative position embedding tensor for C2P bias with shape (2 * max_distance, D), 365 | or None to disable content-to-position bias. 366 | pos_query (Tensor or None): Relative position embedding tensor for P2C bias with shape (2 * max_distance, D), 367 | or None to disable position-to-content bias. 368 | causal (bool): If True, applies causal (autoregressive) masking to the attention weights. 369 | sm_scale (float): Scaling factor applied to the dot-product attention scores. 370 | BLOCK_M (int): Block size for splitting the query sequence dimension. 371 | BLOCK_N (int): Block size for splitting the key sequence dimension. 372 | position_buckets (int): Number of relative position buckets. If > 0, bucketing is applied. 373 | max_relative_distance (int): Maximum relative distance used in bucketing or span window size. 374 | num_warps (int): Number of warps used in the Triton kernel (hardware-specific parallelism). 375 | num_stages (int): Number of pipeline stages in the Triton kernel. 376 | 377 | Returns: 378 | o (Tensor): Output attention tensor of shape (B, H, M, D), same shape as `q`. 379 | L (Tensor): Log-sum-exp normalizer tensor of shape (B, H, M), used for numerically stable softmax. 380 | 381 | Notes: 382 | - This function utilizes a custom Triton kernel to efficiently compute block-sparse FlashAttention 383 | with optional relative position biasing (both C2P and P2C). 384 | - The relative attention mechanism supports DeBERTa's disentangled attention formulation, where 385 | the attention bias is computed separately for position-query and key-position interactions. 386 | - The number of relative position buckets and max distance determines the size and behavior 387 | of the relative bias. 388 | """ 389 | M = max_seqlen_q 390 | N = max_seqlen_k 391 | B = len(cu_seqlens_q)-1 392 | Z, H, D = q.shape 393 | 394 | mid_batch, mid_start, MN = get_mid(cu_seqlens_q, B, BLOCK_M) 395 | 396 | mid_batch = torch.LongTensor(mid_batch).to(q.device) 397 | mid_start = torch.LongTensor(mid_start).to(q.device) 398 | 399 | # Determine if each bias term is present. 400 | has_c2p = pos_key is not None 401 | has_p2c = pos_query is not None 402 | 403 | grid = (MN, H) 404 | o = torch.empty_like(q) 405 | L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) 406 | 407 | if has_c2p: 408 | stride_pk0, stride_pk1, stride_pk2 = pos_key.stride() 409 | else: 410 | stride_pk0 = stride_pk1 = stride_pk2 = 0 411 | if has_p2c: 412 | stride_pq0, stride_pq1, stride_pq2 = pos_query.stride() 413 | else: 414 | stride_pq0 = stride_pq1 = stride_pq2 = 0 415 | 416 | with torch.cuda.device(q.device.index): 417 | _fwd_kernel_deberta_disentangled_attention[grid]( 418 | q, k, v, 419 | pos_key, pos_query, 420 | L, o, 421 | sm_scale, 422 | cu_seqlens_q, cu_seqlens_k, 423 | mid_batch, mid_start, 424 | q.stride(0), q.stride(1), q.stride(2), 425 | k.stride(0), k.stride(1), k.stride(2), 426 | v.stride(0), v.stride(1), v.stride(2), 427 | o.stride(0), o.stride(1), o.stride(2), 428 | stride_pk0, stride_pk1, stride_pk2, 429 | stride_pq0, stride_pq1, stride_pq2, 430 | B, H, M, N, 431 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 432 | IS_CAUSAL=causal, 433 | HAS_C2P=has_c2p, HAS_P2C=has_p2c, 434 | ATT_SPAN=ATT_SPAN, 435 | NUM_BUCKETS=position_buckets, 436 | MAX_DISTANCE=max_relative_distance, 437 | num_warps=num_warps, num_stages=num_stages, 438 | ) 439 | 440 | return o, L 441 | 442 | 443 | class FlashAttentionDisentangled(torch.autograd.Function): 444 | @staticmethod 445 | def forward(ctx, q, k, v, q_pos, k_pos, cu_seqlens_q, cu_seqlens_k, 446 | max_seqlen_q, max_seqlen_k, causal, 447 | sm_scale, position_buckets, max_relative_distance): 448 | 449 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] 450 | 451 | assert Dq == Dk == Dv 452 | 453 | BM, H, D = q.shape 454 | 455 | # Determine ATT_SPAN from pos_key: assume shape is (2*ATT_SPAN, D) 456 | if position_buckets>0: 457 | ATT_SPAN = position_buckets 458 | else: 459 | ATT_SPAN = max_relative_distance 460 | 461 | if sm_scale is None: 462 | sm_scale = 1. / math.sqrt(D) 463 | 464 | config = get_fwd_config(total_tokens=BM, 465 | max_seqlen_q=max_seqlen_q, 466 | max_seqlen_k=max_seqlen_k, 467 | D=D, 468 | causal=causal, 469 | disentangled=True, 470 | att_span=ATT_SPAN) 471 | 472 | BLOCK_M, BLOCK_N, num_stages, num_warps = config 473 | 474 | o, L = flash_attn_v2_fwd_dise(q, k, v, q_pos, k_pos, cu_seqlens_q, cu_seqlens_k, 475 | max_seqlen_q, max_seqlen_k, causal, sm_scale, 476 | BLOCK_M, BLOCK_N, position_buckets, 477 | max_relative_distance, num_warps, num_stages, ATT_SPAN=ATT_SPAN) 478 | return o 479 | 480 | @staticmethod 481 | def backward(ctx, grad_output): 482 | # Exclude backward capabilities by raising an error. 483 | raise RuntimeError("Backward pass is not implemented for FlashAttentionDisentangled") 484 | 485 | def flash_attention_with_disentangled_varlen(q, k, v, q_pos, k_pos, cu_seqlens_q, cu_seqlens_k, 486 | max_seqlen_q, max_seqlen_k, causal=False, sm_scale=None, 487 | position_buckets=0, max_relative_distance=0): 488 | """ 489 | An implementation of FlashAttention v2 with DeBERTa-style disentangled relative attention. 490 | This version does not support backward propagation. 491 | 492 | Args: 493 | q (Tensor): Queries of shape (B, H, M, D). 494 | k (Tensor): Keys of shape (B, H, N, D). 495 | v (Tensor): Values of shape (B, H, N, D). 496 | q_pos (Tensor): Relative projection tensor for content→position bias. 497 | k_pos (Tensor): Relative projection tensor for position→content bias. 498 | causal (bool): Whether to apply causal masking. 499 | sm_scale (float): Scaling factor for softmax (if None, uses 1/sqrt(D)). 500 | position_buckets (int): Number of position buckets. 501 | max_relative_distance (int): Maximum relative distance. 502 | 503 | Returns: 504 | out (Tensor): Output tensor of shape (B, H, M, D). 505 | 506 | Note: 507 | The backward pass is not implemented, so this function only supports forward propagation. 508 | """ 509 | return FlashAttentionDisentangled.apply(q, k, v, q_pos, k_pos, cu_seqlens_q, cu_seqlens_k, 510 | max_seqlen_q, max_seqlen_k, causal, sm_scale, 511 | position_buckets, max_relative_distance) 512 | 513 | -------------------------------------------------------------------------------- /src/flashdeberta/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple, Dict 2 | import logging 3 | import math 4 | import torch 5 | from torch import nn 6 | from torch.nn import (LayerNorm, CrossEntropyLoss, MSELoss, BCEWithLogitsLoss) 7 | from transformers import PreTrainedModel, PretrainedConfig 8 | from transformers.modeling_outputs import (BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, 9 | QuestionAnsweringModelOutput, MultipleChoiceModelOutput, 10 | TokenClassifierOutput) 11 | from transformers.models.deberta_v2.modeling_deberta_v2 import (DisentangledSelfAttention, 12 | DebertaV2Attention, 13 | DebertaV2SelfOutput, 14 | DebertaV2Intermediate, 15 | DebertaV2Output, 16 | DebertaV2Layer, 17 | ConvLayer, 18 | DebertaV2Embeddings, 19 | DebertaV2Encoder, 20 | DebertaV2Config, 21 | LegacyDebertaV2OnlyMLMHead, 22 | DebertaV2OnlyMLMHead, 23 | ContextPooler, 24 | LegacyDebertaV2LMPredictionHead, 25 | DebertaV2LMPredictionHead 26 | ) 27 | from transformers.utils import logging 28 | 29 | from .padding import _upad_input, pad_input 30 | from .ops.flash_attention import flash_attention_with_disentangled 31 | from .ops.flash_attention_varlen import flash_attention_with_disentangled_varlen 32 | 33 | 34 | logger = logging.get_logger(__name__) 35 | 36 | class DebertaV2Config(PretrainedConfig): 37 | model_type = "deberta-v2" 38 | 39 | def __init__( 40 | self, 41 | vocab_size=128100, 42 | hidden_size=1536, 43 | num_hidden_layers=24, 44 | num_attention_heads=24, 45 | intermediate_size=6144, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=0, 51 | initializer_range=0.02, 52 | layer_norm_eps=1e-7, 53 | relative_attention=False, 54 | max_relative_positions=-1, 55 | pad_token_id=0, 56 | position_biased_input=True, 57 | pos_att_type=None, 58 | pooler_dropout=0, 59 | pooler_hidden_act="gelu", 60 | legacy=True, 61 | _attn_implementation_autoset = True, 62 | _attn_implementation='flash_attention_2', 63 | **kwargs, 64 | ): 65 | super().__init__(**kwargs) 66 | 67 | self.hidden_size = hidden_size 68 | self.num_hidden_layers = num_hidden_layers 69 | self.num_attention_heads = num_attention_heads 70 | self.intermediate_size = intermediate_size 71 | self.hidden_act = hidden_act 72 | self.hidden_dropout_prob = hidden_dropout_prob 73 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 74 | self.max_position_embeddings = max_position_embeddings 75 | self.type_vocab_size = type_vocab_size 76 | self.initializer_range = initializer_range 77 | self.relative_attention = relative_attention 78 | self.max_relative_positions = max_relative_positions 79 | self.pad_token_id = pad_token_id 80 | self.position_biased_input = position_biased_input 81 | self._attn_implementation_autoset = _attn_implementation_autoset 82 | self._attn_implementation = _attn_implementation 83 | 84 | # Backwards compatibility 85 | if isinstance(pos_att_type, str): 86 | pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")] 87 | 88 | self.pos_att_type = pos_att_type 89 | self.vocab_size = vocab_size 90 | self.layer_norm_eps = layer_norm_eps 91 | 92 | self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size) 93 | self.pooler_dropout = pooler_dropout 94 | self.pooler_hidden_act = pooler_hidden_act 95 | self.legacy = legacy 96 | 97 | class FlashDisentangledSelfAttention(DisentangledSelfAttention): 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | 101 | def forward(self, hidden_states, 102 | attention_mask, 103 | output_attentions=False, 104 | query_states=None, 105 | relative_pos=None, 106 | rel_embeddings=None): 107 | """ 108 | Performs the flash attention forward pass with disentangled relative attention. 109 | 110 | Args: 111 | hidden_states (Tensor): Input tensor of shape (B, L, hidden_size). 112 | attention_mask (Tensor): The attention mask. 113 | output_attentions (bool): Whether to return attention weights. 114 | query_states (Tensor, optional): If provided, used as Q. 115 | relative_pos (Tensor, optional): Relative position encoding. If None, will be built. 116 | causal (bool): Whether to apply causal masking. 117 | sm_scale (float, optional): Scaling factor for softmax. 118 | 119 | Returns: 120 | Tuple[Tensor, None]: A tuple where the first element is the output tensor of shape (B, L, hidden_size). 121 | """ 122 | if attention_mask is not None: 123 | total_length = torch.sum(attention_mask) 124 | max_length = torch.prod(torch.tensor(attention_mask.shape).to(attention_mask.device)) 125 | if max_length==total_length: 126 | varlen = False 127 | else: 128 | varlen = True 129 | else: 130 | varlen = False 131 | 132 | if query_states is None: 133 | query_states = hidden_states 134 | 135 | B, L, _ = hidden_states.shape 136 | 137 | def transform(x, attention_heads): 138 | new_x_shape = x.size()[:-1] + (attention_heads, -1) 139 | x = x.view(new_x_shape).permute(0, 2, 1, 3).contiguous() 140 | return x 141 | 142 | def get_heads(x, attention_heads): 143 | new_x_shape = x.size()[:-1] + (attention_heads, -1) 144 | x = x.view(new_x_shape).contiguous() 145 | return x 146 | 147 | query_layer = self.query_proj(query_states) 148 | key_layer = self.key_proj(hidden_states) 149 | value_layer = self.value_proj(hidden_states) 150 | 151 | scale_factor = 1 152 | if "c2p" in self.pos_att_type: 153 | scale_factor += 1 154 | if "p2c" in self.pos_att_type: 155 | scale_factor += 1 156 | 157 | sm_scale = 1/math.sqrt(self.attention_head_size*scale_factor) 158 | 159 | if self.relative_attention: 160 | rel_embeddings = self.pos_dropout(rel_embeddings) 161 | if self.share_att_key: 162 | pos_key_layer = transform( 163 | self.key_proj(rel_embeddings.unsqueeze(0)), self.num_attention_heads 164 | ) # (1, NH, MD, head_dim) 165 | pos_query_layer = transform( 166 | self.query_proj(rel_embeddings.unsqueeze(0)), self.num_attention_heads 167 | ) 168 | else: 169 | if "c2p" in self.pos_att_type: 170 | pos_key_layer = transform( 171 | self.pos_key_proj(rel_embeddings.unsqueeze(0)), self.num_attention_heads 172 | ) 173 | 174 | if "p2c" in self.pos_att_type: 175 | pos_query_layer = transform( 176 | self.pos_query_proj(rel_embeddings.unsqueeze(0)), self.num_attention_heads 177 | ) 178 | pos_key = None 179 | pos_query = None 180 | else: 181 | pos_key, pos_query = None, None 182 | 183 | causal = False 184 | if not varlen: 185 | query_layer = transform(query_layer, self.num_attention_heads) # (B, NH, L, head_dim) 186 | key_layer = transform(key_layer, self.num_attention_heads) 187 | value_layer = transform(value_layer, self.num_attention_heads) 188 | 189 | if "c2p" in self.pos_att_type: 190 | pos_key = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) 191 | if "p2c" in self.pos_att_type: 192 | pos_query = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) 193 | 194 | out = flash_attention_with_disentangled( 195 | query_layer, 196 | key_layer, 197 | value_layer, 198 | pos_key, 199 | pos_query, 200 | causal, 201 | sm_scale, 202 | self.position_buckets, 203 | self.max_relative_positions, 204 | ) 205 | else: 206 | query_layer = get_heads(query_layer, self.num_attention_heads) # (B, L, NH, head_dim) 207 | key_layer = get_heads(key_layer, self.num_attention_heads) 208 | value_layer = get_heads(value_layer, self.num_attention_heads) 209 | 210 | if "c2p" in self.pos_att_type: 211 | # query_layer = (1, NH, L, head_dim) 212 | pos_key = torch.einsum("bqhd,zhmd->bqhm", query_layer, pos_key_layer) 213 | if "p2c" in self.pos_att_type: 214 | pos_query = torch.einsum("bqhd,zhmd->bqhm", key_layer, pos_query_layer) 215 | 216 | (query_layer, 217 | key_layer, 218 | value_layer, 219 | pos_key, 220 | pos_query, 221 | indices_q, 222 | cu_seq_lens, 223 | max_seq_lens) = _upad_input( 224 | query_layer, 225 | key_layer, 226 | value_layer, 227 | pos_key, 228 | pos_query, 229 | attention_mask, 230 | L, 231 | self.num_attention_heads) 232 | 233 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 234 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 235 | 236 | out_unpad = flash_attention_with_disentangled_varlen( 237 | query_layer, 238 | key_layer, 239 | value_layer, 240 | pos_key, 241 | pos_query, 242 | cu_seqlens_q, cu_seqlens_k, 243 | max_seqlen_in_batch_q, max_seqlen_in_batch_k, 244 | causal, 245 | sm_scale, 246 | self.position_buckets, 247 | self.max_relative_positions, 248 | ) 249 | out = pad_input(out_unpad, indices_q, B, L).transpose(1, 2) 250 | 251 | out = out.view(B, self.num_attention_heads, L, self.attention_head_size).transpose(1, 2).reshape(B, L, self.all_head_size) 252 | return (out, None) 253 | 254 | DEBERTA_SELF_ATTENTION_CLASSES = { 255 | "eager": DisentangledSelfAttention, 256 | "flash_attention_2": FlashDisentangledSelfAttention, 257 | } 258 | 259 | class FlashDebertaV2Attention(DebertaV2Attention): 260 | def __init__(self, config): 261 | super().__init__(config) 262 | self.self = DEBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](config) 263 | self.output = DebertaV2SelfOutput(config) 264 | self.config = config 265 | 266 | 267 | class FlashDebertaV2Layer(DebertaV2Layer): 268 | def __init__(self, config): 269 | super().__init__(config) 270 | self.attention = FlashDebertaV2Attention(config) 271 | self.intermediate = DebertaV2Intermediate(config) 272 | self.output = DebertaV2Output(config) 273 | 274 | 275 | class FlashDebertaV2Encoder(DebertaV2Encoder): 276 | def __init__(self, config): 277 | super().__init__(config) 278 | 279 | self.layer = nn.ModuleList([FlashDebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) 280 | self.relative_attention = getattr(config, "relative_attention", False) 281 | 282 | if self.relative_attention: 283 | self.max_relative_positions = getattr(config, "max_relative_positions", -1) 284 | if self.max_relative_positions < 1: 285 | self.max_relative_positions = config.max_position_embeddings 286 | 287 | self.position_buckets = getattr(config, "position_buckets", -1) 288 | pos_ebd_size = self.max_relative_positions * 2 289 | 290 | if self.position_buckets > 0: 291 | pos_ebd_size = self.position_buckets * 2 292 | 293 | self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) 294 | 295 | self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] 296 | 297 | if "layer_norm" in self.norm_rel_ebd: 298 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) 299 | 300 | self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None 301 | self.gradient_checkpointing = False 302 | 303 | def get_attention_mask(self, attention_mask): 304 | if attention_mask.dim() <= 2: 305 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 306 | attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) 307 | elif attention_mask.dim() == 3: 308 | attention_mask = attention_mask.unsqueeze(1) 309 | 310 | return attention_mask 311 | 312 | class FlashDebertaV2PreTrainedModel(PreTrainedModel): 313 | """ 314 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 315 | models. 316 | """ 317 | config_class = DebertaV2Config 318 | base_model_prefix = "deberta" 319 | _keys_to_ignore_on_load_unexpected = ["position_embeddings"] 320 | supports_gradient_checkpointing = True 321 | _supports_flash_attn_2 = True 322 | 323 | def _init_weights(self, module): 324 | """Initialize the weights.""" 325 | if isinstance(module, nn.Linear): 326 | # Slightly different from the TF version which uses truncated_normal for initialization 327 | # cf https://github.com/pytorch/pytorch/pull/5617 328 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 329 | if module.bias is not None: 330 | module.bias.data.zero_() 331 | elif isinstance(module, nn.Embedding): 332 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 333 | if module.padding_idx is not None: 334 | module.weight.data[module.padding_idx].zero_() 335 | elif isinstance(module, nn.LayerNorm): 336 | module.weight.data.fill_(1.0) 337 | module.bias.data.zero_() 338 | elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)): 339 | module.bias.data.zero_() 340 | 341 | @classmethod 342 | def _autoset_attn_implementation( 343 | cls, 344 | config, 345 | use_flash_attention_2: bool = False, 346 | torch_dtype: Optional[torch.dtype] = None, 347 | device_map: Optional[Union[str, Dict[str, int]]] = None, 348 | check_device_map: bool = True, 349 | ): 350 | """ 351 | Decide which attention backend to use. 352 | Priority 353 | -------- 354 | 1. Respect an explicit value already sitting in `config._attn_implementation` 355 | (e.g. user passed `attn_implementation="sdpa"` to `from_pretrained`). 356 | 2. If `use_flash_attention_2=True` **and** a compatible GPU, dtype, and 357 | flash-attn-2 kernels are available → choose `"flash_attention_2"`. 358 | 3. If PyTorch’s scaled-dot-product attention is available → `"sdpa"`. 359 | 4. Otherwise → `"eager"`. 360 | """ 361 | if getattr(config, "_attn_implementation", None) and not getattr( 362 | config, "_attn_implementation_autoset", False 363 | ): 364 | return config 365 | 366 | torch_dtype = torch_dtype or getattr(config, "torch_dtype", None) 367 | 368 | on_cuda = ( 369 | torch.cuda.is_available() 370 | and (device_map is None or (isinstance(device_map, str) and device_map != "cpu")) 371 | ) 372 | 373 | if ( 374 | use_flash_attention_2 375 | and cls._supports_flash_attn_2 376 | and on_cuda 377 | and torch_dtype in (None, torch.float16, torch.bfloat16) 378 | ): 379 | config._attn_implementation = "flash_attention_2" 380 | config._flash_attn_2_enabled = True 381 | config._attn_implementation_autoset = True 382 | return config 383 | 384 | if check_device_map and not on_cuda and torch.cuda.is_available(): 385 | logger.warning_once( 386 | "FlashDeBERTa is being initialised on CPU. Move the model to GPU " 387 | "with `model.to('cuda')` to benefit from Flash-Attention/SDPA." 388 | ) 389 | 390 | config._attn_implementation = "eager" 391 | config._attn_implementation_autoset = True 392 | return config 393 | 394 | class FlashDebertaV2Model(FlashDebertaV2PreTrainedModel): 395 | def __init__(self, config): 396 | super().__init__(config) 397 | 398 | self.embeddings = DebertaV2Embeddings(config) 399 | self.encoder = FlashDebertaV2Encoder(config) 400 | self.z_steps = 0 401 | self.config = config 402 | # Initialize weights and apply final processing 403 | self.post_init() 404 | 405 | def get_input_embeddings(self): 406 | return self.embeddings.word_embeddings 407 | 408 | def set_input_embeddings(self, new_embeddings): 409 | self.embeddings.word_embeddings = new_embeddings 410 | 411 | def _prune_heads(self, heads_to_prune): 412 | """ 413 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 414 | class PreTrainedModel 415 | """ 416 | raise NotImplementedError("The prune function is not implemented in DeBERTa model.") 417 | 418 | def forward( 419 | self, 420 | input_ids: Optional[torch.Tensor] = None, 421 | attention_mask: Optional[torch.Tensor] = None, 422 | token_type_ids: Optional[torch.Tensor] = None, 423 | position_ids: Optional[torch.Tensor] = None, 424 | inputs_embeds: Optional[torch.Tensor] = None, 425 | output_attentions: Optional[bool] = None, 426 | output_hidden_states: Optional[bool] = None, 427 | return_dict: Optional[bool] = None, 428 | ) -> Union[Tuple, BaseModelOutput]: 429 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 430 | output_hidden_states = ( 431 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 432 | ) 433 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 434 | 435 | if input_ids is not None and inputs_embeds is not None: 436 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 437 | elif input_ids is not None: 438 | self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) 439 | input_shape = input_ids.size() 440 | elif inputs_embeds is not None: 441 | input_shape = inputs_embeds.size()[:-1] 442 | else: 443 | raise ValueError("You have to specify either input_ids or inputs_embeds") 444 | 445 | device = input_ids.device if input_ids is not None else inputs_embeds.device 446 | 447 | if attention_mask is None: 448 | attention_mask = torch.ones(input_shape, device=device) 449 | if token_type_ids is None: 450 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 451 | 452 | embedding_output = self.embeddings( 453 | input_ids=input_ids, 454 | token_type_ids=token_type_ids, 455 | position_ids=position_ids, 456 | mask=attention_mask, 457 | inputs_embeds=inputs_embeds, 458 | ) 459 | 460 | encoder_outputs = self.encoder( 461 | embedding_output, 462 | attention_mask, 463 | output_hidden_states=True, 464 | output_attentions=output_attentions, 465 | return_dict=return_dict, 466 | ) 467 | encoded_layers = encoder_outputs[1] 468 | 469 | if self.z_steps > 1: 470 | hidden_states = encoded_layers[-2] 471 | layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] 472 | query_states = encoded_layers[-1] 473 | rel_embeddings = self.encoder.get_rel_embedding() 474 | attention_mask = self.encoder.get_attention_mask(attention_mask) 475 | rel_pos = self.encoder.get_rel_pos(embedding_output) 476 | for layer in layers[1:]: 477 | query_states = layer( 478 | hidden_states, 479 | attention_mask, 480 | output_attentions=False, 481 | query_states=query_states, 482 | relative_pos=rel_pos, 483 | rel_embeddings=rel_embeddings, 484 | ) 485 | encoded_layers.append(query_states) 486 | 487 | sequence_output = encoded_layers[-1] 488 | 489 | if not return_dict: 490 | return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] 491 | 492 | return BaseModelOutput( 493 | last_hidden_state=sequence_output, 494 | hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, 495 | attentions=encoder_outputs.attentions, 496 | ) 497 | 498 | class FlashDebertaV2ForMaskedLM(FlashDebertaV2PreTrainedModel): 499 | _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] 500 | _keys_to_ignore_on_load_unexpected = r"mask_predictions.*" 501 | 502 | def __init__(self, config): 503 | super().__init__(config) 504 | self.legacy = config.legacy 505 | self.deberta = FlashDebertaV2Model(config) 506 | if self.legacy: 507 | self.cls = LegacyDebertaV2OnlyMLMHead(config) 508 | else: 509 | self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"] 510 | self.lm_predictions = DebertaV2OnlyMLMHead(config) 511 | # Initialize weights and apply final processing 512 | self.post_init() 513 | 514 | def get_output_embeddings(self): 515 | if self.legacy: 516 | return self.cls.predictions.decoder 517 | else: 518 | return self.lm_predictions.lm_head.dense 519 | 520 | def set_output_embeddings(self, new_embeddings): 521 | if self.legacy: 522 | self.cls.predictions.decoder = new_embeddings 523 | self.cls.predictions.bias = new_embeddings.bias 524 | else: 525 | self.lm_predictions.lm_head.dense = new_embeddings 526 | self.lm_predictions.lm_head.bias = new_embeddings.bias 527 | 528 | # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2 529 | def forward( 530 | self, 531 | input_ids: Optional[torch.Tensor] = None, 532 | attention_mask: Optional[torch.Tensor] = None, 533 | token_type_ids: Optional[torch.Tensor] = None, 534 | position_ids: Optional[torch.Tensor] = None, 535 | inputs_embeds: Optional[torch.Tensor] = None, 536 | labels: Optional[torch.Tensor] = None, 537 | output_attentions: Optional[bool] = None, 538 | output_hidden_states: Optional[bool] = None, 539 | return_dict: Optional[bool] = None, 540 | ) -> Union[Tuple, MaskedLMOutput]: 541 | r""" 542 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 543 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 544 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 545 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 546 | """ 547 | 548 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 549 | 550 | outputs = self.deberta( 551 | input_ids, 552 | attention_mask=attention_mask, 553 | token_type_ids=token_type_ids, 554 | position_ids=position_ids, 555 | inputs_embeds=inputs_embeds, 556 | output_attentions=output_attentions, 557 | output_hidden_states=output_hidden_states, 558 | return_dict=return_dict, 559 | ) 560 | 561 | sequence_output = outputs[0] 562 | if self.legacy: 563 | prediction_scores = self.cls(sequence_output) 564 | else: 565 | prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings) 566 | 567 | masked_lm_loss = None 568 | if labels is not None: 569 | loss_fct = CrossEntropyLoss() # -100 index = padding token 570 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 571 | 572 | if not return_dict: 573 | output = (prediction_scores,) + outputs[1:] 574 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 575 | 576 | return MaskedLMOutput( 577 | loss=masked_lm_loss, 578 | logits=prediction_scores, 579 | hidden_states=outputs.hidden_states, 580 | attentions=outputs.attentions, 581 | ) 582 | 583 | 584 | class FlashDebertaV2ForSequenceClassification(FlashDebertaV2PreTrainedModel): 585 | def __init__(self, config): 586 | super().__init__(config) 587 | 588 | num_labels = getattr(config, "num_labels", 2) 589 | self.num_labels = num_labels 590 | 591 | self.deberta = FlashDebertaV2Model(config) 592 | self.pooler = ContextPooler(config) 593 | output_dim = self.pooler.output_dim 594 | 595 | self.classifier = nn.Linear(output_dim, num_labels) 596 | drop_out = getattr(config, "cls_dropout", None) 597 | drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out 598 | self.dropout = nn.Dropout(drop_out) 599 | 600 | # Initialize weights and apply final processing 601 | self.post_init() 602 | 603 | def get_input_embeddings(self): 604 | return self.deberta.get_input_embeddings() 605 | 606 | def set_input_embeddings(self, new_embeddings): 607 | self.deberta.set_input_embeddings(new_embeddings) 608 | 609 | # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2 610 | def forward( 611 | self, 612 | input_ids: Optional[torch.Tensor] = None, 613 | attention_mask: Optional[torch.Tensor] = None, 614 | token_type_ids: Optional[torch.Tensor] = None, 615 | position_ids: Optional[torch.Tensor] = None, 616 | inputs_embeds: Optional[torch.Tensor] = None, 617 | labels: Optional[torch.Tensor] = None, 618 | output_attentions: Optional[bool] = None, 619 | output_hidden_states: Optional[bool] = None, 620 | return_dict: Optional[bool] = None, 621 | ) -> Union[Tuple, SequenceClassifierOutput]: 622 | r""" 623 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 624 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 625 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 626 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 627 | """ 628 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 629 | 630 | outputs = self.deberta( 631 | input_ids, 632 | token_type_ids=token_type_ids, 633 | attention_mask=attention_mask, 634 | position_ids=position_ids, 635 | inputs_embeds=inputs_embeds, 636 | output_attentions=output_attentions, 637 | output_hidden_states=output_hidden_states, 638 | return_dict=return_dict, 639 | ) 640 | 641 | encoder_layer = outputs[0] 642 | pooled_output = self.pooler(encoder_layer) 643 | pooled_output = self.dropout(pooled_output) 644 | logits = self.classifier(pooled_output) 645 | 646 | loss = None 647 | if labels is not None: 648 | if self.config.problem_type is None: 649 | if self.num_labels == 1: 650 | # regression task 651 | loss_fn = nn.MSELoss() 652 | logits = logits.view(-1).to(labels.dtype) 653 | loss = loss_fn(logits, labels.view(-1)) 654 | elif labels.dim() == 1 or labels.size(-1) == 1: 655 | label_index = (labels >= 0).nonzero() 656 | labels = labels.long() 657 | if label_index.size(0) > 0: 658 | labeled_logits = torch.gather( 659 | logits, 0, label_index.expand(label_index.size(0), logits.size(1)) 660 | ) 661 | labels = torch.gather(labels, 0, label_index.view(-1)) 662 | loss_fct = CrossEntropyLoss() 663 | loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) 664 | else: 665 | loss = torch.tensor(0).to(logits) 666 | else: 667 | log_softmax = nn.LogSoftmax(-1) 668 | loss = -((log_softmax(logits) * labels).sum(-1)).mean() 669 | elif self.config.problem_type == "regression": 670 | loss_fct = MSELoss() 671 | if self.num_labels == 1: 672 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 673 | else: 674 | loss = loss_fct(logits, labels) 675 | elif self.config.problem_type == "single_label_classification": 676 | loss_fct = CrossEntropyLoss() 677 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 678 | elif self.config.problem_type == "multi_label_classification": 679 | loss_fct = BCEWithLogitsLoss() 680 | loss = loss_fct(logits, labels) 681 | if not return_dict: 682 | output = (logits,) + outputs[1:] 683 | return ((loss,) + output) if loss is not None else output 684 | 685 | return SequenceClassifierOutput( 686 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions 687 | ) 688 | 689 | 690 | class FlashDebertaV2ForTokenClassification(FlashDebertaV2PreTrainedModel): 691 | def __init__(self, config): 692 | super().__init__(config) 693 | self.num_labels = config.num_labels 694 | 695 | self.deberta = FlashDebertaV2Model(config) 696 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 697 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 698 | 699 | # Initialize weights and apply final processing 700 | self.post_init() 701 | 702 | def forward( 703 | self, 704 | input_ids: Optional[torch.Tensor] = None, 705 | attention_mask: Optional[torch.Tensor] = None, 706 | token_type_ids: Optional[torch.Tensor] = None, 707 | position_ids: Optional[torch.Tensor] = None, 708 | inputs_embeds: Optional[torch.Tensor] = None, 709 | labels: Optional[torch.Tensor] = None, 710 | output_attentions: Optional[bool] = None, 711 | output_hidden_states: Optional[bool] = None, 712 | return_dict: Optional[bool] = None, 713 | ) -> Union[Tuple, TokenClassifierOutput]: 714 | r""" 715 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 716 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 717 | """ 718 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 719 | 720 | outputs = self.deberta( 721 | input_ids, 722 | attention_mask=attention_mask, 723 | token_type_ids=token_type_ids, 724 | position_ids=position_ids, 725 | inputs_embeds=inputs_embeds, 726 | output_attentions=output_attentions, 727 | output_hidden_states=output_hidden_states, 728 | return_dict=return_dict, 729 | ) 730 | 731 | sequence_output = outputs[0] 732 | 733 | sequence_output = self.dropout(sequence_output) 734 | logits = self.classifier(sequence_output) 735 | 736 | loss = None 737 | if labels is not None: 738 | loss_fct = CrossEntropyLoss() 739 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 740 | 741 | if not return_dict: 742 | output = (logits,) + outputs[1:] 743 | return ((loss,) + output) if loss is not None else output 744 | 745 | return TokenClassifierOutput( 746 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions 747 | ) 748 | 749 | class FlashDebertaV2ForQuestionAnswering(FlashDebertaV2PreTrainedModel): 750 | def __init__(self, config): 751 | super().__init__(config) 752 | self.num_labels = config.num_labels 753 | 754 | self.deberta = FlashDebertaV2Model(config) 755 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 756 | 757 | # Initialize weights and apply final processing 758 | self.post_init() 759 | 760 | # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2 761 | def forward( 762 | self, 763 | input_ids: Optional[torch.Tensor] = None, 764 | attention_mask: Optional[torch.Tensor] = None, 765 | token_type_ids: Optional[torch.Tensor] = None, 766 | position_ids: Optional[torch.Tensor] = None, 767 | inputs_embeds: Optional[torch.Tensor] = None, 768 | start_positions: Optional[torch.Tensor] = None, 769 | end_positions: Optional[torch.Tensor] = None, 770 | output_attentions: Optional[bool] = None, 771 | output_hidden_states: Optional[bool] = None, 772 | return_dict: Optional[bool] = None, 773 | ) -> Union[Tuple, QuestionAnsweringModelOutput]: 774 | r""" 775 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 776 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 777 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 778 | are not taken into account for computing the loss. 779 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 780 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 781 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 782 | are not taken into account for computing the loss. 783 | """ 784 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 785 | 786 | outputs = self.deberta( 787 | input_ids, 788 | attention_mask=attention_mask, 789 | token_type_ids=token_type_ids, 790 | position_ids=position_ids, 791 | inputs_embeds=inputs_embeds, 792 | output_attentions=output_attentions, 793 | output_hidden_states=output_hidden_states, 794 | return_dict=return_dict, 795 | ) 796 | 797 | sequence_output = outputs[0] 798 | 799 | logits = self.qa_outputs(sequence_output) 800 | start_logits, end_logits = logits.split(1, dim=-1) 801 | start_logits = start_logits.squeeze(-1).contiguous() 802 | end_logits = end_logits.squeeze(-1).contiguous() 803 | 804 | total_loss = None 805 | if start_positions is not None and end_positions is not None: 806 | # If we are on multi-GPU, split add a dimension 807 | if len(start_positions.size()) > 1: 808 | start_positions = start_positions.squeeze(-1) 809 | if len(end_positions.size()) > 1: 810 | end_positions = end_positions.squeeze(-1) 811 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 812 | ignored_index = start_logits.size(1) 813 | start_positions = start_positions.clamp(0, ignored_index) 814 | end_positions = end_positions.clamp(0, ignored_index) 815 | 816 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 817 | start_loss = loss_fct(start_logits, start_positions) 818 | end_loss = loss_fct(end_logits, end_positions) 819 | total_loss = (start_loss + end_loss) / 2 820 | 821 | if not return_dict: 822 | output = (start_logits, end_logits) + outputs[1:] 823 | return ((total_loss,) + output) if total_loss is not None else output 824 | 825 | return QuestionAnsweringModelOutput( 826 | loss=total_loss, 827 | start_logits=start_logits, 828 | end_logits=end_logits, 829 | hidden_states=outputs.hidden_states, 830 | attentions=outputs.attentions, 831 | ) 832 | 833 | 834 | class FlashDebertaV2ForMultipleChoice(FlashDebertaV2PreTrainedModel): 835 | def __init__(self, config): 836 | super().__init__(config) 837 | 838 | num_labels = getattr(config, "num_labels", 2) 839 | self.num_labels = num_labels 840 | 841 | self.deberta = FlashDebertaV2Model(config) 842 | self.pooler = ContextPooler(config) 843 | output_dim = self.pooler.output_dim 844 | 845 | self.classifier = nn.Linear(output_dim, 1) 846 | drop_out = getattr(config, "cls_dropout", None) 847 | drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out 848 | self.dropout = nn.Dropout(drop_out) 849 | 850 | self.init_weights() 851 | 852 | def get_input_embeddings(self): 853 | return self.deberta.get_input_embeddings() 854 | 855 | def set_input_embeddings(self, new_embeddings): 856 | self.deberta.set_input_embeddings(new_embeddings) 857 | 858 | def forward( 859 | self, 860 | input_ids: Optional[torch.Tensor] = None, 861 | attention_mask: Optional[torch.Tensor] = None, 862 | token_type_ids: Optional[torch.Tensor] = None, 863 | position_ids: Optional[torch.Tensor] = None, 864 | inputs_embeds: Optional[torch.Tensor] = None, 865 | labels: Optional[torch.Tensor] = None, 866 | output_attentions: Optional[bool] = None, 867 | output_hidden_states: Optional[bool] = None, 868 | return_dict: Optional[bool] = None, 869 | ) -> Union[Tuple, MultipleChoiceModelOutput]: 870 | r""" 871 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 872 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., 873 | num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See 874 | `input_ids` above) 875 | """ 876 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 877 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 878 | 879 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 880 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 881 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 882 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 883 | flat_inputs_embeds = ( 884 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 885 | if inputs_embeds is not None 886 | else None 887 | ) 888 | 889 | outputs = self.deberta( 890 | flat_input_ids, 891 | position_ids=flat_position_ids, 892 | token_type_ids=flat_token_type_ids, 893 | attention_mask=flat_attention_mask, 894 | inputs_embeds=flat_inputs_embeds, 895 | output_attentions=output_attentions, 896 | output_hidden_states=output_hidden_states, 897 | return_dict=return_dict, 898 | ) 899 | 900 | encoder_layer = outputs[0] 901 | pooled_output = self.pooler(encoder_layer) 902 | pooled_output = self.dropout(pooled_output) 903 | logits = self.classifier(pooled_output) 904 | reshaped_logits = logits.view(-1, num_choices) 905 | 906 | loss = None 907 | if labels is not None: 908 | loss_fct = CrossEntropyLoss() 909 | loss = loss_fct(reshaped_logits, labels) 910 | 911 | if not return_dict: 912 | output = (reshaped_logits,) + outputs[1:] 913 | return ((loss,) + output) if loss is not None else output 914 | 915 | return MultipleChoiceModelOutput( 916 | loss=loss, 917 | logits=reshaped_logits, 918 | hidden_states=outputs.hidden_states, 919 | attentions=outputs.attentions, 920 | ) 921 | -------------------------------------------------------------------------------- /src/flashdeberta/ops/flash_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import triton 4 | import warnings 5 | import triton.language as tl 6 | 7 | def calculate_shared_memory_usage(BLOCK_M, BLOCK_N, BLOCK_DMODEL, num_stages, dtype, 8 | has_c2p=False, has_p2c=False, ATT_SPAN=0): 9 | """ 10 | Calculate the shared memory requirements for Flash Attention with disentangled attention. 11 | 12 | Args: 13 | BLOCK_M: Block size for query sequence dimension 14 | BLOCK_N: Block size for key sequence dimension 15 | BLOCK_DMODEL: Head dimension size 16 | num_stages: Number of pipeline stages 17 | dtype: Data type (torch.float16, torch.float32, etc.) 18 | has_c2p: Whether content-to-position bias is used 19 | has_p2c: Whether position-to-content bias is used 20 | ATT_SPAN: Attention span for relative position 21 | 22 | Returns: 23 | The estimated shared memory usage in bytes 24 | """ 25 | # Determine byte size based on data type 26 | if dtype == torch.float16: 27 | dtype_size = 2 28 | elif dtype == torch.float32: 29 | dtype_size = 4 30 | else: 31 | dtype_size = 2 # Default to float16 size for other types 32 | 33 | # Core tensors that are always used 34 | q_size = BLOCK_M * BLOCK_DMODEL * dtype_size 35 | k_size = BLOCK_N * BLOCK_DMODEL * dtype_size 36 | v_size = BLOCK_N * BLOCK_DMODEL * dtype_size 37 | 38 | # Memory for attention scores and accumulator 39 | attn_matrix_size = BLOCK_M * BLOCK_N * dtype_size 40 | accumulator_size = BLOCK_M * BLOCK_DMODEL * dtype_size 41 | 42 | # Position embedding memory if needed 43 | pos_memory = 0 44 | if has_c2p: 45 | pos_memory += BLOCK_M * 2 * ATT_SPAN * dtype_size 46 | if has_p2c: 47 | pos_memory += BLOCK_N * 2 * ATT_SPAN * dtype_size 48 | 49 | # Additional buffers for intermediate calculations 50 | # This includes arrays for relative positions, bucket indices, etc. 51 | additional_buffers = BLOCK_M * BLOCK_N * 4 # For relative position indices, floating-point calculations 52 | 53 | # Total memory per stage 54 | memory_per_stage = q_size + k_size + v_size + attn_matrix_size + pos_memory + additional_buffers 55 | 56 | # Total shared memory including all pipeline stages 57 | total_shared_memory = num_stages * memory_per_stage + accumulator_size 58 | 59 | return total_shared_memory // 2 60 | 61 | 62 | def calculate_shared_memory_usage_bwd( 63 | BLOCK_M, 64 | BLOCK_N, 65 | BLOCK_DMODEL, 66 | num_stages, 67 | dtype=torch.float16, 68 | *, 69 | has_c2p=False, 70 | has_p2c=False, 71 | ATT_SPAN=0, 72 | store_lse=True, 73 | recompute_probs=True, 74 | accum_dq=True, 75 | accum_dkv=True, 76 | ): 77 | """ 78 | Rough shared-memory estimator for FlashAttention v2 backward with optional 79 | DeBERTa-style disentangled biases. 80 | 81 | We count per-stage tiles: 82 | - q, k, v tiles 83 | - o, do tiles (needed to form dP) 84 | - probs/dP tile (M x N) if recompute_probs 85 | - small per-row LSE/m buffers 86 | - partial accumulators for dq, dk, dv 87 | - positional tiles for c2p / p2c (approx., span-limited) 88 | 89 | Notes: 90 | - This is an upper-bound-ish estimate; kernels keep some values in regs. 91 | - LSE often stored in fp32 on load; counted in shared for safety. 92 | """ 93 | # dtype sizes 94 | if dtype == torch.float16: 95 | t_sz = 2 96 | elif dtype == torch.bfloat16: 97 | t_sz = 2 98 | elif dtype == torch.float32: 99 | t_sz = 4 100 | else: 101 | t_sz = 2 102 | 103 | # Core tiles 104 | q_size = BLOCK_M * BLOCK_DMODEL * t_sz 105 | k_size = BLOCK_N * BLOCK_DMODEL * t_sz 106 | v_size = BLOCK_N * BLOCK_DMODEL * t_sz 107 | o_size = BLOCK_M * BLOCK_DMODEL * t_sz 108 | do_size = BLOCK_M * BLOCK_DMODEL * t_sz 109 | 110 | # Recomputed probabilities / dP tile 111 | probs_size = BLOCK_M * BLOCK_N * (t_sz if recompute_probs else 0) 112 | 113 | # Per-row softmax bookkeeping (m, lse), count as fp32 for safety 114 | lse_size = BLOCK_M * 2 * 4 if store_lse else BLOCK_M * 2 * 4 # still bring rows 115 | 116 | # Partial accumulators 117 | dq_acc = BLOCK_M * BLOCK_DMODEL * t_sz if accum_dq else 0 118 | dk_acc = BLOCK_N * BLOCK_DMODEL * t_sz if accum_dkv else 0 119 | dv_acc = BLOCK_N * BLOCK_DMODEL * t_sz if accum_dkv else 0 120 | 121 | # Positional memory (heuristic upper bound) 122 | pos_mem = 0 123 | if has_c2p: 124 | pos_mem += BLOCK_M * 2 * ATT_SPAN * t_sz 125 | if has_p2c: 126 | pos_mem += BLOCK_N * 2 * ATT_SPAN * t_sz 127 | 128 | # Misc scratch (indices, masks, scales); keep small constant per tile 129 | misc = 8 * 1024 # 8KB safety pad 130 | 131 | per_stage = (q_size + k_size + v_size + 132 | o_size + do_size + 133 | probs_size + lse_size + 134 | dq_acc + dk_acc + dv_acc + 135 | pos_mem + misc) 136 | 137 | # Double buffering across stages (FA kernels pipeline tiles) 138 | total = num_stages * per_stage 139 | 140 | # Many kernels reuse a half-buffer scheme; keep headroom 141 | return total//2 142 | 143 | def cdiv(a, b): 144 | return (a + b - 1) // b 145 | 146 | @triton.jit 147 | def _fwd_kernel_deberta_disentangled_attention( 148 | Q, K, V, 149 | K_POS, Q_POS, 150 | L, O, 151 | sm_scale, 152 | stride_qz, stride_qh, stride_qm, stride_qk, 153 | stride_kz, stride_kh, stride_kn, stride_kk, 154 | stride_vz, stride_vh, stride_vn, stride_vk, 155 | stride_oz, stride_oh, stride_om, stride_ok, 156 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 157 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 158 | Z, H, M, N, P_SEQ, 159 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 160 | IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, 161 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 162 | HAS_C2P: tl.constexpr, HAS_P2C: tl.constexpr, 163 | ATT_SPAN: tl.constexpr, 164 | NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr 165 | ): 166 | input_dtype = Q.dtype.element_ty 167 | 168 | start_m = tl.program_id(0) 169 | off_h = tl.program_id(1) 170 | off_z = tl.program_id(2) 171 | 172 | log2e: tl.constexpr = 1.4426950408889634 173 | 174 | Q += off_z * stride_qz + off_h * stride_qh 175 | K += off_z * stride_kz + off_h * stride_kh 176 | V += off_z * stride_vz + off_h * stride_vh 177 | O += off_z * stride_oz + off_h * stride_oh 178 | L += (off_z * H + off_h) * M # L is of shape (B*H, M) 179 | 180 | if HAS_C2P: 181 | K_POS += off_z*stride_pk0 + off_h*stride_pk1 182 | if HAS_P2C: 183 | Q_POS += off_z*stride_pq0 + off_h*stride_pq1 184 | 185 | offs_m_base = tl.arange(0, BLOCK_M) 186 | offs_m = start_m * BLOCK_M + offs_m_base 187 | offs_n_base = tl.arange(0, BLOCK_N) 188 | offs_k = tl.arange(0, BLOCK_DMODEL) 189 | 190 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) 191 | o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) 192 | l_ptrs = L + offs_m 193 | 194 | mask_m = offs_m < M 195 | if DIVISIBLE_M: 196 | q = tl.load(q_ptrs, cache_modifier=".cg") 197 | else: 198 | q = tl.load(q_ptrs, mask=mask_m[:, None], cache_modifier=".cg") 199 | 200 | if BLOCK_DMODEL < 128: 201 | I = tl.where(offs_k[:, None] == offs_k, 202 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=q.dtype), 203 | tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=q.dtype)) 204 | q = tl.dot(q, I).to(q.dtype) 205 | 206 | m_i = tl.full([BLOCK_M], value=-float("inf"), dtype=tl.float32) 207 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 208 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 209 | 210 | offs_n_init = offs_n_base 211 | k_ptrs = K + (offs_k[:, None] * stride_kk + offs_n_init[None, :] * stride_kn) # (BLOCK_DMODEL, BLOCK_N) 212 | v_ptrs = V + (offs_n_init[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (BLOCK_N, BLOCK_DMODEL) 213 | 214 | if IS_CAUSAL: 215 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 216 | if LARGER_M: 217 | hi = tl.maximum(0, hi) 218 | else: 219 | hi = N 220 | 221 | for start_n in range(0, hi, BLOCK_N): 222 | start_n = tl.multiple_of(start_n, BLOCK_N) 223 | offs_n = start_n + offs_n_base 224 | 225 | mask_n = offs_n < N 226 | if DIVISIBLE_N: 227 | k = tl.load(k_ptrs, cache_modifier=".cg") 228 | v = tl.load(v_ptrs, cache_modifier=".cg") 229 | else: 230 | k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") 231 | v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") 232 | 233 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=input_dtype) 234 | s += tl.dot(q, k) * sm_scale 235 | 236 | relative_positions = offs_m[:, None]-offs_n[None, :] # shape: (BLOCK_M, BLOCK_N) 237 | 238 | sign = tl.where(relative_positions > 0.0, 1.0, tl.where(relative_positions < 0.0, -1.0, 0.0)) 239 | 240 | mid_val = NUM_BUCKETS // 2 241 | 242 | abs_relative = tl.abs(relative_positions) 243 | condition = (relative_positions < mid_val) & (relative_positions > -mid_val) 244 | abs_pos = tl.where(condition, mid_val - 1.0, abs_relative) 245 | 246 | log_numer = tl.log(abs_pos / mid_val) 247 | log_denom = tl.log((MAX_DISTANCE - 1) / mid_val) 248 | log_scaled = log_numer / log_denom * (mid_val - 1.0) 249 | log_pos = tl.ceil(log_scaled) + mid_val 250 | 251 | bucket_pos = tl.where(abs_pos <= mid_val, relative_positions, log_pos * sign) 252 | 253 | if HAS_C2P: 254 | c2p_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32) 255 | 256 | k_pos_ptrs = K_POS+offs_m[:, None]*stride_pk2 + c2p_index*stride_pk3 257 | 258 | c2p_bias = tl.load(k_pos_ptrs, mask=mask_m[:, None] & (c2p_index < 2*ATT_SPAN), other=0.0) 259 | 260 | s += c2p_bias * sm_scale 261 | 262 | if HAS_P2C: 263 | p2c_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32).trans(1, 0) 264 | 265 | q_pos_ptrs = Q_POS + (offs_n[:, None] * stride_pq2 + p2c_index * stride_pq3) 266 | 267 | p2c_bias = tl.load(q_pos_ptrs, mask=mask_n[:, None] & (p2c_index < 2*ATT_SPAN), other=0.0).trans(1, 0) 268 | s += p2c_bias * sm_scale 269 | 270 | if not DIVISIBLE_N: 271 | s = tl.where(mask_n[None, :], s, float("-inf")) 272 | if IS_CAUSAL: 273 | causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] 274 | s = tl.where(causal_mask, s, float("-inf")) 275 | 276 | m_i_new = tl.maximum(m_i, tl.max(s, 1)) 277 | alpha = tl.math.exp2((m_i - m_i_new) * log2e) 278 | p = tl.math.exp2((s - m_i_new[:, None]) * log2e) 279 | acc *= alpha[:, None] 280 | acc += tl.dot(p.to(q.dtype), v) 281 | l_i = l_i * alpha + tl.sum(p, 1) 282 | m_i = m_i_new 283 | 284 | k_ptrs += BLOCK_N * stride_kn 285 | v_ptrs += BLOCK_N * stride_vn 286 | 287 | if IS_CAUSAL and LARGER_M: 288 | is_empty_line = (offs_m + P_SEQ) < 0 289 | acc = tl.where(is_empty_line[:, None], 0.0, acc * (1.0 / l_i[:, None])) 290 | l = tl.where(is_empty_line, float("-inf"), m_i + tl.log(l_i)) 291 | else: 292 | acc = acc * (1.0 / l_i[:, None]) 293 | l = m_i + tl.log(l_i) 294 | 295 | if DIVISIBLE_M: 296 | tl.store(l_ptrs, l, cache_modifier=".cg") 297 | tl.store(o_ptrs, acc.to(q.dtype), cache_modifier=".cg") 298 | else: 299 | tl.store(l_ptrs, l, mask=mask_m, cache_modifier=".cg") 300 | tl.store(o_ptrs, acc.to(q.dtype), mask=mask_m[:, None], cache_modifier=".cg") 301 | 302 | def get_fwd_config(B, H, M, N, D, causal, disentangled=False, max_shared_memory=None, att_span=256): 303 | """ 304 | Determine optimal kernel configuration parameters. 305 | 306 | Args: 307 | B, H, M, N, D: Batch, head, query length, key length, per-head dimension. 308 | causal (bool): Whether causal masking is applied. 309 | disentangled (bool): Whether to use the DeBERTa-style disentangled attention kernel. 310 | max_shared_memory (int, optional): Maximum available shared memory in bytes. 311 | If None, it will be queried from the device. 312 | 313 | Returns: 314 | Tuple (BLOCK_M, BLOCK_N, num_stages, num_warps) 315 | """ 316 | # See more details on the mapping at: https://forums.developer.nvidia.com/t/dynamic-shared-memory-calculated-by-ncu-larger-than-max-shared-memory-per-block/265589 317 | 318 | capability_map = { 319 | (7,0): 96000, 320 | (7,2): 96000, 321 | (7,5): 64000, 322 | (8,0): 163000, 323 | (8,6): 99000, 324 | (8,7): 163000, 325 | (8,9): 99000, 326 | (9,0): 227000, 327 | } 328 | 329 | capability = torch.cuda.get_device_capability() 330 | device_property = torch.cuda.get_device_properties() 331 | if hasattr(device_property,"shared_memory_per_block_optin"): 332 | shared_mem_per_block = device_property.shared_memory_per_block_optin 333 | elif capability in list(capability_map.keys()): 334 | shared_mem_per_block = capability_map[capability] 335 | elif capability[0] >= 8: 336 | shared_mem_per_block = 99000 337 | else: 338 | shared_mem_per_block = 48000 339 | 340 | max_shared_memory = shared_mem_per_block - 2000 # remove 2kb for ops overhead 341 | 342 | # Start with an aggressive configuration 343 | if capability[0] >= 8 : 344 | if not causal: 345 | if D <= 64: 346 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 347 | else: 348 | if M <= 1024: 349 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4 350 | else: 351 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8 352 | else: # causal 353 | if D <= 64: 354 | if disentangled: 355 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 356 | else: 357 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 4, 4 358 | else: 359 | if M <= 1024: 360 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 361 | else: 362 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8 363 | elif capability[0] == 8: 364 | if not causal: 365 | if D <= 64: 366 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 367 | else: 368 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 369 | else: # causal 370 | if D <= 64: 371 | if disentangled: 372 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 373 | else: 374 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 3, 4 375 | else: 376 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 2, 4 377 | else: 378 | BLOCK_M, BLOCK_N, num_stages, num_warps = 16, 16, 10, 4 379 | 380 | # Calculate shared memory usage with current config 381 | has_pos = disentangled 382 | ATT_SPAN = att_span if has_pos else 0 383 | 384 | dtype = torch.float16 385 | 386 | shared_mem_usage = calculate_shared_memory_usage( 387 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 388 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN 389 | ) 390 | 391 | # If shared memory usage exceeds available, adjust parameters 392 | # We prioritize reducing num_stages first, then block sizes 393 | while shared_mem_usage > max_shared_memory and (BLOCK_M > 16 or BLOCK_N > 16 or num_stages > 1): 394 | # First try reducing num_stages 395 | if num_stages > 1: 396 | num_stages -= 1 397 | # Then try reducing block sizes 398 | if BLOCK_M > 32 and BLOCK_N > 32: 399 | BLOCK_M //= 2 400 | BLOCK_N //= 2 401 | elif BLOCK_M > 32: 402 | BLOCK_M //= 2 403 | elif BLOCK_N > 32: 404 | BLOCK_N //= 2 405 | elif BLOCK_M > 16: 406 | BLOCK_M //= 2 407 | elif BLOCK_N > 16: 408 | BLOCK_N //= 2 409 | 410 | # Recalculate with new parameters 411 | shared_mem_usage = calculate_shared_memory_usage( 412 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 413 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN 414 | ) 415 | 416 | warnings.warn(f"INFO: Variable-length forward config is {BLOCK_M}, {BLOCK_N}, {num_stages}, {num_warps} for BLOCK_M, BLOCK_N stages and warps, respectively.\n" 417 | "INFO: If you want to change it, feel free to check ops/flash_attention_varlen") 418 | return (BLOCK_M, BLOCK_N, num_stages, num_warps) 419 | 420 | 421 | def flash_attn_v2_fwd_dise(q, k, v, pos_key, pos_query, causal, sm_scale, BLOCK_M, BLOCK_N, 422 | position_buckets, max_relative_distance, num_warps, num_stages, ATT_SPAN): 423 | """ 424 | Performs the forward pass of FlashAttention with DeBERTa-style disentangled relative attention. 425 | 426 | Args: 427 | ... (existing arguments) 428 | max_shared_memory (int, optional): Maximum available shared memory in bytes. 429 | If None, it will be queried from the device. 430 | """ 431 | B, H, M, D = q.shape 432 | N = k.shape[2] 433 | P_SEQ = N - M 434 | 435 | if sm_scale is None: 436 | sm_scale = 1. / math.sqrt(D) 437 | 438 | # Determine if each bias term is present. 439 | has_c2p = pos_key is not None 440 | has_p2c = pos_query is not None 441 | 442 | larger_m = M > N 443 | 444 | divisible_m = (M % BLOCK_M) == 0 445 | divisible_n = (N % BLOCK_N) == 0 446 | 447 | # Determine if each bias term is present. 448 | has_c2p = pos_key is not None 449 | has_p2c = pos_query is not None 450 | 451 | # Setup grid: use a 3D grid (query blocks, heads, batch) 452 | grid = (cdiv(M, BLOCK_M), H, B) 453 | o = torch.empty_like(q) 454 | L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) 455 | 456 | if has_c2p: 457 | stride_pk0, stride_pk1, stride_pk2, stride_pk3 = pos_key.stride() 458 | else: 459 | stride_pk0 = stride_pk1 = stride_pk2 = stride_pk3 = 0 460 | if has_p2c: 461 | stride_pq0, stride_pq1, stride_pq2, stride_pq3 = pos_query.stride() 462 | else: 463 | stride_pq0 = stride_pq1 = stride_pq2 = stride_pq3 = 0 464 | 465 | with torch.cuda.device(q.device.index): 466 | _fwd_kernel_deberta_disentangled_attention[grid]( 467 | q, k, v, 468 | pos_key, pos_query, 469 | L, o, 470 | sm_scale, 471 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 472 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 473 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 474 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 475 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 476 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 477 | B, H, M, N, P_SEQ, 478 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 479 | IS_CAUSAL=causal, LARGER_M=larger_m, 480 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 481 | HAS_C2P=has_c2p, HAS_P2C=has_p2c, 482 | ATT_SPAN=ATT_SPAN, 483 | NUM_BUCKETS=position_buckets, 484 | MAX_DISTANCE=max_relative_distance, 485 | num_warps=num_warps, num_stages=num_stages, 486 | ) 487 | 488 | return o, L 489 | 490 | def get_bwd_config( 491 | B, H, M, N, D, causal, 492 | *, 493 | disentangled: bool = False, 494 | att_span: int = 256, 495 | dtype = torch.float16, 496 | max_shared_memory: int | None = None, 497 | ): 498 | """ 499 | Heuristic selector for backward kernel tiling. 500 | Returns (BLOCK_M, BLOCK_N, num_stages, num_warps). 501 | """ 502 | capability_map = { 503 | (7,0): 96000, (7,2): 96000, (7,5): 64000, 504 | (8,0): 163000, (8,6): 99000, (8,7): 163000, (8,9): 99000, 505 | (9,0): 227000, 506 | } 507 | cap = torch.cuda.get_device_capability() 508 | prop = torch.cuda.get_device_properties(0) 509 | 510 | if max_shared_memory is None: 511 | if hasattr(prop, "shared_memory_per_block_optin"): 512 | shared_mem_per_block = prop.shared_memory_per_block_optin 513 | elif cap in capability_map: 514 | shared_mem_per_block = capability_map[cap] 515 | elif cap[0] >= 8: 516 | shared_mem_per_block = 99000 517 | else: 518 | shared_mem_per_block = 48000 519 | max_shared_memory = max(0, shared_mem_per_block - 2048) 520 | 521 | if cap[0] >= 9: 522 | if D <= 64: 523 | BLOCK_M, BLOCK_N, num_stages, num_warps = (128, 64, 3, 4) if not causal else (128, 64, 3, 4) 524 | else: 525 | BLOCK_M, BLOCK_N, num_stages, num_warps = (128, 64, 2, 8) if not causal else (128, 64, 2, 8) 526 | elif cap[0] >= 8: 527 | if D <= 64: 528 | if causal: 529 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 530 | else: 531 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4 532 | else: 533 | BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 2, 8 534 | else: 535 | BLOCK_M, BLOCK_N, num_stages, num_warps = 64, 64, 2, 4 536 | 537 | if N >= 2 * M and BLOCK_N < 128 and cap[0] >= 8: 538 | BLOCK_N = 128 539 | num_warps = max(num_warps, 8) 540 | 541 | has_pos = bool(disentangled) 542 | ATT_SPAN = att_span if has_pos else 0 543 | 544 | shm = calculate_shared_memory_usage_bwd( 545 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 546 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN, 547 | store_lse=True, recompute_probs=True, accum_dq=True, accum_dkv=True, 548 | ) 549 | 550 | def halve_pow2(x): return max(16, (x // 2)) if x > 16 else 16 551 | 552 | while shm > max_shared_memory and (num_stages > 1 or BLOCK_M > 16 or BLOCK_N > 16): 553 | if num_stages > 1: 554 | num_stages -= 1 555 | elif BLOCK_N >= BLOCK_M and BLOCK_N > 16: 556 | BLOCK_N = halve_pow2(BLOCK_N) 557 | elif BLOCK_M > 16: 558 | BLOCK_M = halve_pow2(BLOCK_M) 559 | else: 560 | break # nothing else to shrink 561 | 562 | shm = calculate_shared_memory_usage_bwd( 563 | BLOCK_M, BLOCK_N, D, num_stages, dtype, 564 | has_c2p=has_pos, has_p2c=has_pos, ATT_SPAN=ATT_SPAN, 565 | store_lse=True, recompute_probs=True, accum_dq=True, accum_dkv=True, 566 | ) 567 | 568 | if D <= 64 and BLOCK_M * BLOCK_N <= 128 * 64: 569 | num_warps = min(num_warps, 4) 570 | else: 571 | num_warps = max(num_warps, 8 if cap[0] >= 8 else 4) 572 | 573 | warnings.warn( 574 | f"INFO: Varlen backward config -> " 575 | f"BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}, stages={num_stages}, warps={num_warps}." 576 | "\nINFO: Adjust att_span/disentangled or set max_shared_memory to tune." 577 | ) 578 | 579 | return (BLOCK_M, BLOCK_N, num_stages, num_warps) 580 | 581 | @triton.jit 582 | def _bwd_preprocess( 583 | Out, DO, 584 | Delta, 585 | stride_oz, stride_oh, stride_om, stride_ok, 586 | stride_doz, stride_doh, stride_dom, stride_dok, 587 | stride_dz, stride_dh, stride_dm, 588 | M, 589 | BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, 590 | DIVISIBLE_M: tl.constexpr, 591 | ): 592 | off_h = tl.program_id(1) 593 | off_z = tl.program_id(2) 594 | Out += off_z * stride_oz + off_h * stride_oh 595 | DO += off_z * stride_doz + off_h * stride_doh 596 | Delta += off_z * stride_dz + off_h * stride_dh 597 | 598 | off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 599 | off_k = tl.arange(0, D_HEAD) 600 | 601 | o_ptrs = Out + off_m[:, None] * stride_om + off_k[None, :] * stride_ok 602 | do_ptrs = DO + off_m[:, None] * stride_dom + off_k[None, :] * stride_dok 603 | 604 | if DIVISIBLE_M: 605 | o = tl.load(o_ptrs).to(tl.float32) 606 | do = tl.load(do_ptrs).to(tl.float32) 607 | delta = tl.sum(o * do, axis=1) 608 | tl.store(Delta + off_m * stride_dm, delta) 609 | else: 610 | mask_m = off_m < M 611 | o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) 612 | do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) 613 | delta = tl.sum(o * do, axis=1) 614 | tl.store(Delta + off_m * stride_dm, delta, mask=mask_m) 615 | 616 | @triton.jit 617 | def _bwd_kv_dise_kernel( 618 | Q, K, V, K_POS, Q_POS, sm_scale, DO, 619 | DK, DV, DKPOS, DQPOS, 620 | L, Delta, 621 | stride_qz, stride_qh, stride_qm, stride_qk, 622 | stride_kz, stride_kh, stride_kn, stride_kk, 623 | stride_vz, stride_vh, stride_vn, stride_vk, 624 | stride_doz, stride_doh, stride_dom, stride_dok, 625 | stride_dkz, stride_dkh, stride_dkn, stride_dkk, 626 | stride_dvz, stride_dvh, stride_dvn, stride_dvk, 627 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 628 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 629 | Z, H, M, N, P_SEQ, 630 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 631 | CAUSAL: tl.constexpr, 632 | HAS_C2P: tl.constexpr, HAS_P2C: tl.constexpr, 633 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 634 | ATT_SPAN: tl.constexpr, NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr, 635 | ): 636 | input_dtype = Q.dtype.element_ty 637 | log2e: tl.constexpr = 1.4426950408889634 638 | 639 | start_n = tl.program_id(0) 640 | off_h = tl.program_id(1) 641 | off_z = tl.program_id(2) 642 | 643 | # Offset tensors 644 | Q += off_z*stride_qz + off_h*stride_qh 645 | K += off_z*stride_kz + off_h*stride_kh 646 | V += off_z*stride_vz + off_h*stride_vh 647 | DO += off_z*stride_doz + off_h*stride_doh 648 | 649 | DK += off_z*stride_dkz + off_h*stride_dkh 650 | DV += off_z*stride_dvz + off_h*stride_dvh 651 | 652 | if HAS_C2P: 653 | K_POS += off_z*stride_pk0 + off_h*stride_pk1 654 | DKPOS += off_z*stride_pk0 + off_h*stride_pk1 655 | if HAS_P2C: 656 | Q_POS += off_z*stride_pq0 + off_h*stride_pq1 657 | DQPOS += off_z*stride_pq0 + off_h*stride_pq1 658 | 659 | L += (off_z*H + off_h) * M 660 | Delta += (off_z*H + off_h) * M 661 | 662 | # Bounds in m for this block of n 663 | if CAUSAL: 664 | lo = tl.maximum(start_n * BLOCK_N - P_SEQ, 0) 665 | lo = (lo // BLOCK_M) * BLOCK_M 666 | else: 667 | lo = 0 668 | 669 | offs_m_init = lo + tl.arange(0, BLOCK_M) 670 | offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) 671 | offs_m_base = tl.arange(0, BLOCK_M) 672 | offs_k = tl.arange(0, BLOCK_DMODEL) 673 | 674 | # Pointers 675 | q_ptrs = Q + (offs_m_init[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (M, D) 676 | k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (N, D) 677 | v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) # (N, D) 678 | do_ptrs = DO + (offs_m_init[:, None] * stride_dom + offs_k[None, :] * stride_dok) # (M, D) 679 | 680 | dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :] * stride_dvk) 681 | dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :] * stride_dkk) 682 | 683 | # Load K, V once per n-tile 684 | mask_n = offs_n < N 685 | if DIVISIBLE_N: 686 | v = tl.load(v_ptrs) 687 | k = tl.load(k_ptrs) 688 | else: 689 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 690 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 691 | 692 | dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 693 | dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) 694 | 695 | for start_m in range(lo, M, BLOCK_M): 696 | start_m = tl.multiple_of(start_m, BLOCK_M) 697 | offs_m = start_m + offs_m_base 698 | mask_m = offs_m < M 699 | 700 | if DIVISIBLE_M: 701 | q = tl.load(q_ptrs) 702 | do = tl.load(do_ptrs) 703 | l = tl.load(L + offs_m) 704 | delta = tl.load(Delta + offs_m) 705 | else: 706 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 707 | do = tl.load(do_ptrs, mask=mask_m[:, None]) 708 | l = tl.load(L + offs_m, mask=mask_m) 709 | delta = tl.load(Delta + offs_m, mask=mask_m) 710 | 711 | # Recompute scores s = qk^T * sm_scale + biases 712 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 713 | s += tl.dot(q, tl.trans(k)) * sm_scale 714 | 715 | relative_positions = offs_m[:, None] - offs_n[None, :] # (M, N) 716 | sign = tl.where(relative_positions > 0.0, 1.0, tl.where(relative_positions < 0.0, -1.0, 0.0)) 717 | mid_val = NUM_BUCKETS // 2 718 | abs_relative = tl.abs(relative_positions) 719 | condition = (relative_positions < mid_val) & (relative_positions > -mid_val) 720 | abs_pos = tl.where(condition, mid_val - 1.0, abs_relative) 721 | 722 | log_numer = tl.log(abs_pos / mid_val) 723 | log_denom = tl.log((MAX_DISTANCE - 1) / mid_val) 724 | log_scaled = log_numer / log_denom * (mid_val - 1.0) 725 | log_pos = tl.ceil(log_scaled) + mid_val 726 | bucket_pos = tl.where(abs_pos <= mid_val, relative_positions, log_pos * sign) # signed 727 | 728 | if HAS_C2P: 729 | c2p_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32) 730 | k_pos_ptrs = K_POS + (offs_m[:, None] * stride_pk2 + c2p_index * stride_pk3) # (M,N) logical 731 | c2p_bias = tl.load(k_pos_ptrs, mask=mask_m[:, None] & (c2p_index < 2*ATT_SPAN), other=0.0) 732 | s += c2p_bias * sm_scale 733 | 734 | if HAS_P2C: 735 | p2c_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2 * ATT_SPAN - 1).to(tl.int32).trans(1, 0) 736 | q_pos_ptrs = Q_POS + (offs_n[:, None] * stride_pq2 + p2c_index * stride_pq3) # (N,M) 737 | p2c_bias = tl.load(q_pos_ptrs, mask=mask_n[:, None] & (p2c_index < 2*ATT_SPAN), other=0.0).trans(1, 0) 738 | s += p2c_bias * sm_scale 739 | 740 | # Causal mask 741 | if CAUSAL: 742 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) 743 | # Re-materialize p using saved log-normalizer l 744 | p = tl.math.exp2((s - l[:, None]) * log2e) 745 | if not DIVISIBLE_N: 746 | p = tl.where(mask_n[None, :], p, 0.0) 747 | if CAUSAL: 748 | p = tl.where(causal_mask, p, 0.0) 749 | 750 | # dv = p^T @ do 751 | dv += tl.dot(tl.trans(p.to(do.dtype)), do) 752 | 753 | # dp = do @ v^T 754 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 755 | dp += tl.dot(do.to(input_dtype), tl.trans(v)) 756 | 757 | # ds = p * (dp - delta[:, None]) 758 | ds = p * (dp - delta[:, None]) 759 | if not DIVISIBLE_N: 760 | ds = tl.where(mask_n[None, :], ds, 0.0) 761 | if CAUSAL: 762 | ds = tl.where(causal_mask, ds, 0.0) 763 | 764 | ds_scaled = (ds * sm_scale).to(input_dtype) 765 | 766 | # dk += ds^T @ q 767 | dk += tl.dot(tl.trans(ds_scaled), q) 768 | 769 | # Positional grads via atomic adds 770 | if HAS_C2P: 771 | # DKPOS[m, bucket] += sum_n ds(m,n)*sm_scale where bucket = c2p_index(m,n) 772 | kpos_grad_ptrs = DKPOS + (offs_m[:, None] * stride_pk2 + c2p_index * stride_pk3) # (M,N) 773 | tl.atomic_add( 774 | kpos_grad_ptrs, 775 | ds_scaled, 776 | mask=mask_m[:, None] & mask_n[None, :] & (c2p_index < 2*ATT_SPAN), 777 | ) 778 | 779 | if HAS_P2C: 780 | # DQPOS[n, bucket] += sum_m ds(m,n)*sm_scale where bucket = p2c_index(n,m) 781 | qpos_grad_ptrs = DQPOS + (offs_n[:, None] * stride_pq2 + p2c_index * stride_pq3) # (N,M) 782 | tl.atomic_add( 783 | qpos_grad_ptrs, 784 | ds_scaled.trans(1, 0), 785 | mask=mask_n[:, None] & mask_m[None, :] & (p2c_index < 2*ATT_SPAN), 786 | ) 787 | 788 | # advance pointers 789 | q_ptrs += BLOCK_M * stride_qm 790 | do_ptrs += BLOCK_M * stride_dom 791 | 792 | if DIVISIBLE_N: 793 | tl.store(dk_ptrs, dk.to(input_dtype)) 794 | tl.store(dv_ptrs, dv.to(input_dtype)) 795 | else: 796 | tl.store(dk_ptrs, dk.to(input_dtype), mask=mask_n[:, None]) 797 | tl.store(dv_ptrs, dv.to(input_dtype), mask=mask_n[:, None]) 798 | 799 | @triton.jit 800 | def _bwd_q_dise_kernel( 801 | Q, K, V, K_POS, Q_POS, sm_scale, DO, 802 | DQ, 803 | L, Delta, 804 | stride_qz, stride_qh, stride_qm, stride_qk, 805 | stride_kz, stride_kh, stride_kn, stride_kk, 806 | stride_vz, stride_vh, stride_vn, stride_vk, 807 | stride_doz, stride_doh, stride_dom, stride_dok, 808 | stride_dqz, stride_dqh, stride_dqm, stride_dqk, 809 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 810 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 811 | Z, H, M, N, P_SEQ, 812 | BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, 813 | CAUSAL: tl.constexpr, HAS_C2P: tl.constexpr, HAS_P2C: tl.constexpr, LARGER_M: tl.constexpr, 814 | DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, 815 | ATT_SPAN: tl.constexpr, NUM_BUCKETS: tl.constexpr, MAX_DISTANCE: tl.constexpr, 816 | ): 817 | input_dtype = Q.dtype.element_ty 818 | log2e: tl.constexpr = 1.4426950408889634 819 | 820 | start_m = tl.program_id(0) 821 | off_h = tl.program_id(1) 822 | off_z = tl.program_id(2) 823 | 824 | Q += off_z*stride_qz + off_h*stride_qh 825 | K += off_z*stride_kz + off_h*stride_kh 826 | V += off_z*stride_vz + off_h*stride_vh 827 | DO += off_z*stride_doz + off_h*stride_doh 828 | DQ += off_z*stride_dqz + off_h*stride_dqh 829 | 830 | if HAS_C2P: 831 | K_POS += off_z*stride_pk0 + off_h*stride_pk1 832 | if HAS_P2C: 833 | Q_POS += off_z*stride_pq0 + off_h*stride_pq1 834 | 835 | L += (off_z*H + off_h) * M 836 | Delta += (off_z*H + off_h) * M 837 | 838 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 839 | offs_n_base = tl.arange(0, BLOCK_N) 840 | offs_k = tl.arange(0, BLOCK_DMODEL) 841 | 842 | q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) 843 | dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk) 844 | do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok) 845 | 846 | mask_m = offs_m < M 847 | if DIVISIBLE_M: 848 | q = tl.load(q_ptrs) 849 | do = tl.load(do_ptrs) 850 | delta = tl.load(Delta + offs_m) 851 | l = tl.load(L + offs_m) 852 | else: 853 | q = tl.load(q_ptrs, mask=mask_m[:, None]) 854 | do = tl.load(do_ptrs, mask=mask_m[:, None]) 855 | delta = tl.load(Delta + offs_m, mask=mask_m) 856 | l = tl.load(L + offs_m, mask=mask_m) 857 | 858 | dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 859 | 860 | # Upper bound for N this row touches 861 | if CAUSAL: 862 | hi = tl.minimum(N, P_SEQ + (start_m + 1) * BLOCK_M) 863 | if LARGER_M: 864 | hi = tl.maximum(0, hi) 865 | else: 866 | hi = N 867 | 868 | k_ptrs = K + (offs_n_base[:, None] * stride_kn + offs_k[None, :] * stride_kk) 869 | v_ptrs = V + (offs_n_base[:, None] * stride_vn + offs_k[None, :] * stride_vk) 870 | 871 | for start_n in range(0, hi, BLOCK_N): 872 | start_n = tl.multiple_of(start_n, BLOCK_N) 873 | offs_n = start_n + offs_n_base 874 | 875 | mask_n = offs_n < N 876 | if DIVISIBLE_N: 877 | k = tl.load(k_ptrs) 878 | v = tl.load(v_ptrs) 879 | else: 880 | k = tl.load(k_ptrs, mask=mask_n[:, None]) 881 | v = tl.load(v_ptrs, mask=mask_n[:, None]) 882 | 883 | # Recompute s and p 884 | s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 885 | s += tl.dot(q, tl.trans(k)) * sm_scale 886 | 887 | # (same bucketization as fwd) 888 | relative_positions = offs_m[:, None] - offs_n[None, :] 889 | sign = tl.where(relative_positions > 0.0, 1.0, tl.where(relative_positions < 0.0, -1.0, 0.0)) 890 | mid_val = NUM_BUCKETS // 2 891 | abs_relative = tl.abs(relative_positions) 892 | condition = (relative_positions < mid_val) & (relative_positions > -mid_val) 893 | abs_pos = tl.where(condition, mid_val - 1.0, abs_relative) 894 | 895 | log_numer = tl.log(abs_pos / mid_val) 896 | log_denom = tl.log((MAX_DISTANCE - 1) / mid_val) 897 | log_scaled = log_numer / log_denom * (mid_val - 1.0) 898 | log_pos = tl.ceil(log_scaled) + mid_val 899 | bucket_pos = tl.where(abs_pos <= mid_val, relative_positions, log_pos * sign) 900 | 901 | if HAS_C2P: 902 | c2p_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2*ATT_SPAN - 1).to(tl.int32) 903 | k_pos_ptrs = K_POS + (offs_m[:, None] * stride_pk2 + c2p_index * stride_pk3) 904 | c2p_bias = tl.load(k_pos_ptrs, mask=mask_m[:, None] & (c2p_index < 2*ATT_SPAN), other=0.0) 905 | s += c2p_bias * sm_scale 906 | 907 | if HAS_P2C: 908 | p2c_index = tl.minimum(tl.maximum(bucket_pos + ATT_SPAN, 0), 2*ATT_SPAN - 1).to(tl.int32).trans(1, 0) 909 | q_pos_ptrs = Q_POS + (offs_n[:, None] * stride_pq2 + p2c_index * stride_pq3) 910 | p2c_bias = tl.load(q_pos_ptrs, mask=mask_n[:, None] & (p2c_index < 2*ATT_SPAN), other=0.0).trans(1, 0) 911 | s += p2c_bias * sm_scale 912 | 913 | if CAUSAL: 914 | causal_mask = (P_SEQ + offs_m[:, None]) >= (offs_n[None, :]) 915 | 916 | p = tl.math.exp2((s - l[:, None]) * log2e) 917 | if not DIVISIBLE_N: 918 | p = tl.where(mask_n[None, :], p, 0.0) 919 | if CAUSAL: 920 | p = tl.where(causal_mask, p, 0.0) 921 | 922 | dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 923 | dp += tl.dot(do.to(input_dtype), tl.trans(v)) 924 | 925 | ds = p * (dp - delta[:, None]) 926 | if not DIVISIBLE_N: 927 | ds = tl.where(mask_n[None, :], ds, 0.0) 928 | if CAUSAL: 929 | ds = tl.where(causal_mask, ds, 0.0) 930 | 931 | dq += tl.dot((ds * sm_scale).to(input_dtype), k) 932 | 933 | k_ptrs += BLOCK_N * stride_kn 934 | v_ptrs += BLOCK_N * stride_vn 935 | 936 | dq = dq.to(input_dtype) 937 | if DIVISIBLE_M: 938 | tl.store(dq_ptrs, dq) 939 | else: 940 | tl.store(dq_ptrs, dq, mask=mask_m[:, None]) 941 | 942 | def flash_attn_v2_bwd_dise(o, do, q, k, v, k_pos, q_pos, L, causal, sm_scale, 943 | BLOCK_M, BLOCK_N, position_buckets, max_relative_distance, 944 | num_warps, num_stages, ATT_SPAN): 945 | B, H, M, D = q.shape 946 | N = k.shape[2] 947 | P_SEQ = N - M 948 | larger_m = M > N 949 | divisible_m = (M % BLOCK_M) == 0 950 | divisible_n = (N % BLOCK_N) == 0 951 | 952 | has_c2p = (k_pos is not None) 953 | has_p2c = (q_pos is not None) 954 | 955 | # Preprocess: Delta = sum(o * do, dim=-1) 956 | delta = torch.empty_like(L) 957 | grid = (cdiv(M, BLOCK_M), H, B) 958 | with torch.cuda.device(q.device.index): 959 | _bwd_preprocess[grid]( 960 | o, do, delta, 961 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 962 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 963 | delta.stride(0), delta.stride(1), delta.stride(2), 964 | M, 965 | BLOCK_M=BLOCK_M, D_HEAD=D, DIVISIBLE_M=divisible_m, 966 | ) 967 | 968 | dk = torch.empty_like(k) 969 | dv = torch.empty_like(v) 970 | dk_pos = torch.zeros_like(k_pos) if has_c2p else None 971 | dq_pos = torch.zeros_like(q_pos) if has_p2c else None 972 | 973 | if has_c2p: 974 | stride_pk0, stride_pk1, stride_pk2, stride_pk3 = k_pos.stride() 975 | else: 976 | stride_pk0 = stride_pk1 = stride_pk2 = stride_pk3 = 0 977 | 978 | if has_p2c: 979 | stride_pq0, stride_pq1, stride_pq2, stride_pq3 = q_pos.stride() 980 | else: 981 | stride_pq0 = stride_pq1 = stride_pq2 = stride_pq3 = 0 982 | 983 | grid_kv = (cdiv(N, BLOCK_N), H, B) 984 | with torch.cuda.device(q.device.index): 985 | _bwd_kv_dise_kernel[grid_kv]( 986 | q, k, v, k_pos, q_pos, sm_scale, do, 987 | dk, dv, dk_pos, dq_pos, 988 | L, delta, 989 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 990 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 991 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 992 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 993 | dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), 994 | dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), 995 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 996 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 997 | B, H, M, N, P_SEQ, 998 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 999 | CAUSAL=causal, 1000 | HAS_C2P=has_c2p, HAS_P2C=has_p2c, 1001 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 1002 | ATT_SPAN=ATT_SPAN, 1003 | NUM_BUCKETS=position_buckets, 1004 | MAX_DISTANCE=max_relative_distance, 1005 | num_warps=num_warps, num_stages=num_stages, 1006 | ) 1007 | 1008 | dq = torch.empty_like(q) 1009 | grid_q = (cdiv(M, BLOCK_M), H, B) 1010 | with torch.cuda.device(q.device.index): 1011 | _bwd_q_dise_kernel[grid_q]( 1012 | q, k, v, k_pos, q_pos, sm_scale, do, 1013 | dq, 1014 | L, delta, 1015 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 1016 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 1017 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 1018 | do.stride(0), do.stride(1), do.stride(2), do.stride(3), 1019 | dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), 1020 | stride_pk0, stride_pk1, stride_pk2, stride_pk3, 1021 | stride_pq0, stride_pq1, stride_pq2, stride_pq3, 1022 | B, H, M, N, P_SEQ, 1023 | BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, 1024 | CAUSAL=causal, HAS_C2P=has_c2p, HAS_P2C=has_p2c, LARGER_M=(M > N), 1025 | DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, 1026 | ATT_SPAN=ATT_SPAN, NUM_BUCKETS=position_buckets, MAX_DISTANCE=max_relative_distance, 1027 | num_warps=num_warps, num_stages=num_stages, 1028 | ) 1029 | 1030 | return dq, dk, dv, dk_pos, dq_pos 1031 | 1032 | class FlashAttentionDisentangled(torch.autograd.Function): 1033 | @staticmethod 1034 | def forward(ctx, q, k, v, k_pos, q_pos, causal, 1035 | sm_scale, position_buckets, max_relative_distance): 1036 | 1037 | Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] 1038 | assert Dq == Dk == Dv, "Query, key, and value must have the same head dimension" 1039 | 1040 | B, H, M, D = q.shape 1041 | N = k.shape[2] 1042 | if sm_scale is None: 1043 | sm_scale = 1. / math.sqrt(D) 1044 | 1045 | ATT_SPAN = position_buckets if position_buckets > 0 else max_relative_distance 1046 | BLOCK_M, BLOCK_N, num_stages, num_warps = get_fwd_config( 1047 | B, H, M, N, D, causal, disentangled=True, att_span=ATT_SPAN 1048 | ) 1049 | 1050 | o, L = flash_attn_v2_fwd_dise( 1051 | q, k, v, k_pos, q_pos, causal, sm_scale, 1052 | BLOCK_M, BLOCK_N, position_buckets, 1053 | max_relative_distance, num_warps, num_stages, ATT_SPAN 1054 | ) 1055 | 1056 | # Save for backward 1057 | ctx.save_for_backward(q, k, v, k_pos, q_pos, o, L) 1058 | ctx.sm_scale = sm_scale 1059 | ctx.causal = causal 1060 | ctx.position_buckets = position_buckets 1061 | ctx.max_relative_distance = max_relative_distance 1062 | ctx.ATT_SPAN = ATT_SPAN 1063 | ctx.config = (BLOCK_M, BLOCK_N, num_stages, num_warps) 1064 | return o 1065 | 1066 | @staticmethod 1067 | def backward(ctx, do): 1068 | q, k, v, k_pos, q_pos, o, L = ctx.saved_tensors 1069 | sm_scale = ctx.sm_scale 1070 | causal = ctx.causal 1071 | position_buckets = ctx.position_buckets 1072 | max_relative_distance = ctx.max_relative_distance 1073 | ATT_SPAN = ctx.ATT_SPAN 1074 | B, H, M, D = q.shape 1075 | N = k.shape[2] 1076 | 1077 | BLOCK_M, BLOCK_N, num_stages, num_warps = get_bwd_config(B, H, M, N, D, causal) 1078 | 1079 | dq, dk, dv, dk_pos, dq_pos = flash_attn_v2_bwd_dise( 1080 | o, do, q, k, v, k_pos, q_pos, L, causal, sm_scale, 1081 | BLOCK_M, BLOCK_N, position_buckets, max_relative_distance, 1082 | num_warps, num_stages, ATT_SPAN 1083 | ) 1084 | 1085 | # match forward signature: (q, k, v, q_pos, k_pos, causal, sm_scale, position_buckets, max_relative_distance) 1086 | return dq, dk, dv, dk_pos, dq_pos, None, None, None, None 1087 | 1088 | def flash_attention_with_disentangled(q, k, v, k_pos, q_pos, causal=False, sm_scale=None, 1089 | position_buckets=0, max_relative_distance=0): 1090 | return FlashAttentionDisentangled.apply(q, k, v, k_pos, q_pos, causal, sm_scale, 1091 | position_buckets, max_relative_distance) --------------------------------------------------------------------------------