├── .gitignore ├── DraftRetriever ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── build.rs ├── draftretriever │ ├── __init__.py │ ├── draftretriever.pyi │ └── py.typed ├── pyproject.toml ├── src │ ├── lib.rs │ └── libsais │ │ ├── libsais.c │ │ └── libsais.h └── wheels │ └── draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl ├── LICENSE ├── README.md ├── assets ├── rest_overview.png └── rest_results.png ├── datastore ├── get_datastore_chat.py └── get_datastore_code.py ├── human_eval ├── HumanEval.jsonl.gz ├── baseline_test.py ├── dataset.py ├── rest_test.py └── results │ ├── baseline_test.txt │ └── rest_test.txt ├── llm_judge ├── data │ ├── judge_prompts.jsonl │ └── mt_bench │ │ ├── model_answer │ │ ├── baseline-vicuna-7b-v1.5-temperature-0.0-top_p-0.jsonl │ │ └── rest-vicuna-7b-v1.5-temperature-0.0-top_p-0.jsonl │ │ └── question.jsonl ├── gen_model_answer_baseline.py ├── gen_model_answer_rest.py ├── run_baseline.sh └── run_rest.sh ├── requirements.txt └── rest ├── __init__.py ├── inference ├── __init__.py └── cli.py └── model ├── __init__.py ├── kv_cache.py ├── modeling_llama_kv.py ├── rest_model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated files 2 | *.idx 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /DraftRetriever/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "ahash" 7 | version = "0.7.7" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "5a824f2aa7e75a0c98c5a504fceb80649e9c35265d44525b5f94de4771a395cd" 10 | dependencies = [ 11 | "getrandom", 12 | "once_cell", 13 | "version_check", 14 | ] 15 | 16 | [[package]] 17 | name = "autocfg" 18 | version = "1.1.0" 19 | source = "registry+https://github.com/rust-lang/crates.io-index" 20 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 21 | 22 | [[package]] 23 | name = "bitflags" 24 | version = "1.3.2" 25 | source = "registry+https://github.com/rust-lang/crates.io-index" 26 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 27 | 28 | [[package]] 29 | name = "bstr" 30 | version = "0.2.17" 31 | source = "registry+https://github.com/rust-lang/crates.io-index" 32 | checksum = "ba3569f383e8f1598449f1a423e72e99569137b47740b1da11ef19af3d5c3223" 33 | dependencies = [ 34 | "lazy_static", 35 | "memchr", 36 | "regex-automata", 37 | ] 38 | 39 | [[package]] 40 | name = "byteorder" 41 | version = "1.5.0" 42 | source = "registry+https://github.com/rust-lang/crates.io-index" 43 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 44 | 45 | [[package]] 46 | name = "cc" 47 | version = "1.0.83" 48 | source = "registry+https://github.com/rust-lang/crates.io-index" 49 | checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" 50 | dependencies = [ 51 | "jobserver", 52 | "libc", 53 | ] 54 | 55 | [[package]] 56 | name = "cfg-if" 57 | version = "1.0.0" 58 | source = "registry+https://github.com/rust-lang/crates.io-index" 59 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 60 | 61 | [[package]] 62 | name = "crossbeam-deque" 63 | version = "0.8.3" 64 | source = "registry+https://github.com/rust-lang/crates.io-index" 65 | checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" 66 | dependencies = [ 67 | "cfg-if", 68 | "crossbeam-epoch", 69 | "crossbeam-utils", 70 | ] 71 | 72 | [[package]] 73 | name = "crossbeam-epoch" 74 | version = "0.9.15" 75 | source = "registry+https://github.com/rust-lang/crates.io-index" 76 | checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" 77 | dependencies = [ 78 | "autocfg", 79 | "cfg-if", 80 | "crossbeam-utils", 81 | "memoffset 0.9.0", 82 | "scopeguard", 83 | ] 84 | 85 | [[package]] 86 | name = "crossbeam-utils" 87 | version = "0.8.16" 88 | source = "registry+https://github.com/rust-lang/crates.io-index" 89 | checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" 90 | dependencies = [ 91 | "cfg-if", 92 | ] 93 | 94 | [[package]] 95 | name = "draftretriever" 96 | version = "0.1.0" 97 | dependencies = [ 98 | "ahash", 99 | "bstr", 100 | "byteorder", 101 | "cc", 102 | "memchr", 103 | "parking_lot", 104 | "pyo3", 105 | "rayon", 106 | ] 107 | 108 | [[package]] 109 | name = "either" 110 | version = "1.9.0" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" 113 | 114 | [[package]] 115 | name = "getrandom" 116 | version = "0.2.11" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" 119 | dependencies = [ 120 | "cfg-if", 121 | "libc", 122 | "wasi", 123 | ] 124 | 125 | [[package]] 126 | name = "indoc" 127 | version = "1.0.9" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" 130 | 131 | [[package]] 132 | name = "jobserver" 133 | version = "0.1.27" 134 | source = "registry+https://github.com/rust-lang/crates.io-index" 135 | checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" 136 | dependencies = [ 137 | "libc", 138 | ] 139 | 140 | [[package]] 141 | name = "lazy_static" 142 | version = "1.4.0" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 145 | 146 | [[package]] 147 | name = "libc" 148 | version = "0.2.150" 149 | source = "registry+https://github.com/rust-lang/crates.io-index" 150 | checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" 151 | 152 | [[package]] 153 | name = "lock_api" 154 | version = "0.4.11" 155 | source = "registry+https://github.com/rust-lang/crates.io-index" 156 | checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" 157 | dependencies = [ 158 | "autocfg", 159 | "scopeguard", 160 | ] 161 | 162 | [[package]] 163 | name = "memchr" 164 | version = "2.6.4" 165 | source = "registry+https://github.com/rust-lang/crates.io-index" 166 | checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" 167 | 168 | [[package]] 169 | name = "memoffset" 170 | version = "0.6.5" 171 | source = "registry+https://github.com/rust-lang/crates.io-index" 172 | checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" 173 | dependencies = [ 174 | "autocfg", 175 | ] 176 | 177 | [[package]] 178 | name = "memoffset" 179 | version = "0.9.0" 180 | source = "registry+https://github.com/rust-lang/crates.io-index" 181 | checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" 182 | dependencies = [ 183 | "autocfg", 184 | ] 185 | 186 | [[package]] 187 | name = "once_cell" 188 | version = "1.18.0" 189 | source = "registry+https://github.com/rust-lang/crates.io-index" 190 | checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" 191 | 192 | [[package]] 193 | name = "parking_lot" 194 | version = "0.12.1" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" 197 | dependencies = [ 198 | "lock_api", 199 | "parking_lot_core", 200 | ] 201 | 202 | [[package]] 203 | name = "parking_lot_core" 204 | version = "0.9.9" 205 | source = "registry+https://github.com/rust-lang/crates.io-index" 206 | checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" 207 | dependencies = [ 208 | "cfg-if", 209 | "libc", 210 | "redox_syscall", 211 | "smallvec", 212 | "windows-targets", 213 | ] 214 | 215 | [[package]] 216 | name = "proc-macro2" 217 | version = "1.0.69" 218 | source = "registry+https://github.com/rust-lang/crates.io-index" 219 | checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" 220 | dependencies = [ 221 | "unicode-ident", 222 | ] 223 | 224 | [[package]] 225 | name = "pyo3" 226 | version = "0.17.3" 227 | source = "registry+https://github.com/rust-lang/crates.io-index" 228 | checksum = "268be0c73583c183f2b14052337465768c07726936a260f480f0857cb95ba543" 229 | dependencies = [ 230 | "cfg-if", 231 | "indoc", 232 | "libc", 233 | "memoffset 0.6.5", 234 | "parking_lot", 235 | "pyo3-build-config", 236 | "pyo3-ffi", 237 | "pyo3-macros", 238 | "unindent", 239 | ] 240 | 241 | [[package]] 242 | name = "pyo3-build-config" 243 | version = "0.17.3" 244 | source = "registry+https://github.com/rust-lang/crates.io-index" 245 | checksum = "28fcd1e73f06ec85bf3280c48c67e731d8290ad3d730f8be9dc07946923005c8" 246 | dependencies = [ 247 | "once_cell", 248 | "target-lexicon", 249 | ] 250 | 251 | [[package]] 252 | name = "pyo3-ffi" 253 | version = "0.17.3" 254 | source = "registry+https://github.com/rust-lang/crates.io-index" 255 | checksum = "0f6cb136e222e49115b3c51c32792886defbfb0adead26a688142b346a0b9ffc" 256 | dependencies = [ 257 | "libc", 258 | "pyo3-build-config", 259 | ] 260 | 261 | [[package]] 262 | name = "pyo3-macros" 263 | version = "0.17.3" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" 266 | dependencies = [ 267 | "proc-macro2", 268 | "pyo3-macros-backend", 269 | "quote", 270 | "syn", 271 | ] 272 | 273 | [[package]] 274 | name = "pyo3-macros-backend" 275 | version = "0.17.3" 276 | source = "registry+https://github.com/rust-lang/crates.io-index" 277 | checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" 278 | dependencies = [ 279 | "proc-macro2", 280 | "quote", 281 | "syn", 282 | ] 283 | 284 | [[package]] 285 | name = "quote" 286 | version = "1.0.33" 287 | source = "registry+https://github.com/rust-lang/crates.io-index" 288 | checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" 289 | dependencies = [ 290 | "proc-macro2", 291 | ] 292 | 293 | [[package]] 294 | name = "rayon" 295 | version = "1.8.0" 296 | source = "registry+https://github.com/rust-lang/crates.io-index" 297 | checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" 298 | dependencies = [ 299 | "either", 300 | "rayon-core", 301 | ] 302 | 303 | [[package]] 304 | name = "rayon-core" 305 | version = "1.12.0" 306 | source = "registry+https://github.com/rust-lang/crates.io-index" 307 | checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" 308 | dependencies = [ 309 | "crossbeam-deque", 310 | "crossbeam-utils", 311 | ] 312 | 313 | [[package]] 314 | name = "redox_syscall" 315 | version = "0.4.1" 316 | source = "registry+https://github.com/rust-lang/crates.io-index" 317 | checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" 318 | dependencies = [ 319 | "bitflags", 320 | ] 321 | 322 | [[package]] 323 | name = "regex-automata" 324 | version = "0.1.10" 325 | source = "registry+https://github.com/rust-lang/crates.io-index" 326 | checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" 327 | 328 | [[package]] 329 | name = "scopeguard" 330 | version = "1.2.0" 331 | source = "registry+https://github.com/rust-lang/crates.io-index" 332 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 333 | 334 | [[package]] 335 | name = "smallvec" 336 | version = "1.11.2" 337 | source = "registry+https://github.com/rust-lang/crates.io-index" 338 | checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" 339 | 340 | [[package]] 341 | name = "syn" 342 | version = "1.0.109" 343 | source = "registry+https://github.com/rust-lang/crates.io-index" 344 | checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" 345 | dependencies = [ 346 | "proc-macro2", 347 | "quote", 348 | "unicode-ident", 349 | ] 350 | 351 | [[package]] 352 | name = "target-lexicon" 353 | version = "0.12.12" 354 | source = "registry+https://github.com/rust-lang/crates.io-index" 355 | checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" 356 | 357 | [[package]] 358 | name = "unicode-ident" 359 | version = "1.0.12" 360 | source = "registry+https://github.com/rust-lang/crates.io-index" 361 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 362 | 363 | [[package]] 364 | name = "unindent" 365 | version = "0.1.11" 366 | source = "registry+https://github.com/rust-lang/crates.io-index" 367 | checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c" 368 | 369 | [[package]] 370 | name = "version_check" 371 | version = "0.9.4" 372 | source = "registry+https://github.com/rust-lang/crates.io-index" 373 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 374 | 375 | [[package]] 376 | name = "wasi" 377 | version = "0.11.0+wasi-snapshot-preview1" 378 | source = "registry+https://github.com/rust-lang/crates.io-index" 379 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 380 | 381 | [[package]] 382 | name = "windows-targets" 383 | version = "0.48.5" 384 | source = "registry+https://github.com/rust-lang/crates.io-index" 385 | checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" 386 | dependencies = [ 387 | "windows_aarch64_gnullvm", 388 | "windows_aarch64_msvc", 389 | "windows_i686_gnu", 390 | "windows_i686_msvc", 391 | "windows_x86_64_gnu", 392 | "windows_x86_64_gnullvm", 393 | "windows_x86_64_msvc", 394 | ] 395 | 396 | [[package]] 397 | name = "windows_aarch64_gnullvm" 398 | version = "0.48.5" 399 | source = "registry+https://github.com/rust-lang/crates.io-index" 400 | checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" 401 | 402 | [[package]] 403 | name = "windows_aarch64_msvc" 404 | version = "0.48.5" 405 | source = "registry+https://github.com/rust-lang/crates.io-index" 406 | checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" 407 | 408 | [[package]] 409 | name = "windows_i686_gnu" 410 | version = "0.48.5" 411 | source = "registry+https://github.com/rust-lang/crates.io-index" 412 | checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" 413 | 414 | [[package]] 415 | name = "windows_i686_msvc" 416 | version = "0.48.5" 417 | source = "registry+https://github.com/rust-lang/crates.io-index" 418 | checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" 419 | 420 | [[package]] 421 | name = "windows_x86_64_gnu" 422 | version = "0.48.5" 423 | source = "registry+https://github.com/rust-lang/crates.io-index" 424 | checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" 425 | 426 | [[package]] 427 | name = "windows_x86_64_gnullvm" 428 | version = "0.48.5" 429 | source = "registry+https://github.com/rust-lang/crates.io-index" 430 | checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" 431 | 432 | [[package]] 433 | name = "windows_x86_64_msvc" 434 | version = "0.48.5" 435 | source = "registry+https://github.com/rust-lang/crates.io-index" 436 | checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" 437 | -------------------------------------------------------------------------------- /DraftRetriever/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "draftretriever" 3 | version = "0.1.0" 4 | authors = ["Zhenyu He ", "Zexuan Zhong ", "Tianle Cai "] 5 | edition = "2021" 6 | description = "Retriver for searching draft tokens for speculative decoding" 7 | readme = "README.md" 8 | repository = "https://github.com/zhenyuhe00/DraftRetriever" 9 | homepage = "https://github.com/zhenyuhe00/DraftRetriever" 10 | license = "MIT" 11 | keywords = [ 12 | "substring", 13 | "pattern", 14 | "search", 15 | "suffix", 16 | "array", 17 | "rust", 18 | "pyo3" 19 | ] 20 | 21 | 22 | [lib] 23 | name = "draftretriever" 24 | crate-type = ["cdylib"] 25 | 26 | [dependencies] 27 | ahash = "0.7" 28 | bstr = "0.2" 29 | byteorder = "1" 30 | memchr = "2" 31 | parking_lot = "0.12" 32 | rayon = "1" 33 | 34 | [dependencies.pyo3] 35 | version = "0.17.0" 36 | features = ["extension-module"] 37 | 38 | [build-dependencies] 39 | cc = { version = "1.0", features = ["parallel"] } 40 | 41 | [profile.release] 42 | lto = true 43 | panic = "abort" 44 | -------------------------------------------------------------------------------- /DraftRetriever/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gal Ben David 4 | Copyright (c) 2023 Zhenyu He, Zexuan Zhong, Tianle Cai 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /DraftRetriever/README.md: -------------------------------------------------------------------------------- 1 | Retriever for searching draft tokens for speculative decoding 2 | 3 | 4 | ## Table of Contents 5 | 6 | - [Table of Contents](#table-of-contents) 7 | - [About The Project](#about-the-project) 8 | - [Built With](#built-with) 9 | - [Installation](#installation) 10 | - [Usage](#usage) 11 | - [License](#license) 12 | - [Acknowledgement](#acknowledgement) 13 | - [Contact](#contact) 14 | 15 | 16 | ## About The Project 17 | 18 | DraftRerriever is a library designed to searching draft tokens for speculative decoding. In order to achieve speed and efficiency, the library is written in Rust. For string indexing, the library uses [libsais](https://github.com/IlyaGrebnov/libsais) suffix array construction library. The datastore created consists of the original 16bit tokens and a 32bit suffix array struct. 19 | 20 | The module implements a method for searching. 21 | - `search` - Find multiple candidates give preceding tokens, and return the most probable draft tokens by constructing a Trie. It also returns draft buffer. 22 | 23 | 24 | ### Built With 25 | 26 | * [libsais](https://github.com/IlyaGrebnov/libsais) 27 | 28 | 29 | ### Installation 30 | 31 | **Use pre-compiled wheels** 32 | ```sh 33 | pip3 install wheels/draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl 34 | ``` 35 | 36 | **Build from source** 37 | ```sh 38 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 39 | maturin build --release --strip -i python3.9 # will produce a .whl file 40 | pip3 install [.whl] 41 | ``` 42 | 43 | 44 | ## Usage 45 | 46 | Create a datastore 47 | ```python 48 | import draftretriever 49 | 50 | # creating a new datastore 51 | # if a file with this name is already exists, it will be overwritten 52 | writer = draftretriever.Writer( 53 | index_file_path='output.idx', 54 | # vocab_size=tokenizer.vocab_size 55 | ) 56 | 57 | # adding entries to the new datastore 58 | writer.add_entry([1, 2, 3, 4]) # a list of token 59 | writer.add_entry([1, 2, 3, 4]) 60 | writer.add_entry([2, 3, 5, 6]) 61 | 62 | # making sure the data is dumped to the file 63 | writer.finalize() 64 | ``` 65 | 66 | Search draft tokens 67 | ```python 68 | import draftretriever 69 | 70 | # opening a datastore for searching 71 | reader = draftretriever.Reader( 72 | index_file_path='output.idx', 73 | ) 74 | 75 | # search for draft tokens 76 | preceding = [2, 3] 77 | # "choices" is the number of maximum draft tokens. The implementation is not very strict and has some randomness. 78 | retrieved_token_list, _draft_attn_mask, _tree_indices, _draft_position_ids, _retrieve_indices = reader.search(preceding, choices=2) 79 | print(retrieved_token_list) 80 | >>> [[4]] or [[4], [5]] 81 | # retrieved_token_list is a list of selected paths(sequences) in the Trie. Each sequence are padded (-2) to the maximum length of these sequences. 82 | ``` 83 | 84 | ## License 85 | 86 | Distributed under the MIT License. See `LICENSE` for more information. 87 | 88 | ## Acknowledgement 89 | The main framework is from [PySubstringSearch](https://github.com/Intsights/PySubstringSearch) 90 | 91 | 92 | -------------------------------------------------------------------------------- /DraftRetriever/build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | println!("cargo:rerun-if-changed=libsais.c"); 3 | 4 | let src = [ 5 | "src/libsais/libsais.c", 6 | ]; 7 | let mut builder = cc::Build::new(); 8 | let build = builder 9 | .files(src.iter()); 10 | build.compile("libsais"); 11 | } 12 | -------------------------------------------------------------------------------- /DraftRetriever/draftretriever/__init__.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from . import draftretriever 4 | 5 | 6 | class Writer: 7 | def __init__( 8 | self, 9 | index_file_path: str, 10 | max_chunk_len: typing.Optional[int] = None, 11 | vocab_size: typing.Optional[int] = None, 12 | ) -> None: 13 | self.writer = draftretriever.Writer( 14 | index_file_path=index_file_path, 15 | max_chunk_len=max_chunk_len, 16 | vocab_size=vocab_size, 17 | ) 18 | 19 | def add_entry( 20 | self, 21 | py_text: typing.List, 22 | ) -> None: 23 | self.writer.add_entry( 24 | py_text=py_text, 25 | ) 26 | 27 | def dump_data( 28 | self, 29 | ) -> None: 30 | self.writer.dump_data() 31 | 32 | def finalize( 33 | self, 34 | ) -> None: 35 | self.writer.finalize() 36 | 37 | 38 | class Reader: 39 | def __init__( 40 | self, 41 | index_file_path: str, 42 | ) -> None: 43 | self.reader = draftretriever.Reader( 44 | index_file_path=index_file_path, 45 | ) 46 | 47 | def search( 48 | self, 49 | py_substring: typing.List, 50 | k: typing.Optional[int] = None, 51 | choices: typing.Optional[int] = None, 52 | long: typing.Optional[int] = None, 53 | ): 54 | return self.reader.search( 55 | py_substring=py_substring, 56 | k=k, 57 | choices=choices, 58 | long=long, 59 | ) 60 | 61 | -------------------------------------------------------------------------------- /DraftRetriever/draftretriever/draftretriever.pyi: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | 4 | class Writer: 5 | def __init__( 6 | self, 7 | index_file_path: str, 8 | max_chunk_len: typing.Optional[int] = None, 9 | vocab_size: typing.Optional[int] = None, 10 | ) -> None: ... 11 | 12 | 13 | def add_entry( 14 | self, 15 | py_text: typing.List, 16 | ) -> None: ... 17 | 18 | def dump_data( 19 | self, 20 | ) -> None: ... 21 | 22 | def finalize( 23 | self, 24 | ) -> None: ... 25 | 26 | 27 | class Reader: 28 | def __init__( 29 | self, 30 | index_file_path: str, 31 | ) -> None: ... 32 | 33 | def search( 34 | self, 35 | py_substring: typing.List, 36 | k: typing.Optional[int] = None, 37 | choices: typing.Optional[int] = None, 38 | long: typing.Optional[int] = None, 39 | ): ... 40 | -------------------------------------------------------------------------------- /DraftRetriever/draftretriever/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/DraftRetriever/draftretriever/py.typed -------------------------------------------------------------------------------- /DraftRetriever/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=0.12,<0.13"] 3 | build-backend = "maturin" 4 | 5 | [tool.maturin] 6 | sdist-include = [ 7 | "src/*", 8 | "Cargo.toml", 9 | "draftretriever/*.py", 10 | "draftretriever/*.pyi" 11 | ] 12 | 13 | [tool.poetry] 14 | name = "draftretriever" 15 | version = "0.1.0" 16 | authors = ["Zhenyu He ", "Zexuan Zhong ", "Tianle Cai "] 17 | description = "Retriver for searching draft tokens for speculative decoding" 18 | readme = "README.md" 19 | repository = "https://github.com/zhenyuhe00/DraftRetriever" 20 | homepage = "https://github.com/zhenyuhe00/DraftRetriever" 21 | license = "MIT" 22 | keywords = [ 23 | "substring", 24 | "pattern", 25 | "search", 26 | "suffix", 27 | "array", 28 | "rust", 29 | "pyo3" 30 | ] 31 | classifiers = [ 32 | "License :: OSI Approved :: MIT License", 33 | "Operating System :: MacOS", 34 | "Operating System :: Microsoft", 35 | "Operating System :: POSIX :: Linux", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | "Programming Language :: Python :: 3.9", 39 | "Programming Language :: Python :: 3.10", 40 | "Programming Language :: Python :: 3.11", 41 | "Programming Language :: Rust", 42 | ] 43 | 44 | [tool.poetry.dependencies] 45 | python = "^3.7" 46 | 47 | [tool.poetry.dev-dependencies] 48 | pytest = "*" 49 | gitpython = "*" 50 | wheel = "*" 51 | pytest-runner = "*" 52 | maturin = "*" 53 | 54 | [tool.pytest.ini_options] 55 | minversion = "6.0" 56 | addopts = [ 57 | "--tb=native", 58 | "--pythonwarnings=all", 59 | ] 60 | testpaths = [ 61 | "tests", 62 | ] 63 | -------------------------------------------------------------------------------- /DraftRetriever/src/lib.rs: -------------------------------------------------------------------------------- 1 | // The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch; 2 | // The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124 3 | use ahash::AHashSet; 4 | use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian}; 5 | use parking_lot::Mutex; 6 | use pyo3::exceptions; 7 | use pyo3::prelude::*; 8 | use rayon::prelude::*; 9 | use std::fs::File; 10 | use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}; 11 | use std::str; 12 | use std::sync::Arc; 13 | use std::collections::HashMap; 14 | use std::cmp::Reverse; 15 | use std::collections::HashSet; 16 | use std::cmp; 17 | use std::cmp::Ordering; 18 | use pyo3::types::PyList; 19 | use std::collections::BinaryHeap; 20 | use std::fs; 21 | use std::io::Cursor; 22 | 23 | extern "C" { 24 | pub fn libsais_int( 25 | data: *const i32, 26 | suffix_array: *mut i32, 27 | data_len: i32, 28 | suffix_array_extra_space: i32, 29 | symbol_frequency_table: i32, 30 | ) -> i32; 31 | } 32 | 33 | fn construct_suffix_array( 34 | buffer: &[i32], 35 | vocab_size: i32, 36 | ) -> Vec { 37 | let mut suffix_array = vec![0; buffer.len()]; 38 | 39 | unsafe { 40 | libsais_int( 41 | buffer.as_ptr(), 42 | suffix_array.as_mut_ptr(), 43 | buffer.len() as i32, 44 | vocab_size, 45 | 0, 46 | ); 47 | } 48 | 49 | suffix_array 50 | } 51 | 52 | #[pyclass] 53 | struct Writer { 54 | index_file: BufWriter, 55 | buffer: Vec, 56 | vocab_size: i32, 57 | } 58 | 59 | #[pymethods] 60 | impl Writer { 61 | #[new] 62 | fn new( 63 | index_file_path: &str, 64 | max_chunk_len: Option, 65 | vocab_size: Option, 66 | ) -> PyResult { 67 | let index_file = File::create(index_file_path)?; 68 | let index_file = BufWriter::new(index_file); 69 | 70 | let max_chunk_len = max_chunk_len.unwrap_or(512 * 1024 * 1024); 71 | let vocab_size = vocab_size.unwrap_or(35000); 72 | 73 | Ok( 74 | Writer { 75 | index_file, 76 | buffer: Vec::with_capacity(max_chunk_len), 77 | vocab_size, 78 | } 79 | ) 80 | } 81 | 82 | fn add_entry( 83 | &mut self, 84 | py_text: &PyList, 85 | ) -> PyResult<()> { 86 | 87 | let mut text = Vec::new(); 88 | for item in py_text.iter() { 89 | let num: i32 = item.extract()?; 90 | text.push(num); 91 | } 92 | 93 | if text.len() > self.buffer.capacity() { 94 | return Err(exceptions::PyValueError::new_err("entry is too big")); 95 | } 96 | 97 | if self.buffer.len() + text.len() > self.buffer.capacity() { 98 | self.dump_data()?; 99 | } 100 | self.buffer.extend_from_slice(&text); 101 | 102 | // self.buffer.push(34999); 103 | 104 | Ok(()) 105 | } 106 | 107 | fn dump_data( 108 | &mut self, 109 | ) -> PyResult<()> { 110 | if self.buffer.is_empty() { 111 | return Ok(()); 112 | } 113 | 114 | self.index_file.write_u32::((self.buffer.len() * 2) as u32)?; 115 | 116 | for &item in &self.buffer { 117 | self.index_file.write_u16::(item as u16)?; 118 | } 119 | 120 | let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size); 121 | self.index_file.write_u32::((suffix_array.len() * 4) as u32)?; 122 | for suffix in suffix_array { 123 | self.index_file.write_i32::(suffix)?; 124 | } 125 | self.buffer.clear(); 126 | 127 | Ok(()) 128 | } 129 | 130 | fn finalize( 131 | &mut self, 132 | ) -> PyResult<()> { 133 | if !self.buffer.is_empty() { 134 | self.dump_data()?; 135 | } 136 | self.index_file.flush()?; 137 | 138 | Ok(()) 139 | } 140 | } 141 | 142 | impl Drop for Writer { 143 | fn drop( 144 | &mut self, 145 | ) { 146 | self.finalize().unwrap(); 147 | } 148 | } 149 | 150 | struct SubIndex { 151 | data: Vec, 152 | index_file: Cursor>, // BufReader, // Cursor>, 153 | suffixes_file_start: usize, 154 | suffixes_file_end: usize, 155 | } 156 | 157 | #[pyclass] 158 | struct Reader { 159 | sub_indexes: Vec, 160 | } 161 | 162 | #[pymethods] 163 | impl Reader { 164 | #[new] 165 | fn new( 166 | index_file_path: &str, 167 | ) -> PyResult { 168 | let index_file = File::open(index_file_path)?; 169 | let mut index_file = BufReader::new(index_file); 170 | let index_file_metadata = std::fs::metadata(index_file_path)?; 171 | let index_file_len = index_file_metadata.len(); 172 | let mut bytes_read = 0; 173 | 174 | let mut sub_indexes = Vec::new(); 175 | 176 | while bytes_read < index_file_len { 177 | let data_file_len = index_file.read_u32::()?; 178 | let mut data_u8 = vec![0; data_file_len as usize]; 179 | index_file.read_exact(&mut data_u8)?; 180 | 181 | let suffixes_file_len = index_file.read_u32::()? as usize; 182 | let suffixes_file_start = index_file.seek(SeekFrom::Current(0))? as usize; 183 | let suffixes_file_end = suffixes_file_start + suffixes_file_len; 184 | index_file.seek(SeekFrom::Current(suffixes_file_len as i64))?; 185 | 186 | bytes_read += 4 + 4 + data_file_len as u64 + suffixes_file_len as u64; 187 | 188 | 189 | let mut data: Vec = Vec::new(); 190 | 191 | for i in (0..data_u8.len()).step_by(2) { 192 | let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32; 193 | data.push(int); 194 | } 195 | 196 | sub_indexes.push( 197 | SubIndex { 198 | data, 199 | index_file: Cursor::new(fs::read(index_file_path).unwrap()), // BufReader::new(File::open(index_file_path)?), // Cursor::new(fs::read(index_file_path).unwrap()), 200 | suffixes_file_start, 201 | suffixes_file_end, 202 | } 203 | ); 204 | } 205 | 206 | Ok(Reader { sub_indexes }) 207 | } 208 | 209 | fn search( 210 | &mut self, 211 | py_substring: &PyList, 212 | k: Option, 213 | choices: Option, 214 | long: Option, 215 | ) -> PyResult<(Vec>, Vec>, Vec, Vec, Vec>)> { 216 | 217 | // substring_i32 is just a rust version of py_substring 218 | let mut substring_i32 = Vec::new(); 219 | for item in py_substring.iter() { 220 | let num: i32 = item.extract()?; 221 | substring_i32.push(num); 222 | } 223 | 224 | let results = Arc::new(Mutex::new(Vec::new())); 225 | 226 | // each sub index is a buffer/suffix pair 227 | self.sub_indexes.par_iter_mut().for_each( 228 | |sub_index| { 229 | let mut start_of_indices = None; 230 | let mut end_of_indices = None; 231 | 232 | // since suffix arrays have the suffixes in sorted order, we do a binary search 233 | // over the suffix array 234 | // this binary search finds the start of the matching suffixes 235 | let mut left_anchor = sub_index.suffixes_file_start; 236 | let mut right_anchor = sub_index.suffixes_file_end - 4; 237 | while left_anchor <= right_anchor { 238 | let middle_anchor = left_anchor + ((right_anchor - left_anchor) / 4 / 2 * 4); 239 | sub_index.index_file.seek(SeekFrom::Start(middle_anchor as u64)).unwrap(); 240 | // data_index is the value at middle_anchor in the suffix array 241 | let data_index = sub_index.index_file.read_i32::().unwrap(); 242 | // line is the actual suffix 243 | let line = &sub_index.data[(data_index) as usize..]; 244 | 245 | // we don't use the entire suffix. we look for suffixes that start with the substring we're looking for 246 | // the suffix array sorts suffixes based on the start of the suffix, so this technique is sound 247 | // the "match length" is defined by the length of substring_i32. the suffix array doesn't need to worry about "match length" 248 | if line.starts_with(&substring_i32) { 249 | start_of_indices = Some(middle_anchor); 250 | right_anchor = middle_anchor - 4; 251 | } else { 252 | match line.cmp(&substring_i32) { 253 | std::cmp::Ordering::Less => left_anchor = middle_anchor + 4, 254 | std::cmp::Ordering::Greater => right_anchor = middle_anchor - 4, 255 | std::cmp::Ordering::Equal => {}, 256 | }; 257 | } 258 | } 259 | if start_of_indices.is_none() { 260 | return; 261 | } 262 | 263 | // this binary search finds the end of the matching suffixes 264 | let mut right_anchor = sub_index.suffixes_file_end - 4; 265 | while left_anchor <= right_anchor { 266 | let middle_anchor = left_anchor + ((right_anchor - left_anchor) / 4 / 2 * 4); 267 | sub_index.index_file.seek(SeekFrom::Start(middle_anchor as u64)).unwrap(); 268 | let data_index = sub_index.index_file.read_i32::().unwrap(); 269 | let line = &sub_index.data[(data_index) as usize..]; 270 | if line.starts_with(&substring_i32) { 271 | end_of_indices = Some(middle_anchor); 272 | left_anchor = middle_anchor + 4; 273 | } else { 274 | match line.cmp(&substring_i32) { 275 | std::cmp::Ordering::Less => left_anchor = middle_anchor + 4, 276 | std::cmp::Ordering::Greater => right_anchor = middle_anchor - 4, 277 | std::cmp::Ordering::Equal => {}, 278 | }; 279 | } 280 | } 281 | 282 | let start_of_indices = start_of_indices.unwrap(); 283 | let end_of_indices = end_of_indices.unwrap(); 284 | 285 | let mut suffixes = vec![0; end_of_indices - start_of_indices + 4]; 286 | 287 | sub_index.index_file.seek(SeekFrom::Start(start_of_indices as u64)).unwrap(); 288 | sub_index.index_file.read_exact(&mut suffixes).unwrap(); 289 | 290 | let mut matches_ranges = AHashSet::new(); 291 | 292 | let mut cnt = 0; 293 | let k = k.unwrap_or(5000); 294 | let long = long.unwrap_or(10); 295 | let indices_size = (end_of_indices - start_of_indices + 4) / 4; 296 | let initial_capacity = std::cmp::min(indices_size, k as usize); 297 | let mut local_results = Vec::with_capacity(initial_capacity); 298 | 299 | for suffix in suffixes.chunks_mut(4) { 300 | let data_index = LittleEndian::read_i32(suffix); 301 | if matches_ranges.insert(data_index) { 302 | let sub_string_plus = &sub_index.data[data_index as usize + substring_i32.len() ..std::cmp::min(data_index as usize + substring_i32.len() + long as usize, sub_index.data.len())]; 303 | 304 | local_results.push(sub_string_plus.to_vec()); 305 | cnt += 1; 306 | if cnt >= k as usize { 307 | break; 308 | } 309 | 310 | } 311 | } 312 | 313 | results.lock().extend(local_results); 314 | } 315 | ); 316 | 317 | let results = results.lock(); 318 | 319 | if results.is_empty() { 320 | return Ok((Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new())); 321 | } 322 | 323 | let mut cnt = HashMap::new(); 324 | for retrieved_token in &*results { 325 | for j in 0..retrieved_token.len() { 326 | let tmp_token = &retrieved_token[0..=j]; 327 | let counter = cnt.entry(tmp_token).or_insert(0); 328 | *counter += 1; 329 | } 330 | } 331 | 332 | let choices = choices.unwrap_or(64); 333 | // The items in the heap must be a Trie. 334 | let mut heap = BinaryHeap::new(); 335 | for (k, v) in &cnt { 336 | if heap.len() < (choices as usize) { 337 | heap.push((Reverse(*v), k)); 338 | } else if let Some(&(Reverse(top_v), _)) = heap.peek() { 339 | if *v > top_v { 340 | heap.pop(); 341 | heap.push((Reverse(*v), k)); 342 | } 343 | } 344 | } 345 | let verified: Vec<_> = heap.into_iter().map(|(_, k)| k.to_vec()).collect(); 346 | // Convert into a HashSet to remove duplicates 347 | let verified: std::collections::HashSet<_> = verified.into_iter().collect(); 348 | let verified: Vec<_> = verified.into_iter().collect(); 349 | 350 | // Because multiple nodes in the Trie may have same weights around the threshold, the number of draft tokens may exceed choices 351 | // We roughly cut nodes to be less than choices in most cases. 352 | let paths = cut_to_choices(verified, choices); 353 | 354 | let (draft_choices, max_branch) = get_draft_choices(paths.clone()); 355 | 356 | if draft_choices.len() > choices as usize { 357 | // It might not be cut enough because cut_to_choices() is best effort, as mentioned in the comment above 358 | return Err(exceptions::PyValueError::new_err("draft_choices was not cut enough")); 359 | } 360 | 361 | let (draft_attn_mask, tree_indices, draft_position_ids, retrieve_indices) = generate_draft_buffers(draft_choices.clone(), max_branch); 362 | 363 | let max_length = paths.iter().map(|path| path.len()).max().unwrap_or(0); 364 | 365 | Ok((paths.into_iter().map(|path| pad_path(path, max_length, -2)).collect::>>(), draft_attn_mask, tree_indices, draft_position_ids, retrieve_indices)) 366 | } 367 | } 368 | 369 | 370 | 371 | fn cut_to_choices(paths: Vec>, choices: i32) -> Vec> { 372 | let mut count: Vec<(usize, usize)> = paths.iter() 373 | .map(|p| (p.iter().collect::>().len(), paths.iter().position(|x| x == p).unwrap())) 374 | .collect(); 375 | count.sort_by(|a, b| b.0.cmp(&a.0)); 376 | 377 | let mut total_unique = count.iter().map(|(x, _)| x).sum::(); 378 | let mut to_remove = Vec::new(); 379 | 380 | for (c, i) in count { 381 | if total_unique > choices as usize { 382 | total_unique -= c; 383 | to_remove.push(i); 384 | } else { 385 | break; 386 | } 387 | } 388 | 389 | paths.into_iter().enumerate().filter(|(i, _)| !to_remove.contains(i)).map(|(_, p)| p).collect() 390 | } 391 | 392 | 393 | fn get_draft_choices(paths: Vec>) -> (Vec>, i32) { 394 | let mut path_dict: HashMap> = HashMap::new(); 395 | let mut cnt_dict: HashMap = HashMap::new(); 396 | let max_depth = paths.iter().map(|path| path.len() as i32).max().unwrap(); 397 | 398 | for depth in 0..max_depth { 399 | cnt_dict.insert(depth, 0); 400 | } 401 | 402 | for path in &paths { 403 | for (depth, item) in path.iter().enumerate() { 404 | let depth = depth as i32; 405 | if !path_dict.contains_key(&depth) { 406 | path_dict.insert(depth, HashMap::new()); 407 | } 408 | 409 | let current_path_dict = path_dict.get_mut(&depth).unwrap(); 410 | if !current_path_dict.contains_key(item) { 411 | let current_cnt = cnt_dict.get(&depth).unwrap().clone(); 412 | current_path_dict.insert(*item, current_cnt); 413 | *cnt_dict.get_mut(&depth).unwrap() += 1; 414 | } 415 | } 416 | } 417 | 418 | let max_branch = path_dict.values().map(|v| v.len() as i32).max().unwrap(); 419 | 420 | let mut draft_choices: HashSet> = HashSet::new(); 421 | for path in paths { 422 | for (depth, _) in path.iter().enumerate() { 423 | let depth = depth as i32; 424 | let draft_choice: Vec = (0..=depth) 425 | .map(|prev_depth| { 426 | let prev_item = *path.get(prev_depth as usize).unwrap(); 427 | *path_dict.get(&prev_depth).unwrap().get(&prev_item).unwrap() 428 | }) 429 | .collect(); 430 | draft_choices.insert(draft_choice); 431 | } 432 | } 433 | 434 | let draft_choices: Vec> = draft_choices.into_iter().collect(); 435 | (draft_choices, max_branch) 436 | } 437 | 438 | 439 | 440 | fn pad_path(path: Vec, length: usize, pad_value: i32) -> Vec { 441 | let mut path = path; 442 | while path.len() < length { 443 | path.push(pad_value); 444 | } 445 | path 446 | } 447 | 448 | 449 | fn generate_draft_buffers(draft_choices: Vec>, topk: i32) -> (Vec>, Vec, Vec, Vec>) { 450 | 451 | // Sort the draft_choices based on their lengths and then their values 452 | let mut sorted_draft_choices = draft_choices; 453 | sorted_draft_choices.sort_by(|a, b| match a.len().cmp(&b.len()) { 454 | Ordering::Equal => a.cmp(b), 455 | other => other, 456 | }); 457 | 458 | let draft_len = sorted_draft_choices.len() + 1; 459 | assert! (draft_len <= 65, "draft_len should not exceed 65"); 460 | // Initialize depth_counts to keep track of how many choices have a particular depth 461 | let mut depth_counts:Vec = vec![0; draft_len]; 462 | let mut prev_depth = 0; 463 | for path in &sorted_draft_choices { 464 | let depth = path.len(); 465 | if depth != prev_depth { 466 | depth_counts[depth - 1] = 0; 467 | } 468 | depth_counts[depth - 1] += 1; 469 | prev_depth = depth; 470 | } 471 | // Create the attention mask for draft 472 | let mut draft_attn_mask:Vec> = vec![vec![0; draft_len]; draft_len]; 473 | for i in 0..draft_len { 474 | draft_attn_mask[i][0] = 1; 475 | draft_attn_mask[i][i] = 1; 476 | } 477 | 478 | let mut start = 0; 479 | for i in 0..depth_counts.len() { 480 | for j in 0..depth_counts[i] { 481 | let cur_draft_choice: Vec = sorted_draft_choices[(start + j) as usize].clone(); 482 | if cur_draft_choice.len() == 1 { 483 | continue; 484 | } 485 | 486 | let mut ancestor_idx = vec![]; 487 | for c in 0..(cur_draft_choice.len() - 1) { 488 | let index = sorted_draft_choices.iter().position(|x| x[..=cmp::min(c, x.len() - 1)] == cur_draft_choice[..=cmp::min(c, cur_draft_choice.len() - 1)]).unwrap() + 1; 489 | ancestor_idx.push(index); 490 | } 491 | 492 | for idx in ancestor_idx { 493 | draft_attn_mask[(j + start + 1) as usize][idx] = 1; 494 | } 495 | } 496 | start += depth_counts[i]; 497 | } 498 | 499 | // Generate tree indices for the draft structure 500 | let mut draft_tree_indices: Vec = vec![0; draft_len]; 501 | let mut start = 0; 502 | for i in 0..depth_counts.len() { 503 | for j in 0..depth_counts[i] { 504 | let cur_draft_choice = &sorted_draft_choices[(start + j) as usize]; 505 | draft_tree_indices[(start + j + 1) as usize] = cur_draft_choice.last().unwrap() + topk * (i as i32) + 1; 506 | } 507 | start += depth_counts[i]; 508 | } 509 | 510 | // Generate position IDs for the draft structure 511 | let mut draft_position_ids: Vec = vec![0; draft_len]; 512 | start = 0; 513 | for i in 0..depth_counts.len() { 514 | for j in start + 1..start + depth_counts[i] + 1 { 515 | draft_position_ids[j as usize] = (i as i32) + 1; 516 | } 517 | start += depth_counts[i]; 518 | } 519 | 520 | // Generate retrieval indices for draft structure verification 521 | let mut retrieve_indices_nest = Vec::new(); 522 | let mut retrieve_paths = Vec::new(); 523 | for i in 0..sorted_draft_choices.len() { 524 | let cur_draft_choice = sorted_draft_choices[sorted_draft_choices.len() - 1 - i].clone(); 525 | let mut retrieve_indice = Vec::new(); 526 | if retrieve_paths.contains(&cur_draft_choice) { 527 | continue; 528 | } else { 529 | for c in 0..cur_draft_choice.len() { 530 | let index = sorted_draft_choices.iter().position(|x| *x == cur_draft_choice[0..=c]).unwrap(); 531 | retrieve_indice.push(index as i32); 532 | retrieve_paths.push(cur_draft_choice[0..=c].to_vec()); 533 | } 534 | } 535 | retrieve_indices_nest.push(retrieve_indice); 536 | } 537 | let max_length = retrieve_indices_nest.iter().map(|x| x.len()).max().unwrap(); 538 | let mut retrieve_indices: Vec> = retrieve_indices_nest.iter().map(|x| pad_path(x.clone(), max_length, -2)).collect(); 539 | 540 | for i in 0..retrieve_indices.len() { 541 | for j in 0..retrieve_indices[i].len() { 542 | retrieve_indices[i][j] += 1; 543 | } 544 | } 545 | 546 | for i in 0..retrieve_indices.len() { 547 | retrieve_indices[i].insert(0, 0); 548 | } 549 | 550 | 551 | (draft_attn_mask, draft_tree_indices, draft_position_ids, retrieve_indices) 552 | } 553 | 554 | 555 | #[pymodule] 556 | fn draftretriever( 557 | _py: Python, 558 | m: &PyModule, 559 | ) -> PyResult<()> { 560 | m.add_class::()?; 561 | m.add_class::()?; 562 | 563 | Ok(()) 564 | } 565 | 566 | -------------------------------------------------------------------------------- /DraftRetriever/src/libsais/libsais.h: -------------------------------------------------------------------------------- 1 | /*-- 2 | 3 | This file is a part of libsais, a library for linear time 4 | suffix array and burrows wheeler transform construction. 5 | 6 | Copyright (c) 2021-2022 Ilya Grebnov 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | Please see the file LICENSE for full copyright information. 21 | 22 | --*/ 23 | 24 | #ifndef LIBSAIS_H 25 | #define LIBSAIS_H 1 26 | 27 | #ifdef __cplusplus 28 | extern "C" { 29 | #endif 30 | 31 | #include 32 | 33 | /** 34 | * Creates the libsais context that allows reusing allocated memory with each libsais operation. 35 | * In multi-threaded environments, use one context per thread for parallel executions. 36 | * @return the libsais context, NULL otherwise. 37 | */ 38 | void * libsais_create_ctx(void); 39 | 40 | #if defined(_OPENMP) 41 | /** 42 | * Creates the libsais context that allows reusing allocated memory with each parallel libsais operation using OpenMP. 43 | * In multi-threaded environments, use one context per thread for parallel executions. 44 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 45 | * @return the libsais context, NULL otherwise. 46 | */ 47 | void * libsais_create_ctx_omp(int32_t threads); 48 | #endif 49 | 50 | /** 51 | * Destroys the libsass context and free previusly allocated memory. 52 | * @param ctx The libsais context (can be NULL). 53 | */ 54 | void libsais_free_ctx(void * ctx); 55 | 56 | /** 57 | * Constructs the suffix array of a given string. 58 | * @param T [0..n-1] The input string. 59 | * @param SA [0..n-1+fs] The output array of suffixes. 60 | * @param n The length of the given string. 61 | * @param fs The extra space available at the end of SA array (0 should be enough for most cases). 62 | * @param freq [0..255] The output symbol frequency table (can be NULL). 63 | * @return 0 if no error occurred, -1 or -2 otherwise. 64 | */ 65 | int32_t libsais(const uint8_t * T, int32_t * SA, int32_t n, int32_t fs, int32_t * freq); 66 | 67 | /** 68 | * Constructs the suffix array of a given integer array. 69 | * Note, during construction input array will be modified, but restored at the end if no errors occurred. 70 | * @param T [0..n-1] The input integer array. 71 | * @param SA [0..n-1+fs] The output array of suffixes. 72 | * @param n The length of the integer array. 73 | * @param k The alphabet size of the input integer array. 74 | * @param fs Extra space available at the end of SA array (can be 0, but 4k or better 6k is recommended for optimal performance). 75 | * @return 0 if no error occurred, -1 or -2 otherwise. 76 | */ 77 | int32_t libsais_int(int32_t * T, int32_t * SA, int32_t n, int32_t k, int32_t fs); 78 | 79 | /** 80 | * Constructs the suffix array of a given string using libsais context. 81 | * @param ctx The libsais context. 82 | * @param T [0..n-1] The input string. 83 | * @param SA [0..n-1+fs] The output array of suffixes. 84 | * @param n The length of the given string. 85 | * @param fs The extra space available at the end of SA array (0 should be enough for most cases). 86 | * @param freq [0..255] The output symbol frequency table (can be NULL). 87 | * @return 0 if no error occurred, -1 or -2 otherwise. 88 | */ 89 | int32_t libsais_ctx(const void * ctx, const uint8_t * T, int32_t * SA, int32_t n, int32_t fs, int32_t * freq); 90 | 91 | #if defined(_OPENMP) 92 | /** 93 | * Constructs the suffix array of a given string in parallel using OpenMP. 94 | * @param T [0..n-1] The input string. 95 | * @param SA [0..n-1+fs] The output array of suffixes. 96 | * @param n The length of the given string. 97 | * @param fs The extra space available at the end of SA array (0 should be enough for most cases). 98 | * @param freq [0..255] The output symbol frequency table (can be NULL). 99 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 100 | * @return 0 if no error occurred, -1 or -2 otherwise. 101 | */ 102 | int32_t libsais_omp(const uint8_t * T, int32_t * SA, int32_t n, int32_t fs, int32_t * freq, int32_t threads); 103 | 104 | /** 105 | * Constructs the suffix array of a given integer array in parallel using OpenMP. 106 | * Note, during construction input array will be modified, but restored at the end if no errors occurred. 107 | * @param T [0..n-1] The input integer array. 108 | * @param SA [0..n-1+fs] The output array of suffixes. 109 | * @param n The length of the integer array. 110 | * @param k The alphabet size of the input integer array. 111 | * @param fs Extra space available at the end of SA array (can be 0, but 4k or better 6k is recommended for optimal performance). 112 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 113 | * @return 0 if no error occurred, -1 or -2 otherwise. 114 | */ 115 | int32_t libsais_int_omp(int32_t * T, int32_t * SA, int32_t n, int32_t k, int32_t fs, int32_t threads); 116 | #endif 117 | 118 | /** 119 | * Constructs the burrows-wheeler transformed string of a given string. 120 | * @param T [0..n-1] The input string. 121 | * @param U [0..n-1] The output string (can be T). 122 | * @param A [0..n-1+fs] The temporary array. 123 | * @param n The length of the given string. 124 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 125 | * @param freq [0..255] The output symbol frequency table (can be NULL). 126 | * @return The primary index if no error occurred, -1 or -2 otherwise. 127 | */ 128 | int32_t libsais_bwt(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq); 129 | 130 | /** 131 | * Constructs the burrows-wheeler transformed string of a given string with auxiliary indexes. 132 | * @param T [0..n-1] The input string. 133 | * @param U [0..n-1] The output string (can be T). 134 | * @param A [0..n-1+fs] The temporary array. 135 | * @param n The length of the given string. 136 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 137 | * @param freq [0..255] The output symbol frequency table (can be NULL). 138 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 139 | * @param I [0..(n-1)/r] The output auxiliary indexes. 140 | * @return 0 if no error occurred, -1 or -2 otherwise. 141 | */ 142 | int32_t libsais_bwt_aux(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq, int32_t r, int32_t * I); 143 | 144 | /** 145 | * Constructs the burrows-wheeler transformed string of a given string using libsais context. 146 | * @param ctx The libsais context. 147 | * @param T [0..n-1] The input string. 148 | * @param U [0..n-1] The output string (can be T). 149 | * @param A [0..n-1+fs] The temporary array. 150 | * @param n The length of the given string. 151 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 152 | * @param freq [0..255] The output symbol frequency table (can be NULL). 153 | * @return The primary index if no error occurred, -1 or -2 otherwise. 154 | */ 155 | int32_t libsais_bwt_ctx(const void * ctx, const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq); 156 | 157 | /** 158 | * Constructs the burrows-wheeler transformed string of a given string with auxiliary indexes using libsais context. 159 | * @param ctx The libsais context. 160 | * @param T [0..n-1] The input string. 161 | * @param U [0..n-1] The output string (can be T). 162 | * @param A [0..n-1+fs] The temporary array. 163 | * @param n The length of the given string. 164 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 165 | * @param freq [0..255] The output symbol frequency table (can be NULL). 166 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 167 | * @param I [0..(n-1)/r] The output auxiliary indexes. 168 | * @return 0 if no error occurred, -1 or -2 otherwise. 169 | */ 170 | int32_t libsais_bwt_aux_ctx(const void * ctx, const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq, int32_t r, int32_t * I); 171 | 172 | #if defined(_OPENMP) 173 | /** 174 | * Constructs the burrows-wheeler transformed string of a given string in parallel using OpenMP. 175 | * @param T [0..n-1] The input string. 176 | * @param U [0..n-1] The output string (can be T). 177 | * @param A [0..n-1+fs] The temporary array. 178 | * @param n The length of the given string. 179 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 180 | * @param freq [0..255] The output symbol frequency table (can be NULL). 181 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 182 | * @return The primary index if no error occurred, -1 or -2 otherwise. 183 | */ 184 | int32_t libsais_bwt_omp(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq, int32_t threads); 185 | 186 | /** 187 | * Constructs the burrows-wheeler transformed string of a given string with auxiliary indexes in parallel using OpenMP. 188 | * @param T [0..n-1] The input string. 189 | * @param U [0..n-1] The output string (can be T). 190 | * @param A [0..n-1+fs] The temporary array. 191 | * @param n The length of the given string. 192 | * @param fs The extra space available at the end of A array (0 should be enough for most cases). 193 | * @param freq [0..255] The output symbol frequency table (can be NULL). 194 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 195 | * @param I [0..(n-1)/r] The output auxiliary indexes. 196 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 197 | * @return 0 if no error occurred, -1 or -2 otherwise. 198 | */ 199 | int32_t libsais_bwt_aux_omp(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, int32_t fs, int32_t * freq, int32_t r, int32_t * I, int32_t threads); 200 | #endif 201 | 202 | /** 203 | * Creates the libsais reverse BWT context that allows reusing allocated memory with each libsais_unbwt_* operation. 204 | * In multi-threaded environments, use one context per thread for parallel executions. 205 | * @return the libsais context, NULL otherwise. 206 | */ 207 | void * libsais_unbwt_create_ctx(void); 208 | 209 | #if defined(_OPENMP) 210 | /** 211 | * Creates the libsais reverse BWT context that allows reusing allocated memory with each parallel libsais_unbwt_* operation using OpenMP. 212 | * In multi-threaded environments, use one context per thread for parallel executions. 213 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 214 | * @return the libsais context, NULL otherwise. 215 | */ 216 | void * libsais_unbwt_create_ctx_omp(int32_t threads); 217 | #endif 218 | 219 | /** 220 | * Destroys the libsass reverse BWT context and free previusly allocated memory. 221 | * @param ctx The libsais context (can be NULL). 222 | */ 223 | void libsais_unbwt_free_ctx(void * ctx); 224 | 225 | /** 226 | * Constructs the original string from a given burrows-wheeler transformed string with primary index. 227 | * @param T [0..n-1] The input string. 228 | * @param U [0..n-1] The output string (can be T). 229 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 230 | * @param n The length of the given string. 231 | * @param freq [0..255] The input symbol frequency table (can be NULL). 232 | * @param i The primary index. 233 | * @return 0 if no error occurred, -1 or -2 otherwise. 234 | */ 235 | int32_t libsais_unbwt(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t i); 236 | 237 | /** 238 | * Constructs the original string from a given burrows-wheeler transformed string with primary index using libsais reverse BWT context. 239 | * @param ctx The libsais reverse BWT context. 240 | * @param T [0..n-1] The input string. 241 | * @param U [0..n-1] The output string (can be T). 242 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 243 | * @param n The length of the given string. 244 | * @param freq [0..255] The input symbol frequency table (can be NULL). 245 | * @param i The primary index. 246 | * @return 0 if no error occurred, -1 or -2 otherwise. 247 | */ 248 | int32_t libsais_unbwt_ctx(const void * ctx, const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t i); 249 | 250 | /** 251 | * Constructs the original string from a given burrows-wheeler transformed string with auxiliary indexes. 252 | * @param T [0..n-1] The input string. 253 | * @param U [0..n-1] The output string (can be T). 254 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 255 | * @param n The length of the given string. 256 | * @param freq [0..255] The input symbol frequency table (can be NULL). 257 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 258 | * @param I [0..(n-1)/r] The input auxiliary indexes. 259 | * @return 0 if no error occurred, -1 or -2 otherwise. 260 | */ 261 | int32_t libsais_unbwt_aux(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t r, const int32_t * I); 262 | 263 | /** 264 | * Constructs the original string from a given burrows-wheeler transformed string with auxiliary indexes using libsais reverse BWT context. 265 | * @param ctx The libsais reverse BWT context. 266 | * @param T [0..n-1] The input string. 267 | * @param U [0..n-1] The output string (can be T). 268 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 269 | * @param n The length of the given string. 270 | * @param freq [0..255] The input symbol frequency table (can be NULL). 271 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 272 | * @param I [0..(n-1)/r] The input auxiliary indexes. 273 | * @return 0 if no error occurred, -1 or -2 otherwise. 274 | */ 275 | int32_t libsais_unbwt_aux_ctx(const void * ctx, const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t r, const int32_t * I); 276 | 277 | #if defined(_OPENMP) 278 | /** 279 | * Constructs the original string from a given burrows-wheeler transformed string with primary index in parallel using OpenMP. 280 | * @param T [0..n-1] The input string. 281 | * @param U [0..n-1] The output string (can be T). 282 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 283 | * @param n The length of the given string. 284 | * @param freq [0..255] The input symbol frequency table (can be NULL). 285 | * @param i The primary index. 286 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 287 | * @return 0 if no error occurred, -1 or -2 otherwise. 288 | */ 289 | int32_t libsais_unbwt_omp(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t i, int32_t threads); 290 | 291 | /** 292 | * Constructs the original string from a given burrows-wheeler transformed string with auxiliary indexes in parallel using OpenMP. 293 | * @param T [0..n-1] The input string. 294 | * @param U [0..n-1] The output string (can be T). 295 | * @param A [0..n] The temporary array (NOTE, temporary array must be n + 1 size). 296 | * @param n The length of the given string. 297 | * @param freq [0..255] The input symbol frequency table (can be NULL). 298 | * @param r The sampling rate for auxiliary indexes (must be power of 2). 299 | * @param I [0..(n-1)/r] The input auxiliary indexes. 300 | * @param threads The number of OpenMP threads to use (can be 0 for OpenMP default). 301 | * @return 0 if no error occurred, -1 or -2 otherwise. 302 | */ 303 | int32_t libsais_unbwt_aux_omp(const uint8_t * T, uint8_t * U, int32_t * A, int32_t n, const int32_t * freq, int32_t r, const int32_t * I, int32_t threads); 304 | #endif 305 | 306 | #ifdef __cplusplus 307 | } 308 | #endif 309 | 310 | #endif 311 | -------------------------------------------------------------------------------- /DraftRetriever/wheels/draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/DraftRetriever/wheels/draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REST: Retrieval-Based Speculative Decoding 2 | 3 | ***If training's got you in a stew, take a REST and speed right through.*** 4 | 5 | [[Paper](https://arxiv.org/abs/2311.08252)] [[Blog](https://sites.google.com/view/rest-llm/)] 6 | 7 | ## News 8 | 🎉 2024-3-14: REST is accepted to **NAACL 2024**! 9 | 10 | ## Introduction 11 | 12 | REST is a retrieval-based speculative decoding method designed to boost generation speed of LLMs. Instead of relying on a draft language model like speculative decoding, REST utilizes a datastore to retrieve and employ draft tokens. Moreover, REST differs from blockwise parallel decoding and Medusa in that it doesn't require extra training steps. It functions as a plug-and-play solution capable of **accelerating any pre-existing language model**. 13 | 14 |
15 | 16 | 17 | 18 |
19 |
20 | Overview of REST. During inference, the input context is utilized as the query to retrieve docs from the datastore that match the longest suffix of the input. A Trie is constructed using the continuations from the retrieved docs and low-frequency branches are pruned. Candidates from the pruned subtree will be further fed into the LLM with a tree attention mask for verification. All correct tokens from the start will be accepted, and the draft tokens after the first mistake will be rejected. 21 |
22 |
23 |
24 | 25 |
26 | 27 | 28 | 29 |
30 |
31 | Speed on HumanEval and MT-Bench with standard autoregressive generation and REST. The temperature is set to 0.8 and the top-p to 0.95 for nucleus sampling in HumanEval. For MT-Bench, the settings are 0.7 for temperature and 0.8 for top-p. All the experiments are conducted on a single NVIDIA A6000 GPU and 96 CPU cores with a batch size of 1. 32 |
33 |
34 |
35 | 36 | 41 | 42 | ## Contents 43 | - [Introduction](#introduction) 44 | - [Contents](#contents) 45 | - [Installation](#installation) 46 | - [Build datastores](#Build-datastore) 47 | - [Build a small one](#Build-a-small-one) 48 | - [Build a large one](#Build-a-large-one) 49 | - [Inference](#Inference) 50 | - [Inference on MT-Bench](#Inference-on-MT-Bench) 51 | - [Inference on HumanEval](#Inference-on-HumanEval) 52 | - [Free Chat](#Free-Chat) 53 | - [Citation](#citation) 54 | - [Other Models and Datastore](#other-models-and-datastore) 55 | - [Acknowledgements](#acknowledgements) 56 | 57 | ## Installation 58 | ```bash 59 | conda create -n rest python=3.9 60 | conda activate rest 61 | pip3 install -r requirements.txt # pay attention to Pytorch CUDA version 62 | pip3 install DraftRetriever/wheels/draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl 63 | ``` 64 | 65 | ## Build datastore 66 | 67 | ### Build a small one 68 | Build a chat datastore using data from [ShareGPT](https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered) within 10 minutes (requires 465MB disk storage) 69 | ```bash 70 | cd datastore 71 | python3 get_datastore_chat.py --model-path lmsys/vicuna-7b-v1.5 # get datastore_chat_small.idx in this folder 72 | ``` 73 | Build a Python code generation datastore from [The Stack](https://huggingface.co/datasets/bigcode/the-stack) within 20 minutes (requires 924MB disk storage) 74 | ```bash 75 | cd datastore 76 | python3 get_datastore_code.py --model-path codellama/CodeLlama-7b-instruct-hf # get datastore_stack_small.idx in this folder 77 | ``` 78 | 79 | ### Build a large one 80 | (optionally) Build a chat datastore using data from [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) (requires 12GB disk storage) 81 | ```bash 82 | cd datastore 83 | python3 get_datastore_chat.py --model-path lmsys/vicuna-7b-v1.5 --large-datastore True # get datastore_chat_large.idx in this folder 84 | ``` 85 | (optionally) Build a Python code generation datastore from [The Stack](https://huggingface.co/datasets/bigcode/the-stack) (requires 27GB disk storage) 86 | ```bash 87 | cd datastore 88 | python3 get_datastore_code.py --model-path codellama/CodeLlama-7b-instruct-hf --large-datastore True # get datastore_stack_large.idx in this folder 89 | ``` 90 | 91 | ## Inference 92 | 93 | ### Inference on MT-Bench 94 | ```bash 95 | cd llm_judge 96 | RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 gen_model_answer_rest.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 --datastore-path ../datastore/datastore_chat_small.idx 97 | ``` 98 | 99 | ### Inference on HumanEval 100 | ```bash 101 | cd human_eval 102 | RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 rest_test.py --model-path codellama/CodeLlama-7b-instruct-hf --datastore-path ../datastore/datastore_stack_small.idx 103 | ``` 104 | 105 | ### Free Chat 106 | ```bash 107 | RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 -m rest.inference.cli --datastore-path datastore/datastore_chat_small.idx --base-model lmsys/vicuna-7b-v1.5 108 | ``` 109 | 110 | Note that the RAYON_NUM_THREADS environment variable control the maximum number of threads for retrieval. You can adjust it based on your machine. 111 | 112 | 113 | ## Other Models and Datastore 114 | In the examples above, we default to use Vicuna and CodeLlama. But actually you can use any LLaMA-based models you like by simply changing the "--model-path" argument. You can also build the datastore from any data you like. If you want to use architectures other than LLaMA, you can also modify the file model/modeling_llama_kv.py to match the corresponding model. 115 | 116 | Note: For models with a vocab size larger than 65535 (range of u16), you may change [this line in writer](https://github.com/FasterDecoding/REST/blob/main/DraftRetriever/src/lib.rs#L117) from `self.index_file.write_u16::(item as u16)?;` to `self.index_file.write_u32::(item as u32)?;` 117 | Besides, change [these two lines in Reader](https://github.com/FasterDecoding/REST/blob/main/DraftRetriever/src/lib.rs#L191-L192) from `for i in (0..data_u8.len()).step_by(2) { let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32;` to `for i in (0..data_u8.len()).step_by(4) { let int = LittleEndian::read_u32(&data_u8[i..i+4]) as i32;` (Fixed by [scandukuri](https://github.com/FasterDecoding/REST/pull/23)) 118 | 119 | ## Citation 120 | ``` 121 | @misc{he2023rest, 122 | title={REST: Retrieval-Based Speculative Decoding}, 123 | author={Zhenyu He and Zexuan Zhong and Tianle Cai and Jason D Lee and Di He}, 124 | year={2023}, 125 | eprint={2311.08252}, 126 | archivePrefix={arXiv}, 127 | primaryClass={cs.CL} 128 | } 129 | ``` 130 | 131 | ## Acknowledgements 132 | The codebase is from [Medusa](https://github.com/FasterDecoding/Medusa) and influenced by remarkable projects from the LLM community, including [FastChat](https://github.com/lm-sys/FastChat), [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/), [vllm](https://github.com/vllm-project/vllm) and many others. 133 | 134 | -------------------------------------------------------------------------------- /assets/rest_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/assets/rest_overview.png -------------------------------------------------------------------------------- /assets/rest_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/assets/rest_results.png -------------------------------------------------------------------------------- /datastore/get_datastore_chat.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoTokenizer 3 | import draftretriever 4 | from tqdm import tqdm 5 | import json 6 | 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument( 11 | "--model-path", 12 | type=str, 13 | default="lmsys/vicuna-7b-v1.5", 14 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 15 | ) 16 | parser.add_argument( 17 | "--large-datastore", 18 | type=bool, 19 | default=False, 20 | help="Whether to use a large datastore", 21 | ) 22 | args = parser.parse_args() 23 | print(args) 24 | 25 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 26 | 27 | 28 | datastore_path = './datastore_chat_large.idx' if args.large_datastore else './datastore_chat_small.idx' 29 | writer = draftretriever.Writer( 30 | index_file_path=datastore_path, 31 | max_chunk_len=512*1024*1024, 32 | vocab_size=tokenizer.vocab_size + len(tokenizer.get_added_vocab()), 33 | ) 34 | if args.large_datastore: 35 | dataset = load_dataset('stingning/ultrachat', split='train') 36 | total_length = len(dataset) 37 | print("number of samples: ", total_length) 38 | for conversations in tqdm(dataset, total=total_length): 39 | for sample in conversations['data']: 40 | token_list = tokenizer.encode(sample) 41 | writer.add_entry(token_list) 42 | else: 43 | 44 | dataset_path = None 45 | assert dataset_path is not None, "please download the dataset from https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered" 46 | dataset = json.load(open(dataset_path)) 47 | total_length = len(dataset) 48 | print("number of samples: ", total_length) 49 | for conversations in tqdm(dataset, total=total_length): 50 | for sample in conversations['conversations']: 51 | token_list = tokenizer.encode(sample['value']) 52 | writer.add_entry(token_list) 53 | 54 | writer.finalize() 55 | -------------------------------------------------------------------------------- /datastore/get_datastore_code.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import AutoTokenizer 3 | import draftretriever 4 | from tqdm import tqdm 5 | import argparse 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | "--model-path", 10 | type=str, 11 | default="codellama/CodeLlama-7b-instruct-hf", 12 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 13 | ) 14 | parser.add_argument( 15 | "--large-datastore", 16 | type=bool, 17 | default=False, 18 | help="Whether to use a large datastore", 19 | ) 20 | args = parser.parse_args() 21 | print(args) 22 | 23 | 24 | tokenizer = AutoTokenizer.from_pretrained(args.model_path) 25 | segment = 30 if args.large_datastore else 1 # Maximum number of segment: 144 26 | data_files = [] 27 | for i in range(segment): 28 | if i>=100: 29 | data_files.append(f"data-00{i}-of-00144.parquet") 30 | elif i >=10: 31 | data_files.append(f"data-000{i}-of-00144.parquet") 32 | else: 33 | data_files.append(f"data-0000{i}-of-00144.parquet") 34 | print("data_files:", data_files) 35 | 36 | dataset = load_dataset('bigcode/the-stack-dedup', \ 37 | data_dir='data/python', split='train', data_files=data_files) 38 | 39 | 40 | datastore_path = './datastore_stack_large.idx' if args.large_datastore else './datastore_stack_small.idx' 41 | writer = draftretriever.Writer( 42 | index_file_path=datastore_path, 43 | max_chunk_len=512 * 1024 * 1024, 44 | vocab_size=tokenizer.vocab_size + len(tokenizer.get_added_vocab()), 45 | ) 46 | 47 | total_length = len(dataset) 48 | print("number of samples: ", total_length) 49 | 50 | for sample in tqdm(dataset, total=len(dataset)): 51 | token_list = tokenizer.encode(sample['content']) 52 | writer.add_entry(token_list) 53 | 54 | writer.finalize() 55 | -------------------------------------------------------------------------------- /human_eval/HumanEval.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/human_eval/HumanEval.jsonl.gz -------------------------------------------------------------------------------- /human_eval/baseline_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import torch 5 | from contextlib import contextmanager 6 | import numpy as np 7 | from rest.model.rest_model import RestModel 8 | from rest.model.kv_cache import * 9 | from rest.model.utils import * 10 | 11 | from tqdm import tqdm 12 | import time 13 | import argparse 14 | 15 | from dataset import HumanEvalDataset 16 | 17 | def run_eval(model, tokenizer,temperature, top_p, max_new_token): 18 | avg_time_per_token_list = [] 19 | avg_time_per_token_list_micro = [] 20 | 21 | for sample in tqdm(dataset, total=len(dataset)): 22 | prompt = sample['prompt'] 23 | with torch.inference_mode(): 24 | past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model) 25 | model.past_key_values = past_key_values 26 | model.past_key_values_data = past_key_values_data 27 | model.current_length_data = current_length_data 28 | 29 | model.current_length_data.zero_() # this is for rerun 30 | 31 | 32 | new_token = 0 33 | input_ids = tokenizer([prompt]).input_ids 34 | input_len = len(input_ids[0]) 35 | input_ids = torch.as_tensor(input_ids).cuda() 36 | model.base_model.model.draft_mask = None 37 | outputs = model.base_model(input_ids, past_key_values = past_key_values, use_cache=True) 38 | new_token = 0 39 | # logits = initialize_logits( 40 | # input_ids, model, past_key_values 41 | # ) 42 | # cur_length = input_len + 1 43 | # accept_lengths_tree.append(1) 44 | 45 | torch.cuda.synchronize() 46 | start_time = time.time() 47 | for i in range(2000): 48 | # candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 49 | # logits, 50 | # input_ids, 51 | # datastore, 52 | # token_spans, 53 | # top_p, 54 | # temperature, 55 | # max_num_draft=num_draft, 56 | # device=model.base_model.device 57 | # ) 58 | 59 | # model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 60 | 61 | # logits, outputs = tree_decoding( 62 | # model, 63 | # tree_candidates, 64 | # past_key_values, 65 | # draft_buffers["draft_position_ids"], 66 | # input_ids, 67 | # draft_buffers["retrieve_indices"], 68 | # ) 69 | 70 | # best_candidate, accept_length = evaluate_posterior( 71 | # logits, candidates, temperature = temperature, top_p=top_p 72 | # ) 73 | # input_ids, logits, new_token = update_inference_inputs( 74 | # input_ids, 75 | # candidates, 76 | # best_candidate, 77 | # accept_length, 78 | # draft_buffers["retrieve_indices"], 79 | # outputs, 80 | # logits, 81 | # new_token, 82 | # past_key_values_data, 83 | # current_length_data, 84 | # ) 85 | if top_p > 0: 86 | assert top_p < 1, "top_p should between 0.0 and 1" 87 | next_token_logits = outputs.logits[:, -1, :] 88 | next_token_logits = next_token_logits / (temperature if temperature > 0 else 1.) 89 | filtered_logits = top_p_filtering(next_token_logits, top_p=top_p) 90 | input_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 91 | input_id = input_id.view(input_id.shape[0], 1) 92 | else: 93 | input_id = outputs.logits[:, -1:].argmax(dim=-1) 94 | 95 | outputs = model.base_model(input_id, use_cache=True, past_key_values = past_key_values) 96 | input_ids = torch.cat([input_ids, input_id], dim=-1) 97 | new_token += 1 98 | if model.tokenizer.eos_token_id in input_ids[0, input_len:] or new_token > max_new_token: 99 | break 100 | 101 | torch.cuda.synchronize() 102 | total_time = time.time() - start_time 103 | avg_time_per_token = total_time / new_token 104 | avg_time_per_token_list.append(avg_time_per_token) 105 | avg_time_per_token_list_micro.append((total_time, new_token)) 106 | 107 | 108 | print("avg_time_per_token: ", np.mean(avg_time_per_token_list)) 109 | print("avg_time_per_token_micro: ", np.sum([item[0] for item in avg_time_per_token_list_micro]) / np.sum([item[1] for item in avg_time_per_token_list_micro])) 110 | print("*"*30) 111 | print() 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument( 117 | "--model-path", 118 | type=str, 119 | default="codellama/CodeLlama-7b-instruct-hf", 120 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 121 | ) 122 | parser.add_argument( 123 | "--dataset-path", 124 | type=str, 125 | default="./HumanEval.jsonl.gz", 126 | help="The path to the HumanEval dataset", 127 | ) 128 | parser.add_argument( 129 | "--max-new-token", 130 | type=int, 131 | default=512, 132 | help="The maximum number of new generated tokens.", 133 | ) 134 | parser.add_argument( 135 | "--temperature", 136 | type=float, 137 | default=0.0, 138 | help="The temperature for sampling.", 139 | ) 140 | 141 | parser.add_argument( 142 | "--top-p", 143 | type=float, 144 | default=0.0, 145 | help="The threshold for nucleus sampling.", 146 | ) 147 | 148 | args = parser.parse_args() 149 | 150 | if args.temperature == 0: 151 | args.top_p = 0 152 | 153 | print(args) 154 | 155 | model = RestModel.from_pretrained( 156 | args.model_path, 157 | torch_dtype=torch.float16, 158 | low_cpu_mem_usage=True, 159 | device_map="auto" 160 | ) 161 | 162 | tokenizer = model.get_tokenizer() 163 | 164 | dataset = HumanEvalDataset(args.dataset_path) 165 | 166 | 167 | run_eval( 168 | model, 169 | tokenizer, 170 | args.temperature, 171 | args.top_p, 172 | args.max_new_token 173 | ) -------------------------------------------------------------------------------- /human_eval/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import gzip 4 | import json 5 | 6 | class HumanEvalDataset(Dataset): 7 | def __init__(self, file_name): 8 | self.file_name = file_name 9 | self.data = [] 10 | self.load_data(file_name) 11 | 12 | def load_data(self, filename): 13 | if filename.endswith(".gz"): 14 | with open(filename, "rb") as gzfp: 15 | with gzip.open(gzfp, 'rt') as fp: 16 | for line in fp: 17 | if any(not x.isspace() for x in line): 18 | self.data.append(json.loads(line)) 19 | else: 20 | with open(filename, "r") as fp: 21 | for line in fp: 22 | if any(not x.isspace() for x in line): 23 | self.data.append(json.loads(line)) 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, idx): 29 | return self.data[idx] -------------------------------------------------------------------------------- /human_eval/rest_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("../") 4 | import torch 5 | from contextlib import contextmanager 6 | import numpy as np 7 | from rest.model.rest_model import RestModel 8 | from rest.model.kv_cache import * 9 | from rest.model.utils import * 10 | import draftretriever 11 | 12 | from tqdm import tqdm 13 | import time 14 | import argparse 15 | 16 | from dataset import HumanEvalDataset 17 | 18 | def run_eval(model, tokenizer, datastore, max_token_span, num_draft, temperature, top_p, max_new_token): 19 | accept_lengths_tree_average = [] 20 | avg_time_per_token_list = [] 21 | 22 | accept_lengths_tree_average_micro = [] 23 | avg_time_per_token_list_micro = [] 24 | token_spans = list(range(2, max_token_span + 1))[::-1] 25 | print("token_spans: ", token_spans) 26 | 27 | for sample in tqdm(dataset, total=len(dataset)): 28 | prompt = sample['prompt'] 29 | 30 | accept_lengths_tree = [] 31 | with torch.inference_mode(): 32 | 33 | # Initialize the past key and value states 34 | if hasattr(model, "past_key_values"): 35 | past_key_values = model.past_key_values 36 | past_key_values_data = model.past_key_values_data 37 | current_length_data = model.current_length_data 38 | # Reset the past key and value states 39 | current_length_data.zero_() 40 | else: 41 | ( 42 | past_key_values, 43 | past_key_values_data, 44 | current_length_data, 45 | ) = initialize_past_key_values(model.base_model) 46 | model.past_key_values = past_key_values 47 | model.past_key_values_data = past_key_values_data 48 | model.current_length_data = current_length_data 49 | 50 | 51 | new_token = 0 52 | input_ids = tokenizer([prompt]).input_ids 53 | input_len = len(input_ids[0]) 54 | input_ids = torch.as_tensor(input_ids).cuda() 55 | model.base_model.model.draft_mask = None 56 | logits = initialize_logits( 57 | input_ids, model, past_key_values 58 | ) 59 | cur_length = input_len + 1 60 | accept_lengths_tree.append(1) 61 | 62 | torch.cuda.synchronize() 63 | start_time = time.time() 64 | for i in range(2000): 65 | candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 66 | logits, 67 | input_ids, 68 | datastore, 69 | token_spans, 70 | top_p, 71 | temperature, 72 | max_num_draft=num_draft, 73 | device=model.base_model.device 74 | ) 75 | 76 | model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 77 | 78 | logits, outputs = tree_decoding( 79 | model, 80 | tree_candidates, 81 | past_key_values, 82 | draft_buffers["draft_position_ids"], 83 | input_ids, 84 | draft_buffers["retrieve_indices"], 85 | ) 86 | 87 | best_candidate, accept_length = evaluate_posterior( 88 | logits, candidates, temperature = temperature, top_p=top_p 89 | ) 90 | input_ids, logits, new_token = update_inference_inputs( 91 | input_ids, 92 | candidates, 93 | best_candidate, 94 | accept_length, 95 | draft_buffers["retrieve_indices"], 96 | outputs, 97 | logits, 98 | new_token, 99 | past_key_values_data, 100 | current_length_data, 101 | ) 102 | 103 | accept_length_tree = input_ids.shape[1] - cur_length 104 | cur_length = accept_length_tree + cur_length 105 | accept_lengths_tree.append(accept_length_tree) 106 | if model.tokenizer.eos_token_id in input_ids[0, input_len:] or new_token > max_new_token: 107 | break 108 | 109 | torch.cuda.synchronize() 110 | total_time = time.time() - start_time 111 | avg_time_per_token = total_time / (new_token.cpu()) 112 | avg_time_per_token_list.append(avg_time_per_token) 113 | avg_time_per_token_list_micro.append((total_time, new_token.cpu())) 114 | 115 | accept_lengths_tree_average.append(np.mean(accept_lengths_tree)) 116 | accept_lengths_tree_average_micro.extend(accept_lengths_tree) 117 | 118 | print("accept_lengths_tree_average: ", np.mean(accept_lengths_tree_average)) 119 | print("accept_lengths_tree_average_micro: ", np.mean(accept_lengths_tree_average_micro)) 120 | print("avg_time_per_token: ", np.mean(avg_time_per_token_list)) 121 | print("avg_time_per_token_micro: ", np.sum([item[0] for item in avg_time_per_token_list_micro]) / np.sum([item[1] for item in avg_time_per_token_list_micro])) 122 | print("*"*30) 123 | print() 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument( 129 | "--model-path", 130 | type=str, 131 | default="codellama/CodeLlama-7b-instruct-hf", 132 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 133 | ) 134 | parser.add_argument( 135 | "--dataset-path", 136 | type=str, 137 | default="./HumanEval.jsonl.gz", 138 | help="The path to the HumanEval dataset", 139 | ) 140 | parser.add_argument( 141 | "--max-new-token", 142 | type=int, 143 | default=512, 144 | help="The maximum number of new generated tokens.", 145 | ) 146 | parser.add_argument( 147 | "--temperature", 148 | type=float, 149 | default=0.0, 150 | help="The temperature for sampling.", 151 | ) 152 | 153 | parser.add_argument( 154 | "--top-p", 155 | type=float, 156 | default=0.0, 157 | help="The threshold for nucleus sampling.", 158 | ) 159 | 160 | # REST's hyperparameters 161 | parser.add_argument( 162 | "--datastore-path", 163 | type=str, 164 | required=True, 165 | help="The path of the datastore for retrival.", 166 | ) 167 | 168 | parser.add_argument( 169 | "--num-draft", 170 | type=int, 171 | default=64, 172 | help="The number of draft tokens.", 173 | ) 174 | parser.add_argument( 175 | "--max-token-span", 176 | type=int, 177 | default=16, 178 | help="The maximum length of suffix for retrieval.", 179 | ) 180 | 181 | args = parser.parse_args() 182 | 183 | if args.temperature == 0: 184 | args.top_p = 0 185 | 186 | print(args) 187 | 188 | model = RestModel.from_pretrained( 189 | args.model_path, 190 | torch_dtype=torch.float16, 191 | low_cpu_mem_usage=True, 192 | device_map="auto" 193 | ) 194 | 195 | tokenizer = model.get_tokenizer() 196 | 197 | dataset = HumanEvalDataset(args.dataset_path) 198 | 199 | print("loading the datastore ...") 200 | datastore = draftretriever.Reader( 201 | index_file_path=args.datastore_path, 202 | ) 203 | print("datastore loaded!") 204 | 205 | run_eval( 206 | model, 207 | tokenizer, 208 | datastore, 209 | args.max_token_span, 210 | args.num_draft, 211 | args.temperature, 212 | args.top_p, 213 | args.max_new_token 214 | ) -------------------------------------------------------------------------------- /human_eval/results/baseline_test.txt: -------------------------------------------------------------------------------- 1 | avg_time_per_token: 0.027981323390069406 2 | avg_time_per_token_micro: 0.027969173306964402 3 | ****************************** -------------------------------------------------------------------------------- /human_eval/results/rest_test.txt: -------------------------------------------------------------------------------- 1 | loading the datastore ... 2 | datastore loaded! 3 | token_spans: [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2] 4 | accept_lengths_tree_average: 2.978881842847655 5 | accept_lengths_tree_average_micro: 2.6481571486431754 6 | avg_time_per_token: 0.011493277 7 | avg_time_per_token_micro: 0.01133284168328882 -------------------------------------------------------------------------------- /llm_judge/data/judge_prompts.jsonl: -------------------------------------------------------------------------------- 1 | {"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} 2 | {"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} 3 | {"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"} 4 | {"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} 5 | {"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} 6 | {"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} 7 | {"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} 8 | {"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} 9 | -------------------------------------------------------------------------------- /llm_judge/gen_model_answer_baseline.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import random 10 | import time 11 | import shortuuid 12 | import torch 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | from fastchat.llm_judge.common import load_questions 17 | from fastchat.model import load_model, get_conversation_template 18 | 19 | # Rest imports 20 | import transformers 21 | 22 | import sys 23 | sys.path.append("../") 24 | 25 | from rest.model.utils import * 26 | from rest.model.rest_model import RestModel 27 | from rest.model.kv_cache import initialize_past_key_values 28 | 29 | 30 | def baseline_forward(input_ids, model, tokenizer, max_new_token, temperature, top_p, max_steps=1024): 31 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 32 | # Avoid modifying the input_ids in-place 33 | input_ids = input_ids.clone() 34 | accept_length_list = [] 35 | 36 | # Initialize the past key and value states 37 | if hasattr(model, "past_key_values"): 38 | past_key_values = model.past_key_values 39 | past_key_values_data = model.past_key_values_data 40 | current_length_data = model.current_length_data 41 | # Reset the past key and value states 42 | current_length_data.zero_() 43 | else: 44 | ( 45 | past_key_values, 46 | past_key_values_data, 47 | current_length_data, 48 | ) = initialize_past_key_values(model.base_model) 49 | model.past_key_values = past_key_values 50 | model.past_key_values_data = past_key_values_data 51 | model.current_length_data = current_length_data 52 | 53 | input_len = input_ids.shape[1] 54 | model.base_model.model.draft_mask = None 55 | outputs = model.base_model(input_ids, past_key_values = past_key_values, use_cache=True) 56 | new_token = 0 57 | 58 | torch.cuda.synchronize() 59 | start_time = time.time() 60 | for idx in range(max_steps): 61 | # candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 62 | # logits, 63 | # input_ids, 64 | # datastore, 65 | # token_spans, 66 | # top_p, 67 | # temperature, 68 | # max_num_draft=num_draft, 69 | # device=model.base_model.device 70 | # ) 71 | # model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 72 | # logits, outputs = tree_decoding( 73 | # model, 74 | # tree_candidates, 75 | # past_key_values, 76 | # draft_buffers["draft_position_ids"], 77 | # input_ids, 78 | # draft_buffers["retrieve_indices"], 79 | # ) 80 | # best_candidate, accept_length = evaluate_posterior( 81 | # logits, candidates, temperature, top_p 82 | # ) 83 | # input_ids, logits, new_token = update_inference_inputs( 84 | # input_ids, 85 | # candidates, 86 | # best_candidate, 87 | # accept_length, 88 | # draft_buffers["retrieve_indices"], 89 | # outputs, 90 | # logits, 91 | # new_token, 92 | # past_key_values_data, 93 | # current_length_data, 94 | # ) 95 | # accept_length_tree = input_ids.shape[1] - cur_length 96 | if top_p > 0: 97 | assert top_p < 1, "top_p should between 0.0 and 1" 98 | next_token_logits = outputs.logits[:, -1, :] 99 | next_token_logits = next_token_logits / (temperature if temperature > 0 else 1.) 100 | filtered_logits = top_p_filtering(next_token_logits, top_p=top_p) 101 | input_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 102 | input_id = input_id.view(input_id.shape[0], 1) 103 | else: 104 | input_id = outputs.logits[:, -1:].argmax(dim=-1) 105 | outputs = model.base_model(input_id, use_cache=True, past_key_values = past_key_values) 106 | input_ids = torch.cat([input_ids, input_id], dim=-1) 107 | new_token += 1 108 | accept_length_list.append(1) 109 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): 110 | break 111 | if new_token > max_new_token: 112 | break 113 | return input_ids, new_token, idx, accept_length_list, start_time 114 | 115 | def run_eval( 116 | model_path, 117 | model_id, 118 | question_file, 119 | question_begin, 120 | question_end, 121 | answer_file, 122 | max_new_token, 123 | num_choices, 124 | num_gpus_per_model, 125 | num_gpus_total, 126 | max_gpu_memory, 127 | temperature, 128 | top_p, 129 | ): 130 | questions = load_questions(question_file, question_begin, question_end) 131 | # random shuffle the questions to balance the loading 132 | # random.shuffle(questions) 133 | shuffled_ids = [q["question_id"] for q in questions] 134 | # with open(f"data/{args.bench_name}/model_ids/{args.model_id}.shuffled_ids", "w") as fout: 135 | # json.dump(shuffled_ids, fout) 136 | 137 | # Split the question file into `num_gpus` files 138 | assert num_gpus_total % num_gpus_per_model == 0 139 | use_ray = num_gpus_total // num_gpus_per_model > 1 140 | 141 | if use_ray: 142 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( 143 | get_model_answers 144 | ).remote 145 | else: 146 | get_answers_func = get_model_answers 147 | 148 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) # // 2 149 | ans_handles = [] 150 | for i in range(0, len(questions), chunk_size): 151 | ans_handles.append( 152 | get_answers_func( 153 | model_path, 154 | model_id, 155 | questions[i : i + chunk_size], 156 | answer_file, 157 | max_new_token, 158 | num_choices, 159 | num_gpus_per_model, 160 | max_gpu_memory, 161 | temperature, 162 | top_p, 163 | ) 164 | ) 165 | 166 | if use_ray: 167 | ray.get(ans_handles) 168 | 169 | 170 | @torch.inference_mode() 171 | def get_model_answers( 172 | model_path, 173 | model_id, 174 | questions, 175 | answer_file, 176 | max_new_token, 177 | num_choices, 178 | num_gpus_per_model, 179 | max_gpu_memory, 180 | temperature, 181 | top_p, 182 | ): 183 | 184 | model = RestModel.from_pretrained( 185 | model_path, 186 | torch_dtype=torch.float16, 187 | low_cpu_mem_usage=True, 188 | device_map="auto" 189 | ) 190 | 191 | tokenizer = model.get_tokenizer() 192 | 193 | model.eval() 194 | print('Check model training state:',model.training) 195 | 196 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') 197 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices) 198 | 199 | question = questions[0] 200 | 201 | # warmup 202 | for _ in range(3): 203 | torch.manual_seed(0) 204 | conv = get_conversation_template(model_id) 205 | turns = [] 206 | idxs = [] 207 | new_tokens = [] 208 | wall_time = [] 209 | for j in range(len(question["turns"])): 210 | qs = question["turns"][j] 211 | conv.append_message(conv.roles[0], qs) 212 | conv.append_message(conv.roles[1], None) 213 | prompt = conv.get_prompt() 214 | input_ids = tokenizer([prompt]).input_ids 215 | 216 | # if temperature < 1e-4: 217 | # do_sample = False 218 | # else: 219 | # do_sample = True 220 | 221 | # some models may error out when generating long outputs 222 | try: 223 | output_ids, new_token, idx, _, start_time = baseline_forward( 224 | torch.as_tensor(input_ids).cuda(), 225 | model, 226 | tokenizer, 227 | max_new_token, 228 | temperature, 229 | top_p, 230 | ) 231 | torch.cuda.synchronize() 232 | total_time = time.time() - start_time 233 | output_ids = output_ids[0][len(input_ids[0]) :] 234 | # be consistent with the template's stop_token_ids 235 | if conv.stop_token_ids: 236 | stop_token_ids_index = [ 237 | i 238 | for i, id in enumerate(output_ids) 239 | if id in conv.stop_token_ids 240 | ] 241 | if len(stop_token_ids_index) > 0: 242 | output_ids = output_ids[: stop_token_ids_index[0]] 243 | 244 | output = tokenizer.decode( 245 | output_ids, 246 | spaces_between_special_tokens=False, 247 | ) 248 | if conv.stop_str and output.find(conv.stop_str) > 0: 249 | output = output[: output.find(conv.stop_str)] 250 | 251 | if conv.name == "xgen" and output.startswith("Assistant:"): 252 | output = output.replace("Assistant:", "", 1).strip() 253 | except RuntimeError as e: 254 | print("ERROR question ID: ", question["question_id"]) 255 | output = "ERROR" 256 | 257 | turns.append(output) 258 | idxs.append(int(idx)) 259 | new_tokens.append(int(new_token)) 260 | wall_time.append(total_time) 261 | conv.messages[-1][-1] = output 262 | print('Warmup done') 263 | 264 | accept_lengths_tree = [] 265 | for question in tqdm(questions): 266 | # if question["category"] in temperature_config: 267 | # temperature = temperature_config[question["category"]] 268 | # else: 269 | # temperature = 0.7 270 | choices = [] 271 | for i in range(num_choices): 272 | accept_lengths_tree_this = [] 273 | torch.manual_seed(i) 274 | conv = get_conversation_template(model_id) 275 | turns = [] 276 | idxs = [] 277 | new_tokens = [] 278 | wall_time = [] 279 | for j in range(len(question["turns"])): 280 | qs = question["turns"][j] 281 | conv.append_message(conv.roles[0], qs) 282 | conv.append_message(conv.roles[1], None) 283 | prompt = conv.get_prompt() 284 | input_ids = tokenizer([prompt]).input_ids 285 | 286 | # if temperature < 1e-4: 287 | # do_sample = False 288 | # else: 289 | # do_sample = True 290 | 291 | # some models may error out when generating long outputs 292 | try: 293 | 294 | output_ids, new_token, idx, accept_length_tree, start_time = baseline_forward( 295 | torch.as_tensor(input_ids).cuda(), 296 | model, 297 | tokenizer, 298 | max_new_token, 299 | temperature, 300 | top_p, 301 | ) 302 | torch.cuda.synchronize() 303 | total_time = time.time() - start_time 304 | accept_lengths_tree.extend(accept_length_tree) 305 | # if model.config.is_encoder_decoder: 306 | # output_ids = output_ids[0] 307 | # else: 308 | output_ids = output_ids[0][len(input_ids[0]) :] 309 | 310 | # be consistent with the template's stop_token_ids 311 | if conv.stop_token_ids: 312 | stop_token_ids_index = [ 313 | i 314 | for i, id in enumerate(output_ids) 315 | if id in conv.stop_token_ids 316 | ] 317 | if len(stop_token_ids_index) > 0: 318 | output_ids = output_ids[: stop_token_ids_index[0]] 319 | 320 | output = tokenizer.decode( 321 | output_ids, 322 | spaces_between_special_tokens=False, 323 | ) 324 | if conv.stop_str and output.find(conv.stop_str) > 0: 325 | output = output[: output.find(conv.stop_str)] 326 | # for special_token in tokenizer.special_tokens_map.values(): 327 | # if isinstance(special_token, list): 328 | # for special_tok in special_token: 329 | # output = output.replace(special_tok, "") 330 | # else: 331 | # output = output.replace(special_token, "") 332 | # output = output.strip() 333 | 334 | if conv.name == "xgen" and output.startswith("Assistant:"): 335 | output = output.replace("Assistant:", "", 1).strip() 336 | except RuntimeError as e: 337 | print("ERROR question ID: ", question["question_id"]) 338 | output = "ERROR" 339 | 340 | turns.append(output) 341 | idxs.append(int(idx)) 342 | new_tokens.append(int(new_token)) 343 | wall_time.append(total_time) 344 | accept_lengths_tree_this.extend(accept_length_tree) 345 | conv.messages[-1][-1] = output 346 | # torch.cuda.empty_cache() 347 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time, "accept_lengths:": accept_lengths_tree_this}) 348 | 349 | # Dump answers 350 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 351 | with open(os.path.expanduser(answer_file), "a") as fout: 352 | ans_json = { 353 | "category": question["category"], 354 | "question_id": question["question_id"], 355 | "answer_id": shortuuid.uuid(), 356 | "model_id": model_id, 357 | "choices": choices, 358 | "tstamp": time.time(), 359 | } 360 | fout.write(json.dumps(ans_json) + "\n") 361 | print("accept_lengths_tree: ", np.mean(accept_lengths_tree)) 362 | 363 | 364 | def reorg_answer_file(answer_file): 365 | """Sort by question id and de-duplication""" 366 | answers = {} 367 | with open(answer_file, "r") as fin: 368 | for l in fin: 369 | qid = json.loads(l)["question_id"] 370 | answers[qid] = l 371 | 372 | qids = sorted(list(answers.keys())) 373 | with open(answer_file, "w") as fout: 374 | for qid in qids: 375 | fout.write(answers[qid]) 376 | 377 | 378 | if __name__ == "__main__": 379 | parser = argparse.ArgumentParser() 380 | parser.add_argument( 381 | "--model-path", 382 | type=str, 383 | required=True, 384 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 385 | ) 386 | parser.add_argument("--model-id", type=str, required=True) 387 | parser.add_argument( 388 | "--bench-name", 389 | type=str, 390 | default="mt_bench", 391 | help="The name of the benchmark question set.", 392 | ) 393 | parser.add_argument( 394 | "--question-begin", 395 | type=int, 396 | help="A debug option. The begin index of questions.", 397 | ) 398 | parser.add_argument( 399 | "--question-end", type=int, help="A debug option. The end index of questions." 400 | ) 401 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 402 | parser.add_argument( 403 | "--max-new-token", 404 | type=int, 405 | default=1024, 406 | help="The maximum number of new generated tokens.", 407 | ) 408 | parser.add_argument( 409 | "--num-choices", 410 | type=int, 411 | default=1, 412 | help="How many completion choices to generate.", 413 | ) 414 | parser.add_argument( 415 | "--num-gpus-per-model", 416 | type=int, 417 | default=1, 418 | help="The number of GPUs per model.", 419 | ) 420 | parser.add_argument( 421 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 422 | ) 423 | parser.add_argument( 424 | "--max-gpu-memory", 425 | type=str, 426 | help="Maxmum GPU memory used for model weights per GPU.", 427 | ) 428 | 429 | parser.add_argument( 430 | "--temperature", 431 | type=float, 432 | default=0.0, 433 | help="The temperature for sampling.", 434 | ) 435 | 436 | parser.add_argument( 437 | "--top-p", 438 | type=float, 439 | default=0.0, 440 | help="The threshold for nucleus sampling.", 441 | ) 442 | 443 | args = parser.parse_args() 444 | 445 | if args.temperature == 0: 446 | args.top_p = 0 447 | 448 | 449 | args.model_id = "baseline-" + args.model_id+"-temperature-"+str(args.temperature)+"-top_p-"+str(args.top_p) 450 | if args.num_gpus_total // args.num_gpus_per_model > 1: 451 | import ray 452 | ray.init() 453 | 454 | question_file = f"data/{args.bench_name}/question.jsonl" 455 | if args.answer_file: 456 | answer_file = args.answer_file 457 | else: 458 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 459 | 460 | print(f"Output to {answer_file}") 461 | 462 | run_eval( 463 | args.model_path, 464 | args.model_id, 465 | question_file, 466 | args.question_begin, 467 | args.question_end, 468 | answer_file, 469 | args.max_new_token, 470 | args.num_choices, 471 | args.num_gpus_per_model, 472 | args.num_gpus_total, 473 | args.max_gpu_memory, 474 | args.temperature, 475 | args.top_p, 476 | ) 477 | 478 | reorg_answer_file(answer_file) -------------------------------------------------------------------------------- /llm_judge/gen_model_answer_rest.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import random 10 | import time 11 | import shortuuid 12 | import torch 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | from fastchat.llm_judge.common import load_questions 17 | from fastchat.model import load_model, get_conversation_template 18 | 19 | # Rest imports 20 | import transformers 21 | 22 | import sys 23 | sys.path.append("../") 24 | 25 | from rest.model.utils import * 26 | from rest.model.rest_model import RestModel 27 | from rest.model.kv_cache import initialize_past_key_values 28 | import draftretriever 29 | 30 | 31 | def rest_forward(input_ids, model, tokenizer, max_new_token, temperature, top_p, datastore, num_draft, token_spans, max_steps=1024): 32 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 33 | # Avoid modifying the input_ids in-place 34 | input_ids = input_ids.clone() 35 | accept_length_list = [] 36 | 37 | # Initialize the past key and value states 38 | if hasattr(model, "past_key_values"): 39 | past_key_values = model.past_key_values 40 | past_key_values_data = model.past_key_values_data 41 | current_length_data = model.current_length_data 42 | # Reset the past key and value states 43 | current_length_data.zero_() 44 | else: 45 | ( 46 | past_key_values, 47 | past_key_values_data, 48 | current_length_data, 49 | ) = initialize_past_key_values(model.base_model) 50 | model.past_key_values = past_key_values 51 | model.past_key_values_data = past_key_values_data 52 | model.current_length_data = current_length_data 53 | 54 | input_len = input_ids.shape[1] 55 | cur_length = input_len + 1 56 | model.base_model.model.draft_mask = None 57 | logits = initialize_logits( 58 | input_ids, model, past_key_values 59 | ) 60 | new_token = 0 61 | 62 | torch.cuda.synchronize() 63 | start_time = time.time() 64 | for idx in range(max_steps): 65 | candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 66 | logits, 67 | input_ids, 68 | datastore, 69 | token_spans, 70 | top_p, 71 | temperature, 72 | max_num_draft=num_draft, 73 | device=model.base_model.device 74 | ) 75 | model.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 76 | logits, outputs = tree_decoding( 77 | model, 78 | tree_candidates, 79 | past_key_values, 80 | draft_buffers["draft_position_ids"], 81 | input_ids, 82 | draft_buffers["retrieve_indices"], 83 | ) 84 | best_candidate, accept_length = evaluate_posterior( 85 | logits, candidates, temperature, top_p 86 | ) 87 | input_ids, logits, new_token = update_inference_inputs( 88 | input_ids, 89 | candidates, 90 | best_candidate, 91 | accept_length, 92 | draft_buffers["retrieve_indices"], 93 | outputs, 94 | logits, 95 | new_token, 96 | past_key_values_data, 97 | current_length_data, 98 | ) 99 | accept_length_tree = input_ids.shape[1] - cur_length 100 | cur_length = accept_length_tree + cur_length 101 | accept_length_list.append(accept_length_tree) 102 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): 103 | break 104 | if new_token > max_new_token: 105 | break 106 | return input_ids, new_token, idx, accept_length_list, start_time 107 | 108 | def run_eval( 109 | model_path, 110 | model_id, 111 | question_file, 112 | question_begin, 113 | question_end, 114 | answer_file, 115 | max_new_token, 116 | num_choices, 117 | num_gpus_per_model, 118 | num_gpus_total, 119 | max_gpu_memory, 120 | temperature, 121 | top_p, 122 | datastore_path, 123 | num_draft, 124 | max_token_span, 125 | ): 126 | questions = load_questions(question_file, question_begin, question_end) 127 | # random shuffle the questions to balance the loading 128 | # random.shuffle(questions) 129 | shuffled_ids = [q["question_id"] for q in questions] 130 | # with open(f"data/{args.bench_name}/model_ids/{args.model_id}.shuffled_ids", "w") as fout: 131 | # json.dump(shuffled_ids, fout) 132 | 133 | token_spans = list(range(2, max_token_span+1))[::-1] 134 | print("loading the datastore ...") 135 | datastore = draftretriever.Reader( 136 | index_file_path=datastore_path, 137 | ) 138 | print("datastore loaded!") 139 | # Split the question file into `num_gpus` files 140 | assert num_gpus_total % num_gpus_per_model == 0 141 | use_ray = num_gpus_total // num_gpus_per_model > 1 142 | 143 | if use_ray: 144 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( 145 | get_model_answers 146 | ).remote 147 | else: 148 | get_answers_func = get_model_answers 149 | 150 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) # // 2 151 | ans_handles = [] 152 | for i in range(0, len(questions), chunk_size): 153 | ans_handles.append( 154 | get_answers_func( 155 | model_path, 156 | model_id, 157 | questions[i : i + chunk_size], 158 | answer_file, 159 | max_new_token, 160 | num_choices, 161 | num_gpus_per_model, 162 | max_gpu_memory, 163 | temperature, 164 | top_p, 165 | datastore, 166 | num_draft, 167 | token_spans, 168 | ) 169 | ) 170 | 171 | if use_ray: 172 | ray.get(ans_handles) 173 | 174 | 175 | @torch.inference_mode() 176 | def get_model_answers( 177 | model_path, 178 | model_id, 179 | questions, 180 | answer_file, 181 | max_new_token, 182 | num_choices, 183 | num_gpus_per_model, 184 | max_gpu_memory, 185 | temperature, 186 | top_p, 187 | datastore, 188 | num_draft, 189 | token_spans, 190 | ): 191 | 192 | model = RestModel.from_pretrained( 193 | model_path, 194 | torch_dtype=torch.float16, 195 | low_cpu_mem_usage=True, 196 | device_map="auto" 197 | ) 198 | 199 | tokenizer = model.get_tokenizer() 200 | 201 | model.eval() 202 | print('Check model training state:',model.training) 203 | 204 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') 205 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices) 206 | 207 | question = questions[0] 208 | 209 | # warmup 210 | for _ in range(3): 211 | torch.manual_seed(0) 212 | conv = get_conversation_template(model_id) 213 | turns = [] 214 | idxs = [] 215 | new_tokens = [] 216 | wall_time = [] 217 | for j in range(len(question["turns"])): 218 | qs = question["turns"][j] 219 | conv.append_message(conv.roles[0], qs) 220 | conv.append_message(conv.roles[1], None) 221 | prompt = conv.get_prompt() 222 | input_ids = tokenizer([prompt]).input_ids 223 | 224 | # if temperature < 1e-4: 225 | # do_sample = False 226 | # else: 227 | # do_sample = True 228 | 229 | # some models may error out when generating long outputs 230 | try: 231 | output_ids, new_token, idx, _, start_time = rest_forward( 232 | torch.as_tensor(input_ids).cuda(), 233 | model, 234 | tokenizer, 235 | max_new_token, 236 | temperature, 237 | top_p, 238 | datastore, 239 | num_draft, 240 | token_spans, 241 | ) 242 | torch.cuda.synchronize() 243 | total_time = time.time() - start_time 244 | output_ids = output_ids[0][len(input_ids[0]) :] 245 | # be consistent with the template's stop_token_ids 246 | if conv.stop_token_ids: 247 | stop_token_ids_index = [ 248 | i 249 | for i, id in enumerate(output_ids) 250 | if id in conv.stop_token_ids 251 | ] 252 | if len(stop_token_ids_index) > 0: 253 | output_ids = output_ids[: stop_token_ids_index[0]] 254 | 255 | output = tokenizer.decode( 256 | output_ids, 257 | spaces_between_special_tokens=False, 258 | ) 259 | if conv.stop_str and output.find(conv.stop_str) > 0: 260 | output = output[: output.find(conv.stop_str)] 261 | # for special_token in tokenizer.special_tokens_map.values(): 262 | # if isinstance(special_token, list): 263 | # for special_tok in special_token: 264 | # output = output.replace(special_tok, "") 265 | # else: 266 | # output = output.replace(special_token, "") 267 | # output = output.strip() 268 | 269 | if conv.name == "xgen" and output.startswith("Assistant:"): 270 | output = output.replace("Assistant:", "", 1).strip() 271 | except RuntimeError as e: 272 | print(f"question ID {question['question_id']} errored out with {e}") 273 | output = "ERROR" 274 | 275 | turns.append(output) 276 | idxs.append(int(idx)) 277 | new_tokens.append(int(new_token)) 278 | wall_time.append(total_time) 279 | conv.messages[-1][-1] = output 280 | print('Warmup done') 281 | 282 | accept_lengths_tree = [] 283 | for question in tqdm(questions): 284 | # if question["category"] in temperature_config: 285 | # temperature = temperature_config[question["category"]] 286 | # else: 287 | # temperature = 0.7 288 | choices = [] 289 | for i in range(num_choices): 290 | accept_lengths_tree_this = [] 291 | torch.manual_seed(i) 292 | conv = get_conversation_template(model_id) 293 | turns = [] 294 | idxs = [] 295 | new_tokens = [] 296 | wall_time = [] 297 | for j in range(len(question["turns"])): 298 | qs = question["turns"][j] 299 | conv.append_message(conv.roles[0], qs) 300 | conv.append_message(conv.roles[1], None) 301 | prompt = conv.get_prompt() 302 | input_ids = tokenizer([prompt]).input_ids 303 | 304 | # if temperature < 1e-4: 305 | # do_sample = False 306 | # else: 307 | # do_sample = True 308 | 309 | # some models may error out when generating long outputs 310 | try: 311 | 312 | output_ids, new_token, idx, accept_length_tree, start_time = rest_forward( 313 | torch.as_tensor(input_ids).cuda(), 314 | model, 315 | tokenizer, 316 | max_new_token, 317 | temperature, 318 | top_p, 319 | datastore, 320 | num_draft, 321 | token_spans, 322 | ) 323 | torch.cuda.synchronize() 324 | total_time = time.time() - start_time 325 | accept_lengths_tree.extend(accept_length_tree) 326 | # if model.config.is_encoder_decoder: 327 | # output_ids = output_ids[0] 328 | # else: 329 | output_ids = output_ids[0][len(input_ids[0]) :] 330 | 331 | # be consistent with the template's stop_token_ids 332 | if conv.stop_token_ids: 333 | stop_token_ids_index = [ 334 | i 335 | for i, id in enumerate(output_ids) 336 | if id in conv.stop_token_ids 337 | ] 338 | if len(stop_token_ids_index) > 0: 339 | output_ids = output_ids[: stop_token_ids_index[0]] 340 | 341 | output = tokenizer.decode( 342 | output_ids, 343 | spaces_between_special_tokens=False, 344 | ) 345 | if conv.stop_str and output.find(conv.stop_str) > 0: 346 | output = output[: output.find(conv.stop_str)] 347 | 348 | if conv.name == "xgen" and output.startswith("Assistant:"): 349 | output = output.replace("Assistant:", "", 1).strip() 350 | except RuntimeError as e: 351 | print("ERROR question ID: ", question["question_id"]) 352 | output = "ERROR" 353 | 354 | turns.append(output) 355 | idxs.append(int(idx)) 356 | new_tokens.append(int(new_token)) 357 | wall_time.append(total_time) 358 | accept_lengths_tree_this.extend(accept_length_tree) 359 | conv.messages[-1][-1] = output 360 | # torch.cuda.empty_cache() 361 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time, "accept_lengths:": accept_lengths_tree_this}) 362 | 363 | # Dump answers 364 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 365 | with open(os.path.expanduser(answer_file), "a") as fout: 366 | ans_json = { 367 | "category": question["category"], 368 | "question_id": question["question_id"], 369 | "answer_id": shortuuid.uuid(), 370 | "model_id": model_id, 371 | "choices": choices, 372 | "tstamp": time.time(), 373 | } 374 | fout.write(json.dumps(ans_json) + "\n") 375 | print("accept_lengths_tree: ", np.mean(accept_lengths_tree)) 376 | 377 | 378 | def reorg_answer_file(answer_file): 379 | """Sort by question id and de-duplication""" 380 | answers = {} 381 | with open(answer_file, "r") as fin: 382 | for l in fin: 383 | qid = json.loads(l)["question_id"] 384 | answers[qid] = l 385 | 386 | qids = sorted(list(answers.keys())) 387 | with open(answer_file, "w") as fout: 388 | for qid in qids: 389 | fout.write(answers[qid]) 390 | 391 | 392 | if __name__ == "__main__": 393 | parser = argparse.ArgumentParser() 394 | parser.add_argument( 395 | "--model-path", 396 | type=str, 397 | required=True, 398 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 399 | ) 400 | parser.add_argument("--model-id", type=str, required=True) 401 | parser.add_argument( 402 | "--bench-name", 403 | type=str, 404 | default="mt_bench", 405 | help="The name of the benchmark question set.", 406 | ) 407 | parser.add_argument( 408 | "--question-begin", 409 | type=int, 410 | help="A debug option. The begin index of questions.", 411 | ) 412 | parser.add_argument( 413 | "--question-end", type=int, help="A debug option. The end index of questions." 414 | ) 415 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 416 | parser.add_argument( 417 | "--max-new-token", 418 | type=int, 419 | default=1024, 420 | help="The maximum number of new generated tokens.", 421 | ) 422 | parser.add_argument( 423 | "--num-choices", 424 | type=int, 425 | default=1, 426 | help="How many completion choices to generate.", 427 | ) 428 | parser.add_argument( 429 | "--num-gpus-per-model", 430 | type=int, 431 | default=1, 432 | help="The number of GPUs per model.", 433 | ) 434 | parser.add_argument( 435 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 436 | ) 437 | parser.add_argument( 438 | "--max-gpu-memory", 439 | type=str, 440 | help="Maxmum GPU memory used for model weights per GPU.", 441 | ) 442 | 443 | parser.add_argument( 444 | "--temperature", 445 | type=float, 446 | default=0.0, 447 | help="The temperature for sampling.", 448 | ) 449 | 450 | parser.add_argument( 451 | "--top-p", 452 | type=float, 453 | default=0.0, 454 | help="The threshold for nucleus sampling.", 455 | ) 456 | 457 | # REST's hyperparameters 458 | parser.add_argument( 459 | "--datastore-path", 460 | type=str, 461 | required=True, 462 | help="The path of the datastore for retrival.", 463 | ) 464 | 465 | parser.add_argument( 466 | "--num-draft", 467 | type=int, 468 | default=64, 469 | help="The maximum number of draft tokens.", 470 | ) 471 | parser.add_argument( 472 | "--max-token-span", 473 | type=int, 474 | default=16, 475 | help="The maximum length of suffix for retrieval.", 476 | ) 477 | 478 | args = parser.parse_args() 479 | 480 | if args.temperature == 0: 481 | args.top_p = 0 482 | 483 | 484 | args.model_id = "rest-" + args.model_id+"-temperature-"+str(args.temperature)+"-top_p-"+str(args.top_p) 485 | if args.num_gpus_total // args.num_gpus_per_model > 1: 486 | import ray 487 | ray.init() 488 | 489 | question_file = f"data/{args.bench_name}/question.jsonl" 490 | if args.answer_file: 491 | answer_file = args.answer_file 492 | else: 493 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 494 | 495 | print(f"Output to {answer_file}") 496 | 497 | run_eval( 498 | args.model_path, 499 | args.model_id, 500 | question_file, 501 | args.question_begin, 502 | args.question_end, 503 | answer_file, 504 | args.max_new_token, 505 | args.num_choices, 506 | args.num_gpus_per_model, 507 | args.num_gpus_total, 508 | args.max_gpu_memory, 509 | args.temperature, 510 | args.top_p, 511 | args.datastore_path, 512 | args.num_draft, 513 | args.max_token_span, 514 | ) 515 | 516 | reorg_answer_file(answer_file) -------------------------------------------------------------------------------- /llm_judge/run_baseline.sh: -------------------------------------------------------------------------------- 1 | RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=4 python gen_model_answer_baseline.py --temperature 0.0 --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /llm_judge/run_rest.sh: -------------------------------------------------------------------------------- 1 | RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=4 python gen_model_answer_rest.py --temperature 0.0 --top-p 0.8 --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 --datastore-path <> 2 | 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | fschat[model_worker,webui] 3 | maturin==0.12 4 | numpy==1.26.1 5 | tqdm==4.66.1 6 | transformers 7 | accelerate==0.24.1 8 | datasets 9 | openai 10 | anthropic 11 | sentencepiece 12 | protobuf 13 | shortuuid 14 | -------------------------------------------------------------------------------- /rest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/rest/__init__.py -------------------------------------------------------------------------------- /rest/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/rest/inference/__init__.py -------------------------------------------------------------------------------- /rest/inference/cli.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py 2 | # Adapted from: https://github.com/FasterDecoding/Medusa/blob/main/medusa/inference/cli.py 3 | """ 4 | Chat with a model with command line interface. 5 | 6 | Usage: 7 | python3 -m rest.inference.cli --datastore-path <> --base-model <> 8 | Other commands: 9 | - Type "!!exit" or an empty line to exit. 10 | - Type "!!reset" to start a new conversation. 11 | - Type "!!remove" to remove the last prompt. 12 | - Type "!!regen" to regenerate the last message. 13 | - Type "!!save " to save the conversation history to a json file. 14 | - Type "!!load " to load a conversation history from a json file. 15 | """ 16 | import argparse 17 | import os 18 | import re 19 | import sys 20 | import torch 21 | from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO 22 | from fastchat.model.model_adapter import get_conversation_template 23 | from fastchat.conversation import get_conv_template 24 | import json 25 | from rest.model.rest_model import RestModel 26 | 27 | import draftretriever 28 | 29 | 30 | def main(args): 31 | if args.style == "simple": 32 | chatio = SimpleChatIO(args.multiline) 33 | elif args.style == "rich": 34 | chatio = RichChatIO(args.multiline, args.mouse) 35 | elif args.style == "programmatic": 36 | chatio = ProgrammaticChatIO() 37 | else: 38 | raise ValueError(f"Invalid style for console: {args.style}") 39 | try: 40 | model = RestModel.from_pretrained( 41 | args.base_model, 42 | torch_dtype=torch.float16, 43 | low_cpu_mem_usage=True, 44 | device_map="auto", 45 | load_in_8bit=args.load_in_8bit, 46 | load_in_4bit=args.load_in_4bit, 47 | ) 48 | tokenizer = model.get_tokenizer() 49 | if not args.baseline: 50 | datastore = draftretriever.Reader( 51 | index_file_path=args.datastore_path, # "/data/tianle2/hzy/ret_decode/datastore/vicuna_ultrachat.idx", 52 | ) 53 | conv = None 54 | 55 | def new_chat(): 56 | return get_conversation_template("vicuna") 57 | 58 | def reload_conv(conv): 59 | """ 60 | Reprints the conversation from the start. 61 | """ 62 | for message in conv.messages[conv.offset :]: 63 | chatio.prompt_for_output(message[0]) 64 | chatio.print_output(message[1]) 65 | 66 | while True: 67 | if not conv: 68 | conv = new_chat() 69 | 70 | try: 71 | inp = chatio.prompt_for_input(conv.roles[0]) 72 | except EOFError: 73 | inp = "" 74 | 75 | if inp == "!!exit" or not inp: 76 | print("exit...") 77 | break 78 | elif inp == "!!reset": 79 | print("resetting...") 80 | conv = new_chat() 81 | continue 82 | elif inp == "!!remove": 83 | print("removing last message...") 84 | if len(conv.messages) > conv.offset: 85 | # Assistant 86 | if conv.messages[-1][0] == conv.roles[1]: 87 | conv.messages.pop() 88 | # User 89 | if conv.messages[-1][0] == conv.roles[0]: 90 | conv.messages.pop() 91 | reload_conv(conv) 92 | else: 93 | print("No messages to remove.") 94 | continue 95 | elif inp == "!!regen": 96 | print("regenerating last message...") 97 | if len(conv.messages) > conv.offset: 98 | # Assistant 99 | if conv.messages[-1][0] == conv.roles[1]: 100 | conv.messages.pop() 101 | # User 102 | if conv.messages[-1][0] == conv.roles[0]: 103 | reload_conv(conv) 104 | # Set inp to previous message 105 | inp = conv.messages.pop()[1] 106 | else: 107 | # Shouldn't happen in normal circumstances 108 | print("No user message to regenerate from.") 109 | continue 110 | else: 111 | print("No messages to regenerate.") 112 | continue 113 | elif inp.startswith("!!save"): 114 | args = inp.split(" ", 1) 115 | 116 | if len(args) != 2: 117 | print("usage: !!save ") 118 | continue 119 | else: 120 | filename = args[1] 121 | 122 | # Add .json if extension not present 123 | if not "." in filename: 124 | filename += ".json" 125 | 126 | print("saving...", filename) 127 | with open(filename, "w") as outfile: 128 | json.dump(conv.dict(), outfile) 129 | continue 130 | elif inp.startswith("!!load"): 131 | args = inp.split(" ", 1) 132 | 133 | if len(args) != 2: 134 | print("usage: !!load ") 135 | continue 136 | else: 137 | filename = args[1] 138 | 139 | # Check if file exists and add .json if needed 140 | if not os.path.exists(filename): 141 | if (not filename.endswith(".json")) and os.path.exists( 142 | filename + ".json" 143 | ): 144 | filename += ".json" 145 | else: 146 | print("file not found:", filename) 147 | continue 148 | 149 | print("loading...", filename) 150 | with open(filename, "r") as infile: 151 | new_conv = json.load(infile) 152 | 153 | conv = get_conv_template(new_conv["template_name"]) 154 | conv.set_system_message(new_conv["system_message"]) 155 | conv.messages = new_conv["messages"] 156 | reload_conv(conv) 157 | continue 158 | 159 | conv.append_message(conv.roles[0], inp) 160 | conv.append_message(conv.roles[1], None) 161 | prompt = conv.get_prompt() 162 | 163 | try: 164 | chatio.prompt_for_output(conv.roles[1]) 165 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to( 166 | model.base_model.device 167 | ) 168 | outputs = chatio.stream_output( 169 | model.rest_generate( 170 | input_ids, 171 | datastore, 172 | temperature=args.temperature, 173 | top_p=args.top_p, 174 | max_steps=args.max_steps, 175 | ) if not args.baseline else \ 176 | model.baseline_generate( 177 | input_ids, 178 | temperature=args.temperature, 179 | top_p=args.top_p, 180 | max_steps=args.max_steps, 181 | ) 182 | ) 183 | conv.update_last_message(outputs.strip()) 184 | 185 | except KeyboardInterrupt: 186 | print("stopped generation.") 187 | # If generation didn't finish 188 | if conv.messages[-1][1] is None: 189 | conv.messages.pop() 190 | # Remove last user message, so there isn't a double up 191 | if conv.messages[-1][0] == conv.roles[0]: 192 | conv.messages.pop() 193 | 194 | reload_conv(conv) 195 | 196 | except KeyboardInterrupt: 197 | print("exit...") 198 | 199 | 200 | if __name__ == "__main__": 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument("--datastore-path", type=str, required=True, help="datastore path.") 203 | parser.add_argument("--base-model", type=str, default="lmsys/vicuna-7b-v1.5", help="LLM name or path.") 204 | parser.add_argument( 205 | "--load-in-8bit", action="store_true", help="Use 8-bit quantization" 206 | ) 207 | parser.add_argument( 208 | "--load-in-4bit", action="store_true", help="Use 4-bit quantization" 209 | ) 210 | parser.add_argument( 211 | "--conv-template", type=str, default=None, help="Conversation prompt template." 212 | ) 213 | parser.add_argument( 214 | "--conv-system-msg", type=str, default=None, help="Conversation system message." 215 | ) 216 | parser.add_argument("--temperature", type=float, default=0.7) 217 | parser.add_argument("--top_p", type=float, default=0.8) 218 | parser.add_argument("--max-steps", type=int, default=512) 219 | parser.add_argument("--baseline", action="store_true", help="Use standard autoregressive generation.") 220 | parser.add_argument("--no-history", action="store_true") 221 | parser.add_argument( 222 | "--style", 223 | type=str, 224 | default="simple", 225 | choices=["simple", "rich", "programmatic"], 226 | help="Display style.", 227 | ) 228 | parser.add_argument( 229 | "--multiline", 230 | action="store_true", 231 | help="Enable multiline input. Use ESC+Enter for newline.", 232 | ) 233 | parser.add_argument( 234 | "--mouse", 235 | action="store_true", 236 | help="[Rich Style]: Enable mouse support for cursor positioning.", 237 | ) 238 | parser.add_argument( 239 | "--debug", 240 | action="store_true", 241 | help="Print useful debug information (e.g., prompts)", 242 | ) 243 | args = parser.parse_args() 244 | main(args) 245 | -------------------------------------------------------------------------------- /rest/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/REST/50a5fc197382ed8df5b3e946dad2f8337511b541/rest/model/__init__.py -------------------------------------------------------------------------------- /rest/model/kv_cache.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/kv_cache.py 2 | 3 | import torch 4 | 5 | 6 | class KVCache: 7 | """ 8 | A key-value cache for the model. 9 | 10 | This class provides a mechanism to maintain a growing cache of keys and values, 11 | particularly useful for models that benefit from caching previous states, 12 | like transformers during autoregressive decoding. 13 | 14 | Attributes: 15 | data (torch.Tensor): The tensor storing keys and values. 16 | current_length (int): Current length of the data being stored. 17 | """ 18 | 19 | def __init__(self, data, current_length): 20 | """ 21 | Initialize the KVCache. 22 | 23 | Args: 24 | data (torch.Tensor): Initial tensor to store the keys and values. 25 | current_length (int): Initial length of the data. 26 | """ 27 | self.data = data 28 | self.current_length = current_length 29 | 30 | @property 31 | def shape(self): 32 | """Return the shape of the data tensor with updated length.""" 33 | return ( 34 | self.data.shape[0], 35 | self.data.shape[1], 36 | self.current_length.item(), 37 | self.data.shape[3], 38 | ) 39 | 40 | def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): 41 | """ 42 | Copy values from the current data at specified indices to a new location. 43 | 44 | Args: 45 | indices (torch.Tensor): Indices of the data tensor to be copied. 46 | prev_length (int): Previous length before adding new data. 47 | dim (int, optional): Dimension along which copying should be performed. Default is 2. 48 | """ 49 | tgt = self.data.index_select(dim, indices) 50 | dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) 51 | dst.copy_(tgt, non_blocking=True) 52 | self.current_length.fill_(prev_length + tgt.shape[dim]) 53 | 54 | def cat(self, tensor: torch.Tensor, dim: int = 2): 55 | """ 56 | Concatenate the given tensor with the current data. 57 | 58 | Args: 59 | tensor (torch.Tensor): The tensor to be concatenated. 60 | dim (int, optional): The dimension along which concatenation should be done. Default is 2. 61 | 62 | Returns: 63 | torch.Tensor: The data tensor after concatenation up to the current length. 64 | """ 65 | dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) 66 | dst.copy_(tensor) 67 | self.current_length.add_(tensor.shape[dim]) 68 | return torch.narrow(self.data, 2, 0, self.current_length) 69 | 70 | 71 | def initialize_past_key_values(model): 72 | """ 73 | Initialize past key and value states for a given transformer model. 74 | 75 | This function prepares key-value cache structures for the model, allowing it to store and reuse 76 | past key and value states during autoregressive decoding, which can improve efficiency. 77 | 78 | Args: 79 | model (nn.Module): The transformer model for which past key-value states need to be initialized. 80 | 81 | Returns: 82 | tuple: 83 | - past_key_values (list): A list of KVCache objects for each layer in the model. 84 | - past_key_values_data (torch.Tensor): The tensor that will store all keys and values. 85 | - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache. 86 | """ 87 | # Extracting configuration from the model 88 | config = model.config 89 | # Initializing the batch size to 1, this can be modified if different batch sizes are required 90 | batch_size = 1 91 | # Initializing a tensor to store past keys and values for all layers 92 | past_key_values_data = torch.zeros( 93 | config.num_hidden_layers * 2, 94 | batch_size, 95 | config.num_attention_heads, 96 | config.max_position_embeddings, 97 | config.hidden_size // config.num_attention_heads, 98 | device=model.device, 99 | dtype=model.dtype, 100 | ) 101 | # Initialize tensor to store the current length of the cached data for all layers. 102 | # [IMPORTANT] It needs to be kept on CPU for quick access and updates. 103 | current_length_data = torch.zeros( 104 | config.num_hidden_layers * 2, dtype=torch.long, device="cpu" 105 | ) 106 | # Creating a KVCache for each pair of key and value in all layers 107 | past_key_values = [] * config.num_hidden_layers 108 | for i in range(config.num_hidden_layers): 109 | past_key_values.append( 110 | [ 111 | KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j]) 112 | for j in range(2) 113 | ] 114 | ) 115 | return past_key_values, past_key_values_data, current_length_data 116 | -------------------------------------------------------------------------------- /rest/model/rest_model.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/medusa_model.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import PreTrainedModel, PretrainedConfig 6 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM 7 | from .utils import * 8 | from .kv_cache import initialize_past_key_values 9 | from transformers import AutoTokenizer 10 | import os 11 | import draftretriever 12 | 13 | 14 | 15 | class RestModel(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | base_model, 20 | base_model_name_or_path, 21 | token_spans=[16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2], 22 | ): 23 | """ 24 | Args: 25 | base_model (nn.Module): The LLM to be used. 26 | """ 27 | super().__init__() 28 | self.base_model = base_model 29 | self.config = base_model.config 30 | self.hidden_size = base_model.lm_head.weight.shape[-1] 31 | self.vocab_size = base_model.lm_head.weight.shape[0] 32 | self.base_model_name_or_path = base_model_name_or_path 33 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) 34 | self.token_spans = token_spans 35 | 36 | def get_tokenizer(self): 37 | 38 | """Get the tokenizer of the base model. 39 | 40 | Returns: 41 | Tokenizer: The tokenizer of the base model. 42 | """ 43 | return self.tokenizer 44 | 45 | @classmethod 46 | def from_pretrained( 47 | cls, 48 | base_model_path="codellama/CodeLlama-7b-instruct-hf", 49 | **kwargs, 50 | ): 51 | """ 52 | Args: 53 | base_model_path (str): Name or path of the LLM to load. 54 | 55 | Returns: 56 | RestModel 57 | """ 58 | 59 | base_model = KVLlamaForCausalLM.from_pretrained( 60 | base_model_path, **kwargs 61 | ) 62 | 63 | model = cls( 64 | base_model, 65 | base_model_path, 66 | ) 67 | 68 | return model 69 | 70 | def forward( 71 | self, 72 | input_ids=None, 73 | attention_mask=None, 74 | past_key_values=None, 75 | output_orig=False, 76 | position_ids=None, 77 | ): 78 | """Forward pass of the LLM. 79 | 80 | Args: 81 | input_ids (torch.Tensor, optional): Input token IDs. 82 | attention_mask (torch.Tensor, optional): Attention mask. 83 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 84 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 85 | position_ids (torch.Tensor, optional): Position IDs. 86 | 87 | Returns: 88 | torch.Tensor: A tensor containing predictions from the LM head. 89 | """ 90 | with torch.inference_mode(): 91 | # Pass input through the base model 92 | outputs = self.base_model.model( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | past_key_values=past_key_values, 96 | position_ids=position_ids, 97 | ) 98 | if output_orig: 99 | orig = self.base_model.lm_head(outputs[0]) 100 | 101 | if output_orig: 102 | return outputs, orig 103 | raise NotImplementedError 104 | 105 | def rest_generate( 106 | self, 107 | input_ids, 108 | datastore, 109 | temperature=0.0, 110 | top_p=0.8, 111 | max_steps=512, 112 | ): 113 | """ 114 | Args: 115 | input_ids (torch.Tensor, optional): Input token IDs. 116 | attention_mask (torch.Tensor, optional): Attention mask. 117 | temperature (float, optional): Temperature for typical acceptance. 118 | 119 | Returns: 120 | torch.Tensor: Output token IDs. 121 | 122 | Warning: Only support batch size 1 for now!! 123 | """ 124 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 125 | # Avoid modifying the input_ids in-place 126 | input_ids = input_ids.clone() 127 | 128 | # Initialize the past key and value states 129 | if hasattr(self, "past_key_values"): 130 | past_key_values = self.past_key_values 131 | past_key_values_data = self.past_key_values_data 132 | current_length_data = self.current_length_data 133 | # Reset the past key and value states 134 | current_length_data.zero_() 135 | else: 136 | ( 137 | past_key_values, 138 | past_key_values_data, 139 | current_length_data, 140 | ) = initialize_past_key_values(self.base_model) 141 | self.past_key_values = past_key_values 142 | self.past_key_values_data = past_key_values_data 143 | self.current_length_data = current_length_data 144 | 145 | input_len = input_ids.shape[1] 146 | 147 | self.base_model.model.draft_mask = None 148 | 149 | # Initialize tree attention mask and process prefill tokens 150 | logits = initialize_logits( 151 | input_ids, self, past_key_values 152 | ) 153 | 154 | new_token = 0 155 | last_round_token = 0 156 | 157 | for idx in range(max_steps): 158 | # Retrievd candidates (draft tokens) from the datastore 159 | candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 160 | logits, 161 | input_ids, 162 | datastore, 163 | self.token_spans, 164 | device=self.base_model.device 165 | ) 166 | self.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 167 | # Use tree attention to verify the candidates and get predictions 168 | logits, outputs = tree_decoding( 169 | self, 170 | tree_candidates, 171 | past_key_values, 172 | draft_buffers["draft_position_ids"], 173 | input_ids, 174 | draft_buffers["retrieve_indices"], 175 | ) 176 | 177 | # Evaluate the posterior of the candidates to select the accepted candidate prefix 178 | best_candidate, accept_length = evaluate_posterior( 179 | logits, candidates, temperature, top_p 180 | ) 181 | 182 | # Update the input_ids and logits 183 | input_ids, logits, new_token = update_inference_inputs( 184 | input_ids, 185 | candidates, 186 | best_candidate, 187 | accept_length, 188 | draft_buffers["retrieve_indices"], 189 | outputs, 190 | logits, 191 | new_token, 192 | past_key_values_data, 193 | current_length_data, 194 | ) 195 | 196 | yield { 197 | "text": self.tokenizer.decode( 198 | input_ids[0, input_len:], 199 | skip_special_tokens=True, 200 | spaces_between_special_tokens=False, 201 | clean_up_tokenization_spaces=True, 202 | ) 203 | } 204 | 205 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]: 206 | break 207 | 208 | 209 | def baseline_generate( 210 | self, 211 | input_ids, 212 | temperature=0.0, 213 | top_p=0.8, 214 | max_steps=512, 215 | ): 216 | """ 217 | Args: 218 | input_ids (torch.Tensor, optional): Input token IDs. 219 | attention_mask (torch.Tensor, optional): Attention mask. 220 | temperature (float, optional): Temperature for typical acceptance. 221 | 222 | Returns: 223 | torch.Tensor: Output token IDs. 224 | 225 | Warning: Only support batch size 1 for now!! 226 | """ 227 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 228 | # Avoid modifying the input_ids in-place 229 | input_ids = input_ids.clone() 230 | 231 | # Initialize the past key and value states 232 | if hasattr(self, "past_key_values"): 233 | past_key_values = self.past_key_values 234 | past_key_values_data = self.past_key_values_data 235 | current_length_data = self.current_length_data 236 | # Reset the past key and value states 237 | current_length_data.zero_() 238 | else: 239 | ( 240 | past_key_values, 241 | past_key_values_data, 242 | current_length_data, 243 | ) = initialize_past_key_values(self.base_model) 244 | self.past_key_values = past_key_values 245 | self.past_key_values_data = past_key_values_data 246 | self.current_length_data = current_length_data 247 | 248 | input_len = input_ids.shape[1] 249 | 250 | self.base_model.model.draft_mask = None 251 | outputs = self.base_model(input_ids, past_key_values = past_key_values, use_cache=True) 252 | new_token = 0 253 | last_round_token = 0 254 | 255 | for idx in range(max_steps): 256 | # # Retrievd candidates (draft tokens) from the datastore 257 | # candidates, tree_candidates, draft_buffers = generate_candidates_and_draft_buffer( 258 | # logits, 259 | # input_ids, 260 | # datastore, 261 | # self.token_spans, 262 | # device=self.base_model.device 263 | # ) 264 | # self.base_model.model.draft_mask = draft_buffers["draft_attn_mask"] 265 | # # Use tree attention to verify the candidates and get predictions 266 | # logits, outputs = tree_decoding( 267 | # self, 268 | # tree_candidates, 269 | # past_key_values, 270 | # draft_buffers["draft_position_ids"], 271 | # input_ids, 272 | # draft_buffers["retrieve_indices"], 273 | # ) 274 | 275 | # # Evaluate the posterior of the candidates to select the accepted candidate prefix 276 | # best_candidate, accept_length = evaluate_posterior( 277 | # logits, candidates, temperature 278 | # ) 279 | 280 | # # Update the input_ids and logits 281 | # input_ids, logits, new_token = update_inference_inputs( 282 | # input_ids, 283 | # candidates, 284 | # best_candidate, 285 | # accept_length, 286 | # draft_buffers["retrieve_indices"], 287 | # outputs, 288 | # logits, 289 | # new_token, 290 | # past_key_values_data, 291 | # current_length_data, 292 | # ) 293 | 294 | if top_p > 0: 295 | assert top_p < 1, "top_p should between 0.0 and 1" 296 | next_token_logits = outputs.logits[:, -1, :] 297 | next_token_logits = next_token_logits / (temperature if temperature > 0 else 1.) 298 | filtered_logits = top_p_filtering(next_token_logits, top_p=top_p) 299 | input_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 300 | input_id = input_id.view(input_id.shape[0], 1) 301 | else: 302 | input_id = outputs.logits[:, -1:].argmax(dim=-1) 303 | outputs = self.base_model(input_id, use_cache=True, past_key_values = past_key_values) 304 | input_ids = torch.cat([input_ids, input_id], dim=-1) 305 | 306 | yield { 307 | "text": self.tokenizer.decode( 308 | input_ids[0, input_len:], 309 | skip_special_tokens=True, 310 | spaces_between_special_tokens=False, 311 | clean_up_tokenization_spaces=True, 312 | ) 313 | } 314 | 315 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]: 316 | break -------------------------------------------------------------------------------- /rest/model/utils.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import draftretriever 7 | 8 | def pad_path(path, length, pad_value=-2): 9 | """ 10 | Pad the given path list with a specific value up to a specified length. 11 | 12 | Parameters: 13 | - path (list): The original list that needs padding. 14 | - length (int): The desired length of the padded list. 15 | - pad_value (optional, default=-2): The value to use for padding. 16 | 17 | Returns: 18 | - list: A new list based on the original path but padded to the desired length. 19 | 20 | Example: 21 | >>> pad_path([1,2,3], 5) 22 | [1, 2, 3, -2, -2] 23 | 24 | Note: 25 | If the given path is already longer than the specified length, 26 | then no padding occurs, and the original path is returned. 27 | """ 28 | 29 | # Calculate the number of padding values needed by subtracting the length 30 | # of the path from the desired length. 31 | # Append the padding values to the original path and return the new list. 32 | return path + [pad_value] * (length - len(path)) 33 | 34 | 35 | 36 | def initialize_logits(input_ids, model, past_key_values): 37 | """ 38 | Forward pass through the model to obtain the model outputs, and logits. 39 | 40 | 41 | Args: 42 | - input_ids (torch.Tensor): The input tensor containing token ids. 43 | - model: The LLM for generation. 44 | - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values. 45 | 46 | Returns: 47 | - logits (torch.Tensor): logits from the LLM. 48 | """ 49 | outputs, logits = model( 50 | input_ids, past_key_values=past_key_values, output_orig=True 51 | ) 52 | return logits 53 | 54 | 55 | def reset_past_key_values(passed_key_values): 56 | """ 57 | Resets the current lengths in the passed key-values to zero. 58 | 59 | This function is designed to be used during the evaluation of a baseline model. 60 | It iterates through each layer's key-values and sets their current lengths to zero, 61 | effectively resetting their state. 62 | 63 | Args: 64 | - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. 65 | 66 | Returns: 67 | - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. 68 | """ 69 | for i in range(len(passed_key_values)): 70 | for j in range(2): 71 | passed_key_values[i][j].current_length.fill_(0) 72 | return passed_key_values 73 | 74 | 75 | def generate_candidates_and_draft_buffer(logits, input_ids, datastore, token_spans, top_p=0., temperature=1., max_num_draft=64, device="cuda"): 76 | """ 77 | Generate candidates based on provided logits and indices. 78 | 79 | Parameters: 80 | - logits (torch.Tensor): Original logits. 81 | - tree_indices (list or torch.Tensor): Indices associated with a tree structure. 82 | - retrieve_indices (list or torch.Tensor): Indices for retrieving candidates. 83 | 84 | Returns: 85 | - tuple: Returns cartesian candidates and tree candidates. 86 | """ 87 | 88 | # Greedy decoding: Select the most probable candidate from the original logits. 89 | if top_p == 0: 90 | candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0) 91 | else: 92 | assert top_p < 1, "top_p should between 0.0 and 1" 93 | next_token_logits = logits[:, -1, :] 94 | next_token_logits = next_token_logits / (temperature if temperature > 0 else 1.) 95 | filtered_logits = top_p_filtering(next_token_logits, top_p=top_p) 96 | candidates_logit = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1).squeeze(0) 97 | 98 | input_ids_extend = torch.cat([input_ids.squeeze(0), candidates_logit], dim=-1) 99 | 100 | retrieved_token_list = [] 101 | _draft_attn_mask, _tree_indices, _draft_position_ids, _retrieve_indices = [], [], [], [] 102 | for span_id, token_span in enumerate(token_spans): 103 | this_token = input_ids_extend.squeeze(0)[-token_span:].to("cpu").tolist() 104 | # Retrieve draft tokens from the datastore, and get draft buffer 105 | retrieved_token_list, _draft_attn_mask, _tree_indices, _draft_position_ids, _retrieve_indices = datastore.search(this_token, choices=max_num_draft) 106 | 107 | # No retrieved sequences 108 | if len(retrieved_token_list) == 0: 109 | continue 110 | # Break because this span has hitted 111 | else: 112 | break 113 | # TODO: just continue to the next retrieval process 114 | if len(retrieved_token_list) == 0: 115 | # Just randomlt guess one token 116 | random_index = 100 117 | retrieved_position_token_list = [[random_index]] 118 | _draft_attn_mask = [[1., 0.], [1., 1.]] 119 | _tree_indices = [0, 1] 120 | _draft_position_ids = [0, 1] 121 | _retrieve_indices = [[0, 1]] 122 | else: 123 | retrieved_position_token_list = [list(row) for row in zip(*retrieved_token_list)] 124 | retrieved_position_token_list = [[x for i, x in enumerate(sublist) if sublist.index(x) == i and x != -2] for sublist in retrieved_position_token_list] 125 | TOPK = max(len(retrieved_position_token) for retrieved_position_token in retrieved_position_token_list) 126 | retrieved_position_token_list = [pad_path(retrieved_position_token, TOPK) for retrieved_position_token in retrieved_position_token_list] 127 | 128 | # Aggregate the generated buffers into a dictionary and Move the tensors in the dictionary to the specified device 129 | draft_buffers = { 130 | "draft_attn_mask": torch.tensor(_draft_attn_mask, device=device).unsqueeze(0).unsqueeze(0), 131 | "tree_indices": torch.tensor(_tree_indices, device=device), 132 | "draft_position_ids": torch.tensor(_draft_position_ids, device=device), 133 | "retrieve_indices": torch.tensor(_retrieve_indices, device=device), 134 | } 135 | 136 | candidates_draft_logits = torch.tensor(retrieved_position_token_list, dtype=torch.long, device=candidates_logit.device).contiguous() 137 | 138 | # Combine the selected candidate from the original logits with the draft logits. 139 | candidates = torch.cat([candidates_logit, candidates_draft_logits.view(-1)], dim=-1) 140 | 141 | # Map the combined candidates to the tree indices to get tree candidates. 142 | tree_candidates = candidates[draft_buffers["tree_indices"]] 143 | 144 | # Extend the tree candidates by appending a zero. 145 | tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0) 146 | 147 | # Retrieve the cartesian candidates using the retrieve indices. 148 | cart_candidates = tree_candidates_ext[draft_buffers["retrieve_indices"]] 149 | 150 | # Unsqueeze the tree candidates for dimension consistency. 151 | tree_candidates = tree_candidates.unsqueeze(0) 152 | 153 | return cart_candidates, tree_candidates, draft_buffers 154 | 155 | 156 | def tree_decoding( 157 | model, 158 | tree_candidates, 159 | past_key_values, 160 | draft_position_ids, 161 | input_ids, 162 | retrieve_indices, 163 | ): 164 | """ 165 | Decode the tree candidates using the provided model and reorganize the logits. 166 | 167 | Parameters: 168 | - model (nn.Module): Model to be used for decoding the tree candidates. 169 | - tree_candidates (torch.Tensor): Input candidates based on a tree structure. 170 | - past_key_values (torch.Tensor): Past states, such as key and value pairs, used in attention layers. 171 | - draft_position_ids (torch.Tensor): Positional IDs (Layer IDs in the Trie) of each draft token. 172 | - input_ids (torch.Tensor): Input sequence IDs. 173 | - retrieve_indices (list or torch.Tensor): Indices for reordering the logits. 174 | 175 | Returns: 176 | - tuple: Returns logits, and other outputs from the model. 177 | """ 178 | 179 | # Compute new position IDs by adding the draft position IDs to the length of the input sequence. 180 | position_ids = draft_position_ids + input_ids.shape[1] 181 | 182 | # Use the model to decode the tree candidates. 183 | # The model is expected to return each draft token's logits, and possibly other outputs. 184 | outputs, tree_logits = model( 185 | tree_candidates, 186 | output_orig=True, 187 | past_key_values=past_key_values, 188 | position_ids=position_ids, 189 | ) 190 | 191 | # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering. 192 | logits = tree_logits[0, retrieve_indices] 193 | 194 | return logits, outputs 195 | 196 | def get_nucleus_posterior_mask(logits, candidates, temperature, top_p): 197 | 198 | # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79 199 | 200 | # Apply temperature 201 | logits = logits[:, :-1] / temperature 202 | 203 | n_samples, n_tokens = logits.shape[0], logits.shape[1] 204 | logits = logits.view(n_samples*n_tokens, -1) 205 | 206 | # Convert to probabilities (softmax) 207 | probs = F.softmax(logits, dim=-1) 208 | # Sort the probabilities 209 | sorted_logits, sorted_indices = torch.sort(probs, descending=True) 210 | 211 | # Compute cumulative probabilities 212 | cum_probs = torch.cumsum(sorted_logits, dim=-1) 213 | 214 | # Create mask for the top-p nucleus 215 | sorted_indices_to_remove = cum_probs > top_p 216 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 217 | sorted_indices_to_remove[..., 0] = 0 218 | 219 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) 220 | 221 | 222 | # Remove low-probability tokens 223 | logits[indices_to_remove] = float('-inf') 224 | 225 | # Sample from the remaining tokens 226 | sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) 227 | sampled_tokens = sampled_tokens.view(n_samples, n_tokens) 228 | # Create a mask for selected tokens 229 | posterior_mask = (candidates[:, 1:] == sampled_tokens).int() 230 | 231 | return posterior_mask 232 | 233 | 234 | def evaluate_posterior( 235 | logits, candidates, temperature, top_p=0.8 236 | ): 237 | """ 238 | Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. 239 | 240 | Depending on the temperature value, the function either uses greedy decoding or evaluates posterior 241 | probabilities to select the best candidate. 242 | 243 | Args: 244 | - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). 245 | - candidates (torch.Tensor): Candidate token sequences. 246 | - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. 247 | Returns: 248 | - best_candidate (torch.Tensor): Index of the chosen best candidate. 249 | - accept_length (int): Length of the accepted candidate sequence. 250 | """ 251 | # Greedy decoding based on temperature value 252 | if temperature == 0: 253 | # Find the tokens that match the maximum logits for each position in the sequence 254 | posterior_mask = ( 255 | candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) 256 | ).int() 257 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) 258 | accept_length = candidates_accept_length.max() 259 | # Choose the best candidate 260 | if accept_length == 0: 261 | # Default to the first candidate if none are accepted 262 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) 263 | else: 264 | best_candidate = torch.argmax(candidates_accept_length).to(torch.long) 265 | return best_candidate, accept_length 266 | elif top_p > 0: 267 | assert top_p < 1.0, "top_p should between 0 and 1" 268 | posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p) 269 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) 270 | accept_length = candidates_accept_length.max() 271 | # Choose the best candidate 272 | if accept_length == 0: 273 | # Default to the first candidate if none are accepted 274 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) 275 | else: 276 | best_candidate = torch.argmax(candidates_accept_length).to(torch.long) 277 | return best_candidate, accept_length 278 | else: 279 | raise NotImplementedError 280 | 281 | 282 | def update_inference_inputs( 283 | input_ids, 284 | candidates, 285 | best_candidate, 286 | accept_length, 287 | retrieve_indices, 288 | outputs, 289 | logits, 290 | new_token, 291 | past_key_values_data, 292 | current_length_data, 293 | ): 294 | """ 295 | Update the input sequences and relevant tensors based on the selected best candidate from the inference results. 296 | 297 | Args: 298 | - input_ids (torch.Tensor): Current input token sequences. 299 | - candidates (torch.Tensor): Candidate token sequences generated in the current step. 300 | - best_candidate (int): Index of the chosen best candidate. 301 | - accept_length (int): Length of the accepted candidate sequence. 302 | - retrieve_indices (torch.Tensor): Indices to map tree to a cartesian product. 303 | - outputs, logits (torch.Tensor): Model's outputs from the previous inference step. 304 | - new_token (int): Counter for the new tokens added during inference. 305 | - past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model. 306 | - current_length_data (torch.Tensor): Tensor containing the current length of sequences in the batch. 307 | 308 | Returns: 309 | - input_ids (torch.Tensor): Updated input token sequences. 310 | - logits (torch.Tensor): Updated logits. 311 | - new_token (int): Updated counter for the new tokens added. 312 | """ 313 | # Calculate the starting position for new tokens based on the previous input length 314 | prev_input_len = input_ids.shape[1] 315 | # Map the best candidate indices to the original indices in the sequence 316 | select_indices = ( 317 | retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len 318 | ) 319 | # Append the tokens from the best candidate to the input sequence 320 | input_ids = torch.cat( 321 | [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1 322 | ) 323 | # Update the past key values based on the selected tokens 324 | # Source tensor that contains relevant past information based on the selected candidate 325 | tgt = past_key_values_data[..., select_indices, :] 326 | # Destination tensor where the relevant past information will be stored 327 | dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :] 328 | # Copy relevant past information from the source to the destination 329 | dst.copy_(tgt, non_blocking=True) 330 | 331 | # Update the current length tensor (currently only support batch size is 1) 332 | current_length_data.fill_(prev_input_len + tgt.shape[-2]) 333 | 334 | # Extract logits for the accepted tokens 335 | logits = logits[None, best_candidate, accept_length : accept_length + 1] 336 | 337 | # Update the new token counter 338 | new_token += accept_length + 1 339 | 340 | return input_ids, logits, new_token 341 | 342 | 343 | def top_p_filtering(logits, top_p=0.0, filter_value=float('-inf')): 344 | # from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79 345 | 346 | 347 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 348 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 349 | 350 | # Remove tokens with cumulative probability above the threshold 351 | sorted_indices_to_remove = cumulative_probs > top_p 352 | # Shift the indices to the right to keep also the first token above the threshold 353 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 354 | sorted_indices_to_remove[..., 0] = 0 355 | 356 | # scatter sorted tensors to original indexing 357 | indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) 358 | logits[indices_to_remove] = filter_value 359 | return logits --------------------------------------------------------------------------------