├── .gitignore ├── .cargo └── config ├── Cargo.toml ├── tiktoken_core-0.2.5-1.rockspec ├── test ├── test.lua └── json.lua ├── LICENSE ├── .github └── workflows │ ├── build.yml │ └── release.yml ├── Cargo.lock └── src └── lib.rs /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | -------------------------------------------------------------------------------- /.cargo/config: -------------------------------------------------------------------------------- 1 | [target.x86_64-apple-darwin] 2 | rustflags = [ 3 | "-C", "link-arg=-undefined", 4 | "-C", "link-arg=dynamic_lookup", 5 | ] 6 | 7 | [target.aarch64-apple-darwin] 8 | rustflags = [ 9 | "-C", "link-arg=-undefined", 10 | "-C", "link-arg=dynamic_lookup", 11 | ] 12 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tiktoken_core" 3 | version = "0.4.0" 4 | edition = "2021" 5 | rust-version = "1.57.0" 6 | 7 | [lib] 8 | name = "tiktoken_core" 9 | crate-type = ["cdylib"] 10 | 11 | [dependencies] 12 | mlua = { version = "0.10.5", features = ["serialize", "module"] } 13 | # tiktoken dependencies 14 | fancy-regex = "0.11.0" 15 | regex = "1.8.3" 16 | rustc-hash = "1.1.0" 17 | bstr = "1.5.0" 18 | base64 = "0.21.7" 19 | 20 | [features] 21 | lua54 = ["mlua/lua54"] 22 | lua53 = ["mlua/lua53"] 23 | lua52 = ["mlua/lua52"] 24 | lua51 = ["mlua/lua51"] 25 | luajit = ["mlua/luajit"] 26 | -------------------------------------------------------------------------------- /tiktoken_core-0.2.5-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "tiktoken_core" 2 | version = "0.2.5-1" 3 | 4 | source = { 5 | url = "git+https://github.com/gptlang/lua-tiktoken", 6 | tag = "v0.2.5", 7 | } 8 | 9 | description = { 10 | summary = "An experimental port of OpenAI's Tokenizer to lua", 11 | detailed = [[ 12 | The Lua module written in Rust that provides Tiktoken support for Lua. 13 | ]], 14 | homepage = "https://github.com/gptlang/lua-tiktoken", 15 | license = "MIT", 16 | } 17 | 18 | dependencies = { 19 | "lua >= 5.1", 20 | "luarocks-build-rust-mlua", 21 | } 22 | 23 | build = { 24 | type = "rust-mlua", 25 | modules = { 26 | "tiktoken_core", 27 | }, 28 | } 29 | -------------------------------------------------------------------------------- /test/test.lua: -------------------------------------------------------------------------------- 1 | local tiktoken_core = require('tiktoken_core') 2 | local dkjson = require('json') 3 | 4 | local special_tokens = {} 5 | special_tokens['<|endoftext|>'] = 100257 6 | special_tokens['<|fim_prefix|>'] = 100258 7 | special_tokens['<|fim_middle|>'] = 100259 8 | special_tokens['<|fim_suffix|>'] = 100260 9 | special_tokens['<|endofprompt|>'] = 100276 10 | local pat_str = 11 | "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" 12 | tiktoken_core.new("/tmp/cl100k_base.tiktoken", special_tokens, pat_str) 13 | 14 | local result = tiktoken_core.encode('Hello, world!') 15 | print(dkjson.encode(result)) 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 gptlang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Upload Artifacts 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest, macos-latest, windows-latest] 14 | feature: [lua54, lua53, lua52, lua51, luajit] 15 | runs-on: ${{ matrix.os }} 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Install Rust 21 | uses: dtolnay/rust-toolchain@stable 22 | 23 | - name: Build 24 | run: cargo build --release --features ${{ matrix.feature }} 25 | 26 | - name: Prepare artifact 27 | shell: bash 28 | run: | 29 | if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then 30 | OS="linux" 31 | EXT="so" 32 | elif [ "${{ matrix.os }}" == "macos-latest" ]; then 33 | OS="macOS" 34 | EXT="dylib" 35 | else 36 | OS="windows" 37 | EXT="dll" 38 | fi 39 | mkdir -p artifacts 40 | if [ "${{ matrix.os }}" == "windows-latest" ]; then 41 | cp target/release/tiktoken_core.$EXT artifacts/tiktoken_core-$OS-${{ matrix.feature }}.$EXT 42 | else 43 | cp target/release/libtiktoken_core.$EXT artifacts/tiktoken_core-$OS-${{ matrix.feature }}.$EXT 44 | fi 45 | 46 | - name: Upload artifact 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: tiktoken_core-${{ matrix.os }}-${{ matrix.feature }} 50 | path: artifacts/tiktoken_core-*.${{ matrix.os == 'ubuntu-latest' && 'so' || matrix.os == 'macos-latest' && 'dylib' || 'dll' }} 51 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release and Upload Artifacts 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | build_and_release: 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, macos-latest, windows-latest] 13 | feature: [lua54, lua53, lua52, lua51, luajit] 14 | arch: [x86_64, aarch64] 15 | exclude: 16 | - os: windows-latest 17 | arch: aarch64 18 | 19 | runs-on: ${{ matrix.os }} 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - name: Install Rust 25 | uses: dtolnay/rust-toolchain@stable 26 | with: 27 | targets: ${{ matrix.arch }}-${{ matrix.os == 'ubuntu-latest' && 'unknown-linux-gnu' || matrix.os == 'macos-latest' && 'apple-darwin' || 'pc-windows-msvc' }} 28 | 29 | - name: Install cross-compilation tools (Linux ARM64) 30 | if: matrix.os == 'ubuntu-latest' && matrix.arch == 'aarch64' 31 | run: | 32 | sudo apt-get update 33 | sudo apt-get install -y gcc-aarch64-linux-gnu 34 | 35 | - name: Build 36 | shell: bash 37 | run: | 38 | if [ "${{ matrix.arch }}" == "aarch64" ]; then 39 | if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then 40 | RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc" cargo build --release --features ${{ matrix.feature }} --target aarch64-unknown-linux-gnu 41 | else 42 | cargo build --release --features ${{ matrix.feature }} --target aarch64-apple-darwin 43 | fi 44 | else 45 | cargo build --release --features ${{ matrix.feature }} 46 | fi 47 | 48 | - name: Prepare artifact 49 | shell: bash 50 | run: | 51 | if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then 52 | OS="linux" 53 | EXT="so" 54 | elif [ "${{ matrix.os }}" == "macos-latest" ]; then 55 | OS="macOS" 56 | EXT="dylib" 57 | else 58 | OS="windows" 59 | EXT="dll" 60 | fi 61 | ARCH="${{ matrix.arch == 'x86_64' && 'x86_64' || 'arm64' }}" 62 | mkdir -p artifacts 63 | if [ "${{ matrix.arch }}" == "aarch64" ]; then 64 | if [ "${{ matrix.os }}" == "ubuntu-latest" ]; then 65 | cp target/aarch64-unknown-linux-gnu/release/libtiktoken_core.$EXT artifacts/tiktoken_core-$OS-$ARCH-${{ matrix.feature }}.$EXT 66 | else 67 | cp target/aarch64-apple-darwin/release/libtiktoken_core.$EXT artifacts/tiktoken_core-$OS-$ARCH-${{ matrix.feature }}.$EXT 68 | fi 69 | else 70 | if [ "${{ matrix.os }}" == "windows-latest" ]; then 71 | cp target/release/tiktoken_core.$EXT artifacts/tiktoken_core-$OS-$ARCH-${{ matrix.feature }}.$EXT 72 | else 73 | cp target/release/libtiktoken_core.$EXT artifacts/tiktoken_core-$OS-$ARCH-${{ matrix.feature }}.$EXT 74 | fi 75 | fi 76 | 77 | - name: Upload artifacts 78 | uses: actions/upload-artifact@v4 79 | with: 80 | name: tiktoken_core-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.feature }} 81 | path: artifacts/* 82 | 83 | release: 84 | needs: build_and_release 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@v4 88 | 89 | - name: Download all artifacts 90 | uses: actions/download-artifact@v4 91 | with: 92 | path: artifacts 93 | 94 | - name: Create Release 95 | uses: softprops/action-gh-release@v2 96 | with: 97 | files: artifacts/**/* 98 | draft: false 99 | prerelease: false 100 | env: 101 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 102 | -------------------------------------------------------------------------------- /test/json.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- json.lua 3 | -- 4 | -- Copyright (c) 2020 rxi 5 | -- 6 | -- Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | -- this software and associated documentation files (the "Software"), to deal in 8 | -- the Software without restriction, including without limitation the rights to 9 | -- use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | -- of the Software, and to permit persons to whom the Software is furnished to do 11 | -- so, subject to the following conditions: 12 | -- 13 | -- The above copyright notice and this permission notice shall be included in all 14 | -- copies or substantial portions of the Software. 15 | -- 16 | -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | -- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | -- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | -- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | -- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | -- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | -- SOFTWARE. 23 | -- 24 | 25 | local json = { _version = "0.1.2" } 26 | 27 | ------------------------------------------------------------------------------- 28 | -- Encode 29 | ------------------------------------------------------------------------------- 30 | 31 | local encode 32 | 33 | local escape_char_map = { 34 | [ "\\" ] = "\\", 35 | [ "\"" ] = "\"", 36 | [ "\b" ] = "b", 37 | [ "\f" ] = "f", 38 | [ "\n" ] = "n", 39 | [ "\r" ] = "r", 40 | [ "\t" ] = "t", 41 | } 42 | 43 | local escape_char_map_inv = { [ "/" ] = "/" } 44 | for k, v in pairs(escape_char_map) do 45 | escape_char_map_inv[v] = k 46 | end 47 | 48 | 49 | local function escape_char(c) 50 | return "\\" .. (escape_char_map[c] or string.format("u%04x", c:byte())) 51 | end 52 | 53 | 54 | local function encode_nil(val) 55 | return "null" 56 | end 57 | 58 | 59 | local function encode_table(val, stack) 60 | local res = {} 61 | stack = stack or {} 62 | 63 | -- Circular reference? 64 | if stack[val] then error("circular reference") end 65 | 66 | stack[val] = true 67 | 68 | if rawget(val, 1) ~= nil or next(val) == nil then 69 | -- Treat as array -- check keys are valid and it is not sparse 70 | local n = 0 71 | for k in pairs(val) do 72 | if type(k) ~= "number" then 73 | error("invalid table: mixed or invalid key types") 74 | end 75 | n = n + 1 76 | end 77 | if n ~= #val then 78 | error("invalid table: sparse array") 79 | end 80 | -- Encode 81 | for i, v in ipairs(val) do 82 | table.insert(res, encode(v, stack)) 83 | end 84 | stack[val] = nil 85 | return "[" .. table.concat(res, ",") .. "]" 86 | 87 | else 88 | -- Treat as an object 89 | for k, v in pairs(val) do 90 | if type(k) ~= "string" then 91 | error("invalid table: mixed or invalid key types") 92 | end 93 | table.insert(res, encode(k, stack) .. ":" .. encode(v, stack)) 94 | end 95 | stack[val] = nil 96 | return "{" .. table.concat(res, ",") .. "}" 97 | end 98 | end 99 | 100 | 101 | local function encode_string(val) 102 | return '"' .. val:gsub('[%z\1-\31\\"]', escape_char) .. '"' 103 | end 104 | 105 | 106 | local function encode_number(val) 107 | -- Check for NaN, -inf and inf 108 | if val ~= val or val <= -math.huge or val >= math.huge then 109 | error("unexpected number value '" .. tostring(val) .. "'") 110 | end 111 | return string.format("%.14g", val) 112 | end 113 | 114 | 115 | local type_func_map = { 116 | [ "nil" ] = encode_nil, 117 | [ "table" ] = encode_table, 118 | [ "string" ] = encode_string, 119 | [ "number" ] = encode_number, 120 | [ "boolean" ] = tostring, 121 | } 122 | 123 | 124 | encode = function(val, stack) 125 | local t = type(val) 126 | local f = type_func_map[t] 127 | if f then 128 | return f(val, stack) 129 | end 130 | error("unexpected type '" .. t .. "'") 131 | end 132 | 133 | 134 | function json.encode(val) 135 | return ( encode(val) ) 136 | end 137 | 138 | 139 | ------------------------------------------------------------------------------- 140 | -- Decode 141 | ------------------------------------------------------------------------------- 142 | 143 | local parse 144 | 145 | local function create_set(...) 146 | local res = {} 147 | for i = 1, select("#", ...) do 148 | res[ select(i, ...) ] = true 149 | end 150 | return res 151 | end 152 | 153 | local space_chars = create_set(" ", "\t", "\r", "\n") 154 | local delim_chars = create_set(" ", "\t", "\r", "\n", "]", "}", ",") 155 | local escape_chars = create_set("\\", "/", '"', "b", "f", "n", "r", "t", "u") 156 | local literals = create_set("true", "false", "null") 157 | 158 | local literal_map = { 159 | [ "true" ] = true, 160 | [ "false" ] = false, 161 | [ "null" ] = nil, 162 | } 163 | 164 | 165 | local function next_char(str, idx, set, negate) 166 | for i = idx, #str do 167 | if set[str:sub(i, i)] ~= negate then 168 | return i 169 | end 170 | end 171 | return #str + 1 172 | end 173 | 174 | 175 | local function decode_error(str, idx, msg) 176 | local line_count = 1 177 | local col_count = 1 178 | for i = 1, idx - 1 do 179 | col_count = col_count + 1 180 | if str:sub(i, i) == "\n" then 181 | line_count = line_count + 1 182 | col_count = 1 183 | end 184 | end 185 | error( string.format("%s at line %d col %d", msg, line_count, col_count) ) 186 | end 187 | 188 | 189 | local function codepoint_to_utf8(n) 190 | -- http://scripts.sil.org/cms/scripts/page.php?site_id=nrsi&id=iws-appendixa 191 | local f = math.floor 192 | if n <= 0x7f then 193 | return string.char(n) 194 | elseif n <= 0x7ff then 195 | return string.char(f(n / 64) + 192, n % 64 + 128) 196 | elseif n <= 0xffff then 197 | return string.char(f(n / 4096) + 224, f(n % 4096 / 64) + 128, n % 64 + 128) 198 | elseif n <= 0x10ffff then 199 | return string.char(f(n / 262144) + 240, f(n % 262144 / 4096) + 128, 200 | f(n % 4096 / 64) + 128, n % 64 + 128) 201 | end 202 | error( string.format("invalid unicode codepoint '%x'", n) ) 203 | end 204 | 205 | 206 | local function parse_unicode_escape(s) 207 | local n1 = tonumber( s:sub(1, 4), 16 ) 208 | local n2 = tonumber( s:sub(7, 10), 16 ) 209 | -- Surrogate pair? 210 | if n2 then 211 | return codepoint_to_utf8((n1 - 0xd800) * 0x400 + (n2 - 0xdc00) + 0x10000) 212 | else 213 | return codepoint_to_utf8(n1) 214 | end 215 | end 216 | 217 | 218 | local function parse_string(str, i) 219 | local res = "" 220 | local j = i + 1 221 | local k = j 222 | 223 | while j <= #str do 224 | local x = str:byte(j) 225 | 226 | if x < 32 then 227 | decode_error(str, j, "control character in string") 228 | 229 | elseif x == 92 then -- `\`: Escape 230 | res = res .. str:sub(k, j - 1) 231 | j = j + 1 232 | local c = str:sub(j, j) 233 | if c == "u" then 234 | local hex = str:match("^[dD][89aAbB]%x%x\\u%x%x%x%x", j + 1) 235 | or str:match("^%x%x%x%x", j + 1) 236 | or decode_error(str, j - 1, "invalid unicode escape in string") 237 | res = res .. parse_unicode_escape(hex) 238 | j = j + #hex 239 | else 240 | if not escape_chars[c] then 241 | decode_error(str, j - 1, "invalid escape char '" .. c .. "' in string") 242 | end 243 | res = res .. escape_char_map_inv[c] 244 | end 245 | k = j + 1 246 | 247 | elseif x == 34 then -- `"`: End of string 248 | res = res .. str:sub(k, j - 1) 249 | return res, j + 1 250 | end 251 | 252 | j = j + 1 253 | end 254 | 255 | decode_error(str, i, "expected closing quote for string") 256 | end 257 | 258 | 259 | local function parse_number(str, i) 260 | local x = next_char(str, i, delim_chars) 261 | local s = str:sub(i, x - 1) 262 | local n = tonumber(s) 263 | if not n then 264 | decode_error(str, i, "invalid number '" .. s .. "'") 265 | end 266 | return n, x 267 | end 268 | 269 | 270 | local function parse_literal(str, i) 271 | local x = next_char(str, i, delim_chars) 272 | local word = str:sub(i, x - 1) 273 | if not literals[word] then 274 | decode_error(str, i, "invalid literal '" .. word .. "'") 275 | end 276 | return literal_map[word], x 277 | end 278 | 279 | 280 | local function parse_array(str, i) 281 | local res = {} 282 | local n = 1 283 | i = i + 1 284 | while 1 do 285 | local x 286 | i = next_char(str, i, space_chars, true) 287 | -- Empty / end of array? 288 | if str:sub(i, i) == "]" then 289 | i = i + 1 290 | break 291 | end 292 | -- Read token 293 | x, i = parse(str, i) 294 | res[n] = x 295 | n = n + 1 296 | -- Next token 297 | i = next_char(str, i, space_chars, true) 298 | local chr = str:sub(i, i) 299 | i = i + 1 300 | if chr == "]" then break end 301 | if chr ~= "," then decode_error(str, i, "expected ']' or ','") end 302 | end 303 | return res, i 304 | end 305 | 306 | 307 | local function parse_object(str, i) 308 | local res = {} 309 | i = i + 1 310 | while 1 do 311 | local key, val 312 | i = next_char(str, i, space_chars, true) 313 | -- Empty / end of object? 314 | if str:sub(i, i) == "}" then 315 | i = i + 1 316 | break 317 | end 318 | -- Read key 319 | if str:sub(i, i) ~= '"' then 320 | decode_error(str, i, "expected string for key") 321 | end 322 | key, i = parse(str, i) 323 | -- Read ':' delimiter 324 | i = next_char(str, i, space_chars, true) 325 | if str:sub(i, i) ~= ":" then 326 | decode_error(str, i, "expected ':' after key") 327 | end 328 | i = next_char(str, i + 1, space_chars, true) 329 | -- Read value 330 | val, i = parse(str, i) 331 | -- Set 332 | res[key] = val 333 | -- Next token 334 | i = next_char(str, i, space_chars, true) 335 | local chr = str:sub(i, i) 336 | i = i + 1 337 | if chr == "}" then break end 338 | if chr ~= "," then decode_error(str, i, "expected '}' or ','") end 339 | end 340 | return res, i 341 | end 342 | 343 | 344 | local char_func_map = { 345 | [ '"' ] = parse_string, 346 | [ "0" ] = parse_number, 347 | [ "1" ] = parse_number, 348 | [ "2" ] = parse_number, 349 | [ "3" ] = parse_number, 350 | [ "4" ] = parse_number, 351 | [ "5" ] = parse_number, 352 | [ "6" ] = parse_number, 353 | [ "7" ] = parse_number, 354 | [ "8" ] = parse_number, 355 | [ "9" ] = parse_number, 356 | [ "-" ] = parse_number, 357 | [ "t" ] = parse_literal, 358 | [ "f" ] = parse_literal, 359 | [ "n" ] = parse_literal, 360 | [ "[" ] = parse_array, 361 | [ "{" ] = parse_object, 362 | } 363 | 364 | 365 | parse = function(str, idx) 366 | local chr = str:sub(idx, idx) 367 | local f = char_func_map[chr] 368 | if f then 369 | return f(str, idx) 370 | end 371 | decode_error(str, idx, "unexpected character '" .. chr .. "'") 372 | end 373 | 374 | 375 | function json.decode(str) 376 | if type(str) ~= "string" then 377 | error("expected argument of type string, got " .. type(str)) 378 | end 379 | local res, idx = parse(str, next_char(str, 1, space_chars, true)) 380 | idx = next_char(str, idx, space_chars, true) 381 | if idx <= #str then 382 | decode_error(str, idx, "trailing garbage") 383 | end 384 | return res 385 | end 386 | 387 | 388 | return json 389 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "1.1.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "autocfg" 16 | version = "1.1.0" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 19 | 20 | [[package]] 21 | name = "base64" 22 | version = "0.21.7" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" 25 | 26 | [[package]] 27 | name = "bit-set" 28 | version = "0.5.3" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" 31 | dependencies = [ 32 | "bit-vec", 33 | ] 34 | 35 | [[package]] 36 | name = "bit-vec" 37 | version = "0.6.3" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" 40 | 41 | [[package]] 42 | name = "bitflags" 43 | version = "2.9.1" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" 46 | 47 | [[package]] 48 | name = "bstr" 49 | version = "1.9.1" 50 | source = "registry+https://github.com/rust-lang/crates.io-index" 51 | checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" 52 | dependencies = [ 53 | "memchr", 54 | "regex-automata", 55 | "serde", 56 | ] 57 | 58 | [[package]] 59 | name = "cc" 60 | version = "1.0.88" 61 | source = "registry+https://github.com/rust-lang/crates.io-index" 62 | checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc" 63 | 64 | [[package]] 65 | name = "cfg-if" 66 | version = "1.0.0" 67 | source = "registry+https://github.com/rust-lang/crates.io-index" 68 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 69 | 70 | [[package]] 71 | name = "either" 72 | version = "1.15.0" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" 75 | 76 | [[package]] 77 | name = "erased-serde" 78 | version = "0.4.3" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "388979d208a049ffdfb22fa33b9c81942215b940910bccfe258caeb25d125cb3" 81 | dependencies = [ 82 | "serde", 83 | ] 84 | 85 | [[package]] 86 | name = "fancy-regex" 87 | version = "0.11.0" 88 | source = "registry+https://github.com/rust-lang/crates.io-index" 89 | checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" 90 | dependencies = [ 91 | "bit-set", 92 | "regex", 93 | ] 94 | 95 | [[package]] 96 | name = "libc" 97 | version = "0.2.172" 98 | source = "registry+https://github.com/rust-lang/crates.io-index" 99 | checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" 100 | 101 | [[package]] 102 | name = "lock_api" 103 | version = "0.4.13" 104 | source = "registry+https://github.com/rust-lang/crates.io-index" 105 | checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" 106 | dependencies = [ 107 | "autocfg", 108 | "scopeguard", 109 | ] 110 | 111 | [[package]] 112 | name = "memchr" 113 | version = "2.7.1" 114 | source = "registry+https://github.com/rust-lang/crates.io-index" 115 | checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" 116 | 117 | [[package]] 118 | name = "mlua" 119 | version = "0.10.5" 120 | source = "registry+https://github.com/rust-lang/crates.io-index" 121 | checksum = "c1f5f8fbebc7db5f671671134b9321c4b9aa9adeafccfd9a8c020ae45c6a35d0" 122 | dependencies = [ 123 | "bstr", 124 | "either", 125 | "erased-serde", 126 | "mlua-sys", 127 | "mlua_derive", 128 | "num-traits", 129 | "parking_lot", 130 | "rustc-hash 2.0.0", 131 | "rustversion", 132 | "serde", 133 | "serde-value", 134 | ] 135 | 136 | [[package]] 137 | name = "mlua-sys" 138 | version = "0.6.8" 139 | source = "registry+https://github.com/rust-lang/crates.io-index" 140 | checksum = "380c1f7e2099cafcf40e51d3a9f20a346977587aa4d012eae1f043149a728a93" 141 | dependencies = [ 142 | "cc", 143 | "cfg-if", 144 | "pkg-config", 145 | ] 146 | 147 | [[package]] 148 | name = "mlua_derive" 149 | version = "0.10.1" 150 | source = "registry+https://github.com/rust-lang/crates.io-index" 151 | checksum = "870d71c172fcf491c6b5fb4c04160619a2ee3e5a42a1402269c66bcbf1dd4deb" 152 | dependencies = [ 153 | "proc-macro2", 154 | "quote", 155 | "syn", 156 | ] 157 | 158 | [[package]] 159 | name = "num-traits" 160 | version = "0.2.18" 161 | source = "registry+https://github.com/rust-lang/crates.io-index" 162 | checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" 163 | dependencies = [ 164 | "autocfg", 165 | ] 166 | 167 | [[package]] 168 | name = "ordered-float" 169 | version = "2.10.1" 170 | source = "registry+https://github.com/rust-lang/crates.io-index" 171 | checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" 172 | dependencies = [ 173 | "num-traits", 174 | ] 175 | 176 | [[package]] 177 | name = "parking_lot" 178 | version = "0.12.4" 179 | source = "registry+https://github.com/rust-lang/crates.io-index" 180 | checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" 181 | dependencies = [ 182 | "lock_api", 183 | "parking_lot_core", 184 | ] 185 | 186 | [[package]] 187 | name = "parking_lot_core" 188 | version = "0.9.11" 189 | source = "registry+https://github.com/rust-lang/crates.io-index" 190 | checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" 191 | dependencies = [ 192 | "cfg-if", 193 | "libc", 194 | "redox_syscall", 195 | "smallvec", 196 | "windows-targets", 197 | ] 198 | 199 | [[package]] 200 | name = "pkg-config" 201 | version = "0.3.30" 202 | source = "registry+https://github.com/rust-lang/crates.io-index" 203 | checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" 204 | 205 | [[package]] 206 | name = "proc-macro2" 207 | version = "1.0.78" 208 | source = "registry+https://github.com/rust-lang/crates.io-index" 209 | checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" 210 | dependencies = [ 211 | "unicode-ident", 212 | ] 213 | 214 | [[package]] 215 | name = "quote" 216 | version = "1.0.35" 217 | source = "registry+https://github.com/rust-lang/crates.io-index" 218 | checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" 219 | dependencies = [ 220 | "proc-macro2", 221 | ] 222 | 223 | [[package]] 224 | name = "redox_syscall" 225 | version = "0.5.12" 226 | source = "registry+https://github.com/rust-lang/crates.io-index" 227 | checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" 228 | dependencies = [ 229 | "bitflags", 230 | ] 231 | 232 | [[package]] 233 | name = "regex" 234 | version = "1.10.3" 235 | source = "registry+https://github.com/rust-lang/crates.io-index" 236 | checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" 237 | dependencies = [ 238 | "aho-corasick", 239 | "memchr", 240 | "regex-automata", 241 | "regex-syntax", 242 | ] 243 | 244 | [[package]] 245 | name = "regex-automata" 246 | version = "0.4.5" 247 | source = "registry+https://github.com/rust-lang/crates.io-index" 248 | checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" 249 | dependencies = [ 250 | "aho-corasick", 251 | "memchr", 252 | "regex-syntax", 253 | ] 254 | 255 | [[package]] 256 | name = "regex-syntax" 257 | version = "0.8.2" 258 | source = "registry+https://github.com/rust-lang/crates.io-index" 259 | checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" 260 | 261 | [[package]] 262 | name = "rustc-hash" 263 | version = "1.1.0" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 266 | 267 | [[package]] 268 | name = "rustc-hash" 269 | version = "2.0.0" 270 | source = "registry+https://github.com/rust-lang/crates.io-index" 271 | checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" 272 | 273 | [[package]] 274 | name = "rustversion" 275 | version = "1.0.21" 276 | source = "registry+https://github.com/rust-lang/crates.io-index" 277 | checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" 278 | 279 | [[package]] 280 | name = "scopeguard" 281 | version = "1.2.0" 282 | source = "registry+https://github.com/rust-lang/crates.io-index" 283 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 284 | 285 | [[package]] 286 | name = "serde" 287 | version = "1.0.197" 288 | source = "registry+https://github.com/rust-lang/crates.io-index" 289 | checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" 290 | dependencies = [ 291 | "serde_derive", 292 | ] 293 | 294 | [[package]] 295 | name = "serde-value" 296 | version = "0.7.0" 297 | source = "registry+https://github.com/rust-lang/crates.io-index" 298 | checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" 299 | dependencies = [ 300 | "ordered-float", 301 | "serde", 302 | ] 303 | 304 | [[package]] 305 | name = "serde_derive" 306 | version = "1.0.197" 307 | source = "registry+https://github.com/rust-lang/crates.io-index" 308 | checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" 309 | dependencies = [ 310 | "proc-macro2", 311 | "quote", 312 | "syn", 313 | ] 314 | 315 | [[package]] 316 | name = "smallvec" 317 | version = "1.15.0" 318 | source = "registry+https://github.com/rust-lang/crates.io-index" 319 | checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" 320 | 321 | [[package]] 322 | name = "syn" 323 | version = "2.0.51" 324 | source = "registry+https://github.com/rust-lang/crates.io-index" 325 | checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c" 326 | dependencies = [ 327 | "proc-macro2", 328 | "quote", 329 | "unicode-ident", 330 | ] 331 | 332 | [[package]] 333 | name = "tiktoken_core" 334 | version = "0.4.0" 335 | dependencies = [ 336 | "base64", 337 | "bstr", 338 | "fancy-regex", 339 | "mlua", 340 | "regex", 341 | "rustc-hash 1.1.0", 342 | ] 343 | 344 | [[package]] 345 | name = "unicode-ident" 346 | version = "1.0.12" 347 | source = "registry+https://github.com/rust-lang/crates.io-index" 348 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 349 | 350 | [[package]] 351 | name = "windows-targets" 352 | version = "0.52.6" 353 | source = "registry+https://github.com/rust-lang/crates.io-index" 354 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 355 | dependencies = [ 356 | "windows_aarch64_gnullvm", 357 | "windows_aarch64_msvc", 358 | "windows_i686_gnu", 359 | "windows_i686_gnullvm", 360 | "windows_i686_msvc", 361 | "windows_x86_64_gnu", 362 | "windows_x86_64_gnullvm", 363 | "windows_x86_64_msvc", 364 | ] 365 | 366 | [[package]] 367 | name = "windows_aarch64_gnullvm" 368 | version = "0.52.6" 369 | source = "registry+https://github.com/rust-lang/crates.io-index" 370 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 371 | 372 | [[package]] 373 | name = "windows_aarch64_msvc" 374 | version = "0.52.6" 375 | source = "registry+https://github.com/rust-lang/crates.io-index" 376 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 377 | 378 | [[package]] 379 | name = "windows_i686_gnu" 380 | version = "0.52.6" 381 | source = "registry+https://github.com/rust-lang/crates.io-index" 382 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 383 | 384 | [[package]] 385 | name = "windows_i686_gnullvm" 386 | version = "0.52.6" 387 | source = "registry+https://github.com/rust-lang/crates.io-index" 388 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 389 | 390 | [[package]] 391 | name = "windows_i686_msvc" 392 | version = "0.52.6" 393 | source = "registry+https://github.com/rust-lang/crates.io-index" 394 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 395 | 396 | [[package]] 397 | name = "windows_x86_64_gnu" 398 | version = "0.52.6" 399 | source = "registry+https://github.com/rust-lang/crates.io-index" 400 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 401 | 402 | [[package]] 403 | name = "windows_x86_64_gnullvm" 404 | version = "0.52.6" 405 | source = "registry+https://github.com/rust-lang/crates.io-index" 406 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 407 | 408 | [[package]] 409 | name = "windows_x86_64_msvc" 410 | version = "0.52.6" 411 | source = "registry+https://github.com/rust-lang/crates.io-index" 412 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 413 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use base64::{prelude::BASE64_STANDARD, Engine as _}; 2 | use fancy_regex::Regex; 3 | use mlua::prelude::*; 4 | use rustc_hash::FxHashMap as HashMap; 5 | use std::collections::HashSet; 6 | use std::fs::File; 7 | use std::io::{BufRead, BufReader}; 8 | use std::sync::{Arc, Mutex}; 9 | use std::thread; 10 | 11 | const MAX_NUM_THREADS: usize = 1; 12 | 13 | fn _byte_pair_merge( 14 | piece: &[u8], 15 | ranks: &HashMap, usize>, 16 | f: impl Fn(std::ops::Range) -> T, 17 | ) -> Vec { 18 | // This is a vector of (start, rank). 19 | // The rank is of the byte pair starting at position start. 20 | // The rank of the last item in the vector is not a valid value. 21 | let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); 22 | 23 | // NOTE: using a macro here because a closure fails to get inlined 24 | // according to optimization remarks. 25 | // A closure also cannot capture a reference to `piece` without 26 | // the borrow checker complaining about the mutable borrows during 27 | // the assignments later in this code. 28 | macro_rules! get_rank { 29 | ($start_idx:expr, $skip:expr) => {{ 30 | let start_idx: usize = $start_idx; 31 | let skip: usize = $skip; 32 | if (start_idx + skip + 2) < parts.len() { 33 | ranks 34 | .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) 35 | .map(|r| *r) 36 | } else { 37 | None 38 | } 39 | }}; 40 | ($idx:expr) => {{ 41 | get_rank!($idx, 0) 42 | }}; 43 | } 44 | 45 | // We look up the ranks once in the beggining and iteratively update 46 | // them during each merge, which reduces the number of rank lookups. 47 | for i in 0..parts.len() - 2 { 48 | match get_rank!(i) { 49 | Some(rank) => { 50 | // usize::MAX is a sentinel value and cannot be a valid rank 51 | debug_assert!(rank != usize::MAX); 52 | parts[i].1 = rank; 53 | } 54 | None => { 55 | continue; 56 | } 57 | }; 58 | } 59 | 60 | // If you have n parts and m merges, this does O(mn) work. 61 | // We could do something with a heap and do O(m log n) work. 62 | // It is important to consider that n is often small (<100), and as such 63 | // the cache-locality benefits outweigh the algorithmic complexity downsides 64 | // of the `parts` vector data structure above. 65 | 66 | // Note that we hash bytes, not token pairs. As long as we train BPE the way we 67 | // currently do, this is equivalent. An easy way to break this would be to decouple 68 | // merge priority from token index or to prevent specific token merges. 69 | loop { 70 | if parts.len() == 1 { 71 | break; 72 | } 73 | 74 | // usize::MAX is a sentinel rank value allowing us to 75 | // take the min more quickly 76 | let mut min_rank: (usize, usize) = (usize::MAX, 0); 77 | for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { 78 | if rank < min_rank.0 { 79 | min_rank = (rank, i); 80 | } 81 | } 82 | 83 | if min_rank.0 != usize::MAX { 84 | let i = min_rank.1; 85 | 86 | // NOTE: We are about to remove parts[i + 1]. We do not do it 87 | // yet because there are cache-locality benefits to updating 88 | // parts[i] and parts[i-1] before removing, which could thrash 89 | // the cache. Thus, we update the rank calculation by skipping over 90 | // parts[i + 1], by invoking `get_rank!` with `skip = 1`. 91 | parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX); 92 | if i > 0 { 93 | parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX); 94 | } 95 | 96 | parts.remove(i + 1); 97 | } else { 98 | break; 99 | } 100 | } 101 | let mut out: Vec = Vec::with_capacity(parts.len() - 1); 102 | for i in 0..parts.len() - 1 { 103 | out.push(f(parts[i].0..parts[i + 1].0)); 104 | } 105 | out 106 | } 107 | 108 | pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { 109 | if piece.len() == 1 { 110 | return vec![ranks[piece]]; 111 | } 112 | _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) 113 | } 114 | 115 | pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { 116 | if piece.len() == 1 { 117 | return vec![piece]; 118 | } 119 | _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) 120 | } 121 | 122 | // Various performance notes: 123 | // 124 | // Regex 125 | // ===== 126 | // Most of the time is spent in regex. The easiest way to speed this up is by using less fancy 127 | // regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than 128 | // the usual regex we use. 129 | // 130 | // However, given that we're using a regex parse-able by `regex`, there isn't much difference 131 | // between using the `regex` crate and using the `fancy_regex` crate. 132 | // 133 | // There is an important interaction between threading, `regex` and `fancy_regex`. 134 | // When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on 135 | // some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain 136 | // old `regex`, we don't hit this, because `find_iter` has a different code path. 137 | // Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md 138 | // Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for 139 | // each thread. 140 | // 141 | // Threading 142 | // ========= 143 | // I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. 144 | // So goodbye `rayon`! Let thread count etc be in control of our Python users. 145 | // 146 | // Caching 147 | // ======= 148 | // The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. 149 | // Originally, we had one too! Without it, we were only vaguely faster than Python. 150 | // I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance 151 | // noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect 152 | // multi-threaded performance even when I only had readers (maybed I messed something up?). 153 | // Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! 154 | // These are exactly the set or merges that are likely to be hot. And now we don't have to think 155 | // about interior mutability, memory use, or cloning. 156 | // 157 | // Hashing 158 | // ======= 159 | // We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? 160 | // The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made 161 | // to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. 162 | 163 | use std::num::NonZeroU64; 164 | pub struct FakeThreadId(NonZeroU64); 165 | 166 | fn hash_current_thread() -> usize { 167 | // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter 168 | // that works great for our use case of avoiding collisions in our array. Unfortunately, 169 | // it's private. However, there are only so many ways you can layout a u64, so just transmute 170 | // https://github.com/rust-lang/rust/issues/67939 171 | const _: [u8; 8] = [0; std::mem::size_of::()]; 172 | const _: [u8; 8] = [0; std::mem::size_of::()]; 173 | let x = unsafe { 174 | std::mem::transmute::(thread::current().id()).0 175 | }; 176 | u64::from(x) as usize 177 | } 178 | 179 | struct State { 180 | core_bpe: Mutex>, 181 | } 182 | 183 | #[mlua::lua_module] 184 | pub fn tiktoken_core(lua: &mlua::Lua) -> LuaResult { 185 | let core_bpe = State { 186 | core_bpe: Mutex::new(None), 187 | }; 188 | let state = Arc::new(core_bpe); 189 | let state2 = Arc::clone(&state); 190 | 191 | let _new = lua.create_function( 192 | move |_, 193 | (encoder_path, special_tokens_encoder, pattern): ( 194 | String, 195 | HashMap, 196 | String, 197 | )| { 198 | new(&*state, encoder_path, special_tokens_encoder, pattern); 199 | Ok(()) 200 | }, 201 | )?; 202 | let _encode = lua.create_function(move |_, text: mlua::String| encode(&*state2, text))?; 203 | 204 | let exports = lua.create_table()?; 205 | exports.set("new", _new)?; 206 | exports.set("encode", _encode)?; 207 | Ok(exports) 208 | } 209 | 210 | fn new( 211 | state: &State, 212 | encoder_path: String, 213 | special_tokens_encoder: HashMap, 214 | pattern: String, 215 | ) { 216 | let mut encoder: HashMap, usize> = HashMap::default(); 217 | // Read the encoder file each line is a base64 encoded token and rank separated by a space 218 | let file = File::open(&encoder_path) 219 | .map_err(|e| format!("Failed to open encoder file '{}': {}", encoder_path, e)) 220 | .unwrap(); 221 | let reader = BufReader::new(file); 222 | for line in reader.lines() { 223 | let line = line 224 | .map_err(|e| format!("Failed to read line from encoder file: {}", e)) 225 | .unwrap(); 226 | let mut parts = line.split_whitespace(); 227 | let token_b64 = parts.next() 228 | .ok_or_else(|| format!("Invalid encoder file format: missing token in line '{}'", line)) 229 | .unwrap(); 230 | let token = BASE64_STANDARD 231 | .decode(token_b64.as_bytes()) 232 | .map_err(|e| format!("Failed to decode base64 token '{}': {}", token_b64, e)) 233 | .unwrap(); 234 | let rank_str = parts.next() 235 | .ok_or_else(|| format!("Invalid encoder file format: missing rank in line '{}'", line)) 236 | .unwrap(); 237 | let rank = rank_str.parse() 238 | .map_err(|e| format!("Failed to parse rank '{}': {}", rank_str, e)) 239 | .unwrap(); 240 | encoder.insert(token, rank); 241 | } 242 | let regex = Regex::new(&pattern) 243 | .map_err(|e| format!("Failed to compile main regex pattern '{}': {}", pattern, e)) 244 | .unwrap(); 245 | let special_regex = { 246 | let _parts = special_tokens_encoder 247 | .keys() 248 | .map(|s| fancy_regex::escape(s)) 249 | .collect::>(); 250 | Regex::new(&_parts.join("|")) 251 | .map_err(|e| format!("Failed to compile special tokens regex: {}", e)) 252 | .unwrap() 253 | }; 254 | let special_tokens_decoder: HashMap> = special_tokens_encoder 255 | .iter() 256 | .map(|(k, v)| (*v, k.as_bytes().to_vec())) 257 | .collect(); 258 | let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); 259 | sorted_token_bytes.sort(); 260 | let mut core_bpe_lock = state.core_bpe.lock() 261 | .map_err(|e| format!("Failed to acquire lock on core_bpe: {}", e)) 262 | .unwrap(); 263 | *core_bpe_lock = Some(CoreBPENative { 264 | encoder, 265 | special_tokens_encoder, 266 | // empty decoder 267 | decoder: HashMap::default(), 268 | special_tokens_decoder, 269 | regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), 270 | special_regex_tls: (0..MAX_NUM_THREADS) 271 | .map(|_| special_regex.clone()) 272 | .collect(), 273 | sorted_token_bytes, 274 | }); 275 | } 276 | 277 | fn encode(state: &State, text: mlua::String) -> LuaResult<(Vec, usize, usize)> { 278 | let text_bytes = text.as_bytes(); 279 | let encoded_str = String::from_utf8_lossy(&text_bytes); 280 | let allowed_special = HashSet::new(); 281 | let max_tokens = None; 282 | Ok(state 283 | .core_bpe 284 | .lock() 285 | .map_err(|e| mlua::Error::external(format!("Failed to acquire lock on core_bpe: {}", e)))? 286 | .as_ref() 287 | .ok_or_else(|| mlua::Error::external("Core BPE not initialized"))? 288 | ._encode_native(&encoded_str, &allowed_special, max_tokens)) 289 | } 290 | 291 | pub struct CoreBPENative { 292 | encoder: HashMap, usize>, 293 | special_tokens_encoder: HashMap, 294 | decoder: HashMap>, 295 | special_tokens_decoder: HashMap>, 296 | regex_tls: Vec, 297 | special_regex_tls: Vec, 298 | sorted_token_bytes: Vec>, 299 | } 300 | 301 | impl CoreBPENative { 302 | fn _get_tl_regex(&self) -> &Regex { 303 | // See performance notes above for what this is about 304 | // It's also a little janky, please make a better version of it! 305 | // However, it's nice that this doesn't leak memory to short-lived threads 306 | &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] 307 | } 308 | 309 | fn _get_tl_special_regex(&self) -> &Regex { 310 | &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] 311 | } 312 | 313 | pub fn _decode_native(&self, tokens: &[usize]) -> Vec { 314 | let mut ret = Vec::with_capacity(tokens.len() * 2); 315 | for token in tokens { 316 | let token_bytes = self 317 | .decoder 318 | .get(token) 319 | .unwrap_or_else(|| &self.special_tokens_decoder[token]); 320 | ret.extend(token_bytes); 321 | } 322 | ret 323 | } 324 | 325 | pub fn _encode_ordinary_native(&self, text: &str) -> Vec { 326 | // This is the core of the encoding logic; the other functions in here 327 | // just make things complicated :-) 328 | let regex = self._get_tl_regex(); 329 | let mut ret = vec![]; 330 | for mat in regex.find_iter(text) { 331 | let piece = mat 332 | .map_err(|e| format!("Regex matching failed: {}", e)) 333 | .unwrap() 334 | .as_str().as_bytes(); 335 | if let Some(token) = self.encoder.get(piece) { 336 | ret.push(*token); 337 | continue; 338 | } 339 | ret.extend(&byte_pair_encode(piece, &self.encoder)); 340 | } 341 | ret 342 | } 343 | 344 | pub fn _encode_native( 345 | &self, 346 | text: &str, 347 | allowed_special: &HashSet<&str>, 348 | max_tokens: Option, 349 | ) -> (Vec, usize, usize) { 350 | let max_tokens = max_tokens.unwrap_or(usize::MAX); 351 | let special_regex = self._get_tl_special_regex(); 352 | let regex = self._get_tl_regex(); 353 | let mut ret = vec![]; 354 | 355 | let mut start = 0; 356 | let mut last_piece_token_len = 0; 357 | loop { 358 | let mut next_special; 359 | let mut start_find = start; 360 | loop { 361 | // Find the next allowed special token, if any 362 | next_special = special_regex.find_from_pos(text, start_find) 363 | .map_err(|e| format!("Special regex matching failed at position {}: {}", start_find, e)) 364 | .unwrap(); 365 | match next_special { 366 | Some(m) => { 367 | if allowed_special.contains(&text[m.start()..m.end()]) { 368 | break; 369 | } 370 | start_find = m.start() + 1; 371 | } 372 | None => break, 373 | } 374 | } 375 | let end = next_special.map_or(text.len(), |m| m.start()); 376 | 377 | // Okay, here we go, compare this logic to _encode_ordinary_native 378 | for mat in regex.find_iter(&text[start..end]) { 379 | let piece = mat 380 | .map_err(|e| format!("Regex matching failed in text slice: {}", e)) 381 | .unwrap() 382 | .as_str().as_bytes(); 383 | if let Some(token) = self.encoder.get(piece) { 384 | last_piece_token_len = 1; 385 | ret.push(*token); 386 | 387 | if ret.len() >= max_tokens { 388 | return (ret, last_piece_token_len, start); 389 | } 390 | continue; 391 | } 392 | let tokens = byte_pair_encode(piece, &self.encoder); 393 | last_piece_token_len = tokens.len(); 394 | for token in tokens { 395 | ret.push(token); 396 | if ret.len() >= max_tokens { 397 | return (ret, last_piece_token_len, start); 398 | } 399 | } 400 | } 401 | 402 | match next_special { 403 | // And here we push the special token 404 | Some(m) => { 405 | let piece = m.as_str(); 406 | let token = self.special_tokens_encoder[piece]; 407 | ret.push(token); 408 | 409 | start = m.end(); 410 | last_piece_token_len = 0; 411 | if ret.len() >= max_tokens { 412 | return (ret, last_piece_token_len, start); 413 | } 414 | } 415 | None => break, 416 | } 417 | } 418 | 419 | // last_piece_token_len is how many tokens came from the last regex split. This is used 420 | // for determining unstable tokens, since you can't merge across (stable) regex splits 421 | (ret, last_piece_token_len, start) 422 | } 423 | 424 | pub fn _encode_bytes(&self, bytes: &[u8]) -> Vec { 425 | match std::str::from_utf8(bytes) { 426 | Ok(text) => self._encode_ordinary_native(text), 427 | Err(e) => { 428 | let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; 429 | let (tokens, last_piece_token_len, _) = 430 | self._encode_native(text, &HashSet::new(), None); 431 | let (mut tokens, last_piece_token_len) = 432 | self._increase_last_piece_token_len(tokens, last_piece_token_len); 433 | if !tokens.is_empty() && last_piece_token_len > 0 { 434 | // Lop off the tokens from the last piece and run BPE on the remaining bytes 435 | // Somewhat niche, but this may not be correct if we'd have had a regex 436 | // split between the valid UTF-8 and the invalid bytes, which is why this 437 | // method is private 438 | let mut unstable_bytes = 439 | self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); 440 | unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); 441 | 442 | tokens.truncate(tokens.len() - last_piece_token_len); 443 | tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); 444 | } 445 | tokens 446 | } 447 | } 448 | } 449 | 450 | fn _increase_last_piece_token_len( 451 | &self, 452 | tokens: Vec, 453 | mut last_piece_token_len: usize, 454 | ) -> (Vec, usize) { 455 | // Unfortunately, the locations where our regex splits can be unstable. 456 | // For the purposes of determining unstable tokens, unstable regex splitting 457 | // is only a problem if a split that was present disappears, since this can 458 | // lead to merging of tokens otherwise thought to be stable. 459 | // cl100k_base makes our life hard by including the \s*[\r\n]+ 460 | // pattern. This can e.g. cause "\n" + " " to become "\n \n". 461 | // Here is a quick and dirty fix: 462 | { 463 | let token_is_all_space = |token| { 464 | self.decoder 465 | .get(token) 466 | .map(|token_bytes| { 467 | token_bytes 468 | .iter() 469 | .rev() 470 | .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) 471 | }) 472 | .unwrap_or(false) 473 | }; 474 | if last_piece_token_len > 0 475 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) 476 | { 477 | while (last_piece_token_len < tokens.len()) 478 | && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) 479 | { 480 | last_piece_token_len += 1; 481 | } 482 | } 483 | } 484 | debug_assert!(last_piece_token_len <= tokens.len()); 485 | 486 | (tokens, last_piece_token_len) 487 | } 488 | 489 | pub fn _encode_unstable_native( 490 | &self, 491 | text: &str, 492 | allowed_special: &HashSet<&str>, 493 | ) -> (Vec, HashSet>) { 494 | let (tokens, last_piece_token_len, _) = self._encode_native(text, allowed_special, None); 495 | if last_piece_token_len == 0 { 496 | // If last_piece_token_len is zero, the last token was a special token and we have 497 | // no unstable bytes 498 | return (tokens, HashSet::new()); 499 | } 500 | let (mut tokens, last_piece_token_len) = 501 | self._increase_last_piece_token_len(tokens, last_piece_token_len); 502 | 503 | let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); 504 | tokens.truncate(tokens.len() - last_piece_token_len); 505 | 506 | // TODO: we should try harder to find additional stable tokens 507 | // This would reduce the amount of retokenising when determining completions 508 | // Refer to the logic in an older version of this file 509 | 510 | let mut completions = HashSet::new(); 511 | if unstable_bytes.is_empty() { 512 | return (tokens, completions); 513 | } 514 | 515 | // This is the easy bit. Just find all single tokens that start with unstable_bytes 516 | // (including tokens that exactly match unstable_bytes) 517 | // Separating this from the loop below helps with performance in a common case. 518 | let mut point = self 519 | .sorted_token_bytes 520 | .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); 521 | while point < self.sorted_token_bytes.len() 522 | && self.sorted_token_bytes[point].starts_with(&unstable_bytes) 523 | { 524 | completions.insert(vec![ 525 | self.encoder[self.sorted_token_bytes[point].as_slice()], 526 | ]); 527 | point += 1; 528 | } 529 | 530 | // Now apply even more brute force. At every (other) possible position for the straddling 531 | // token, concatenate additional bytes from that token (if any) to unstable_bytes, 532 | // and retokenise the whole thing and see what we get. 533 | for i in 1..unstable_bytes.len() { 534 | let prefix = &unstable_bytes[..i]; 535 | let suffix = &unstable_bytes[i..]; 536 | let mut point = self 537 | .sorted_token_bytes 538 | .partition_point(|x| x.as_slice() < suffix); 539 | // TODO: Perf optimisation if suffix starts with " "? 540 | while point < self.sorted_token_bytes.len() 541 | && self.sorted_token_bytes[point].starts_with(suffix) 542 | { 543 | let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); 544 | let encoded = match std::str::from_utf8(&possibility) { 545 | // Morally, this is byte_pair_encode(&possibility, &self.encoder) 546 | // But we might have introduced a regex split which would prevent merges. 547 | // (particularly possible in the presence of unstable regex splits) 548 | // So convert to UTF-8 and do regex splitting. 549 | // E.g. with cl100k_base " !" gets split to " " + " !", 550 | // but byte_pair_encode(" !") != byte_pair_encode(" ") 551 | Ok(s) => self._encode_ordinary_native(s), 552 | 553 | // Technically, whether or not this arm is correct depends on whether there 554 | // would be a regex split before the UTF-8 truncation point. 555 | // Probably niche enough that no one will ever notice (after all, people didn't 556 | // notice all the big holes in the previous unstable token implementation) 557 | Err(_) => byte_pair_encode(&possibility, &self.encoder), 558 | // Something like the following is intriguing but incorrect: 559 | // Err(e) => self._encode_ordinary_native(unsafe { 560 | // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) 561 | // }), 562 | }; 563 | let mut seq = Vec::new(); 564 | let mut seq_len = 0; 565 | for token in encoded { 566 | seq.push(token); 567 | seq_len += self.decoder[&token].len(); 568 | if seq_len >= unstable_bytes.len() { 569 | break; 570 | } 571 | } 572 | completions.insert(seq); 573 | point += 1; 574 | } 575 | } 576 | 577 | // This is also not straightforward. While we generally assume that regex splits are stable, 578 | // unfortunately, they are not. That is, if adding bytes were to make a split appear in 579 | // unstable_bytes, this could make tokens possible which our logic would otherwise think 580 | // would be merged. 581 | // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could 582 | // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. 583 | // Here is a quick and dirty fix: 584 | // This isn't right if we ever remove \s+(?!\S) 585 | if unstable_bytes.len() > 1 { 586 | let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); 587 | if unstable_bytes.len() - last_decoded.1 > 0 588 | && last_decoded.0.map_or(false, |c| c.is_whitespace()) 589 | { 590 | let mut reencoded = byte_pair_encode( 591 | &unstable_bytes[..unstable_bytes.len() - last_decoded.1], 592 | &self.encoder, 593 | ); 594 | reencoded.extend(byte_pair_encode( 595 | &unstable_bytes[unstable_bytes.len() - last_decoded.1..], 596 | &self.encoder, 597 | )); 598 | completions.insert(reencoded); 599 | } 600 | } 601 | 602 | (tokens, completions) 603 | } 604 | 605 | pub fn encode_single_token(&self, piece: &[u8]) -> Result> { 606 | if let Some(token) = self.encoder.get(piece).copied() { 607 | return Ok(token); 608 | } 609 | if let Ok(piece_str) = std::str::from_utf8(piece) { 610 | if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { 611 | return Ok(token); 612 | } 613 | } 614 | Err(piece.to_owned()) 615 | } 616 | 617 | // ==================== 618 | // Decoding 619 | // ==================== 620 | 621 | pub fn decode_single_token_bytes(&self, token: usize) -> Result<&[u8], String> { 622 | if let Some(bytes) = self.decoder.get(&token) { 623 | return Ok(bytes); 624 | } 625 | if let Some(bytes) = self.special_tokens_decoder.get(&token) { 626 | return Ok(bytes); 627 | } 628 | Err(token.to_string()) 629 | } 630 | 631 | // ==================== 632 | // Miscellaneous 633 | // ==================== 634 | 635 | pub fn token_byte_values(&self) -> &Vec> { 636 | &self.sorted_token_bytes 637 | } 638 | 639 | pub fn new( 640 | encoder: HashMap, usize>, 641 | special_tokens_encoder: HashMap, 642 | pattern: &str, 643 | ) -> Result { 644 | let regex = Regex::new(pattern)?; 645 | // .map_err(|e| PyErr::new::(e.to_string()))?; 646 | 647 | let special_regex = { 648 | let _parts = special_tokens_encoder 649 | .keys() 650 | .map(|s| fancy_regex::escape(s)) 651 | .collect::>(); 652 | Regex::new(&_parts.join("|"))? 653 | 654 | // .map_err(|e| PyErr::new::(e.to_string()))? 655 | }; 656 | 657 | let decoder: HashMap> = 658 | encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); 659 | 660 | assert!( 661 | encoder.len() == decoder.len(), 662 | "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" 663 | ); 664 | 665 | let special_tokens_decoder: HashMap> = special_tokens_encoder 666 | .iter() 667 | .map(|(k, v)| (*v, k.as_bytes().to_vec())) 668 | .collect(); 669 | 670 | // Clone because I don't know how to tell Rust I'm not going to change the map 671 | let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); 672 | sorted_token_bytes.sort(); 673 | 674 | Ok(CoreBPENative { 675 | encoder, 676 | special_tokens_encoder, 677 | decoder, 678 | special_tokens_decoder, 679 | regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), 680 | special_regex_tls: (0..MAX_NUM_THREADS) 681 | .map(|_| special_regex.clone()) 682 | .collect(), 683 | sorted_token_bytes, 684 | }) 685 | } 686 | } 687 | 688 | #[cfg(test)] 689 | mod tests { 690 | use rustc_hash::FxHashMap as HashMap; 691 | 692 | use crate::byte_pair_split; 693 | 694 | #[test] 695 | fn very_simple_test() { 696 | let mut ranks = HashMap::default(); 697 | ranks.insert(b"ab".to_vec(), 1); 698 | ranks.insert(b"cd".to_vec(), 2); 699 | 700 | let res = byte_pair_split(b"abcd", &ranks); 701 | assert_eq!(res, vec![b"ab", b"cd"]); 702 | } 703 | } 704 | --------------------------------------------------------------------------------