├── .gitignore ├── README.md ├── assets ├── logo.png └── pipfreeze.txt ├── main_generate.py ├── main_prefill.py ├── nano_sparse_attn ├── README.md ├── __init__.py ├── attention │ ├── __init__.py │ ├── abstract.py │ ├── inference_handler.py │ └── sparse_attention.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── modelling.py │ └── plotting.py ├── notebooks └── tutorial.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | results/ 3 | *.pyc 4 | *.egg-info/ 5 | *.egg 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanoSparseAttention 2 | 3 |

4 | logo 5 |

6 | 7 | ## Overview 8 | 9 | nanoSparseAttention provides clean, educational implementations of recent Sparse Attention mechanisms for both prefilling and generation stages of LLM inference. The repository prioritizes clarity and understanding over performance, making it ideal for learning and experimentation. 10 | 11 | We implemented a [Jupyter notebook](./notebooks/tutorial.ipynb) that provides: 12 | 1. Detailed explanation of Sparse Attention concepts 13 | 2. Step-by-step implementation walkthrough 14 | 3. Visualization of attention patterns 15 | 4. Performance comparisons between different methods 16 | 17 | The notebook has been prepared for the purpose of [NeurIPS 2024 Dynamic Sparsity Workshop](https://dynamic-sparsity.github.io/) - check it out if you want to learn more about dynamic execution, not only in the context of self-attention! 18 | 19 | ### Key Features 20 | 21 | - **Pure PyTorch Implementation**: All attention mechanisms are implemented in pure PyTorch for maximum clarity and ease of understanding. 22 | - **Real-world Testing**: Uses Llama-3.2-1B-Instruct model and FiscalNote/billsum dataset for practical experiments. 23 | - **Comprehensive Tutorial**: Includes a detailed Jupyter notebook explaining core concepts and implementations. 24 | - **Extensible Design**: Easy to add new models, datasets, and attention patterns through modular architecture. 25 | - **Flexible Inference**: Supports both prefilling and generation stages with ability to mix both at once. 26 | 27 | ### Implemented Methods 28 | 29 | #### Prefilling Stage 30 | - **Local Window + Attention Sinks** [(Xiao et al, 2023)](https://arxiv.org/abs/2309.17453), [(Han et al, 2024)](https://arxiv.org/abs/2308.16137) 31 | - **Vertical-Slash Attention** [(Jiang et al, 2024)](https://arxiv.org/abs/2407.02490) 32 | - **Block-Sparse Attention** [(Jiang et al, 2024)](https://arxiv.org/abs/2407.02490) 33 | 34 | #### Generation Stage 35 | - **Local Window + Attention Sinks** [(Xiao et al, 2023)](https://arxiv.org/abs/2309.17453), [(Han et al, 2024)](https://arxiv.org/abs/2308.16137) 36 | - **SnapKV** [(Li et al, 2024)](https://arxiv.org/abs/2404.14469) 37 | - **TOVA** [(Oren et al, 2023)](https://arxiv.org/abs/2401.06104) 38 | 39 | ## Installation 40 | 41 | Assuming that we want to use Python venv it's as easy as: 42 | 43 | ``` 44 | git clone https://github.com/PiotrNawrot/nano-sparse-attention 45 | cd nano-sparse-attention 46 | python3 -m venv .venv 47 | source .venv/bin/activate 48 | pip install --upgrade pip setuptools wheel psutil 49 | pip install -e ./ 50 | ``` 51 | 52 | ## Example Usage 53 | 54 | The repository provides two main scripts for experimenting with sparse attention mechanisms: 55 | 56 | ### Prefilling Stage 57 | 58 | ```python 59 | from nano_sparse_attn.attention import InferenceHandler, DenseAttention, LocalAndSinksAttention 60 | from nano_sparse_attn.utils import load_model_and_tokenizer, load_examples, update_attention, model_forward 61 | 62 | # Load model and prepare inputs 63 | model, tokenizer = load_model_and_tokenizer() 64 | model_inputs = load_examples(tokenizer, num_examples=1) 65 | 66 | # Create an inference handler with Local Window + Attention Sinks 67 | handler = InferenceHandler( 68 | prefill_attention=LocalAndSinksAttention( 69 | window_size=256, 70 | attention_sinks=16 71 | ), 72 | generation_attention=DenseAttention() 73 | ) 74 | 75 | # Update model's attention mechanism and run forward pass 76 | update_attention(model, handler) 77 | loss = model_forward(model, model_inputs, handler) 78 | 79 | # Get information about the attention mechanism 80 | info = handler.info() 81 | print(f"Loss: {loss}") 82 | print(f"Sparsity: {info['prefill']['sparsity']}") 83 | ``` 84 | 85 | ### Generation Stage 86 | 87 | ```python 88 | # Assumes imports from the previous example 89 | from nano_sparse_attn.attention import SnapKVAttention 90 | 91 | # Create an inference handler with SnapKV for generation 92 | handler = InferenceHandler( 93 | prefill_attention=DenseAttention(), 94 | generation_attention=SnapKVAttention( 95 | approximation_window=64, 96 | token_capacity=256 97 | ) 98 | ) 99 | 100 | # Update model's attention mechanism and run forward pass 101 | update_attention(model, handler) 102 | loss = model_forward(model, model_inputs, handler) 103 | 104 | # Get information about the attention mechanism 105 | info = handler.info() 106 | print(f"Loss: {loss}") 107 | print(f"Sparsity: {info['generation']['sparsity']}") 108 | ``` 109 | 110 | For ready-to-use scripts check out [main_prefill.py](./main_prefill.py) and [main_generate.py](./main_generate.py). 111 | For a detailed walkthrough of the repository and information about extending it to new models, datasets, and attention patterns, refer to [this README](./nano_sparse_attn/README.md). 112 | 113 | ## Contributing 114 | 115 | Contributions are welcome! Our goal is to keep this repository up-to-date with the latest Sparse Attention methods, by consistently adding new methods. Feel free to submit a Pull Request if 1) you want a new method to be added or 2) [even better] you have an implementation of a new Sparse Attention method! 116 | 117 | ## Authors 118 | 119 | Piotr Nawrot - [Website](https://piotrnawrot.github.io/) - piotr@nawrot.org 120 | 121 | Edoardo Maria Ponti - [Website](https://ducdauge.github.io/) - eponti@ed.ac.uk 122 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/nano-sparse-attention/24e0b5844024a6d17dd899af8f17a0c4042f31d0/assets/logo.png -------------------------------------------------------------------------------- /assets/pipfreeze.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.10.10 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==24.2.0 6 | certifi==2024.8.30 7 | charset-normalizer==3.4.0 8 | datasets==3.1.0 9 | dill==0.3.8 10 | filelock==3.16.1 11 | frozenlist==1.5.0 12 | fsspec==2024.9.0 13 | huggingface-hub==0.26.2 14 | idna==3.10 15 | Jinja2==3.1.4 16 | MarkupSafe==3.0.2 17 | mpmath==1.3.0 18 | multidict==6.1.0 19 | multiprocess==0.70.16 20 | networkx==3.4.2 21 | numpy==2.1.3 22 | nvidia-cublas-cu12==12.4.5.8 23 | nvidia-cuda-cupti-cu12==12.4.127 24 | nvidia-cuda-nvrtc-cu12==12.4.127 25 | nvidia-cuda-runtime-cu12==12.4.127 26 | nvidia-cudnn-cu12==9.1.0.70 27 | nvidia-cufft-cu12==11.2.1.3 28 | nvidia-curand-cu12==10.3.5.147 29 | nvidia-cusolver-cu12==11.6.1.9 30 | nvidia-cusparse-cu12==12.3.1.170 31 | nvidia-nccl-cu12==2.21.5 32 | nvidia-nvjitlink-cu12==12.4.127 33 | nvidia-nvtx-cu12==12.4.127 34 | packaging==24.1 35 | pandas==2.2.3 36 | propcache==0.2.0 37 | psutil==6.1.0 38 | pyarrow==18.0.0 39 | python-dateutil==2.9.0.post0 40 | pytz==2024.2 41 | PyYAML==6.0.2 42 | regex==2024.9.11 43 | requests==2.32.3 44 | safetensors==0.4.5 45 | six==1.16.0 46 | sympy==1.13.1 47 | tokenizers==0.20.1 48 | torch==2.5.1 49 | tqdm==4.66.6 50 | transformers==4.46.1 51 | triton==3.1.0 52 | typing_extensions==4.12.2 53 | tzdata==2024.2 54 | urllib3==2.2.3 55 | xxhash==3.5.0 56 | yarl==1.17.1 57 | -------------------------------------------------------------------------------- /main_generate.py: -------------------------------------------------------------------------------- 1 | from nano_sparse_attn.utils import ( 2 | load_model_and_tokenizer, 3 | load_examples, 4 | model_forward, 5 | update_attention, 6 | CONSTANTS, 7 | ) 8 | 9 | from nano_sparse_attn.attention import ( 10 | InferenceHandler, 11 | DenseAttention, 12 | LocalAndSinksAttention, 13 | SnapKVAttention, 14 | TOVAAttention, 15 | ) 16 | 17 | 18 | model, tokenizer = load_model_and_tokenizer() 19 | model_inputs = load_examples( 20 | tokenizer, 21 | target_length_min=CONSTANTS['runtime_args']['target_length_min'], 22 | target_length_max=CONSTANTS['runtime_args']['target_length_max'], 23 | num_examples=CONSTANTS['runtime_args']['num_examples'], 24 | ) 25 | 26 | inference_handlers = [] 27 | 28 | # Dense baseline 29 | inference_handlers.append( 30 | InferenceHandler( 31 | prefill_attention=DenseAttention(), 32 | generation_attention=DenseAttention(), 33 | ) 34 | ) 35 | 36 | # Local and Sinks Generation 37 | for window_size in range(128, 1024+1, 128): 38 | inference_handlers.append( 39 | InferenceHandler( 40 | prefill_attention=DenseAttention(), 41 | generation_attention=LocalAndSinksAttention( 42 | window_size=window_size, 43 | attention_sinks=16, 44 | ), 45 | ) 46 | ) 47 | 48 | # SnapKV Generation 49 | for token_capacity in range(128, 1024+1, 128): 50 | inference_handlers.append( 51 | InferenceHandler( 52 | prefill_attention=DenseAttention(), 53 | generation_attention=SnapKVAttention( 54 | approximation_window=64, 55 | token_capacity=token_capacity, 56 | ), 57 | ) 58 | ) 59 | 60 | # TOVA Generation 61 | for token_capacity in range(128, 1024+1, 128): 62 | inference_handlers.append( 63 | InferenceHandler( 64 | prefill_attention=DenseAttention(), 65 | generation_attention=TOVAAttention(token_capacity=token_capacity), 66 | ) 67 | ) 68 | 69 | for idx, inference_handler in enumerate(inference_handlers): 70 | update_attention(model, inference_handler) 71 | loss = model_forward(model, model_inputs, inference_handler) 72 | info_dict = inference_handler.info() 73 | print(f"InferenceHandler {idx}: \n-- Name: {info_dict['generation']['name']}\n-- Loss: {loss}\n-- Sparsity: {info_dict['generation']['sparsity']}\n-- Params: {info_dict['generation']['params']}") 74 | -------------------------------------------------------------------------------- /main_prefill.py: -------------------------------------------------------------------------------- 1 | from nano_sparse_attn.utils import ( 2 | load_model_and_tokenizer, 3 | load_examples, 4 | model_forward, 5 | update_attention, 6 | CONSTANTS, 7 | ) 8 | 9 | from nano_sparse_attn.attention import ( 10 | InferenceHandler, 11 | DenseAttention, 12 | LocalAndSinksAttention, 13 | VerticalAndSlashAttention, 14 | BlockSparseAttention, 15 | ) 16 | 17 | 18 | model, tokenizer = load_model_and_tokenizer() 19 | model_inputs = load_examples( 20 | tokenizer, 21 | target_length_min=CONSTANTS['runtime_args']['target_length_min'], 22 | target_length_max=CONSTANTS['runtime_args']['target_length_max'], 23 | num_examples=CONSTANTS['runtime_args']['num_examples'], 24 | ) 25 | 26 | inference_handlers = [] 27 | 28 | # Dense baseline 29 | inference_handlers.append( 30 | InferenceHandler( 31 | prefill_attention=DenseAttention(), 32 | generation_attention=DenseAttention(), 33 | ) 34 | ) 35 | 36 | for window_size in range(128, 1024+1, 128): 37 | inference_handlers.append( 38 | InferenceHandler( 39 | prefill_attention=LocalAndSinksAttention( 40 | window_size=window_size, 41 | attention_sinks=16, 42 | ), 43 | generation_attention=DenseAttention(), 44 | ) 45 | ) 46 | 47 | for top_k in range(128, 1024+1, 128): 48 | inference_handlers.append( 49 | InferenceHandler( 50 | prefill_attention=VerticalAndSlashAttention( 51 | top_tokens=top_k, 52 | top_slashes=top_k, 53 | window_size=128, 54 | approximation_window=64, 55 | attention_sinks=16, 56 | ), 57 | generation_attention=DenseAttention(), 58 | ) 59 | ) 60 | 61 | for top_chunks in range(2, 16+1, 2): 62 | inference_handlers.append( 63 | InferenceHandler( 64 | prefill_attention=BlockSparseAttention( 65 | chunk_size=64, 66 | top_chunks=top_chunks, 67 | ), 68 | generation_attention=DenseAttention(), 69 | ) 70 | ) 71 | 72 | for idx, inference_handler in enumerate(inference_handlers): 73 | update_attention(model, inference_handler) 74 | loss = model_forward(model, model_inputs, inference_handler) 75 | info_dict = inference_handler.info() 76 | print(f"InferenceHandler {idx}: \n-- Name: {info_dict['prefill']['name']}\n-- Loss: {loss}\n-- Sparsity: {info_dict['prefill']['sparsity']}\n-- Params: {info_dict['prefill']['params']}") 77 | -------------------------------------------------------------------------------- /nano_sparse_attn/README.md: -------------------------------------------------------------------------------- 1 | ## Technical Details & Extension 2 | 3 | ### Inference Handler & Attention Mechanism 4 | 5 | The core of nanoSparseAttention is built around two key components: 6 | 7 | 1. **InferenceHandler** (`attention/inference_handler.py`): Manages the switching between prefilling and generation attention patterns. It allows you to: 8 | - Use different attention patterns for prefilling and generation stages 9 | - Track sparsity ratios and attention masks for visualization 10 | 11 | 2. **Base Attention** (`attention/abstract.py`): Provides the foundation for implementing sparse attention patterns with: 12 | - Common utility functions (`get_causal_mask`, `get_local_mask`, etc.) 13 | - Sparsity calculation and mask tracking 14 | - Standard attention computation methods 15 | 16 | ### Creating New Sparse Patterns 17 | 18 | To implement a new sparse attention pattern: 19 | 20 | 1. Inherit from the base `Attention` class 21 | 2. Implement either/both: 22 | - `forward()` for prefilling attention 23 | - `generation_forward()` for generation attention 24 | 3. Create attention masks that specify which connections to keep/prune 25 | 26 | Example skeleton: 27 | ```python 28 | class MyNewAttention(Attention): 29 | def init(self, window_size): 30 | super().init() 31 | self.name = 'MyNewAttention' 32 | self.params = {"window_size": window_size} 33 | 34 | def forward(self, queries, keys, values, args, kwargs): 35 | # Create your attention mask 36 | attention_mask = ... 37 | 38 | # Track sparsity 39 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 40 | 41 | # Compute attention 42 | return self.attention(queries, keys, values, attention_mask) 43 | ``` 44 | 45 | ### Supporting New Models 46 | 47 | The repository only supports Llama-style attention mechanisms. To add support for a new model: 48 | 49 | 1. Locate the attention implementation in HuggingFace's transformers library (e.g., `modeling_llama.py` for Llama) 50 | 2. Copy and adapt the attention's forward pass to use our inference handler 51 | 3. Ensure correct tensor shapes for queries, keys, and values 52 | 53 | ### Using Custom Datasets 54 | 55 | The repository can be adapted to work with any dataset. The key requirements are: 56 | 57 | 1. Implement a data loading function that returns examples in the format: 58 | ```python 59 | { 60 | 'input_ids': tensor, # Input token ids 61 | 'attention_mask': tensor, # Attention mask 62 | 'output_length': int, # Length of the output sequence 63 | } 64 | ``` 65 | 66 | 2. If using instruction-tuned models (like we do), format your prompts appropriately. Our implementation uses chat templates for Llama-style models, but this is configurable. 67 | 68 | Example adaptation: 69 | ```python 70 | def load_my_dataset(tokenizer, kwargs): 71 | # Load your data 72 | dataset = ... 73 | examples = [] 74 | 75 | for item in dataset: 76 | # Format your prompt 77 | input_text = f"Your prompt: {item['input']}" 78 | 79 | # Tokenize (with or without chat template) 80 | tokens = tokenizer( 81 | input_text, 82 | return_tensors="pt", 83 | ) 84 | 85 | # Get output length for loss calculation 86 | output_length = ... 87 | examples.append({ 88 | 'input_ids': tokens['input_ids'], 89 | 'attention_mask': tokens['attention_mask'], 90 | 'output_length': output_length, 91 | }) 92 | 93 | return examples 94 | ``` 95 | -------------------------------------------------------------------------------- /nano_sparse_attn/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention.abstract import Attention 2 | from .attention.sparse_attention import ( 3 | DenseAttention, 4 | LocalAndSinksAttention, 5 | SnapKVAttention, 6 | TOVAAttention, 7 | ) 8 | from .attention.inference_handler import InferenceHandler 9 | 10 | __all__ = [ 11 | "Attention", 12 | "DenseAttention", 13 | "LocalAndSinksAttention", 14 | "SnapKVAttention", 15 | "TOVAAttention", 16 | "InferenceHandler", 17 | ] 18 | -------------------------------------------------------------------------------- /nano_sparse_attn/attention/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract import Attention 2 | from .sparse_attention import ( 3 | DenseAttention, 4 | LocalAndSinksAttention, 5 | VerticalAndSlashAttention, 6 | BlockSparseAttention, 7 | SnapKVAttention, 8 | TOVAAttention, 9 | ) 10 | from .inference_handler import InferenceHandler 11 | 12 | __all__ = [ 13 | "Attention", 14 | "DenseAttention", 15 | "LocalAndSinksAttention", 16 | "VerticalAndSlashAttention", 17 | "BlockSparseAttention", 18 | "SnapKVAttention", 19 | "TOVAAttention", 20 | "InferenceHandler", 21 | ] 22 | -------------------------------------------------------------------------------- /nano_sparse_attn/attention/abstract.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.name = 'Attention' 10 | self.masks = [] 11 | self.sparsity_ratios = [] 12 | self.params = {} 13 | self.layer_counter = 0 14 | 15 | def info(self): 16 | return { 17 | "name": self.name, 18 | "params": self.params, 19 | "masks": self.masks, 20 | "sparsity": self.calculate_average_sparsity_ratio(self.sparsity_ratios) 21 | } 22 | 23 | def __getattr__(self, name): 24 | if name in self.params: 25 | return self.params[name] 26 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") 27 | 28 | def maybe_save_mask(self, attention_mask, modulo_layers=8): 29 | if self.layer_counter % modulo_layers == 0: 30 | self.masks.append(attention_mask[0, 0].clone().cpu().numpy()) 31 | 32 | self.layer_counter += 1 33 | 34 | def forward(self, queries, keys, values, *args, **kwargs): 35 | """Base forward method for prefilling attention patterns. 36 | 37 | Args: 38 | queries: (batch_size, num_heads, seq_len, head_dim) 39 | keys: (batch_size, num_heads, seq_len, head_dim) 40 | values: (batch_size, num_heads, seq_len, head_dim) 41 | 42 | Returns: 43 | attention_output: (batch_size, num_heads, seq_len, head_dim) 44 | """ 45 | raise NotImplementedError 46 | 47 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values, 48 | generation_queries, generation_keys, generation_values, *args, **kwargs): 49 | """Base forward method for generation attention patterns. 50 | 51 | Args: 52 | prefilling_queries: Queries from prefilling stage 53 | prefilling_keys: Keys from prefilling stage 54 | prefilling_values: Values from prefilling stage 55 | generation_queries: New queries for generation 56 | generation_keys: New keys for generation 57 | generation_values: New values for generation 58 | 59 | Returns: 60 | attention_output: Output for generation tokens 61 | """ 62 | raise NotImplementedError 63 | 64 | @staticmethod 65 | def attention(queries, keys, values, attention_mask, return_attention_scores=False): 66 | """Standard attention computation.""" 67 | attention_weights = torch.matmul(queries, keys.transpose(2, 3)) / math.sqrt(queries.size(-1)) 68 | attention_weights += attention_mask.to(queries.dtype) * torch.finfo(queries.dtype).min 69 | attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(queries.dtype) 70 | attention_output = torch.matmul(attention_weights, values) 71 | if return_attention_scores: 72 | return attention_output, attention_weights 73 | return attention_output 74 | 75 | @staticmethod 76 | def get_causal_mask(seq_len, device): 77 | """Creates a causal mask where future tokens cannot attend to past tokens. 78 | 79 | Args: 80 | seq_len: Length of the sequence 81 | device: Device to create the mask on 82 | 83 | Returns: 84 | mask: Boolean tensor of shape (1, 1, seq_len, seq_len) where True/1 indicates 85 | that position (i,j) should be masked (set to -inf before softmax) 86 | """ 87 | mask = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device) 88 | mask = torch.triu(mask, diagonal=1) 89 | return mask.unsqueeze(0).unsqueeze(0) 90 | 91 | @staticmethod 92 | def get_local_mask(seq_len, window_size, device): 93 | """Creates a local attention mask where tokens can only attend to nearby tokens 94 | within a fixed window size, plus the causal constraint. 95 | 96 | Args: 97 | seq_len: Length of the sequence 98 | window_size: Size of the local attention window including current token 99 | device: Device to create the mask on 100 | 101 | Returns: 102 | mask: Boolean tensor of shape (1, 1, seq_len, seq_len) where True/1 indicates 103 | that position (i,j) should be masked (set to -inf before softmax) 104 | """ 105 | mask = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device) 106 | mask = torch.triu(mask, diagonal=-(window_size-1)) 107 | mask = torch.tril(mask, diagonal=0) 108 | mask = (~mask) 109 | return mask.unsqueeze(0).unsqueeze(0) 110 | 111 | @staticmethod 112 | def get_generation_mask(gen_len, prefill_len, device): 113 | """Creates a mask that allows generation tokens to: 114 | 1. Attend to all prefilling tokens 115 | 2. Attend causally to other generation tokens 116 | 117 | Args: 118 | gen_len: Number of tokens being generated 119 | prefill_len: Number of tokens in the prefill context 120 | device: Device to create the mask on 121 | 122 | Returns: 123 | mask: Boolean tensor of shape (1, 1, gen_len, prefill_len + gen_len) 124 | where True indicates positions that should be masked 125 | """ 126 | return torch.cat([ 127 | torch.zeros((1, 1, gen_len, prefill_len), dtype=torch.bool, device=device), 128 | Attention.get_causal_mask(gen_len, device) 129 | ], dim=-1) 130 | 131 | @staticmethod 132 | def calculate_sparsity_ratio(mask): 133 | """Calculates the sparsity ratio of an attention mask. 134 | 135 | This method computes what fraction of the possible attention connections are masked 136 | assuming that attention is causal, i.e., that tokens cannot attend to tokens before them. 137 | A higher ratio means more sparse attention. Asummes batch_size = 1. 138 | 139 | Args: 140 | mask: Boolean tensor of shape (batch_size, num_heads, queries_len, keys_len) where 141 | True/1 indicates masked (disabled) attention connections 142 | 143 | Returns: 144 | float: The sparsity ratio between 0 and 1, where: 145 | 0 means all possible connections are enabled (dense attention) 146 | 1 means all possible connections are masked (completely sparse) 147 | """ 148 | 149 | _, _, queries_len, keys_len = mask.shape 150 | 151 | if queries_len != keys_len: 152 | prefill_length = keys_len - queries_len 153 | total_connections = queries_len * (queries_len + 1) // 2 + prefill_length * queries_len 154 | else: 155 | total_connections = queries_len * (queries_len + 1) // 2 156 | 157 | connections_per_head = (~mask).long().sum(dim=(-1, -2)) 158 | non_masked_ratio = (connections_per_head.float() / total_connections).mean(dim=-1).item() 159 | sparsity_ratio = 1 - non_masked_ratio 160 | 161 | return sparsity_ratio 162 | 163 | @staticmethod 164 | def calculate_average_sparsity_ratio(sparsity_ratios): 165 | return sum(sparsity_ratios) / len(sparsity_ratios) if sparsity_ratios else 0 166 | -------------------------------------------------------------------------------- /nano_sparse_attn/attention/inference_handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class InferenceHandler(nn.Module): 5 | """Handles different attention patterns for prefilling and generation stages. 6 | 7 | This class manages the splitting of input sequences into prefilling and generation 8 | stages, and applies the appropriate attention pattern for each stage. 9 | 10 | Args: 11 | prefill_attention: Attention module to use during prefilling 12 | generation_attention: Attention module to use during generation 13 | """ 14 | def __init__(self, prefill_attention, generation_attention): 15 | super().__init__() 16 | self.prefill_attention = prefill_attention 17 | self.generation_attention = generation_attention 18 | self.name = f"{prefill_attention.name}_to_{generation_attention.name}" 19 | 20 | def set_output_length(self, output_length): 21 | self.output_length = output_length 22 | 23 | def info(self): 24 | return { 25 | "prefill": self.prefill_attention.info(), 26 | "generation": self.generation_attention.info() 27 | } 28 | 29 | def forward(self, queries, keys, values, layer_idx=None): 30 | """Handles attention computation for both prefilling and generation stages. 31 | 32 | Args: 33 | queries: (batch_size, num_heads, seq_len, head_dim) 34 | keys: (batch_size, num_heads, seq_len, head_dim) 35 | values: (batch_size, num_heads, seq_len, head_dim) 36 | layer_idx: Optional layer index for some attention implementations 37 | 38 | Returns: 39 | attention_output: Combined output from both stages 40 | """ 41 | # Split sequence into prefill and generation parts 42 | input_length = queries.size(-2) - self.output_length 43 | assert input_length > 0, "Input length must be > 0" 44 | 45 | # Prefilling stage 46 | prefill_output = self.prefill_attention.forward( 47 | queries=queries[..., :input_length, :], 48 | keys=keys[..., :input_length, :], 49 | values=values[..., :input_length, :], 50 | layer_idx=layer_idx 51 | ) 52 | 53 | # Generation stage 54 | generation_output = self.generation_attention.generation_forward( 55 | prefilling_queries=queries[..., :input_length, :], 56 | prefilling_keys=keys[..., :input_length, :], 57 | prefilling_values=values[..., :input_length, :], 58 | generation_queries=queries[..., input_length:, :], 59 | generation_keys=keys[..., input_length:, :], 60 | generation_values=values[..., input_length:, :], 61 | layer_idx=layer_idx 62 | ) 63 | 64 | # Combine outputs 65 | return torch.cat([prefill_output, generation_output], dim=-2) 66 | -------------------------------------------------------------------------------- /nano_sparse_attn/attention/sparse_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .abstract import Attention 6 | 7 | 8 | class DenseAttention(Attention): 9 | def __init__(self): 10 | super().__init__() 11 | self.name = 'DenseAttention' 12 | 13 | def forward(self, queries, keys, values, *args, **kwargs): 14 | attention_mask = self.get_causal_mask(queries.size(-2), queries.device) 15 | self.maybe_save_mask(attention_mask) 16 | return self.attention(queries, keys, values, attention_mask) 17 | 18 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values, 19 | generation_queries, generation_keys, generation_values, *args, **kwargs): 20 | keys = torch.cat([prefilling_keys, generation_keys], dim=-2) 21 | values = torch.cat([prefilling_values, generation_values], dim=-2) 22 | 23 | generation_mask = self.get_generation_mask( 24 | gen_len=generation_queries.size(-2), 25 | prefill_len=prefilling_keys.size(-2), 26 | device=generation_queries.device 27 | ) 28 | 29 | return self.attention(generation_queries, keys, values, generation_mask) 30 | 31 | 32 | class LocalAndSinksAttention(Attention): 33 | """Implements a sparse attention pattern combining local windows with global attention sinks. 34 | 35 | This attention mechanism reduces computational complexity by: 36 | 1. Allowing each token to attend only to nearby tokens within a fixed window 37 | 2. Designating the first K tokens as "attention sinks" that can be attended to by all tokens 38 | 39 | This creates a sparse pattern where most tokens have local connectivity, but important 40 | context tokens (sinks) maintain global connectivity. 41 | 42 | Args: 43 | window_size (int): Size of the local attention window around each token 44 | attention_sinks (int): Number of initial tokens that serve as global attention sinks 45 | 46 | Reference Papers: 47 | - (Xiao et al, 2023) https://arxiv.org/abs/2309.17453 48 | - (Han et al, 2024) https://arxiv.org/abs/2308.16137 49 | """ 50 | def __init__(self, window_size, attention_sinks): 51 | super().__init__() 52 | self.name = 'LocalAndSinksAttention' 53 | self.params = { 54 | "window_size": window_size, 55 | "attention_sinks": attention_sinks 56 | } 57 | 58 | def create_mask(self, seq_len, device): 59 | """Creates sparse attention mask with local windows and global sinks.""" 60 | assert self.window_size <= seq_len, "Window size must be less or equal to sequence length" 61 | assert self.attention_sinks <= seq_len, "Number of attention sinks must be less or equal to sequence length" 62 | 63 | # Create base local attention window 64 | mask = self.get_local_mask(seq_len, self.window_size, device) 65 | 66 | # Allow attention to sink tokens 67 | mask[..., :, :self.attention_sinks] = 0 68 | mask = mask | self.get_causal_mask(seq_len, device) 69 | 70 | return mask 71 | 72 | def forward(self, queries, keys, values, *args, **kwargs): 73 | """Sparse attention for prefilling.""" 74 | attention_mask = self.create_mask(queries.size(-2), queries.device) 75 | 76 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 77 | self.maybe_save_mask(attention_mask) 78 | 79 | return self.attention(queries, keys, values, attention_mask) 80 | 81 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values, 82 | generation_queries, generation_keys, generation_values, *args, **kwargs): 83 | assert self.attention_sinks <= prefilling_queries.size(-2) 84 | 85 | total_keys = torch.cat([ 86 | prefilling_keys, 87 | generation_keys, 88 | ], dim=-2) 89 | 90 | total_values = torch.cat([ 91 | prefilling_values, 92 | generation_values, 93 | ], dim=-2) 94 | 95 | attention_mask = self.create_mask(total_keys.size(-2), generation_queries.device) 96 | attention_mask = attention_mask[..., -generation_queries.size(-2):, :] 97 | 98 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 99 | self.maybe_save_mask(attention_mask) 100 | 101 | return self.attention(generation_queries, total_keys, total_values, attention_mask) 102 | 103 | 104 | class VerticalAndSlashAttention(Attention): 105 | """Implements a content-dependent sparse attention combining top tokens and diagonal patterns, 106 | known as Vertical-Slash, introduced by (Jiang et al, 2024) in the MInference 1.0 paper. 107 | 108 | This attention mechanism creates a sparse pattern by combining two types of connectivity: 109 | 1. Dynamic attention to top-K most relevant tokens (content-based) 110 | 2. Dynamic attention along top-K diagonal stripes (structure-based) 111 | 112 | The diagonal stripes (slashes) capture recurring patterns at fixed offsets, while 113 | top tokens capture semantic relevance. This combines structural and semantic sparsity. 114 | 115 | Args: 116 | attention_sinks (int): Number of initial tokens that serve as global attention sinks 117 | window_size (int): Size of the local attention window around each token 118 | top_tokens (int): Number of highest-scoring tokens to attend to globally 119 | top_slashes (int): Number of diagonal stripes to include in the pattern 120 | approximation_window (int): Size of the window used to approximate attention scores 121 | 122 | Reference Papers: 123 | - (Jiang et al, 2024) https://arxiv.org/abs/2407.02490 124 | """ 125 | def __init__(self, attention_sinks, window_size, top_tokens, top_slashes, approximation_window): 126 | super().__init__() 127 | self.name = 'VerticalAndSlashAttention' 128 | self.params = { 129 | "attention_sinks": attention_sinks, 130 | "window_size": window_size, 131 | "top_tokens": top_tokens, 132 | "top_slashes": top_slashes, 133 | "approximation_window": approximation_window, 134 | } 135 | 136 | def fill_diagonals(self, L: int, indexes: torch.Tensor, device='cuda', chunk_size=1024): 137 | """ 138 | Memory-efficient implementation to create a mask LxL with diagonals filled with False. 139 | Processes diagonals in chunks to reduce peak memory usage. 140 | 141 | Parameters: 142 | - L (int): The size of the square matrix. 143 | - indexes (torch.Tensor): Tensor of shape (B, H, M) with integer values which determine 144 | the diagonals to be filled in the output matrix for each batch and head. 145 | - device: The device to perform computations on ('cuda' or 'cpu'). 146 | - chunk_size: Size of chunks to process at once to manage memory usage. 147 | 148 | Returns: 149 | - mask (torch.Tensor): A boolean matrix of size (B, H, L, L) with specified diagonals filled with True. 150 | """ 151 | assert (indexes <= 0).all().item(), "Indexes must be on or below diagonal." 152 | 153 | batch_size, num_heads, num_diagonals = indexes.shape 154 | 155 | # Create output tensor 156 | mask_dense = torch.ones((batch_size, num_heads, L, L), dtype=torch.bool, device=device) 157 | 158 | # Process the sequence length in chunks 159 | for chunk_start in range(0, L, chunk_size): 160 | chunk_end = min(chunk_start + chunk_size, L) 161 | chunk_len = chunk_end - chunk_start 162 | 163 | # Create row indices for this chunk 164 | row_indices = torch.arange(chunk_start, chunk_end, device=device, dtype=torch.int32) 165 | row_indices = row_indices.view(1, 1, 1, chunk_len) 166 | row_indices = row_indices.expand(batch_size, num_heads, num_diagonals, chunk_len) 167 | 168 | # Add the diagonal offsets to get column indices 169 | col_indices = row_indices + indexes.unsqueeze(-1).to(torch.int32) 170 | 171 | # Mask out indices that are out of bounds 172 | valid_mask = (col_indices >= 0) & (col_indices < L) 173 | 174 | if not valid_mask.any(): 175 | continue 176 | 177 | # Create batch and head indices for valid positions only 178 | batch_idx = torch.arange(batch_size, device=device).view(-1, 1, 1, 1) 179 | batch_idx = batch_idx.expand(-1, num_heads, num_diagonals, chunk_len) 180 | head_idx = torch.arange(num_heads, device=device).view(1, -1, 1, 1) 181 | head_idx = head_idx.expand(batch_size, -1, num_diagonals, chunk_len) 182 | 183 | # Select only valid indices 184 | valid_batch_idx = batch_idx[valid_mask] 185 | valid_head_idx = head_idx[valid_mask] 186 | valid_row_idx = row_indices[valid_mask] 187 | valid_col_idx = col_indices[valid_mask] 188 | 189 | # Set the valid diagonal elements to False 190 | mask_dense[valid_batch_idx, valid_head_idx, valid_row_idx, valid_col_idx] = False 191 | 192 | # Free memory explicitly 193 | del row_indices, col_indices, valid_mask, batch_idx, head_idx 194 | 195 | return mask_dense 196 | 197 | def sum_over_diagonals(self, matrix): 198 | """Efficiently sum values along diagonals of the attention matrix. 199 | 200 | This method uses stride tricks to efficiently extract and sum diagonals: 201 | 1. Pad the matrix with zeros on both sides to handle all possible diagonals 202 | 2. Create a strided view that groups elements along diagonals 203 | 3. Sum along each diagonal to get their total attention scores 204 | 205 | The strided operation creates a view where each row contains elements from 206 | one diagonal, allowing for efficient parallel summation. 207 | 208 | Args: 209 | matrix: Attention scores tensor of shape (batch_size, num_heads, queries, keys) 210 | 211 | Returns: 212 | Tensor of shape (batch_size, num_heads, queries + keys - 1) containing 213 | summed attention scores for each diagonal 214 | 215 | This function is based on the implementation from: 216 | https://github.com/microsoft/MInference/blob/main/minference/modules/minference_forward.py#L101 217 | """ 218 | batch_size, num_heads, queries, keys = matrix.shape 219 | zero_matrix = torch.zeros((batch_size, num_heads, queries, queries), device=matrix.device) 220 | matrix_padded = torch.cat((zero_matrix, matrix, zero_matrix), -1) 221 | matrix_strided = matrix_padded.as_strided( 222 | ( 223 | batch_size, 224 | num_heads, 225 | queries, 226 | queries + keys 227 | ), 228 | ( 229 | num_heads * queries * (2 * queries + keys), 230 | queries * (2 * queries + keys), 231 | 2 * queries + keys + 1, 232 | 1 233 | ) 234 | ) 235 | sum_diagonals = torch.sum(matrix_strided, 2) 236 | return sum_diagonals[:, :, 1:] 237 | 238 | def create_mask(self, queries, keys, seq_len, device): 239 | assert self.top_slashes <= seq_len 240 | assert self.top_tokens <= seq_len 241 | assert self.top_slashes >= self.window_size 242 | assert self.attention_sinks <= self.top_tokens 243 | 244 | # Approximate attention scores 245 | approx_queries = queries[..., -self.approximation_window:, :] 246 | attention_scores = torch.matmul(approx_queries, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1)) 247 | attention_scores[..., -self.approximation_window:] += self.get_causal_mask(self.approximation_window, device) * torch.finfo(queries.dtype).min 248 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype) 249 | 250 | # Get top tokens 251 | summed_scores = attention_scores.sum(dim=-2) 252 | summed_scores[..., :self.attention_sinks] = torch.inf 253 | top_k_indices = torch.topk(summed_scores, self.top_tokens, dim=-1, sorted=False).indices 254 | 255 | # Get top_k slashes 256 | top_k_slash = self.sum_over_diagonals(attention_scores)[..., :-self.approximation_window + 1] 257 | top_k_slash[..., -self.window_size:] = torch.inf 258 | top_k_slash = torch.topk(top_k_slash, self.top_slashes, -1).indices - (seq_len - 1) 259 | 260 | # Get final mask 261 | mask = self.fill_diagonals(seq_len, top_k_slash, device) 262 | mask.scatter_(-1, top_k_indices.unsqueeze(-2).expand(-1, -1, seq_len, -1), 0) 263 | mask = mask | self.get_causal_mask(seq_len, device) 264 | 265 | return mask 266 | 267 | def forward(self, queries, keys, values, *args, **kwargs): 268 | """Sparse attention for prefilling using top tokens and slashes.""" 269 | attention_mask = self.create_mask(queries, keys, queries.size(-2), queries.device) 270 | 271 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 272 | self.maybe_save_mask(attention_mask) 273 | 274 | return self.attention(queries, keys, values, attention_mask) 275 | 276 | 277 | class BlockSparseAttention(Attention): 278 | """Implements a Block-Sparse Attention pattern based on chunk-level relevance, 279 | introduced by (Jiang et al, 2024) in the MInference 1.0 paper. 280 | 281 | This attention mechanism reduces complexity by operating on chunks of tokens: 282 | 1. Divides the sequence into fixed-size chunks. 283 | 2. Computes chunk-level attention scores with averaged token representations. 284 | 3. Allows each chunk to attend to the top-K most relevant chunks. 285 | 286 | This creates a coarse-grained sparse pattern where entire blocks of tokens attend 287 | to each other, making it efficient for long sequences while preserving approximate 288 | attention patterns. 289 | 290 | Contrary to the original paper, in this implementation we additionally: 291 | 1. Include the local window of size chunk_size around each token because we found 292 | it crucial for performance and it is non trivial to be executed by this pattern 293 | for high sparsity ratios. 294 | 2. Always pick the prefix chunk because we also found it crucial for performance but 295 | it is not always picked as top-1 chunk. 296 | 297 | For this reason, top_chunks = 1 BlockSparseAttention is equivalent to 298 | LocalAndSinksAttention(window_size=chunk_size, attention_sinks=chunk_size). 299 | 300 | Args: 301 | chunk_size (int): Size of each token chunk/block 302 | top_chunks (int): Number of highest-scoring chunks each chunk can attend to 303 | 304 | Reference Papers: 305 | - (Jiang et al, 2024) https://arxiv.org/abs/2407.02490 306 | """ 307 | def __init__(self, chunk_size, top_chunks): 308 | super().__init__() 309 | self.name = 'BlockSparseAttention' 310 | self.params = { 311 | "chunk_size": chunk_size, 312 | "top_chunks": top_chunks, 313 | } 314 | 315 | def create_mask(self, queries, keys, seq_len, device): 316 | assert self.chunk_size < seq_len, "Chunk size must be smaller than sequence length" 317 | assert self.top_chunks > 0, "Must select at least one top chunk" 318 | assert self.chunk_size >= 8, "Recommended chunk size is >= 8." 319 | 320 | # Calculate number of chunks and padding needed 321 | num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size 322 | padded_seq_len = num_chunks * self.chunk_size 323 | padding_len = padded_seq_len - seq_len 324 | 325 | assert self.top_chunks <= num_chunks, "Cannot select more top chunks than available chunks" 326 | 327 | # Pad queries and keys if needed 328 | if padding_len > 0: 329 | queries_padded = torch.nn.functional.pad(queries, (0, 0, 0, padding_len)) 330 | keys_padded = torch.nn.functional.pad(keys, (0, 0, 0, padding_len)) 331 | else: 332 | queries_padded = queries 333 | keys_padded = keys 334 | 335 | # Reshape padded queries and keys to chunks 336 | query_chunks = queries_padded.reshape(queries.size(0), queries.size(1), num_chunks, self.chunk_size, -1) 337 | key_chunks = keys_padded.reshape(keys.size(0), keys.size(1), num_chunks, self.chunk_size, -1) 338 | 339 | # Compute chunk representations by averaging 340 | query_chunk_representations = query_chunks.mean(dim=-2) 341 | key_chunk_representations = key_chunks.mean(dim=-2) 342 | 343 | # Compute attention scores between chunk representations 344 | chunk_attention_scores = torch.matmul(query_chunk_representations, key_chunk_representations.transpose(-2, -1)) 345 | 346 | # Add causal masking of upper triangle 347 | chunk_attention_scores.masked_fill_(self.get_causal_mask(num_chunks, device), float('-inf')) 348 | 349 | # Always pick the prefix chunk 350 | chunk_attention_scores[..., 0] = float('inf') 351 | 352 | # Get top-k key chunks for each query chunk 353 | top_k_chunk_indices = torch.topk(chunk_attention_scores, self.top_chunks, dim=-1, sorted=False).indices 354 | 355 | # Create a mask for top-k interactions 356 | top_k_mask = torch.ones((num_chunks, num_chunks), dtype=torch.bool, device=device) 357 | top_k_mask = top_k_mask.unsqueeze(0).unsqueeze(0).repeat(queries.size(0), queries.size(1), 1, 1) 358 | top_k_mask.scatter_(-1, top_k_chunk_indices, 0) 359 | 360 | # Expand mask to padded sequence length 361 | mask = top_k_mask.repeat_interleave(self.chunk_size, dim=-2).repeat_interleave(self.chunk_size, dim=-1) 362 | 363 | # Include the local window of size chunk_size around each token 364 | mask = mask & self.get_local_mask(padded_seq_len, self.chunk_size, queries.device) 365 | 366 | # Include the causal mask 367 | mask = mask | self.get_causal_mask(padded_seq_len, queries.device) 368 | 369 | # Remove padding from mask if needed 370 | if padding_len > 0: 371 | mask = mask[..., :seq_len, :seq_len] 372 | 373 | return mask 374 | 375 | def forward(self, queries, keys, values, *args, **kwargs): 376 | """Block-sparse attention for prefilling.""" 377 | attention_mask = self.create_mask(queries, keys, queries.size(-2), queries.device) 378 | 379 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 380 | self.maybe_save_mask(attention_mask) 381 | 382 | return self.attention(queries, keys, values, attention_mask) 383 | 384 | 385 | class SnapKVAttention(Attention): 386 | """Implements SnapKV's attention pattern for efficient text generation. 387 | 388 | This attention mechanism compresses the Pre-Filling KV Cache, therefore reducing 389 | memory usage during generation, by: 390 | 1. Approximating attention scores using a suffix window of queries 391 | 2. Using attention score pooling to identify important token clusters 392 | 3. Keeping only the most relevant tokens plus a recent window 393 | 394 | Args: 395 | token_capacity (int): Maximum number of tokens to keep in compressed history 396 | approximation_window (int): Number of suffix tokens used to approximate attention scores 397 | kernel_size (int): Size of the pooling kernel for score aggregation 398 | 399 | Reference: 400 | - (Li et al, 2024) https://arxiv.org/abs/2404.14469 401 | """ 402 | def __init__(self, token_capacity, approximation_window=64, kernel_size=7): 403 | super().__init__() 404 | self.name = 'SnapKVAttention' 405 | self.params = { 406 | "token_capacity": token_capacity, 407 | "approximation_window": approximation_window, 408 | "kernel_size": kernel_size, 409 | } 410 | 411 | def create_mask_for_prefill(self, queries, keys, seq_len, device): 412 | """Create mask for generation with compressed KV cache. 413 | 414 | It uses prefilling queries and keys to estimate which tokens will be important 415 | during generation and keeps only the most important tokens plus a recent window. 416 | 417 | The output mask only concerns the prefilling tokens and has to be extended to 418 | account for the tokens from the generation phase. 419 | """ 420 | assert self.token_capacity >= self.approximation_window 421 | 422 | # Approximate attention scores 423 | approx_queries = queries[..., -self.approximation_window:, :] 424 | attention_scores = torch.matmul(approx_queries, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1)) 425 | attention_scores[..., -self.approximation_window:] += self.get_causal_mask(self.approximation_window, device) * torch.finfo(queries.dtype).min 426 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype) 427 | 428 | # Sum over queries to get per-key importance and apply average pooling 429 | key_importance = attention_scores.sum(dim=-2) 430 | key_importance = F.avg_pool1d( 431 | key_importance, 432 | kernel_size=self.kernel_size, 433 | padding=self.kernel_size // 2, 434 | stride=1 435 | ) 436 | 437 | # Always keep the window tokens 438 | key_importance[..., -self.approximation_window:] = torch.inf 439 | 440 | # Keep only the top_k tokens 441 | mask = torch.ones((queries.size(0), queries.size(1), seq_len), dtype=torch.bool, device=device) 442 | top_indices = key_importance.topk(self.token_capacity, dim=-1).indices 443 | mask.scatter_(-1, top_indices, False) 444 | 445 | # Expand mask to proper attention shape 446 | mask = mask.unsqueeze(-2) 447 | 448 | return mask 449 | 450 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values, 451 | generation_queries, generation_keys, generation_values, *args, **kwargs): 452 | # Concatenate prefilling and generation KV states 453 | keys = torch.cat([prefilling_keys, generation_keys], dim=-2) 454 | values = torch.cat([prefilling_values, generation_values], dim=-2) 455 | 456 | # Create mask for prefilling tokens 457 | prefill_mask = self.create_mask_for_prefill(prefilling_queries, prefilling_keys, prefilling_keys.size(-2), generation_queries.device) 458 | 459 | # Create dense causal mask for generation tokens 460 | attention_mask = self.get_generation_mask( 461 | gen_len=generation_queries.size(-2), 462 | prefill_len=prefilling_keys.size(-2), 463 | device=generation_queries.device 464 | ).repeat(generation_queries.size(0), generation_queries.size(1), 1, 1) 465 | 466 | # Combine masks 467 | attention_mask[..., :prefilling_keys.size(-2)] |= prefill_mask 468 | 469 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 470 | self.maybe_save_mask(attention_mask) 471 | 472 | return self.attention(generation_queries, keys, values, attention_mask) 473 | 474 | 475 | class TOVAAttention(Attention): 476 | """Implements Token Omission Via Attention (TOVA) for efficient text generation. 477 | 478 | This attention mechanism dynamically prunes the KV cache during generation by: 479 | 1. Using the last prefilling token to identify initially important tokens 480 | 2. During generation, maintaining a fixed-size cache by removing the least attended token 481 | after processing each new token 482 | 483 | Args: 484 | token_capacity (int): Maximum number of tokens to keep in the pruned KV cache 485 | 486 | Reference: 487 | - (Oren et al, 2023) https://arxiv.org/abs/2401.06104 488 | """ 489 | def __init__(self, token_capacity): 490 | super().__init__() 491 | self.name = 'TOVAAttention' 492 | self.params = { 493 | "token_capacity": token_capacity 494 | } 495 | 496 | def create_mask_for_prefill(self, queries, keys, seq_len, device): 497 | """Create initial mask based on attention scores from the last prefilling token.""" 498 | if self.token_capacity >= seq_len: 499 | return torch.zeros((queries.size(0), queries.size(1), 1, seq_len), dtype=torch.bool, device=device) 500 | 501 | # Get attention scores for the last prefilling token 502 | last_query = queries[..., -1:, :] 503 | attention_scores = torch.matmul(last_query, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1)) 504 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype) 505 | 506 | # Average attention scores across heads 507 | mean_attention_scores = attention_scores.mean(dim=1) 508 | 509 | # Create mask keeping only top-k tokens 510 | mask = torch.ones((queries.size(0), queries.size(1), seq_len), dtype=torch.bool, device=device) 511 | top_indices = mean_attention_scores.topk(self.token_capacity, dim=-1).indices 512 | mask.scatter_(-1, top_indices.expand(-1, queries.size(1), -1), False) 513 | 514 | # Expand mask to proper attention shape 515 | mask = mask.unsqueeze(-2) 516 | 517 | return mask 518 | 519 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values, 520 | generation_queries, generation_keys, generation_values, *args, **kwargs): 521 | # Get initial mask for prefilling tokens 522 | current_mask = self.create_mask_for_prefill( 523 | prefilling_queries, 524 | prefilling_keys, 525 | prefilling_keys.size(-2), 526 | generation_queries.device 527 | ) 528 | 529 | # Initialize lists to store outputs and updated keys/values 530 | outputs = [] 531 | current_keys = prefilling_keys 532 | current_values = prefilling_values 533 | 534 | # Intialise final mask we'll output 535 | attention_mask = torch.ones(( 536 | generation_queries.size(0), 537 | generation_queries.size(1), 538 | generation_queries.size(2), 539 | prefilling_keys.size(-2) + generation_keys.size(-2), 540 | ), dtype=torch.bool, device=generation_queries.device) 541 | 542 | # Process generation tokens one by one 543 | for idx in range(generation_queries.size(-2)): 544 | # Get current generation token 545 | current_query = generation_queries[..., idx:idx+1, :] 546 | current_gen_key = generation_keys[..., idx:idx+1, :] 547 | current_gen_value = generation_values[..., idx:idx+1, :] 548 | 549 | # Extend keys and values 550 | current_keys = torch.cat([current_keys, current_gen_key], dim=-2) 551 | current_values = torch.cat([current_values, current_gen_value], dim=-2) 552 | 553 | # Extend mask for the new token (always attended) 554 | current_mask = torch.cat([ 555 | current_mask, 556 | torch.zeros((current_query.size(0), current_query.size(1), 1, 1), dtype=torch.bool, device=current_query.device) 557 | ], dim=-1) 558 | 559 | attention_mask[..., idx:idx+1, :current_keys.size(-2)] = current_mask 560 | 561 | # Compute attention with scores 562 | output, attention_scores = self.attention( 563 | current_query, 564 | current_keys, 565 | current_values, 566 | current_mask, 567 | return_attention_scores=True 568 | ) 569 | outputs.append(output) 570 | 571 | # If we exceed capacity, mask the token with lowest attention score 572 | if current_keys.size(-2) > self.token_capacity: 573 | # Set scores to inf where tokens were already masked 574 | attention_scores = attention_scores.masked_fill(current_mask, float('inf')) 575 | 576 | # Average attention scores across heads 577 | mean_scores = attention_scores.mean(dim=1, keepdim=True) 578 | 579 | # Find token with lowest attention score 580 | min_indices = mean_scores.argmin(dim=-1, keepdim=True) 581 | min_indices = min_indices.expand(-1, current_query.size(1), -1, -1) 582 | 583 | # Update mask to exclude the lowest scoring token 584 | current_mask.scatter_(-1, min_indices, True) 585 | 586 | # Concatenate all outputs 587 | final_output = torch.cat(outputs, dim=-2) 588 | 589 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask)) 590 | self.maybe_save_mask(attention_mask) 591 | 592 | return final_output 593 | -------------------------------------------------------------------------------- /nano_sparse_attn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import CONSTANTS 2 | from .modelling import ( 3 | load_model_and_tokenizer, 4 | load_examples, 5 | model_forward, 6 | update_attention, 7 | ) 8 | from .plotting import ( 9 | plot_sparse_attention_results, 10 | plot_prefill_masks, 11 | plot_generation_masks, 12 | ) 13 | 14 | __all__ = [ 15 | "CONSTANTS", 16 | "load_model_and_tokenizer", 17 | "load_examples", 18 | "model_forward", 19 | "update_attention", 20 | "plot_sparse_attention_results", 21 | "plot_prefill_masks", 22 | "plot_generation_masks", 23 | ] 24 | -------------------------------------------------------------------------------- /nano_sparse_attn/utils/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | CONSTANTS = { 5 | 'model_name': 'unsloth/Llama-3.2-1B-Instruct', 6 | 'dataset_name': 'FiscalNote/billsum', 7 | 'runtime_args': { 8 | 'dtype': torch.bfloat16, 9 | 'num_examples': 3, 10 | 'target_length_min': 4096, 11 | 'target_length_max': 4096 + 512, 12 | 'device': 'cuda' if torch.cuda.is_available() else 'cpu', 13 | }, 14 | 'hf_kwargs': { 15 | 'trust_remote_code': True, 16 | }, 17 | 'save_results': True, 18 | } 19 | -------------------------------------------------------------------------------- /nano_sparse_attn/utils/modelling.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer, AutoModelForCausalLM 7 | from transformers.models.llama.modeling_llama import LlamaAttention 8 | from transformers.models.llama.modeling_llama import Cache, repeat_kv, apply_rotary_pos_emb 9 | 10 | from .constants import CONSTANTS 11 | 12 | 13 | def llama_attention_forward( 14 | self, 15 | hidden_states: torch.Tensor, 16 | position_ids: Optional[torch.LongTensor] = None, 17 | past_key_value: Optional[Cache] = None, 18 | cache_position: Optional[torch.LongTensor] = None, 19 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 20 | **kwargs, 21 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 22 | batch_size, query_length, _ = hidden_states.size() 23 | num_heads, head_dim = self.num_heads, self.head_dim 24 | num_kv_heads, num_kv_groups = self.num_key_value_heads, self.num_key_value_groups 25 | 26 | queries = self.q_proj(hidden_states).view(batch_size, query_length, num_heads, head_dim).transpose(1, 2) 27 | keys = self.k_proj(hidden_states).view(batch_size, query_length, num_kv_heads, head_dim).transpose(1, 2) 28 | values = self.v_proj(hidden_states).view(batch_size, query_length, num_kv_heads, head_dim).transpose(1, 2) 29 | 30 | if position_embeddings is None: 31 | cos, sin = self.rotary_emb(values, position_ids) 32 | else: 33 | cos, sin = position_embeddings 34 | 35 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) 36 | 37 | if past_key_value is not None: 38 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 39 | keys, values = past_key_value.update(keys, values, self.layer_idx, cache_kwargs) 40 | 41 | keys = repeat_kv(keys, num_kv_groups) 42 | values = repeat_kv(values, num_kv_groups) 43 | 44 | attention_output = self.attn(queries, keys, values, layer_idx=self.layer_idx) 45 | 46 | attention_output = attention_output.transpose(1, 2).reshape(batch_size, query_length, -1).contiguous() 47 | attention_output = self.o_proj(attention_output) 48 | 49 | return attention_output, None, past_key_value 50 | 51 | 52 | def update_attention(model, inference_handler): 53 | def update_attention_recursive(module): 54 | if isinstance(module, LlamaAttention): 55 | module.forward = llama_attention_forward.__get__(module, LlamaAttention) 56 | module.attn = inference_handler 57 | 58 | model.apply(update_attention_recursive) 59 | 60 | 61 | def load_model_and_tokenizer(): 62 | tokenizer = AutoTokenizer.from_pretrained( 63 | CONSTANTS['model_name'], 64 | **CONSTANTS['hf_kwargs'], 65 | ) 66 | 67 | model = AutoModelForCausalLM.from_pretrained( 68 | CONSTANTS['model_name'], 69 | torch_dtype=CONSTANTS['runtime_args']['dtype'], 70 | attn_implementation="eager", 71 | **CONSTANTS['hf_kwargs'], 72 | ).to(CONSTANTS['runtime_args']['device']) 73 | 74 | 75 | return model, tokenizer 76 | 77 | 78 | def load_examples(tokenizer, target_length_min=3500, target_length_max=4096, num_examples=1): 79 | dataset = load_dataset( 80 | CONSTANTS['dataset_name'], 81 | split="train", 82 | **CONSTANTS['hf_kwargs'], 83 | ) 84 | 85 | examples = [] 86 | for _, example in enumerate(dataset): 87 | # Separate input prompt and output 88 | input_prompt = f"""Below is a US Congressional and California state bill. Please provide a concise summary of the bill. 89 | 90 | Bill: 91 | {example['text']}""" 92 | 93 | output_text = example['summary'] 94 | 95 | # Apply chat template to input 96 | templated_input = tokenizer.apply_chat_template( 97 | [{"role": "user", "content": input_prompt}], 98 | tokenize=False, 99 | add_generation_prompt=True 100 | ) 101 | 102 | # Create full sequence with answer 103 | full_sequence = templated_input + output_text 104 | 105 | # Check total length 106 | tokens = tokenizer( 107 | full_sequence, 108 | return_tensors="pt", 109 | add_special_tokens=False, 110 | ) 111 | 112 | total_length = tokens['input_ids'].shape[1] 113 | if target_length_min <= total_length <= target_length_max: 114 | # Get output length for loss calculation 115 | output_length = len(tokenizer(output_text, add_special_tokens=False)['input_ids']) 116 | 117 | # Create final tokens 118 | model_inputs = tokens.to(CONSTANTS['runtime_args']['device']) 119 | 120 | examples.append({ 121 | 'input_ids': model_inputs['input_ids'], 122 | 'attention_mask': model_inputs['attention_mask'], 123 | 'output_length': output_length, 124 | }) 125 | 126 | if len(examples) >= num_examples: 127 | break 128 | 129 | return examples 130 | 131 | 132 | def model_forward(model, model_inputs, inference_handler): 133 | total_loss = 0 134 | 135 | for example in model_inputs: 136 | # Set output length for current example 137 | inference_handler.set_output_length(example['output_length']) 138 | 139 | with torch.no_grad(): 140 | outputs = model( 141 | input_ids=example['input_ids'], 142 | attention_mask=example['attention_mask'], 143 | ) 144 | 145 | shift_logits = outputs.logits[..., :-1, :] 146 | shift_labels = example['input_ids'][..., 1:] 147 | 148 | # Calculate loss only over the output length 149 | output_length = example['output_length'] 150 | loss_fct = nn.CrossEntropyLoss() 151 | loss = loss_fct( 152 | shift_logits.view(-1, shift_logits.size(-1))[-output_length:], 153 | shift_labels.view(-1)[-output_length:] 154 | ) 155 | total_loss += loss.item() 156 | 157 | # Return average loss across examples 158 | return total_loss / len(model_inputs) 159 | -------------------------------------------------------------------------------- /nano_sparse_attn/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | # This is improtant to see individual columns in a mask 5 | # plt.rcParams['figure.dpi']= 600 6 | 7 | def plot_prefill_masks(mask1, mask2, title): 8 | """ 9 | Plot two prefill attention masks side by side with custom styling. 10 | 11 | Args: 12 | mask1: 2D boolean numpy array or tensor for left plot 13 | mask2: 2D boolean numpy array or tensor for right plot 14 | title1: Optional string to display as title for left plot 15 | title2: Optional string to display as title for right plot 16 | """ 17 | # Create figure with gray background 18 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) 19 | fig.patch.set_facecolor('#3d3d3d') 20 | 21 | # Custom colormap (purple to yellow) 22 | custom_cmap = plt.cm.colors.ListedColormap(['#2E1A47', '#FFE135']) 23 | 24 | # Plot left mask 25 | ax1.imshow(~mask1, cmap=custom_cmap, aspect='equal') 26 | ax1.axis('off') 27 | ax1.set_title(f'{title} Mask #1', pad=5, color='white') 28 | ax1.text(-0.05, 0.5, 'Queries', rotation=90, 29 | transform=ax1.transAxes, va='center', color='white') 30 | ax1.text(0.5, -0.05, 'Keys', transform=ax1.transAxes, ha='center', color='white') 31 | ax1.set_facecolor('#3d3d3d') 32 | 33 | # Plot right mask 34 | ax2.imshow(~mask2, cmap=custom_cmap, aspect='equal') 35 | ax2.axis('off') 36 | ax2.set_title(f'{title} Mask #2', pad=5, color='white') 37 | ax2.text(-0.05, 0.5, 'Queries', rotation=90, 38 | transform=ax2.transAxes, va='center', color='white') 39 | ax2.text(0.5, -0.05, 'Keys', transform=ax2.transAxes, ha='center', color='white') 40 | ax2.set_facecolor('#3d3d3d') 41 | 42 | plt.tight_layout() 43 | 44 | 45 | def plot_generation_masks(mask1, mask2, title, mult): 46 | """ 47 | Plot two generation attention masks stacked vertically with custom styling. 48 | 49 | Args: 50 | mask1: 2D boolean numpy array or tensor for top plot (N queries × M keys) 51 | mask2: 2D boolean numpy array or tensor for bottom plot (N queries × M keys) 52 | title: String to display as base title 53 | mult: Multiplier for height of the plot 54 | """ 55 | # Create figure with gray background 56 | n_queries = mask1.shape[0] 57 | height = min(n_queries * mult, 10) # Scale height with queries, bounded between 4 and 10 58 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, height)) 59 | fig.patch.set_facecolor('#3d3d3d') 60 | 61 | # Custom colormap (purple for background False, yellow for True, red for generation False) 62 | colors = ['#2E1A47', '#FFE135', '#E74C3C'] # purple, yellow, red 63 | custom_cmap = plt.cm.colors.ListedColormap(colors) 64 | 65 | for mask, ax, mask_num in [(mask1, ax1, 1), (mask2, ax2, 2)]: 66 | # Create visualization array (0: background False, 1: True, 2: generation False) 67 | n_queries, n_keys = mask.shape 68 | n_gen_keys = n_queries # Last N keys are generation keys 69 | 70 | viz_array = (~mask).astype(int) # Convert to int (False -> 1, True -> 0) 71 | 72 | # Mark generation area with different color (2) if False 73 | gen_area = viz_array[:, -n_gen_keys:] 74 | viz_array[:, -n_gen_keys:] = np.where(gen_area == 1, 2, 0) 75 | 76 | # Plot mask 77 | ax.imshow(viz_array, cmap=custom_cmap, aspect='auto') 78 | ax.axis('off') 79 | ax.set_title(f'{title} Mask #{mask_num}', pad=5, color='white') 80 | ax.text(-0.02, 0.5, 'Queries', rotation=90, 81 | transform=ax.transAxes, va='center', color='white') 82 | ax.text(0.5, -0.3, 'Keys', transform=ax.transAxes, ha='center', color='white') 83 | ax.set_facecolor('#3d3d3d') 84 | 85 | plt.tight_layout() 86 | 87 | 88 | def plot_sparse_attention_results(results): 89 | # Separate Dense results and get baseline 90 | dense_results = [r for r in results if 'Dense' in r['name']] 91 | sparse_results = [r for r in results if 'Dense' not in r['name']] 92 | 93 | # Calculate dense baseline 94 | dense_baseline = np.mean([r['loss'] for r in dense_results]) 95 | 96 | # Group sparse results by method 97 | method_groups = {} 98 | for result in sparse_results: 99 | # Extract method name (remove 'Attention' suffix) 100 | method = result['name'].replace('Attention', '') 101 | 102 | if method not in method_groups: 103 | method_groups[method] = {'sparsity': [], 'loss': []} 104 | 105 | method_groups[method]['sparsity'].append(result['sparsity']) 106 | method_groups[method]['loss'].append(result['loss']) 107 | 108 | # Plot settings 109 | plt.figure(figsize=(10, 6)) 110 | color_palette = ['#2ecc71', '#e74c3c', '#9b59b6', '#3498db', '#f1c40f'] 111 | marker_palette = ['o', 's', 'D', '^', 'v'] 112 | 113 | # Plot each method 114 | for i, (method, data) in enumerate(method_groups.items()): 115 | color_idx = i % len(color_palette) 116 | marker_idx = i % len(marker_palette) 117 | 118 | plt.plot(data['sparsity'], data['loss'], 119 | label=method, 120 | color=color_palette[color_idx], 121 | marker=marker_palette[marker_idx], 122 | linewidth=2, 123 | markersize=8) 124 | 125 | # Add dense baseline 126 | plt.axhline(y=dense_baseline, color='#95a5a6', linestyle='--', label='Dense') 127 | 128 | # Customize plot 129 | plt.xlabel('Sparsity Ratio', fontsize=12) 130 | plt.ylabel('Loss', fontsize=12) 131 | plt.title('Comparison of Sparse Attention Methods', fontsize=14, pad=15) 132 | plt.grid(True, linestyle='--', alpha=0.7) 133 | plt.legend(fontsize=10) 134 | 135 | # Set axis limits 136 | plt.xlim(-0.05, 1.05) 137 | losses = [l for data in method_groups.values() for l in data['loss']] + [dense_baseline] 138 | sparsity = [s for data in method_groups.values() for s in data['sparsity']] 139 | 140 | y_min, y_max = min(losses), max(losses) 141 | y_padding = (y_max - y_min) * 0.1 142 | plt.ylim(y_min - y_padding, y_max + y_padding) 143 | 144 | x_min, x_max = min(sparsity), max(sparsity) 145 | x_padding = (x_max - x_min) * 0.1 146 | plt.xlim(x_min - x_padding, x_max + x_padding) 147 | 148 | # Save plot 149 | plt.tight_layout() 150 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="nano_sparse_attn", 5 | version="0.99", 6 | author="Piotr Nawrot", 7 | author_email="piotr@nawrot.org", 8 | description="nanoSparseAttention: PyTorch implementation of novel sparse attention mechanisms", 9 | url="https://github.com/PiotrNawrot/nanoSparseAttention", 10 | packages=find_packages(include=['nano_sparse_attn', 'nano_sparse_attn.*']), 11 | install_requires=[ 12 | "torch", 13 | "datasets", 14 | "transformers", 15 | "matplotlib" 16 | ], 17 | python_requires=">=3.10", 18 | ) 19 | --------------------------------------------------------------------------------