├── .cargo └── config.toml ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── assets └── cl100k_base.tiktoken ├── deny.toml ├── examples ├── basic_example.model ├── gpt4_encode.rs ├── test_gpt4.rs └── train.rs ├── src ├── base.rs ├── basic.rs ├── gpt4.rs ├── lib.rs ├── regex.rs └── test_common.rs └── tests ├── gen_test_case.py ├── taylorswift.txt ├── tiktoken_compat.proptest-regressions ├── tiktoken_compat.rs ├── tokenizer_tests.rs └── tokenizer_tests_gpt4.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [alias] 2 | bc = "build --all-targets --all-features" 3 | bcv = "bc --verbose" 4 | tc = "test --all-features" 5 | tcv = "tc --verbose" 6 | cc = "clippy --all-targets --all-features" 7 | fc = "fmt --all --check" 8 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Build 20 | run: cargo bcv 21 | - name: Run tests 22 | run: cargo tcv 23 | - name: Clippy 24 | run: cargo cc 25 | - name: Check Format 26 | run: cargo fc 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /flamegraph.svg 2 | /models 3 | /target 4 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "1.1.3" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "anyhow" 16 | version = "1.0.81" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" 19 | 20 | [[package]] 21 | name = "autocfg" 22 | version = "1.2.0" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" 25 | 26 | [[package]] 27 | name = "base64" 28 | version = "0.21.7" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" 31 | 32 | [[package]] 33 | name = "bit-set" 34 | version = "0.5.3" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" 37 | dependencies = [ 38 | "bit-vec", 39 | ] 40 | 41 | [[package]] 42 | name = "bit-vec" 43 | version = "0.6.3" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" 46 | 47 | [[package]] 48 | name = "bitflags" 49 | version = "1.3.2" 50 | source = "registry+https://github.com/rust-lang/crates.io-index" 51 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 52 | 53 | [[package]] 54 | name = "bitflags" 55 | version = "2.5.0" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" 58 | 59 | [[package]] 60 | name = "bstr" 61 | version = "1.9.1" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" 64 | dependencies = [ 65 | "memchr", 66 | "regex-automata", 67 | "serde", 68 | ] 69 | 70 | [[package]] 71 | name = "cfg-if" 72 | version = "1.0.0" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 75 | 76 | [[package]] 77 | name = "equivalent" 78 | version = "1.0.1" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" 81 | 82 | [[package]] 83 | name = "errno" 84 | version = "0.3.8" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" 87 | dependencies = [ 88 | "libc", 89 | "windows-sys", 90 | ] 91 | 92 | [[package]] 93 | name = "fancy-regex" 94 | version = "0.12.0" 95 | source = "registry+https://github.com/rust-lang/crates.io-index" 96 | checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" 97 | dependencies = [ 98 | "bit-set", 99 | "regex", 100 | ] 101 | 102 | [[package]] 103 | name = "fancy-regex" 104 | version = "0.13.0" 105 | source = "registry+https://github.com/rust-lang/crates.io-index" 106 | checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" 107 | dependencies = [ 108 | "bit-set", 109 | "regex-automata", 110 | "regex-syntax", 111 | ] 112 | 113 | [[package]] 114 | name = "fastrand" 115 | version = "2.0.2" 116 | source = "registry+https://github.com/rust-lang/crates.io-index" 117 | checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" 118 | 119 | [[package]] 120 | name = "fnv" 121 | version = "1.0.7" 122 | source = "registry+https://github.com/rust-lang/crates.io-index" 123 | checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" 124 | 125 | [[package]] 126 | name = "getrandom" 127 | version = "0.2.13" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | checksum = "a06fddc2749e0528d2813f95e050e87e52c8cbbae56223b9babf73b3e53b0cc6" 130 | dependencies = [ 131 | "cfg-if", 132 | "libc", 133 | "wasi", 134 | ] 135 | 136 | [[package]] 137 | name = "hashbrown" 138 | version = "0.14.3" 139 | source = "registry+https://github.com/rust-lang/crates.io-index" 140 | checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" 141 | 142 | [[package]] 143 | name = "indexmap" 144 | version = "2.2.6" 145 | source = "registry+https://github.com/rust-lang/crates.io-index" 146 | checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" 147 | dependencies = [ 148 | "equivalent", 149 | "hashbrown", 150 | ] 151 | 152 | [[package]] 153 | name = "lazy_static" 154 | version = "1.4.0" 155 | source = "registry+https://github.com/rust-lang/crates.io-index" 156 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 157 | 158 | [[package]] 159 | name = "libc" 160 | version = "0.2.153" 161 | source = "registry+https://github.com/rust-lang/crates.io-index" 162 | checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" 163 | 164 | [[package]] 165 | name = "libm" 166 | version = "0.2.8" 167 | source = "registry+https://github.com/rust-lang/crates.io-index" 168 | checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" 169 | 170 | [[package]] 171 | name = "linux-raw-sys" 172 | version = "0.4.13" 173 | source = "registry+https://github.com/rust-lang/crates.io-index" 174 | checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" 175 | 176 | [[package]] 177 | name = "lock_api" 178 | version = "0.4.11" 179 | source = "registry+https://github.com/rust-lang/crates.io-index" 180 | checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" 181 | dependencies = [ 182 | "autocfg", 183 | "scopeguard", 184 | ] 185 | 186 | [[package]] 187 | name = "memchr" 188 | version = "2.7.1" 189 | source = "registry+https://github.com/rust-lang/crates.io-index" 190 | checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" 191 | 192 | [[package]] 193 | name = "minbpe" 194 | version = "0.1.0" 195 | dependencies = [ 196 | "base64", 197 | "fancy-regex 0.13.0", 198 | "indexmap", 199 | "lazy_static", 200 | "proptest", 201 | "regex", 202 | "tempfile", 203 | "tiktoken-rs", 204 | ] 205 | 206 | [[package]] 207 | name = "num-traits" 208 | version = "0.2.18" 209 | source = "registry+https://github.com/rust-lang/crates.io-index" 210 | checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" 211 | dependencies = [ 212 | "autocfg", 213 | "libm", 214 | ] 215 | 216 | [[package]] 217 | name = "parking_lot" 218 | version = "0.12.1" 219 | source = "registry+https://github.com/rust-lang/crates.io-index" 220 | checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" 221 | dependencies = [ 222 | "lock_api", 223 | "parking_lot_core", 224 | ] 225 | 226 | [[package]] 227 | name = "parking_lot_core" 228 | version = "0.9.9" 229 | source = "registry+https://github.com/rust-lang/crates.io-index" 230 | checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" 231 | dependencies = [ 232 | "cfg-if", 233 | "libc", 234 | "redox_syscall", 235 | "smallvec", 236 | "windows-targets 0.48.5", 237 | ] 238 | 239 | [[package]] 240 | name = "ppv-lite86" 241 | version = "0.2.17" 242 | source = "registry+https://github.com/rust-lang/crates.io-index" 243 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 244 | 245 | [[package]] 246 | name = "proc-macro2" 247 | version = "1.0.79" 248 | source = "registry+https://github.com/rust-lang/crates.io-index" 249 | checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" 250 | dependencies = [ 251 | "unicode-ident", 252 | ] 253 | 254 | [[package]] 255 | name = "proptest" 256 | version = "1.4.0" 257 | source = "registry+https://github.com/rust-lang/crates.io-index" 258 | checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" 259 | dependencies = [ 260 | "bit-set", 261 | "bit-vec", 262 | "bitflags 2.5.0", 263 | "lazy_static", 264 | "num-traits", 265 | "rand", 266 | "rand_chacha", 267 | "rand_xorshift", 268 | "regex-syntax", 269 | "rusty-fork", 270 | "tempfile", 271 | "unarray", 272 | ] 273 | 274 | [[package]] 275 | name = "quick-error" 276 | version = "1.2.3" 277 | source = "registry+https://github.com/rust-lang/crates.io-index" 278 | checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" 279 | 280 | [[package]] 281 | name = "quote" 282 | version = "1.0.35" 283 | source = "registry+https://github.com/rust-lang/crates.io-index" 284 | checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" 285 | dependencies = [ 286 | "proc-macro2", 287 | ] 288 | 289 | [[package]] 290 | name = "rand" 291 | version = "0.8.5" 292 | source = "registry+https://github.com/rust-lang/crates.io-index" 293 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 294 | dependencies = [ 295 | "libc", 296 | "rand_chacha", 297 | "rand_core", 298 | ] 299 | 300 | [[package]] 301 | name = "rand_chacha" 302 | version = "0.3.1" 303 | source = "registry+https://github.com/rust-lang/crates.io-index" 304 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 305 | dependencies = [ 306 | "ppv-lite86", 307 | "rand_core", 308 | ] 309 | 310 | [[package]] 311 | name = "rand_core" 312 | version = "0.6.4" 313 | source = "registry+https://github.com/rust-lang/crates.io-index" 314 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 315 | dependencies = [ 316 | "getrandom", 317 | ] 318 | 319 | [[package]] 320 | name = "rand_xorshift" 321 | version = "0.3.0" 322 | source = "registry+https://github.com/rust-lang/crates.io-index" 323 | checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" 324 | dependencies = [ 325 | "rand_core", 326 | ] 327 | 328 | [[package]] 329 | name = "redox_syscall" 330 | version = "0.4.1" 331 | source = "registry+https://github.com/rust-lang/crates.io-index" 332 | checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" 333 | dependencies = [ 334 | "bitflags 1.3.2", 335 | ] 336 | 337 | [[package]] 338 | name = "regex" 339 | version = "1.10.4" 340 | source = "registry+https://github.com/rust-lang/crates.io-index" 341 | checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" 342 | dependencies = [ 343 | "aho-corasick", 344 | "memchr", 345 | "regex-automata", 346 | "regex-syntax", 347 | ] 348 | 349 | [[package]] 350 | name = "regex-automata" 351 | version = "0.4.6" 352 | source = "registry+https://github.com/rust-lang/crates.io-index" 353 | checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" 354 | dependencies = [ 355 | "aho-corasick", 356 | "memchr", 357 | "regex-syntax", 358 | ] 359 | 360 | [[package]] 361 | name = "regex-syntax" 362 | version = "0.8.2" 363 | source = "registry+https://github.com/rust-lang/crates.io-index" 364 | checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" 365 | 366 | [[package]] 367 | name = "rustc-hash" 368 | version = "1.1.0" 369 | source = "registry+https://github.com/rust-lang/crates.io-index" 370 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 371 | 372 | [[package]] 373 | name = "rustix" 374 | version = "0.38.32" 375 | source = "registry+https://github.com/rust-lang/crates.io-index" 376 | checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" 377 | dependencies = [ 378 | "bitflags 2.5.0", 379 | "errno", 380 | "libc", 381 | "linux-raw-sys", 382 | "windows-sys", 383 | ] 384 | 385 | [[package]] 386 | name = "rusty-fork" 387 | version = "0.3.0" 388 | source = "registry+https://github.com/rust-lang/crates.io-index" 389 | checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" 390 | dependencies = [ 391 | "fnv", 392 | "quick-error", 393 | "tempfile", 394 | "wait-timeout", 395 | ] 396 | 397 | [[package]] 398 | name = "scopeguard" 399 | version = "1.2.0" 400 | source = "registry+https://github.com/rust-lang/crates.io-index" 401 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 402 | 403 | [[package]] 404 | name = "serde" 405 | version = "1.0.197" 406 | source = "registry+https://github.com/rust-lang/crates.io-index" 407 | checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" 408 | dependencies = [ 409 | "serde_derive", 410 | ] 411 | 412 | [[package]] 413 | name = "serde_derive" 414 | version = "1.0.197" 415 | source = "registry+https://github.com/rust-lang/crates.io-index" 416 | checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" 417 | dependencies = [ 418 | "proc-macro2", 419 | "quote", 420 | "syn", 421 | ] 422 | 423 | [[package]] 424 | name = "smallvec" 425 | version = "1.13.2" 426 | source = "registry+https://github.com/rust-lang/crates.io-index" 427 | checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" 428 | 429 | [[package]] 430 | name = "syn" 431 | version = "2.0.55" 432 | source = "registry+https://github.com/rust-lang/crates.io-index" 433 | checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" 434 | dependencies = [ 435 | "proc-macro2", 436 | "quote", 437 | "unicode-ident", 438 | ] 439 | 440 | [[package]] 441 | name = "tempfile" 442 | version = "3.10.1" 443 | source = "registry+https://github.com/rust-lang/crates.io-index" 444 | checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" 445 | dependencies = [ 446 | "cfg-if", 447 | "fastrand", 448 | "rustix", 449 | "windows-sys", 450 | ] 451 | 452 | [[package]] 453 | name = "tiktoken-rs" 454 | version = "0.5.8" 455 | source = "registry+https://github.com/rust-lang/crates.io-index" 456 | checksum = "40894b788eb28bbb7e36bdc8b7b1b1488b9c93fa3730f315ab965330c94c0842" 457 | dependencies = [ 458 | "anyhow", 459 | "base64", 460 | "bstr", 461 | "fancy-regex 0.12.0", 462 | "lazy_static", 463 | "parking_lot", 464 | "rustc-hash", 465 | ] 466 | 467 | [[package]] 468 | name = "unarray" 469 | version = "0.1.4" 470 | source = "registry+https://github.com/rust-lang/crates.io-index" 471 | checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" 472 | 473 | [[package]] 474 | name = "unicode-ident" 475 | version = "1.0.12" 476 | source = "registry+https://github.com/rust-lang/crates.io-index" 477 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 478 | 479 | [[package]] 480 | name = "wait-timeout" 481 | version = "0.2.0" 482 | source = "registry+https://github.com/rust-lang/crates.io-index" 483 | checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" 484 | dependencies = [ 485 | "libc", 486 | ] 487 | 488 | [[package]] 489 | name = "wasi" 490 | version = "0.11.0+wasi-snapshot-preview1" 491 | source = "registry+https://github.com/rust-lang/crates.io-index" 492 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 493 | 494 | [[package]] 495 | name = "windows-sys" 496 | version = "0.52.0" 497 | source = "registry+https://github.com/rust-lang/crates.io-index" 498 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 499 | dependencies = [ 500 | "windows-targets 0.52.4", 501 | ] 502 | 503 | [[package]] 504 | name = "windows-targets" 505 | version = "0.48.5" 506 | source = "registry+https://github.com/rust-lang/crates.io-index" 507 | checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" 508 | dependencies = [ 509 | "windows_aarch64_gnullvm 0.48.5", 510 | "windows_aarch64_msvc 0.48.5", 511 | "windows_i686_gnu 0.48.5", 512 | "windows_i686_msvc 0.48.5", 513 | "windows_x86_64_gnu 0.48.5", 514 | "windows_x86_64_gnullvm 0.48.5", 515 | "windows_x86_64_msvc 0.48.5", 516 | ] 517 | 518 | [[package]] 519 | name = "windows-targets" 520 | version = "0.52.4" 521 | source = "registry+https://github.com/rust-lang/crates.io-index" 522 | checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" 523 | dependencies = [ 524 | "windows_aarch64_gnullvm 0.52.4", 525 | "windows_aarch64_msvc 0.52.4", 526 | "windows_i686_gnu 0.52.4", 527 | "windows_i686_msvc 0.52.4", 528 | "windows_x86_64_gnu 0.52.4", 529 | "windows_x86_64_gnullvm 0.52.4", 530 | "windows_x86_64_msvc 0.52.4", 531 | ] 532 | 533 | [[package]] 534 | name = "windows_aarch64_gnullvm" 535 | version = "0.48.5" 536 | source = "registry+https://github.com/rust-lang/crates.io-index" 537 | checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" 538 | 539 | [[package]] 540 | name = "windows_aarch64_gnullvm" 541 | version = "0.52.4" 542 | source = "registry+https://github.com/rust-lang/crates.io-index" 543 | checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" 544 | 545 | [[package]] 546 | name = "windows_aarch64_msvc" 547 | version = "0.48.5" 548 | source = "registry+https://github.com/rust-lang/crates.io-index" 549 | checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" 550 | 551 | [[package]] 552 | name = "windows_aarch64_msvc" 553 | version = "0.52.4" 554 | source = "registry+https://github.com/rust-lang/crates.io-index" 555 | checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" 556 | 557 | [[package]] 558 | name = "windows_i686_gnu" 559 | version = "0.48.5" 560 | source = "registry+https://github.com/rust-lang/crates.io-index" 561 | checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" 562 | 563 | [[package]] 564 | name = "windows_i686_gnu" 565 | version = "0.52.4" 566 | source = "registry+https://github.com/rust-lang/crates.io-index" 567 | checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" 568 | 569 | [[package]] 570 | name = "windows_i686_msvc" 571 | version = "0.48.5" 572 | source = "registry+https://github.com/rust-lang/crates.io-index" 573 | checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" 574 | 575 | [[package]] 576 | name = "windows_i686_msvc" 577 | version = "0.52.4" 578 | source = "registry+https://github.com/rust-lang/crates.io-index" 579 | checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" 580 | 581 | [[package]] 582 | name = "windows_x86_64_gnu" 583 | version = "0.48.5" 584 | source = "registry+https://github.com/rust-lang/crates.io-index" 585 | checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" 586 | 587 | [[package]] 588 | name = "windows_x86_64_gnu" 589 | version = "0.52.4" 590 | source = "registry+https://github.com/rust-lang/crates.io-index" 591 | checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" 592 | 593 | [[package]] 594 | name = "windows_x86_64_gnullvm" 595 | version = "0.48.5" 596 | source = "registry+https://github.com/rust-lang/crates.io-index" 597 | checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" 598 | 599 | [[package]] 600 | name = "windows_x86_64_gnullvm" 601 | version = "0.52.4" 602 | source = "registry+https://github.com/rust-lang/crates.io-index" 603 | checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" 604 | 605 | [[package]] 606 | name = "windows_x86_64_msvc" 607 | version = "0.48.5" 608 | source = "registry+https://github.com/rust-lang/crates.io-index" 609 | checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" 610 | 611 | [[package]] 612 | name = "windows_x86_64_msvc" 613 | version = "0.52.4" 614 | source = "registry+https://github.com/rust-lang/crates.io-index" 615 | checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" 616 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "minbpe" 3 | version = "0.1.0" 4 | description = "Port of Andrej Karpathy's minbpe to Rust" 5 | authors = ["Gregor Purdy "] 6 | license = "MIT OR Apache-2.0" 7 | edition = "2021" 8 | keywords = ["language-model", "codec", "gpt", "ai"] 9 | readme = "README.md" 10 | repository = "https://github.com/gnp/minbpe-rs.git" 11 | include = ["assets/**/*", "examples/**/*", "src/**/*", "tests/**/*"] 12 | 13 | [features] 14 | default = ["basic", "regex"] 15 | basic = [] 16 | regex = [] 17 | gpt4 = ["regex"] 18 | tiktoken_tests = ["gpt4", "tiktoken-rs"] 19 | 20 | [lib] 21 | path = "src/lib.rs" 22 | 23 | [dependencies] 24 | regex = "1.10" 25 | fancy-regex = "0.13" 26 | indexmap = "2.2" 27 | lazy_static = "1.4.0" 28 | base64 = "0.21.5" 29 | tiktoken-rs = { version = "0.5.8", optional = true } 30 | 31 | [dev-dependencies] 32 | tempfile = "3.10" 33 | proptest = "1.4.0" 34 | 35 | [profile.release] 36 | debug = true 37 | 38 | [[example]] 39 | name = "gpt4_encode" 40 | required-features = ["gpt4"] 41 | 42 | [[example]] 43 | name = "test_gpt4" 44 | required-features = ["gpt4"] 45 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `minbpe-rs` 2 | 3 | > Port of Andrej Karpathy's [minbpe](https://github.com/karpathy/minbpe) to Rust. 4 | 5 | [![minbpe-rs crate](https://img.shields.io/crates/v/minbpe.svg)](https://crates.io/crates/minbpe) 6 | [![minbpe-rs documentation](https://docs.rs/minbpe/badge.svg)](https://docs.rs/minbpe) 7 | 8 | 9 | ## Quick Start 10 | 11 | Create a Rust application crate with `cargo`, 12 | 13 | ``` 14 | $> cargo new minbpe-test 15 | ``` 16 | 17 | In the resulting project, add `minbpe` to `Cargo.toml`, 18 | 19 | ```toml 20 | [dependencies] 21 | minbpe = "0.1.0" 22 | ``` 23 | 24 | Refer [`crates.io`](https://crates.io/crates/minbpe) for selecting the latest version. Next in `src/main.rs`, 25 | 26 | ```rust 27 | use std::path::Path; 28 | use minbpe::{BasicTokenizer, Saveable, Tokenizer, Trainable}; 29 | 30 | fn main() { 31 | let text = "aaabdaaabac" ; 32 | let mut tokenizer = BasicTokenizer::new() ; 33 | tokenizer.train( text , 256 + 3 , false ) ; 34 | println!( "{:?}" , tokenizer.encode(text) ) ; 35 | println!( "{:?}" , tokenizer.decode( &[258, 100, 258, 97, 99] ) ) ; 36 | tokenizer.save( Path::new( "./" ) , "toy" ) ; 37 | } 38 | ``` 39 | 40 | Execute the binary with `cargo run`, 41 | 42 | ``` 43 | $> cargo run 44 | 45 | ... 46 | Compiling minbpe-test v0.1.0 (~/minbpe-test) 47 | Finished dev [unoptimized + debuginfo] target(s) in 15.71s 48 | Running `target/debug/minbpe-test` 49 | [258, 100, 258, 97, 99] 50 | "aaabdaaabac" 51 | 52 | ``` 53 | 54 | 55 | ## License 56 | 57 | Licensed under either of 58 | 59 | * Apache License, Version 2.0 60 | ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 61 | * MIT license 62 | ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 63 | 64 | at your option. 65 | 66 | 67 | ## Contribution 68 | 69 | Unless you explicitly state otherwise, any contribution intentionally submitted 70 | for inclusion in the work by you, as defined in the Apache-2.0 license, shall be 71 | dual licensed as above, without any additional terms or conditions. 72 | -------------------------------------------------------------------------------- /deny.toml: -------------------------------------------------------------------------------- 1 | # This template contains all of the possible sections and their default values 2 | 3 | # Note that all fields that take a lint level have these possible values: 4 | # * deny - An error will be produced and the check will fail 5 | # * warn - A warning will be produced, but the check will not fail 6 | # * allow - No warning or error will be produced, though in some cases a note 7 | # will be 8 | 9 | # The values provided in this template are the default values that will be used 10 | # when any section or field is not specified in your own configuration 11 | 12 | # Root options 13 | 14 | # The graph table configures how the dependency graph is constructed and thus 15 | # which crates the checks are performed against 16 | [graph] 17 | # If 1 or more target triples (and optionally, target_features) are specified, 18 | # only the specified targets will be checked when running `cargo deny check`. 19 | # This means, if a particular package is only ever used as a target specific 20 | # dependency, such as, for example, the `nix` crate only being used via the 21 | # `target_family = "unix"` configuration, that only having windows targets in 22 | # this list would mean the nix crate, as well as any of its exclusive 23 | # dependencies not shared by any other crates, would be ignored, as the target 24 | # list here is effectively saying which targets you are building for. 25 | targets = [ 26 | # The triple can be any string, but only the target triples built in to 27 | # rustc (as of 1.40) can be checked against actual config expressions 28 | #"x86_64-unknown-linux-musl", 29 | # You can also specify which target_features you promise are enabled for a 30 | # particular target. target_features are currently not validated against 31 | # the actual valid features supported by the target architecture. 32 | #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, 33 | ] 34 | # When creating the dependency graph used as the source of truth when checks are 35 | # executed, this field can be used to prune crates from the graph, removing them 36 | # from the view of cargo-deny. This is an extremely heavy hammer, as if a crate 37 | # is pruned from the graph, all of its dependencies will also be pruned unless 38 | # they are connected to another crate in the graph that hasn't been pruned, 39 | # so it should be used with care. The identifiers are [Package ID Specifications] 40 | # (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html) 41 | #exclude = [] 42 | # If true, metadata will be collected with `--all-features`. Note that this can't 43 | # be toggled off if true, if you want to conditionally enable `--all-features` it 44 | # is recommended to pass `--all-features` on the cmd line instead 45 | all-features = false 46 | # If true, metadata will be collected with `--no-default-features`. The same 47 | # caveat with `all-features` applies 48 | no-default-features = false 49 | # If set, these feature will be enabled when collecting metadata. If `--features` 50 | # is specified on the cmd line they will take precedence over this option. 51 | #features = [] 52 | 53 | # The output table provides options for how/if diagnostics are outputted 54 | [output] 55 | # When outputting inclusion graphs in diagnostics that include features, this 56 | # option can be used to specify the depth at which feature edges will be added. 57 | # This option is included since the graphs can be quite large and the addition 58 | # of features from the crate(s) to all of the graph roots can be far too verbose. 59 | # This option can be overridden via `--feature-depth` on the cmd line 60 | feature-depth = 1 61 | 62 | # This section is considered when running `cargo deny check advisories` 63 | # More documentation for the advisories section can be found here: 64 | # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html 65 | [advisories] 66 | # The path where the advisory databases are cloned/fetched into 67 | #db-path = "$CARGO_HOME/advisory-dbs" 68 | # The url(s) of the advisory databases to use 69 | #db-urls = ["https://github.com/rustsec/advisory-db"] 70 | # A list of advisory IDs to ignore. Note that ignored advisories will still 71 | # output a note when they are encountered. 72 | ignore = [ 73 | #"RUSTSEC-0000-0000", 74 | #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, 75 | #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish 76 | #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, 77 | ] 78 | # If this is true, then cargo deny will use the git executable to fetch advisory database. 79 | # If this is false, then it uses a built-in git library. 80 | # Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support. 81 | # See Git Authentication for more information about setting up git authentication. 82 | #git-fetch-with-cli = true 83 | 84 | # This section is considered when running `cargo deny check licenses` 85 | # More documentation for the licenses section can be found here: 86 | # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html 87 | [licenses] 88 | # List of explicitly allowed licenses 89 | # See https://spdx.org/licenses/ for list of possible licenses 90 | # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. 91 | allow = [ 92 | "MIT", 93 | "Apache-2.0", 94 | #"Apache-2.0 WITH LLVM-exception", 95 | ] 96 | # The confidence threshold for detecting a license from license text. 97 | # The higher the value, the more closely the license text must be to the 98 | # canonical license text of a valid SPDX license file. 99 | # [possible values: any between 0.0 and 1.0]. 100 | confidence-threshold = 0.8 101 | # Allow 1 or more licenses on a per-crate basis, so that particular licenses 102 | # aren't accepted for every possible crate as with the normal allow list 103 | exceptions = [ 104 | # Each entry is the crate and version constraint, and its specific allow 105 | # list 106 | #{ allow = ["Zlib"], crate = "adler32" }, 107 | ] 108 | 109 | # Some crates don't have (easily) machine readable licensing information, 110 | # adding a clarification entry for it allows you to manually specify the 111 | # licensing information 112 | #[[licenses.clarify]] 113 | # The package spec the clarification applies to 114 | #crate = "ring" 115 | # The SPDX expression for the license requirements of the crate 116 | #expression = "MIT AND ISC AND OpenSSL" 117 | # One or more files in the crate's source used as the "source of truth" for 118 | # the license expression. If the contents match, the clarification will be used 119 | # when running the license check, otherwise the clarification will be ignored 120 | # and the crate will be checked normally, which may produce warnings or errors 121 | # depending on the rest of your configuration 122 | #license-files = [ 123 | # Each entry is a crate relative path, and the (opaque) hash of its contents 124 | #{ path = "LICENSE", hash = 0xbd0eed23 } 125 | #] 126 | 127 | [licenses.private] 128 | # If true, ignores workspace crates that aren't published, or are only 129 | # published to private registries. 130 | # To see how to mark a crate as unpublished (to the official registry), 131 | # visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field. 132 | ignore = false 133 | # One or more private registries that you might publish crates to, if a crate 134 | # is only published to private registries, and ignore is true, the crate will 135 | # not have its license(s) checked 136 | registries = [ 137 | #"https://sekretz.com/registry 138 | ] 139 | 140 | # This section is considered when running `cargo deny check bans`. 141 | # More documentation about the 'bans' section can be found here: 142 | # https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html 143 | [bans] 144 | # Lint level for when multiple versions of the same crate are detected 145 | multiple-versions = "warn" 146 | # Lint level for when a crate version requirement is `*` 147 | wildcards = "allow" 148 | # The graph highlighting used when creating dotgraphs for crates 149 | # with multiple versions 150 | # * lowest-version - The path to the lowest versioned duplicate is highlighted 151 | # * simplest-path - The path to the version with the fewest edges is highlighted 152 | # * all - Both lowest-version and simplest-path are used 153 | highlight = "all" 154 | # The default lint level for `default` features for crates that are members of 155 | # the workspace that is being checked. This can be overridden by allowing/denying 156 | # `default` on a crate-by-crate basis if desired. 157 | workspace-default-features = "allow" 158 | # The default lint level for `default` features for external crates that are not 159 | # members of the workspace. This can be overridden by allowing/denying `default` 160 | # on a crate-by-crate basis if desired. 161 | external-default-features = "allow" 162 | # List of crates that are allowed. Use with care! 163 | allow = [ 164 | #"ansi_term@0.11.0", 165 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, 166 | ] 167 | # List of crates to deny 168 | deny = [ 169 | #"ansi_term@0.11.0", 170 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" }, 171 | # Wrapper crates can optionally be specified to allow the crate when it 172 | # is a direct dependency of the otherwise banned crate 173 | #{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] }, 174 | ] 175 | 176 | # List of features to allow/deny 177 | # Each entry the name of a crate and a version range. If version is 178 | # not specified, all versions will be matched. 179 | #[[bans.features]] 180 | #crate = "reqwest" 181 | # Features to not allow 182 | #deny = ["json"] 183 | # Features to allow 184 | #allow = [ 185 | # "rustls", 186 | # "__rustls", 187 | # "__tls", 188 | # "hyper-rustls", 189 | # "rustls", 190 | # "rustls-pemfile", 191 | # "rustls-tls-webpki-roots", 192 | # "tokio-rustls", 193 | # "webpki-roots", 194 | #] 195 | # If true, the allowed features must exactly match the enabled feature set. If 196 | # this is set there is no point setting `deny` 197 | #exact = true 198 | 199 | # Certain crates/versions that will be skipped when doing duplicate detection. 200 | skip = [ 201 | #"ansi_term@0.11.0", 202 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" }, 203 | ] 204 | # Similarly to `skip` allows you to skip certain crates during duplicate 205 | # detection. Unlike skip, it also includes the entire tree of transitive 206 | # dependencies starting at the specified crate, up to a certain depth, which is 207 | # by default infinite. 208 | skip-tree = [ 209 | #"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies 210 | #{ crate = "ansi_term@0.11.0", depth = 20 }, 211 | ] 212 | 213 | # This section is considered when running `cargo deny check sources`. 214 | # More documentation about the 'sources' section can be found here: 215 | # https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html 216 | [sources] 217 | # Lint level for what to happen when a crate from a crate registry that is not 218 | # in the allow list is encountered 219 | unknown-registry = "warn" 220 | # Lint level for what to happen when a crate from a git repository that is not 221 | # in the allow list is encountered 222 | unknown-git = "warn" 223 | # List of URLs for allowed crate registries. Defaults to the crates.io index 224 | # if not specified. If it is specified but empty, no registries are allowed. 225 | allow-registry = ["https://github.com/rust-lang/crates.io-index"] 226 | # List of URLs for allowed Git repositories 227 | allow-git = [] 228 | 229 | [sources.allow-org] 230 | # 1 or more github.com organizations to allow git sources for 231 | #github = [""] 232 | # 1 or more gitlab.com organizations to allow git sources for 233 | #gitlab = [""] 234 | # 1 or more bitbucket.org organizations to allow git sources for 235 | #bitbucket = [""] 236 | -------------------------------------------------------------------------------- /examples/basic_example.model: -------------------------------------------------------------------------------- 1 | minbpe v1 2 | 3 | 0 4 | 101 32 -------------------------------------------------------------------------------- /examples/gpt4_encode.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | use std::path::PathBuf; 3 | 4 | use minbpe::GPT4Tokenizer; 5 | use minbpe::RegexTokenizerTrait; 6 | 7 | fn main() -> std::io::Result<()> { 8 | let file_path = PathBuf::from("tests/taylorswift.txt"); 9 | 10 | // Pre-initialize the tokenizer 11 | println!("Pre-initializing the tokenizer..."); 12 | let start = std::time::Instant::now(); 13 | GPT4Tokenizer::initialize(); 14 | let duration = start.elapsed(); 15 | println!( 16 | "GPT4Tokenizer static initialization completed in: {:?}", 17 | duration 18 | ); 19 | 20 | // Get default instance of the tokenizer 21 | println!("Getting a default instance of GPT4Tokenizer..."); 22 | let start = std::time::Instant::now(); 23 | let tokenizer = GPT4Tokenizer::default(); 24 | let duration = start.elapsed(); 25 | println!( 26 | "GPT4Tokenizer default instance construction completed in: {:?}", 27 | duration 28 | ); 29 | 30 | // Read the input file 31 | println!("Reading file: {:?}...", file_path); 32 | let start = std::time::Instant::now(); 33 | let text = fs::read_to_string(file_path)?; 34 | let duration = start.elapsed(); 35 | println!( 36 | "Reading {} characters completed in: {:?}", 37 | text.len(), 38 | duration 39 | ); 40 | 41 | // Timing the encoding process, optional. 42 | let start = std::time::Instant::now(); 43 | let tokens = tokenizer.encode(&text); 44 | let duration = start.elapsed(); 45 | 46 | println!("Encoding completed in: {:?}", duration); 47 | println!("Produced {} encoded tokens", tokens.len()); 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /examples/test_gpt4.rs: -------------------------------------------------------------------------------- 1 | use minbpe::GPT4Tokenizer; 2 | use minbpe::RegexTokenizerTrait; 3 | 4 | fn main() { 5 | let text = "\u{1e01b}%SΣ"; 6 | 7 | // Pre-initialize the tokenizer 8 | let start = std::time::Instant::now(); 9 | GPT4Tokenizer::initialize(); 10 | let duration = start.elapsed(); 11 | println!( 12 | "GPT4Tokenizer static initialization completed in: {:?}", 13 | duration 14 | ); 15 | 16 | // Initialize the tokenizer 17 | let start = std::time::Instant::now(); 18 | let tokenizer = GPT4Tokenizer::default(); 19 | let duration = start.elapsed(); 20 | println!( 21 | "GPT4Tokenizer default instance construction completed in: {:?}", 22 | duration 23 | ); 24 | 25 | // Encode the string 26 | let start = std::time::Instant::now(); 27 | let tokens = tokenizer.encode(text); 28 | let duration = start.elapsed(); 29 | println!( 30 | "GPT4Tokenizer encoding of {} character string completed in: {:?}", 31 | text.len(), 32 | duration 33 | ); 34 | 35 | // Print the resulting tokens 36 | println!("{:?}", tokens); 37 | } 38 | -------------------------------------------------------------------------------- /examples/train.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | use std::path::Path; 3 | use std::time::Instant; 4 | 5 | use minbpe::BasicTokenizer; 6 | use minbpe::RegexTokenizerStruct; 7 | use minbpe::Saveable; 8 | use minbpe::Tokenizer; 9 | use minbpe::Trainable; 10 | 11 | fn main() { 12 | let text = fs::read_to_string("tests/taylorswift.txt").expect("Unable to read file"); 13 | 14 | fs::create_dir_all("models").expect("Unable to create models directory"); 15 | 16 | let basic = BasicTokenizer::new(); 17 | let regex = RegexTokenizerStruct::default(); 18 | 19 | fn doit(tokenizer: T, name: &str, text: &str) { 20 | let mut tokenizer = tokenizer; 21 | tokenizer.train(text, 512, true); 22 | 23 | let dir = Path::new("models").to_path_buf(); 24 | tokenizer.save(&dir, name); 25 | } 26 | 27 | let t0 = Instant::now(); 28 | doit(basic, "basic", &text); 29 | doit(regex, "regex", &text); 30 | let t1 = Instant::now(); 31 | 32 | let duration = t1.duration_since(t0); 33 | println!("Training took {:.2} seconds", duration.as_secs_f64()); 34 | } 35 | -------------------------------------------------------------------------------- /src/base.rs: -------------------------------------------------------------------------------- 1 | //! Contains the base Tokenizer struct and a few common helper functions. 2 | //! The base struct also contains the (common) save/load functionality. 3 | //! It would be possible to be a lot more strict about the interface and 4 | //! e.g. isolating all regex/pattern parts to the RegexTokenizer, but 5 | //! some concessions are made for simplicity. 6 | 7 | use std::io::Write; 8 | use std::path::Path; 9 | use std::{ 10 | fs::File, 11 | io::{BufRead, BufReader}, 12 | }; 13 | 14 | use indexmap::IndexMap; 15 | 16 | /// Token type to support up to 2^31 distinct tokens. It is signed in case a Tokenizer 17 | /// needs to use negative values for special tokens. 18 | pub type Token = i32; 19 | 20 | /// Count type to support up to 2^64 occurences of any token pair. 21 | pub type Count = u64; 22 | 23 | /// Base trait for Tokenizers to implement. 24 | pub trait Tokenizer { 25 | fn special_tokens(&self) -> &IndexMap; 26 | 27 | fn merges(&self) -> &IndexMap<(Token, Token), Token>; 28 | 29 | fn vocab(&self) -> &IndexMap>; 30 | 31 | /// A Tokenizer can encode a string into a list of integers. 32 | fn encode(&self, text: &str) -> Vec; 33 | 34 | /// A Tokenizer can decode a list of integers into a string. 35 | fn decode(&self, ids: &[Token]) -> String; 36 | } 37 | 38 | /// A Tokenizer that can be trained. 39 | pub trait Trainable: Tokenizer { 40 | /// Train a vocabulary of size `vocab_size` in distinct Tokens from `text`. 41 | fn train(&mut self, text: &str, vocab_size: Token, verbose: bool); 42 | } 43 | 44 | pub trait Saveable: Tokenizer { 45 | fn pattern(&self) -> &str; 46 | 47 | /// Saves the tokenizer's model and vocabulary to two files: 48 | /// - `file_prefix.model`: The model file used for loading the tokenizer. 49 | /// - `file_prefix.vocab`: A human-readable version of the vocabulary for inspection. 50 | /// 51 | /// This is inspired by (but not equivalent to) SentencePiece's model saving. 52 | /// 53 | /// # Arguments 54 | /// 55 | /// * `dir` - The path to the output directory. 56 | /// * `prefix` - The prefix for the output file name. 57 | /// 58 | /// # Examples 59 | /// 60 | /// ``` 61 | /// # use tempfile::tempdir; 62 | /// use minbpe::Saveable; 63 | /// use minbpe::Tokenizer; 64 | /// use minbpe::BasicTokenizer; 65 | /// let tokenizer = BasicTokenizer::new(); 66 | /// let dir = tempdir().unwrap(); 67 | /// let path = dir.path(); 68 | /// tokenizer.save(&path, "prefix"); 69 | /// ``` 70 | fn save(&self, dir: &Path, prefix: &str) { 71 | // let dir = dir.as_ref(); 72 | 73 | // Write the model file (used for loading the tokenizer later) 74 | let model_file_path = dir.join(format!("{}.model", prefix)); 75 | let mut model_file = File::create(model_file_path).expect("Unable to create model file"); 76 | 77 | // Write the version, pattern, and merges 78 | writeln!(model_file, "minbpe v1").expect("Unable to write to model file"); 79 | writeln!(model_file, "{}", self.pattern()).expect("Unable to write to model file"); 80 | 81 | // Write the special tokens (first the number, then each token and its index) 82 | writeln!(model_file, "{}", self.special_tokens().len()) 83 | .expect("Unable to write to model file"); 84 | for (special, idx) in self.special_tokens() { 85 | writeln!(model_file, "{} {}", special, idx).expect("Unable to write to model file"); 86 | } 87 | 88 | let mut merges: Vec<(&(Token, Token), &Token)> = self.merges().iter().collect(); 89 | merges.sort_by_key(|&k| k.1); 90 | 91 | // Write the merges dictionary 92 | for (token_pair, _new_token) in merges { 93 | writeln!(model_file, "{} {}", token_pair.0, token_pair.1) 94 | .expect("Unable to write to model file"); 95 | } 96 | 97 | // Write the vocabulary file (for human inspection) 98 | let vocab_file_path = dir.join(format!("{}.vocab", prefix)); 99 | let mut vocab_file = File::create(vocab_file_path).expect("Unable to create vocab file"); 100 | 101 | // Invert the merges dictionary for easier lookup 102 | let inverted_merges: IndexMap = self 103 | .merges() 104 | .iter() 105 | .map(|((idx1, idx2), idx)| (*idx, (*idx1, *idx2))) 106 | .collect(); 107 | 108 | let vocab = self.vocab(); 109 | 110 | for (idx, token) in vocab { 111 | // Render the token, replacing invalid UTF-8 sequences with the replacement character 112 | let s = render_token(token); 113 | 114 | if let Some((idx0, idx1)) = inverted_merges.get(idx) { 115 | // If the token has children, render it as a merge 116 | let s0 = render_token(&vocab[idx0]); 117 | let s1 = render_token(&vocab[idx1]); 118 | writeln!(vocab_file, "[{}][{}] -> [{}] {}", s0, s1, s, idx) 119 | .expect("Unable to write to vocab file"); 120 | } else { 121 | // Otherwise, it's a leaf token (one of the first 256 bytes) 122 | writeln!(vocab_file, "[{}] {}", s, idx).expect("Unable to write to vocab file"); 123 | } 124 | } 125 | } 126 | } 127 | 128 | pub trait Loadable: Tokenizer { 129 | fn set_pattern(&mut self, pattern: &str); 130 | fn set_special_tokens(&mut self, special_tokens: IndexMap); 131 | fn set_merges(&mut self, merges: IndexMap<(Token, Token), Token>); 132 | fn set_vocab(&mut self, vocab: IndexMap>); 133 | 134 | /// Loads the tokenizer's model from a file. 135 | /// 136 | /// This is the inverse of `save` but only for the model file. 137 | /// 138 | /// # Arguments 139 | /// 140 | /// * `model_file` - The path to the model file. 141 | /// 142 | /// # Panics 143 | /// 144 | /// Panics if the model file does not have a ".model" extension or if the file format is invalid. 145 | /// 146 | /// # Examples 147 | /// 148 | /// ``` 149 | /// use std::path::PathBuf; 150 | /// use minbpe::Loadable; 151 | /// use minbpe::Tokenizer; 152 | /// use minbpe::BasicTokenizer; 153 | /// let mut tokenizer = BasicTokenizer::new(); 154 | /// let model_path = PathBuf::from("examples/basic_example.model"); 155 | /// tokenizer.load(&model_path); 156 | /// ``` 157 | fn load(&mut self, model_file: &Path) { 158 | // FIXME: Return a Result instead of panicking 159 | // let model_file = model_file.as_ref(); 160 | assert!( 161 | model_file.extension().map_or(false, |ext| ext == "model"), 162 | "Model file must have a .model extension" 163 | ); 164 | 165 | let mut merges: IndexMap<(Token, Token), Token> = IndexMap::new(); 166 | let mut special_tokens: IndexMap = IndexMap::new(); 167 | let mut idx: Token = 256; 168 | 169 | let file = File::open(model_file).expect("Unable to open model file"); 170 | let reader = BufReader::new(file); 171 | 172 | let lines: Vec = reader 173 | .lines() 174 | .map(|line| line.expect("Unable to read line from model file")) 175 | .collect(); 176 | 177 | let mut line_iter = lines.iter(); 178 | 179 | if let Some(version) = line_iter.next() { 180 | assert_eq!(version, "minbpe v1", "Invalid model file version"); 181 | } else { 182 | panic!("Missing version line in model file"); 183 | } 184 | 185 | // FIXME: Check whether Tokenizer supports a Pattern at all. 186 | 187 | if let Some(pattern) = line_iter.next() { 188 | self.set_pattern(pattern); 189 | } else { 190 | panic!("Missing pattern line in model file"); 191 | } 192 | 193 | if let Some(num_special_str) = line_iter.next() { 194 | let num_special = num_special_str 195 | .parse::() 196 | .expect("Invalid number of special tokens"); 197 | 198 | // FIXME: Check whether Tokenizer supports Special Tokens at all. 199 | // FIXME: Ensure it is >= 0 because Token type is signed. 200 | // FIXME: Enforce some reasonable maximum less than 2^31. 201 | 202 | for _ in 0..num_special { 203 | if let Some(special_line) = line_iter.next() { 204 | let mut parts = special_line.split_whitespace(); 205 | let special = parts.next().expect("Missing special token").to_string(); 206 | let special_idx = parts 207 | .next() 208 | .expect("Missing special token index") 209 | .parse::() 210 | .expect("Invalid special token index"); 211 | special_tokens.insert(special, special_idx); 212 | } else { 213 | panic!("Missing special token line in model file"); 214 | } 215 | } 216 | } else { 217 | panic!("Missing number of special tokens line in model file"); 218 | } 219 | 220 | for merge_line in line_iter { 221 | let mut parts = merge_line.split_whitespace(); 222 | let idx1 = parts 223 | .next() 224 | .expect("Missing first index") 225 | .parse::() 226 | .expect("Invalid first index"); 227 | let idx2 = parts 228 | .next() 229 | .expect("Missing second index") 230 | .parse::() 231 | .expect("Invalid second index"); 232 | merges.insert((idx1, idx2), idx); 233 | idx += 1; 234 | } 235 | 236 | let vocab = build_vocab(&special_tokens, &merges); 237 | 238 | self.set_special_tokens(special_tokens); 239 | self.set_merges(merges); 240 | self.set_vocab(vocab); 241 | } 242 | } 243 | 244 | /// Additional operations for Tokenizers. 245 | /// Given a slice of integers, returns a new `IndexMap` containing the counts of consecutive pairs. 246 | /// 247 | /// Example: 248 | /// ``` 249 | /// # use indexmap::IndexMap; 250 | /// # use minbpe::get_stats; 251 | /// let ids = vec![1, 2, 3, 1, 2]; 252 | /// let counts = get_stats(&ids); 253 | /// assert_eq!(counts, IndexMap::from([((1, 2), 2), ((2, 3), 1), ((3, 1), 1)])); 254 | /// ``` 255 | pub fn get_stats(ids: &[Token]) -> IndexMap<(Token, Token), Count> { 256 | let mut counts = IndexMap::new(); 257 | update_stats(ids, &mut counts); 258 | counts 259 | } 260 | 261 | /// Updates an existing `IndexMap` with the counts of consecutive pairs from the given slice of integers. 262 | /// 263 | /// Example: 264 | /// ``` 265 | /// # use indexmap::IndexMap; 266 | /// # use minbpe::update_stats; 267 | /// let ids = vec![1, 2, 3, 1, 2]; 268 | /// let mut existing_counts = IndexMap::from([((1, 2), 1), ((2, 3), 1)]); 269 | /// update_stats(&ids, &mut existing_counts); 270 | /// assert_eq!(existing_counts, IndexMap::from([((1, 2), 3), ((2, 3), 2), ((3, 1), 1)])); 271 | /// ``` 272 | pub fn update_stats(ids: &[Token], counts: &mut IndexMap<(Token, Token), Count>) { 273 | for pair in ids.windows(2) { 274 | let pair = (pair[0], pair[1]); 275 | *counts.entry(pair).or_insert(0) += 1; 276 | } 277 | } 278 | 279 | /// Given an `IndexMap` of consecutive pair counts, returns the pair with the highest count. This 280 | /// technique preserves the insertion order of the pairs that IndexMap maintains, returning the 281 | /// first-inserted pair with the highest count. 282 | pub fn get_max_entry(stats: &IndexMap<(Token, Token), Count>) -> Option<(&(Token, Token), &Count)> { 283 | let mut max_entry = None; 284 | 285 | for entry in stats.iter() { 286 | match max_entry { 287 | None => max_entry = Some(entry), 288 | Some((_, max_count)) => { 289 | let (_, count) = entry; 290 | if count > max_count { 291 | max_entry = Some(entry); 292 | } 293 | } 294 | } 295 | } 296 | 297 | max_entry 298 | } 299 | 300 | /// Merges consecutive occurrences of a pair of integers in the given slice, 301 | /// replacing them with a new integer. 302 | /// 303 | /// Arguments: 304 | /// - `ids`: The slice of Tokens to merge. 305 | /// - `pair`: The pair of consecutive integers to replace. 306 | /// - `new_id`: The new integer to replace the consecutive pairs with. 307 | /// 308 | /// Returns: 309 | /// A new `Vec` with the merged Tokens. 310 | /// 311 | /// Example: 312 | /// ``` 313 | /// # use minbpe::merge; 314 | /// let ids = vec![1, 2, 3, 1, 2]; 315 | /// let pair = (1, 2); 316 | /// let new_id = 4; 317 | /// let merged = merge(&ids, pair, new_id); 318 | /// assert_eq!(merged, vec![4, 3, 4]); 319 | /// ``` 320 | pub fn merge(ids: &[Token], pair: (Token, Token), new_id: Token) -> Vec { 321 | let mut new_ids = Vec::with_capacity(ids.len()); 322 | let mut i = 0; 323 | 324 | while i < ids.len() { 325 | if i < ids.len() - 1 && ids[i] == pair.0 && ids[i + 1] == pair.1 { 326 | new_ids.push(new_id); 327 | i += 2; 328 | } else { 329 | new_ids.push(ids[i]); 330 | i += 1; 331 | } 332 | } 333 | 334 | new_ids 335 | } 336 | 337 | /// vocab is simply and deterministically derived from merges 338 | pub fn build_vocab( 339 | special_tokens: &IndexMap, 340 | merges: &IndexMap<(Token, Token), Token>, 341 | ) -> IndexMap> { 342 | let mut vocab: IndexMap> = (0..256).map(|idx| (idx, vec![idx as u8])).collect(); 343 | 344 | for ((p0, p1), idx) in merges { 345 | let mut token = vocab[p0].clone(); 346 | token.extend_from_slice(&vocab[p1]); 347 | vocab.insert(*idx, token); 348 | } 349 | 350 | for (special, idx) in special_tokens { 351 | vocab.insert(*idx, special.as_bytes().to_vec()); 352 | } 353 | 354 | vocab 355 | } 356 | 357 | /// Replaces control characters in the given string with their Unicode escape sequences. 358 | /// 359 | /// Control characters are characters that distort the output, such as newline ('\n') or 360 | /// other characters that fall under the Unicode category "C" (Other). 361 | /// 362 | /// References: 363 | /// - https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 364 | /// - http://www.unicode.org/reports/tr44/#GC_Values_Table 365 | /// 366 | /// Arguments: 367 | /// - `s`: The string to process. 368 | /// 369 | /// Returns: 370 | /// A new `String` with control characters replaced by their Unicode escape sequences. 371 | /// 372 | /// Example: 373 | /// ```ignore 374 | /// # use minbpe::tokenizer::replace_control_characters; 375 | /// let s = "Hello\nWorld\u{7}!"; 376 | /// let result = replace_control_characters(s); 377 | /// assert_eq!(result, "Hello\\u000aWorld\\u0007!"); 378 | /// ``` 379 | fn replace_control_characters(s: &str) -> String { 380 | let mut chars = String::with_capacity(s.len()); 381 | 382 | for ch in s.chars() { 383 | if ch.is_control() { 384 | let escaped = format!("\\u{:04x}", ch as u32); 385 | chars.push_str(&escaped); 386 | } else { 387 | chars.push(ch); 388 | } 389 | } 390 | 391 | chars 392 | } 393 | 394 | /// Pretty-prints a token by decoding it as UTF-8 and escaping control characters. 395 | /// 396 | /// Arguments: 397 | /// - `token`: The token as a byte slice. 398 | /// 399 | /// Returns: 400 | /// A `String` representation of the token with control characters escaped. 401 | /// 402 | /// Example: 403 | /// ```ignore 404 | /// # use minbpe::tokenizer::render_token; 405 | /// let token = b"Hello\nWorld\x07!"; 406 | /// let result = render_token(token); 407 | /// assert_eq!(result, "Hello\\u000aWorld\\u0007!"); 408 | /// ``` 409 | fn render_token(token: &[u8]) -> String { 410 | let s = String::from_utf8_lossy(token); 411 | replace_control_characters(&s) 412 | } 413 | 414 | #[cfg(test)] 415 | mod tests { 416 | use super::*; 417 | 418 | #[test] 419 | fn test_replace_control_characters() { 420 | let s = "Hello\nWorld\u{7}!"; 421 | let result = replace_control_characters(s); 422 | assert_eq!(result, "Hello\\u000aWorld\\u0007!"); 423 | } 424 | 425 | #[test] 426 | fn test_render_token() { 427 | let token = b"Hello\nWorld\x07!"; 428 | let result = render_token(token); 429 | assert_eq!(result, "Hello\\u000aWorld\\u0007!"); 430 | } 431 | 432 | #[test] 433 | fn test_indexmap_order() { 434 | let input_data: Vec<((Token, Token), Count)> = vec![ 435 | ((0, 0), 2), 436 | ((1, 1), 12), 437 | ((2, 2), 18), 438 | ((3, 3), 11), 439 | ((4, 4), 1), 440 | ((5, 5), 9), 441 | ((6, 6), 99), 442 | ((7, 7), 7), 443 | ((8, 8), 20), 444 | ((9, 9), 99), 445 | ((10, 10), 99), 446 | ((11, 11), 99), 447 | ((12, 12), 4), 448 | ((13, 13), 99), 449 | ((14, 14), 19), 450 | ((15, 15), 99), 451 | ((16, 16), 5), 452 | ((17, 17), 99), 453 | ((18, 18), 99), 454 | ((19, 19), 7), 455 | ]; 456 | 457 | let expected_max_key: (Token, Token) = (6, 6); 458 | 459 | let stats: IndexMap<(Token, Token), Count> = IndexMap::from_iter(input_data.clone()); 460 | 461 | let keys: Vec<_> = stats.keys().collect(); 462 | let input_keys: Vec<_> = input_data.iter().map(|(k, _)| k).collect(); 463 | 464 | assert_eq!(keys, input_keys, "Keys are not in insertion order"); 465 | 466 | let entries: Vec<_> = stats.iter().map(|(k, v)| (*k, *v)).collect(); 467 | assert_eq!( 468 | entries, 469 | input_data.as_slice(), 470 | "Entries are not in insertion order" 471 | ); 472 | 473 | let max_entry = get_max_entry(&stats); 474 | 475 | let pair = max_entry.expect("Stats is empty"); 476 | 477 | assert_eq!(*pair.0, expected_max_key); 478 | } 479 | } 480 | -------------------------------------------------------------------------------- /src/basic.rs: -------------------------------------------------------------------------------- 1 | use indexmap::IndexMap; 2 | 3 | use crate::base::{ 4 | get_max_entry, get_stats, merge, Loadable, Saveable, Token, Tokenizer, Trainable, 5 | }; 6 | 7 | /// Minimal (byte-level) Byte Pair Encoding tokenizer. 8 | /// 9 | /// Algorithmically follows along the GPT tokenizer: 10 | /// https://github.com/openai/gpt-2/blob/master/src/encoder.py 11 | /// 12 | /// But: 13 | /// - Does not handle the regular expression splitting pattern. 14 | /// - Does not handle any special tokens. 15 | /// 16 | /// # Examples 17 | /// 18 | /// ``` 19 | /// use minbpe::BasicTokenizer; 20 | /// use minbpe::Tokenizer; 21 | /// use minbpe::Trainable; 22 | /// 23 | /// let mut tokenizer = BasicTokenizer::new(); 24 | /// let text = "Hello, world!"; 25 | /// let vocab_size = 256; 26 | /// let verbose = true; 27 | /// 28 | /// tokenizer.train(text, vocab_size, verbose); 29 | /// let encoded = tokenizer.encode(text); 30 | /// let decoded = tokenizer.decode(&encoded); 31 | /// 32 | /// assert_eq!(text, decoded); 33 | /// ``` 34 | pub struct BasicTokenizer { 35 | special_tokens: IndexMap, 36 | merges: IndexMap<(Token, Token), Token>, 37 | vocab: IndexMap>, 38 | } 39 | 40 | impl BasicTokenizer { 41 | pub fn new() -> Self { 42 | BasicTokenizer { 43 | special_tokens: IndexMap::new(), 44 | merges: IndexMap::new(), 45 | vocab: IndexMap::new(), 46 | } 47 | } 48 | } 49 | 50 | impl Default for BasicTokenizer { 51 | fn default() -> Self { 52 | Self::new() 53 | } 54 | } 55 | 56 | impl Tokenizer for BasicTokenizer { 57 | fn special_tokens(&self) -> &IndexMap { 58 | &self.special_tokens 59 | } 60 | 61 | fn merges(&self) -> &IndexMap<(Token, Token), Token> { 62 | &self.merges 63 | } 64 | 65 | fn vocab(&self) -> &IndexMap> { 66 | &self.vocab 67 | } 68 | 69 | fn decode(&self, ids: &[Token]) -> String { 70 | // Given ids (list of integers), return Rust string 71 | let text_bytes: Vec = ids 72 | .iter() 73 | .flat_map(|&idx| self.vocab[&idx].clone()) 74 | .collect(); 75 | String::from_utf8_lossy(&text_bytes).into_owned() 76 | } 77 | 78 | fn encode(&self, text: &str) -> Vec { 79 | // Given a string text, return the token ids 80 | let text_bytes = text.as_bytes(); 81 | let mut ids: Vec = text_bytes.iter().map(|&b| b as Token).collect(); 82 | while ids.len() >= 2 { 83 | // Find the pair with the lowest merge index 84 | let stats = get_stats(&ids); 85 | 86 | let pair_opt = stats 87 | .keys() 88 | .filter_map(|&pair| self.merges.get(&pair).map(|_| pair)) 89 | .min_by_key(|&pair| self.merges[&pair]); 90 | 91 | match pair_opt { 92 | None => break, // If there are no more merges available, break 93 | Some(pair) => { 94 | // Otherwise, merge the best pair (lowest merge index) 95 | let idx = self.merges[&pair]; 96 | ids = merge(&ids, pair, idx); 97 | } 98 | }; 99 | } 100 | ids 101 | } 102 | } 103 | 104 | impl Trainable for BasicTokenizer { 105 | fn train(&mut self, text: &str, vocab_size: Token, verbose: bool) { 106 | assert!(vocab_size >= 256, "Vocab size must be at least 256"); 107 | let num_merges = vocab_size - 256; 108 | 109 | // Input text preprocessing 110 | let text_bytes = text.as_bytes(); 111 | let mut ids: Vec = text_bytes.iter().map(|&b| b as Token).collect(); 112 | 113 | // Iteratively merge the most common pairs to create new tokens 114 | let mut merges: IndexMap<(Token, Token), Token> = IndexMap::new(); 115 | let mut vocab: IndexMap> = 116 | (0..256).map(|idx| (idx, vec![idx as u8])).collect(); 117 | for i in 0..num_merges { 118 | // Count up the number of times every consecutive pair appears 119 | let stats = get_stats(&ids); 120 | // Find the pair with the highest count 121 | let pair = get_max_entry(&stats).unwrap().0; 122 | // Mint a new token: assign it the next available id 123 | let idx = 256 + i; 124 | // Replace all occurrences of pair in ids with idx 125 | ids = merge(&ids, *pair, idx); 126 | // Save the merge 127 | merges.insert(*pair, idx); 128 | vocab.insert( 129 | idx, 130 | [vocab[&pair.0].clone(), vocab[&pair.1].clone()].concat(), 131 | ); 132 | // Prints 133 | if verbose { 134 | println!( 135 | "merge {}/{}: {:?} -> {} ({:?}) had {} occurrences", 136 | i + 1, 137 | num_merges, 138 | pair, 139 | idx, 140 | vocab[&idx], 141 | stats[pair] 142 | ); 143 | } 144 | } 145 | 146 | // Save instance variables 147 | self.merges = merges; 148 | self.vocab = vocab; // FIXME: vs. build_vocab(&self.special_tokens, &self.merges); 149 | } 150 | } 151 | 152 | impl Saveable for BasicTokenizer { 153 | fn pattern(&self) -> &str { 154 | "" 155 | } 156 | } 157 | 158 | impl Loadable for BasicTokenizer { 159 | fn set_pattern(&mut self, pattern: &str) { 160 | let temp = pattern.trim(); 161 | 162 | if !temp.is_empty() { 163 | panic!("Cannot set a non-empty pattern!") 164 | } 165 | } 166 | 167 | fn set_special_tokens(&mut self, special_tokens: IndexMap) { 168 | self.special_tokens = special_tokens; 169 | } 170 | 171 | fn set_merges(&mut self, merges: IndexMap<(Token, Token), Token>) { 172 | self.merges = merges; 173 | } 174 | 175 | fn set_vocab(&mut self, vocab: IndexMap>) { 176 | self.vocab = vocab; 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /src/gpt4.rs: -------------------------------------------------------------------------------- 1 | use base64::{engine::general_purpose, Engine as _}; 2 | use core::panic; 3 | use fancy_regex::Regex; 4 | use indexmap::IndexMap; 5 | use lazy_static::lazy_static; 6 | 7 | use crate::{RegexTokenizerTrait, Token, Tokenizer}; 8 | 9 | const GPT4_SPLIT_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; 10 | 11 | lazy_static! { 12 | static ref GPT4_SPLIT_COMPILED_PATTERN: Regex = Regex::new(GPT4_SPLIT_PATTERN).unwrap(); 13 | } 14 | 15 | lazy_static! { 16 | static ref GPT4_SPECIAL_TOKENS: IndexMap<&'static str, Token> = { 17 | let mut map = IndexMap::new(); 18 | map.insert("<|endoftext|>", 100257); 19 | map.insert("<|fim_prefix|>", 100258); 20 | map.insert("<|fim_middle|>", 100259); 21 | map.insert("<|fim_suffix|>", 100260); 22 | map.insert("<|endofprompt|>", 100276); 23 | map 24 | }; 25 | } 26 | 27 | // We need this because tiktoken-rs does not expose the encoder and we need to recover the merges. If it did, we would 28 | // use tiktoken_rs::cl100k_base() and get the encoder from there. 29 | lazy_static! { 30 | static ref GPT4_MERGEABLE_RANKS: IndexMap, Token> = { 31 | // https://github.com/zurawiki/tiktoken-rs/blob/main/tiktoken-rs/assets/cl100k_base.tiktoken 32 | let cl100k_base: &str = include_str!("../assets/cl100k_base.tiktoken"); 33 | 34 | // Also from tiktoken-rs's constructor 35 | let mut encoder = IndexMap::default(); 36 | for line in cl100k_base.lines() { 37 | let mut parts = line.split(' '); 38 | let raw = parts.next().unwrap(); 39 | let token = &general_purpose::STANDARD.decode(raw).unwrap(); 40 | let rank: Token = parts.next().unwrap().parse().unwrap(); 41 | if rank < 0 { 42 | panic!("Rank {} for token {:?} is negative", rank, token); 43 | } 44 | encoder.insert(token.clone(), rank); 45 | } 46 | encoder 47 | }; 48 | } 49 | 50 | fn bpe( 51 | mergeable_ranks: &IndexMap, Token>, 52 | token: &[u8], 53 | max_rank: Option, 54 | ) -> Vec> { 55 | let mut parts: Vec> = token.iter().map(|&b| vec![b]).collect(); 56 | loop { 57 | let mut min_idx = None; 58 | let mut min_rank = None; 59 | for (i, pair) in parts.windows(2).enumerate() { 60 | let rank = mergeable_ranks.get(&[pair[0].clone(), pair[1].clone()].concat()); 61 | if let Some(rank) = rank { 62 | if min_rank.is_none() || rank < min_rank.unwrap() { 63 | min_idx = Some(i); 64 | min_rank = Some(rank); 65 | } 66 | } 67 | } 68 | if min_rank.is_none() || (max_rank.is_some() && *min_rank.unwrap() >= max_rank.unwrap()) { 69 | break; 70 | } 71 | let min_idx = min_idx.unwrap(); 72 | parts[min_idx] = [parts[min_idx].clone(), parts[min_idx + 1].clone()].concat(); 73 | parts.remove(min_idx + 1); 74 | } 75 | parts 76 | } 77 | 78 | fn recover_merges(mergeable_ranks: &IndexMap, Token>) -> IndexMap<(Token, Token), Token> { 79 | let mut merges = IndexMap::new(); 80 | for (token, &rank) in mergeable_ranks { 81 | if token.len() == 1 { 82 | continue; 83 | } 84 | let pair = bpe(mergeable_ranks, token, Some(rank)); 85 | assert_eq!(pair.len(), 2); 86 | let ix0 = mergeable_ranks[&pair[0]]; 87 | let ix1 = mergeable_ranks[&pair[1]]; 88 | merges.insert((ix0, ix1), rank); 89 | } 90 | merges 91 | } 92 | 93 | /// Does not implement Tokenizer trait because it cannot be trained, loaded or saved. 94 | pub struct GPT4Tokenizer { 95 | special_tokens: IndexMap, 96 | inverse_special_tokens: IndexMap, 97 | merges: IndexMap<(Token, Token), Token>, 98 | vocab: IndexMap>, 99 | 100 | byte_shuffle: IndexMap, 101 | inverse_byte_shuffle: IndexMap, 102 | } 103 | 104 | impl Default for GPT4Tokenizer { 105 | fn default() -> Self { 106 | Self::new() 107 | } 108 | } 109 | 110 | impl GPT4Tokenizer { 111 | /// This method may be called before any other method in this module, in case you want to ensure all the 112 | /// lazy static initializations are done before any other operation. 113 | pub fn initialize() { 114 | let _ = &*GPT4_SPLIT_COMPILED_PATTERN; 115 | let _ = &*GPT4_MERGEABLE_RANKS; 116 | } 117 | 118 | pub fn new() -> Self { 119 | // let enc = cl100k_base().unwrap(); 120 | let mergeable_ranks = &GPT4_MERGEABLE_RANKS; 121 | let merges = recover_merges(mergeable_ranks); 122 | let mut vocab: IndexMap> = 123 | (0..=255).map(|i| (i as Token, vec![i])).collect(); 124 | for (&(p0, p1), &idx) in &merges { 125 | let mut token = vocab[&p0].clone(); 126 | token.extend(vocab[&p1].clone()); 127 | vocab.insert(idx, token); 128 | } 129 | let byte_shuffle: IndexMap = (0..=255) 130 | .map(|i| { 131 | let value = mergeable_ranks[&vec![i]]; 132 | if value < 0 || value > u8::MAX as Token { 133 | panic!( 134 | "Value {} for key {} in mergeable_ranks does not fit in u8", 135 | value, i 136 | ); 137 | } 138 | (i, value as u8) 139 | }) 140 | .collect(); 141 | let inverse_byte_shuffle: IndexMap = 142 | byte_shuffle.iter().map(|(&k, &v)| (v, k)).collect(); 143 | let special_tokens = GPT4_SPECIAL_TOKENS 144 | .iter() 145 | .map(|(&k, &v)| (k.to_string(), v)) 146 | .collect::>(); 147 | 148 | let inverse_special_tokens = special_tokens 149 | .iter() 150 | .map(|(k, v)| (*v, k.clone())) 151 | .collect(); 152 | 153 | GPT4Tokenizer { 154 | special_tokens, 155 | inverse_special_tokens, 156 | merges, 157 | vocab, 158 | 159 | byte_shuffle, 160 | inverse_byte_shuffle, 161 | } 162 | } 163 | 164 | pub fn decode(&self, ids: &[Token]) -> String { 165 | let text_bytes: Vec = ids 166 | .iter() 167 | .flat_map(|&idx| self.vocab[&idx].clone()) 168 | .collect(); 169 | let text_bytes: Vec = text_bytes 170 | .into_iter() 171 | .map(|b| self.inverse_byte_shuffle[&b]) 172 | .collect(); 173 | String::from_utf8_lossy(&text_bytes).to_string() 174 | } 175 | 176 | pub fn register_special_tokens_x(&mut self, tokens: &IndexMap) { 177 | self.special_tokens 178 | .extend(tokens.iter().map(|(k, &v)| (k.clone(), v))); 179 | 180 | self.inverse_special_tokens = self 181 | .special_tokens 182 | .iter() 183 | .map(|(k, v)| (*v, k.clone())) 184 | .collect(); 185 | } 186 | } 187 | 188 | impl Tokenizer for GPT4Tokenizer { 189 | fn special_tokens(&self) -> &IndexMap { 190 | &self.special_tokens 191 | } 192 | 193 | fn merges(&self) -> &IndexMap<(Token, Token), Token> { 194 | &self.merges 195 | } 196 | 197 | fn vocab(&self) -> &IndexMap> { 198 | &self.vocab 199 | } 200 | 201 | fn decode(&self, ids: &[Token]) -> String { 202 | let mut text = String::new(); 203 | for &id in ids { 204 | if let Some(token) = self.vocab.get(&id) { 205 | text.push_str(std::str::from_utf8(token).expect("Invalid UTF-8 sequence")); 206 | } else if let Some(token) = self.inverse_special_tokens.get(&id) { 207 | text.push_str(token); 208 | } 209 | } 210 | text 211 | } 212 | 213 | fn encode(&self, text: &str) -> Vec { 214 | RegexTokenizerTrait::encode(self, text) 215 | } 216 | } 217 | 218 | impl RegexTokenizerTrait for GPT4Tokenizer { 219 | fn encode_chunk(&self, text_bytes: &[u8]) -> Vec { 220 | let text_bytes: Vec = text_bytes.iter().map(|&b| self.byte_shuffle[&b]).collect(); 221 | ::encode_chunk_inner(self, &text_bytes) 222 | } 223 | 224 | fn compiled_pattern(&self) -> &Regex { 225 | &GPT4_SPLIT_COMPILED_PATTERN 226 | } 227 | 228 | fn inverse_special_tokens(&self) -> &IndexMap { 229 | &self.inverse_special_tokens 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod base; 2 | #[cfg(feature = "basic")] 3 | pub mod basic; 4 | #[cfg(feature = "gpt4")] 5 | pub mod gpt4; 6 | #[cfg(feature = "regex")] 7 | pub mod regex; 8 | 9 | pub mod test_common; 10 | 11 | pub use base::*; 12 | 13 | #[cfg(feature = "basic")] 14 | pub use basic::BasicTokenizer; 15 | 16 | #[cfg(feature = "regex")] 17 | pub use regex::{AllowedSpecial, RegexTokenizerStruct, RegexTokenizerTrait}; 18 | 19 | #[cfg(feature = "gpt4")] 20 | pub use gpt4::GPT4Tokenizer; 21 | -------------------------------------------------------------------------------- /src/regex.rs: -------------------------------------------------------------------------------- 1 | use fancy_regex::Regex; 2 | use indexmap::IndexMap; 3 | use std::collections::HashSet; 4 | 5 | use crate::{get_max_entry, Loadable, Saveable, Trainable}; 6 | use crate::{get_stats, merge, update_stats, Token, Tokenizer}; 7 | 8 | /// The main GPT text split patterns, see 9 | /// https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py 10 | pub const GPT2_SPLIT_PATTERN: &str = 11 | r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; 12 | 13 | pub const GPT4_SPLIT_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; 14 | 15 | /// Specifies how to handle special tokens during encoding. 16 | /// 17 | /// This enum is used to control the behavior of the `encode_special` function 18 | /// when encountering special tokens in the text. 19 | /// 20 | /// # Variants 21 | /// 22 | /// - `All`: Allow all special tokens during encoding. 23 | /// Special tokens will be encoded according to their corresponding token IDs. 24 | /// 25 | /// - `None`: Ignore all special tokens during encoding. 26 | /// Special tokens will be treated as regular text and encoded using the standard encoding process. 27 | /// 28 | /// - `NoneRaise`: Raise an error if any special token is encountered in the text during encoding. 29 | /// This is the default behavior of the `tiktoken` library. 30 | /// 31 | /// - `Set(HashSet)`: Allow only the special tokens specified in the provided `HashSet`. 32 | /// Special tokens not included in the set will be treated as regular text and encoded using the standard encoding process. 33 | /// 34 | /// # Examples 35 | /// 36 | /// ``` 37 | /// use minbpe::AllowedSpecial; 38 | /// use std::collections::HashSet; 39 | /// 40 | /// // Allow all special tokens 41 | /// let allowed_all = AllowedSpecial::All; 42 | /// 43 | /// // Ignore all special tokens 44 | /// let allowed_none = AllowedSpecial::None; 45 | /// 46 | /// // Raise an error if any special token is encountered 47 | /// let allowed_none_raise = AllowedSpecial::NoneRaise; 48 | /// 49 | /// // Allow only specific special tokens 50 | /// let custom_set = HashSet::from(["<|endoftext|>".to_string(), "<|startoftext|>".to_string()]); 51 | /// let allowed_custom = AllowedSpecial::Set(custom_set); 52 | /// ``` 53 | pub enum AllowedSpecial { 54 | All, 55 | None, 56 | NoneRaise, 57 | Set(HashSet), 58 | } 59 | 60 | pub trait RegexTokenizerTrait: Tokenizer { 61 | fn encode_chunk_inner(&self, text_bytes: &[u8]) -> Vec { 62 | let merges = self.merges(); 63 | let mut ids: Vec = text_bytes.iter().map(|&b| b as Token).collect(); 64 | while ids.len() >= 2 { 65 | // Find the pair with the lowest merge index 66 | let stats = get_stats(&ids); 67 | 68 | let pair_opt = stats 69 | .keys() 70 | .filter_map(|&pair| merges.get(&pair).map(|_| pair)) 71 | .min_by_key(|&pair| merges[&pair]); 72 | 73 | match pair_opt { 74 | None => break, // If there are no more merges available, break 75 | Some(pair) => { 76 | // Otherwise, merge the best pair (lowest merge index) 77 | let idx = merges[&pair]; 78 | ids = merge(&ids, pair, idx); 79 | } 80 | }; 81 | } 82 | ids 83 | } 84 | 85 | fn encode_chunk(&self, text_bytes: &[u8]) -> Vec { 86 | self.encode_chunk_inner(text_bytes) 87 | } 88 | 89 | // fn pattern(&self) -> &str; 90 | // fn set_pattern(&mut self, pattern: &str); 91 | 92 | fn compiled_pattern(&self) -> &Regex; 93 | 94 | // fn special_tokens(&self) -> &IndexMap; 95 | // fn set_special_tokens(&mut self, special_tokens: IndexMap); 96 | 97 | fn inverse_special_tokens(&self) -> &IndexMap; 98 | 99 | // fn merges(&self) -> &IndexMap<(Token, Token), Token>; 100 | // fn set_merges(&mut self, merges: IndexMap<(Token, Token), Token>); 101 | 102 | // fn vocab(&self) -> &IndexMap>; 103 | // fn set_vocab(&mut self, vocab: IndexMap>); 104 | 105 | // fn train(&mut self, text: &str, vocab_size: Token, verbose: bool); 106 | // fn decode(&self, ids: &[Token]) -> String; 107 | // fn encode(&self, text: &str) -> Vec; 108 | 109 | fn decode(&self, ids: &[Token]) -> String { 110 | let mut part_bytes = Vec::new(); 111 | for &idx in ids { 112 | if let Some(bytes) = self.vocab().get(&idx) { 113 | part_bytes.extend_from_slice(bytes); 114 | } else if let Some(special_token) = self.inverse_special_tokens().get(&idx) { 115 | part_bytes.extend_from_slice(special_token.as_bytes()); 116 | } else { 117 | panic!("Invalid token id: {}", idx); 118 | } 119 | } 120 | String::from_utf8_lossy(&part_bytes).into_owned() 121 | } 122 | 123 | fn encode(&self, text: &str) -> Vec { 124 | self.encode_special(text, AllowedSpecial::NoneRaise) 125 | } 126 | 127 | /// Encoding that ignores any special tokens. 128 | fn encode_ordinary(&self, text: &str) -> Vec { 129 | let text_chunks: Vec<&str> = self 130 | .compiled_pattern() 131 | .find_iter(text) 132 | .map(|m| { 133 | let matched = m.unwrap(); 134 | &text[matched.start()..matched.end()] 135 | }) 136 | .collect(); 137 | let mut ids = Vec::new(); 138 | for chunk in text_chunks { 139 | let chunk_bytes = chunk.as_bytes(); 140 | let chunk_ids = self.encode_chunk(chunk_bytes); 141 | ids.extend(chunk_ids); 142 | } 143 | ids 144 | } 145 | 146 | /// Encodes the given text into token IDs, handling special tokens. 147 | /// 148 | /// Unlike `encode_ordinary`, this function handles special tokens based on the `allowed_special` parameter. 149 | /// 150 | /// # Arguments 151 | /// 152 | /// * `text` - The text to encode. 153 | /// * `allowed_special` - Specifies how to handle special tokens. It can be one of the following: 154 | /// - `AllowedSpecial::All`: Allow all special tokens. 155 | /// - `AllowedSpecial::None`: Ignore all special tokens. 156 | /// - `AllowedSpecial::NoneRaise`: Raise an error if any special token is encountered in the text. 157 | /// This is the default behavior of the `tiktoken` library. 158 | /// - `AllowedSpecial::Set(HashSet)`: A custom set of allowed special tokens. 159 | /// 160 | /// # Panics 161 | /// 162 | /// Panics if `allowed_special` is set to `AllowedSpecial::NoneRaise` and any special token is encountered in the text. 163 | fn encode_special(&self, text: &str, allowed_special: AllowedSpecial) -> Vec { 164 | let special = match allowed_special { 165 | AllowedSpecial::All => self.special_tokens().clone(), 166 | AllowedSpecial::None => IndexMap::new(), 167 | AllowedSpecial::NoneRaise => { 168 | assert!( 169 | self.special_tokens() 170 | .keys() 171 | .all(|token| !text.contains(token)), 172 | "Special token found in text" 173 | ); 174 | IndexMap::new() 175 | } 176 | AllowedSpecial::Set(special_tokens) => { 177 | let mut special = IndexMap::new(); 178 | for token in special_tokens { 179 | if let Some(&idx) = self.special_tokens().get(&token) { 180 | special.insert(token, idx); 181 | } 182 | } 183 | special 184 | } 185 | }; 186 | 187 | if special.is_empty() { 188 | return self.encode_ordinary(text); 189 | } 190 | 191 | let special_pattern = "(".to_string() 192 | + &special 193 | .keys() 194 | .map(|k| regex::escape(k)) 195 | .collect::>() 196 | .join("|") 197 | + ")"; 198 | 199 | let re = fancy_regex::Regex::new(&special_pattern).unwrap(); 200 | let mut last_end = 0; 201 | let mut special_chunks = Vec::new(); 202 | for m in re.find_iter(text) { 203 | let m = m.unwrap(); 204 | // Push the text between matches 205 | special_chunks.push(&text[last_end..m.start()]); 206 | // Push the matched text 207 | special_chunks.push(&text[m.start()..m.end()]); 208 | last_end = m.end(); 209 | } 210 | let remaining = &text[last_end..]; 211 | if !remaining.is_empty() { 212 | special_chunks.push(remaining); 213 | } 214 | 215 | let mut ids = Vec::new(); 216 | for part in special_chunks { 217 | if let Some(&idx) = special.get(part) { 218 | ids.push(idx); 219 | } else { 220 | ids.extend(self.encode_ordinary(part)); 221 | } 222 | } 223 | ids 224 | } 225 | } 226 | 227 | /// Minimal (byte-level) Byte Pair Encoding tokenizer. 228 | /// 229 | /// Algorithmically follows along the GPT tokenizer: 230 | /// https://github.com/openai/gpt-2/blob/master/src/encoder.py 231 | /// 232 | /// Unlike `BasicTokenizer`: 233 | /// - `RegexTokenizer` handles an optional regex splitting pattern. 234 | /// - `RegexTokenizer` handles optional special tokens. 235 | /// 236 | /// # Examples 237 | /// 238 | /// ``` 239 | /// use fancy_regex::Regex; 240 | /// use minbpe::base::Loadable; 241 | /// use minbpe::base::Tokenizer; 242 | /// use minbpe::base::Trainable; 243 | /// use minbpe::RegexTokenizerStruct; 244 | /// use minbpe::RegexTokenizerTrait; 245 | /// use minbpe::AllowedSpecial; 246 | /// use indexmap::IndexMap; 247 | /// 248 | /// let pattern = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; 249 | /// let mut tokenizer = RegexTokenizerStruct::new(pattern.to_string()); 250 | /// let special_tokens = IndexMap::from([("<|endoftext|>".to_string(), 100257)]); 251 | /// tokenizer.set_special_tokens(special_tokens); 252 | /// 253 | /// let text = "Hello, world! This is a test."; 254 | /// let vocab_size = 256 + 10; 255 | /// let verbose = true; 256 | /// 257 | /// tokenizer.train(text, vocab_size, verbose); 258 | /// 259 | /// let encoded = tokenizer.encode_special(text, AllowedSpecial::NoneRaise); 260 | /// let decoded = RegexTokenizerTrait::decode(&tokenizer, &encoded); 261 | /// 262 | /// assert_eq!(text, decoded); 263 | /// ``` 264 | pub struct RegexTokenizerStruct { 265 | pattern: String, 266 | compiled_pattern: Regex, 267 | special_tokens: IndexMap, 268 | inverse_special_tokens: IndexMap, 269 | merges: IndexMap<(Token, Token), Token>, 270 | vocab: IndexMap>, 271 | } 272 | 273 | impl Default for RegexTokenizerStruct { 274 | fn default() -> Self { 275 | Self::new(GPT4_SPLIT_PATTERN.to_string()) 276 | } 277 | } 278 | 279 | impl RegexTokenizerStruct { 280 | fn make(pattern: String) -> Self { 281 | let compiled_pattern = Regex::new(&pattern).unwrap(); 282 | 283 | RegexTokenizerStruct { 284 | pattern, 285 | compiled_pattern, 286 | special_tokens: IndexMap::new(), 287 | inverse_special_tokens: IndexMap::new(), 288 | merges: IndexMap::new(), 289 | vocab: IndexMap::new(), 290 | } 291 | } 292 | 293 | pub fn new(pattern: String) -> Self { 294 | Self::make(pattern) 295 | } 296 | } 297 | 298 | impl Tokenizer for RegexTokenizerStruct { 299 | fn special_tokens(&self) -> &IndexMap { 300 | &self.special_tokens 301 | } 302 | 303 | fn merges(&self) -> &IndexMap<(Token, Token), Token> { 304 | &self.merges 305 | } 306 | 307 | fn vocab(&self) -> &IndexMap> { 308 | &self.vocab 309 | } 310 | 311 | fn decode(&self, ids: &[Token]) -> String { 312 | // Forwarding to the default implementation provided by RegexTokenizerTrait 313 | ::decode(self, ids) 314 | } 315 | 316 | fn encode(&self, text: &str) -> Vec { 317 | // Forwarding to the default implementation provided by RegexTokenizerTrait 318 | ::encode(self, text) 319 | } 320 | } 321 | 322 | impl Trainable for RegexTokenizerStruct { 323 | fn train(&mut self, text: &str, vocab_size: Token, verbose: bool) { 324 | assert!(vocab_size >= 256, "Vocab size must be at least 256"); 325 | let num_merges = vocab_size - 256; 326 | 327 | // Split the text into chunks 328 | let text_chunks: Vec<&str> = self 329 | .compiled_pattern() 330 | .find_iter(text) 331 | .map(|m| { 332 | let matched = m.unwrap(); 333 | &text[matched.start()..matched.end()] 334 | }) 335 | .collect(); 336 | 337 | // Input text preprocessing 338 | let mut ids: Vec> = text_chunks 339 | .iter() 340 | .map(|chunk| chunk.as_bytes().iter().map(|b| *b as Token).collect()) 341 | .collect(); 342 | 343 | // Iteratively merge the most common pairs to create new tokens 344 | let mut merges: IndexMap<(Token, Token), Token> = IndexMap::new(); 345 | let mut vocab: IndexMap> = 346 | (0..256).map(|idx| (idx, vec![idx as u8])).collect(); 347 | 348 | for i in 0..num_merges { 349 | // Count the number of times every consecutive pair appears 350 | let mut stats = IndexMap::new(); 351 | for chunk_ids in &ids { 352 | update_stats(chunk_ids, &mut stats); 353 | } 354 | 355 | // Find the pair with the highest count 356 | let pair = get_max_entry(&stats).unwrap().0; 357 | 358 | // Mint a new token: assign it the next available id 359 | let idx = 256 + i; 360 | 361 | // Replace all occurrences of pair in ids with idx 362 | ids = ids 363 | .iter() 364 | .map(|chunk_ids| merge(chunk_ids, *pair, idx)) 365 | .collect(); 366 | 367 | // Save the merge 368 | merges.insert(*pair, idx); 369 | vocab.insert( 370 | idx, 371 | [vocab[&pair.0].clone(), vocab[&pair.1].clone()].concat(), 372 | ); 373 | 374 | // Prints 375 | if verbose { 376 | println!( 377 | "merge {}/{}: {:?} -> {} ({:?}) had {} occurrences", 378 | i + 1, 379 | num_merges, 380 | pair, 381 | idx, 382 | vocab[&idx], 383 | stats[pair] 384 | ); 385 | } 386 | } 387 | 388 | // Save instance variables 389 | self.merges = merges; 390 | self.vocab = vocab; // FIXME: vs. build_vocab(&self.special_tokens, &self.merges); 391 | } 392 | } 393 | 394 | impl Saveable for RegexTokenizerStruct { 395 | fn pattern(&self) -> &str { 396 | &self.pattern 397 | } 398 | } 399 | 400 | impl Loadable for RegexTokenizerStruct { 401 | fn set_pattern(&mut self, pattern: &str) { 402 | self.pattern = pattern.to_string(); 403 | self.compiled_pattern = Regex::new(pattern).unwrap(); 404 | } 405 | 406 | fn set_special_tokens(&mut self, special_tokens: IndexMap) { 407 | self.special_tokens = special_tokens.clone(); 408 | self.inverse_special_tokens = special_tokens 409 | .iter() 410 | .map(|(k, v)| (*v, k.clone())) 411 | .collect(); 412 | } 413 | 414 | fn set_merges(&mut self, merges: IndexMap<(Token, Token), Token>) { 415 | self.merges = merges; 416 | } 417 | 418 | fn set_vocab(&mut self, vocab: IndexMap>) { 419 | self.vocab = vocab; 420 | } 421 | } 422 | 423 | impl RegexTokenizerTrait for RegexTokenizerStruct { 424 | fn compiled_pattern(&self) -> &Regex { 425 | &self.compiled_pattern 426 | } 427 | 428 | fn inverse_special_tokens(&self) -> &IndexMap { 429 | &self.inverse_special_tokens 430 | } 431 | } 432 | 433 | #[cfg(test)] 434 | mod tests { 435 | use super::*; 436 | use indexmap::IndexMap; 437 | use std::collections::HashSet; 438 | 439 | #[test] 440 | fn test_pattern_matching() { 441 | let text = "Hello, world! <|endoftext|>"; 442 | 443 | let pattern = "(<\\|endoftext\\|>)"; 444 | let re = fancy_regex::Regex::new(pattern).unwrap(); 445 | 446 | let mut last_end = 0; 447 | let mut special_chunks = Vec::new(); 448 | for m in re.find_iter(text) { 449 | let m = m.unwrap(); 450 | // Push the text between matches 451 | special_chunks.push(&text[last_end..m.start()]); 452 | // Push the matched text 453 | special_chunks.push(&text[m.start()..m.end()]); 454 | last_end = m.end(); 455 | } 456 | let remaining = &text[last_end..]; 457 | if !remaining.is_empty() { 458 | special_chunks.push(remaining); 459 | } 460 | } 461 | 462 | #[test] 463 | fn test_encode_special() { 464 | let mut tokenizer = RegexTokenizerStruct::default(); 465 | tokenizer.train("Hello, world! Goodbye, world!, So long...", 256 + 10, true); 466 | 467 | let text = "Hello, world! <|endoftext|>"; 468 | 469 | let special_tokens = IndexMap::from([("<|endoftext|>".to_string(), 100257)]); 470 | tokenizer.set_special_tokens(special_tokens); 471 | 472 | let encoded_all = tokenizer.encode_special(text, AllowedSpecial::All); 473 | let encoded_none = tokenizer.encode_special(text, AllowedSpecial::None); 474 | 475 | let custom_set = HashSet::from(["<|endoftext|>".to_string()]); 476 | let encoded_custom = tokenizer.encode_special(text, AllowedSpecial::Set(custom_set)); 477 | 478 | assert!(encoded_all.contains(&100257)); 479 | assert!(!encoded_none.contains(&100257)); 480 | assert!(encoded_custom.contains(&100257)); 481 | } 482 | 483 | #[test] 484 | #[should_panic] 485 | fn test_encode_special_panic() { 486 | let mut tokenizer = RegexTokenizerStruct::default(); 487 | let text = "Hello, world! <|endofext|>"; 488 | 489 | let special_tokens = IndexMap::from([("<|endofext|>".to_string(), 100257)]); 490 | tokenizer.set_special_tokens(special_tokens); 491 | 492 | // This should panic 493 | let _ = tokenizer.encode_special(text, AllowedSpecial::NoneRaise); 494 | } 495 | } 496 | -------------------------------------------------------------------------------- /src/test_common.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | use std::path::PathBuf; 3 | 4 | use crate::Token; 5 | 6 | use indexmap::indexmap; 7 | use indexmap::IndexMap; 8 | use lazy_static::lazy_static; 9 | 10 | lazy_static! { 11 | pub static ref SPECIAL_TOKENS: IndexMap = indexmap! { 12 | "<|endoftext|>".to_string() => 100257, 13 | "<|fim_prefix|>".to_string()=> 100258, 14 | "<|fim_middle|>".to_string()=> 100259, 15 | "<|fim_suffix|>".to_string()=> 100260, 16 | "<|endofprompt|>".to_string()=> 100276 17 | }; 18 | } 19 | 20 | pub const LLAMA_TEXT: &str = r###"<|endoftext|>The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era. 21 | Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4] 22 | The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5] 23 | <|fim_prefix|>In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology,<|fim_suffix|> where they come from at the end of time.[6]<|fim_middle|> llamas will return to the water springs and ponds<|endofprompt|>"###; 24 | 25 | // a few strings to test the tokenizers on 26 | pub const TEST_STRINGS: [&str; 4] = [ 27 | "", // empty string 28 | "?", // single character 29 | "hello world!!!? (안녕하세요!) lol123 😉", // fun small string 30 | "FILE:../tests/taylorswift.txt", // FILE: is handled as a special string in unpack() 31 | ]; 32 | 33 | pub fn test_strings() -> [&'static str; 4] { 34 | TEST_STRINGS 35 | } 36 | 37 | pub fn unpack(text: &str) -> std::io::Result { 38 | if let Some(filename) = text.strip_prefix("FILE:") { 39 | let dirname = PathBuf::from(file!()).parent().unwrap().to_path_buf(); 40 | let file_path = dirname.join(filename); 41 | println!("Reading file: {:?}...", file_path); 42 | fs::read_to_string(file_path) 43 | } else { 44 | Ok(text.to_string()) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tests/gen_test_case.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | def generate_test_case(): 5 | """ 6 | Generate a test case for checking Rust IndexMap preserves insertion order like Python dict. 7 | 8 | Note: The test case in question is in tokenizer.rs. 9 | """ 10 | stats = {} 11 | for i in range(20): 12 | if random.random() < 0.25: 13 | value = 99 14 | else: 15 | value = random.randint(0, 20) 16 | stats[(i, i)] = value 17 | print(f"(({i},{i}), {value})") 18 | 19 | pair = max(stats, key=stats.get) 20 | print(pair) 21 | 22 | 23 | generate_test_case() 24 | -------------------------------------------------------------------------------- /tests/tiktoken_compat.proptest-regressions: -------------------------------------------------------------------------------- 1 | # Seeds for failure cases proptest has generated in the past. It is 2 | # automatically read and these particular cases re-run before any 3 | # novel cases are generated. 4 | # 5 | # It is recommended to check this file in to source control so that 6 | # everyone who runs the test benefits from these saved cases. 7 | cc 4a95a5006a1366458c24f6b29e4d255ce9e7cacb790f546d44284b6700898251 # shrinks to s = "\u{1e01b}%SΣ" 8 | -------------------------------------------------------------------------------- /tests/tiktoken_compat.rs: -------------------------------------------------------------------------------- 1 | #[cfg(all(test, feature = "tiktoken_tests"))] 2 | mod tests { 3 | use std::collections::HashSet; 4 | 5 | use lazy_static::lazy_static; 6 | use minbpe::{GPT4Tokenizer, RegexTokenizerTrait, Token}; 7 | use proptest::prelude::*; 8 | use tiktoken_rs::{cl100k_base, CoreBPE}; 9 | 10 | lazy_static! { 11 | static ref SPECIAL_TOKENS: HashSet<&'static str> = HashSet::new(); 12 | static ref TIKTOKEN_ENC: CoreBPE = cl100k_base().unwrap(); 13 | static ref GPT4_TOKENIZER: GPT4Tokenizer = GPT4Tokenizer::default(); 14 | } 15 | 16 | fn test_one(s: &str) { 17 | let special_tokens = HashSet::new(); 18 | 19 | let tiktoken_ids = TIKTOKEN_ENC.encode(s, special_tokens); 20 | let tiktoken_tokens: Vec = tiktoken_ids.iter().map(|&id| id as Token).collect(); 21 | 22 | let gpt4_tokenizer_tokens = GPT4_TOKENIZER.encode(s); 23 | 24 | assert_eq!(tiktoken_tokens, gpt4_tokenizer_tokens); 25 | } 26 | 27 | #[test] 28 | fn test_high_char() { 29 | test_one("\u{1e01b}%SΣ"); 30 | } 31 | 32 | proptest! { 33 | #[test] 34 | #[allow(unused_must_use)] 35 | fn gpt4_tokenizer_matches_tiktoken(s in "\\PC*") { 36 | test_one(&s); 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /tests/tokenizer_tests.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests { 3 | use minbpe::test_common::{LLAMA_TEXT, SPECIAL_TOKENS}; 4 | use minbpe::AllowedSpecial; 5 | use minbpe::BasicTokenizer; 6 | use minbpe::Loadable; 7 | use minbpe::RegexTokenizerStruct; 8 | use minbpe::RegexTokenizerTrait; 9 | use minbpe::Saveable; 10 | use minbpe::Token; 11 | use minbpe::Trainable; 12 | 13 | use indexmap::IndexMap; 14 | use tempfile::tempdir; 15 | 16 | // Quick unit test, following along the Wikipedia example: 17 | // https://en.wikipedia.org/wiki/Byte_pair_encoding 18 | // 19 | // According to Wikipedia, running bpe on the input string: 20 | // "aaabdaaabac" 21 | // 22 | // for 3 merges will result in string: 23 | // "XdXac" 24 | // 25 | // where: 26 | // X=ZY 27 | // Y=ab 28 | // Z=aa 29 | // 30 | // Keep in mind that for us a=97, b=98, c=99, d=100 (ASCII values) 31 | // so Z will be 256, Y will be 257, X will be 258. 32 | // 33 | // So we expect the output list of ids to be [258, 100, 258, 97, 99] 34 | fn test_wikipedia_example_inner(tokenizer: &mut Box) { 35 | let text = "aaabdaaabac"; 36 | tokenizer.train(text, 256 + 3, false); 37 | let ids = tokenizer.encode(text); 38 | assert_eq!(ids, [258, 100, 258, 97, 99]); 39 | let encoded = tokenizer.encode(text); 40 | let decoded = tokenizer.decode(&encoded); 41 | assert_eq!(decoded, text); 42 | } 43 | 44 | #[test] 45 | fn test_wikipedia_example() { 46 | let tokenizers: Vec> = vec![ 47 | Box::new(BasicTokenizer::new()), 48 | Box::::default(), 49 | ]; 50 | 51 | for mut tokenizer in tokenizers { 52 | test_wikipedia_example_inner(&mut tokenizer); 53 | } 54 | } 55 | 56 | fn test_save_load_inner(special_tokens: &IndexMap) { 57 | // take a bit more complex piece of text and train the tokenizer 58 | let text = LLAMA_TEXT; 59 | // create a Tokenizer and do 64 merges 60 | let mut tokenizer = RegexTokenizerStruct::default(); 61 | tokenizer.train(text, 256 + 64, false); 62 | tokenizer.set_special_tokens(special_tokens.clone()); // Feels weird to do this after training, not part of setup 63 | 64 | // verify that decode(encode(x)) == x 65 | let encoded = tokenizer.encode_special(text, AllowedSpecial::All); 66 | let decoded = tokenizer.decode(&encoded); 67 | assert_eq!(decoded, text); 68 | 69 | // verify that save/load work as expected; save the tokenizer 70 | let dir = tempdir().unwrap(); 71 | tokenizer.save(dir.path(), "test_tokenizer_tmp"); 72 | 73 | // re-load the tokenizer 74 | let mut tokenizer = RegexTokenizerStruct::default(); 75 | let model_file = dir.path().join("test_tokenizer_tmp.model"); 76 | tokenizer.load(&model_file); 77 | 78 | // verify that decode(encode(x)) == x 79 | assert_eq!(tokenizer.decode(&encoded), text); 80 | assert_eq!( 81 | tokenizer.decode(&tokenizer.encode_special(text, AllowedSpecial::All)), 82 | text 83 | ); 84 | assert_eq!(tokenizer.encode_special(text, AllowedSpecial::All), encoded); 85 | } 86 | 87 | #[test] 88 | fn test_save_load() { 89 | let special_tokens = IndexMap::new(); 90 | test_save_load_inner(&special_tokens); 91 | let special_tokens = &SPECIAL_TOKENS; 92 | test_save_load_inner(special_tokens); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /tests/tokenizer_tests_gpt4.rs: -------------------------------------------------------------------------------- 1 | #[cfg(all(test, feature = "gpt4"))] 2 | mod tests { 3 | use std::collections::HashSet; 4 | use tiktoken_rs::cl100k_base; 5 | 6 | use minbpe::GPT4Tokenizer; 7 | use minbpe::RegexTokenizerTrait; 8 | use minbpe::Token; 9 | 10 | use minbpe::test_common::{unpack, TEST_STRINGS}; 11 | 12 | // test that our tokenizer matches the official GPT-4 tokenizer 13 | fn test_gpt4_tiktoken_equality_inner(text: String) { 14 | let special_tokens: HashSet<&str> = HashSet::new(); 15 | 16 | let text = unpack(&text).unwrap(); 17 | println!( 18 | "test_gpt4_tiktoken_equality_inner: text length is: {:?}", 19 | text.len() 20 | ); 21 | use std::time::Instant; 22 | 23 | let enc = cl100k_base().unwrap(); 24 | 25 | let tiktoken_start = Instant::now(); 26 | let tiktoken_ids = enc.encode(&text, special_tokens); 27 | let tiktoken_tokens: Vec = tiktoken_ids.iter().map(|&id| id as Token).collect(); 28 | let tiktoken_duration = tiktoken_start.elapsed(); 29 | println!("TikToken encoding took: {:?}", tiktoken_duration); 30 | 31 | let tokenizer = GPT4Tokenizer::new(); 32 | 33 | let gpt4_start = Instant::now(); 34 | let gpt4_tokenizer_tokens = tokenizer.encode(&text); 35 | let gpt4_duration = gpt4_start.elapsed(); 36 | println!("GPT4 encoding took: {:?}", gpt4_duration); 37 | 38 | assert_eq!( 39 | tiktoken_tokens.len(), 40 | gpt4_tokenizer_tokens.len(), 41 | "Token vectors are of different lengths: {} expected, but found {}", 42 | tiktoken_tokens.len(), 43 | gpt4_tokenizer_tokens.len() 44 | ); 45 | assert_eq!( 46 | tiktoken_tokens, gpt4_tokenizer_tokens, 47 | "Token vectors do not match" 48 | ); 49 | } 50 | 51 | #[test] 52 | fn test_gpt4_tiktoken_equality() { 53 | GPT4Tokenizer::initialize(); // pre-initialize the tokenizer static data 54 | 55 | for text in TEST_STRINGS.iter() { 56 | println!("test_gpt4_tiktoken_equality: testing with text: {:?}", text); 57 | let text = unpack(text).unwrap(); 58 | test_gpt4_tiktoken_equality_inner(text); 59 | } 60 | } 61 | } 62 | --------------------------------------------------------------------------------