├── .gitignore
├── README.md
├── assets
├── logo.png
└── pipfreeze.txt
├── main_generate.py
├── main_prefill.py
├── nano_sparse_attn
├── README.md
├── __init__.py
├── attention
│ ├── __init__.py
│ ├── abstract.py
│ ├── inference_handler.py
│ └── sparse_attention.py
└── utils
│ ├── __init__.py
│ ├── constants.py
│ ├── modelling.py
│ └── plotting.py
├── notebooks
└── tutorial.ipynb
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | results/
3 | *.pyc
4 | *.egg-info/
5 | *.egg
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # nanoSparseAttention
2 |
3 |
4 |
5 |
6 |
7 | ## Overview
8 |
9 | nanoSparseAttention provides clean, educational implementations of recent Sparse Attention mechanisms for both prefilling and generation stages of LLM inference. The repository prioritizes clarity and understanding over performance, making it ideal for learning and experimentation.
10 |
11 | We implemented a [Jupyter notebook](./notebooks/tutorial.ipynb) that provides:
12 | 1. Detailed explanation of Sparse Attention concepts
13 | 2. Step-by-step implementation walkthrough
14 | 3. Visualization of attention patterns
15 | 4. Performance comparisons between different methods
16 |
17 | The notebook has been prepared for the purpose of [NeurIPS 2024 Dynamic Sparsity Workshop](https://dynamic-sparsity.github.io/) - check it out if you want to learn more about dynamic execution, not only in the context of self-attention!
18 |
19 | ### Key Features
20 |
21 | - **Pure PyTorch Implementation**: All attention mechanisms are implemented in pure PyTorch for maximum clarity and ease of understanding.
22 | - **Real-world Testing**: Uses Llama-3.2-1B-Instruct model and FiscalNote/billsum dataset for practical experiments.
23 | - **Comprehensive Tutorial**: Includes a detailed Jupyter notebook explaining core concepts and implementations.
24 | - **Extensible Design**: Easy to add new models, datasets, and attention patterns through modular architecture.
25 | - **Flexible Inference**: Supports both prefilling and generation stages with ability to mix both at once.
26 |
27 | ### Implemented Methods
28 |
29 | #### Prefilling Stage
30 | - **Local Window + Attention Sinks** [(Xiao et al, 2023)](https://arxiv.org/abs/2309.17453), [(Han et al, 2024)](https://arxiv.org/abs/2308.16137)
31 | - **Vertical-Slash Attention** [(Jiang et al, 2024)](https://arxiv.org/abs/2407.02490)
32 | - **Block-Sparse Attention** [(Jiang et al, 2024)](https://arxiv.org/abs/2407.02490)
33 |
34 | #### Generation Stage
35 | - **Local Window + Attention Sinks** [(Xiao et al, 2023)](https://arxiv.org/abs/2309.17453), [(Han et al, 2024)](https://arxiv.org/abs/2308.16137)
36 | - **SnapKV** [(Li et al, 2024)](https://arxiv.org/abs/2404.14469)
37 | - **TOVA** [(Oren et al, 2023)](https://arxiv.org/abs/2401.06104)
38 |
39 | ## Installation
40 |
41 | Assuming that we want to use Python venv it's as easy as:
42 |
43 | ```
44 | git clone https://github.com/PiotrNawrot/nano-sparse-attention
45 | cd nano-sparse-attention
46 | python3 -m venv .venv
47 | source .venv/bin/activate
48 | pip install --upgrade pip setuptools wheel psutil
49 | pip install -e ./
50 | ```
51 |
52 | ## Example Usage
53 |
54 | The repository provides two main scripts for experimenting with sparse attention mechanisms:
55 |
56 | ### Prefilling Stage
57 |
58 | ```python
59 | from nano_sparse_attn.attention import InferenceHandler, DenseAttention, LocalAndSinksAttention
60 | from nano_sparse_attn.utils import load_model_and_tokenizer, load_examples, update_attention, model_forward
61 |
62 | # Load model and prepare inputs
63 | model, tokenizer = load_model_and_tokenizer()
64 | model_inputs = load_examples(tokenizer, num_examples=1)
65 |
66 | # Create an inference handler with Local Window + Attention Sinks
67 | handler = InferenceHandler(
68 | prefill_attention=LocalAndSinksAttention(
69 | window_size=256,
70 | attention_sinks=16
71 | ),
72 | generation_attention=DenseAttention()
73 | )
74 |
75 | # Update model's attention mechanism and run forward pass
76 | update_attention(model, handler)
77 | loss = model_forward(model, model_inputs, handler)
78 |
79 | # Get information about the attention mechanism
80 | info = handler.info()
81 | print(f"Loss: {loss}")
82 | print(f"Sparsity: {info['prefill']['sparsity']}")
83 | ```
84 |
85 | ### Generation Stage
86 |
87 | ```python
88 | # Assumes imports from the previous example
89 | from nano_sparse_attn.attention import SnapKVAttention
90 |
91 | # Create an inference handler with SnapKV for generation
92 | handler = InferenceHandler(
93 | prefill_attention=DenseAttention(),
94 | generation_attention=SnapKVAttention(
95 | approximation_window=64,
96 | token_capacity=256
97 | )
98 | )
99 |
100 | # Update model's attention mechanism and run forward pass
101 | update_attention(model, handler)
102 | loss = model_forward(model, model_inputs, handler)
103 |
104 | # Get information about the attention mechanism
105 | info = handler.info()
106 | print(f"Loss: {loss}")
107 | print(f"Sparsity: {info['generation']['sparsity']}")
108 | ```
109 |
110 | For ready-to-use scripts check out [main_prefill.py](./main_prefill.py) and [main_generate.py](./main_generate.py).
111 | For a detailed walkthrough of the repository and information about extending it to new models, datasets, and attention patterns, refer to [this README](./nano_sparse_attn/README.md).
112 |
113 | ## Contributing
114 |
115 | Contributions are welcome! Our goal is to keep this repository up-to-date with the latest Sparse Attention methods, by consistently adding new methods. Feel free to submit a Pull Request if 1) you want a new method to be added or 2) [even better] you have an implementation of a new Sparse Attention method!
116 |
117 | ## Authors
118 |
119 | Piotr Nawrot - [Website](https://piotrnawrot.github.io/) - piotr@nawrot.org
120 |
121 | Edoardo Maria Ponti - [Website](https://ducdauge.github.io/) - eponti@ed.ac.uk
122 |
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/nano-sparse-attention/24e0b5844024a6d17dd899af8f17a0c4042f31d0/assets/logo.png
--------------------------------------------------------------------------------
/assets/pipfreeze.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs==2.4.3
2 | aiohttp==3.10.10
3 | aiosignal==1.3.1
4 | async-timeout==4.0.3
5 | attrs==24.2.0
6 | certifi==2024.8.30
7 | charset-normalizer==3.4.0
8 | datasets==3.1.0
9 | dill==0.3.8
10 | filelock==3.16.1
11 | frozenlist==1.5.0
12 | fsspec==2024.9.0
13 | huggingface-hub==0.26.2
14 | idna==3.10
15 | Jinja2==3.1.4
16 | MarkupSafe==3.0.2
17 | mpmath==1.3.0
18 | multidict==6.1.0
19 | multiprocess==0.70.16
20 | networkx==3.4.2
21 | numpy==2.1.3
22 | nvidia-cublas-cu12==12.4.5.8
23 | nvidia-cuda-cupti-cu12==12.4.127
24 | nvidia-cuda-nvrtc-cu12==12.4.127
25 | nvidia-cuda-runtime-cu12==12.4.127
26 | nvidia-cudnn-cu12==9.1.0.70
27 | nvidia-cufft-cu12==11.2.1.3
28 | nvidia-curand-cu12==10.3.5.147
29 | nvidia-cusolver-cu12==11.6.1.9
30 | nvidia-cusparse-cu12==12.3.1.170
31 | nvidia-nccl-cu12==2.21.5
32 | nvidia-nvjitlink-cu12==12.4.127
33 | nvidia-nvtx-cu12==12.4.127
34 | packaging==24.1
35 | pandas==2.2.3
36 | propcache==0.2.0
37 | psutil==6.1.0
38 | pyarrow==18.0.0
39 | python-dateutil==2.9.0.post0
40 | pytz==2024.2
41 | PyYAML==6.0.2
42 | regex==2024.9.11
43 | requests==2.32.3
44 | safetensors==0.4.5
45 | six==1.16.0
46 | sympy==1.13.1
47 | tokenizers==0.20.1
48 | torch==2.5.1
49 | tqdm==4.66.6
50 | transformers==4.46.1
51 | triton==3.1.0
52 | typing_extensions==4.12.2
53 | tzdata==2024.2
54 | urllib3==2.2.3
55 | xxhash==3.5.0
56 | yarl==1.17.1
57 |
--------------------------------------------------------------------------------
/main_generate.py:
--------------------------------------------------------------------------------
1 | from nano_sparse_attn.utils import (
2 | load_model_and_tokenizer,
3 | load_examples,
4 | model_forward,
5 | update_attention,
6 | CONSTANTS,
7 | )
8 |
9 | from nano_sparse_attn.attention import (
10 | InferenceHandler,
11 | DenseAttention,
12 | LocalAndSinksAttention,
13 | SnapKVAttention,
14 | TOVAAttention,
15 | )
16 |
17 |
18 | model, tokenizer = load_model_and_tokenizer()
19 | model_inputs = load_examples(
20 | tokenizer,
21 | target_length_min=CONSTANTS['runtime_args']['target_length_min'],
22 | target_length_max=CONSTANTS['runtime_args']['target_length_max'],
23 | num_examples=CONSTANTS['runtime_args']['num_examples'],
24 | )
25 |
26 | inference_handlers = []
27 |
28 | # Dense baseline
29 | inference_handlers.append(
30 | InferenceHandler(
31 | prefill_attention=DenseAttention(),
32 | generation_attention=DenseAttention(),
33 | )
34 | )
35 |
36 | # Local and Sinks Generation
37 | for window_size in range(128, 1024+1, 128):
38 | inference_handlers.append(
39 | InferenceHandler(
40 | prefill_attention=DenseAttention(),
41 | generation_attention=LocalAndSinksAttention(
42 | window_size=window_size,
43 | attention_sinks=16,
44 | ),
45 | )
46 | )
47 |
48 | # SnapKV Generation
49 | for token_capacity in range(128, 1024+1, 128):
50 | inference_handlers.append(
51 | InferenceHandler(
52 | prefill_attention=DenseAttention(),
53 | generation_attention=SnapKVAttention(
54 | approximation_window=64,
55 | token_capacity=token_capacity,
56 | ),
57 | )
58 | )
59 |
60 | # TOVA Generation
61 | for token_capacity in range(128, 1024+1, 128):
62 | inference_handlers.append(
63 | InferenceHandler(
64 | prefill_attention=DenseAttention(),
65 | generation_attention=TOVAAttention(token_capacity=token_capacity),
66 | )
67 | )
68 |
69 | for idx, inference_handler in enumerate(inference_handlers):
70 | update_attention(model, inference_handler)
71 | loss = model_forward(model, model_inputs, inference_handler)
72 | info_dict = inference_handler.info()
73 | print(f"InferenceHandler {idx}: \n-- Name: {info_dict['generation']['name']}\n-- Loss: {loss}\n-- Sparsity: {info_dict['generation']['sparsity']}\n-- Params: {info_dict['generation']['params']}")
74 |
--------------------------------------------------------------------------------
/main_prefill.py:
--------------------------------------------------------------------------------
1 | from nano_sparse_attn.utils import (
2 | load_model_and_tokenizer,
3 | load_examples,
4 | model_forward,
5 | update_attention,
6 | CONSTANTS,
7 | )
8 |
9 | from nano_sparse_attn.attention import (
10 | InferenceHandler,
11 | DenseAttention,
12 | LocalAndSinksAttention,
13 | VerticalAndSlashAttention,
14 | BlockSparseAttention,
15 | )
16 |
17 |
18 | model, tokenizer = load_model_and_tokenizer()
19 | model_inputs = load_examples(
20 | tokenizer,
21 | target_length_min=CONSTANTS['runtime_args']['target_length_min'],
22 | target_length_max=CONSTANTS['runtime_args']['target_length_max'],
23 | num_examples=CONSTANTS['runtime_args']['num_examples'],
24 | )
25 |
26 | inference_handlers = []
27 |
28 | # Dense baseline
29 | inference_handlers.append(
30 | InferenceHandler(
31 | prefill_attention=DenseAttention(),
32 | generation_attention=DenseAttention(),
33 | )
34 | )
35 |
36 | for window_size in range(128, 1024+1, 128):
37 | inference_handlers.append(
38 | InferenceHandler(
39 | prefill_attention=LocalAndSinksAttention(
40 | window_size=window_size,
41 | attention_sinks=16,
42 | ),
43 | generation_attention=DenseAttention(),
44 | )
45 | )
46 |
47 | for top_k in range(128, 1024+1, 128):
48 | inference_handlers.append(
49 | InferenceHandler(
50 | prefill_attention=VerticalAndSlashAttention(
51 | top_tokens=top_k,
52 | top_slashes=top_k,
53 | window_size=128,
54 | approximation_window=64,
55 | attention_sinks=16,
56 | ),
57 | generation_attention=DenseAttention(),
58 | )
59 | )
60 |
61 | for top_chunks in range(2, 16+1, 2):
62 | inference_handlers.append(
63 | InferenceHandler(
64 | prefill_attention=BlockSparseAttention(
65 | chunk_size=64,
66 | top_chunks=top_chunks,
67 | ),
68 | generation_attention=DenseAttention(),
69 | )
70 | )
71 |
72 | for idx, inference_handler in enumerate(inference_handlers):
73 | update_attention(model, inference_handler)
74 | loss = model_forward(model, model_inputs, inference_handler)
75 | info_dict = inference_handler.info()
76 | print(f"InferenceHandler {idx}: \n-- Name: {info_dict['prefill']['name']}\n-- Loss: {loss}\n-- Sparsity: {info_dict['prefill']['sparsity']}\n-- Params: {info_dict['prefill']['params']}")
77 |
--------------------------------------------------------------------------------
/nano_sparse_attn/README.md:
--------------------------------------------------------------------------------
1 | ## Technical Details & Extension
2 |
3 | ### Inference Handler & Attention Mechanism
4 |
5 | The core of nanoSparseAttention is built around two key components:
6 |
7 | 1. **InferenceHandler** (`attention/inference_handler.py`): Manages the switching between prefilling and generation attention patterns. It allows you to:
8 | - Use different attention patterns for prefilling and generation stages
9 | - Track sparsity ratios and attention masks for visualization
10 |
11 | 2. **Base Attention** (`attention/abstract.py`): Provides the foundation for implementing sparse attention patterns with:
12 | - Common utility functions (`get_causal_mask`, `get_local_mask`, etc.)
13 | - Sparsity calculation and mask tracking
14 | - Standard attention computation methods
15 |
16 | ### Creating New Sparse Patterns
17 |
18 | To implement a new sparse attention pattern:
19 |
20 | 1. Inherit from the base `Attention` class
21 | 2. Implement either/both:
22 | - `forward()` for prefilling attention
23 | - `generation_forward()` for generation attention
24 | 3. Create attention masks that specify which connections to keep/prune
25 |
26 | Example skeleton:
27 | ```python
28 | class MyNewAttention(Attention):
29 | def init(self, window_size):
30 | super().init()
31 | self.name = 'MyNewAttention'
32 | self.params = {"window_size": window_size}
33 |
34 | def forward(self, queries, keys, values, args, kwargs):
35 | # Create your attention mask
36 | attention_mask = ...
37 |
38 | # Track sparsity
39 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
40 |
41 | # Compute attention
42 | return self.attention(queries, keys, values, attention_mask)
43 | ```
44 |
45 | ### Supporting New Models
46 |
47 | The repository only supports Llama-style attention mechanisms. To add support for a new model:
48 |
49 | 1. Locate the attention implementation in HuggingFace's transformers library (e.g., `modeling_llama.py` for Llama)
50 | 2. Copy and adapt the attention's forward pass to use our inference handler
51 | 3. Ensure correct tensor shapes for queries, keys, and values
52 |
53 | ### Using Custom Datasets
54 |
55 | The repository can be adapted to work with any dataset. The key requirements are:
56 |
57 | 1. Implement a data loading function that returns examples in the format:
58 | ```python
59 | {
60 | 'input_ids': tensor, # Input token ids
61 | 'attention_mask': tensor, # Attention mask
62 | 'output_length': int, # Length of the output sequence
63 | }
64 | ```
65 |
66 | 2. If using instruction-tuned models (like we do), format your prompts appropriately. Our implementation uses chat templates for Llama-style models, but this is configurable.
67 |
68 | Example adaptation:
69 | ```python
70 | def load_my_dataset(tokenizer, kwargs):
71 | # Load your data
72 | dataset = ...
73 | examples = []
74 |
75 | for item in dataset:
76 | # Format your prompt
77 | input_text = f"Your prompt: {item['input']}"
78 |
79 | # Tokenize (with or without chat template)
80 | tokens = tokenizer(
81 | input_text,
82 | return_tensors="pt",
83 | )
84 |
85 | # Get output length for loss calculation
86 | output_length = ...
87 | examples.append({
88 | 'input_ids': tokens['input_ids'],
89 | 'attention_mask': tokens['attention_mask'],
90 | 'output_length': output_length,
91 | })
92 |
93 | return examples
94 | ```
95 |
--------------------------------------------------------------------------------
/nano_sparse_attn/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention.abstract import Attention
2 | from .attention.sparse_attention import (
3 | DenseAttention,
4 | LocalAndSinksAttention,
5 | SnapKVAttention,
6 | TOVAAttention,
7 | )
8 | from .attention.inference_handler import InferenceHandler
9 |
10 | __all__ = [
11 | "Attention",
12 | "DenseAttention",
13 | "LocalAndSinksAttention",
14 | "SnapKVAttention",
15 | "TOVAAttention",
16 | "InferenceHandler",
17 | ]
18 |
--------------------------------------------------------------------------------
/nano_sparse_attn/attention/__init__.py:
--------------------------------------------------------------------------------
1 | from .abstract import Attention
2 | from .sparse_attention import (
3 | DenseAttention,
4 | LocalAndSinksAttention,
5 | VerticalAndSlashAttention,
6 | BlockSparseAttention,
7 | SnapKVAttention,
8 | TOVAAttention,
9 | )
10 | from .inference_handler import InferenceHandler
11 |
12 | __all__ = [
13 | "Attention",
14 | "DenseAttention",
15 | "LocalAndSinksAttention",
16 | "VerticalAndSlashAttention",
17 | "BlockSparseAttention",
18 | "SnapKVAttention",
19 | "TOVAAttention",
20 | "InferenceHandler",
21 | ]
22 |
--------------------------------------------------------------------------------
/nano_sparse_attn/attention/abstract.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class Attention(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.name = 'Attention'
10 | self.masks = []
11 | self.sparsity_ratios = []
12 | self.params = {}
13 | self.layer_counter = 0
14 |
15 | def info(self):
16 | return {
17 | "name": self.name,
18 | "params": self.params,
19 | "masks": self.masks,
20 | "sparsity": self.calculate_average_sparsity_ratio(self.sparsity_ratios)
21 | }
22 |
23 | def __getattr__(self, name):
24 | if name in self.params:
25 | return self.params[name]
26 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
27 |
28 | def maybe_save_mask(self, attention_mask, modulo_layers=8):
29 | if self.layer_counter % modulo_layers == 0:
30 | self.masks.append(attention_mask[0, 0].clone().cpu().numpy())
31 |
32 | self.layer_counter += 1
33 |
34 | def forward(self, queries, keys, values, *args, **kwargs):
35 | """Base forward method for prefilling attention patterns.
36 |
37 | Args:
38 | queries: (batch_size, num_heads, seq_len, head_dim)
39 | keys: (batch_size, num_heads, seq_len, head_dim)
40 | values: (batch_size, num_heads, seq_len, head_dim)
41 |
42 | Returns:
43 | attention_output: (batch_size, num_heads, seq_len, head_dim)
44 | """
45 | raise NotImplementedError
46 |
47 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values,
48 | generation_queries, generation_keys, generation_values, *args, **kwargs):
49 | """Base forward method for generation attention patterns.
50 |
51 | Args:
52 | prefilling_queries: Queries from prefilling stage
53 | prefilling_keys: Keys from prefilling stage
54 | prefilling_values: Values from prefilling stage
55 | generation_queries: New queries for generation
56 | generation_keys: New keys for generation
57 | generation_values: New values for generation
58 |
59 | Returns:
60 | attention_output: Output for generation tokens
61 | """
62 | raise NotImplementedError
63 |
64 | @staticmethod
65 | def attention(queries, keys, values, attention_mask, return_attention_scores=False):
66 | """Standard attention computation."""
67 | attention_weights = torch.matmul(queries, keys.transpose(2, 3)) / math.sqrt(queries.size(-1))
68 | attention_weights += attention_mask.to(queries.dtype) * torch.finfo(queries.dtype).min
69 | attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to(queries.dtype)
70 | attention_output = torch.matmul(attention_weights, values)
71 | if return_attention_scores:
72 | return attention_output, attention_weights
73 | return attention_output
74 |
75 | @staticmethod
76 | def get_causal_mask(seq_len, device):
77 | """Creates a causal mask where future tokens cannot attend to past tokens.
78 |
79 | Args:
80 | seq_len: Length of the sequence
81 | device: Device to create the mask on
82 |
83 | Returns:
84 | mask: Boolean tensor of shape (1, 1, seq_len, seq_len) where True/1 indicates
85 | that position (i,j) should be masked (set to -inf before softmax)
86 | """
87 | mask = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
88 | mask = torch.triu(mask, diagonal=1)
89 | return mask.unsqueeze(0).unsqueeze(0)
90 |
91 | @staticmethod
92 | def get_local_mask(seq_len, window_size, device):
93 | """Creates a local attention mask where tokens can only attend to nearby tokens
94 | within a fixed window size, plus the causal constraint.
95 |
96 | Args:
97 | seq_len: Length of the sequence
98 | window_size: Size of the local attention window including current token
99 | device: Device to create the mask on
100 |
101 | Returns:
102 | mask: Boolean tensor of shape (1, 1, seq_len, seq_len) where True/1 indicates
103 | that position (i,j) should be masked (set to -inf before softmax)
104 | """
105 | mask = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
106 | mask = torch.triu(mask, diagonal=-(window_size-1))
107 | mask = torch.tril(mask, diagonal=0)
108 | mask = (~mask)
109 | return mask.unsqueeze(0).unsqueeze(0)
110 |
111 | @staticmethod
112 | def get_generation_mask(gen_len, prefill_len, device):
113 | """Creates a mask that allows generation tokens to:
114 | 1. Attend to all prefilling tokens
115 | 2. Attend causally to other generation tokens
116 |
117 | Args:
118 | gen_len: Number of tokens being generated
119 | prefill_len: Number of tokens in the prefill context
120 | device: Device to create the mask on
121 |
122 | Returns:
123 | mask: Boolean tensor of shape (1, 1, gen_len, prefill_len + gen_len)
124 | where True indicates positions that should be masked
125 | """
126 | return torch.cat([
127 | torch.zeros((1, 1, gen_len, prefill_len), dtype=torch.bool, device=device),
128 | Attention.get_causal_mask(gen_len, device)
129 | ], dim=-1)
130 |
131 | @staticmethod
132 | def calculate_sparsity_ratio(mask):
133 | """Calculates the sparsity ratio of an attention mask.
134 |
135 | This method computes what fraction of the possible attention connections are masked
136 | assuming that attention is causal, i.e., that tokens cannot attend to tokens before them.
137 | A higher ratio means more sparse attention. Asummes batch_size = 1.
138 |
139 | Args:
140 | mask: Boolean tensor of shape (batch_size, num_heads, queries_len, keys_len) where
141 | True/1 indicates masked (disabled) attention connections
142 |
143 | Returns:
144 | float: The sparsity ratio between 0 and 1, where:
145 | 0 means all possible connections are enabled (dense attention)
146 | 1 means all possible connections are masked (completely sparse)
147 | """
148 |
149 | _, _, queries_len, keys_len = mask.shape
150 |
151 | if queries_len != keys_len:
152 | prefill_length = keys_len - queries_len
153 | total_connections = queries_len * (queries_len + 1) // 2 + prefill_length * queries_len
154 | else:
155 | total_connections = queries_len * (queries_len + 1) // 2
156 |
157 | connections_per_head = (~mask).long().sum(dim=(-1, -2))
158 | non_masked_ratio = (connections_per_head.float() / total_connections).mean(dim=-1).item()
159 | sparsity_ratio = 1 - non_masked_ratio
160 |
161 | return sparsity_ratio
162 |
163 | @staticmethod
164 | def calculate_average_sparsity_ratio(sparsity_ratios):
165 | return sum(sparsity_ratios) / len(sparsity_ratios) if sparsity_ratios else 0
166 |
--------------------------------------------------------------------------------
/nano_sparse_attn/attention/inference_handler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class InferenceHandler(nn.Module):
5 | """Handles different attention patterns for prefilling and generation stages.
6 |
7 | This class manages the splitting of input sequences into prefilling and generation
8 | stages, and applies the appropriate attention pattern for each stage.
9 |
10 | Args:
11 | prefill_attention: Attention module to use during prefilling
12 | generation_attention: Attention module to use during generation
13 | """
14 | def __init__(self, prefill_attention, generation_attention):
15 | super().__init__()
16 | self.prefill_attention = prefill_attention
17 | self.generation_attention = generation_attention
18 | self.name = f"{prefill_attention.name}_to_{generation_attention.name}"
19 |
20 | def set_output_length(self, output_length):
21 | self.output_length = output_length
22 |
23 | def info(self):
24 | return {
25 | "prefill": self.prefill_attention.info(),
26 | "generation": self.generation_attention.info()
27 | }
28 |
29 | def forward(self, queries, keys, values, layer_idx=None):
30 | """Handles attention computation for both prefilling and generation stages.
31 |
32 | Args:
33 | queries: (batch_size, num_heads, seq_len, head_dim)
34 | keys: (batch_size, num_heads, seq_len, head_dim)
35 | values: (batch_size, num_heads, seq_len, head_dim)
36 | layer_idx: Optional layer index for some attention implementations
37 |
38 | Returns:
39 | attention_output: Combined output from both stages
40 | """
41 | # Split sequence into prefill and generation parts
42 | input_length = queries.size(-2) - self.output_length
43 | assert input_length > 0, "Input length must be > 0"
44 |
45 | # Prefilling stage
46 | prefill_output = self.prefill_attention.forward(
47 | queries=queries[..., :input_length, :],
48 | keys=keys[..., :input_length, :],
49 | values=values[..., :input_length, :],
50 | layer_idx=layer_idx
51 | )
52 |
53 | # Generation stage
54 | generation_output = self.generation_attention.generation_forward(
55 | prefilling_queries=queries[..., :input_length, :],
56 | prefilling_keys=keys[..., :input_length, :],
57 | prefilling_values=values[..., :input_length, :],
58 | generation_queries=queries[..., input_length:, :],
59 | generation_keys=keys[..., input_length:, :],
60 | generation_values=values[..., input_length:, :],
61 | layer_idx=layer_idx
62 | )
63 |
64 | # Combine outputs
65 | return torch.cat([prefill_output, generation_output], dim=-2)
66 |
--------------------------------------------------------------------------------
/nano_sparse_attn/attention/sparse_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from .abstract import Attention
6 |
7 |
8 | class DenseAttention(Attention):
9 | def __init__(self):
10 | super().__init__()
11 | self.name = 'DenseAttention'
12 |
13 | def forward(self, queries, keys, values, *args, **kwargs):
14 | attention_mask = self.get_causal_mask(queries.size(-2), queries.device)
15 | self.maybe_save_mask(attention_mask)
16 | return self.attention(queries, keys, values, attention_mask)
17 |
18 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values,
19 | generation_queries, generation_keys, generation_values, *args, **kwargs):
20 | keys = torch.cat([prefilling_keys, generation_keys], dim=-2)
21 | values = torch.cat([prefilling_values, generation_values], dim=-2)
22 |
23 | generation_mask = self.get_generation_mask(
24 | gen_len=generation_queries.size(-2),
25 | prefill_len=prefilling_keys.size(-2),
26 | device=generation_queries.device
27 | )
28 |
29 | return self.attention(generation_queries, keys, values, generation_mask)
30 |
31 |
32 | class LocalAndSinksAttention(Attention):
33 | """Implements a sparse attention pattern combining local windows with global attention sinks.
34 |
35 | This attention mechanism reduces computational complexity by:
36 | 1. Allowing each token to attend only to nearby tokens within a fixed window
37 | 2. Designating the first K tokens as "attention sinks" that can be attended to by all tokens
38 |
39 | This creates a sparse pattern where most tokens have local connectivity, but important
40 | context tokens (sinks) maintain global connectivity.
41 |
42 | Args:
43 | window_size (int): Size of the local attention window around each token
44 | attention_sinks (int): Number of initial tokens that serve as global attention sinks
45 |
46 | Reference Papers:
47 | - (Xiao et al, 2023) https://arxiv.org/abs/2309.17453
48 | - (Han et al, 2024) https://arxiv.org/abs/2308.16137
49 | """
50 | def __init__(self, window_size, attention_sinks):
51 | super().__init__()
52 | self.name = 'LocalAndSinksAttention'
53 | self.params = {
54 | "window_size": window_size,
55 | "attention_sinks": attention_sinks
56 | }
57 |
58 | def create_mask(self, seq_len, device):
59 | """Creates sparse attention mask with local windows and global sinks."""
60 | assert self.window_size <= seq_len, "Window size must be less or equal to sequence length"
61 | assert self.attention_sinks <= seq_len, "Number of attention sinks must be less or equal to sequence length"
62 |
63 | # Create base local attention window
64 | mask = self.get_local_mask(seq_len, self.window_size, device)
65 |
66 | # Allow attention to sink tokens
67 | mask[..., :, :self.attention_sinks] = 0
68 | mask = mask | self.get_causal_mask(seq_len, device)
69 |
70 | return mask
71 |
72 | def forward(self, queries, keys, values, *args, **kwargs):
73 | """Sparse attention for prefilling."""
74 | attention_mask = self.create_mask(queries.size(-2), queries.device)
75 |
76 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
77 | self.maybe_save_mask(attention_mask)
78 |
79 | return self.attention(queries, keys, values, attention_mask)
80 |
81 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values,
82 | generation_queries, generation_keys, generation_values, *args, **kwargs):
83 | assert self.attention_sinks <= prefilling_queries.size(-2)
84 |
85 | total_keys = torch.cat([
86 | prefilling_keys,
87 | generation_keys,
88 | ], dim=-2)
89 |
90 | total_values = torch.cat([
91 | prefilling_values,
92 | generation_values,
93 | ], dim=-2)
94 |
95 | attention_mask = self.create_mask(total_keys.size(-2), generation_queries.device)
96 | attention_mask = attention_mask[..., -generation_queries.size(-2):, :]
97 |
98 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
99 | self.maybe_save_mask(attention_mask)
100 |
101 | return self.attention(generation_queries, total_keys, total_values, attention_mask)
102 |
103 |
104 | class VerticalAndSlashAttention(Attention):
105 | """Implements a content-dependent sparse attention combining top tokens and diagonal patterns,
106 | known as Vertical-Slash, introduced by (Jiang et al, 2024) in the MInference 1.0 paper.
107 |
108 | This attention mechanism creates a sparse pattern by combining two types of connectivity:
109 | 1. Dynamic attention to top-K most relevant tokens (content-based)
110 | 2. Dynamic attention along top-K diagonal stripes (structure-based)
111 |
112 | The diagonal stripes (slashes) capture recurring patterns at fixed offsets, while
113 | top tokens capture semantic relevance. This combines structural and semantic sparsity.
114 |
115 | Args:
116 | attention_sinks (int): Number of initial tokens that serve as global attention sinks
117 | window_size (int): Size of the local attention window around each token
118 | top_tokens (int): Number of highest-scoring tokens to attend to globally
119 | top_slashes (int): Number of diagonal stripes to include in the pattern
120 | approximation_window (int): Size of the window used to approximate attention scores
121 |
122 | Reference Papers:
123 | - (Jiang et al, 2024) https://arxiv.org/abs/2407.02490
124 | """
125 | def __init__(self, attention_sinks, window_size, top_tokens, top_slashes, approximation_window):
126 | super().__init__()
127 | self.name = 'VerticalAndSlashAttention'
128 | self.params = {
129 | "attention_sinks": attention_sinks,
130 | "window_size": window_size,
131 | "top_tokens": top_tokens,
132 | "top_slashes": top_slashes,
133 | "approximation_window": approximation_window,
134 | }
135 |
136 | def fill_diagonals(self, L: int, indexes: torch.Tensor, device='cuda', chunk_size=1024):
137 | """
138 | Memory-efficient implementation to create a mask LxL with diagonals filled with False.
139 | Processes diagonals in chunks to reduce peak memory usage.
140 |
141 | Parameters:
142 | - L (int): The size of the square matrix.
143 | - indexes (torch.Tensor): Tensor of shape (B, H, M) with integer values which determine
144 | the diagonals to be filled in the output matrix for each batch and head.
145 | - device: The device to perform computations on ('cuda' or 'cpu').
146 | - chunk_size: Size of chunks to process at once to manage memory usage.
147 |
148 | Returns:
149 | - mask (torch.Tensor): A boolean matrix of size (B, H, L, L) with specified diagonals filled with True.
150 | """
151 | assert (indexes <= 0).all().item(), "Indexes must be on or below diagonal."
152 |
153 | batch_size, num_heads, num_diagonals = indexes.shape
154 |
155 | # Create output tensor
156 | mask_dense = torch.ones((batch_size, num_heads, L, L), dtype=torch.bool, device=device)
157 |
158 | # Process the sequence length in chunks
159 | for chunk_start in range(0, L, chunk_size):
160 | chunk_end = min(chunk_start + chunk_size, L)
161 | chunk_len = chunk_end - chunk_start
162 |
163 | # Create row indices for this chunk
164 | row_indices = torch.arange(chunk_start, chunk_end, device=device, dtype=torch.int32)
165 | row_indices = row_indices.view(1, 1, 1, chunk_len)
166 | row_indices = row_indices.expand(batch_size, num_heads, num_diagonals, chunk_len)
167 |
168 | # Add the diagonal offsets to get column indices
169 | col_indices = row_indices + indexes.unsqueeze(-1).to(torch.int32)
170 |
171 | # Mask out indices that are out of bounds
172 | valid_mask = (col_indices >= 0) & (col_indices < L)
173 |
174 | if not valid_mask.any():
175 | continue
176 |
177 | # Create batch and head indices for valid positions only
178 | batch_idx = torch.arange(batch_size, device=device).view(-1, 1, 1, 1)
179 | batch_idx = batch_idx.expand(-1, num_heads, num_diagonals, chunk_len)
180 | head_idx = torch.arange(num_heads, device=device).view(1, -1, 1, 1)
181 | head_idx = head_idx.expand(batch_size, -1, num_diagonals, chunk_len)
182 |
183 | # Select only valid indices
184 | valid_batch_idx = batch_idx[valid_mask]
185 | valid_head_idx = head_idx[valid_mask]
186 | valid_row_idx = row_indices[valid_mask]
187 | valid_col_idx = col_indices[valid_mask]
188 |
189 | # Set the valid diagonal elements to False
190 | mask_dense[valid_batch_idx, valid_head_idx, valid_row_idx, valid_col_idx] = False
191 |
192 | # Free memory explicitly
193 | del row_indices, col_indices, valid_mask, batch_idx, head_idx
194 |
195 | return mask_dense
196 |
197 | def sum_over_diagonals(self, matrix):
198 | """Efficiently sum values along diagonals of the attention matrix.
199 |
200 | This method uses stride tricks to efficiently extract and sum diagonals:
201 | 1. Pad the matrix with zeros on both sides to handle all possible diagonals
202 | 2. Create a strided view that groups elements along diagonals
203 | 3. Sum along each diagonal to get their total attention scores
204 |
205 | The strided operation creates a view where each row contains elements from
206 | one diagonal, allowing for efficient parallel summation.
207 |
208 | Args:
209 | matrix: Attention scores tensor of shape (batch_size, num_heads, queries, keys)
210 |
211 | Returns:
212 | Tensor of shape (batch_size, num_heads, queries + keys - 1) containing
213 | summed attention scores for each diagonal
214 |
215 | This function is based on the implementation from:
216 | https://github.com/microsoft/MInference/blob/main/minference/modules/minference_forward.py#L101
217 | """
218 | batch_size, num_heads, queries, keys = matrix.shape
219 | zero_matrix = torch.zeros((batch_size, num_heads, queries, queries), device=matrix.device)
220 | matrix_padded = torch.cat((zero_matrix, matrix, zero_matrix), -1)
221 | matrix_strided = matrix_padded.as_strided(
222 | (
223 | batch_size,
224 | num_heads,
225 | queries,
226 | queries + keys
227 | ),
228 | (
229 | num_heads * queries * (2 * queries + keys),
230 | queries * (2 * queries + keys),
231 | 2 * queries + keys + 1,
232 | 1
233 | )
234 | )
235 | sum_diagonals = torch.sum(matrix_strided, 2)
236 | return sum_diagonals[:, :, 1:]
237 |
238 | def create_mask(self, queries, keys, seq_len, device):
239 | assert self.top_slashes <= seq_len
240 | assert self.top_tokens <= seq_len
241 | assert self.top_slashes >= self.window_size
242 | assert self.attention_sinks <= self.top_tokens
243 |
244 | # Approximate attention scores
245 | approx_queries = queries[..., -self.approximation_window:, :]
246 | attention_scores = torch.matmul(approx_queries, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1))
247 | attention_scores[..., -self.approximation_window:] += self.get_causal_mask(self.approximation_window, device) * torch.finfo(queries.dtype).min
248 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype)
249 |
250 | # Get top tokens
251 | summed_scores = attention_scores.sum(dim=-2)
252 | summed_scores[..., :self.attention_sinks] = torch.inf
253 | top_k_indices = torch.topk(summed_scores, self.top_tokens, dim=-1, sorted=False).indices
254 |
255 | # Get top_k slashes
256 | top_k_slash = self.sum_over_diagonals(attention_scores)[..., :-self.approximation_window + 1]
257 | top_k_slash[..., -self.window_size:] = torch.inf
258 | top_k_slash = torch.topk(top_k_slash, self.top_slashes, -1).indices - (seq_len - 1)
259 |
260 | # Get final mask
261 | mask = self.fill_diagonals(seq_len, top_k_slash, device)
262 | mask.scatter_(-1, top_k_indices.unsqueeze(-2).expand(-1, -1, seq_len, -1), 0)
263 | mask = mask | self.get_causal_mask(seq_len, device)
264 |
265 | return mask
266 |
267 | def forward(self, queries, keys, values, *args, **kwargs):
268 | """Sparse attention for prefilling using top tokens and slashes."""
269 | attention_mask = self.create_mask(queries, keys, queries.size(-2), queries.device)
270 |
271 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
272 | self.maybe_save_mask(attention_mask)
273 |
274 | return self.attention(queries, keys, values, attention_mask)
275 |
276 |
277 | class BlockSparseAttention(Attention):
278 | """Implements a Block-Sparse Attention pattern based on chunk-level relevance,
279 | introduced by (Jiang et al, 2024) in the MInference 1.0 paper.
280 |
281 | This attention mechanism reduces complexity by operating on chunks of tokens:
282 | 1. Divides the sequence into fixed-size chunks.
283 | 2. Computes chunk-level attention scores with averaged token representations.
284 | 3. Allows each chunk to attend to the top-K most relevant chunks.
285 |
286 | This creates a coarse-grained sparse pattern where entire blocks of tokens attend
287 | to each other, making it efficient for long sequences while preserving approximate
288 | attention patterns.
289 |
290 | Contrary to the original paper, in this implementation we additionally:
291 | 1. Include the local window of size chunk_size around each token because we found
292 | it crucial for performance and it is non trivial to be executed by this pattern
293 | for high sparsity ratios.
294 | 2. Always pick the prefix chunk because we also found it crucial for performance but
295 | it is not always picked as top-1 chunk.
296 |
297 | For this reason, top_chunks = 1 BlockSparseAttention is equivalent to
298 | LocalAndSinksAttention(window_size=chunk_size, attention_sinks=chunk_size).
299 |
300 | Args:
301 | chunk_size (int): Size of each token chunk/block
302 | top_chunks (int): Number of highest-scoring chunks each chunk can attend to
303 |
304 | Reference Papers:
305 | - (Jiang et al, 2024) https://arxiv.org/abs/2407.02490
306 | """
307 | def __init__(self, chunk_size, top_chunks):
308 | super().__init__()
309 | self.name = 'BlockSparseAttention'
310 | self.params = {
311 | "chunk_size": chunk_size,
312 | "top_chunks": top_chunks,
313 | }
314 |
315 | def create_mask(self, queries, keys, seq_len, device):
316 | assert self.chunk_size < seq_len, "Chunk size must be smaller than sequence length"
317 | assert self.top_chunks > 0, "Must select at least one top chunk"
318 | assert self.chunk_size >= 8, "Recommended chunk size is >= 8."
319 |
320 | # Calculate number of chunks and padding needed
321 | num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size
322 | padded_seq_len = num_chunks * self.chunk_size
323 | padding_len = padded_seq_len - seq_len
324 |
325 | assert self.top_chunks <= num_chunks, "Cannot select more top chunks than available chunks"
326 |
327 | # Pad queries and keys if needed
328 | if padding_len > 0:
329 | queries_padded = torch.nn.functional.pad(queries, (0, 0, 0, padding_len))
330 | keys_padded = torch.nn.functional.pad(keys, (0, 0, 0, padding_len))
331 | else:
332 | queries_padded = queries
333 | keys_padded = keys
334 |
335 | # Reshape padded queries and keys to chunks
336 | query_chunks = queries_padded.reshape(queries.size(0), queries.size(1), num_chunks, self.chunk_size, -1)
337 | key_chunks = keys_padded.reshape(keys.size(0), keys.size(1), num_chunks, self.chunk_size, -1)
338 |
339 | # Compute chunk representations by averaging
340 | query_chunk_representations = query_chunks.mean(dim=-2)
341 | key_chunk_representations = key_chunks.mean(dim=-2)
342 |
343 | # Compute attention scores between chunk representations
344 | chunk_attention_scores = torch.matmul(query_chunk_representations, key_chunk_representations.transpose(-2, -1))
345 |
346 | # Add causal masking of upper triangle
347 | chunk_attention_scores.masked_fill_(self.get_causal_mask(num_chunks, device), float('-inf'))
348 |
349 | # Always pick the prefix chunk
350 | chunk_attention_scores[..., 0] = float('inf')
351 |
352 | # Get top-k key chunks for each query chunk
353 | top_k_chunk_indices = torch.topk(chunk_attention_scores, self.top_chunks, dim=-1, sorted=False).indices
354 |
355 | # Create a mask for top-k interactions
356 | top_k_mask = torch.ones((num_chunks, num_chunks), dtype=torch.bool, device=device)
357 | top_k_mask = top_k_mask.unsqueeze(0).unsqueeze(0).repeat(queries.size(0), queries.size(1), 1, 1)
358 | top_k_mask.scatter_(-1, top_k_chunk_indices, 0)
359 |
360 | # Expand mask to padded sequence length
361 | mask = top_k_mask.repeat_interleave(self.chunk_size, dim=-2).repeat_interleave(self.chunk_size, dim=-1)
362 |
363 | # Include the local window of size chunk_size around each token
364 | mask = mask & self.get_local_mask(padded_seq_len, self.chunk_size, queries.device)
365 |
366 | # Include the causal mask
367 | mask = mask | self.get_causal_mask(padded_seq_len, queries.device)
368 |
369 | # Remove padding from mask if needed
370 | if padding_len > 0:
371 | mask = mask[..., :seq_len, :seq_len]
372 |
373 | return mask
374 |
375 | def forward(self, queries, keys, values, *args, **kwargs):
376 | """Block-sparse attention for prefilling."""
377 | attention_mask = self.create_mask(queries, keys, queries.size(-2), queries.device)
378 |
379 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
380 | self.maybe_save_mask(attention_mask)
381 |
382 | return self.attention(queries, keys, values, attention_mask)
383 |
384 |
385 | class SnapKVAttention(Attention):
386 | """Implements SnapKV's attention pattern for efficient text generation.
387 |
388 | This attention mechanism compresses the Pre-Filling KV Cache, therefore reducing
389 | memory usage during generation, by:
390 | 1. Approximating attention scores using a suffix window of queries
391 | 2. Using attention score pooling to identify important token clusters
392 | 3. Keeping only the most relevant tokens plus a recent window
393 |
394 | Args:
395 | token_capacity (int): Maximum number of tokens to keep in compressed history
396 | approximation_window (int): Number of suffix tokens used to approximate attention scores
397 | kernel_size (int): Size of the pooling kernel for score aggregation
398 |
399 | Reference:
400 | - (Li et al, 2024) https://arxiv.org/abs/2404.14469
401 | """
402 | def __init__(self, token_capacity, approximation_window=64, kernel_size=7):
403 | super().__init__()
404 | self.name = 'SnapKVAttention'
405 | self.params = {
406 | "token_capacity": token_capacity,
407 | "approximation_window": approximation_window,
408 | "kernel_size": kernel_size,
409 | }
410 |
411 | def create_mask_for_prefill(self, queries, keys, seq_len, device):
412 | """Create mask for generation with compressed KV cache.
413 |
414 | It uses prefilling queries and keys to estimate which tokens will be important
415 | during generation and keeps only the most important tokens plus a recent window.
416 |
417 | The output mask only concerns the prefilling tokens and has to be extended to
418 | account for the tokens from the generation phase.
419 | """
420 | assert self.token_capacity >= self.approximation_window
421 |
422 | # Approximate attention scores
423 | approx_queries = queries[..., -self.approximation_window:, :]
424 | attention_scores = torch.matmul(approx_queries, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1))
425 | attention_scores[..., -self.approximation_window:] += self.get_causal_mask(self.approximation_window, device) * torch.finfo(queries.dtype).min
426 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype)
427 |
428 | # Sum over queries to get per-key importance and apply average pooling
429 | key_importance = attention_scores.sum(dim=-2)
430 | key_importance = F.avg_pool1d(
431 | key_importance,
432 | kernel_size=self.kernel_size,
433 | padding=self.kernel_size // 2,
434 | stride=1
435 | )
436 |
437 | # Always keep the window tokens
438 | key_importance[..., -self.approximation_window:] = torch.inf
439 |
440 | # Keep only the top_k tokens
441 | mask = torch.ones((queries.size(0), queries.size(1), seq_len), dtype=torch.bool, device=device)
442 | top_indices = key_importance.topk(self.token_capacity, dim=-1).indices
443 | mask.scatter_(-1, top_indices, False)
444 |
445 | # Expand mask to proper attention shape
446 | mask = mask.unsqueeze(-2)
447 |
448 | return mask
449 |
450 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values,
451 | generation_queries, generation_keys, generation_values, *args, **kwargs):
452 | # Concatenate prefilling and generation KV states
453 | keys = torch.cat([prefilling_keys, generation_keys], dim=-2)
454 | values = torch.cat([prefilling_values, generation_values], dim=-2)
455 |
456 | # Create mask for prefilling tokens
457 | prefill_mask = self.create_mask_for_prefill(prefilling_queries, prefilling_keys, prefilling_keys.size(-2), generation_queries.device)
458 |
459 | # Create dense causal mask for generation tokens
460 | attention_mask = self.get_generation_mask(
461 | gen_len=generation_queries.size(-2),
462 | prefill_len=prefilling_keys.size(-2),
463 | device=generation_queries.device
464 | ).repeat(generation_queries.size(0), generation_queries.size(1), 1, 1)
465 |
466 | # Combine masks
467 | attention_mask[..., :prefilling_keys.size(-2)] |= prefill_mask
468 |
469 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
470 | self.maybe_save_mask(attention_mask)
471 |
472 | return self.attention(generation_queries, keys, values, attention_mask)
473 |
474 |
475 | class TOVAAttention(Attention):
476 | """Implements Token Omission Via Attention (TOVA) for efficient text generation.
477 |
478 | This attention mechanism dynamically prunes the KV cache during generation by:
479 | 1. Using the last prefilling token to identify initially important tokens
480 | 2. During generation, maintaining a fixed-size cache by removing the least attended token
481 | after processing each new token
482 |
483 | Args:
484 | token_capacity (int): Maximum number of tokens to keep in the pruned KV cache
485 |
486 | Reference:
487 | - (Oren et al, 2023) https://arxiv.org/abs/2401.06104
488 | """
489 | def __init__(self, token_capacity):
490 | super().__init__()
491 | self.name = 'TOVAAttention'
492 | self.params = {
493 | "token_capacity": token_capacity
494 | }
495 |
496 | def create_mask_for_prefill(self, queries, keys, seq_len, device):
497 | """Create initial mask based on attention scores from the last prefilling token."""
498 | if self.token_capacity >= seq_len:
499 | return torch.zeros((queries.size(0), queries.size(1), 1, seq_len), dtype=torch.bool, device=device)
500 |
501 | # Get attention scores for the last prefilling token
502 | last_query = queries[..., -1:, :]
503 | attention_scores = torch.matmul(last_query, keys.transpose(-2, -1)) / math.sqrt(queries.size(-1))
504 | attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(queries.dtype)
505 |
506 | # Average attention scores across heads
507 | mean_attention_scores = attention_scores.mean(dim=1)
508 |
509 | # Create mask keeping only top-k tokens
510 | mask = torch.ones((queries.size(0), queries.size(1), seq_len), dtype=torch.bool, device=device)
511 | top_indices = mean_attention_scores.topk(self.token_capacity, dim=-1).indices
512 | mask.scatter_(-1, top_indices.expand(-1, queries.size(1), -1), False)
513 |
514 | # Expand mask to proper attention shape
515 | mask = mask.unsqueeze(-2)
516 |
517 | return mask
518 |
519 | def generation_forward(self, prefilling_queries, prefilling_keys, prefilling_values,
520 | generation_queries, generation_keys, generation_values, *args, **kwargs):
521 | # Get initial mask for prefilling tokens
522 | current_mask = self.create_mask_for_prefill(
523 | prefilling_queries,
524 | prefilling_keys,
525 | prefilling_keys.size(-2),
526 | generation_queries.device
527 | )
528 |
529 | # Initialize lists to store outputs and updated keys/values
530 | outputs = []
531 | current_keys = prefilling_keys
532 | current_values = prefilling_values
533 |
534 | # Intialise final mask we'll output
535 | attention_mask = torch.ones((
536 | generation_queries.size(0),
537 | generation_queries.size(1),
538 | generation_queries.size(2),
539 | prefilling_keys.size(-2) + generation_keys.size(-2),
540 | ), dtype=torch.bool, device=generation_queries.device)
541 |
542 | # Process generation tokens one by one
543 | for idx in range(generation_queries.size(-2)):
544 | # Get current generation token
545 | current_query = generation_queries[..., idx:idx+1, :]
546 | current_gen_key = generation_keys[..., idx:idx+1, :]
547 | current_gen_value = generation_values[..., idx:idx+1, :]
548 |
549 | # Extend keys and values
550 | current_keys = torch.cat([current_keys, current_gen_key], dim=-2)
551 | current_values = torch.cat([current_values, current_gen_value], dim=-2)
552 |
553 | # Extend mask for the new token (always attended)
554 | current_mask = torch.cat([
555 | current_mask,
556 | torch.zeros((current_query.size(0), current_query.size(1), 1, 1), dtype=torch.bool, device=current_query.device)
557 | ], dim=-1)
558 |
559 | attention_mask[..., idx:idx+1, :current_keys.size(-2)] = current_mask
560 |
561 | # Compute attention with scores
562 | output, attention_scores = self.attention(
563 | current_query,
564 | current_keys,
565 | current_values,
566 | current_mask,
567 | return_attention_scores=True
568 | )
569 | outputs.append(output)
570 |
571 | # If we exceed capacity, mask the token with lowest attention score
572 | if current_keys.size(-2) > self.token_capacity:
573 | # Set scores to inf where tokens were already masked
574 | attention_scores = attention_scores.masked_fill(current_mask, float('inf'))
575 |
576 | # Average attention scores across heads
577 | mean_scores = attention_scores.mean(dim=1, keepdim=True)
578 |
579 | # Find token with lowest attention score
580 | min_indices = mean_scores.argmin(dim=-1, keepdim=True)
581 | min_indices = min_indices.expand(-1, current_query.size(1), -1, -1)
582 |
583 | # Update mask to exclude the lowest scoring token
584 | current_mask.scatter_(-1, min_indices, True)
585 |
586 | # Concatenate all outputs
587 | final_output = torch.cat(outputs, dim=-2)
588 |
589 | self.sparsity_ratios.append(self.calculate_sparsity_ratio(attention_mask))
590 | self.maybe_save_mask(attention_mask)
591 |
592 | return final_output
593 |
--------------------------------------------------------------------------------
/nano_sparse_attn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import CONSTANTS
2 | from .modelling import (
3 | load_model_and_tokenizer,
4 | load_examples,
5 | model_forward,
6 | update_attention,
7 | )
8 | from .plotting import (
9 | plot_sparse_attention_results,
10 | plot_prefill_masks,
11 | plot_generation_masks,
12 | )
13 |
14 | __all__ = [
15 | "CONSTANTS",
16 | "load_model_and_tokenizer",
17 | "load_examples",
18 | "model_forward",
19 | "update_attention",
20 | "plot_sparse_attention_results",
21 | "plot_prefill_masks",
22 | "plot_generation_masks",
23 | ]
24 |
--------------------------------------------------------------------------------
/nano_sparse_attn/utils/constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | CONSTANTS = {
5 | 'model_name': 'unsloth/Llama-3.2-1B-Instruct',
6 | 'dataset_name': 'FiscalNote/billsum',
7 | 'runtime_args': {
8 | 'dtype': torch.bfloat16,
9 | 'num_examples': 3,
10 | 'target_length_min': 4096,
11 | 'target_length_max': 4096 + 512,
12 | 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
13 | },
14 | 'hf_kwargs': {
15 | 'trust_remote_code': True,
16 | },
17 | 'save_results': True,
18 | }
19 |
--------------------------------------------------------------------------------
/nano_sparse_attn/utils/modelling.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | from datasets import load_dataset
6 | from transformers import AutoTokenizer, AutoModelForCausalLM
7 | from transformers.models.llama.modeling_llama import LlamaAttention
8 | from transformers.models.llama.modeling_llama import Cache, repeat_kv, apply_rotary_pos_emb
9 |
10 | from .constants import CONSTANTS
11 |
12 |
13 | def llama_attention_forward(
14 | self,
15 | hidden_states: torch.Tensor,
16 | position_ids: Optional[torch.LongTensor] = None,
17 | past_key_value: Optional[Cache] = None,
18 | cache_position: Optional[torch.LongTensor] = None,
19 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
20 | **kwargs,
21 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
22 | batch_size, query_length, _ = hidden_states.size()
23 | num_heads, head_dim = self.num_heads, self.head_dim
24 | num_kv_heads, num_kv_groups = self.num_key_value_heads, self.num_key_value_groups
25 |
26 | queries = self.q_proj(hidden_states).view(batch_size, query_length, num_heads, head_dim).transpose(1, 2)
27 | keys = self.k_proj(hidden_states).view(batch_size, query_length, num_kv_heads, head_dim).transpose(1, 2)
28 | values = self.v_proj(hidden_states).view(batch_size, query_length, num_kv_heads, head_dim).transpose(1, 2)
29 |
30 | if position_embeddings is None:
31 | cos, sin = self.rotary_emb(values, position_ids)
32 | else:
33 | cos, sin = position_embeddings
34 |
35 | queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
36 |
37 | if past_key_value is not None:
38 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
39 | keys, values = past_key_value.update(keys, values, self.layer_idx, cache_kwargs)
40 |
41 | keys = repeat_kv(keys, num_kv_groups)
42 | values = repeat_kv(values, num_kv_groups)
43 |
44 | attention_output = self.attn(queries, keys, values, layer_idx=self.layer_idx)
45 |
46 | attention_output = attention_output.transpose(1, 2).reshape(batch_size, query_length, -1).contiguous()
47 | attention_output = self.o_proj(attention_output)
48 |
49 | return attention_output, None, past_key_value
50 |
51 |
52 | def update_attention(model, inference_handler):
53 | def update_attention_recursive(module):
54 | if isinstance(module, LlamaAttention):
55 | module.forward = llama_attention_forward.__get__(module, LlamaAttention)
56 | module.attn = inference_handler
57 |
58 | model.apply(update_attention_recursive)
59 |
60 |
61 | def load_model_and_tokenizer():
62 | tokenizer = AutoTokenizer.from_pretrained(
63 | CONSTANTS['model_name'],
64 | **CONSTANTS['hf_kwargs'],
65 | )
66 |
67 | model = AutoModelForCausalLM.from_pretrained(
68 | CONSTANTS['model_name'],
69 | torch_dtype=CONSTANTS['runtime_args']['dtype'],
70 | attn_implementation="eager",
71 | **CONSTANTS['hf_kwargs'],
72 | ).to(CONSTANTS['runtime_args']['device'])
73 |
74 |
75 | return model, tokenizer
76 |
77 |
78 | def load_examples(tokenizer, target_length_min=3500, target_length_max=4096, num_examples=1):
79 | dataset = load_dataset(
80 | CONSTANTS['dataset_name'],
81 | split="train",
82 | **CONSTANTS['hf_kwargs'],
83 | )
84 |
85 | examples = []
86 | for _, example in enumerate(dataset):
87 | # Separate input prompt and output
88 | input_prompt = f"""Below is a US Congressional and California state bill. Please provide a concise summary of the bill.
89 |
90 | Bill:
91 | {example['text']}"""
92 |
93 | output_text = example['summary']
94 |
95 | # Apply chat template to input
96 | templated_input = tokenizer.apply_chat_template(
97 | [{"role": "user", "content": input_prompt}],
98 | tokenize=False,
99 | add_generation_prompt=True
100 | )
101 |
102 | # Create full sequence with answer
103 | full_sequence = templated_input + output_text
104 |
105 | # Check total length
106 | tokens = tokenizer(
107 | full_sequence,
108 | return_tensors="pt",
109 | add_special_tokens=False,
110 | )
111 |
112 | total_length = tokens['input_ids'].shape[1]
113 | if target_length_min <= total_length <= target_length_max:
114 | # Get output length for loss calculation
115 | output_length = len(tokenizer(output_text, add_special_tokens=False)['input_ids'])
116 |
117 | # Create final tokens
118 | model_inputs = tokens.to(CONSTANTS['runtime_args']['device'])
119 |
120 | examples.append({
121 | 'input_ids': model_inputs['input_ids'],
122 | 'attention_mask': model_inputs['attention_mask'],
123 | 'output_length': output_length,
124 | })
125 |
126 | if len(examples) >= num_examples:
127 | break
128 |
129 | return examples
130 |
131 |
132 | def model_forward(model, model_inputs, inference_handler):
133 | total_loss = 0
134 |
135 | for example in model_inputs:
136 | # Set output length for current example
137 | inference_handler.set_output_length(example['output_length'])
138 |
139 | with torch.no_grad():
140 | outputs = model(
141 | input_ids=example['input_ids'],
142 | attention_mask=example['attention_mask'],
143 | )
144 |
145 | shift_logits = outputs.logits[..., :-1, :]
146 | shift_labels = example['input_ids'][..., 1:]
147 |
148 | # Calculate loss only over the output length
149 | output_length = example['output_length']
150 | loss_fct = nn.CrossEntropyLoss()
151 | loss = loss_fct(
152 | shift_logits.view(-1, shift_logits.size(-1))[-output_length:],
153 | shift_labels.view(-1)[-output_length:]
154 | )
155 | total_loss += loss.item()
156 |
157 | # Return average loss across examples
158 | return total_loss / len(model_inputs)
159 |
--------------------------------------------------------------------------------
/nano_sparse_attn/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | # This is improtant to see individual columns in a mask
5 | # plt.rcParams['figure.dpi']= 600
6 |
7 | def plot_prefill_masks(mask1, mask2, title):
8 | """
9 | Plot two prefill attention masks side by side with custom styling.
10 |
11 | Args:
12 | mask1: 2D boolean numpy array or tensor for left plot
13 | mask2: 2D boolean numpy array or tensor for right plot
14 | title1: Optional string to display as title for left plot
15 | title2: Optional string to display as title for right plot
16 | """
17 | # Create figure with gray background
18 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
19 | fig.patch.set_facecolor('#3d3d3d')
20 |
21 | # Custom colormap (purple to yellow)
22 | custom_cmap = plt.cm.colors.ListedColormap(['#2E1A47', '#FFE135'])
23 |
24 | # Plot left mask
25 | ax1.imshow(~mask1, cmap=custom_cmap, aspect='equal')
26 | ax1.axis('off')
27 | ax1.set_title(f'{title} Mask #1', pad=5, color='white')
28 | ax1.text(-0.05, 0.5, 'Queries', rotation=90,
29 | transform=ax1.transAxes, va='center', color='white')
30 | ax1.text(0.5, -0.05, 'Keys', transform=ax1.transAxes, ha='center', color='white')
31 | ax1.set_facecolor('#3d3d3d')
32 |
33 | # Plot right mask
34 | ax2.imshow(~mask2, cmap=custom_cmap, aspect='equal')
35 | ax2.axis('off')
36 | ax2.set_title(f'{title} Mask #2', pad=5, color='white')
37 | ax2.text(-0.05, 0.5, 'Queries', rotation=90,
38 | transform=ax2.transAxes, va='center', color='white')
39 | ax2.text(0.5, -0.05, 'Keys', transform=ax2.transAxes, ha='center', color='white')
40 | ax2.set_facecolor('#3d3d3d')
41 |
42 | plt.tight_layout()
43 |
44 |
45 | def plot_generation_masks(mask1, mask2, title, mult):
46 | """
47 | Plot two generation attention masks stacked vertically with custom styling.
48 |
49 | Args:
50 | mask1: 2D boolean numpy array or tensor for top plot (N queries × M keys)
51 | mask2: 2D boolean numpy array or tensor for bottom plot (N queries × M keys)
52 | title: String to display as base title
53 | mult: Multiplier for height of the plot
54 | """
55 | # Create figure with gray background
56 | n_queries = mask1.shape[0]
57 | height = min(n_queries * mult, 10) # Scale height with queries, bounded between 4 and 10
58 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, height))
59 | fig.patch.set_facecolor('#3d3d3d')
60 |
61 | # Custom colormap (purple for background False, yellow for True, red for generation False)
62 | colors = ['#2E1A47', '#FFE135', '#E74C3C'] # purple, yellow, red
63 | custom_cmap = plt.cm.colors.ListedColormap(colors)
64 |
65 | for mask, ax, mask_num in [(mask1, ax1, 1), (mask2, ax2, 2)]:
66 | # Create visualization array (0: background False, 1: True, 2: generation False)
67 | n_queries, n_keys = mask.shape
68 | n_gen_keys = n_queries # Last N keys are generation keys
69 |
70 | viz_array = (~mask).astype(int) # Convert to int (False -> 1, True -> 0)
71 |
72 | # Mark generation area with different color (2) if False
73 | gen_area = viz_array[:, -n_gen_keys:]
74 | viz_array[:, -n_gen_keys:] = np.where(gen_area == 1, 2, 0)
75 |
76 | # Plot mask
77 | ax.imshow(viz_array, cmap=custom_cmap, aspect='auto')
78 | ax.axis('off')
79 | ax.set_title(f'{title} Mask #{mask_num}', pad=5, color='white')
80 | ax.text(-0.02, 0.5, 'Queries', rotation=90,
81 | transform=ax.transAxes, va='center', color='white')
82 | ax.text(0.5, -0.3, 'Keys', transform=ax.transAxes, ha='center', color='white')
83 | ax.set_facecolor('#3d3d3d')
84 |
85 | plt.tight_layout()
86 |
87 |
88 | def plot_sparse_attention_results(results):
89 | # Separate Dense results and get baseline
90 | dense_results = [r for r in results if 'Dense' in r['name']]
91 | sparse_results = [r for r in results if 'Dense' not in r['name']]
92 |
93 | # Calculate dense baseline
94 | dense_baseline = np.mean([r['loss'] for r in dense_results])
95 |
96 | # Group sparse results by method
97 | method_groups = {}
98 | for result in sparse_results:
99 | # Extract method name (remove 'Attention' suffix)
100 | method = result['name'].replace('Attention', '')
101 |
102 | if method not in method_groups:
103 | method_groups[method] = {'sparsity': [], 'loss': []}
104 |
105 | method_groups[method]['sparsity'].append(result['sparsity'])
106 | method_groups[method]['loss'].append(result['loss'])
107 |
108 | # Plot settings
109 | plt.figure(figsize=(10, 6))
110 | color_palette = ['#2ecc71', '#e74c3c', '#9b59b6', '#3498db', '#f1c40f']
111 | marker_palette = ['o', 's', 'D', '^', 'v']
112 |
113 | # Plot each method
114 | for i, (method, data) in enumerate(method_groups.items()):
115 | color_idx = i % len(color_palette)
116 | marker_idx = i % len(marker_palette)
117 |
118 | plt.plot(data['sparsity'], data['loss'],
119 | label=method,
120 | color=color_palette[color_idx],
121 | marker=marker_palette[marker_idx],
122 | linewidth=2,
123 | markersize=8)
124 |
125 | # Add dense baseline
126 | plt.axhline(y=dense_baseline, color='#95a5a6', linestyle='--', label='Dense')
127 |
128 | # Customize plot
129 | plt.xlabel('Sparsity Ratio', fontsize=12)
130 | plt.ylabel('Loss', fontsize=12)
131 | plt.title('Comparison of Sparse Attention Methods', fontsize=14, pad=15)
132 | plt.grid(True, linestyle='--', alpha=0.7)
133 | plt.legend(fontsize=10)
134 |
135 | # Set axis limits
136 | plt.xlim(-0.05, 1.05)
137 | losses = [l for data in method_groups.values() for l in data['loss']] + [dense_baseline]
138 | sparsity = [s for data in method_groups.values() for s in data['sparsity']]
139 |
140 | y_min, y_max = min(losses), max(losses)
141 | y_padding = (y_max - y_min) * 0.1
142 | plt.ylim(y_min - y_padding, y_max + y_padding)
143 |
144 | x_min, x_max = min(sparsity), max(sparsity)
145 | x_padding = (x_max - x_min) * 0.1
146 | plt.xlim(x_min - x_padding, x_max + x_padding)
147 |
148 | # Save plot
149 | plt.tight_layout()
150 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="nano_sparse_attn",
5 | version="0.99",
6 | author="Piotr Nawrot",
7 | author_email="piotr@nawrot.org",
8 | description="nanoSparseAttention: PyTorch implementation of novel sparse attention mechanisms",
9 | url="https://github.com/PiotrNawrot/nanoSparseAttention",
10 | packages=find_packages(include=['nano_sparse_attn', 'nano_sparse_attn.*']),
11 | install_requires=[
12 | "torch",
13 | "datasets",
14 | "transformers",
15 | "matplotlib"
16 | ],
17 | python_requires=">=3.10",
18 | )
19 |
--------------------------------------------------------------------------------