├── .gitignore
├── README.md
├── assets
├── photo.png
└── pipfreeze.txt
├── compile.py
├── setup.py
└── sparse_frontier
├── configs
├── attention
│ ├── ada_snapkv.yaml
│ ├── block_sparse.yaml
│ ├── dense.yaml
│ ├── flexprefill.yaml
│ ├── quest.yaml
│ ├── snapkv.yaml
│ └── vertical_and_slash.yaml
├── default.yaml
├── model
│ ├── qwen_14b.yaml
│ ├── qwen_32b.yaml
│ ├── qwen_72b.yaml
│ └── qwen_7b.yaml
└── task
│ ├── qa_quality.yaml
│ ├── qa_squad.yaml
│ ├── qa_toefl.yaml
│ ├── ruler_cwe.yaml
│ ├── ruler_niah.yaml
│ ├── ruler_vt.yaml
│ ├── story_filtering.yaml
│ ├── story_multihop.yaml
│ └── story_retrieval.yaml
├── evaluation.py
├── main.py
├── modelling
├── __init__.py
├── attention
│ ├── __init__.py
│ ├── abstract_attention.py
│ ├── efficient_decoding.py
│ ├── efficient_prefilling.py
│ ├── handler.py
│ ├── kv_compression.py
│ ├── minference
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── csrc
│ │ │ ├── kernels.cpp
│ │ │ └── vertical_slash_index.cu
│ │ └── vertical_and_slash.py
│ └── registry.py
├── models
│ ├── __init__.py
│ ├── abstract_model.py
│ └── vllm_model.py
└── tokenizer.py
├── prediction.py
├── preparation.py
├── tasks
├── __init__.py
├── abstract_prompt.py
├── abstract_sample.py
├── abstract_task.py
├── qa
│ ├── data
│ │ ├── quality.jsonl
│ │ ├── squad.json
│ │ └── toeflqa.jsonl
│ ├── qa_data.py
│ ├── qa_task.py
│ └── qa_utils.py
├── registry.py
├── ruler
│ ├── cwe.py
│ ├── niah.py
│ └── vt.py
└── story
│ ├── filtering.py
│ ├── multihop.py
│ ├── narrative.py
│ ├── retrieval.py
│ └── templates.py
└── utils
├── __init__.py
├── checks.py
├── data.py
├── general.py
└── globals.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | .vscode/
4 | .venv
5 | .eggs
6 | *.egg-info
7 | *.log
8 | *.so
9 | build
10 | .DS_Store
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## TL;DR
6 |
7 | This repository contains the official implementation for the paper "[The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs](https://arxiv.org/abs/2504.17768)". We perform a large-scale empirical evaluation of *training-free* sparse attention methods in LLMs (7B to 72B parameters) on long sequences (16K to 128K tokens) across diverse tasks.
8 |
9 | **Key Findings:**
10 | 1. **IsoFLOPS:** For very long sequences, larger, highly sparse models are preferable to smaller, dense ones under a fixed compute budget.
11 | 2. **Sparsity Limits:** Higher sparsity can be tolerated during decoding than prefilling while preserving accuracy (correlating with model size), but even moderate sparsity often degrades performance on at least one task.
12 | 3. **No Universal Best:** The optimal sparse attention strategy depends on the task and inference phase; no single method excels everywhere.
13 | 4. **Scaling Laws:** We introduce and validate scaling laws specific to sparse attention, suggesting our findings generalize.
14 |
15 | **Conclusion:** Sparse attention is a key tool for long-context LLMs but requires careful evaluation of accuracy-efficiency trade-offs for specific applications. This codebase enables reproducing our experiments and further research.
16 |
17 | ## Setup
18 |
19 | Follow these steps to set up the environment and prepare for running experiments:
20 |
21 | 1. **Create Virtual Environment and Install Dependencies:**
22 | Set up a dedicated Python environment and install the required packages, including compiling custom CUDA kernels.
23 |
24 | ```bash
25 | # Create a virtual environment using Python 3.11
26 | python3.11 -m venv .venv
27 |
28 | # Activate the virtual environment
29 | source .venv/bin/activate
30 |
31 | # Upgrade pip and install essential build/utility tools
32 | pip install --no-cache-dir --upgrade pip setuptools wheel psutil ninja
33 |
34 | # Install PyTorch
35 | pip install --no-cache-dir torch==2.5.1
36 |
37 | # Install the sparse_frontier project in editable mode
38 | pip install --no-cache-dir -e .
39 |
40 | # Compile custom CUDA kernels (for MInference attention)
41 | # Adjust MAX_JOBS based on your system core count for faster compilation
42 | MAX_JOBS=8 python compile.py build_ext --inplace --build-lib ./sparse_frontier/modelling/attention/minference
43 | ```
44 |
45 | 2. **Configure Paths:**
46 | Modify the default configuration file to specify where data, results, and checkpoints should be stored on your system.
47 |
48 | * Edit the `paths` section in `configs/default.yaml`.
49 |
50 | 3. **Download Model Checkpoints:**
51 | Obtain the pre-trained model weights you intend to evaluate from Hugging Face Hub.
52 |
53 | * Ensure the final directory structure for the downloaded checkpoints matches the format expected by the corresponding model configuration file (e.g., as defined in `configs/model/qwen_7b.yaml`). The `model.path` variable in these configs should point to the directory containing the model files.
54 |
55 | ## Where should I look at if I want to:
56 |
57 | ### Reproduce your experiments
58 |
59 | Experiments are launched using the main script `sparse_frontier.main`. Configuration is managed via [Hydra](https://hydra.cc/), allowing parameters to be set in YAML files (primarily `configs/default.yaml`) and overridden directly from the command line.
60 |
61 | The execution pipeline typically involves three stages, controlled by the `mode` parameter (defaulting to `all`):
62 | 1. **Preparation (`preparation.py`):** Generates and saves task-specific data based on the selected `task` configuration. Tasks are defined in `sparse_frontier/tasks/` (inheriting from `AbstractTask` and `AbstractSample`) and registered in `sparse_frontier/tasks/registry.py`.
63 | 2. **Prediction (`prediction.py`):** Runs the specified `model` with the chosen `attention` mechanism on the prepared data, saving the model outputs. Attention mechanisms are defined in `sparse_frontier/modelling/attention/` and registered in `sparse_frontier/modelling/attention/registry.py`.
64 | 3. **Evaluation (`evaluation.py`):** Compares the predictions against the gold answers using the task's specific evaluation logic and saves the final metrics.
65 |
66 | For example, to run the `ruler_niah` task using the `llama_8b` model configuration with standard `dense` attention on 4 samples using 2 GPUs:
67 |
68 | ```bash
69 | python -m sparse_frontier.main task=ruler_niah model=llama_8b attention=dense samples=4 gpus=2
70 | ```
71 |
72 | ### Understand how you modified vLLM to test my own Sparse Attention
73 |
74 | We integrate custom sparse attention mechanisms by intercepting and modifying vLLM's standard attention execution flow. Here's a breakdown of the key components involved:
75 |
76 | 1. **Patching vLLM's Attention:** We replace vLLM's default `FlashAttentionImpl.forward` method with our custom function, `vllm_patched_forward` (defined in `sparse_frontier/modelling/models/vllm_model.py`). This function serves as the entry point for our custom attention logic within the vLLM generation loop.
77 |
78 | 2. **Centralized Handling:** The `vllm_patched_forward` function delegates the core processing to an `AttentionHandler` instance (from `sparse_frontier/modelling/attention/handler.py`). This handler manages layer-specific state (like token counts per head) and differentiates between the prefill and decoding phases of generation.
79 |
80 | 3. **Abstract Attention Interface:** The actual attention computation logic for different patterns is encapsulated in classes that inherit from `AbstractAttention` (defined in `sparse_frontier/modelling/attention/abstract_attention.py`). The `AttentionHandler` retrieves the currently configured attention implementation using `get_attention()` (from `sparse_frontier/modelling/attention/registry.py`).
81 |
82 | 4. **Implementing a Custom Pattern:** To introduce a new sparse attention mechanism:
83 | * Create a new class inheriting from `AbstractAttention`.
84 | * Implement the necessary methods based on your pattern's requirements:
85 | * `__call__(self, queries, keys, values, layer_idx)`: Implement the attention computation logic for the prefill phase. The default implementation uses standard FlashAttention.
86 | * `decode(self, query, keys, values, k_cache, v_cache, cache_seqlens, output, layer_idx)`: Implement the attention computation for the single-token decoding phase, typically involving interaction with the KV cache. The default uses `flash_attn_with_kvcache`. Specific methods like Quest (`efficient_decoding.py`) implement custom logic (e.g., page selection).
87 | * `kv_compress(self, queries, keys, values)`: (Optional) Implement logic to compress or select keys and values *after* the prefill computation, before they are written to the KV cache by `update_kv_cache` in `handler.py`. See `SnapKVCompression` (`kv_compression.py`) for an example. It should return the processed keys, values, and the resulting sequence lengths per head.
88 |
89 | 5. **Registration:** Add your new class to the `ATTENTION_REGISTRY` in `sparse_frontier/modelling/attention/registry.py`. This allows selecting your custom attention mechanism through the experiment configuration files.
90 |
91 | ### Understand or Create Experimental Datasets
92 |
93 | Experimental data generation is handled by task-specific modules located in `sparse_frontier/tasks/`. Each task implements `AbstractTask` and `AbstractSample` subclasses (defined in `sparse_frontier/tasks/abstract_*.py`) to define data loading, preprocessing, and the formatting of individual input prompts. Tasks are registered in `sparse_frontier/tasks/registry.py` and selected via configuration (e.g., `task=your_task_name`). The `preparation.py` script orchestrates the generation process based on the configuration, saving the formatted samples. See existing tasks like `QATask` (`qa_task.py`) or the Story task (`narrative.py`, `templates.py`) for implementation examples.
94 |
95 | ## References
96 |
97 | ### Sparse Attention Patterns
98 |
99 | In this repository, we evaluate 6 sparse attention patterns:
100 |
101 | | Pattern | Source |
102 | |---------|--------|
103 | | **Vertical-Slash / Block-Sparse** | [Microsoft](https://github.com/microsoft/MInference) |
104 | | **FlexPrefill** | [ByteDance-Seed](https://github.com/ByteDance-Seed/FlexPrefill) |
105 | | **SnapKV** | [FasterDecoding](https://github.com/FasterDecoding/SnapKV) |
106 | | **Ada-SnapKV** | [FFY0](https://github.com/FFY0/AdaKV) |
107 | | **Quest** | [MIT-HAN-Lab](https://github.com/mit-han-lab/Quest) |
108 |
109 | We either re-implement these patterns based on the original code or borrow implementations including kernels (for Vertical-Slash and Block-Sparse) from MInference.
110 |
111 | ### Evaluation Tasks
112 |
113 | Our evaluation framework includes the following tasks:
114 |
115 | 1. **RULER Tasks**: Re-implementation of NIAH, VT, and CWE tasks from [NVIDIA/RULER](https://github.com/NVIDIA/RULER)
116 |
117 | 2. **QA Tasks**:
118 | - Toefl and Quality datasets from [LC-VS-RAG](https://github.com/lixinze777/LC_VS_RAG)
119 | - Squad dataset from [NVIDIA/RULER](https://github.com/NVIDIA/RULER)
120 |
121 | 3. **Novel Story Tasks**: Narrative tasks developed specifically for this project.
122 |
123 | ## Cite
124 |
125 | If you found the repository useful consider citing the paper about this work.
126 |
127 | ```
128 | @article{nawrot2025sparsefrontier,
129 | title={The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs},
130 | author={Piotr Nawrot and Robert Li and Renjie Huang and Sebastian Ruder and Kelly Marchisio and Edoardo M. Ponti},
131 | year={2025},
132 | journal={arXiv:2504.17768}
133 | url={https://arxiv.org/abs/2504.17768},
134 | }
135 | ```
136 |
137 | ## Issues:
138 |
139 | If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk
140 |
--------------------------------------------------------------------------------
/assets/photo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/assets/photo.png
--------------------------------------------------------------------------------
/assets/pipfreeze.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.3.0
2 | aiohappyeyeballs==2.6.1
3 | aiohttp==3.11.16
4 | aiohttp-cors==0.8.1
5 | aiosignal==1.3.2
6 | airportsdata==20250224
7 | annotated-types==0.7.0
8 | antlr4-python3-runtime==4.9.3
9 | anyio==4.9.0
10 | astor==0.8.1
11 | attrs==25.3.0
12 | blake3==1.0.4
13 | cachetools==5.5.2
14 | certifi==2025.1.31
15 | charset-normalizer==3.4.1
16 | click==8.1.8
17 | cloudpickle==3.1.1
18 | colorful==0.5.6
19 | compressed-tensors==0.8.1
20 | contourpy==1.3.2
21 | cycler==0.12.1
22 | datasets==3.2.0
23 | depyf==0.18.0
24 | dill==0.3.8
25 | diskcache==5.6.3
26 | distlib==0.3.9
27 | distro==1.9.0
28 | einops==0.8.1
29 | fastapi==0.115.12
30 | filelock==3.18.0
31 | flash-attn==2.7.3
32 | fonttools==4.57.0
33 | frozenlist==1.5.0
34 | fsspec==2024.9.0
35 | gguf==0.10.0
36 | gitdb==4.0.12
37 | GitPython==3.1.44
38 | google-api-core==2.24.2
39 | google-auth==2.39.0
40 | googleapis-common-protos==1.70.0
41 | grpcio==1.71.0
42 | h11==0.14.0
43 | httpcore==1.0.8
44 | httptools==0.6.4
45 | httpx==0.28.1
46 | huggingface-hub==0.30.2
47 | hydra-core==1.3.2
48 | idna==3.10
49 | importlib_metadata==8.6.1
50 | interegular==0.3.3
51 | Jinja2==3.1.6
52 | jiter==0.9.0
53 | jsonschema==4.23.0
54 | jsonschema-specifications==2024.10.1
55 | kiwisolver==1.4.8
56 | lark==1.2.2
57 | lm-format-enforcer==0.10.11
58 | MarkupSafe==3.0.2
59 | matplotlib==3.10.0
60 | mistral_common==1.5.4
61 | mpmath==1.3.0
62 | msgpack==1.1.0
63 | msgspec==0.19.0
64 | multidict==6.4.3
65 | multiprocess==0.70.16
66 | nest-asyncio==1.6.0
67 | networkx==3.4.2
68 | ninja==1.11.1.4
69 | numpy==1.26.4
70 | nvidia-cublas-cu12==12.4.5.8
71 | nvidia-cuda-cupti-cu12==12.4.127
72 | nvidia-cuda-nvrtc-cu12==12.4.127
73 | nvidia-cuda-runtime-cu12==12.4.127
74 | nvidia-cudnn-cu12==9.1.0.70
75 | nvidia-cufft-cu12==11.2.1.3
76 | nvidia-curand-cu12==10.3.5.147
77 | nvidia-cusolver-cu12==11.6.1.9
78 | nvidia-cusparse-cu12==12.3.1.170
79 | nvidia-ml-py==12.570.86
80 | nvidia-nccl-cu12==2.21.5
81 | nvidia-nvjitlink-cu12==12.4.127
82 | nvidia-nvtx-cu12==12.4.127
83 | omegaconf==2.3.0
84 | openai==1.74.0
85 | opencensus==0.11.4
86 | opencensus-context==0.1.3
87 | opencv-python-headless==4.11.0.86
88 | outlines==0.1.11
89 | outlines_core==0.1.26
90 | packaging==24.2
91 | pandas==2.2.3
92 | partial-json-parser==0.2.1.1.post5
93 | patsy==1.0.1
94 | pillow==11.2.1
95 | platformdirs==4.3.7
96 | prometheus-fastapi-instrumentator==7.1.0
97 | prometheus_client==0.21.1
98 | propcache==0.3.1
99 | proto-plus==1.26.1
100 | protobuf==6.30.2
101 | psutil==7.0.0
102 | py-cpuinfo==9.0.0
103 | py-spy==0.4.0
104 | pyarrow==19.0.1
105 | pyasn1==0.6.1
106 | pyasn1_modules==0.4.2
107 | pycountry==24.6.1
108 | pydantic==2.11.3
109 | pydantic_core==2.33.1
110 | pyparsing==3.2.3
111 | python-dateutil==2.9.0.post0
112 | python-dotenv==1.1.0
113 | pytz==2025.2
114 | PyYAML==6.0.2
115 | pyzmq==26.4.0
116 | ray==2.44.1
117 | referencing==0.36.2
118 | regex==2024.11.6
119 | requests==2.32.3
120 | rpds-py==0.24.0
121 | rsa==4.9.1
122 | safetensors==0.5.3
123 | scipy==1.15.2
124 | seaborn==0.13.2
125 | sentencepiece==0.2.0
126 | six==1.17.0
127 | smart-open==7.1.0
128 | smmap==5.0.2
129 | sniffio==1.3.1
130 | starlette==0.46.2
131 | statsmodels==0.14.4
132 | sympy==1.13.1
133 | tiktoken==0.9.0
134 | tokenizers==0.21.1
135 | torch==2.5.1
136 | torchvision==0.20.1
137 | tqdm==4.67.1
138 | transformers==4.48.0
139 | triton==3.1.0
140 | typing-inspection==0.4.0
141 | typing_extensions==4.13.2
142 | tzdata==2025.2
143 | urllib3==2.4.0
144 | uvicorn==0.34.1
145 | uvloop==0.21.0
146 | virtualenv==20.30.0
147 | vllm==0.6.6.post1
148 | watchfiles==1.0.5
149 | websockets==15.0.1
150 | wonderwords==2.2.0
151 | wrapt==1.17.2
152 | xformers==0.0.28.post3
153 | xgrammar==0.1.18
154 | xxhash==3.5.0
155 | yarl==1.19.0
156 | zipp==3.21.0
157 |
--------------------------------------------------------------------------------
/compile.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='minference',
6 | ext_modules=[
7 | CUDAExtension(
8 | name='minference',
9 | sources=['./sparse_frontier/modelling/attention/minference/csrc/kernels.cpp', './sparse_frontier/modelling/attention/minference/csrc/vertical_slash_index.cu'],
10 | ),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | },
15 | package_dir={'': 'sparse_frontier/modelling/attention/minference'}
16 | )
17 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="sparse_frontier",
5 | version="0.01",
6 | description="Official implementation of the Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs",
7 | url="https://github.com/PiotrNawrot/sparse-frontier",
8 | packages=find_packages(include=['sparse_frontier', 'sparse_frontier.*']),
9 | entry_points={
10 | 'vllm.general_plugins':[
11 | "swap_vllm_attention = sparse_frontier.modelling.models.vllm_model:swap_vllm_attention"
12 | ]
13 | },
14 | install_requires=[
15 | "transformers==4.48.0",
16 | "datasets==3.2.0",
17 | "flash-attn==2.7.3",
18 | "vllm==0.6.6.post1",
19 | "accelerate==1.3.0",
20 | "hydra-core==1.3.2",
21 | "omegaconf==2.3.0",
22 | "matplotlib==3.10.0",
23 | "wonderwords",
24 | "gitpython",
25 | "matplotlib",
26 | "pandas",
27 | "seaborn",
28 | "statsmodels",
29 | "pyyaml",
30 | "numpy",
31 | ],
32 | python_requires=">=3.10",
33 | )
34 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/ada_snapkv.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: ada_snapkv
5 | args:
6 | token_capacity: 1024
7 | approximation_window: 256
8 | kernel_size: 21
9 | local_window: 128
10 | prefix_tokens: 4
11 | min_head_capacity_ratio: 0.2
12 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/block_sparse.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: block_sparse
5 | args:
6 | chunk_size: 16
7 | top_chunks: 64
8 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/dense.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: dense
5 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/flexprefill.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: flexprefill
5 | args:
6 | alpha: 0.65
7 | approximation_size: 512
8 | min_budget: 512
9 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/quest.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: quest
5 | args:
6 | token_budget: 1024
7 | page_size: 16
8 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/snapkv.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: snapkv
5 | args:
6 | token_capacity: 1024
7 | approximation_window: 256
8 | kernel_size: 21
9 | local_window: 128
10 | prefix_tokens: 4
11 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/attention/vertical_and_slash.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | attention:
4 | name: vertical_and_slash
5 | args:
6 | vertical_size: 512
7 | slash_size: 512
8 | approximation_size: 256
9 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - model: qwen_7b
4 | - attention: dense
5 | - task: ruler_niah
6 |
7 | overwrite: false # overwrite current results
8 | debug: false # changes the directory we create data and results in - this makes overwrite safe to use
9 | mode: "all"
10 | gpus: 1
11 |
12 | samples: 100
13 | max_input_tokens: 8192
14 | max_output_tokens: 1024
15 | kv_cache_block_size: 256
16 | random_seed: 43
17 |
18 | print_eval_results: true
19 |
20 | paths:
21 | results: /writeable/evaluation_results
22 | predictions: /writeable/predictions
23 | data: /writeable/data
24 | debug: /writeable/debug
25 | checkpoints: /writeable/checkpoints
26 |
27 | hydra:
28 | run:
29 | dir: /writeable/evaluation_results/hydraoutputs/${now:%Y-%m-%d_%H-%M-%S}
30 | job:
31 | env_set:
32 | OMP_NUM_THREADS: 1
33 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/model/qwen_14b.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model:
4 | name: qwen_14b
5 | path: ${paths.checkpoints}/Qwen2.5-14B-Instruct
6 | num_q_heads: 40
7 | num_kv_heads: 8
8 | num_layers: 48
9 | hidden_dim: 5120
10 | intermediate_dim: 13824
11 | vocab_size: 152064
12 |
13 | tp: 1
14 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/model/qwen_32b.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model:
4 | name: qwen_32b
5 | path: ${paths.checkpoints}/Qwen2.5-32B-Instruct
6 | num_q_heads: 40
7 | num_kv_heads: 8
8 | num_layers: 64
9 | hidden_dim: 5120
10 | intermediate_dim: 27648
11 | vocab_size: 152064
12 |
13 | tp: 2
14 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/model/qwen_72b.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model:
4 | name: qwen_72b
5 | path: ${paths.checkpoints}/Qwen2.5-72B-Instruct
6 | num_q_heads: 64
7 | num_kv_heads: 8
8 | num_layers: 80
9 | hidden_dim: 8192
10 | intermediate_dim: 29568
11 | vocab_size: 152064
12 |
13 | tp: 4
14 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/model/qwen_7b.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model:
4 | name: qwen_7b
5 | path: ${paths.checkpoints}/Qwen2.5-7B-Instruct
6 | num_q_heads: 28
7 | num_kv_heads: 4
8 | num_layers: 28
9 | hidden_dim: 3584
10 | intermediate_dim: 18944
11 | vocab_size: 152064
12 |
13 | tp: 1
14 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/qa_quality.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: qa_quality
5 | args:
6 | dataset_name: quality
7 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/qa_squad.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: qa_squad
5 | args:
6 | dataset_name: squad
7 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/qa_toefl.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: qa_toefl
5 | args:
6 | dataset_name: toeflqa
7 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/ruler_cwe.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: ruler_cwe
5 | args:
6 | num_common_words: 10
7 | common_word_frequency: 30
8 | rare_word_frequency: 3
9 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/ruler_niah.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: ruler_niah
5 | args:
6 | num_queries: 4
7 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/ruler_vt.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: ruler_vt
5 | args:
6 | num_chains: 8
7 | num_hops: 4
8 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/story_filtering.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: story_filtering
5 | args:
6 | chapters_in_question: 3
7 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/story_multihop.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: story_multihop
5 |
--------------------------------------------------------------------------------
/sparse_frontier/configs/task/story_retrieval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task:
4 | name: story_retrieval
5 | args:
6 | num_queries: 16
7 |
--------------------------------------------------------------------------------
/sparse_frontier/evaluation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import json
3 | from sparse_frontier.utils.general import GlobalSettings
4 | from sparse_frontier.utils.data import get_pred_path, get_data_path, read_jsonl, get_results_path
5 | from sparse_frontier.tasks.registry import TASK_REGISTRY
6 |
7 |
8 | def merge_data_and_predictions(data: List[dict], predictions: List[dict]) -> List[dict]:
9 | """Merge data samples with model predictions based on index.
10 |
11 | Args:
12 | data: List of dictionaries containing input data and gold answers
13 | predictions: List of dictionaries containing model predictions and metrics
14 |
15 | Returns:
16 | List of merged dictionaries with all fields
17 |
18 | Raises:
19 | AssertionError: If indexes don't match between data and predictions
20 | """
21 | # Create index mappings
22 | data_by_index = {item['index']: item for item in data}
23 | pred_by_index = {item['index']: item for item in predictions}
24 |
25 | # Verify all indexes match
26 | data_indexes = set(data_by_index.keys())
27 | pred_indexes = set(pred_by_index.keys())
28 | assert data_indexes == pred_indexes, \
29 | f"Mismatch between data and prediction indexes. Missing from data: {pred_indexes - data_indexes}. " \
30 | f"Missing from predictions: {data_indexes - pred_indexes}"
31 |
32 | # Merge data and predictions
33 | merged = []
34 | for idx in data_indexes:
35 | sample = data_by_index[idx].copy()
36 | sample.update(pred_by_index[idx])
37 | merged.append(sample)
38 |
39 | assert len(merged) == len(data) == len(predictions), \
40 | f"Merged data length {len(merged)} doesn't match input lengths: data={len(data)}, predictions={len(predictions)}"
41 |
42 | return merged
43 |
44 |
45 | def evaluate_task() -> None:
46 | cfg = GlobalSettings.get('cfg')
47 |
48 | results_file = get_results_path()
49 |
50 | # Load and merge data and predictions
51 | data = [x for x in read_jsonl(get_data_path()) if x['index'] < cfg.samples]
52 | predictions = [x for x in read_jsonl(get_pred_path()) if x['index'] < cfg.samples]
53 | examples = merge_data_and_predictions(data, predictions)
54 |
55 | metrics = TASK_REGISTRY[cfg.task.name].evaluate(examples)
56 |
57 | # Add total number of samples evaluated
58 | metrics['total_samples'] = len(examples)
59 |
60 | # Calculate average sparsity
61 | total_sparsity = sum(example['sparsity'] for example in examples)
62 | metrics['average_attention_sparsity'] = total_sparsity / len(examples)
63 |
64 | # Calculate average and max output token length if available
65 | if any('output_tokens_len' in example for example in examples):
66 | total_output_tokens = sum(example['output_tokens_len'] for example in examples if 'output_tokens_len' in example)
67 | num_examples_with_len = sum(1 for example in examples if 'output_tokens_len' in example)
68 | metrics['average_output_tokens'] = total_output_tokens / num_examples_with_len
69 | metrics['max_output_tokens'] = max(example['output_tokens_len'] for example in examples if 'output_tokens_len' in example)
70 |
71 | with open(results_file, 'w', encoding='utf-8') as f:
72 | json.dump(metrics, f, indent=4)
73 |
74 | print(f'Evaluation results for task {cfg.task.name} saved to {results_file}')
75 |
76 | if cfg.print_eval_results:
77 | print(f'Evaluation results for task {cfg.task.name}:')
78 | print(json.dumps(metrics, indent=2))
79 |
--------------------------------------------------------------------------------
/sparse_frontier/main.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from omegaconf import DictConfig
3 |
4 |
5 | def setting_up(cfg):
6 | import os
7 | import random
8 | import numpy as np
9 |
10 | from sparse_frontier.utils import GlobalSettings
11 | from sparse_frontier.utils.data import get_data_path, get_pred_path, get_results_path
12 |
13 | random.seed(cfg.random_seed)
14 | np.random.seed(cfg.random_seed)
15 | # In case of torch seed it's set in vLLM Model class
16 |
17 | if cfg.overwrite:
18 | assert cfg.debug, "Overwrite is only allowed in debug mode"
19 |
20 | GlobalSettings.set('cfg', cfg)
21 |
22 | os.makedirs(os.path.dirname(get_data_path()), exist_ok=True)
23 | os.makedirs(os.path.dirname(get_pred_path()), exist_ok=True)
24 | os.makedirs(os.path.dirname(get_results_path()), exist_ok=True)
25 |
26 |
27 | def run(cfg):
28 | setting_up(cfg)
29 |
30 | from sparse_frontier.utils.checks import prepration_needed, prediction_needed, evaluation_needed
31 |
32 | if cfg.mode in ["prep", "all"]:
33 | if prepration_needed():
34 | from sparse_frontier.preparation import prepare_task
35 | prepare_task()
36 |
37 | if cfg.mode in ["pred", "all"]:
38 | if prediction_needed():
39 | from sparse_frontier.prediction import predict_task
40 | predict_task()
41 |
42 | if cfg.mode in ["eval", "all"]:
43 | if evaluation_needed():
44 | from sparse_frontier.evaluation import evaluate_task
45 | evaluate_task()
46 |
47 | if cfg.mode not in ["prep", "pred", "eval", "all"]:
48 | raise ValueError(f'Invalid mode: {cfg.mode}')
49 |
50 |
51 | @hydra.main(config_path="configs", config_name="default", version_base="1.3")
52 | def main(cfg: DictConfig):
53 | run(cfg)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/__init__.py
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/attention/__init__.py
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/abstract_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from abc import ABC
4 | from vllm.distributed import (
5 | get_tensor_model_parallel_world_size,
6 | tensor_model_parallel_all_gather,
7 | )
8 | from flash_attn import flash_attn_func
9 | from vllm.attention.backends.flash_attn import flash_attn_with_kvcache
10 |
11 |
12 | class AttentionUtils:
13 | @staticmethod
14 | def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
15 | """Compute attention using Flash Attention
16 |
17 | flash_attn_func expects BLHD format, so we need to convert the input tensors to this format.
18 |
19 | Args:
20 | q: Query tensor of shape (batch_size, num_heads, seq_len, head_dim)
21 | k: Key tensor of shape (batch_size, num_kv_heads, seq_len, head_dim)
22 | v: Value tensor of shape (batch_size, num_kv_heads, seq_len, head_dim)
23 |
24 | Returns:
25 | Attention output tensor of shape (batch_size, num_heads, seq_len, head_dim)
26 | """
27 | out = flash_attn_func(
28 | q.transpose(1, 2),
29 | k.transpose(1, 2),
30 | v.transpose(1, 2),
31 | causal=True,
32 | )
33 | return out.transpose(1, 2)
34 |
35 | @staticmethod
36 | def reshape_kv_cache(
37 | kv_cache: torch.Tensor,
38 | target_block_size: int,
39 | max_blocks: int,
40 | ) -> tuple[torch.Tensor, torch.Tensor]:
41 | """Retrieve keys and values from cache for attention computation.
42 |
43 | Args:
44 | kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size]
45 | target_block_size: Target block size for reshaping
46 |
47 | Returns:
48 | k_cache: [num_kv_heads, num_blocks, target_block_size, head_size]
49 | v_cache: [num_kv_heads, num_blocks, target_block_size, head_size]
50 | """
51 | num_blocks, block_size, num_kv_heads, head_size = kv_cache[0].shape
52 | final_num_blocks = min(max_blocks, (num_blocks * block_size) // target_block_size)
53 | left_original_num_blocks = (final_num_blocks * target_block_size) // block_size
54 |
55 | k_cache = kv_cache[0, :left_original_num_blocks, :, :].view(num_kv_heads, final_num_blocks, target_block_size, head_size)
56 | v_cache = kv_cache[1, :left_original_num_blocks, :, :].view(num_kv_heads, final_num_blocks, target_block_size, head_size)
57 |
58 | return k_cache, v_cache
59 |
60 |
61 | class AbstractAttention(ABC):
62 | """Base class for attention implementations (both prefilling and KV compression)"""
63 | def __init__(self):
64 | self.sparsity_statistics = []
65 | self.layer_sparsity_statistics = []
66 | self.block_table = None
67 | self.cache_batch_idx = None
68 |
69 | def reset_sparsity_statistics(self):
70 | """Reset the accumulated sparsity statistics."""
71 | self.sparsity_statistics = []
72 | self.layer_sparsity_statistics = []
73 |
74 | def sync_and_calc_layer_stats(self):
75 | # Ensure we have layer sparsity statistics to process
76 | if not self.layer_sparsity_statistics:
77 | raise AssertionError("Layer sparsity statistics list is empty. Make sure statistics are collected before syncing.")
78 |
79 | layer_sparsity = torch.stack(self.layer_sparsity_statistics).mean(dim=0, keepdim=True)
80 |
81 | if get_tensor_model_parallel_world_size() > 1:
82 | layer_sparsity = tensor_model_parallel_all_gather(layer_sparsity)
83 |
84 | self.sparsity_statistics.append(layer_sparsity.mean().item())
85 | self.layer_sparsity_statistics = []
86 |
87 | def calculate_sparsity(self) -> float:
88 | return sum(self.sparsity_statistics) / len(self.sparsity_statistics)
89 |
90 | def __call__(
91 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int
92 | ) -> torch.Tensor:
93 | """Compute attention with pattern-specific masking.
94 |
95 | Args:
96 | queries: Query tensor of shape (batch_size, num_heads, seq_len, head_dim)
97 | keys: Key tensor of shape (batch_size, num_kv_heads, seq_len, head_dim)
98 | values: Value tensor of shape (batch_size, num_kv_heads, seq_len, head_dim)
99 | layer_idx: Index of the current transformer layer
100 | Returns:
101 | Attention output tensor of shape (batch_size, num_heads, seq_len, head_dim)
102 | """
103 | return AttentionUtils.flash_attention(queries, keys, values)
104 |
105 | def decode(
106 | self,
107 | query: torch.Tensor, # [1, num_heads, head_dim]
108 | keys: torch.Tensor, # [1, num_kv_heads, head_dim]
109 | values: torch.Tensor, # [1, num_kv_heads, head_dim]
110 | k_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim]
111 | v_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim]
112 | cache_seqlens: torch.Tensor, # [num_heads]
113 | output: torch.Tensor, # [1, num_heads, head_dim]
114 | layer_idx: int,
115 | ) -> torch.Tensor:
116 | """Compute attention during decoding phase using flash_attn_with_kvcache.
117 |
118 | Args:
119 | query: Query tensor for a single token [1, num_heads, head_dim]
120 | keys: Key tensor for the current token [1, num_kv_heads, head_dim]
121 | values: Value tensor for the current token [1, num_kv_heads, head_dim]
122 | k_cache: Key cache tensor [num_kv_heads, num_blocks, block_size, head_dim]
123 | v_cache: Value cache tensor [num_kv_heads, num_blocks, block_size, head_dim]
124 | cache_seqlens: Tensor of sequence lengths per head [num_heads]
125 | output: Output tensor to store results [1, num_heads, head_dim]
126 | layer_idx: Index of the current transformer layer
127 | """
128 | _, num_q_heads, _ = query.shape
129 | num_kv_heads, num_blocks, block_size, head_size = k_cache.shape
130 |
131 | if self.block_table is None:
132 | block_indices = torch.arange(num_blocks * num_kv_heads, device=query.device, dtype=torch.int32).reshape(num_kv_heads, num_blocks)
133 | block_indices = block_indices.repeat(1, num_q_heads // num_kv_heads)
134 | self.block_table = block_indices.reshape(num_q_heads, num_blocks)
135 |
136 | flash_attn_with_kvcache(
137 | q=query.squeeze(0).unsqueeze(1).unsqueeze(1),
138 | k_cache=k_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size),
139 | v_cache=v_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size),
140 | block_table=self.block_table,
141 | cache_seqlens=cache_seqlens,
142 | causal=True,
143 | out=output.squeeze(0).unsqueeze(1).unsqueeze(1),
144 | )
145 |
146 | def kv_compress(
147 | self,
148 | queries: torch.Tensor, # [num_tokens, num_heads, head_dim]
149 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
150 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
151 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152 | """Compress KV cache after prefilling (default: no compression)
153 |
154 | Returns:
155 | tuple: (compressed_keys, compressed_values, seq_lens) where:
156 | - compressed_keys: [num_kv_heads, max_seq_len, head_size]
157 | - compressed_values: [num_kv_heads, max_seq_len, head_size]
158 | - seq_lens: [num_kv_heads] tensor with actual sequence length per head
159 | """
160 | # Default implementation: no compression, all tokens kept
161 | seq_lens = torch.full((keys.size(1),), keys.size(0), device=keys.device, dtype=torch.long)
162 | # Transpose keys and values to match the expected output shape
163 | keys_t = keys.transpose(0, 1) # [num_kv_heads, num_tokens, head_size]
164 | values_t = values.transpose(0, 1) # [num_kv_heads, num_tokens, head_size]
165 | return keys_t, values_t, seq_lens
166 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/efficient_decoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Tuple, List
3 | from .abstract_attention import AbstractAttention
4 | from .abstract_attention import AttentionUtils
5 | from vllm.attention.backends.flash_attn import flash_attn_with_kvcache
6 |
7 |
8 | def _update_last_page(
9 | page_reps: torch.Tensor,
10 | keys: torch.Tensor, # [1, num_heads, head_dim]
11 | cache_seqlens: int,
12 | page_size: int,
13 | ):
14 | """Update representations of the page containing the current token.
15 |
16 | Args:
17 | page_reps: Page representations tensor
18 | keys: Key tensor for the current token
19 | cache_seqlens: Cache sequence length as an integer
20 | page_size: Size of each page in the KV cache
21 | """
22 | current_page_idx = (cache_seqlens - 1) // page_size
23 |
24 | page_reps[current_page_idx, 0] = torch.minimum(
25 | page_reps[current_page_idx, 0],
26 | keys.squeeze(0)
27 | )
28 |
29 | page_reps[current_page_idx, 1] = torch.maximum(
30 | page_reps[current_page_idx, 1],
31 | keys.squeeze(0)
32 | )
33 |
34 |
35 | def _select_pages(
36 | page_reps: torch.Tensor,
37 | query: torch.Tensor,
38 | cache_seqlens: int,
39 | page_size: int,
40 | page_budget: int,
41 | offsets: torch.Tensor,
42 | ) -> Tuple[torch.Tensor, int]:
43 | """Select most relevant pages based on query-page similarity.
44 |
45 | Args:
46 | page_reps: Page representations tensor
47 | query: Query tensor for the current token
48 | cache_seqlens: Cache sequence length as an integer
49 | page_size: Size of each page in the KV cache
50 | page_budget: Maximum number of pages to select
51 | offsets: Offsets for each KV head
52 |
53 | Returns:
54 | Tuple of (selected page indices, new cache sequence length)
55 | """
56 | current_page_idx = (cache_seqlens - 1) // page_size
57 | assert current_page_idx > 0 and current_page_idx >= page_budget + 1
58 |
59 | # All models have GQA
60 | query_squeezed = query.squeeze(0)
61 | group_size = query_squeezed.size(0) // page_reps.size(2)
62 | page_reps = page_reps[:current_page_idx].repeat_interleave(group_size, dim=2)
63 |
64 | scores = torch.einsum(
65 | 'hd,prhd->hprd',
66 | query_squeezed,
67 | page_reps
68 | )
69 |
70 | scores = scores.max(dim=2).values # [num_pages, num_heads, head_dim]
71 | scores = scores.sum(dim=-1)
72 |
73 | _, indices = torch.topk(
74 | scores,
75 | k=page_budget + 1,
76 | dim=1,
77 | sorted=True,
78 | )
79 |
80 | indices[:, -1] = current_page_idx
81 | new_cache_seqlens = cache_seqlens - (current_page_idx - page_budget) * page_size
82 | new_cache_seqlens = torch.full((query.shape[1],), new_cache_seqlens, device=query.device, dtype=torch.int32)
83 |
84 | active_pages = indices.int() + offsets
85 | return active_pages, new_cache_seqlens
86 |
87 |
88 | _select_pages = torch.compile(_select_pages)
89 | _update_last_page = torch.jit.script(_update_last_page)
90 |
91 |
92 | class QuestAttention(AbstractAttention):
93 | """Quest attention for efficient decoding with dynamic page selection.
94 |
95 | Quest maintains min and max representations for each page of KV cache and uses
96 | them to dynamically select the most relevant pages during decoding.
97 | """
98 |
99 | def __init__(
100 | self,
101 | token_budget: int,
102 | page_size: int,
103 | max_input_tokens: int,
104 | max_output_tokens: int,
105 | num_layers: int,
106 | ):
107 | """Initialize Quest attention.
108 |
109 | Args:
110 | token_budget: Maximum number of tokens to attend to
111 | page_size: Size of each page in the KV cache
112 | max_input_tokens: Maximum input token length (from config)
113 | max_output_tokens: Maximum output token length (from config)
114 | num_layers: Number of transformer layers (from model config)
115 | """
116 | super().__init__()
117 | self.token_budget = token_budget
118 | self.page_size = page_size
119 | self.page_budget = token_budget // page_size
120 | assert token_budget % page_size == 0, "Token budget must be divisible by page size"
121 |
122 | self.max_pages = ((max_input_tokens + max_output_tokens) + page_size - 1) // page_size
123 | self.num_layers = num_layers
124 |
125 | # Page representations per layer
126 | self.page_reps_per_layer: List[Optional[torch.Tensor]] = [None] * num_layers
127 | self.offsets = None
128 |
129 | def _init_page_reps(
130 | self,
131 | keys: torch.Tensor,
132 | layer_idx: int = 0,
133 | ):
134 | """Initialize page representations during prefilling for a specific layer."""
135 | _, num_heads, seq_len, head_dim = keys.shape
136 | keys = keys.squeeze(0).transpose(0, 1) # [seq_len, num_heads, head_dim]
137 |
138 | num_pages = (seq_len + self.page_size - 1) // self.page_size
139 |
140 | if self.page_reps_per_layer[layer_idx] is None:
141 | self.page_reps_per_layer[layer_idx] = torch.zeros(
142 | self.max_pages, 2, num_heads, head_dim,
143 | device=keys.device, dtype=keys.dtype
144 | )
145 |
146 | self.page_reps_per_layer[layer_idx][:, 0] = float('inf')
147 | self.page_reps_per_layer[layer_idx][:, 1] = float('-inf')
148 |
149 | if seq_len % self.page_size != 0:
150 | complete_pages = seq_len // self.page_size
151 | complete_keys = keys[:complete_pages * self.page_size].view(complete_pages, self.page_size, num_heads, head_dim)
152 | complete_min = complete_keys.amin(dim=1)
153 | complete_max = complete_keys.amax(dim=1)
154 | self.page_reps_per_layer[layer_idx][:complete_pages] = torch.stack([complete_min, complete_max], dim=1)
155 |
156 | remainder_keys = keys[complete_pages * self.page_size:]
157 |
158 | self.page_reps_per_layer[layer_idx][complete_pages] = torch.stack([
159 | remainder_keys.amin(dim=0),
160 | remainder_keys.amax(dim=0),
161 | ], dim=0)
162 | else:
163 | keys = keys.view(num_pages, self.page_size, num_heads, head_dim)
164 | self.page_reps_per_layer[layer_idx][:num_pages] = torch.stack([
165 | keys.amin(dim=1),
166 | keys.amax(dim=1),
167 | ], dim=1)
168 |
169 | def __call__(
170 | self,
171 | queries: torch.Tensor, # [batch_size, num_heads, seq_len, head_dim]
172 | keys: torch.Tensor, # [batch_size, num_kv_heads, seq_len, head_dim]
173 | values: torch.Tensor, # [batch_size, num_kv_heads, seq_len, head_dim]
174 | layer_idx: int = 0,
175 | ) -> torch.Tensor:
176 | """
177 | Args:
178 | queries: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
179 | keys: Key tensor of shape [batch_size, num_kv_heads, seq_len, head_dim]
180 | values: Value tensor of shape [batch_size, num_kv_heads, seq_len, head_dim]
181 | layer_idx: Index of the current transformer layer
182 |
183 | Returns:
184 | Attention output tensor of shape [batch_size, num_heads, seq_len, head_dim]
185 | """
186 | self._init_page_reps(keys, layer_idx)
187 |
188 | sparsity = 1.0 - self.token_budget / queries.shape[2]
189 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device))
190 |
191 | return AttentionUtils.flash_attention(queries, keys, values)
192 |
193 | def decode(
194 | self,
195 | query: torch.Tensor, # [1, num_heads, head_dim]
196 | keys: torch.Tensor, # [1, num_kv_heads, head_dim]
197 | values: torch.Tensor, # [1, num_kv_heads, head_dim]
198 | k_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim]
199 | v_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim]
200 | cache_seqlens: torch.Tensor, # [num_heads]
201 | output: torch.Tensor, # [1, num_heads, head_dim]
202 | layer_idx: int = 0,
203 | ) -> torch.Tensor:
204 | """Compute attention during decoding with dynamic page selection for a specific layer.
205 |
206 | Instead of retrieving the selected KV cache, we pass the entire KV cache
207 | and the selected page indices to flash_attn_with_kvcache.
208 |
209 | Args:
210 | query: Query tensor for a single token [1, num_heads, head_dim]
211 | keys: Key tensor for the current token [1, num_kv_heads, head_dim]
212 | values: Value tensor for the current token [1, num_kv_heads, head_dim]
213 | k_cache: Key cache tensor [num_kv_heads, num_blocks, block_size, head_dim]
214 | v_cache: Value cache tensor [num_kv_heads, num_blocks, block_size, head_dim]
215 | cache_seqlens: Tensor of sequence lengths per head [num_heads]
216 | output: Output tensor to store results [1, num_heads, head_dim]
217 | layer_idx: Index of the current transformer layer
218 |
219 | Returns:
220 | Attention output tensor of shape [1, num_heads, head_dim]
221 | """
222 | num_kv_heads, num_blocks, block_size, head_size = k_cache.shape
223 | _, num_q_heads, _ = query.shape
224 | cache_seqlens_int = cache_seqlens[0].item() # Convert to integer for page selection
225 |
226 | if self.offsets is None:
227 | offsets = torch.arange(num_kv_heads, device=query.device, dtype=torch.int32)
228 | offsets = offsets.repeat_interleave(num_q_heads // num_kv_heads)
229 | offsets = offsets.unsqueeze(1) * num_blocks
230 | self.offsets = offsets
231 |
232 | _update_last_page(
233 | page_reps=self.page_reps_per_layer[layer_idx],
234 | keys=keys,
235 | cache_seqlens=cache_seqlens_int,
236 | page_size=self.page_size
237 | )
238 |
239 | active_pages, new_cache_seqlens = _select_pages(
240 | page_reps=self.page_reps_per_layer[layer_idx],
241 | query=query,
242 | cache_seqlens=cache_seqlens_int,
243 | page_size=self.page_size,
244 | page_budget=self.page_budget,
245 | offsets=self.offsets,
246 | )
247 |
248 | flash_attn_with_kvcache(
249 | q=query.squeeze(0).unsqueeze(1).unsqueeze(1),
250 | k_cache=k_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size),
251 | v_cache=v_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size),
252 | block_table=active_pages,
253 | cache_seqlens=new_cache_seqlens,
254 | causal=True,
255 | out=output.squeeze(0).unsqueeze(1).unsqueeze(1),
256 | )
257 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/efficient_prefilling.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .abstract_attention import AbstractAttention
4 | from .minference import (
5 | block_sparse_attention,
6 | vertical_and_slash_kernel,
7 | vertical_slash_sparse_attention,
8 | sum_over_diagonals,
9 | )
10 | from .abstract_attention import AttentionUtils
11 |
12 |
13 | class DenseAttention(AbstractAttention):
14 | """Standard dense attention with causal masking."""
15 |
16 | def __init__(self):
17 | super().__init__()
18 |
19 | def __call__(
20 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int
21 | ) -> torch.Tensor:
22 | self.layer_sparsity_statistics.append(torch.tensor(0.0, device=queries.device))
23 | return AttentionUtils.flash_attention(queries, keys, values)
24 |
25 |
26 | class BlockSparseAttentionMInference(AbstractAttention):
27 | """Block-sparse attention with chunk-level relevance and local windows."""
28 |
29 | def __init__(self, chunk_size: int, top_chunks: int):
30 | super().__init__()
31 | self.chunk_size = chunk_size
32 | self.top_chunks = top_chunks
33 |
34 | assert chunk_size >= 8, "Recommended chunk size is >= 8"
35 | assert top_chunks >= 3, "Must select at least one top chunk"
36 |
37 | @staticmethod
38 | def _calculate_sparsity(seq_len: int, chunk_size: int, top_chunks: int) -> float:
39 | """Calculate sparsity ratio based on selected blocks vs total possible blocks.
40 |
41 | For autoregressive attention:
42 | - First query block picks 1 key block
43 | - First top_chunks query blocks pick max possible key blocks
44 | - Remaining query blocks pick top_chunks key blocks each
45 |
46 | Args:
47 | seq_len: Length of input sequence
48 | chunk_size: Size of each attention block
49 | top_chunks: Number of chunks to select per query
50 |
51 | Returns:
52 | Sparsity ratio between 0 and 1
53 | """
54 | num_blocks = (seq_len + chunk_size - 1) // chunk_size
55 |
56 | total_blocks = num_blocks * (num_blocks + 1) // 2
57 |
58 | selected_blocks = top_chunks * (top_chunks + 1) // 2
59 | selected_blocks += (num_blocks - top_chunks) * top_chunks
60 |
61 | return 1.0 - (selected_blocks / total_blocks)
62 |
63 | @staticmethod
64 | def _get_blocks_for_sparsity(seq_len: int, chunk_size: int, target_sparsity: float) -> int:
65 | """Calculate number of blocks needed to achieve desired sparsity level.
66 |
67 | Uses binary search to find the number of blocks that gives sparsity closest
68 | to the target. The relationship between blocks and sparsity is monotonic.
69 |
70 | Args:
71 | seq_len: Length of input sequence
72 | chunk_size: Size of each attention block
73 | target_sparsity: Desired sparsity ratio between 0 and 1
74 |
75 | Returns:
76 | Number of blocks to select per query to achieve target sparsity
77 | """
78 | num_blocks = (seq_len + chunk_size - 1) // chunk_size
79 |
80 | # Binary search for number of blocks
81 | left, right = 3, num_blocks # Minimum 4 blocks needed
82 | best_blocks = 3
83 | best_diff = float('inf')
84 |
85 | while left <= right:
86 | mid = (left + right) // 2
87 | sparsity = BlockSparseAttentionMInference._calculate_sparsity(seq_len, chunk_size, mid)
88 |
89 | diff = abs(sparsity - target_sparsity)
90 | if diff < best_diff:
91 | best_diff = diff
92 | best_blocks = mid
93 |
94 | if sparsity < target_sparsity:
95 | right = mid - 1
96 | else:
97 | left = mid + 1
98 |
99 | return best_blocks
100 |
101 | def __call__(
102 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int
103 | ) -> torch.Tensor:
104 | sparsity = torch.tensor(self._calculate_sparsity(queries.shape[2], self.chunk_size, self.top_chunks), device=queries.device)
105 | self.layer_sparsity_statistics.append(sparsity)
106 | return block_sparse_attention(queries, keys, values, self.top_chunks, self.chunk_size, self.chunk_size)
107 |
108 |
109 | class VerticalAndSlashAttentionMInference(AbstractAttention):
110 | """Combines vertical and diagonal patterns for efficient sparse attention.
111 |
112 | Implements the Vertical and Slash attention mechanism that selects important tokens based on:
113 | 1) Vertical patterns - Top-k tokens that receive high attention across all queries
114 | 2) Diagonal patterns - Diagonal stripes that capture local dependencies
115 | """
116 |
117 | def __init__(
118 | self,
119 | vertical_size: int = 64,
120 | slash_size: int = 128,
121 | approximation_size: int = 64,
122 | ):
123 | """Initialize Vertical and Slash attention.
124 |
125 | Args:
126 | vertical_size (int): Number of vertical tokens to select
127 | slash_size (int): Number of diagonal stripes to select
128 | """
129 | super().__init__()
130 | self.vertical_size = vertical_size
131 | self.slash_size = slash_size
132 | self.approximation_size = approximation_size
133 |
134 | def __call__(
135 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int
136 | ) -> torch.Tensor:
137 | attn_output, sparsity = vertical_and_slash_kernel(
138 | queries,
139 | keys,
140 | values,
141 | vertical_size=min(self.vertical_size, queries.shape[2]),
142 | slash_size=min(self.slash_size, queries.shape[2]),
143 | last_q=min(self.approximation_size, queries.shape[2]),
144 | )
145 | self.layer_sparsity_statistics.append(sparsity)
146 | return attn_output
147 |
148 |
149 | class FlexPrefill(AbstractAttention):
150 | def __init__(
151 | self,
152 | alpha: float = 0.9,
153 | approximation_size: int = 512,
154 | min_budget: int = 256,
155 | ):
156 | super().__init__()
157 | self.alpha = alpha
158 | self.approximation_size = approximation_size
159 | self.min_budget = min_budget
160 | self.causal_mask = None
161 |
162 | @staticmethod
163 | def score_cover_topk(x: torch.Tensor, score: float):
164 | cumsum_x = torch.cumsum(torch.sort(x, dim=-1, descending=True).values, dim=-1)
165 | topk = torch.sum(cumsum_x <= score, dim=-1) + 1
166 | return topk
167 |
168 | def get_active_blocks(
169 | self, q, k, v
170 | ):
171 | _, _, seq_len, head_dim = q.shape
172 |
173 | # Compute attention scores for last queries
174 | last_q_tokens = q[..., -self.approximation_size:, :] / math.sqrt(head_dim)
175 | qk = torch.einsum('bhik,bhjk->bhij', last_q_tokens, k)
176 |
177 | # Apply causal masking
178 | if self.causal_mask is None:
179 | self.causal_mask = torch.arange(0, self.approximation_size, device=last_q_tokens.device)
180 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :]
181 | self.causal_mask = self.causal_mask[None, None, ...]
182 |
183 | qk[..., -self.approximation_size:] = torch.where(
184 | self.causal_mask,
185 | qk[..., -self.approximation_size:],
186 | torch.tensor(float("-inf"), device=qk.device, dtype=qk.dtype)
187 | )
188 |
189 | # Get attention patterns
190 | scores = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32)
191 |
192 | # Compute vertical patterns
193 | vertical = scores.mean(dim=-2)
194 | vertical_size = max(self.min_budget, self.score_cover_topk(vertical, self.alpha).item())
195 | vertical[..., :4] = float("inf")
196 | vertical_topk = torch.topk(vertical, vertical_size, -1).indices
197 |
198 | # Fixed local window for slash patterns
199 | slashes = sum_over_diagonals(scores)[..., :-self.approximation_size + 1] / self.approximation_size
200 | slash_size = max(self.min_budget, self.score_cover_topk(slashes, self.alpha).item())
201 | slashes[..., -64:] = float("inf")
202 | slash = (seq_len - 1) - torch.topk(slashes, slash_size, -1).indices
203 |
204 | return vertical_topk, slash
205 |
206 | def __call__(
207 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int
208 | ) -> torch.Tensor:
209 | assert queries.shape[-1] == keys.shape[-1]
210 |
211 | # Get active blocks for sparse portion
212 | vertical_idx, slash_idx = self.get_active_blocks(queries, keys, values)
213 |
214 | # Calculate sparse attention for non-dense portion
215 | sparse_out, sparsity = vertical_slash_sparse_attention(
216 | queries,
217 | keys,
218 | values,
219 | vertical_idx,
220 | slash_idx,
221 | )
222 |
223 | self.layer_sparsity_statistics.append(sparsity)
224 |
225 | return sparse_out
226 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/handler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .abstract_attention import AttentionUtils
3 |
4 |
5 | def update_kv_cache(
6 | key: torch.Tensor,
7 | value: torch.Tensor,
8 | k_cache: torch.Tensor,
9 | v_cache: torch.Tensor,
10 | is_prefilling: bool,
11 | tokens_per_head: torch.Tensor,
12 | q_heads_per_kv: int,
13 | head_indices: torch.Tensor = None,
14 | queries: torch.Tensor = None,
15 | ):
16 | """Update the KV cache with new key/value tensors.
17 |
18 | Args:
19 | key: [num_tokens, num_kv_heads, head_size]
20 | value: [num_tokens, num_kv_heads, head_size]
21 | k_cache: [num_kv_heads, num_blocks, block_size, head_size]
22 | v_cache: [num_kv_heads, num_blocks, block_size, head_size]
23 | is_prefilling: Whether we're in prefilling phase
24 | tokens_per_head: Tensor tracking token counts per head for current layer
25 | q_heads_per_kv: Number of query heads per key/value head
26 | head_indices: Precomputed indices for heads to avoid recreating tensor each time
27 | queries: [num_tokens, num_heads, head_size], required for compression
28 | """
29 | if is_prefilling:
30 | from sparse_frontier.modelling.attention.registry import get_attention
31 | key, value, seq_lens = get_attention().kv_compress(
32 | queries=queries,
33 | keys=key,
34 | values=value,
35 | )
36 |
37 | num_kv_heads, num_tokens, _ = key.shape
38 |
39 | # View the cache as a contiguous vector by merging block dimensions
40 | k_cache_flat = k_cache.view(num_kv_heads, -1, k_cache.shape[-1])
41 | v_cache_flat = v_cache.view(num_kv_heads, -1, v_cache.shape[-1])
42 |
43 | # Fill the prefix of the flattened cache
44 | k_cache_flat[:, :num_tokens] = key
45 | v_cache_flat[:, :num_tokens] = value
46 |
47 | tokens_per_head += seq_lens.repeat_interleave(q_heads_per_kv)
48 | else:
49 | num_tokens, num_kv_heads, _ = key.shape
50 |
51 | # Decoding case: single token update, use tokens_per_head to place correctly
52 | assert num_tokens == 1, "Decoding should only add one token at a time"
53 |
54 | # View the cache as a contiguous vector by merging block dimensions
55 | k_cache_flat = k_cache.view(num_kv_heads, -1, k_cache.shape[-1])
56 | v_cache_flat = v_cache.view(num_kv_heads, -1, v_cache.shape[-1])
57 |
58 | # Update the cache at the specific positions for all heads at once
59 | k_cache_flat[head_indices, tokens_per_head[::q_heads_per_kv]] = key[0]
60 | v_cache_flat[head_indices, tokens_per_head[::q_heads_per_kv]] = value[0]
61 |
62 | # Update token counts
63 | tokens_per_head += 1
64 |
65 |
66 | class AttentionHandler:
67 | def __init__(
68 | self,
69 | tp_size: int,
70 | model_q_heads: int,
71 | model_kv_heads: int,
72 | model_layers: int,
73 | max_input_tokens: int,
74 | max_output_tokens: int,
75 | block_size: int
76 | ):
77 | """Initialize the attention handler.
78 |
79 | Args:
80 | tp_size: Tensor parallelism size
81 | model_q_heads: Total number of query heads
82 | model_kv_heads: Total number of key/value heads
83 | model_layers: Number of transformer layers
84 | max_input_tokens: Maximum number of input tokens
85 | max_output_tokens: Maximum number of output tokens
86 | block_size: default size of each block in the KV cache
87 | """
88 | self.tp_size = tp_size
89 | self.model_q_heads = model_q_heads
90 | self.model_kv_heads = model_kv_heads
91 | self.q_heads_per_gpu = model_q_heads // tp_size
92 | self.kv_heads_per_gpu = model_kv_heads // tp_size
93 | self.q_heads_per_kv = model_q_heads // model_kv_heads
94 | self.model_layers = model_layers
95 | self.max_seq_len = max_input_tokens + max_output_tokens
96 | self.max_blocks = (self.max_seq_len + block_size - 1) // block_size
97 | self.block_size = block_size
98 |
99 | self.current_layer = 0
100 | self.tokens_per_layer_head = torch.zeros(
101 | (model_layers, self.q_heads_per_gpu),
102 | dtype=torch.int32,
103 | device="cpu"
104 | )
105 | self.head_indices = None
106 |
107 | def _compute_attention_head_by_head(
108 | self,
109 | queries: torch.Tensor,
110 | keys: torch.Tensor,
111 | values: torch.Tensor,
112 | output: torch.Tensor,
113 | ) -> torch.Tensor:
114 | """Compute attention output by processing each head separately.
115 |
116 | Args:
117 | queries: [num_tokens, num_heads, head_size]
118 | keys: [num_tokens, num_kv_heads, head_size]
119 | values: [num_tokens, num_kv_heads, head_size]
120 | output: [num_tokens, num_heads, head_size]
121 | """
122 | from sparse_frontier.modelling.attention.registry import get_attention
123 |
124 | # Get the attention implementation
125 | attention = get_attention()
126 |
127 | efficient_attention_classes = [
128 | "BlockSparseAttentionMInference",
129 | "VerticalAndSlashAttentionMInference",
130 | "FlexPrefill"
131 | ]
132 |
133 | if attention.__class__.__name__ not in efficient_attention_classes:
134 | output[:] = attention(
135 | queries=queries.transpose(0, 1).unsqueeze(0),
136 | keys=keys.transpose(0, 1).unsqueeze(0),
137 | values=values.transpose(0, 1).unsqueeze(0),
138 | layer_idx=self.current_layer,
139 | ).squeeze(0).transpose(0, 1)
140 | return
141 |
142 | # Otherwise, process head by head
143 | head_size = queries.shape[-1]
144 | q_heads_per_kv = queries.shape[1] // keys.shape[1]
145 |
146 | for head in range(self.q_heads_per_gpu):
147 | kv_head = head // q_heads_per_kv
148 | output[:, head, :] = attention(
149 | queries=queries[:, head, :].view(1, 1, -1, head_size),
150 | keys=keys[:, kv_head, :].view(1, 1, -1, head_size),
151 | values=values[:, kv_head, :].view(1, 1, -1, head_size),
152 | layer_idx=self.current_layer,
153 | ).view(-1, head_size)
154 |
155 | def __call__(
156 | self,
157 | queries: torch.Tensor,
158 | keys: torch.Tensor,
159 | values: torch.Tensor,
160 | kv_cache: torch.Tensor,
161 | output: torch.Tensor,
162 | ) -> torch.Tensor:
163 | """Process attention for either prefill or decode phase.
164 |
165 | Args:
166 | queries: [num_tokens, num_heads, head_size]
167 | keys: [num_tokens, num_kv_heads, head_size]
168 | values: [num_tokens, num_kv_heads, head_size]
169 | kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size]
170 | output: [num_tokens, num_heads, head_size]
171 |
172 | Returns:
173 | attention_output: [num_tokens, num_heads, head_size]
174 | """
175 | num_tokens = queries.shape[0]
176 | is_prefilling = num_tokens > 1
177 |
178 | if is_prefilling:
179 | self.tokens_per_layer_head[self.current_layer, :] = 0
180 |
181 | # Initialize head_indices if not already done
182 | if self.head_indices is None and keys.numel() > 0:
183 | num_kv_heads = keys.shape[1]
184 | self.head_indices = torch.arange(num_kv_heads, device=keys.device)
185 | self.tokens_per_layer_head = self.tokens_per_layer_head.to(keys.device)
186 |
187 | if kv_cache.numel() == 0:
188 | self._compute_attention_head_by_head(
189 | queries=queries,
190 | keys=keys,
191 | values=values,
192 | output=output,
193 | )
194 | self.current_layer = (self.current_layer + 1) % self.model_layers
195 | return
196 |
197 | if is_prefilling:
198 | self._compute_attention_head_by_head(
199 | queries=queries,
200 | keys=keys,
201 | values=values,
202 | output=output,
203 | )
204 |
205 | k_cache, v_cache = AttentionUtils.reshape_kv_cache(kv_cache, self.block_size, self.max_blocks)
206 |
207 | # Then update cache (with compression if prefilling)
208 | update_kv_cache(
209 | keys,
210 | values,
211 | k_cache,
212 | v_cache,
213 | is_prefilling=is_prefilling,
214 | tokens_per_head=self.tokens_per_layer_head[self.current_layer],
215 | q_heads_per_kv=self.q_heads_per_kv,
216 | head_indices=self.head_indices,
217 | queries=queries if is_prefilling else None,
218 | )
219 |
220 | if is_prefilling:
221 | # By this point we should have put all the layer sparsity statistics for all methods
222 | from sparse_frontier.modelling.attention.registry import get_attention
223 | get_attention().sync_and_calc_layer_stats()
224 |
225 | if not is_prefilling:
226 | from sparse_frontier.modelling.attention.registry import get_attention
227 |
228 | get_attention().decode(
229 | query=queries,
230 | keys=keys,
231 | values=values,
232 | k_cache=k_cache,
233 | v_cache=v_cache,
234 | cache_seqlens=self.tokens_per_layer_head[self.current_layer],
235 | output=output,
236 | layer_idx=self.current_layer,
237 | )
238 |
239 | self.current_layer = (self.current_layer + 1) % self.model_layers
240 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/kv_compression.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from .abstract_attention import AbstractAttention
5 | from sparse_frontier.utils.globals import is_vllm_profiling_done
6 |
7 |
8 | class SnapKVCompression(AbstractAttention):
9 | """SnapKV compression for efficient decoding"""
10 | def __init__(
11 | self,
12 | token_capacity: int,
13 | approximation_window: int = 256,
14 | kernel_size: int = 7,
15 | local_window: int = 128,
16 | prefix_tokens: int = 4,
17 | ):
18 | super().__init__()
19 | self.token_capacity = token_capacity
20 | self.approximation_window = approximation_window
21 | self.kernel_size = kernel_size
22 | self.local_window = local_window
23 | self.prefix_tokens = prefix_tokens
24 | self.causal_mask = None
25 |
26 | def kv_compress(
27 | self,
28 | queries: torch.Tensor, # [num_tokens, num_q_heads, head_size]
29 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
30 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
31 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
32 | """Compress KV cache using SnapKV approach with GQA support"""
33 | # Use last approximation_window queries to estimate importance
34 | approx_queries = queries[-self.approximation_window:] / math.sqrt(queries.size(-1)) # [approx_window, num_q_heads, head_size]
35 |
36 | num_q_heads = queries.size(1)
37 | num_kv_heads = keys.size(1)
38 | group_size = num_q_heads // num_kv_heads
39 |
40 | # Repeat keys for each query head in the group
41 | # [num_tokens, num_kv_heads, head_size] -> [num_tokens, num_q_heads, head_size]
42 | repeated_keys = keys.repeat_interleave(group_size, dim=1)
43 |
44 | # Now we can calculate attention scores with matching head dimensions
45 | scores = torch.einsum(
46 | 'thd,nhd->htn',
47 | approx_queries,
48 | repeated_keys
49 | )
50 |
51 | # Initialize and cache causal mask if not already created
52 | if self.causal_mask is None:
53 | self.causal_mask = torch.arange(0, self.approximation_window, device=scores.device)
54 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :]
55 | self.causal_mask = self.causal_mask[None, ...] # Add head dimension
56 |
57 | # Apply causal masking and softmax
58 | scores[..., -self.approximation_window:] = torch.where(
59 | self.causal_mask,
60 | scores[..., -self.approximation_window:],
61 | torch.tensor(float("-inf"), device=scores.device, dtype=scores.dtype)
62 | )
63 |
64 | attn_weights = torch.softmax(scores, dim=-1, dtype=torch.float32).to(keys.dtype)
65 |
66 | # Reshape attention weights to group query heads [num_kv_heads, group_size, approx_window, num_tokens]
67 | grouped_weights = attn_weights.view(num_kv_heads, group_size, -1, attn_weights.size(-1))
68 | # Average across group dimension [num_kv_heads, approx_window, num_tokens]
69 | token_importance = grouped_weights.mean(dim=1).sum(dim=1) # [num_kv_heads, num_tokens]
70 |
71 | # Apply pooling for smoother selection per head
72 | token_importance = F.avg_pool1d(
73 | token_importance.unsqueeze(1),
74 | kernel_size=self.kernel_size,
75 | padding=self.kernel_size // 2,
76 | stride=1
77 | ).squeeze(1)
78 |
79 | token_importance[..., -self.local_window:] = float('inf')
80 | token_importance[..., :self.prefix_tokens] = float('inf')
81 |
82 | # Select top-k tokens per head
83 | capacity = min(self.token_capacity, keys.size(0))
84 |
85 | assert capacity > self.local_window + self.prefix_tokens, f"Capacity {capacity} must be greater than local_window {self.local_window} + prefix_tokens {self.prefix_tokens}"
86 |
87 | _, indices = torch.topk(token_importance, k=capacity, dim=-1) # [num_kv_heads, capacity]
88 |
89 | # Expand indices for gathering
90 | # [num_kv_heads, capacity] -> [num_kv_heads, capacity, head_size]
91 | expanded_indices = indices.unsqueeze(-1).expand(-1, -1, keys.size(-1))
92 |
93 | compressed_keys = torch.gather(
94 | keys.transpose(0, 1), # [num_kv_heads, num_tokens, head_size]
95 | dim=1,
96 | index=expanded_indices
97 | )
98 |
99 | compressed_values = torch.gather(
100 | values.transpose(0, 1), # [num_kv_heads, num_tokens, head_size]
101 | dim=1,
102 | index=expanded_indices
103 | )
104 |
105 | # Track sparsity - based on fixed capacity ratio
106 | sparsity = 1.0 - (capacity / keys.size(0))
107 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device))
108 |
109 | # Create sequence length tensor (same for all heads)
110 | seq_lens = torch.full((num_kv_heads,), capacity, device=queries.device, dtype=torch.long)
111 |
112 | return compressed_keys, compressed_values, seq_lens
113 |
114 |
115 | class AdaSnapKVCompression(AbstractAttention):
116 | """Adaptive SnapKV compression with non-uniform token distribution across heads"""
117 | def __init__(
118 | self,
119 | token_capacity: int,
120 | approximation_window: int = 256,
121 | kernel_size: int = 7,
122 | local_window: int = 128,
123 | prefix_tokens: int = 4,
124 | min_head_capacity_ratio: float = 0.2,
125 | ):
126 | super().__init__()
127 | self.token_capacity = token_capacity
128 | self.approximation_window = approximation_window
129 | self.kernel_size = kernel_size
130 | self.local_window = local_window
131 | self.prefix_tokens = prefix_tokens
132 | self.min_head_capacity_ratio = min_head_capacity_ratio
133 | self.causal_mask = None
134 |
135 | def kv_compress(
136 | self,
137 | queries: torch.Tensor, # [num_tokens, num_q_heads, head_size]
138 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
139 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
140 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
141 | """Compress KV cache using AdaSnapKV approach with adaptive token distribution"""
142 | assert self.approximation_window < keys.size(0)
143 |
144 | # Use last approximation_window queries to estimate importance
145 | approx_queries = queries[-self.approximation_window:] / math.sqrt(queries.size(-1)) # [approx_window, num_q_heads, head_size]
146 |
147 | num_q_heads = queries.size(1)
148 | num_kv_heads = keys.size(1)
149 | group_size = num_q_heads // num_kv_heads
150 |
151 | # Repeat keys for each query head in the group
152 | # [num_tokens, num_kv_heads, head_size] -> [num_tokens, num_q_heads, head_size]
153 | repeated_keys = keys.repeat_interleave(group_size, dim=1)
154 |
155 | # Now we can calculate attention scores with matching head dimensions
156 | scores = torch.einsum(
157 | 'thd,nhd->htn',
158 | approx_queries,
159 | repeated_keys
160 | )
161 |
162 | # Initialize and cache causal mask if not already created
163 | if self.causal_mask is None:
164 | self.causal_mask = torch.arange(0, self.approximation_window, device=scores.device)
165 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :]
166 | self.causal_mask = self.causal_mask[None, ...] # Add head dimension
167 |
168 | # Apply causal masking and softmax
169 | scores[..., -self.approximation_window:] = torch.where(
170 | self.causal_mask,
171 | scores[..., -self.approximation_window:],
172 | torch.tensor(float("-inf"), device=scores.device, dtype=scores.dtype)
173 | )
174 |
175 | attn_weights = torch.softmax(scores, dim=-1, dtype=torch.float32).to(keys.dtype)
176 |
177 | grouped_weights = attn_weights.view(num_kv_heads, group_size, -1, attn_weights.size(-1))
178 |
179 | token_importance = grouped_weights.max(dim=2)[0].max(dim=1)[0]
180 |
181 | # Apply pooling for smoother selection per head
182 | token_importance = F.avg_pool1d(
183 | token_importance.unsqueeze(1),
184 | kernel_size=self.kernel_size,
185 | padding=self.kernel_size // 2,
186 | stride=1
187 | ).squeeze(1)
188 |
189 | # Always keep recent tokens and prefix tokens
190 | token_importance[..., -self.local_window:] = float('inf')
191 | token_importance[..., :self.prefix_tokens] = float('inf')
192 |
193 | # Calculate total capacity and minimum capacity per head
194 | total_capacity = self.token_capacity * num_kv_heads
195 | assert total_capacity <= keys.size(0) * num_kv_heads
196 | min_capacity_per_head = int(self.token_capacity * self.min_head_capacity_ratio)
197 | min_capacity_per_head = max(min_capacity_per_head, self.local_window + self.prefix_tokens)
198 | assert self.token_capacity >= self.local_window + self.prefix_tokens
199 | assert min_capacity_per_head <= keys.size(0)
200 | assert total_capacity >= min_capacity_per_head * num_kv_heads
201 | remaining_capacity = total_capacity - (min_capacity_per_head * num_kv_heads)
202 |
203 | # Get the top-k tokens for minimum capacity per head
204 | _, min_indices = torch.topk(token_importance, k=min_capacity_per_head, dim=-1) # [num_kv_heads, min_capacity]
205 |
206 | # Vectorized mask creation using scatter
207 | selected_mask = torch.zeros_like(token_importance, dtype=torch.bool)
208 | selected_mask.scatter_(
209 | dim=1,
210 | index=min_indices,
211 | src=torch.ones_like(min_indices, dtype=torch.bool)
212 | )
213 |
214 | # Vectorized masking
215 | masked_importance = token_importance.masked_fill(selected_mask, float('-inf'))
216 | flat_importance = masked_importance.view(-1)
217 |
218 | # Global selection with vectorized index conversion
219 | _, flat_indices = torch.topk(flat_importance, k=remaining_capacity, dim=-1)
220 |
221 | # Flatten and update selected_mask
222 | flat_selected_mask = selected_mask.view(-1)
223 | flat_selected_mask.scatter_(0, flat_indices, True)
224 | selected_mask = flat_selected_mask.view(num_kv_heads, -1)
225 |
226 | seq_lens = selected_mask.sum(dim=1)
227 | max_seq_len = seq_lens.max().item()
228 | compressed_keys = torch.zeros(num_kv_heads, max_seq_len, keys.size(-1), device=keys.device, dtype=keys.dtype)
229 | compressed_values = torch.zeros(num_kv_heads, max_seq_len, values.size(-1), device=values.device, dtype=values.dtype)
230 |
231 | keys_t = keys.transpose(0, 1)
232 | values_t = values.transpose(0, 1)
233 |
234 | for head_idx in range(num_kv_heads):
235 | compressed_keys[head_idx, :seq_lens[head_idx]] = keys_t[head_idx, selected_mask[head_idx]]
236 | compressed_values[head_idx, :seq_lens[head_idx]] = values_t[head_idx, selected_mask[head_idx]]
237 |
238 | # Track sparsity - based on actual tokens kept
239 | total_tokens_kept = seq_lens.sum().item()
240 | sparsity = 1.0 - (total_tokens_kept / (keys.size(0) * num_kv_heads))
241 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device))
242 |
243 | return compressed_keys, compressed_values, seq_lens
244 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/minference/__init__.py:
--------------------------------------------------------------------------------
1 | from .block import block_sparse_attention
2 | from .vertical_and_slash import vertical_and_slash_kernel, vertical_slash_sparse_attention, sum_over_diagonals
3 |
4 |
5 | __all__ = [
6 | "block_sparse_attention",
7 | "vertical_and_slash_kernel",
8 | "vertical_slash_sparse_attention",
9 | "sum_over_diagonals",
10 | ]
11 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/minference/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 | # Code adopted from https://github.com/microsoft/MInference
4 |
5 | import torch
6 | import triton
7 | import triton.language as tl
8 |
9 |
10 | @triton.jit
11 | def _triton_block_sparse_attn_fwd_kernel(
12 | Q, K, V, seqlens, sm_scale,
13 | block_index,
14 | Out,
15 | stride_qz, stride_qh, stride_qm, stride_qk,
16 | stride_kz, stride_kh, stride_kn, stride_kk,
17 | stride_vz, stride_vh, stride_vn, stride_vk,
18 | stride_oz, stride_oh, stride_om, stride_ok,
19 | Z, H, N_CTX,
20 | NUM_ROWS, MAX_BLOCKS_PRE_ROW,
21 | BLOCK_M: tl.constexpr,
22 | BLOCK_N: tl.constexpr,
23 | BLOCK_DMODEL: tl.constexpr,
24 | dtype: tl.constexpr,
25 | ):
26 | start_m = tl.program_id(0)
27 | off_hz = tl.program_id(1)
28 |
29 | seqlen = tl.load(seqlens + off_hz // H)
30 | if start_m * BLOCK_M >= seqlen:
31 | return
32 |
33 | # initialize offsets
34 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
35 | offs_n = tl.arange(0, BLOCK_N)
36 | offs_d = tl.arange(0, BLOCK_DMODEL)
37 |
38 | qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
39 | kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
40 |
41 | q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
42 | k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk
43 | v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk
44 | o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
45 |
46 | blocks_ptr = block_index + (off_hz * NUM_ROWS + start_m) * MAX_BLOCKS_PRE_ROW
47 |
48 | # initialize pointer to m and l
49 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
50 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
51 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
52 | # scale sm_scale by log_2(e) and use
53 | # 2^x instead of exp in the loop because CSE and LICM
54 | # don't work as expected with `exp` in the loop
55 | qk_scale = sm_scale * 1.44269504
56 | # load q: it will stay in SRAM throughout
57 | q = tl.load(q_ptrs)
58 | q = (q * qk_scale).to(dtype)
59 |
60 | # loop over k, v and update accumulator
61 | m_mask = offs_m[:, None] < seqlen
62 | block_count = tl.minimum((start_m + 1) * BLOCK_M // BLOCK_N, MAX_BLOCKS_PRE_ROW)
63 |
64 | for sparse_block_idx in range(block_count):
65 | real_block_idx = tl.load(blocks_ptr + sparse_block_idx)
66 | start_n = real_block_idx * BLOCK_N
67 | cols = start_n + offs_n
68 | # -- load k, v --
69 | k = tl.load(k_ptrs + cols[None, :] * stride_kn)
70 | v = tl.load(v_ptrs + cols[:, None] * stride_vn)
71 | # -- compute qk --
72 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
73 | # if start_n + BLOCK_N < seqlen:
74 | # qk = tl.where(m_mask, qk, float("-inf"))
75 | # else:
76 | causal_mask = cols[None, :] <= offs_m[:, None]
77 | qk = tl.where(m_mask & causal_mask, qk, float("-inf"))
78 | qk += tl.dot(q, k)
79 | # -- compute scaling constant --
80 | m_i_new = tl.maximum(m_i, tl.max(qk, 1))
81 | alpha = tl.math.exp2(m_i - m_i_new)
82 | p = tl.math.exp2(qk - m_i_new[:, None])
83 | # -- scale and update acc --
84 | acc_scale = l_i * 0 + alpha # workaround some compiler bug
85 | acc *= acc_scale[:, None]
86 | acc += tl.dot(p.to(dtype), v)
87 | # -- update m_i and l_i --
88 | l_i = l_i * alpha + tl.sum(p, 1)
89 | m_i = m_i_new
90 |
91 | # write back O
92 | acc /= l_i[:, None]
93 | tl.store(o_ptrs, acc.to(dtype), mask=m_mask)
94 |
95 |
96 | def _triton_block_sparse_attention(
97 | q, # [BATCH, N_HEADS, N_CTX, D_HEAD]
98 | k, # [BATCH, N_HEADS, N_CTX, D_HEAD]
99 | v, # [BATCH, N_HEADS, N_CTX, D_HEAD]
100 | seqlens, # [BATCH, ]
101 | block_index, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), MAX_BLOCKS_PRE_ROW]
102 | sm_scale,
103 | block_size_M=64,
104 | block_size_N=64,
105 | ) -> torch.Tensor:
106 | # shape constraints
107 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
108 | assert Lq == Lk and Lk == Lv
109 | assert Lk in {16, 32, 64, 128}
110 | o = torch.zeros_like(q)
111 | grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
112 | dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
113 | _triton_block_sparse_attn_fwd_kernel[grid](
114 | q, k, v, seqlens, sm_scale,
115 | block_index,
116 | o,
117 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
118 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
119 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
120 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
121 | q.shape[0], q.shape[1], q.shape[2],
122 | block_index.shape[-2], block_index.shape[-1],
123 | BLOCK_M=block_size_M, BLOCK_N=block_size_N,
124 | BLOCK_DMODEL=Lk,
125 | dtype=dtype,
126 | num_warps=4, num_stages=2,
127 | )
128 |
129 | return o
130 |
131 |
132 | def _build_block_index(
133 | query: torch.Tensor,
134 | key: torch.Tensor,
135 | top_k: int,
136 | block_size_M: int = 64,
137 | block_size_N: int = 64,
138 | ):
139 | """Build block index for sparse attention.
140 |
141 | This function enforces a fixed pattern for each query block:
142 | - The first block (attention sink)
143 | - The main diagonal block
144 | - The block one diagonal below the main diagonal
145 |
146 | These three blocks are forced by setting their scores to infinity before the topk operation.
147 | The remaining blocks are selected based on the attention scores between query and key blocks.
148 |
149 | Args:
150 | query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim]
151 | key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim]
152 | top_k: Number of blocks to select per query block
153 | block_size_M: Query block size
154 | block_size_N: Key block size
155 |
156 | Returns:
157 | Block indices tensor of shape [batch_size, num_heads, num_query_blocks, top_k]
158 | """
159 |
160 | batch_size, num_heads, context_size, head_dim = query.shape
161 | query_pool = query.reshape((batch_size, num_heads, -1, block_size_M, head_dim)).mean(dim=-2)
162 | key_pool = key.reshape((batch_size, num_heads, -1, block_size_N, head_dim)).mean(dim=-2)
163 | arange_M = torch.arange(query_pool.shape[-2], dtype=torch.int32, device=query.device) * block_size_M
164 | arange_N = torch.arange(key_pool.shape[-2], dtype=torch.int32, device=key.device) * block_size_N
165 | p_pool = torch.einsum(f'bhmk, bhnk -> bhmn', query_pool, key_pool)
166 |
167 | # Replace the direct assignment with diagonal_scatter
168 | diag_size = min(p_pool.shape[-2], p_pool.shape[-1])
169 | inf_diag = torch.full((batch_size, num_heads, diag_size), float('inf'), device=p_pool.device)
170 |
171 | # Set main diagonal to inf
172 | p_pool = torch.diagonal_scatter(p_pool, inf_diag, dim1=-2, dim2=-1)
173 |
174 | # Set one diagonal below main to inf
175 | p_pool = torch.diagonal_scatter(p_pool, inf_diag[..., :-1], offset=-1, dim1=-2, dim2=-1)
176 |
177 | # Keep the causal mask
178 | p_pool = p_pool.where(arange_M[None, None, :, None] >= arange_N[None, None, None, :], -torch.inf)
179 |
180 | # Set the first block to inf - always pick attention sinks
181 | p_pool[..., 0] = torch.inf
182 |
183 | top_k = min(top_k, context_size // block_size_N)
184 | # assert top_k <= context_size // block_size_N, f"top_k ({top_k}) must be <= context_size/block_size_N ({context_size // block_size_N})"
185 |
186 | return torch.topk(p_pool, top_k, dim=-1).indices.to(torch.int32).sort(dim=-1).values
187 |
188 |
189 | def block_sparse_attention(
190 | query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
191 | key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
192 | value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
193 | top_k: int,
194 | block_size_M: int = 64,
195 | block_size_N: int = 64,
196 | ):
197 | _, _, context_size, head_dim = query.shape
198 | pad = block_size_M - (query.shape[2] & (block_size_M - 1))
199 | query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
200 | key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
201 | value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
202 | seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
203 | sm_scale = head_dim ** -0.5
204 | block_index = _build_block_index(query, key, top_k, block_size_N, block_size_N)
205 | out = _triton_block_sparse_attention(query, key, value, seqlens, block_index, sm_scale, block_size_M, block_size_N)
206 | return out[..., :context_size, :]
207 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/minference/csrc/kernels.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT license.
3 | // Code adopted from https://github.com/microsoft/MInference
4 |
5 | #include
6 | #include "torch/extension.h"
7 |
8 | std::vector convert_vertical_slash_indexes(
9 | torch::Tensor seqlens, // [BATCH, ]
10 | torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
11 | torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
12 | int context_size,
13 | int block_size_M,
14 | int block_size_N
15 | );
16 |
17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18 | m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes, "dynamic sparse index function");
19 | }
20 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/minference/csrc/vertical_slash_index.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT license.
3 | // Code adopted from https://github.com/microsoft/MInference
4 |
5 | #include
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | #include
13 |
14 | // __device__ int min(int x, int y) {
15 | // return x < y ? x : y;
16 | // }
17 |
18 | // __device__ int max(int x, int y) {
19 | // return x > y ? x : y;
20 | // }
21 |
22 | __device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
23 | for (int idx = range_start; idx < range_end; idx += block_size) {
24 | block_offset[block_count++] = idx;
25 | }
26 | }
27 |
28 | __global__ void convert_vertical_slash_indexes_kernel(
29 | const int* seqlens, // [BATCH, ]
30 | const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
31 | const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
32 | int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
33 | int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
34 | int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
35 | int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
36 | int N_HEADS,
37 | int N_ROWS,
38 | int BLOCK_SIZE_M,
39 | int BLOCK_SIZE_N,
40 | int NNZ_V,
41 | int NNZ_S
42 | ) {
43 | const int batch_idx = blockIdx.y;
44 | const int head_idx = blockIdx.x;
45 | const int group_idx = blockIdx.z;
46 |
47 | int seqlen = seqlens[batch_idx];
48 | int block_idx_m = group_idx * blockDim.x + threadIdx.x;
49 | int start_m = block_idx_m * BLOCK_SIZE_M;
50 | if (start_m >= seqlen) {
51 | return;
52 | }
53 | int end_m = start_m + BLOCK_SIZE_M;
54 | vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
55 | slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
56 | int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
57 | block_count += row_offset;
58 | block_offset += row_offset * NNZ_S;
59 | column_count += row_offset;
60 | column_index += row_offset * NNZ_V;
61 |
62 | int tmp_col_cnt = 0, tmp_blk_cnt = 0;
63 | int s = 0, v = 0;
64 | int v_idx = vertical_indexes[v++];
65 | int s_idx = slash_indexes[s++];
66 | while (s_idx >= end_m) {
67 | s_idx = slash_indexes[s++];
68 | }
69 | s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
70 | int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
71 | while (1) {
72 | if (v_idx < range_end) {
73 | if (v_idx < range_start) {
74 | column_index[tmp_col_cnt++] = v_idx;
75 | }
76 | if (v < NNZ_V) {
77 | v_idx = vertical_indexes[v++];
78 | } else {
79 | v_idx = end_m + BLOCK_SIZE_M;
80 | }
81 | } else {
82 | if (s < NNZ_S) {
83 | s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
84 | } else {
85 | save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
86 | break;
87 | }
88 | if (s_idx > range_end + BLOCK_SIZE_M) {
89 | save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
90 | range_start = s_idx - BLOCK_SIZE_M;
91 | range_end = s_idx;
92 | } else if (s_idx > range_end) {
93 | range_end += BLOCK_SIZE_M;
94 | }
95 | }
96 | }
97 |
98 | block_count[0] = tmp_blk_cnt;
99 | column_count[0] = tmp_col_cnt;
100 | }
101 |
102 | void convert_vertical_slash_indexes_64x64(
103 | const int* seqlens, // [BATCH, ]
104 | const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
105 | const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
106 | int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
107 | int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
108 | int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
109 | int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
110 | int BATCH_SIZE,
111 | int N_HEADS,
112 | int N_ROWS,
113 | int NNZ_V,
114 | int NNZ_S
115 | ) {
116 | const int BLOCK_SIZE_M = 64;
117 | const int BLOCK_SIZE_N = 64;
118 | const int N_THREADS = 64;
119 | const dim3 dimBlock(N_THREADS);
120 | const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
121 | convert_vertical_slash_indexes_kernel<<>>(
122 | seqlens, vertical_indexes, slash_indexes,
123 | block_count, block_offset, column_count, column_index,
124 | N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
125 | );
126 | }
127 |
128 | std::vector convert_vertical_slash_indexes(
129 | torch::Tensor seqlens, // [BATCH, ]
130 | torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
131 | torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
132 | int context_size,
133 | int block_size_M,
134 | int block_size_N
135 | ) {
136 | assert(block_size_M == 64);
137 | assert(block_size_N == 64);
138 |
139 | cudaSetDevice(seqlens.get_device());
140 |
141 | int batch_size = slash_indexes.size(0);
142 | int num_heads = slash_indexes.size(1);
143 | int nnz_slash = slash_indexes.size(2);
144 | int nnz_vertical = vertical_indexes.size(2);
145 | int num_rows = (context_size + block_size_M - 1) / block_size_M;
146 |
147 | torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
148 | torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
149 | torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
150 | torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
151 |
152 | convert_vertical_slash_indexes_64x64(
153 | seqlens.data_ptr(),
154 | vertical_indexes.data_ptr(),
155 | slash_indexes.data_ptr(),
156 | block_count.data_ptr(),
157 | block_offset.data_ptr(),
158 | column_count.data_ptr(),
159 | column_index.data_ptr(),
160 | batch_size,
161 | num_heads,
162 | num_rows,
163 | nnz_vertical,
164 | nnz_slash
165 | );
166 |
167 | return { block_count, block_offset, column_count, column_index };
168 | }
169 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/minference/vertical_and_slash.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Microsoft
2 | # Licensed under The MIT License [see LICENSE for details]
3 | # Code adopted from https://github.com/microsoft/MInference
4 |
5 | import math
6 | import torch
7 | import triton
8 | import triton.language as tl
9 |
10 | from sparse_frontier.modelling.attention.minference.minference import convert_vertical_slash_indexes
11 |
12 |
13 | @triton.jit
14 | def _triton_mixed_sparse_attn_fwd_kernel(
15 | Q, K, V, seqlens, sm_scale,
16 | block_count, block_offset, column_count, column_index,
17 | Out,
18 | stride_qz, stride_qh, stride_qm, stride_qk,
19 | stride_kz, stride_kh, stride_kn, stride_kk,
20 | stride_vz, stride_vh, stride_vn, stride_vk,
21 | stride_oz, stride_oh, stride_om, stride_ok,
22 | Z, H, N_CTX,
23 | NUM_ROWS, NNZ_S, NNZ_V,
24 | BLOCK_M: tl.constexpr,
25 | BLOCK_N: tl.constexpr,
26 | BLOCK_DMODEL: tl.constexpr,
27 | dtype: tl.constexpr,
28 | ):
29 | start_m = tl.program_id(0)
30 | off_hz = tl.program_id(1)
31 |
32 | seqlen = tl.load(seqlens + off_hz // H)
33 | if start_m * BLOCK_M >= seqlen:
34 | return
35 |
36 | # initialize offsets
37 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
38 | offs_n = tl.arange(0, BLOCK_N)
39 | offs_d = tl.arange(0, BLOCK_DMODEL)
40 |
41 | qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh
42 | kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh
43 |
44 | q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
45 | k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk
46 | v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk
47 | o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
48 |
49 | num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m)
50 | blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S
51 | num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m)
52 | cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V
53 |
54 | # initialize pointer to m and l
55 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
56 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
57 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
58 | # scale sm_scale by log_2(e) and use
59 | # 2^x instead of exp in the loop because CSE and LICM
60 | # don't work as expected with `exp` in the loop
61 | qk_scale = sm_scale * 1.44269504
62 | # load q: it will stay in SRAM throughout
63 | q = tl.load(q_ptrs)
64 | q = (q * qk_scale).to(dtype)
65 |
66 | # loop over k, v and update accumulator
67 | m_mask = offs_m[:, None] < seqlen
68 |
69 | for block_index in range(num_blks):
70 | start_n = tl.load(blks_ptr + block_index)
71 | cols = start_n + offs_n
72 | n_mask = cols < seqlen
73 | # -- load k, v --
74 | k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)
75 | v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)
76 | # -- compute qk --
77 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
78 | causal_mask = cols[None, :] <= offs_m[:, None]
79 | qk = tl.where(m_mask & causal_mask, qk, float("-inf"))
80 | qk += tl.dot(q, k)
81 | # -- compute scaling constant --
82 | m_i_new = tl.maximum(m_i, tl.max(qk, 1))
83 | alpha = tl.math.exp2(m_i - m_i_new)
84 | p = tl.math.exp2(qk - m_i_new[:, None])
85 | # -- scale and update acc --
86 | acc_scale = l_i * 0 + alpha # workaround some compiler bug
87 | acc *= acc_scale[:, None]
88 | acc += tl.dot(p.to(dtype), v)
89 | # -- update m_i and l_i --
90 | l_i = l_i * alpha + tl.sum(p, 1)
91 | m_i = m_i_new
92 |
93 | for start_n in range(0, num_cols, BLOCK_N):
94 | n_mask = start_n + offs_n < num_cols
95 | cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0)
96 | # -- load k, v --
97 | k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)
98 | v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)
99 | # -- compute qk --
100 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
101 | qk = tl.where(m_mask & n_mask, qk, float("-inf"))
102 | qk += tl.dot(q, k)
103 | # -- compute scaling constant --
104 | m_i_new = tl.maximum(m_i, tl.max(qk, 1))
105 | alpha = tl.math.exp2(m_i - m_i_new)
106 | p = tl.math.exp2(qk - m_i_new[:, None])
107 | # -- scale and update acc --
108 | acc_scale = l_i * 0 + alpha # workaround some compiler bug
109 | acc *= acc_scale[:, None]
110 | acc += tl.dot(p.to(dtype), v)
111 | # -- update m_i and l_i --
112 | l_i = l_i * alpha + tl.sum(p, 1)
113 | m_i = m_i_new
114 |
115 | # write back O
116 | acc /= l_i[:, None]
117 | # acc = tl.where(m_mask, acc / l_i[:, None], 0.0)
118 | tl.store(o_ptrs, acc.to(dtype), mask=m_mask)
119 |
120 |
121 | def _triton_mixed_sparse_attention(
122 | q: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
123 | k: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
124 | v: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
125 | seqlens: torch.Tensor, # [BATCH, ]
126 | block_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
127 | block_offset: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
128 | column_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
129 | column_index: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
130 | sm_scale: float,
131 | block_size_M: int = 64,
132 | block_size_N: int = 64,
133 | ) -> torch.Tensor:
134 | # shape constraints
135 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
136 | assert Lq == Lk and Lk == Lv
137 | assert Lk in {16, 32, 64, 128}
138 | o = torch.zeros_like(q)
139 | grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)
140 | dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16
141 | _triton_mixed_sparse_attn_fwd_kernel[grid](
142 | q, k, v, seqlens, sm_scale,
143 | block_count, block_offset, column_count, column_index,
144 | o,
145 | q.stride(0), q.stride(1), q.stride(2), q.stride(3),
146 | k.stride(0), k.stride(1), k.stride(2), k.stride(3),
147 | v.stride(0), v.stride(1), v.stride(2), v.stride(3),
148 | o.stride(0), o.stride(1), o.stride(2), o.stride(3),
149 | q.shape[0], q.shape[1], q.shape[2],
150 | block_count.shape[-1], block_offset.shape[-1], column_index.shape[-1],
151 | BLOCK_M=block_size_M, BLOCK_N=block_size_N,
152 | BLOCK_DMODEL=Lk,
153 | dtype=dtype,
154 | num_warps=4, num_stages=1,
155 | )
156 |
157 | return o
158 |
159 |
160 | def vertical_slash_sparse_attention(
161 | query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
162 | key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
163 | value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
164 | v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
165 | s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
166 | block_size_M: int = 64,
167 | block_size_N: int = 64,
168 | ):
169 | batch_size, num_heads, context_size, head_dim = query.shape
170 | pad = block_size_M - (context_size & (block_size_M - 1))
171 | query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0])
172 | key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0])
173 | value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0])
174 |
175 | if head_dim not in [16, 32, 64, 128, 256, 512]:
176 | target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
177 | query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
178 | key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
179 | value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
180 |
181 | v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
182 | s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
183 | seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device)
184 | sm_scale = head_dim ** -0.5
185 | block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
186 | seqlens, v_idx, s_idx, context_size, block_size_M, block_size_N,
187 | )
188 | sparsity = calc_sparsity(block_count, column_count, context_size)
189 | out = _triton_mixed_sparse_attention(
190 | query, key, value, seqlens,
191 | block_count, block_offset, column_count, column_index,
192 | sm_scale, block_size_M, block_size_N,
193 | )
194 | return out[..., :context_size, :head_dim], sparsity
195 |
196 |
197 | def calc_sparsity(block_count, column_count, seq_len):
198 | block_cells = block_count.sum(dim=-1) * 64 * 64
199 | column_cells = column_count.sum(dim=-1) * 64
200 | total_cells = seq_len * (seq_len + 1) // 2
201 | return 1 - (block_cells + column_cells) / total_cells
202 |
203 |
204 | def sum_over_diagonals(matrix: torch.Tensor) -> torch.Tensor:
205 | """Efficiently sum values along diagonals of the attention matrix.
206 |
207 | This function computes the sum of values along each diagonal of a 4D attention matrix.
208 | It uses an efficient strided implementation to avoid explicit diagonal extraction.
209 |
210 | Args:
211 | matrix: Input attention matrix of shape (batch_size, num_heads, queries, keys)
212 | where queries and keys are sequence lengths
213 |
214 | Returns:
215 | Tensor of shape (batch_size, num_heads, queries + keys - 1) containing the
216 | summed values for each diagonal. The diagonals are ordered from top-right
217 | to bottom-left, with the main diagonal at index queries-1.
218 | """
219 | batch_size, num_heads, queries, keys = matrix.shape
220 | zero_matrix = torch.zeros((batch_size, num_heads, queries, queries), device=matrix.device)
221 | matrix_padded = torch.cat((zero_matrix, matrix, zero_matrix), -1)
222 |
223 | matrix_strided = matrix_padded.as_strided(
224 | (batch_size, num_heads, queries, queries + keys),
225 | (num_heads * queries * (2 * queries + keys),
226 | queries * (2 * queries + keys),
227 | 2 * queries + keys + 1, 1)
228 | )
229 | return torch.sum(matrix_strided, 2)[:, :, 1:]
230 |
231 |
232 | def vertical_and_slash_kernel(
233 | q: torch.Tensor,
234 | k: torch.Tensor,
235 | v: torch.Tensor,
236 | vertical_size: int,
237 | slash_size: int,
238 | last_q: int = 64,
239 | inf_value: float = float('inf'),
240 | topk_vertical_inf: int = 4,
241 | topk_slash_inf: int = 64,
242 | ):
243 | """
244 | Compute the vertical and slash kernel for sparse attention.
245 |
246 | Args:
247 | q: Query tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD]
248 | k: Key tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD]
249 | v: Value tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD]
250 | vertical_size: Size of the vertical attention
251 | slash_size: Size of the slash attention
252 | last_q: Number of last queries to consider (default: 64)
253 | inf_value: Value to use for infinity (default: float('inf'))
254 | topk_vertical_inf: Number of top-k vertical elements to set to infinity (default: 30)
255 | topk_slash_inf: Number of top-k slash elements to set to infinity (default: 100)
256 |
257 | Returns:
258 | Output tensor after applying vertical and slash sparse attention.
259 | """
260 | arange = torch.arange(last_q, device=q.device)
261 | LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :]
262 |
263 | _, _, q_len, d = q.shape
264 |
265 | # Compute scaled dot-product attention
266 | qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) / math.sqrt(d)
267 | qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[..., -last_q:, -last_q:], qk[:, :, :, -last_q:], -inf_value)
268 | qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
269 |
270 | assert topk_vertical_inf <= vertical_size
271 | assert topk_slash_inf <= slash_size
272 |
273 | # Compute top verticals
274 | vertical = qk.sum(-2, keepdim=True)
275 | vertical[..., :topk_vertical_inf] = inf_value
276 | vertical_topk = torch.topk(vertical, vertical_size, -1).indices
277 |
278 | # # Compute top slashes
279 | slash = sum_over_diagonals(qk)[..., :-last_q + 1]
280 | slash[..., -topk_slash_inf:] = inf_value
281 | slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices
282 |
283 | return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash)
284 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/attention/registry.py:
--------------------------------------------------------------------------------
1 | from .efficient_prefilling import (
2 | DenseAttention,
3 | VerticalAndSlashAttentionMInference,
4 | BlockSparseAttentionMInference,
5 | FlexPrefill,
6 | )
7 | from .efficient_decoding import QuestAttention
8 | from .kv_compression import SnapKVCompression, AdaSnapKVCompression
9 | from sparse_frontier.utils import GlobalSettings
10 | from .handler import AttentionHandler
11 |
12 |
13 | ATTENTION_REGISTRY = {
14 | 'dense': DenseAttention,
15 | 'vertical_and_slash': VerticalAndSlashAttentionMInference,
16 | 'block_sparse': BlockSparseAttentionMInference,
17 | 'snapkv': SnapKVCompression,
18 | 'ada_snapkv': AdaSnapKVCompression,
19 | 'quest': QuestAttention,
20 | 'flexprefill': FlexPrefill,
21 | }
22 |
23 |
24 | def get_attention():
25 | return GlobalSettings.get('ATTENTION', DenseAttention())
26 |
27 |
28 | def get_attention_handler() -> AttentionHandler:
29 | return GlobalSettings.get('ATTENTION_HANDLER')
30 |
31 |
32 | def configure_attention():
33 | cfg = GlobalSettings.get('cfg')
34 |
35 | attention_args = cfg.attention.get('args', {})
36 |
37 | if cfg.attention.name == 'quest':
38 | block_size = attention_args.page_size
39 | else:
40 | block_size = cfg.kv_cache_block_size
41 |
42 | # Configure attention handler
43 | attention_handler = AttentionHandler(
44 | tp_size=cfg.tp,
45 | model_q_heads=cfg.model.num_q_heads,
46 | model_kv_heads=cfg.model.num_kv_heads,
47 | model_layers=cfg.model.num_layers,
48 | max_input_tokens=cfg.max_input_tokens,
49 | max_output_tokens=cfg.max_output_tokens,
50 | block_size=block_size,
51 | )
52 | GlobalSettings.set('ATTENTION_HANDLER', attention_handler)
53 |
54 | if cfg.attention.name == 'quest':
55 | extra_args = {
56 | 'num_layers': cfg.model.num_layers,
57 | 'max_input_tokens': cfg.max_input_tokens,
58 | 'max_output_tokens': cfg.max_output_tokens,
59 | }
60 | else:
61 | extra_args = {}
62 |
63 | attention = ATTENTION_REGISTRY[cfg.attention.name](**attention_args, **extra_args)
64 | GlobalSettings.set('ATTENTION', attention)
65 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/models/__init__.py
--------------------------------------------------------------------------------
/sparse_frontier/modelling/models/abstract_model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Any
3 | import torch
4 |
5 |
6 | class AbstractModel(ABC):
7 | @abstractmethod
8 | def __init__(
9 | self,
10 | model_path: str,
11 | max_input_tokens: int = 8192,
12 | max_output_tokens: int = 256,
13 | device: torch.device = None,
14 | dtype: torch.dtype = None,
15 | ) -> None:
16 | """Initialize the model.
17 |
18 | Args:
19 | model_path: Path or identifier for the model
20 | max_input_tokens: Maximum number of input tokens
21 | max_output_tokens: Maximum number of new tokens to generate
22 | device: Device to run the model on
23 | dtype: Data type for model parameters
24 | """
25 | pass
26 |
27 | @abstractmethod
28 | def _load_model(self, model_path: str) -> Any:
29 | """Load the underlying model.
30 |
31 | Args:
32 | model_path: Path or identifier for the model
33 |
34 | Returns:
35 | The loaded model
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def _greedy_config(self, max_output_tokens: int) -> Dict:
41 | """Get configuration for greedy generation.
42 |
43 | Args:
44 | max_output_tokens: Maximum number of new tokens to generate
45 |
46 | Returns:
47 | Dictionary of generation configuration parameters
48 | """
49 | pass
50 |
51 | @abstractmethod
52 | def generate(
53 | self,
54 | input_text: str,
55 | max_output_tokens: int = None,
56 | ) -> str:
57 | """Generate text from input.
58 |
59 | Args:
60 | input_text: The input text to generate from
61 | max_output_tokens: Maximum number of new tokens to generate
62 |
63 | Returns:
64 | Generated text string
65 | """
66 | pass
67 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/models/vllm_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import torch
4 | from vllm import LLM, SamplingParams
5 | from vllm.attention.backends.abstract import AttentionType
6 | from vllm.attention.backends.flash_attn import (
7 | FlashAttentionMetadata,
8 | get_num_prefill_decode_query_kv_tokens,
9 | )
10 |
11 | from sparse_frontier.modelling.attention.registry import get_attention, get_attention_handler
12 | from sparse_frontier.utils.globals import set_vllm_profiling_done, is_vllm_profiling_done
13 | from .abstract_model import AbstractModel
14 | from sparse_frontier.modelling.tokenizer import Tokenizer
15 |
16 |
17 | class VLLMModel(AbstractModel):
18 | def __init__(
19 | self,
20 | model_path: str,
21 | max_input_tokens: int = 8192,
22 | max_output_tokens: int = 256,
23 | dtype: torch.dtype = None,
24 | tensor_parallel_size: int = 1,
25 | seed: Optional[int] = 43,
26 | ):
27 | """
28 | vLLM uses forking to run the TP. Therefore we can't initialise CUDA,
29 | before the fork, as it will trigger an error. That's why we don't
30 | call get_device() etc.
31 | """
32 | self.model_path = model_path
33 | self.max_input_tokens = max_input_tokens
34 | self.max_output_tokens = max_output_tokens
35 | self.dtype = dtype or torch.bfloat16
36 | self.tensor_parallel_size = tensor_parallel_size
37 | self.seed = seed
38 |
39 | set_vllm_profiling_done(False)
40 |
41 | assert not torch.cuda.is_initialized(), "CUDA is not initialized"
42 | self.model = self._load_model(self.model_path)
43 | self.tokenizer = Tokenizer(self.model_path)
44 |
45 | def _load_model(self, model_path: str) -> LLM:
46 | if 'Qwen' in model_path and self.max_input_tokens + self.max_output_tokens > 32768:
47 | factor = (self.max_input_tokens + self.max_output_tokens) / 32768
48 | hf_overrides = {
49 | "rope_scaling": {
50 | "factor": factor,
51 | "original_max_position_embeddings": 32768,
52 | "rope_type": "yarn"
53 | }
54 | }
55 | else:
56 | hf_overrides = {}
57 |
58 | model = LLM(
59 | model=model_path,
60 | skip_tokenizer_init=True,
61 | trust_remote_code=True,
62 | enforce_eager=True,
63 | seed=self.seed,
64 | gpu_memory_utilization=0.9,
65 | max_num_batched_tokens=self.max_input_tokens + self.max_output_tokens,
66 | max_model_len=self.max_input_tokens + self.max_output_tokens,
67 | enable_chunked_prefill=False,
68 | tensor_parallel_size=self.tensor_parallel_size,
69 | hf_overrides=hf_overrides,
70 | )
71 |
72 | # Statistics has been accumulated during vLLM profiling
73 | get_attention().reset_sparsity_statistics()
74 | set_vllm_profiling_done(True)
75 |
76 | return model
77 |
78 | def _greedy_config(self, max_output_tokens: int) -> Dict:
79 | return {
80 | 'sampling_params': SamplingParams(
81 | max_tokens=max_output_tokens,
82 | temperature=0,
83 | ),
84 | 'use_tqdm': False,
85 | }
86 |
87 | @torch.no_grad()
88 | def generate(
89 | self,
90 | input_text: str,
91 | max_output_tokens: int = None,
92 | ) -> str:
93 | max_output_tokens = max_output_tokens or self.max_output_tokens
94 | model_input = self.tokenizer.encode_for_generation(input_text, return_tensors=False)
95 |
96 | output = self.model.generate(
97 | prompt_token_ids=model_input['input_ids'],
98 | **self._greedy_config(max_output_tokens),
99 | )
100 |
101 | output_ids = output[0].__dict__['outputs'][0].token_ids
102 | decoded = self.tokenizer.decode(output_ids)
103 |
104 | return {
105 | 'text': decoded[0] if isinstance(decoded, list) else decoded,
106 | 'output_tokens_len': len(output_ids),
107 | }
108 |
109 |
110 |
111 | def vllm_patched_forward(
112 | self,
113 | query: torch.Tensor,
114 | key: torch.Tensor,
115 | value: torch.Tensor,
116 | kv_cache: torch.Tensor,
117 | attn_metadata: FlashAttentionMetadata,
118 | k_scale: float = 1.0,
119 | v_scale: float = 1.0,
120 | attn_type: AttentionType = AttentionType.DECODER,
121 | output: Optional[torch.Tensor] = None,
122 | ) -> torch.Tensor:
123 | """Forward pass with FlashAttention.
124 |
125 | Args:
126 | query: shape = [num_tokens, num_heads, head_size]
127 | key: shape = [num_tokens, num_kv_heads, head_size]
128 | value: shape = [num_tokens, num_kv_heads, head_size]
129 | output: shape = [num_tokens, num_heads, head_size]
130 | kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
131 | attn_metadata: Metadata for attention.
132 | Returns:
133 | shape = [num_tokens, num_heads * head_size]
134 | """
135 | (num_prefill_query_tokens, num_prefill_kv_tokens, _) = \
136 | get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
137 |
138 | if attn_metadata.prefill_metadata:
139 | get_attention_handler().__call__(
140 | queries=query[:num_prefill_query_tokens],
141 | keys=key[:num_prefill_kv_tokens],
142 | values=value[:num_prefill_kv_tokens],
143 | kv_cache=kv_cache,
144 | output=output[:num_prefill_query_tokens],
145 | )
146 | else:
147 | get_attention_handler().__call__(
148 | queries=query[num_prefill_query_tokens:],
149 | keys=key[num_prefill_query_tokens:],
150 | values=value[num_prefill_query_tokens:],
151 | kv_cache=kv_cache,
152 | output=output[num_prefill_query_tokens:],
153 | )
154 |
155 | return output
156 |
157 |
158 | def swap_vllm_attention():
159 | from vllm.attention.backends.flash_attn import FlashAttentionImpl
160 | FlashAttentionImpl.forward = vllm_patched_forward
161 |
--------------------------------------------------------------------------------
/sparse_frontier/modelling/tokenizer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from typing import List, Dict, Union
3 |
4 | import torch
5 |
6 | class Tokenizer:
7 | def __init__(
8 | self,
9 | model_path: str,
10 | device: torch.device = None,
11 | ) -> None:
12 | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 |
14 | self.tokenizer = AutoTokenizer.from_pretrained(
15 | model_path,
16 | trust_remote_code=True,
17 | )
18 |
19 | if self.tokenizer.pad_token is None:
20 | self.tokenizer.pad_token = self.tokenizer.eos_token
21 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
22 |
23 | @property
24 | def pad_token_id(self) -> int:
25 | return self.tokenizer.pad_token_id
26 |
27 | @property
28 | def model_max_length(self) -> int:
29 | return self.tokenizer.model_max_length
30 |
31 | def text_to_tokens(self, input_text: str) -> List[int]:
32 | return self.tokenizer.encode(input_text, add_special_tokens=False)
33 |
34 | def encode_for_generation(self, input_text: str, return_tensors: bool = True) -> Dict:
35 | input_text_with_prompt = self.tokenizer.apply_chat_template(
36 | [{"role": "user", "content": input_text}],
37 | tokenize=False,
38 | add_generation_prompt=True
39 | )
40 |
41 | if return_tensors:
42 | encoded_input = self.tokenizer(
43 | input_text_with_prompt,
44 | return_tensors="pt",
45 | add_special_tokens=False,
46 | ).to(self.device)
47 |
48 | return {
49 | **encoded_input,
50 | 'input_length': encoded_input['input_ids'].size(1),
51 | }
52 | else:
53 | input_ids = self.text_to_tokens(input_text_with_prompt)
54 | return {
55 | 'input_ids': input_ids,
56 | 'input_length': len(input_ids),
57 | }
58 |
59 | def decode(self, outputs: Union[List[int], List[List[int]]], input_length: int = 0) -> Union[str, List[str]]:
60 | if isinstance(outputs[0], int):
61 | # Single list of tokens
62 | return self.tokenizer.decode(
63 | outputs[input_length:],
64 | skip_special_tokens=True
65 | )
66 | else:
67 | # List of lists of tokens
68 | return [
69 | self.tokenizer.decode(
70 | output[input_length:],
71 | skip_special_tokens=True
72 | )
73 | for output in outputs
74 | ]
75 |
--------------------------------------------------------------------------------
/sparse_frontier/prediction.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from concurrent.futures import ProcessPoolExecutor, as_completed
4 | from typing import Dict, Any
5 |
6 | from tqdm import tqdm
7 |
8 | from sparse_frontier.utils.data import load_data_without_predictions, get_pred_path
9 | from sparse_frontier.utils.general import get_free_ports, save_config
10 | from sparse_frontier.utils.globals import GlobalSettings
11 |
12 |
13 | model = None
14 |
15 |
16 | def process_sample(sample: Dict[str, Any]) -> Dict[str, Any]:
17 | """Process a single sample through the model.
18 |
19 | Args:
20 | sample: Dictionary containing input text and metadata
21 |
22 | Returns:
23 | Sample dictionary augmented with model prediction
24 | """
25 | global model
26 | from sparse_frontier.modelling.attention.registry import get_attention
27 |
28 | get_attention().reset_sparsity_statistics()
29 |
30 | output = model.generate(sample['input_text'])
31 |
32 | output_dict = {
33 | 'pred': output['text'],
34 | 'output_tokens_len': output['output_tokens_len'],
35 | 'sparsity': get_attention().calculate_sparsity(),
36 | 'index': sample['index']
37 | }
38 |
39 | return output_dict
40 |
41 |
42 | def init_worker() -> None:
43 | """Initialize worker process with GPU configuration and VLLM model."""
44 | import torch.multiprocessing as mp
45 | from sparse_frontier.modelling.models.vllm_model import VLLMModel
46 |
47 | # Explicitly set tokenizers parallelism for each worker
48 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
49 |
50 | cfg = GlobalSettings.get('cfg')
51 |
52 | # Check if CUDA_VISIBLE_DEVICES is already set
53 | preset_gpus = os.environ.get('CUDA_VISIBLE_DEVICES')
54 | if preset_gpus:
55 | available_gpus = [int(gpu) for gpu in preset_gpus.split(',')]
56 | if len(available_gpus) < cfg.gpus:
57 | raise ValueError(f"Number of available GPUs ({len(available_gpus)}) is less than required ({cfg.gpus})")
58 | else:
59 | available_gpus = list(range(cfg.gpus))
60 |
61 | # Set GPU device visibility based on worker index
62 | if len(mp.current_process()._identity) > 0:
63 | worker_index = mp.current_process()._identity[0] - 1
64 | # Calculate GPU slice for this worker from available GPUs
65 | worker_gpus = available_gpus[worker_index * cfg.tp:(worker_index + 1) * cfg.tp]
66 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, worker_gpus))
67 | # Get pre-allocated port for this worker
68 | worker_port = GlobalSettings.get('worker_ports')[worker_index]
69 | else:
70 | # Single worker case - use all available GPUs and first port
71 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, available_gpus))
72 | worker_port = GlobalSettings.get('worker_ports')[0]
73 |
74 | # Configure VLLM environment
75 | os.environ['VLLM_HOST_IP'] = 'localhost'
76 | os.environ['VLLM_PORT'] = str(worker_port)
77 |
78 | global model
79 | model = VLLMModel(
80 | model_path=cfg.model.path,
81 | max_input_tokens=cfg.max_input_tokens,
82 | max_output_tokens=cfg.max_output_tokens,
83 | tensor_parallel_size=cfg.tp,
84 | seed=cfg.random_seed
85 | )
86 |
87 |
88 | def predict_task() -> None:
89 | cfg = GlobalSettings.get('cfg')
90 |
91 | from sparse_frontier.modelling.attention.registry import configure_attention
92 | configure_attention()
93 |
94 | data = load_data_without_predictions()
95 |
96 | pred_path = get_pred_path()
97 |
98 | num_workers = cfg.gpus // cfg.tp
99 |
100 | # Get free ports for all workers at the start
101 | free_ports = get_free_ports(num_workers)
102 | GlobalSettings.set('worker_ports', free_ports)
103 |
104 | if num_workers == 1:
105 | # Single worker case - process samples sequentially
106 | init_worker()
107 | with open(pred_path, 'at', encoding="utf-8", buffering=1) as fout:
108 | for sample in tqdm(data, total=len(data)):
109 | sample_results = process_sample(sample)
110 | fout.write(json.dumps(sample_results) + '\n')
111 | fout.flush()
112 |
113 | global model
114 | model = None
115 | else:
116 | # Multi-worker case - process samples in parallel
117 | with ProcessPoolExecutor(max_workers=num_workers, initializer=init_worker) as executor, \
118 | open(pred_path, 'at', encoding="utf-8", buffering=1) as fout:
119 | futures = {executor.submit(process_sample, sample): sample for sample in data}
120 | for future in tqdm(as_completed(futures), total=len(data)):
121 | sample_results = future.result()
122 | fout.write(json.dumps(sample_results) + '\n')
123 | fout.flush()
124 |
125 | save_config(os.path.dirname(pred_path))
126 | print(f'Prediction for task {cfg.task.name} is done. Output is saved to {pred_path}.')
127 |
--------------------------------------------------------------------------------
/sparse_frontier/preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from sparse_frontier.utils import GlobalSettings
4 | from sparse_frontier.utils.data import write_jsonl, get_data_path
5 | from sparse_frontier.utils.general import save_config
6 | from sparse_frontier.tasks.registry import TASK_REGISTRY
7 |
8 |
9 | def check_args():
10 | from transformers import AutoTokenizer, AutoConfig
11 |
12 | cfg = GlobalSettings.get("cfg")
13 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.path, trust_remote_code=True)
14 | config = AutoConfig.from_pretrained(cfg.model.path, trust_remote_code=True)
15 |
16 | if tokenizer.model_max_length < cfg.max_input_tokens + cfg.max_output_tokens:
17 | raise ValueError(f"Model maximum sequence length ({tokenizer.model_max_length}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})")
18 |
19 | max_pos_embeddings = getattr(config, "max_position_embeddings", 131072)
20 | if max_pos_embeddings < cfg.max_input_tokens + cfg.max_output_tokens and 'qwen' not in cfg.model.name:
21 | raise ValueError(f"Model maximum position embeddings ({max_pos_embeddings}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})")
22 |
23 | seq_length = getattr(config, "seq_length", 131072)
24 | if seq_length < cfg.max_input_tokens + cfg.max_output_tokens:
25 | raise ValueError(f"Model maximum sequence length ({seq_length}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})")
26 |
27 |
28 | def get_task_generator():
29 | from sparse_frontier.modelling.tokenizer import Tokenizer
30 | cfg = GlobalSettings.get("cfg")
31 | task_kwargs = {
32 | 'num_samples': cfg.samples,
33 | 'max_input_tokens': cfg.max_input_tokens,
34 | 'max_output_tokens': cfg.max_output_tokens,
35 | 'tokenizer': Tokenizer(cfg.model.path, device='cpu'),
36 | 'random_seed': cfg.random_seed,
37 | **cfg.task.get('args', {}),
38 | }
39 | return TASK_REGISTRY[cfg.task.name](**task_kwargs)
40 |
41 |
42 | def prepare_task():
43 | cfg = GlobalSettings.get("cfg")
44 |
45 | # Explicitly enable tokenizers parallelism during preparation
46 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
47 |
48 | check_args()
49 |
50 | data_path = get_data_path()
51 |
52 | print(f"Preparing {cfg.task.name} with {cfg.samples} samples")
53 | generator = get_task_generator()
54 | samples = generator.generate_samples()
55 |
56 | write_jsonl(data_path, samples)
57 | save_config(os.path.dirname(data_path))
58 | print(f"Saved {cfg.task.name} with {cfg.samples} samples to {data_path}")
59 |
60 | # Disable tokenizers parallelism after preparation
61 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
62 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .registry import TASK_REGISTRY
2 |
3 | __all__ = [
4 | 'TASK_REGISTRY'
5 | ]
6 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/abstract_prompt.py:
--------------------------------------------------------------------------------
1 | # Base template with shared components
2 | _BASE_TEMPLATE = """{task_intro}
3 |
4 | {question_intro}
5 |
6 | {question_section}
7 |
8 |
9 | {context}
10 |
11 |
12 | {question_repeat_section}
13 |
14 | Instructions:
15 | 1. First, provide a brief explanation of your reasoning process. Explain how you identified
16 | the relevant information from the context and how you determined your answer.
17 | 2. Then, provide your final answer following this exact format:
18 |
19 | {answer_format}
20 |
21 |
22 | Your response must follow this structure exactly:
23 |
24 | Your explanation here...
25 |
26 |
27 | Your answer here...
28 |
29 |
30 | Important:
31 | {extra_instructions}
32 | - Keep your explanations clear, coherent, concise, and to the point.
33 | - Do not include any additional text, explanations, or reasoning in the answer section. Follow the answer format exactly.
34 | """
35 |
36 | # Template for single question tasks
37 | SINGLEQ_PROMPT_TEMPLATE = _BASE_TEMPLATE.format(
38 | task_intro="{task_intro}",
39 | context="{context}",
40 | question_intro="Below is your question. I will state it both before and after the context.",
41 | question_section="\n{question}\n",
42 | question_repeat_section="\n{question}\n",
43 | answer_format="{answer_format}",
44 | extra_instructions="{extra_instructions}"
45 | )
46 |
47 | # Template for multiple question tasks
48 | MULTIPLEQ_PROMPT_TEMPLATE = _BASE_TEMPLATE.format(
49 | task_intro="{task_intro}",
50 | context="{context}",
51 | question_intro="Below are your questions. I will state them both before and after the context.",
52 | question_section="\n{question}\n",
53 | question_repeat_section="\n{question}\n",
54 | answer_format="{answer_format}",
55 | extra_instructions="{extra_instructions}"
56 | )
57 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/abstract_sample.py:
--------------------------------------------------------------------------------
1 | import random
2 | from abc import ABC, abstractmethod
3 | from typing import Any, Dict
4 |
5 | from sparse_frontier.modelling.tokenizer import Tokenizer
6 |
7 |
8 | class AbstractSample(ABC):
9 | """Base class for task samples.
10 |
11 | Handles common sample functionality like length validation and conversion to dict format.
12 | Subclasses must implement _generate_sample().
13 | """
14 |
15 | def __init__(
16 | self,
17 | sample_id: int,
18 | random_seed: int,
19 | max_tokens: int,
20 | tokenizer: Tokenizer,
21 | task_params: Dict[str, Any]
22 | ) -> None:
23 | """Initialize sample parameters.
24 |
25 | Args:
26 | sample_id: Sample index number
27 | random_seed: Random seed for reproducibility
28 | max_tokens: Maximum input sequence length in tokens
29 | tokenizer: Tokenizer for text encoding/decoding
30 | task_params: Dictionary of task parameters
31 | """
32 | self.sample_id = sample_id
33 | self.max_tokens = max_tokens
34 | self.random_obj = random.Random(random_seed + sample_id)
35 | self.tokenizer = tokenizer
36 | self.task_params = task_params
37 | self.input_text, self.gold_answer, self.extra_data = self._generate_sample()
38 |
39 | @abstractmethod
40 | def _generate_sample(self) -> tuple[str, str, Dict[str, Any]]:
41 | """Generate the input text, gold answer and extra data for this sample.
42 |
43 | Returns:
44 | Tuple containing:
45 | - Input text string to be provided to the model
46 | - Expected gold answer string
47 | - Dictionary of extra data to include in to_dict output
48 | """
49 | pass
50 |
51 | def to_dict(self) -> Dict[str, Any]:
52 | """Convert sample to dictionary format.
53 |
54 | Returns:
55 | Dictionary containing sample data and any extra data from _generate_sample
56 | """
57 | # Handle both single string and list of strings for gold_answer
58 | reference_answer = self.gold_answer[0] if isinstance(self.gold_answer, list) else self.gold_answer
59 |
60 | total_sample_length = len(self.tokenizer.encode_for_generation(self.input_text, return_tensors=False)['input_ids']) + \
61 | len(self.tokenizer.text_to_tokens(reference_answer))
62 |
63 | base_dict = {
64 | "index": self.sample_id,
65 | "input_text": self.input_text,
66 | "gold_answer": self.gold_answer,
67 | "length": total_sample_length
68 | }
69 | base_dict.update(self.extra_data)
70 | return base_dict
71 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/abstract_task.py:
--------------------------------------------------------------------------------
1 | import random
2 | from abc import ABC, abstractmethod
3 | from typing import Any, Dict, List
4 |
5 | from sparse_frontier.modelling.tokenizer import Tokenizer
6 |
7 |
8 | class AbstractTask(ABC):
9 | """Base class for defining evaluation tasks.
10 |
11 | Handles common task functionality like sample generation and evaluation.
12 | Subclasses must implement check_inputs(), generate_samples() and evaluate().
13 | """
14 |
15 | def __init__(
16 | self,
17 | num_samples: int,
18 | max_input_tokens: int,
19 | max_output_tokens: int,
20 | tokenizer: Tokenizer,
21 | random_seed: int,
22 | template_tokens: int = 64,
23 | **kwargs
24 | ) -> None:
25 | """Initialize task parameters.
26 |
27 | Args:
28 | num_samples: Number of samples to generate
29 | max_input_tokens: Maximum input sequence length in tokens
30 | max_output_tokens: Maximum output sequence length in tokens
31 | tokenizer: Tokenizer for text encoding/decoding
32 | random_seed: Random seed for reproducibility
33 | template_tokens: Approximate number of tokens for model's template
34 | """
35 | self.num_samples = num_samples
36 | self.max_input_tokens = max_input_tokens
37 | self.max_output_tokens = max_output_tokens
38 | self.tokenizer = tokenizer
39 | self.random_seed = random_seed
40 | self.template_tokens = template_tokens
41 | self.task_params = {}
42 | self.random_obj = random.Random(self.random_seed)
43 |
44 | def __getattr__(self, name: str) -> Any:
45 | """Get task parameter by name."""
46 | if name in self.params:
47 | return self.params[name]
48 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
49 |
50 | def check_sample_length(self, input_text: str, gold_answer: str | list[str]) -> None:
51 | """Validate input and output sequence lengths.
52 |
53 | Args:
54 | input_text: Model input text
55 | gold_answer: Expected model output, either a single string or list of possible answers
56 |
57 | Raises:
58 | AssertionError: If sequences exceed maximum allowed lengths
59 | """
60 | # Handle both single string and list of strings
61 | if isinstance(gold_answer, str):
62 | gold_answers = [gold_answer]
63 | else:
64 | gold_answers = gold_answer
65 |
66 | # Check length of each possible answer
67 | for answer in gold_answers:
68 | gold_length = len(self.tokenizer.text_to_tokens(answer))
69 | if gold_length > self.max_output_tokens:
70 | raise AssertionError(
71 | f"Gold answer too long: {gold_length} tokens, max {self.max_output_tokens}"
72 | )
73 |
74 | input_length = len(self.tokenizer.text_to_tokens(input_text))
75 | if input_length > self.max_input_tokens - self.template_tokens:
76 | raise AssertionError(
77 | f"Only input too long: {input_length} tokens, max {self.max_input_tokens - self.template_tokens} tokens."
78 | )
79 |
80 | final_input = self.tokenizer.encode_for_generation(input_text, return_tensors=False)
81 | final_input_length = len(final_input["input_ids"])
82 |
83 | if final_input_length > self.max_input_tokens:
84 | raise AssertionError(
85 | f"Input + Template too long: {final_input_length} tokens, max {self.max_input_tokens} tokens."
86 | )
87 |
88 | if final_input_length < 0.90 * self.max_input_tokens:
89 | raise AssertionError(
90 | f"Input + Template too short: {final_input_length} tokens, min {int(0.90 * self.max_input_tokens)} tokens."
91 | )
92 |
93 | @abstractmethod
94 | def check_params(self) -> None:
95 | """Validate task-specific parameters.
96 |
97 | Raises:
98 | ValueError: If parameters are missing or invalid
99 | AssertionError: If parameters fail validation checks
100 | """
101 | pass
102 |
103 | @property
104 | @abstractmethod
105 | def sample_class(self):
106 | """Return the sample class to use for this task.
107 |
108 | Returns:
109 | A class that inherits from AbstractSample
110 | """
111 | pass
112 |
113 | def generate_samples(self) -> List[Dict[str, Any]]:
114 | """Generate task evaluation samples.
115 |
116 | Returns:
117 | List of sample dicts containing:
118 | input_text: Model input text
119 | gold_answer: Expected model output
120 | index: Sample index
121 | length: Input sequence length
122 | + Additional task-specific fields
123 | """
124 | samples = []
125 | for i in range(self.num_samples):
126 | sample = self.sample_class(
127 | sample_id=i,
128 | random_seed=self.random_seed,
129 | max_tokens=self.max_input_tokens - self.template_tokens,
130 | tokenizer=self.tokenizer,
131 | task_params=self.task_params,
132 | )
133 | self.check_sample_length(sample.input_text, sample.gold_answer)
134 | samples.append(sample.to_dict())
135 |
136 | return samples
137 |
138 | @staticmethod
139 | @abstractmethod
140 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, Any]:
141 | """Evaluate model predictions against gold answers.
142 |
143 | Args:
144 | predictions: List of sample dicts with model predictions in 'pred' field
145 |
146 | Returns:
147 | Dict of evaluation metrics
148 | """
149 | pass
150 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/qa/qa_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import List, Dict, Tuple
4 | from abc import ABC, abstractmethod
5 |
6 |
7 | def get_dataset(dataset_name: str):
8 | """Get initialized QA dataset based on dataset name.
9 |
10 | Args:
11 | dataset_name: Name of dataset to load ('squad', 'quality', or 'toeflqa')
12 |
13 | Returns:
14 | Initialized QA dataset with data loaded
15 | """
16 | # Get path to data directory relative to this file
17 | current_dir = Path(__file__).parent
18 | data_dir = current_dir / 'data'
19 |
20 | if dataset_name == "squad":
21 | data_path = data_dir / 'squad.json'
22 | return SQuADDataset(str(data_path))
23 | elif dataset_name == "quality":
24 | data_path = data_dir / 'quality.jsonl'
25 | return QualityDataset(str(data_path))
26 | elif dataset_name == "toeflqa":
27 | data_path = data_dir / 'toeflqa.jsonl'
28 | return TOEFLQADataset(str(data_path))
29 |
30 |
31 | class QADataset(ABC):
32 | """Abstract base class defining the interface for QA dataset parsers.
33 |
34 | All QA dataset parsers should inherit from this class and implement the read_data method.
35 |
36 | Output Format:
37 | - qa_samples: List[Dict] where each dict contains:
38 | - context_idx: int, index into unique_contexts list
39 | - question: str, the question text
40 | - answer: Union[List[str], str], either:
41 | - List[str] for open-ended QA (e.g. SQuAD)
42 | - str for multiple choice QA (e.g. Quality, TOEFL)
43 | - unique_contexts: List[str], list of all unique context passages
44 | """
45 |
46 | def __init__(self, data_path: str):
47 | """Initialize the QA dataset parser.
48 | """
49 | self.qa_samples, self.unique_contexts = self.read_data(data_path)
50 | self._remove_duplicate_questions()
51 |
52 | def _remove_duplicate_questions(self):
53 | """Remove samples with duplicate question-context pairs."""
54 | # Create a set to track seen question-context pairs
55 | seen_pairs = set()
56 | filtered_samples = []
57 |
58 | for sample in self.qa_samples:
59 | # Create a unique key from question and context
60 | key = (sample['question'].strip().lower(), sample['context_idx'])
61 |
62 | # Only keep samples with unique question-context pairs
63 | if key not in seen_pairs:
64 | seen_pairs.add(key)
65 | filtered_samples.append(sample)
66 |
67 | # Update qa_samples with deduplicated list
68 | self.qa_samples = filtered_samples
69 |
70 | @abstractmethod
71 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]:
72 | """Read and process a QA dataset file.
73 |
74 | Args:
75 | data_path: Path to dataset file
76 |
77 | Returns:
78 | Tuple containing:
79 | - List of processed QA samples with context indices
80 | - List of unique context documents
81 | """
82 | pass
83 |
84 | @property
85 | @abstractmethod
86 | def is_mcq(self) -> bool:
87 | """Whether this dataset contains multiple choice questions."""
88 | pass
89 |
90 |
91 | class SQuADDataset(QADataset):
92 | """Parser for SQuAD format datasets."""
93 |
94 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]:
95 | with open(data_path) as f:
96 | data = json.load(f)
97 |
98 | # Extract and deduplicate contexts
99 | all_contexts = [p['context'] for d in data['data'] for p in d['paragraphs']]
100 | unique_contexts = sorted(list(set(all_contexts)))
101 | context_to_idx = {context: idx for idx, context in enumerate(unique_contexts)}
102 |
103 | qa_samples = []
104 | for article in data['data']:
105 | for paragraph in article['paragraphs']:
106 | curr_context_idx = context_to_idx[paragraph['context']]
107 |
108 | for qa in paragraph['qas']:
109 | if not qa['is_impossible']:
110 | qa_samples.append({
111 | 'context_idx': curr_context_idx,
112 | 'question': qa['question'],
113 | 'answer': [a['text'] for a in qa['answers']],
114 | })
115 |
116 | return qa_samples, unique_contexts
117 |
118 | @property
119 | def is_mcq(self) -> bool:
120 | return False
121 |
122 |
123 | class JSONLinesQADataset(QADataset):
124 | """Base parser for JSONL format QA datasets (Quality and TOEFL)."""
125 |
126 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]:
127 | with open(data_path) as f:
128 | data = [json.loads(line) for line in f]
129 |
130 | # First pass: collect unique contexts
131 | unique_contexts = sorted(list({item['context'] for item in data}))
132 | context_to_idx = {context: idx for idx, context in enumerate(unique_contexts)}
133 |
134 | # Second pass: create QA samples
135 | qa_samples = []
136 | for item in data:
137 | context_idx = context_to_idx[item['context']]
138 | questions = item['questions']
139 | answers = item['answer']
140 |
141 | # Verify questions and answers match in length
142 | assert len(questions) == len(answers), \
143 | f"Mismatch in questions ({len(questions)}) and answers ({len(answers)}) length"
144 |
145 | # Create individual QA samples
146 | for q, a in zip(questions, answers):
147 | # For TOEFL QA, remove "Question: " prefix if present
148 | if isinstance(self, TOEFLQADataset) and q.startswith("Question: "):
149 | q = q[len("Question: "):]
150 |
151 | qa_samples.append({
152 | 'context_idx': context_idx,
153 | 'question': q,
154 | 'answer': a,
155 | })
156 |
157 | return qa_samples, unique_contexts
158 |
159 |
160 | class QualityDataset(JSONLinesQADataset):
161 | """Parser for Quality format datasets."""
162 |
163 | @property
164 | def is_mcq(self) -> bool:
165 | return True
166 |
167 |
168 | class TOEFLQADataset(JSONLinesQADataset):
169 | """Parser for TOEFL QA format datasets."""
170 |
171 | @property
172 | def is_mcq(self) -> bool:
173 | return True
174 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/qa/qa_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | from typing import Set
4 |
5 |
6 | def normalize_mcq_answer(answer: str) -> str:
7 | """Normalize multiple choice answer to extract just the letter choices.
8 |
9 | Args:
10 | answer: Raw answer string that may contain explanation tags and other text
11 |
12 | Returns:
13 | Normalized string containing only the answer letters (e.g., 'A' or 'AB')
14 | """
15 | # First try to extract answer from tags if present
16 | answer_match = re.search(r'(.*?)', answer, re.DOTALL)
17 | if answer_match:
18 | answer = answer_match.group(1)
19 |
20 | # Remove any "Question X" prefix
21 | answer = re.sub(r'Question\s*\d+\.?\s*', '', answer)
22 |
23 | # Clean up answer text
24 | answer = answer.upper()
25 | answer = re.sub(r'[^A-D]', '', answer)
26 |
27 | # Sort multiple choices to ensure consistent ordering (e.g., 'BA' -> 'AB')
28 | answer = ''.join(sorted(answer))
29 |
30 | return answer
31 |
32 |
33 | def extract_tagged_response(text: str) -> tuple[str, str]:
34 | """Extract explanation and answer from tagged response.
35 |
36 | Args:
37 | text: Full response text containing tagged sections
38 |
39 | Returns:
40 | Tuple of (explanation text, answer text)
41 | Returns empty strings if tags are not found
42 | """
43 | explanation = ''
44 | answer = ''
45 |
46 | explanation_match = re.search(r'(.*?)', text, re.DOTALL)
47 | if explanation_match:
48 | explanation = explanation_match.group(1).strip()
49 |
50 | answer_match = re.search(r'(.*?)', text, re.DOTALL)
51 | if answer_match:
52 | answer = answer_match.group(1).strip()
53 |
54 | return explanation, answer
55 |
56 |
57 | def normalize_answer(s: str) -> str:
58 | """Normalize open-ended answer for comparison.
59 |
60 | Args:
61 | s: Raw answer string
62 |
63 | Returns:
64 | Normalized answer with consistent formatting
65 | """
66 | def remove_articles(text):
67 | return re.sub(r"\b(a|an|the|and|or|about|to)\b", " ", text)
68 |
69 | def white_space_fix(text):
70 | return " ".join(text.split())
71 |
72 | def remove_punc(text):
73 | exclude = set(string.punctuation)
74 | return "".join(ch for ch in text if ch not in exclude)
75 |
76 | # Extract answer from tags if present
77 | answer_match = re.search(r'(.*?)', s, re.DOTALL)
78 | if answer_match:
79 | s = answer_match.group(1)
80 |
81 | return white_space_fix(remove_articles(remove_punc(s.lower())))
82 |
83 |
84 | def get_token_overlap(pred: str, gold: str) -> tuple[Set[str], Set[str]]:
85 | """Get overlapping tokens between prediction and gold answer.
86 |
87 | Args:
88 | pred: Predicted answer text
89 | gold: Gold answer text
90 |
91 | Returns:
92 | Tuple of (prediction tokens, gold tokens)
93 | """
94 | pred_tokens = set(pred.split())
95 | gold_tokens = set(gold.split())
96 | return pred_tokens, gold_tokens
97 |
98 |
99 | def f1_score(prediction: str, gold: str) -> float:
100 | """Calculate F1 score between prediction and gold answer.
101 |
102 | Args:
103 | prediction: Predicted answer text
104 | gold: Gold answer text
105 |
106 | Returns:
107 | F1 score between 0 and 1
108 | """
109 | prediction_tokens, gold_tokens = get_token_overlap(prediction, gold)
110 |
111 | true_positives = len(prediction_tokens & gold_tokens)
112 | false_positives = len(prediction_tokens - gold_tokens)
113 | false_negatives = len(gold_tokens - prediction_tokens)
114 |
115 | precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
116 | recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
117 |
118 | if precision + recall == 0:
119 | return 0.0
120 |
121 | f1 = 2 * (precision * recall) / (precision + recall)
122 | return round(f1, 2)
123 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/registry.py:
--------------------------------------------------------------------------------
1 | from .qa.qa_task import QATask
2 | from .ruler.cwe import CommonWordTask
3 | from .ruler.niah import NIAHTask
4 | from .ruler.vt import VariableTrackingTask
5 | from .story.multihop import MultiHopTask
6 | from .story.filtering import FilteringTask
7 | from .story.retrieval import RetrievalTask
8 |
9 |
10 | TASK_REGISTRY = {
11 | "qa_quality": QATask,
12 | "qa_squad": QATask,
13 | "qa_toefl": QATask,
14 | "ruler_cwe": CommonWordTask,
15 | "ruler_niah": NIAHTask,
16 | "ruler_vt": VariableTrackingTask,
17 | "story_multihop": MultiHopTask,
18 | "story_filtering": FilteringTask,
19 | "story_retrieval": RetrievalTask,
20 | }
21 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/ruler/cwe.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating common word frequency tracking capabilities of language models.
2 |
3 | This module implements a task where models need to identify frequently occurring
4 | words in a list containing both common and rare words. The task tests the model's ability to:
5 | 1. Track word frequencies in a long list
6 | 2. Identify words that appear more frequently than others
7 | 3. Ignore less frequent distractors
8 | """
9 |
10 | from typing import List, Dict, Any, Tuple
11 | import re
12 | import numpy as np
13 |
14 | from sparse_frontier.tasks.abstract_task import AbstractTask
15 | from sparse_frontier.tasks.abstract_sample import AbstractSample
16 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE
17 |
18 | TASK_INTRO = """You will be given a numbered list of words. Your task is to identify the most frequently occurring words. You should solve this task by carefully reading and analyzing the word list. Do not attempt to write code or use programming tools to count frequencies. This is a test of your ability to track word frequencies directly."""
19 |
20 | QUESTION_TEMPLATE = """The list contains exactly {num_common} words that appear {common_freq} times each. All other words appear {rare_freq} times each. The order of words in the list is randomized.
21 | Your task is to identify the {num_common} words that appear {common_freq} times each."""
22 |
23 | ANSWER_FORMAT = """1. word_one
24 | 2. word_two
25 | ...
26 | {num_common}. word_{num_common}
27 |
28 | Note: List exactly {num_common} words, one per line, numbered from 1 to {num_common}."""
29 |
30 |
31 | class CommonWordSample(AbstractSample):
32 | """Handles generation of individual common word frequency tracking samples."""
33 |
34 | def __init__(self, *args, **kwargs):
35 | super().__init__(*args, **kwargs)
36 |
37 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
38 | """Generate a single common word frequency tracking sample."""
39 |
40 | words = self.task_params['words'].copy()
41 | self.random_obj.shuffle(words)
42 |
43 | common_words = words[:self.task_params['num_common_words']]
44 | rare_words = words[self.task_params['num_common_words']:]
45 |
46 | # Create list of common words with repetitions
47 | common_word_list = common_words * self.task_params['common_word_frequency']
48 |
49 | # Binary search to find maximum number of rare words that fit
50 | left, right = 0, len(rare_words)
51 | while left < right:
52 | mid = (left + right + 1) // 2
53 |
54 | # Create test word list with current number of rare words
55 | test_words = (
56 | common_word_list +
57 | rare_words[:mid] * self.task_params['rare_word_frequency']
58 | )
59 | test_context = '\n'.join(f"{i+1}. {word}" for i, word in enumerate(test_words))
60 |
61 | # Format question with current parameters
62 | question = QUESTION_TEMPLATE.format(
63 | num_common=self.task_params['num_common_words'],
64 | common_freq=self.task_params['common_word_frequency'],
65 | rare_freq=self.task_params['rare_word_frequency']
66 | )
67 |
68 | # Calculate total tokens with full prompt template
69 | test_input = SINGLEQ_PROMPT_TEMPLATE.format(
70 | task_intro=TASK_INTRO,
71 | context=test_context,
72 | question=question,
73 | answer_format=ANSWER_FORMAT.format(
74 | num_common=self.task_params['num_common_words']
75 | ),
76 | extra_instructions=""
77 | )
78 | total_tokens = len(self.tokenizer.text_to_tokens(test_input))
79 |
80 | if total_tokens <= self.max_tokens:
81 | left = mid
82 | else:
83 | right = mid - 1
84 |
85 | num_rare_words = left
86 |
87 | # Ensure we have space for at least one rare word
88 | assert num_rare_words > 0, "No space for rare words after common words"
89 |
90 | # Create final word list with optimal number of rare words
91 | final_words = (
92 | common_word_list +
93 | rare_words[:num_rare_words] * self.task_params['rare_word_frequency']
94 | )
95 | self.random_obj.shuffle(final_words)
96 |
97 | # Format context and question
98 | context = '\n'.join(f"{i+1}. {word}" for i, word in enumerate(final_words))
99 | question = QUESTION_TEMPLATE.format(
100 | num_common=self.task_params['num_common_words'],
101 | common_freq=self.task_params['common_word_frequency'],
102 | rare_freq=self.task_params['rare_word_frequency']
103 | )
104 |
105 | input_text = SINGLEQ_PROMPT_TEMPLATE.format(
106 | task_intro=TASK_INTRO,
107 | context=context,
108 | question=question,
109 | answer_format=ANSWER_FORMAT.format(
110 | num_common=self.task_params['num_common_words']
111 | ),
112 | extra_instructions=""
113 | )
114 |
115 | gold_answer = '\n'.join(f"{i+1}. {word.lower()}" for i, word in enumerate(common_words))
116 |
117 | extra_data = {
118 | "num_total_words": len(final_words),
119 | "common_words": [w.lower() for w in common_words],
120 | "common_frequency": self.task_params['common_word_frequency'],
121 | "rare_frequency": self.task_params['rare_word_frequency']
122 | }
123 |
124 | return input_text, gold_answer, extra_data
125 |
126 |
127 | class CommonWordTask(AbstractTask):
128 | """Task for evaluating common word frequency tracking capabilities."""
129 |
130 | def __init__(
131 | self,
132 | common_word_frequency: int = 30,
133 | rare_word_frequency: int = 3,
134 | num_common_words: int = 10,
135 | **kwargs
136 | ) -> None:
137 | """Initialize common word frequency tracking task.
138 |
139 | Args:
140 | common_word_frequency: Number of times each common word appears
141 | rare_word_frequency: Number of times each rare word appears
142 | num_common_words: Number of common words to identify
143 | **kwargs: Additional arguments passed to parent class
144 | """
145 | super().__init__(**kwargs)
146 | self._create_word_list()
147 | self.task_params.update({
148 | 'common_word_frequency': common_word_frequency,
149 | 'rare_word_frequency': rare_word_frequency,
150 | 'num_common_words': num_common_words,
151 | 'words': self.words,
152 | })
153 | self.check_params()
154 |
155 | def _create_word_list(self) -> None:
156 | from wonderwords import random_word
157 | nouns = random_word._get_words_from_text_file("nounlist.txt")
158 | adjs = random_word._get_words_from_text_file("adjectivelist.txt")
159 | verbs = random_word._get_words_from_text_file("verblist.txt")
160 | words = nouns + adjs + verbs
161 | self.words = sorted(list(set(words)))
162 | self.random_obj.shuffle(self.words)
163 | self.words = [word for word in self.words if '-' not in word]
164 |
165 | def check_params(self) -> None:
166 | """Validate task parameters."""
167 | if not isinstance(self.task_params.get('common_word_frequency'), int):
168 | raise ValueError("common_word_frequency must be an integer")
169 | if not isinstance(self.task_params.get('rare_word_frequency'), int):
170 | raise ValueError("rare_word_frequency must be an integer")
171 | if not isinstance(self.task_params.get('num_common_words'), int):
172 | raise ValueError("num_common_words must be an integer")
173 |
174 | if self.task_params['common_word_frequency'] <= self.task_params['rare_word_frequency']:
175 | raise ValueError("common_word_frequency must be greater than rare_word_frequency")
176 | if self.task_params['num_common_words'] < 1:
177 | raise ValueError("num_common_words must be at least 1")
178 |
179 | @property
180 | def sample_class(self):
181 | return CommonWordSample
182 |
183 | @staticmethod
184 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, float]:
185 | """Evaluate model predictions against gold answers.
186 |
187 | For each prediction, calculates the intersection over union (IoU) between
188 | the predicted set of common words and the gold set.
189 |
190 | Returns mean IoU across all predictions and IoU variance.
191 | """
192 | sample_ious = []
193 |
194 | for pred in predictions:
195 | # Extract answer section
196 | answer_match = re.search(r'(.*?)', pred['pred'], re.DOTALL | re.IGNORECASE)
197 | if not answer_match:
198 | sample_ious.append(0.0)
199 | continue
200 |
201 | answer_text = answer_match.group(1).strip()
202 |
203 | # Extract words from prediction as a set, normalizing and handling edge cases
204 | pred_words = set()
205 | for line in answer_text.split('\n'):
206 | line = line.strip()
207 | if not line:
208 | continue
209 |
210 | # Match numbered lines, handling various formats (1., 1), 1-, etc.)
211 | match = re.match(r'^\d+[.)\-]?\s*(.+)$', line)
212 | if match:
213 | word = match.group(1).strip().lower()
214 | if word:
215 | pred_words.add(word)
216 |
217 | gold_words = set(pred['common_words'])
218 |
219 | # Calculate intersection over union
220 | intersection = len(pred_words & gold_words)
221 | union = len(pred_words | gold_words)
222 | iou = intersection / union if union > 0 else 0.0
223 | sample_ious.append(iou)
224 |
225 | # Calculate mean and variance
226 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0
227 | # Use ddof=1 for unbiased estimate of the variance
228 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0
229 |
230 | return {
231 | 'iou': mean_iou,
232 | 'iou_variance': variance
233 | }
234 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/ruler/niah.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating needle-in-a-haystack (NIAH) capabilities of language models.
2 |
3 | This module implements a task where models need to extract specific key-value pairs from
4 | a document containing both relevant pairs and distractors. The task tests the model's
5 | ability to:
6 | 1. Follow precise instructions
7 | 2. Extract relevant information while ignoring distractors
8 | 3. Format responses according to a specified template
9 | 4. Provide clear reasoning for answers
10 | """
11 |
12 | import string
13 | import re
14 | from typing import List, Dict, Any, Tuple
15 | import numpy as np
16 |
17 | from sparse_frontier.tasks.abstract_task import AbstractTask
18 | from sparse_frontier.tasks.abstract_sample import AbstractSample
19 | from sparse_frontier.tasks.abstract_prompt import MULTIPLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE
20 |
21 | # Task introduction and instructions
22 | TASK_INTRO = """I will provide you with a document containing multiple key-value pairs. Your task is to extract specific values associated with given keys."""
23 |
24 | ANSWER_FORMAT = """1. The answer for is .
25 | 2. The answer for is .
26 | etc."""
27 |
28 | EXTRA_INSTRUCTIONS = """
29 | - Provide answers in the exact order of the requested keys
30 | - Each answer must follow the format: ". The answer for is ."
31 | - Ensure exact key matches - do not modify or paraphrase the keys
32 | - Values must match exactly as they appear in the document
33 | """.strip()
34 |
35 |
36 | class Needle:
37 | """Represents a key-value pair in the NIAH task."""
38 |
39 | def __init__(self, random_obj, used_keys: set, used_values: set) -> None:
40 | self.key = self.generate_unique(used_keys, random_obj)
41 | self.value = self.generate_unique(used_values, random_obj)
42 |
43 | @staticmethod
44 | def generate_kv_part(random_obj, length: int = 8) -> str:
45 | return ''.join(random_obj.choices(string.ascii_lowercase + string.digits, k=length))
46 |
47 | @staticmethod
48 | def generate_kv(random_obj) -> str:
49 | parts = [Needle.generate_kv_part(random_obj) for _ in range(4)]
50 | return '-'.join(parts)
51 |
52 | @staticmethod
53 | def generate_unique(used_kv: set, random_obj) -> str:
54 | while True:
55 | key = Needle.generate_kv(random_obj)
56 | if key not in used_kv:
57 | used_kv.add(key)
58 | return key
59 |
60 | def to_sentence(self) -> str:
61 | return f"The value for key {self.key} is: {self.value}."
62 |
63 |
64 | class NIAHSample(AbstractSample):
65 | """Represents a single NIAH task sample with queries and distractors."""
66 |
67 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
68 | """Generate the input text, gold answer and extra data for this sample."""
69 | num_queries = self.task_params['num_queries']
70 | used_keys = set()
71 | used_values = set()
72 | needles = [
73 | Needle(self.random_obj, used_keys, used_values)
74 | for _ in range(num_queries)
75 | ]
76 |
77 | # Generate query sentences and shuffle
78 | queries_sentences = [needle.to_sentence() for needle in needles]
79 | self.random_obj.shuffle(queries_sentences)
80 |
81 | def get_current_token_count(sentences: List[str], keys: List[str]) -> int:
82 | question = f"Extract the values for the following keys: {', '.join(keys)}"
83 | text = PROMPT_TEMPLATE.format(
84 | task_intro=TASK_INTRO,
85 | context=" ".join(sentences),
86 | question=question,
87 | answer_format=ANSWER_FORMAT,
88 | extra_instructions=EXTRA_INSTRUCTIONS
89 | )
90 | return len(self.tokenizer.text_to_tokens(text))
91 |
92 | current_token_count = get_current_token_count(
93 | queries_sentences,
94 | [needle.key for needle in needles]
95 | )
96 | tokens_needed = self.max_tokens - current_token_count
97 |
98 | if tokens_needed < 0:
99 | raise ValueError(
100 | f"Needles are too long. Current length: {current_token_count}, "
101 | f"maximum length: {self.max_tokens}"
102 | )
103 |
104 | # Generate distractor sentences
105 | distractors_sentences = []
106 | while True:
107 | new_distractor = Needle(self.random_obj, used_keys, used_values)
108 | distractor_sentence = new_distractor.to_sentence()
109 | distractor_token_length = len(self.tokenizer.text_to_tokens(distractor_sentence))
110 | if tokens_needed < distractor_token_length:
111 | break
112 | distractors_sentences.append(distractor_sentence)
113 | tokens_needed -= distractor_token_length
114 |
115 | # Combine and shuffle all sentences
116 | all_sentences = queries_sentences + distractors_sentences
117 | self.random_obj.shuffle(all_sentences)
118 |
119 | # Generate question
120 | keys = [needle.key for needle in needles]
121 | question = f"Extract the values for the following keys: {', '.join(keys)}"
122 |
123 | # Format input using template
124 | input_text = PROMPT_TEMPLATE.format(
125 | task_intro=TASK_INTRO,
126 | context=" ".join(all_sentences),
127 | question=question,
128 | answer_format=ANSWER_FORMAT,
129 | extra_instructions=EXTRA_INSTRUCTIONS
130 | )
131 |
132 | # Generate gold answer
133 | gold_answer = "\n".join(
134 | f"{i+1}. The answer for {needle.key} is {needle.value}."
135 | for i, needle in enumerate(needles)
136 | )
137 |
138 | extra_data = {
139 | "answers": [(needle.key, needle.value) for needle in needles]
140 | }
141 |
142 | return input_text, gold_answer, extra_data
143 |
144 |
145 | class NIAHTask(AbstractTask):
146 | """Main task class for the Needle-in-a-Haystack evaluation."""
147 |
148 | def __init__(
149 | self,
150 | num_queries: int,
151 | **kwargs
152 | ) -> None:
153 | super().__init__(**kwargs)
154 | self.task_params['num_queries'] = num_queries
155 | self.check_params()
156 |
157 | def check_params(self) -> None:
158 | if 'num_queries' not in self.task_params:
159 | raise ValueError("Missing required parameter 'num_queries'")
160 |
161 | if not isinstance(self.task_params['num_queries'], int):
162 | raise ValueError("Parameter 'num_queries' must be an integer")
163 |
164 | if self.task_params['num_queries'] < 1:
165 | raise AssertionError("Parameter 'num_queries' must be greater than or equal to 1")
166 |
167 | @property
168 | def sample_class(self):
169 | return NIAHSample
170 |
171 | @staticmethod
172 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]:
173 | """Evaluates model predictions against gold answers."""
174 | def normalize_answer(text: str) -> str:
175 | """Normalize answer text for comparison."""
176 | # Convert to lowercase and remove extra whitespace
177 | text = re.sub(r'\s+', ' ', text.lower().strip())
178 | # Remove optional colon after "is"
179 | text = re.sub(r'\bis:\s+', 'is ', text)
180 | return text
181 |
182 | def extract_answers(text: str) -> Dict[int, Tuple[str, str]]:
183 | """Extract answers from text, handling both formats."""
184 | # First try to find the section
185 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
186 | if answer_match:
187 | text = answer_match.group(1)
188 |
189 | # Extract numbered answers with key-value pairs
190 | answers = {}
191 | pattern = re.compile(
192 | r'(\d+)\.\s*The answer for\s+([\w-]+)\s+is:?\s+(.+?)(?:\.|$)',
193 | re.IGNORECASE | re.MULTILINE
194 | )
195 |
196 | for match in pattern.finditer(text):
197 | idx = int(match.group(1))
198 | key = match.group(2).strip()
199 | value = match.group(3).strip()
200 | if idx not in answers: # Take first occurrence if duplicates
201 | answers[idx] = (key, normalize_answer(value))
202 |
203 | return answers
204 |
205 | sample_accuracies = []
206 |
207 | for example in examples:
208 | answers = example['answers']
209 | prediction = example['pred']
210 |
211 | # Get gold answer key-value pairs
212 | gold_pairs = {
213 | i + 1: (key, normalize_answer(value))
214 | for i, (key, value) in enumerate(answers)
215 | }
216 |
217 | # Extract predicted key-value pairs
218 | pred_pairs = extract_answers(prediction)
219 |
220 | # Compare predictions
221 | correct = 0
222 | total = len(gold_pairs)
223 |
224 | for idx, (gold_key, gold_value) in gold_pairs.items():
225 | if (idx in pred_pairs and
226 | pred_pairs[idx][0] == gold_key and
227 | pred_pairs[idx][1] == gold_value):
228 | correct += 1
229 |
230 | # Calculate accuracy for this sample
231 | sample_accuracies.append(correct / total if total > 0 else 0.0)
232 |
233 | # Calculate mean and variance
234 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0
235 | # Use ddof=1 for unbiased estimate of the variance
236 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0
237 |
238 | return {
239 | 'accuracy': mean_accuracy,
240 | 'accuracy_variance': variance
241 | }
242 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/ruler/vt.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating variable tracking capabilities of language models.
2 |
3 | This module implements a task where models need to track chains of variable assignments
4 | in a document containing both relevant assignments and distractors. The task tests the model's
5 | ability to:
6 | 1. Track multiple variable assignments through chains
7 | 2. Identify all variables that eventually get assigned a specific value
8 | 3. Ignore irrelevant distractors and noise
9 | """
10 |
11 | from typing import List, Dict, Any, Tuple
12 | import string
13 | import re
14 | import numpy as np
15 |
16 | from sparse_frontier.tasks.abstract_task import AbstractTask
17 | from sparse_frontier.tasks.abstract_sample import AbstractSample
18 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE
19 |
20 | NOISE_SENTENCE = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
21 |
22 | # Task introduction and instructions
23 | TASK_INTRO = """I will provide you with a text containing variable assignments. The text contains two types of assignments:
24 | 1. Numeric assignments that set a variable to a number (e.g., "VAR ABC = 12345")
25 | 2. Copy assignments that set a variable equal to another variable (e.g., "VAR XYZ = VAR ABC")
26 | Variables are sequences of uppercase letters. The assignments can appear in any order in the text."""
27 |
28 | ANSWER_FORMAT = "VARIABLE_ONE VARIABLE_TWO etc."
29 |
30 | EXTRA_INSTRUCTIONS = """
31 | - List ONLY the variable names that resolve to the target value.
32 | - Variables can be listed in any order.
33 | - Do not include "VAR" prefix in your answer. Do not include punctuation.
34 | """.strip()
35 |
36 | QUESTION_TEMPLATE = "Which variables resolve to the value {target_value}? A variable resolves to {target_value} if it is either directly assigned {target_value}, or assigned to another variable that resolves to {target_value}."
37 |
38 |
39 | class VariableTrackingSample(AbstractSample):
40 | """Handles generation of individual variable tracking samples."""
41 |
42 | def _generate_random_var(self) -> str:
43 | """Generate a random 5-letter uppercase variable name."""
44 | return ''.join(self.random_obj.choices(string.ascii_uppercase, k=5))
45 |
46 | def _generate_unique_vars(self, num_vars: int) -> List[str]:
47 | """Generate a list of unique variable names.
48 |
49 | Args:
50 | num_vars: Number of unique variable names needed
51 |
52 | Returns:
53 | List of unique variable names
54 | """
55 | unique_vars = []
56 | while len(unique_vars) < num_vars:
57 | new_var = self._generate_random_var()
58 | if new_var not in unique_vars:
59 | unique_vars.append(new_var)
60 | return unique_vars
61 |
62 | def _create_chain(self, vars_for_chain: List[str], initial_value: int) -> List[str]:
63 | """Create a single chain of variable assignments.
64 |
65 | Args:
66 | vars_for_chain: List of variables to use in the chain
67 | initial_value: Starting numeric value for the chain
68 |
69 | Returns:
70 | List of assignment statements forming the chain
71 | """
72 | chain = [f"VAR {vars_for_chain[0]} = {initial_value}"]
73 | for i in range(len(vars_for_chain) - 1):
74 | chain.append(f"VAR {vars_for_chain[i+1]} = VAR {vars_for_chain[i]}")
75 | return chain
76 |
77 | def _generate_chains(self) -> Tuple[List[List[str]], List[str], str]:
78 | """Generate variable assignment chains.
79 |
80 | Returns:
81 | Tuple containing:
82 | - List of variable chains
83 | - List of variables that get the target value
84 | - Target value that propagates through first chain
85 | """
86 | num_chains = self.task_params['num_chains']
87 | num_hops = self.task_params['num_hops']
88 | vars_per_chain = num_hops + 1
89 | total_vars_needed = num_chains * vars_per_chain
90 |
91 | # Generate all unique variables needed
92 | all_vars = self._generate_unique_vars(total_vars_needed)
93 |
94 | # Generate unique integers for each chain
95 | unique_integers = self.random_obj.sample(range(10000, 99999), num_chains)
96 |
97 | # Create chains
98 | chains = []
99 | target_vars = None
100 | target_value = str(unique_integers[0]) # Value that propagates through first chain
101 |
102 | for i in range(num_chains):
103 | # Get variables for this chain
104 | chain_vars = all_vars[i*vars_per_chain:(i+1)*vars_per_chain]
105 | chain = self._create_chain(chain_vars, unique_integers[i])
106 | chains.append(chain)
107 |
108 | # Store variables from first chain as target
109 | if i == 0:
110 | target_vars = chain_vars
111 |
112 | return chains, target_vars, target_value
113 |
114 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
115 | """Generate a single variable tracking sample."""
116 | chains, target_vars, target_value = self._generate_chains()
117 |
118 | # Extract all variable assignment statements
119 | assignment_statements = []
120 | for chain in chains:
121 | assignment_statements.extend(chain)
122 |
123 | # Calculate tokens used by assignments and prompt
124 | assignment_text = " ".join(assignment_statements)
125 | prompt_tokens = len(self.tokenizer.text_to_tokens(
126 | SINGLEQ_PROMPT_TEMPLATE.format(
127 | task_intro=TASK_INTRO,
128 | context=assignment_text,
129 | question=QUESTION_TEMPLATE.format(target_value=target_value),
130 | answer_format=ANSWER_FORMAT,
131 | extra_instructions=EXTRA_INSTRUCTIONS
132 | )
133 | ))
134 |
135 | # Calculate how many noise sentences we can add
136 | noise_tokens = len(self.tokenizer.text_to_tokens(NOISE_SENTENCE))
137 | remaining_tokens = self.max_tokens - prompt_tokens
138 | num_noise_sentences = remaining_tokens // noise_tokens
139 | num_noise_sentences = max(num_noise_sentences - 5, 0) # Safety margin
140 |
141 | # Create final list of sentences and shuffle
142 | sentences = assignment_statements + [NOISE_SENTENCE] * num_noise_sentences
143 | self.random_obj.shuffle(sentences)
144 |
145 | # Format context and question
146 | context = " ".join(sentences)
147 | question = QUESTION_TEMPLATE.format(target_value=target_value)
148 |
149 | # Format input using template
150 | input_text = SINGLEQ_PROMPT_TEMPLATE.format(
151 | task_intro=TASK_INTRO,
152 | context=context,
153 | question=question,
154 | answer_format=ANSWER_FORMAT,
155 | extra_instructions=EXTRA_INSTRUCTIONS
156 | )
157 |
158 | # Format gold answer
159 | gold_answer = " ".join(target_vars)
160 |
161 | extra_data = {
162 | "target_value": target_value,
163 | "num_chains": len(chains),
164 | "num_hops": self.task_params['num_hops'],
165 | "target_vars": target_vars
166 | }
167 |
168 | return input_text, gold_answer, extra_data
169 |
170 |
171 | class VariableTrackingTask(AbstractTask):
172 | """Task for evaluating variable tracking capabilities."""
173 |
174 | def __init__(
175 | self,
176 | num_chains: int = 1,
177 | num_hops: int = 4,
178 | **kwargs
179 | ) -> None:
180 | """Initialize variable tracking task.
181 |
182 | Args:
183 | num_chains: Number of variable chains to include
184 | num_hops: Number of variable assignments in each chain
185 | **kwargs: Additional arguments passed to parent class
186 | """
187 | super().__init__(**kwargs)
188 | self.task_params.update({
189 | 'num_chains': num_chains,
190 | 'num_hops': num_hops
191 | })
192 | self.check_params()
193 |
194 | def check_params(self) -> None:
195 | """Validate task parameters."""
196 | if not isinstance(self.task_params.get('num_chains'), int):
197 | raise ValueError("num_chains must be an integer")
198 | if not isinstance(self.task_params.get('num_hops'), int):
199 | raise ValueError("num_hops must be an integer")
200 | if self.task_params['num_chains'] < 1:
201 | raise ValueError("num_chains must be at least 1")
202 | if self.task_params['num_hops'] < 1:
203 | raise ValueError("num_hops must be at least 1")
204 |
205 | @property
206 | def sample_class(self):
207 | return VariableTrackingSample
208 |
209 | @staticmethod
210 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, float]:
211 | """Evaluate model predictions against gold answers using intersection over union.
212 |
213 | For each prediction, calculates IoU between predicted and gold variable sets:
214 | IoU = |intersection| / |union|
215 |
216 | Returns mean IoU across all predictions and variance.
217 | """
218 | def normalize_answer(text: str) -> str:
219 | """Normalize answer text for comparison."""
220 | # Extract answer from tagged response if present
221 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
222 | if answer_match:
223 | text = answer_match.group(1)
224 |
225 | # Convert to uppercase and remove extra whitespace
226 | text = re.sub(r'\s+', ' ', text.upper().strip())
227 | # Remove any "VAR" prefixes
228 | text = re.sub(r'VAR\s+', '', text)
229 | return text
230 |
231 | sample_ious = []
232 |
233 | for pred in predictions:
234 | # Convert predictions and gold answers to sets of variables
235 | pred_vars = set(normalize_answer(pred['pred']).split())
236 | gold_vars = set(normalize_answer(pred['gold_answer']).split())
237 |
238 | # Calculate intersection over union
239 | intersection = len(pred_vars & gold_vars)
240 | union = len(pred_vars | gold_vars)
241 |
242 | # Handle edge case where both sets are empty
243 | iou = 1.0 if union == 0 else intersection / union
244 | sample_ious.append(iou)
245 |
246 | # Calculate mean and variance of IoU scores
247 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0
248 | # Use ddof=1 for unbiased estimate of the variance
249 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0
250 |
251 | return {
252 | 'iou': mean_iou,
253 | 'iou_variance': variance
254 | }
255 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/story/filtering.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating model's ability to identify chapters without purchases.
2 |
3 | This module implements a task where models need to identify chapters in a narrative where
4 | the protagonist did not make any purchases. The task tests the model's ability to:
5 | 1. Follow precise instructions
6 | 2. Track item transactions across a narrative
7 | 3. Format responses according to a specified template
8 | """
9 |
10 | import re
11 | from typing import Any, Dict, List, Set, Tuple
12 | import copy
13 | import numpy as np
14 |
15 | from sparse_frontier.tasks.abstract_task import AbstractTask
16 | from sparse_frontier.tasks.abstract_sample import AbstractSample
17 | from sparse_frontier.tasks.story.narrative import Chapter, NarrativeGenerator
18 | from sparse_frontier.tasks.story.templates import TASK_INTRO
19 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE
20 |
21 | QUESTION = """Identify all chapters where the protagonist did not buy any item.
22 | Note: There are exactly {num_chapters} chapters without any purchases."""
23 |
24 | ANSWER_FORMAT = "chapter_id_1, chapter_id_2, ..."
25 |
26 | EXTRA_INSTRUCTIONS = """
27 | - In the answer section, provide only the chapter IDs separated by commas.
28 | """.strip()
29 |
30 | PROMPT = PROMPT_TEMPLATE.format(
31 | task_intro=TASK_INTRO,
32 | question="{question}",
33 | context="{context}",
34 | answer_format=ANSWER_FORMAT,
35 | extra_instructions=EXTRA_INSTRUCTIONS
36 | )
37 |
38 |
39 | class FilteringSample(AbstractSample):
40 | """Represents a single filtering task sample."""
41 |
42 | @staticmethod
43 | def _remove_buying_transactions(chapter: Chapter, items_to_remove: Set[str]) -> None:
44 | """Remove buying transactions for specified items from a chapter."""
45 | new_bought = []
46 | new_buying_trans = []
47 |
48 | for item, trans in zip(chapter.bought_items, chapter.structure['buying_transactions']):
49 | if item not in items_to_remove:
50 | new_bought.append(item)
51 | new_buying_trans.append(trans)
52 |
53 | chapter.bought_items = new_bought
54 | chapter.structure['buying_transactions'] = new_buying_trans
55 |
56 | def _modify_chapters(self, chapters: List[Chapter], chapter_ids: List[int]) -> List[Chapter]:
57 | """Modify chapters by removing purchases from specified chapters and updating subsequent chapters."""
58 | modified_chapters = copy.deepcopy(chapters)
59 | items_to_track = set()
60 |
61 | # Remove purchases from specified chapters and track items
62 | for chapter in modified_chapters:
63 | if chapter.chapter_id in chapter_ids:
64 | items_to_track.update(chapter.bought_items)
65 | self._remove_buying_transactions(chapter, set(chapter.bought_items))
66 |
67 | return modified_chapters
68 |
69 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
70 | """Generate the input text, gold answer and extra data for this sample."""
71 | # Calculate prompt tokens
72 | prompt_tokens = len(self.tokenizer.text_to_tokens(
73 | PROMPT.format(
74 | context="",
75 | question=QUESTION.format(num_chapters=1),
76 | )
77 | ))
78 |
79 | # Calculate remaining tokens for narrative
80 | narrative_tokens = self.max_tokens - prompt_tokens
81 |
82 | # Generate narrative with adjusted token limit
83 | narrative_gen = NarrativeGenerator(
84 | tokenizer=self.tokenizer,
85 | sequence_length=narrative_tokens,
86 | random_obj=self.random_obj,
87 | )
88 |
89 | # Randomly select chapters to remove purchases from
90 | num_chapters = len(narrative_gen.chapters)
91 | chapters_to_modify = sorted(
92 | self.random_obj.sample(range(1, num_chapters + 1), self.task_params['chapters_in_question'])
93 | )
94 |
95 | modified_chapters = self._modify_chapters(narrative_gen.chapters, chapters_to_modify)
96 |
97 | # Build prompt
98 | context = "\n\n".join(ch.compile_text() for ch in modified_chapters)
99 | input_text = PROMPT.format(
100 | context=context,
101 | question=QUESTION.format(num_chapters=self.task_params['chapters_in_question'])
102 | )
103 |
104 | # Generate expected output
105 | gold_answer = ", ".join(map(str, chapters_to_modify)) if chapters_to_modify else ""
106 |
107 | return input_text, gold_answer, {}
108 |
109 |
110 | class FilteringTask(AbstractTask):
111 | """Task class for evaluating chapter purchase identification capabilities."""
112 |
113 | def __init__(
114 | self,
115 | chapters_in_question: int,
116 | protagonist_name: str = "Arion",
117 | **kwargs
118 | ) -> None:
119 | """Initialize the filtering task.
120 |
121 | Args:
122 | chapters_in_question: Number of chapters to remove purchases from
123 | protagonist_name: Name of the story protagonist
124 | **kwargs: Additional arguments passed to AbstractTask
125 | """
126 | super().__init__(**kwargs)
127 | self.task_params['chapters_in_question'] = chapters_in_question
128 | self.task_params['protagonist_name'] = protagonist_name
129 | self.check_params()
130 |
131 | def check_params(self) -> None:
132 | """Validate task parameters."""
133 | if 'chapters_in_question' not in self.task_params:
134 | raise ValueError("Missing required parameter 'chapters_in_question'")
135 |
136 | if not isinstance(self.task_params['chapters_in_question'], int):
137 | raise ValueError("Parameter 'chapters_in_question' must be an integer")
138 |
139 | if self.task_params['chapters_in_question'] < 1:
140 | raise ValueError("Parameter 'chapters_in_question' must be at least 1")
141 |
142 | @property
143 | def sample_class(self):
144 | return FilteringSample
145 |
146 | @staticmethod
147 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]:
148 | """Evaluate model predictions against gold answers.
149 |
150 | Args:
151 | examples: List of dictionaries containing predictions and gold answers
152 |
153 | Returns:
154 | Dictionary containing IoU (Intersection over Union) metric and its variance
155 | """
156 | def normalize_answer(text: str) -> Set[int]:
157 | """Extract and normalize chapter IDs from answer text."""
158 | # Remove any non-numeric characters except commas and whitespace
159 | text = re.sub(r'[^0-9,\s]', '', text.strip())
160 | # Split on comma or whitespace
161 | parts = re.split(r'[,\s]+', text)
162 | # Convert to integers, ignoring any invalid parts
163 | try:
164 | return {int(p) for p in parts if p.strip()}
165 | except ValueError:
166 | return set()
167 |
168 | def extract_answer(text: str) -> Set[int]:
169 | """Extract answer from the formatted response."""
170 | match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
171 | if not match:
172 | return set()
173 | answer_text = match.group(1)
174 |
175 | # Handle both formats: "1, 2, 3" and "Chapter 1, Chapter 2, Chapter 3"
176 | # First try to extract chapter numbers from "Chapter X" format
177 | chapter_matches = re.findall(r'Chapter\s+(\d+)', answer_text, re.IGNORECASE)
178 | if chapter_matches:
179 | return {int(num) for num in chapter_matches}
180 |
181 | # If no "Chapter X" format found, fall back to original number parsing
182 | return normalize_answer(answer_text)
183 |
184 | sample_ious = []
185 |
186 | for ex in examples:
187 | pred = ex.get("pred", "").strip()
188 | gold = ex.get("gold_answer", "").strip()
189 |
190 | if not pred or not gold:
191 | continue
192 |
193 | gold_set = normalize_answer(gold)
194 | pred_set = extract_answer(pred)
195 |
196 | if not gold_set and not pred_set:
197 | sample_ious.append(1.0)
198 | continue
199 |
200 | union = gold_set.union(pred_set)
201 | intersect = gold_set.intersection(pred_set)
202 | iou = len(intersect) / len(union) if union else 0.0
203 | sample_ious.append(iou)
204 |
205 | # Calculate mean and variance
206 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0
207 | # Use ddof=1 for unbiased estimate of the variance
208 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0
209 |
210 | return {
211 | "iou": mean_iou,
212 | "iou_variance": variance
213 | }
214 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/story/multihop.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating multi-hop reasoning capabilities in narrative comprehension.
2 |
3 | This module implements a task where models need to track item acquisitions across chapters
4 | in a narrative. The task tests the model's ability to:
5 | 1. Follow precise instructions
6 | 2. Extract relevant item information from context
7 | 3. Connect information across multiple chapters
8 | 4. Format responses according to a specified template
9 | """
10 |
11 | import re
12 | from typing import List, Dict, Any, Tuple
13 |
14 | from sparse_frontier.tasks.abstract_task import AbstractTask
15 | from sparse_frontier.tasks.abstract_sample import AbstractSample
16 | from sparse_frontier.tasks.story.narrative import NarrativeGenerator
17 | from sparse_frontier.tasks.story.templates import TASK_INTRO
18 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE
19 |
20 | QUESTION = "What was the last item that the protagonist acquired before acquiring {target_item}?"
21 |
22 | ANSWER_FORMAT = "ITEM_NAME"
23 |
24 | EXTRA_INSTRUCTIONS = """
25 | - Provide only the item name in the answer section.
26 | - Do not include articles like 'the' or 'a' in your answer.
27 | - The item name must be exactly as mentioned in the text.
28 | """.strip()
29 |
30 | PROMPT = PROMPT_TEMPLATE.format(
31 | task_intro=TASK_INTRO,
32 | question="{question}",
33 | context="{context}",
34 | answer_format=ANSWER_FORMAT,
35 | extra_instructions=EXTRA_INSTRUCTIONS
36 | )
37 |
38 |
39 | class MultiHopSample(AbstractSample):
40 | """Represents a single multi-hop reasoning sample with narrative and item queries."""
41 |
42 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
43 | """Generate the input text, gold answer and extra data for this sample."""
44 | # Calculate prompt tokens
45 | prompt_tokens = len(self.tokenizer.text_to_tokens(
46 | PROMPT.format(
47 | context="",
48 | question=QUESTION.format(target_item="item_component_1 item_component_2 item_component_3"),
49 | )
50 | ))
51 |
52 | # Calculate remaining tokens for narrative
53 | narrative_tokens = self.max_tokens - prompt_tokens
54 |
55 | # Generate narrative
56 | narrative_gen = NarrativeGenerator(
57 | tokenizer=self.tokenizer,
58 | sequence_length=narrative_tokens,
59 | random_obj=self.random_obj,
60 | protagonist_name=self.task_params.get('protagonist_name', "Arion")
61 | )
62 |
63 | # Find chapters where items were bought
64 | chapters_with_items = [
65 | ch for ch in narrative_gen.chapters
66 | if ch.bought_items
67 | ]
68 |
69 | assert len(chapters_with_items) >= 2, "Need at least two chapters with item purchases"
70 |
71 | # Select a random chapter with an item purchase (excluding first chapter with items)
72 | target_chapter = self.random_obj.choice(chapters_with_items[1:])
73 | target_item = target_chapter.bought_items[0]
74 |
75 | # Find the previous chapter where an item was bought
76 | prev_chapters = [ch for ch in chapters_with_items if ch.chapter_id < target_chapter.chapter_id]
77 | prev_chapter = prev_chapters[-1] # Get the most recent previous chapter
78 | prev_item = prev_chapter.bought_items[0]
79 |
80 | # Generate question
81 | question = QUESTION.format(target_item=target_item)
82 |
83 | input_text = PROMPT.format(
84 | context=narrative_gen.compile_narrative(),
85 | question=question
86 | )
87 |
88 | return input_text, prev_item, {}
89 |
90 |
91 | class MultiHopTask(AbstractTask):
92 | """Main task class for multi-hop reasoning evaluation in narratives."""
93 |
94 | def __init__(
95 | self,
96 | protagonist_name: str = "Arion",
97 | **kwargs
98 | ) -> None:
99 | """Initialize the multi-hop task.
100 |
101 | Args:
102 | protagonist_name: Name of the story protagonist
103 | **kwargs: Additional arguments passed to AbstractTask
104 | """
105 | super().__init__(**kwargs)
106 | self.task_params['protagonist_name'] = protagonist_name
107 | self.check_params()
108 |
109 | def check_params(self) -> None:
110 | """Validate task-specific parameters."""
111 | if not isinstance(self.task_params['protagonist_name'], str):
112 | raise ValueError("protagonist_name must be a string")
113 |
114 | @property
115 | def sample_class(self):
116 | """Return the sample class for this task."""
117 | return MultiHopSample
118 |
119 | @staticmethod
120 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, Any]:
121 | """Evaluate model predictions against gold answers."""
122 | def normalize_answer(text: str) -> str:
123 | """Normalize answer text for comparison."""
124 | # Convert to lowercase and remove extra whitespace
125 | text = re.sub(r'\s+', ' ', text.lower().strip())
126 | # Remove articles
127 | text = re.sub(r'^(the|a|an)\s+', '', text)
128 | return text
129 |
130 | def extract_answer(text: str) -> str:
131 | """Extract answer from the formatted response."""
132 | # Find content between tags
133 | match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
134 | if not match:
135 | return ""
136 | return normalize_answer(match.group(1))
137 |
138 | import numpy as np
139 | sample_accuracies = []
140 |
141 | for example in examples:
142 | gold = normalize_answer(example['gold_answer'])
143 | pred = extract_answer(example['pred'])
144 |
145 | if gold and pred: # Only evaluate if both answers are non-empty
146 | # Binary accuracy (1 if correct, 0 if incorrect)
147 | accuracy = 1.0 if pred == gold else 0.0
148 | sample_accuracies.append(accuracy)
149 |
150 | # Calculate mean and variance
151 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0
152 | # Use ddof=1 for unbiased estimate of the variance
153 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0
154 |
155 | return {
156 | 'accuracy': mean_accuracy,
157 | 'accuracy_variance': variance
158 | }
159 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/story/narrative.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Set, Tuple
2 |
3 | from sparse_frontier.tasks.story.templates import NarrativeResources, ItemGenerator
4 |
5 |
6 | class Chapter:
7 | """Represents a single narrative chapter."""
8 | def __init__(self, chapter_id: int, location: str, character: str, event: str, random_obj):
9 | self.chapter_id = chapter_id
10 | self.location = location
11 | self.event = event
12 | self.character = character
13 | self.random_obj = random_obj
14 | self.structure = {
15 | 'scene_introduction': [],
16 | 'encounter': [],
17 | 'conversation_extension': [],
18 | 'buying_transactions': [],
19 | 'selling_transactions': [],
20 | 'farewell': [],
21 | 'scene_conclusions': [],
22 | 'extra': []
23 | }
24 | self.bought_items: List[str] = []
25 | self.sold_items: List[str] = []
26 | self.end_item_count: Optional[int] = None
27 |
28 | def compile_text(self) -> str:
29 | parts = (
30 | self.structure['scene_introduction']
31 | + self.structure['encounter']
32 | + self.structure['conversation_extension']
33 | + self.structure['buying_transactions']
34 | + self.structure['selling_transactions']
35 | + self.structure['farewell']
36 | + self.structure['scene_conclusions']
37 | + self.structure['extra']
38 | )
39 | return f"Chapter {self.chapter_id}:\n" + "".join(parts)
40 |
41 | def generate_chapter(self, protagonist_name: str, conversation_extensions: int, items_seen: Set[str], inventory: List[str]) -> None:
42 | """Generates a complete chapter with all narrative elements."""
43 | # Scene introduction
44 | self.structure['scene_introduction'].append(
45 | NarrativeResources.choose(NarrativeResources.SCENE_INTRO_TEMPLATES,
46 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj)
47 | )
48 | self.structure['scene_introduction'].append(
49 | NarrativeResources.choose(NarrativeResources.REASON_TEMPLATES,
50 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj)
51 | )
52 | self.structure['scene_introduction'].append(
53 | NarrativeResources.choose(NarrativeResources.EVENT_TEMPLATES,
54 | protagonist=protagonist_name, event=self.event, location=self.location, random_obj=self.random_obj)
55 | )
56 |
57 | # Encounter
58 | self.structure['encounter'].append(
59 | NarrativeResources.choose(NarrativeResources.CHAR_INTRO_TEMPLATES,
60 | protagonist=protagonist_name, character=self.character, random_obj=self.random_obj)
61 | )
62 |
63 | # Conversation extensions
64 | self.structure['conversation_extension'].extend(
65 | NarrativeResources.choose_multiple(
66 | NarrativeResources.CONVERSATION_EXTENSION_TEMPLATES,
67 | conversation_extensions,
68 | protagonist=protagonist_name,
69 | location=self.location,
70 | random_obj=self.random_obj
71 | )
72 | )
73 |
74 | # Handle buying and selling
75 | new_item = ItemGenerator._get_unique_item(items_seen, self.random_obj)
76 | inventory.append(new_item)
77 | self.bought_items.append(new_item)
78 | self.structure['buying_transactions'].append(
79 | NarrativeResources.choose(NarrativeResources.BUYING_TEMPLATES,
80 | protagonist=protagonist_name,
81 | character=self.character,
82 | item=new_item,
83 | random_obj=self.random_obj)
84 | )
85 |
86 | older_items = [it for it in inventory if it != new_item]
87 | if older_items and self.random_obj.random() < 0.5:
88 | sell_item = self.random_obj.choice(older_items)
89 | inventory.remove(sell_item)
90 | self.sold_items.append(sell_item)
91 | self.structure['selling_transactions'].append(
92 | NarrativeResources.choose(NarrativeResources.SELLING_TEMPLATES,
93 | protagonist=protagonist_name,
94 | character=self.character,
95 | item=sell_item,
96 | random_obj=self.random_obj)
97 | )
98 |
99 | self.end_item_count = len(inventory)
100 |
101 | # Farewell
102 | self.structure['farewell'].append(
103 | NarrativeResources.choose(NarrativeResources.FAREWELL_TEMPLATES,
104 | protagonist=protagonist_name, character=self.character, random_obj=self.random_obj)
105 | )
106 |
107 | # Scene conclusions
108 | self.structure['scene_conclusions'].append(
109 | NarrativeResources.choose(NarrativeResources.CONCLUSION_TEMPLATES,
110 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj)
111 | )
112 |
113 | # Add extra sentences
114 | self.structure['extra'].extend(
115 | NarrativeResources.choose_multiple(NarrativeResources.EXTRA_TEMPLATES, 1,
116 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj)
117 | )
118 |
119 |
120 | class NarrativeGenerator:
121 | """Generates a base narrative with a certain number of chapters."""
122 | def __init__(
123 | self,
124 | tokenizer,
125 | sequence_length: int,
126 | random_obj,
127 | protagonist_name: str = "Arion",
128 | conversation_extensions: int = 3,
129 | ):
130 | self.tokenizer = tokenizer
131 | self.sequence_length = sequence_length
132 | self.protagonist_name = protagonist_name
133 | self.protagonist_inventory: List[str] = []
134 | self.conversation_extensions = conversation_extensions
135 | self._used_pairs: Set[Tuple[str, str]] = set()
136 | self._items_seen: Set[str] = set()
137 | self.random_obj = random_obj
138 |
139 | self.chapters = self._generate_base_narrative()
140 | assert len(self.tokenizer.text_to_tokens(self.compile_narrative())) <= self.sequence_length, "Narrative is too long."
141 |
142 | def compile_narrative(self) -> str:
143 | return "\n\n".join(chapter.compile_text() for chapter in self.chapters)
144 |
145 | def _generate_base_narrative(self) -> List[Chapter]:
146 | chapters = []
147 | total_tokens = 0
148 | chapter_id = 1
149 |
150 | while True:
151 | assert chapter_id <= 2000, "Limit reached."
152 | for _ in range(100):
153 | location = self.random_obj.choice(NarrativeResources.LOCATIONS)
154 | character = self.random_obj.choice(NarrativeResources.CHARACTERS)
155 | if (location, character) not in self._used_pairs:
156 | self._used_pairs.add((location, character))
157 | break
158 | else:
159 | raise ValueError("Could not find a unique (location, character) pair.")
160 |
161 | event = self.random_obj.choice(NarrativeResources.EVENTS)
162 | chapter = Chapter(chapter_id, location, character, event, self.random_obj)
163 |
164 | chapter.generate_chapter(
165 | self.protagonist_name,
166 | self.conversation_extensions,
167 | self._items_seen,
168 | self.protagonist_inventory
169 | )
170 |
171 | chapter_text = chapter.compile_text()
172 | chapter_tokens = len(self.tokenizer.text_to_tokens(f'\n\n{chapter_text}'))
173 |
174 | if total_tokens + chapter_tokens > self.sequence_length:
175 | break
176 |
177 | chapters.append(chapter)
178 | total_tokens += chapter_tokens
179 | chapter_id += 1
180 |
181 | return chapters
182 |
--------------------------------------------------------------------------------
/sparse_frontier/tasks/story/retrieval.py:
--------------------------------------------------------------------------------
1 | """Task for evaluating story comprehension and information retrieval capabilities.
2 |
3 | This module implements a task where models need to extract specific information from
4 | narrative chapters. The task tests the model's ability to:
5 | 1. Comprehend multi-chapter narratives
6 | 2. Extract relevant information about locations, characters, and items
7 | 3. Format responses according to a specified template
8 | 4. Provide clear reasoning for its answers
9 | 5. Answer questions about specific details from the text
10 | """
11 |
12 | import re
13 | from typing import List, Dict, Any, Tuple
14 |
15 | from sparse_frontier.tasks.abstract_task import AbstractTask
16 | from sparse_frontier.tasks.abstract_sample import AbstractSample
17 | from sparse_frontier.tasks.story.narrative import NarrativeGenerator
18 | from sparse_frontier.tasks.story.templates import TASK_INTRO
19 | from sparse_frontier.tasks.abstract_prompt import MULTIPLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE
20 |
21 | ANSWER_FORMAT = """1. ANSWER_ONE
22 | 2. ANSWER_TWO
23 | etc."""
24 |
25 | EXTRA_INSTRUCTIONS = """
26 | - For answers, use one line per answer with the number prefix
27 | - Do not include articles like 'the' or 'a' in answers
28 | - Answers should be specific names/items/locations mentioned in the text
29 | """.strip()
30 |
31 | PROMPT = PROMPT_TEMPLATE.format(
32 | task_intro=TASK_INTRO,
33 | question="{questions}",
34 | context="{context}",
35 | answer_format=ANSWER_FORMAT,
36 | extra_instructions=EXTRA_INSTRUCTIONS
37 | )
38 |
39 |
40 | class RetrievalSample(AbstractSample):
41 | SAFETY_TOKENS = 10
42 |
43 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]:
44 | num_queries = self.task_params['num_queries']
45 |
46 | # Calculate tokens needed for prompt and instructions
47 | prompt_tokens = len(self.tokenizer.text_to_tokens(
48 | PROMPT.format(context="", questions="")
49 | ))
50 |
51 | # Estimate tokens for questions and explanation
52 | question_template = "XXX. In Chapter XXX, which specific location/character/item did the protagonist visit/meet/acquire?\n"
53 | question_tokens = len(self.tokenizer.text_to_tokens(question_template)) * num_queries * 2
54 |
55 | # Calculate remaining tokens for narrative
56 | narrative_tokens = self.max_tokens - (prompt_tokens + question_tokens + self.SAFETY_TOKENS)
57 |
58 | # Generate narrative with adjusted token limit
59 | narrative_gen = NarrativeGenerator(
60 | tokenizer=self.tokenizer,
61 | sequence_length=narrative_tokens,
62 | random_obj=self.random_obj,
63 | )
64 |
65 | assert len(narrative_gen.chapters) >= num_queries, "Not enough chapters for the requested complexity."
66 |
67 | # Select random chapters and generate questions
68 | selected_chapters = self.random_obj.sample(narrative_gen.chapters, num_queries)
69 | questions_and_answers = []
70 |
71 | if num_queries == 3:
72 | # Fixed query types for 3 questions with clearer phrasing
73 | query_types = ["location", "character", "item"]
74 | self.random_obj.shuffle(query_types)
75 |
76 | for i, (ch, query_type) in enumerate(zip(selected_chapters, query_types), start=1):
77 | if query_type == "location":
78 | q = f"{i}. In Chapter {ch.chapter_id}, which specific location did the protagonist visit?"
79 | a = ch.location
80 | elif query_type == "character":
81 | q = f"{i}. In Chapter {ch.chapter_id}, which character did the protagonist interact with?"
82 | a = ch.character
83 | else: # item
84 | q = f"{i}. In Chapter {ch.chapter_id}, which specific item was acquired by the protagonist?"
85 | a = ch.bought_items[0] if ch.bought_items else "None"
86 |
87 | questions_and_answers.append((i, q, a))
88 | else:
89 | # Enhanced question generation for other numbers of queries
90 | for i, ch in enumerate(selected_chapters, start=1):
91 | query_type = self.random_obj.choice(["location", "character", "item"])
92 |
93 | if query_type == "location":
94 | q = f"{i}. In Chapter {ch.chapter_id}, which specific location did the protagonist visit?"
95 | a = ch.location
96 | elif query_type == "character":
97 | q = f"{i}. In Chapter {ch.chapter_id}, which character did the protagonist interact with?"
98 | a = ch.character
99 | else:
100 | q = f"{i}. In Chapter {ch.chapter_id}, which specific item was acquired by the protagonist?"
101 | a = ch.bought_items[0] if ch.bought_items else "None"
102 |
103 | questions_and_answers.append((i, q, a))
104 |
105 | input_text = PROMPT.format(
106 | context=narrative_gen.compile_narrative(),
107 | questions="\n".join(q for (_, q, _) in questions_and_answers)
108 | )
109 |
110 | gold_answer = "\n".join(f"{i}. {a}" for (i, _, a) in questions_and_answers)
111 |
112 | return input_text, gold_answer, {}
113 |
114 |
115 | class RetrievalTask(AbstractTask):
116 | """Task class for evaluating story comprehension and information retrieval."""
117 |
118 | def __init__(
119 | self,
120 | num_queries: int,
121 | protagonist_name: str = "Arion",
122 | **kwargs
123 | ) -> None:
124 | super().__init__(**kwargs)
125 | self.task_params['num_queries'] = num_queries
126 | self.task_params['protagonist_name'] = protagonist_name
127 | self.check_params()
128 |
129 | def check_params(self) -> None:
130 | """Validate task-specific parameters."""
131 | if 'num_queries' not in self.task_params:
132 | raise ValueError("Missing required parameter 'num_queries'")
133 |
134 | if not isinstance(self.task_params['num_queries'], int):
135 | raise ValueError("Parameter 'num_queries' must be an integer")
136 |
137 | if self.task_params['num_queries'] < 1:
138 | raise AssertionError("Parameter 'num_queries' must be greater than 0")
139 |
140 | @property
141 | def sample_class(self):
142 | return RetrievalSample
143 |
144 | @staticmethod
145 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]:
146 | """Evaluate model predictions against gold answers."""
147 | def normalize_answer(text: str) -> str:
148 | """Normalize answer text for comparison."""
149 | # Convert to lowercase and remove extra whitespace
150 | text = re.sub(r'\s+', ' ', text.lower().strip())
151 | # Remove articles
152 | text = re.sub(r'^(the|a|an)\s+', '', text)
153 | return text
154 |
155 | def extract_answers(text: str) -> Dict[int, str]:
156 | """Extract answers from text, handling both formats."""
157 | # First try to find the section
158 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
159 | if answer_match:
160 | text = answer_match.group(1)
161 |
162 | # Extract numbered answers
163 | answers = {}
164 | for line in text.split('\n'):
165 | match = re.match(r'^(\d+)[\.:\)]\s*(.+?)\s*$', line.strip())
166 | if match:
167 | idx = int(match.group(1))
168 | answer = normalize_answer(match.group(2))
169 | if idx not in answers: # Take first occurrence if duplicates
170 | answers[idx] = answer
171 | return answers
172 |
173 | import numpy as np
174 | sample_accuracies = []
175 |
176 | for example in examples:
177 | gold_answers = extract_answers(example['gold_answer'])
178 | pred_answers = extract_answers(example['pred'])
179 |
180 | if not gold_answers: # Skip examples without gold answers
181 | continue
182 |
183 | correct = 0
184 | total = len(gold_answers)
185 |
186 | for idx, gold in gold_answers.items():
187 | if idx in pred_answers and pred_answers[idx] == gold:
188 | correct += 1
189 |
190 | sample_accuracies.append(correct / total if total > 0 else 0.0)
191 |
192 | # Calculate mean and variance
193 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0
194 | # Use ddof=1 for unbiased estimate of the variance
195 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0
196 |
197 | return {
198 | 'accuracy': mean_accuracy,
199 | 'accuracy_variance': variance
200 | }
201 |
--------------------------------------------------------------------------------
/sparse_frontier/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .globals import GlobalSettings
2 |
3 | __all__ = [
4 | "GlobalSettings",
5 | ]
6 |
--------------------------------------------------------------------------------
/sparse_frontier/utils/checks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from sparse_frontier.utils.globals import GlobalSettings
4 | from sparse_frontier.utils.data import (
5 | read_jsonl,
6 | get_data_path,
7 | get_pred_path,
8 | get_results_path,
9 | )
10 |
11 |
12 | def prepration_needed():
13 | cfg = GlobalSettings.get("cfg")
14 | data_path = get_data_path()
15 |
16 | if not os.path.exists(data_path):
17 | return True
18 |
19 | if cfg.overwrite:
20 | os.remove(data_path)
21 | return True
22 |
23 | data = read_jsonl(data_path)
24 | return (len(data) < cfg.samples)
25 |
26 |
27 | def prediction_needed():
28 | cfg = GlobalSettings.get("cfg")
29 | pred_path = get_pred_path()
30 |
31 | if not os.path.exists(pred_path):
32 | return True
33 |
34 | if cfg.overwrite:
35 | os.remove(pred_path)
36 | return True
37 |
38 | data = read_jsonl(pred_path)
39 | return (len(data) < cfg.samples)
40 |
41 |
42 | def evaluation_needed():
43 | cfg = GlobalSettings.get("cfg")
44 | results_path = get_results_path()
45 |
46 | if not os.path.exists(results_path):
47 | return True
48 |
49 | if cfg.overwrite:
50 | os.remove(results_path)
51 | return True
52 |
53 | return False
54 |
--------------------------------------------------------------------------------
/sparse_frontier/utils/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from pathlib import Path
5 | from typing import List, Union
6 |
7 | from sparse_frontier.utils.globals import GlobalSettings
8 |
9 |
10 | def read_jsonl(manifest: Union[Path, str]) -> List[dict]:
11 | """Read and parse a JSONL file into a list of dictionaries.
12 |
13 | Args:
14 | manifest: Path to JSONL file to read
15 |
16 | Returns:
17 | List of dictionaries parsed from JSONL
18 |
19 | Raises:
20 | json.JSONDecodeError: If JSONL parsing fails
21 | Exception: If file cannot be read
22 | """
23 | try:
24 | with open(manifest, 'r', encoding='utf-8') as f:
25 | return [json.loads(line) for line in f if line.strip()]
26 | except json.JSONDecodeError as e:
27 | logging.error(f"Failed to parse line in manifest file {manifest}: {e}")
28 | raise
29 | except Exception as e:
30 | raise Exception(f"Could not read manifest file {manifest}") from e
31 |
32 |
33 | def write_jsonl(output_path: Union[Path, str], data: List[dict]) -> None:
34 | """Write a list of dictionaries to a JSONL file.
35 |
36 | Args:
37 | output_path: Path to output JSONL file
38 | data: List of dictionaries to serialize
39 | """
40 | with open(output_path, "w", encoding="utf-8") as f:
41 | for item in data:
42 | f.write(json.dumps(item) + '\n')
43 |
44 |
45 | def build_task_params_str(task_args, max_input_tokens) -> str:
46 | """Build descriptive string from task args and max_input_tokens.
47 |
48 | Args:
49 | task_args: List of tuples containing task args
50 | max_input_tokens: Maximum input tokens
51 |
52 | Returns:
53 | String with parameters in format: param1@value1+param2@value2
54 | """
55 | params = []
56 | for arg_name, arg_value in task_args:
57 | params.append(f"{arg_name}@{arg_value}")
58 | params.append(f"max_input_tokens@{max_input_tokens}")
59 | return "+".join(params)
60 |
61 |
62 | def build_attn_params_str(attn_args) -> str:
63 | """Build descriptive string for attention config.
64 |
65 | Args:
66 | attn_args: List of tuples containing attention args
67 |
68 | Returns:
69 | String with parameters in format: param1@value1+param2@value2 or 'default'
70 | """
71 | attn_params = []
72 | for arg_name, arg_value in attn_args:
73 | attn_params.append(f"{arg_name}@{arg_value}")
74 | return "+".join(attn_params) if attn_params else "default"
75 |
76 |
77 | def get_pred_dir() -> str:
78 | """Get the path to the predictions directory based on config settings.
79 | The path uses @ to separate parameter names from values and + to join parameters,
80 | which avoids ambiguities and potential issues with bash interpretation.
81 |
82 | Returns:
83 | Path string constructed from config parameters including task args
84 | """
85 | cfg = GlobalSettings.get('cfg')
86 |
87 | return os.path.join(
88 | cfg.paths.debug if cfg.debug else cfg.paths.predictions,
89 | cfg.task.name,
90 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens),
91 | cfg.model.name,
92 | cfg.attention.name,
93 | build_attn_params_str(cfg.attention.get('args', {}).items()),
94 | )
95 |
96 |
97 | def get_results_dir() -> str:
98 | """Get the path to the results directory based on config settings.
99 | The path uses @ to separate parameter names from values and + to join parameters,
100 | which avoids ambiguities and potential issues with bash interpretation.
101 |
102 | Returns:
103 | Path string constructed from config parameters including task args
104 | """
105 | cfg = GlobalSettings.get('cfg')
106 |
107 | return os.path.join(
108 | cfg.paths.debug if cfg.debug else cfg.paths.results,
109 | cfg.task.name,
110 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens),
111 | cfg.model.name,
112 | cfg.attention.name,
113 | build_attn_params_str(cfg.attention.get('args', {}).items()),
114 | )
115 |
116 |
117 | def get_data_dir() -> str:
118 | cfg = GlobalSettings.get('cfg')
119 |
120 | return os.path.join(
121 | cfg.paths.debug if cfg.debug else cfg.paths.data,
122 | cfg.task.name,
123 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens),
124 | cfg.model.name,
125 | )
126 |
127 |
128 | def get_data_path() -> str:
129 | """Get the path to the task's data file.
130 |
131 | Returns:
132 | Path to data.jsonl in the task directory
133 | """
134 | return os.path.join(get_data_dir(), 'data.jsonl')
135 |
136 |
137 | def get_pred_path() -> str:
138 | """Get the path to the task's predictions file.
139 |
140 | Returns:
141 | Path to pred.jsonl in the task directory
142 | """
143 | return os.path.join(get_pred_dir(), 'pred.jsonl')
144 |
145 |
146 | def get_results_path() -> str:
147 | """Get the path to the task's evaluation results file.
148 |
149 | Returns:
150 | Path to evaluation_results.json in the task directory
151 | """
152 |
153 | cfg = GlobalSettings.get('cfg')
154 | return os.path.join(get_results_dir(), f'evaluation_results_{cfg.samples}.json')
155 |
156 |
157 | def load_data_without_predictions() -> List[dict]:
158 | """Load task data excluding samples that already have predictions.
159 | Only returns samples with indexes up to cfg.samples.
160 |
161 | Returns:
162 | List of data samples that haven't been predicted yet, limited by index <= cfg.samples.
163 | """
164 | cfg = GlobalSettings.get('cfg')
165 | data_path = get_data_path()
166 | pred_path = get_pred_path()
167 |
168 | if os.path.exists(pred_path):
169 | pred_index = {sample['index'] for sample in read_jsonl(pred_path)}
170 | data = [sample for sample in read_jsonl(data_path)
171 | if sample['index'] not in pred_index and sample['index'] < cfg.samples]
172 | else:
173 | data = [sample for sample in read_jsonl(data_path)
174 | if sample['index'] < cfg.samples]
175 |
176 | return data
177 |
--------------------------------------------------------------------------------
/sparse_frontier/utils/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import socket
4 |
5 | from sparse_frontier.utils.globals import GlobalSettings
6 |
7 |
8 | def get_free_ports(n: int) -> list[int]:
9 | """Find N free ports on the local machine."""
10 | free_ports = []
11 | sockets = []
12 |
13 | try:
14 | for _ in range(n):
15 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
16 | s.bind(('localhost', 0)) # Bind to an available port
17 | free_ports.append(s.getsockname()[1]) # Get the assigned port number
18 | sockets.append(s) # Keep the socket open to reserve the port
19 | finally:
20 | # Close all sockets to release the ports
21 | for s in sockets:
22 | s.close()
23 |
24 | return free_ports
25 |
26 |
27 | def get_latest_commit_id():
28 | try:
29 | import git
30 | repo = git.Repo(search_parent_directories=True)
31 | return repo.head.object.hexsha
32 | except Exception:
33 | return None
34 |
35 |
36 | def save_config(dir_path: str):
37 | from omegaconf import OmegaConf
38 | from datetime import datetime
39 |
40 | cfg = GlobalSettings.get('cfg')
41 |
42 | config_dict = OmegaConf.to_container(cfg, resolve=True)
43 | config_dict['commit_id'] = get_latest_commit_id()
44 |
45 | # Add timestamp
46 | config_dict['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
47 |
48 | config_path = os.path.join(dir_path, "config.json")
49 | with open(config_path, "w") as f:
50 | json.dump(config_dict, f, indent=2)
51 |
--------------------------------------------------------------------------------
/sparse_frontier/utils/globals.py:
--------------------------------------------------------------------------------
1 | class GlobalSettings:
2 | _settings = {}
3 |
4 | @classmethod
5 | def get(cls, key, default=None):
6 | return cls._settings.get(key, default)
7 |
8 | @classmethod
9 | def set(cls, key, value):
10 | cls._settings[key] = value
11 |
12 |
13 | def is_vllm_profiling_done():
14 | return GlobalSettings.get("vllm_profiling_done", False)
15 |
16 |
17 | def set_vllm_profiling_done(done: bool):
18 | GlobalSettings.set("vllm_profiling_done", done)
19 |
--------------------------------------------------------------------------------