├── .gitignore ├── .rustfmt.toml ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── docs ├── group_size.md ├── quantization.md ├── temperature.md ├── tokenizer.md └── transformer.md ├── qwen3-cli ├── Cargo.toml └── src │ └── main.rs ├── qwen3-export ├── Cargo.toml ├── src │ ├── chat_template_exporter.rs │ ├── config_loader.rs │ ├── lib.rs │ ├── lora_merger.rs │ ├── model_exporter.rs │ ├── models │ │ ├── mod.rs │ │ └── qwen3.rs │ ├── tensor_reader.rs │ ├── tokenizer_exporter.rs │ └── utils.rs └── tests │ └── unit │ ├── config_loader_test.rs │ ├── model_exporter_test.rs │ └── tokenizer_exporter_test.rs └── qwen3-inference ├── Cargo.toml └── src ├── configuration.rs ├── generation.rs ├── layers.rs ├── lib.rs ├── models ├── mod.rs └── qwen3.rs ├── sampler.rs ├── tensor.rs ├── tokenizer.rs └── utils.rs /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | target/ 3 | DeepSeek* 4 | Qwen3-* 5 | .models 6 | 7 | notes.log 8 | 9 | *.py 10 | *.c -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width=120 2 | use_small_heuristics="Max" -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "1.1.3" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "anstream" 16 | version = "0.6.19" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" 19 | dependencies = [ 20 | "anstyle", 21 | "anstyle-parse", 22 | "anstyle-query", 23 | "anstyle-wincon", 24 | "colorchoice", 25 | "is_terminal_polyfill", 26 | "utf8parse", 27 | ] 28 | 29 | [[package]] 30 | name = "anstyle" 31 | version = "1.0.11" 32 | source = "registry+https://github.com/rust-lang/crates.io-index" 33 | checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" 34 | 35 | [[package]] 36 | name = "anstyle-parse" 37 | version = "0.2.7" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" 40 | dependencies = [ 41 | "utf8parse", 42 | ] 43 | 44 | [[package]] 45 | name = "anstyle-query" 46 | version = "1.1.3" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" 49 | dependencies = [ 50 | "windows-sys", 51 | ] 52 | 53 | [[package]] 54 | name = "anstyle-wincon" 55 | version = "3.0.9" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" 58 | dependencies = [ 59 | "anstyle", 60 | "once_cell_polyfill", 61 | "windows-sys", 62 | ] 63 | 64 | [[package]] 65 | name = "anyhow" 66 | version = "1.0.98" 67 | source = "registry+https://github.com/rust-lang/crates.io-index" 68 | checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" 69 | 70 | [[package]] 71 | name = "bitflags" 72 | version = "2.9.1" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" 75 | 76 | [[package]] 77 | name = "byteorder" 78 | version = "1.5.0" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 81 | 82 | [[package]] 83 | name = "cfg-if" 84 | version = "1.0.1" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" 87 | 88 | [[package]] 89 | name = "clap" 90 | version = "4.5.40" 91 | source = "registry+https://github.com/rust-lang/crates.io-index" 92 | checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" 93 | dependencies = [ 94 | "clap_builder", 95 | "clap_derive", 96 | ] 97 | 98 | [[package]] 99 | name = "clap_builder" 100 | version = "4.5.40" 101 | source = "registry+https://github.com/rust-lang/crates.io-index" 102 | checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" 103 | dependencies = [ 104 | "anstream", 105 | "anstyle", 106 | "clap_lex", 107 | "strsim", 108 | ] 109 | 110 | [[package]] 111 | name = "clap_derive" 112 | version = "4.5.40" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" 115 | dependencies = [ 116 | "heck", 117 | "proc-macro2", 118 | "quote", 119 | "syn", 120 | ] 121 | 122 | [[package]] 123 | name = "clap_lex" 124 | version = "0.7.5" 125 | source = "registry+https://github.com/rust-lang/crates.io-index" 126 | checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" 127 | 128 | [[package]] 129 | name = "colorchoice" 130 | version = "1.0.4" 131 | source = "registry+https://github.com/rust-lang/crates.io-index" 132 | checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" 133 | 134 | [[package]] 135 | name = "crossbeam-deque" 136 | version = "0.8.6" 137 | source = "registry+https://github.com/rust-lang/crates.io-index" 138 | checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" 139 | dependencies = [ 140 | "crossbeam-epoch", 141 | "crossbeam-utils", 142 | ] 143 | 144 | [[package]] 145 | name = "crossbeam-epoch" 146 | version = "0.9.18" 147 | source = "registry+https://github.com/rust-lang/crates.io-index" 148 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 149 | dependencies = [ 150 | "crossbeam-utils", 151 | ] 152 | 153 | [[package]] 154 | name = "crossbeam-utils" 155 | version = "0.8.21" 156 | source = "registry+https://github.com/rust-lang/crates.io-index" 157 | checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" 158 | 159 | [[package]] 160 | name = "either" 161 | version = "1.15.0" 162 | source = "registry+https://github.com/rust-lang/crates.io-index" 163 | checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" 164 | 165 | [[package]] 166 | name = "env_logger" 167 | version = "0.10.2" 168 | source = "registry+https://github.com/rust-lang/crates.io-index" 169 | checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" 170 | dependencies = [ 171 | "humantime", 172 | "is-terminal", 173 | "log", 174 | "regex", 175 | "termcolor", 176 | ] 177 | 178 | [[package]] 179 | name = "errno" 180 | version = "0.3.13" 181 | source = "registry+https://github.com/rust-lang/crates.io-index" 182 | checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" 183 | dependencies = [ 184 | "libc", 185 | "windows-sys", 186 | ] 187 | 188 | [[package]] 189 | name = "fastrand" 190 | version = "2.3.0" 191 | source = "registry+https://github.com/rust-lang/crates.io-index" 192 | checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" 193 | 194 | [[package]] 195 | name = "getrandom" 196 | version = "0.3.3" 197 | source = "registry+https://github.com/rust-lang/crates.io-index" 198 | checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" 199 | dependencies = [ 200 | "cfg-if", 201 | "libc", 202 | "r-efi", 203 | "wasi", 204 | ] 205 | 206 | [[package]] 207 | name = "heck" 208 | version = "0.5.0" 209 | source = "registry+https://github.com/rust-lang/crates.io-index" 210 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" 211 | 212 | [[package]] 213 | name = "hermit-abi" 214 | version = "0.5.2" 215 | source = "registry+https://github.com/rust-lang/crates.io-index" 216 | checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" 217 | 218 | [[package]] 219 | name = "humantime" 220 | version = "2.2.0" 221 | source = "registry+https://github.com/rust-lang/crates.io-index" 222 | checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" 223 | 224 | [[package]] 225 | name = "is-terminal" 226 | version = "0.4.16" 227 | source = "registry+https://github.com/rust-lang/crates.io-index" 228 | checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" 229 | dependencies = [ 230 | "hermit-abi", 231 | "libc", 232 | "windows-sys", 233 | ] 234 | 235 | [[package]] 236 | name = "is_terminal_polyfill" 237 | version = "1.70.1" 238 | source = "registry+https://github.com/rust-lang/crates.io-index" 239 | checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" 240 | 241 | [[package]] 242 | name = "itoa" 243 | version = "1.0.15" 244 | source = "registry+https://github.com/rust-lang/crates.io-index" 245 | checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" 246 | 247 | [[package]] 248 | name = "libc" 249 | version = "0.2.174" 250 | source = "registry+https://github.com/rust-lang/crates.io-index" 251 | checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" 252 | 253 | [[package]] 254 | name = "linux-raw-sys" 255 | version = "0.9.4" 256 | source = "registry+https://github.com/rust-lang/crates.io-index" 257 | checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" 258 | 259 | [[package]] 260 | name = "log" 261 | version = "0.4.27" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" 264 | 265 | [[package]] 266 | name = "memchr" 267 | version = "2.7.5" 268 | source = "registry+https://github.com/rust-lang/crates.io-index" 269 | checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" 270 | 271 | [[package]] 272 | name = "memmap2" 273 | version = "0.9.5" 274 | source = "registry+https://github.com/rust-lang/crates.io-index" 275 | checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" 276 | dependencies = [ 277 | "libc", 278 | ] 279 | 280 | [[package]] 281 | name = "once_cell" 282 | version = "1.21.3" 283 | source = "registry+https://github.com/rust-lang/crates.io-index" 284 | checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" 285 | 286 | [[package]] 287 | name = "once_cell_polyfill" 288 | version = "1.70.1" 289 | source = "registry+https://github.com/rust-lang/crates.io-index" 290 | checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" 291 | 292 | [[package]] 293 | name = "proc-macro2" 294 | version = "1.0.95" 295 | source = "registry+https://github.com/rust-lang/crates.io-index" 296 | checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" 297 | dependencies = [ 298 | "unicode-ident", 299 | ] 300 | 301 | [[package]] 302 | name = "quote" 303 | version = "1.0.40" 304 | source = "registry+https://github.com/rust-lang/crates.io-index" 305 | checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" 306 | dependencies = [ 307 | "proc-macro2", 308 | ] 309 | 310 | [[package]] 311 | name = "qwen3-cli" 312 | version = "0.1.0" 313 | dependencies = [ 314 | "anyhow", 315 | "clap", 316 | "env_logger", 317 | "log", 318 | "qwen3-export", 319 | "qwen3-inference", 320 | ] 321 | 322 | [[package]] 323 | name = "qwen3-export" 324 | version = "0.1.0" 325 | dependencies = [ 326 | "anyhow", 327 | "byteorder", 328 | "log", 329 | "memmap2", 330 | "rayon", 331 | "safetensors", 332 | "serde", 333 | "serde_json", 334 | "tempfile", 335 | ] 336 | 337 | [[package]] 338 | name = "qwen3-inference" 339 | version = "0.1.0" 340 | dependencies = [ 341 | "anyhow", 342 | "byteorder", 343 | "log", 344 | "memmap2", 345 | "rayon", 346 | ] 347 | 348 | [[package]] 349 | name = "r-efi" 350 | version = "5.3.0" 351 | source = "registry+https://github.com/rust-lang/crates.io-index" 352 | checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" 353 | 354 | [[package]] 355 | name = "rayon" 356 | version = "1.10.0" 357 | source = "registry+https://github.com/rust-lang/crates.io-index" 358 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" 359 | dependencies = [ 360 | "either", 361 | "rayon-core", 362 | ] 363 | 364 | [[package]] 365 | name = "rayon-core" 366 | version = "1.12.1" 367 | source = "registry+https://github.com/rust-lang/crates.io-index" 368 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" 369 | dependencies = [ 370 | "crossbeam-deque", 371 | "crossbeam-utils", 372 | ] 373 | 374 | [[package]] 375 | name = "regex" 376 | version = "1.11.1" 377 | source = "registry+https://github.com/rust-lang/crates.io-index" 378 | checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" 379 | dependencies = [ 380 | "aho-corasick", 381 | "memchr", 382 | "regex-automata", 383 | "regex-syntax", 384 | ] 385 | 386 | [[package]] 387 | name = "regex-automata" 388 | version = "0.4.9" 389 | source = "registry+https://github.com/rust-lang/crates.io-index" 390 | checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" 391 | dependencies = [ 392 | "aho-corasick", 393 | "memchr", 394 | "regex-syntax", 395 | ] 396 | 397 | [[package]] 398 | name = "regex-syntax" 399 | version = "0.8.5" 400 | source = "registry+https://github.com/rust-lang/crates.io-index" 401 | checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" 402 | 403 | [[package]] 404 | name = "rustix" 405 | version = "1.0.7" 406 | source = "registry+https://github.com/rust-lang/crates.io-index" 407 | checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" 408 | dependencies = [ 409 | "bitflags", 410 | "errno", 411 | "libc", 412 | "linux-raw-sys", 413 | "windows-sys", 414 | ] 415 | 416 | [[package]] 417 | name = "ryu" 418 | version = "1.0.20" 419 | source = "registry+https://github.com/rust-lang/crates.io-index" 420 | checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" 421 | 422 | [[package]] 423 | name = "safetensors" 424 | version = "0.4.5" 425 | source = "registry+https://github.com/rust-lang/crates.io-index" 426 | checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" 427 | dependencies = [ 428 | "serde", 429 | "serde_json", 430 | ] 431 | 432 | [[package]] 433 | name = "serde" 434 | version = "1.0.219" 435 | source = "registry+https://github.com/rust-lang/crates.io-index" 436 | checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" 437 | dependencies = [ 438 | "serde_derive", 439 | ] 440 | 441 | [[package]] 442 | name = "serde_derive" 443 | version = "1.0.219" 444 | source = "registry+https://github.com/rust-lang/crates.io-index" 445 | checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" 446 | dependencies = [ 447 | "proc-macro2", 448 | "quote", 449 | "syn", 450 | ] 451 | 452 | [[package]] 453 | name = "serde_json" 454 | version = "1.0.140" 455 | source = "registry+https://github.com/rust-lang/crates.io-index" 456 | checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" 457 | dependencies = [ 458 | "itoa", 459 | "memchr", 460 | "ryu", 461 | "serde", 462 | ] 463 | 464 | [[package]] 465 | name = "strsim" 466 | version = "0.11.1" 467 | source = "registry+https://github.com/rust-lang/crates.io-index" 468 | checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" 469 | 470 | [[package]] 471 | name = "syn" 472 | version = "2.0.104" 473 | source = "registry+https://github.com/rust-lang/crates.io-index" 474 | checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" 475 | dependencies = [ 476 | "proc-macro2", 477 | "quote", 478 | "unicode-ident", 479 | ] 480 | 481 | [[package]] 482 | name = "tempfile" 483 | version = "3.20.0" 484 | source = "registry+https://github.com/rust-lang/crates.io-index" 485 | checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" 486 | dependencies = [ 487 | "fastrand", 488 | "getrandom", 489 | "once_cell", 490 | "rustix", 491 | "windows-sys", 492 | ] 493 | 494 | [[package]] 495 | name = "termcolor" 496 | version = "1.4.1" 497 | source = "registry+https://github.com/rust-lang/crates.io-index" 498 | checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" 499 | dependencies = [ 500 | "winapi-util", 501 | ] 502 | 503 | [[package]] 504 | name = "unicode-ident" 505 | version = "1.0.18" 506 | source = "registry+https://github.com/rust-lang/crates.io-index" 507 | checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" 508 | 509 | [[package]] 510 | name = "utf8parse" 511 | version = "0.2.2" 512 | source = "registry+https://github.com/rust-lang/crates.io-index" 513 | checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" 514 | 515 | [[package]] 516 | name = "wasi" 517 | version = "0.14.2+wasi-0.2.4" 518 | source = "registry+https://github.com/rust-lang/crates.io-index" 519 | checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" 520 | dependencies = [ 521 | "wit-bindgen-rt", 522 | ] 523 | 524 | [[package]] 525 | name = "winapi-util" 526 | version = "0.1.9" 527 | source = "registry+https://github.com/rust-lang/crates.io-index" 528 | checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" 529 | dependencies = [ 530 | "windows-sys", 531 | ] 532 | 533 | [[package]] 534 | name = "windows-sys" 535 | version = "0.59.0" 536 | source = "registry+https://github.com/rust-lang/crates.io-index" 537 | checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" 538 | dependencies = [ 539 | "windows-targets", 540 | ] 541 | 542 | [[package]] 543 | name = "windows-targets" 544 | version = "0.52.6" 545 | source = "registry+https://github.com/rust-lang/crates.io-index" 546 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 547 | dependencies = [ 548 | "windows_aarch64_gnullvm", 549 | "windows_aarch64_msvc", 550 | "windows_i686_gnu", 551 | "windows_i686_gnullvm", 552 | "windows_i686_msvc", 553 | "windows_x86_64_gnu", 554 | "windows_x86_64_gnullvm", 555 | "windows_x86_64_msvc", 556 | ] 557 | 558 | [[package]] 559 | name = "windows_aarch64_gnullvm" 560 | version = "0.52.6" 561 | source = "registry+https://github.com/rust-lang/crates.io-index" 562 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 563 | 564 | [[package]] 565 | name = "windows_aarch64_msvc" 566 | version = "0.52.6" 567 | source = "registry+https://github.com/rust-lang/crates.io-index" 568 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 569 | 570 | [[package]] 571 | name = "windows_i686_gnu" 572 | version = "0.52.6" 573 | source = "registry+https://github.com/rust-lang/crates.io-index" 574 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 575 | 576 | [[package]] 577 | name = "windows_i686_gnullvm" 578 | version = "0.52.6" 579 | source = "registry+https://github.com/rust-lang/crates.io-index" 580 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 581 | 582 | [[package]] 583 | name = "windows_i686_msvc" 584 | version = "0.52.6" 585 | source = "registry+https://github.com/rust-lang/crates.io-index" 586 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 587 | 588 | [[package]] 589 | name = "windows_x86_64_gnu" 590 | version = "0.52.6" 591 | source = "registry+https://github.com/rust-lang/crates.io-index" 592 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 593 | 594 | [[package]] 595 | name = "windows_x86_64_gnullvm" 596 | version = "0.52.6" 597 | source = "registry+https://github.com/rust-lang/crates.io-index" 598 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 599 | 600 | [[package]] 601 | name = "windows_x86_64_msvc" 602 | version = "0.52.6" 603 | source = "registry+https://github.com/rust-lang/crates.io-index" 604 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 605 | 606 | [[package]] 607 | name = "wit-bindgen-rt" 608 | version = "0.39.0" 609 | source = "registry+https://github.com/rust-lang/crates.io-index" 610 | checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" 611 | dependencies = [ 612 | "bitflags", 613 | ] 614 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = [ 4 | "qwen3-cli", 5 | "qwen3-export", 6 | "qwen3-inference" 7 | ] 8 | 9 | [workspace.package] 10 | version = "0.1.0" 11 | authors = ["Ilya Builuk "] 12 | repository = "https://github.com/reinterpretcat/qwen3-rs" 13 | license = "Apache-2.0" 14 | keywords = ["LLM", "qwen3"] 15 | categories = ["LLM"] 16 | edition = "2024" 17 | 18 | [workspace.dependencies] 19 | anyhow = "1.0" 20 | byteorder = "1.5" 21 | clap = { version = "4.0", features = ["derive"] } 22 | rayon = "1.8" 23 | serde_json = "1.0" 24 | serde = { version = "1.0", features = ["derive"] } 25 | safetensors = "0.4" 26 | memmap2 = "0.9" 27 | log = "0.4" 28 | env_logger = "0.10" 29 | 30 | qwen3-export = { path = "qwen3-export" } 31 | qwen3-inference = { path = "qwen3-inference" } 32 | 33 | [profile.release] 34 | opt-level = 3 35 | lto = true 36 | codegen-units = 1 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2025 Ilya Builuk 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | **qwen3-rs** is an educational Rust project for exploring and running Qwen3 language family models. It is designed to be clear, modular, and approachable for learners, with minimal dependencies and many core algorithms reimplemented from scratch for transparency. 4 | 5 | > **Note:** Parts of this codebase, including documentation and core algorithms, were generated or assisted by large language models (LLMs) to accelerate development and improve educational clarity. As a starting reference, the project [qwen3.c](https://github.com/adriancable/qwen3.c) was used for understanding model internals and file formats. 6 | 7 | 8 | ## Project Goals 9 | 10 | - **Educational:** Learn how transformer architectures, quantization, and efficient inference work in Rust. 11 | - **Minimal Dependencies:** Most algorithms (tokenization, quantization, sampling, etc.) are implemented from scratch—no heavy ML or Python bindings. 12 | - **Modular:** Core library logic is separated from CLI tools for clarity and maintainability. 13 | - **Efficiency:** Uses memory mapping and zero-copy techniques for handling large model files. 14 | 15 | ## Workspace Structure 16 | 17 | ``` 18 | qwen3-rs/ 19 | ├── docs # LLM generated docs for key components 20 | ├── Cargo.toml # Workspace configuration 21 | ├── qwen3-cli/ # Command-line interface crate 22 | ├── qwen3-export/ # Model export crate 23 | ├── qwen3-inference/ # LLM inference crate 24 | ``` 25 | 26 | ## How to Use 27 | 28 | ### 1. Get a HuggingFace Qwen3 model 29 | 30 | Currently, `Qwen3ForCausalLM` architecture is supported. 31 | 32 | ```bash 33 | git clone https://huggingface.co/Qwen/Qwen3-0.6B 34 | # Or try larger/alternative models: 35 | # git clone https://huggingface.co/Qwen/Qwen3-4B 36 | # git clone https://huggingface.co/Qwen/Qwen3-8B 37 | # git clone https://huggingface.co/deepseek-ai/DeepSeek-R1-0528-Qwen3-8B 38 | ``` 39 | 40 | **NOTE**: `Low Rank Adaptation (LoRA)` is supported: copy `adapter_config.json` and safe tensor files to the same folder, 41 | where base model is located, and it will be automatically detected (tested with `usloth` output). 42 | 43 | 44 | ### 2. Build and run the exporter 45 | 46 | ```bash 47 | cargo build --release -p qwen3-cli 48 | 49 | # Export a HuggingFace model to quantized checkpoint format 50 | cargo run --release -p qwen3-cli -- export /path/to/model /path/to/output.bin --group-size 64 51 | ``` 52 | 53 | ### 3. Run inference 54 | 55 | In chat mode with default parameters: 56 | 57 | ```bash 58 | cargo run --release -p qwen3-cli -- inference /path/to/output.bin -m chat 59 | ``` 60 | 61 | ## CLI Commands and Options 62 | 63 | ### `export` 64 | Exports a HuggingFace Qwen3 model to a custom binary format for efficient Rust inference. 65 | 66 | **Usage:** 67 | ```bash 68 | qwen3 export [--group-size ] 69 | ``` 70 | - `MODEL_PATH`: Path to HuggingFace model directory (must contain config.json, *.safetensors, tokenizer.json) 71 | - `OUTPUT_PATH`: Output path for the binary model file 72 | - `--group-size`, `-g`: Quantization group size (default: 64) 73 | 74 | ### `inference` 75 | Runs inference on a binary Qwen3 model. 76 | 77 | **Usage:** 78 | ```bash 79 | qwen3 inference [options] 80 | ``` 81 | **Options:** 82 | - `--temperature`, `-t `: Sampling temperature (default: 1.0) 83 | - `--topp`, `-p `: Top-p nucleus sampling (default: 0.9) 84 | - `--seed`, `-s `: Random seed 85 | - `--context`, `-c `: Context window size (default: max_seq_len) 86 | - `--mode`, `-m `: Mode: `generate` or `chat` (default: chat) 87 | - `--input`, `-i `: Input prompt 88 | - `--system`, `-y `: System prompt (for chat mode) 89 | - `--reasoning`, `-r `: Reasoning mode: 0=no thinking, 1=thinking (default: 0) 90 | 91 | -------------------------------------------------------------------------------- /docs/group_size.md: -------------------------------------------------------------------------------- 1 | Group size is a critical hyperparameter in quantization that affects both accuracy and performance. Let's explain how it's typically determined. 2 | 3 | ## **Common Group Size Values** 4 | 5 | ```rust 6 | // Typical group sizes in practice: 7 | let common_group_sizes = [32, 64, 128, 256, 512, 1024]; 8 | 9 | // Most popular choices: 10 | let group_size = 128; // Good balance for most models 11 | let group_size = 256; // Your code's likely default 12 | ``` 13 | 14 | --- 15 | 16 | ## **Factors That Determine Group Size** 17 | 18 | ### **1. Model Architecture Constraints** 19 | 20 | In the code we have this consideration: 21 | 22 | ```rust 23 | pub fn new(config: ModelConfig, mut group_size: usize) -> Self { 24 | // Adjust group size to fit hidden_dim 25 | while config.dim % group_size as i32 != 0 { 26 | group_size /= 2; // Keep halving until it divides evenly 27 | } 28 | Self { config, group_size } 29 | } 30 | ``` 31 | 32 | **Why this matters:** 33 | ```rust 34 | // Example: LLaMA model 35 | let hidden_dim = 4096; // Must divide evenly 36 | 37 | // Good group sizes: 32, 64, 128, 256, 512, 1024, 2048, 4096 38 | // Bad group sizes: 100, 300, 500 (don't divide evenly) 39 | 40 | // The code automatically fixes this: 41 | let mut group_size = 300; // Bad choice 42 | // After adjustment: 300 → 150 → 75 → 37 → 18 → 9 → 4 43 | // Final group_size = 4 (algorithm can be improved) 44 | ``` 45 | 46 | ### **2. Accuracy vs Compression Trade-off** 47 | 48 | ```rust 49 | // Smaller groups = Better accuracy, More overhead 50 | let small_groups = 32; // High precision, more scales to store 51 | 52 | // Larger groups = Less accuracy, Less overhead 53 | let large_groups = 1024; // Lower precision, fewer scales 54 | ``` 55 | 56 | **Example with real numbers:** 57 | 58 | ```rust 59 | // Tensor: [4096] weights 60 | // Group size 32: 4096/32 = 128 groups → 128 scales (512 bytes overhead) 61 | // Group size 256: 4096/256 = 16 groups → 16 scales (64 bytes overhead) 62 | // Group size 1024: 4096/1024 = 4 groups → 4 scales (16 bytes overhead) 63 | ``` 64 | 65 | ### **3. Hardware Optimization** 66 | 67 | ```rust 68 | // Modern CPUs prefer certain sizes for vectorization 69 | let cpu_friendly = [32, 64, 128, 256]; // Align with SIMD instructions 70 | 71 | // GPU memory coalescing (if using GPU) 72 | let gpu_friendly = [128, 256, 512]; // Align with warp/wavefront sizes 73 | ``` 74 | 75 | --- 76 | 77 | ## **How Group Size Affects Quality** 78 | 79 | Let's show with a concrete example: 80 | 81 | ```rust 82 | // Example weight tensor with mixed scales 83 | let weights = [ 84 | // Large values section 85 | 5.0, -4.2, 3.8, -3.1, 4.5, -2.9, 3.2, -4.1, 86 | // Small values section 87 | 0.01, -0.008, 0.012, -0.015, 0.009, -0.011, 0.007, -0.013 88 | ]; 89 | ``` 90 | 91 | ### **Large Group Size (group_size = 16, one group):** 92 | ```rust 93 | let max_abs = 5.0; // Dominated by large values 94 | let scale = 5.0 / 127.0 = 0.0394; 95 | 96 | // Large values quantize well: 97 | 5.0 / 0.0394 = 127 ✓ 98 | 99 | // Small values lose precision: 100 | 0.01 / 0.0394 = 0.25 → 0 // Becomes zero! ❌ 101 | ``` 102 | 103 | ### **Small Group Size (group_size = 8, two groups):** 104 | ```rust 105 | // Group 1: [5.0, -4.2, 3.8, -3.1, 4.5, -2.9, 3.2, -4.1] 106 | let scale1 = 5.0 / 127.0 = 0.0394; 107 | 108 | // Group 2: [0.01, -0.008, 0.012, -0.015, 0.009, -0.011, 0.007, -0.013] 109 | let scale2 = 0.015 / 127.0 = 0.000118; 110 | 111 | // Now small values preserve precision: 112 | 0.01 / 0.000118 = 85 ✓ // Good quantization! 113 | ``` 114 | 115 | --- 116 | 117 | ## **Different Model Sizes** 118 | ```rust 119 | // Small models (7B parameters) 120 | let group_size = 64; // Can afford smaller groups 121 | 122 | // Medium models (13B-30B parameters) 123 | let group_size = 128; // Standard choice 124 | 125 | // Large models (70B+ parameters) 126 | let group_size = 256; // Reduce memory overhead 127 | ``` 128 | -------------------------------------------------------------------------------- /docs/quantization.md: -------------------------------------------------------------------------------- 1 | Let's explain quantization with simple, concrete examples that show exactly what's happening. 2 | 3 | ## **What is Quantization in Simple Terms?** 4 | 5 | Quantization is like **rounding numbers to save space**, but doing it smartly to preserve accuracy. 6 | 7 | ```rust 8 | // BEFORE: High precision, lots of space 9 | let weights = [0.123456, -0.567890, 0.234567, -0.345678]; // 4 × 4 bytes = 16 bytes 10 | 11 | // AFTER: Lower precision, less space 12 | let weights = [25, -115, 47, -70]; // 4 × 1 byte = 4 bytes (4x smaller!) 13 | let scale = 0.005; // One shared number to "uncompress" 14 | ``` 15 | 16 | --- 17 | 18 | ## **Simple Example: Without Groups (Bad Approach)** 19 | 20 | Let's say we have these weights: 21 | ```rust 22 | let weights = [0.1, -0.8, 0.05, -0.02]; 23 | ``` 24 | 25 | ### **Step 1: Find overall maximum** 26 | ```rust 27 | let max_abs = 0.8; // Largest absolute value 28 | ``` 29 | 30 | ### **Step 2: Calculate scale** 31 | ```rust 32 | let scale = max_abs / 127.0; // 0.8 / 127 = 0.0063 33 | ``` 34 | 35 | ### **Step 3: Quantize each weight** 36 | ```rust 37 | // Formula: quantized = round(weight / scale) 38 | let weight_0 = 0.1 / 0.0063 = 15.87 → 16 39 | let weight_1 = -0.8 / 0.0063 = -126.98 → -127 40 | let weight_2 = 0.05 / 0.0063 = 7.94 → 8 41 | let weight_3 = -0.02 / 0.0063 = -3.17 → -3 42 | 43 | // Result: [16, -127, 8, -3] (i8 values) 44 | ``` 45 | 46 | ### **Step 4: Check accuracy (dequantization)** 47 | ```rust 48 | // To get back original: quantized * scale 49 | let recovered_0 = 16 * 0.0063 = 0.101 (original: 0.1) ✓ Good 50 | let recovered_1 = -127 * 0.0063 = -0.800 (original: -0.8) ✓ Good 51 | let recovered_2 = 8 * 0.0063 = 0.0504 (original: 0.05) ✓ Good 52 | let recovered_3 = -3 * 0.0063 = -0.0189 (original: -0.02) ✓ Good 53 | ``` 54 | 55 | **This works okay, but...** 56 | 57 | --- 58 | 59 | ## **The Problem: Mixed Scales** 60 | 61 | What if we have weights with very different ranges? 62 | 63 | ```rust 64 | let weights = [ 65 | // Group 1: Large values 66 | 10.5, -8.2, 9.1, -7.8, 67 | 68 | // Group 2: Small values 69 | 0.001, -0.002, 0.0015, -0.0008 70 | ]; 71 | ``` 72 | 73 | ### **Using single scale (bad):** 74 | ```rust 75 | let max_abs = 10.5; // Dominated by large values 76 | let scale = 10.5 / 127.0 = 0.0827; 77 | 78 | // Large values quantize well: 79 | 10.5 / 0.0827 = 127 ✓ 80 | 81 | // Small values lose ALL precision: 82 | 0.001 / 0.0827 = 0.012 → 0 // Becomes zero! ❌ 83 | 0.002 / 0.0827 = 0.024 → 0 // Becomes zero! ❌ 84 | ``` 85 | 86 | **Result**: Small weights disappear completely! 87 | 88 | --- 89 | 90 | ## **Solution: Groups with Separate Scales** 91 | 92 | Instead of one scale for everything, use **different scales for different groups**: 93 | 94 | ```rust 95 | let weights = [ 96 | // Group 1: Large values [indices 0-3] 97 | 10.5, -8.2, 9.1, -7.8, 98 | 99 | // Group 2: Small values [indices 4-7] 100 | 0.001, -0.002, 0.0015, -0.0008 101 | ]; 102 | 103 | let group_size = 4; // Process 4 weights at a time 104 | ``` 105 | 106 | ### **Group 1 processing:** 107 | ```rust 108 | let group1 = [10.5, -8.2, 9.1, -7.8]; 109 | let group1_max = 10.5; 110 | let scale1 = 10.5 / 127.0 = 0.0827; 111 | 112 | // Quantize group 1: 113 | let q1 = [127, -99, 110, -94]; // Good precision! 114 | ``` 115 | 116 | ### **Group 2 processing:** 117 | ```rust 118 | let group2 = [0.001, -0.002, 0.0015, -0.0008]; 119 | let group2_max = 0.002; 120 | let scale2 = 0.002 / 127.0 = 0.0000157; // Much smaller scale! 121 | 122 | // Quantize group 2: 123 | let q2 = [64, -127, 95, -51]; // Good precision preserved! 124 | ``` 125 | 126 | ### **Verify accuracy:** 127 | ```rust 128 | // Group 1 recovery: 129 | 127 * 0.0827 = 10.51 (original: 10.5) ✓ 130 | -99 * 0.0827 = -8.19 (original: -8.2) ✓ 131 | 132 | // Group 2 recovery: 133 | 64 * 0.0000157 = 0.001 (original: 0.001) ✓ 134 | -127 * 0.0000157 = -0.002 (original: -0.002) ✓ 135 | ``` 136 | 137 | **Much better!** Both large and small values preserve precision. 138 | 139 | --- 140 | 141 | ## ** `model_explorer.rs` Code Step by Step** 142 | 143 | Let's trace through actual code with a concrete example: 144 | 145 | ```rust 146 | // Example input tensor 147 | let weights = [2.0, -1.5, 0.8, -0.3, 0.01, -0.02, 0.005, -0.001]; 148 | let group_size = 4; 149 | ``` 150 | 151 | ### **Step 1: Split into groups** 152 | ```rust 153 | // Your code: (0..num_groups).into_par_iter() 154 | let num_groups = weights.len() / group_size; // 8 / 4 = 2 groups 155 | 156 | // Group 0: indices 0-3 → [2.0, -1.5, 0.8, -0.3] 157 | // Group 1: indices 4-7 → [0.01, -0.02, 0.005, -0.001] 158 | ``` 159 | 160 | ### **Step 2: Process each group in parallel** 161 | ```rust 162 | // Group 0 processing: 163 | let group = [2.0, -1.5, 0.8, -0.3]; 164 | 165 | // Find max absolute value 166 | let group_max = group.iter().map(|&x| x.abs()).fold(0.0f32, f32::max); 167 | // group_max = max(2.0, 1.5, 0.8, 0.3) = 2.0 168 | 169 | // Calculate scale 170 | let scale = if group_max > 0.0 { 171 | group_max / 127.0 // 2.0 / 127 = 0.0157 172 | } else { 173 | 1.0 174 | }; 175 | ``` 176 | 177 | ### **Step 3: Quantize each weight in group** 178 | ```rust 179 | let mut group_int8 = Vec::with_capacity(4); 180 | 181 | for &weight in group { // [2.0, -1.5, 0.8, -0.3] 182 | let quantized = (weight / scale).round().clamp(-127.0, 127.0) as i8; 183 | group_int8.push(quantized); 184 | } 185 | 186 | // Calculations: 187 | // 2.0 / 0.0157 = 127.4 → 127 188 | // -1.5 / 0.0157 = -95.5 → -96 189 | // 0.8 / 0.0157 = 51.0 → 51 190 | // -0.3 / 0.0157 = -19.1 → -19 191 | 192 | // group_int8 = [127, -96, 51, -19] 193 | ``` 194 | 195 | ### **Step 4: Calculate error** 196 | ```rust 197 | let mut group_error = 0.0f32; 198 | 199 | for (quantized, original) in group_int8.iter().zip(group.iter()) { 200 | let dequantized = *quantized as f32 * scale; 201 | let error = (dequantized - original).abs(); 202 | group_error = group_error.max(error); 203 | } 204 | 205 | // Check errors: 206 | // 127 * 0.0157 = 1.994 vs 2.0 → error = 0.006 207 | // -96 * 0.0157 = -1.507 vs -1.5 → error = 0.007 ← max error 208 | ``` 209 | 210 | ### **Step 5: Same process for Group 1** 211 | ```rust 212 | // Group 1: [0.01, -0.02, 0.005, -0.001] 213 | // group_max = 0.02 214 | // scale = 0.02 / 127 = 0.000157 215 | // quantized = [64, -127, 32, -6] 216 | ``` 217 | 218 | ### **Step 6: Combine results** 219 | ```rust 220 | // Final result: 221 | let int8_data = [127, -96, 51, -19, 64, -127, 32, -6]; // 8 bytes 222 | let scales = [0.0157, 0.000157]; // 2 scales (8 bytes) 223 | // Total: 16 bytes vs original 32 bytes = 50% compression 224 | ``` 225 | 226 | --- 227 | 228 | ## **Why This Works So Well** 229 | 230 | ### **Memory Savings:** 231 | ```rust 232 | // Original: 8 weights × 4 bytes = 32 bytes 233 | let original = [2.0_f32, -1.5, 0.8, -0.3, 0.01, -0.02, 0.005, -0.001]; 234 | 235 | // Quantized: 8 weights × 1 byte + 2 scales × 4 bytes = 16 bytes 236 | let quantized = [127_i8, -96, 51, -19, 64, -127, 32, -6]; // 8 bytes 237 | let scales = [0.0157_f32, 0.000157]; // 8 bytes 238 | // Total: 50% size reduction! 239 | ``` 240 | 241 | ### **Precision Preservation:** 242 | ```rust 243 | // Without groups: Small values → 0 (lost!) 244 | // With groups: Small values → [64, -127, 32, -6] (preserved!) 245 | ``` 246 | 247 | --- 248 | 249 | ## **Real LLM Example** 250 | 251 | For a real transformer layer weight matrix: 252 | 253 | ```rust 254 | // Attention weight matrix: [4096, 4096] = 16M parameters 255 | // group_size = 256 256 | // num_groups = 16M / 256 = 65,536 groups 257 | 258 | // Each group gets its own scale → better precision across the huge matrix 259 | // Parallel processing → uses all CPU cores 260 | // Memory efficient → process one group at a time 261 | ``` 262 | 263 | **Result**: 70B parameter models compress from 280GB → 70GB with minimal accuracy loss! 264 | 265 | The magic is that **different parts of neural networks have different value ranges**, and group-wise quantization adapts to preserve precision everywhere. 🎯 -------------------------------------------------------------------------------- /docs/temperature.md: -------------------------------------------------------------------------------- 1 | # Temperature sampling 2 | 3 | The **theoretical range for temperature is [0, ∞)**, but in practice, most useful values fall within a much smaller range. 4 | 5 | Here's what's happening with different temperature values: 6 | 7 | ## Temperature Ranges and Effects 8 | 9 | **Temperature = 0**: Greedy sampling (deterministic, always picks highest probability token) 10 | 11 | **Temperature ∈ (0, 1)**: 12 | - Makes the distribution more "sharp" (concentrates probability on high-scoring tokens) 13 | - Values like 0.1-0.8 are common for more focused/conservative generation 14 | 15 | **Temperature = 1**: 16 | - No scaling applied (uses raw softmax probabilities) 17 | - This is often considered the "neutral" baseline 18 | 19 | **Temperature > 1**: 20 | - Makes distribution more "flat" (spreads probability more evenly) 21 | - Values like 1.2-2.0 can work for more creative/diverse generation 22 | - Very high values (>3.0) tend to produce mostly random text 23 | 24 | ## Why Very High Temperatures Cause Issues 25 | 26 | When temperature gets very large (say, >5.0), the logits become very small after division, and the softmax essentially becomes uniform. This means you're sampling almost randomly from your vocabulary, which produces nonsensical text. 27 | 28 | ## Practical Recommendations 29 | 30 | Most practitioners use temperature values in these ranges: 31 | - **Conservative/Focused**: 0.1 - 0.7 32 | - **Balanced**: 0.7 - 1.2 33 | - **Creative/Diverse**: 1.2 - 2.0 34 | - **Experimental**: 2.0 - 3.0 35 | -------------------------------------------------------------------------------- /docs/tokenizer.md: -------------------------------------------------------------------------------- 1 | # Understanding Tokenization in Large Language Models 2 | 3 | Let/s explain tokenization from the LLM theory perspective and why the `tokenizer_exporter.rs` code is essential for inference. 4 | 5 | --- 6 | 7 | ## **What Are Tokens in LLMs?** 8 | 9 | **Tokens are the fundamental units** that LLMs understand. They're like "words" in the model's vocabulary, but more flexible: 10 | 11 | ```rust 12 | // Human text: "Hello, world!" 13 | // LLM sees: [15496, 11, 1917, 0] // Token IDs 14 | 15 | // Each number represents a learned piece of language: 16 | // 15496 -> "Hello" 17 | // 11 -> "," 18 | // 1917 -> " world" 19 | // 0 -> "!" 20 | ``` 21 | 22 | **Why not just use characters or words?** 23 | - **Characters**: Too granular (millions of combinations) 24 | - **Words**: Too rigid (can't handle "unhappiness" if only learned "happy") 25 | - **Tokens**: Perfect middle ground (subword pieces) 26 | 27 | --- 28 | 29 | ## **Byte Pair Encoding (BPE): The Core Algorithm** 30 | 31 | Modern LLMs use **BPE** to learn optimal token boundaries from training data. 32 | 33 | ### **How BPE Training Works:** 34 | 35 | ```python 36 | # 1. Start with character-level vocabulary 37 | vocab = {"h", "e", "l", "o", "w", "r", "d"} 38 | 39 | # 2. Count character pair frequencies in training data 40 | pairs = { 41 | "he": 1000, # "he" appears 1000 times 42 | "ll": 800, # "ll" appears 800 times 43 | "lo": 600, # "lo" appears 600 times 44 | } 45 | 46 | # 3. Merge most frequent pair 47 | vocab.add("he") # Now "hello" becomes "he" + "l" + "l" + "o" 48 | 49 | # 4. Repeat until desired vocabulary size (e.g., 50,000 tokens) 50 | ``` 51 | 52 | ### **Why This Is Brilliant:** 53 | 54 | ```rust 55 | // Common words become single tokens: 56 | "the" -> [1965] // High frequency = single token 57 | 58 | // Rare words get decomposed: 59 | "antidisestablishmentarianism" -> [4523, 1234, 8901, 2345] // Multiple tokens 60 | 61 | // Unknown words still work: 62 | "supercalifragilisticexpialidocious" -> [many_tokens] // Never fails! 63 | ``` 64 | 65 | This is exactly what `extract_merge_ranks()` extracts from the HuggingFace tokenizer - those learned merge rules. 66 | 67 | --- 68 | 69 | ## **The Mathematical Foundation** 70 | 71 | ### **Token Embeddings:** 72 | Each token ID maps to a high-dimensional vector: 73 | 74 | ```rust 75 | // Token ID 1000 ("hello") -> [0.1, -0.3, 0.8, ..., 0.2] // 4096 dimensions 76 | // This vector captures semantic meaning learned during training 77 | ``` 78 | 79 | ### **Attention Mechanism:** 80 | ```rust 81 | // LLM processes: "Hello world" 82 | // Token IDs: [1000, 1001] 83 | // Embeddings: [[0.1, -0.3, ...], [0.5, 0.2, ...]] 84 | // 85 | // Attention computes relationships: 86 | // How much should "world" pay attention to "Hello"? 87 | // Result: Rich contextual understanding 88 | ``` 89 | 90 | ### **Text Generation:** 91 | ```rust 92 | // Given context: "The cat sat on the" 93 | // Model outputs probability distribution over ALL tokens: 94 | // P(token_0) = 0.001 // "!" 95 | // P(token_1) = 0.002 // "a" 96 | // ... 97 | // P(token_5431) = 0.847 // "mat" <- highest probability 98 | // 99 | // Sample from distribution -> generate "mat" 100 | // Result: "The cat sat on the mat" 101 | ``` 102 | 103 | --- 104 | 105 | ## **Why Custom Binary Format Matters** 106 | 107 | ### **The Inference Challenge:** 108 | 109 | During text generation, the model performs **millions of token lookups**: 110 | 111 | ```rust 112 | // For each generated token: 113 | // 1. Model outputs: token_id = 1000 114 | // 2. Need fast lookup: 1000 -> "hello" 115 | // 3. Append to output: "The cat says hello" 116 | // 4. Repeat... 117 | 118 | // For a 100-word response = ~150 tokens = 150 lookups 119 | // For real-time chat = need sub-millisecond lookups! 120 | ``` 121 | 122 | ### **JSON vs Binary Performance:** 123 | 124 | ```rust 125 | // HuggingFace tokenizer.json approach: 126 | let token = tokenizer.decode(1000)?; 127 | // 1. Parse JSON structure 128 | // 2. Navigate nested objects 129 | // 3. Hash table lookup 130 | // 4. String allocation 131 | // Time: ~50 microseconds per lookup 132 | 133 | // Custom binary approach (our export_tokenizer): 134 | let token = binary_tokenizer.decode(1000)?; 135 | // 1. Direct memory access: base_ptr + (1000 * token_size) 136 | // 2. Read token data directly 137 | // Time: ~0.5 microseconds per lookup (100x faster!) 138 | ``` 139 | 140 | --- 141 | 142 | ## **Binary Format Design Decisions** 143 | 144 | ### **Why Sort Tokens by ID:** 145 | 146 | ```rust 147 | // tokens_by_id.sort_by_key(|&(id, _)| id); 148 | 149 | // This enables O(1) lookup during inference: 150 | // To find token 1000: seek to position (1000 * RECORD_SIZE) 151 | // No searching, no hash tables - just arithmetic! 152 | ``` 153 | 154 | ### **Why Store Token Scores:** 155 | 156 | The scoring system in `write_tokenizer_binary()` implements **BPE priority**: 157 | 158 | ```rust 159 | let score = if let Some(&rank) = merge_ranks.get(token) { 160 | -((rank + 1) as f32).ln() // Lower merge rank = higher priority 161 | } else { 162 | Self::DEFAULT_SCORE // Base tokens get low priority 163 | }; 164 | ``` 165 | 166 | **During tokenization:** 167 | ```rust 168 | // Input: "hello" 169 | // Possible tokenizations: 170 | // Option 1: ["h", "e", "l", "l", "o"] // Score: 5 * (-10) = -50 171 | // Option 2: ["he", "ll", "o"] // Score: (-2) + (-3) + (-10) = -15 172 | // Option 3: ["hello"] // Score: (-1) = -1 ✓ BEST 173 | // 174 | // Choose highest total score = most efficient tokenization 175 | ``` 176 | 177 | ### **Why Unicode Mapping (`create_unicode_to_byte_map`):** 178 | 179 | ```rust 180 | // Problem: Tokens can contain ANY Unicode character 181 | let problematic_token = "café🤖"; 182 | 183 | // Solution: Convert to consistent byte representation 184 | // 'c' -> 99 (ASCII) 185 | // 'a' -> 97 (ASCII) 186 | // 'f' -> 102 (ASCII) 187 | // 'é' -> 233 (mapped using GPT-2 scheme) 188 | // '🤖' -> [240, 159, 164, 150] (UTF-8 bytes) 189 | 190 | // Now we can store ANY token as bytes in binary file 191 | ``` 192 | 193 | --- 194 | 195 | ## **LLM Training vs Inference Perspective** 196 | 197 | ### **Training Time (One-time):** 198 | ```rust 199 | // 1. Learn BPE merges from massive text corpus 200 | // 2. Build vocabulary of ~50,000 tokens 201 | // 3. Train transformer weights 202 | // 4. Save as HuggingFace format (human-readable) 203 | ``` 204 | 205 | ### **Inference Time (Every user interaction):** 206 | ```rust 207 | // 1. FAST tokenization: text -> token_ids 208 | // 2. Model forward pass: token_ids -> probabilities 209 | // 3. FAST detokenization: sampled_token_id -> text 210 | // 4. Repeat for each generated token 211 | // 212 | // Steps 1 & 3 must be BLAZING fast! 213 | ``` 214 | 215 | **This is why `export_tokenizer()` exists** - to optimize the bottlenecks! 216 | 217 | --- 218 | 219 | ## **Real-World Impact** 220 | 221 | ### **Before Optimization (JSON tokenizer):** 222 | ```rust 223 | // Generate "Hello, how are you today?" 224 | // ~8 tokens × 50μs lookup = 400μs tokenization overhead 225 | // For streaming chat, this causes noticeable lag 226 | ``` 227 | 228 | ### **After Optimization (Binary tokenizer):** 229 | ```rust 230 | // Same generation: 8 tokens × 0.5μs = 4μs overhead 231 | // 100x speedup = imperceptible to users 232 | // Enables real-time streaming generation ⚡ 233 | ``` 234 | 235 | ### **Memory Efficiency:** 236 | ```rust 237 | // Qwen3-1.7B tokenizer: 238 | // tokenizer.json: 5MB (nested JSON, metadata) 239 | // .bin.tokenizer: 2MB (pure token data) 240 | // 241 | // 2.5x space savings + structured for fast access 242 | ``` 243 | 244 | --- 245 | 246 | ## **The Complete LLM Pipeline** 247 | 248 | ```rust 249 | // User input: "What is the capital of France?" 250 | 251 | // 1. TOKENIZATION (our optimized binary tokenizer) 252 | let tokens = tokenizer.encode("What is the capital of France?")?; 253 | // -> [3923, 374, 279, 6864, 315, 9822, 30] 254 | 255 | // 2. MODEL INFERENCE (quantized weights from model_exporter) 256 | let logits = model.forward(&tokens)?; 257 | // -> [0.001, 0.002, ..., 0.847, ...] // 50k probabilities 258 | 259 | // 3. SAMPLING 260 | let next_token_id = sample_from_distribution(&logits)?; 261 | // -> 3842 // "Paris" 262 | 263 | // 4. DETOKENIZATION (our optimized binary tokenizer) 264 | let token_text = tokenizer.decode(next_token_id)?; 265 | // -> "Paris" 266 | 267 | // 5. REPEAT until EOS token 268 | // Final: "What is the capital of France? Paris is the capital of France." 269 | ``` 270 | 271 | **Every step must be optimized** for real-time inference - that's why both `model_exporter.rs` (quantized weights) and `tokenizer_exporter.rs` (binary tokens) exist! 272 | 273 | --- 274 | 275 | ## **Key Insights** 276 | 277 | 1. **🧠 Tokens are the LLM's "words"** - learned subword pieces that balance flexibility with efficiency 278 | 279 | 2. **⚡ Inference speed matters** - millions of token lookups per conversation require microsecond performance 280 | 281 | 3. **🗜️ Custom formats win** - HuggingFace formats optimize for compatibility, our formats optimize for speed 282 | 283 | 4. **📊 Data structure = algorithm** - sorting tokens by ID enables O(1) lookup instead of O(log n) search 284 | 285 | 5. **🎯 Every microsecond counts** - in real-time AI, tokenization overhead is the difference between smooth and laggy user experience 286 | 287 | The `tokenizer_exporter.rs` code transforms a general-purpose tokenizer into a speed-optimized inference engine component! 🚀 -------------------------------------------------------------------------------- /docs/transformer.md: -------------------------------------------------------------------------------- 1 | ## 🔬 Model Architecture Summary 2 | 3 | The project uses a **decoder-only Transformer** model with optimizations such as **Grouped Query Attention (GQA)**, **Rotary Position Embedding (RoPE)**, **INT8 Quantization**, and **RMSNorm**. The architecture closely follows Qwen3 and is designed for efficient autoregressive language modeling. 4 | 5 | --- 6 | 7 | ### 🧠 Layer Order in Forward Pass 8 | 9 | 1. **Token Embedding** 10 | - Converts discrete token IDs into continuous vector representations. 11 | - Shape: `[vocab_size, dim]` 12 | - Shared weights with `lm_head` if enabled. 13 | 14 | 2. **Transformer Blocks (N times)** 15 | Each block contains: 16 | - **Pre-Attention RMSNorm**: Normalizes input before attention. 17 | $$ 18 | \text{RMSNorm}(x) = x \cdot \frac{\gamma}{\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}} 19 | $$ 20 | - **Multi-Head Self-Attention with GQA and RoPE** 21 | $$ 22 | \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 23 | $$ 24 | - Queries are projected using `wq`, keys with `wk`, values with `wv`. 25 | - Keys/Queries are normalized via `QK-RMSNorm` and rotated via RoPE. 26 | - Grouped heads reduce memory usage. 27 | - **Residual Connection** after attention. 28 | - **Pre-FFN RMSNorm**: Same formula as above. 29 | - **SwiGLU Feed-Forward Network** 30 | $$ 31 | \text{SwiGLU}(x) = \sigma(xW_1) \odot (xW_3) W_2 32 | $$ 33 | where: 34 | - $ W_1 $: gate projection 35 | - $ W_3 $: up-projection 36 | - $ W_2 $: down-projection 37 | - $ \sigma $: sigmoid function 38 | - **Residual Connection** after FFN. 39 | 40 | 3. **Final RMSNorm** 41 | - Applied to the final hidden state before output. 42 | $$ 43 | \text{RMSNorm}(x) = x \cdot \frac{\gamma}{\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}} 44 | $$ 45 | 46 | 4. **Language Model Head (`lm_head`)** 47 | - Linear projection from hidden dimension to vocabulary size. 48 | - Optionally shares weights with `token_embedding`. 49 | 50 | --- 51 | 52 | ### 📐 Key Components & Formulas 53 | 54 | #### 1. **RMSNorm (Root Mean Square Layer Normalization)** 55 | $$ 56 | \text{RMSNorm}(x) = x \cdot \frac{\gamma}{\sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2 + \epsilon}} 57 | $$ 58 | - No mean subtraction, only variance normalization. 59 | - Used before attention and FFN blocks. 60 | 61 | #### 2. **Rotary Position Embedding (RoPE)** 62 | - Encodes position information directly into query and key vectors. 63 | - For each head and dimension pair: 64 | $$ 65 | \text{freq} = \theta^{-2i/d}, \quad i \in [0, d/2) 66 | $$ 67 | $$ 68 | \text{angle} = pos \cdot freq 69 | $$ 70 | $$ 71 | \begin{bmatrix} 72 | x' \\ 73 | y' 74 | \end{bmatrix} 75 | = 76 | \begin{bmatrix} 77 | \cos(\text{angle}) & -\sin(\text{angle}) \\ 78 | \sin(\text{angle}) & \cos(\text{angle}) 79 | \end{bmatrix} 80 | \begin{bmatrix} 81 | x \\ 82 | y 83 | \end{bmatrix} 84 | $$ 85 | 86 | #### 3. **Multi-Head Attention (MHA)** 87 | $$ 88 | \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 89 | $$ 90 | - With **Grouped Query Attention (GQA)**: 91 | $$ 92 | n_{kv\_heads} < n_{heads} 93 | $$ 94 | Multiple query heads share the same key/value heads, reducing KV cache size. 95 | 96 | #### 4. **SwiGLU Activation Function** 97 | $$ 98 | \text{SwiGLU}(x) = \sigma(xW_1) \odot (xW_3) W_2 99 | $$ 100 | - Combines gated activation with linear transformation. 101 | - Enhances model expressiveness. 102 | 103 | #### 5. **Quantized Linear Layers** 104 | - Weights stored in INT8 format: 105 | $$ 106 | w_{quantized} = \text{round}(w / s), \quad w = s \cdot w_{quantized} 107 | $$ 108 | - Dynamic dequantization during forward pass improves efficiency without much accuracy loss. 109 | 110 | --- 111 | 112 | ### 📦 PyTorch-like Module Hierarchy 113 | 114 | ```python 115 | class Transformer(nn.Module): 116 | def __init__(self): 117 | self.token_embedding = TokenEmbedding(vocab_size, dim) 118 | self.blocks = nn.ModuleList([ 119 | TransformerBlock(dim, n_heads, head_dim, hidden_dim, group_size) 120 | for _ in range(n_layers) 121 | ]) 122 | self.final_norm = RMSNorm(dim) 123 | self.lm_head = Linear(dim, vocab_size) 124 | 125 | def forward(self, token_id, pos): 126 | x = self.token_embedding(token_id) 127 | for block in self.blocks: 128 | x = block(x, pos) 129 | x = self.final_norm(x) 130 | logits = self.lm_head(x) 131 | return logits 132 | ``` 133 | 134 | Each `TransformerBlock` includes: 135 | 136 | ```python 137 | class TransformerBlock(nn.Module): 138 | def __init__(): 139 | self.attn_norm = RMSNorm(dim) 140 | self.attention = MultiHeadAttention(...) 141 | self.ffn_norm = RMSNorm(dim) 142 | self.feed_forward = FeedForward(...) # SwiGLU-based 143 | ``` 144 | 145 | --- 146 | 147 | ### ⚙️ Inference Flow (Mathematically) 148 | 149 | ```plaintext 150 | Input Token ID → TokenEmbedding → N × TransformerBlock → Final RMSNorm → lm_head → Logits 151 | ``` 152 | 153 | Where each `TransformerBlock` does: 154 | 155 | 1. **Attention Path** 156 | - $ x_{norm} = \text{RMSNorm}(x) $ 157 | - $ Q, K, V = x_{norm}W_q, x_{norm}W_k, x_{norm}W_v $ 158 | - Apply RoPE to $ Q, K $ 159 | - Compute $ A = \text{softmax}(QK^T/\sqrt{d_k})V $ 160 | - $ x = x + A $ 161 | 162 | 2. **Feed-Forward Path** 163 | - $ x_{norm} = \text{RMSNorm}(x) $ 164 | - $ G = \sigma(x_{norm}W_1) $ 165 | - $ U = x_{norm}W_3 $ 166 | - $ F = (G \odot U)W_2 $ 167 | - $ x = x + F $ 168 | 169 | --- 170 | 171 | ### ✅ Summary of Optimizations 172 | 173 | | Feature | Description | Benefit | 174 | |--------|-------------|---------| 175 | | **Grouped Query Attention (GQA)** | Reduces number of KV heads | Lowers memory usage | 176 | | **Rotary Position Embedding (RoPE)** | Relative positional encoding | Better extrapolation | 177 | | **RMSNorm** | Simplified normalization | Faster and more stable | 178 | | **SwiGLU** | Gated non-linearity | Increased model capacity | 179 | | **INT8 Quantization** | Stores weights in 8-bit integers | Saves memory, faster inference | 180 | 181 | --- 182 | 183 | ### Educational Insights 184 | 185 | 1. **Why RMSNorm?** 186 | - Removes mean-centering for faster computation 187 | - Works well when combined with residual connections 188 | - Original paper: https://arxiv.org/abs/1910.07467 189 | 190 | 2. **Why Rotary Embeddings?** 191 | - Relative positions handled naturally via rotation 192 | - No position embedding learned parameters 193 | - Original paper: https://arxiv.org/abs/2104.09864 194 | 195 | 3. **GQA Tradeoffs** 196 | - Memory: Reduces KV cache by `n_heads/n_kv_heads` 197 | - Quality: Minimal impact when ratio ≤ 8:1 198 | - Paper: https://arxiv.org/abs/2305.13245 199 | 200 | 4. **SwiGLU Benefits** 201 | - More parameters than standard FFN (W1,W3 vs single W1) 202 | - Better modeling of complex interactions 203 | - From PaLM paper: https://arxiv.org/abs/2204.02311 204 | 205 | This architecture represents modern best practices for efficient LLM design, combining memory optimizations (GQA, quantization) with high-performance components (RoPE, SwiGLU). -------------------------------------------------------------------------------- /qwen3-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "qwen3-cli" 3 | description = "Command-line interface for qwen3 library" 4 | version.workspace = true 5 | edition.workspace = true 6 | authors.workspace = true 7 | license.workspace = true 8 | keywords.workspace = true 9 | categories.workspace = true 10 | repository.workspace = true 11 | 12 | [[bin]] 13 | name = "qwen3" 14 | path = "src/main.rs" 15 | 16 | [dependencies] 17 | qwen3-export = { workspace = true } 18 | qwen3-inference = { workspace = true } 19 | 20 | anyhow = { workspace = true } 21 | clap = { workspace = true } 22 | log = { workspace = true } 23 | env_logger = { workspace = true } 24 | -------------------------------------------------------------------------------- /qwen3-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use anyhow::Result; 4 | use clap::{Arg, ArgMatches, Command}; 5 | use log::{error, info}; 6 | use qwen3_export::export_model; 7 | use qwen3_inference::{InferenceConfigBuilder, TransformerBuilder, run_inference}; 8 | 9 | /// Define the export subcommand. 10 | fn export_subcommand() -> Command { 11 | Command::new("export") 12 | .about("Export Qwen3 model from HuggingFace format to custom binary format") 13 | .arg( 14 | Arg::new("MODEL_PATH") 15 | .help("Path to the HuggingFace model directory (containing config.json, *.safetensors, tokenizer.json)") 16 | .required(true) 17 | .index(1), 18 | ) 19 | .arg( 20 | Arg::new("OUTPUT_PATH") 21 | .help("Output path for the binary model file (without extension)") 22 | .required(true) 23 | .index(2), 24 | ) 25 | .arg( 26 | Arg::new("group-size") 27 | .long("group-size") 28 | .short('g') 29 | .help("Quantization group size") 30 | .value_name("SIZE") 31 | .default_value("64"), 32 | ) 33 | } 34 | 35 | /// Define the inference subcommand. 36 | fn inference_subcommand() -> Command { 37 | Command::new("inference") 38 | .about("Qwen3 inference in Rust") 39 | .arg(Arg::new("checkpoint").help("Model checkpoint file").required(true).index(1)) 40 | .arg( 41 | Arg::new("temperature") 42 | .short('t') 43 | .long("temperature") 44 | .value_name("FLOAT") 45 | .help("Temperature for sampling in [0, inf], default 1.0") 46 | .default_value("1.0") 47 | .value_parser(clap::value_parser!(f32)), 48 | ) 49 | .arg( 50 | Arg::new("topp") 51 | .short('p') 52 | .long("topp") 53 | .value_name("FLOAT") 54 | .help("Top-p for nucleus sampling in [0,1], default 0.9") 55 | .default_value("0.9") 56 | .value_parser(clap::value_parser!(f32)), 57 | ) 58 | .arg( 59 | Arg::new("seed") 60 | .short('s') 61 | .long("seed") 62 | .value_name("INT") 63 | .help("Random seed") 64 | .value_parser(clap::value_parser!(u64)), 65 | ) 66 | .arg( 67 | Arg::new("context") 68 | .short('c') 69 | .long("context") 70 | .value_name("INT") 71 | .help("Context window size, (default) = max_seq_len") 72 | .value_parser(clap::value_parser!(u32)), 73 | ) 74 | .arg( 75 | Arg::new("mode") 76 | .short('m') 77 | .long("mode") 78 | .value_name("STRING") 79 | .help("Mode: generate|chat [default: chat]") 80 | .default_value("chat"), 81 | ) 82 | .arg(Arg::new("input").short('i').long("input").value_name("STRING").help("Input prompt")) 83 | .arg(Arg::new("system").short('y').long("system").value_name("STRING").help("System prompt in chat mode")) 84 | .arg( 85 | Arg::new("reasoning") 86 | .short('r') 87 | .long("reasoning") 88 | .value_name("INT") 89 | .help("Reasoning mode: 0=no thinking, 1=thinking [default: 0]") 90 | .default_value("0") 91 | .value_parser(clap::value_parser!(i32)), 92 | ) 93 | } 94 | 95 | /// Run the export command with the provided arguments 96 | fn run_export_command(matches: &ArgMatches) -> Result<()> { 97 | let model_path = matches.get_one::("MODEL_PATH").unwrap(); 98 | let output_path = matches.get_one::("OUTPUT_PATH").unwrap(); 99 | let group_size: usize = 100 | matches.get_one::("group-size").unwrap().parse().map_err(|_| anyhow::anyhow!("Invalid group size"))?; 101 | 102 | // Validate input path 103 | let model_dir = Path::new(model_path); 104 | if !model_dir.exists() { 105 | anyhow::bail!("Model directory does not exist: {model_path}"); 106 | } 107 | 108 | let config_path = model_dir.join("config.json"); 109 | let adapter_config_path = model_dir.join("adapter_config.json"); 110 | 111 | if !config_path.exists() && !adapter_config_path.exists() { 112 | anyhow::bail!("Neither config.json nor adapter_config.json found in model directory") 113 | } 114 | 115 | let tokenizer_path = model_dir.join("tokenizer.json"); 116 | if !tokenizer_path.exists() { 117 | anyhow::bail!("tokenizer.json not found in model directory"); 118 | } 119 | 120 | // Check for safetensors files 121 | let has_safetensors = std::fs::read_dir(model_dir)?.any(|entry| { 122 | if let Ok(entry) = entry { 123 | entry.path().extension().and_then(|ext| ext.to_str()).map(|ext| ext == "safetensors").unwrap_or(false) 124 | } else { 125 | false 126 | } 127 | }); 128 | 129 | if !has_safetensors { 130 | anyhow::bail!("No .safetensors files found in model directory"); 131 | } 132 | 133 | info!(""); 134 | info!("🚀 Qwen3 Model Exporter"); 135 | info!("📁 Model path: {model_path}"); 136 | info!("💾 Output path: {output_path}"); 137 | info!("🔢 Group size: {group_size}\n"); 138 | 139 | export_model(model_path, output_path, group_size)?; 140 | 141 | Ok(()) 142 | } 143 | 144 | /// Run the inference command with the provided arguments 145 | fn run_inference_command(matches: &ArgMatches) -> Result<()> { 146 | let config = InferenceConfigBuilder::default() 147 | .checkpoint_path(matches.get_one::("checkpoint")) 148 | .temperature(matches.get_one::("temperature").copied()) 149 | .topp(matches.get_one::("topp").copied()) 150 | .ctx_length(matches.get_one::("context").copied()) 151 | .mode(matches.get_one::("mode")) 152 | .prompt(matches.get_one::("input")) 153 | .system_prompt(matches.get_one::("system")) 154 | .enable_thinking(matches.get_one::("reasoning").map(|v| *v != 0)) 155 | .seed(matches.get_one::("seed").copied()) 156 | .build() 157 | .map_err(|e| anyhow::anyhow!(e))?; 158 | 159 | let transformer = TransformerBuilder::new(&config.checkpoint_path).with_ctx_length(config.ctx_length).build()?; 160 | 161 | run_inference(transformer, config).map_err(|e| anyhow::anyhow!("Inference failed: {e}"))?; 162 | 163 | Ok(()) 164 | } 165 | 166 | fn execute_commands() -> Result<()> { 167 | // Initialize logger with clean format (no timestamp/module prefix) and use info level by default 168 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) 169 | .format(|buf, record| { 170 | use std::io::Write; 171 | writeln!(buf, "{}", record.args()) 172 | }) 173 | .init(); 174 | 175 | let matches = Command::new("qwen3") 176 | .about("Qwen3 CLI: an educational tool for exporting and running Qwen3 models") 177 | .subcommand(export_subcommand()) 178 | .subcommand(inference_subcommand()) 179 | .get_matches(); 180 | 181 | match matches.subcommand() { 182 | Some(("export", matches)) => run_export_command(matches), 183 | Some(("inference", matches)) => run_inference_command(matches), 184 | _ => anyhow::bail!("No subcommand specified. Use -h to print help information."), 185 | } 186 | } 187 | 188 | fn main() { 189 | if let Err(e) = execute_commands() { 190 | error!("Error: {e}"); 191 | std::process::exit(1); 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /qwen3-export/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "qwen3-export" 3 | description = "A Rust library for exporting Qwen3 models to binary format" 4 | version.workspace = true 5 | edition.workspace = true 6 | authors.workspace = true 7 | license.workspace = true 8 | keywords.workspace = true 9 | categories.workspace = true 10 | repository.workspace = true 11 | 12 | [dependencies] 13 | anyhow = { workspace = true } 14 | byteorder = { workspace = true } 15 | rayon = { workspace = true } 16 | safetensors = { workspace = true } 17 | serde_json = { workspace = true } 18 | serde = { workspace = true } 19 | memmap2 = { workspace = true } 20 | log = { workspace = true } 21 | 22 | [dev-dependencies] 23 | tempfile = "3.0" 24 | 25 | [features] 26 | default = [] 27 | -------------------------------------------------------------------------------- /qwen3-export/src/chat_template_exporter.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | use log::info; 3 | use serde_json::Value; 4 | use std::path::Path; 5 | 6 | /// Chat template exporter for model chat templates 7 | /// 8 | /// This module creates chat templates based on the model's Jinja2 template 9 | /// from tokenizer_config.json. It generates templates that match the exact 10 | /// output of the Python reference implementation. 11 | /// 12 | /// Supported Templates: 13 | /// - Qwen3: Uses `<|im_start|>role\ncontent<|im_end|>\n` format with thinking support 14 | /// - DeepSeek-R1: Uses `<|User|>content<|Assistant|>` format without thinking 15 | /// 16 | /// Generated Templates: 17 | /// - `.template`: Single user message with thinking disabled 18 | /// - `.template.with-thinking`: Single user message with thinking enabled (if supported) 19 | /// - `.template.with-system`: System + user messages with thinking disabled 20 | /// - `.template.with-system-and-thinking`: System + user messages with thinking enabled (if supported) 21 | #[derive(Debug)] 22 | pub struct ChatTemplateExporter; 23 | 24 | /// Template configuration for different chat formats 25 | #[derive(Debug, Clone)] 26 | pub struct TemplateConfig { 27 | pub suffix: &'static str, 28 | pub description: &'static str, 29 | pub enable_thinking: bool, 30 | pub has_system: bool, 31 | } 32 | 33 | /// Template capabilities detected from the Jinja2 template 34 | #[derive(Debug)] 35 | struct TemplateCapabilities { 36 | supports_thinking: bool, 37 | supports_system: bool, 38 | template_type: TemplateType, 39 | } 40 | 41 | /// Different template types we support 42 | #[derive(Debug, Clone)] 43 | enum TemplateType { 44 | Qwen3, 45 | DeepSeek, 46 | Unknown, 47 | } 48 | 49 | impl ChatTemplateExporter { 50 | // Template suffixes 51 | const BASIC_SUFFIX: &'static str = ".template"; 52 | const WITH_THINKING_SUFFIX: &'static str = ".template.with-thinking"; 53 | const WITH_SYSTEM_SUFFIX: &'static str = ".template.with-system"; 54 | const WITH_SYSTEM_THINKING_SUFFIX: &'static str = ".template.with-system-and-thinking"; 55 | 56 | // Template constants 57 | const TEMPLATE_EXTENSION: &'static str = ".template"; 58 | 59 | /// Create a new ChatTemplateExporter 60 | pub fn new() -> Self { 61 | Self 62 | } 63 | 64 | /// Export chat templates to the specified output path 65 | pub fn export_templates(&self, model_path: &Path, output_path: &Path) -> Result<()> { 66 | // Load chat template from tokenizer config - this is now required 67 | let chat_template = self 68 | .load_chat_template_from_model(model_path) 69 | .with_context(|| format!("Failed to load chat template from model at {}", model_path.display()))? 70 | .ok_or_else(|| { 71 | anyhow::anyhow!("No chat template found in tokenizer_config.json at {}", model_path.display()) 72 | })?; 73 | 74 | // Analyze template capabilities 75 | let capabilities = self.analyze_template_capabilities(&chat_template); 76 | info!("Template type: {:?}", capabilities.template_type); 77 | info!("Template capabilities:"); 78 | info!(" Supports thinking: {}", capabilities.supports_thinking); 79 | info!(" Supports system: {}", capabilities.supports_system); 80 | 81 | self.export_dynamic_templates(output_path, &chat_template, &capabilities) 82 | .with_context(|| format!("Failed to export templates for model at {}", model_path.display())) 83 | } 84 | 85 | /// Analyze template to determine what capabilities it supports 86 | /// A very basic analysis to detect template type and capabilities 87 | fn analyze_template_capabilities(&self, template: &str) -> TemplateCapabilities { 88 | let template_type = if template.contains("<|im_start|>") && template.contains("<|im_end|>") { 89 | TemplateType::Qwen3 90 | } else if template.contains("<|User|>") && template.contains("<|Assistant|>") { 91 | TemplateType::DeepSeek 92 | } else { 93 | TemplateType::Unknown 94 | }; 95 | 96 | let (supports_thinking, supports_system) = match template_type { 97 | TemplateType::Qwen3 => ( 98 | template.contains("enable_thinking"), 99 | template.contains("system") && template.contains("messages[0].role"), 100 | ), 101 | TemplateType::DeepSeek => (template.contains("think"), template.contains("system_prompt")), 102 | TemplateType::Unknown => (false, false), 103 | }; 104 | 105 | TemplateCapabilities { supports_thinking, supports_system, template_type } 106 | } 107 | 108 | /// Get template configurations based on detected capabilities 109 | fn get_template_configs(&self, capabilities: &TemplateCapabilities) -> Vec { 110 | // Maximum possible templates: basic + thinking + system + system_thinking = 4 111 | let mut configs = Vec::with_capacity(4); 112 | 113 | // Always generate basic user template 114 | configs.push(TemplateConfig { 115 | suffix: Self::BASIC_SUFFIX, 116 | description: "basic", 117 | enable_thinking: false, 118 | has_system: false, 119 | }); 120 | 121 | // Add thinking variant if supported 122 | if capabilities.supports_thinking { 123 | configs.push(TemplateConfig { 124 | suffix: Self::WITH_THINKING_SUFFIX, 125 | description: "with thinking", 126 | enable_thinking: true, 127 | has_system: false, 128 | }); 129 | } 130 | 131 | // Add system variants if supported 132 | if capabilities.supports_system { 133 | configs.push(TemplateConfig { 134 | suffix: Self::WITH_SYSTEM_SUFFIX, 135 | description: "with system", 136 | enable_thinking: false, 137 | has_system: true, 138 | }); 139 | 140 | // Only add system + thinking if both are supported 141 | if capabilities.supports_thinking { 142 | configs.push(TemplateConfig { 143 | suffix: Self::WITH_SYSTEM_THINKING_SUFFIX, 144 | description: "with system and thinking", 145 | enable_thinking: true, 146 | has_system: true, 147 | }); 148 | } 149 | } 150 | 151 | configs 152 | } 153 | 154 | /// Load chat template from model's tokenizer_config.json 155 | fn load_chat_template_from_model(&self, model_path: &Path) -> Result> { 156 | let tokenizer_config_path = model_path.join("tokenizer_config.json"); 157 | 158 | if !tokenizer_config_path.exists() { 159 | return Ok(None); 160 | } 161 | 162 | let config_content = std::fs::read_to_string(&tokenizer_config_path) 163 | .with_context(|| format!("Failed to read tokenizer config from {}", tokenizer_config_path.display()))?; 164 | 165 | let config: Value = 166 | serde_json::from_str(&config_content).with_context(|| "Failed to parse tokenizer config JSON")?; 167 | 168 | // Extract chat_template if it exists 169 | Ok(config.get("chat_template").and_then(|v| v.as_str()).map(|s| s.to_string())) 170 | } 171 | 172 | /// Export templates using the dynamic chat template from the model 173 | fn export_dynamic_templates( 174 | &self, 175 | output_path: &Path, 176 | chat_template: &str, 177 | capabilities: &TemplateCapabilities, 178 | ) -> Result<()> { 179 | info!("Using model's chat template for template generation"); 180 | 181 | // Create template variants based on detected capabilities 182 | let template_configs = self.get_template_configs(capabilities); 183 | 184 | info!("Generating {} template variants:", template_configs.len()); 185 | for config in &template_configs { 186 | info!(" - {}", config.description); 187 | } 188 | 189 | template_configs.iter().try_for_each(|config| { 190 | let template_content = self.render_chat_template(chat_template, config, capabilities)?; 191 | let template_path = format!("{}{}", output_path.display(), config.suffix); 192 | 193 | std::fs::write(&template_path, template_content) 194 | .with_context(|| format!("Failed to write template to {template_path}"))?; 195 | 196 | info!("📝 Written {} template: {template_path}", config.description); 197 | Ok::<(), anyhow::Error>(()) 198 | })?; 199 | 200 | info!("💬 All prompt templates written to {}{}.*", output_path.display(), Self::TEMPLATE_EXTENSION); 201 | Ok(()) 202 | } 203 | 204 | /// Render chat template for specific configuration 205 | /// This is a simplified Jinja2 template renderer for different template types 206 | fn render_chat_template( 207 | &self, 208 | _template: &str, 209 | config: &TemplateConfig, 210 | capabilities: &TemplateCapabilities, 211 | ) -> Result { 212 | match capabilities.template_type { 213 | TemplateType::Qwen3 => { 214 | if config.has_system { 215 | self.render_qwen3_system_message_template(config.enable_thinking) 216 | } else { 217 | self.render_qwen3_single_message_template(config.enable_thinking) 218 | } 219 | } 220 | TemplateType::DeepSeek => { 221 | if config.has_system { 222 | self.render_deepseek_system_message_template(config.enable_thinking) 223 | } else { 224 | self.render_deepseek_single_message_template(config.enable_thinking) 225 | } 226 | } 227 | TemplateType::Unknown => Err(anyhow::anyhow!("Unknown template type, cannot render templates")), 228 | } 229 | } 230 | 231 | /// Render Qwen3 template for single user message 232 | fn render_qwen3_single_message_template(&self, enable_thinking: bool) -> Result { 233 | if enable_thinking { 234 | Ok("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n".to_string()) 235 | } else { 236 | Ok("<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n\n\n\n\n".to_string()) 237 | } 238 | } 239 | 240 | /// Render Qwen3 template for system + user messages 241 | fn render_qwen3_system_message_template(&self, enable_thinking: bool) -> Result { 242 | if enable_thinking { 243 | Ok("<|im_start|>system\n%s<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n".to_string()) 244 | } else { 245 | Ok("<|im_start|>system\n%s<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n\n\n\n\n".to_string()) 246 | } 247 | } 248 | 249 | /// Render DeepSeek template for single user message 250 | fn render_deepseek_single_message_template(&self, enable_thinking: bool) -> Result { 251 | if enable_thinking { 252 | Ok("<|User|>%s<|Assistant|>".to_string()) 253 | } else { 254 | Ok("<|User|>%s<|Assistant|>\n".to_string()) 255 | } 256 | } 257 | 258 | /// Render DeepSeek template for system + user messages 259 | fn render_deepseek_system_message_template(&self, enable_thinking: bool) -> Result { 260 | if enable_thinking { 261 | Ok("%s<|User|>%s<|Assistant|>".to_string()) 262 | } else { 263 | Ok("%s<|User|>%s<|Assistant|>\n".to_string()) 264 | } 265 | } 266 | } 267 | 268 | impl Default for ChatTemplateExporter { 269 | fn default() -> Self { 270 | Self::new() 271 | } 272 | } 273 | -------------------------------------------------------------------------------- /qwen3-export/src/config_loader.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | #[path = "../tests/unit/config_loader_test.rs"] 3 | mod config_loader_test; 4 | 5 | use anyhow::{Context, Result}; 6 | use log::info; 7 | use serde::{Deserialize, Serialize}; 8 | use std::{fs::File, io::Read, path::Path}; 9 | 10 | use crate::models::ArchitectureId; 11 | 12 | /// Model type detection with embedded LoRA configuration 13 | #[derive(Debug, Clone, PartialEq)] 14 | pub enum ModelType { 15 | Base, // Standard base model 16 | LoRA(LoRAConfig), // LoRA fine-tuned model with full config 17 | } 18 | 19 | /// Enhanced model information that includes type and configs 20 | #[derive(Debug, Clone)] 21 | pub struct ModelInfo { 22 | pub model_type: ModelType, 23 | pub config: ModelConfig, 24 | } 25 | 26 | /// Configuration structure matching the Python ModelArgs 27 | #[derive(Debug, Clone)] 28 | pub struct ModelConfig { 29 | pub dim: u32, 30 | pub hidden_dim: u32, 31 | pub n_layers: u32, 32 | pub n_heads: u32, 33 | pub n_kv_heads: u32, 34 | pub vocab_size: u32, 35 | pub max_seq_len: u32, 36 | pub head_dim: u32, 37 | pub norm_eps: f32, 38 | pub bos_token_id: u32, 39 | pub eos_token_id: u32, 40 | pub architecture: ArchitectureId, 41 | } 42 | 43 | /// LoRA configuration from adapter_config.json 44 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] 45 | pub struct LoRAConfig { 46 | pub lora_alpha: f32, 47 | pub r: usize, 48 | pub target_modules: Vec, 49 | pub base_model_name_or_path: Option, 50 | } 51 | 52 | /// Auto-detect model type and load appropriate configuration 53 | /// This is the main entry point that replaces load_hf_config 54 | pub fn load_model_info(model_path: &str) -> Result { 55 | let model_path = Path::new(model_path); 56 | 57 | // Detect model type based on config files 58 | let model_type = detect_model_type(model_path)?; 59 | 60 | let (config, _) = match &model_type { 61 | ModelType::Base => { 62 | info!("Detected Base model type (no LoRA)"); 63 | let config = load_base_model_config(model_path)?; 64 | (config, ()) 65 | } 66 | ModelType::LoRA(lora_config) => { 67 | let config = load_base_model_config(model_path)?; 68 | info!("Detected Base model type with LoRA configuration:"); 69 | info!(" • Alpha: {}", lora_config.lora_alpha); 70 | info!(" • Rank (r): {}", lora_config.r); 71 | info!(" • Target modules: {:?}", lora_config.target_modules); 72 | if let Some(ref base_model) = lora_config.base_model_name_or_path { 73 | info!(" • Base model: {}", base_model); 74 | } 75 | info!(""); 76 | (config, ()) 77 | } 78 | }; 79 | 80 | Ok(ModelInfo { model_type, config }) 81 | } 82 | 83 | /// Detect model type based on presence of config files. 84 | /// For LoRA models, loads and embeds the LoRA configuration. 85 | fn detect_model_type(model_path: &Path) -> Result { 86 | let has_adapter_config = model_path.join("adapter_config.json").exists(); 87 | let has_base_config = model_path.join("config.json").exists(); 88 | 89 | match (has_base_config, has_adapter_config) { 90 | (true, true) => { 91 | // LoRA model - load adapter config and embed the full config 92 | let lora_config = load_lora_config(model_path)?; 93 | Ok(ModelType::LoRA(lora_config)) 94 | } 95 | (true, false) => Ok(ModelType::Base), 96 | (false, true) => anyhow::bail!( 97 | "Only LoRA config is found in {}. Make sure to have base model files in the same directory", 98 | model_path.display() 99 | ), 100 | _ => anyhow::bail!("No valid configuration files found in {}", model_path.display()), 101 | } 102 | } 103 | 104 | /// Load base model configuration - handles both direct config.json and LoRA case 105 | fn load_base_model_config(model_path: &Path) -> Result { 106 | let config_path = model_path.join("config.json"); 107 | 108 | if config_path.exists() { 109 | // Direct config.json exists 110 | load_hf_config(&config_path) 111 | } else { 112 | // For LoRA models, we might need to look elsewhere or use defaults 113 | // For now, return an error to let user know they need base model config 114 | anyhow::bail!( 115 | "Base model config.json not found in {}. For LoRA models, ensure the base model config is available.", 116 | model_path.display() 117 | ) 118 | } 119 | } 120 | 121 | /// Load model configuration from HuggingFace format. 122 | fn load_hf_config(config_path: &Path) -> Result { 123 | let mut file = 124 | File::open(&config_path).with_context(|| format!("Failed to open config.json at {config_path:?}"))?; 125 | let mut contents = String::new(); 126 | file.read_to_string(&mut contents)?; 127 | 128 | #[derive(Debug, Deserialize)] 129 | struct HFConfig { 130 | hidden_size: u32, 131 | intermediate_size: u32, 132 | num_hidden_layers: u32, 133 | num_attention_heads: u32, 134 | num_key_value_heads: u32, 135 | vocab_size: u32, 136 | max_position_embeddings: u32, 137 | rms_norm_eps: f32, 138 | #[serde(default)] 139 | head_dim: Option, 140 | #[serde(default)] 141 | bos_token_id: Option, 142 | #[serde(default)] 143 | eos_token_id: Option, 144 | #[serde(default)] 145 | architectures: Option>, 146 | } 147 | 148 | let hf_config: HFConfig = 149 | serde_json::from_str(&contents).map_err(|err| anyhow::anyhow!("Failed to parse config.json: {}", err))?; 150 | 151 | let head_dim = hf_config.head_dim.unwrap_or(hf_config.hidden_size / hf_config.num_attention_heads); 152 | 153 | // Try to determine architecture 154 | let architectures = hf_config.architectures.as_ref(); 155 | let architecture = match (architectures, architectures.and_then(|a| a.first())) { 156 | (Some(architectures), Some(first)) if architectures.len() == 1 => ArchitectureId::try_from(first.as_str())?, 157 | (Some(architectures), _) => { 158 | anyhow::bail!("Multiple architectures are not supported: {architectures:?}") 159 | } 160 | _ => anyhow::bail!("Cannot determine architecture"), 161 | }; 162 | 163 | let config = ModelConfig { 164 | dim: hf_config.hidden_size, 165 | hidden_dim: hf_config.intermediate_size, 166 | n_layers: hf_config.num_hidden_layers, 167 | n_heads: hf_config.num_attention_heads, 168 | n_kv_heads: hf_config.num_key_value_heads, 169 | vocab_size: hf_config.vocab_size, 170 | max_seq_len: hf_config.max_position_embeddings, 171 | norm_eps: hf_config.rms_norm_eps, 172 | head_dim, 173 | bos_token_id: hf_config.bos_token_id.unwrap_or(0), 174 | eos_token_id: hf_config.eos_token_id.unwrap_or(0), 175 | architecture, 176 | }; 177 | 178 | info!("Model configuration loaded:"); 179 | info!(" • Architecture: {:?}", config.architecture); 180 | info!(" • Dimensions: {}", config.dim); 181 | info!(" • Layers: {}", config.n_layers); 182 | info!(" • Attention heads: {}", config.n_heads); 183 | info!(" • KV heads: {}", config.n_kv_heads); 184 | info!(" • Vocabulary size: {}", config.vocab_size); 185 | info!(" • Max sequence length: {}", config.max_seq_len); 186 | info!(" • Head dimension: {}", config.head_dim); 187 | info!(""); 188 | 189 | Ok(config) 190 | } 191 | 192 | /// Load LoRA configuration from adapter_config.json 193 | fn load_lora_config(model_path: &Path) -> Result { 194 | let config_path = model_path.join("adapter_config.json"); 195 | let mut file = File::open(&config_path) 196 | .with_context(|| format!("Failed to open adapter_config.json at {}", config_path.display()))?; 197 | let mut contents = String::new(); 198 | file.read_to_string(&mut contents)?; 199 | 200 | let config: LoRAConfig = serde_json::from_str(&contents) 201 | .map_err(|err| anyhow::anyhow!("Failed to parse adapter_config.json: {}", err))?; 202 | 203 | Ok(config) 204 | } 205 | -------------------------------------------------------------------------------- /qwen3-export/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # qwen3-export 2 | //! 3 | //! A Rust library for exporting Qwen3 models from HuggingFace format to binary format. 4 | //! 5 | //! ## Examples 6 | //! 7 | //! ### Exporting a model 8 | //! 9 | //! ```rust,no_run 10 | //! use qwen3_export::export_model; 11 | //! 12 | //! # fn main() -> anyhow::Result<()> { 13 | //! let model_path = "path/to/huggingface/model"; 14 | //! let output_path = "output/model"; 15 | //! 16 | //! 17 | //! // Export the model 18 | //! export_model(model_path, output_path, 32)?; 19 | //! # Ok(()) 20 | //! # } 21 | //! ``` 22 | 23 | // Public modules and re-exports from the former export module 24 | pub mod chat_template_exporter; 25 | pub mod config_loader; 26 | pub mod lora_merger; 27 | pub mod model_exporter; 28 | pub mod models; 29 | pub mod tensor_reader; 30 | pub mod tokenizer_exporter; 31 | mod utils; 32 | 33 | // Re-export main types for easy access 34 | pub use chat_template_exporter::ChatTemplateExporter; 35 | pub use config_loader::{LoRAConfig, ModelConfig, ModelInfo, ModelType, load_model_info}; 36 | pub use model_exporter::{BinaryModelExporter, QuantizedWeight}; 37 | pub use tokenizer_exporter::TokenizerExporter; 38 | 39 | use anyhow::Result; 40 | use log::info; 41 | use std::path::Path; 42 | 43 | /// Export the model weights in Q8_0 into .bin file to be used by later within inference implementation. 44 | /// That is: 45 | /// - quantize all weights to symmetric int8, in range [-127, 127] 46 | /// - all other tensors (the rmsnorm params) are kept and exported in fp32 47 | /// - quantization is done in groups of group_size to reduce the effects of any outliers 48 | /// 49 | /// Automatically detects model type and loads appropriate configs. 50 | pub fn export_model(model_path: &str, output_path: &str, group_size: usize) -> Result<()> { 51 | info!("🚀 Starting automatic model export with type detection..."); 52 | info!(""); 53 | 54 | // Auto-detect model type and load configs 55 | let model_info = load_model_info(model_path)?; 56 | 57 | let model_path = Path::new(model_path); 58 | let output_path = Path::new(output_path); 59 | 60 | info!("🧮 Exporting quantized binary model..."); 61 | BinaryModelExporter::new_from_model_info(&model_info, group_size).export_binary_model( 62 | model_path, 63 | output_path, 64 | &model_info, 65 | )?; 66 | info!(""); 67 | 68 | info!("🔤 Exporting tokenizer..."); 69 | TokenizerExporter::new().export_tokenizer( 70 | model_path, 71 | output_path, 72 | model_info.config.bos_token_id, 73 | model_info.config.eos_token_id, 74 | )?; 75 | info!(""); 76 | 77 | info!("💬 Exporting chat templates..."); 78 | ChatTemplateExporter::new().export_templates(model_path, output_path)?; 79 | 80 | info!(""); 81 | info!("✅ Complete export finished successfully!"); 82 | Ok(()) 83 | } 84 | -------------------------------------------------------------------------------- /qwen3-export/src/lora_merger.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor_reader::TensorReader; 2 | use anyhow::Result; 3 | use log::{debug, warn}; 4 | use rayon::prelude::*; 5 | 6 | /// LoraMerger applies standard LoRA merge logic on tensors: W = W_base + α / r * (B @ A) 7 | /// 8 | /// Assumptions: 9 | /// - LoRA uses a low-rank update to fine-tune a frozen base model weight. 10 | /// - A and B are learned matrices with shapes: 11 | /// - A: (r, in_features) 12 | /// - B: (out_features, r) 13 | /// - The base weight matrix W_base has shape (out_features, in_features) and is stored in row-major 1D layout. 14 | /// - The rank `r` is a small integer (e.g., 4, 8, 16) << min(in_features, out_features). 15 | /// - `alpha` is a scalar hyperparameter; `scaling = alpha / r`. 16 | /// 17 | /// Merge formula: 18 | /// W = W_base + scaling * (B @ A) 19 | /// Where: 20 | /// - `B @ A` produces a matrix of shape (out_features, in_features) 21 | /// - `scaling` modulates the update magnitude 22 | /// - The update is applied elementwise to the flattened base tensor 23 | /// 24 | /// Notes: 25 | /// - This code assumes tensors are flattened 1D f32 buffers in row-major order. 26 | /// - The caller is responsible for ensuring tensor shapes are consistent. 27 | pub(crate) struct LoraMerger<'a> { 28 | tensor_reader: &'a TensorReader, 29 | scaling: f32, 30 | rank: usize, 31 | } 32 | 33 | impl<'a> LoraMerger<'a> { 34 | pub fn new(tensor_reader: &'a TensorReader, alpha: f32, rank: usize) -> Result { 35 | let scaling = alpha / rank as f32; 36 | 37 | if !scaling.is_finite() || scaling.is_nan() { 38 | anyhow::bail!("Invalid scaling factor: {scaling} (must be finite). Alpha: {alpha}, Rank: {rank}"); 39 | } 40 | 41 | Ok(Self { tensor_reader, scaling, rank }) 42 | } 43 | 44 | /// Try to merge LoRA adapters with base weights 45 | pub fn try_merge_lora_adapters( 46 | &self, 47 | base_weights: &[f32], 48 | component: &str, 49 | layer_idx: u32, 50 | ) -> Result>> { 51 | // component is already clean (e.g., "self_attn.k_proj", "mlp.gate_proj") 52 | let (lora_a, lora_b) = self.discover_and_load_lora_pairs(component, layer_idx)?; 53 | 54 | if let (Some(a), Some(b)) = (lora_a, lora_b) { 55 | debug!("Merging LoRA adapters for {component} layer {layer_idx} with scaling {}", self.scaling); 56 | 57 | // Merge LoRA: W = W_base + scaling * (B @ A) 58 | let merged = self.merge_lora_weights(base_weights, &a, &b)?; 59 | Ok(Some(merged)) 60 | } else { 61 | Ok(None) 62 | } 63 | } 64 | 65 | /// Dynamically discover and load LoRA adapter pairs 66 | fn discover_and_load_lora_pairs( 67 | &self, 68 | component: &str, 69 | layer_idx: u32, 70 | ) -> Result<(Option>, Option>)> { 71 | // TODO: consider supporting different patterns? 72 | 73 | // Based on the actual tensor naming pattern: 74 | // Base: model.layers.{layer}.{component}.weight 75 | 76 | let lora_a_name = format!("base_model.model.model.layers.{layer_idx}.{component}.lora_A.weight"); 77 | let lora_b_name = format!("base_model.model.model.layers.{layer_idx}.{component}.lora_B.weight"); 78 | 79 | debug!("Looking for LoRA A: '{lora_a_name}'"); 80 | let lora_a = self.tensor_reader.load_tensor(&lora_a_name)?; 81 | 82 | debug!("Looking for LoRA B: '{lora_b_name}'"); 83 | let lora_b = self.tensor_reader.load_tensor(&lora_b_name)?; 84 | 85 | if lora_a.is_none() || lora_b.is_none() { 86 | debug!( 87 | "Could not find LoRA pair for layers.{layer_idx}.{component}, A found: {}, B found: {}", 88 | lora_a.is_some(), 89 | lora_b.is_some() 90 | ); 91 | } 92 | 93 | Ok((lora_a, lora_b)) 94 | } 95 | 96 | /// Merge LoRA weights: W = W_base + scaling * (B @ A) 97 | fn merge_lora_weights(&self, base: &[f32], lora_a: &[f32], lora_b: &[f32]) -> Result> { 98 | // Input validation 99 | if base.is_empty() || lora_a.is_empty() || lora_b.is_empty() { 100 | return Err(anyhow::anyhow!( 101 | "Empty tensors not allowed: base={}, A={}, B={}", 102 | base.len(), 103 | lora_a.len(), 104 | lora_b.len() 105 | )); 106 | } 107 | 108 | let rank = self.rank; 109 | let base_len = base.len(); 110 | let a_len = lora_a.len(); 111 | let b_len = lora_b.len(); 112 | 113 | // Calculate dimensions using known rank 114 | let (in_features, out_features) = self.calculate_lora_dimensions(base_len, a_len, b_len)?; 115 | 116 | debug!( 117 | "Merging LoRA: base {out_features}×{in_features} ({base_len}), A {rank}×{in_features} ({a_len}), B {out_features}×{rank} ({b_len}), rank={rank}, scaling={:.6}", 118 | self.scaling 119 | ); 120 | 121 | if self.scaling.abs() > 1e3 { 122 | warn!("Large scaling factor detected: {:.6}", self.scaling); 123 | } 124 | 125 | // apply merge: W = W_base + scaling * (B @ A) 126 | let mut result = base.to_vec(); 127 | result.par_iter_mut().enumerate().for_each(|(idx, base_val)| { 128 | let out_idx = idx / in_features; 129 | let in_idx = idx % in_features; 130 | 131 | let mut delta_val = 0.0f32; 132 | for r in 0..self.rank { 133 | let b_val = lora_b[out_idx * self.rank + r]; 134 | let a_val = lora_a[r * in_features + in_idx]; 135 | delta_val += b_val * a_val; 136 | } 137 | 138 | *base_val += self.scaling * delta_val; 139 | }); 140 | 141 | let (max_abs_delta, avg_abs_delta, max_abs_base, avg_abs_base) = 142 | self.compute_merge_statistics(base, &result)?; 143 | 144 | debug!( 145 | "LoRA merge complete: max_delta={max_abs_delta:.6}, avg_delta={avg_abs_delta:.6}, max_base={max_abs_base:.6}, avg_base={avg_abs_base:.6}, relative_change={:.3}%", 146 | if avg_abs_base > 1e-12 { (avg_abs_delta / avg_abs_base) * 100.0 } else { 0.0 } 147 | ); 148 | 149 | Ok(result) 150 | } 151 | 152 | /// Calculate LoRA dimensions using the known rank from config 153 | /// Returns (in_features, out_features) or error if dimensions don't match 154 | fn calculate_lora_dimensions(&self, base_len: usize, a_len: usize, b_len: usize) -> Result<(usize, usize)> { 155 | // With known rank, we can directly calculate dimensions 156 | // LoRA format: A: (rank, in_features), B: (out_features, rank) 157 | 158 | if a_len % self.rank != 0 { 159 | anyhow::bail!("LoRA A tensor size ({}) is not divisible by rank ({})", a_len, self.rank); 160 | } 161 | 162 | if b_len % self.rank != 0 { 163 | anyhow::bail!("LoRA B tensor size ({}) is not divisible by rank ({})", b_len, self.rank); 164 | } 165 | 166 | let in_features = a_len / self.rank; 167 | let out_features = b_len / self.rank; 168 | 169 | // Verify that dimensions are consistent with base weight 170 | if in_features * out_features != base_len { 171 | anyhow::bail!( 172 | "Dimension mismatch: base tensor size ({base_len}) doesn't match calculated dimensions ({out_features}×{in_features} = {})", 173 | in_features * out_features 174 | ); 175 | } 176 | 177 | if in_features == 0 || out_features == 0 { 178 | anyhow::bail!("Invalid dimensions: in_features={in_features}, out_features={out_features}",); 179 | } 180 | 181 | debug!( 182 | "Calculated LoRA dimensions: rank={}, in_features={in_features}, out_features={out_features}", 183 | self.rank, 184 | ); 185 | 186 | Ok((in_features, out_features)) 187 | } 188 | 189 | /// Compute statistics for LoRA merge validation and logging using parallel processing 190 | fn compute_merge_statistics(&self, base: &[f32], result: &[f32]) -> Result<(f32, f32, f32, f32)> { 191 | if base.len() != result.len() { 192 | anyhow::bail!("Base and result tensor lengths don't match: {} vs {}", base.len(), result.len()); 193 | } 194 | 195 | // Parallel computation of statistics with overflow checking 196 | let stats_result: Result<(f32, f64, f32, f64)> = base 197 | .par_iter() 198 | .zip(result.par_iter()) 199 | .enumerate() 200 | .map(|(idx, (&base_val, &result_val))| { 201 | // Check for overflow/NaN in result during statistics computation 202 | if !result_val.is_finite() { 203 | anyhow::bail!("Non-finite value detected in result at index {idx}: {result_val}"); 204 | } 205 | 206 | let delta = (result_val - base_val).abs(); 207 | let abs_base = base_val.abs(); 208 | Ok((delta, delta as f64, abs_base, abs_base as f64)) 209 | }) 210 | .try_reduce( 211 | || (0.0f32, 0.0f64, 0.0f32, 0.0f64), 212 | |acc, curr| { 213 | let (max_delta_acc, sum_delta_acc, max_base_acc, sum_base_acc) = acc; 214 | let (delta, delta_f64, abs_base, abs_base_f64) = curr; 215 | Ok(( 216 | max_delta_acc.max(delta), 217 | sum_delta_acc + delta_f64, 218 | max_base_acc.max(abs_base), 219 | sum_base_acc + abs_base_f64, 220 | )) 221 | }, 222 | ); 223 | 224 | let (max_abs_delta, sum_abs_delta, max_abs_base, sum_abs_base) = stats_result?; 225 | 226 | let len = base.len() as f64; 227 | let avg_abs_delta = (sum_abs_delta / len) as f32; 228 | let avg_abs_base = (sum_abs_base / len) as f32; 229 | 230 | Ok((max_abs_delta, avg_abs_delta, max_abs_base, avg_abs_base)) 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /qwen3-export/src/model_exporter.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | #[path = "../tests/unit/model_exporter_test.rs"] 3 | mod model_exporter_test; 4 | 5 | use anyhow::Result; 6 | use byteorder::{LittleEndian, WriteBytesExt}; 7 | use log::{info, warn}; 8 | use rayon::prelude::*; 9 | use std::fs::File; 10 | use std::io::{BufWriter, Write}; 11 | use std::path::Path; 12 | 13 | use crate::lora_merger::LoraMerger; 14 | use crate::models::{Architecture, HeaderInfo, NormWeightLayer, WeightLayer, create_architecture}; 15 | use crate::tensor_reader::TensorReader; 16 | use crate::utils::ProgressTracker; 17 | use crate::{ModelConfig, ModelInfo, ModelType}; 18 | 19 | // Quantization result 20 | #[derive(Debug)] 21 | pub struct QuantizedWeight { 22 | pub int8_data: Vec, 23 | pub scales: Vec, 24 | pub max_error: f32, 25 | } 26 | 27 | /// Binary model exporter for quantized model weights 28 | pub struct BinaryModelExporter { 29 | config: ModelConfig, 30 | group_size: usize, 31 | } 32 | 33 | impl BinaryModelExporter { 34 | const MAGIC_NUMBER: u32 = 0x616A6331; // "ajc1" in ASCII 35 | const VERSION: i32 = 1; 36 | const HEADER_SIZE: usize = 256; 37 | const MIN_GROUP_SIZE: usize = 4; 38 | 39 | pub fn new(config: ModelConfig, group_size: usize) -> Self { 40 | let optimal_group_size = Self::find_optimal_group_size(config.dim as usize, group_size); 41 | if optimal_group_size != group_size { 42 | info!("Adjusted group size from {} to {} to fit hidden_dim {}", group_size, optimal_group_size, config.dim); 43 | } 44 | Self { config, group_size: optimal_group_size } 45 | } 46 | 47 | /// Find optimal group size that divides hidden_dim and is reasonable 48 | fn find_optimal_group_size(hidden_dim: usize, requested_size: usize) -> usize { 49 | let mut size = requested_size.min(hidden_dim); 50 | 51 | // Find largest size that divides hidden_dim 52 | while size >= Self::MIN_GROUP_SIZE && hidden_dim % size != 0 { 53 | size /= 2; 54 | } 55 | 56 | size.max(Self::MIN_GROUP_SIZE) 57 | } 58 | 59 | /// Create exporter from ModelInfo (recommended for new code) 60 | pub fn new_from_model_info(model_info: &ModelInfo, group_size: usize) -> Self { 61 | Self::new(model_info.config.clone(), group_size) 62 | } 63 | 64 | /// Export binary model with quantized weights using streaming to minimize memory usage 65 | pub fn export_binary_model(&self, model_path: &Path, output_path: &Path, model_info: &ModelInfo) -> Result<()> { 66 | let tensor_reader = TensorReader::new(model_path)?; 67 | 68 | #[cfg(debug_assertions)] 69 | log::debug!("Tensor names: {:?}", tensor_reader.list_tensor_names()?); 70 | 71 | let file = File::create(output_path)?; 72 | let mut writer = BufWriter::new(file); 73 | 74 | let architecture = create_architecture(model_info, &tensor_reader); 75 | let header_info = architecture.header()?; 76 | 77 | // Write header (256 bytes) 78 | self.write_header(&mut writer, &header_info)?; 79 | 80 | // Write normalization weights (fp32) - these are small 81 | self.write_norm_weights(architecture.as_ref(), &mut writer, &tensor_reader)?; 82 | 83 | // Stream and quantize weights one by one 84 | self.stream_and_quantize_weights( 85 | architecture.as_ref(), 86 | &mut writer, 87 | &tensor_reader, 88 | header_info.shared_classifier, 89 | &model_info.model_type, 90 | )?; 91 | 92 | writer.flush()?; 93 | info!("💾 Written model checkpoint to {}", output_path.display()); 94 | 95 | // Clear cache to free memory 96 | if let Err(e) = tensor_reader.clear_cache() { 97 | warn!("Failed to clear cache: {e}"); 98 | } 99 | 100 | Ok(()) 101 | } 102 | 103 | /// Quantize weights to Q8_0 format (symmetric int8, range [-127, 127]) 104 | pub fn quantize_q80(&self, weights: &[f32]) -> Result { 105 | if weights.len() % self.group_size != 0 { 106 | return Err(anyhow::anyhow!("Weight length is not a multiple of group_size")); 107 | } 108 | 109 | let num_groups = weights.len() / self.group_size; 110 | 111 | // Process groups in parallel 112 | let group_results: Vec<_> = (0..num_groups) 113 | .into_par_iter() 114 | .map(|group_idx| { 115 | let start_idx = group_idx * self.group_size; 116 | let end_idx = start_idx + self.group_size; 117 | let group = &weights[start_idx..end_idx]; 118 | 119 | // Find the maximum absolute value in this group 120 | let group_max = group.iter().map(|&x| x.abs()).fold(0.0f32, f32::max); 121 | 122 | // Calculate scaling factor 123 | let scale = if group_max > 0.0 { group_max / 127.0 } else { 1.0 }; 124 | 125 | // Quantize the group 126 | let mut group_int8 = Vec::with_capacity(self.group_size); 127 | let mut group_error = 0.0f32; 128 | 129 | for &weight in group { 130 | let quantized = if scale > 0.0 { 131 | // Use banker's rounding to match PyTorch exactly 132 | let scaled = weight / scale; 133 | round_half_to_even(scaled).clamp(-127.0, 127.0) as i8 134 | } else { 135 | 0i8 136 | }; 137 | group_int8.push(quantized); 138 | 139 | // Calculate reconstruction error for this value 140 | let dequantized = f32::from(quantized) * scale; 141 | let error = (dequantized - weight).abs(); 142 | group_error = group_error.max(error); 143 | } 144 | 145 | (group_int8, scale, group_error) 146 | }) 147 | .collect(); 148 | 149 | // Reconstruct results in order 150 | let mut int8_data = Vec::with_capacity(weights.len()); 151 | let mut scales = Vec::with_capacity(num_groups); 152 | let mut max_error = 0.0f32; 153 | 154 | for (group_int8, scale, group_error) in group_results { 155 | int8_data.extend(group_int8); 156 | scales.push(scale); 157 | max_error = max_error.max(group_error); 158 | } 159 | 160 | Ok(QuantizedWeight { int8_data, scales, max_error }) 161 | } 162 | 163 | /// Write binary header 164 | fn write_header(&self, writer: &mut W, header_info: &HeaderInfo) -> Result<()> { 165 | // Magic number "ajc1" in ASCII 166 | writer.write_u32::(Self::MAGIC_NUMBER)?; 167 | 168 | // Version 169 | writer.write_i32::(Self::VERSION)?; 170 | 171 | // Model parameters (10 int32 values) 172 | writer.write_u32::(header_info.architecture_id as u32)?; 173 | writer.write_u32::(self.config.dim)?; 174 | writer.write_u32::(self.config.hidden_dim)?; 175 | writer.write_u32::(self.config.n_layers)?; 176 | writer.write_u32::(self.config.n_heads)?; 177 | writer.write_u32::(self.config.n_kv_heads)?; 178 | writer.write_u32::(self.config.vocab_size)?; 179 | writer.write_u32::(self.config.max_seq_len)?; 180 | writer.write_u32::(self.config.head_dim)?; 181 | writer.write_u32::(header_info.shared_classifier as u32)?; 182 | writer.write_u32::(self.group_size as u32)?; 183 | 184 | // Pad to header size 185 | let current_pos = 4 + 4 + 4 + 10 * 4; // magic + version + architecture_id + 10 params 186 | let padding = Self::HEADER_SIZE - current_pos; 187 | let zeros = vec![0u8; padding]; 188 | writer.write_all(&zeros)?; 189 | 190 | Ok(()) 191 | } 192 | 193 | /// Write normalization weights (fp32). 194 | fn write_norm_weights( 195 | &self, 196 | architecture: &dyn Architecture, 197 | writer: &mut W, 198 | tensor_reader: &TensorReader, 199 | ) -> Result<()> { 200 | info!("Writing normalization weights..."); 201 | 202 | let mut write_fn = |tensor_name: &str, is_required| -> Result<()> { 203 | match (tensor_reader.load_tensor(tensor_name)?, is_required) { 204 | (Some(attn_norm), _) => { 205 | for &value in &attn_norm { 206 | writer.write_f32::(value)?; 207 | } 208 | } 209 | (None, false) => { 210 | for _ in 0..self.config.head_dim as usize { 211 | writer.write_f32::(1.0)?; 212 | } 213 | } 214 | (None, true) => anyhow::bail!("Missing weight for tensor_name: '{tensor_name}'"), 215 | } 216 | 217 | Ok(()) 218 | }; 219 | 220 | architecture.norm_weight_layers().iter().try_for_each(|&NormWeightLayer { name, layered, is_required }| { 221 | if layered { 222 | for layer_idx in 0..self.config.n_layers { 223 | let layer_name = name.replace("{}", &layer_idx.to_string()); 224 | write_fn(&layer_name, is_required)?; 225 | } 226 | } else { 227 | write_fn(name, is_required)?; 228 | } 229 | 230 | Ok(()) 231 | }) 232 | } 233 | 234 | /// Stream and quantize weights one by one to minimize memory usage (LoRA-aware) 235 | fn stream_and_quantize_weights( 236 | &self, 237 | architecture: &dyn Architecture, 238 | writer: &mut W, 239 | tensor_reader: &TensorReader, 240 | shared_classifier: bool, 241 | model_type: &ModelType, 242 | ) -> Result<()> { 243 | let estimated_capacity = 1 // embed_tokens 244 | + architecture.weight_layers().len() // layer weights 245 | + usize::from(!shared_classifier); // classifier if not shared 246 | 247 | let mut weight_tensors = Vec::with_capacity(estimated_capacity); 248 | 249 | // First: embedding tokens 250 | weight_tensors.push((architecture.embed_tokens_layer().to_string(), None, None)); 251 | 252 | // Then: layer weights 253 | for WeightLayer { tensor_name, component, layer_idx } in architecture.weight_layers() { 254 | weight_tensors.push((tensor_name.clone(), Some(component.to_string()), Some(*layer_idx))); 255 | } 256 | 257 | // Then Classifier if not shared 258 | if !shared_classifier { 259 | weight_tensors.push((architecture.lm_head_layer().to_string(), None, None)); 260 | } 261 | 262 | let lora_merger = if let ModelType::LoRA(lora_config) = model_type { 263 | Some(LoraMerger::new(tensor_reader, lora_config.lora_alpha, lora_config.r)?) 264 | } else { 265 | None 266 | }; 267 | 268 | let progress = ProgressTracker::new(weight_tensors.len(), "Quantizing"); 269 | 270 | // Process each weight tensor individually 271 | let max_errors: Result> = weight_tensors 272 | .iter() 273 | .enumerate() 274 | .map(|(i, (tensor_name, tensor_type, layer_idx))| { 275 | progress.set_current(i + 1, Some(tensor_name)); 276 | 277 | // Load the base tensor (same for both base and LoRA models) 278 | let mut weight_tensor = tensor_reader 279 | .load_tensor(tensor_name)? 280 | .ok_or_else(|| anyhow::anyhow!("Missing weight tensor: {}", tensor_name))?; 281 | 282 | // If LoRA is used, try to merge adapters 283 | if let (Some(lora_merger), Some(layer_idx), Some(component)) = 284 | (lora_merger.as_ref(), layer_idx, tensor_type.as_ref()) 285 | { 286 | if let Some(merged_weights) = 287 | lora_merger.try_merge_lora_adapters(&weight_tensor, component, *layer_idx)? 288 | { 289 | weight_tensor = merged_weights; 290 | } 291 | } 292 | 293 | if weight_tensor.is_empty() { 294 | warn!("Empty weight tensor: {}", tensor_name); 295 | return Ok(0.0); 296 | } 297 | 298 | // Quantize this tensor 299 | let quantized = self.quantize_q80(&weight_tensor)?; 300 | 301 | // Write quantized data using iterators 302 | quantized.int8_data.iter().try_for_each(|&value| writer.write_i8(value))?; 303 | quantized.scales.iter().try_for_each(|&scale| writer.write_f32::(scale))?; 304 | 305 | Ok(quantized.max_error) 306 | }) 307 | .collect(); 308 | 309 | let max_errors = max_errors?; 310 | 311 | // Print overall max error 312 | let overall_max_error = max_errors.iter().fold(0.0f32, |acc, &x| acc.max(x)); 313 | info!("Quantized {} weight tensors to Q8_0 with max error: {overall_max_error:.8}", weight_tensors.len()); 314 | 315 | Ok(()) 316 | } 317 | } 318 | 319 | /// Round half to even (banker's rounding) to match PyTorch's torch.round() behavior 320 | #[inline] 321 | fn round_half_to_even(x: f32) -> f32 { 322 | // For non-half values, use standard rounding 323 | let rounded = x.round(); 324 | let diff = (x - rounded).abs(); 325 | 326 | // If not exactly halfway, return standard rounding 327 | if diff != 0.5 { 328 | return rounded; 329 | } 330 | 331 | // For exactly halfway cases, round to nearest even 332 | if rounded as i32 % 2 == 0 { 333 | rounded // Already even 334 | } else { 335 | // Make even by rounding toward zero 336 | if x >= 0.0 { rounded - 1.0 } else { rounded + 1.0 } 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /qwen3-export/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | use crate::{ModelInfo, models::qwen3::Qwen3, tensor_reader::TensorReader}; 4 | 5 | mod qwen3; 6 | 7 | /// Architecture ID for binary format identification 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | #[repr(u32)] 10 | pub enum ArchitectureId { 11 | Qwen3ForCausalLM = 1, 12 | LlamaForCausalLM = 2, 13 | } 14 | 15 | impl TryFrom<&str> for ArchitectureId { 16 | type Error = anyhow::Error; 17 | 18 | fn try_from(value: &str) -> Result { 19 | match value { 20 | "Qwen3ForCausalLM" => Ok(Self::Qwen3ForCausalLM), 21 | "LlamaForCausalLM" => Ok(Self::LlamaForCausalLM), 22 | _ => anyhow::bail!("Unknown ArchitectureId: {value}"), 23 | } 24 | } 25 | } 26 | 27 | impl TryFrom for ArchitectureId { 28 | type Error = anyhow::Error; 29 | 30 | fn try_from(value: u32) -> Result { 31 | match value { 32 | 1 => Ok(Self::Qwen3ForCausalLM), 33 | 2 => Ok(Self::LlamaForCausalLM), 34 | _ => anyhow::bail!("Unknown ArchitectureId: {value}"), 35 | } 36 | } 37 | } 38 | 39 | /// Header information structure (lightweight) 40 | #[derive(Debug, Clone)] 41 | pub struct HeaderInfo { 42 | pub architecture_id: u32, 43 | pub shared_classifier: bool, 44 | } 45 | 46 | /// Represents normalization layer. 47 | pub struct NormWeightLayer<'a> { 48 | /// Name of the layer 49 | pub name: &'a str, 50 | /// If set to true, name is a pattern parametrized with layer index 51 | pub layered: bool, 52 | /// If true, error will be returned if the layer not found 53 | /// Otherwise, default(1.0) value will be set. 54 | pub is_required: bool, 55 | } 56 | 57 | impl<'a> NormWeightLayer<'a> { 58 | pub const fn new(pattern: &'a str, layered: bool, is_required: bool) -> Self { 59 | Self { name: pattern, layered, is_required } 60 | } 61 | } 62 | 63 | pub struct WeightLayer<'a> { 64 | pub tensor_name: String, 65 | pub component: &'a str, 66 | pub layer_idx: u32, 67 | } 68 | 69 | impl<'a> WeightLayer<'a> { 70 | pub fn new(tensor_name: String, component: &'a str, layer_idx: u32) -> Self { 71 | Self { tensor_name, component, layer_idx } 72 | } 73 | } 74 | 75 | pub trait Architecture { 76 | fn id(&self) -> ArchitectureId; 77 | 78 | fn name(&self) -> &'static str; 79 | 80 | fn header(&self) -> Result; 81 | 82 | fn norm_weight_layers(&self) -> &[NormWeightLayer<'_>]; 83 | 84 | fn embed_tokens_layer(&self) -> &'static str; 85 | 86 | fn lm_head_layer(&self) -> &'static str; 87 | 88 | fn weight_layers(&self) -> &[WeightLayer<'_>]; 89 | } 90 | 91 | pub fn create_architecture<'a>(model_info: &ModelInfo, tensor_reader: &'a TensorReader) -> Box { 92 | match model_info.config.architecture { 93 | ArchitectureId::Qwen3ForCausalLM => Box::new(Qwen3::new(model_info, tensor_reader)), 94 | ArchitectureId::LlamaForCausalLM => todo!("LlamaForCausalLM not yet implemented"), 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /qwen3-export/src/models/qwen3.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor_reader::TensorReader; 2 | 3 | use super::*; 4 | 5 | pub struct Qwen3<'a> { 6 | weight_layers: Vec>, 7 | tensor_reader: &'a TensorReader, 8 | } 9 | 10 | impl<'a> Qwen3<'a> { 11 | const ARCH_NAME: &'static str = "Qwen3ForCausalLM"; 12 | const EMBED_TOKENS_KEY: &'static str = "model.embed_tokens.weight"; 13 | const LM_HEAD_KEY: &'static str = "lm_head.weight"; 14 | 15 | #[rustfmt::skip] 16 | const NORM_WEIGHTS_LAYERS: &'static [NormWeightLayer<'static>] = &[ 17 | NormWeightLayer::new("model.layers.{}.input_layernorm.weight", true, true), 18 | NormWeightLayer::new("model.layers.{}.post_attention_layernorm.weight", true, true), 19 | NormWeightLayer::new("model.norm.weight", false, true), 20 | NormWeightLayer::new("model.layers.{}.self_attn.q_norm.weight", true, false), 21 | NormWeightLayer::new("model.layers.{}.self_attn.k_norm.weight", true, false), 22 | ]; 23 | 24 | // Qwen3 model layer weight component names (without .weight suffix) 25 | const QWEN3_LAYER_COMPONENTS: &'static [&'static str] = &[ 26 | "self_attn.q_proj", 27 | "self_attn.k_proj", 28 | "self_attn.v_proj", 29 | "self_attn.o_proj", 30 | "mlp.gate_proj", 31 | "mlp.down_proj", 32 | "mlp.up_proj", 33 | ]; 34 | 35 | pub fn new(model_info: &ModelInfo, tensor_reader: &'a TensorReader) -> Self { 36 | let weight_layers = Self::QWEN3_LAYER_COMPONENTS 37 | .iter() 38 | .flat_map(|&component| { 39 | (0..model_info.config.n_layers).map(move |layer_idx| { 40 | let tensor_name = format!("model.layers.{}.{}.weight", layer_idx, component); 41 | WeightLayer::new(tensor_name, component, layer_idx) 42 | }) 43 | }) 44 | .collect(); 45 | 46 | Self { weight_layers, tensor_reader } 47 | } 48 | } 49 | 50 | impl<'a> Architecture for Qwen3<'a> { 51 | fn id(&self) -> ArchitectureId { 52 | ArchitectureId::Qwen3ForCausalLM 53 | } 54 | 55 | fn name(&self) -> &'static str { 56 | Self::ARCH_NAME 57 | } 58 | 59 | fn header(&self) -> Result { 60 | let shared_classifier = match ( 61 | self.tensor_reader.load_tensor(Self::LM_HEAD_KEY)?, 62 | self.tensor_reader.load_tensor(Self::EMBED_TOKENS_KEY)?, 63 | ) { 64 | (Some(lm_head_weights), Some(embed_weights)) => { 65 | // Compare tensor values to determine if they're identical 66 | lm_head_weights.len() == embed_weights.len() 67 | && lm_head_weights.iter().zip(embed_weights.iter()).all(|(a, b)| (a - b).abs() < 1e-6) 68 | } 69 | (None, Some(_)) => true, // No lm_head means shared 70 | _ => false, // Missing embed_tokens is an error, but we'll handle it later 71 | }; 72 | 73 | Ok(HeaderInfo { architecture_id: self.id() as u32, shared_classifier }) 74 | } 75 | 76 | fn norm_weight_layers(&self) -> &[NormWeightLayer<'_>] { 77 | &Self::NORM_WEIGHTS_LAYERS 78 | } 79 | 80 | fn embed_tokens_layer(&self) -> &'static str { 81 | Self::EMBED_TOKENS_KEY 82 | } 83 | 84 | fn lm_head_layer(&self) -> &'static str { 85 | Self::LM_HEAD_KEY 86 | } 87 | 88 | fn weight_layers(&self) -> &[WeightLayer<'_>] { 89 | &self.weight_layers 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /qwen3-export/src/tensor_reader.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | use log::info; 3 | use memmap2::Mmap; 4 | use safetensors::SafeTensors; 5 | use std::{ 6 | collections::{HashMap, VecDeque}, 7 | fs::File, 8 | mem, 9 | path::{Path, PathBuf}, 10 | sync::{Arc, Mutex}, 11 | }; 12 | 13 | /// Memory-efficient tensor reader from SafeTensors files 14 | #[derive(Debug)] 15 | pub struct TensorReader { 16 | safetensors_files: Vec, // Just store file paths, not data 17 | mmap_cache: Arc>, // LRU cache for memory mappings 18 | } 19 | 20 | impl TensorReader { 21 | pub fn new(model_path: &Path) -> Result { 22 | let safetensors_files = std::fs::read_dir(model_path) 23 | .with_context(|| format!("Failed to read directory: {}", model_path.display()))? 24 | .filter_map(|entry| { 25 | let entry = entry.ok()?; 26 | let path = entry.path(); 27 | 28 | // Check if it's a .safetensors file 29 | matches!(path.extension(), Some(ext) if ext == "safetensors").then_some(path) 30 | }) 31 | .collect::>(); 32 | 33 | if safetensors_files.is_empty() { 34 | anyhow::bail!("No SafeTensors files found in {}", model_path.display()); 35 | } 36 | 37 | info!("Found {} safetensor files", safetensors_files.len()); 38 | 39 | Ok(TensorReader { 40 | safetensors_files, 41 | mmap_cache: Arc::new(Mutex::new(MmapCache::new(10))), // Max 10 cached files 42 | }) 43 | } 44 | 45 | /// Load a specific tensor by name, converting from BF16/F32 to F32 46 | pub fn load_tensor(&self, tensor_name: &str) -> Result>> { 47 | self.safetensors_files 48 | .iter() 49 | .find_map(|filename| { 50 | // Use cached memory mapping 51 | // TODO we deserialize here each time, might be not so efficient 52 | let mmap = self.get_mmap(filename).ok()?; 53 | let safetensors = SafeTensors::deserialize(&mmap) 54 | .with_context(|| format!("Failed to deserialize {}", filename.display())) 55 | .ok()?; 56 | 57 | // Try to find the tensor in this file 58 | safetensors 59 | .tensor(tensor_name) 60 | .ok() 61 | .and_then(|tensor_view| Self::convert_tensor_to_f32(&tensor_view, tensor_name).ok()) 62 | }) 63 | .map_or(Ok(None), |data| Ok(Some(data))) 64 | } 65 | 66 | /// Read all tensor files and lists all available tensor names in the model. 67 | #[cfg(debug_assertions)] 68 | pub fn list_tensor_names(&self) -> Result> { 69 | let mut all_tensor_names = HashMap::new(); 70 | 71 | for filename in &self.safetensors_files { 72 | let mmap = self.get_mmap(filename)?; 73 | let safetensors = SafeTensors::deserialize(&mmap) 74 | .with_context(|| format!("Failed to deserialize {}", filename.display()))?; 75 | 76 | all_tensor_names.insert( 77 | filename.to_string_lossy().into_owned(), 78 | safetensors.names().iter().map(|&name| name.clone()).collect(), 79 | ); 80 | } 81 | 82 | Ok(all_tensor_names) 83 | } 84 | 85 | /// Convert tensor data to f32 based on its data type 86 | fn convert_tensor_to_f32(tensor_view: &safetensors::tensor::TensorView, tensor_name: &str) -> Result> { 87 | let tensor_data = tensor_view.data(); 88 | let shape = tensor_view.shape(); 89 | let dtype = tensor_view.dtype(); 90 | let expected_elements = shape.iter().product::(); 91 | 92 | match dtype { 93 | safetensors::Dtype::F32 => { 94 | Self::validate_tensor_size( 95 | tensor_data.len(), 96 | expected_elements * mem::size_of::(), 97 | tensor_name, 98 | "F32", 99 | )?; 100 | Ok(Self::convert_f32_data(tensor_data)) 101 | } 102 | safetensors::Dtype::BF16 => { 103 | Self::validate_tensor_size(tensor_data.len(), expected_elements * 2, tensor_name, "BF16")?; 104 | Ok(Self::convert_bf16_data(tensor_data)) 105 | } 106 | _ => anyhow::bail!("Unsupported tensor dtype {:?} for {}", dtype, tensor_name), 107 | } 108 | } 109 | 110 | /// Validate tensor data size matches expected size 111 | fn validate_tensor_size( 112 | actual_bytes: usize, 113 | expected_bytes: usize, 114 | tensor_name: &str, 115 | dtype_name: &str, 116 | ) -> Result<()> { 117 | if actual_bytes != expected_bytes { 118 | anyhow::bail!( 119 | "{} tensor {} size mismatch. Expected {} bytes, got {}", 120 | dtype_name, 121 | tensor_name, 122 | expected_bytes, 123 | actual_bytes 124 | ); 125 | } 126 | Ok(()) 127 | } 128 | 129 | /// Convert F32 tensor data 130 | fn convert_f32_data(data: &[u8]) -> Vec { 131 | data.chunks_exact(mem::size_of::()) 132 | .map(|chunk| { 133 | let bytes: [u8; 4] = chunk.try_into().expect("chunk size is guaranteed to be 4"); 134 | f32::from_le_bytes(bytes) 135 | }) 136 | .collect() 137 | } 138 | 139 | /// Convert BF16 tensor data to F32 140 | fn convert_bf16_data(data: &[u8]) -> Vec { 141 | data.chunks_exact(2) 142 | .map(|chunk| { 143 | let [low, high] = chunk else { unreachable!("chunks_exact(2) guarantees 2 bytes") }; 144 | // BF16 to F32: BF16 is the upper 16 bits of F32 145 | let bf16_bits = u16::from_le_bytes([*low, *high]); 146 | let f32_bits = (bf16_bits as u32) << 16; 147 | f32::from_bits(f32_bits) 148 | }) 149 | .collect() 150 | } 151 | 152 | /// Get or create a cached memory mapping for a file 153 | fn get_mmap(&self, path: &Path) -> Result> { 154 | let mut cache = self.mmap_cache.lock().map_err(|_| anyhow::anyhow!("Failed to acquire cache lock"))?; 155 | 156 | if let Some(cached_mmap) = cache.get(path) { 157 | return Ok(cached_mmap); 158 | } 159 | 160 | // Create new mapping 161 | let file = File::open(path).with_context(|| format!("Failed to open {}", path.display()))?; 162 | 163 | // SAFETY: All file-backed memory map constructors are marked `unsafe` because of the potential for 164 | // *Undefined Behavior* (UB) using the map if the underlying file is subsequently modified, in or 165 | // out of process. 166 | let mmap = 167 | Arc::new(unsafe { Mmap::map(&file) }.with_context(|| format!("Failed to memory map {}", path.display()))?); 168 | 169 | // Cache it with LRU eviction 170 | cache.insert(path.to_path_buf(), Arc::clone(&mmap)); 171 | Ok(mmap) 172 | } 173 | 174 | /// Clear the memory mapping cache to free memory 175 | #[allow(dead_code)] 176 | pub fn clear_cache(&self) -> Result<()> { 177 | let mut cache = self.mmap_cache.lock().map_err(|_| anyhow::anyhow!("Failed to acquire cache lock"))?; 178 | cache.clear(); 179 | Ok(()) 180 | } 181 | } 182 | 183 | /// Memory-efficient tensor reader with LRU cache 184 | #[derive(Debug)] 185 | struct MmapCache { 186 | cache: HashMap>, 187 | access_order: VecDeque, 188 | max_size: usize, 189 | } 190 | 191 | impl MmapCache { 192 | fn new(max_size: usize) -> Self { 193 | Self { cache: HashMap::new(), access_order: VecDeque::new(), max_size } 194 | } 195 | 196 | fn get(&mut self, path: &Path) -> Option> { 197 | if let Some(mmap) = self.cache.get(path) { 198 | // Move to front (most recently used) 199 | if let Some(pos) = self.access_order.iter().position(|p| p == path) { 200 | self.access_order.remove(pos); 201 | } 202 | self.access_order.push_front(path.to_path_buf()); 203 | Some(Arc::clone(mmap)) 204 | } else { 205 | None 206 | } 207 | } 208 | 209 | fn insert(&mut self, path: PathBuf, mmap: Arc) { 210 | // Remove if already exists 211 | if self.cache.contains_key(&path) { 212 | if let Some(pos) = self.access_order.iter().position(|p| p == &path) { 213 | self.access_order.remove(pos); 214 | } 215 | } 216 | 217 | // Evict least recently used if cache is full 218 | while self.cache.len() >= self.max_size { 219 | if let Some(lru_path) = self.access_order.pop_back() { 220 | self.cache.remove(&lru_path); 221 | } else { 222 | break; 223 | } 224 | } 225 | 226 | // Insert new mapping 227 | self.cache.insert(path.clone(), mmap); 228 | self.access_order.push_front(path); 229 | } 230 | 231 | fn clear(&mut self) { 232 | self.cache.clear(); 233 | self.access_order.clear(); 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /qwen3-export/src/tokenizer_exporter.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | #[path = "../tests/unit/tokenizer_exporter_test.rs"] 3 | mod tests; 4 | 5 | use anyhow::{Context, Result}; 6 | use byteorder::{LittleEndian, WriteBytesExt}; 7 | use log::info; 8 | use serde_json::Value; 9 | use std::{ 10 | collections::HashMap, 11 | fs::File, 12 | io::{BufWriter, Read, Write}, 13 | path::Path, 14 | }; 15 | 16 | /// Tokenizer exporter for converting HuggingFace tokenizers to binary format 17 | #[derive(Debug)] 18 | pub struct TokenizerExporter; 19 | 20 | #[derive(Debug)] 21 | struct TokenData { 22 | vocab: HashMap, 23 | merge_ranks: HashMap, 24 | max_token_length: u32, 25 | } 26 | 27 | /// GPT-2 style Unicode to byte mapping 28 | #[derive(Debug)] 29 | struct UnicodeToByteMap { 30 | mapping: HashMap, 31 | } 32 | 33 | impl UnicodeToByteMap { 34 | const PRINTABLE_ASCII_START: u8 = 33; 35 | const PRINTABLE_ASCII_END: u8 = 126; 36 | const EXTENDED_ASCII_START1: u8 = 161; 37 | const EXTENDED_ASCII_END1: u8 = 172; 38 | const EXTENDED_ASCII_START2: u8 = 174; 39 | const EXTENDED_ASCII_END2: u8 = 255; 40 | const UNICODE_OFFSET: u32 = 256; 41 | 42 | fn new() -> Self { 43 | let mut mapping = HashMap::new(); 44 | 45 | // Printable ASCII characters 46 | for b in Self::PRINTABLE_ASCII_START..=Self::PRINTABLE_ASCII_END { 47 | mapping.insert(b as char, b); 48 | } 49 | 50 | // Extended ASCII characters 51 | for b in Self::EXTENDED_ASCII_START1..=Self::EXTENDED_ASCII_END1 { 52 | mapping.insert(b as char, b); 53 | } 54 | 55 | for b in Self::EXTENDED_ASCII_START2..=Self::EXTENDED_ASCII_END2 { 56 | mapping.insert(b as char, b); 57 | } 58 | 59 | // Special mappings for unprintable characters 60 | let mut n = 0u8; 61 | for b in 0..=255u8 { 62 | if !mapping.values().any(|&v| v == b) { 63 | mapping.insert(char::from_u32(Self::UNICODE_OFFSET + n as u32).unwrap(), b); 64 | n += 1; 65 | } 66 | } 67 | 68 | Self { mapping } 69 | } 70 | 71 | /// Convert token string to bytes using GPT-2 style mapping 72 | fn token_to_bytes(&self, token_str: &str) -> Vec { 73 | token_str 74 | .chars() 75 | .flat_map(|ch| { 76 | self.mapping.get(&ch).map(|&b| vec![b]).unwrap_or_else(|| ch.to_string().as_bytes().to_vec()) 77 | }) 78 | .collect() 79 | } 80 | } 81 | 82 | impl TokenizerExporter { 83 | const TOKENIZER_FILE_NAME: &'static str = "tokenizer.json"; 84 | const DEFAULT_SCORE: f32 = -1e6; 85 | 86 | /// Create a new TokenizerExporter 87 | pub const fn new() -> Self { 88 | Self 89 | } 90 | 91 | /// Export tokenizer to binary format 92 | pub fn export_tokenizer( 93 | &self, 94 | model_path: &Path, 95 | output_path: &Path, 96 | bos_token_id: u32, 97 | eos_token_id: u32, 98 | ) -> Result<()> { 99 | let token_data = self.load_token_data(model_path)?; 100 | let tokens_by_id = self.create_ordered_tokens(&token_data.vocab); 101 | let u2b_map = UnicodeToByteMap::new(); 102 | 103 | self.write_tokenizer_file(output_path, &token_data, &tokens_by_id, &u2b_map, bos_token_id, eos_token_id) 104 | } 105 | 106 | /// Load and process all token data 107 | fn load_token_data(&self, model_path: &Path) -> Result { 108 | let tokenizer_data = self.load_tokenizer_json(model_path)?; 109 | let vocab = self.extract_vocabulary(&tokenizer_data)?; 110 | 111 | let merge_ranks = self.extract_merge_ranks(&tokenizer_data); 112 | let max_token_length = vocab.keys().map(|token| token.len()).max().unwrap_or(0) as u32; 113 | 114 | info!("📊 Found {} tokens in vocabulary", vocab.len()); 115 | 116 | Ok(TokenData { vocab, merge_ranks, max_token_length }) 117 | } 118 | 119 | /// Load tokenizer.json file 120 | fn load_tokenizer_json(&self, model_path: &Path) -> Result { 121 | let tokenizer_path = model_path.join(Self::TOKENIZER_FILE_NAME); 122 | 123 | if !tokenizer_path.exists() { 124 | anyhow::bail!("tokenizer.json not found in model directory: {}", model_path.display()); 125 | } 126 | 127 | let mut file = File::open(&tokenizer_path)?; 128 | let mut contents = String::new(); 129 | file.read_to_string(&mut contents)?; 130 | 131 | serde_json::from_str(&contents) 132 | .with_context(|| format!("Failed to parse tokenizer.json from {}", tokenizer_path.display())) 133 | } 134 | 135 | /// Create ordered list of tokens by ID 136 | fn create_ordered_tokens(&self, vocab: &HashMap) -> Vec<(u32, String)> { 137 | let mut tokens_by_id: Vec<(u32, String)> = vocab.iter().map(|(token, &id)| (id, token.clone())).collect(); 138 | tokens_by_id.sort_by_key(|&(id, _)| id); 139 | tokens_by_id 140 | } 141 | 142 | /// Write tokenizer binary file 143 | fn write_tokenizer_file( 144 | &self, 145 | output_path: &Path, 146 | token_data: &TokenData, 147 | tokens_by_id: &[(u32, String)], 148 | u2b_map: &UnicodeToByteMap, 149 | bos_token_id: u32, 150 | eos_token_id: u32, 151 | ) -> Result<()> { 152 | let tokenizer_output = format!("{}.tokenizer", output_path.display()); 153 | let file = File::create(&tokenizer_output)?; 154 | let mut writer = BufWriter::new(file); 155 | 156 | // Write header 157 | writer.write_u32::(token_data.max_token_length)?; 158 | writer.write_u32::(bos_token_id)?; 159 | writer.write_u32::(eos_token_id)?; 160 | 161 | // Write tokens 162 | for (_, token) in tokens_by_id { 163 | self.write_token(&mut writer, token, &token_data.merge_ranks, u2b_map)?; 164 | } 165 | 166 | writer.flush()?; 167 | info!("💾 Written tokenizer model to {tokenizer_output}"); 168 | Ok(()) 169 | } 170 | 171 | /// Write a single token to the binary file 172 | fn write_token( 173 | &self, 174 | writer: &mut W, 175 | token: &str, 176 | merge_ranks: &HashMap, 177 | u2b_map: &UnicodeToByteMap, 178 | ) -> Result<()> { 179 | // Calculate pseudo-score 180 | let score = merge_ranks.get(token).map(|&rank| -((rank + 1) as f32).ln()).unwrap_or(Self::DEFAULT_SCORE); 181 | 182 | writer.write_f32::(score)?; 183 | 184 | // Convert token to bytes using GPT-2 style mapping 185 | let token_bytes = u2b_map.token_to_bytes(token); 186 | writer.write_u32::(token_bytes.len() as u32)?; 187 | writer.write_all(&token_bytes)?; 188 | 189 | Ok(()) 190 | } 191 | 192 | /// Extract vocabulary from tokenizer data 193 | fn extract_vocabulary(&self, tokenizer_data: &Value) -> Result> { 194 | // Extract vocabulary from model/vocab 195 | let mut vocab: HashMap = 196 | if let Some(vocab_obj) = tokenizer_data.pointer("/model/vocab").and_then(|v| v.as_object()) { 197 | vocab_obj.iter().filter_map(|(token, id)| id.as_u64().map(|id| (token.clone(), id as u32))).collect() 198 | } else { 199 | anyhow::bail!("Could not find vocabulary in tokenizer.json") 200 | }; 201 | 202 | info!("📚 Found {} tokens in model/vocab", vocab.len()); 203 | 204 | // Add tokens from added_tokens array 205 | if let Some(added_tokens) = tokenizer_data.pointer("/added_tokens").and_then(|v| v.as_array()) { 206 | for token_obj in added_tokens { 207 | if let (Some(content), Some(id)) = ( 208 | token_obj.pointer("/content").and_then(|v| v.as_str()), 209 | token_obj.pointer("/id").and_then(|v| v.as_u64()), 210 | ) { 211 | vocab.insert(content.to_string(), id as u32); 212 | } 213 | } 214 | 215 | info!("📝 Added {} tokens from added_tokens", added_tokens.len()); 216 | } 217 | 218 | info!("📖 Total vocabulary size: {}", vocab.len()); 219 | 220 | Ok(vocab) 221 | } 222 | 223 | /// Extract merge ranks from tokenizer data 224 | fn extract_merge_ranks(&self, tokenizer_data: &Value) -> HashMap { 225 | tokenizer_data 226 | .pointer("/model/merges") 227 | .and_then(|m| m.as_array()) 228 | .map(|merges| { 229 | merges 230 | .iter() 231 | .enumerate() 232 | .filter_map(|(rank, merge)| merge.as_str().map(|merge_str| (merge_str.to_string(), rank))) 233 | .collect() 234 | }) 235 | .unwrap_or_default() 236 | } 237 | } 238 | 239 | impl Default for TokenizerExporter { 240 | fn default() -> Self { 241 | Self::new() 242 | } 243 | } 244 | -------------------------------------------------------------------------------- /qwen3-export/src/utils.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io::{self, Write}, 3 | sync::atomic::{AtomicUsize, Ordering}, 4 | }; 5 | 6 | /// Progress tracker for showing progress bars 7 | #[derive(Debug)] 8 | pub(crate) struct ProgressTracker { 9 | total: usize, 10 | completed: AtomicUsize, 11 | last_displayed: AtomicUsize, 12 | label: String, 13 | } 14 | 15 | impl ProgressTracker { 16 | pub fn new(total: usize, label: &str) -> Self { 17 | Self { total, completed: AtomicUsize::new(0), last_displayed: AtomicUsize::new(0), label: label.to_string() } 18 | } 19 | 20 | pub fn set_current(&self, current: usize, description: Option<&str>) { 21 | self.completed.store(current, Ordering::Relaxed); 22 | let percent = (current * 100) / self.total; 23 | let last_displayed = self.last_displayed.load(Ordering::Relaxed); 24 | let last_percent = (last_displayed * 100) / self.total; 25 | 26 | // Update display every 1% or on key milestones 27 | if current == 0 || percent > last_percent || current >= self.total || current - last_displayed >= 10 { 28 | self.last_displayed.store(current, Ordering::Relaxed); 29 | let bar_width = 30; 30 | let filled = (current * bar_width) / self.total; 31 | let bar = "█".repeat(filled) + &"░".repeat(bar_width - filled); 32 | 33 | print!( 34 | "\r{}: [{bar}] {current}/{} ({percent}%): {}", 35 | self.label, 36 | self.total, 37 | fixed_len(description.unwrap_or_default(), 42) 38 | ); 39 | io::stdout().flush().unwrap_or(()); 40 | 41 | if current >= self.total { 42 | println!(); // New line when complete 43 | } 44 | } 45 | } 46 | } 47 | 48 | fn fixed_len(description: &str, width: usize) -> String { 49 | let mut desc = description.to_string(); 50 | if desc.len() > width { 51 | // Cut and add ".." 52 | desc.truncate(width.saturating_sub(2)); 53 | desc.push_str(".."); 54 | } else if desc.len() < width { 55 | // Pad with spaces 56 | desc = format!("{:width$}", desc, width = width); 57 | } 58 | desc 59 | } 60 | 61 | #[cfg(test)] 62 | mod tests { 63 | 64 | use rayon::prelude::*; 65 | 66 | #[test] 67 | fn test_par_chunks_zip() { 68 | let mut data1 = vec![1, 2, 3, 4, 5, 6, 7, 8]; 69 | let mut data2 = vec![10, 20, 30, 40, 50, 60, 70, 80]; 70 | let n_heads = 4; 71 | 72 | data1.par_chunks_mut(2).zip(data2.par_chunks_mut(2)).zip((0..n_heads).into_par_iter()).for_each( 73 | |((chunk1, chunk2), head_idx)| { 74 | for (a, b) in chunk1.iter_mut().zip(chunk2.iter_mut()) { 75 | *a += head_idx; 76 | *b += head_idx * 10; 77 | } 78 | }, 79 | ); 80 | 81 | assert_eq!(data1, vec![1, 2, 4, 5, 7, 8, 10, 11]); 82 | assert_eq!(data2, vec![10, 20, 40, 50, 70, 80, 100, 110]); 83 | } 84 | 85 | #[test] 86 | fn test_attention_pattern() { 87 | struct MockState { 88 | att: Vec, 89 | xb: Vec, 90 | } 91 | 92 | let seq_len = 4; 93 | let head_dim = 8; 94 | let n_heads = 4; 95 | 96 | let mut state = MockState { att: vec![0.0; n_heads * seq_len], xb: vec![0.0; n_heads * head_dim] }; 97 | 98 | state 99 | .att 100 | .par_chunks_mut(seq_len) 101 | .zip(state.xb.par_chunks_mut(head_dim)) 102 | .zip((0..n_heads).into_par_iter()) 103 | .for_each(|((att_slice, xb_slice), head_idx)| { 104 | for val in att_slice.iter_mut() { 105 | *val = head_idx as f32; 106 | } 107 | for val in xb_slice.iter_mut() { 108 | *val = (head_idx * 10) as f32; 109 | } 110 | }); 111 | 112 | for head in 0..n_heads { 113 | for i in 0..seq_len { 114 | assert_eq!(state.att[head * seq_len + i], head as f32); 115 | } 116 | for i in 0..head_dim { 117 | assert_eq!(state.xb[head * head_dim + i], (head * 10) as f32); 118 | } 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /qwen3-export/tests/unit/config_loader_test.rs: -------------------------------------------------------------------------------- 1 | //! Integration tests for config loader functionality 2 | 3 | use super::*; 4 | use anyhow::Result; 5 | use std::{fs, path::PathBuf}; 6 | use tempfile::TempDir; 7 | 8 | /// Helper to create a minimal config.json for testing 9 | fn create_test_config_json(temp_dir: &TempDir) -> Result { 10 | let config_content = r#"{ 11 | "architectures": ["Qwen3ForCausalLM"], 12 | "hidden_size": 256, 13 | "intermediate_size": 1024, 14 | "num_hidden_layers": 4, 15 | "num_attention_heads": 8, 16 | "num_key_value_heads": 8, 17 | "vocab_size": 1000, 18 | "max_position_embeddings": 512, 19 | "rms_norm_eps": 1e-6, 20 | "head_dim": 32, 21 | "bos_token_id": 1, 22 | "eos_token_id": 2 23 | }"#; 24 | let config_path = temp_dir.path().join("config.json"); 25 | fs::write(config_path.clone(), config_content)?; 26 | 27 | Ok(config_path) 28 | } 29 | 30 | #[test] 31 | fn test_load_hf_config_valid() -> Result<()> { 32 | let temp_dir = TempDir::new()?; 33 | let path = create_test_config_json(&temp_dir)?; 34 | 35 | let config = load_hf_config(&path)?; 36 | 37 | // Verify all fields are loaded correctly 38 | assert_eq!(config.dim, 256); 39 | assert_eq!(config.hidden_dim, 1024); 40 | assert_eq!(config.n_layers, 4); 41 | assert_eq!(config.n_heads, 8); 42 | assert_eq!(config.n_kv_heads, 8); 43 | assert_eq!(config.vocab_size, 1000); 44 | assert_eq!(config.max_seq_len, 512); 45 | assert_eq!(config.head_dim, 32); 46 | assert!((config.norm_eps - 1e-6).abs() < 1e-9); 47 | assert_eq!(config.bos_token_id, 1); 48 | assert_eq!(config.eos_token_id, 2); 49 | 50 | Ok(()) 51 | } 52 | 53 | #[test] 54 | fn test_load_hf_config_invalid_json() -> Result<()> { 55 | let temp_dir = TempDir::new()?; 56 | let config_path = temp_dir.path().join("config.json"); 57 | fs::write(config_path.clone(), "invalid json")?; 58 | 59 | let result = load_hf_config(&config_path); 60 | assert!(result.is_err()); 61 | assert_eq!(result.unwrap_err().to_string(), "Failed to parse config.json: expected value at line 1 column 1"); 62 | 63 | Ok(()) 64 | } 65 | 66 | #[test] 67 | fn test_load_hf_config_missing_required_field() -> Result<()> { 68 | let temp_dir = TempDir::new()?; 69 | 70 | // Config missing required "hidden_size" field 71 | let config_content = r#"{ 72 | "intermediate_size": 1024, 73 | "num_hidden_layers": 4 74 | }"#; 75 | 76 | let config_path = temp_dir.path().join("config.json"); 77 | fs::write(config_path.clone(), config_content)?; 78 | 79 | let result = load_hf_config(&config_path); 80 | assert!(result.is_err()); 81 | assert_eq!( 82 | result.unwrap_err().to_string(), 83 | "Failed to parse config.json: missing field `hidden_size` at line 4 column 5" 84 | ); 85 | 86 | Ok(()) 87 | } 88 | 89 | #[test] 90 | fn test_load_hf_config_with_defaults() -> Result<()> { 91 | let temp_dir = TempDir::new()?; 92 | 93 | // Config without optional fields (bos_token_id, eos_token_id, head_dim) 94 | let config_content = r#"{ 95 | "architectures": ["Qwen3ForCausalLM"], 96 | "hidden_size": 256, 97 | "intermediate_size": 1024, 98 | "num_hidden_layers": 4, 99 | "num_attention_heads": 8, 100 | "num_key_value_heads": 8, 101 | "vocab_size": 1000, 102 | "max_position_embeddings": 512, 103 | "rms_norm_eps": 1e-6 104 | }"#; 105 | 106 | let config_path = temp_dir.path().join("config.json"); 107 | fs::write(config_path.clone(), config_content)?; 108 | 109 | let config = load_hf_config(&config_path)?; 110 | 111 | // Check defaults are applied 112 | assert_eq!(config.bos_token_id, 0); // default 113 | assert_eq!(config.eos_token_id, 0); // default 114 | assert_eq!(config.head_dim, 256 / 8); // calculated: dim / n_heads 115 | 116 | Ok(()) 117 | } 118 | -------------------------------------------------------------------------------- /qwen3-inference/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "qwen3-inference" 3 | description = "Qwen3 inference code." 4 | version.workspace = true 5 | edition.workspace = true 6 | authors.workspace = true 7 | license.workspace = true 8 | keywords.workspace = true 9 | categories.workspace = true 10 | repository.workspace = true 11 | 12 | [dependencies] 13 | anyhow = { workspace = true } 14 | byteorder = { workspace = true } 15 | rayon = { workspace = true } 16 | memmap2 = { workspace = true } 17 | log = { workspace = true } -------------------------------------------------------------------------------- /qwen3-inference/src/configuration.rs: -------------------------------------------------------------------------------- 1 | use std::io::Cursor; 2 | 3 | use crate::utils::MemoryMapper; 4 | use anyhow::{Context, Error, Result}; 5 | use byteorder::{LittleEndian, ReadBytesExt}; 6 | 7 | /// Magic number for validating checkpoint files 8 | const CHECKPOINT_MAGIC: i32 = 0x616a6331; 9 | /// Expected checkpoint version 10 | const CHECKPOINT_VERSION: i32 = 1; 11 | /// Size of the checkpoint header in bytes 12 | const HEADER_SIZE: usize = 256; 13 | /// Size of config structure in bytes. 14 | const CONFIG_SIZE: usize = std::mem::size_of::(); 15 | 16 | /// Configuration struct for transformer models. 17 | #[derive(Debug, Clone)] 18 | pub struct ModelConfig { 19 | pub architecture_id: usize, 20 | pub dim: usize, 21 | pub hidden_dim: usize, 22 | pub n_layers: usize, 23 | pub n_heads: usize, 24 | pub n_kv_heads: usize, 25 | pub head_dim: usize, 26 | pub seq_len: usize, 27 | pub vocab_size: usize, 28 | pub group_size: usize, 29 | pub shared_classifier: bool, 30 | } 31 | 32 | /// Configuration struct for reading model parameters from checkpoint files. 33 | #[derive(Debug, Clone, Copy)] 34 | #[repr(C)] 35 | struct Config { 36 | pub magic_number: i32, 37 | pub version: i32, 38 | pub architecture_id: i32, 39 | pub dim: i32, 40 | pub hidden_dim: i32, 41 | pub n_layers: i32, 42 | pub n_heads: i32, 43 | pub n_kv_heads: i32, 44 | pub vocab_size: i32, 45 | pub seq_len: i32, 46 | pub head_dim: i32, 47 | pub shared_classifier: i32, 48 | pub group_size: i32, 49 | } 50 | 51 | impl TryInto for Config { 52 | type Error = Error; 53 | 54 | fn try_into(self) -> Result { 55 | validate_config(&self).with_context(|| "Invalid model configuration")?; 56 | 57 | Ok(ModelConfig { 58 | architecture_id: self.architecture_id as usize, 59 | dim: self.dim as usize, 60 | hidden_dim: self.hidden_dim as usize, 61 | n_layers: self.n_layers as usize, 62 | n_heads: self.n_heads as usize, 63 | n_kv_heads: self.n_kv_heads as usize, 64 | head_dim: self.head_dim as usize, 65 | seq_len: self.seq_len as usize, 66 | vocab_size: self.vocab_size as usize, 67 | group_size: self.group_size as usize, 68 | shared_classifier: self.shared_classifier != 0, 69 | }) 70 | } 71 | } 72 | 73 | /// Reads and validates the model configuration from checkpoint data (mapper). 74 | /// 75 | /// The configuration is stored as 12 consecutive i32 values in little-endian format. 76 | /// This function performs bounds checking and validates the magic number and version. 77 | pub fn read_config(mapper: &mut MemoryMapper) -> Result { 78 | let data = mapper.get_bytes(CONFIG_SIZE)?; 79 | 80 | if data.len() != CONFIG_SIZE { 81 | anyhow::bail!("Insufficient data for config: need {} bytes, got {}", CONFIG_SIZE, data.len()); 82 | } 83 | 84 | let mut cursor = Cursor::new(data); 85 | 86 | // Use a macro to reduce repetitive error handling 87 | macro_rules! read_i32 { 88 | ($field:literal) => { 89 | cursor.read_i32::().with_context(|| format!("Failed to read {}", $field))? 90 | }; 91 | } 92 | 93 | let config = Config { 94 | magic_number: read_i32!("magic number"), 95 | version: read_i32!("version"), 96 | architecture_id: read_i32!("architecture id"), 97 | dim: read_i32!("dimension"), 98 | hidden_dim: read_i32!("hidden dimension"), 99 | n_layers: read_i32!("number of layers"), 100 | n_heads: read_i32!("number of heads"), 101 | n_kv_heads: read_i32!("number of KV heads"), 102 | vocab_size: read_i32!("vocabulary size"), 103 | seq_len: read_i32!("sequence length"), 104 | head_dim: read_i32!("head dimension"), 105 | shared_classifier: read_i32!("shared classifier flag"), 106 | group_size: read_i32!("group size"), 107 | }; 108 | 109 | // prepare to load model weights (skip header). 110 | mapper.skip(HEADER_SIZE - CONFIG_SIZE)?; 111 | 112 | config.try_into() 113 | } 114 | 115 | /// Validates the model configuration to ensure it's supported. 116 | fn validate_config(config: &Config) -> Result<()> { 117 | match config.magic_number { 118 | CHECKPOINT_MAGIC => {} 119 | actual => anyhow::bail!("Invalid checkpoint magic number: expected {:#x}, got {:#x}", CHECKPOINT_MAGIC, actual), 120 | } 121 | 122 | match config.version { 123 | CHECKPOINT_VERSION => {} 124 | actual => anyhow::bail!("Unsupported checkpoint version: expected {}, got {}", CHECKPOINT_VERSION, actual), 125 | } 126 | 127 | // Validate positive dimensions 128 | let dimensions = [ 129 | ("architecture_id", config.architecture_id), 130 | ("dim", config.dim), 131 | ("n_layers", config.n_layers), 132 | ("n_heads", config.n_heads), 133 | ("n_kv_heads", config.n_kv_heads), 134 | ("vocab_size", config.vocab_size), 135 | ("seq_len", config.seq_len), 136 | ("head_dim", config.head_dim), 137 | ]; 138 | 139 | for (name, value) in dimensions { 140 | if value <= 0 { 141 | anyhow::bail!("Invalid {}: must be positive, got {}", name, value); 142 | } 143 | } 144 | 145 | Ok(()) 146 | } 147 | -------------------------------------------------------------------------------- /qwen3-inference/src/generation.rs: -------------------------------------------------------------------------------- 1 | use crate::models::Transformer; 2 | use crate::sampler::Sampler; 3 | use crate::tokenizer::Tokenizer; 4 | use anyhow::Result; 5 | use log::info; 6 | use std::io::{self, Write}; 7 | use std::time::Instant; 8 | 9 | pub fn generate( 10 | transformer: &mut T, 11 | tokenizer: &Tokenizer, 12 | sampler: &mut Sampler, 13 | prompt: Option<&str>, 14 | ) -> Result<()> { 15 | let prompt = prompt.unwrap_or(""); 16 | let prompt_tokens = tokenizer.encode(prompt); 17 | 18 | if prompt_tokens.is_empty() { 19 | anyhow::bail!("Please provide a prompt"); 20 | } 21 | 22 | let seq_len = transformer.get_config().seq_len; 23 | let mut state = GenerationState::new(prompt_tokens[0]); 24 | 25 | while state.pos < seq_len { 26 | let next_token = if state.pos < prompt_tokens.len() - 1 { 27 | // Still processing prompt tokens 28 | prompt_tokens[state.pos + 1] 29 | } else { 30 | // Generate new tokens 31 | state.metrics.start_generation(); 32 | let next = generate_next_token(transformer, sampler, state.token, state.pos)?; 33 | state.metrics.increment_token(); 34 | 35 | if is_termination_token(next, tokenizer) { 36 | break; 37 | } 38 | next 39 | }; 40 | 41 | output_token(tokenizer, state.token)?; 42 | state.advance(next_token); 43 | } 44 | 45 | state.metrics.report_and_reset(); 46 | println!(); 47 | Ok(()) 48 | } 49 | 50 | pub fn chat( 51 | transformer: &mut T, 52 | tokenizer: &Tokenizer, 53 | sampler: &mut Sampler, 54 | cli_user_prompt: Option<&str>, 55 | system_prompt: Option<&str>, 56 | ) -> Result<()> { 57 | let stdin = io::stdin(); 58 | let seq_len = transformer.get_config().seq_len; 59 | let mut state = GenerationState::new(0); 60 | let mut user_turn = true; 61 | let mut next_token = 0; 62 | 63 | loop { 64 | // Reset context if window exceeded 65 | if state.pos >= seq_len { 66 | state.reset(0); 67 | user_turn = true; 68 | println!(); 69 | } 70 | 71 | if user_turn { 72 | state.metrics.report_and_reset(); 73 | 74 | if !handle_user_turn( 75 | &stdin, 76 | transformer, 77 | tokenizer, 78 | sampler, 79 | &mut state, 80 | &mut next_token, 81 | cli_user_prompt, 82 | system_prompt, 83 | )? { 84 | break; 85 | } 86 | user_turn = false; 87 | } else if handle_assistant_turn(transformer, tokenizer, sampler, &mut state, &mut next_token, &mut user_turn)? { 88 | continue; // Turn ended, continue to next iteration 89 | } 90 | } 91 | 92 | Ok(()) 93 | } 94 | 95 | fn handle_user_turn( 96 | stdin: &io::Stdin, 97 | transformer: &mut T, 98 | tokenizer: &Tokenizer, 99 | sampler: &mut Sampler, 100 | state: &mut GenerationState, 101 | next_token: &mut usize, 102 | cli_user_prompt: Option<&str>, 103 | system_prompt: Option<&str>, 104 | ) -> Result { 105 | let user_prompt = get_user_input(stdin, state.pos, cli_user_prompt)?; 106 | 107 | // Check if we should exit 108 | if user_prompt.is_empty() && !(state.pos == 0 && cli_user_prompt.is_some()) { 109 | return Ok(false); 110 | } 111 | 112 | let rendered_prompt = render_prompt(state.pos, system_prompt, &user_prompt, tokenizer); 113 | let prompt_tokens = tokenizer.encode(&rendered_prompt); 114 | 115 | // Process prompt tokens 116 | for &token in &prompt_tokens { 117 | if state.pos >= transformer.get_config().seq_len { 118 | break; 119 | } 120 | 121 | *next_token = generate_next_token(transformer, sampler, token, state.pos)?; 122 | state.advance(token); 123 | } 124 | 125 | Ok(true) 126 | } 127 | 128 | fn handle_assistant_turn( 129 | transformer: &mut T, 130 | tokenizer: &Tokenizer, 131 | sampler: &mut Sampler, 132 | state: &mut GenerationState, 133 | next_token: &mut usize, 134 | user_turn: &mut bool, 135 | ) -> Result { 136 | if is_termination_token(*next_token, tokenizer) { 137 | state.metrics.report_and_reset(); 138 | println!(); 139 | *user_turn = true; 140 | return Ok(true); 141 | } 142 | 143 | state.metrics.start_generation(); 144 | output_token(tokenizer, *next_token)?; 145 | 146 | *next_token = generate_next_token(transformer, sampler, *next_token, state.pos)?; 147 | state.metrics.increment_token(); 148 | state.advance(*next_token); 149 | 150 | Ok(false) 151 | } 152 | 153 | fn generate_next_token( 154 | transformer: &mut T, 155 | sampler: &mut Sampler, 156 | token: usize, 157 | pos: usize, 158 | ) -> Result { 159 | let logits = transformer.forward(token, pos); 160 | let mut logits_copy = logits.to_vec(); 161 | Ok(sampler.sample(&mut logits_copy)) 162 | } 163 | 164 | fn output_token(tokenizer: &Tokenizer, token: usize) -> Result<()> { 165 | print!("{}", tokenizer.decode(token)); 166 | io::stdout().flush()?; 167 | Ok(()) 168 | } 169 | 170 | fn is_termination_token(token: usize, tokenizer: &Tokenizer) -> bool { 171 | token == tokenizer.bos_token_id as usize || token == tokenizer.eos_token_id as usize 172 | } 173 | 174 | fn get_user_input(stdin: &io::Stdin, pos: usize, cli_user_prompt: Option<&str>) -> Result { 175 | match (pos, cli_user_prompt) { 176 | (0, Some(prompt)) => Ok(prompt.to_string()), 177 | (_, Some(_)) => Ok(String::new()), // Signal to break 178 | _ => { 179 | print!("> "); 180 | io::stdout().flush()?; 181 | let mut input = String::new(); 182 | stdin.read_line(&mut input)?; 183 | Ok(input.trim().to_string()) 184 | } 185 | } 186 | } 187 | 188 | fn render_prompt(pos: usize, system_prompt: Option<&str>, user_prompt: &str, tokenizer: &Tokenizer) -> String { 189 | match (pos, system_prompt) { 190 | (0, Some(sys_prompt)) => { 191 | tokenizer.system_prompt_template.replace("%s", &format!("{sys_prompt}\n{user_prompt}")) 192 | } 193 | _ => tokenizer.prompt_template.replace("%s", user_prompt), 194 | } 195 | } 196 | 197 | /// Tracks token generation performance metrics 198 | struct TokenMetrics { 199 | start_time: Option, 200 | generated_count: usize, 201 | } 202 | 203 | impl TokenMetrics { 204 | fn new() -> Self { 205 | Self { start_time: None, generated_count: 0 } 206 | } 207 | 208 | fn start_generation(&mut self) { 209 | if self.start_time.is_none() { 210 | self.start_time = Some(Instant::now()); 211 | } 212 | } 213 | 214 | fn increment_token(&mut self) { 215 | self.generated_count += 1; 216 | } 217 | 218 | fn report_and_reset(&mut self) { 219 | if let Some(start_time) = self.start_time.take() { 220 | let duration = start_time.elapsed(); 221 | if self.generated_count > 0 && duration.as_secs_f64() > 0.0 { 222 | let tps = self.generated_count as f64 / duration.as_secs_f64(); 223 | info!( 224 | "\n[Generated {} tokens in {:.2}s - {:.2} tokens/sec]", 225 | self.generated_count, 226 | duration.as_secs_f64(), 227 | tps 228 | ); 229 | } 230 | } 231 | self.generated_count = 0; 232 | } 233 | } 234 | 235 | /// Represents the current generation state 236 | struct GenerationState { 237 | pos: usize, 238 | token: usize, 239 | metrics: TokenMetrics, 240 | } 241 | 242 | impl GenerationState { 243 | fn new(initial_token: usize) -> Self { 244 | Self { pos: 0, token: initial_token, metrics: TokenMetrics::new() } 245 | } 246 | 247 | fn reset(&mut self, initial_token: usize) { 248 | self.metrics.report_and_reset(); 249 | self.pos = 0; 250 | self.token = initial_token; 251 | } 252 | 253 | fn advance(&mut self, next_token: usize) { 254 | self.token = next_token; 255 | self.pos += 1; 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /qwen3-inference/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Stub crate for Qwen3 inference code. 2 | //! 3 | //! This crate will provide inference functionality for Qwen3 models in the future. 4 | 5 | mod configuration; 6 | mod generation; 7 | mod layers; 8 | mod models; 9 | mod sampler; 10 | mod tensor; 11 | mod tokenizer; 12 | mod utils; 13 | 14 | use anyhow::Result; 15 | use log::debug; 16 | use std::time::{SystemTime, UNIX_EPOCH}; 17 | 18 | pub use crate::models::{Transformer, TransformerBuilder}; 19 | 20 | use crate::generation::{chat, generate}; 21 | use crate::sampler::Sampler; 22 | use crate::tokenizer::Tokenizer; 23 | 24 | #[derive(Debug, Clone)] 25 | pub struct InferenceConfig { 26 | pub checkpoint_path: String, 27 | pub temperature: f32, 28 | pub topp: f32, 29 | pub ctx_length: Option, 30 | pub mode: String, 31 | pub prompt: Option, 32 | pub system_prompt: Option, 33 | pub enable_thinking: bool, 34 | pub seed: u64, 35 | } 36 | 37 | impl InferenceConfig { 38 | pub fn builder() -> InferenceConfigBuilder { 39 | InferenceConfigBuilder::default() 40 | } 41 | } 42 | 43 | #[derive(Debug, Default)] 44 | pub struct InferenceConfigBuilder { 45 | checkpoint_path: Option, 46 | temperature: Option, 47 | topp: Option, 48 | ctx_length: Option, 49 | mode: Option, 50 | prompt: Option, 51 | system_prompt: Option, 52 | enable_thinking: Option, 53 | seed: Option, 54 | } 55 | 56 | impl InferenceConfigBuilder { 57 | pub fn checkpoint_path(mut self, path: Option<&String>) -> Self { 58 | self.checkpoint_path = path.cloned(); 59 | self 60 | } 61 | pub fn temperature(mut self, temperature: Option) -> Self { 62 | self.temperature = temperature; 63 | self 64 | } 65 | pub fn topp(mut self, topp: Option) -> Self { 66 | self.topp = topp; 67 | self 68 | } 69 | pub fn ctx_length(mut self, ctx_length: Option) -> Self { 70 | self.ctx_length = ctx_length; 71 | self 72 | } 73 | pub fn mode(mut self, mode: Option<&String>) -> Self { 74 | self.mode = mode.cloned(); 75 | self 76 | } 77 | pub fn prompt(mut self, prompt: Option<&String>) -> Self { 78 | self.prompt = prompt.cloned(); 79 | self 80 | } 81 | pub fn system_prompt(mut self, system_prompt: Option<&String>) -> Self { 82 | self.system_prompt = system_prompt.cloned(); 83 | self 84 | } 85 | pub fn enable_thinking(mut self, enable: Option) -> Self { 86 | self.enable_thinking = enable; 87 | self 88 | } 89 | pub fn seed(mut self, seed: Option) -> Self { 90 | self.seed = seed; 91 | self 92 | } 93 | pub fn build(self) -> Result { 94 | Ok(InferenceConfig { 95 | checkpoint_path: self.checkpoint_path.ok_or("checkpoint_path is required")?, 96 | temperature: self.temperature.unwrap_or(1.0), 97 | topp: self.topp.unwrap_or(0.9), 98 | ctx_length: self.ctx_length, 99 | mode: self.mode.unwrap_or_else(|| "chat".to_string()), 100 | prompt: self.prompt, 101 | system_prompt: self.system_prompt, 102 | enable_thinking: self.enable_thinking.unwrap_or(false), 103 | seed: self.seed.unwrap_or_else(|| SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs()), 104 | }) 105 | } 106 | } 107 | 108 | /// Runs inference. 109 | pub fn run_inference(mut transformer: T, inference_config: InferenceConfig) -> Result<()> { 110 | debug!("{inference_config:#?}"); 111 | 112 | let transformer_config = transformer.get_config(); 113 | 114 | let tokenizer = Tokenizer::new( 115 | &inference_config.checkpoint_path, 116 | transformer_config.vocab_size, 117 | inference_config.enable_thinking, 118 | )?; 119 | 120 | debug!("{tokenizer:#?}"); 121 | 122 | let mut sampler = Sampler::new( 123 | transformer_config.vocab_size, 124 | inference_config.temperature, 125 | inference_config.topp, 126 | inference_config.seed, 127 | ); 128 | 129 | let prompt = inference_config.prompt.as_deref(); 130 | let system_prompt = inference_config.system_prompt.as_deref(); 131 | 132 | // Run 133 | match inference_config.mode.as_str() { 134 | "generate" => generate(&mut transformer, &tokenizer, &mut sampler, prompt), 135 | "chat" => chat(&mut transformer, &tokenizer, &mut sampler, prompt, system_prompt), 136 | _ => anyhow::bail!("Unknown mode: {inference_config:?}"), 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /qwen3-inference/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | 3 | use crate::{ 4 | configuration::{ModelConfig, read_config}, 5 | tensor::QuantizedTensor, 6 | utils::MemoryMapper, 7 | }; 8 | use anyhow::{Context, Result}; 9 | 10 | mod qwen3; 11 | 12 | /// Contains the main inference logic for the Transformer model. 13 | pub trait Transformer { 14 | /// Runs forward pass for the Transformer model. 15 | fn forward(&mut self, token: usize, pos: usize) -> &[f32]; 16 | 17 | fn get_config(&self) -> &ModelConfig; 18 | } 19 | 20 | #[non_exhaustive] 21 | pub enum Transformers { 22 | Qwen3(qwen3::Qwen3Transformer), 23 | } 24 | 25 | impl Transformer for Transformers { 26 | fn forward(&mut self, token: usize, pos: usize) -> &[f32] { 27 | match self { 28 | Transformers::Qwen3(model) => model.forward(token, pos), 29 | } 30 | } 31 | 32 | fn get_config(&self) -> &ModelConfig { 33 | match self { 34 | Transformers::Qwen3(model) => model.get_config(), 35 | } 36 | } 37 | } 38 | 39 | /// Builder pattern for creating transformer models 40 | pub struct TransformerBuilder { 41 | checkpoint_path: String, 42 | ctx_length: Option, 43 | } 44 | 45 | impl TransformerBuilder { 46 | pub fn new(checkpoint_path: &str) -> Self { 47 | Self { checkpoint_path: checkpoint_path.to_string(), ctx_length: None } 48 | } 49 | 50 | pub fn with_ctx_length(mut self, ctx_length: Option) -> Self { 51 | self.ctx_length = ctx_length; 52 | self 53 | } 54 | 55 | pub fn build(self) -> Result { 56 | let file = File::open(&self.checkpoint_path) 57 | .with_context(|| format!("Failed to open checkpoint: {}", self.checkpoint_path))?; 58 | 59 | let mut mapper = MemoryMapper::new(file)?; 60 | 61 | // Read config from the first part of the file 62 | let mut config = read_config(&mut mapper)?; 63 | 64 | // Apply context length override if provided 65 | if let Some(ctx_len) = self.ctx_length { 66 | config.seq_len = ctx_len.min(config.seq_len); 67 | } 68 | 69 | match config.architecture_id { 70 | 1 => Ok(Transformers::Qwen3(qwen3::Qwen3Transformer::new(config, mapper)?)), 71 | x => anyhow::bail!("Unknown architecture_id: {x}"), 72 | } 73 | } 74 | } 75 | 76 | /// Reads multiple quantized tensors from memory mapper. 77 | /// 78 | /// Each quantized tensor consists of: 79 | /// 1. Quantized weights (i8 values) 80 | /// 2. Scale factors (f32 values) 81 | /// 82 | /// The scale factors are grouped according to the quantization group size. 83 | pub(crate) fn create_quantized_tensors( 84 | mapper: &mut MemoryMapper, 85 | n_tensors: usize, 86 | size_each: usize, 87 | group_size: usize, 88 | ) -> Result> { 89 | (0..n_tensors) 90 | .map(|i| { 91 | // Read quantized values 92 | let q_bytes = 93 | mapper.get_bytes(size_each).with_context(|| format!("Failed to read quantized tensor {i} data"))?; 94 | 95 | // Convert bytes to i8 (avoiding copy by using unsafe) 96 | let q_slice = unsafe { std::slice::from_raw_parts(q_bytes.as_ptr() as *const i8, size_each) }; 97 | 98 | // Calculate and read scale factors 99 | let s_len = size_each / group_size; 100 | let s_slice = 101 | mapper.get_f32_slice(s_len).with_context(|| format!("Failed to read scale factors for tensor {i}"))?; 102 | 103 | // Convert to 'static lifetime using unsafe transmute 104 | let q_static = unsafe { std::mem::transmute::<&[i8], &'static [i8]>(q_slice) }; 105 | let s_static = unsafe { std::mem::transmute::<&[f32], &'static [f32]>(s_slice) }; 106 | 107 | Ok(QuantizedTensor::from_slices(q_static, s_static)) 108 | }) 109 | .collect() 110 | } 111 | -------------------------------------------------------------------------------- /qwen3-inference/src/models/qwen3.rs: -------------------------------------------------------------------------------- 1 | use crate::{configuration::ModelConfig, layers::*, models::create_quantized_tensors, tensor::*, utils::MemoryMapper}; 2 | use anyhow::{Context, Result}; 3 | 4 | /// Main Transformer model implementing a decoder-only architecture. 5 | pub struct Qwen3Transformer { 6 | config: ModelConfig, 7 | token_embedding: TokenEmbedding, 8 | blocks: Vec, 9 | final_norm: RMSNorm, 10 | lm_head: Linear, 11 | buffers: TransformerBlockBuffers, 12 | logits: Vec, 13 | _mapper: MemoryMapper, 14 | } 15 | 16 | impl Qwen3Transformer { 17 | pub(crate) fn new(config: ModelConfig, mut mapper: MemoryMapper) -> Result { 18 | let weights = load_weights(&mut mapper, &config)?; 19 | 20 | // Initialize block buffers 21 | let buffers = TransformerBlockBuffers::new(&config)?; 22 | 23 | // Output buffer 24 | let logits = vec![0.0; config.vocab_size]; 25 | 26 | // Create transformer blocks 27 | let mut blocks = Vec::new(); 28 | for layer_idx in 0..config.n_layers { 29 | let block = create_transformer_block(&config, layer_idx, &weights)?; 30 | blocks.push(block); 31 | } 32 | 33 | // Create final normalization 34 | let final_norm = RMSNorm::new(weights.rms_final_weight[..config.dim].to_vec()); 35 | 36 | // Create language model head 37 | let lm_head = Linear::new(weights.wcls, config.dim, config.vocab_size, config.group_size); 38 | 39 | // Create token embedding 40 | let token_embedding = TokenEmbedding::new(weights.token_embedding_table, config.dim); 41 | 42 | Ok(Self { 43 | config, 44 | token_embedding, 45 | blocks, 46 | final_norm, 47 | lm_head, 48 | buffers, 49 | logits, 50 | _mapper: mapper, // Keep the mapper alive for the lifetime of the transformer 51 | }) 52 | } 53 | 54 | /// Forward pass through the transformer for autoregressive generation 55 | /// 56 | /// **Arguments:** 57 | /// - `token`: Current input token ID 58 | /// - `pos`: Current position in sequence (for RoPE and KV cache indexing) 59 | /// 60 | /// **Returns:** 61 | /// - Probability distribution over vocabulary (logits) for next token prediction 62 | pub fn forward(&mut self, token: usize, pos: usize) -> &[f32] { 63 | // Token embedding 64 | self.token_embedding.forward(token, &mut self.buffers.x); 65 | 66 | // Process through transformer blocks 67 | for block in &self.blocks { 68 | block.forward(pos, &mut self.buffers); 69 | } 70 | 71 | // Final normalization 72 | self.final_norm.forward_inplace(&mut self.buffers.x); 73 | 74 | // Classification head 75 | quantize(&mut self.buffers.xq, &self.buffers.x, self.buffers.x.len(), self.lm_head.group_size); 76 | self.lm_head.forward(&mut self.logits, &self.buffers.xq); 77 | 78 | &self.logits 79 | } 80 | 81 | pub fn get_config(&self) -> &ModelConfig { 82 | &self.config 83 | } 84 | } 85 | 86 | impl std::fmt::Debug for Qwen3Transformer { 87 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 88 | struct BlocksSummary<'a, T>(&'a [T]); 89 | 90 | impl<'a, T: std::fmt::Debug> std::fmt::Debug for BlocksSummary<'a, T> { 91 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 92 | f.debug_list() 93 | .entries(self.0.iter().take(1)) 94 | .entry(&format_args!("... and {} more", self.0.len().saturating_sub(1))) 95 | .finish() 96 | } 97 | } 98 | 99 | f.debug_struct("Qwen3Transformer") 100 | .field("config", &self.config) 101 | .field("token_embedding", &self.token_embedding) 102 | .field("blocks", &BlocksSummary(&self.blocks)) 103 | .field("final_norm", &self.final_norm) 104 | .field("lm_head", &self.lm_head) 105 | .finish() 106 | } 107 | } 108 | 109 | /// Transformer Block - Core decoder layer combining self-attention and feed-forward 110 | pub struct TransformerBlock { 111 | pub attn_norm: RMSNorm, 112 | pub attention: MultiHeadAttention, 113 | pub ffn_norm: RMSNorm, 114 | pub feed_forward: FeedForward, 115 | pub layer_idx: usize, 116 | pub residual_conn: ResidualConnection, 117 | } 118 | 119 | impl TransformerBlock { 120 | pub fn new( 121 | attn_norm: RMSNorm, 122 | attention: MultiHeadAttention, 123 | ffn_norm: RMSNorm, 124 | feed_forward: FeedForward, 125 | residual_conn: ResidualConnection, 126 | layer_idx: usize, 127 | ) -> Self { 128 | Self { attn_norm, attention, ffn_norm, feed_forward, layer_idx, residual_conn } 129 | } 130 | 131 | fn forward(&self, pos: usize, buffers: &mut TransformerBlockBuffers) { 132 | // Attention block with residual connection 133 | let dim = buffers.x.len(); 134 | self.attn_norm.forward(&mut buffers.xb[..dim], &buffers.x); 135 | 136 | quantize(&mut buffers.xq, &buffers.xb[..dim], dim, self.attention.wq.group_size); 137 | 138 | self.attention.forward( 139 | pos, 140 | self.layer_idx, 141 | AttentionBuffers { 142 | xq: &buffers.xq, 143 | q: &mut buffers.q, 144 | xb: &mut buffers.xb, 145 | att: &mut buffers.att, 146 | temp: &mut buffers.temp, 147 | key_cache: &mut buffers.key_cache, 148 | value_cache: &mut buffers.value_cache, 149 | }, 150 | ); 151 | 152 | quantize(&mut buffers.xq, &buffers.xb, buffers.xb.len(), self.attention.wo.group_size); 153 | self.attention.wo.forward(&mut buffers.xb2, &buffers.xq); 154 | 155 | // Residual connection 156 | self.residual_conn.forward(&mut buffers.x, &buffers.xb2); 157 | 158 | // Feed-forward block with residual connection 159 | self.ffn_norm.forward(&mut buffers.xb[..dim], &buffers.x); 160 | 161 | quantize(&mut buffers.xq, &buffers.xb[..dim], dim, self.feed_forward.w1.group_size); 162 | 163 | // Create feed-forward buffer context 164 | let ffn_buffers = FeedForwardBuffers { 165 | xq: &buffers.xq, 166 | hb: &mut buffers.hb, 167 | hb2: &mut buffers.hb2, 168 | hq: &mut buffers.hq, 169 | xb: &mut buffers.xb, 170 | }; 171 | 172 | self.feed_forward.forward(ffn_buffers); 173 | 174 | // Residual connection 175 | self.residual_conn.forward(&mut buffers.x, &buffers.xb[..dim]); 176 | } 177 | } 178 | 179 | impl std::fmt::Debug for TransformerBlock { 180 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 181 | f.debug_struct("TransformerBlock") 182 | .field("layer_idx", &self.layer_idx) 183 | .field("attn_norm", &self.attn_norm) 184 | .field("attention", &self.attention) 185 | .field("ffn_norm", &self.ffn_norm) 186 | .field("feed_forward", &self.feed_forward) 187 | .finish() 188 | } 189 | } 190 | 191 | /// Loads all model weights from the checkpoint data. 192 | /// 193 | /// This function reads weights in the order they appear in the checkpoint: 194 | /// 1. Normalization weights (f32) 195 | /// 2. Token embeddings (quantized) 196 | /// 3. Attention weights (quantized) 197 | /// 4. Feed-forward weights (quantized) 198 | /// 5. Classification weights (quantized, may be shared) 199 | fn load_weights(mapper: &mut MemoryMapper, config: &ModelConfig) -> Result { 200 | let ModelConfig { 201 | group_size, 202 | dim, 203 | n_layers, 204 | head_dim, 205 | vocab_size, 206 | hidden_dim, 207 | n_heads, 208 | n_kv_heads, 209 | shared_classifier, 210 | .. 211 | } = *config; 212 | 213 | let all_heads_dim = n_heads * head_dim; 214 | let kv_dim = n_kv_heads * head_dim; 215 | 216 | // Helper macro for reading f32 arrays with context 217 | macro_rules! read_f32_weights { 218 | ($size:expr, $field:literal) => { 219 | // SAFETY: we keep the mmap alive for the lifetime of the transformer 220 | unsafe { 221 | std::mem::transmute::<&[f32], &[f32]>( 222 | mapper.get_f32_slice($size).with_context(|| format!("Failed to read {}", $field))?, 223 | ) 224 | } 225 | }; 226 | } 227 | 228 | let rms_att_weight = read_f32_weights!(n_layers * dim, "attention normalization weights"); 229 | let rms_ffn_weight = read_f32_weights!(n_layers * dim, "FFN normalization weights"); 230 | let rms_final_weight = read_f32_weights!(dim, "final normalization weights"); 231 | let q_ln_weights = read_f32_weights!(n_layers * head_dim, "query layer norm weights"); 232 | let k_ln_weights = read_f32_weights!(n_layers * head_dim, "key layer norm weights"); 233 | 234 | // Read quantized tensors 235 | let q_tokens = create_quantized_tensors(mapper, 1, vocab_size * dim, group_size)? 236 | .into_iter() 237 | .next() 238 | .expect("Expected exactly one token embedding tensor"); 239 | 240 | // Dequantize token embeddings (we need owned data for this) 241 | let mut token_embedding_table = vec![0.0; vocab_size * dim]; 242 | dequantize(&q_tokens, &mut token_embedding_table, group_size); 243 | 244 | let wq = create_quantized_tensors(mapper, n_layers, dim * all_heads_dim, group_size)?; 245 | let wk = create_quantized_tensors(mapper, n_layers, dim * kv_dim, group_size)?; 246 | let wv = create_quantized_tensors(mapper, n_layers, dim * kv_dim, group_size)?; 247 | let wo = create_quantized_tensors(mapper, n_layers, all_heads_dim * dim, group_size)?; 248 | let w1 = create_quantized_tensors(mapper, n_layers, dim * hidden_dim, group_size)?; 249 | let w2 = create_quantized_tensors(mapper, n_layers, hidden_dim * dim, group_size)?; 250 | let w3 = create_quantized_tensors(mapper, n_layers, dim * hidden_dim, group_size)?; 251 | 252 | let wcls = if shared_classifier { 253 | q_tokens.clone() 254 | } else { 255 | create_quantized_tensors(mapper, 1, dim * vocab_size, group_size)? 256 | .into_iter() 257 | .next() 258 | .expect("Expected exactly one classification tensor") 259 | }; 260 | 261 | Ok(TransformerWeights { 262 | token_embedding_table, 263 | rms_att_weight, 264 | rms_ffn_weight, 265 | wq, 266 | wk, 267 | wv, 268 | wo, 269 | q_ln_weights, 270 | k_ln_weights, 271 | w1, 272 | w2, 273 | w3, 274 | rms_final_weight, 275 | wcls, 276 | }) 277 | } 278 | 279 | fn create_transformer_block( 280 | model_config: &ModelConfig, 281 | layer_idx: usize, 282 | weights: &TransformerWeights, 283 | ) -> Result { 284 | let dim = model_config.dim; 285 | let head_dim = model_config.head_dim; 286 | let all_heads_dim = model_config.n_heads * head_dim; 287 | let kv_dim = model_config.n_kv_heads * head_dim; 288 | let hidden_dim = model_config.hidden_dim; 289 | let group_size = model_config.group_size; 290 | 291 | // Attention normalization 292 | let attn_norm_start = layer_idx * dim; 293 | let attn_norm = RMSNorm::new(weights.rms_att_weight[attn_norm_start..attn_norm_start + dim].to_vec()); 294 | 295 | // Query/Key normalization 296 | let qk_norm_start = layer_idx * head_dim; 297 | let q_norm = RMSNorm::new(weights.q_ln_weights[qk_norm_start..qk_norm_start + head_dim].to_vec()); 298 | let k_norm = RMSNorm::new(weights.k_ln_weights[qk_norm_start..qk_norm_start + head_dim].to_vec()); 299 | 300 | // Attention projections 301 | let wq = Linear::new(weights.wq[layer_idx].clone(), dim, all_heads_dim, group_size); 302 | let wk = Linear::new(weights.wk[layer_idx].clone(), dim, kv_dim, group_size); 303 | let wv = Linear::new(weights.wv[layer_idx].clone(), dim, kv_dim, group_size); 304 | let wo = Linear::new(weights.wo[layer_idx].clone(), all_heads_dim, dim, group_size); 305 | 306 | let attention = MultiHeadAttention::new(wq, wk, wv, wo, q_norm, k_norm, model_config); 307 | 308 | // FFN normalization 309 | let ffn_norm_start = layer_idx * dim; 310 | let ffn_norm = RMSNorm::new(weights.rms_ffn_weight[ffn_norm_start..ffn_norm_start + dim].to_vec()); 311 | 312 | // Feed-forward projections 313 | let w1 = Linear::new(weights.w1[layer_idx].clone(), dim, hidden_dim, group_size); 314 | let w2 = Linear::new(weights.w2[layer_idx].clone(), hidden_dim, dim, group_size); 315 | let w3 = Linear::new(weights.w3[layer_idx].clone(), dim, hidden_dim, group_size); 316 | 317 | let feed_forward = FeedForward::new(w1, w2, w3); 318 | 319 | let residual_conn = ResidualConnection::new(); 320 | 321 | Ok(TransformerBlock::new(attn_norm, attention, ffn_norm, feed_forward, residual_conn, layer_idx)) 322 | } 323 | 324 | /// Contains all the learned parameters for the transformer model. 325 | /// 326 | /// This structure holds both quantized weights (for memory efficiency) and 327 | /// pre-computed values like the dequantized token embedding table. 328 | #[derive(Debug)] 329 | struct TransformerWeights { 330 | /// Pre-dequantized token embedding table for fast lookup during inference 331 | /// Shape: [vocab_size, dim] 332 | pub token_embedding_table: Vec, 333 | 334 | /// RMS normalization weights for attention layers 335 | /// Shape: [n_layers, dim] (flattened) 336 | pub rms_att_weight: &'static [f32], 337 | 338 | /// RMS normalization weights for feed-forward layers 339 | /// Shape: [n_layers, dim] (flattened) 340 | pub rms_ffn_weight: &'static [f32], 341 | 342 | /// Attention projection weights (quantized for memory efficiency) 343 | /// Query projections: [n_layers] × [dim, n_heads * head_dim] 344 | pub wq: Vec, 345 | /// Key projections: [n_layers] × [dim, n_kv_heads * head_dim] 346 | pub wk: Vec, 347 | /// Value projections: [n_layers] × [dim, n_kv_heads * head_dim] 348 | pub wv: Vec, 349 | /// Output projections: [n_layers] × [n_heads * head_dim, dim] 350 | pub wo: Vec, 351 | 352 | /// QK-RMSNorm weights for Qwen3 architecture 353 | /// Query layer norm: [n_layers, head_dim] (flattened) 354 | pub q_ln_weights: &'static [f32], 355 | /// Key layer norm: [n_layers, head_dim] (flattened) 356 | pub k_ln_weights: &'static [f32], 357 | 358 | /// Feed-forward network weights (quantized) 359 | /// Gate projection: [n_layers] × [dim, hidden_dim] 360 | pub w1: Vec, 361 | /// Down projection: [n_layers] × [hidden_dim, dim] 362 | pub w2: Vec, 363 | /// Up projection: [n_layers] × [dim, hidden_dim] 364 | pub w3: Vec, 365 | 366 | /// Final RMS normalization weight before classification 367 | /// Shape: [dim] 368 | pub rms_final_weight: &'static [f32], 369 | 370 | /// Classification head weights (may be shared with token embeddings) 371 | /// Shape: [dim, vocab_size] 372 | pub wcls: QuantizedTensor, 373 | } 374 | 375 | /// Buffer context for transformer block operations (owned buffers) 376 | pub struct TransformerBlockBuffers { 377 | /// Primary activation buffer (input/output) [dim] 378 | pub x: Vec, 379 | 380 | /// Multi-purpose working buffer [all_heads_dim] 381 | pub xb: Vec, 382 | 383 | /// Secondary buffer for residual connections [dim] 384 | pub xb2: Vec, 385 | 386 | /// Quantized buffer for attention/FFN operations 387 | pub xq: QuantizedTensor, 388 | 389 | /// Query buffer for attention computation [all_heads_dim] 390 | pub q: Vec, 391 | 392 | /// Attention weights buffer [n_heads * seq_len] 393 | pub att: Vec, 394 | 395 | /// FFN gate projection buffer [hidden_dim] 396 | pub hb: Vec, 397 | 398 | /// FFN up projection buffer [hidden_dim] 399 | pub hb2: Vec, 400 | 401 | /// Quantized FFN buffer [hidden_dim] 402 | pub hq: QuantizedTensor, 403 | 404 | /// Key-Value cache (full cache, shared across all layers) 405 | pub key_cache: Vec, 406 | pub value_cache: Vec, 407 | 408 | /// Temporary workspace for intermediate computations [head_dim] 409 | pub temp: Vec, 410 | } 411 | 412 | impl TransformerBlockBuffers { 413 | /// Creates a new buffer state with pre-allocated buffers based on model configuration. 414 | pub fn new(config: &ModelConfig) -> Result { 415 | let ModelConfig { group_size, n_heads, head_dim, n_kv_heads, dim, hidden_dim, seq_len, n_layers, .. } = *config; 416 | 417 | let all_heads_dim = n_heads * head_dim; 418 | let kv_dim = n_kv_heads * head_dim; 419 | 420 | Ok(Self { 421 | // Core activation buffers 422 | x: vec![0.0; dim], 423 | xb: vec![0.0; all_heads_dim], 424 | xb2: vec![0.0; dim], 425 | 426 | // Quantized buffers for efficient computation 427 | xq: QuantizedTensor::new(all_heads_dim, group_size), 428 | 429 | // Attention-specific buffers 430 | q: vec![0.0; all_heads_dim], 431 | att: vec![0.0; n_heads * seq_len], 432 | 433 | // FFN buffers 434 | hb: vec![0.0; hidden_dim], 435 | hb2: vec![0.0; hidden_dim], 436 | hq: QuantizedTensor::new(hidden_dim, group_size), 437 | 438 | // KV cache for autoregressive generation 439 | key_cache: vec![0.0; n_layers * seq_len * kv_dim], 440 | value_cache: vec![0.0; n_layers * seq_len * kv_dim], 441 | 442 | // Temporary workspace for computations 443 | temp: vec![0.0; head_dim], 444 | }) 445 | } 446 | } 447 | -------------------------------------------------------------------------------- /qwen3-inference/src/sampler.rs: -------------------------------------------------------------------------------- 1 | use crate::layers::softmax; 2 | 3 | /// Stores a probability and its associated index (token id). 4 | #[derive(Clone, Debug)] 5 | pub struct ProbIndex { 6 | pub prob: f32, 7 | pub index: usize, 8 | } 9 | 10 | /// Top-p/temperature sampler for language model logits. 11 | /// 12 | /// This struct implements temperature scaling, top-p (nucleus) sampling, 13 | /// and multinomial sampling, using a simple xorshift RNG for reproducibility. 14 | #[derive(Debug)] 15 | pub struct Sampler { 16 | pub probindex: Vec, 17 | pub temperature: f32, 18 | pub topp: f32, 19 | pub rng_state: u64, 20 | } 21 | 22 | impl Sampler { 23 | /// Creates a new sampler with the given vocabulary size, temperature, top-p, and RNG seed. 24 | /// 25 | /// # Arguments 26 | /// * `vocab_size` - Size of the vocabulary 27 | /// * `temperature` - Temperature for sampling (typical range: 0.1-2.0, 0.0 for greedy) 28 | /// * `topp` - Top-p threshold (0.0-1.0, 1.0 disables top-p) 29 | /// * `rng_seed` - Random seed for reproducibility 30 | pub fn new(vocab_size: usize, temperature: f32, topp: f32, rng_seed: u64) -> Self { 31 | assert!(vocab_size > 0, "Vocab size must be positive"); 32 | assert!(temperature >= 0.0, "Temperature must be non-negative"); 33 | assert!((0.0..=1.0).contains(&topp), "Top-p must be between 0.0 and 1.0"); 34 | 35 | Self { 36 | probindex: vec![ProbIndex { prob: 0.0, index: 0 }; vocab_size], 37 | temperature, 38 | topp: topp.clamp(0.0, 1.0), 39 | rng_state: rng_seed, 40 | } 41 | } 42 | 43 | /// Xorshift-based random number generator. 44 | fn random_u32(&mut self) -> u32 { 45 | self.rng_state ^= self.rng_state >> 12; 46 | self.rng_state ^= self.rng_state << 25; 47 | self.rng_state ^= self.rng_state >> 27; 48 | ((self.rng_state.wrapping_mul(0x2545F4914F6CDD1D)) >> 32) as u32 49 | } 50 | 51 | /// Returns a random float in [0, 1). 52 | fn random_f32(&mut self) -> f32 { 53 | (self.random_u32() >> 8) as f32 / 16777216.0 54 | } 55 | 56 | /// Returns the index of the maximum logit (greedy decoding). 57 | fn sample_argmax(logits: &[f32]) -> usize { 58 | logits.iter().enumerate().max_by(|(_, a), (_, b)| a.total_cmp(b)).map(|(i, _)| i).unwrap_or_default() 59 | } 60 | 61 | /// Multinomial sampling from a probability distribution. 62 | fn sample_mult(logits: &[f32], coin: f32) -> usize { 63 | let mut cdf = 0.0; 64 | for (i, &prob) in logits.iter().enumerate() { 65 | cdf += prob; 66 | if coin < cdf { 67 | return i; 68 | } 69 | } 70 | logits.len().saturating_sub(1) 71 | } 72 | 73 | /// Top-p (nucleus) sampling: sample from the smallest set of tokens whose cumulative probability exceeds `topp`. 74 | fn sample_topp(&mut self, logits: &[f32], coin: f32) -> usize { 75 | let cutoff = (1.0 - self.topp) / (logits.len().saturating_sub(1).max(1)) as f32; 76 | let mut n0 = 0; 77 | 78 | // Collect candidates above cutoff 79 | for (i, &prob) in logits.iter().enumerate() { 80 | if prob >= cutoff { 81 | self.probindex[n0] = ProbIndex { prob, index: i }; 82 | n0 += 1; 83 | } 84 | } 85 | 86 | // Sort by probability (descending) 87 | self.probindex[..n0].sort_unstable_by(|a, b| b.prob.total_cmp(&a.prob)); 88 | 89 | // Find truncation point 90 | let mut cumulative_prob = 0.0; 91 | let mut last_idx = n0.saturating_sub(1); 92 | for i in 0..n0 { 93 | cumulative_prob += self.probindex[i].prob; 94 | if cumulative_prob > self.topp { 95 | last_idx = i; 96 | break; 97 | } 98 | } 99 | 100 | // Sample from truncated list 101 | let r = coin * cumulative_prob; 102 | let mut cdf = 0.0; 103 | for i in 0..=last_idx { 104 | cdf += self.probindex[i].prob; 105 | if r < cdf { 106 | return self.probindex[i].index; 107 | } 108 | } 109 | self.probindex[last_idx].index 110 | } 111 | 112 | /// Samples a token index from logits using temperature and top-p. 113 | /// 114 | /// - If temperature is 0, returns the argmax (greedy). 115 | /// - Otherwise, applies temperature scaling, softmax, and top-p or multinomial sampling. 116 | pub fn sample(&mut self, logits: &mut [f32]) -> usize { 117 | if self.temperature == 0.0 { 118 | Self::sample_argmax(logits) 119 | } else { 120 | // Apply temperature 121 | for logit in logits.iter_mut() { 122 | *logit /= self.temperature; 123 | } 124 | 125 | // Apply softmax 126 | softmax(logits); 127 | 128 | let coin = self.random_f32(); 129 | 130 | if self.topp <= 0.0 || self.topp >= 1.0 { 131 | Self::sample_mult(logits, coin) 132 | } else { 133 | self.sample_topp(logits, coin) 134 | } 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /qwen3-inference/src/tensor.rs: -------------------------------------------------------------------------------- 1 | use rayon::prelude::*; 2 | use std::borrow::Cow; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct QuantizedTensor { 6 | pub q: Cow<'static, [i8]>, 7 | pub s: Cow<'static, [f32]>, 8 | } 9 | 10 | impl QuantizedTensor { 11 | // Create with owned data (for temporary/working tensors) 12 | pub fn new(size: usize, group_size: usize) -> Self { 13 | let scale_size = size / group_size; 14 | Self { q: Cow::Owned(vec![0; size]), s: Cow::Owned(vec![0.0; scale_size]) } 15 | } 16 | 17 | // Create from borrowed slices (for memory-mapped data) 18 | pub fn from_slices(q: &'static [i8], s: &'static [f32]) -> Self { 19 | Self { q: Cow::Borrowed(q), s: Cow::Borrowed(s) } 20 | } 21 | } 22 | 23 | pub fn matmul(xout: &mut [f32], x: &QuantizedTensor, w: &QuantizedTensor, n: usize, d: usize, group_size: usize) { 24 | assert!(xout.len() >= d, "Output slice length must be at least d parameter: {} >= {}", xout.len(), d); 25 | 26 | xout.par_iter_mut().enumerate().take(d).for_each(|(i, out_val)| { 27 | compute_matmul_row(out_val, x, w, i, n, group_size); 28 | }); 29 | } 30 | 31 | #[inline] 32 | fn compute_matmul_row( 33 | out_val: &mut f32, 34 | x: &QuantizedTensor, 35 | w: &QuantizedTensor, 36 | row_idx: usize, 37 | n: usize, 38 | group_size: usize, 39 | ) { 40 | debug_assert_eq!(n % group_size, 0, "n must be divisible by group_size"); 41 | 42 | let weight_row_offset = row_idx * n; 43 | let num_groups = n / group_size; 44 | 45 | *out_val = (0..num_groups) 46 | .map(|group_idx| { 47 | let group_start = group_idx * group_size; 48 | let weight_group_offset = weight_row_offset + group_start; 49 | 50 | let quantized_dot_product: i32 = x.q[group_start..group_start + group_size] 51 | .iter() 52 | .zip(&w.q[weight_group_offset..weight_group_offset + group_size]) 53 | .map(|(&x_quant, &w_quant)| x_quant as i32 * w_quant as i32) 54 | .sum(); 55 | 56 | let weight_scale = w.s[weight_group_offset / group_size]; 57 | let input_scale = x.s[group_idx]; 58 | 59 | quantized_dot_product as f32 * weight_scale * input_scale 60 | }) 61 | .sum(); 62 | } 63 | 64 | /// Dequantizes a quantized tensor into a float buffer. 65 | /// 66 | /// For each group of quantized values, multiplies by the corresponding scale factor. 67 | /// 68 | /// # Arguments 69 | /// * `qx` - The quantized tensor (with quantized values and scale factors) 70 | /// * `x` - Output buffer for dequantized values (must be at least as large as `qx.q`) 71 | /// * `group_size` - Number of elements per quantization group 72 | pub fn dequantize(qx: &QuantizedTensor, x: &mut [f32], group_size: usize) { 73 | debug_assert_eq!(x.len(), qx.q.len(), "Output buffer size must match quantized tensor size"); 74 | debug_assert_eq!(qx.s.len(), x.len() / group_size); 75 | 76 | for (i, &q_val) in qx.q.iter().enumerate() { 77 | let scale = qx.s[i / group_size]; 78 | x[i] = q_val as f32 * scale; 79 | } 80 | } 81 | 82 | /// Quantizes a float buffer into a quantized tensor using per-group scaling. 83 | /// 84 | /// For each group, finds the max absolute value, computes a scale, and quantizes values to i8. 85 | /// 86 | /// # Arguments 87 | /// * `qx` - The quantized tensor to write into (must have preallocated `q` and `s`) 88 | /// * `x` - Input float buffer to quantize 89 | /// * `size` - Number of elements to quantize (should be <= x.len()) 90 | /// * `group_size` - Number of elements per quantization group 91 | pub fn quantize(qx: &mut QuantizedTensor, x: &[f32], size: usize, group_size: usize) { 92 | debug_assert_eq!(x.len(), size); 93 | debug_assert!(qx.q.len() >= size, "Quantized buffer too small: {} < {}", qx.q.len(), size); 94 | debug_assert!(qx.s.len() >= size / group_size, "Scale buffer too small: {} < {}", qx.s.len(), size / group_size); 95 | 96 | const Q_MAX: f32 = 127.0; 97 | let num_groups = size / group_size; 98 | 99 | // Get separate mutable references to avoid borrowing conflicts 100 | let q_data = qx.q.to_mut(); 101 | let s_data = qx.s.to_mut(); 102 | 103 | for group in 0..num_groups { 104 | let group_start = group * group_size; 105 | let group_end = group_start + group_size; 106 | 107 | // Find the maximum absolute value in the group 108 | let wmax = x[group_start..group_end].iter().fold(0.0f32, |acc, &val| acc.max(val.abs())); 109 | 110 | let scale = wmax / Q_MAX; 111 | s_data[group] = scale; 112 | 113 | // Quantize the group 114 | for (i, &val) in x[group_start..group_end].iter().enumerate() { 115 | let quant_value = if scale != 0.0 { val / scale } else { 0.0 }; 116 | q_data[group_start + i] = quant_value.round() as i8; 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /qwen3-inference/src/tokenizer.rs: -------------------------------------------------------------------------------- 1 | //! Tokenizer for BPE-based language models. 2 | //! 3 | //! This module provides a simple byte-level BPE tokenizer that matches the behavior of C reference implementations. 4 | //! 5 | //! - Loads vocabulary and merge scores from a binary file. 6 | //! - Encodes text into token IDs using special token and character lookup, then applies BPE merges. 7 | //! - Decodes token IDs back to strings, handling both valid and invalid UTF-8. 8 | 9 | use anyhow::Result; 10 | use byteorder::{LittleEndian, ReadBytesExt}; 11 | use std::borrow::Cow; 12 | use std::fs::File; 13 | use std::io::Read; 14 | 15 | /// Tokenizer for byte-level BPE models. 16 | /// 17 | /// Holds the vocabulary, merge scores, and prompt templates. 18 | /// Provides encode/decode methods for text and token IDs. 19 | pub struct Tokenizer { 20 | /// Vocabulary: each token is a byte sequence (not necessarily valid UTF-8) 21 | pub vocab: Vec>, // Store raw bytes instead of Strings 22 | /// Merge scores for BPE merges (higher is better) 23 | pub merge_scores: Vec, 24 | /// Number of tokens in the vocabulary 25 | pub vocab_size: usize, 26 | /// Maximum token length (in bytes) 27 | pub max_token_length: u32, 28 | /// Beginning-of-sequence token ID 29 | pub bos_token_id: u32, 30 | /// End-of-sequence token ID 31 | pub eos_token_id: u32, 32 | /// Prompt template for user prompts 33 | pub prompt_template: String, 34 | /// Prompt template for system prompts 35 | pub system_prompt_template: String, 36 | } 37 | 38 | impl Tokenizer { 39 | /// Loads a tokenizer from a checkpoint path and vocabulary size. 40 | /// 41 | /// Reads the vocabulary, merge scores, and prompt templates from disk. 42 | pub fn new(checkpoint_path: &str, vocab_size: usize, enable_thinking: bool) -> Result { 43 | let tokenizer_path = format!("{checkpoint_path}.tokenizer"); 44 | let file = File::open(&tokenizer_path)?; 45 | let mut reader = std::io::BufReader::new(file); 46 | 47 | // Read header: max token length, BOS/EOS token IDs 48 | let max_token_length = reader.read_u32::()?; 49 | let bos_token_id = reader.read_u32::()?; 50 | let eos_token_id = reader.read_u32::()?; 51 | 52 | let mut vocab = Vec::with_capacity(vocab_size); 53 | let mut merge_scores = Vec::with_capacity(vocab_size); 54 | 55 | // Read vocabulary: (score, length, bytes) for each token 56 | for _i in 0..vocab_size { 57 | // Read score 58 | let score = match reader.read_f32::() { 59 | Ok(s) => s, 60 | Err(_) => { 61 | // If reading fails, push empty token and zero score 62 | vocab.push(Vec::new()); 63 | merge_scores.push(0.0); 64 | continue; 65 | } 66 | }; 67 | merge_scores.push(score); 68 | 69 | // Read token length 70 | let len = match reader.read_u32::() { 71 | Ok(l) => l as usize, 72 | Err(_) => { 73 | vocab.push(Vec::new()); 74 | continue; 75 | } 76 | }; 77 | 78 | // Read token bytes 79 | let mut token_bytes = vec![0u8; len]; 80 | match reader.read_exact(&mut token_bytes) { 81 | Ok(_) => vocab.push(token_bytes), 82 | Err(_) => vocab.push(Vec::new()), 83 | } 84 | } 85 | 86 | // Load prompt templates (for chat/instruction mode) 87 | let prompt_template = Self::load_prompt_template(checkpoint_path, false, enable_thinking)?; 88 | let system_prompt_template = Self::load_prompt_template(checkpoint_path, true, enable_thinking)?; 89 | 90 | Ok(Self { 91 | vocab, 92 | merge_scores, 93 | vocab_size, 94 | max_token_length, 95 | bos_token_id, 96 | eos_token_id, 97 | prompt_template, 98 | system_prompt_template, 99 | }) 100 | } 101 | 102 | /// Loads a prompt template from disk, with support for system and "thinking" variants. 103 | fn load_prompt_template(checkpoint_path: &str, with_system: bool, enable_thinking: bool) -> Result { 104 | let suffix = match (with_system, enable_thinking) { 105 | (true, true) => ".template.with-system-and-thinking", 106 | (true, false) => ".template.with-system", 107 | (false, true) => ".template.with-thinking", 108 | (false, false) => ".template", 109 | }; 110 | 111 | let template_path = format!("{checkpoint_path}{suffix}"); 112 | 113 | match std::fs::read_to_string(&template_path) { 114 | Ok(content) => Ok(content), 115 | Err(_) => { 116 | eprintln!("Warning: Could not load prompt template {template_path}"); 117 | Ok(String::new()) 118 | } 119 | } 120 | } 121 | 122 | /// Decodes a token ID to a string (may be invalid UTF-8). 123 | /// 124 | /// Returns a borrowed str if valid UTF-8, otherwise an owned String. 125 | pub fn decode(&self, token: usize) -> Cow { 126 | if token < self.vocab.len() { 127 | // Try to interpret as valid UTF-8 first (no allocation needed) 128 | match std::str::from_utf8(&self.vocab[token]) { 129 | Ok(valid_str) => Cow::Borrowed(valid_str), 130 | Err(_) => { 131 | // SAFETY: For incomplete UTF-8 sequences (like partial emoji bytes), 132 | // we need to preserve the exact bytes. Use unsafe since we know 133 | // these bytes come from a trusted tokenizer file and will be 134 | // combined with other tokens to form valid UTF-8 during generation. 135 | let string = unsafe { String::from_utf8_unchecked(self.vocab[token].clone()) }; 136 | Cow::Owned(string) 137 | } 138 | } 139 | } else { 140 | Cow::Borrowed("") 141 | } 142 | } 143 | 144 | /// Looks up a string in the vocabulary and returns its token ID, if present. 145 | fn str_lookup(&self, s: &str) -> Option { 146 | // Validate vocab_size matches actual vocab length (safety check) 147 | debug_assert_eq!(self.vocab.len(), self.vocab_size, "Vocab size mismatch"); 148 | // Convert string to bytes and compare with vocab bytes 149 | let s_bytes = s.as_bytes(); 150 | self.vocab.iter().position(|token| token.as_slice() == s_bytes) 151 | } 152 | 153 | /// Encodes a string into a sequence of token IDs using BPE. 154 | /// 155 | /// 1. Looks up special tokens (e.g., , ) and single characters. 156 | /// 2. Applies BPE merges: repeatedly merges the pair of tokens with the highest merge score, 157 | /// replacing them with the merged token, until no more merges are possible. 158 | /// 3. Returns the resulting token ID sequence. 159 | /// 160 | /// # Arguments 161 | /// * `text` - The input string to encode. 162 | /// 163 | /// # Returns 164 | /// A vector of token IDs. 165 | pub fn encode(&self, text: &str) -> Vec { 166 | let mut tokens = Vec::new(); 167 | let chars: Vec = text.chars().collect(); 168 | let mut i = 0; 169 | 170 | while i < chars.len() { 171 | let mut found_special = false; 172 | 173 | // Check for special tokens (use max_token_length for buffer bounds) 174 | if chars[i] == '<' { 175 | let mut end_pos = None; 176 | let search_limit = chars.len().min(i + self.max_token_length as usize); 177 | for j in i + 1..search_limit { 178 | if chars[j] == '>' { 179 | end_pos = Some(j); 180 | break; 181 | } 182 | } 183 | 184 | if let Some(end) = end_pos { 185 | let special_token: String = chars[i..=end].iter().collect(); 186 | if let Some(token_id) = self.str_lookup(&special_token) { 187 | tokens.push(token_id); 188 | i = end + 1; 189 | found_special = true; 190 | } 191 | } 192 | } 193 | 194 | if !found_special { 195 | let char_str = chars[i].to_string(); 196 | if let Some(token_id) = self.str_lookup(&char_str) { 197 | tokens.push(token_id); 198 | } else { 199 | // Print a warning for unknown characters (not present in vocab) 200 | println!("Warning: unknown character '{}' in input, skipping.", chars[i]); 201 | } 202 | i += 1; 203 | } 204 | } 205 | 206 | // Merge tokens using BPE (Byte Pair Encoding) 207 | // Repeatedly merge the pair with the highest merge score until no merges remain. 208 | loop { 209 | let mut best_score = -1e10; 210 | let mut best_id = None; 211 | let mut best_idx = None; 212 | 213 | for i in 0..tokens.len().saturating_sub(1) { 214 | // Concatenate the raw bytes of the two tokens 215 | let mut merged_bytes = self.vocab[tokens[i]].clone(); 216 | merged_bytes.extend_from_slice(&self.vocab[tokens[i + 1]]); 217 | 218 | if let Some(id) = self.vocab.iter().position(|token| token.as_slice() == merged_bytes.as_slice()) { 219 | if self.merge_scores[id] > best_score { 220 | best_score = self.merge_scores[id]; 221 | best_id = Some(id); 222 | best_idx = Some(i); 223 | } 224 | } 225 | } 226 | 227 | let (id, idx) = match (best_id, best_idx) { 228 | (Some(id), Some(idx)) => (id, idx), 229 | _ => break, 230 | }; 231 | 232 | tokens[idx] = id; 233 | tokens.remove(idx + 1); 234 | } 235 | 236 | tokens 237 | } 238 | } 239 | 240 | impl std::fmt::Debug for Tokenizer { 241 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 242 | let bos_token = (self.bos_token_id, self.decode(self.bos_token_id as usize)); 243 | let eos_token = (self.eos_token_id, self.decode(self.eos_token_id as usize)); 244 | 245 | f.debug_struct("Tokenizer") 246 | .field("vocab_size", &self.vocab_size) 247 | .field("max_token_length", &self.max_token_length) 248 | .field("bos_token_id", &bos_token) 249 | .field("eos_token_id", &eos_token) 250 | .field("prompt_template", &self.prompt_template) 251 | .field("system_prompt_template", &self.system_prompt_template) 252 | .finish_non_exhaustive() 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /qwen3-inference/src/utils.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | use memmap2::Mmap; 3 | use std::fs::File; 4 | use std::slice; 5 | 6 | #[derive(Debug)] 7 | pub(crate) struct MemoryMapper { 8 | mmap: Mmap, 9 | offset: usize, 10 | } 11 | 12 | impl MemoryMapper { 13 | pub fn new(file: File) -> Result { 14 | let mmap = unsafe { memmap2::MmapOptions::new().map(&file).context("Failed to create memory mapping")? }; 15 | Ok(Self { mmap, offset: 0 }) 16 | } 17 | 18 | pub fn get_f32_slice(&mut self, count: usize) -> Result<&[f32]> { 19 | let bytes_needed = count * std::mem::size_of::(); 20 | 21 | if self.offset + bytes_needed > self.mmap.len() { 22 | anyhow::bail!( 23 | "Insufficient data: need {} bytes, have {} remaining", 24 | bytes_needed, 25 | self.mmap.len() - self.offset 26 | ); 27 | } 28 | 29 | let byte_slice = &self.mmap[self.offset..self.offset + bytes_needed]; 30 | self.offset += bytes_needed; 31 | 32 | // SAFETY: We're casting from &[u8] to &[f32] 33 | // This is safe because: 34 | // 1. We've verified the slice has the correct length 35 | // 2. f32 has less strict alignment requirements than most types 36 | // 3. The checkpoint file is assumed to be correctly formatted 37 | let f32_slice = unsafe { slice::from_raw_parts(byte_slice.as_ptr() as *const f32, count) }; 38 | 39 | Ok(f32_slice) 40 | } 41 | 42 | pub fn get_bytes(&mut self, count: usize) -> Result<&[u8]> { 43 | if self.offset + count > self.mmap.len() { 44 | anyhow::bail!("Insufficient data: need {} bytes, have {} remaining", count, self.mmap.len() - self.offset); 45 | } 46 | 47 | let result = &self.mmap[self.offset..self.offset + count]; 48 | self.offset += count; 49 | Ok(result) 50 | } 51 | 52 | pub fn skip(&mut self, bytes: usize) -> Result<()> { 53 | if self.offset + bytes > self.mmap.len() { 54 | anyhow::bail!("Cannot skip {} bytes: insufficient data", bytes); 55 | } 56 | self.offset += bytes; 57 | Ok(()) 58 | } 59 | } 60 | --------------------------------------------------------------------------------