├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── Transparency_FAQ.md ├── attn_hub ├── __init__.py ├── flash_attn.py └── retroinfer_attn.py ├── benchmark ├── LongBench │ ├── config │ │ ├── dataset2maxlen.json │ │ ├── dataset2prompt.json │ │ ├── model2maxlen.json │ │ └── model2path.json │ ├── eval.py │ ├── longbench_run.sh │ ├── metrics.py │ ├── pred.py │ └── pred.sh ├── config.py └── ruler │ ├── data │ ├── prepare.py │ ├── synthetic │ │ ├── common_words_extraction.py │ │ ├── constants.py │ │ ├── freq_words_extraction.py │ │ ├── json │ │ │ ├── PaulGrahamEssays_URLs.txt │ │ │ ├── download_paulgraham_essay.py │ │ │ └── download_qa_dataset.sh │ │ ├── niah.py │ │ ├── qa.py │ │ └── variable_tracking.py │ ├── tokenizer.py │ └── utils.py │ ├── eval │ ├── evaluate.py │ ├── synthetic │ │ └── constants.py │ └── utils.py │ ├── pred │ ├── call_api.py │ └── utils.py │ ├── ruler_config_models.sh │ ├── ruler_config_tasks.sh │ ├── ruler_run.sh │ └── synthetic.yaml ├── cache_hub ├── __init__.py ├── cache.py ├── flash_attn_cache.py ├── kmeans.py └── retroinfer_cache.py ├── config ├── Llama-3-8B-Instruct-Gradient-1048k.json ├── Llama-3.1-8B-Instruct.json ├── Qwen2.5-72B-Instruct.json └── Qwen2.5-7B-Instruct.json ├── library └── retroinfer │ ├── retroinfer_kernels │ ├── __init__.py │ └── src │ │ ├── batch_gemm_softmax.cu │ │ ├── batch_gemm_softmax.h │ │ ├── batch_gemm_with_epilogue_visitor.h │ │ ├── copy_kernel.cuh │ │ ├── gather_copy.cu │ │ ├── thread_pool.hpp │ │ └── wave_buffer_cpu.cpp │ ├── setup.py │ └── test │ ├── test_batch_gemm_softmax.py │ └── test_gather_copy.py ├── model_hub ├── LLM.py ├── __init__.py ├── llama.py └── qwen.py ├── requirements.txt ├── simple_test.py ├── simple_test_data.json └── throughput_eval ├── run.sh ├── run_different_lengths.sh ├── run_different_models.sh ├── run_different_tasks.sh ├── test.py └── test_data ├── NIAH_1024000.json ├── NIAH_120000.json ├── NIAH_240000.json ├── NIAH_480000.json ├── NIAH_60000.json ├── fwe.json ├── qa1.json └── vt.json /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | 400 | .huggingface 401 | *.egg-info 402 | build/ 403 | dist/ 404 | library/cutlass 405 | benchmark/LongBench/results/ 406 | benchmark/ruler/data/synthetic/json/*.json 407 | benchmark/ruler/ruler_eval_result/ 408 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RetroInfer 2 | 3 | [RetroInfer](https://arxiv.org/pdf/2505.02922) is a novel system that **rethinks the KV cache as vector storage** within a GPU–CPU co-execution setup to accelerate long-context LLM inference. It exploits the inherent sparsity of the attention mechanism and introduces an **A**ttention-a**W**are **VE**ctor index (*wave index*) that enables efficient and accurate retrieval of critical tokens from the KV cache. Complementing this is the *wave buffer*, which coordinates KV cache placement and overlaps computation and data transfer across GPU and CPU to sustain high throughput. 4 | 5 | ## Getting Started 6 | 7 | ### Environment Setup 8 | The required dependency packages rely on `CUDA 12.4`, you can use the docker image `nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04` if your system does not have CUDA 12.4 installed. 9 | 10 | The code was tested with `Python 3.10.16`, we recommend using `conda` to mange your Python environments: 11 | ```bash 12 | # firstly install miniconda if you don't have it, then create a new conda environment: 13 | conda create -n retroinfer python=3.10 -y 14 | conda activate retroinfer 15 | 16 | # install conda packages 17 | conda install -y mkl 18 | conda install -c conda-forge libstdcxx-ng -y 19 | 20 | # may need to downgrade pip to <=25.0 to solve `DEPRECATION warning` when using `pip install .` to install kernels 21 | python -m pip install pip==25.0 22 | 23 | # install python packages 24 | pip install -r requirements.txt 25 | pip install flash-attn==2.6.3 --no-build-isolation 26 | pip install flashinfer-python==0.2.4 -i https://flashinfer.ai/whl/cu124/torch2.4/ 27 | pip install git+https://github.com/Starmys/flash-attention.git@weighted 28 | ``` 29 | 30 | ### Install RetroInfer Kernels 31 | ```bash 32 | cd library/ 33 | git clone https://github.com/NVIDIA/cutlass.git 34 | cd retroinfer && pip install . && cd ../../ 35 | ``` 36 | 37 | ### Simple Test 38 | We provide a simple demo to verify that the environment is set up correctly. The demo runs on four different contexts from [RULER](https://github.com/NVIDIA/RULER), each containing approximately 120,000 tokens. You can run the demo using the following command: 39 | ```bash 40 | python -u simple_test.py --batch_size 4 41 | ``` 42 | Running this demo requires about 25GB GPU memory and 67GB CPU memory. If you encounter out-of-memory errors, consider reducing the batch size. 43 | 44 | You can also customize the input contexts by providing a `json` file in the following format: 45 | ``` 46 | [ 47 | {"input": str, "outputs": str}, 48 | {"input": str, "outputs": str}, 49 | ... 50 | ] 51 | ``` 52 | Then, pass the file path using the `--data_path` argument: 53 | ```bash 54 | python -u simple_test.py --data_path 55 | ``` 56 | 57 | ## Run Benchmark 58 | 59 | You may need to set `CUDA_VISIBLE_DEVICES` before running the benchmark since our code will automatically split models into all visable GPUs. For example, when evaluating with A100 80GB, 7B/8B models only need one GPU card while 72B models need at least 3 GPU cards. 60 | 61 | ### [RULER](https://github.com/NVIDIA/RULER) 62 | To evaluate the model accuracy on the RULER benchmark, you need firstly download the benchmark datasets: 63 | ```bash 64 | cd benchmark/ruler 65 | cd data/synthetic/json/ && python -u download_paulgraham_essay.py && bash download_qa_dataset.sh && cd ../../../ 66 | ``` 67 | Then, you can run [ruler_run.sh](benchmark/ruler/ruler_run.sh) to evaluate. 68 | For example, you can evaluate `RetroInfer` on RULER variable tracing task `vt` with the context length of `128K` using the following command: 69 | ```bash 70 | bash ruler_run.sh llama-3-8b-1048k synthetic RetroInfer 131072 vt bf16 0.018 0.232 71 | ``` 72 | The input parameters of the evaluation script are, in order: 73 | - `model name`: supported models include `llama-3.1-8b`, `llama-3-8b-1048k`, `qwen2.5-7b` and `qwen2.5-72b`; 74 | - `benchmark name`: set to `synthetic`; 75 | - `attention type`: `RetroInfer` or `Full_Flash_Attn`; 76 | - `input context length`: the input context length; 77 | - `evaluate task name`: supported tasks include `niah_single_1`, `niah_single_2`, `niah_single_3`, `niah_multikey_1`, `niah_multikey_2`, `niah_multikey_3`, `niah_multivalue`, `niah_multiquery`, `vt`, `cwe`, `fwe`, `qa_1` and `qa_2`; 78 | - `model data type`: supported data types include `bf16` and `fp16`; 79 | - `retrieval budget ratio`: the ratio of the number of tokens to be retrieved from the KV cache to the total number of tokens in the input context; 80 | - `attention estimate ratio`: the ratio of the number of clusters to be estimated in the attention mechanism to the total number of clusters. 81 | 82 | ### [LongBench](https://github.com/THUDM/LongBench) 83 | You can use the following command to evaluate the model accuracy of `RetroInfer` on the LongBench: 84 | ```bash 85 | cd benchmark/LongBench 86 | bash longbench_run.sh llama-3-8b-1048k RetroInfer 0.018 0.232 bf16 87 | ``` 88 | The input parameters of the evaluation script are, in order: 89 | - `model name`: supported models include `llama-3.1-8b`, `llama-3-8b-1048k`, `qwen2.5-7b` and `qwen2.5-72b`; 90 | - `attention type`: `RetroInfer` or `Full_Flash_Attn`; 91 | - `retrieval budget ratio`: the ratio of the number of tokens to be retrieved from the KV cache to the total number of tokens in the input context; 92 | - `attention estimate ratio`: the ratio of the number of clusters to be estimated in the attention mechanism to the total number of clusters; 93 | - `model data type`: supported data types include `bf16` and `fp16`. 94 | 95 | ## Reproduce Throughput Results 96 | We provide scripts to reproduce the throughput results reported in the paper. These experiments were conducted on an [Azure virtual machine](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/ndma100v4-series?tabs=sizebasic) featuring 4 NUMA nodes. Each NUMA node is equipped with 24 CPU cores, 475 GB of CPU memory, and two 80GB A100 GPUs. 97 | ```bash 98 | # Firstly, install the numactl package 99 | sudo apt install numactl -y 100 | 101 | # run scripts 102 | cd throughput_eval 103 | bash run.sh 104 | ``` 105 | 106 | ## Reference 107 | 108 | If you find this project helpful, please cite our paper: 109 | ```bibtex 110 | @misc{chen2025retroinfervectorstorageapproachscalable, 111 | title={RetroInfer: A Vector-Storage Approach for Scalable Long-Context LLM Inference}, 112 | author={Yaoqi Chen and Jinkai Zhang and Baotong Lu and Qianxi Zhang and Chengruidong Zhang and Jingjia Luo and Di Liu and Huiqiang Jiang and Qi Chen and Jing Liu and Bailu Ding and Xiao Yan and Jiawei Jiang and Chen Chen and Mingxing Zhang and Yuqing Yang and Fan Yang and Mao Yang}, 113 | year={2025}, 114 | eprint={2505.02922}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.LG}, 117 | url={https://arxiv.org/abs/2505.02922}, 118 | } 119 | 120 | @misc{liu2024retrievalattentionacceleratinglongcontextllm, 121 | title={RetrievalAttention: Accelerating Long-Context LLM Inference via Vector Retrieval}, 122 | author={Di Liu and Meng Chen and Baotong Lu and Huiqiang Jiang and Zhenhua Han and Qianxi Zhang and Qi Chen and Chengruidong Zhang and Bailu Ding and Kai Zhang and Chen Chen and Fan Yang and Yuqing Yang and Lili Qiu}, 123 | year={2024}, 124 | eprint={2409.10516}, 125 | archivePrefix={arXiv}, 126 | primaryClass={cs.LG}, 127 | url={https://arxiv.org/abs/2409.10516}, 128 | } 129 | ``` 130 | 131 | ## Contributing 132 | 133 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 134 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 135 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 136 | 137 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 138 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 139 | provided by the bot. You will only need to do this once across all repos using our CLA. 140 | 141 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 142 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 143 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 144 | 145 | ## Trademarks 146 | 147 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 148 | trademarks or logos is subject to and must follow 149 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 150 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 151 | Any use of third-party trademarks or logos are subject to those third-party's policies. 152 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please refer the [document](./README.md). 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /Transparency_FAQ.md: -------------------------------------------------------------------------------- 1 | # RetroInfer: Responsible AI FAQ 2 | 3 | ## What is RetroInfer? 4 | - RetroInfer is a system that rethinks the KV cache as vector storage within a GPU–CPU co-execution setup to accelerate long-context Large Language Model (LLM) inference. It exploits the inherent sparsity of the attention mechanism and introduces an **A**ttention-a**W**are **VE**ctor index (*wave index*) that enables efficient and accurate retrieval of critical tokens from the KV cache. Complementing this is the *wave buffer*, which coordinates KV cache placement and overlaps computation and data transfer across GPU and CPU to sustain high throughput. It achieves 4.5x–10.5x improvements of the decoding throughput over FlashAttention——without accuracy loss. 5 | 6 | ## What can RetroInfer do? 7 | - RetroInfer can effectively improve the decoding throughput of generating LLMs in long-context scenarios, with minimal impact on model accuracy. 8 | 9 | ## What are RetroInfer’s intended use(s)? 10 | - RetroInfer is intended for LLM deployers and users who need to manage long-context scenarios efficiently. 11 | 12 | ## How was RetroInfer evaluated? What metrics are used to measure performance? 13 | - We evaluated RetroInfer using state-of-the-art long-context benchmarks, including RULER, LongBench, and Needle in a Haystack, and their respective evaluation metrics. 14 | - Extensive testing was conducted across various scenarios, such as multi-needle, multi-hop tracing, multi-document QA, single-document QA, code completion, few-shot learning, etc. The results showed almost no change in accuracy. 15 | 16 | ## What are the limitations of RetroInfer? How can users minimize the impact of RetroInfer’s limitations when using the system? 17 | - Potentially harmful, false, or biased responses generated by LLMs are likely unchanged with RetroInfer. As a result, using RetroInfer does not inherently mitigate or exacerbate these responsible AI concerns. 18 | - RetroInfer was developed for research and experimental purposes. Further testing and validation are needed before considering its application in real-world scenarios. 19 | 20 | ## What operational factors and settings allow for effective and responsible use of RetroInfer? 21 | - Users can adjust parameters such as retrieval budget ratio and attention estimate ratio when using RetroInfer. Once configured, RetroInfer can effectively enhance LLM response generation in long-context scenarios. 22 | -------------------------------------------------------------------------------- /attn_hub/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .flash_attn import prefill_full_flash_attn, decode_full_flash_attn 3 | from .retroinfer_attn import retroinfer_prefill_attn, retroinfer_decode_attn -------------------------------------------------------------------------------- /attn_hub/flash_attn.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_with_kvcache 2 | 3 | 4 | def prefill_full_flash_attn(query_states, key_states, value_states, causal): 5 | 6 | attn_out = flash_attn_with_kvcache( 7 | q=query_states, 8 | k_cache=key_states, 9 | v_cache=value_states, 10 | causal=causal 11 | ) 12 | 13 | return attn_out 14 | 15 | 16 | 17 | def decode_full_flash_attn(query_states, key_states, value_states, layer_idx, full_attn_cache): 18 | 19 | valid_len = full_attn_cache.valid_length + 1 if layer_idx < full_attn_cache.layer_num - 1 else full_attn_cache.valid_length 20 | 21 | attn_out = flash_attn_with_kvcache( 22 | q=query_states, 23 | k_cache=key_states, 24 | v_cache=value_states, 25 | cache_seqlens=valid_len, 26 | ) 27 | 28 | return attn_out 29 | -------------------------------------------------------------------------------- /attn_hub/retroinfer_attn.py: -------------------------------------------------------------------------------- 1 | from flash_attn import flash_attn_with_kvcache 2 | 3 | 4 | def retroinfer_prefill_attn(query_states, key_states, value_states, causal): 5 | 6 | attn_out = flash_attn_with_kvcache( 7 | q=query_states, 8 | k_cache=key_states, 9 | v_cache=value_states, 10 | causal=causal 11 | ) 12 | 13 | return attn_out 14 | 15 | 16 | 17 | def retroinfer_decode_attn(query_states, key_states, value_states, layer_idx, retroinfer_cache): 18 | 19 | attn_out = retroinfer_cache.compute( 20 | query_states.contiguous(), layer_idx 21 | ) 22 | 23 | return attn_out 24 | -------------------------------------------------------------------------------- /benchmark/LongBench/config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64 23 | } -------------------------------------------------------------------------------- /benchmark/LongBench/config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /benchmark/LongBench/config/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama-3-8b-1048k": 130000, 3 | "qwen2.5-7b": 130000, 4 | "llama-3.1-8b": 130000, 5 | "qwen2.5-72b": 130000 6 | } 7 | -------------------------------------------------------------------------------- /benchmark/LongBench/config/model2path.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama-3-8b-1048k": "gradientai/Llama-3-8B-Instruct-Gradient-1048k", 3 | "llama-3.1-8b": "meta-llama/Llama-3.1-8B-Instruct", 4 | "qwen2.5-7b": "Qwen/Qwen2.5-7B-Instruct", 5 | "qwen2.5-72b": "Qwen/Qwen2.5-72B-Instruct" 6 | } 7 | -------------------------------------------------------------------------------- /benchmark/LongBench/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from metrics import ( 7 | qa_f1_score, 8 | rouge_zh_score, 9 | qa_f1_zh_score, 10 | rouge_score, 11 | classification_score, 12 | retrieval_score, 13 | retrieval_zh_score, 14 | count_score, 15 | code_sim_score, 16 | ) 17 | 18 | dataset2metric = { 19 | "narrativeqa": qa_f1_score, 20 | "qasper": qa_f1_score, 21 | "multifieldqa_en": qa_f1_score, 22 | "multifieldqa_zh": qa_f1_zh_score, 23 | "hotpotqa": qa_f1_score, 24 | "2wikimqa": qa_f1_score, 25 | "musique": qa_f1_score, 26 | "dureader": rouge_zh_score, 27 | "gov_report": rouge_score, 28 | "qmsum": rouge_score, 29 | "multi_news": rouge_score, 30 | "vcsum": rouge_zh_score, 31 | "trec": classification_score, 32 | "triviaqa": qa_f1_score, 33 | "samsum": rouge_score, 34 | "lsht": classification_score, 35 | "passage_retrieval_en": retrieval_score, 36 | "passage_count": count_score, 37 | "passage_retrieval_zh": retrieval_zh_score, 38 | "lcc": code_sim_score, 39 | "repobench-p": code_sim_score, 40 | } 41 | 42 | def parse_args(args=None): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--model', type=str, default=None, choices= 45 | ["llama-3-8b-1048k", "qwen2.5-7b", "llama-3.1-8b", "qwen2.5-72b"]) 46 | parser.add_argument("--attn_type", type=str, default="Full_Flash_Attn", \ 47 | choices=["Full_Flash_Attn", "RetroInfer"], \ 48 | help="Attention method") 49 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 50 | return parser.parse_args(args) 51 | 52 | def scorer_e(dataset, predictions, answers, lengths, all_classes): 53 | scores = {"0-4k": [], "4-8k": [], "8k+": []} 54 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths): 55 | score = 0. 56 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 57 | prediction = prediction.lstrip('\n').split('\n')[0] 58 | for ground_truth in ground_truths: 59 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 60 | if length < 4000: 61 | scores["0-4k"].append(score) 62 | elif length < 8000: 63 | scores["4-8k"].append(score) 64 | else: 65 | scores["8k+"].append(score) 66 | for key in scores.keys(): 67 | scores[key] = round(100 * np.mean(scores[key]), 2) 68 | return scores 69 | 70 | def scorer(dataset, predictions, answers, all_classes): 71 | total_score = 0. 72 | for (prediction, ground_truths) in zip(predictions, answers): 73 | score = 0. 74 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 75 | prediction = prediction.lstrip('\n').split('\n')[0] 76 | for ground_truth in ground_truths: 77 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 78 | total_score += score 79 | return round(100 * total_score / len(predictions), 2) 80 | 81 | 82 | if __name__ == '__main__': 83 | args = parse_args() 84 | attn_type = args.attn_type 85 | model_name = args.model # not hf model path 86 | 87 | model2path = json.load(open("config/model2path.json", "r")) 88 | model_path = model2path[model_name] 89 | 90 | if args.e: 91 | path = f"results/pred_e/{args.model}/{attn_type}/" 92 | else: 93 | path = f"results/pred/{args.model}/{attn_type}/" 94 | 95 | scores = dict() 96 | all_files = os.listdir(path) 97 | print("Evaluating on:", all_files) 98 | for filename in all_files: 99 | if not filename.endswith("jsonl"): 100 | continue 101 | predictions, answers, lengths = [], [], [] 102 | dataset = filename.split('.')[0] 103 | with open(f"{path}{filename}", "r", encoding="utf-8") as f: 104 | for line in f: 105 | data = json.loads(line) 106 | predictions.append(data["pred"]) 107 | answers.append(data["answers"]) 108 | all_classes = data["all_classes"] 109 | if "length" in data: 110 | lengths.append(data["length"]) 111 | if args.e: 112 | score = scorer_e(dataset, predictions, answers, lengths, all_classes) 113 | else: 114 | score = scorer(dataset, predictions, answers, all_classes) 115 | scores[dataset] = score 116 | 117 | 118 | out_path = f"{path}result.json" 119 | 120 | with open(out_path, "w") as f: 121 | json.dump(scores, f, ensure_ascii=False, indent=4) 122 | -------------------------------------------------------------------------------- /benchmark/LongBench/longbench_run.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | if [ $# -ne 5 ]; then 4 | echo "Usage: $0 $1 $2 $3 $4 " 5 | exit 1 6 | fi 7 | 8 | MODEL=${1} 9 | ATTN_TYPE=${2} 10 | BUDGET_RATIO=${3} 11 | ESTIMATE_RATIO=${4} 12 | DTYPE=${5} 13 | 14 | RESULT_DIR="./results/pred/${MODEL}/${ATTN_TYPE}" 15 | 16 | tasks=(qasper repobench-p lcc gov_report triviaqa) 17 | 18 | for task in "${tasks[@]}"; do 19 | echo "Parameters: ${MODEL} ${task} ${ATTN_TYPE} ${DTYPE} ${BUDGET_RATIO} ${ESTIMATE_RATIO}" 20 | bash pred.sh ${MODEL} ${task} ${ATTN_TYPE} ${DTYPE} ${BUDGET_RATIO} ${ESTIMATE_RATIO} 21 | done 22 | 23 | echo "Start to evaluate..." 24 | python -u eval.py \ 25 | --attn_type ${ATTN_TYPE} \ 26 | --model ${MODEL} \ 27 | 28 | echo "Results:" 29 | cat "${RESULT_DIR}/result.json" 30 | -------------------------------------------------------------------------------- /benchmark/LongBench/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def normalize_zh_answer(s): 32 | """Lower text and remove punctuation, extra whitespace.""" 33 | 34 | def white_space_fix(text): 35 | return "".join(text.split()) 36 | 37 | def remove_punc(text): 38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 39 | all_punctuation = set(string.punctuation + cn_punctuation) 40 | return "".join(ch for ch in text if ch not in all_punctuation) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_punc(lower(s))) 46 | 47 | def count_score(prediction, ground_truth, **kwargs): 48 | numbers = re.findall(r"\d+", prediction) 49 | right_num = 0 50 | for number in numbers: 51 | if str(number) == str(ground_truth): 52 | right_num += 1 53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 54 | return float(final_score) 55 | 56 | def retrieval_score(prediction, ground_truth, **kwargs): 57 | pattern = r'Paragraph (\d+)' 58 | matches = re.findall(pattern, ground_truth) 59 | ground_truth_id = matches[0] 60 | numbers = re.findall(r"\d+", prediction) 61 | right_num = 0 62 | for number in numbers: 63 | if str(number) == str(ground_truth_id): 64 | right_num += 1 65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 66 | return float(final_score) 67 | 68 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 69 | pattern = r'段落(\d+)' 70 | matches = re.findall(pattern, ground_truth) 71 | ground_truth_id = matches[0] 72 | numbers = re.findall(r"\d+", prediction) 73 | right_num = 0 74 | for number in numbers: 75 | if str(number) == str(ground_truth_id): 76 | right_num += 1 77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 78 | return float(final_score) 79 | 80 | def code_sim_score(prediction, ground_truth, **kwargs): 81 | all_lines = prediction.lstrip('\n').split('\n') 82 | prediction = "" 83 | for line in all_lines: 84 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 85 | prediction = line 86 | break 87 | return (fuzz.ratio(prediction, ground_truth) / 100) 88 | 89 | def classification_score(prediction, ground_truth, **kwargs): 90 | em_match_list = [] 91 | all_classes = kwargs["all_classes"] 92 | for class_name in all_classes: 93 | if class_name in prediction: 94 | em_match_list.append(class_name) 95 | for match_term in em_match_list: 96 | if match_term in ground_truth and match_term != ground_truth: 97 | em_match_list.remove(match_term) 98 | if ground_truth in em_match_list: 99 | score = (1.0 / len(em_match_list)) 100 | else: 101 | score = 0.0 102 | return score 103 | 104 | def rouge_score(prediction, ground_truth, **kwargs): 105 | rouge = Rouge() 106 | try: 107 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 108 | except: 109 | return 0.0 110 | return scores["rouge-l"]["f"] 111 | 112 | def rouge_zh_score(prediction, ground_truth, **kwargs): 113 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 114 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 115 | score = rouge_score(prediction, ground_truth) 116 | return score 117 | 118 | def f1_score(prediction, ground_truth, **kwargs): 119 | common = Counter(prediction) & Counter(ground_truth) 120 | num_same = sum(common.values()) 121 | if num_same == 0: 122 | return 0 123 | precision = 1.0 * num_same / len(prediction) 124 | recall = 1.0 * num_same / len(ground_truth) 125 | f1 = (2 * precision * recall) / (precision + recall) 126 | return f1 127 | 128 | def qa_f1_score(prediction, ground_truth, **kwargs): 129 | normalized_prediction = normalize_answer(prediction) 130 | normalized_ground_truth = normalize_answer(ground_truth) 131 | 132 | prediction_tokens = normalized_prediction.split() 133 | ground_truth_tokens = normalized_ground_truth.split() 134 | return f1_score(prediction_tokens, ground_truth_tokens) 135 | 136 | 137 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 138 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 139 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 140 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 141 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 142 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 143 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 144 | return f1_score(prediction_tokens, ground_truth_tokens) 145 | -------------------------------------------------------------------------------- /benchmark/LongBench/pred.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datasets import load_dataset 4 | import torch 5 | import json 6 | from tqdm import tqdm 7 | import numpy as np 8 | import random 9 | import argparse 10 | import time 11 | 12 | PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) 13 | sys.path.append(PROJECT_ROOT) 14 | from model_hub import LlamaModel, QwenModel 15 | 16 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 17 | from config import generate_config, parse_attn_args 18 | 19 | model2path = json.load(open("config/model2path.json", "r")) 20 | model2maxlen = json.load(open("config/model2maxlen.json", "r")) 21 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 22 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 23 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) 24 | 25 | def parse_args(args=None): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--attn_type", type=str, default="Full_Flash_Attn", \ 28 | choices=["Full_Flash_Attn", "RetroInfer"], \ 29 | help="Attention method") 30 | parser.add_argument('--model', type=str, default=None, choices= 31 | ["llama-3-8b-1048k", "qwen2.5-7b", "llama-3.1-8b", "qwen2.5-72b"]) 32 | parser.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16"], help="Dtype") 33 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 34 | parser.add_argument('--task', type=str, required=True, help="task name. work when --e is false") 35 | parser.add_argument("--device", type=str, default="auto", help="Device") 36 | parser.add_argument("--num_examples", type=int, default=-1, help="num of example to evaluate. -1 for all.") 37 | 38 | parser = parse_attn_args(parser) 39 | 40 | return parser.parse_args(args) 41 | 42 | 43 | def get_pred(llm, data, max_new_tokens, prompt_format, model_name, out_path, args): 44 | for json_obj in tqdm(data): 45 | prompt = prompt_format.format(**json_obj) 46 | 47 | inputs = llm.tokenizer([prompt], return_tensors="pt", padding=True) 48 | input_ids = inputs.input_ids 49 | attention_masks = inputs.attention_mask 50 | 51 | attn_config = generate_config( 52 | model2path[model_name], 53 | input_ids.shape[1], 54 | attn_type, 55 | budget_ratio=args.budget_ratio, 56 | estimate_ratio=args.estimate_ratio, 57 | ) 58 | 59 | out = llm.generate( 60 | attention_type=attn_type, 61 | inputs_ids = input_ids.to(llm.layers[0].device), 62 | attention_masks = attention_masks.to(llm.layers[0].device), 63 | max_new_length=max_new_tokens, 64 | attn_config=attn_config 65 | ) 66 | 67 | output = llm.tokenizer.batch_decode(out, skip_special_tokens=True) 68 | 69 | torch.cuda.empty_cache() 70 | 71 | print("Chunked generation:", output[0][:50]) 72 | 73 | pred = output[0] 74 | 75 | with open(out_path, "a", encoding="utf-8") as f: 76 | json.dump( 77 | { 78 | "pred": pred, 79 | "answers": json_obj["answers"], 80 | "all_classes": json_obj["all_classes"], 81 | "length": json_obj["length"] 82 | }, 83 | f, 84 | ensure_ascii=False 85 | ) 86 | f.write('\n') 87 | 88 | def seed_everything(seed): 89 | torch.manual_seed(seed) 90 | torch.cuda.manual_seed(seed) 91 | np.random.seed(seed) 92 | random.seed(seed) 93 | torch.backends.cudnn.benchmark = False 94 | torch.backends.cudnn.deterministic = True 95 | torch.cuda.manual_seed_all(seed) 96 | 97 | 98 | def load_model(model_path, max_len, dtype, device): 99 | if 'Llama' in model_path: 100 | llm = LlamaModel(model_path, 101 | max_length=max_len, 102 | dtype=dtype, 103 | device_map=device) 104 | elif 'Qwen' in model_path: 105 | llm = QwenModel(model_path, 106 | max_length=max_len, 107 | dtype=dtype, 108 | device_map=device) 109 | else: 110 | raise ValueError(f"Unsupported model: {model_path}") 111 | 112 | llm.tokenizer.pad_token = llm.tokenizer.eos_token 113 | llm.tokenizer.padding_side = "left" 114 | 115 | return llm 116 | 117 | 118 | if __name__ == '__main__': 119 | seed_everything(42) 120 | args = parse_args() 121 | 122 | num_examples = args.num_examples 123 | attn_type = args.attn_type 124 | model_name = args.model # not hf model path 125 | device = args.device 126 | dtype = torch.float16 if args.dtype == 'fp16' else torch.bfloat16 127 | 128 | max_length = model2maxlen[model_name] 129 | model_path = model2path[model_name] 130 | 131 | llm = load_model(model_path, max_length, dtype, device) 132 | 133 | if args.e: 134 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ 135 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 136 | else: 137 | datasets = [args.task] 138 | # datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ 139 | # "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ 140 | # "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] 141 | 142 | # predict on each dataset 143 | if not os.path.exists("results/pred"): 144 | os.makedirs("results/pred") 145 | if not os.path.exists("results/pred_e"): 146 | os.makedirs("results/pred_e") 147 | 148 | for dataset in datasets: 149 | if args.e: 150 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') 151 | 152 | prefix = f"results/pred_e/{model_name}/{attn_type}" 153 | if not os.path.exists(prefix): 154 | os.makedirs(prefix) 155 | out_path = f"{prefix}/{dataset}.jsonl" 156 | else: 157 | data = load_dataset('THUDM/LongBench', dataset, split='test') 158 | 159 | prefix = f"results/pred/{model_name}/{attn_type}" 160 | if not os.path.exists(prefix): 161 | os.makedirs(prefix) 162 | out_path = f"{prefix}/{dataset}.jsonl" 163 | 164 | prompt_format = dataset2prompt[dataset] 165 | max_new_tokens = dataset2maxlen[dataset] 166 | data_all = [data_sample for data_sample in data] 167 | data_all = data_all[:num_examples] if num_examples > 0 else data_all 168 | 169 | get_pred( 170 | llm, 171 | data_all, 172 | max_new_tokens, 173 | prompt_format, 174 | model_name, 175 | out_path, 176 | args, 177 | ) -------------------------------------------------------------------------------- /benchmark/LongBench/pred.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | if [ $# -ne 6 ]; then 4 | echo "Usage: $0 $1 $2 $3 $4 $5 " 5 | exit 1 6 | fi 7 | 8 | NUM_EXAMPLES=-1 9 | MODEL=${1} 10 | TASK=${2} 11 | ATTN_TYPE=${3} 12 | DTYPE=${4} 13 | BUDGET_RATIO=${5} 14 | ESTIMATE_RATIO=${6} 15 | 16 | RESULT_DIR="./results/pred/${MODEL}/${ATTN_TYPE}" 17 | RESULT_DIR_E="./results/pred_e/${MODEL}/${ATTN_TYPE}" 18 | 19 | echo "remove previous result file..." 20 | rm -f "${RESULT_DIR}/${TASK}.jsonl" 21 | rm -f "${RESULT_DIR_E}/${TASK}.jsonl" 22 | 23 | echo "Start to predict..." 24 | python -u pred.py \ 25 | --task ${TASK} \ 26 | --attn_type ${ATTN_TYPE} \ 27 | --model ${MODEL} \ 28 | --dtype ${DTYPE} \ 29 | --device auto \ 30 | --budget_ratio ${BUDGET_RATIO} \ 31 | --estimate_ratio ${ESTIMATE_RATIO} \ 32 | --num_examples ${NUM_EXAMPLES} -------------------------------------------------------------------------------- /benchmark/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import argparse 6 | 7 | PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 8 | sys.path.append(PROJECT_ROOT) 9 | 10 | def parse_attn_args(parser: argparse.ArgumentParser): 11 | parser.add_argument("--budget_ratio", type=float, default=0.018, help="ratio of budget") 12 | parser.add_argument("--estimate_ratio", type=float, default=0.25, help="ratio of estimated clusters for RetriveInfer") 13 | 14 | return parser 15 | 16 | 17 | def generate_config( 18 | model_name, 19 | context_len, 20 | attn_type, 21 | budget_ratio=0.018, 22 | estimate_ratio=0.25, 23 | # default retrieve infer configs 24 | n_segments=None, 25 | ): 26 | aprox_cluster_size = 16 27 | 28 | CONFIG_DIR = os.path.join(PROJECT_ROOT, "config") 29 | MODEL_NAME = model_name.split("/")[-1]+'.json' 30 | CONFIG_FILE = os.path.join(CONFIG_DIR, MODEL_NAME) 31 | with open(CONFIG_FILE, "r") as f: 32 | original_config = json.load(f) 33 | 34 | if n_segments is None: 35 | n_segments = max(1, context_len // 8192) 36 | 37 | n_clusters = math.ceil(context_len/aprox_cluster_size) 38 | 39 | if attn_type == 'RetroInfer': 40 | # compute the nearest multiple of (n_segments*32) 41 | lower = (n_clusters // (n_segments*32)) * (n_segments*32) 42 | upper = lower + (n_segments*32) 43 | n_clusters = lower if abs(n_clusters - lower) <= abs(n_clusters - upper) else upper 44 | 45 | nprobe = max(1, int(n_clusters*budget_ratio)) 46 | print(f"context_len: {context_len}, n_clusters: {n_clusters}, nprobe: {nprobe}, n_segments: {n_segments}") 47 | 48 | if attn_type == 'RetroInfer': 49 | original_config[attn_type]['n_centroids'] = n_clusters 50 | original_config[attn_type]['n_segment'] = n_segments 51 | original_config[attn_type]['nprobe'] = nprobe 52 | original_config[attn_type]['cache_cluster_num'] = int(nprobe*3) 53 | original_config[attn_type]['max_compute_cluster_num'] = int(n_clusters*estimate_ratio) + nprobe 54 | 55 | print(original_config[attn_type]) 56 | 57 | return original_config -------------------------------------------------------------------------------- /benchmark/ruler/data/prepare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Prepare jsonl with field `input` and `outputs`. 17 | { 18 | "index" int, 19 | "input": str, 20 | "outputs": [str], 21 | } 22 | 23 | python prepare.py \ 24 | --save_dir ./ \ 25 | --benchmark synthetic \ 26 | --task niah_single_1 \ 27 | --tokenizer_path tokenizer.model \ 28 | --tokenizer_type nemo \ 29 | --max_seq_length 4096 \ 30 | --model_template_type base \ 31 | --num_samples 10 \ 32 | """ 33 | import os 34 | import re 35 | import json 36 | import argparse 37 | import importlib 38 | import subprocess 39 | import time 40 | import math 41 | import yaml 42 | from pathlib import Path 43 | import nltk 44 | try: 45 | nltk.data.find('tokenizers/punkt') 46 | except LookupError: 47 | nltk.download('punkt') 48 | 49 | 50 | Templates = { 51 | 'base': "{task_template}", 52 | 53 | 'meta-chat': "[INST] {task_template} [/INST]", 54 | 55 | 'vicuna-chat': "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {task_template} ASSISTANT:", 56 | 57 | 'lwm-chat': "You are a helpful assistant. USER: {task_template} ASSISTANT: ", 58 | 59 | 'command-r-chat': "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{task_template}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", 60 | 61 | 'chatglm-chat': "[gMASK]sop<|user|> \n {task_template}<|assistant|> \n ", 62 | 63 | 'RWKV': "User: hi\n\nAssistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it\n\nUser: {task_template}\n\nAssistant:", 64 | } 65 | 66 | 67 | def main(args): 68 | start_time = time.time() 69 | curr_folder = os.path.dirname(os.path.abspath(__file__)) 70 | 71 | try: 72 | module = importlib.import_module(f"{args.benchmark}.constants") 73 | except ImportError: 74 | print(f"Module data.{args.benchmark}.constants not found.") 75 | 76 | tasks_base = module.TASKS 77 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f: 78 | tasks_customized = yaml.safe_load(f) 79 | 80 | if args.task not in tasks_customized: 81 | raise ValueError(f'{args.task} is not found in config_tasks.yaml') 82 | 83 | config = tasks_customized.get(args.task) 84 | config.update(tasks_base[config['task']]) 85 | 86 | # Add templates 87 | assert args.model_template_type in Templates, print(f'{args.model_template_type} is not found in {Templates.keys()}') 88 | model_template = Templates[args.model_template_type] 89 | task_template = config['template'] 90 | 91 | # Add answer prefix for all models 92 | answer_prefix = config['answer_prefix'] if 'answer_prefix' in config else '' 93 | config['template'] = model_template.format(task_template=task_template) + answer_prefix 94 | 95 | # Split task into multiple chunks 96 | chunks = [(args.num_samples // args.chunk_amount) + (1 if i < args.num_samples % args.chunk_amount else 0) for i in range(args.chunk_amount)] 97 | num_samples = chunks[args.chunk_idx] 98 | pre_samples = sum(chunks[:args.chunk_idx]) 99 | 100 | random_seed = 42 + args.chunk_idx 101 | 102 | 103 | try: 104 | script = os.path.join(curr_folder, args.benchmark, f"{config['task']}.py") 105 | additional_args = " ".join([f"--{k} {v}" for k, v in config['args'].items()]) 106 | command = f"""python -u {script} \ 107 | --save_dir {args.save_dir} \ 108 | --save_name {args.task} \ 109 | --subset {args.subset} \ 110 | --tokenizer_path {args.tokenizer_path} \ 111 | --tokenizer_type {args.tokenizer_type} \ 112 | --max_seq_length {args.max_seq_length} \ 113 | --tokens_to_generate {config['tokens_to_generate']} \ 114 | --num_samples {num_samples} \ 115 | --random_seed {random_seed} \ 116 | {additional_args} \ 117 | {f"--remove_newline_tab" if args.remove_newline_tab else ""} \ 118 | {f"--pre_samples {pre_samples}" if config['task'] == 'qa' else ""} \ 119 | --template "{config['template']}" 120 | """ 121 | print(command) 122 | result = subprocess.run(command, 123 | shell=True, 124 | check=True, 125 | stdout=subprocess.PIPE, 126 | stderr=subprocess.PIPE, 127 | text=True) 128 | 129 | if result.returncode == 0: 130 | print("Output:") 131 | print(result.stdout) 132 | else: 133 | print("Error:") 134 | print(result.stderr) 135 | except subprocess.CalledProcessError as e: 136 | print("Error output:", e.stderr) 137 | 138 | save_file = args.save_dir / args.task / f"{args.subset}.jsonl" 139 | print(f"Prepare {args.task} with lines: {args.num_samples} to {save_file}") 140 | print(f"Used time: {round((time.time() - start_time) / 60, 1)} minutes") 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 145 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]') 146 | parser.add_argument("--task", type=str, required=True, help='tasks in benchmark') 147 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 148 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 149 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 150 | parser.add_argument("--model_template_type", type=str, default='base', help='Options in `template.py`') 151 | parser.add_argument("--num_samples", type=int, default=500, help='maximum number of samples we want to test') 152 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 153 | parser.add_argument("--chunk_idx", type=int, default=0, help='index of current split chunk') 154 | parser.add_argument("--chunk_amount", type=int, default=1, help='size of split chunk') 155 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 156 | parser.add_argument("--random_seed", type=int, default=42) 157 | 158 | args = parser.parse_args() 159 | main(args) 160 | -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/common_words_extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for common words extraction. 17 | 18 | python common_words_extraction.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type nemo \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | -freq_cw 30 --freq_ucw 3 --num_cw 10 \ 28 | --template "[INST] Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list? [/INST] Answer: The top 10 words that appear most often in the list are:" 29 | """ 30 | 31 | import os 32 | import argparse 33 | from pathlib import Path 34 | from tqdm import tqdm 35 | import random 36 | import wonderwords 37 | import sys 38 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 39 | from utils import dump_jsonl 40 | from tokenizer import select_tokenizer 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 44 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 45 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 46 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 47 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 48 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 49 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 50 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 51 | parser.add_argument("--random_seed", type=int, default=42) 52 | parser.add_argument("--template", type=str, default='', help='prompt template') 53 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 54 | 55 | parser.add_argument("--freq_cw", type=int, default=30) 56 | parser.add_argument("--freq_ucw", type=int, default=3) 57 | parser.add_argument("--num_cw", type=int, default=10) 58 | 59 | args = parser.parse_args() 60 | random.seed(args.random_seed) 61 | 62 | # Load Tokenizer 63 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 64 | 65 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") 66 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") 67 | verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt") 68 | words = nouns + adjs + verbs 69 | words = sorted(list(set(words))) 70 | random.Random(args.random_seed).shuffle(words) 71 | 72 | def get_example(num_words, common_repeats=30, uncommon_repeats=3, common_nums=10): 73 | word_list_full = random.sample(words, num_words) 74 | common, uncommon = word_list_full[:common_nums], word_list_full[common_nums:] 75 | word_list = common * int(common_repeats) + uncommon * int(uncommon_repeats) 76 | random.Random(args.random_seed).shuffle(word_list) 77 | 78 | # Formatting the word list as "1. word1 2. word2 3. word3 ..." 79 | context = ' '.join([f"{i + 1}. {word}" for i, word in enumerate(word_list)]) 80 | 81 | return context, common 82 | 83 | def generate_input_output(num_words): 84 | if args.max_seq_length < 4096: 85 | context_example, answer_example = get_example(20, 3, 1, args.num_cw) 86 | context, answer = get_example(num_words, 6, 1, args.num_cw) 87 | else: 88 | context_example, answer_example = get_example(40, 10, 3, args.num_cw) 89 | context, answer = get_example(num_words, args.freq_cw, args.freq_ucw, args.num_cw) 90 | 91 | template = args.template 92 | 93 | input_example = template.format( 94 | context=context_example, 95 | query='', 96 | ) + ' '.join([f"{i + 1}. {word}" for i, word in enumerate(answer_example)]) 97 | 98 | input_text = template.format( 99 | context=context, 100 | query='', 101 | ) 102 | 103 | return input_example + "\n" + input_text, answer 104 | 105 | def sys_word_pair_random(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10): 106 | write_jsons = [] 107 | tokens_to_generate = args.tokens_to_generate 108 | 109 | # Find the perfect num_words 110 | num_words = incremental 111 | 112 | total_tokens = 0 113 | while total_tokens + tokens_to_generate < max_seq_length: 114 | 115 | input_text, answer = generate_input_output(num_words) 116 | # Calculate the number of tokens in the example 117 | total_tokens = len(TOKENIZER.text_to_tokens(input_text + ' ' + ' '.join([f"{i + 1}. {word}" for i, word in enumerate(answer)]))) 118 | print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}') 119 | if total_tokens + tokens_to_generate > max_seq_length: 120 | num_words -= incremental 121 | break 122 | 123 | num_words += incremental 124 | if num_words > len(words): 125 | num_words = len(words) 126 | break 127 | 128 | print('num_words:', num_words) 129 | 130 | # Generate samples 131 | for index in tqdm(range(num_samples)): 132 | used_words = num_words 133 | while(True): 134 | try: 135 | input_text, answer = generate_input_output(used_words) 136 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 137 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 138 | break 139 | except: 140 | if used_words > incremental: 141 | used_words -= incremental 142 | 143 | if args.remove_newline_tab: 144 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 145 | 146 | formatted_output = { 147 | 'index': index, 148 | "input": input_text, 149 | "outputs": answer, 150 | "length": length, 151 | } 152 | write_jsons.append(formatted_output) 153 | 154 | return write_jsons 155 | 156 | 157 | def main(): 158 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 159 | save_file.parent.mkdir(parents=True, exist_ok=True) 160 | 161 | write_jsons = sys_word_pair_random(num_samples=args.num_samples, max_seq_length=args.max_seq_length, save_dir=args.save_dir) 162 | 163 | dump_jsonl(save_file, write_jsons) 164 | 165 | if __name__=="__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Add a new task (required arguments): 17 | 18 | TASK_NAME: { 19 | 'tokens_to_generate': how many tokens we want to generate. 20 | 'template': the template with at least {context} and {query}. 21 | } 22 | """ 23 | 24 | TASKS = { 25 | 'niah': { 26 | 'tokens_to_generate': 128, 27 | 'template': """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""", 28 | 'answer_prefix': """ The special magic {type_needle_v} for {query} mentioned in the provided text are""" 29 | }, 30 | 31 | 'variable_tracking': { 32 | 'tokens_to_generate': 30, 33 | 'template': """Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above.""", 34 | 'answer_prefix': """ Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: """ 35 | }, 36 | 37 | 'common_words_extraction': { 38 | 'tokens_to_generate': 120, 39 | 'template': """Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?""", 40 | 'answer_prefix': """ Answer: The top 10 words that appear most often in the list are:""" 41 | }, 42 | 43 | 'freq_words_extraction' : { 44 | 'tokens_to_generate': 50, 45 | 'template': """Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text?""", 46 | 'answer_prefix': """ Answer: According to the coded text above, the three most frequently appeared words are:""" 47 | }, 48 | 49 | 'qa': { 50 | 'tokens_to_generate': 32, 51 | 'template': """Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query}""", 52 | 'answer_prefix': """ Answer:""", 53 | }, 54 | } -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/freq_words_extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for frequent words extraction. 17 | 18 | python freq_words_extraction.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type nemo \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | --alpha 2.0 \ 28 | --template "[INST] Read the following coded text and track the frequency of each coded word. Find the three most frequently appeared coded words. {context}\nQuestion: Do not provide any explanation. Please ignore the dots '....'. What are the three most frequently appeared words in the above coded text? [/INST] Answer: According to the coded text above, the three most frequently appeared words are:" 29 | """ 30 | 31 | import os 32 | import argparse 33 | from pathlib import Path 34 | from tqdm import tqdm 35 | import random 36 | import string 37 | import numpy as np 38 | import sys 39 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 40 | from tokenizer import select_tokenizer 41 | from utils import dump_jsonl 42 | from scipy.special import zeta 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 46 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 47 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 48 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 49 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 50 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 51 | parser.add_argument("--tokens_to_generate", type=int, default=50, help='number of tokens to generate') 52 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 53 | parser.add_argument("--random_seed", type=int, default=42) 54 | parser.add_argument("--template", type=str, default='', help='prompt template') 55 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 56 | parser.add_argument("--coded_wordlen", type=int, default=6, help="length of synthetic word") 57 | parser.add_argument("--vocab_size", type=int, default=-1, help='synthetic vocab size to sample from') 58 | parser.add_argument("--alpha", type=float, default=2.0, help='zeta distribution alpha') 59 | parser.add_argument("--add_fewshot", action="store_true", default=False) 60 | 61 | args = parser.parse_args() 62 | random.seed(args.random_seed) 63 | np.random.seed(args.random_seed) 64 | 65 | # Load Tokenizer 66 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 67 | 68 | def generate_input_output(max_len, num_words=-1, coded_wordlen=6, vocab_size=2000, incremental=10, alpha=2.0): 69 | # generate vocab 70 | vocab = [''.join(random.choices(string.ascii_lowercase, k=coded_wordlen)) for _ in range(vocab_size)] 71 | while len(set(vocab)) < vocab_size: 72 | vocab.append(''.join(random.choices(string.ascii_lowercase, k=coded_wordlen))) 73 | vocab = sorted(list(set(vocab))) 74 | random.Random(args.random_seed).shuffle(vocab) 75 | vocab[0] = '...' # treat the top ranked as noise 76 | 77 | # sample words 78 | template = args.template 79 | def gen_text(num_words): 80 | k = np.arange(1, len(vocab)+1) 81 | sampled_cnt = num_words*(k**-alpha)/zeta(alpha) 82 | sampled_words = [[w] * zi for w, zi in zip(vocab, sampled_cnt.astype(int))] 83 | sampled_words = [x for wlst in sampled_words for x in wlst] 84 | random.Random(args.random_seed).shuffle(sampled_words) 85 | return template.format(context=' '.join(sampled_words), query=''), vocab[1:4] 86 | 87 | if num_words > 0: 88 | num_words = num_words 89 | text, answer = gen_text(num_words) 90 | while len(TOKENIZER.text_to_tokens(text)) > max_len: 91 | num_words -= incremental 92 | text, answer = gen_text(num_words) 93 | else: 94 | num_words = max_len // coded_wordlen # init 95 | text, answer = gen_text(num_words) 96 | while len(TOKENIZER.text_to_tokens(text)) < max_len: 97 | num_words += incremental 98 | text, answer = gen_text(num_words) 99 | num_words -= incremental 100 | text, answer = gen_text(num_words) 101 | return text, answer, num_words 102 | 103 | def sys_kwext(num_samples: int, max_seq_length: int, incremental: int = 10): 104 | write_jsons = [] 105 | tokens_to_generate = args.tokens_to_generate 106 | 107 | vocab_size = max_seq_length // 50 if args.vocab_size == -1 else args.vocab_size 108 | 109 | # get number of words 110 | input_max_len = max_seq_length 111 | _, _, num_example_words = generate_input_output(input_max_len, 112 | coded_wordlen=args.coded_wordlen, 113 | vocab_size=vocab_size, 114 | incremental=input_max_len//32, 115 | alpha=args.alpha) 116 | print('num_example_words:', num_example_words) 117 | # Generate samples 118 | for index in tqdm(range(num_samples)): 119 | 120 | # construct input 121 | input_max_len = max_seq_length 122 | input_text, answer, _ = generate_input_output(input_max_len, 123 | num_words=num_example_words, 124 | coded_wordlen=args.coded_wordlen, 125 | vocab_size=vocab_size, 126 | incremental=input_max_len//32, 127 | alpha=args.alpha) 128 | 129 | 130 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 131 | 132 | if args.remove_newline_tab: 133 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 134 | 135 | formatted_output = { 136 | 'index': index, 137 | "input": input_text, 138 | "outputs": answer, 139 | "length": length, 140 | } 141 | write_jsons.append(formatted_output) 142 | 143 | return write_jsons 144 | 145 | 146 | def main(): 147 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 148 | save_file.parent.mkdir(parents=True, exist_ok=True) 149 | write_jsons = sys_kwext(num_samples=args.num_samples, max_seq_length=args.max_seq_length, 150 | incremental=10) 151 | 152 | dump_jsonl(save_file, write_jsons) 153 | 154 | if __name__=="__main__": 155 | main() -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/json/download_paulgraham_essay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | import os 16 | import shutil 17 | import glob 18 | import json 19 | import urllib.request 20 | import html2text 21 | from bs4 import BeautifulSoup 22 | from tqdm import tqdm 23 | os.environ["PYTHONHASHSEED"] = str(42) 24 | 25 | temp_folder_repo = 'essay_repo' 26 | temp_folder_html = 'essay_html' 27 | os.makedirs(temp_folder_repo, exist_ok=True) 28 | os.makedirs(temp_folder_html, exist_ok=True) 29 | 30 | h = html2text.HTML2Text() 31 | h.ignore_images = True 32 | h.ignore_tables = True 33 | h.escape_all = True 34 | h.reference_links = False 35 | h.mark_code = False 36 | 37 | with open('PaulGrahamEssays_URLs.txt') as f: 38 | urls = [line.strip() for line in f] 39 | 40 | for url in tqdm(urls): 41 | if '.html' in url: 42 | filename = url.split('/')[-1].replace('.html', '.txt') 43 | try: 44 | with urllib.request.urlopen(url) as website: 45 | content = website.read().decode("unicode_escape", "utf-8") 46 | soup = BeautifulSoup(content, 'html.parser') 47 | specific_tag = soup.find('font') 48 | parsed = h.handle(str(specific_tag)) 49 | 50 | with open(os.path.join(temp_folder_html, filename), 'w') as file: 51 | file.write(parsed) 52 | 53 | except Exception as e: 54 | print(f"Fail download {filename}, ({e})") 55 | 56 | else: 57 | filename = url.split('/')[-1] 58 | try: 59 | with urllib.request.urlopen(url) as website: 60 | content = website.read().decode('utf-8') 61 | 62 | with open(os.path.join(temp_folder_repo, filename), 'w') as file: 63 | file.write(content) 64 | 65 | except Exception as e: 66 | print(f"Fail download {filename}, ({e})") 67 | 68 | files_repo = glob.glob(os.path.join(temp_folder_repo,'*.txt')) 69 | files_html = glob.glob(os.path.join(temp_folder_html,'*.txt')) 70 | print(f'Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`') 71 | print(f'Download {len(files_html)} essays from `http://www.paulgraham.com/`') 72 | 73 | text = "" 74 | for file in sorted(files_repo + files_html): # sort by filename to ensure text is the same 75 | with open(file, 'r') as f: 76 | text += f.read() 77 | 78 | with open('PaulGrahamEssays.json', 'w') as f: 79 | json.dump({"text": text}, f) 80 | 81 | 82 | shutil.rmtree(temp_folder_repo) 83 | shutil.rmtree(temp_folder_html) 84 | -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/json/download_qa_dataset.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O squad.json 16 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json -O hotpotqa.json 17 | -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/niah.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for needle in a haystack. 17 | 18 | python niah.py \ 19 | --save_dir=./ \ 20 | --save_name=niah_single \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type=nemo \ 23 | --max_seq_length=4096 \ 24 | --tokens_to_generate=128 \ 25 | --num_samples=10 \ 26 | --template="Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text? The special magic {type_needle_v} for {query} mentioned in the provided text are" 27 | """ 28 | import os 29 | import re 30 | import json 31 | import uuid 32 | import argparse 33 | import importlib 34 | import numpy as np 35 | from pathlib import Path 36 | from tqdm import tqdm 37 | import random 38 | import wonderwords 39 | import sys 40 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 41 | from utils import dump_jsonl 42 | from tokenizer import select_tokenizer 43 | import nltk 44 | nltk.download('punkt_tab') 45 | from nltk.tokenize import sent_tokenize 46 | 47 | 48 | parser = argparse.ArgumentParser() 49 | # Basic Configurations 50 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 51 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 52 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 53 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 54 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 55 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 56 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 57 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 58 | parser.add_argument("--random_seed", type=int, default=42) 59 | parser.add_argument("--template", type=str, default='', help='prompt template') 60 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 61 | 62 | # Complexity Configurations 63 | parser.add_argument("--num_needle_k", type=int, default=1) 64 | parser.add_argument("--num_needle_v", type=int, default=1) 65 | parser.add_argument("--num_needle_q", type=int, default=1) 66 | parser.add_argument("--type_haystack", type=str, default='essay', help='[Options] repeat, essay, needle.') 67 | parser.add_argument("--type_needle_k", type=str, default='words', help='[Options] numbers, words, uuids.') 68 | parser.add_argument("--type_needle_v", type=str, default='numbers', help='[Options] numbers, words, uuids.') 69 | 70 | args = parser.parse_args() 71 | random.seed(args.random_seed) 72 | np.random.seed(args.random_seed) 73 | args.num_needle_k = max(args.num_needle_k, args.num_needle_q) 74 | 75 | # Load Tokenizer 76 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 77 | 78 | # Define Needle/Haystack Format 79 | needle = "One of the special magic {type_needle_v} for {key} is: {value}." 80 | if args.type_haystack == 'essay': 81 | essay = os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/PaulGrahamEssays.json") 82 | essay = json.load(open(essay))['text'] 83 | haystack = re.sub(r'\s+', " ", essay).split(" ") 84 | elif args.type_haystack == 'repeat': 85 | haystack = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 86 | elif args.type_haystack == 'needle': 87 | haystack = needle 88 | else: 89 | raise NotImplementedError(f'{args.type_haystack} is not implemented.') 90 | 91 | 92 | # Words 93 | nouns = wonderwords.random_word._get_words_from_text_file("nounlist.txt") 94 | adjs = wonderwords.random_word._get_words_from_text_file("adjectivelist.txt") 95 | # verbs = wonderwords.random_word._get_words_from_text_file("verblist.txt") 96 | words = [f"{adj}-{noun}" for adj in adjs for noun in nouns] 97 | words = sorted(list(set(words))) 98 | 99 | 100 | # Positions 101 | DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int)) 102 | 103 | 104 | def generate_random_number(num_digits=7): 105 | lower_bound = 10**(num_digits - 1) 106 | upper_bound = 10**num_digits - 1 107 | return str(random.randint(lower_bound, upper_bound)) 108 | 109 | def generate_random_word(): 110 | word = random.choice(words) 111 | return word 112 | 113 | def generate_random_uuid(): 114 | return str(uuid.UUID(int=random.getrandbits(128), version=4)) 115 | 116 | def generate_random(type_needle: str): 117 | if type_needle == 'numbers': 118 | return generate_random_number() 119 | elif type_needle == 'words': 120 | return generate_random_word() 121 | elif type_needle == 'uuids': 122 | return generate_random_uuid() 123 | else: 124 | raise NotImplementedError(f'{args.type_needle} is not implemented.') 125 | 126 | def generate_input_output(num_haystack): 127 | keys, values, needles = [], [], [] 128 | for _ in range(args.num_needle_k): 129 | keys.append(generate_random(args.type_needle_k)) 130 | value = [] 131 | for _ in range(args.num_needle_v): 132 | value.append(generate_random(args.type_needle_v)) 133 | needles.append(needle.format( 134 | type_needle_v=args.type_needle_v, 135 | key=keys[-1], 136 | value=value[-1], 137 | )) 138 | values.append(value) 139 | 140 | random.Random(args.random_seed).shuffle(needles) 141 | 142 | # Context 143 | if args.type_haystack == 'essay': 144 | text = " ".join(haystack[:num_haystack]) 145 | document_sents = sent_tokenize(text.strip()) 146 | insertion_positions = [0] + \ 147 | sorted([int(len(document_sents) * (depth / 100)) for depth in random.sample(DEPTHS, len(needles))]) + \ 148 | [len(document_sents)] 149 | document_sents_list = [] 150 | for i in range(1,len(insertion_positions)): 151 | last_pos = insertion_positions[i-1] 152 | next_pos = insertion_positions[i] 153 | document_sents_list.append(" ".join(document_sents[last_pos:next_pos])) 154 | if i-1 < len(needles): 155 | document_sents_list.append(needles[i-1]) 156 | context = " ".join(document_sents_list) 157 | 158 | else: 159 | if args.type_haystack == 'repeat': 160 | sentences = [haystack] * num_haystack 161 | elif args.type_haystack == 'needle': 162 | sentences = [haystack.format( 163 | type_needle_v=args.type_needle_v, 164 | key=generate_random(args.type_needle_k), 165 | value=generate_random(args.type_needle_v), 166 | ) for _ in range(num_haystack)] 167 | 168 | 169 | indexes = sorted(random.sample(range(num_haystack), len(needles)), reverse=True) 170 | for index, element in zip(indexes, needles): 171 | sentences.insert(index, element) 172 | context = "\n".join(sentences) 173 | 174 | 175 | ## Query and Answer 176 | indices = random.sample(range(args.num_needle_k), args.num_needle_q) 177 | queries = [keys[i] for i in indices] 178 | answers = [a for i in indices for a in values[i]] 179 | query = ', '.join(queries[:-1]) + ', and ' + queries[-1] if len(queries) > 1 else queries[0] 180 | 181 | template = args.template 182 | type_needle_v = args.type_needle_v 183 | if args.num_needle_q * args.num_needle_v == 1: 184 | template = template.replace('Some', 'A') 185 | template = template.replace('are all', 'is') 186 | template = template.replace('are', 'is') 187 | template = template.replace('answers', 'answer') 188 | type_needle_v = type_needle_v[:-1] # remove "s" 189 | 190 | input_text = template.format( 191 | type_needle_v=type_needle_v, 192 | context=context, 193 | query=query, 194 | ) 195 | 196 | return input_text, answers 197 | 198 | 199 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 500): 200 | write_jsons = [] 201 | tokens_to_generate = args.tokens_to_generate 202 | 203 | if args.type_haystack == 'essay': 204 | incremental = 500 205 | elif args.type_haystack == 'repeat': 206 | incremental = 25 207 | elif args.type_haystack == 'needle': 208 | incremental = 25 209 | 210 | if args.type_haystack != 'essay' and args.max_seq_length < 4096: 211 | incremental = 5 212 | 213 | num_haystack = incremental 214 | 215 | total_tokens = 0 # Track the total tokens generated for the first example 216 | while total_tokens + tokens_to_generate < max_seq_length : 217 | input_text, answer = generate_input_output(num_haystack) 218 | # Calculate the number of tokens in the example 219 | total_tokens = len(TOKENIZER.text_to_tokens(input_text + ' '.join(answer))) 220 | print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Haystack: {num_haystack}') 221 | if total_tokens + tokens_to_generate > max_seq_length: 222 | num_haystack -= incremental 223 | break 224 | 225 | if args.type_haystack == 'essay' and num_haystack > len(haystack): 226 | num_haystack = len(haystack) 227 | break 228 | 229 | num_haystack += incremental 230 | 231 | print('Num haystack:', num_haystack) 232 | 233 | # Generate samples 234 | for index in tqdm(range(num_samples)): 235 | used_haystack = num_haystack 236 | while(True): 237 | try: 238 | input_text, answer = generate_input_output(used_haystack) 239 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 240 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 241 | break 242 | except: 243 | if used_haystack > incremental: 244 | used_haystack -= incremental 245 | 246 | if args.remove_newline_tab: 247 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 248 | 249 | formatted_output = { 250 | 'index': index, 251 | "input": input_text, 252 | "outputs": answer, 253 | "length": length, 254 | } 255 | write_jsons.append(formatted_output) 256 | 257 | return write_jsons 258 | 259 | 260 | def main(): 261 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 262 | save_file.parent.mkdir(parents=True, exist_ok=True) 263 | 264 | write_jsons = generate_samples( 265 | num_samples=args.num_samples, 266 | max_seq_length=args.max_seq_length, 267 | save_dir=args.save_dir 268 | ) 269 | 270 | dump_jsonl(save_file, write_jsons) 271 | 272 | if __name__ == "__main__": 273 | main() -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/qa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for QA task. 17 | 18 | python qa.py \ 19 | --save_dir=./ \ 20 | --save_name=niah_single \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type=nemo \ 23 | --max_seq_length=4096 \ 24 | --tokens_to_generate=128 \ 25 | --num_samples=10 \ 26 | --template="Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query} Answer:" 27 | """ 28 | import os 29 | import re 30 | import json 31 | import argparse 32 | from pathlib import Path 33 | from tqdm import tqdm 34 | import random 35 | import numpy as np 36 | import sys 37 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 38 | from utils import dump_jsonl 39 | from tokenizer import select_tokenizer 40 | 41 | 42 | parser = argparse.ArgumentParser() 43 | # Basic Configurations 44 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 45 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 46 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 47 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 48 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 49 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 50 | parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.') 51 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 52 | parser.add_argument("--pre_samples", type=int, default=0, help='number of samples are already generated') 53 | parser.add_argument("--random_seed", type=int, default=42) 54 | parser.add_argument("--template", type=str, required=True, help='prompt template') 55 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 56 | 57 | # Complexity Configurations 58 | parser.add_argument("--dataset", type=str, required=True, help='dataset file') 59 | 60 | args = parser.parse_args() 61 | random.seed(args.random_seed) 62 | np.random.seed(args.random_seed) 63 | 64 | # Load Tokenizer 65 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 66 | 67 | # Read SQuAD QA dataset 68 | def read_squad(file): 69 | with open(file) as f: 70 | data = json.load(f) 71 | 72 | total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']] 73 | total_docs = sorted(list(set(total_docs))) 74 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} 75 | 76 | total_qas = [] 77 | for d in data['data']: 78 | more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']] 79 | for p in d['paragraphs']: 80 | for qas in p['qas']: 81 | if not qas['is_impossible']: 82 | total_qas.append({ 83 | 'query': qas['question'], 84 | 'outputs': [a['text'] for a in qas['answers']], 85 | 'context': [total_docs_dict[p['context']]], 86 | 'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]] 87 | }) 88 | 89 | return total_qas, total_docs 90 | 91 | # Read Hotpot QA dataset 92 | def read_hotpotqa(file): 93 | with open(file) as f: 94 | data = json.load(f) 95 | 96 | total_docs = [f"{t}\n{''.join(p)}" for d in data for t, p in d['context']] 97 | total_docs = sorted(list(set(total_docs))) 98 | total_docs_dict = {c: idx for idx, c in enumerate(total_docs)} 99 | 100 | total_qas = [] 101 | for d in data: 102 | total_qas.append({ 103 | 'query': d['question'], 104 | 'outputs': [d['answer']], 105 | 'context': [total_docs_dict[f"{t}\n{''.join(p)}"] for t, p in d['context']], 106 | }) 107 | 108 | return total_qas, total_docs 109 | 110 | 111 | DOCUMENT_PROMPT = "Document {i}:\n{document}" 112 | if args.dataset == 'squad': 113 | QAS, DOCS = read_squad(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/squad.json")) 114 | elif args.dataset == 'hotpotqa': 115 | QAS, DOCS = read_hotpotqa(os.path.join(os.path.dirname(os.path.abspath(__file__)), "json/hotpotqa.json")) 116 | else: 117 | raise NotImplementedError(f'{args.dataset} is not implemented.') 118 | 119 | 120 | def generate_input_output(index, num_docs): 121 | curr_q = QAS[index]['query'] 122 | curr_a = QAS[index]['outputs'] 123 | curr_docs = QAS[index]['context'] 124 | curr_more = QAS[index].get('more_context', []) 125 | if num_docs < len(DOCS): 126 | if (num_docs - len(curr_docs)) > len(curr_more): 127 | addition_docs = [i for i, d in enumerate(DOCS) if i not in curr_docs + curr_more] 128 | all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more))) 129 | else: 130 | all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs)) 131 | 132 | all_docs = [DOCS[idx] for idx in all_docs] 133 | else: 134 | all_docs = DOCS 135 | 136 | random.Random(args.random_seed).shuffle(all_docs) 137 | 138 | context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)]) 139 | input_text = args.template.format( 140 | context=context, 141 | query=curr_q 142 | ) 143 | return input_text, curr_a 144 | 145 | 146 | def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10): 147 | 148 | write_jsons = [] 149 | tokens_to_generate = args.tokens_to_generate 150 | 151 | # Find the perfect num_docs 152 | num_docs = incremental 153 | 154 | total_tokens = 0 # Track the total tokens generated for this example 155 | while total_tokens + tokens_to_generate < max_seq_length : 156 | input_text, answer = generate_input_output(0, num_docs) 157 | # Calculate the number of tokens in the example 158 | total_tokens = len(TOKENIZER.text_to_tokens(input_text + f' {answer}')) 159 | print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}') 160 | if total_tokens + tokens_to_generate > max_seq_length: 161 | num_docs -= incremental 162 | break 163 | 164 | num_docs += incremental 165 | if num_docs > len(DOCS): 166 | num_docs = len(DOCS) 167 | break 168 | print('Number of documents:', num_docs) 169 | 170 | # Generate samples 171 | for index in tqdm(range(num_samples)): 172 | used_docs = num_docs 173 | while(True): 174 | try: 175 | input_text, answer = generate_input_output(index + args.pre_samples, used_docs) 176 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate 177 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 178 | break 179 | except: 180 | if used_docs > incremental: 181 | used_docs -= incremental 182 | 183 | if args.remove_newline_tab: 184 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 185 | 186 | formatted_output = { 187 | "index": index, 188 | "input": input_text, 189 | "outputs": answer, 190 | "length": length 191 | } 192 | write_jsons.append(formatted_output) 193 | 194 | return write_jsons 195 | 196 | 197 | def main(): 198 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 199 | save_file.parent.mkdir(parents=True, exist_ok=True) 200 | write_jsons = generate_samples( 201 | num_samples=args.num_samples, 202 | max_seq_length=args.max_seq_length, 203 | save_dir=args.save_dir 204 | ) 205 | 206 | dump_jsonl(save_file, write_jsons) 207 | 208 | if __name__=="__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /benchmark/ruler/data/synthetic/variable_tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """ 16 | Create a dataset jsonl file for variable tracking. 17 | 18 | python variable_tracking.py \ 19 | --save_dir=./ \ 20 | --save_name=vt \ 21 | --tokenizer_path=tokenizer.model \ 22 | --tokenizer_type nemo \ 23 | --max_seq_length 4096 \ 24 | --tokens_to_generate 30 \ 25 | --num_samples 10 \ 26 | --random_seed 42 \ 27 | --num_chains 1 --num_hops 4 \ 28 | --template "[INST] Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n{context}\nQuestion: Find all variables that are assigned the value {query} in the text above. [/INST] Answer: According to the chain(s) of variable assignment in the text above, {num_v} variables are assgined the value {query}, they are: " 29 | """ 30 | import os 31 | import argparse 32 | from pathlib import Path 33 | from tqdm import tqdm 34 | import random 35 | import string 36 | from constants import TASKS 37 | import sys 38 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 39 | from tokenizer import select_tokenizer 40 | from utils import dump_jsonl 41 | import numpy as np 42 | 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset') 45 | parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file') 46 | parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test') 47 | parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model') 48 | parser.add_argument("--tokenizer_type", type=str, default='nemo', help='[Options] nemo, hf, openai.') 49 | parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.') 50 | parser.add_argument("--tokens_to_generate", type=int, default=120, help='number of tokens to generate') 51 | parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate') 52 | parser.add_argument("--random_seed", type=int, default=42) 53 | parser.add_argument("--template", type=str, default='', help='prompt template') 54 | parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.') 55 | 56 | parser.add_argument("--num_chains", type=int, default=1, help='number of inserted variable chains') 57 | parser.add_argument("--num_hops", type=int, default=4, help='number of hops in each chain') 58 | parser.add_argument("--add_fewshot", action="store_true", default=False) 59 | 60 | args = parser.parse_args() 61 | random.seed(args.random_seed) 62 | np.random.seed(args.random_seed) 63 | 64 | # Load Tokenizer 65 | TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path) 66 | 67 | def generate_chains(num_chains, num_hops, is_icl=False): 68 | 69 | vars_all = [] 70 | k = 5 if not is_icl else 3 71 | num_hops = num_hops if not is_icl else min(10, num_hops) 72 | vars_all = [''.join(random.choices(string.ascii_uppercase, k=k)).upper() for _ in range((num_hops+1) * num_chains)] 73 | while len(set(vars_all)) < num_chains * (num_hops+1): 74 | vars_all.append(''.join(random.choices(string.ascii_uppercase, k=k)).upper()) 75 | 76 | vars_ret = [] 77 | chains_ret = [] 78 | for i in range(0, len(vars_all), num_hops+1): 79 | this_vars = vars_all[i:i+num_hops+1] 80 | vars_ret.append(this_vars) 81 | this_chain = [f"VAR {this_vars[0]} = {np.random.randint(10000, 99999)}"] 82 | for j in range(num_hops): 83 | this_chain.append(f"VAR {this_vars[j+1]} = VAR {this_vars[j]} ") 84 | chains_ret.append(this_chain) 85 | return vars_ret, chains_ret 86 | 87 | def generate_input_output(num_noises, num_chains, num_hops, is_icl=False): 88 | 89 | vars, chains = generate_chains(num_chains, num_hops, is_icl=is_icl) 90 | 91 | noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n" 92 | 93 | # Create a list of the repeated noise 94 | sentences = [noise] * num_noises 95 | if len(sentences) <= len(chains[0]): 96 | sentences = [n + '.' if len(n.strip()) > 0 else n for n in [x for noise in sentences for x in noise.split('.')] ] 97 | try: 98 | assert len(sentences) > len(chains[0]), "Noises too short, unable to generate data" 99 | except: 100 | print("reduces chain length for not enough noises") 101 | chains = [chain[:len(sentences)-1] for chain in chains] 102 | # sample random positions to insert variable assignment 103 | for chain_i in chains: 104 | # sample random positions (sorted) to insert variable assignment 105 | positions = list(sorted(random.sample(range(len(sentences)), len(chain_i)))) 106 | for insert_pi, j in zip(positions, range(len(chain_i))): 107 | sentences.insert(insert_pi+j, chain_i[j]) 108 | 109 | # Insert the passkey sentence at the random position 110 | context = " ".join(sentences) 111 | context = context.replace(". \n", ".\n") 112 | 113 | template = args.template 114 | if is_icl: 115 | # remove model template 116 | cutoff = template.index(TASKS['variable_tracking']['template'][:20]) 117 | cutoff_ans = template.index(TASKS['variable_tracking']['answer_prefix'][:10]) 118 | template = ' '.join(template[cutoff:cutoff_ans].split()[:-1]) + template[cutoff_ans:] 119 | 120 | value = chains[0][0].split("=")[-1].strip() 121 | input_text = template.format( 122 | context=context, 123 | query=value, 124 | num_v=num_hops+1 125 | ) 126 | 127 | return input_text, vars[0] 128 | 129 | 130 | def sys_vartrack_w_noise_random(num_samples: int, max_seq_length: int, incremental: int = 10, 131 | num_chains: int = 1, num_hops: int = 4, 132 | add_fewshot: bool = True, 133 | icl_example: str = None): 134 | write_jsons = [] 135 | tokens_to_generate = args.tokens_to_generate 136 | 137 | # Find the perfect num_noises 138 | num_noises = incremental 139 | 140 | total_tokens = 0 # Track the total tokens generated for this example 141 | example_tokens = 0 142 | if add_fewshot and (icl_example is not None): 143 | icl_example_out = ' '.join(icl_example['outputs']) 144 | icl_example = icl_example['input'] + " " + icl_example_out + '\n\n' 145 | example_tokens = len(TOKENIZER.text_to_tokens(icl_example)) 146 | 147 | while total_tokens + tokens_to_generate + example_tokens < max_seq_length : 148 | input_text, answer = generate_input_output(num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)) 149 | # Calculate the number of tokens in the example 150 | total_tokens = len(TOKENIZER.text_to_tokens(input_text + f' {answer}')) 151 | print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate + example_tokens} | Noises: {num_noises}') 152 | if total_tokens + tokens_to_generate + example_tokens > max_seq_length: 153 | num_noises -= incremental 154 | break 155 | num_noises += incremental 156 | print('Num noises:', num_noises) 157 | 158 | # Generate samples 159 | for index in tqdm(range(num_samples)): 160 | used_noises = num_noises 161 | while(True): 162 | try: 163 | input_text, answer = generate_input_output(num_noises, num_chains, num_hops, is_icl=add_fewshot & (icl_example is None)) 164 | length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate + example_tokens 165 | assert length <= max_seq_length, f"{length} exceeds max_seq_length." 166 | break 167 | except: 168 | if used_noises > incremental: 169 | used_noises -= incremental 170 | 171 | if add_fewshot and (icl_example is not None): 172 | # insert icl_example between model template and input 173 | cutoff = input_text.index(TASKS['variable_tracking']['template'][:20]) 174 | input_text = input_text[:cutoff] + ' ' + icl_example + '\n\n' + input_text[cutoff:] 175 | if args.remove_newline_tab: 176 | input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split()) 177 | 178 | formatted_output = { 179 | 'index': index, 180 | "input": input_text, 181 | "outputs": answer, 182 | "length": length, 183 | } 184 | write_jsons.append(formatted_output) 185 | 186 | return write_jsons 187 | 188 | 189 | def main(): 190 | save_file = args.save_dir / f'{args.save_name}' / f'{args.subset}.jsonl' 191 | save_file.parent.mkdir(parents=True, exist_ok=True) 192 | 193 | icl_example = sys_vartrack_w_noise_random(num_samples=1, 194 | max_seq_length=500, 195 | incremental=5, 196 | num_chains=args.num_chains, 197 | num_hops=args.num_hops)[0] 198 | write_jsons = sys_vartrack_w_noise_random(num_samples=args.num_samples, 199 | max_seq_length=args.max_seq_length, 200 | num_chains=args.num_chains, 201 | num_hops=args.num_hops, 202 | icl_example=icl_example) 203 | 204 | dump_jsonl(save_file, write_jsons) 205 | 206 | if __name__=="__main__": 207 | main() -------------------------------------------------------------------------------- /benchmark/ruler/data/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | from typing import List 18 | from transformers import AutoTokenizer 19 | 20 | 21 | def select_tokenizer(tokenizer_type, tokenizer_path): 22 | if tokenizer_type == 'hf': 23 | return HFTokenizer(model_path=tokenizer_path) 24 | 25 | 26 | class HFTokenizer: 27 | """ 28 | Tokenizer from HF models 29 | """ 30 | def __init__(self, model_path) -> None: 31 | 32 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 33 | 34 | def text_to_tokens(self, text: str) -> List[str]: 35 | tokens = self.tokenizer.tokenize(text) 36 | return tokens 37 | 38 | def tokens_to_text(self, tokens: List[int]) -> str: 39 | text = self.tokenizer.convert_tokens_to_string(tokens) 40 | return text -------------------------------------------------------------------------------- /benchmark/ruler/data/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def dump_jsonl(fname, data): 5 | with open(fname, "w", encoding="utf8") as fout: 6 | for line in data: 7 | fout.write(json.dumps(line, ensure_ascii=False) + "\n") 8 | 9 | -------------------------------------------------------------------------------- /benchmark/ruler/eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Get summary.csv with score and null predictions amount. 17 | 18 | Running 19 | ``` 20 | python evaluate.py \ 21 | --data_dir /path/to/your/prediction_jsonl_folder \ 22 | --benchmark synthetic 23 | ``` 24 | """ 25 | 26 | import re 27 | import os 28 | import sys 29 | import argparse 30 | import nltk 31 | try: 32 | nltk.data.find('tokenizers/punkt') 33 | except LookupError: 34 | nltk.download('punkt') 35 | 36 | import pandas as pd 37 | import importlib 38 | import yaml 39 | from pathlib import Path 40 | from tqdm import tqdm 41 | from collections import defaultdict 42 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 43 | from utils import load_data, dump_jsonl 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--data_dir", type=str, required=True, help='path to the prediction jsonl files') 47 | parser.add_argument("--benchmark", type=str, default='synthetic', help='Options: [synthetic]') 48 | parser.add_argument("--verbose", type=int, default=0, help='how many lines you want to display.') 49 | args = parser.parse_args() 50 | 51 | 52 | def postprocess_pred(predict_str: str, task_config: dict): 53 | 54 | predict_str = predict_str.strip() 55 | 56 | # Remove all non-printable characters 57 | np_pattern = re.compile(r'[\x00-\x1f]') 58 | predict_str = np_pattern.sub('\n', predict_str).strip() 59 | 60 | return predict_str 61 | 62 | 63 | def get_pred_and_ref( 64 | predictions_file: str, 65 | task_config: dict, 66 | input_field: str = 'input', 67 | references_field: str = 'outputs', 68 | prediction_field: str = 'pred', 69 | metadata_field: str = 'others', 70 | ): 71 | lines = load_data(predictions_file) 72 | 73 | inputs = [] 74 | predicts = [] 75 | references = [] 76 | indices = [] 77 | 78 | for line in tqdm(lines): 79 | input = line[input_field] 80 | predict = line[prediction_field] 81 | predict = postprocess_pred(predict, task_config) 82 | reference = line.get(references_field, [line.get('output', '')]) 83 | index = line[metadata_field].get('id', line['index']) 84 | 85 | inputs.append(input) 86 | predicts.append(predict) 87 | references.append(reference) 88 | indices.append(index) 89 | 90 | return inputs, predicts, references, indices 91 | 92 | def run_evaluation_per_task(task_config: dict, predictions_file: str, verbose: int = 0): 93 | inputs, predicts, references, indices = get_pred_and_ref( 94 | predictions_file=predictions_file, 95 | task_config=task_config, 96 | ) 97 | 98 | task_nulls = f'{sum([len(x)==0 for x in predicts])}/{len(predicts)}' 99 | 100 | if len(references) > 0 and references[0][0] is not None: 101 | task_score = task_config['metric_fn'](predicts, references) 102 | else: 103 | task_score = 0.0 104 | 105 | if verbose != 0: 106 | print('=' * 40) 107 | for i, (input, reference, predict) in enumerate(zip(inputs, references, predicts)): 108 | print(f'Input : {input}') 109 | print(f'Reference : {reference}') 110 | print(f'Prediction: {predict}') 111 | print('=' * 40) 112 | if i > verbose: 113 | break 114 | 115 | return task_score, task_nulls, predicts, indices 116 | 117 | 118 | def write_evaluation(results: dict): 119 | tasks = list(results.keys()) 120 | score = [results[task]['score'] for task in tasks] 121 | nulls = [results[task]['nulls'] for task in tasks] 122 | dfs = [ 123 | ['Tasks'] + tasks, 124 | ['Score'] + score, 125 | ['Nulls'] + nulls, 126 | ] 127 | 128 | output_file = os.path.join(args.data_dir, 'summary.csv' if len(tasks) > 1 else f'summary-{tasks[0]}.csv') 129 | df = pd.DataFrame(dfs) 130 | df.to_csv(output_file, index=False) 131 | print('\n=============================================\n') 132 | print(df) 133 | print(f'\nSaved eval results to {output_file}') 134 | 135 | 136 | def write_submission(results: dict): 137 | COLUMNS = ["Task", "ID", "Prediction"] 138 | dfs = pd.DataFrame(columns=COLUMNS, data=[]) 139 | 140 | for task, result in results.items(): 141 | df = pd.DataFrame({ 142 | 'Task': task, 143 | 'ID': result['indices'], 144 | 'Prediction': result['predicts'] 145 | }) 146 | dfs = pd.concat((dfs, df[COLUMNS])) 147 | 148 | output_file = os.path.join(args.data_dir, 'submission.csv') 149 | dfs = dfs.reset_index(drop=True) 150 | dfs.to_csv(output_file, index=False) 151 | print(f'\nSaved submission results to {output_file}') 152 | 153 | 154 | def aggregate_chunk(folder): 155 | jsonl_files = [file for file in os.listdir(folder) if Path(file).suffix == '.jsonl' ] 156 | chunk_files = sorted([file for file in jsonl_files if re.match(r'.*[^_]+-\d+\.jsonl', file)]) 157 | chunk_files_dict = defaultdict(list) 158 | for file in chunk_files: 159 | task = '-'.join(file.split('-')[:-1]) 160 | chunk_files_dict[task].append(file) 161 | 162 | for task, files in chunk_files_dict.items(): 163 | lines = [] 164 | for file in sorted(files): 165 | file = os.path.join(folder, file) 166 | lines += load_data(file) 167 | os.remove(file) # Remove chunk files 168 | dump_jsonl(os.path.join(folder, f'{task}.jsonl'), lines) 169 | 170 | 171 | def main(): 172 | curr_folder = os.path.dirname(os.path.abspath(__file__)) 173 | 174 | try: 175 | module = importlib.import_module(f"{args.benchmark}.constants") 176 | except ImportError: 177 | print(f"Module eval.{args.benchmark}.constants not found.") 178 | 179 | tasks_base = module.TASKS 180 | with open(os.path.join(curr_folder, f"../{args.benchmark}.yaml"), "r") as f: 181 | tasks_customized = yaml.safe_load(f) 182 | 183 | 184 | TASKS = tasks_customized 185 | for _, config in TASKS.items(): 186 | config.update(tasks_base[config['task']]) 187 | 188 | print(f"Total tasks: {list(TASKS.keys())}") 189 | 190 | # Aggregate all prediction files 191 | aggregate_chunk(args.data_dir) 192 | 193 | # Get scores and nulls 194 | jsonl_files = [file for file in os.listdir(args.data_dir) if Path(file).suffix == '.jsonl'] 195 | eval_results = {} 196 | subm_results = {} 197 | 198 | 199 | for task, config in TASKS.items(): 200 | 201 | if f'{task}.jsonl' not in jsonl_files: 202 | print(f'Prediction file {task}.jsonl is not found.') 203 | continue 204 | 205 | print(f'Evaluate task {task}...') 206 | task_score, task_nulls, predicts, indices = run_evaluation_per_task( 207 | predictions_file=os.path.join(args.data_dir, f'{task}.jsonl'), 208 | task_config=config, 209 | ) 210 | eval_results[task] = { 211 | 'score': task_score, 212 | 'nulls': task_nulls, 213 | } 214 | subm_results[task] = { 215 | 'predicts': predicts, 216 | 'indices':indices, 217 | } 218 | 219 | # Write to csv 220 | write_evaluation(eval_results) 221 | write_submission(subm_results) 222 | 223 | if __name__ == '__main__': 224 | main() -------------------------------------------------------------------------------- /benchmark/ruler/eval/synthetic/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Add a new task: 17 | 18 | TASK_NAME: { 19 | 'metric_fn': the metric function with input (predictions: [str], references: [[str]]) to compute score. 20 | } 21 | """ 22 | 23 | 24 | def string_match_part(preds, refs): 25 | score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100 26 | return round(score, 2) 27 | 28 | def string_match_all(preds, refs): 29 | score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100 30 | return round(score, 2) 31 | 32 | 33 | TASKS = { 34 | 'niah': { 35 | 'metric_fn': string_match_all, 36 | }, 37 | 'variable_tracking': { 38 | 'metric_fn': string_match_all, 39 | }, 40 | 'common_words_extraction': { 41 | 'metric_fn': string_match_all, 42 | }, 43 | 'freq_words_extraction': { 44 | 'metric_fn': string_match_all 45 | }, 46 | 'qa': { 47 | 'metric_fn': string_match_part, 48 | }, 49 | } 50 | -------------------------------------------------------------------------------- /benchmark/ruler/eval/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def dump_jsonl(fname, data): 5 | with open(fname, "w", encoding="utf8") as fout: 6 | for line in data: 7 | fout.write(json.dumps(line, ensure_ascii=False) + "\n") 8 | 9 | def iter_jsonl(fname, cnt=None): 10 | i = 0 11 | with open(fname, "r") as fin: 12 | for line in fin: 13 | if i == cnt: 14 | break 15 | yield json.loads(line) 16 | i += 1 17 | 18 | def load_data(fname): 19 | return list(iter_jsonl(fname)) -------------------------------------------------------------------------------- /benchmark/ruler/pred/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def iter_jsonl(fname, cnt=None): 4 | i = 0 5 | with open(fname, "r") as fin: 6 | for line in fin: 7 | if i == cnt: 8 | break 9 | yield json.loads(line) 10 | i += 1 11 | 12 | def load_data(fname): 13 | return list(iter_jsonl(fname)) 14 | -------------------------------------------------------------------------------- /benchmark/ruler/ruler_config_models.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | MODEL_SELECT() { 16 | MODEL_NAME=$1 17 | 18 | case $MODEL_NAME in 19 | qwen2.5-7b) 20 | MODEL_PATH="Qwen/Qwen2.5-7B-Instruct" 21 | MODEL_TEMPLATE_TYPE="meta-chat" 22 | MODEL_FRAMEWORK="hf" 23 | ;; 24 | llama-3-8b-1048k) 25 | MODEL_PATH="gradientai/Llama-3-8B-Instruct-Gradient-1048k" 26 | MODEL_TEMPLATE_TYPE="meta-chat" 27 | MODEL_FRAMEWORK="hf" 28 | ;; 29 | llama-3.1-8b) 30 | MODEL_PATH="meta-llama/Llama-3.1-8B-Instruct" 31 | MODEL_TEMPLATE_TYPE="meta-chat" 32 | MODEL_FRAMEWORK="hf" 33 | ;; 34 | qwen2.5-72b) 35 | MODEL_PATH="Qwen/Qwen2.5-72B-Instruct" 36 | MODEL_TEMPLATE_TYPE="meta-chat" 37 | MODEL_FRAMEWORK="hf" 38 | ;; 39 | esac 40 | 41 | 42 | TOKENIZER_PATH=${MODEL_PATH} 43 | TOKENIZER_TYPE="hf" 44 | 45 | echo "$MODEL_PATH:$MODEL_TEMPLATE_TYPE:$MODEL_FRAMEWORK:$TOKENIZER_PATH:$TOKENIZER_TYPE" 46 | } -------------------------------------------------------------------------------- /benchmark/ruler/ruler_config_tasks.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | REMOVE_NEWLINE_TAB=false 16 | STOP_WORDS="" 17 | 18 | if [ -z "${STOP_WORDS}" ]; then 19 | STOP_WORDS="" 20 | else 21 | STOP_WORDS="--stop_words \"${STOP_WORDS}\"" 22 | fi 23 | 24 | if [ "${REMOVE_NEWLINE_TAB}" = false ]; then 25 | REMOVE_NEWLINE_TAB="" 26 | else 27 | REMOVE_NEWLINE_TAB="--remove_newline_tab" 28 | fi 29 | 30 | # task name in `synthetic.yaml` 31 | synthetic=( 32 | "niah_single_1" 33 | "niah_single_2" 34 | "niah_single_3" 35 | "niah_multikey_1" 36 | "niah_multikey_2" 37 | "niah_multikey_3" 38 | "niah_multivalue" 39 | "niah_multiquery" 40 | "vt" 41 | "cwe" 42 | "fwe" 43 | "qa_1" 44 | "qa_2" 45 | ) 46 | -------------------------------------------------------------------------------- /benchmark/ruler/ruler_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | if [ $# -ne 8 ]; then 18 | echo "Usage: $0 $1 $2 $3 $4 $5 $6 $7 " 19 | exit 1 20 | fi 21 | 22 | # Root Directories 23 | ROOT_DIR="./ruler_eval_result" # the path that stores generated task samples and model predictions. 24 | 25 | NUM_SAMPLES=200 26 | MAX_SEQ_LENGTH=${4} 27 | ATTN_TYPE=${3} 28 | DEVICE=auto 29 | BUDGET_RATIO=${7} 30 | ESTIMATE_RATIO=${8} 31 | 32 | # Model and Tokenizer 33 | source ruler_config_models.sh 34 | MODEL_NAME=${1} 35 | MODEL_CONFIG=$(MODEL_SELECT ${MODEL_NAME}) 36 | IFS=":" read MODEL_NAME MODEL_TEMPLATE_TYPE MODEL_FRAMEWORK TOKENIZER_PATH TOKENIZER_TYPE <<< "$MODEL_CONFIG" 37 | if [ -z "${MODEL_NAME}" ]; then 38 | echo "Model: ${MODEL_NAME} is not supported" 39 | exit 1 40 | fi 41 | 42 | # Benchmark and Tasks 43 | source ruler_config_tasks.sh 44 | BENCHMARK=${2} 45 | declare -n TASKS=$BENCHMARK 46 | if [ -z "${TASKS}" ]; then 47 | echo "Benchmark: ${BENCHMARK} is not supported" 48 | exit 1 49 | fi 50 | 51 | # Start client (prepare data / call model API / obtain final metrics) 52 | 53 | RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME}/${BENCHMARK}/${MAX_SEQ_LENGTH}/${ATTN_TYPE}" 54 | DATA_DIR="${RESULTS_DIR}/data" 55 | PRED_DIR="${RESULTS_DIR}/pred" 56 | mkdir -p ${DATA_DIR} 57 | mkdir -p ${PRED_DIR} 58 | 59 | TASK=${5} 60 | python -u data/prepare.py \ 61 | --save_dir ${DATA_DIR} \ 62 | --benchmark ${BENCHMARK} \ 63 | --task ${TASK} \ 64 | --tokenizer_path ${TOKENIZER_PATH} \ 65 | --tokenizer_type ${TOKENIZER_TYPE} \ 66 | --max_seq_length ${MAX_SEQ_LENGTH} \ 67 | --model_template_type ${MODEL_TEMPLATE_TYPE} \ 68 | --num_samples ${NUM_SAMPLES} \ 69 | ${REMOVE_NEWLINE_TAB} 70 | 71 | DTYPE=${6} 72 | python -u pred/call_api.py \ 73 | --model_name ${MODEL_NAME} \ 74 | --attn_type ${ATTN_TYPE} \ 75 | --max_len ${MAX_SEQ_LENGTH} \ 76 | --batch_size 1 \ 77 | --data_dir ${DATA_DIR} \ 78 | --save_dir ${PRED_DIR} \ 79 | --benchmark ${BENCHMARK} \ 80 | --task ${TASK} \ 81 | --dtype ${DTYPE} \ 82 | --server_type ${MODEL_FRAMEWORK} \ 83 | --device ${DEVICE} \ 84 | --budget_ratio ${BUDGET_RATIO} \ 85 | --estimate_ratio ${ESTIMATE_RATIO} \ 86 | --synthetic_len ${MAX_SEQ_LENGTH} \ 87 | 88 | python -u eval/evaluate.py \ 89 | --data_dir ${PRED_DIR} \ 90 | --benchmark ${BENCHMARK} 91 | 92 | -------------------------------------------------------------------------------- /benchmark/ruler/synthetic.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | niah_single_1: 16 | task: niah 17 | args: 18 | type_haystack: repeat 19 | type_needle_k: words 20 | type_needle_v: numbers 21 | num_needle_k: 1 22 | num_needle_v: 1 23 | num_needle_q: 1 24 | 25 | niah_single_2: 26 | task: niah 27 | args: 28 | type_haystack: essay 29 | type_needle_k: words 30 | type_needle_v: numbers 31 | num_needle_k: 1 32 | num_needle_v: 1 33 | num_needle_q: 1 34 | 35 | niah_single_3: 36 | task: niah 37 | args: 38 | type_haystack: essay 39 | type_needle_k: words 40 | type_needle_v: uuids 41 | num_needle_k: 1 42 | num_needle_v: 1 43 | num_needle_q: 1 44 | 45 | niah_multikey_1: 46 | task: niah 47 | args: 48 | type_haystack: essay 49 | type_needle_k: words 50 | type_needle_v: numbers 51 | num_needle_k: 4 52 | num_needle_v: 1 53 | num_needle_q: 1 54 | 55 | niah_multikey_2: 56 | task: niah 57 | args: 58 | type_haystack: needle 59 | type_needle_k: words 60 | type_needle_v: numbers 61 | num_needle_k: 1 62 | num_needle_v: 1 63 | num_needle_q: 1 64 | 65 | niah_multikey_3: 66 | task: niah 67 | args: 68 | type_haystack: needle 69 | type_needle_k: uuids 70 | type_needle_v: uuids 71 | num_needle_k: 1 72 | num_needle_v: 1 73 | num_needle_q: 1 74 | 75 | niah_multivalue: 76 | task: niah 77 | args: 78 | type_haystack: essay 79 | type_needle_k: words 80 | type_needle_v: numbers 81 | num_needle_k: 1 82 | num_needle_v: 4 83 | num_needle_q: 1 84 | 85 | niah_multiquery: 86 | task: niah 87 | args: 88 | type_haystack: essay 89 | type_needle_k: words 90 | type_needle_v: numbers 91 | num_needle_k: 1 92 | num_needle_v: 1 93 | num_needle_q: 4 94 | 95 | vt: 96 | task: variable_tracking 97 | args: 98 | num_chains: 1 99 | num_hops: 4 100 | 101 | cwe: 102 | task: common_words_extraction 103 | args: 104 | freq_cw: 30 105 | freq_ucw: 3 106 | num_cw: 10 107 | 108 | fwe: 109 | task: freq_words_extraction 110 | args: 111 | alpha: 2.0 112 | 113 | qa_1: 114 | task: qa 115 | args: 116 | dataset: squad 117 | 118 | qa_2: 119 | task: qa 120 | args: 121 | dataset: hotpotqa -------------------------------------------------------------------------------- /cache_hub/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .flash_attn_cache import flash_attn_cache 3 | from .retroinfer_cache import retroinfer_cache -------------------------------------------------------------------------------- /cache_hub/cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KV_Cache: 5 | """ 6 | A class representing the KV Cache. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | layer_num: int, 12 | batch_size: int, 13 | max_length: int, 14 | num_key_value_heads: int, 15 | num_heads: int, 16 | head_dim: int, 17 | dtype: torch.dtype, 18 | layer_mapping: dict, 19 | num_gpus: int, 20 | model_size: int 21 | ) -> None: 22 | """ Initializes the KV Cache. 23 | Args: 24 | layer_num (int) 25 | batch_size (int) 26 | num_key_value_heads (int) 27 | num_heads (int) 28 | max_length (int) 29 | head_dim (int) 30 | dtype (torch.dtype) 31 | layer_mapping (dict) 32 | num_gpus (int) 33 | model_size (int) 34 | """ 35 | 36 | self.layer_num = layer_num 37 | self.batch_size = batch_size 38 | self.max_length = max_length 39 | self.kv_head = num_key_value_heads 40 | self.num_heads = num_heads 41 | self.head_dim = head_dim 42 | self.dtype = dtype 43 | self.layer_mapping = layer_mapping 44 | self.context = 0 45 | 46 | # estimate free GPU memory when prefilling 47 | self.num_gpus = num_gpus 48 | self.model_size = model_size 49 | # total gpu memory 50 | total_gpu_memory = self.num_gpus * torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024 51 | # model weight consumption 52 | model_weight_consumption = self.model_size * 2 53 | # prefill consumption for single GPU 54 | prefill_consumption = self.num_heads * self.max_length * self.head_dim * 2 / 1024 / 1024 / 1024 # hidden 55 | prefill_consumption += self.num_heads * self.max_length * self.head_dim * 2 / 1024 / 1024 / 1024 # residual 56 | prefill_consumption += (self.num_heads + 2*self.kv_head) * self.max_length * self.head_dim * 2 / 1024 / 1024 / 1024 # qkv 57 | prefill_consumption += 4 * self.num_heads * self.max_length * self.head_dim * 2 / 1024 / 1024 / 1024 # temp 58 | # free memory during prefill 59 | self.free_memory = total_gpu_memory - model_weight_consumption - prefill_consumption*self.num_gpus -------------------------------------------------------------------------------- /cache_hub/flash_attn_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .cache import KV_Cache 3 | 4 | 5 | class flash_attn_cache(KV_Cache): 6 | """ 7 | A class representing the KV Cache of Full flash-attn. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | valid_start, 13 | layer_num: int, 14 | batch_size: int, 15 | max_length: int, 16 | num_key_value_heads: int, 17 | num_heads: int, 18 | head_dim: int, 19 | dtype: torch.dtype, 20 | layer_mapping: dict, 21 | num_gpus: int, 22 | model_size: int 23 | ) -> None: 24 | super().__init__(layer_num, batch_size, max_length, num_key_value_heads, num_heads, head_dim, dtype, layer_mapping, num_gpus, model_size) 25 | 26 | self.valid_start = valid_start # start index of valid tokens for each batch 27 | self.valid_length = None # valid seq length of each batch 28 | self.batch_indices = torch.arange(self.batch_size, dtype=torch.int32, device=self.layer_mapping[str(0)]) 29 | 30 | self.allocated = self.pre_allocate_decision() 31 | 32 | if self.allocated: 33 | self.key_cache = [ 34 | torch.empty( 35 | self.batch_size, 36 | self.max_length, 37 | self.kv_head, 38 | self.head_dim, 39 | device=self.layer_mapping[str(ldx)], 40 | dtype=self.dtype 41 | ) for ldx in range(self.layer_num) 42 | ] 43 | self.value_cache = [ 44 | torch.empty( 45 | self.batch_size, 46 | self.max_length, 47 | self.kv_head, 48 | self.head_dim, 49 | device=self.layer_mapping[str(ldx)], 50 | dtype=self.dtype 51 | ) for ldx in range(self.layer_num) 52 | ] 53 | else: 54 | self.key_cache = [ 55 | torch.empty( 56 | self.batch_size, 57 | self.max_length, 58 | self.kv_head, 59 | self.head_dim, 60 | device='cpu', 61 | pin_memory=True, 62 | dtype=self.dtype 63 | ) for _ in range(self.layer_num) 64 | ] 65 | self.value_cache = [ 66 | torch.empty( 67 | self.batch_size, 68 | self.max_length, 69 | self.kv_head, 70 | self.head_dim, 71 | device='cpu', 72 | pin_memory=True, 73 | dtype=self.dtype 74 | ) for _ in range(self.layer_num) 75 | ] 76 | 77 | self.copystream = torch.cuda.Stream() 78 | self.copyevents = {} 79 | self.KVreadyevents = {} 80 | device_list = sorted(set(self.layer_mapping.values()), key=lambda x: int(x.split(':')[-1])) 81 | for device_idx in device_list: 82 | with torch.cuda.device(device_idx): 83 | self.copyevents[device_idx] = torch.cuda.Event() 84 | self.KVreadyevents[device_idx] = torch.cuda.Event() 85 | 86 | # decide whether to pre-allocate GPU memory before prefilling 87 | def pre_allocate_decision(self): 88 | kv_consumption = (2 * self.layer_num * self.batch_size * self.max_length * self.kv_head * self.head_dim * 2) / 1024 / 1024 / 1024 89 | return self.free_memory > kv_consumption * 1.3 90 | 91 | def move_gpu(self): 92 | if not self.allocated: 93 | for ldx in range(self.layer_num): 94 | self.key_cache[ldx] = self.key_cache[ldx].to(self.layer_mapping[str(ldx)]) 95 | self.value_cache[ldx] = self.value_cache[ldx].to(self.layer_mapping[str(ldx)]) 96 | 97 | 98 | def prefill_update_kv_cache(self, query_states, key_states, value_states, layer_idx, start_bdx): 99 | """ 100 | update part of batches keys and values, start from start_bdx 101 | Args: 102 | query_states: (bsz, seq_len, num_heads, head_dim) 103 | key_states: (bsz, seq_len, kv_head, head_dim) 104 | value_states: (bsz, seq_len, kv_head, head_dim) 105 | layer_idx: the index of the layer 106 | start_bdx: the start index of the batch (=batch_idx) 107 | """ 108 | bsz, seq_len, _, _ = key_states.shape 109 | assert bsz == 1, f"Multi-batch prefilling only support prefill single batch one by one." 110 | assert seq_len <= self.max_length, f"Prefilling sequence length {seq_len} exceeds max length {self.max_length}." 111 | 112 | self.KVreadyevents[self.layer_mapping[str(layer_idx)]].record() 113 | 114 | if self.valid_length is None: 115 | self.valid_length = torch.from_numpy(seq_len - self.valid_start).to(torch.int32).to(self.layer_mapping[str(0)]) 116 | 117 | _valid_start = self.valid_start[start_bdx] 118 | _valid_length = seq_len - _valid_start 119 | 120 | with torch.cuda.stream(self.copystream): 121 | self.KVreadyevents[self.layer_mapping[str(layer_idx)]].wait() # wait for KV ready 122 | self.key_cache[layer_idx][start_bdx:start_bdx+bsz, :_valid_length, :, :].copy_(key_states[:, _valid_start:, :, :], non_blocking=True) 123 | self.value_cache[layer_idx][start_bdx:start_bdx+bsz, :_valid_length, :, :].copy_(value_states[:, _valid_start:, :, :], non_blocking=True) 124 | self.copyevents[self.layer_mapping[str(layer_idx)]].record() 125 | 126 | if (layer_idx == self.layer_num - 1) and (start_bdx + bsz == self.batch_size): 127 | self.context += seq_len 128 | 129 | return key_states[:, _valid_start:, :, :], value_states[:, _valid_start:, :, :] 130 | 131 | def sync(self, layer_idx, start_bdx): 132 | self.copyevents[self.layer_mapping[str(layer_idx)]].wait() # wait for copy done 133 | 134 | 135 | def decode_update_kv_cache(self, key_states, value_states, layer_idx): 136 | """ 137 | update all batch of the key and value cache for decoding 138 | Args: 139 | key_states: (bsz, seq_len(=1), kv_head, head_dim) 140 | value_states: (bsz, seq_len(=1), kv_head, head_dim) 141 | layer_idx: the index of the layer 142 | """ 143 | 144 | self.key_cache[layer_idx][self.batch_indices, self.valid_length, :, :] = key_states[:, 0, :, :] 145 | self.value_cache[layer_idx][self.batch_indices, self.valid_length, :, :] = value_states[:, 0, :, :] 146 | 147 | if layer_idx == self.layer_num - 1: 148 | self.context += 1 149 | self.valid_length += 1 150 | 151 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 152 | -------------------------------------------------------------------------------- /cache_hub/kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | 5 | 6 | @triton.jit 7 | def _triton_assign_kernel( 8 | K, X, S, C, M, # data, centroids, data_sum, data_cnt, max_idx 9 | stride_kz, stride_kn, stride_kd, 10 | stride_xz, stride_xk, stride_xd, 11 | stride_sz, stride_sk, stride_sd, 12 | stride_cz, stride_ck, 13 | stride_mz, stride_mn, 14 | num_tokens, num_centroids, 15 | BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_D: tl.constexpr, 16 | ): 17 | start_n = tl.program_id(0) * BLOCK_N 18 | batch_idx = tl.program_id(1) 19 | 20 | if start_n >= num_tokens: 21 | return 22 | 23 | offs_n = start_n + tl.arange(0, BLOCK_N) 24 | offs_k = tl.arange(0, BLOCK_K) 25 | offs_d = tl.arange(0, BLOCK_D) 26 | n_mask = offs_n < num_tokens 27 | 28 | k_ptrs = K + batch_idx * stride_kz + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd 29 | x_ptrs = X + batch_idx * stride_xz + offs_k[None, :] * stride_xk + offs_d[:, None] * stride_xd 30 | s_ptrs = S + batch_idx * stride_sz + offs_d[None, :] * stride_sd 31 | c_ptrs = C + batch_idx * stride_cz 32 | m_ptrs = M + batch_idx * stride_mz + offs_n * stride_mn 33 | 34 | k = tl.load(k_ptrs, mask=n_mask[:, None], other=0.) # [BLOCK_N, BLOCK_D] 35 | max_val = tl.zeros([BLOCK_N], dtype=tl.float32) - float("inf") 36 | max_idx = tl.zeros([BLOCK_N], dtype=tl.int32) 37 | 38 | for start_k in tl.range(0, num_centroids, BLOCK_K): 39 | x = tl.load(x_ptrs) # [BLOCK_D, BLOCK_K] 40 | ip = tl.dot(k, x).to(tl.float32) # [BLOCK_N, BLOCK_K] 41 | tmp_max_val, tmp_max_idx = tl.max(ip, axis=1, return_indices=True) 42 | tmp_max_idx += start_k 43 | max_idx = tl.where(tmp_max_val > max_val, tmp_max_idx, max_idx) 44 | max_val = tl.maximum(tmp_max_val, max_val) 45 | x_ptrs += BLOCK_K * stride_xk 46 | 47 | tl.store(m_ptrs, max_idx, mask=n_mask) 48 | tl.atomic_add(s_ptrs + max_idx[:, None] * stride_sk, k.to(tl.float32), mask=n_mask[:, None], sem='relaxed') 49 | tl.atomic_add(c_ptrs + max_idx * stride_ck, tl.zeros_like(max_idx) + 1, mask=n_mask, sem='relaxed') 50 | 51 | 52 | @triton.jit 53 | def _triton_update_kernel( 54 | X, S, C, # centroids, data_sum, data_cnt 55 | stride_xz, stride_xk, stride_xd, 56 | stride_sz, stride_sk, stride_sd, 57 | stride_cz, stride_ck, 58 | num_centroids, 59 | BLOCK_K: tl.constexpr, BLOCK_D: tl.constexpr, 60 | NORMORLIZE: tl.constexpr, 61 | ): 62 | start_k = tl.program_id(0) * BLOCK_K 63 | batch_idx = tl.program_id(1) 64 | 65 | offs_k = start_k + tl.arange(0, BLOCK_K) 66 | offs_d = tl.arange(0, BLOCK_D) 67 | k_mask = offs_k < num_centroids 68 | 69 | x_ptrs = X + batch_idx * stride_xz + offs_k[:, None] * stride_xk + offs_d[None, :] * stride_xd 70 | s_ptrs = S + batch_idx * stride_sz + offs_k[:, None] * stride_sk + offs_d[None, :] * stride_sd 71 | c_ptrs = C + batch_idx * stride_cz + offs_k[:, None] * stride_ck 72 | 73 | s = tl.load(s_ptrs, mask=k_mask[:, None], other=0.) # [BLOCK_K, BLOCK_D] 74 | c = tl.load(c_ptrs, mask=k_mask[:, None], other=0) 75 | x_mask = c > 0 76 | x = s / c 77 | if NORMORLIZE: 78 | x /= tl.sqrt(tl.sum(x * x, axis=-1, keep_dims=True)) 79 | 80 | tl.store(x_ptrs, x.to(X.type.element_ty), mask=x_mask) 81 | 82 | 83 | def _triton_k_means_train( 84 | data: torch.Tensor, # [batch_size, num_tokens, dim] 85 | centroids: torch.Tensor, # [batch_size, num_centroids, dim] 86 | max_idx: torch.Tensor = None, # [batch_size, num_tokens] 87 | normalize_centroids: bool = True, 88 | return_indices: bool = False, 89 | ): 90 | batch_size, num_tokens, dim = data.shape 91 | num_centroids = centroids.shape[1] 92 | data_sum = torch.zeros_like(centroids, dtype=torch.float32) 93 | data_cnt = torch.zeros((batch_size, num_centroids), dtype=torch.int32, device=data.device) 94 | if max_idx is None: 95 | max_idx = torch.empty((batch_size, num_tokens), dtype=torch.int32, device=data.device) 96 | # assert max_idx.shape == (batch_size, num_tokens) 97 | block_N = 128 98 | block_K = 32 99 | assert num_centroids % block_K == 0 100 | # assert dim in [32, 64, 128] 101 | _triton_assign_kernel[(triton.cdiv(num_tokens, block_N), batch_size, 1)]( 102 | data, centroids, data_sum, data_cnt, max_idx, 103 | data.stride(0), data.stride(1), data.stride(2), 104 | centroids.stride(0), centroids.stride(1), centroids.stride(2), 105 | data_sum.stride(0), data_sum.stride(1), data_sum.stride(2), 106 | data_cnt.stride(0), data_cnt.stride(1), 107 | max_idx.stride(0), max_idx.stride(1), 108 | num_tokens, num_centroids, 109 | BLOCK_N=block_N, BLOCK_K=block_K, BLOCK_D=dim, 110 | num_warps=4, num_stages=2, 111 | ) 112 | block_K = 128 113 | _triton_update_kernel[(triton.cdiv(num_centroids, block_K), batch_size, 1)]( 114 | centroids, data_sum, data_cnt, 115 | centroids.stride(0), centroids.stride(1), centroids.stride(2), 116 | data_sum.stride(0), data_sum.stride(1), data_sum.stride(2), 117 | data_cnt.stride(0), data_cnt.stride(1), 118 | num_centroids, 119 | BLOCK_K=block_K, BLOCK_D=dim, 120 | NORMORLIZE=normalize_centroids, 121 | num_warps=4, num_stages=1, 122 | ) 123 | if return_indices: 124 | return centroids, max_idx, data_cnt.max().item() 125 | return centroids 126 | 127 | 128 | @triton.jit 129 | def _triton_reverse_index_kernel( 130 | M, I, C, # max_idx, clusters, cluster_size 131 | stride_mz, stride_mn, 132 | stride_iz, stride_ik, stride_in, 133 | stride_cz, stride_ck, 134 | num_tokens, 135 | BLOCK_N: tl.constexpr, 136 | ): 137 | start_n = tl.program_id(0) * BLOCK_N 138 | batch_idx = tl.program_id(1) 139 | 140 | if start_n >= num_tokens: 141 | return 142 | 143 | offs_n = start_n + tl.arange(0, BLOCK_N) 144 | n_mask = offs_n < num_tokens 145 | 146 | m_ptrs = M + batch_idx * stride_mz + offs_n * stride_mn 147 | i_ptrs = I + batch_idx * stride_iz 148 | c_ptrs = C + batch_idx * stride_cz 149 | 150 | max_idx = tl.load(m_ptrs, mask=n_mask, other=0) 151 | cnt = tl.atomic_add(c_ptrs + max_idx * stride_ck, tl.zeros_like(max_idx) + 1, mask=n_mask, sem='relaxed') 152 | tl.store(i_ptrs + max_idx * stride_ik + cnt * stride_in, offs_n, mask=n_mask) 153 | 154 | 155 | def triton_reverse_index( 156 | max_idx: torch.Tensor, # [batch_size, num_tokens] 157 | num_centroids: int, 158 | max_cluster_size: int, 159 | ): 160 | batch_size, num_tokens = max_idx.shape 161 | clusters = torch.zeros((batch_size, num_centroids, max_cluster_size), dtype=torch.int32, device=max_idx.device) 162 | cluster_size = torch.zeros((batch_size, num_centroids), dtype=torch.int32, device=max_idx.device) 163 | block_N = 128 164 | _triton_reverse_index_kernel[(triton.cdiv(num_tokens, block_N), batch_size, 1)]( 165 | max_idx, clusters, cluster_size, 166 | max_idx.stride(0), 167 | max_idx.stride(1), 168 | clusters.stride(0), clusters.stride(1), clusters.stride(2), 169 | cluster_size.stride(0), cluster_size.stride(1), 170 | num_tokens, BLOCK_N=block_N, 171 | num_warps=4, num_stages=1, 172 | ) 173 | return clusters, cluster_size 174 | 175 | 176 | @triton.jit 177 | def _triton_index_add_kernel( 178 | V, S, M, # data, sum, max_idx 179 | stride_vz, stride_vn, stride_vd, 180 | stride_sz, stride_sk, stride_sd, 181 | stride_mz, stride_mn, 182 | num_tokens, 183 | BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, 184 | ): 185 | start_n = tl.program_id(0) * BLOCK_N 186 | batch_idx = tl.program_id(1) 187 | 188 | if start_n >= num_tokens: 189 | return 190 | 191 | offs_n = start_n + tl.arange(0, BLOCK_N) 192 | offs_d = tl.arange(0, BLOCK_D) 193 | n_mask = offs_n < num_tokens 194 | 195 | v_ptrs = V + batch_idx * stride_vz + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd 196 | s_ptrs = S + batch_idx * stride_sz + offs_d[None, :] * stride_sd 197 | m_ptrs = M + batch_idx * stride_mz + offs_n * stride_mn 198 | 199 | v = tl.load(v_ptrs, mask=n_mask[:, None], other=0.) 200 | max_idx = tl.load(m_ptrs, mask=n_mask, other=0) 201 | 202 | tl.atomic_add(s_ptrs + max_idx[:, None] * stride_sk, v.to(S.type.element_ty), mask=n_mask[:, None], sem='relaxed') 203 | 204 | 205 | def triton_index_add( 206 | value: torch.Tensor, # [batch_size, num_tokens, head_dim] 207 | max_idx: torch.Tensor, # [batch_size, num_tokens] 208 | num_centroids: int, 209 | ): 210 | batch_size, num_tokens, dim = value.shape 211 | value_sum = torch.zeros((batch_size, num_centroids, dim), dtype=torch.float32, device=value.device) 212 | block_N = 128 213 | _triton_index_add_kernel[(triton.cdiv(num_tokens, block_N), batch_size, 1)]( 214 | value, value_sum, max_idx, 215 | value.stride(0), value.stride(1), value.stride(2), 216 | value_sum.stride(0), value_sum.stride(1), value_sum.stride(2), 217 | max_idx.stride(0), max_idx.stride(1), 218 | num_tokens, 219 | BLOCK_N=block_N, BLOCK_D=dim, 220 | ) 221 | return value_sum.to(value.dtype) 222 | 223 | 224 | def segment_k_means( 225 | key: torch.Tensor, # [batch_size(=1)*num_heads, num_tokens, head_dim] 226 | value: torch.Tensor, # [batch_size(=1)*num_heads, num_tokens, head_dim] 227 | num_centroids: int, 228 | num_iters: int = 10, 229 | num_segments: int = 1 230 | ): 231 | num_groups, num_tokens, head_dim = key.shape 232 | 233 | # initialize centroids uniformly 234 | centroid_indices = torch.arange(num_centroids, dtype=torch.float32, device=key.device) * (num_tokens / num_centroids) 235 | centroid_indices += num_tokens / num_centroids / 2 236 | centroid_indices = centroid_indices.to(torch.int64) 237 | centroids = torch.index_select(key, dim=1, index=centroid_indices) 238 | 239 | assert num_centroids % num_segments == 0 240 | num_tokens_per_segment = num_tokens // num_segments 241 | num_centroids_per_segment = num_centroids // num_segments 242 | data = key[:, :num_tokens_per_segment * num_segments].reshape((-1, num_tokens_per_segment, head_dim)) 243 | centroids = centroids.reshape((-1, num_centroids_per_segment, head_dim)) 244 | max_idx = torch.empty((data.shape[0], data.shape[1]), dtype=torch.int32, device=data.device) 245 | for _ in range(num_iters - 1): 246 | centroids = _triton_k_means_train(data, centroids, max_idx=max_idx, normalize_centroids=True, return_indices=False) 247 | 248 | data = key.reshape((-1, num_tokens, head_dim)) 249 | centroids = centroids.reshape((-1, num_centroids, head_dim)) 250 | centroids, max_idx, max_cluster_size = _triton_k_means_train(data, centroids, normalize_centroids=False, return_indices=True) 251 | 252 | value_sum = triton_index_add(value.reshape((-1, num_tokens, head_dim)), max_idx, num_centroids) 253 | clusters, cluster_size = triton_reverse_index(max_idx, num_centroids, max_cluster_size) 254 | 255 | # centroids = centroids.reshape((batch_size*num_groups, num_centroids, head_dim)) 256 | # value_sum = value_sum.reshape((batch_size*num_groups, num_centroids, head_dim)) 257 | # clusters = clusters.reshape((batch_size*num_groups, num_centroids, max_cluster_size)) 258 | # cluster_size = cluster_size.reshape((batch_size*num_groups, num_centroids)) 259 | return centroids, value_sum, clusters, cluster_size 260 | -------------------------------------------------------------------------------- /config/Llama-3-8B-Instruct-Gradient-1048k.json: -------------------------------------------------------------------------------- 1 | { 2 | "RetroInfer": { 3 | "static_pattern_start": 4, 4 | "static_pattern_end": 64, 5 | "core": 22, 6 | "n_centroids": 8192, 7 | "n_segment": 16, 8 | "nprobe": 150, 9 | "max_compute_cluster_num": 2048, 10 | "cache_unit_size": 8, 11 | "cache_cluster_num": 450 12 | } 13 | } -------------------------------------------------------------------------------- /config/Llama-3.1-8B-Instruct.json: -------------------------------------------------------------------------------- 1 | { 2 | "RetroInfer": { 3 | "static_pattern_start": 4, 4 | "static_pattern_end": 64, 5 | "core": 22, 6 | "n_centroids": 8192, 7 | "n_segment": 16, 8 | "nprobe": 150, 9 | "max_compute_cluster_num": 2048, 10 | "cache_unit_size": 8, 11 | "cache_cluster_num": 450 12 | } 13 | } -------------------------------------------------------------------------------- /config/Qwen2.5-72B-Instruct.json: -------------------------------------------------------------------------------- 1 | { 2 | "RetroInfer": { 3 | "static_pattern_start": 4, 4 | "static_pattern_end": 64, 5 | "core": 22, 6 | "n_centroids": 8192, 7 | "n_segment": 16, 8 | "nprobe": 150, 9 | "max_compute_cluster_num": 2048, 10 | "cache_unit_size": 8, 11 | "cache_cluster_num": 450 12 | } 13 | } -------------------------------------------------------------------------------- /config/Qwen2.5-7B-Instruct.json: -------------------------------------------------------------------------------- 1 | { 2 | "RetroInfer": { 3 | "static_pattern_start": 4, 4 | "static_pattern_end": 64, 5 | "core": 22, 6 | "n_centroids": 8192, 7 | "n_segment": 16, 8 | "nprobe": 150, 9 | "max_compute_cluster_num": 2048, 10 | "cache_unit_size": 8, 11 | "cache_cluster_num": 450 12 | } 13 | } -------------------------------------------------------------------------------- /library/retroinfer/retroinfer_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .WaveBuffer import * 2 | from .Copy import * 3 | from .gemm_softmax import * 4 | -------------------------------------------------------------------------------- /library/retroinfer/retroinfer_kernels/src/batch_gemm_softmax.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "cutlass/gemm/device/gemm_universal.h" 20 | #include "cutlass/util/reference/device/gemm.h" 21 | #include "helper.h" 22 | 23 | #include "cutlass/arch/memory.h" 24 | #include "cutlass/arch/memory_sm80.h" 25 | #include "cutlass/cutlass.h" 26 | #include "cutlass/gemm/device/gemm_complex.h" 27 | 28 | #include "cutlass/util/command_line.h" 29 | #include "cutlass/util/host_tensor.h" 30 | 31 | #include "cutlass/util/reference/device/gemm_complex.h" 32 | #include "cutlass/util/reference/device/tensor_fill.h" 33 | #include "cutlass/util/reference/host/error_metrics.h" 34 | #include "cutlass/util/reference/host/gemm_complex.h" 35 | #include "cutlass/util/reference/host/tensor_compare.h" 36 | #include "cutlass/util/reference/host/tensor_copy.h" 37 | #include "cutlass/util/reference/host/tensor_fill.h" 38 | #include "cutlass/util/reference/host/tensor_norm.h" 39 | #include "cutlass/util/reference/host/tensor_reduce.h" 40 | #include "cutlass/util/tensor_view_io.h" 41 | 42 | #include "cutlass/epilogue/thread/linear_combination.h" 43 | #include "cutlass/layout/matrix.h" 44 | 45 | #include "batch_gemm_softmax.h" 46 | 47 | 48 | template 49 | void batch_gemm_softmax_impl( 50 | torch::Tensor A, 51 | torch::Tensor B, 52 | torch::Tensor D, 53 | torch::Tensor Norm, 54 | torch::Tensor Sum, 55 | torch::Tensor Softmax, 56 | int batch_count, 57 | int m, 58 | int n, 59 | int k, 60 | float alpha = 1.0, 61 | float beta = 0.0 62 | ) { 63 | /// GEMM types 64 | using ElementA = T; 65 | using ElementB = T; 66 | using ElementC = T; 67 | using ElementCompute = float; 68 | using ElementD = ElementC; 69 | /// Softmax types 70 | using ElementSoftmax = ElementC; 71 | using ElementSoftmaxCompute = float; 72 | using ElementNorm = float; 73 | using ElementSum = float; 74 | 75 | using LayoutA = cutlass::layout::RowMajor; 76 | using LayoutB = cutlass::layout::ColumnMajor; 77 | 78 | static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; 79 | static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; 80 | static constexpr int AlignmentSoftmax = 128 / cutlass::sizeof_bits::value; 81 | 82 | using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 32>; 83 | using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; 84 | using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; 85 | 86 | using OperatorClass = cutlass::arch::OpClassTensorOp; 87 | using ArchTag = cutlass::arch::Sm80; 88 | 89 | // ApplyShape for final Softmax. 90 | using ApplyShape = cutlass::MatrixShape<1, 1024>; 91 | static int const kStages = 4; 92 | 93 | /// Linear scaling operator 94 | using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< 95 | ElementC, 96 | 128 / cutlass::sizeof_bits::value, 97 | ElementCompute, 98 | ElementCompute, 99 | cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling 100 | >; 101 | 102 | using BatchGemmSoftmax = cutlass::BatchGemmSoftmax< 103 | ElementA, LayoutA, 104 | ElementB, LayoutB, 105 | ElementC, 106 | ElementCompute, 107 | OperatorClass, 108 | ArchTag, 109 | ThreadblockShape, 110 | WarpShape, 111 | InstructionShape, 112 | EpilogueFunctorOp, 113 | kStages, 114 | ApplyShape, 115 | AlignmentA, 116 | AlignmentB, 117 | AlignmentSoftmax, 118 | ElementNorm, 119 | ElementSum, 120 | ElementSoftmax, 121 | ElementSoftmaxCompute 122 | >; 123 | 124 | using LayoutC = typename BatchGemmSoftmax::LayoutC; 125 | using LayoutN = typename BatchGemmSoftmax::LayoutN; 126 | using LayoutS = typename BatchGemmSoftmax::LayoutS; 127 | using MatrixCoord = typename LayoutC::TensorCoord; 128 | 129 | cutlass::gemm::GemmCoord problem = {m, n, k}; 130 | 131 | int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); 132 | int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); 133 | int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); 134 | 135 | // fixed rowmajor for norm and sum 136 | int64_t ldn = problem.m(); 137 | int64_t lds = problem.m(); 138 | 139 | int block_num = (problem.n() + BatchGemmSoftmax::ThreadblockShape::kN - 1) / BatchGemmSoftmax::ThreadblockShape::kN; 140 | 141 | typename BatchGemmSoftmax::Arguments args( 142 | problem, 143 | batch_count, 144 | {reinterpret_cast(A.data_ptr()), lda}, 145 | {reinterpret_cast(B.data_ptr()), ldb}, 146 | {nullptr, ldc}, 147 | {reinterpret_cast(D.data_ptr()), ldc}, 148 | { 149 | ElementCompute(alpha), 150 | ElementCompute(beta) 151 | }, 152 | {reinterpret_cast(Norm.data_ptr()), ldn}, 153 | {reinterpret_cast(Sum.data_ptr()), lds}, 154 | {reinterpret_cast(Softmax.data_ptr()), ldc}, 155 | problem.m() * problem.k(), 156 | problem.k() * problem.n(), 157 | problem.m() * problem.n(), 158 | problem.m() * problem.n(), 159 | block_num * problem.m(), 160 | block_num * problem.m(), 161 | problem.m() * problem.n() 162 | ); 163 | 164 | BatchGemmSoftmax batch_gemm_softmax; 165 | 166 | CUTLASS_CHECK(batch_gemm_softmax.initialize(args)); 167 | 168 | CUTLASS_CHECK(batch_gemm_softmax()); 169 | } 170 | 171 | void batch_gemm_softmax( 172 | torch::Tensor A, 173 | torch::Tensor B, 174 | torch::Tensor D, 175 | torch::Tensor Norm, 176 | torch::Tensor Sum, 177 | torch::Tensor Softmax, 178 | int batch_count, 179 | int m, 180 | int n, 181 | int k, 182 | float alpha = 1.0, 183 | float beta = 0.0 184 | ) { 185 | if (A.dtype() == torch::kBFloat16) { 186 | batch_gemm_softmax_impl( 187 | A, B, D, Norm, Sum, Softmax, 188 | batch_count, m, n, k, alpha, beta 189 | ); 190 | } else if (A.dtype() == torch::kFloat16) { 191 | batch_gemm_softmax_impl( 192 | A, B, D, Norm, Sum, Softmax, 193 | batch_count, m, n, k, alpha, beta 194 | ); 195 | } else { 196 | TORCH_CHECK(false, "Only BFloat16 and Float16 dtypes are supported"); 197 | } 198 | } 199 | 200 | 201 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 202 | m.def("batch_gemm_softmax", &batch_gemm_softmax, "Batch GEMM Softmax (CUDA)", 203 | py::arg("A"), py::arg("B"), py::arg("D"), 204 | py::arg("Norm"), py::arg("Sum"), py::arg("Softmax"), 205 | py::arg("batch_count"), py::arg("m"), py::arg("n"), py::arg("k"), 206 | py::arg("alpha") = 1.0f, 207 | py::arg("beta") = 0.0f 208 | ); 209 | } -------------------------------------------------------------------------------- /library/retroinfer/retroinfer_kernels/src/thread_pool.hpp: -------------------------------------------------------------------------------- 1 | #ifndef THREAD_POOL_HPP 2 | #define THREAD_POOL_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include // for memcpy 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | 18 | void set_affinity(uint32_t idx) { 19 | cpu_set_t my_set; 20 | CPU_ZERO(&my_set); 21 | CPU_SET(idx, &my_set); 22 | sched_setaffinity(0, sizeof(cpu_set_t), &my_set); 23 | } 24 | 25 | // Thread pool 26 | // Note that this thread pool should be reused across multiple layers 27 | class MyThreadPool { 28 | public: 29 | void Start(uint32_t num_threads = 0); 30 | void QueueJobWOLock(const std::function& job, void* para); 31 | void NotifyAll(); 32 | void NotifyOne(); 33 | void Stop(); 34 | void Wait(); 35 | void AddNumTask(int); 36 | void DisplayNumTask(); 37 | void LockQueue(); 38 | void UnlockQueue(); 39 | 40 | private: 41 | void ThreadLoop(uint32_t); // input is the thread id 42 | 43 | bool should_terminate = false; // Tells threads to stop looking for jobs 44 | std::mutex queue_mutex; // Prevents data races to the job queue 45 | std::mutex main_mutex; 46 | std::condition_variable mutex_condition; // Allows threads to wait on new jobs or termination 47 | std::condition_variable main_condition; // main thread uses this condition variable to wait 48 | std::vector threads; 49 | std::queue,void*>> jobs; 50 | std::atomic num_tasks; 51 | }; 52 | 53 | void MyThreadPool::Start(uint32_t num_threads) { 54 | if(num_threads == 0){ 55 | num_threads = std::thread::hardware_concurrency(); 56 | } 57 | 58 | for (uint32_t ii = 0; ii < num_threads; ++ii) { 59 | threads.emplace_back(std::thread(&MyThreadPool::ThreadLoop, this, ii)); 60 | } 61 | 62 | num_tasks = 0; 63 | } 64 | 65 | void MyThreadPool::ThreadLoop(uint32_t thread_idx) { 66 | set_affinity(thread_idx); 67 | while (true) { 68 | std::pair,void*> job; 69 | { 70 | std::unique_lock lock(queue_mutex); 71 | mutex_condition.wait(lock, [this] { 72 | return !jobs.empty() || should_terminate; 73 | }); 74 | if (should_terminate) { 75 | return; 76 | } 77 | job = jobs.front(); 78 | jobs.pop(); 79 | 80 | } 81 | job.first(job.second); 82 | auto cur_num = num_tasks.fetch_sub(1); 83 | if(cur_num == 1){ 84 | // std::cout << "finally notify main thread" << std::endl; 85 | std::unique_lock lock(main_mutex); 86 | main_condition.notify_one(); 87 | } 88 | } 89 | } 90 | 91 | void MyThreadPool::QueueJobWOLock(const std::function& job, void* para) { 92 | jobs.push(std::pair,void*>(job, para)); 93 | } 94 | 95 | void MyThreadPool::AddNumTask(int num){ 96 | num_tasks.fetch_add(num); 97 | // assert(num_tasks == jobs.size()); 98 | } 99 | 100 | void MyThreadPool::DisplayNumTask(){ 101 | std::cout << "Num tasks = " << num_tasks << std::endl; 102 | } 103 | 104 | void MyThreadPool::NotifyAll() { 105 | mutex_condition.notify_all(); 106 | } 107 | 108 | void MyThreadPool::NotifyOne() { 109 | mutex_condition.notify_one(); 110 | } 111 | 112 | void MyThreadPool::LockQueue() { 113 | queue_mutex.lock(); 114 | } 115 | 116 | void MyThreadPool::UnlockQueue() { 117 | queue_mutex.unlock(); 118 | } 119 | 120 | // Only use this when the system terminates 121 | void MyThreadPool::Stop() { 122 | { 123 | std::unique_lock lock(queue_mutex); 124 | should_terminate = true; 125 | } 126 | mutex_condition.notify_all(); 127 | for (std::thread& active_thread : threads) { 128 | active_thread.join(); 129 | } 130 | threads.clear(); 131 | } 132 | 133 | // Wait until all submitted tasks have been executed 134 | void MyThreadPool::Wait(){ 135 | { 136 | std::unique_lock lock(main_mutex); 137 | main_condition.wait(lock, [this] { 138 | return num_tasks == 0; 139 | }); 140 | } 141 | } 142 | 143 | #endif // THREAD_POOL_HPP -------------------------------------------------------------------------------- /library/retroinfer/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import CUDAExtension, CppExtension, BuildExtension 4 | 5 | src_dir = "retroinfer_kernels/src" 6 | cutlass_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"../cutlass") 7 | 8 | ext_modules = [ 9 | CppExtension( 10 | 'retroinfer_kernels.WaveBuffer', 11 | sources=[f'{src_dir}/wave_buffer_cpu.cpp'], 12 | include_dirs=['/usr/local/cuda-12/include'], 13 | library_dirs=['/usr/local/lib'], 14 | extra_compile_args=['-O3', '-fopenmp'], 15 | extra_link_args=['-fopenmp'], 16 | language='c++' 17 | ), 18 | CUDAExtension( 19 | 'retroinfer_kernels.Copy', 20 | sources=[f'{src_dir}/gather_copy.cu'], 21 | extra_compile_args={'cxx': ['-O3', '-std=c++17'], 22 | 'nvcc': ['-O3', '-std=c++17', '--expt-relaxed-constexpr']}, 23 | extra_link_args=['-lcuda', '-lcudart'], 24 | ), 25 | CUDAExtension( 26 | 'retroinfer_kernels.gemm_softmax', 27 | sources=[f'{src_dir}/batch_gemm_softmax.cu'], 28 | include_dirs=[ 29 | f"{cutlass_dir}/include", 30 | f"{cutlass_dir}/examples/common", 31 | f"{cutlass_dir}/tools/util/include" 32 | ], 33 | extra_compile_args={'cxx': ['-O3', '-std=c++17'], 34 | 'nvcc': ['-O3', '-std=c++17', '--expt-relaxed-constexpr']}, 35 | extra_link_args=['-lcuda', '-lcudart'], 36 | ), 37 | ] 38 | 39 | 40 | setup( 41 | name='retroinfer_kernels', 42 | version='0.1', 43 | packages=['retroinfer_kernels'], 44 | description='RetroInfer kernels and modules', 45 | long_description='A collection of CUDA and C++ extensions for RetroInfer.', 46 | ext_modules=ext_modules, 47 | cmdclass={'build_ext': BuildExtension}, 48 | install_requires=['pybind11', 'torch'], 49 | python_requires='>=3.10', 50 | ) -------------------------------------------------------------------------------- /library/retroinfer/test/test_batch_gemm_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from retroinfer_kernels import batch_gemm_softmax 4 | import time 5 | 6 | DTYPE = torch.bfloat16 7 | 8 | 9 | def torch_batch_gemm_softmax(A, B): 10 | """ 11 | A: query vector, shape: (batch_size*group_num, group_size, dim), gpu torch tensor 12 | B: centroids vector, shape: (batch_size*group_num, n_clusters, dim), gpu torch tensor 13 | """ 14 | # [batch_size*group_num, group_size, n_centroids] 15 | dist = torch.bmm(A, B.transpose(1, 2)) # [batch_size*group_num, group_size, n_centroids] 16 | dist = dist / math.sqrt(A.size(-1)) 17 | dist = torch.softmax(dist, dim=-1) # [batch_size*group_num, group_size, n_centroids] 18 | return dist 19 | 20 | def test_batch_gemm_softmax(): 21 | batch_size = 16 22 | group_num = 8 23 | group_size = 4 24 | n_clusters = 8200 # must be multiple of 8 25 | dim = 128 26 | 27 | queries = torch.randn((batch_size*group_num, group_size, dim), device='cuda', dtype=DTYPE).contiguous() 28 | centroids = torch.randn((batch_size*group_num, n_clusters, dim), device='cuda', dtype=DTYPE).contiguous() 29 | 30 | # buffer 31 | gemm_o = torch.zeros((batch_size, group_num, group_size, n_clusters), device='cuda', dtype=DTYPE).contiguous() 32 | softmax_o = torch.zeros((batch_size*group_num, group_size, n_clusters), device='cuda', dtype=DTYPE).contiguous() 33 | n_clusters_256 = (n_clusters + 256 - 1) // 256 34 | _norm = torch.zeros((batch_size*group_num, group_size, n_clusters_256), device='cuda', dtype=torch.float32).contiguous() 35 | _sum = torch.zeros((batch_size*group_num, group_size, n_clusters_256), device='cuda', dtype=torch.float32).contiguous() 36 | torch.cuda.synchronize() 37 | 38 | start = time.time() 39 | batch_gemm_softmax(queries, centroids, gemm_o, _norm, _sum, softmax_o, 40 | batch_size*group_num, group_size, n_clusters, dim, 1/math.sqrt(dim), 0) 41 | torch.cuda.synchronize() 42 | print(f"cuda time {1000*(time.time() - start):.4f} ms") 43 | 44 | start = time.time() 45 | softmax_ref = torch_batch_gemm_softmax(queries, centroids) 46 | torch.cuda.synchronize() 47 | print(f"torch time {1000*(time.time() - start):.4f} ms") 48 | 49 | assert softmax_o.shape == softmax_ref.shape, f"shape mismatch: {softmax_o.shape.shape} vs {softmax_ref.shape.shape}" 50 | assert torch.allclose(softmax_o, softmax_ref, atol=1e-3), f"{torch.max(torch.abs(softmax_o - softmax_ref))}" 51 | 52 | 53 | if __name__ == "__main__": 54 | for _ in range(10): 55 | test_batch_gemm_softmax() 56 | print("pass") -------------------------------------------------------------------------------- /library/retroinfer/test/test_gather_copy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | import math 5 | from retroinfer_kernels import gather_copy_and_concat, gather_copy_and_scatter, gather_copy_vectors 6 | import random 7 | 8 | DTYPE = torch.bfloat16 9 | 10 | 11 | # generate two random indices with shape (rows, cols) 12 | def gen_two_indices(rows, cols, max_range1, max_range2, unit_size): 13 | src_indices1 = np.random.randint(-1, 100, size=(rows, cols), dtype=np.int32) 14 | src_copy_size1 = np.random.randint(1, unit_size+1, size=(rows, cols), dtype=np.int32) 15 | dst_indices1 = np.random.randint(0, 100, size=(rows, cols), dtype=np.int32) 16 | copy_chunks1 = np.random.randint(0, 10, size=(rows,), dtype=np.int32) 17 | 18 | src_indices2 = np.random.randint(-1, 100, size=(rows, cols), dtype=np.int32) 19 | src_copy_size2 = np.random.randint(1, unit_size+1, size=(rows, cols), dtype=np.int32) 20 | dst_indices2 = np.random.randint(0, 100, size=(rows, cols), dtype=np.int32) 21 | copy_chunks2 = np.random.randint(0, 10, size=(rows,), dtype=np.int32) 22 | 23 | for i in range(rows): 24 | num = np.random.randint(int(0.8*cols), cols) 25 | if i == 1: 26 | num1 = 0 27 | num2 = num 28 | elif i == 5: 29 | num1 = num 30 | num2 = 0 31 | else: 32 | num1 = np.random.randint(0, num) 33 | num2 = num - num1 34 | 35 | src_indices1[i, :num1] = np.random.choice(max_range1, num1, replace=False) 36 | cumsum = 0 37 | for j in range(num1): 38 | copy_size = np.random.randint(0, unit_size+1) # [0, unit_size] 39 | src_copy_size1[i, j] = copy_size 40 | dst_indices1[i, j] = cumsum 41 | cumsum += copy_size 42 | # 设定边界拷贝 43 | if num1 > 0: 44 | x = np.random.randint(0, num1) 45 | src_indices1[i, x] = max_range1+unit_size-src_copy_size1[i, x] 46 | copy_chunks1[i] = num1 47 | 48 | src_indices2[i, :num2] = np.random.choice(max_range2, num2, replace=False) 49 | cumsum = 0 50 | for j in range(num2): 51 | copy_size = np.random.randint(0, unit_size+1) # [0, unit_size] 52 | src_copy_size2[i, j] = copy_size 53 | dst_indices2[i, j] = cumsum 54 | cumsum += copy_size 55 | copy_chunks2[i] = num2 56 | 57 | src_indices1 = torch.from_numpy(src_indices1).pin_memory() 58 | src_copy_size1 = torch.from_numpy(src_copy_size1).pin_memory() 59 | dst_indices1 = torch.from_numpy(dst_indices1).pin_memory() 60 | copy_chunks1 = torch.from_numpy(copy_chunks1).pin_memory() 61 | 62 | src_indices2 = torch.from_numpy(src_indices2).pin_memory() 63 | src_copy_size2 = torch.from_numpy(src_copy_size2).pin_memory() 64 | dst_indices2 = torch.from_numpy(dst_indices2).pin_memory() 65 | copy_chunks2 = torch.from_numpy(copy_chunks2).pin_memory() 66 | return src_indices1, src_copy_size1, dst_indices1, copy_chunks1, src_indices2, src_copy_size2, dst_indices2, copy_chunks2 67 | 68 | def test_concat_gather_copy(): 69 | groups = 8 70 | src_vector_num1 = 1769 71 | src_vector_num2 = 12397 72 | src_unit_num3 = 1000 73 | buffer_unit_num = 400 74 | index_length = 400 75 | unit_size = 8 76 | dim = 128 77 | copy_vector_num = 1602 78 | buffer_vector_num = buffer_unit_num * unit_size + src_vector_num1 79 | 80 | key_src1 = torch.randn((groups, src_vector_num1, dim), device='cuda', dtype=DTYPE).contiguous() 81 | key_src2 = torch.randn((groups, src_vector_num2, dim), pin_memory=True, dtype=DTYPE).contiguous() 82 | key_src3 = torch.randn((groups, src_unit_num3, unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 83 | key_dst1 = torch.randn((groups, buffer_vector_num, dim), device='cuda', dtype=DTYPE).contiguous() 84 | key_dst2 = key_dst1.clone() 85 | 86 | value_src1 = torch.randn((groups, src_vector_num1, dim), device='cuda', dtype=DTYPE).contiguous() 87 | value_src2 = torch.randn((groups, src_vector_num2, dim), pin_memory=True, dtype=DTYPE).contiguous() 88 | value_src3 = torch.randn((groups, src_unit_num3, unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 89 | value_dst1 = torch.randn((groups, buffer_vector_num, dim), device='cuda', dtype=DTYPE).contiguous() 90 | value_dst2 = value_dst1.clone() 91 | 92 | valid_lengths = torch.empty((groups,), dtype=torch.int32, pin_memory=True) 93 | 94 | src_indices1, src_copy_size1, dst_indices1, copy_chunks1, src_indices2, src_copy_size2, dst_indices2, copy_chunks2 = gen_two_indices(groups, index_length, src_vector_num2-unit_size, src_unit_num3, unit_size) 95 | torch.cuda.synchronize() 96 | 97 | t1 = time.time() 98 | gather_copy_and_concat(key_src1, key_src2, key_src3, key_dst1, 99 | value_src1, value_src2, value_src3, value_dst1, 100 | src_indices1, src_copy_size1, dst_indices1, copy_chunks1, 101 | src_indices2, src_copy_size2, dst_indices2, copy_chunks2, 102 | valid_lengths, groups, src_vector_num1, src_vector_num2, src_unit_num3, 103 | buffer_vector_num, index_length, copy_vector_num) 104 | 105 | torch.cuda.synchronize() 106 | print("cuda time: ", time.time()-t1) 107 | 108 | print("valid_lengths: ", valid_lengths) 109 | 110 | for i in range(groups): 111 | print(f"group{i}, {copy_chunks1[i]}, {copy_chunks2[i]}") 112 | 113 | key_dst2[i, :copy_vector_num, :] = key_src1[i, :copy_vector_num, :] 114 | value_dst2[i, :copy_vector_num, :] = value_src1[i, :copy_vector_num, :] 115 | copy_num = copy_vector_num 116 | 117 | for j in range(copy_chunks1[i]): 118 | key_dst2[i, copy_num:copy_num+src_copy_size1[i, j], :] = key_src2[i, src_indices1[i, j]:src_indices1[i, j]+src_copy_size1[i, j], :] 119 | value_dst2[i, copy_num:copy_num+src_copy_size1[i, j], :] = value_src2[i, src_indices1[i, j]:src_indices1[i, j]+src_copy_size1[i, j], :] 120 | copy_num += src_copy_size1[i, j] 121 | 122 | for j in range(copy_chunks2[i]): 123 | key_dst2[i, copy_num:copy_num+src_copy_size2[i, j], :] = key_src3[i, src_indices2[i, j], :src_copy_size2[i, j], :] 124 | value_dst2[i, copy_num:copy_num+src_copy_size2[i, j], :] = value_src3[i, src_indices2[i, j], :src_copy_size2[i, j], :] 125 | copy_num += src_copy_size2[i, j] 126 | 127 | assert copy_num == valid_lengths[i], f"{i, copy_num, valid_lengths[i]}" 128 | 129 | assert (key_dst1 == key_dst2).all() 130 | assert (value_dst1 == value_dst2).all() 131 | 132 | 133 | 134 | def gen_indices(rows, cols, max_range1, max_range2, unit_size): 135 | src_indices = np.random.randint(-1, 100, size=(rows, cols), dtype=np.int32) 136 | src_copy_size = np.random.randint(1, unit_size+1, size=(rows, cols), dtype=np.int32) 137 | dst_indices = np.random.randint(0, 100, size=(rows, cols), dtype=np.int32) 138 | copy_chunks = np.random.randint(0, 10, size=(rows,), dtype=np.int32) 139 | 140 | for i in range(rows): 141 | if i == 1: 142 | copy_chunks[i] = 0 143 | continue 144 | 145 | num = np.random.randint(int(0.2*cols), int(0.8*cols)) 146 | 147 | src_indices[i, :num] = np.random.choice(max_range1, num, replace=False) 148 | dst_indices[i, :num] = np.random.choice(max_range2, num, replace=False) 149 | for j in range(num): 150 | copy_size = np.random.randint(0, unit_size+1) # [0, unit_size] 151 | if src_indices[i, j] + copy_size > max_range1: # overflow 152 | copy_size = max_range1 - src_indices[i, j] 153 | src_copy_size[i, j] = copy_size 154 | copy_chunks[i] = num 155 | 156 | src_indices = torch.from_numpy(src_indices).pin_memory() 157 | src_copy_size = torch.from_numpy(src_copy_size).pin_memory() 158 | dst_indices = torch.from_numpy(dst_indices).pin_memory() 159 | copy_chunks = torch.from_numpy(copy_chunks).pin_memory() 160 | 161 | return src_indices, src_copy_size, dst_indices, copy_chunks 162 | 163 | def test_gather_copy_scatter(): 164 | groups = 8 165 | src_unit_num = 400 166 | dst_unit_num = 1000 167 | index_length = 400 168 | unit_size = 8 169 | dim = 128 170 | copy_start = 97 171 | 172 | key_src = torch.randn((groups, src_unit_num*unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 173 | key_dst1 = torch.randn((groups, dst_unit_num, unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 174 | key_dst2 = key_dst1.clone() 175 | 176 | value_src = torch.randn((groups, src_unit_num*unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 177 | value_dst1 = torch.randn((groups, dst_unit_num, unit_size, dim), device='cuda', dtype=DTYPE).contiguous() 178 | value_dst2 = value_dst1.clone() 179 | 180 | src_indices, src_copy_size, dst_indices, copy_chunks = gen_indices(groups, index_length, src_unit_num*unit_size-copy_start, dst_unit_num, unit_size) 181 | torch.cuda.synchronize() 182 | 183 | t1 = time.time() 184 | gather_copy_and_scatter(key_src, key_dst1, value_src, value_dst1, 185 | src_indices, src_copy_size, dst_indices, copy_chunks, 186 | groups, src_unit_num*unit_size, dst_unit_num, index_length, copy_start) 187 | torch.cuda.synchronize() 188 | print("cuda time: ", time.time()-t1) 189 | 190 | for i in range(groups): 191 | print(f"group{i}, {copy_chunks[i]}") 192 | for j in range(copy_chunks[i]): 193 | key_dst2[i, dst_indices[i, j], :src_copy_size[i, j], :] = key_src[i, copy_start+src_indices[i, j]:copy_start+src_indices[i, j]+src_copy_size[i, j], :] 194 | value_dst2[i, dst_indices[i, j], :src_copy_size[i, j], :] = value_src[i, copy_start+src_indices[i, j]:copy_start+src_indices[i, j]+src_copy_size[i, j], :] 195 | 196 | assert (key_dst1 == key_dst2).all() 197 | assert (value_dst1 == value_dst2).all() 198 | 199 | 200 | 201 | def test_gather_copy_vectors(): 202 | groups = 8 203 | src_vector_num = 8192 204 | dim = 128 205 | nprobe = 150 206 | index_size = 2048 207 | copy_vector_num = index_size - nprobe 208 | buffer_size = copy_vector_num 209 | 210 | key_src = torch.randn((groups, src_vector_num, dim), device='cuda', dtype=DTYPE).contiguous() 211 | key_dst1 = torch.randn((groups, buffer_size, dim), device='cuda', dtype=DTYPE).contiguous() 212 | key_dst2 = key_dst1.clone() 213 | 214 | value_src = torch.randn((groups, src_vector_num, dim), device='cuda', dtype=DTYPE).contiguous() 215 | value_dst1 = torch.randn((groups, buffer_size, dim), device='cuda', dtype=DTYPE).contiguous() 216 | value_dst2 = value_dst1.clone() 217 | 218 | src_metadata = torch.randn(size=(groups, src_vector_num), dtype=DTYPE, device='cuda').contiguous() 219 | dst_metadata1 = torch.empty((groups, buffer_size), dtype=DTYPE, device='cuda').contiguous() 220 | dst_metadata2 = dst_metadata1.clone() 221 | 222 | indices = torch.empty((groups, index_size), dtype=torch.int64, device='cuda') 223 | for i in range(groups): 224 | indices[i, :] = torch.randperm(src_vector_num)[:index_size].to(torch.int64).to("cuda") 225 | 226 | torch.cuda.synchronize() 227 | start = time.time() 228 | gather_copy_vectors(key_src, key_dst1, value_src, value_dst1, src_metadata, dst_metadata1, 229 | indices, groups, src_vector_num, buffer_size, index_size, nprobe, copy_vector_num) 230 | torch.cuda.synchronize() 231 | print("cuda time: ", time.time()-start) 232 | 233 | copy_indices = indices[:, nprobe:nprobe+copy_vector_num] 234 | for i in range(groups): 235 | key_dst2[i, :copy_vector_num, :] = key_src[i, copy_indices[i, :], :] 236 | value_dst2[i, :copy_vector_num, :] = value_src[i, copy_indices[i, :], :] 237 | dst_metadata2[i, :copy_vector_num] = src_metadata[i, copy_indices[i, :]] 238 | 239 | assert (key_dst1 == key_dst2).all() 240 | assert (value_dst1 == value_dst2).all() 241 | assert (dst_metadata1 == dst_metadata2).all() 242 | 243 | 244 | 245 | if __name__ == "__main__": 246 | for i in range(10): 247 | test_concat_gather_copy() 248 | test_gather_copy_scatter() 249 | test_gather_copy_vectors() 250 | print("pass") 251 | -------------------------------------------------------------------------------- /model_hub/LLM.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from termcolor import colored 4 | 5 | 6 | class LLM: 7 | """ 8 | A class representing the LLM (currently support Llama and Qwen). 9 | """ 10 | 11 | def __init__( 12 | self, 13 | model_name: str, 14 | max_length: int, 15 | dtype: torch.dtype, 16 | device_map: str 17 | ) -> None: 18 | """ Initializes the LLM. 19 | Args: 20 | model_name (str): The name of the model. 21 | max_length (int): The maximum length (prefill+decode) of sequences. 22 | dtype (torch.dtype): The data type for model computations. 23 | device_map (str): The device for model, suppor 'cuda:x' or 'auto (automatically use all visible GPUs)'. 24 | """ 25 | 26 | self.model_name = model_name 27 | self.max_length = max_length 28 | self.dtype = dtype 29 | self.device_map = device_map 30 | 31 | 32 | def layer_prefill(self, layer_idx, start_bdx, hidden_states): 33 | # print(f'Layer = {layer_idx}, start_bdx = {start_bdx}') 34 | 35 | bsz, seq_len, dim = hidden_states.shape 36 | layer = self.layers[layer_idx] 37 | 38 | # original hidden_states used as residual, clone a new one to process 39 | temp_hidden_states = hidden_states.clone() 40 | 41 | # chunk for lower memory comsumption 42 | for start_idx in range(0, seq_len, 8192//bsz): 43 | end_idx = min(seq_len, start_idx + 8192//bsz) 44 | temp_hidden_states[:, start_idx:end_idx, :] = self.layernorm(temp_hidden_states[:, start_idx:end_idx, :], 45 | layer.input_layernorm_variance_epsilon, 46 | layer.input_layernorm_weight) 47 | 48 | query_states, key_states, value_states = self.wqkv(temp_hidden_states, layer) 49 | del temp_hidden_states 50 | torch.cuda.empty_cache() 51 | query_states, key_states = self.position_embedd(query_states, key_states) 52 | 53 | query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim) # reshape [bs, seq_len, dim] => [bs, seq_len, head, head_dim] 54 | key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) 55 | value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim) 56 | 57 | key_states, value_states = self.kv_cache.prefill_update_kv_cache(query_states, key_states, value_states, layer_idx, start_bdx) 58 | torch.cuda.empty_cache() 59 | 60 | temp_attn_out = self.prefill_attention(query_states, key_states, value_states) 61 | 62 | self.kv_cache.sync(layer_idx, start_bdx) 63 | 64 | del query_states, key_states, value_states 65 | torch.cuda.empty_cache() 66 | 67 | hidden_states += self.wo(temp_attn_out, layer, temp_attn_out.shape[0], seq_len, dim) 68 | del temp_attn_out 69 | torch.cuda.empty_cache() 70 | 71 | # post attention 72 | residual = hidden_states.clone() 73 | 74 | # chunk for lower memory comsumption 75 | for start_idx in range(0, seq_len, 8192//bsz): 76 | end_idx = min(seq_len, start_idx + 8192//bsz) 77 | hidden_states[:, start_idx:end_idx, :] = self.layernorm(hidden_states[:, start_idx:end_idx, :], 78 | layer.post_attention_layernorm_variance_epsilon, 79 | layer.post_attention_layernorm_weight) 80 | hidden_states[:, start_idx:end_idx, :] = self.mlp(hidden_states[:, start_idx:end_idx, :], layer) 81 | 82 | hidden_states += residual 83 | 84 | del residual 85 | torch.cuda.empty_cache() 86 | 87 | return hidden_states 88 | 89 | 90 | def layer_decode(self, layer_idx, hidden_states): 91 | # print(f'Layer = {layer_idx}') 92 | 93 | residual = hidden_states 94 | bsz, seq_len, dim = hidden_states.shape 95 | layer = self.layers[layer_idx] 96 | 97 | hidden_states = self.layernorm(hidden_states, layer.input_layernorm_variance_epsilon, layer.input_layernorm_weight) 98 | 99 | query_states, key_states, value_states = self.wqkv(hidden_states, layer) 100 | query_states, key_states = self.position_embedd(query_states, key_states) 101 | 102 | query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) 103 | key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim) 104 | value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim) 105 | 106 | key_states, value_states = self.kv_cache.decode_update_kv_cache(key_states, value_states, layer_idx) 107 | attn_out = self.decode_attention(query_states, key_states, value_states, layer_idx) 108 | hidden_states = self.wo(attn_out, layer, bsz, seq_len, dim) 109 | hidden_states = residual + hidden_states 110 | 111 | residual = hidden_states 112 | hidden_states = self.layernorm(hidden_states, layer.post_attention_layernorm_variance_epsilon, layer.post_attention_layernorm_weight) 113 | hidden_states = self.mlp(hidden_states, layer) 114 | hidden_states = residual + hidden_states 115 | 116 | return hidden_states 117 | 118 | 119 | def prefill_forward(self, inputs_ids): 120 | bsz, seq_len = inputs_ids.shape 121 | device = inputs_ids.device 122 | 123 | last_hidden_states = torch.empty((bsz, 1, self.hidden_size), dtype=self.dtype, device=device) 124 | for start_bdx in range(0, bsz, 1): 125 | end_bdx = min(bsz, start_bdx + 1) 126 | hidden_states = self.word_embedding(inputs_ids[start_bdx:end_bdx]) # [1, seq_len, hidden_size] 127 | 128 | if self.num_gpus > 1: 129 | for ldx in range(self.num_layers): 130 | hidden_states = self.layer_prefill(ldx, start_bdx, hidden_states) 131 | hidden_states = self.parameter_move(hidden_states, ldx) 132 | torch.cuda.empty_cache() 133 | last_hidden_states[start_bdx:end_bdx] = hidden_states[:, -1:, :].to(self.layers[0].device) 134 | else: 135 | for ldx in range(self.num_layers): 136 | hidden_states = self.layer_prefill(ldx, start_bdx, hidden_states) 137 | torch.cuda.empty_cache() 138 | last_hidden_states[start_bdx:end_bdx] = hidden_states[:, -1:, :] 139 | 140 | last_hidden_states = self.layernorm(last_hidden_states.contiguous(), self.norm_variance_epsilon, self.norm_weight) 141 | logits = self.lm(last_hidden_states) 142 | 143 | return logits 144 | 145 | 146 | def decode_forward(self, inputs_ids): 147 | hidden_states = self.word_embedding(inputs_ids) 148 | 149 | if self.num_gpus > 1: 150 | for ldx in range(self.num_layers): 151 | hidden_states = self.layer_decode(ldx, hidden_states) 152 | hidden_states = self.parameter_move(hidden_states, ldx) 153 | hidden_states = hidden_states.to(self.layers[0].device) 154 | else: 155 | for ldx in range(self.num_layers): 156 | hidden_states = self.layer_decode(ldx, hidden_states) 157 | 158 | hidden_states = self.layernorm(hidden_states[:, -1:, :], self.norm_variance_epsilon, self.norm_weight) 159 | logits = self.lm(hidden_states) 160 | 161 | return logits 162 | 163 | 164 | def inference(self, inputs_ids): 165 | outputs_ids = [] # multi iteration, multi request 166 | output_ids = [] # single iteration, multi request 167 | 168 | print("Start prefilling ...") 169 | torch.cuda.synchronize() 170 | prefill_start = time.time() 171 | 172 | logits = self.prefill_forward(inputs_ids=inputs_ids) 173 | output_ids = logits.argmax(dim=-1) 174 | outputs_ids.append(output_ids) 175 | self.move() 176 | 177 | torch.cuda.synchronize() 178 | prefill_end = time.time() 179 | print(colored(f"Prefilling latency: {round((prefill_end - prefill_start), 4)} s\n", 'green')) 180 | 181 | print("Start decoding ...") 182 | decode_start = time.time() 183 | 184 | for _ in range(self.max_new_length-1): 185 | logits = self.decode_forward(inputs_ids=output_ids) 186 | output_ids = logits.argmax(dim=-1) 187 | outputs_ids.append(output_ids) 188 | 189 | decode_end = time.time() 190 | print(colored( 191 | f"Decoding latency: {round((decode_end - decode_start) * 1000 / (self.max_new_length - 1), 2)} ms/step, " 192 | f"Throughput: {round(self.batch_size * (self.max_new_length - 1) / (decode_end - decode_start), 2)} tokens/s\n", 193 | 'green' 194 | )) 195 | 196 | outputs_ids = torch.cat(outputs_ids, dim=-1).tolist() 197 | 198 | return outputs_ids 199 | 200 | 201 | def generate(self, attention_type, inputs_ids, attention_masks, max_new_length, attn_config=None): 202 | """ LLM Inference. 203 | Args: 204 | attention_type: str, 205 | input_ids (torch.tensor): The input of LLM. 206 | attention_masks (torch.tensor): The attention masks of LLM. 207 | max_new_length (int): The maximum length of generated sequences. 208 | """ 209 | 210 | bs, input_length = inputs_ids.shape 211 | assert input_length + max_new_length <= self.max_length, \ 212 | f"Error: input_length({input_length}) + max_new_length({max_new_length}) exceeds max_length({self.max_length})" 213 | 214 | self.batch_size = bs 215 | self.input_length = input_length 216 | self.max_new_length = max_new_length 217 | self.attention_type = attention_type 218 | 219 | valid_start = attention_masks.shape[1] - torch.sum(attention_masks, dim=-1).detach().cpu().numpy() 220 | del attention_masks 221 | torch.cuda.empty_cache() 222 | 223 | print("Allocate GPU buffers and CPU pin memory ...\n") 224 | self.init_kv_cache(input_length, valid_start, attn_config) 225 | 226 | outputs = self.inference(inputs_ids) 227 | 228 | return outputs 229 | -------------------------------------------------------------------------------- /model_hub/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .llama import LlamaModel 3 | from .qwen import QwenModel -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | ninja==1.11.1.4 3 | argparse 4 | termcolor 5 | packaging==25.0 6 | torch==2.4.0 7 | vllm==0.6.2 8 | transformers==4.49.0 9 | pybind11==2.12.0 10 | 11 | evaluate==0.4.3 12 | rouge==1.0.1 13 | rouge_score==0.1.2 14 | 15 | wonderwords==2.2.0 16 | html2text==2024.2.26 17 | bs4==0.0.2 18 | scipy==1.13.0 19 | 20 | jieba==0.42.1 21 | fuzzywuzzy==0.18.0 -------------------------------------------------------------------------------- /simple_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import math 5 | import torch 6 | import argparse 7 | import random 8 | import numpy as np 9 | from termcolor import colored 10 | from transformers import AutoTokenizer 11 | PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) 12 | sys.path.append(PROJECT_ROOT) 13 | from model_hub import LlamaModel, QwenModel 14 | 15 | 16 | def set_seed(seed): 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description="Test example") 26 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size") 27 | parser.add_argument("--gen_len", type=int, default=100, help="Generation length") 28 | parser.add_argument("--device", type=str, default="cuda:0", help="Device") 29 | parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], help="Dtype") 30 | parser.add_argument("--attn_type", type=str, default="RetroInfer", \ 31 | choices=["Full_Flash_Attn", "RetroInfer"], help="Attention method") 32 | parser.add_argument("--model_name", type=str, default="gradientai/Llama-3-8B-Instruct-Gradient-1048k", \ 33 | choices=["gradientai/Llama-3-8B-Instruct-Gradient-1048k", "Qwen/Qwen2.5-7B-Instruct", \ 34 | "Qwen/Qwen2.5-72B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"], help="huggingface model name") 35 | parser.add_argument("--data_path", type=str, default="", help="Input json file path") 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | 41 | def load_model(model_name, max_len, dtype, device): 42 | if 'Llama' in model_name: 43 | llm = LlamaModel(model_name, 44 | max_length=max_len, 45 | dtype=dtype, 46 | device_map=device) 47 | elif 'Qwen' in model_name: 48 | llm = QwenModel(model_name, 49 | max_length=max_len, 50 | dtype=dtype, 51 | device_map=device) 52 | else: 53 | raise ValueError(f"Unsupported model: {model_name}") 54 | 55 | return llm 56 | 57 | 58 | def generate_config(model_name, context_len, attn_type): 59 | CONFIG_DIR = os.path.join(PROJECT_ROOT, "config") 60 | MODEL_NAME = model_name.split("/")[-1]+'.json' 61 | CONFIG_FILE = os.path.join(CONFIG_DIR, MODEL_NAME) 62 | with open(CONFIG_FILE, "r") as f: 63 | original_config = json.load(f) 64 | 65 | n_clusters = max(int(context_len/16), 1) 66 | n_segments = max(int(context_len/8192), 1) 67 | # compute the nearest multiple of (n_segments*32) 68 | lower = (n_clusters // (n_segments*32)) * (n_segments*32) 69 | upper = lower + (n_segments*32) 70 | n_clusters = lower if abs(n_clusters - lower) <= abs(n_clusters - upper) else upper 71 | nprobe = max(int(n_clusters*0.018), 1) 72 | 73 | if attn_type == 'RetroInfer': 74 | original_config[attn_type]['n_centroids'] = n_clusters 75 | original_config[attn_type]['n_segment'] = n_segments 76 | original_config[attn_type]['nprobe'] = nprobe 77 | original_config[attn_type]['cache_cluster_num'] = nprobe*3 78 | original_config[attn_type]['max_compute_cluster_num'] = max(int(n_clusters/4), nprobe) 79 | 80 | if attn_type != "Full_Flash_Attn": 81 | print(original_config[attn_type]) 82 | 83 | return original_config 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | set_seed(2025) 89 | 90 | model_name = args.model_name 91 | batch_size = args.batch_size 92 | attn_type = args.attn_type 93 | dtype = torch.float16 if args.dtype=='fp16' else torch.bfloat16 94 | device = args.device 95 | data_path = args.data_path 96 | 97 | # load input data 98 | if data_path == "": 99 | TEST_FILE = os.path.join(PROJECT_ROOT, "simple_test_data.json") 100 | else: 101 | TEST_FILE = os.path.join(PROJECT_ROOT, f"{data_path}") 102 | print(colored(f"Loading test data from {TEST_FILE}", 'yellow')) 103 | data = json.load(open(TEST_FILE)) # [{"input": str, "outputs": str}, ...] 104 | prompt = [] 105 | groundtruth = [] 106 | for dd in data: 107 | prompt.append(dd['input']) 108 | groundtruth.append(dd['outputs']) 109 | 110 | # copy to fit batch size 111 | copy_round = math.ceil(batch_size/len(prompt)) 112 | prompts = [] 113 | groundtruths = [] 114 | for i in range(copy_round): 115 | prompts.extend(prompt) 116 | groundtruths.extend(groundtruth) 117 | prompts = prompts[:batch_size] 118 | groundtruths = groundtruths[:batch_size] 119 | 120 | # tokenize input data 121 | tokenizer = AutoTokenizer.from_pretrained(model_name) 122 | tokenizer.pad_token = tokenizer.eos_token 123 | tokenizer.padding_side = "left" 124 | inputs = tokenizer(prompts, return_tensors="pt", padding=True) 125 | input_ids = inputs.input_ids 126 | attention_masks = inputs.attention_mask 127 | 128 | input_len = input_ids.shape[1] 129 | gen_len = args.gen_len 130 | max_len = input_len + gen_len 131 | print(colored(f"Input length: {input_len}", 'yellow')) 132 | 133 | if data_path == "": 134 | attn_config = generate_config(model_name, 122880, attn_type) 135 | else: 136 | attn_config = generate_config(model_name, input_len, attn_type) 137 | 138 | llm = load_model(model_name, max_len, dtype, device) 139 | out = llm.generate(attention_type=attn_type, 140 | inputs_ids = input_ids.to(llm.layers[0].device), 141 | attention_masks = attention_masks.to(llm.layers[0].device), 142 | max_new_length=gen_len, attn_config=attn_config) 143 | 144 | result = tokenizer.batch_decode(out, skip_special_tokens=True) 145 | for gt, res in zip(groundtruths, result): 146 | print(colored(f"Answer: {gt}", 'yellow')) 147 | print(f"{[res]}") 148 | -------------------------------------------------------------------------------- /throughput_eval/run.sh: -------------------------------------------------------------------------------- 1 | echo "Running different context lengths ..." 2 | bash run_different_lengths.sh 3 | 4 | echo "Running different models ..." 5 | bash run_different_models.sh 6 | 7 | echo "Running different tasks ..." 8 | bash run_different_tasks.sh 9 | 10 | echo "Done" 11 | -------------------------------------------------------------------------------- /throughput_eval/run_different_lengths.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | mkdir -p different_lengths_logs 4 | 5 | ################################ Full Attention ################################ 6 | for bsz in 1 2 4 8 7 | do 8 | for round in 1 9 | do 10 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 11 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 12 | --attn_type Full_Flash_Attn \ 13 | --context_len 60000 \ 14 | --task_name NIAH \ 15 | --batch_size $bsz > different_lengths_logs/full_attn_60k_bsz${bsz}_${round}.log 2>&1 16 | done 17 | done 18 | 19 | 20 | for bsz in 1 2 4 21 | do 22 | for round in 1 23 | do 24 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 25 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 26 | --attn_type Full_Flash_Attn \ 27 | --context_len 120000 \ 28 | --task_name NIAH \ 29 | --batch_size $bsz > different_lengths_logs/full_attn_120k_bsz${bsz}_${round}.log 2>&1 30 | done 31 | done 32 | 33 | 34 | for bsz in 1 2 35 | do 36 | for round in 1 37 | do 38 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 39 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 40 | --attn_type Full_Flash_Attn \ 41 | --context_len 240000 \ 42 | --task_name NIAH \ 43 | --batch_size $bsz > different_lengths_logs/full_attn_240k_bsz${bsz}_${round}.log 2>&1 44 | done 45 | done 46 | 47 | 48 | for bsz in 1 49 | do 50 | for round in 1 51 | do 52 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 53 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 54 | --attn_type Full_Flash_Attn \ 55 | --context_len 480000 \ 56 | --task_name NIAH \ 57 | --batch_size $bsz > different_lengths_logs/full_attn_480k_bsz${bsz}_${round}.log 2>&1 58 | done 59 | done 60 | 61 | 62 | ################################ RetroInfer ################################ 63 | # 60K 64 | for bsz in 1 2 4 8 16 32 65 | do 66 | for round in 1 67 | do 68 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 69 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 70 | --attn_type RetroInfer \ 71 | --context_len 60000 \ 72 | --task_name NIAH \ 73 | --batch_size $bsz > different_lengths_logs/retroinfer_60k_bsz${bsz}_${round}.log 2>&1 74 | done 75 | done 76 | 77 | for bsz in 64 78 | do 79 | for round in 1 80 | do 81 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 82 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 83 | --attn_type RetroInfer \ 84 | --context_len 60000 \ 85 | --task_name NIAH \ 86 | --batch_size $bsz > different_lengths_logs/retroinfer_60k_bsz${bsz}_${round}.log 2>&1 87 | done 88 | done 89 | 90 | # 120K 91 | for bsz in 1 2 4 8 16 92 | do 93 | for round in 1 94 | do 95 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 96 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 97 | --attn_type RetroInfer \ 98 | --context_len 120000 \ 99 | --task_name NIAH \ 100 | --batch_size $bsz > different_lengths_logs/retroinfer_120k_bsz${bsz}_${round}.log 2>&1 101 | done 102 | done 103 | 104 | for bsz in 32 105 | do 106 | for round in 1 107 | do 108 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 109 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 110 | --attn_type RetroInfer \ 111 | --context_len 120000 \ 112 | --task_name NIAH \ 113 | --batch_size $bsz > different_lengths_logs/retroinfer_120k_bsz${bsz}_${round}.log 2>&1 114 | done 115 | done 116 | 117 | # 240K 118 | for bsz in 1 2 4 8 119 | do 120 | for round in 1 121 | do 122 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 123 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 124 | --attn_type RetroInfer \ 125 | --context_len 240000 \ 126 | --task_name NIAH \ 127 | --batch_size $bsz > different_lengths_logs/retroinfer_240k_bsz${bsz}_${round}.log 2>&1 128 | done 129 | done 130 | 131 | for bsz in 16 132 | do 133 | for round in 1 134 | do 135 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 136 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 137 | --attn_type RetroInfer \ 138 | --context_len 240000 \ 139 | --task_name NIAH \ 140 | --batch_size $bsz > different_lengths_logs/retroinfer_240k_bsz${bsz}_${round}.log 2>&1 141 | done 142 | done 143 | 144 | # 480K 145 | for bsz in 1 2 4 146 | do 147 | for round in 1 148 | do 149 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 150 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 151 | --attn_type RetroInfer \ 152 | --context_len 480000 \ 153 | --task_name NIAH \ 154 | --batch_size $bsz > different_lengths_logs/retroinfer_480k_bsz${bsz}_${round}.log 2>&1 155 | done 156 | done 157 | 158 | for bsz in 8 159 | do 160 | for round in 1 161 | do 162 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 163 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 164 | --attn_type RetroInfer \ 165 | --context_len 480000 \ 166 | --task_name NIAH \ 167 | --batch_size $bsz > different_lengths_logs/retroinfer_480k_bsz${bsz}_${round}.log 2>&1 168 | done 169 | done 170 | 171 | # 1024K 172 | for bsz in 1 2 173 | do 174 | for round in 1 175 | do 176 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 177 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 178 | --attn_type RetroInfer \ 179 | --context_len 1024000 \ 180 | --task_name NIAH \ 181 | --batch_size $bsz > different_lengths_logs/retroinfer_1024k_bsz${bsz}_${round}.log 2>&1 182 | done 183 | done 184 | 185 | for bsz in 4 186 | do 187 | for round in 1 188 | do 189 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 190 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 191 | --attn_type RetroInfer \ 192 | --context_len 1024000 \ 193 | --task_name NIAH \ 194 | --batch_size $bsz > different_lengths_logs/retroinfer_1024k_bsz${bsz}_${round}.log 2>&1 195 | done 196 | done 197 | 198 | unset CUDA_VISIBLE_DEVICES 199 | -------------------------------------------------------------------------------- /throughput_eval/run_different_models.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir -p different_models_logs 3 | 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | ################################ Full Attention ################################ 7 | # Llama-3.1-8B 8 | for bsz in 1 4 9 | do 10 | for round in 1 11 | do 12 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 13 | --model_name meta-llama/Llama-3.1-8B-Instruct \ 14 | --attn_type Full_Flash_Attn \ 15 | --context_len 120000 \ 16 | --task_name NIAH \ 17 | --batch_size $bsz > different_models_logs/full_attn_llama31_bsz${bsz}_${round}.log 2>&1 18 | done 19 | done 20 | 21 | # Qwen-2.5-7B 22 | for bsz in 1 9 23 | do 24 | for round in 1 25 | do 26 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 27 | --model_name Qwen/Qwen2.5-7B-Instruct \ 28 | --attn_type Full_Flash_Attn \ 29 | --context_len 120000 \ 30 | --task_name NIAH \ 31 | --batch_size $bsz > different_models_logs/full_attn_qwen_bsz${bsz}_${round}.log 2>&1 32 | done 33 | done 34 | 35 | ################################ RetroInfer ################################ 36 | # Llama-3.1-8B 37 | for bsz in 1 38 | do 39 | for round in 1 40 | do 41 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 42 | --model_name meta-llama/Llama-3.1-8B-Instruct \ 43 | --attn_type RetroInfer \ 44 | --context_len 120000 \ 45 | --task_name NIAH \ 46 | --batch_size $bsz > different_models_logs/retroinfer_llama31_bsz${bsz}_${round}.log 2>&1 47 | done 48 | done 49 | 50 | for bsz in 32 51 | do 52 | for round in 1 53 | do 54 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 55 | --model_name meta-llama/Llama-3.1-8B-Instruct \ 56 | --attn_type RetroInfer \ 57 | --context_len 120000 \ 58 | --task_name NIAH \ 59 | --batch_size $bsz > different_models_logs/retroinfer_llama31_bsz${bsz}_${round}.log 2>&1 60 | done 61 | done 62 | 63 | # Qwen-2.5-7B 64 | for bsz in 1 65 | do 66 | for round in 1 67 | do 68 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 69 | --model_name Qwen/Qwen2.5-7B-Instruct \ 70 | --attn_type RetroInfer \ 71 | --context_len 120000 \ 72 | --task_name NIAH \ 73 | --batch_size $bsz > different_models_logs/retroinfer_qwen_bsz${bsz}_${round}.log 2>&1 74 | done 75 | done 76 | 77 | for bsz in 64 78 | do 79 | for round in 1 80 | do 81 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 82 | --model_name Qwen/Qwen2.5-7B-Instruct \ 83 | --attn_type RetroInfer \ 84 | --context_len 120000 \ 85 | --task_name NIAH \ 86 | --batch_size $bsz > different_models_logs/retroinfer_qwen_bsz${bsz}_${round}.log 2>&1 87 | done 88 | done 89 | 90 | for bsz in 72 91 | do 92 | for round in 1 93 | do 94 | numactl --cpunodebind=0 --membind=0,1,2 python -u test.py \ 95 | --model_name Qwen/Qwen2.5-7B-Instruct \ 96 | --attn_type RetroInfer \ 97 | --context_len 120000 \ 98 | --task_name NIAH \ 99 | --batch_size $bsz > different_models_logs/retroinfer_qwen_bsz${bsz}_${round}.log 2>&1 100 | done 101 | done 102 | unset CUDA_VISIBLE_DEVICES 103 | 104 | 105 | export CUDA_VISIBLE_DEVICES=0,1,2,3 106 | ################################ Qwen2.5-72B ################################ 107 | # Full Attention 108 | for bsz in 1 2 4 109 | do 110 | for round in 1 111 | do 112 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 113 | --model_name Qwen/Qwen2.5-72B-Instruct \ 114 | --attn_type Full_Flash_Attn \ 115 | --device auto \ 116 | --context_len 120000 \ 117 | --task_name NIAH \ 118 | --batch_size $bsz > different_models_logs/full_attn_qwen72b_bsz${bsz}_${round}.log 2>&1 119 | done 120 | done 121 | 122 | # RetroInfer 123 | for bsz in 1 2 4 8 124 | do 125 | for round in 1 126 | do 127 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 128 | --model_name Qwen/Qwen2.5-72B-Instruct \ 129 | --attn_type RetroInfer \ 130 | --device auto \ 131 | --context_len 120000 \ 132 | --task_name NIAH \ 133 | --batch_size $bsz > different_models_logs/retroinfer_qwen72b_bsz${bsz}_${round}.log 2>&1 134 | done 135 | done 136 | 137 | for bsz in 16 138 | do 139 | for round in 1 140 | do 141 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 142 | --model_name Qwen/Qwen2.5-72B-Instruct \ 143 | --attn_type RetroInfer \ 144 | --device auto \ 145 | --context_len 120000 \ 146 | --task_name NIAH \ 147 | --batch_size $bsz > different_models_logs/retroinfer_qwen72b_bsz${bsz}_${round}.log 2>&1 148 | done 149 | done 150 | 151 | for bsz in 32 152 | do 153 | for round in 1 154 | do 155 | numactl --cpunodebind=0 --membind=0,1,2,3 python -u test.py \ 156 | --model_name Qwen/Qwen2.5-72B-Instruct \ 157 | --attn_type RetroInfer \ 158 | --device auto \ 159 | --context_len 120000 \ 160 | --task_name NIAH \ 161 | --batch_size $bsz > different_models_logs/retroinfer_qwen72b_bsz${bsz}_${round}.log 2>&1 162 | done 163 | done 164 | unset CUDA_VISIBLE_DEVICES 165 | -------------------------------------------------------------------------------- /throughput_eval/run_different_tasks.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | mkdir -p different_tasks_logs 4 | 5 | ################################ Full Attention ################################ 6 | for bsz in 1 4 7 | do 8 | for round in 1 9 | do 10 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 11 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 12 | --attn_type Full_Flash_Attn \ 13 | --context_len 120000 \ 14 | --task_name fwe \ 15 | --batch_size $bsz > different_tasks_logs/full_attn_fwe_bsz${bsz}_${round}.log 2>&1 16 | done 17 | done 18 | 19 | for bsz in 1 4 20 | do 21 | for round in 1 22 | do 23 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 24 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 25 | --attn_type Full_Flash_Attn \ 26 | --context_len 120000 \ 27 | --task_name vt \ 28 | --batch_size $bsz > different_tasks_logs/full_attn_vt_bsz${bsz}_${round}.log 2>&1 29 | done 30 | done 31 | 32 | for bsz in 1 4 33 | do 34 | for round in 1 35 | do 36 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 37 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 38 | --attn_type Full_Flash_Attn \ 39 | --context_len 120000 \ 40 | --task_name qa1 \ 41 | --batch_size $bsz > different_tasks_logs/full_attn_qa1_bsz${bsz}_${round}.log 2>&1 42 | done 43 | done 44 | 45 | 46 | ################################ RetroInfer ################################ 47 | # fwe 48 | for bsz in 1 49 | do 50 | for round in 1 51 | do 52 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 53 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 54 | --attn_type RetroInfer \ 55 | --context_len 120000 \ 56 | --task_name fwe \ 57 | --batch_size $bsz > different_tasks_logs/retroinfer_fwe_bsz${bsz}_${round}.log 2>&1 58 | done 59 | done 60 | 61 | for bsz in 32 62 | do 63 | for round in 1 64 | do 65 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 66 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 67 | --attn_type RetroInfer \ 68 | --context_len 120000 \ 69 | --task_name fwe \ 70 | --batch_size $bsz > different_tasks_logs/retroinfer_fwe_bsz${bsz}_${round}.log 2>&1 71 | done 72 | done 73 | 74 | # vt 75 | for bsz in 1 76 | do 77 | for round in 1 78 | do 79 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 80 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 81 | --attn_type RetroInfer \ 82 | --context_len 120000 \ 83 | --task_name vt \ 84 | --batch_size $bsz > different_tasks_logs/retroinfer_vt_bsz${bsz}_${round}.log 2>&1 85 | done 86 | done 87 | 88 | for bsz in 32 89 | do 90 | for round in 1 91 | do 92 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 93 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 94 | --attn_type RetroInfer \ 95 | --context_len 120000 \ 96 | --task_name vt \ 97 | --batch_size $bsz > different_tasks_logs/retroinfer_vt_bsz${bsz}_${round}.log 2>&1 98 | done 99 | done 100 | 101 | # qa1 102 | for bsz in 1 103 | do 104 | for round in 1 105 | do 106 | numactl --cpunodebind=0 --membind=0 python -u test.py \ 107 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 108 | --attn_type RetroInfer \ 109 | --context_len 120000 \ 110 | --task_name qa1 \ 111 | --batch_size $bsz > different_tasks_logs/retroinfer_qa1_bsz${bsz}_${round}.log 2>&1 112 | done 113 | done 114 | 115 | for bsz in 32 116 | do 117 | for round in 1 118 | do 119 | numactl --cpunodebind=0 --membind=0,1 python -u test.py \ 120 | --model_name gradientai/Llama-3-8B-Instruct-Gradient-1048k \ 121 | --attn_type RetroInfer \ 122 | --context_len 120000 \ 123 | --task_name qa1 \ 124 | --batch_size $bsz > different_tasks_logs/retroinfer_qa1_bsz${bsz}_${round}.log 2>&1 125 | done 126 | done 127 | 128 | unset CUDA_VISIBLE_DEVICES 129 | -------------------------------------------------------------------------------- /throughput_eval/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | import argparse 6 | import random 7 | import numpy as np 8 | from termcolor import colored 9 | from transformers import AutoTokenizer 10 | PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 11 | sys.path.append(PROJECT_ROOT) 12 | from model_hub import LlamaModel, QwenModel 13 | 14 | 15 | def set_seed(seed): 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Test example") 25 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size") 26 | parser.add_argument("--context_len", type=int, default=-1, help="Input context length") 27 | parser.add_argument("--device", type=str, default="cuda:0", help="Device") 28 | parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], help="Dtype") 29 | parser.add_argument("--attn_type", type=str, default="Full_Flash_Attn", \ 30 | choices=["Full_Flash_Attn", "RetroInfer"], help="Attention method") 31 | parser.add_argument("--model_name", type=str, default="gradientai/Llama-3-8B-Instruct-Gradient-1048k", \ 32 | choices=["gradientai/Llama-3-8B-Instruct-Gradient-1048k", "Qwen/Qwen2.5-7B-Instruct", \ 33 | "Qwen/Qwen2.5-72B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"], help="huggingface model name") 34 | parser.add_argument("--task_name", type=str, default="multivalue", choices=["NIAH", "fwe", "vt", "qa1"], \ 35 | help="Test task name") 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | 41 | def load_model(model_name, max_len, dtype, device): 42 | if 'Llama' in model_name: 43 | llm = LlamaModel(model_name, 44 | max_length=max_len, 45 | dtype=dtype, 46 | device_map=device) 47 | elif 'Qwen' in model_name: 48 | llm = QwenModel(model_name, 49 | max_length=max_len, 50 | dtype=dtype, 51 | device_map=device) 52 | else: 53 | raise ValueError(f"Unsupported model: {model_name}") 54 | 55 | return llm 56 | 57 | 58 | def generate_config(model_name, context_len, attn_type): 59 | CONFIG_DIR = os.path.join(PROJECT_ROOT, "config") 60 | MODEL_NAME = model_name.split("/")[-1]+'.json' 61 | CONFIG_FILE = os.path.join(CONFIG_DIR, MODEL_NAME) 62 | with open(CONFIG_FILE, "r") as f: 63 | original_config = json.load(f) 64 | 65 | n_clusters = max(int(context_len/16), 1) 66 | n_segments = max(int(context_len/8000), 1) 67 | # compute the nearest multiple of (n_segments*32) 68 | lower = (n_clusters // (n_segments*32)) * (n_segments*32) 69 | upper = lower + (n_segments*32) 70 | n_clusters = lower if abs(n_clusters - lower) <= abs(n_clusters - upper) else upper 71 | nprobe = int(n_clusters*0.018) 72 | 73 | if attn_type == 'RetroInfer': 74 | original_config[attn_type]['n_centroids'] = n_clusters 75 | original_config[attn_type]['n_segment'] = n_segments 76 | original_config[attn_type]['nprobe'] = nprobe 77 | original_config[attn_type]['cache_cluster_num'] = nprobe*3 78 | original_config[attn_type]['max_compute_cluster_num'] = int(n_clusters/4) 79 | 80 | if attn_type != "Full_Flash_Attn": 81 | print(original_config[attn_type]) 82 | 83 | return original_config 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | print(args) 89 | 90 | set_seed(2025) 91 | 92 | model_name = args.model_name 93 | batch_size = args.batch_size 94 | attn_type = args.attn_type 95 | dtype = torch.float16 if args.dtype=='fp16' else torch.bfloat16 96 | device = args.device 97 | task_name = args.task_name 98 | 99 | if task_name == "NIAH": 100 | TEST_DIR = os.path.join(PROJECT_ROOT, "throughput_eval") 101 | TEST_FILE = os.path.join(TEST_DIR, f"test_data/NIAH_{args.context_len}.json") 102 | data = json.load(open(TEST_FILE))[0] 103 | prompt = data['input'] 104 | groundtruth = data['answer'] 105 | attn_config = generate_config(model_name, args.context_len, attn_type) 106 | else: 107 | TEST_DIR = os.path.join(PROJECT_ROOT, "throughput_eval") 108 | TEST_FILE = os.path.join(TEST_DIR, f"test_data/{task_name}.json") 109 | data = json.load(open(TEST_FILE)) 110 | prompt = data['input'] 111 | groundtruth = data['outputs'] 112 | attn_config = generate_config(model_name, 120000, attn_type) 113 | 114 | prompts = [prompt for _ in range(batch_size)] 115 | tokenizer = AutoTokenizer.from_pretrained(model_name) 116 | tokenizer.pad_token = tokenizer.eos_token 117 | tokenizer.padding_side = "left" 118 | inputs = tokenizer(prompts, return_tensors="pt", padding=True) 119 | input_ids = inputs.input_ids 120 | attention_masks = inputs.attention_mask 121 | 122 | input_len = input_ids.shape[1] 123 | gen_len = 100 124 | max_len = input_len + gen_len 125 | print(colored(f"Input length: {input_len}", 'yellow')) 126 | 127 | llm = load_model(model_name, max_len, dtype, device) 128 | out = llm.generate(attention_type=attn_type, 129 | inputs_ids = input_ids.to(llm.layers[0].device), 130 | attention_masks = attention_masks.to(llm.layers[0].device), 131 | max_new_length=gen_len, attn_config=attn_config) 132 | 133 | result = tokenizer.batch_decode(out, skip_special_tokens=True) 134 | print(groundtruth) 135 | print(result) 136 | --------------------------------------------------------------------------------