├── .gitignore ├── README.md ├── assets ├── photo.png └── pipfreeze.txt ├── compile.py ├── setup.py └── sparse_frontier ├── configs ├── attention │ ├── ada_snapkv.yaml │ ├── block_sparse.yaml │ ├── dense.yaml │ ├── flexprefill.yaml │ ├── quest.yaml │ ├── snapkv.yaml │ └── vertical_and_slash.yaml ├── default.yaml ├── model │ ├── qwen_14b.yaml │ ├── qwen_32b.yaml │ ├── qwen_72b.yaml │ └── qwen_7b.yaml └── task │ ├── qa_quality.yaml │ ├── qa_squad.yaml │ ├── qa_toefl.yaml │ ├── ruler_cwe.yaml │ ├── ruler_niah.yaml │ ├── ruler_vt.yaml │ ├── story_filtering.yaml │ ├── story_multihop.yaml │ └── story_retrieval.yaml ├── evaluation.py ├── main.py ├── modelling ├── __init__.py ├── attention │ ├── __init__.py │ ├── abstract_attention.py │ ├── efficient_decoding.py │ ├── efficient_prefilling.py │ ├── handler.py │ ├── kv_compression.py │ ├── minference │ │ ├── __init__.py │ │ ├── block.py │ │ ├── csrc │ │ │ ├── kernels.cpp │ │ │ └── vertical_slash_index.cu │ │ └── vertical_and_slash.py │ └── registry.py ├── models │ ├── __init__.py │ ├── abstract_model.py │ └── vllm_model.py └── tokenizer.py ├── prediction.py ├── preparation.py ├── tasks ├── __init__.py ├── abstract_prompt.py ├── abstract_sample.py ├── abstract_task.py ├── qa │ ├── data │ │ ├── quality.jsonl │ │ ├── squad.json │ │ └── toeflqa.jsonl │ ├── qa_data.py │ ├── qa_task.py │ └── qa_utils.py ├── registry.py ├── ruler │ ├── cwe.py │ ├── niah.py │ └── vt.py └── story │ ├── filtering.py │ ├── multihop.py │ ├── narrative.py │ ├── retrieval.py │ └── templates.py └── utils ├── __init__.py ├── checks.py ├── data.py ├── general.py └── globals.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | .vscode/ 4 | .venv 5 | .eggs 6 | *.egg-info 7 | *.log 8 | *.so 9 | build 10 | .DS_Store 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | logo 3 |

4 | 5 | ## TL;DR 6 | 7 | This repository contains the official implementation for the paper "[The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs](https://arxiv.org/abs/2504.17768)". We perform a large-scale empirical evaluation of *training-free* sparse attention methods in LLMs (7B to 72B parameters) on long sequences (16K to 128K tokens) across diverse tasks. 8 | 9 | **Key Findings:** 10 | 1. **IsoFLOPS:** For very long sequences, larger, highly sparse models are preferable to smaller, dense ones under a fixed compute budget. 11 | 2. **Sparsity Limits:** Higher sparsity can be tolerated during decoding than prefilling while preserving accuracy (correlating with model size), but even moderate sparsity often degrades performance on at least one task. 12 | 3. **No Universal Best:** The optimal sparse attention strategy depends on the task and inference phase; no single method excels everywhere. 13 | 4. **Scaling Laws:** We introduce and validate scaling laws specific to sparse attention, suggesting our findings generalize. 14 | 15 | **Conclusion:** Sparse attention is a key tool for long-context LLMs but requires careful evaluation of accuracy-efficiency trade-offs for specific applications. This codebase enables reproducing our experiments and further research. 16 | 17 | ## Setup 18 | 19 | Follow these steps to set up the environment and prepare for running experiments: 20 | 21 | 1. **Create Virtual Environment and Install Dependencies:** 22 | Set up a dedicated Python environment and install the required packages, including compiling custom CUDA kernels. 23 | 24 | ```bash 25 | # Create a virtual environment using Python 3.11 26 | python3.11 -m venv .venv 27 | 28 | # Activate the virtual environment 29 | source .venv/bin/activate 30 | 31 | # Upgrade pip and install essential build/utility tools 32 | pip install --no-cache-dir --upgrade pip setuptools wheel psutil ninja 33 | 34 | # Install PyTorch 35 | pip install --no-cache-dir torch==2.5.1 36 | 37 | # Install the sparse_frontier project in editable mode 38 | pip install --no-cache-dir -e . 39 | 40 | # Compile custom CUDA kernels (for MInference attention) 41 | # Adjust MAX_JOBS based on your system core count for faster compilation 42 | MAX_JOBS=8 python compile.py build_ext --inplace --build-lib ./sparse_frontier/modelling/attention/minference 43 | ``` 44 | 45 | 2. **Configure Paths:** 46 | Modify the default configuration file to specify where data, results, and checkpoints should be stored on your system. 47 | 48 | * Edit the `paths` section in `configs/default.yaml`. 49 | 50 | 3. **Download Model Checkpoints:** 51 | Obtain the pre-trained model weights you intend to evaluate from Hugging Face Hub. 52 | 53 | * Ensure the final directory structure for the downloaded checkpoints matches the format expected by the corresponding model configuration file (e.g., as defined in `configs/model/qwen_7b.yaml`). The `model.path` variable in these configs should point to the directory containing the model files. 54 | 55 | ## Where should I look at if I want to: 56 | 57 | ### Reproduce your experiments 58 | 59 | Experiments are launched using the main script `sparse_frontier.main`. Configuration is managed via [Hydra](https://hydra.cc/), allowing parameters to be set in YAML files (primarily `configs/default.yaml`) and overridden directly from the command line. 60 | 61 | The execution pipeline typically involves three stages, controlled by the `mode` parameter (defaulting to `all`): 62 | 1. **Preparation (`preparation.py`):** Generates and saves task-specific data based on the selected `task` configuration. Tasks are defined in `sparse_frontier/tasks/` (inheriting from `AbstractTask` and `AbstractSample`) and registered in `sparse_frontier/tasks/registry.py`. 63 | 2. **Prediction (`prediction.py`):** Runs the specified `model` with the chosen `attention` mechanism on the prepared data, saving the model outputs. Attention mechanisms are defined in `sparse_frontier/modelling/attention/` and registered in `sparse_frontier/modelling/attention/registry.py`. 64 | 3. **Evaluation (`evaluation.py`):** Compares the predictions against the gold answers using the task's specific evaluation logic and saves the final metrics. 65 | 66 | For example, to run the `ruler_niah` task using the `llama_8b` model configuration with standard `dense` attention on 4 samples using 2 GPUs: 67 | 68 | ```bash 69 | python -m sparse_frontier.main task=ruler_niah model=llama_8b attention=dense samples=4 gpus=2 70 | ``` 71 | 72 | ### Understand how you modified vLLM to test my own Sparse Attention 73 | 74 | We integrate custom sparse attention mechanisms by intercepting and modifying vLLM's standard attention execution flow. Here's a breakdown of the key components involved: 75 | 76 | 1. **Patching vLLM's Attention:** We replace vLLM's default `FlashAttentionImpl.forward` method with our custom function, `vllm_patched_forward` (defined in `sparse_frontier/modelling/models/vllm_model.py`). This function serves as the entry point for our custom attention logic within the vLLM generation loop. 77 | 78 | 2. **Centralized Handling:** The `vllm_patched_forward` function delegates the core processing to an `AttentionHandler` instance (from `sparse_frontier/modelling/attention/handler.py`). This handler manages layer-specific state (like token counts per head) and differentiates between the prefill and decoding phases of generation. 79 | 80 | 3. **Abstract Attention Interface:** The actual attention computation logic for different patterns is encapsulated in classes that inherit from `AbstractAttention` (defined in `sparse_frontier/modelling/attention/abstract_attention.py`). The `AttentionHandler` retrieves the currently configured attention implementation using `get_attention()` (from `sparse_frontier/modelling/attention/registry.py`). 81 | 82 | 4. **Implementing a Custom Pattern:** To introduce a new sparse attention mechanism: 83 | * Create a new class inheriting from `AbstractAttention`. 84 | * Implement the necessary methods based on your pattern's requirements: 85 | * `__call__(self, queries, keys, values, layer_idx)`: Implement the attention computation logic for the prefill phase. The default implementation uses standard FlashAttention. 86 | * `decode(self, query, keys, values, k_cache, v_cache, cache_seqlens, output, layer_idx)`: Implement the attention computation for the single-token decoding phase, typically involving interaction with the KV cache. The default uses `flash_attn_with_kvcache`. Specific methods like Quest (`efficient_decoding.py`) implement custom logic (e.g., page selection). 87 | * `kv_compress(self, queries, keys, values)`: (Optional) Implement logic to compress or select keys and values *after* the prefill computation, before they are written to the KV cache by `update_kv_cache` in `handler.py`. See `SnapKVCompression` (`kv_compression.py`) for an example. It should return the processed keys, values, and the resulting sequence lengths per head. 88 | 89 | 5. **Registration:** Add your new class to the `ATTENTION_REGISTRY` in `sparse_frontier/modelling/attention/registry.py`. This allows selecting your custom attention mechanism through the experiment configuration files. 90 | 91 | ### Understand or Create Experimental Datasets 92 | 93 | Experimental data generation is handled by task-specific modules located in `sparse_frontier/tasks/`. Each task implements `AbstractTask` and `AbstractSample` subclasses (defined in `sparse_frontier/tasks/abstract_*.py`) to define data loading, preprocessing, and the formatting of individual input prompts. Tasks are registered in `sparse_frontier/tasks/registry.py` and selected via configuration (e.g., `task=your_task_name`). The `preparation.py` script orchestrates the generation process based on the configuration, saving the formatted samples. See existing tasks like `QATask` (`qa_task.py`) or the Story task (`narrative.py`, `templates.py`) for implementation examples. 94 | 95 | ## References 96 | 97 | ### Sparse Attention Patterns 98 | 99 | In this repository, we evaluate 6 sparse attention patterns: 100 | 101 | | Pattern | Source | 102 | |---------|--------| 103 | | **Vertical-Slash / Block-Sparse** | [Microsoft](https://github.com/microsoft/MInference) | 104 | | **FlexPrefill** | [ByteDance-Seed](https://github.com/ByteDance-Seed/FlexPrefill) | 105 | | **SnapKV** | [FasterDecoding](https://github.com/FasterDecoding/SnapKV) | 106 | | **Ada-SnapKV** | [FFY0](https://github.com/FFY0/AdaKV) | 107 | | **Quest** | [MIT-HAN-Lab](https://github.com/mit-han-lab/Quest) | 108 | 109 | We either re-implement these patterns based on the original code or borrow implementations including kernels (for Vertical-Slash and Block-Sparse) from MInference. 110 | 111 | ### Evaluation Tasks 112 | 113 | Our evaluation framework includes the following tasks: 114 | 115 | 1. **RULER Tasks**: Re-implementation of NIAH, VT, and CWE tasks from [NVIDIA/RULER](https://github.com/NVIDIA/RULER) 116 | 117 | 2. **QA Tasks**: 118 | - Toefl and Quality datasets from [LC-VS-RAG](https://github.com/lixinze777/LC_VS_RAG) 119 | - Squad dataset from [NVIDIA/RULER](https://github.com/NVIDIA/RULER) 120 | 121 | 3. **Novel Story Tasks**: Narrative tasks developed specifically for this project. 122 | 123 | ## Cite 124 | 125 | If you found the repository useful consider citing the paper about this work. 126 | 127 | ``` 128 | @article{nawrot2025sparsefrontier, 129 | title={The Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs}, 130 | author={Piotr Nawrot and Robert Li and Renjie Huang and Sebastian Ruder and Kelly Marchisio and Edoardo M. Ponti}, 131 | year={2025}, 132 | journal={arXiv:2504.17768} 133 | url={https://arxiv.org/abs/2504.17768}, 134 | } 135 | ``` 136 | 137 | ## Issues: 138 | 139 | If you have any questions, feel free to raise a Github issue or contact me directly at: piotr.nawrot@ed.ac.uk 140 | -------------------------------------------------------------------------------- /assets/photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/assets/photo.png -------------------------------------------------------------------------------- /assets/pipfreeze.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.3.0 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.11.16 4 | aiohttp-cors==0.8.1 5 | aiosignal==1.3.2 6 | airportsdata==20250224 7 | annotated-types==0.7.0 8 | antlr4-python3-runtime==4.9.3 9 | anyio==4.9.0 10 | astor==0.8.1 11 | attrs==25.3.0 12 | blake3==1.0.4 13 | cachetools==5.5.2 14 | certifi==2025.1.31 15 | charset-normalizer==3.4.1 16 | click==8.1.8 17 | cloudpickle==3.1.1 18 | colorful==0.5.6 19 | compressed-tensors==0.8.1 20 | contourpy==1.3.2 21 | cycler==0.12.1 22 | datasets==3.2.0 23 | depyf==0.18.0 24 | dill==0.3.8 25 | diskcache==5.6.3 26 | distlib==0.3.9 27 | distro==1.9.0 28 | einops==0.8.1 29 | fastapi==0.115.12 30 | filelock==3.18.0 31 | flash-attn==2.7.3 32 | fonttools==4.57.0 33 | frozenlist==1.5.0 34 | fsspec==2024.9.0 35 | gguf==0.10.0 36 | gitdb==4.0.12 37 | GitPython==3.1.44 38 | google-api-core==2.24.2 39 | google-auth==2.39.0 40 | googleapis-common-protos==1.70.0 41 | grpcio==1.71.0 42 | h11==0.14.0 43 | httpcore==1.0.8 44 | httptools==0.6.4 45 | httpx==0.28.1 46 | huggingface-hub==0.30.2 47 | hydra-core==1.3.2 48 | idna==3.10 49 | importlib_metadata==8.6.1 50 | interegular==0.3.3 51 | Jinja2==3.1.6 52 | jiter==0.9.0 53 | jsonschema==4.23.0 54 | jsonschema-specifications==2024.10.1 55 | kiwisolver==1.4.8 56 | lark==1.2.2 57 | lm-format-enforcer==0.10.11 58 | MarkupSafe==3.0.2 59 | matplotlib==3.10.0 60 | mistral_common==1.5.4 61 | mpmath==1.3.0 62 | msgpack==1.1.0 63 | msgspec==0.19.0 64 | multidict==6.4.3 65 | multiprocess==0.70.16 66 | nest-asyncio==1.6.0 67 | networkx==3.4.2 68 | ninja==1.11.1.4 69 | numpy==1.26.4 70 | nvidia-cublas-cu12==12.4.5.8 71 | nvidia-cuda-cupti-cu12==12.4.127 72 | nvidia-cuda-nvrtc-cu12==12.4.127 73 | nvidia-cuda-runtime-cu12==12.4.127 74 | nvidia-cudnn-cu12==9.1.0.70 75 | nvidia-cufft-cu12==11.2.1.3 76 | nvidia-curand-cu12==10.3.5.147 77 | nvidia-cusolver-cu12==11.6.1.9 78 | nvidia-cusparse-cu12==12.3.1.170 79 | nvidia-ml-py==12.570.86 80 | nvidia-nccl-cu12==2.21.5 81 | nvidia-nvjitlink-cu12==12.4.127 82 | nvidia-nvtx-cu12==12.4.127 83 | omegaconf==2.3.0 84 | openai==1.74.0 85 | opencensus==0.11.4 86 | opencensus-context==0.1.3 87 | opencv-python-headless==4.11.0.86 88 | outlines==0.1.11 89 | outlines_core==0.1.26 90 | packaging==24.2 91 | pandas==2.2.3 92 | partial-json-parser==0.2.1.1.post5 93 | patsy==1.0.1 94 | pillow==11.2.1 95 | platformdirs==4.3.7 96 | prometheus-fastapi-instrumentator==7.1.0 97 | prometheus_client==0.21.1 98 | propcache==0.3.1 99 | proto-plus==1.26.1 100 | protobuf==6.30.2 101 | psutil==7.0.0 102 | py-cpuinfo==9.0.0 103 | py-spy==0.4.0 104 | pyarrow==19.0.1 105 | pyasn1==0.6.1 106 | pyasn1_modules==0.4.2 107 | pycountry==24.6.1 108 | pydantic==2.11.3 109 | pydantic_core==2.33.1 110 | pyparsing==3.2.3 111 | python-dateutil==2.9.0.post0 112 | python-dotenv==1.1.0 113 | pytz==2025.2 114 | PyYAML==6.0.2 115 | pyzmq==26.4.0 116 | ray==2.44.1 117 | referencing==0.36.2 118 | regex==2024.11.6 119 | requests==2.32.3 120 | rpds-py==0.24.0 121 | rsa==4.9.1 122 | safetensors==0.5.3 123 | scipy==1.15.2 124 | seaborn==0.13.2 125 | sentencepiece==0.2.0 126 | six==1.17.0 127 | smart-open==7.1.0 128 | smmap==5.0.2 129 | sniffio==1.3.1 130 | starlette==0.46.2 131 | statsmodels==0.14.4 132 | sympy==1.13.1 133 | tiktoken==0.9.0 134 | tokenizers==0.21.1 135 | torch==2.5.1 136 | torchvision==0.20.1 137 | tqdm==4.67.1 138 | transformers==4.48.0 139 | triton==3.1.0 140 | typing-inspection==0.4.0 141 | typing_extensions==4.13.2 142 | tzdata==2025.2 143 | urllib3==2.4.0 144 | uvicorn==0.34.1 145 | uvloop==0.21.0 146 | virtualenv==20.30.0 147 | vllm==0.6.6.post1 148 | watchfiles==1.0.5 149 | websockets==15.0.1 150 | wonderwords==2.2.0 151 | wrapt==1.17.2 152 | xformers==0.0.28.post3 153 | xgrammar==0.1.18 154 | xxhash==3.5.0 155 | yarl==1.19.0 156 | zipp==3.21.0 157 | -------------------------------------------------------------------------------- /compile.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='minference', 6 | ext_modules=[ 7 | CUDAExtension( 8 | name='minference', 9 | sources=['./sparse_frontier/modelling/attention/minference/csrc/kernels.cpp', './sparse_frontier/modelling/attention/minference/csrc/vertical_slash_index.cu'], 10 | ), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }, 15 | package_dir={'': 'sparse_frontier/modelling/attention/minference'} 16 | ) 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="sparse_frontier", 5 | version="0.01", 6 | description="Official implementation of the Sparse Frontier: Sparse Attention Trade-offs in Transformer LLMs", 7 | url="https://github.com/PiotrNawrot/sparse-frontier", 8 | packages=find_packages(include=['sparse_frontier', 'sparse_frontier.*']), 9 | entry_points={ 10 | 'vllm.general_plugins':[ 11 | "swap_vllm_attention = sparse_frontier.modelling.models.vllm_model:swap_vllm_attention" 12 | ] 13 | }, 14 | install_requires=[ 15 | "transformers==4.48.0", 16 | "datasets==3.2.0", 17 | "flash-attn==2.7.3", 18 | "vllm==0.6.6.post1", 19 | "accelerate==1.3.0", 20 | "hydra-core==1.3.2", 21 | "omegaconf==2.3.0", 22 | "matplotlib==3.10.0", 23 | "wonderwords", 24 | "gitpython", 25 | "matplotlib", 26 | "pandas", 27 | "seaborn", 28 | "statsmodels", 29 | "pyyaml", 30 | "numpy", 31 | ], 32 | python_requires=">=3.10", 33 | ) 34 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/ada_snapkv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: ada_snapkv 5 | args: 6 | token_capacity: 1024 7 | approximation_window: 256 8 | kernel_size: 21 9 | local_window: 128 10 | prefix_tokens: 4 11 | min_head_capacity_ratio: 0.2 12 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/block_sparse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: block_sparse 5 | args: 6 | chunk_size: 16 7 | top_chunks: 64 8 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/dense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: dense 5 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/flexprefill.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: flexprefill 5 | args: 6 | alpha: 0.65 7 | approximation_size: 512 8 | min_budget: 512 9 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/quest.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: quest 5 | args: 6 | token_budget: 1024 7 | page_size: 16 8 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/snapkv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: snapkv 5 | args: 6 | token_capacity: 1024 7 | approximation_window: 256 8 | kernel_size: 21 9 | local_window: 128 10 | prefix_tokens: 4 11 | -------------------------------------------------------------------------------- /sparse_frontier/configs/attention/vertical_and_slash.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | attention: 4 | name: vertical_and_slash 5 | args: 6 | vertical_size: 512 7 | slash_size: 512 8 | approximation_size: 256 9 | -------------------------------------------------------------------------------- /sparse_frontier/configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: qwen_7b 4 | - attention: dense 5 | - task: ruler_niah 6 | 7 | overwrite: false # overwrite current results 8 | debug: false # changes the directory we create data and results in - this makes overwrite safe to use 9 | mode: "all" 10 | gpus: 1 11 | 12 | samples: 100 13 | max_input_tokens: 8192 14 | max_output_tokens: 1024 15 | kv_cache_block_size: 256 16 | random_seed: 43 17 | 18 | print_eval_results: true 19 | 20 | paths: 21 | results: /writeable/evaluation_results 22 | predictions: /writeable/predictions 23 | data: /writeable/data 24 | debug: /writeable/debug 25 | checkpoints: /writeable/checkpoints 26 | 27 | hydra: 28 | run: 29 | dir: /writeable/evaluation_results/hydraoutputs/${now:%Y-%m-%d_%H-%M-%S} 30 | job: 31 | env_set: 32 | OMP_NUM_THREADS: 1 33 | -------------------------------------------------------------------------------- /sparse_frontier/configs/model/qwen_14b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | name: qwen_14b 5 | path: ${paths.checkpoints}/Qwen2.5-14B-Instruct 6 | num_q_heads: 40 7 | num_kv_heads: 8 8 | num_layers: 48 9 | hidden_dim: 5120 10 | intermediate_dim: 13824 11 | vocab_size: 152064 12 | 13 | tp: 1 14 | -------------------------------------------------------------------------------- /sparse_frontier/configs/model/qwen_32b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | name: qwen_32b 5 | path: ${paths.checkpoints}/Qwen2.5-32B-Instruct 6 | num_q_heads: 40 7 | num_kv_heads: 8 8 | num_layers: 64 9 | hidden_dim: 5120 10 | intermediate_dim: 27648 11 | vocab_size: 152064 12 | 13 | tp: 2 14 | -------------------------------------------------------------------------------- /sparse_frontier/configs/model/qwen_72b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | name: qwen_72b 5 | path: ${paths.checkpoints}/Qwen2.5-72B-Instruct 6 | num_q_heads: 64 7 | num_kv_heads: 8 8 | num_layers: 80 9 | hidden_dim: 8192 10 | intermediate_dim: 29568 11 | vocab_size: 152064 12 | 13 | tp: 4 14 | -------------------------------------------------------------------------------- /sparse_frontier/configs/model/qwen_7b.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | name: qwen_7b 5 | path: ${paths.checkpoints}/Qwen2.5-7B-Instruct 6 | num_q_heads: 28 7 | num_kv_heads: 4 8 | num_layers: 28 9 | hidden_dim: 3584 10 | intermediate_dim: 18944 11 | vocab_size: 152064 12 | 13 | tp: 1 14 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/qa_quality.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: qa_quality 5 | args: 6 | dataset_name: quality 7 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/qa_squad.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: qa_squad 5 | args: 6 | dataset_name: squad 7 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/qa_toefl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: qa_toefl 5 | args: 6 | dataset_name: toeflqa 7 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/ruler_cwe.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: ruler_cwe 5 | args: 6 | num_common_words: 10 7 | common_word_frequency: 30 8 | rare_word_frequency: 3 9 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/ruler_niah.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: ruler_niah 5 | args: 6 | num_queries: 4 7 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/ruler_vt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: ruler_vt 5 | args: 6 | num_chains: 8 7 | num_hops: 4 8 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/story_filtering.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: story_filtering 5 | args: 6 | chapters_in_question: 3 7 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/story_multihop.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: story_multihop 5 | -------------------------------------------------------------------------------- /sparse_frontier/configs/task/story_retrieval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task: 4 | name: story_retrieval 5 | args: 6 | num_queries: 16 7 | -------------------------------------------------------------------------------- /sparse_frontier/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import json 3 | from sparse_frontier.utils.general import GlobalSettings 4 | from sparse_frontier.utils.data import get_pred_path, get_data_path, read_jsonl, get_results_path 5 | from sparse_frontier.tasks.registry import TASK_REGISTRY 6 | 7 | 8 | def merge_data_and_predictions(data: List[dict], predictions: List[dict]) -> List[dict]: 9 | """Merge data samples with model predictions based on index. 10 | 11 | Args: 12 | data: List of dictionaries containing input data and gold answers 13 | predictions: List of dictionaries containing model predictions and metrics 14 | 15 | Returns: 16 | List of merged dictionaries with all fields 17 | 18 | Raises: 19 | AssertionError: If indexes don't match between data and predictions 20 | """ 21 | # Create index mappings 22 | data_by_index = {item['index']: item for item in data} 23 | pred_by_index = {item['index']: item for item in predictions} 24 | 25 | # Verify all indexes match 26 | data_indexes = set(data_by_index.keys()) 27 | pred_indexes = set(pred_by_index.keys()) 28 | assert data_indexes == pred_indexes, \ 29 | f"Mismatch between data and prediction indexes. Missing from data: {pred_indexes - data_indexes}. " \ 30 | f"Missing from predictions: {data_indexes - pred_indexes}" 31 | 32 | # Merge data and predictions 33 | merged = [] 34 | for idx in data_indexes: 35 | sample = data_by_index[idx].copy() 36 | sample.update(pred_by_index[idx]) 37 | merged.append(sample) 38 | 39 | assert len(merged) == len(data) == len(predictions), \ 40 | f"Merged data length {len(merged)} doesn't match input lengths: data={len(data)}, predictions={len(predictions)}" 41 | 42 | return merged 43 | 44 | 45 | def evaluate_task() -> None: 46 | cfg = GlobalSettings.get('cfg') 47 | 48 | results_file = get_results_path() 49 | 50 | # Load and merge data and predictions 51 | data = [x for x in read_jsonl(get_data_path()) if x['index'] < cfg.samples] 52 | predictions = [x for x in read_jsonl(get_pred_path()) if x['index'] < cfg.samples] 53 | examples = merge_data_and_predictions(data, predictions) 54 | 55 | metrics = TASK_REGISTRY[cfg.task.name].evaluate(examples) 56 | 57 | # Add total number of samples evaluated 58 | metrics['total_samples'] = len(examples) 59 | 60 | # Calculate average sparsity 61 | total_sparsity = sum(example['sparsity'] for example in examples) 62 | metrics['average_attention_sparsity'] = total_sparsity / len(examples) 63 | 64 | # Calculate average and max output token length if available 65 | if any('output_tokens_len' in example for example in examples): 66 | total_output_tokens = sum(example['output_tokens_len'] for example in examples if 'output_tokens_len' in example) 67 | num_examples_with_len = sum(1 for example in examples if 'output_tokens_len' in example) 68 | metrics['average_output_tokens'] = total_output_tokens / num_examples_with_len 69 | metrics['max_output_tokens'] = max(example['output_tokens_len'] for example in examples if 'output_tokens_len' in example) 70 | 71 | with open(results_file, 'w', encoding='utf-8') as f: 72 | json.dump(metrics, f, indent=4) 73 | 74 | print(f'Evaluation results for task {cfg.task.name} saved to {results_file}') 75 | 76 | if cfg.print_eval_results: 77 | print(f'Evaluation results for task {cfg.task.name}:') 78 | print(json.dumps(metrics, indent=2)) 79 | -------------------------------------------------------------------------------- /sparse_frontier/main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | 5 | def setting_up(cfg): 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | from sparse_frontier.utils import GlobalSettings 11 | from sparse_frontier.utils.data import get_data_path, get_pred_path, get_results_path 12 | 13 | random.seed(cfg.random_seed) 14 | np.random.seed(cfg.random_seed) 15 | # In case of torch seed it's set in vLLM Model class 16 | 17 | if cfg.overwrite: 18 | assert cfg.debug, "Overwrite is only allowed in debug mode" 19 | 20 | GlobalSettings.set('cfg', cfg) 21 | 22 | os.makedirs(os.path.dirname(get_data_path()), exist_ok=True) 23 | os.makedirs(os.path.dirname(get_pred_path()), exist_ok=True) 24 | os.makedirs(os.path.dirname(get_results_path()), exist_ok=True) 25 | 26 | 27 | def run(cfg): 28 | setting_up(cfg) 29 | 30 | from sparse_frontier.utils.checks import prepration_needed, prediction_needed, evaluation_needed 31 | 32 | if cfg.mode in ["prep", "all"]: 33 | if prepration_needed(): 34 | from sparse_frontier.preparation import prepare_task 35 | prepare_task() 36 | 37 | if cfg.mode in ["pred", "all"]: 38 | if prediction_needed(): 39 | from sparse_frontier.prediction import predict_task 40 | predict_task() 41 | 42 | if cfg.mode in ["eval", "all"]: 43 | if evaluation_needed(): 44 | from sparse_frontier.evaluation import evaluate_task 45 | evaluate_task() 46 | 47 | if cfg.mode not in ["prep", "pred", "eval", "all"]: 48 | raise ValueError(f'Invalid mode: {cfg.mode}') 49 | 50 | 51 | @hydra.main(config_path="configs", config_name="default", version_base="1.3") 52 | def main(cfg: DictConfig): 53 | run(cfg) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/__init__.py -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/attention/__init__.py -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/abstract_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from abc import ABC 4 | from vllm.distributed import ( 5 | get_tensor_model_parallel_world_size, 6 | tensor_model_parallel_all_gather, 7 | ) 8 | from flash_attn import flash_attn_func 9 | from vllm.attention.backends.flash_attn import flash_attn_with_kvcache 10 | 11 | 12 | class AttentionUtils: 13 | @staticmethod 14 | def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 15 | """Compute attention using Flash Attention 16 | 17 | flash_attn_func expects BLHD format, so we need to convert the input tensors to this format. 18 | 19 | Args: 20 | q: Query tensor of shape (batch_size, num_heads, seq_len, head_dim) 21 | k: Key tensor of shape (batch_size, num_kv_heads, seq_len, head_dim) 22 | v: Value tensor of shape (batch_size, num_kv_heads, seq_len, head_dim) 23 | 24 | Returns: 25 | Attention output tensor of shape (batch_size, num_heads, seq_len, head_dim) 26 | """ 27 | out = flash_attn_func( 28 | q.transpose(1, 2), 29 | k.transpose(1, 2), 30 | v.transpose(1, 2), 31 | causal=True, 32 | ) 33 | return out.transpose(1, 2) 34 | 35 | @staticmethod 36 | def reshape_kv_cache( 37 | kv_cache: torch.Tensor, 38 | target_block_size: int, 39 | max_blocks: int, 40 | ) -> tuple[torch.Tensor, torch.Tensor]: 41 | """Retrieve keys and values from cache for attention computation. 42 | 43 | Args: 44 | kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] 45 | target_block_size: Target block size for reshaping 46 | 47 | Returns: 48 | k_cache: [num_kv_heads, num_blocks, target_block_size, head_size] 49 | v_cache: [num_kv_heads, num_blocks, target_block_size, head_size] 50 | """ 51 | num_blocks, block_size, num_kv_heads, head_size = kv_cache[0].shape 52 | final_num_blocks = min(max_blocks, (num_blocks * block_size) // target_block_size) 53 | left_original_num_blocks = (final_num_blocks * target_block_size) // block_size 54 | 55 | k_cache = kv_cache[0, :left_original_num_blocks, :, :].view(num_kv_heads, final_num_blocks, target_block_size, head_size) 56 | v_cache = kv_cache[1, :left_original_num_blocks, :, :].view(num_kv_heads, final_num_blocks, target_block_size, head_size) 57 | 58 | return k_cache, v_cache 59 | 60 | 61 | class AbstractAttention(ABC): 62 | """Base class for attention implementations (both prefilling and KV compression)""" 63 | def __init__(self): 64 | self.sparsity_statistics = [] 65 | self.layer_sparsity_statistics = [] 66 | self.block_table = None 67 | self.cache_batch_idx = None 68 | 69 | def reset_sparsity_statistics(self): 70 | """Reset the accumulated sparsity statistics.""" 71 | self.sparsity_statistics = [] 72 | self.layer_sparsity_statistics = [] 73 | 74 | def sync_and_calc_layer_stats(self): 75 | # Ensure we have layer sparsity statistics to process 76 | if not self.layer_sparsity_statistics: 77 | raise AssertionError("Layer sparsity statistics list is empty. Make sure statistics are collected before syncing.") 78 | 79 | layer_sparsity = torch.stack(self.layer_sparsity_statistics).mean(dim=0, keepdim=True) 80 | 81 | if get_tensor_model_parallel_world_size() > 1: 82 | layer_sparsity = tensor_model_parallel_all_gather(layer_sparsity) 83 | 84 | self.sparsity_statistics.append(layer_sparsity.mean().item()) 85 | self.layer_sparsity_statistics = [] 86 | 87 | def calculate_sparsity(self) -> float: 88 | return sum(self.sparsity_statistics) / len(self.sparsity_statistics) 89 | 90 | def __call__( 91 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int 92 | ) -> torch.Tensor: 93 | """Compute attention with pattern-specific masking. 94 | 95 | Args: 96 | queries: Query tensor of shape (batch_size, num_heads, seq_len, head_dim) 97 | keys: Key tensor of shape (batch_size, num_kv_heads, seq_len, head_dim) 98 | values: Value tensor of shape (batch_size, num_kv_heads, seq_len, head_dim) 99 | layer_idx: Index of the current transformer layer 100 | Returns: 101 | Attention output tensor of shape (batch_size, num_heads, seq_len, head_dim) 102 | """ 103 | return AttentionUtils.flash_attention(queries, keys, values) 104 | 105 | def decode( 106 | self, 107 | query: torch.Tensor, # [1, num_heads, head_dim] 108 | keys: torch.Tensor, # [1, num_kv_heads, head_dim] 109 | values: torch.Tensor, # [1, num_kv_heads, head_dim] 110 | k_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim] 111 | v_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim] 112 | cache_seqlens: torch.Tensor, # [num_heads] 113 | output: torch.Tensor, # [1, num_heads, head_dim] 114 | layer_idx: int, 115 | ) -> torch.Tensor: 116 | """Compute attention during decoding phase using flash_attn_with_kvcache. 117 | 118 | Args: 119 | query: Query tensor for a single token [1, num_heads, head_dim] 120 | keys: Key tensor for the current token [1, num_kv_heads, head_dim] 121 | values: Value tensor for the current token [1, num_kv_heads, head_dim] 122 | k_cache: Key cache tensor [num_kv_heads, num_blocks, block_size, head_dim] 123 | v_cache: Value cache tensor [num_kv_heads, num_blocks, block_size, head_dim] 124 | cache_seqlens: Tensor of sequence lengths per head [num_heads] 125 | output: Output tensor to store results [1, num_heads, head_dim] 126 | layer_idx: Index of the current transformer layer 127 | """ 128 | _, num_q_heads, _ = query.shape 129 | num_kv_heads, num_blocks, block_size, head_size = k_cache.shape 130 | 131 | if self.block_table is None: 132 | block_indices = torch.arange(num_blocks * num_kv_heads, device=query.device, dtype=torch.int32).reshape(num_kv_heads, num_blocks) 133 | block_indices = block_indices.repeat(1, num_q_heads // num_kv_heads) 134 | self.block_table = block_indices.reshape(num_q_heads, num_blocks) 135 | 136 | flash_attn_with_kvcache( 137 | q=query.squeeze(0).unsqueeze(1).unsqueeze(1), 138 | k_cache=k_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size), 139 | v_cache=v_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size), 140 | block_table=self.block_table, 141 | cache_seqlens=cache_seqlens, 142 | causal=True, 143 | out=output.squeeze(0).unsqueeze(1).unsqueeze(1), 144 | ) 145 | 146 | def kv_compress( 147 | self, 148 | queries: torch.Tensor, # [num_tokens, num_heads, head_dim] 149 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 150 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 151 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 152 | """Compress KV cache after prefilling (default: no compression) 153 | 154 | Returns: 155 | tuple: (compressed_keys, compressed_values, seq_lens) where: 156 | - compressed_keys: [num_kv_heads, max_seq_len, head_size] 157 | - compressed_values: [num_kv_heads, max_seq_len, head_size] 158 | - seq_lens: [num_kv_heads] tensor with actual sequence length per head 159 | """ 160 | # Default implementation: no compression, all tokens kept 161 | seq_lens = torch.full((keys.size(1),), keys.size(0), device=keys.device, dtype=torch.long) 162 | # Transpose keys and values to match the expected output shape 163 | keys_t = keys.transpose(0, 1) # [num_kv_heads, num_tokens, head_size] 164 | values_t = values.transpose(0, 1) # [num_kv_heads, num_tokens, head_size] 165 | return keys_t, values_t, seq_lens 166 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/efficient_decoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple, List 3 | from .abstract_attention import AbstractAttention 4 | from .abstract_attention import AttentionUtils 5 | from vllm.attention.backends.flash_attn import flash_attn_with_kvcache 6 | 7 | 8 | def _update_last_page( 9 | page_reps: torch.Tensor, 10 | keys: torch.Tensor, # [1, num_heads, head_dim] 11 | cache_seqlens: int, 12 | page_size: int, 13 | ): 14 | """Update representations of the page containing the current token. 15 | 16 | Args: 17 | page_reps: Page representations tensor 18 | keys: Key tensor for the current token 19 | cache_seqlens: Cache sequence length as an integer 20 | page_size: Size of each page in the KV cache 21 | """ 22 | current_page_idx = (cache_seqlens - 1) // page_size 23 | 24 | page_reps[current_page_idx, 0] = torch.minimum( 25 | page_reps[current_page_idx, 0], 26 | keys.squeeze(0) 27 | ) 28 | 29 | page_reps[current_page_idx, 1] = torch.maximum( 30 | page_reps[current_page_idx, 1], 31 | keys.squeeze(0) 32 | ) 33 | 34 | 35 | def _select_pages( 36 | page_reps: torch.Tensor, 37 | query: torch.Tensor, 38 | cache_seqlens: int, 39 | page_size: int, 40 | page_budget: int, 41 | offsets: torch.Tensor, 42 | ) -> Tuple[torch.Tensor, int]: 43 | """Select most relevant pages based on query-page similarity. 44 | 45 | Args: 46 | page_reps: Page representations tensor 47 | query: Query tensor for the current token 48 | cache_seqlens: Cache sequence length as an integer 49 | page_size: Size of each page in the KV cache 50 | page_budget: Maximum number of pages to select 51 | offsets: Offsets for each KV head 52 | 53 | Returns: 54 | Tuple of (selected page indices, new cache sequence length) 55 | """ 56 | current_page_idx = (cache_seqlens - 1) // page_size 57 | assert current_page_idx > 0 and current_page_idx >= page_budget + 1 58 | 59 | # All models have GQA 60 | query_squeezed = query.squeeze(0) 61 | group_size = query_squeezed.size(0) // page_reps.size(2) 62 | page_reps = page_reps[:current_page_idx].repeat_interleave(group_size, dim=2) 63 | 64 | scores = torch.einsum( 65 | 'hd,prhd->hprd', 66 | query_squeezed, 67 | page_reps 68 | ) 69 | 70 | scores = scores.max(dim=2).values # [num_pages, num_heads, head_dim] 71 | scores = scores.sum(dim=-1) 72 | 73 | _, indices = torch.topk( 74 | scores, 75 | k=page_budget + 1, 76 | dim=1, 77 | sorted=True, 78 | ) 79 | 80 | indices[:, -1] = current_page_idx 81 | new_cache_seqlens = cache_seqlens - (current_page_idx - page_budget) * page_size 82 | new_cache_seqlens = torch.full((query.shape[1],), new_cache_seqlens, device=query.device, dtype=torch.int32) 83 | 84 | active_pages = indices.int() + offsets 85 | return active_pages, new_cache_seqlens 86 | 87 | 88 | _select_pages = torch.compile(_select_pages) 89 | _update_last_page = torch.jit.script(_update_last_page) 90 | 91 | 92 | class QuestAttention(AbstractAttention): 93 | """Quest attention for efficient decoding with dynamic page selection. 94 | 95 | Quest maintains min and max representations for each page of KV cache and uses 96 | them to dynamically select the most relevant pages during decoding. 97 | """ 98 | 99 | def __init__( 100 | self, 101 | token_budget: int, 102 | page_size: int, 103 | max_input_tokens: int, 104 | max_output_tokens: int, 105 | num_layers: int, 106 | ): 107 | """Initialize Quest attention. 108 | 109 | Args: 110 | token_budget: Maximum number of tokens to attend to 111 | page_size: Size of each page in the KV cache 112 | max_input_tokens: Maximum input token length (from config) 113 | max_output_tokens: Maximum output token length (from config) 114 | num_layers: Number of transformer layers (from model config) 115 | """ 116 | super().__init__() 117 | self.token_budget = token_budget 118 | self.page_size = page_size 119 | self.page_budget = token_budget // page_size 120 | assert token_budget % page_size == 0, "Token budget must be divisible by page size" 121 | 122 | self.max_pages = ((max_input_tokens + max_output_tokens) + page_size - 1) // page_size 123 | self.num_layers = num_layers 124 | 125 | # Page representations per layer 126 | self.page_reps_per_layer: List[Optional[torch.Tensor]] = [None] * num_layers 127 | self.offsets = None 128 | 129 | def _init_page_reps( 130 | self, 131 | keys: torch.Tensor, 132 | layer_idx: int = 0, 133 | ): 134 | """Initialize page representations during prefilling for a specific layer.""" 135 | _, num_heads, seq_len, head_dim = keys.shape 136 | keys = keys.squeeze(0).transpose(0, 1) # [seq_len, num_heads, head_dim] 137 | 138 | num_pages = (seq_len + self.page_size - 1) // self.page_size 139 | 140 | if self.page_reps_per_layer[layer_idx] is None: 141 | self.page_reps_per_layer[layer_idx] = torch.zeros( 142 | self.max_pages, 2, num_heads, head_dim, 143 | device=keys.device, dtype=keys.dtype 144 | ) 145 | 146 | self.page_reps_per_layer[layer_idx][:, 0] = float('inf') 147 | self.page_reps_per_layer[layer_idx][:, 1] = float('-inf') 148 | 149 | if seq_len % self.page_size != 0: 150 | complete_pages = seq_len // self.page_size 151 | complete_keys = keys[:complete_pages * self.page_size].view(complete_pages, self.page_size, num_heads, head_dim) 152 | complete_min = complete_keys.amin(dim=1) 153 | complete_max = complete_keys.amax(dim=1) 154 | self.page_reps_per_layer[layer_idx][:complete_pages] = torch.stack([complete_min, complete_max], dim=1) 155 | 156 | remainder_keys = keys[complete_pages * self.page_size:] 157 | 158 | self.page_reps_per_layer[layer_idx][complete_pages] = torch.stack([ 159 | remainder_keys.amin(dim=0), 160 | remainder_keys.amax(dim=0), 161 | ], dim=0) 162 | else: 163 | keys = keys.view(num_pages, self.page_size, num_heads, head_dim) 164 | self.page_reps_per_layer[layer_idx][:num_pages] = torch.stack([ 165 | keys.amin(dim=1), 166 | keys.amax(dim=1), 167 | ], dim=1) 168 | 169 | def __call__( 170 | self, 171 | queries: torch.Tensor, # [batch_size, num_heads, seq_len, head_dim] 172 | keys: torch.Tensor, # [batch_size, num_kv_heads, seq_len, head_dim] 173 | values: torch.Tensor, # [batch_size, num_kv_heads, seq_len, head_dim] 174 | layer_idx: int = 0, 175 | ) -> torch.Tensor: 176 | """ 177 | Args: 178 | queries: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] 179 | keys: Key tensor of shape [batch_size, num_kv_heads, seq_len, head_dim] 180 | values: Value tensor of shape [batch_size, num_kv_heads, seq_len, head_dim] 181 | layer_idx: Index of the current transformer layer 182 | 183 | Returns: 184 | Attention output tensor of shape [batch_size, num_heads, seq_len, head_dim] 185 | """ 186 | self._init_page_reps(keys, layer_idx) 187 | 188 | sparsity = 1.0 - self.token_budget / queries.shape[2] 189 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device)) 190 | 191 | return AttentionUtils.flash_attention(queries, keys, values) 192 | 193 | def decode( 194 | self, 195 | query: torch.Tensor, # [1, num_heads, head_dim] 196 | keys: torch.Tensor, # [1, num_kv_heads, head_dim] 197 | values: torch.Tensor, # [1, num_kv_heads, head_dim] 198 | k_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim] 199 | v_cache: torch.Tensor, # [num_kv_heads, num_blocks, block_size, head_dim] 200 | cache_seqlens: torch.Tensor, # [num_heads] 201 | output: torch.Tensor, # [1, num_heads, head_dim] 202 | layer_idx: int = 0, 203 | ) -> torch.Tensor: 204 | """Compute attention during decoding with dynamic page selection for a specific layer. 205 | 206 | Instead of retrieving the selected KV cache, we pass the entire KV cache 207 | and the selected page indices to flash_attn_with_kvcache. 208 | 209 | Args: 210 | query: Query tensor for a single token [1, num_heads, head_dim] 211 | keys: Key tensor for the current token [1, num_kv_heads, head_dim] 212 | values: Value tensor for the current token [1, num_kv_heads, head_dim] 213 | k_cache: Key cache tensor [num_kv_heads, num_blocks, block_size, head_dim] 214 | v_cache: Value cache tensor [num_kv_heads, num_blocks, block_size, head_dim] 215 | cache_seqlens: Tensor of sequence lengths per head [num_heads] 216 | output: Output tensor to store results [1, num_heads, head_dim] 217 | layer_idx: Index of the current transformer layer 218 | 219 | Returns: 220 | Attention output tensor of shape [1, num_heads, head_dim] 221 | """ 222 | num_kv_heads, num_blocks, block_size, head_size = k_cache.shape 223 | _, num_q_heads, _ = query.shape 224 | cache_seqlens_int = cache_seqlens[0].item() # Convert to integer for page selection 225 | 226 | if self.offsets is None: 227 | offsets = torch.arange(num_kv_heads, device=query.device, dtype=torch.int32) 228 | offsets = offsets.repeat_interleave(num_q_heads // num_kv_heads) 229 | offsets = offsets.unsqueeze(1) * num_blocks 230 | self.offsets = offsets 231 | 232 | _update_last_page( 233 | page_reps=self.page_reps_per_layer[layer_idx], 234 | keys=keys, 235 | cache_seqlens=cache_seqlens_int, 236 | page_size=self.page_size 237 | ) 238 | 239 | active_pages, new_cache_seqlens = _select_pages( 240 | page_reps=self.page_reps_per_layer[layer_idx], 241 | query=query, 242 | cache_seqlens=cache_seqlens_int, 243 | page_size=self.page_size, 244 | page_budget=self.page_budget, 245 | offsets=self.offsets, 246 | ) 247 | 248 | flash_attn_with_kvcache( 249 | q=query.squeeze(0).unsqueeze(1).unsqueeze(1), 250 | k_cache=k_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size), 251 | v_cache=v_cache.view(num_kv_heads * num_blocks, block_size, 1, head_size), 252 | block_table=active_pages, 253 | cache_seqlens=new_cache_seqlens, 254 | causal=True, 255 | out=output.squeeze(0).unsqueeze(1).unsqueeze(1), 256 | ) 257 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/efficient_prefilling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .abstract_attention import AbstractAttention 4 | from .minference import ( 5 | block_sparse_attention, 6 | vertical_and_slash_kernel, 7 | vertical_slash_sparse_attention, 8 | sum_over_diagonals, 9 | ) 10 | from .abstract_attention import AttentionUtils 11 | 12 | 13 | class DenseAttention(AbstractAttention): 14 | """Standard dense attention with causal masking.""" 15 | 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def __call__( 20 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int 21 | ) -> torch.Tensor: 22 | self.layer_sparsity_statistics.append(torch.tensor(0.0, device=queries.device)) 23 | return AttentionUtils.flash_attention(queries, keys, values) 24 | 25 | 26 | class BlockSparseAttentionMInference(AbstractAttention): 27 | """Block-sparse attention with chunk-level relevance and local windows.""" 28 | 29 | def __init__(self, chunk_size: int, top_chunks: int): 30 | super().__init__() 31 | self.chunk_size = chunk_size 32 | self.top_chunks = top_chunks 33 | 34 | assert chunk_size >= 8, "Recommended chunk size is >= 8" 35 | assert top_chunks >= 3, "Must select at least one top chunk" 36 | 37 | @staticmethod 38 | def _calculate_sparsity(seq_len: int, chunk_size: int, top_chunks: int) -> float: 39 | """Calculate sparsity ratio based on selected blocks vs total possible blocks. 40 | 41 | For autoregressive attention: 42 | - First query block picks 1 key block 43 | - First top_chunks query blocks pick max possible key blocks 44 | - Remaining query blocks pick top_chunks key blocks each 45 | 46 | Args: 47 | seq_len: Length of input sequence 48 | chunk_size: Size of each attention block 49 | top_chunks: Number of chunks to select per query 50 | 51 | Returns: 52 | Sparsity ratio between 0 and 1 53 | """ 54 | num_blocks = (seq_len + chunk_size - 1) // chunk_size 55 | 56 | total_blocks = num_blocks * (num_blocks + 1) // 2 57 | 58 | selected_blocks = top_chunks * (top_chunks + 1) // 2 59 | selected_blocks += (num_blocks - top_chunks) * top_chunks 60 | 61 | return 1.0 - (selected_blocks / total_blocks) 62 | 63 | @staticmethod 64 | def _get_blocks_for_sparsity(seq_len: int, chunk_size: int, target_sparsity: float) -> int: 65 | """Calculate number of blocks needed to achieve desired sparsity level. 66 | 67 | Uses binary search to find the number of blocks that gives sparsity closest 68 | to the target. The relationship between blocks and sparsity is monotonic. 69 | 70 | Args: 71 | seq_len: Length of input sequence 72 | chunk_size: Size of each attention block 73 | target_sparsity: Desired sparsity ratio between 0 and 1 74 | 75 | Returns: 76 | Number of blocks to select per query to achieve target sparsity 77 | """ 78 | num_blocks = (seq_len + chunk_size - 1) // chunk_size 79 | 80 | # Binary search for number of blocks 81 | left, right = 3, num_blocks # Minimum 4 blocks needed 82 | best_blocks = 3 83 | best_diff = float('inf') 84 | 85 | while left <= right: 86 | mid = (left + right) // 2 87 | sparsity = BlockSparseAttentionMInference._calculate_sparsity(seq_len, chunk_size, mid) 88 | 89 | diff = abs(sparsity - target_sparsity) 90 | if diff < best_diff: 91 | best_diff = diff 92 | best_blocks = mid 93 | 94 | if sparsity < target_sparsity: 95 | right = mid - 1 96 | else: 97 | left = mid + 1 98 | 99 | return best_blocks 100 | 101 | def __call__( 102 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int 103 | ) -> torch.Tensor: 104 | sparsity = torch.tensor(self._calculate_sparsity(queries.shape[2], self.chunk_size, self.top_chunks), device=queries.device) 105 | self.layer_sparsity_statistics.append(sparsity) 106 | return block_sparse_attention(queries, keys, values, self.top_chunks, self.chunk_size, self.chunk_size) 107 | 108 | 109 | class VerticalAndSlashAttentionMInference(AbstractAttention): 110 | """Combines vertical and diagonal patterns for efficient sparse attention. 111 | 112 | Implements the Vertical and Slash attention mechanism that selects important tokens based on: 113 | 1) Vertical patterns - Top-k tokens that receive high attention across all queries 114 | 2) Diagonal patterns - Diagonal stripes that capture local dependencies 115 | """ 116 | 117 | def __init__( 118 | self, 119 | vertical_size: int = 64, 120 | slash_size: int = 128, 121 | approximation_size: int = 64, 122 | ): 123 | """Initialize Vertical and Slash attention. 124 | 125 | Args: 126 | vertical_size (int): Number of vertical tokens to select 127 | slash_size (int): Number of diagonal stripes to select 128 | """ 129 | super().__init__() 130 | self.vertical_size = vertical_size 131 | self.slash_size = slash_size 132 | self.approximation_size = approximation_size 133 | 134 | def __call__( 135 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int 136 | ) -> torch.Tensor: 137 | attn_output, sparsity = vertical_and_slash_kernel( 138 | queries, 139 | keys, 140 | values, 141 | vertical_size=min(self.vertical_size, queries.shape[2]), 142 | slash_size=min(self.slash_size, queries.shape[2]), 143 | last_q=min(self.approximation_size, queries.shape[2]), 144 | ) 145 | self.layer_sparsity_statistics.append(sparsity) 146 | return attn_output 147 | 148 | 149 | class FlexPrefill(AbstractAttention): 150 | def __init__( 151 | self, 152 | alpha: float = 0.9, 153 | approximation_size: int = 512, 154 | min_budget: int = 256, 155 | ): 156 | super().__init__() 157 | self.alpha = alpha 158 | self.approximation_size = approximation_size 159 | self.min_budget = min_budget 160 | self.causal_mask = None 161 | 162 | @staticmethod 163 | def score_cover_topk(x: torch.Tensor, score: float): 164 | cumsum_x = torch.cumsum(torch.sort(x, dim=-1, descending=True).values, dim=-1) 165 | topk = torch.sum(cumsum_x <= score, dim=-1) + 1 166 | return topk 167 | 168 | def get_active_blocks( 169 | self, q, k, v 170 | ): 171 | _, _, seq_len, head_dim = q.shape 172 | 173 | # Compute attention scores for last queries 174 | last_q_tokens = q[..., -self.approximation_size:, :] / math.sqrt(head_dim) 175 | qk = torch.einsum('bhik,bhjk->bhij', last_q_tokens, k) 176 | 177 | # Apply causal masking 178 | if self.causal_mask is None: 179 | self.causal_mask = torch.arange(0, self.approximation_size, device=last_q_tokens.device) 180 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :] 181 | self.causal_mask = self.causal_mask[None, None, ...] 182 | 183 | qk[..., -self.approximation_size:] = torch.where( 184 | self.causal_mask, 185 | qk[..., -self.approximation_size:], 186 | torch.tensor(float("-inf"), device=qk.device, dtype=qk.dtype) 187 | ) 188 | 189 | # Get attention patterns 190 | scores = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) 191 | 192 | # Compute vertical patterns 193 | vertical = scores.mean(dim=-2) 194 | vertical_size = max(self.min_budget, self.score_cover_topk(vertical, self.alpha).item()) 195 | vertical[..., :4] = float("inf") 196 | vertical_topk = torch.topk(vertical, vertical_size, -1).indices 197 | 198 | # Fixed local window for slash patterns 199 | slashes = sum_over_diagonals(scores)[..., :-self.approximation_size + 1] / self.approximation_size 200 | slash_size = max(self.min_budget, self.score_cover_topk(slashes, self.alpha).item()) 201 | slashes[..., -64:] = float("inf") 202 | slash = (seq_len - 1) - torch.topk(slashes, slash_size, -1).indices 203 | 204 | return vertical_topk, slash 205 | 206 | def __call__( 207 | self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, layer_idx: int 208 | ) -> torch.Tensor: 209 | assert queries.shape[-1] == keys.shape[-1] 210 | 211 | # Get active blocks for sparse portion 212 | vertical_idx, slash_idx = self.get_active_blocks(queries, keys, values) 213 | 214 | # Calculate sparse attention for non-dense portion 215 | sparse_out, sparsity = vertical_slash_sparse_attention( 216 | queries, 217 | keys, 218 | values, 219 | vertical_idx, 220 | slash_idx, 221 | ) 222 | 223 | self.layer_sparsity_statistics.append(sparsity) 224 | 225 | return sparse_out 226 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .abstract_attention import AttentionUtils 3 | 4 | 5 | def update_kv_cache( 6 | key: torch.Tensor, 7 | value: torch.Tensor, 8 | k_cache: torch.Tensor, 9 | v_cache: torch.Tensor, 10 | is_prefilling: bool, 11 | tokens_per_head: torch.Tensor, 12 | q_heads_per_kv: int, 13 | head_indices: torch.Tensor = None, 14 | queries: torch.Tensor = None, 15 | ): 16 | """Update the KV cache with new key/value tensors. 17 | 18 | Args: 19 | key: [num_tokens, num_kv_heads, head_size] 20 | value: [num_tokens, num_kv_heads, head_size] 21 | k_cache: [num_kv_heads, num_blocks, block_size, head_size] 22 | v_cache: [num_kv_heads, num_blocks, block_size, head_size] 23 | is_prefilling: Whether we're in prefilling phase 24 | tokens_per_head: Tensor tracking token counts per head for current layer 25 | q_heads_per_kv: Number of query heads per key/value head 26 | head_indices: Precomputed indices for heads to avoid recreating tensor each time 27 | queries: [num_tokens, num_heads, head_size], required for compression 28 | """ 29 | if is_prefilling: 30 | from sparse_frontier.modelling.attention.registry import get_attention 31 | key, value, seq_lens = get_attention().kv_compress( 32 | queries=queries, 33 | keys=key, 34 | values=value, 35 | ) 36 | 37 | num_kv_heads, num_tokens, _ = key.shape 38 | 39 | # View the cache as a contiguous vector by merging block dimensions 40 | k_cache_flat = k_cache.view(num_kv_heads, -1, k_cache.shape[-1]) 41 | v_cache_flat = v_cache.view(num_kv_heads, -1, v_cache.shape[-1]) 42 | 43 | # Fill the prefix of the flattened cache 44 | k_cache_flat[:, :num_tokens] = key 45 | v_cache_flat[:, :num_tokens] = value 46 | 47 | tokens_per_head += seq_lens.repeat_interleave(q_heads_per_kv) 48 | else: 49 | num_tokens, num_kv_heads, _ = key.shape 50 | 51 | # Decoding case: single token update, use tokens_per_head to place correctly 52 | assert num_tokens == 1, "Decoding should only add one token at a time" 53 | 54 | # View the cache as a contiguous vector by merging block dimensions 55 | k_cache_flat = k_cache.view(num_kv_heads, -1, k_cache.shape[-1]) 56 | v_cache_flat = v_cache.view(num_kv_heads, -1, v_cache.shape[-1]) 57 | 58 | # Update the cache at the specific positions for all heads at once 59 | k_cache_flat[head_indices, tokens_per_head[::q_heads_per_kv]] = key[0] 60 | v_cache_flat[head_indices, tokens_per_head[::q_heads_per_kv]] = value[0] 61 | 62 | # Update token counts 63 | tokens_per_head += 1 64 | 65 | 66 | class AttentionHandler: 67 | def __init__( 68 | self, 69 | tp_size: int, 70 | model_q_heads: int, 71 | model_kv_heads: int, 72 | model_layers: int, 73 | max_input_tokens: int, 74 | max_output_tokens: int, 75 | block_size: int 76 | ): 77 | """Initialize the attention handler. 78 | 79 | Args: 80 | tp_size: Tensor parallelism size 81 | model_q_heads: Total number of query heads 82 | model_kv_heads: Total number of key/value heads 83 | model_layers: Number of transformer layers 84 | max_input_tokens: Maximum number of input tokens 85 | max_output_tokens: Maximum number of output tokens 86 | block_size: default size of each block in the KV cache 87 | """ 88 | self.tp_size = tp_size 89 | self.model_q_heads = model_q_heads 90 | self.model_kv_heads = model_kv_heads 91 | self.q_heads_per_gpu = model_q_heads // tp_size 92 | self.kv_heads_per_gpu = model_kv_heads // tp_size 93 | self.q_heads_per_kv = model_q_heads // model_kv_heads 94 | self.model_layers = model_layers 95 | self.max_seq_len = max_input_tokens + max_output_tokens 96 | self.max_blocks = (self.max_seq_len + block_size - 1) // block_size 97 | self.block_size = block_size 98 | 99 | self.current_layer = 0 100 | self.tokens_per_layer_head = torch.zeros( 101 | (model_layers, self.q_heads_per_gpu), 102 | dtype=torch.int32, 103 | device="cpu" 104 | ) 105 | self.head_indices = None 106 | 107 | def _compute_attention_head_by_head( 108 | self, 109 | queries: torch.Tensor, 110 | keys: torch.Tensor, 111 | values: torch.Tensor, 112 | output: torch.Tensor, 113 | ) -> torch.Tensor: 114 | """Compute attention output by processing each head separately. 115 | 116 | Args: 117 | queries: [num_tokens, num_heads, head_size] 118 | keys: [num_tokens, num_kv_heads, head_size] 119 | values: [num_tokens, num_kv_heads, head_size] 120 | output: [num_tokens, num_heads, head_size] 121 | """ 122 | from sparse_frontier.modelling.attention.registry import get_attention 123 | 124 | # Get the attention implementation 125 | attention = get_attention() 126 | 127 | efficient_attention_classes = [ 128 | "BlockSparseAttentionMInference", 129 | "VerticalAndSlashAttentionMInference", 130 | "FlexPrefill" 131 | ] 132 | 133 | if attention.__class__.__name__ not in efficient_attention_classes: 134 | output[:] = attention( 135 | queries=queries.transpose(0, 1).unsqueeze(0), 136 | keys=keys.transpose(0, 1).unsqueeze(0), 137 | values=values.transpose(0, 1).unsqueeze(0), 138 | layer_idx=self.current_layer, 139 | ).squeeze(0).transpose(0, 1) 140 | return 141 | 142 | # Otherwise, process head by head 143 | head_size = queries.shape[-1] 144 | q_heads_per_kv = queries.shape[1] // keys.shape[1] 145 | 146 | for head in range(self.q_heads_per_gpu): 147 | kv_head = head // q_heads_per_kv 148 | output[:, head, :] = attention( 149 | queries=queries[:, head, :].view(1, 1, -1, head_size), 150 | keys=keys[:, kv_head, :].view(1, 1, -1, head_size), 151 | values=values[:, kv_head, :].view(1, 1, -1, head_size), 152 | layer_idx=self.current_layer, 153 | ).view(-1, head_size) 154 | 155 | def __call__( 156 | self, 157 | queries: torch.Tensor, 158 | keys: torch.Tensor, 159 | values: torch.Tensor, 160 | kv_cache: torch.Tensor, 161 | output: torch.Tensor, 162 | ) -> torch.Tensor: 163 | """Process attention for either prefill or decode phase. 164 | 165 | Args: 166 | queries: [num_tokens, num_heads, head_size] 167 | keys: [num_tokens, num_kv_heads, head_size] 168 | values: [num_tokens, num_kv_heads, head_size] 169 | kv_cache: [2, num_blocks, block_size, num_kv_heads, head_size] 170 | output: [num_tokens, num_heads, head_size] 171 | 172 | Returns: 173 | attention_output: [num_tokens, num_heads, head_size] 174 | """ 175 | num_tokens = queries.shape[0] 176 | is_prefilling = num_tokens > 1 177 | 178 | if is_prefilling: 179 | self.tokens_per_layer_head[self.current_layer, :] = 0 180 | 181 | # Initialize head_indices if not already done 182 | if self.head_indices is None and keys.numel() > 0: 183 | num_kv_heads = keys.shape[1] 184 | self.head_indices = torch.arange(num_kv_heads, device=keys.device) 185 | self.tokens_per_layer_head = self.tokens_per_layer_head.to(keys.device) 186 | 187 | if kv_cache.numel() == 0: 188 | self._compute_attention_head_by_head( 189 | queries=queries, 190 | keys=keys, 191 | values=values, 192 | output=output, 193 | ) 194 | self.current_layer = (self.current_layer + 1) % self.model_layers 195 | return 196 | 197 | if is_prefilling: 198 | self._compute_attention_head_by_head( 199 | queries=queries, 200 | keys=keys, 201 | values=values, 202 | output=output, 203 | ) 204 | 205 | k_cache, v_cache = AttentionUtils.reshape_kv_cache(kv_cache, self.block_size, self.max_blocks) 206 | 207 | # Then update cache (with compression if prefilling) 208 | update_kv_cache( 209 | keys, 210 | values, 211 | k_cache, 212 | v_cache, 213 | is_prefilling=is_prefilling, 214 | tokens_per_head=self.tokens_per_layer_head[self.current_layer], 215 | q_heads_per_kv=self.q_heads_per_kv, 216 | head_indices=self.head_indices, 217 | queries=queries if is_prefilling else None, 218 | ) 219 | 220 | if is_prefilling: 221 | # By this point we should have put all the layer sparsity statistics for all methods 222 | from sparse_frontier.modelling.attention.registry import get_attention 223 | get_attention().sync_and_calc_layer_stats() 224 | 225 | if not is_prefilling: 226 | from sparse_frontier.modelling.attention.registry import get_attention 227 | 228 | get_attention().decode( 229 | query=queries, 230 | keys=keys, 231 | values=values, 232 | k_cache=k_cache, 233 | v_cache=v_cache, 234 | cache_seqlens=self.tokens_per_layer_head[self.current_layer], 235 | output=output, 236 | layer_idx=self.current_layer, 237 | ) 238 | 239 | self.current_layer = (self.current_layer + 1) % self.model_layers 240 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/kv_compression.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from .abstract_attention import AbstractAttention 5 | from sparse_frontier.utils.globals import is_vllm_profiling_done 6 | 7 | 8 | class SnapKVCompression(AbstractAttention): 9 | """SnapKV compression for efficient decoding""" 10 | def __init__( 11 | self, 12 | token_capacity: int, 13 | approximation_window: int = 256, 14 | kernel_size: int = 7, 15 | local_window: int = 128, 16 | prefix_tokens: int = 4, 17 | ): 18 | super().__init__() 19 | self.token_capacity = token_capacity 20 | self.approximation_window = approximation_window 21 | self.kernel_size = kernel_size 22 | self.local_window = local_window 23 | self.prefix_tokens = prefix_tokens 24 | self.causal_mask = None 25 | 26 | def kv_compress( 27 | self, 28 | queries: torch.Tensor, # [num_tokens, num_q_heads, head_size] 29 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 30 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 31 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 32 | """Compress KV cache using SnapKV approach with GQA support""" 33 | # Use last approximation_window queries to estimate importance 34 | approx_queries = queries[-self.approximation_window:] / math.sqrt(queries.size(-1)) # [approx_window, num_q_heads, head_size] 35 | 36 | num_q_heads = queries.size(1) 37 | num_kv_heads = keys.size(1) 38 | group_size = num_q_heads // num_kv_heads 39 | 40 | # Repeat keys for each query head in the group 41 | # [num_tokens, num_kv_heads, head_size] -> [num_tokens, num_q_heads, head_size] 42 | repeated_keys = keys.repeat_interleave(group_size, dim=1) 43 | 44 | # Now we can calculate attention scores with matching head dimensions 45 | scores = torch.einsum( 46 | 'thd,nhd->htn', 47 | approx_queries, 48 | repeated_keys 49 | ) 50 | 51 | # Initialize and cache causal mask if not already created 52 | if self.causal_mask is None: 53 | self.causal_mask = torch.arange(0, self.approximation_window, device=scores.device) 54 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :] 55 | self.causal_mask = self.causal_mask[None, ...] # Add head dimension 56 | 57 | # Apply causal masking and softmax 58 | scores[..., -self.approximation_window:] = torch.where( 59 | self.causal_mask, 60 | scores[..., -self.approximation_window:], 61 | torch.tensor(float("-inf"), device=scores.device, dtype=scores.dtype) 62 | ) 63 | 64 | attn_weights = torch.softmax(scores, dim=-1, dtype=torch.float32).to(keys.dtype) 65 | 66 | # Reshape attention weights to group query heads [num_kv_heads, group_size, approx_window, num_tokens] 67 | grouped_weights = attn_weights.view(num_kv_heads, group_size, -1, attn_weights.size(-1)) 68 | # Average across group dimension [num_kv_heads, approx_window, num_tokens] 69 | token_importance = grouped_weights.mean(dim=1).sum(dim=1) # [num_kv_heads, num_tokens] 70 | 71 | # Apply pooling for smoother selection per head 72 | token_importance = F.avg_pool1d( 73 | token_importance.unsqueeze(1), 74 | kernel_size=self.kernel_size, 75 | padding=self.kernel_size // 2, 76 | stride=1 77 | ).squeeze(1) 78 | 79 | token_importance[..., -self.local_window:] = float('inf') 80 | token_importance[..., :self.prefix_tokens] = float('inf') 81 | 82 | # Select top-k tokens per head 83 | capacity = min(self.token_capacity, keys.size(0)) 84 | 85 | assert capacity > self.local_window + self.prefix_tokens, f"Capacity {capacity} must be greater than local_window {self.local_window} + prefix_tokens {self.prefix_tokens}" 86 | 87 | _, indices = torch.topk(token_importance, k=capacity, dim=-1) # [num_kv_heads, capacity] 88 | 89 | # Expand indices for gathering 90 | # [num_kv_heads, capacity] -> [num_kv_heads, capacity, head_size] 91 | expanded_indices = indices.unsqueeze(-1).expand(-1, -1, keys.size(-1)) 92 | 93 | compressed_keys = torch.gather( 94 | keys.transpose(0, 1), # [num_kv_heads, num_tokens, head_size] 95 | dim=1, 96 | index=expanded_indices 97 | ) 98 | 99 | compressed_values = torch.gather( 100 | values.transpose(0, 1), # [num_kv_heads, num_tokens, head_size] 101 | dim=1, 102 | index=expanded_indices 103 | ) 104 | 105 | # Track sparsity - based on fixed capacity ratio 106 | sparsity = 1.0 - (capacity / keys.size(0)) 107 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device)) 108 | 109 | # Create sequence length tensor (same for all heads) 110 | seq_lens = torch.full((num_kv_heads,), capacity, device=queries.device, dtype=torch.long) 111 | 112 | return compressed_keys, compressed_values, seq_lens 113 | 114 | 115 | class AdaSnapKVCompression(AbstractAttention): 116 | """Adaptive SnapKV compression with non-uniform token distribution across heads""" 117 | def __init__( 118 | self, 119 | token_capacity: int, 120 | approximation_window: int = 256, 121 | kernel_size: int = 7, 122 | local_window: int = 128, 123 | prefix_tokens: int = 4, 124 | min_head_capacity_ratio: float = 0.2, 125 | ): 126 | super().__init__() 127 | self.token_capacity = token_capacity 128 | self.approximation_window = approximation_window 129 | self.kernel_size = kernel_size 130 | self.local_window = local_window 131 | self.prefix_tokens = prefix_tokens 132 | self.min_head_capacity_ratio = min_head_capacity_ratio 133 | self.causal_mask = None 134 | 135 | def kv_compress( 136 | self, 137 | queries: torch.Tensor, # [num_tokens, num_q_heads, head_size] 138 | keys: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 139 | values: torch.Tensor, # [num_tokens, num_kv_heads, head_size] 140 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 141 | """Compress KV cache using AdaSnapKV approach with adaptive token distribution""" 142 | assert self.approximation_window < keys.size(0) 143 | 144 | # Use last approximation_window queries to estimate importance 145 | approx_queries = queries[-self.approximation_window:] / math.sqrt(queries.size(-1)) # [approx_window, num_q_heads, head_size] 146 | 147 | num_q_heads = queries.size(1) 148 | num_kv_heads = keys.size(1) 149 | group_size = num_q_heads // num_kv_heads 150 | 151 | # Repeat keys for each query head in the group 152 | # [num_tokens, num_kv_heads, head_size] -> [num_tokens, num_q_heads, head_size] 153 | repeated_keys = keys.repeat_interleave(group_size, dim=1) 154 | 155 | # Now we can calculate attention scores with matching head dimensions 156 | scores = torch.einsum( 157 | 'thd,nhd->htn', 158 | approx_queries, 159 | repeated_keys 160 | ) 161 | 162 | # Initialize and cache causal mask if not already created 163 | if self.causal_mask is None: 164 | self.causal_mask = torch.arange(0, self.approximation_window, device=scores.device) 165 | self.causal_mask = self.causal_mask[:, None] >= self.causal_mask[None, :] 166 | self.causal_mask = self.causal_mask[None, ...] # Add head dimension 167 | 168 | # Apply causal masking and softmax 169 | scores[..., -self.approximation_window:] = torch.where( 170 | self.causal_mask, 171 | scores[..., -self.approximation_window:], 172 | torch.tensor(float("-inf"), device=scores.device, dtype=scores.dtype) 173 | ) 174 | 175 | attn_weights = torch.softmax(scores, dim=-1, dtype=torch.float32).to(keys.dtype) 176 | 177 | grouped_weights = attn_weights.view(num_kv_heads, group_size, -1, attn_weights.size(-1)) 178 | 179 | token_importance = grouped_weights.max(dim=2)[0].max(dim=1)[0] 180 | 181 | # Apply pooling for smoother selection per head 182 | token_importance = F.avg_pool1d( 183 | token_importance.unsqueeze(1), 184 | kernel_size=self.kernel_size, 185 | padding=self.kernel_size // 2, 186 | stride=1 187 | ).squeeze(1) 188 | 189 | # Always keep recent tokens and prefix tokens 190 | token_importance[..., -self.local_window:] = float('inf') 191 | token_importance[..., :self.prefix_tokens] = float('inf') 192 | 193 | # Calculate total capacity and minimum capacity per head 194 | total_capacity = self.token_capacity * num_kv_heads 195 | assert total_capacity <= keys.size(0) * num_kv_heads 196 | min_capacity_per_head = int(self.token_capacity * self.min_head_capacity_ratio) 197 | min_capacity_per_head = max(min_capacity_per_head, self.local_window + self.prefix_tokens) 198 | assert self.token_capacity >= self.local_window + self.prefix_tokens 199 | assert min_capacity_per_head <= keys.size(0) 200 | assert total_capacity >= min_capacity_per_head * num_kv_heads 201 | remaining_capacity = total_capacity - (min_capacity_per_head * num_kv_heads) 202 | 203 | # Get the top-k tokens for minimum capacity per head 204 | _, min_indices = torch.topk(token_importance, k=min_capacity_per_head, dim=-1) # [num_kv_heads, min_capacity] 205 | 206 | # Vectorized mask creation using scatter 207 | selected_mask = torch.zeros_like(token_importance, dtype=torch.bool) 208 | selected_mask.scatter_( 209 | dim=1, 210 | index=min_indices, 211 | src=torch.ones_like(min_indices, dtype=torch.bool) 212 | ) 213 | 214 | # Vectorized masking 215 | masked_importance = token_importance.masked_fill(selected_mask, float('-inf')) 216 | flat_importance = masked_importance.view(-1) 217 | 218 | # Global selection with vectorized index conversion 219 | _, flat_indices = torch.topk(flat_importance, k=remaining_capacity, dim=-1) 220 | 221 | # Flatten and update selected_mask 222 | flat_selected_mask = selected_mask.view(-1) 223 | flat_selected_mask.scatter_(0, flat_indices, True) 224 | selected_mask = flat_selected_mask.view(num_kv_heads, -1) 225 | 226 | seq_lens = selected_mask.sum(dim=1) 227 | max_seq_len = seq_lens.max().item() 228 | compressed_keys = torch.zeros(num_kv_heads, max_seq_len, keys.size(-1), device=keys.device, dtype=keys.dtype) 229 | compressed_values = torch.zeros(num_kv_heads, max_seq_len, values.size(-1), device=values.device, dtype=values.dtype) 230 | 231 | keys_t = keys.transpose(0, 1) 232 | values_t = values.transpose(0, 1) 233 | 234 | for head_idx in range(num_kv_heads): 235 | compressed_keys[head_idx, :seq_lens[head_idx]] = keys_t[head_idx, selected_mask[head_idx]] 236 | compressed_values[head_idx, :seq_lens[head_idx]] = values_t[head_idx, selected_mask[head_idx]] 237 | 238 | # Track sparsity - based on actual tokens kept 239 | total_tokens_kept = seq_lens.sum().item() 240 | sparsity = 1.0 - (total_tokens_kept / (keys.size(0) * num_kv_heads)) 241 | self.layer_sparsity_statistics.append(torch.tensor(sparsity, device=queries.device)) 242 | 243 | return compressed_keys, compressed_values, seq_lens 244 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/minference/__init__.py: -------------------------------------------------------------------------------- 1 | from .block import block_sparse_attention 2 | from .vertical_and_slash import vertical_and_slash_kernel, vertical_slash_sparse_attention, sum_over_diagonals 3 | 4 | 5 | __all__ = [ 6 | "block_sparse_attention", 7 | "vertical_and_slash_kernel", 8 | "vertical_slash_sparse_attention", 9 | "sum_over_diagonals", 10 | ] 11 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/minference/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | # Code adopted from https://github.com/microsoft/MInference 4 | 5 | import torch 6 | import triton 7 | import triton.language as tl 8 | 9 | 10 | @triton.jit 11 | def _triton_block_sparse_attn_fwd_kernel( 12 | Q, K, V, seqlens, sm_scale, 13 | block_index, 14 | Out, 15 | stride_qz, stride_qh, stride_qm, stride_qk, 16 | stride_kz, stride_kh, stride_kn, stride_kk, 17 | stride_vz, stride_vh, stride_vn, stride_vk, 18 | stride_oz, stride_oh, stride_om, stride_ok, 19 | Z, H, N_CTX, 20 | NUM_ROWS, MAX_BLOCKS_PRE_ROW, 21 | BLOCK_M: tl.constexpr, 22 | BLOCK_N: tl.constexpr, 23 | BLOCK_DMODEL: tl.constexpr, 24 | dtype: tl.constexpr, 25 | ): 26 | start_m = tl.program_id(0) 27 | off_hz = tl.program_id(1) 28 | 29 | seqlen = tl.load(seqlens + off_hz // H) 30 | if start_m * BLOCK_M >= seqlen: 31 | return 32 | 33 | # initialize offsets 34 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 35 | offs_n = tl.arange(0, BLOCK_N) 36 | offs_d = tl.arange(0, BLOCK_DMODEL) 37 | 38 | qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh 39 | kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh 40 | 41 | q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk 42 | k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk 43 | v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk 44 | o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok 45 | 46 | blocks_ptr = block_index + (off_hz * NUM_ROWS + start_m) * MAX_BLOCKS_PRE_ROW 47 | 48 | # initialize pointer to m and l 49 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 50 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 51 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 52 | # scale sm_scale by log_2(e) and use 53 | # 2^x instead of exp in the loop because CSE and LICM 54 | # don't work as expected with `exp` in the loop 55 | qk_scale = sm_scale * 1.44269504 56 | # load q: it will stay in SRAM throughout 57 | q = tl.load(q_ptrs) 58 | q = (q * qk_scale).to(dtype) 59 | 60 | # loop over k, v and update accumulator 61 | m_mask = offs_m[:, None] < seqlen 62 | block_count = tl.minimum((start_m + 1) * BLOCK_M // BLOCK_N, MAX_BLOCKS_PRE_ROW) 63 | 64 | for sparse_block_idx in range(block_count): 65 | real_block_idx = tl.load(blocks_ptr + sparse_block_idx) 66 | start_n = real_block_idx * BLOCK_N 67 | cols = start_n + offs_n 68 | # -- load k, v -- 69 | k = tl.load(k_ptrs + cols[None, :] * stride_kn) 70 | v = tl.load(v_ptrs + cols[:, None] * stride_vn) 71 | # -- compute qk -- 72 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 73 | # if start_n + BLOCK_N < seqlen: 74 | # qk = tl.where(m_mask, qk, float("-inf")) 75 | # else: 76 | causal_mask = cols[None, :] <= offs_m[:, None] 77 | qk = tl.where(m_mask & causal_mask, qk, float("-inf")) 78 | qk += tl.dot(q, k) 79 | # -- compute scaling constant -- 80 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 81 | alpha = tl.math.exp2(m_i - m_i_new) 82 | p = tl.math.exp2(qk - m_i_new[:, None]) 83 | # -- scale and update acc -- 84 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 85 | acc *= acc_scale[:, None] 86 | acc += tl.dot(p.to(dtype), v) 87 | # -- update m_i and l_i -- 88 | l_i = l_i * alpha + tl.sum(p, 1) 89 | m_i = m_i_new 90 | 91 | # write back O 92 | acc /= l_i[:, None] 93 | tl.store(o_ptrs, acc.to(dtype), mask=m_mask) 94 | 95 | 96 | def _triton_block_sparse_attention( 97 | q, # [BATCH, N_HEADS, N_CTX, D_HEAD] 98 | k, # [BATCH, N_HEADS, N_CTX, D_HEAD] 99 | v, # [BATCH, N_HEADS, N_CTX, D_HEAD] 100 | seqlens, # [BATCH, ] 101 | block_index, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), MAX_BLOCKS_PRE_ROW] 102 | sm_scale, 103 | block_size_M=64, 104 | block_size_N=64, 105 | ) -> torch.Tensor: 106 | # shape constraints 107 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 108 | assert Lq == Lk and Lk == Lv 109 | assert Lk in {16, 32, 64, 128} 110 | o = torch.zeros_like(q) 111 | grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1) 112 | dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16 113 | _triton_block_sparse_attn_fwd_kernel[grid]( 114 | q, k, v, seqlens, sm_scale, 115 | block_index, 116 | o, 117 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 118 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 119 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 120 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 121 | q.shape[0], q.shape[1], q.shape[2], 122 | block_index.shape[-2], block_index.shape[-1], 123 | BLOCK_M=block_size_M, BLOCK_N=block_size_N, 124 | BLOCK_DMODEL=Lk, 125 | dtype=dtype, 126 | num_warps=4, num_stages=2, 127 | ) 128 | 129 | return o 130 | 131 | 132 | def _build_block_index( 133 | query: torch.Tensor, 134 | key: torch.Tensor, 135 | top_k: int, 136 | block_size_M: int = 64, 137 | block_size_N: int = 64, 138 | ): 139 | """Build block index for sparse attention. 140 | 141 | This function enforces a fixed pattern for each query block: 142 | - The first block (attention sink) 143 | - The main diagonal block 144 | - The block one diagonal below the main diagonal 145 | 146 | These three blocks are forced by setting their scores to infinity before the topk operation. 147 | The remaining blocks are selected based on the attention scores between query and key blocks. 148 | 149 | Args: 150 | query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] 151 | key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] 152 | top_k: Number of blocks to select per query block 153 | block_size_M: Query block size 154 | block_size_N: Key block size 155 | 156 | Returns: 157 | Block indices tensor of shape [batch_size, num_heads, num_query_blocks, top_k] 158 | """ 159 | 160 | batch_size, num_heads, context_size, head_dim = query.shape 161 | query_pool = query.reshape((batch_size, num_heads, -1, block_size_M, head_dim)).mean(dim=-2) 162 | key_pool = key.reshape((batch_size, num_heads, -1, block_size_N, head_dim)).mean(dim=-2) 163 | arange_M = torch.arange(query_pool.shape[-2], dtype=torch.int32, device=query.device) * block_size_M 164 | arange_N = torch.arange(key_pool.shape[-2], dtype=torch.int32, device=key.device) * block_size_N 165 | p_pool = torch.einsum(f'bhmk, bhnk -> bhmn', query_pool, key_pool) 166 | 167 | # Replace the direct assignment with diagonal_scatter 168 | diag_size = min(p_pool.shape[-2], p_pool.shape[-1]) 169 | inf_diag = torch.full((batch_size, num_heads, diag_size), float('inf'), device=p_pool.device) 170 | 171 | # Set main diagonal to inf 172 | p_pool = torch.diagonal_scatter(p_pool, inf_diag, dim1=-2, dim2=-1) 173 | 174 | # Set one diagonal below main to inf 175 | p_pool = torch.diagonal_scatter(p_pool, inf_diag[..., :-1], offset=-1, dim1=-2, dim2=-1) 176 | 177 | # Keep the causal mask 178 | p_pool = p_pool.where(arange_M[None, None, :, None] >= arange_N[None, None, None, :], -torch.inf) 179 | 180 | # Set the first block to inf - always pick attention sinks 181 | p_pool[..., 0] = torch.inf 182 | 183 | top_k = min(top_k, context_size // block_size_N) 184 | # assert top_k <= context_size // block_size_N, f"top_k ({top_k}) must be <= context_size/block_size_N ({context_size // block_size_N})" 185 | 186 | return torch.topk(p_pool, top_k, dim=-1).indices.to(torch.int32).sort(dim=-1).values 187 | 188 | 189 | def block_sparse_attention( 190 | query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 191 | key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 192 | value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 193 | top_k: int, 194 | block_size_M: int = 64, 195 | block_size_N: int = 64, 196 | ): 197 | _, _, context_size, head_dim = query.shape 198 | pad = block_size_M - (query.shape[2] & (block_size_M - 1)) 199 | query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) 200 | key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) 201 | value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) 202 | seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) 203 | sm_scale = head_dim ** -0.5 204 | block_index = _build_block_index(query, key, top_k, block_size_N, block_size_N) 205 | out = _triton_block_sparse_attention(query, key, value, seqlens, block_index, sm_scale, block_size_M, block_size_N) 206 | return out[..., :context_size, :] 207 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/minference/csrc/kernels.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | // Code adopted from https://github.com/microsoft/MInference 4 | 5 | #include 6 | #include "torch/extension.h" 7 | 8 | std::vector convert_vertical_slash_indexes( 9 | torch::Tensor seqlens, // [BATCH, ] 10 | torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] 11 | torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] 12 | int context_size, 13 | int block_size_M, 14 | int block_size_N 15 | ); 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes, "dynamic sparse index function"); 19 | } 20 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/minference/csrc/vertical_slash_index.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT license. 3 | // Code adopted from https://github.com/microsoft/MInference 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | // __device__ int min(int x, int y) { 15 | // return x < y ? x : y; 16 | // } 17 | 18 | // __device__ int max(int x, int y) { 19 | // return x > y ? x : y; 20 | // } 21 | 22 | __device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { 23 | for (int idx = range_start; idx < range_end; idx += block_size) { 24 | block_offset[block_count++] = idx; 25 | } 26 | } 27 | 28 | __global__ void convert_vertical_slash_indexes_kernel( 29 | const int* seqlens, // [BATCH, ] 30 | const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] 31 | const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] 32 | int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 33 | int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] 34 | int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 35 | int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] 36 | int N_HEADS, 37 | int N_ROWS, 38 | int BLOCK_SIZE_M, 39 | int BLOCK_SIZE_N, 40 | int NNZ_V, 41 | int NNZ_S 42 | ) { 43 | const int batch_idx = blockIdx.y; 44 | const int head_idx = blockIdx.x; 45 | const int group_idx = blockIdx.z; 46 | 47 | int seqlen = seqlens[batch_idx]; 48 | int block_idx_m = group_idx * blockDim.x + threadIdx.x; 49 | int start_m = block_idx_m * BLOCK_SIZE_M; 50 | if (start_m >= seqlen) { 51 | return; 52 | } 53 | int end_m = start_m + BLOCK_SIZE_M; 54 | vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; 55 | slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; 56 | int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; 57 | block_count += row_offset; 58 | block_offset += row_offset * NNZ_S; 59 | column_count += row_offset; 60 | column_index += row_offset * NNZ_V; 61 | 62 | int tmp_col_cnt = 0, tmp_blk_cnt = 0; 63 | int s = 0, v = 0; 64 | int v_idx = vertical_indexes[v++]; 65 | int s_idx = slash_indexes[s++]; 66 | while (s_idx >= end_m) { 67 | s_idx = slash_indexes[s++]; 68 | } 69 | s_idx = max(end_m - s_idx, BLOCK_SIZE_M); 70 | int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; 71 | while (1) { 72 | if (v_idx < range_end) { 73 | if (v_idx < range_start) { 74 | column_index[tmp_col_cnt++] = v_idx; 75 | } 76 | if (v < NNZ_V) { 77 | v_idx = vertical_indexes[v++]; 78 | } else { 79 | v_idx = end_m + BLOCK_SIZE_M; 80 | } 81 | } else { 82 | if (s < NNZ_S) { 83 | s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); 84 | } else { 85 | save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); 86 | break; 87 | } 88 | if (s_idx > range_end + BLOCK_SIZE_M) { 89 | save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); 90 | range_start = s_idx - BLOCK_SIZE_M; 91 | range_end = s_idx; 92 | } else if (s_idx > range_end) { 93 | range_end += BLOCK_SIZE_M; 94 | } 95 | } 96 | } 97 | 98 | block_count[0] = tmp_blk_cnt; 99 | column_count[0] = tmp_col_cnt; 100 | } 101 | 102 | void convert_vertical_slash_indexes_64x64( 103 | const int* seqlens, // [BATCH, ] 104 | const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] 105 | const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] 106 | int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 107 | int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] 108 | int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 109 | int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] 110 | int BATCH_SIZE, 111 | int N_HEADS, 112 | int N_ROWS, 113 | int NNZ_V, 114 | int NNZ_S 115 | ) { 116 | const int BLOCK_SIZE_M = 64; 117 | const int BLOCK_SIZE_N = 64; 118 | const int N_THREADS = 64; 119 | const dim3 dimBlock(N_THREADS); 120 | const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); 121 | convert_vertical_slash_indexes_kernel<<>>( 122 | seqlens, vertical_indexes, slash_indexes, 123 | block_count, block_offset, column_count, column_index, 124 | N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S 125 | ); 126 | } 127 | 128 | std::vector convert_vertical_slash_indexes( 129 | torch::Tensor seqlens, // [BATCH, ] 130 | torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] 131 | torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] 132 | int context_size, 133 | int block_size_M, 134 | int block_size_N 135 | ) { 136 | assert(block_size_M == 64); 137 | assert(block_size_N == 64); 138 | 139 | cudaSetDevice(seqlens.get_device()); 140 | 141 | int batch_size = slash_indexes.size(0); 142 | int num_heads = slash_indexes.size(1); 143 | int nnz_slash = slash_indexes.size(2); 144 | int nnz_vertical = vertical_indexes.size(2); 145 | int num_rows = (context_size + block_size_M - 1) / block_size_M; 146 | 147 | torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); 148 | torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); 149 | torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); 150 | torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); 151 | 152 | convert_vertical_slash_indexes_64x64( 153 | seqlens.data_ptr(), 154 | vertical_indexes.data_ptr(), 155 | slash_indexes.data_ptr(), 156 | block_count.data_ptr(), 157 | block_offset.data_ptr(), 158 | column_count.data_ptr(), 159 | column_index.data_ptr(), 160 | batch_size, 161 | num_heads, 162 | num_rows, 163 | nnz_vertical, 164 | nnz_slash 165 | ); 166 | 167 | return { block_count, block_offset, column_count, column_index }; 168 | } 169 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/minference/vertical_and_slash.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | # Code adopted from https://github.com/microsoft/MInference 4 | 5 | import math 6 | import torch 7 | import triton 8 | import triton.language as tl 9 | 10 | from sparse_frontier.modelling.attention.minference.minference import convert_vertical_slash_indexes 11 | 12 | 13 | @triton.jit 14 | def _triton_mixed_sparse_attn_fwd_kernel( 15 | Q, K, V, seqlens, sm_scale, 16 | block_count, block_offset, column_count, column_index, 17 | Out, 18 | stride_qz, stride_qh, stride_qm, stride_qk, 19 | stride_kz, stride_kh, stride_kn, stride_kk, 20 | stride_vz, stride_vh, stride_vn, stride_vk, 21 | stride_oz, stride_oh, stride_om, stride_ok, 22 | Z, H, N_CTX, 23 | NUM_ROWS, NNZ_S, NNZ_V, 24 | BLOCK_M: tl.constexpr, 25 | BLOCK_N: tl.constexpr, 26 | BLOCK_DMODEL: tl.constexpr, 27 | dtype: tl.constexpr, 28 | ): 29 | start_m = tl.program_id(0) 30 | off_hz = tl.program_id(1) 31 | 32 | seqlen = tl.load(seqlens + off_hz // H) 33 | if start_m * BLOCK_M >= seqlen: 34 | return 35 | 36 | # initialize offsets 37 | offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 38 | offs_n = tl.arange(0, BLOCK_N) 39 | offs_d = tl.arange(0, BLOCK_DMODEL) 40 | 41 | qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh 42 | kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh 43 | 44 | q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk 45 | k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk 46 | v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk 47 | o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok 48 | 49 | num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m) 50 | blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S 51 | num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m) 52 | cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V 53 | 54 | # initialize pointer to m and l 55 | m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 56 | l_i = tl.zeros([BLOCK_M], dtype=tl.float32) 57 | acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) 58 | # scale sm_scale by log_2(e) and use 59 | # 2^x instead of exp in the loop because CSE and LICM 60 | # don't work as expected with `exp` in the loop 61 | qk_scale = sm_scale * 1.44269504 62 | # load q: it will stay in SRAM throughout 63 | q = tl.load(q_ptrs) 64 | q = (q * qk_scale).to(dtype) 65 | 66 | # loop over k, v and update accumulator 67 | m_mask = offs_m[:, None] < seqlen 68 | 69 | for block_index in range(num_blks): 70 | start_n = tl.load(blks_ptr + block_index) 71 | cols = start_n + offs_n 72 | n_mask = cols < seqlen 73 | # -- load k, v -- 74 | k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) 75 | v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) 76 | # -- compute qk -- 77 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 78 | causal_mask = cols[None, :] <= offs_m[:, None] 79 | qk = tl.where(m_mask & causal_mask, qk, float("-inf")) 80 | qk += tl.dot(q, k) 81 | # -- compute scaling constant -- 82 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 83 | alpha = tl.math.exp2(m_i - m_i_new) 84 | p = tl.math.exp2(qk - m_i_new[:, None]) 85 | # -- scale and update acc -- 86 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 87 | acc *= acc_scale[:, None] 88 | acc += tl.dot(p.to(dtype), v) 89 | # -- update m_i and l_i -- 90 | l_i = l_i * alpha + tl.sum(p, 1) 91 | m_i = m_i_new 92 | 93 | for start_n in range(0, num_cols, BLOCK_N): 94 | n_mask = start_n + offs_n < num_cols 95 | cols = tl.load(cols_ptr + start_n + offs_n, mask=n_mask, other=0) 96 | # -- load k, v -- 97 | k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0) 98 | v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0) 99 | # -- compute qk -- 100 | qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 101 | qk = tl.where(m_mask & n_mask, qk, float("-inf")) 102 | qk += tl.dot(q, k) 103 | # -- compute scaling constant -- 104 | m_i_new = tl.maximum(m_i, tl.max(qk, 1)) 105 | alpha = tl.math.exp2(m_i - m_i_new) 106 | p = tl.math.exp2(qk - m_i_new[:, None]) 107 | # -- scale and update acc -- 108 | acc_scale = l_i * 0 + alpha # workaround some compiler bug 109 | acc *= acc_scale[:, None] 110 | acc += tl.dot(p.to(dtype), v) 111 | # -- update m_i and l_i -- 112 | l_i = l_i * alpha + tl.sum(p, 1) 113 | m_i = m_i_new 114 | 115 | # write back O 116 | acc /= l_i[:, None] 117 | # acc = tl.where(m_mask, acc / l_i[:, None], 0.0) 118 | tl.store(o_ptrs, acc.to(dtype), mask=m_mask) 119 | 120 | 121 | def _triton_mixed_sparse_attention( 122 | q: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 123 | k: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 124 | v: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 125 | seqlens: torch.Tensor, # [BATCH, ] 126 | block_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 127 | block_offset: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] 128 | column_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] 129 | column_index: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] 130 | sm_scale: float, 131 | block_size_M: int = 64, 132 | block_size_N: int = 64, 133 | ) -> torch.Tensor: 134 | # shape constraints 135 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] 136 | assert Lq == Lk and Lk == Lv 137 | assert Lk in {16, 32, 64, 128} 138 | o = torch.zeros_like(q) 139 | grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1) 140 | dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16 141 | _triton_mixed_sparse_attn_fwd_kernel[grid]( 142 | q, k, v, seqlens, sm_scale, 143 | block_count, block_offset, column_count, column_index, 144 | o, 145 | q.stride(0), q.stride(1), q.stride(2), q.stride(3), 146 | k.stride(0), k.stride(1), k.stride(2), k.stride(3), 147 | v.stride(0), v.stride(1), v.stride(2), v.stride(3), 148 | o.stride(0), o.stride(1), o.stride(2), o.stride(3), 149 | q.shape[0], q.shape[1], q.shape[2], 150 | block_count.shape[-1], block_offset.shape[-1], column_index.shape[-1], 151 | BLOCK_M=block_size_M, BLOCK_N=block_size_N, 152 | BLOCK_DMODEL=Lk, 153 | dtype=dtype, 154 | num_warps=4, num_stages=1, 155 | ) 156 | 157 | return o 158 | 159 | 160 | def vertical_slash_sparse_attention( 161 | query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 162 | key: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 163 | value: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] 164 | v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] 165 | s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] 166 | block_size_M: int = 64, 167 | block_size_N: int = 64, 168 | ): 169 | batch_size, num_heads, context_size, head_dim = query.shape 170 | pad = block_size_M - (context_size & (block_size_M - 1)) 171 | query = torch.nn.functional.pad(query, [0, 0, 0, pad, 0, 0, 0, 0]) 172 | key = torch.nn.functional.pad(key, [0, 0, 0, pad, 0, 0, 0, 0]) 173 | value = torch.nn.functional.pad(value, [0, 0, 0, pad, 0, 0, 0, 0]) 174 | 175 | if head_dim not in [16, 32, 64, 128, 256, 512]: 176 | target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim 177 | query = torch.nn.functional.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) 178 | key = torch.nn.functional.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) 179 | value = torch.nn.functional.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) 180 | 181 | v_idx = v_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] 182 | s_idx = s_idx.to(torch.int32).reshape((batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] 183 | seqlens = torch.tensor([context_size], dtype=torch.int32, device=query.device) 184 | sm_scale = head_dim ** -0.5 185 | block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes( 186 | seqlens, v_idx, s_idx, context_size, block_size_M, block_size_N, 187 | ) 188 | sparsity = calc_sparsity(block_count, column_count, context_size) 189 | out = _triton_mixed_sparse_attention( 190 | query, key, value, seqlens, 191 | block_count, block_offset, column_count, column_index, 192 | sm_scale, block_size_M, block_size_N, 193 | ) 194 | return out[..., :context_size, :head_dim], sparsity 195 | 196 | 197 | def calc_sparsity(block_count, column_count, seq_len): 198 | block_cells = block_count.sum(dim=-1) * 64 * 64 199 | column_cells = column_count.sum(dim=-1) * 64 200 | total_cells = seq_len * (seq_len + 1) // 2 201 | return 1 - (block_cells + column_cells) / total_cells 202 | 203 | 204 | def sum_over_diagonals(matrix: torch.Tensor) -> torch.Tensor: 205 | """Efficiently sum values along diagonals of the attention matrix. 206 | 207 | This function computes the sum of values along each diagonal of a 4D attention matrix. 208 | It uses an efficient strided implementation to avoid explicit diagonal extraction. 209 | 210 | Args: 211 | matrix: Input attention matrix of shape (batch_size, num_heads, queries, keys) 212 | where queries and keys are sequence lengths 213 | 214 | Returns: 215 | Tensor of shape (batch_size, num_heads, queries + keys - 1) containing the 216 | summed values for each diagonal. The diagonals are ordered from top-right 217 | to bottom-left, with the main diagonal at index queries-1. 218 | """ 219 | batch_size, num_heads, queries, keys = matrix.shape 220 | zero_matrix = torch.zeros((batch_size, num_heads, queries, queries), device=matrix.device) 221 | matrix_padded = torch.cat((zero_matrix, matrix, zero_matrix), -1) 222 | 223 | matrix_strided = matrix_padded.as_strided( 224 | (batch_size, num_heads, queries, queries + keys), 225 | (num_heads * queries * (2 * queries + keys), 226 | queries * (2 * queries + keys), 227 | 2 * queries + keys + 1, 1) 228 | ) 229 | return torch.sum(matrix_strided, 2)[:, :, 1:] 230 | 231 | 232 | def vertical_and_slash_kernel( 233 | q: torch.Tensor, 234 | k: torch.Tensor, 235 | v: torch.Tensor, 236 | vertical_size: int, 237 | slash_size: int, 238 | last_q: int = 64, 239 | inf_value: float = float('inf'), 240 | topk_vertical_inf: int = 4, 241 | topk_slash_inf: int = 64, 242 | ): 243 | """ 244 | Compute the vertical and slash kernel for sparse attention. 245 | 246 | Args: 247 | q: Query tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD] 248 | k: Key tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD] 249 | v: Value tensor of shape [BATCH, N_HEADS, N_CTX, D_HEAD] 250 | vertical_size: Size of the vertical attention 251 | slash_size: Size of the slash attention 252 | last_q: Number of last queries to consider (default: 64) 253 | inf_value: Value to use for infinity (default: float('inf')) 254 | topk_vertical_inf: Number of top-k vertical elements to set to infinity (default: 30) 255 | topk_slash_inf: Number of top-k slash elements to set to infinity (default: 100) 256 | 257 | Returns: 258 | Output tensor after applying vertical and slash sparse attention. 259 | """ 260 | arange = torch.arange(last_q, device=q.device) 261 | LAST_Q_MASK = arange[None, None, :, None] >= arange[None, None, None, :] 262 | 263 | _, _, q_len, d = q.shape 264 | 265 | # Compute scaled dot-product attention 266 | qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) / math.sqrt(d) 267 | qk[:, :, :, -last_q:] = torch.where(LAST_Q_MASK[..., -last_q:, -last_q:], qk[:, :, :, -last_q:], -inf_value) 268 | qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype) 269 | 270 | assert topk_vertical_inf <= vertical_size 271 | assert topk_slash_inf <= slash_size 272 | 273 | # Compute top verticals 274 | vertical = qk.sum(-2, keepdim=True) 275 | vertical[..., :topk_vertical_inf] = inf_value 276 | vertical_topk = torch.topk(vertical, vertical_size, -1).indices 277 | 278 | # # Compute top slashes 279 | slash = sum_over_diagonals(qk)[..., :-last_q + 1] 280 | slash[..., -topk_slash_inf:] = inf_value 281 | slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices 282 | 283 | return vertical_slash_sparse_attention(q, k, v, vertical_topk, slash) 284 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/attention/registry.py: -------------------------------------------------------------------------------- 1 | from .efficient_prefilling import ( 2 | DenseAttention, 3 | VerticalAndSlashAttentionMInference, 4 | BlockSparseAttentionMInference, 5 | FlexPrefill, 6 | ) 7 | from .efficient_decoding import QuestAttention 8 | from .kv_compression import SnapKVCompression, AdaSnapKVCompression 9 | from sparse_frontier.utils import GlobalSettings 10 | from .handler import AttentionHandler 11 | 12 | 13 | ATTENTION_REGISTRY = { 14 | 'dense': DenseAttention, 15 | 'vertical_and_slash': VerticalAndSlashAttentionMInference, 16 | 'block_sparse': BlockSparseAttentionMInference, 17 | 'snapkv': SnapKVCompression, 18 | 'ada_snapkv': AdaSnapKVCompression, 19 | 'quest': QuestAttention, 20 | 'flexprefill': FlexPrefill, 21 | } 22 | 23 | 24 | def get_attention(): 25 | return GlobalSettings.get('ATTENTION', DenseAttention()) 26 | 27 | 28 | def get_attention_handler() -> AttentionHandler: 29 | return GlobalSettings.get('ATTENTION_HANDLER') 30 | 31 | 32 | def configure_attention(): 33 | cfg = GlobalSettings.get('cfg') 34 | 35 | attention_args = cfg.attention.get('args', {}) 36 | 37 | if cfg.attention.name == 'quest': 38 | block_size = attention_args.page_size 39 | else: 40 | block_size = cfg.kv_cache_block_size 41 | 42 | # Configure attention handler 43 | attention_handler = AttentionHandler( 44 | tp_size=cfg.tp, 45 | model_q_heads=cfg.model.num_q_heads, 46 | model_kv_heads=cfg.model.num_kv_heads, 47 | model_layers=cfg.model.num_layers, 48 | max_input_tokens=cfg.max_input_tokens, 49 | max_output_tokens=cfg.max_output_tokens, 50 | block_size=block_size, 51 | ) 52 | GlobalSettings.set('ATTENTION_HANDLER', attention_handler) 53 | 54 | if cfg.attention.name == 'quest': 55 | extra_args = { 56 | 'num_layers': cfg.model.num_layers, 57 | 'max_input_tokens': cfg.max_input_tokens, 58 | 'max_output_tokens': cfg.max_output_tokens, 59 | } 60 | else: 61 | extra_args = {} 62 | 63 | attention = ATTENTION_REGISTRY[cfg.attention.name](**attention_args, **extra_args) 64 | GlobalSettings.set('ATTENTION', attention) 65 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PiotrNawrot/sparse-frontier/12b0c5cde3750893c6676cf3a60a81cf1c704fcb/sparse_frontier/modelling/models/__init__.py -------------------------------------------------------------------------------- /sparse_frontier/modelling/models/abstract_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any 3 | import torch 4 | 5 | 6 | class AbstractModel(ABC): 7 | @abstractmethod 8 | def __init__( 9 | self, 10 | model_path: str, 11 | max_input_tokens: int = 8192, 12 | max_output_tokens: int = 256, 13 | device: torch.device = None, 14 | dtype: torch.dtype = None, 15 | ) -> None: 16 | """Initialize the model. 17 | 18 | Args: 19 | model_path: Path or identifier for the model 20 | max_input_tokens: Maximum number of input tokens 21 | max_output_tokens: Maximum number of new tokens to generate 22 | device: Device to run the model on 23 | dtype: Data type for model parameters 24 | """ 25 | pass 26 | 27 | @abstractmethod 28 | def _load_model(self, model_path: str) -> Any: 29 | """Load the underlying model. 30 | 31 | Args: 32 | model_path: Path or identifier for the model 33 | 34 | Returns: 35 | The loaded model 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def _greedy_config(self, max_output_tokens: int) -> Dict: 41 | """Get configuration for greedy generation. 42 | 43 | Args: 44 | max_output_tokens: Maximum number of new tokens to generate 45 | 46 | Returns: 47 | Dictionary of generation configuration parameters 48 | """ 49 | pass 50 | 51 | @abstractmethod 52 | def generate( 53 | self, 54 | input_text: str, 55 | max_output_tokens: int = None, 56 | ) -> str: 57 | """Generate text from input. 58 | 59 | Args: 60 | input_text: The input text to generate from 61 | max_output_tokens: Maximum number of new tokens to generate 62 | 63 | Returns: 64 | Generated text string 65 | """ 66 | pass 67 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/models/vllm_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from vllm import LLM, SamplingParams 5 | from vllm.attention.backends.abstract import AttentionType 6 | from vllm.attention.backends.flash_attn import ( 7 | FlashAttentionMetadata, 8 | get_num_prefill_decode_query_kv_tokens, 9 | ) 10 | 11 | from sparse_frontier.modelling.attention.registry import get_attention, get_attention_handler 12 | from sparse_frontier.utils.globals import set_vllm_profiling_done, is_vllm_profiling_done 13 | from .abstract_model import AbstractModel 14 | from sparse_frontier.modelling.tokenizer import Tokenizer 15 | 16 | 17 | class VLLMModel(AbstractModel): 18 | def __init__( 19 | self, 20 | model_path: str, 21 | max_input_tokens: int = 8192, 22 | max_output_tokens: int = 256, 23 | dtype: torch.dtype = None, 24 | tensor_parallel_size: int = 1, 25 | seed: Optional[int] = 43, 26 | ): 27 | """ 28 | vLLM uses forking to run the TP. Therefore we can't initialise CUDA, 29 | before the fork, as it will trigger an error. That's why we don't 30 | call get_device() etc. 31 | """ 32 | self.model_path = model_path 33 | self.max_input_tokens = max_input_tokens 34 | self.max_output_tokens = max_output_tokens 35 | self.dtype = dtype or torch.bfloat16 36 | self.tensor_parallel_size = tensor_parallel_size 37 | self.seed = seed 38 | 39 | set_vllm_profiling_done(False) 40 | 41 | assert not torch.cuda.is_initialized(), "CUDA is not initialized" 42 | self.model = self._load_model(self.model_path) 43 | self.tokenizer = Tokenizer(self.model_path) 44 | 45 | def _load_model(self, model_path: str) -> LLM: 46 | if 'Qwen' in model_path and self.max_input_tokens + self.max_output_tokens > 32768: 47 | factor = (self.max_input_tokens + self.max_output_tokens) / 32768 48 | hf_overrides = { 49 | "rope_scaling": { 50 | "factor": factor, 51 | "original_max_position_embeddings": 32768, 52 | "rope_type": "yarn" 53 | } 54 | } 55 | else: 56 | hf_overrides = {} 57 | 58 | model = LLM( 59 | model=model_path, 60 | skip_tokenizer_init=True, 61 | trust_remote_code=True, 62 | enforce_eager=True, 63 | seed=self.seed, 64 | gpu_memory_utilization=0.9, 65 | max_num_batched_tokens=self.max_input_tokens + self.max_output_tokens, 66 | max_model_len=self.max_input_tokens + self.max_output_tokens, 67 | enable_chunked_prefill=False, 68 | tensor_parallel_size=self.tensor_parallel_size, 69 | hf_overrides=hf_overrides, 70 | ) 71 | 72 | # Statistics has been accumulated during vLLM profiling 73 | get_attention().reset_sparsity_statistics() 74 | set_vllm_profiling_done(True) 75 | 76 | return model 77 | 78 | def _greedy_config(self, max_output_tokens: int) -> Dict: 79 | return { 80 | 'sampling_params': SamplingParams( 81 | max_tokens=max_output_tokens, 82 | temperature=0, 83 | ), 84 | 'use_tqdm': False, 85 | } 86 | 87 | @torch.no_grad() 88 | def generate( 89 | self, 90 | input_text: str, 91 | max_output_tokens: int = None, 92 | ) -> str: 93 | max_output_tokens = max_output_tokens or self.max_output_tokens 94 | model_input = self.tokenizer.encode_for_generation(input_text, return_tensors=False) 95 | 96 | output = self.model.generate( 97 | prompt_token_ids=model_input['input_ids'], 98 | **self._greedy_config(max_output_tokens), 99 | ) 100 | 101 | output_ids = output[0].__dict__['outputs'][0].token_ids 102 | decoded = self.tokenizer.decode(output_ids) 103 | 104 | return { 105 | 'text': decoded[0] if isinstance(decoded, list) else decoded, 106 | 'output_tokens_len': len(output_ids), 107 | } 108 | 109 | 110 | 111 | def vllm_patched_forward( 112 | self, 113 | query: torch.Tensor, 114 | key: torch.Tensor, 115 | value: torch.Tensor, 116 | kv_cache: torch.Tensor, 117 | attn_metadata: FlashAttentionMetadata, 118 | k_scale: float = 1.0, 119 | v_scale: float = 1.0, 120 | attn_type: AttentionType = AttentionType.DECODER, 121 | output: Optional[torch.Tensor] = None, 122 | ) -> torch.Tensor: 123 | """Forward pass with FlashAttention. 124 | 125 | Args: 126 | query: shape = [num_tokens, num_heads, head_size] 127 | key: shape = [num_tokens, num_kv_heads, head_size] 128 | value: shape = [num_tokens, num_kv_heads, head_size] 129 | output: shape = [num_tokens, num_heads, head_size] 130 | kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] 131 | attn_metadata: Metadata for attention. 132 | Returns: 133 | shape = [num_tokens, num_heads * head_size] 134 | """ 135 | (num_prefill_query_tokens, num_prefill_kv_tokens, _) = \ 136 | get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) 137 | 138 | if attn_metadata.prefill_metadata: 139 | get_attention_handler().__call__( 140 | queries=query[:num_prefill_query_tokens], 141 | keys=key[:num_prefill_kv_tokens], 142 | values=value[:num_prefill_kv_tokens], 143 | kv_cache=kv_cache, 144 | output=output[:num_prefill_query_tokens], 145 | ) 146 | else: 147 | get_attention_handler().__call__( 148 | queries=query[num_prefill_query_tokens:], 149 | keys=key[num_prefill_query_tokens:], 150 | values=value[num_prefill_query_tokens:], 151 | kv_cache=kv_cache, 152 | output=output[num_prefill_query_tokens:], 153 | ) 154 | 155 | return output 156 | 157 | 158 | def swap_vllm_attention(): 159 | from vllm.attention.backends.flash_attn import FlashAttentionImpl 160 | FlashAttentionImpl.forward = vllm_patched_forward 161 | -------------------------------------------------------------------------------- /sparse_frontier/modelling/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from typing import List, Dict, Union 3 | 4 | import torch 5 | 6 | class Tokenizer: 7 | def __init__( 8 | self, 9 | model_path: str, 10 | device: torch.device = None, 11 | ) -> None: 12 | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | self.tokenizer = AutoTokenizer.from_pretrained( 15 | model_path, 16 | trust_remote_code=True, 17 | ) 18 | 19 | if self.tokenizer.pad_token is None: 20 | self.tokenizer.pad_token = self.tokenizer.eos_token 21 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 22 | 23 | @property 24 | def pad_token_id(self) -> int: 25 | return self.tokenizer.pad_token_id 26 | 27 | @property 28 | def model_max_length(self) -> int: 29 | return self.tokenizer.model_max_length 30 | 31 | def text_to_tokens(self, input_text: str) -> List[int]: 32 | return self.tokenizer.encode(input_text, add_special_tokens=False) 33 | 34 | def encode_for_generation(self, input_text: str, return_tensors: bool = True) -> Dict: 35 | input_text_with_prompt = self.tokenizer.apply_chat_template( 36 | [{"role": "user", "content": input_text}], 37 | tokenize=False, 38 | add_generation_prompt=True 39 | ) 40 | 41 | if return_tensors: 42 | encoded_input = self.tokenizer( 43 | input_text_with_prompt, 44 | return_tensors="pt", 45 | add_special_tokens=False, 46 | ).to(self.device) 47 | 48 | return { 49 | **encoded_input, 50 | 'input_length': encoded_input['input_ids'].size(1), 51 | } 52 | else: 53 | input_ids = self.text_to_tokens(input_text_with_prompt) 54 | return { 55 | 'input_ids': input_ids, 56 | 'input_length': len(input_ids), 57 | } 58 | 59 | def decode(self, outputs: Union[List[int], List[List[int]]], input_length: int = 0) -> Union[str, List[str]]: 60 | if isinstance(outputs[0], int): 61 | # Single list of tokens 62 | return self.tokenizer.decode( 63 | outputs[input_length:], 64 | skip_special_tokens=True 65 | ) 66 | else: 67 | # List of lists of tokens 68 | return [ 69 | self.tokenizer.decode( 70 | output[input_length:], 71 | skip_special_tokens=True 72 | ) 73 | for output in outputs 74 | ] 75 | -------------------------------------------------------------------------------- /sparse_frontier/prediction.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from concurrent.futures import ProcessPoolExecutor, as_completed 4 | from typing import Dict, Any 5 | 6 | from tqdm import tqdm 7 | 8 | from sparse_frontier.utils.data import load_data_without_predictions, get_pred_path 9 | from sparse_frontier.utils.general import get_free_ports, save_config 10 | from sparse_frontier.utils.globals import GlobalSettings 11 | 12 | 13 | model = None 14 | 15 | 16 | def process_sample(sample: Dict[str, Any]) -> Dict[str, Any]: 17 | """Process a single sample through the model. 18 | 19 | Args: 20 | sample: Dictionary containing input text and metadata 21 | 22 | Returns: 23 | Sample dictionary augmented with model prediction 24 | """ 25 | global model 26 | from sparse_frontier.modelling.attention.registry import get_attention 27 | 28 | get_attention().reset_sparsity_statistics() 29 | 30 | output = model.generate(sample['input_text']) 31 | 32 | output_dict = { 33 | 'pred': output['text'], 34 | 'output_tokens_len': output['output_tokens_len'], 35 | 'sparsity': get_attention().calculate_sparsity(), 36 | 'index': sample['index'] 37 | } 38 | 39 | return output_dict 40 | 41 | 42 | def init_worker() -> None: 43 | """Initialize worker process with GPU configuration and VLLM model.""" 44 | import torch.multiprocessing as mp 45 | from sparse_frontier.modelling.models.vllm_model import VLLMModel 46 | 47 | # Explicitly set tokenizers parallelism for each worker 48 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 49 | 50 | cfg = GlobalSettings.get('cfg') 51 | 52 | # Check if CUDA_VISIBLE_DEVICES is already set 53 | preset_gpus = os.environ.get('CUDA_VISIBLE_DEVICES') 54 | if preset_gpus: 55 | available_gpus = [int(gpu) for gpu in preset_gpus.split(',')] 56 | if len(available_gpus) < cfg.gpus: 57 | raise ValueError(f"Number of available GPUs ({len(available_gpus)}) is less than required ({cfg.gpus})") 58 | else: 59 | available_gpus = list(range(cfg.gpus)) 60 | 61 | # Set GPU device visibility based on worker index 62 | if len(mp.current_process()._identity) > 0: 63 | worker_index = mp.current_process()._identity[0] - 1 64 | # Calculate GPU slice for this worker from available GPUs 65 | worker_gpus = available_gpus[worker_index * cfg.tp:(worker_index + 1) * cfg.tp] 66 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, worker_gpus)) 67 | # Get pre-allocated port for this worker 68 | worker_port = GlobalSettings.get('worker_ports')[worker_index] 69 | else: 70 | # Single worker case - use all available GPUs and first port 71 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, available_gpus)) 72 | worker_port = GlobalSettings.get('worker_ports')[0] 73 | 74 | # Configure VLLM environment 75 | os.environ['VLLM_HOST_IP'] = 'localhost' 76 | os.environ['VLLM_PORT'] = str(worker_port) 77 | 78 | global model 79 | model = VLLMModel( 80 | model_path=cfg.model.path, 81 | max_input_tokens=cfg.max_input_tokens, 82 | max_output_tokens=cfg.max_output_tokens, 83 | tensor_parallel_size=cfg.tp, 84 | seed=cfg.random_seed 85 | ) 86 | 87 | 88 | def predict_task() -> None: 89 | cfg = GlobalSettings.get('cfg') 90 | 91 | from sparse_frontier.modelling.attention.registry import configure_attention 92 | configure_attention() 93 | 94 | data = load_data_without_predictions() 95 | 96 | pred_path = get_pred_path() 97 | 98 | num_workers = cfg.gpus // cfg.tp 99 | 100 | # Get free ports for all workers at the start 101 | free_ports = get_free_ports(num_workers) 102 | GlobalSettings.set('worker_ports', free_ports) 103 | 104 | if num_workers == 1: 105 | # Single worker case - process samples sequentially 106 | init_worker() 107 | with open(pred_path, 'at', encoding="utf-8", buffering=1) as fout: 108 | for sample in tqdm(data, total=len(data)): 109 | sample_results = process_sample(sample) 110 | fout.write(json.dumps(sample_results) + '\n') 111 | fout.flush() 112 | 113 | global model 114 | model = None 115 | else: 116 | # Multi-worker case - process samples in parallel 117 | with ProcessPoolExecutor(max_workers=num_workers, initializer=init_worker) as executor, \ 118 | open(pred_path, 'at', encoding="utf-8", buffering=1) as fout: 119 | futures = {executor.submit(process_sample, sample): sample for sample in data} 120 | for future in tqdm(as_completed(futures), total=len(data)): 121 | sample_results = future.result() 122 | fout.write(json.dumps(sample_results) + '\n') 123 | fout.flush() 124 | 125 | save_config(os.path.dirname(pred_path)) 126 | print(f'Prediction for task {cfg.task.name} is done. Output is saved to {pred_path}.') 127 | -------------------------------------------------------------------------------- /sparse_frontier/preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sparse_frontier.utils import GlobalSettings 4 | from sparse_frontier.utils.data import write_jsonl, get_data_path 5 | from sparse_frontier.utils.general import save_config 6 | from sparse_frontier.tasks.registry import TASK_REGISTRY 7 | 8 | 9 | def check_args(): 10 | from transformers import AutoTokenizer, AutoConfig 11 | 12 | cfg = GlobalSettings.get("cfg") 13 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.path, trust_remote_code=True) 14 | config = AutoConfig.from_pretrained(cfg.model.path, trust_remote_code=True) 15 | 16 | if tokenizer.model_max_length < cfg.max_input_tokens + cfg.max_output_tokens: 17 | raise ValueError(f"Model maximum sequence length ({tokenizer.model_max_length}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})") 18 | 19 | max_pos_embeddings = getattr(config, "max_position_embeddings", 131072) 20 | if max_pos_embeddings < cfg.max_input_tokens + cfg.max_output_tokens and 'qwen' not in cfg.model.name: 21 | raise ValueError(f"Model maximum position embeddings ({max_pos_embeddings}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})") 22 | 23 | seq_length = getattr(config, "seq_length", 131072) 24 | if seq_length < cfg.max_input_tokens + cfg.max_output_tokens: 25 | raise ValueError(f"Model maximum sequence length ({seq_length}) is less than required length ({cfg.max_input_tokens + cfg.max_output_tokens})") 26 | 27 | 28 | def get_task_generator(): 29 | from sparse_frontier.modelling.tokenizer import Tokenizer 30 | cfg = GlobalSettings.get("cfg") 31 | task_kwargs = { 32 | 'num_samples': cfg.samples, 33 | 'max_input_tokens': cfg.max_input_tokens, 34 | 'max_output_tokens': cfg.max_output_tokens, 35 | 'tokenizer': Tokenizer(cfg.model.path, device='cpu'), 36 | 'random_seed': cfg.random_seed, 37 | **cfg.task.get('args', {}), 38 | } 39 | return TASK_REGISTRY[cfg.task.name](**task_kwargs) 40 | 41 | 42 | def prepare_task(): 43 | cfg = GlobalSettings.get("cfg") 44 | 45 | # Explicitly enable tokenizers parallelism during preparation 46 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 47 | 48 | check_args() 49 | 50 | data_path = get_data_path() 51 | 52 | print(f"Preparing {cfg.task.name} with {cfg.samples} samples") 53 | generator = get_task_generator() 54 | samples = generator.generate_samples() 55 | 56 | write_jsonl(data_path, samples) 57 | save_config(os.path.dirname(data_path)) 58 | print(f"Saved {cfg.task.name} with {cfg.samples} samples to {data_path}") 59 | 60 | # Disable tokenizers parallelism after preparation 61 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 62 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import TASK_REGISTRY 2 | 3 | __all__ = [ 4 | 'TASK_REGISTRY' 5 | ] 6 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/abstract_prompt.py: -------------------------------------------------------------------------------- 1 | # Base template with shared components 2 | _BASE_TEMPLATE = """{task_intro} 3 | 4 | {question_intro} 5 | 6 | {question_section} 7 | 8 | 9 | {context} 10 | 11 | 12 | {question_repeat_section} 13 | 14 | Instructions: 15 | 1. First, provide a brief explanation of your reasoning process. Explain how you identified 16 | the relevant information from the context and how you determined your answer. 17 | 2. Then, provide your final answer following this exact format: 18 | 19 | {answer_format} 20 | 21 | 22 | Your response must follow this structure exactly: 23 | 24 | Your explanation here... 25 | 26 | 27 | Your answer here... 28 | 29 | 30 | Important: 31 | {extra_instructions} 32 | - Keep your explanations clear, coherent, concise, and to the point. 33 | - Do not include any additional text, explanations, or reasoning in the answer section. Follow the answer format exactly. 34 | """ 35 | 36 | # Template for single question tasks 37 | SINGLEQ_PROMPT_TEMPLATE = _BASE_TEMPLATE.format( 38 | task_intro="{task_intro}", 39 | context="{context}", 40 | question_intro="Below is your question. I will state it both before and after the context.", 41 | question_section="\n{question}\n", 42 | question_repeat_section="\n{question}\n", 43 | answer_format="{answer_format}", 44 | extra_instructions="{extra_instructions}" 45 | ) 46 | 47 | # Template for multiple question tasks 48 | MULTIPLEQ_PROMPT_TEMPLATE = _BASE_TEMPLATE.format( 49 | task_intro="{task_intro}", 50 | context="{context}", 51 | question_intro="Below are your questions. I will state them both before and after the context.", 52 | question_section="\n{question}\n", 53 | question_repeat_section="\n{question}\n", 54 | answer_format="{answer_format}", 55 | extra_instructions="{extra_instructions}" 56 | ) 57 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/abstract_sample.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC, abstractmethod 3 | from typing import Any, Dict 4 | 5 | from sparse_frontier.modelling.tokenizer import Tokenizer 6 | 7 | 8 | class AbstractSample(ABC): 9 | """Base class for task samples. 10 | 11 | Handles common sample functionality like length validation and conversion to dict format. 12 | Subclasses must implement _generate_sample(). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sample_id: int, 18 | random_seed: int, 19 | max_tokens: int, 20 | tokenizer: Tokenizer, 21 | task_params: Dict[str, Any] 22 | ) -> None: 23 | """Initialize sample parameters. 24 | 25 | Args: 26 | sample_id: Sample index number 27 | random_seed: Random seed for reproducibility 28 | max_tokens: Maximum input sequence length in tokens 29 | tokenizer: Tokenizer for text encoding/decoding 30 | task_params: Dictionary of task parameters 31 | """ 32 | self.sample_id = sample_id 33 | self.max_tokens = max_tokens 34 | self.random_obj = random.Random(random_seed + sample_id) 35 | self.tokenizer = tokenizer 36 | self.task_params = task_params 37 | self.input_text, self.gold_answer, self.extra_data = self._generate_sample() 38 | 39 | @abstractmethod 40 | def _generate_sample(self) -> tuple[str, str, Dict[str, Any]]: 41 | """Generate the input text, gold answer and extra data for this sample. 42 | 43 | Returns: 44 | Tuple containing: 45 | - Input text string to be provided to the model 46 | - Expected gold answer string 47 | - Dictionary of extra data to include in to_dict output 48 | """ 49 | pass 50 | 51 | def to_dict(self) -> Dict[str, Any]: 52 | """Convert sample to dictionary format. 53 | 54 | Returns: 55 | Dictionary containing sample data and any extra data from _generate_sample 56 | """ 57 | # Handle both single string and list of strings for gold_answer 58 | reference_answer = self.gold_answer[0] if isinstance(self.gold_answer, list) else self.gold_answer 59 | 60 | total_sample_length = len(self.tokenizer.encode_for_generation(self.input_text, return_tensors=False)['input_ids']) + \ 61 | len(self.tokenizer.text_to_tokens(reference_answer)) 62 | 63 | base_dict = { 64 | "index": self.sample_id, 65 | "input_text": self.input_text, 66 | "gold_answer": self.gold_answer, 67 | "length": total_sample_length 68 | } 69 | base_dict.update(self.extra_data) 70 | return base_dict 71 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/abstract_task.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import ABC, abstractmethod 3 | from typing import Any, Dict, List 4 | 5 | from sparse_frontier.modelling.tokenizer import Tokenizer 6 | 7 | 8 | class AbstractTask(ABC): 9 | """Base class for defining evaluation tasks. 10 | 11 | Handles common task functionality like sample generation and evaluation. 12 | Subclasses must implement check_inputs(), generate_samples() and evaluate(). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | num_samples: int, 18 | max_input_tokens: int, 19 | max_output_tokens: int, 20 | tokenizer: Tokenizer, 21 | random_seed: int, 22 | template_tokens: int = 64, 23 | **kwargs 24 | ) -> None: 25 | """Initialize task parameters. 26 | 27 | Args: 28 | num_samples: Number of samples to generate 29 | max_input_tokens: Maximum input sequence length in tokens 30 | max_output_tokens: Maximum output sequence length in tokens 31 | tokenizer: Tokenizer for text encoding/decoding 32 | random_seed: Random seed for reproducibility 33 | template_tokens: Approximate number of tokens for model's template 34 | """ 35 | self.num_samples = num_samples 36 | self.max_input_tokens = max_input_tokens 37 | self.max_output_tokens = max_output_tokens 38 | self.tokenizer = tokenizer 39 | self.random_seed = random_seed 40 | self.template_tokens = template_tokens 41 | self.task_params = {} 42 | self.random_obj = random.Random(self.random_seed) 43 | 44 | def __getattr__(self, name: str) -> Any: 45 | """Get task parameter by name.""" 46 | if name in self.params: 47 | return self.params[name] 48 | raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") 49 | 50 | def check_sample_length(self, input_text: str, gold_answer: str | list[str]) -> None: 51 | """Validate input and output sequence lengths. 52 | 53 | Args: 54 | input_text: Model input text 55 | gold_answer: Expected model output, either a single string or list of possible answers 56 | 57 | Raises: 58 | AssertionError: If sequences exceed maximum allowed lengths 59 | """ 60 | # Handle both single string and list of strings 61 | if isinstance(gold_answer, str): 62 | gold_answers = [gold_answer] 63 | else: 64 | gold_answers = gold_answer 65 | 66 | # Check length of each possible answer 67 | for answer in gold_answers: 68 | gold_length = len(self.tokenizer.text_to_tokens(answer)) 69 | if gold_length > self.max_output_tokens: 70 | raise AssertionError( 71 | f"Gold answer too long: {gold_length} tokens, max {self.max_output_tokens}" 72 | ) 73 | 74 | input_length = len(self.tokenizer.text_to_tokens(input_text)) 75 | if input_length > self.max_input_tokens - self.template_tokens: 76 | raise AssertionError( 77 | f"Only input too long: {input_length} tokens, max {self.max_input_tokens - self.template_tokens} tokens." 78 | ) 79 | 80 | final_input = self.tokenizer.encode_for_generation(input_text, return_tensors=False) 81 | final_input_length = len(final_input["input_ids"]) 82 | 83 | if final_input_length > self.max_input_tokens: 84 | raise AssertionError( 85 | f"Input + Template too long: {final_input_length} tokens, max {self.max_input_tokens} tokens." 86 | ) 87 | 88 | if final_input_length < 0.90 * self.max_input_tokens: 89 | raise AssertionError( 90 | f"Input + Template too short: {final_input_length} tokens, min {int(0.90 * self.max_input_tokens)} tokens." 91 | ) 92 | 93 | @abstractmethod 94 | def check_params(self) -> None: 95 | """Validate task-specific parameters. 96 | 97 | Raises: 98 | ValueError: If parameters are missing or invalid 99 | AssertionError: If parameters fail validation checks 100 | """ 101 | pass 102 | 103 | @property 104 | @abstractmethod 105 | def sample_class(self): 106 | """Return the sample class to use for this task. 107 | 108 | Returns: 109 | A class that inherits from AbstractSample 110 | """ 111 | pass 112 | 113 | def generate_samples(self) -> List[Dict[str, Any]]: 114 | """Generate task evaluation samples. 115 | 116 | Returns: 117 | List of sample dicts containing: 118 | input_text: Model input text 119 | gold_answer: Expected model output 120 | index: Sample index 121 | length: Input sequence length 122 | + Additional task-specific fields 123 | """ 124 | samples = [] 125 | for i in range(self.num_samples): 126 | sample = self.sample_class( 127 | sample_id=i, 128 | random_seed=self.random_seed, 129 | max_tokens=self.max_input_tokens - self.template_tokens, 130 | tokenizer=self.tokenizer, 131 | task_params=self.task_params, 132 | ) 133 | self.check_sample_length(sample.input_text, sample.gold_answer) 134 | samples.append(sample.to_dict()) 135 | 136 | return samples 137 | 138 | @staticmethod 139 | @abstractmethod 140 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, Any]: 141 | """Evaluate model predictions against gold answers. 142 | 143 | Args: 144 | predictions: List of sample dicts with model predictions in 'pred' field 145 | 146 | Returns: 147 | Dict of evaluation metrics 148 | """ 149 | pass 150 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/qa/qa_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Dict, Tuple 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | def get_dataset(dataset_name: str): 8 | """Get initialized QA dataset based on dataset name. 9 | 10 | Args: 11 | dataset_name: Name of dataset to load ('squad', 'quality', or 'toeflqa') 12 | 13 | Returns: 14 | Initialized QA dataset with data loaded 15 | """ 16 | # Get path to data directory relative to this file 17 | current_dir = Path(__file__).parent 18 | data_dir = current_dir / 'data' 19 | 20 | if dataset_name == "squad": 21 | data_path = data_dir / 'squad.json' 22 | return SQuADDataset(str(data_path)) 23 | elif dataset_name == "quality": 24 | data_path = data_dir / 'quality.jsonl' 25 | return QualityDataset(str(data_path)) 26 | elif dataset_name == "toeflqa": 27 | data_path = data_dir / 'toeflqa.jsonl' 28 | return TOEFLQADataset(str(data_path)) 29 | 30 | 31 | class QADataset(ABC): 32 | """Abstract base class defining the interface for QA dataset parsers. 33 | 34 | All QA dataset parsers should inherit from this class and implement the read_data method. 35 | 36 | Output Format: 37 | - qa_samples: List[Dict] where each dict contains: 38 | - context_idx: int, index into unique_contexts list 39 | - question: str, the question text 40 | - answer: Union[List[str], str], either: 41 | - List[str] for open-ended QA (e.g. SQuAD) 42 | - str for multiple choice QA (e.g. Quality, TOEFL) 43 | - unique_contexts: List[str], list of all unique context passages 44 | """ 45 | 46 | def __init__(self, data_path: str): 47 | """Initialize the QA dataset parser. 48 | """ 49 | self.qa_samples, self.unique_contexts = self.read_data(data_path) 50 | self._remove_duplicate_questions() 51 | 52 | def _remove_duplicate_questions(self): 53 | """Remove samples with duplicate question-context pairs.""" 54 | # Create a set to track seen question-context pairs 55 | seen_pairs = set() 56 | filtered_samples = [] 57 | 58 | for sample in self.qa_samples: 59 | # Create a unique key from question and context 60 | key = (sample['question'].strip().lower(), sample['context_idx']) 61 | 62 | # Only keep samples with unique question-context pairs 63 | if key not in seen_pairs: 64 | seen_pairs.add(key) 65 | filtered_samples.append(sample) 66 | 67 | # Update qa_samples with deduplicated list 68 | self.qa_samples = filtered_samples 69 | 70 | @abstractmethod 71 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]: 72 | """Read and process a QA dataset file. 73 | 74 | Args: 75 | data_path: Path to dataset file 76 | 77 | Returns: 78 | Tuple containing: 79 | - List of processed QA samples with context indices 80 | - List of unique context documents 81 | """ 82 | pass 83 | 84 | @property 85 | @abstractmethod 86 | def is_mcq(self) -> bool: 87 | """Whether this dataset contains multiple choice questions.""" 88 | pass 89 | 90 | 91 | class SQuADDataset(QADataset): 92 | """Parser for SQuAD format datasets.""" 93 | 94 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]: 95 | with open(data_path) as f: 96 | data = json.load(f) 97 | 98 | # Extract and deduplicate contexts 99 | all_contexts = [p['context'] for d in data['data'] for p in d['paragraphs']] 100 | unique_contexts = sorted(list(set(all_contexts))) 101 | context_to_idx = {context: idx for idx, context in enumerate(unique_contexts)} 102 | 103 | qa_samples = [] 104 | for article in data['data']: 105 | for paragraph in article['paragraphs']: 106 | curr_context_idx = context_to_idx[paragraph['context']] 107 | 108 | for qa in paragraph['qas']: 109 | if not qa['is_impossible']: 110 | qa_samples.append({ 111 | 'context_idx': curr_context_idx, 112 | 'question': qa['question'], 113 | 'answer': [a['text'] for a in qa['answers']], 114 | }) 115 | 116 | return qa_samples, unique_contexts 117 | 118 | @property 119 | def is_mcq(self) -> bool: 120 | return False 121 | 122 | 123 | class JSONLinesQADataset(QADataset): 124 | """Base parser for JSONL format QA datasets (Quality and TOEFL).""" 125 | 126 | def read_data(self, data_path: str) -> Tuple[List[Dict], List[str]]: 127 | with open(data_path) as f: 128 | data = [json.loads(line) for line in f] 129 | 130 | # First pass: collect unique contexts 131 | unique_contexts = sorted(list({item['context'] for item in data})) 132 | context_to_idx = {context: idx for idx, context in enumerate(unique_contexts)} 133 | 134 | # Second pass: create QA samples 135 | qa_samples = [] 136 | for item in data: 137 | context_idx = context_to_idx[item['context']] 138 | questions = item['questions'] 139 | answers = item['answer'] 140 | 141 | # Verify questions and answers match in length 142 | assert len(questions) == len(answers), \ 143 | f"Mismatch in questions ({len(questions)}) and answers ({len(answers)}) length" 144 | 145 | # Create individual QA samples 146 | for q, a in zip(questions, answers): 147 | # For TOEFL QA, remove "Question: " prefix if present 148 | if isinstance(self, TOEFLQADataset) and q.startswith("Question: "): 149 | q = q[len("Question: "):] 150 | 151 | qa_samples.append({ 152 | 'context_idx': context_idx, 153 | 'question': q, 154 | 'answer': a, 155 | }) 156 | 157 | return qa_samples, unique_contexts 158 | 159 | 160 | class QualityDataset(JSONLinesQADataset): 161 | """Parser for Quality format datasets.""" 162 | 163 | @property 164 | def is_mcq(self) -> bool: 165 | return True 166 | 167 | 168 | class TOEFLQADataset(JSONLinesQADataset): 169 | """Parser for TOEFL QA format datasets.""" 170 | 171 | @property 172 | def is_mcq(self) -> bool: 173 | return True 174 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/qa/qa_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from typing import Set 4 | 5 | 6 | def normalize_mcq_answer(answer: str) -> str: 7 | """Normalize multiple choice answer to extract just the letter choices. 8 | 9 | Args: 10 | answer: Raw answer string that may contain explanation tags and other text 11 | 12 | Returns: 13 | Normalized string containing only the answer letters (e.g., 'A' or 'AB') 14 | """ 15 | # First try to extract answer from tags if present 16 | answer_match = re.search(r'(.*?)', answer, re.DOTALL) 17 | if answer_match: 18 | answer = answer_match.group(1) 19 | 20 | # Remove any "Question X" prefix 21 | answer = re.sub(r'Question\s*\d+\.?\s*', '', answer) 22 | 23 | # Clean up answer text 24 | answer = answer.upper() 25 | answer = re.sub(r'[^A-D]', '', answer) 26 | 27 | # Sort multiple choices to ensure consistent ordering (e.g., 'BA' -> 'AB') 28 | answer = ''.join(sorted(answer)) 29 | 30 | return answer 31 | 32 | 33 | def extract_tagged_response(text: str) -> tuple[str, str]: 34 | """Extract explanation and answer from tagged response. 35 | 36 | Args: 37 | text: Full response text containing tagged sections 38 | 39 | Returns: 40 | Tuple of (explanation text, answer text) 41 | Returns empty strings if tags are not found 42 | """ 43 | explanation = '' 44 | answer = '' 45 | 46 | explanation_match = re.search(r'(.*?)', text, re.DOTALL) 47 | if explanation_match: 48 | explanation = explanation_match.group(1).strip() 49 | 50 | answer_match = re.search(r'(.*?)', text, re.DOTALL) 51 | if answer_match: 52 | answer = answer_match.group(1).strip() 53 | 54 | return explanation, answer 55 | 56 | 57 | def normalize_answer(s: str) -> str: 58 | """Normalize open-ended answer for comparison. 59 | 60 | Args: 61 | s: Raw answer string 62 | 63 | Returns: 64 | Normalized answer with consistent formatting 65 | """ 66 | def remove_articles(text): 67 | return re.sub(r"\b(a|an|the|and|or|about|to)\b", " ", text) 68 | 69 | def white_space_fix(text): 70 | return " ".join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return "".join(ch for ch in text if ch not in exclude) 75 | 76 | # Extract answer from tags if present 77 | answer_match = re.search(r'(.*?)', s, re.DOTALL) 78 | if answer_match: 79 | s = answer_match.group(1) 80 | 81 | return white_space_fix(remove_articles(remove_punc(s.lower()))) 82 | 83 | 84 | def get_token_overlap(pred: str, gold: str) -> tuple[Set[str], Set[str]]: 85 | """Get overlapping tokens between prediction and gold answer. 86 | 87 | Args: 88 | pred: Predicted answer text 89 | gold: Gold answer text 90 | 91 | Returns: 92 | Tuple of (prediction tokens, gold tokens) 93 | """ 94 | pred_tokens = set(pred.split()) 95 | gold_tokens = set(gold.split()) 96 | return pred_tokens, gold_tokens 97 | 98 | 99 | def f1_score(prediction: str, gold: str) -> float: 100 | """Calculate F1 score between prediction and gold answer. 101 | 102 | Args: 103 | prediction: Predicted answer text 104 | gold: Gold answer text 105 | 106 | Returns: 107 | F1 score between 0 and 1 108 | """ 109 | prediction_tokens, gold_tokens = get_token_overlap(prediction, gold) 110 | 111 | true_positives = len(prediction_tokens & gold_tokens) 112 | false_positives = len(prediction_tokens - gold_tokens) 113 | false_negatives = len(gold_tokens - prediction_tokens) 114 | 115 | precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 116 | recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 117 | 118 | if precision + recall == 0: 119 | return 0.0 120 | 121 | f1 = 2 * (precision * recall) / (precision + recall) 122 | return round(f1, 2) 123 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/registry.py: -------------------------------------------------------------------------------- 1 | from .qa.qa_task import QATask 2 | from .ruler.cwe import CommonWordTask 3 | from .ruler.niah import NIAHTask 4 | from .ruler.vt import VariableTrackingTask 5 | from .story.multihop import MultiHopTask 6 | from .story.filtering import FilteringTask 7 | from .story.retrieval import RetrievalTask 8 | 9 | 10 | TASK_REGISTRY = { 11 | "qa_quality": QATask, 12 | "qa_squad": QATask, 13 | "qa_toefl": QATask, 14 | "ruler_cwe": CommonWordTask, 15 | "ruler_niah": NIAHTask, 16 | "ruler_vt": VariableTrackingTask, 17 | "story_multihop": MultiHopTask, 18 | "story_filtering": FilteringTask, 19 | "story_retrieval": RetrievalTask, 20 | } 21 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/ruler/cwe.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating common word frequency tracking capabilities of language models. 2 | 3 | This module implements a task where models need to identify frequently occurring 4 | words in a list containing both common and rare words. The task tests the model's ability to: 5 | 1. Track word frequencies in a long list 6 | 2. Identify words that appear more frequently than others 7 | 3. Ignore less frequent distractors 8 | """ 9 | 10 | from typing import List, Dict, Any, Tuple 11 | import re 12 | import numpy as np 13 | 14 | from sparse_frontier.tasks.abstract_task import AbstractTask 15 | from sparse_frontier.tasks.abstract_sample import AbstractSample 16 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE 17 | 18 | TASK_INTRO = """You will be given a numbered list of words. Your task is to identify the most frequently occurring words. You should solve this task by carefully reading and analyzing the word list. Do not attempt to write code or use programming tools to count frequencies. This is a test of your ability to track word frequencies directly.""" 19 | 20 | QUESTION_TEMPLATE = """The list contains exactly {num_common} words that appear {common_freq} times each. All other words appear {rare_freq} times each. The order of words in the list is randomized. 21 | Your task is to identify the {num_common} words that appear {common_freq} times each.""" 22 | 23 | ANSWER_FORMAT = """1. word_one 24 | 2. word_two 25 | ... 26 | {num_common}. word_{num_common} 27 | 28 | Note: List exactly {num_common} words, one per line, numbered from 1 to {num_common}.""" 29 | 30 | 31 | class CommonWordSample(AbstractSample): 32 | """Handles generation of individual common word frequency tracking samples.""" 33 | 34 | def __init__(self, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | 37 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 38 | """Generate a single common word frequency tracking sample.""" 39 | 40 | words = self.task_params['words'].copy() 41 | self.random_obj.shuffle(words) 42 | 43 | common_words = words[:self.task_params['num_common_words']] 44 | rare_words = words[self.task_params['num_common_words']:] 45 | 46 | # Create list of common words with repetitions 47 | common_word_list = common_words * self.task_params['common_word_frequency'] 48 | 49 | # Binary search to find maximum number of rare words that fit 50 | left, right = 0, len(rare_words) 51 | while left < right: 52 | mid = (left + right + 1) // 2 53 | 54 | # Create test word list with current number of rare words 55 | test_words = ( 56 | common_word_list + 57 | rare_words[:mid] * self.task_params['rare_word_frequency'] 58 | ) 59 | test_context = '\n'.join(f"{i+1}. {word}" for i, word in enumerate(test_words)) 60 | 61 | # Format question with current parameters 62 | question = QUESTION_TEMPLATE.format( 63 | num_common=self.task_params['num_common_words'], 64 | common_freq=self.task_params['common_word_frequency'], 65 | rare_freq=self.task_params['rare_word_frequency'] 66 | ) 67 | 68 | # Calculate total tokens with full prompt template 69 | test_input = SINGLEQ_PROMPT_TEMPLATE.format( 70 | task_intro=TASK_INTRO, 71 | context=test_context, 72 | question=question, 73 | answer_format=ANSWER_FORMAT.format( 74 | num_common=self.task_params['num_common_words'] 75 | ), 76 | extra_instructions="" 77 | ) 78 | total_tokens = len(self.tokenizer.text_to_tokens(test_input)) 79 | 80 | if total_tokens <= self.max_tokens: 81 | left = mid 82 | else: 83 | right = mid - 1 84 | 85 | num_rare_words = left 86 | 87 | # Ensure we have space for at least one rare word 88 | assert num_rare_words > 0, "No space for rare words after common words" 89 | 90 | # Create final word list with optimal number of rare words 91 | final_words = ( 92 | common_word_list + 93 | rare_words[:num_rare_words] * self.task_params['rare_word_frequency'] 94 | ) 95 | self.random_obj.shuffle(final_words) 96 | 97 | # Format context and question 98 | context = '\n'.join(f"{i+1}. {word}" for i, word in enumerate(final_words)) 99 | question = QUESTION_TEMPLATE.format( 100 | num_common=self.task_params['num_common_words'], 101 | common_freq=self.task_params['common_word_frequency'], 102 | rare_freq=self.task_params['rare_word_frequency'] 103 | ) 104 | 105 | input_text = SINGLEQ_PROMPT_TEMPLATE.format( 106 | task_intro=TASK_INTRO, 107 | context=context, 108 | question=question, 109 | answer_format=ANSWER_FORMAT.format( 110 | num_common=self.task_params['num_common_words'] 111 | ), 112 | extra_instructions="" 113 | ) 114 | 115 | gold_answer = '\n'.join(f"{i+1}. {word.lower()}" for i, word in enumerate(common_words)) 116 | 117 | extra_data = { 118 | "num_total_words": len(final_words), 119 | "common_words": [w.lower() for w in common_words], 120 | "common_frequency": self.task_params['common_word_frequency'], 121 | "rare_frequency": self.task_params['rare_word_frequency'] 122 | } 123 | 124 | return input_text, gold_answer, extra_data 125 | 126 | 127 | class CommonWordTask(AbstractTask): 128 | """Task for evaluating common word frequency tracking capabilities.""" 129 | 130 | def __init__( 131 | self, 132 | common_word_frequency: int = 30, 133 | rare_word_frequency: int = 3, 134 | num_common_words: int = 10, 135 | **kwargs 136 | ) -> None: 137 | """Initialize common word frequency tracking task. 138 | 139 | Args: 140 | common_word_frequency: Number of times each common word appears 141 | rare_word_frequency: Number of times each rare word appears 142 | num_common_words: Number of common words to identify 143 | **kwargs: Additional arguments passed to parent class 144 | """ 145 | super().__init__(**kwargs) 146 | self._create_word_list() 147 | self.task_params.update({ 148 | 'common_word_frequency': common_word_frequency, 149 | 'rare_word_frequency': rare_word_frequency, 150 | 'num_common_words': num_common_words, 151 | 'words': self.words, 152 | }) 153 | self.check_params() 154 | 155 | def _create_word_list(self) -> None: 156 | from wonderwords import random_word 157 | nouns = random_word._get_words_from_text_file("nounlist.txt") 158 | adjs = random_word._get_words_from_text_file("adjectivelist.txt") 159 | verbs = random_word._get_words_from_text_file("verblist.txt") 160 | words = nouns + adjs + verbs 161 | self.words = sorted(list(set(words))) 162 | self.random_obj.shuffle(self.words) 163 | self.words = [word for word in self.words if '-' not in word] 164 | 165 | def check_params(self) -> None: 166 | """Validate task parameters.""" 167 | if not isinstance(self.task_params.get('common_word_frequency'), int): 168 | raise ValueError("common_word_frequency must be an integer") 169 | if not isinstance(self.task_params.get('rare_word_frequency'), int): 170 | raise ValueError("rare_word_frequency must be an integer") 171 | if not isinstance(self.task_params.get('num_common_words'), int): 172 | raise ValueError("num_common_words must be an integer") 173 | 174 | if self.task_params['common_word_frequency'] <= self.task_params['rare_word_frequency']: 175 | raise ValueError("common_word_frequency must be greater than rare_word_frequency") 176 | if self.task_params['num_common_words'] < 1: 177 | raise ValueError("num_common_words must be at least 1") 178 | 179 | @property 180 | def sample_class(self): 181 | return CommonWordSample 182 | 183 | @staticmethod 184 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, float]: 185 | """Evaluate model predictions against gold answers. 186 | 187 | For each prediction, calculates the intersection over union (IoU) between 188 | the predicted set of common words and the gold set. 189 | 190 | Returns mean IoU across all predictions and IoU variance. 191 | """ 192 | sample_ious = [] 193 | 194 | for pred in predictions: 195 | # Extract answer section 196 | answer_match = re.search(r'(.*?)', pred['pred'], re.DOTALL | re.IGNORECASE) 197 | if not answer_match: 198 | sample_ious.append(0.0) 199 | continue 200 | 201 | answer_text = answer_match.group(1).strip() 202 | 203 | # Extract words from prediction as a set, normalizing and handling edge cases 204 | pred_words = set() 205 | for line in answer_text.split('\n'): 206 | line = line.strip() 207 | if not line: 208 | continue 209 | 210 | # Match numbered lines, handling various formats (1., 1), 1-, etc.) 211 | match = re.match(r'^\d+[.)\-]?\s*(.+)$', line) 212 | if match: 213 | word = match.group(1).strip().lower() 214 | if word: 215 | pred_words.add(word) 216 | 217 | gold_words = set(pred['common_words']) 218 | 219 | # Calculate intersection over union 220 | intersection = len(pred_words & gold_words) 221 | union = len(pred_words | gold_words) 222 | iou = intersection / union if union > 0 else 0.0 223 | sample_ious.append(iou) 224 | 225 | # Calculate mean and variance 226 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0 227 | # Use ddof=1 for unbiased estimate of the variance 228 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0 229 | 230 | return { 231 | 'iou': mean_iou, 232 | 'iou_variance': variance 233 | } 234 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/ruler/niah.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating needle-in-a-haystack (NIAH) capabilities of language models. 2 | 3 | This module implements a task where models need to extract specific key-value pairs from 4 | a document containing both relevant pairs and distractors. The task tests the model's 5 | ability to: 6 | 1. Follow precise instructions 7 | 2. Extract relevant information while ignoring distractors 8 | 3. Format responses according to a specified template 9 | 4. Provide clear reasoning for answers 10 | """ 11 | 12 | import string 13 | import re 14 | from typing import List, Dict, Any, Tuple 15 | import numpy as np 16 | 17 | from sparse_frontier.tasks.abstract_task import AbstractTask 18 | from sparse_frontier.tasks.abstract_sample import AbstractSample 19 | from sparse_frontier.tasks.abstract_prompt import MULTIPLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE 20 | 21 | # Task introduction and instructions 22 | TASK_INTRO = """I will provide you with a document containing multiple key-value pairs. Your task is to extract specific values associated with given keys.""" 23 | 24 | ANSWER_FORMAT = """1. The answer for is . 25 | 2. The answer for is . 26 | etc.""" 27 | 28 | EXTRA_INSTRUCTIONS = """ 29 | - Provide answers in the exact order of the requested keys 30 | - Each answer must follow the format: ". The answer for is ." 31 | - Ensure exact key matches - do not modify or paraphrase the keys 32 | - Values must match exactly as they appear in the document 33 | """.strip() 34 | 35 | 36 | class Needle: 37 | """Represents a key-value pair in the NIAH task.""" 38 | 39 | def __init__(self, random_obj, used_keys: set, used_values: set) -> None: 40 | self.key = self.generate_unique(used_keys, random_obj) 41 | self.value = self.generate_unique(used_values, random_obj) 42 | 43 | @staticmethod 44 | def generate_kv_part(random_obj, length: int = 8) -> str: 45 | return ''.join(random_obj.choices(string.ascii_lowercase + string.digits, k=length)) 46 | 47 | @staticmethod 48 | def generate_kv(random_obj) -> str: 49 | parts = [Needle.generate_kv_part(random_obj) for _ in range(4)] 50 | return '-'.join(parts) 51 | 52 | @staticmethod 53 | def generate_unique(used_kv: set, random_obj) -> str: 54 | while True: 55 | key = Needle.generate_kv(random_obj) 56 | if key not in used_kv: 57 | used_kv.add(key) 58 | return key 59 | 60 | def to_sentence(self) -> str: 61 | return f"The value for key {self.key} is: {self.value}." 62 | 63 | 64 | class NIAHSample(AbstractSample): 65 | """Represents a single NIAH task sample with queries and distractors.""" 66 | 67 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 68 | """Generate the input text, gold answer and extra data for this sample.""" 69 | num_queries = self.task_params['num_queries'] 70 | used_keys = set() 71 | used_values = set() 72 | needles = [ 73 | Needle(self.random_obj, used_keys, used_values) 74 | for _ in range(num_queries) 75 | ] 76 | 77 | # Generate query sentences and shuffle 78 | queries_sentences = [needle.to_sentence() for needle in needles] 79 | self.random_obj.shuffle(queries_sentences) 80 | 81 | def get_current_token_count(sentences: List[str], keys: List[str]) -> int: 82 | question = f"Extract the values for the following keys: {', '.join(keys)}" 83 | text = PROMPT_TEMPLATE.format( 84 | task_intro=TASK_INTRO, 85 | context=" ".join(sentences), 86 | question=question, 87 | answer_format=ANSWER_FORMAT, 88 | extra_instructions=EXTRA_INSTRUCTIONS 89 | ) 90 | return len(self.tokenizer.text_to_tokens(text)) 91 | 92 | current_token_count = get_current_token_count( 93 | queries_sentences, 94 | [needle.key for needle in needles] 95 | ) 96 | tokens_needed = self.max_tokens - current_token_count 97 | 98 | if tokens_needed < 0: 99 | raise ValueError( 100 | f"Needles are too long. Current length: {current_token_count}, " 101 | f"maximum length: {self.max_tokens}" 102 | ) 103 | 104 | # Generate distractor sentences 105 | distractors_sentences = [] 106 | while True: 107 | new_distractor = Needle(self.random_obj, used_keys, used_values) 108 | distractor_sentence = new_distractor.to_sentence() 109 | distractor_token_length = len(self.tokenizer.text_to_tokens(distractor_sentence)) 110 | if tokens_needed < distractor_token_length: 111 | break 112 | distractors_sentences.append(distractor_sentence) 113 | tokens_needed -= distractor_token_length 114 | 115 | # Combine and shuffle all sentences 116 | all_sentences = queries_sentences + distractors_sentences 117 | self.random_obj.shuffle(all_sentences) 118 | 119 | # Generate question 120 | keys = [needle.key for needle in needles] 121 | question = f"Extract the values for the following keys: {', '.join(keys)}" 122 | 123 | # Format input using template 124 | input_text = PROMPT_TEMPLATE.format( 125 | task_intro=TASK_INTRO, 126 | context=" ".join(all_sentences), 127 | question=question, 128 | answer_format=ANSWER_FORMAT, 129 | extra_instructions=EXTRA_INSTRUCTIONS 130 | ) 131 | 132 | # Generate gold answer 133 | gold_answer = "\n".join( 134 | f"{i+1}. The answer for {needle.key} is {needle.value}." 135 | for i, needle in enumerate(needles) 136 | ) 137 | 138 | extra_data = { 139 | "answers": [(needle.key, needle.value) for needle in needles] 140 | } 141 | 142 | return input_text, gold_answer, extra_data 143 | 144 | 145 | class NIAHTask(AbstractTask): 146 | """Main task class for the Needle-in-a-Haystack evaluation.""" 147 | 148 | def __init__( 149 | self, 150 | num_queries: int, 151 | **kwargs 152 | ) -> None: 153 | super().__init__(**kwargs) 154 | self.task_params['num_queries'] = num_queries 155 | self.check_params() 156 | 157 | def check_params(self) -> None: 158 | if 'num_queries' not in self.task_params: 159 | raise ValueError("Missing required parameter 'num_queries'") 160 | 161 | if not isinstance(self.task_params['num_queries'], int): 162 | raise ValueError("Parameter 'num_queries' must be an integer") 163 | 164 | if self.task_params['num_queries'] < 1: 165 | raise AssertionError("Parameter 'num_queries' must be greater than or equal to 1") 166 | 167 | @property 168 | def sample_class(self): 169 | return NIAHSample 170 | 171 | @staticmethod 172 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]: 173 | """Evaluates model predictions against gold answers.""" 174 | def normalize_answer(text: str) -> str: 175 | """Normalize answer text for comparison.""" 176 | # Convert to lowercase and remove extra whitespace 177 | text = re.sub(r'\s+', ' ', text.lower().strip()) 178 | # Remove optional colon after "is" 179 | text = re.sub(r'\bis:\s+', 'is ', text) 180 | return text 181 | 182 | def extract_answers(text: str) -> Dict[int, Tuple[str, str]]: 183 | """Extract answers from text, handling both formats.""" 184 | # First try to find the section 185 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) 186 | if answer_match: 187 | text = answer_match.group(1) 188 | 189 | # Extract numbered answers with key-value pairs 190 | answers = {} 191 | pattern = re.compile( 192 | r'(\d+)\.\s*The answer for\s+([\w-]+)\s+is:?\s+(.+?)(?:\.|$)', 193 | re.IGNORECASE | re.MULTILINE 194 | ) 195 | 196 | for match in pattern.finditer(text): 197 | idx = int(match.group(1)) 198 | key = match.group(2).strip() 199 | value = match.group(3).strip() 200 | if idx not in answers: # Take first occurrence if duplicates 201 | answers[idx] = (key, normalize_answer(value)) 202 | 203 | return answers 204 | 205 | sample_accuracies = [] 206 | 207 | for example in examples: 208 | answers = example['answers'] 209 | prediction = example['pred'] 210 | 211 | # Get gold answer key-value pairs 212 | gold_pairs = { 213 | i + 1: (key, normalize_answer(value)) 214 | for i, (key, value) in enumerate(answers) 215 | } 216 | 217 | # Extract predicted key-value pairs 218 | pred_pairs = extract_answers(prediction) 219 | 220 | # Compare predictions 221 | correct = 0 222 | total = len(gold_pairs) 223 | 224 | for idx, (gold_key, gold_value) in gold_pairs.items(): 225 | if (idx in pred_pairs and 226 | pred_pairs[idx][0] == gold_key and 227 | pred_pairs[idx][1] == gold_value): 228 | correct += 1 229 | 230 | # Calculate accuracy for this sample 231 | sample_accuracies.append(correct / total if total > 0 else 0.0) 232 | 233 | # Calculate mean and variance 234 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0 235 | # Use ddof=1 for unbiased estimate of the variance 236 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0 237 | 238 | return { 239 | 'accuracy': mean_accuracy, 240 | 'accuracy_variance': variance 241 | } 242 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/ruler/vt.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating variable tracking capabilities of language models. 2 | 3 | This module implements a task where models need to track chains of variable assignments 4 | in a document containing both relevant assignments and distractors. The task tests the model's 5 | ability to: 6 | 1. Track multiple variable assignments through chains 7 | 2. Identify all variables that eventually get assigned a specific value 8 | 3. Ignore irrelevant distractors and noise 9 | """ 10 | 11 | from typing import List, Dict, Any, Tuple 12 | import string 13 | import re 14 | import numpy as np 15 | 16 | from sparse_frontier.tasks.abstract_task import AbstractTask 17 | from sparse_frontier.tasks.abstract_sample import AbstractSample 18 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE 19 | 20 | NOISE_SENTENCE = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 21 | 22 | # Task introduction and instructions 23 | TASK_INTRO = """I will provide you with a text containing variable assignments. The text contains two types of assignments: 24 | 1. Numeric assignments that set a variable to a number (e.g., "VAR ABC = 12345") 25 | 2. Copy assignments that set a variable equal to another variable (e.g., "VAR XYZ = VAR ABC") 26 | Variables are sequences of uppercase letters. The assignments can appear in any order in the text.""" 27 | 28 | ANSWER_FORMAT = "VARIABLE_ONE VARIABLE_TWO etc." 29 | 30 | EXTRA_INSTRUCTIONS = """ 31 | - List ONLY the variable names that resolve to the target value. 32 | - Variables can be listed in any order. 33 | - Do not include "VAR" prefix in your answer. Do not include punctuation. 34 | """.strip() 35 | 36 | QUESTION_TEMPLATE = "Which variables resolve to the value {target_value}? A variable resolves to {target_value} if it is either directly assigned {target_value}, or assigned to another variable that resolves to {target_value}." 37 | 38 | 39 | class VariableTrackingSample(AbstractSample): 40 | """Handles generation of individual variable tracking samples.""" 41 | 42 | def _generate_random_var(self) -> str: 43 | """Generate a random 5-letter uppercase variable name.""" 44 | return ''.join(self.random_obj.choices(string.ascii_uppercase, k=5)) 45 | 46 | def _generate_unique_vars(self, num_vars: int) -> List[str]: 47 | """Generate a list of unique variable names. 48 | 49 | Args: 50 | num_vars: Number of unique variable names needed 51 | 52 | Returns: 53 | List of unique variable names 54 | """ 55 | unique_vars = [] 56 | while len(unique_vars) < num_vars: 57 | new_var = self._generate_random_var() 58 | if new_var not in unique_vars: 59 | unique_vars.append(new_var) 60 | return unique_vars 61 | 62 | def _create_chain(self, vars_for_chain: List[str], initial_value: int) -> List[str]: 63 | """Create a single chain of variable assignments. 64 | 65 | Args: 66 | vars_for_chain: List of variables to use in the chain 67 | initial_value: Starting numeric value for the chain 68 | 69 | Returns: 70 | List of assignment statements forming the chain 71 | """ 72 | chain = [f"VAR {vars_for_chain[0]} = {initial_value}"] 73 | for i in range(len(vars_for_chain) - 1): 74 | chain.append(f"VAR {vars_for_chain[i+1]} = VAR {vars_for_chain[i]}") 75 | return chain 76 | 77 | def _generate_chains(self) -> Tuple[List[List[str]], List[str], str]: 78 | """Generate variable assignment chains. 79 | 80 | Returns: 81 | Tuple containing: 82 | - List of variable chains 83 | - List of variables that get the target value 84 | - Target value that propagates through first chain 85 | """ 86 | num_chains = self.task_params['num_chains'] 87 | num_hops = self.task_params['num_hops'] 88 | vars_per_chain = num_hops + 1 89 | total_vars_needed = num_chains * vars_per_chain 90 | 91 | # Generate all unique variables needed 92 | all_vars = self._generate_unique_vars(total_vars_needed) 93 | 94 | # Generate unique integers for each chain 95 | unique_integers = self.random_obj.sample(range(10000, 99999), num_chains) 96 | 97 | # Create chains 98 | chains = [] 99 | target_vars = None 100 | target_value = str(unique_integers[0]) # Value that propagates through first chain 101 | 102 | for i in range(num_chains): 103 | # Get variables for this chain 104 | chain_vars = all_vars[i*vars_per_chain:(i+1)*vars_per_chain] 105 | chain = self._create_chain(chain_vars, unique_integers[i]) 106 | chains.append(chain) 107 | 108 | # Store variables from first chain as target 109 | if i == 0: 110 | target_vars = chain_vars 111 | 112 | return chains, target_vars, target_value 113 | 114 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 115 | """Generate a single variable tracking sample.""" 116 | chains, target_vars, target_value = self._generate_chains() 117 | 118 | # Extract all variable assignment statements 119 | assignment_statements = [] 120 | for chain in chains: 121 | assignment_statements.extend(chain) 122 | 123 | # Calculate tokens used by assignments and prompt 124 | assignment_text = " ".join(assignment_statements) 125 | prompt_tokens = len(self.tokenizer.text_to_tokens( 126 | SINGLEQ_PROMPT_TEMPLATE.format( 127 | task_intro=TASK_INTRO, 128 | context=assignment_text, 129 | question=QUESTION_TEMPLATE.format(target_value=target_value), 130 | answer_format=ANSWER_FORMAT, 131 | extra_instructions=EXTRA_INSTRUCTIONS 132 | ) 133 | )) 134 | 135 | # Calculate how many noise sentences we can add 136 | noise_tokens = len(self.tokenizer.text_to_tokens(NOISE_SENTENCE)) 137 | remaining_tokens = self.max_tokens - prompt_tokens 138 | num_noise_sentences = remaining_tokens // noise_tokens 139 | num_noise_sentences = max(num_noise_sentences - 5, 0) # Safety margin 140 | 141 | # Create final list of sentences and shuffle 142 | sentences = assignment_statements + [NOISE_SENTENCE] * num_noise_sentences 143 | self.random_obj.shuffle(sentences) 144 | 145 | # Format context and question 146 | context = " ".join(sentences) 147 | question = QUESTION_TEMPLATE.format(target_value=target_value) 148 | 149 | # Format input using template 150 | input_text = SINGLEQ_PROMPT_TEMPLATE.format( 151 | task_intro=TASK_INTRO, 152 | context=context, 153 | question=question, 154 | answer_format=ANSWER_FORMAT, 155 | extra_instructions=EXTRA_INSTRUCTIONS 156 | ) 157 | 158 | # Format gold answer 159 | gold_answer = " ".join(target_vars) 160 | 161 | extra_data = { 162 | "target_value": target_value, 163 | "num_chains": len(chains), 164 | "num_hops": self.task_params['num_hops'], 165 | "target_vars": target_vars 166 | } 167 | 168 | return input_text, gold_answer, extra_data 169 | 170 | 171 | class VariableTrackingTask(AbstractTask): 172 | """Task for evaluating variable tracking capabilities.""" 173 | 174 | def __init__( 175 | self, 176 | num_chains: int = 1, 177 | num_hops: int = 4, 178 | **kwargs 179 | ) -> None: 180 | """Initialize variable tracking task. 181 | 182 | Args: 183 | num_chains: Number of variable chains to include 184 | num_hops: Number of variable assignments in each chain 185 | **kwargs: Additional arguments passed to parent class 186 | """ 187 | super().__init__(**kwargs) 188 | self.task_params.update({ 189 | 'num_chains': num_chains, 190 | 'num_hops': num_hops 191 | }) 192 | self.check_params() 193 | 194 | def check_params(self) -> None: 195 | """Validate task parameters.""" 196 | if not isinstance(self.task_params.get('num_chains'), int): 197 | raise ValueError("num_chains must be an integer") 198 | if not isinstance(self.task_params.get('num_hops'), int): 199 | raise ValueError("num_hops must be an integer") 200 | if self.task_params['num_chains'] < 1: 201 | raise ValueError("num_chains must be at least 1") 202 | if self.task_params['num_hops'] < 1: 203 | raise ValueError("num_hops must be at least 1") 204 | 205 | @property 206 | def sample_class(self): 207 | return VariableTrackingSample 208 | 209 | @staticmethod 210 | def evaluate(predictions: List[Dict[str, Any]]) -> Dict[str, float]: 211 | """Evaluate model predictions against gold answers using intersection over union. 212 | 213 | For each prediction, calculates IoU between predicted and gold variable sets: 214 | IoU = |intersection| / |union| 215 | 216 | Returns mean IoU across all predictions and variance. 217 | """ 218 | def normalize_answer(text: str) -> str: 219 | """Normalize answer text for comparison.""" 220 | # Extract answer from tagged response if present 221 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) 222 | if answer_match: 223 | text = answer_match.group(1) 224 | 225 | # Convert to uppercase and remove extra whitespace 226 | text = re.sub(r'\s+', ' ', text.upper().strip()) 227 | # Remove any "VAR" prefixes 228 | text = re.sub(r'VAR\s+', '', text) 229 | return text 230 | 231 | sample_ious = [] 232 | 233 | for pred in predictions: 234 | # Convert predictions and gold answers to sets of variables 235 | pred_vars = set(normalize_answer(pred['pred']).split()) 236 | gold_vars = set(normalize_answer(pred['gold_answer']).split()) 237 | 238 | # Calculate intersection over union 239 | intersection = len(pred_vars & gold_vars) 240 | union = len(pred_vars | gold_vars) 241 | 242 | # Handle edge case where both sets are empty 243 | iou = 1.0 if union == 0 else intersection / union 244 | sample_ious.append(iou) 245 | 246 | # Calculate mean and variance of IoU scores 247 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0 248 | # Use ddof=1 for unbiased estimate of the variance 249 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0 250 | 251 | return { 252 | 'iou': mean_iou, 253 | 'iou_variance': variance 254 | } 255 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/story/filtering.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating model's ability to identify chapters without purchases. 2 | 3 | This module implements a task where models need to identify chapters in a narrative where 4 | the protagonist did not make any purchases. The task tests the model's ability to: 5 | 1. Follow precise instructions 6 | 2. Track item transactions across a narrative 7 | 3. Format responses according to a specified template 8 | """ 9 | 10 | import re 11 | from typing import Any, Dict, List, Set, Tuple 12 | import copy 13 | import numpy as np 14 | 15 | from sparse_frontier.tasks.abstract_task import AbstractTask 16 | from sparse_frontier.tasks.abstract_sample import AbstractSample 17 | from sparse_frontier.tasks.story.narrative import Chapter, NarrativeGenerator 18 | from sparse_frontier.tasks.story.templates import TASK_INTRO 19 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE 20 | 21 | QUESTION = """Identify all chapters where the protagonist did not buy any item. 22 | Note: There are exactly {num_chapters} chapters without any purchases.""" 23 | 24 | ANSWER_FORMAT = "chapter_id_1, chapter_id_2, ..." 25 | 26 | EXTRA_INSTRUCTIONS = """ 27 | - In the answer section, provide only the chapter IDs separated by commas. 28 | """.strip() 29 | 30 | PROMPT = PROMPT_TEMPLATE.format( 31 | task_intro=TASK_INTRO, 32 | question="{question}", 33 | context="{context}", 34 | answer_format=ANSWER_FORMAT, 35 | extra_instructions=EXTRA_INSTRUCTIONS 36 | ) 37 | 38 | 39 | class FilteringSample(AbstractSample): 40 | """Represents a single filtering task sample.""" 41 | 42 | @staticmethod 43 | def _remove_buying_transactions(chapter: Chapter, items_to_remove: Set[str]) -> None: 44 | """Remove buying transactions for specified items from a chapter.""" 45 | new_bought = [] 46 | new_buying_trans = [] 47 | 48 | for item, trans in zip(chapter.bought_items, chapter.structure['buying_transactions']): 49 | if item not in items_to_remove: 50 | new_bought.append(item) 51 | new_buying_trans.append(trans) 52 | 53 | chapter.bought_items = new_bought 54 | chapter.structure['buying_transactions'] = new_buying_trans 55 | 56 | def _modify_chapters(self, chapters: List[Chapter], chapter_ids: List[int]) -> List[Chapter]: 57 | """Modify chapters by removing purchases from specified chapters and updating subsequent chapters.""" 58 | modified_chapters = copy.deepcopy(chapters) 59 | items_to_track = set() 60 | 61 | # Remove purchases from specified chapters and track items 62 | for chapter in modified_chapters: 63 | if chapter.chapter_id in chapter_ids: 64 | items_to_track.update(chapter.bought_items) 65 | self._remove_buying_transactions(chapter, set(chapter.bought_items)) 66 | 67 | return modified_chapters 68 | 69 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 70 | """Generate the input text, gold answer and extra data for this sample.""" 71 | # Calculate prompt tokens 72 | prompt_tokens = len(self.tokenizer.text_to_tokens( 73 | PROMPT.format( 74 | context="", 75 | question=QUESTION.format(num_chapters=1), 76 | ) 77 | )) 78 | 79 | # Calculate remaining tokens for narrative 80 | narrative_tokens = self.max_tokens - prompt_tokens 81 | 82 | # Generate narrative with adjusted token limit 83 | narrative_gen = NarrativeGenerator( 84 | tokenizer=self.tokenizer, 85 | sequence_length=narrative_tokens, 86 | random_obj=self.random_obj, 87 | ) 88 | 89 | # Randomly select chapters to remove purchases from 90 | num_chapters = len(narrative_gen.chapters) 91 | chapters_to_modify = sorted( 92 | self.random_obj.sample(range(1, num_chapters + 1), self.task_params['chapters_in_question']) 93 | ) 94 | 95 | modified_chapters = self._modify_chapters(narrative_gen.chapters, chapters_to_modify) 96 | 97 | # Build prompt 98 | context = "\n\n".join(ch.compile_text() for ch in modified_chapters) 99 | input_text = PROMPT.format( 100 | context=context, 101 | question=QUESTION.format(num_chapters=self.task_params['chapters_in_question']) 102 | ) 103 | 104 | # Generate expected output 105 | gold_answer = ", ".join(map(str, chapters_to_modify)) if chapters_to_modify else "" 106 | 107 | return input_text, gold_answer, {} 108 | 109 | 110 | class FilteringTask(AbstractTask): 111 | """Task class for evaluating chapter purchase identification capabilities.""" 112 | 113 | def __init__( 114 | self, 115 | chapters_in_question: int, 116 | protagonist_name: str = "Arion", 117 | **kwargs 118 | ) -> None: 119 | """Initialize the filtering task. 120 | 121 | Args: 122 | chapters_in_question: Number of chapters to remove purchases from 123 | protagonist_name: Name of the story protagonist 124 | **kwargs: Additional arguments passed to AbstractTask 125 | """ 126 | super().__init__(**kwargs) 127 | self.task_params['chapters_in_question'] = chapters_in_question 128 | self.task_params['protagonist_name'] = protagonist_name 129 | self.check_params() 130 | 131 | def check_params(self) -> None: 132 | """Validate task parameters.""" 133 | if 'chapters_in_question' not in self.task_params: 134 | raise ValueError("Missing required parameter 'chapters_in_question'") 135 | 136 | if not isinstance(self.task_params['chapters_in_question'], int): 137 | raise ValueError("Parameter 'chapters_in_question' must be an integer") 138 | 139 | if self.task_params['chapters_in_question'] < 1: 140 | raise ValueError("Parameter 'chapters_in_question' must be at least 1") 141 | 142 | @property 143 | def sample_class(self): 144 | return FilteringSample 145 | 146 | @staticmethod 147 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]: 148 | """Evaluate model predictions against gold answers. 149 | 150 | Args: 151 | examples: List of dictionaries containing predictions and gold answers 152 | 153 | Returns: 154 | Dictionary containing IoU (Intersection over Union) metric and its variance 155 | """ 156 | def normalize_answer(text: str) -> Set[int]: 157 | """Extract and normalize chapter IDs from answer text.""" 158 | # Remove any non-numeric characters except commas and whitespace 159 | text = re.sub(r'[^0-9,\s]', '', text.strip()) 160 | # Split on comma or whitespace 161 | parts = re.split(r'[,\s]+', text) 162 | # Convert to integers, ignoring any invalid parts 163 | try: 164 | return {int(p) for p in parts if p.strip()} 165 | except ValueError: 166 | return set() 167 | 168 | def extract_answer(text: str) -> Set[int]: 169 | """Extract answer from the formatted response.""" 170 | match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) 171 | if not match: 172 | return set() 173 | answer_text = match.group(1) 174 | 175 | # Handle both formats: "1, 2, 3" and "Chapter 1, Chapter 2, Chapter 3" 176 | # First try to extract chapter numbers from "Chapter X" format 177 | chapter_matches = re.findall(r'Chapter\s+(\d+)', answer_text, re.IGNORECASE) 178 | if chapter_matches: 179 | return {int(num) for num in chapter_matches} 180 | 181 | # If no "Chapter X" format found, fall back to original number parsing 182 | return normalize_answer(answer_text) 183 | 184 | sample_ious = [] 185 | 186 | for ex in examples: 187 | pred = ex.get("pred", "").strip() 188 | gold = ex.get("gold_answer", "").strip() 189 | 190 | if not pred or not gold: 191 | continue 192 | 193 | gold_set = normalize_answer(gold) 194 | pred_set = extract_answer(pred) 195 | 196 | if not gold_set and not pred_set: 197 | sample_ious.append(1.0) 198 | continue 199 | 200 | union = gold_set.union(pred_set) 201 | intersect = gold_set.intersection(pred_set) 202 | iou = len(intersect) / len(union) if union else 0.0 203 | sample_ious.append(iou) 204 | 205 | # Calculate mean and variance 206 | mean_iou = np.mean(sample_ious) if sample_ious else 0.0 207 | # Use ddof=1 for unbiased estimate of the variance 208 | variance = np.var(sample_ious, ddof=1) if len(sample_ious) > 1 else 0.0 209 | 210 | return { 211 | "iou": mean_iou, 212 | "iou_variance": variance 213 | } 214 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/story/multihop.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating multi-hop reasoning capabilities in narrative comprehension. 2 | 3 | This module implements a task where models need to track item acquisitions across chapters 4 | in a narrative. The task tests the model's ability to: 5 | 1. Follow precise instructions 6 | 2. Extract relevant item information from context 7 | 3. Connect information across multiple chapters 8 | 4. Format responses according to a specified template 9 | """ 10 | 11 | import re 12 | from typing import List, Dict, Any, Tuple 13 | 14 | from sparse_frontier.tasks.abstract_task import AbstractTask 15 | from sparse_frontier.tasks.abstract_sample import AbstractSample 16 | from sparse_frontier.tasks.story.narrative import NarrativeGenerator 17 | from sparse_frontier.tasks.story.templates import TASK_INTRO 18 | from sparse_frontier.tasks.abstract_prompt import SINGLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE 19 | 20 | QUESTION = "What was the last item that the protagonist acquired before acquiring {target_item}?" 21 | 22 | ANSWER_FORMAT = "ITEM_NAME" 23 | 24 | EXTRA_INSTRUCTIONS = """ 25 | - Provide only the item name in the answer section. 26 | - Do not include articles like 'the' or 'a' in your answer. 27 | - The item name must be exactly as mentioned in the text. 28 | """.strip() 29 | 30 | PROMPT = PROMPT_TEMPLATE.format( 31 | task_intro=TASK_INTRO, 32 | question="{question}", 33 | context="{context}", 34 | answer_format=ANSWER_FORMAT, 35 | extra_instructions=EXTRA_INSTRUCTIONS 36 | ) 37 | 38 | 39 | class MultiHopSample(AbstractSample): 40 | """Represents a single multi-hop reasoning sample with narrative and item queries.""" 41 | 42 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 43 | """Generate the input text, gold answer and extra data for this sample.""" 44 | # Calculate prompt tokens 45 | prompt_tokens = len(self.tokenizer.text_to_tokens( 46 | PROMPT.format( 47 | context="", 48 | question=QUESTION.format(target_item="item_component_1 item_component_2 item_component_3"), 49 | ) 50 | )) 51 | 52 | # Calculate remaining tokens for narrative 53 | narrative_tokens = self.max_tokens - prompt_tokens 54 | 55 | # Generate narrative 56 | narrative_gen = NarrativeGenerator( 57 | tokenizer=self.tokenizer, 58 | sequence_length=narrative_tokens, 59 | random_obj=self.random_obj, 60 | protagonist_name=self.task_params.get('protagonist_name', "Arion") 61 | ) 62 | 63 | # Find chapters where items were bought 64 | chapters_with_items = [ 65 | ch for ch in narrative_gen.chapters 66 | if ch.bought_items 67 | ] 68 | 69 | assert len(chapters_with_items) >= 2, "Need at least two chapters with item purchases" 70 | 71 | # Select a random chapter with an item purchase (excluding first chapter with items) 72 | target_chapter = self.random_obj.choice(chapters_with_items[1:]) 73 | target_item = target_chapter.bought_items[0] 74 | 75 | # Find the previous chapter where an item was bought 76 | prev_chapters = [ch for ch in chapters_with_items if ch.chapter_id < target_chapter.chapter_id] 77 | prev_chapter = prev_chapters[-1] # Get the most recent previous chapter 78 | prev_item = prev_chapter.bought_items[0] 79 | 80 | # Generate question 81 | question = QUESTION.format(target_item=target_item) 82 | 83 | input_text = PROMPT.format( 84 | context=narrative_gen.compile_narrative(), 85 | question=question 86 | ) 87 | 88 | return input_text, prev_item, {} 89 | 90 | 91 | class MultiHopTask(AbstractTask): 92 | """Main task class for multi-hop reasoning evaluation in narratives.""" 93 | 94 | def __init__( 95 | self, 96 | protagonist_name: str = "Arion", 97 | **kwargs 98 | ) -> None: 99 | """Initialize the multi-hop task. 100 | 101 | Args: 102 | protagonist_name: Name of the story protagonist 103 | **kwargs: Additional arguments passed to AbstractTask 104 | """ 105 | super().__init__(**kwargs) 106 | self.task_params['protagonist_name'] = protagonist_name 107 | self.check_params() 108 | 109 | def check_params(self) -> None: 110 | """Validate task-specific parameters.""" 111 | if not isinstance(self.task_params['protagonist_name'], str): 112 | raise ValueError("protagonist_name must be a string") 113 | 114 | @property 115 | def sample_class(self): 116 | """Return the sample class for this task.""" 117 | return MultiHopSample 118 | 119 | @staticmethod 120 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 121 | """Evaluate model predictions against gold answers.""" 122 | def normalize_answer(text: str) -> str: 123 | """Normalize answer text for comparison.""" 124 | # Convert to lowercase and remove extra whitespace 125 | text = re.sub(r'\s+', ' ', text.lower().strip()) 126 | # Remove articles 127 | text = re.sub(r'^(the|a|an)\s+', '', text) 128 | return text 129 | 130 | def extract_answer(text: str) -> str: 131 | """Extract answer from the formatted response.""" 132 | # Find content between tags 133 | match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) 134 | if not match: 135 | return "" 136 | return normalize_answer(match.group(1)) 137 | 138 | import numpy as np 139 | sample_accuracies = [] 140 | 141 | for example in examples: 142 | gold = normalize_answer(example['gold_answer']) 143 | pred = extract_answer(example['pred']) 144 | 145 | if gold and pred: # Only evaluate if both answers are non-empty 146 | # Binary accuracy (1 if correct, 0 if incorrect) 147 | accuracy = 1.0 if pred == gold else 0.0 148 | sample_accuracies.append(accuracy) 149 | 150 | # Calculate mean and variance 151 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0 152 | # Use ddof=1 for unbiased estimate of the variance 153 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0 154 | 155 | return { 156 | 'accuracy': mean_accuracy, 157 | 'accuracy_variance': variance 158 | } 159 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/story/narrative.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set, Tuple 2 | 3 | from sparse_frontier.tasks.story.templates import NarrativeResources, ItemGenerator 4 | 5 | 6 | class Chapter: 7 | """Represents a single narrative chapter.""" 8 | def __init__(self, chapter_id: int, location: str, character: str, event: str, random_obj): 9 | self.chapter_id = chapter_id 10 | self.location = location 11 | self.event = event 12 | self.character = character 13 | self.random_obj = random_obj 14 | self.structure = { 15 | 'scene_introduction': [], 16 | 'encounter': [], 17 | 'conversation_extension': [], 18 | 'buying_transactions': [], 19 | 'selling_transactions': [], 20 | 'farewell': [], 21 | 'scene_conclusions': [], 22 | 'extra': [] 23 | } 24 | self.bought_items: List[str] = [] 25 | self.sold_items: List[str] = [] 26 | self.end_item_count: Optional[int] = None 27 | 28 | def compile_text(self) -> str: 29 | parts = ( 30 | self.structure['scene_introduction'] 31 | + self.structure['encounter'] 32 | + self.structure['conversation_extension'] 33 | + self.structure['buying_transactions'] 34 | + self.structure['selling_transactions'] 35 | + self.structure['farewell'] 36 | + self.structure['scene_conclusions'] 37 | + self.structure['extra'] 38 | ) 39 | return f"Chapter {self.chapter_id}:\n" + "".join(parts) 40 | 41 | def generate_chapter(self, protagonist_name: str, conversation_extensions: int, items_seen: Set[str], inventory: List[str]) -> None: 42 | """Generates a complete chapter with all narrative elements.""" 43 | # Scene introduction 44 | self.structure['scene_introduction'].append( 45 | NarrativeResources.choose(NarrativeResources.SCENE_INTRO_TEMPLATES, 46 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj) 47 | ) 48 | self.structure['scene_introduction'].append( 49 | NarrativeResources.choose(NarrativeResources.REASON_TEMPLATES, 50 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj) 51 | ) 52 | self.structure['scene_introduction'].append( 53 | NarrativeResources.choose(NarrativeResources.EVENT_TEMPLATES, 54 | protagonist=protagonist_name, event=self.event, location=self.location, random_obj=self.random_obj) 55 | ) 56 | 57 | # Encounter 58 | self.structure['encounter'].append( 59 | NarrativeResources.choose(NarrativeResources.CHAR_INTRO_TEMPLATES, 60 | protagonist=protagonist_name, character=self.character, random_obj=self.random_obj) 61 | ) 62 | 63 | # Conversation extensions 64 | self.structure['conversation_extension'].extend( 65 | NarrativeResources.choose_multiple( 66 | NarrativeResources.CONVERSATION_EXTENSION_TEMPLATES, 67 | conversation_extensions, 68 | protagonist=protagonist_name, 69 | location=self.location, 70 | random_obj=self.random_obj 71 | ) 72 | ) 73 | 74 | # Handle buying and selling 75 | new_item = ItemGenerator._get_unique_item(items_seen, self.random_obj) 76 | inventory.append(new_item) 77 | self.bought_items.append(new_item) 78 | self.structure['buying_transactions'].append( 79 | NarrativeResources.choose(NarrativeResources.BUYING_TEMPLATES, 80 | protagonist=protagonist_name, 81 | character=self.character, 82 | item=new_item, 83 | random_obj=self.random_obj) 84 | ) 85 | 86 | older_items = [it for it in inventory if it != new_item] 87 | if older_items and self.random_obj.random() < 0.5: 88 | sell_item = self.random_obj.choice(older_items) 89 | inventory.remove(sell_item) 90 | self.sold_items.append(sell_item) 91 | self.structure['selling_transactions'].append( 92 | NarrativeResources.choose(NarrativeResources.SELLING_TEMPLATES, 93 | protagonist=protagonist_name, 94 | character=self.character, 95 | item=sell_item, 96 | random_obj=self.random_obj) 97 | ) 98 | 99 | self.end_item_count = len(inventory) 100 | 101 | # Farewell 102 | self.structure['farewell'].append( 103 | NarrativeResources.choose(NarrativeResources.FAREWELL_TEMPLATES, 104 | protagonist=protagonist_name, character=self.character, random_obj=self.random_obj) 105 | ) 106 | 107 | # Scene conclusions 108 | self.structure['scene_conclusions'].append( 109 | NarrativeResources.choose(NarrativeResources.CONCLUSION_TEMPLATES, 110 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj) 111 | ) 112 | 113 | # Add extra sentences 114 | self.structure['extra'].extend( 115 | NarrativeResources.choose_multiple(NarrativeResources.EXTRA_TEMPLATES, 1, 116 | protagonist=protagonist_name, location=self.location, random_obj=self.random_obj) 117 | ) 118 | 119 | 120 | class NarrativeGenerator: 121 | """Generates a base narrative with a certain number of chapters.""" 122 | def __init__( 123 | self, 124 | tokenizer, 125 | sequence_length: int, 126 | random_obj, 127 | protagonist_name: str = "Arion", 128 | conversation_extensions: int = 3, 129 | ): 130 | self.tokenizer = tokenizer 131 | self.sequence_length = sequence_length 132 | self.protagonist_name = protagonist_name 133 | self.protagonist_inventory: List[str] = [] 134 | self.conversation_extensions = conversation_extensions 135 | self._used_pairs: Set[Tuple[str, str]] = set() 136 | self._items_seen: Set[str] = set() 137 | self.random_obj = random_obj 138 | 139 | self.chapters = self._generate_base_narrative() 140 | assert len(self.tokenizer.text_to_tokens(self.compile_narrative())) <= self.sequence_length, "Narrative is too long." 141 | 142 | def compile_narrative(self) -> str: 143 | return "\n\n".join(chapter.compile_text() for chapter in self.chapters) 144 | 145 | def _generate_base_narrative(self) -> List[Chapter]: 146 | chapters = [] 147 | total_tokens = 0 148 | chapter_id = 1 149 | 150 | while True: 151 | assert chapter_id <= 2000, "Limit reached." 152 | for _ in range(100): 153 | location = self.random_obj.choice(NarrativeResources.LOCATIONS) 154 | character = self.random_obj.choice(NarrativeResources.CHARACTERS) 155 | if (location, character) not in self._used_pairs: 156 | self._used_pairs.add((location, character)) 157 | break 158 | else: 159 | raise ValueError("Could not find a unique (location, character) pair.") 160 | 161 | event = self.random_obj.choice(NarrativeResources.EVENTS) 162 | chapter = Chapter(chapter_id, location, character, event, self.random_obj) 163 | 164 | chapter.generate_chapter( 165 | self.protagonist_name, 166 | self.conversation_extensions, 167 | self._items_seen, 168 | self.protagonist_inventory 169 | ) 170 | 171 | chapter_text = chapter.compile_text() 172 | chapter_tokens = len(self.tokenizer.text_to_tokens(f'\n\n{chapter_text}')) 173 | 174 | if total_tokens + chapter_tokens > self.sequence_length: 175 | break 176 | 177 | chapters.append(chapter) 178 | total_tokens += chapter_tokens 179 | chapter_id += 1 180 | 181 | return chapters 182 | -------------------------------------------------------------------------------- /sparse_frontier/tasks/story/retrieval.py: -------------------------------------------------------------------------------- 1 | """Task for evaluating story comprehension and information retrieval capabilities. 2 | 3 | This module implements a task where models need to extract specific information from 4 | narrative chapters. The task tests the model's ability to: 5 | 1. Comprehend multi-chapter narratives 6 | 2. Extract relevant information about locations, characters, and items 7 | 3. Format responses according to a specified template 8 | 4. Provide clear reasoning for its answers 9 | 5. Answer questions about specific details from the text 10 | """ 11 | 12 | import re 13 | from typing import List, Dict, Any, Tuple 14 | 15 | from sparse_frontier.tasks.abstract_task import AbstractTask 16 | from sparse_frontier.tasks.abstract_sample import AbstractSample 17 | from sparse_frontier.tasks.story.narrative import NarrativeGenerator 18 | from sparse_frontier.tasks.story.templates import TASK_INTRO 19 | from sparse_frontier.tasks.abstract_prompt import MULTIPLEQ_PROMPT_TEMPLATE as PROMPT_TEMPLATE 20 | 21 | ANSWER_FORMAT = """1. ANSWER_ONE 22 | 2. ANSWER_TWO 23 | etc.""" 24 | 25 | EXTRA_INSTRUCTIONS = """ 26 | - For answers, use one line per answer with the number prefix 27 | - Do not include articles like 'the' or 'a' in answers 28 | - Answers should be specific names/items/locations mentioned in the text 29 | """.strip() 30 | 31 | PROMPT = PROMPT_TEMPLATE.format( 32 | task_intro=TASK_INTRO, 33 | question="{questions}", 34 | context="{context}", 35 | answer_format=ANSWER_FORMAT, 36 | extra_instructions=EXTRA_INSTRUCTIONS 37 | ) 38 | 39 | 40 | class RetrievalSample(AbstractSample): 41 | SAFETY_TOKENS = 10 42 | 43 | def _generate_sample(self) -> Tuple[str, str, Dict[str, Any]]: 44 | num_queries = self.task_params['num_queries'] 45 | 46 | # Calculate tokens needed for prompt and instructions 47 | prompt_tokens = len(self.tokenizer.text_to_tokens( 48 | PROMPT.format(context="", questions="") 49 | )) 50 | 51 | # Estimate tokens for questions and explanation 52 | question_template = "XXX. In Chapter XXX, which specific location/character/item did the protagonist visit/meet/acquire?\n" 53 | question_tokens = len(self.tokenizer.text_to_tokens(question_template)) * num_queries * 2 54 | 55 | # Calculate remaining tokens for narrative 56 | narrative_tokens = self.max_tokens - (prompt_tokens + question_tokens + self.SAFETY_TOKENS) 57 | 58 | # Generate narrative with adjusted token limit 59 | narrative_gen = NarrativeGenerator( 60 | tokenizer=self.tokenizer, 61 | sequence_length=narrative_tokens, 62 | random_obj=self.random_obj, 63 | ) 64 | 65 | assert len(narrative_gen.chapters) >= num_queries, "Not enough chapters for the requested complexity." 66 | 67 | # Select random chapters and generate questions 68 | selected_chapters = self.random_obj.sample(narrative_gen.chapters, num_queries) 69 | questions_and_answers = [] 70 | 71 | if num_queries == 3: 72 | # Fixed query types for 3 questions with clearer phrasing 73 | query_types = ["location", "character", "item"] 74 | self.random_obj.shuffle(query_types) 75 | 76 | for i, (ch, query_type) in enumerate(zip(selected_chapters, query_types), start=1): 77 | if query_type == "location": 78 | q = f"{i}. In Chapter {ch.chapter_id}, which specific location did the protagonist visit?" 79 | a = ch.location 80 | elif query_type == "character": 81 | q = f"{i}. In Chapter {ch.chapter_id}, which character did the protagonist interact with?" 82 | a = ch.character 83 | else: # item 84 | q = f"{i}. In Chapter {ch.chapter_id}, which specific item was acquired by the protagonist?" 85 | a = ch.bought_items[0] if ch.bought_items else "None" 86 | 87 | questions_and_answers.append((i, q, a)) 88 | else: 89 | # Enhanced question generation for other numbers of queries 90 | for i, ch in enumerate(selected_chapters, start=1): 91 | query_type = self.random_obj.choice(["location", "character", "item"]) 92 | 93 | if query_type == "location": 94 | q = f"{i}. In Chapter {ch.chapter_id}, which specific location did the protagonist visit?" 95 | a = ch.location 96 | elif query_type == "character": 97 | q = f"{i}. In Chapter {ch.chapter_id}, which character did the protagonist interact with?" 98 | a = ch.character 99 | else: 100 | q = f"{i}. In Chapter {ch.chapter_id}, which specific item was acquired by the protagonist?" 101 | a = ch.bought_items[0] if ch.bought_items else "None" 102 | 103 | questions_and_answers.append((i, q, a)) 104 | 105 | input_text = PROMPT.format( 106 | context=narrative_gen.compile_narrative(), 107 | questions="\n".join(q for (_, q, _) in questions_and_answers) 108 | ) 109 | 110 | gold_answer = "\n".join(f"{i}. {a}" for (i, _, a) in questions_and_answers) 111 | 112 | return input_text, gold_answer, {} 113 | 114 | 115 | class RetrievalTask(AbstractTask): 116 | """Task class for evaluating story comprehension and information retrieval.""" 117 | 118 | def __init__( 119 | self, 120 | num_queries: int, 121 | protagonist_name: str = "Arion", 122 | **kwargs 123 | ) -> None: 124 | super().__init__(**kwargs) 125 | self.task_params['num_queries'] = num_queries 126 | self.task_params['protagonist_name'] = protagonist_name 127 | self.check_params() 128 | 129 | def check_params(self) -> None: 130 | """Validate task-specific parameters.""" 131 | if 'num_queries' not in self.task_params: 132 | raise ValueError("Missing required parameter 'num_queries'") 133 | 134 | if not isinstance(self.task_params['num_queries'], int): 135 | raise ValueError("Parameter 'num_queries' must be an integer") 136 | 137 | if self.task_params['num_queries'] < 1: 138 | raise AssertionError("Parameter 'num_queries' must be greater than 0") 139 | 140 | @property 141 | def sample_class(self): 142 | return RetrievalSample 143 | 144 | @staticmethod 145 | def evaluate(examples: List[Dict[str, Any]]) -> Dict[str, float]: 146 | """Evaluate model predictions against gold answers.""" 147 | def normalize_answer(text: str) -> str: 148 | """Normalize answer text for comparison.""" 149 | # Convert to lowercase and remove extra whitespace 150 | text = re.sub(r'\s+', ' ', text.lower().strip()) 151 | # Remove articles 152 | text = re.sub(r'^(the|a|an)\s+', '', text) 153 | return text 154 | 155 | def extract_answers(text: str) -> Dict[int, str]: 156 | """Extract answers from text, handling both formats.""" 157 | # First try to find the section 158 | answer_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE) 159 | if answer_match: 160 | text = answer_match.group(1) 161 | 162 | # Extract numbered answers 163 | answers = {} 164 | for line in text.split('\n'): 165 | match = re.match(r'^(\d+)[\.:\)]\s*(.+?)\s*$', line.strip()) 166 | if match: 167 | idx = int(match.group(1)) 168 | answer = normalize_answer(match.group(2)) 169 | if idx not in answers: # Take first occurrence if duplicates 170 | answers[idx] = answer 171 | return answers 172 | 173 | import numpy as np 174 | sample_accuracies = [] 175 | 176 | for example in examples: 177 | gold_answers = extract_answers(example['gold_answer']) 178 | pred_answers = extract_answers(example['pred']) 179 | 180 | if not gold_answers: # Skip examples without gold answers 181 | continue 182 | 183 | correct = 0 184 | total = len(gold_answers) 185 | 186 | for idx, gold in gold_answers.items(): 187 | if idx in pred_answers and pred_answers[idx] == gold: 188 | correct += 1 189 | 190 | sample_accuracies.append(correct / total if total > 0 else 0.0) 191 | 192 | # Calculate mean and variance 193 | mean_accuracy = np.mean(sample_accuracies) if sample_accuracies else 0.0 194 | # Use ddof=1 for unbiased estimate of the variance 195 | variance = np.var(sample_accuracies, ddof=1) if len(sample_accuracies) > 1 else 0.0 196 | 197 | return { 198 | 'accuracy': mean_accuracy, 199 | 'accuracy_variance': variance 200 | } 201 | -------------------------------------------------------------------------------- /sparse_frontier/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .globals import GlobalSettings 2 | 3 | __all__ = [ 4 | "GlobalSettings", 5 | ] 6 | -------------------------------------------------------------------------------- /sparse_frontier/utils/checks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sparse_frontier.utils.globals import GlobalSettings 4 | from sparse_frontier.utils.data import ( 5 | read_jsonl, 6 | get_data_path, 7 | get_pred_path, 8 | get_results_path, 9 | ) 10 | 11 | 12 | def prepration_needed(): 13 | cfg = GlobalSettings.get("cfg") 14 | data_path = get_data_path() 15 | 16 | if not os.path.exists(data_path): 17 | return True 18 | 19 | if cfg.overwrite: 20 | os.remove(data_path) 21 | return True 22 | 23 | data = read_jsonl(data_path) 24 | return (len(data) < cfg.samples) 25 | 26 | 27 | def prediction_needed(): 28 | cfg = GlobalSettings.get("cfg") 29 | pred_path = get_pred_path() 30 | 31 | if not os.path.exists(pred_path): 32 | return True 33 | 34 | if cfg.overwrite: 35 | os.remove(pred_path) 36 | return True 37 | 38 | data = read_jsonl(pred_path) 39 | return (len(data) < cfg.samples) 40 | 41 | 42 | def evaluation_needed(): 43 | cfg = GlobalSettings.get("cfg") 44 | results_path = get_results_path() 45 | 46 | if not os.path.exists(results_path): 47 | return True 48 | 49 | if cfg.overwrite: 50 | os.remove(results_path) 51 | return True 52 | 53 | return False 54 | -------------------------------------------------------------------------------- /sparse_frontier/utils/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import List, Union 6 | 7 | from sparse_frontier.utils.globals import GlobalSettings 8 | 9 | 10 | def read_jsonl(manifest: Union[Path, str]) -> List[dict]: 11 | """Read and parse a JSONL file into a list of dictionaries. 12 | 13 | Args: 14 | manifest: Path to JSONL file to read 15 | 16 | Returns: 17 | List of dictionaries parsed from JSONL 18 | 19 | Raises: 20 | json.JSONDecodeError: If JSONL parsing fails 21 | Exception: If file cannot be read 22 | """ 23 | try: 24 | with open(manifest, 'r', encoding='utf-8') as f: 25 | return [json.loads(line) for line in f if line.strip()] 26 | except json.JSONDecodeError as e: 27 | logging.error(f"Failed to parse line in manifest file {manifest}: {e}") 28 | raise 29 | except Exception as e: 30 | raise Exception(f"Could not read manifest file {manifest}") from e 31 | 32 | 33 | def write_jsonl(output_path: Union[Path, str], data: List[dict]) -> None: 34 | """Write a list of dictionaries to a JSONL file. 35 | 36 | Args: 37 | output_path: Path to output JSONL file 38 | data: List of dictionaries to serialize 39 | """ 40 | with open(output_path, "w", encoding="utf-8") as f: 41 | for item in data: 42 | f.write(json.dumps(item) + '\n') 43 | 44 | 45 | def build_task_params_str(task_args, max_input_tokens) -> str: 46 | """Build descriptive string from task args and max_input_tokens. 47 | 48 | Args: 49 | task_args: List of tuples containing task args 50 | max_input_tokens: Maximum input tokens 51 | 52 | Returns: 53 | String with parameters in format: param1@value1+param2@value2 54 | """ 55 | params = [] 56 | for arg_name, arg_value in task_args: 57 | params.append(f"{arg_name}@{arg_value}") 58 | params.append(f"max_input_tokens@{max_input_tokens}") 59 | return "+".join(params) 60 | 61 | 62 | def build_attn_params_str(attn_args) -> str: 63 | """Build descriptive string for attention config. 64 | 65 | Args: 66 | attn_args: List of tuples containing attention args 67 | 68 | Returns: 69 | String with parameters in format: param1@value1+param2@value2 or 'default' 70 | """ 71 | attn_params = [] 72 | for arg_name, arg_value in attn_args: 73 | attn_params.append(f"{arg_name}@{arg_value}") 74 | return "+".join(attn_params) if attn_params else "default" 75 | 76 | 77 | def get_pred_dir() -> str: 78 | """Get the path to the predictions directory based on config settings. 79 | The path uses @ to separate parameter names from values and + to join parameters, 80 | which avoids ambiguities and potential issues with bash interpretation. 81 | 82 | Returns: 83 | Path string constructed from config parameters including task args 84 | """ 85 | cfg = GlobalSettings.get('cfg') 86 | 87 | return os.path.join( 88 | cfg.paths.debug if cfg.debug else cfg.paths.predictions, 89 | cfg.task.name, 90 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens), 91 | cfg.model.name, 92 | cfg.attention.name, 93 | build_attn_params_str(cfg.attention.get('args', {}).items()), 94 | ) 95 | 96 | 97 | def get_results_dir() -> str: 98 | """Get the path to the results directory based on config settings. 99 | The path uses @ to separate parameter names from values and + to join parameters, 100 | which avoids ambiguities and potential issues with bash interpretation. 101 | 102 | Returns: 103 | Path string constructed from config parameters including task args 104 | """ 105 | cfg = GlobalSettings.get('cfg') 106 | 107 | return os.path.join( 108 | cfg.paths.debug if cfg.debug else cfg.paths.results, 109 | cfg.task.name, 110 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens), 111 | cfg.model.name, 112 | cfg.attention.name, 113 | build_attn_params_str(cfg.attention.get('args', {}).items()), 114 | ) 115 | 116 | 117 | def get_data_dir() -> str: 118 | cfg = GlobalSettings.get('cfg') 119 | 120 | return os.path.join( 121 | cfg.paths.debug if cfg.debug else cfg.paths.data, 122 | cfg.task.name, 123 | build_task_params_str(cfg.task.get('args', {}).items(), cfg.max_input_tokens), 124 | cfg.model.name, 125 | ) 126 | 127 | 128 | def get_data_path() -> str: 129 | """Get the path to the task's data file. 130 | 131 | Returns: 132 | Path to data.jsonl in the task directory 133 | """ 134 | return os.path.join(get_data_dir(), 'data.jsonl') 135 | 136 | 137 | def get_pred_path() -> str: 138 | """Get the path to the task's predictions file. 139 | 140 | Returns: 141 | Path to pred.jsonl in the task directory 142 | """ 143 | return os.path.join(get_pred_dir(), 'pred.jsonl') 144 | 145 | 146 | def get_results_path() -> str: 147 | """Get the path to the task's evaluation results file. 148 | 149 | Returns: 150 | Path to evaluation_results.json in the task directory 151 | """ 152 | 153 | cfg = GlobalSettings.get('cfg') 154 | return os.path.join(get_results_dir(), f'evaluation_results_{cfg.samples}.json') 155 | 156 | 157 | def load_data_without_predictions() -> List[dict]: 158 | """Load task data excluding samples that already have predictions. 159 | Only returns samples with indexes up to cfg.samples. 160 | 161 | Returns: 162 | List of data samples that haven't been predicted yet, limited by index <= cfg.samples. 163 | """ 164 | cfg = GlobalSettings.get('cfg') 165 | data_path = get_data_path() 166 | pred_path = get_pred_path() 167 | 168 | if os.path.exists(pred_path): 169 | pred_index = {sample['index'] for sample in read_jsonl(pred_path)} 170 | data = [sample for sample in read_jsonl(data_path) 171 | if sample['index'] not in pred_index and sample['index'] < cfg.samples] 172 | else: 173 | data = [sample for sample in read_jsonl(data_path) 174 | if sample['index'] < cfg.samples] 175 | 176 | return data 177 | -------------------------------------------------------------------------------- /sparse_frontier/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import socket 4 | 5 | from sparse_frontier.utils.globals import GlobalSettings 6 | 7 | 8 | def get_free_ports(n: int) -> list[int]: 9 | """Find N free ports on the local machine.""" 10 | free_ports = [] 11 | sockets = [] 12 | 13 | try: 14 | for _ in range(n): 15 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 16 | s.bind(('localhost', 0)) # Bind to an available port 17 | free_ports.append(s.getsockname()[1]) # Get the assigned port number 18 | sockets.append(s) # Keep the socket open to reserve the port 19 | finally: 20 | # Close all sockets to release the ports 21 | for s in sockets: 22 | s.close() 23 | 24 | return free_ports 25 | 26 | 27 | def get_latest_commit_id(): 28 | try: 29 | import git 30 | repo = git.Repo(search_parent_directories=True) 31 | return repo.head.object.hexsha 32 | except Exception: 33 | return None 34 | 35 | 36 | def save_config(dir_path: str): 37 | from omegaconf import OmegaConf 38 | from datetime import datetime 39 | 40 | cfg = GlobalSettings.get('cfg') 41 | 42 | config_dict = OmegaConf.to_container(cfg, resolve=True) 43 | config_dict['commit_id'] = get_latest_commit_id() 44 | 45 | # Add timestamp 46 | config_dict['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 47 | 48 | config_path = os.path.join(dir_path, "config.json") 49 | with open(config_path, "w") as f: 50 | json.dump(config_dict, f, indent=2) 51 | -------------------------------------------------------------------------------- /sparse_frontier/utils/globals.py: -------------------------------------------------------------------------------- 1 | class GlobalSettings: 2 | _settings = {} 3 | 4 | @classmethod 5 | def get(cls, key, default=None): 6 | return cls._settings.get(key, default) 7 | 8 | @classmethod 9 | def set(cls, key, value): 10 | cls._settings[key] = value 11 | 12 | 13 | def is_vllm_profiling_done(): 14 | return GlobalSettings.get("vllm_profiling_done", False) 15 | 16 | 17 | def set_vllm_profiling_done(done: bool): 18 | GlobalSettings.set("vllm_profiling_done", done) 19 | --------------------------------------------------------------------------------