├── .clang-format ├── .cmakelintrc ├── .flake8 ├── .github └── workflows │ ├── pull.yml │ └── trunk.yml ├── .gitignore ├── .gitmodules ├── .lintrunner.toml ├── .mypy.ini ├── CMakeLists.txt ├── CODE_OF_CONDUCT ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── TARGETS ├── Utils.cmake ├── examples └── tokenize_tool │ ├── CMakeLists.txt │ └── main.cpp ├── include └── pytorch │ └── tokenizers │ ├── base64.h │ ├── bpe_tokenizer_base.h │ ├── error.h │ ├── hf_tokenizer.h │ ├── llama2c_tokenizer.h │ ├── log.h │ ├── pcre2_regex.h │ ├── pre_tokenizer.h │ ├── re2_regex.h │ ├── regex.h │ ├── result.h │ ├── sentencepiece.h │ ├── std_regex.h │ ├── string_integer_map.h │ ├── tiktoken.h │ ├── token_decoder.h │ └── tokenizer.h ├── pyproject.toml ├── pytorch_tokenizers ├── TARGETS ├── __init__.py ├── constants.py ├── hf_tokenizer.py ├── llama2c.py ├── targets.bzl ├── tiktoken.py └── tools │ ├── __init__.py │ └── llama2c │ ├── TARGETS │ ├── __init__.py │ ├── convert.py │ └── targets.bzl ├── requirements-lintrunner.txt ├── setup.py ├── src ├── bpe_tokenizer_base.cpp ├── hf_tokenizer.cpp ├── llama2c_tokenizer.cpp ├── pcre2_regex.cpp ├── pre_tokenizer.cpp ├── re2_regex.cpp ├── regex.cpp ├── regex_lookahead.cpp ├── sentencepiece.cpp ├── std_regex.cpp ├── tiktoken.cpp └── token_decoder.cpp ├── targets.bzl ├── test ├── TARGETS ├── fb │ └── TARGETS ├── resources │ ├── test_bpe_tokenizer.bin │ ├── test_llama2c_tokenizer.bin │ ├── test_sentencepiece.model │ ├── test_tiktoken_invalid_base64.model │ ├── test_tiktoken_invalid_rank.model │ ├── test_tiktoken_no_space.model │ └── test_tiktoken_tokenizer.model ├── targets.bzl ├── test_base64.cpp ├── test_llama2c_tokenizer.cpp ├── test_pre_tokenizer.cpp ├── test_regex.cpp ├── test_sentencepiece.cpp ├── test_string_integer_map.cpp ├── test_tiktoken.cpp └── test_tiktoken.py └── third-party ├── TARGETS ├── llama.cpp-unicode ├── CMakeLists.txt ├── README.md ├── include │ ├── unicode-data.h │ └── unicode.h └── src │ ├── unicode-data.cpp │ └── unicode.cpp └── targets.bzl /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | AccessModifierOffset: -1 4 | AlignAfterOpenBracket: AlwaysBreak 5 | AlignArrayOfStructures: None 6 | AlignConsecutiveAssignments: 7 | Enabled: false 8 | AcrossEmptyLines: false 9 | AcrossComments: false 10 | AlignCompound: false 11 | AlignFunctionPointers: false 12 | PadOperators: true 13 | AlignConsecutiveBitFields: 14 | Enabled: false 15 | AcrossEmptyLines: false 16 | AcrossComments: false 17 | AlignCompound: false 18 | AlignFunctionPointers: false 19 | PadOperators: true 20 | AlignConsecutiveDeclarations: 21 | Enabled: false 22 | AcrossEmptyLines: false 23 | AcrossComments: false 24 | AlignCompound: false 25 | AlignFunctionPointers: false 26 | PadOperators: true 27 | AlignConsecutiveMacros: 28 | Enabled: false 29 | AcrossEmptyLines: false 30 | AcrossComments: false 31 | AlignCompound: false 32 | AlignFunctionPointers: false 33 | PadOperators: true 34 | AlignConsecutiveShortCaseStatements: 35 | Enabled: false 36 | AcrossEmptyLines: false 37 | AcrossComments: false 38 | AlignCaseColons: false 39 | AlignEscapedNewlines: Left 40 | AlignOperands: DontAlign 41 | AlignTrailingComments: 42 | Kind: Never 43 | OverEmptyLines: 0 44 | AllowAllArgumentsOnNextLine: true 45 | AllowAllParametersOfDeclarationOnNextLine: false 46 | AllowBreakBeforeNoexceptSpecifier: Never 47 | AllowShortBlocksOnASingleLine: Never 48 | AllowShortCaseLabelsOnASingleLine: false 49 | AllowShortCompoundRequirementOnASingleLine: true 50 | AllowShortEnumsOnASingleLine: true 51 | AllowShortFunctionsOnASingleLine: Empty 52 | AllowShortIfStatementsOnASingleLine: Never 53 | AllowShortLambdasOnASingleLine: All 54 | AllowShortLoopsOnASingleLine: false 55 | AlwaysBreakAfterDefinitionReturnType: None 56 | AlwaysBreakAfterReturnType: None 57 | AlwaysBreakBeforeMultilineStrings: true 58 | AlwaysBreakTemplateDeclarations: Yes 59 | AttributeMacros: 60 | - __capability 61 | BinPackArguments: false 62 | BinPackParameters: false 63 | BitFieldColonSpacing: Both 64 | BraceWrapping: 65 | AfterCaseLabel: false 66 | AfterClass: false 67 | AfterControlStatement: Never 68 | AfterEnum: false 69 | AfterExternBlock: false 70 | AfterFunction: false 71 | AfterNamespace: false 72 | AfterObjCDeclaration: false 73 | AfterStruct: false 74 | AfterUnion: false 75 | BeforeCatch: false 76 | BeforeElse: false 77 | BeforeLambdaBody: false 78 | BeforeWhile: false 79 | IndentBraces: false 80 | SplitEmptyFunction: true 81 | SplitEmptyRecord: true 82 | SplitEmptyNamespace: true 83 | BreakAdjacentStringLiterals: true 84 | BreakAfterAttributes: Leave 85 | BreakAfterJavaFieldAnnotations: false 86 | BreakArrays: true 87 | BreakBeforeBinaryOperators: None 88 | BreakBeforeConceptDeclarations: Always 89 | BreakBeforeBraces: Attach 90 | BreakBeforeInlineASMColon: OnlyMultiline 91 | BreakBeforeTernaryOperators: true 92 | BreakConstructorInitializers: BeforeColon 93 | BreakInheritanceList: BeforeColon 94 | BreakStringLiterals: false 95 | ColumnLimit: 80 96 | CommentPragmas: '^ IWYU pragma:' 97 | CompactNamespaces: false 98 | ConstructorInitializerIndentWidth: 4 99 | ContinuationIndentWidth: 4 100 | Cpp11BracedListStyle: true 101 | DerivePointerAlignment: false 102 | DisableFormat: false 103 | EmptyLineAfterAccessModifier: Never 104 | EmptyLineBeforeAccessModifier: LogicalBlock 105 | ExperimentalAutoDetectBinPacking: false 106 | FixNamespaceComments: true 107 | ForEachMacros: 108 | - FOR_EACH 109 | - FOR_EACH_R 110 | - FOR_EACH_RANGE 111 | IfMacros: 112 | - KJ_IF_MAYBE 113 | IncludeBlocks: Preserve 114 | IncludeCategories: 115 | - Regex: '^<.*\.h(pp)?>' 116 | Priority: 1 117 | SortPriority: 0 118 | CaseSensitive: false 119 | - Regex: '^<.*' 120 | Priority: 2 121 | SortPriority: 0 122 | CaseSensitive: false 123 | - Regex: '.*' 124 | Priority: 3 125 | SortPriority: 0 126 | CaseSensitive: false 127 | IncludeIsMainRegex: '(Test)?$' 128 | IncludeIsMainSourceRegex: '' 129 | IndentAccessModifiers: false 130 | IndentCaseBlocks: false 131 | IndentCaseLabels: true 132 | IndentExternBlock: AfterExternBlock 133 | IndentGotoLabels: true 134 | IndentPPDirectives: None 135 | IndentRequiresClause: true 136 | IndentWidth: 2 137 | IndentWrappedFunctionNames: false 138 | InsertBraces: false 139 | InsertNewlineAtEOF: false 140 | InsertTrailingCommas: None 141 | IntegerLiteralSeparator: 142 | Binary: 0 143 | BinaryMinDigits: 0 144 | Decimal: 0 145 | DecimalMinDigits: 0 146 | Hex: 0 147 | HexMinDigits: 0 148 | JavaScriptQuotes: Leave 149 | JavaScriptWrapImports: true 150 | KeepEmptyLinesAtTheStartOfBlocks: false 151 | KeepEmptyLinesAtEOF: false 152 | LambdaBodyIndentation: Signature 153 | LineEnding: DeriveLF 154 | MacroBlockBegin: '' 155 | MacroBlockEnd: '' 156 | MaxEmptyLinesToKeep: 1 157 | NamespaceIndentation: None 158 | ObjCBinPackProtocolList: Auto 159 | ObjCBlockIndentWidth: 2 160 | ObjCBreakBeforeNestedBlockParam: true 161 | ObjCSpaceAfterProperty: false 162 | ObjCSpaceBeforeProtocolList: false 163 | PackConstructorInitializers: NextLine 164 | PenaltyBreakAssignment: 2 165 | PenaltyBreakBeforeFirstCallParameter: 1 166 | PenaltyBreakComment: 300 167 | PenaltyBreakFirstLessLess: 120 168 | PenaltyBreakOpenParenthesis: 0 169 | PenaltyBreakScopeResolution: 500 170 | PenaltyBreakString: 1000 171 | PenaltyBreakTemplateDeclaration: 10 172 | PenaltyExcessCharacter: 1000000 173 | PenaltyIndentedWhitespace: 0 174 | PenaltyReturnTypeOnItsOwnLine: 200 175 | PointerAlignment: Left 176 | PPIndentWidth: -1 177 | QualifierAlignment: Leave 178 | ReferenceAlignment: Pointer 179 | ReflowComments: true 180 | RemoveBracesLLVM: false 181 | RemoveParentheses: Leave 182 | RemoveSemicolon: false 183 | RequiresClausePosition: OwnLine 184 | RequiresExpressionIndentation: OuterScope 185 | SeparateDefinitionBlocks: Leave 186 | ShortNamespaceLines: 1 187 | SkipMacroDefinitionBody: false 188 | SortIncludes: CaseSensitive 189 | SortJavaStaticImport: Before 190 | SortUsingDeclarations: LexicographicNumeric 191 | SpaceAfterCStyleCast: false 192 | SpaceAfterLogicalNot: false 193 | SpaceAfterTemplateKeyword: true 194 | SpaceAroundPointerQualifiers: Default 195 | SpaceBeforeAssignmentOperators: true 196 | SpaceBeforeCaseColon: false 197 | SpaceBeforeCpp11BracedList: false 198 | SpaceBeforeCtorInitializerColon: true 199 | SpaceBeforeInheritanceColon: true 200 | SpaceBeforeJsonColon: false 201 | SpaceBeforeParens: ControlStatements 202 | SpaceBeforeParensOptions: 203 | AfterControlStatements: true 204 | AfterForeachMacros: true 205 | AfterFunctionDefinitionName: false 206 | AfterFunctionDeclarationName: false 207 | AfterIfMacros: true 208 | AfterOverloadedOperator: false 209 | AfterPlacementOperator: true 210 | AfterRequiresInClause: false 211 | AfterRequiresInExpression: false 212 | BeforeNonEmptyParentheses: false 213 | SpaceBeforeRangeBasedForLoopColon: true 214 | SpaceBeforeSquareBrackets: false 215 | SpaceInEmptyBlock: false 216 | SpacesBeforeTrailingComments: 1 217 | SpacesInAngles: Never 218 | SpacesInContainerLiterals: true 219 | SpacesInLineCommentPrefix: 220 | Minimum: 1 221 | Maximum: -1 222 | SpacesInParens: Never 223 | SpacesInParensOptions: 224 | InCStyleCasts: false 225 | InConditionalStatements: false 226 | InEmptyParentheses: false 227 | Other: false 228 | SpacesInSquareBrackets: false 229 | Standard: Latest 230 | StatementAttributeLikeMacros: 231 | - Q_EMIT 232 | StatementMacros: 233 | - Q_UNUSED 234 | - QT_REQUIRE_VERSION 235 | TabWidth: 8 236 | UseTab: Never 237 | VerilogBreakBetweenInstancePorts: true 238 | WhitespaceSensitiveMacros: 239 | - BOOST_PP_STRINGIZE 240 | - CF_SWIFT_NAME 241 | - NS_SWIFT_NAME 242 | - PP_STRINGIZE 243 | - STRINGIZE 244 | ... 245 | -------------------------------------------------------------------------------- /.cmakelintrc: -------------------------------------------------------------------------------- 1 | filter=-convention/filename,-linelength,-package/consistency,-readability/logic,+readability/mixedcase,-readability/wonkycase,-syntax,-whitespace/eol,+whitespace/extra,-whitespace/indent,-whitespace/mismatch,-whitespace/newline,-whitespace/tabs 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,W,B9,TOR0,TOR1,TOR2 3 | max-line-length = 80 4 | ignore = 5 | # Black conflicts and overlaps. 6 | B950, 7 | E111, 8 | E115, 9 | E117, 10 | E121, 11 | E122, 12 | E123, 13 | E124, 14 | E125, 15 | E126, 16 | E127, 17 | E128, 18 | E129, 19 | E131, 20 | E201, 21 | E202, 22 | E203, 23 | E221, 24 | E222, 25 | E225, 26 | E226, 27 | E227, 28 | E231, 29 | E241, 30 | E251, 31 | E252, 32 | E261, 33 | E262, 34 | E265, 35 | E271, 36 | E272, 37 | E301, 38 | E302, 39 | E303, 40 | E305, 41 | E306, 42 | E501, 43 | E502, 44 | E701, 45 | E702, 46 | E703, 47 | E704, 48 | W291, 49 | W292, 50 | W293, 51 | W391, 52 | W504, 53 | 54 | # Too opinionated. 55 | E265, 56 | E266, 57 | E402, 58 | E722, 59 | B001, 60 | P207, 61 | B003, 62 | P208, 63 | C403, 64 | W503, 65 | 66 | # Bugbear has opinions: https://github.com/PyCQA/flake8-bugbear#opinionated-warnings 67 | B904, 68 | B905, 69 | B906, 70 | B907, 71 | exclude = 72 | ./.git, 73 | ./backends/xnnpack/third-party, 74 | ./build, 75 | ./configurations, 76 | ./docs, 77 | ./third_party, 78 | *.pyi 79 | 80 | max-complexity = 12 81 | -------------------------------------------------------------------------------- /.github/workflows/pull.yml: -------------------------------------------------------------------------------- 1 | name: pull 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | unittest-linux: 16 | name: unittest-linux 17 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 18 | strategy: 19 | fail-fast: false 20 | with: 21 | runner: linux.2xlarge 22 | submodules: 'true' 23 | ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} 24 | timeout: 90 25 | script: | 26 | set -ex 27 | cmake -DTOKENIZERS_BUILD_TEST=ON -DCMAKE_BUILD_TYPE=Debug . -Bbuild 28 | cmake --build build -j9 --config Debug 29 | cd build && ctest 30 | -------------------------------------------------------------------------------- /.github/workflows/trunk.yml: -------------------------------------------------------------------------------- 1 | name: trunk 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - release/* 8 | tags: 9 | - ciflow/trunk/* 10 | pull_request: 11 | paths: 12 | - CMakeLists.txt 13 | - .github/workflows/trunk.yml 14 | workflow_dispatch: 15 | 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | unittest-macos: 22 | name: unittest-macos 23 | uses: pytorch/test-infra/.github/workflows/macos_job.yml@main 24 | strategy: 25 | fail-fast: false 26 | with: 27 | runner: macos-14-xlarge 28 | python-version: '3.11' 29 | submodules: 'true' 30 | ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} 31 | timeout: 90 32 | script: | 33 | set -ex 34 | cmake -DTOKENIZERS_BUILD_TEST=ON -DCMAKE_BUILD_TYPE=Debug . -Bbuild 35 | cmake --build build -j9 --config Debug 36 | cd build && ctest 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # System files 2 | .DS_Store 3 | 4 | # Python environment and cache 5 | .hypothesis 6 | .mypy_cache/ 7 | .venv/ 8 | __pycache__/ 9 | 10 | # Build and tool-generated files 11 | build/ 12 | cmake-out* 13 | dist/ 14 | pip-out/ 15 | *.egg-info 16 | 17 | # Editor temporaries 18 | *.swa 19 | *.swb 20 | *.swc 21 | *.swd 22 | *.swe 23 | *.swf 24 | *.swg 25 | *.swh 26 | *.swi 27 | *.swj 28 | *.swk 29 | *.swl 30 | *.swm 31 | *.swn 32 | *.swo 33 | *.swp 34 | *~ 35 | .~lock.* 36 | *.idea 37 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/sentencepiece"] 2 | path = third-party/sentencepiece 3 | url = https://github.com/google/sentencepiece.git 4 | [submodule "third-party/re2"] 5 | path = third-party/re2 6 | url = https://github.com/google/re2.git 7 | [submodule "third-party/abseil-cpp"] 8 | path = third-party/abseil-cpp 9 | url = https://github.com/abseil/abseil-cpp.git 10 | [submodule "third-party/json"] 11 | path = third-party/json 12 | url = https://github.com/nlohmann/json.git 13 | [submodule "third-party/pcre2"] 14 | path = third-party/pcre2 15 | url = https://github.com/PCRE2Project/pcre2.git 16 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | merge_base_with = "origin/main" 2 | 3 | [[linter]] 4 | code = 'FLAKE8' 5 | include_patterns = ['**/*.py'] 6 | exclude_patterns = [ 7 | 'third-party/**', 8 | ] 9 | command = [ 10 | 'python', 11 | '-m', 12 | 'lintrunner_adapters', 13 | 'run', 14 | 'flake8_linter', 15 | '--', 16 | '@{{PATHSFILE}}' 17 | ] 18 | init_command = [ 19 | 'python', 20 | '-m', 21 | 'lintrunner_adapters', 22 | 'run', 23 | 'pip_init', 24 | '--dry-run={{DRYRUN}}', 25 | '--requirement=requirements-lintrunner.txt', 26 | ] 27 | 28 | # Black + usort 29 | [[linter]] 30 | code = 'UFMT' 31 | include_patterns = [ 32 | '*.py', 33 | '*.pyi', 34 | ] 35 | exclude_patterns = [ 36 | 'third-party/**', 37 | ] 38 | command = [ 39 | 'python', 40 | '-m', 41 | 'lintrunner_adapters', 42 | 'run', 43 | 'ufmt_linter', 44 | '--', 45 | '@{{PATHSFILE}}' 46 | ] 47 | init_command = [ 48 | 'python', 49 | '-m', 50 | 'lintrunner_adapters', 51 | 'run', 52 | 'pip_init', 53 | '--dry-run={{DRYRUN}}', 54 | '--no-black-binary', 55 | '--requirement=requirements-lintrunner.txt', 56 | ] 57 | is_formatter = true 58 | 59 | #CLANGFORMAT 60 | [[linter]] 61 | code = 'CLANGFORMAT' 62 | include_patterns = [ 63 | '**/*.h', 64 | '**/*.cpp', 65 | ] 66 | exclude_patterns = [ 67 | 'third-party/**', 68 | ] 69 | command = [ 70 | 'python', 71 | '-m', 72 | 'lintrunner_adapters', 73 | 'run', 74 | 'clangformat_linter', 75 | '--binary=clang-format', 76 | '--fallback', 77 | '--', 78 | '@{{PATHSFILE}}' 79 | ] 80 | init_command = [ 81 | 'python', 82 | '-m', 83 | 'lintrunner_adapters', 84 | 'run', 85 | 'pip_init', 86 | '--dry-run={{DRYRUN}}', 87 | '--requirement=requirements-lintrunner.txt', 88 | ] 89 | is_formatter = true 90 | 91 | [[linter]] 92 | code = 'CMAKE' 93 | include_patterns = [ 94 | "**/*.cmake", 95 | "**/*.cmake.in", 96 | "**/CMakeLists.txt", 97 | ] 98 | exclude_patterns = [ 99 | 'third-party/**', 100 | ] 101 | command = [ 102 | 'python', 103 | '-m', 104 | 'lintrunner_adapters', 105 | 'run', 106 | 'cmake_linter', 107 | '--config=.cmakelintrc', 108 | '--', 109 | '@{{PATHSFILE}}', 110 | ] 111 | init_command = [ 112 | 'python', 113 | '-m', 114 | 'lintrunner_adapters', 115 | 'run', 116 | 'pip_init', 117 | '--dry-run={{DRYRUN}}', 118 | '--requirement=requirements-lintrunner.txt', 119 | ] 120 | 121 | [[linter]] 122 | code = 'NEWLINE' 123 | include_patterns = ['**'] 124 | exclude_patterns = [ 125 | 'third-party/**', 126 | 'test/resources/*.model', 127 | ] 128 | command = [ 129 | 'python', 130 | '-m', 131 | 'lintrunner_adapters', 132 | 'run', 133 | 'newlines_linter', 134 | '--', 135 | '@{{PATHSFILE}}', 136 | ] 137 | is_formatter = true 138 | 139 | [[linter]] 140 | code = 'MYPY' 141 | include_patterns = [ 142 | '*.py', 143 | ] 144 | exclude_patterns = [ 145 | 'third-party/**', 146 | ] 147 | command = [ 148 | 'python', 149 | '-m', 150 | 'lintrunner_adapters', 151 | 'run', 152 | 'mypy_linter', 153 | '--config=.mypy.ini', 154 | '--show-disable', 155 | '--', 156 | '--explicit-package-bases', 157 | '@{{PATHSFILE}}' 158 | ] 159 | init_command = [ 160 | 'python', 161 | '-m', 162 | 'lintrunner_adapters', 163 | 'run', 164 | 'pip_init', 165 | '--dry-run={{DRYRUN}}', 166 | '--requirement=requirements-lintrunner.txt', 167 | ] 168 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | allow_redefinition = True 3 | warn_unused_configs = True 4 | warn_redundant_casts = True 5 | show_error_codes = True 6 | show_column_numbers = True 7 | disallow_untyped_decorators = True 8 | follow_imports = normal 9 | local_partial_types = True 10 | enable_error_code = possibly-undefined 11 | warn_unused_ignores = False 12 | 13 | mypy_path = pytorch_tokenizers 14 | 15 | [mypy-buck_util] 16 | ignore_missing_imports = True 17 | 18 | [mypy-docutils.*] 19 | ignore_missing_imports = True 20 | 21 | [mypy-pandas] 22 | ignore_missing_imports = True 23 | 24 | [mypy-ruamel] 25 | ignore_missing_imports = True 26 | 27 | [mypy-tomllib] 28 | ignore_missing_imports = True 29 | 30 | [mypy-yaml] 31 | ignore_missing_imports = True 32 | 33 | [mypy-zstd] 34 | ignore_missing_imports = True 35 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # 4 | # Build tokenizers. 5 | # 6 | # ### Editing this file ### 7 | # 8 | # This file should be formatted with 9 | # ~~~ 10 | # cmake-format -i CMakeLists.txt 11 | # ~~~ 12 | # It should also be cmake-lint clean. 13 | # 14 | cmake_minimum_required(VERSION 3.18) 15 | set(CMAKE_CXX_STANDARD 17) 16 | 17 | project(Tokenizers) 18 | 19 | option(TOKENIZERS_BUILD_TEST "Build tests" OFF) 20 | option(TOKENIZERS_BUILD_TOOLS "Build tools" OFF) 21 | option(SUPPORT_REGEX_LOOKAHEAD 22 | "Support regex lookahead patterns (requires PCRE2)" OFF) 23 | 24 | include(Utils.cmake) 25 | # Ignore weak attribute warning 26 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") 27 | 28 | set(ABSL_ENABLE_INSTALL ON) 29 | set(ABSL_PROPAGATE_CXX_STD ON) 30 | 31 | set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) 32 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 33 | 34 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp) 35 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2) 36 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece) 37 | 38 | set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) 39 | 40 | file(GLOB tokenizers_source_files ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) 41 | set(tokenizers_source_files 42 | ${CMAKE_CURRENT_SOURCE_DIR}/src/bpe_tokenizer_base.cpp 43 | ${CMAKE_CURRENT_SOURCE_DIR}/src/hf_tokenizer.cpp 44 | ${CMAKE_CURRENT_SOURCE_DIR}/src/llama2c_tokenizer.cpp 45 | ${CMAKE_CURRENT_SOURCE_DIR}/src/pre_tokenizer.cpp 46 | ${CMAKE_CURRENT_SOURCE_DIR}/src/re2_regex.cpp 47 | ${CMAKE_CURRENT_SOURCE_DIR}/src/regex.cpp 48 | ${CMAKE_CURRENT_SOURCE_DIR}/src/sentencepiece.cpp 49 | ${CMAKE_CURRENT_SOURCE_DIR}/src/tiktoken.cpp 50 | ${CMAKE_CURRENT_SOURCE_DIR}/src/token_decoder.cpp) 51 | 52 | file(GLOB unicode_source_files 53 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/src/*.cpp) 54 | add_library(tokenizers STATIC ${tokenizers_source_files} 55 | ${unicode_source_files}) 56 | 57 | # Using abseil from sentencepiece/third_party 58 | target_include_directories( 59 | tokenizers 60 | PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include 61 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece 62 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece/src 63 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 64 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include 65 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include) 66 | target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2) 67 | 68 | if(SUPPORT_REGEX_LOOKAHEAD OR TOKENIZERS_BUILD_TEST) 69 | set(PCRE2_BUILD_PCRE2_8 ON) 70 | set(PCRE2_BUILD_PCRE2_16 OFF) 71 | set(PCRE2_BUILD_PCRE2_32 OFF) 72 | set(PCRE2_BUILD_TESTS OFF) 73 | set(PCRE2_BUILD_PCRE2GREP OFF) 74 | set(PCRE2_BUILD_PCRE2TEST OFF) 75 | set(PCRE2_BUILD_PCRE2GPERF OFF) 76 | set(PCRE2_BUILD_DOCS OFF) 77 | set(PCRE2_BUILD_LIBPCRE2_PDB OFF) 78 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2) 79 | add_library( 80 | regex_lookahead STATIC 81 | ${CMAKE_CURRENT_SOURCE_DIR}/src/pcre2_regex.cpp 82 | ${CMAKE_CURRENT_SOURCE_DIR}/src/regex_lookahead.cpp 83 | ${CMAKE_CURRENT_SOURCE_DIR}/src/std_regex.cpp) 84 | target_link_libraries(regex_lookahead PUBLIC pcre2-8) 85 | target_include_directories( 86 | regex_lookahead PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include 87 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) 88 | target_link_options_shared_lib(regex_lookahead) 89 | target_link_libraries(tokenizers PUBLIC regex_lookahead) 90 | endif() 91 | 92 | # Build test 93 | if(TOKENIZERS_BUILD_TEST) 94 | enable_testing() 95 | include(FetchContent) 96 | # CMAKE 97 | FetchContent_Declare( 98 | googletest 99 | # Specify the commit you depend on and update it regularly. 100 | URL https://github.com/google/googletest/archive/5376968f6948923e2411081fd9372e71a59d8e77.zip 101 | ) 102 | set(gtest_force_shared_crt 103 | ON 104 | CACHE BOOL "" FORCE) 105 | FetchContent_MakeAvailable(googletest) 106 | 107 | file(GLOB test_source_files ${CMAKE_CURRENT_SOURCE_DIR}/test/test_*.cpp) 108 | 109 | set(test_env "RESOURCES_PATH=${CMAKE_CURRENT_SOURCE_DIR}/test/resources") 110 | foreach(test_source_file ${test_source_files}) 111 | get_filename_component(test_name ${test_source_file} NAME_WE) 112 | message(STATUS "Configuring unit test ${test_name}") 113 | add_executable(${test_name} ${test_source_file}) 114 | target_include_directories( 115 | ${test_name} 116 | PRIVATE GTEST_INCLUDE_PATH 117 | ${CMAKE_CURRENT_SOURCE_DIR}/include 118 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece 119 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 120 | ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include) 121 | target_link_libraries(${test_name} gtest_main GTest::gmock tokenizers) 122 | add_test(${test_name} "${test_name}") 123 | set_tests_properties(${test_name} PROPERTIES ENVIRONMENT ${test_env}) 124 | endforeach() 125 | endif() 126 | 127 | # Build tools 128 | if(TOKENIZERS_BUILD_TOOLS) 129 | add_subdirectory(examples/tokenize_tool) 130 | endif() 131 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tokenizers 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to tokenizers, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024 Meta 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tokenizers 2 | C++ implementations for various tokenizers (sentencepiece, tiktoken etc). Useful for other PyTorch repos such as torchchat, ExecuTorch to build LLM runners using ExecuTorch stack or AOT Inductor stack. 3 | 4 | 5 | ## SentencePiece tokenizer 6 | Depend on https://github.com/google/sentencepiece from Google. 7 | 8 | ## Tiktoken tokenizer 9 | Adapted from https://github.com/sewenew/tokenizer. 10 | 11 | ## Huggingface tokenizer 12 | Compatible with https://github.com/huggingface/tokenizers/. 13 | 14 | ## Llama2.c tokenizer 15 | Adapted from https://github.com/karpathy/llama2.c. 16 | 17 | ## License 18 | 19 | tokenizers is released under the [BSD 3 license](LICENSE). (Additional 20 | code in this distribution is covered by the MIT and Apache Open Source 21 | licenses.) However you may have other legal obligations that govern 22 | your use of content, such as the terms of service for third-party 23 | models. 24 | -------------------------------------------------------------------------------- /TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain fbcode-only targets. 3 | 4 | load(":targets.bzl", "define_common_targets") 5 | 6 | oncall("executorch") 7 | 8 | define_common_targets() 9 | -------------------------------------------------------------------------------- /Utils.cmake: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # 4 | # Build tokenizers. 5 | # 6 | # ### Editing this file ### 7 | # 8 | # This file should be formatted with 9 | # ~~~ 10 | # cmake-format -i CMakeLists.txt 11 | # ~~~ 12 | # It should also be cmake-lint clean. 13 | # 14 | 15 | # This is the funtion to use -Wl, --whole-archive to link static library NB: 16 | # target_link_options is broken for this case, it only append the interface link 17 | # options of the first library. 18 | function(kernel_link_options target_name) 19 | # target_link_options(${target_name} INTERFACE 20 | # "$") 21 | target_link_options( 22 | ${target_name} INTERFACE "SHELL:LINKER:--whole-archive \ 23 | $ \ 24 | LINKER:--no-whole-archive") 25 | endfunction() 26 | 27 | # Same as kernel_link_options but it's for MacOS linker 28 | function(macos_kernel_link_options target_name) 29 | target_link_options(${target_name} INTERFACE 30 | "SHELL:LINKER:-force_load,$") 31 | endfunction() 32 | 33 | # Same as kernel_link_options but it's for MSVC linker 34 | function(msvc_kernel_link_options target_name) 35 | target_link_options( 36 | ${target_name} INTERFACE 37 | "SHELL:LINKER:/WHOLEARCHIVE:$") 38 | endfunction() 39 | 40 | # Ensure that the load-time constructor functions run. By default, the linker 41 | # would remove them since there are no other references to them. 42 | function(target_link_options_shared_lib target_name) 43 | if(APPLE) 44 | macos_kernel_link_options(${target_name}) 45 | elseif(MSVC) 46 | msvc_kernel_link_options(${target_name}) 47 | else() 48 | kernel_link_options(${target_name}) 49 | endif() 50 | endfunction() 51 | -------------------------------------------------------------------------------- /examples/tokenize_tool/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD-style license found in the LICENSE 4 | # file in the root directory of this source tree. 5 | # @lint-ignore-every LICENSELINT 6 | 7 | file(GLOB source_files ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 8 | get_filename_component(tool_name ${CMAKE_CURRENT_SOURCE_DIR} NAME) 9 | add_executable(${tool_name} ${source_files}) 10 | target_link_libraries(${tool_name} PRIVATE tokenizers) 11 | target_include_directories(${tool_name} PRIVATE 12 | ${CMAKE_SOURCE_DIR}/include/pytorch/tokenizers 13 | ) 14 | -------------------------------------------------------------------------------- /examples/tokenize_tool/main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | /** 11 | * This is a simple tool to instantiate a tokenizer and run it over some text. 12 | * It can be used to evaluate the tokenization done by a given tokenizer model 13 | * relative to its native python library. 14 | */ 15 | 16 | // Standard 17 | #include 18 | #include 19 | #include 20 | 21 | // Local 22 | #include "hf_tokenizer.h" 23 | #include "sentencepiece.h" 24 | #include "tiktoken.h" 25 | 26 | using namespace tokenizers; 27 | 28 | std::string help(char* argv[]) { 29 | std::stringstream ss; 30 | ss << "Usage: " << argv[0] << " " 31 | << std::endl 32 | << std::endl; 33 | ss << "Types:\n" << std::endl; 34 | ss << "* sentencepiece: SPTokenizer" << std::endl; 35 | ss << "* tiktoken: Tiktoken" << std::endl; 36 | ss << "* hf_tokenizers: HFTokenizer" << std::endl; 37 | return ss.str(); 38 | } 39 | 40 | int main(int argc, char* argv[]) { 41 | // Check for the right number of CLI args 42 | if (argc < 4) { 43 | std::cerr << help(argv) << std::endl; 44 | return 1; 45 | } 46 | 47 | // Parse CLI args 48 | const std::string tokenizer_type(argv[1]); 49 | const std::string model_path(argv[2]); 50 | std::stringstream prompt_ss; 51 | for (auto i = 3; i < argc; ++i) { 52 | if (i > 3) { 53 | prompt_ss << " "; 54 | } 55 | prompt_ss << argv[i]; 56 | } 57 | const std::string prompt = prompt_ss.str(); 58 | 59 | // Instantiate the tokenizer 60 | std::unique_ptr tok_ptr; 61 | if (tokenizer_type == "sentencepiece") { 62 | tok_ptr.reset(new SPTokenizer()); 63 | } else if (tokenizer_type == "tiktoken") { 64 | tok_ptr.reset(new Tiktoken()); 65 | } else if (tokenizer_type == "hf_tokenizer") { 66 | tok_ptr.reset(new HFTokenizer()); 67 | } else { 68 | std::stringstream ss; 69 | ss << "ERROR: Invalid tokenizer type: " << tokenizer_type << std::endl 70 | << std::endl; 71 | ss << help(argv); 72 | std::cerr << ss.str() << std::endl; 73 | return 1; 74 | } 75 | 76 | // Load from the path 77 | tok_ptr->load(model_path); 78 | 79 | // Log out the IDs for the BOS/EOS tokens 80 | std::cout << "Vocab Size: " << tok_ptr->vocab_size() << std::endl; 81 | std::cout << "BOS: " << tok_ptr->bos_tok() << std::endl; 82 | std::cout << "EOS: " << tok_ptr->eos_tok() << std::endl << std::endl; 83 | 84 | // Encode 85 | std::cout << "PROMPT:" << std::endl << prompt << std::endl << std::endl; 86 | std::cout << "Encoding..." << std::endl; 87 | const auto encoded_result = tok_ptr->encode(prompt, 0, 0); 88 | const auto encoded = encoded_result.get(); 89 | std::cout << "["; 90 | for (const auto tok_id : encoded) { 91 | std::cout << " " << tok_id; 92 | } 93 | std::cout << " ]" << std::endl << std::endl; 94 | 95 | // Decode 96 | std::cout << "Decoding..." << std::endl; 97 | uint64_t prev = tok_ptr->bos_tok(); 98 | for (const auto& current : encoded) { 99 | const auto decoded_result = tok_ptr->decode(prev, current); 100 | std::cout << decoded_result.get(); 101 | prev = current; 102 | } 103 | std::cout << std::endl; 104 | 105 | return 0; 106 | } 107 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/base64.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | /************************************************************************** 10 | Copyright (c) 2023 sewenew 11 | 12 | Licensed under the Apache License, Version 2.0 (the "License"); 13 | you may not use this file except in compliance with the License. 14 | You may obtain a copy of the License at 15 | 16 | http://www.apache.org/licenses/LICENSE-2.0 17 | 18 | Unless required by applicable law or agreed to in writing, software 19 | distributed under the License is distributed on an "AS IS" BASIS, 20 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | See the License for the specific language governing permissions and 22 | limitations under the License. 23 | *************************************************************************/ 24 | 25 | #pragma once 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include "result.h" 33 | 34 | namespace base64 { 35 | 36 | using tokenizers::Error; 37 | using tokenizers::Result; 38 | 39 | Result decode(const std::string_view& input); 40 | 41 | namespace detail { 42 | 43 | constexpr uint32_t DECODE_TABLE[] = { 44 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 45 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 46 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 47 | 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 48 | 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 49 | 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 50 | 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, 51 | 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 52 | 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 53 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 54 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 55 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 56 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 57 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 58 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 59 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 60 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 61 | 255}; 62 | 63 | inline Error validate(uint32_t v) { 64 | if (v == 255) { 65 | fprintf(stderr, "invalid char"); 66 | return Error::Base64DecodeFailure; 67 | } 68 | return Error::Ok; 69 | } 70 | 71 | inline Error decode(const std::string_view& input, std::string& output) { 72 | TK_CHECK_OR_RETURN_ERROR( 73 | input.size() == 4, 74 | Base64DecodeFailure, 75 | "input length must be 4, got %zu", 76 | input.size()); 77 | 78 | uint32_t val = 0; 79 | 80 | uint8_t c = input[0]; 81 | auto v = DECODE_TABLE[c]; 82 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 83 | val = v; 84 | 85 | c = input[1]; 86 | v = DECODE_TABLE[c]; 87 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 88 | val = (val << 6) | v; 89 | 90 | c = input[2]; 91 | v = DECODE_TABLE[c]; 92 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 93 | val = (val << 6) | v; 94 | 95 | c = input[3]; 96 | v = DECODE_TABLE[c]; 97 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 98 | val = (val << 6) | v; 99 | 100 | output.push_back(static_cast((val >> 16) & 0xFF)); 101 | output.push_back(static_cast((val >> 8) & 0xFF)); 102 | output.push_back(static_cast(val & 0xFF)); 103 | return Error::Ok; 104 | } 105 | 106 | inline Error decode_1_padding( 107 | const std::string_view& input, 108 | std::string& output) { 109 | TK_CHECK_OR_RETURN_ERROR( 110 | input.size() == 3, 111 | Base64DecodeFailure, 112 | "input length must be 3, got %zu", 113 | input.size()); 114 | 115 | uint32_t val = 0; 116 | 117 | uint8_t c = input[0]; 118 | auto v = DECODE_TABLE[c]; 119 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 120 | val = v; 121 | 122 | c = input[1]; 123 | v = DECODE_TABLE[c]; 124 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 125 | val = (val << 6) | v; 126 | 127 | c = input[2]; 128 | v = DECODE_TABLE[c]; 129 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 130 | val = (val << 6) | v; 131 | 132 | output.push_back(static_cast((val >> 10) & 0xFF)); 133 | output.push_back(static_cast((val >> 2) & 0xFF)); 134 | return Error::Ok; 135 | } 136 | 137 | inline Error decode_2_padding( 138 | const std::string_view& input, 139 | std::string& output) { 140 | TK_CHECK_OR_RETURN_ERROR( 141 | input.size() == 2, 142 | Base64DecodeFailure, 143 | "input length must be 2, got %zu", 144 | input.size()); 145 | 146 | uint32_t val = 0; 147 | 148 | uint8_t c = input[0]; 149 | auto v = DECODE_TABLE[c]; 150 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 151 | val = v; 152 | 153 | c = input[1]; 154 | v = DECODE_TABLE[c]; 155 | TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); 156 | val = (val << 6) | v; 157 | 158 | output.push_back(static_cast((val >> 4) & 0xFF)); 159 | return Error::Ok; 160 | } 161 | 162 | } // namespace detail 163 | 164 | inline tokenizers::Result decode(const std::string_view& input) { 165 | TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input"); 166 | 167 | // Faster than `input.size() % 4`. 168 | TK_CHECK_OR_RETURN_ERROR( 169 | (input.size() & 3) == 0 && input.size() >= 4, 170 | Base64DecodeFailure, 171 | "input length must be larger than 4 and is multiple of 4, got %zu", 172 | input.size()); 173 | 174 | std::string output; 175 | output.reserve(input.size() / 4 * 3); 176 | auto idx = 0U; 177 | for (; idx < input.size() - 4; idx += 4) { 178 | TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); 179 | } 180 | 181 | // Last 4 bytes. Might contain paddings. 182 | if (input[idx + 3] == '=') { 183 | if (input[idx + 2] == '=') { 184 | // Tow paddings. 185 | TK_CHECK_OK_OR_RETURN_ERROR( 186 | detail::decode_2_padding(input.substr(idx, 2), output)); 187 | } else { 188 | // One padding. 189 | TK_CHECK_OK_OR_RETURN_ERROR( 190 | detail::decode_1_padding(input.substr(idx, 3), output)); 191 | } 192 | } else { 193 | // No padding. 194 | TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); 195 | } 196 | 197 | return output; 198 | } 199 | } // namespace base64 200 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/bpe_tokenizer_base.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | // Base class for all BPE tokenizer implementations 11 | #pragma once 12 | 13 | // Standard 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | // Local 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include "re2/re2.h" 29 | 30 | namespace tokenizers { 31 | namespace detail { 32 | 33 | using TokenMap = StringIntegerMap<>; 34 | 35 | template 36 | static Result buildTokenMap( 37 | std::vector> container) { 38 | static_assert( 39 | std::is_same_v || 40 | std::is_same_v, 41 | "TToken must be std::string or std::string_view"); 42 | static_assert( 43 | std::is_integral_v && std::is_unsigned_v, 44 | "TRank must be an unsigned integer"); 45 | 46 | std::sort( 47 | container.begin(), container.end(), [](const auto& a, const auto& b) { 48 | return a.first < b.first; 49 | }); 50 | 51 | auto duplicate_begin = std::unique( 52 | container.begin(), container.end(), [](const auto& a, const auto& b) { 53 | return a.first == b.first; 54 | }); 55 | 56 | TK_CHECK_OR_RETURN_ERROR( 57 | duplicate_begin == container.end(), 58 | ParseFailure, 59 | "duplicate token: %s rank: %llu", 60 | duplicate_begin->first.c_str(), 61 | static_cast(duplicate_begin->second)); 62 | 63 | std::sort( 64 | container.begin(), container.end(), [](const auto& a, const auto& b) { 65 | return a.second < b.second; 66 | }); 67 | 68 | duplicate_begin = std::unique( 69 | container.begin(), container.end(), [](const auto& a, const auto& b) { 70 | return a.second == b.second; 71 | }); 72 | 73 | TK_CHECK_OR_RETURN_ERROR( 74 | duplicate_begin == container.end(), 75 | ParseFailure, 76 | "duplicate rank: %llu" 77 | " token: %s", 78 | static_cast(duplicate_begin->second), 79 | duplicate_begin->first.c_str()); 80 | 81 | return TokenMap(container); 82 | }; 83 | 84 | template 85 | static Result buildTokenMap( 86 | const TContainer& container, 87 | TTokenAccessor token_accessor, 88 | TRankAccessor rank_accessor) { 89 | using TokenType = std::invoke_result_t; 90 | using RankType = std::invoke_result_t; 91 | 92 | static_assert( 93 | std::is_same_v || 94 | std::is_same_v, 95 | "TokenType must be std::string or std::string_view"); 96 | static_assert( 97 | std::is_integral_v && std::is_unsigned_v, 98 | "RankType must be an unsigned integer"); 99 | 100 | std::vector> pairs; 101 | pairs.reserve(container.size()); 102 | for (const auto& value : container) { 103 | pairs.emplace_back(token_accessor(value), rank_accessor(value)); 104 | } 105 | 106 | return buildTokenMap(std::move(pairs)); 107 | } 108 | 109 | inline Result> build_special_token_regex( 110 | const TokenMap& special_token_map) { 111 | std::string special_pattern; 112 | const std::size_t count = special_token_map.size(); 113 | 114 | for (std::size_t i = 0; i < count; ++i) { 115 | const auto& [token, _] = special_token_map.getElement(i); 116 | if (!special_pattern.empty()) { 117 | special_pattern += "|"; 118 | } 119 | special_pattern += re2::RE2::QuoteMeta(std::string(token)); 120 | } 121 | 122 | if (special_pattern.empty()) { 123 | return static_cast>(nullptr); 124 | } 125 | return create_regex(special_pattern); 126 | } 127 | 128 | class BPETokenizerBase : public Tokenizer { 129 | public: 130 | Result> 131 | encode(const std::string& input, int8_t bos, int8_t eos) const override; 132 | 133 | Result decode(uint64_t prev_token, uint64_t token) 134 | const override; 135 | 136 | protected: 137 | explicit BPETokenizerBase() {} 138 | virtual ~BPETokenizerBase() override {} 139 | 140 | std::pair, std::string> 141 | split_with_allowed_special_token_( 142 | const std::string& input, 143 | const TokenMap& allowed_special) const; 144 | 145 | std::pair, std::string> 146 | split_with_allowed_special_token_( 147 | const std::string& input, 148 | size_t offset, 149 | const TokenMap& allowed_special) const; 150 | 151 | Result, uint64_t>> encode_with_special_token_( 152 | const std::string& text, 153 | const TokenMap& allowed_special) const; 154 | 155 | Result> byte_pair_encode_( 156 | const std::string& piece, 157 | const TokenMap& encoder) const; 158 | 159 | // Protected members that can be overloaded by other BPE tokenizers 160 | std::unique_ptr special_token_regex_; 161 | std::optional token_map_; 162 | std::optional special_token_map_; 163 | 164 | private: 165 | virtual Error _encode( 166 | const std::string& input, 167 | std::vector& ret, 168 | uint64_t& last_piece_token_len) const = 0; 169 | 170 | virtual void _decode(const std::string& input, std::string& ret) const = 0; 171 | }; 172 | 173 | } // namespace detail 174 | } // namespace tokenizers 175 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/error.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /** 10 | * @file 11 | * Tokenizers Error declarations. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | #include 18 | 19 | namespace tokenizers { 20 | 21 | // Alias error code integral type to minimal platform width (32-bits for now). 22 | typedef uint32_t error_code_t; 23 | 24 | /** 25 | * ExecuTorch Error type. 26 | */ 27 | enum class Error : error_code_t { 28 | /* 29 | * System errors. 30 | */ 31 | 32 | /// Status indicating a successful operation. 33 | Ok = 0x00, 34 | 35 | /// An internal error occurred. 36 | Internal = 0x01, 37 | 38 | /// Tokenizer uninitialized. 39 | Uninitialized = 0x02, 40 | 41 | /// Token out of range. 42 | OutOfRange = 0x03, 43 | 44 | /// Tokenizer artifact load failure. 45 | LoadFailure = 0x04, 46 | 47 | /// Encode failure. 48 | EncodeFailure = 0x05, 49 | 50 | /// Base64 decode failure. 51 | Base64DecodeFailure = 0x06, 52 | 53 | /// Failed to parse tokenizer artifact. 54 | ParseFailure = 0x07, 55 | 56 | /// Decode failure. 57 | DecodeFailure = 0x08, 58 | 59 | /// No suitable regex implementation found. 60 | RegexFailure = 0x09, 61 | }; 62 | 63 | } // namespace tokenizers 64 | 65 | /** 66 | * If cond__ is false, return the specified Error 67 | * from the current function, which must be of return type 68 | * tokenizers::Error. 69 | * TODO: Add logging support 70 | * @param[in] cond__ The condition to be checked, asserted as true. 71 | * @param[in] error__ Error enum value to return without the `Error::` prefix, 72 | * like `Base64DecodeFailure`. 73 | * @param[in] message__ Format string for the log error message. 74 | * @param[in] ... Optional additional arguments for the format string. 75 | */ 76 | #define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \ 77 | { \ 78 | if (!(cond__)) { \ 79 | TK_LOG(Error, message__, ##__VA_ARGS__); \ 80 | return ::tokenizers::Error::error__; \ 81 | } \ 82 | } 83 | 84 | /** 85 | * If error__ is not Error::Ok, return the specified Error 86 | * TODO: Add logging support 87 | * @param[in] error__ Error enum value to return without the `Error::` prefix, 88 | * like `Base64DecodeFailure`. 89 | * @param[in] ... Optional format string for the log error message and its 90 | * arguments. 91 | */ 92 | #define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ 93 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__) 94 | 95 | // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. 96 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ 97 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ 98 | __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ 99 | (__VA_ARGS__) 100 | 101 | /** 102 | * Internal only: Use TK_CHECK_OK_OR_RETURN_ERROR() instead. 103 | * This macro selects the correct version of 104 | * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed. 105 | * It uses a trick with the preprocessor to count the number of arguments and 106 | * then selects the appropriate macro. 107 | * 108 | * The macro expansion uses __VA_ARGS__ to accept any number of arguments and 109 | * then appends them to TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the 110 | * count of arguments. The count is determined by the macro 111 | * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and 112 | * passes them along with a sequence of numbers (2, 1). The preprocessor then 113 | * matches this sequence to the correct number of arguments provided. 114 | * 115 | * If two arguments are passed, TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is 116 | * selected, suitable for cases where an error code and a custom message are 117 | * provided. If only one argument is passed, 118 | * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases 119 | * with just an error code. 120 | * 121 | * Usage: 122 | * TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1 123 | * TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2 124 | */ 125 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ 126 | _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \ 127 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N 128 | 129 | // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. 130 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ 131 | do { \ 132 | const auto et_error__ = (error__); \ 133 | if (et_error__ != ::tokenizers::Error::Ok) { \ 134 | return et_error__; \ 135 | } \ 136 | } while (0) 137 | 138 | // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. 139 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \ 140 | do { \ 141 | const auto et_error__ = (error__); \ 142 | if (et_error__ != ::tokenizers::Error::Ok) { \ 143 | TK_LOG(Error, message__, ##__VA_ARGS__); \ 144 | return et_error__; \ 145 | } \ 146 | } while (0) 147 | 148 | // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. 149 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ 150 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 151 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ 152 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 153 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ 154 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 155 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ 156 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 157 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ 158 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 159 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ 160 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 161 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ 162 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 163 | #define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ 164 | TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 165 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/hf_tokenizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // Used by many Huggingface models. Adapted from a combination of the original 10 | // rust implementation (https://github.com/huggingface/tokenizers/tree/main) 11 | // and the corresponding support in llama.cpp 12 | // (https://github.com/ggerganov/llama.cpp) 13 | #pragma once 14 | 15 | // Standard 16 | #include 17 | 18 | // Local 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | namespace tokenizers { 26 | class HFTokenizer : public detail::BPETokenizerBase { 27 | public: 28 | /*-- Public Interface --*/ 29 | 30 | /** 31 | * Default initialize with no loaded data 32 | */ 33 | explicit HFTokenizer() {} 34 | ~HFTokenizer() {} 35 | 36 | /** 37 | * Load the model data into the 38 | */ 39 | Error load(const std::string& tokenizer_path) override; 40 | 41 | private: 42 | Error _encode( 43 | const std::string& input, 44 | std::vector& ret, 45 | uint64_t& last_piece_token_len) const override; 46 | 47 | void _decode(const std::string& input, std::string& ret) const override; 48 | 49 | PreTokenizer::Ptr _pretokenizer; 50 | TokenDecoder::Ptr _decoder; 51 | }; 52 | 53 | } // namespace tokenizers 54 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/llama2c_tokenizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude 9 | #pragma once 10 | #include 11 | #include 12 | 13 | namespace tokenizers { 14 | 15 | struct TokenIndex { 16 | const char* str; 17 | int32_t id; 18 | }; 19 | 20 | // A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code 21 | // won't work with this class, it needs to go through tokenizer.py first. 22 | class Llama2cTokenizer : public Tokenizer { 23 | public: 24 | explicit Llama2cTokenizer(); 25 | ~Llama2cTokenizer() override; 26 | 27 | Error load(const std::string& tokenizer_path) override; 28 | 29 | Result> 30 | encode(const std::string& input, int8_t bos, int8_t eos) const override; 31 | 32 | Result decode(uint64_t prev_token, uint64_t token) 33 | const override; 34 | 35 | private: 36 | std::unique_ptr vocab_ = nullptr; 37 | std::unique_ptr vocab_scores_ = nullptr; 38 | std::unique_ptr sorted_vocab_ = nullptr; 39 | unsigned int max_token_length_ = 0; 40 | unsigned char byte_pieces_[512]; // stores all single-byte strings 41 | }; 42 | 43 | } // namespace tokenizers 44 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/pcre2_regex.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | // Define PCRE2 code unit width before including pcre2.h 15 | #define PCRE2_CODE_UNIT_WIDTH 8 16 | #include 17 | 18 | #include 19 | 20 | namespace tokenizers { 21 | 22 | /** 23 | * @brief PCRE2-based implementation of IRegex. 24 | */ 25 | class Pcre2Regex : public IRegex { 26 | public: 27 | /** 28 | * @brief Construct a PCRE2 regex. 29 | */ 30 | explicit Pcre2Regex(){}; 31 | 32 | /** 33 | * @brief Compile the given regex pattern. 34 | * @param pattern The regex pattern to compile. 35 | * @return An Error object indicating success or failure of the compilation. 36 | */ 37 | virtual Error compile(const std::string& pattern) override; 38 | 39 | /** 40 | * @brief Destructor to clean up PCRE2 resources. 41 | */ 42 | ~Pcre2Regex(); 43 | 44 | /** 45 | * @brief Return all non-overlapping matches found in the input string. 46 | */ 47 | virtual std::vector find_all(const std::string& text) const override; 48 | 49 | private: 50 | pcre2_code* regex_; 51 | pcre2_match_data* match_data_; 52 | }; 53 | 54 | } // namespace tokenizers 55 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/pre_tokenizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #pragma once 11 | 12 | // Standard 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | // Third Party 19 | #include 20 | #include 21 | 22 | // Local 23 | #include 24 | 25 | namespace tokenizers { 26 | 27 | // -- Base --------------------------------------------------------------------- 28 | 29 | /** 30 | * Base class for all pre-tokenizers with a single virtual method to split the 31 | * input string piece 32 | */ 33 | class PreTokenizer { 34 | public: 35 | /** Shared pointer type */ 36 | typedef std::shared_ptr Ptr; 37 | 38 | /** Split the input string piece into sub-pieces 39 | * 40 | * This pre-tokenization may result in sub-pieces that are not contained 41 | * within the original input, therefore the resulting pieces will be owned by 42 | * the caller. 43 | * 44 | * NOTE: Pass by value per best practice 45 | * https://abseil.io/docs/cpp/guides/strings#string_view 46 | */ 47 | virtual std::vector pre_tokenize( 48 | const std::string& input) const = 0; 49 | 50 | virtual ~PreTokenizer() = default; 51 | }; // end class PreTokenizer 52 | 53 | // -- Factory ------------------------------------------------------------------ 54 | 55 | // Helper macro to standardize addition of config member fields 56 | #define CONFIG_MEMBER(type, name) \ 57 | std::optional name; \ 58 | PreTokenizerConfig& set_##name(type arg) { \ 59 | this->name = std::move(arg); \ 60 | return *this; \ 61 | } 62 | 63 | /** 64 | * Factory and config class for creating a new PreTokenizer 65 | * 66 | * This class is the central method for instantiating a PreTokenizer instance. 67 | * It contains the common construction logic and config parameter names for all 68 | * pre tokenizer constructor args. 69 | * 70 | * NOTE: When adding a new pre tokenizer, you must ensure its arguments are 71 | * added to this class and it's constructor is added in the implementation! 72 | * 73 | * Usage Example: 74 | * 75 | * const auto pre_tokenizer = PreTokenizerConfig("Sequence").set_pretokenizers( 76 | * {PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")} 77 | * ); 78 | * const auto pre_tokenized = pre_tokenizer->pre_tokenize("Hello World!"); 79 | */ 80 | class PreTokenizerConfig { 81 | public: 82 | /*------------------------*/ 83 | /* Public mutable members */ 84 | /*------------------------*/ 85 | 86 | /** 87 | * The Type name string matching from tokenizers 88 | * https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/mod.rs#L73 89 | */ 90 | std::string type; 91 | 92 | /** 93 | * Used by: RegexPreTokenizer, ByteLevelPreTokenizer 94 | */ 95 | CONFIG_MEMBER(std::string, pattern) 96 | 97 | /** 98 | * Used by: DigitsPreTokenizer 99 | */ 100 | CONFIG_MEMBER(bool, individual_digits) 101 | 102 | /** 103 | * Used by: ByteLevelPreTokenizer 104 | */ 105 | CONFIG_MEMBER(bool, add_prefix_space) 106 | 107 | /** 108 | * Used by: SequencePreTokenizer 109 | */ 110 | CONFIG_MEMBER(std::vector, pretokenizers) 111 | 112 | /*----------------*/ 113 | /* Public methods */ 114 | /*----------------*/ 115 | 116 | /** 117 | * Construct with the type 118 | */ 119 | explicit PreTokenizerConfig(std::string type = ""); 120 | 121 | /** 122 | * Construct the pre tokenizer instance from the member data 123 | */ 124 | PreTokenizer::Ptr create() const; 125 | 126 | /** 127 | * Populate from a json config file 128 | */ 129 | PreTokenizerConfig& parse_json(const nlohmann::json& json_config); 130 | 131 | }; // end class PreTokenizerConfig 132 | 133 | // -- Regex -------------------------------------------------------------------- 134 | // Used for general-purpose single-regex pre tokenization 135 | // CITE: 136 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/split.rs 137 | // 138 | // TODO: Support for "behavior" and "invert" options 139 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/normalizer.rs#L82 140 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/pattern.rs#L128 141 | 142 | class RegexPreTokenizer : public PreTokenizer { 143 | public: 144 | explicit RegexPreTokenizer(const std::string& pattern) 145 | : regex_(RegexPreTokenizer::create_regex_(pattern)) {} 146 | 147 | /** Pre-tokenize with the stored regex */ 148 | std::vector pre_tokenize(const std::string& input) const; 149 | 150 | protected: 151 | static std::unique_ptr create_regex_(const std::string& pattern); 152 | 153 | std::unique_ptr regex_; 154 | 155 | }; // end class RegexPreTokenizer 156 | 157 | // -- Digits ------------------------------------------------------------------- 158 | // Used by tokenizers 159 | // CITE: 160 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/digits.rs 161 | 162 | class DigitsPreTokenizer : public RegexPreTokenizer { 163 | public: 164 | explicit DigitsPreTokenizer(bool individual_digits = false) 165 | : RegexPreTokenizer( 166 | individual_digits ? R"([^\p{N}]+|\p{N})" 167 | : R"([^\p{N}]+|[\p{N}]+)") {} 168 | }; // end class DigitsPreTokenizer 169 | 170 | // -- ByteLevel ---------------------------------------------------------------- 171 | // Used by tokenizers 172 | // CITE: 173 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs 174 | 175 | class ByteLevelPreTokenizer : public PreTokenizer { 176 | public: 177 | /** 178 | * @param add_prefix_space: Whether to add a leading space to the first word 179 | * @param pattern: A user-supplied regex to use for token splitting. If not 180 | * provided, it use the standard GPT2 pattern. 181 | */ 182 | ByteLevelPreTokenizer( 183 | bool add_prefix_space = true, 184 | const std::string& pattern = ""); 185 | explicit ByteLevelPreTokenizer(const std::string& pattern) 186 | : ByteLevelPreTokenizer(true, pattern) {} 187 | 188 | /** Perform pre-tokenization */ 189 | std::vector pre_tokenize( 190 | const std::string& input) const override; 191 | 192 | private: 193 | const std::string pattern_; 194 | const bool add_prefix_space_; 195 | 196 | }; // end class ByteLevelPreTokenizer 197 | 198 | // -- Sequence ----------------------------------------------------------------- 199 | // Used by tokenizers 200 | // CITE: 201 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/sequence.rs 202 | 203 | class SequencePreTokenizer : public PreTokenizer { 204 | public: 205 | /** 206 | * @param pre_tokenizers: The sequence of owned pre-tokenizer objects to use 207 | */ 208 | explicit SequencePreTokenizer(std::vector pre_tokenizers); 209 | 210 | /** Perform pre-tokenization */ 211 | std::vector pre_tokenize( 212 | const std::string& input) const override; 213 | 214 | private: 215 | const std::vector pre_tokenizers_; 216 | 217 | }; // end class ByteLevelPreTokenizer 218 | 219 | } // namespace tokenizers 220 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/re2_regex.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include 17 | 18 | namespace tokenizers { 19 | 20 | /** 21 | * @brief RE2-based implementation of IRegex. 22 | */ 23 | class Re2Regex : public IRegex { 24 | public: 25 | /** 26 | * @brief Construct a RE2 regex. 27 | */ 28 | explicit Re2Regex() {} 29 | 30 | /** 31 | * @brief compile the given regex pattern. 32 | * @param pattern The regex pattern to compile. 33 | * @return An Error object indicating success or failure of the compilation. 34 | */ 35 | virtual Error compile(const std::string& pattern) override; 36 | 37 | /** 38 | * @brief Return all non-overlapping matches found in the input string. 39 | */ 40 | virtual std::vector find_all(const std::string& text) const override; 41 | 42 | private: 43 | std::unique_ptr regex_; 44 | }; 45 | 46 | } // namespace tokenizers 47 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/regex.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | namespace tokenizers { 18 | 19 | struct Match { 20 | size_t start; // starting index of the match 21 | size_t end; // ending index of the match (exclusive) 22 | }; 23 | 24 | /** 25 | * @brief Abstract interface for regex wrappers. 26 | */ 27 | class IRegex { 28 | public: 29 | virtual ~IRegex() = default; 30 | 31 | /** 32 | * @brief Compile the given regex pattern. 33 | * @param pattern The regex pattern to compile. 34 | * @return An Error object indicating success or failure of the compilation. 35 | */ 36 | virtual Error compile(const std::string& pattern) = 0; 37 | 38 | /** 39 | * @brief Find all non-overlapping matches in the input string. 40 | * 41 | * @param text The input string to search. 42 | * @return A vector of strings containing all matched substrings. 43 | */ 44 | virtual std::vector find_all(const std::string& text) const = 0; 45 | }; 46 | 47 | // Function pointer type for create_fallback_regex implementations 48 | using FallbackRegexFn = Result> (*)(const std::string&); 49 | 50 | /** 51 | * @brief Creates a regex instance. If no strong symbol defined, only 52 | * uses RE2. This is a weak symbol to allow other regex libraries to be 53 | * used. 54 | * 55 | * @param pattern The regex pattern to compile. 56 | * @return A unique pointer to an IRegex-compatible object. 57 | */ 58 | Result> create_regex(const std::string& pattern); 59 | 60 | bool register_override_fallback_regex(FallbackRegexFn fn); 61 | 62 | FallbackRegexFn get_fallback_regex(); 63 | 64 | } // namespace tokenizers 65 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/result.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /** 10 | * @file 11 | * Result type to be used in conjunction with Tokenizers Error type. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace tokenizers { 22 | 23 | /** 24 | * Result type wrapping either a value of type T or an error. 25 | * 26 | * Example use case: 27 | * @code 28 | * Result decode(uint64_t token); 29 | * 30 | * void generate() 31 | * @endcode 32 | */ 33 | template 34 | class Result final { 35 | public: 36 | /// `value_type` member for generic programming. 37 | typedef T value_type; 38 | 39 | /** 40 | * Creates a Result object from an Error. 41 | * 42 | * To preserve the invariant that `(result.error() == Error::Ok) == 43 | * result.ok()`, an `error` parameter value of `Error:Ok` will be converted to 44 | * a non-Ok value. 45 | */ 46 | /* implicit */ Result(Error error) 47 | : error_(error == Error::Ok ? Error::Internal : error), 48 | hasValue_(false) {} 49 | 50 | /// Value copy constructor. 51 | /* implicit */ Result(const T& val) : value_(val), hasValue_(true) {} 52 | 53 | /// Value move constructor. 54 | /* implicit */ Result(T&& val) : value_(std::move(val)), hasValue_(true) {} 55 | 56 | /// Result move constructor. 57 | /* implicit */ Result(Result&& rhs) noexcept : hasValue_(rhs.hasValue_) { 58 | if (hasValue_) { 59 | // Use the value type's move constructor. 60 | new (&value_) T(std::move(rhs.value_)); 61 | } else { 62 | error_ = rhs.error_; 63 | } 64 | } 65 | 66 | ~Result() { 67 | if (hasValue_) { 68 | // Manual value destruction. 69 | // Result "owns" the memory, so `delete` would segfault. 70 | value_.~T(); 71 | } 72 | } 73 | 74 | /** 75 | * Returns true if this Result has a value. 76 | * 77 | * If true, it is guaranteed that `error()` will return `Error::Ok`. 78 | * If false, it is guaranteed that `error()` will not return `Error::Ok`. 79 | */ 80 | bool ok() const { 81 | return hasValue_; 82 | } 83 | 84 | /** 85 | * Returns the error code of this Result. 86 | * 87 | * If this returns `Error::Ok`, it is guaranteed that `ok()` will return true. 88 | * If this does not return `Error:Ok`, it is guaranteed that `ok()` will 89 | * return false. 90 | */ 91 | Error error() const { 92 | if (hasValue_) { 93 | return Error::Ok; 94 | } else { 95 | return error_; 96 | } 97 | } 98 | 99 | /** 100 | * Returns a reference to the Result's value; longhand for operator*(). 101 | * 102 | * Only legal to call if `ok()` returns true. 103 | */ 104 | T& get() { 105 | CheckOk(); 106 | return value_; 107 | } 108 | 109 | /** 110 | * Returns a reference to the Result's value; longhand for operator*(). 111 | * 112 | * Only legal to call if `ok()` returns true. 113 | */ 114 | const T& get() const { 115 | CheckOk(); 116 | return value_; 117 | } 118 | 119 | /* 120 | * Returns a reference to the Result's value; shorthand for get(). 121 | * 122 | * Only legal to call if `ok()` returns true. 123 | */ 124 | const T& operator*() const&; 125 | T& operator*() &; 126 | 127 | /* 128 | * Returns a pointer to the Result's value. 129 | * 130 | * Only legal to call if `ok()` returns true. 131 | */ 132 | const T* operator->() const; 133 | T* operator->(); 134 | 135 | private: 136 | /** 137 | * Delete default constructor since all Results should contain a value or 138 | * error. 139 | */ 140 | Result() = delete; 141 | /// Delete copy constructor since T may not be copyable. 142 | Result(const Result&) = delete; 143 | /// Delete copy assignment since T may not be copyable. 144 | Result& operator=(const Result&) = delete; 145 | /// Delete move assignment since it's not a supported pattern to reuse Result. 146 | Result& operator=(Result&& rhs) = delete; 147 | 148 | // Panics if ok() would return false; 149 | void CheckOk() const { 150 | assert(hasValue_ && "Result must be ok to access value."); 151 | } 152 | 153 | union { 154 | T value_; // Used if hasValue_ is true. 155 | Error error_; // Used if hasValue_ is false. 156 | }; 157 | 158 | /// True if the Result contains a value. 159 | const bool hasValue_; 160 | }; 161 | 162 | template 163 | const T& Result::operator*() const& { 164 | CheckOk(); 165 | return value_; 166 | } 167 | 168 | template 169 | T& Result::operator*() & { 170 | CheckOk(); 171 | return value_; 172 | } 173 | 174 | template 175 | const T* Result::operator->() const { 176 | CheckOk(); 177 | return &value_; 178 | } 179 | 180 | template 181 | T* Result::operator->() { 182 | CheckOk(); 183 | return &value_; 184 | } 185 | 186 | } // namespace tokenizers 187 | 188 | /** 189 | * Unwraps a Result value, throwing a runtime_error if the result contains an 190 | * error. 191 | * 192 | * @param[in] result__ The Result to unwrap 193 | */ 194 | #define TK_UNWRAP_THROW(result__) \ 195 | ({ \ 196 | auto unwrap_result__ = (result__); \ 197 | if (!unwrap_result__.ok()) { \ 198 | throw std::runtime_error( \ 199 | "Error: " + \ 200 | std::to_string(static_cast(unwrap_result__.error()))); \ 201 | } \ 202 | std::move(unwrap_result__.get()); \ 203 | }) 204 | 205 | /** 206 | * Unwrap a Result to obtain its value. If the Result contains an error, 207 | * propogate the error via trivial function return. 208 | * 209 | * Note: A function using TK_UNWRAP should itself return a Result or Error. 210 | * 211 | * @param[in] result__ Expression yielding the result to unwrap. 212 | * @param[in] ... Optional format string for the log error message and its 213 | * arguments. 214 | */ 215 | #define TK_UNWRAP(result__, ...) TK_INTERNAL_UNWRAP(result__, ##__VA_ARGS__) 216 | 217 | // Internal only: Use TK_UNWRAP() instead. 218 | #define TK_INTERNAL_UNWRAP(...) \ 219 | TK_INTERNAL_UNWRAP_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ 220 | (__VA_ARGS__) 221 | 222 | // Internal only: Use TK_UNWRAP() instead. 223 | #define TK_INTERNAL_UNWRAP_SELECT( \ 224 | _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \ 225 | TK_INTERNAL_UNWRAP_##N 226 | 227 | // Internal only: Use TK_UNWRAP() instead. 228 | #define TK_INTERNAL_UNWRAP_1(result__) \ 229 | ({ \ 230 | auto et_result__ = (result__); \ 231 | if (!et_result__.ok()) { \ 232 | return et_result__.error(); \ 233 | } \ 234 | std::move(*et_result__); \ 235 | }) 236 | 237 | // Internal only: Use TK_UNWRAP() instead. 238 | #define TK_INTERNAL_UNWRAP_2(result__, message__, ...) \ 239 | ({ \ 240 | auto et_result__ = (result__); \ 241 | if (!et_result__.ok()) { \ 242 | TK_LOG(Error, message__, ##__VA_ARGS__); \ 243 | return et_result__.error(); \ 244 | } \ 245 | std::move(*et_result__); \ 246 | }) 247 | 248 | // Internal only: Use TK_UNWRAP() instead. 249 | #define TK_INTERNAL_UNWRAP_3 TK_INTERNAL_UNWRAP_2 250 | #define TK_INTERNAL_UNWRAP_4 TK_INTERNAL_UNWRAP_2 251 | #define TK_INTERNAL_UNWRAP_5 TK_INTERNAL_UNWRAP_2 252 | #define TK_INTERNAL_UNWRAP_6 TK_INTERNAL_UNWRAP_2 253 | #define TK_INTERNAL_UNWRAP_7 TK_INTERNAL_UNWRAP_2 254 | #define TK_INTERNAL_UNWRAP_8 TK_INTERNAL_UNWRAP_2 255 | #define TK_INTERNAL_UNWRAP_9 TK_INTERNAL_UNWRAP_2 256 | #define TK_INTERNAL_UNWRAP_10 TK_INTERNAL_UNWRAP_2 257 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/sentencepiece.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | // A tokenizer that works with sentencepiece. Used by Llama2. 11 | #pragma once 12 | 13 | #include 14 | #include 15 | #include 16 | #include "sentencepiece_processor.h" 17 | namespace tokenizers { 18 | 19 | struct TokenIndex { 20 | const char* str; 21 | int32_t id; 22 | }; 23 | 24 | class SPTokenizer : public Tokenizer { 25 | public: 26 | explicit SPTokenizer(); 27 | ~SPTokenizer() override; 28 | 29 | Error load(const std::string& tokenizer_path) override; 30 | 31 | Result> 32 | encode(const std::string& input, int8_t bos, int8_t eos) const override; 33 | 34 | Result decode(uint64_t prev_token, uint64_t token) 35 | const override; 36 | 37 | private: 38 | std::unique_ptr _processor; 39 | }; 40 | 41 | } // namespace tokenizers 42 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/std_regex.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include "regex.h" 15 | 16 | namespace tokenizers { 17 | 18 | /** 19 | * @brief std::regex-based implementation of IRegex. 20 | */ 21 | class StdRegex : public IRegex { 22 | public: 23 | /** 24 | * @brief Construct a std::regex wrapper. 25 | */ 26 | explicit StdRegex() {} 27 | 28 | /** 29 | * @brief Compile the given regex pattern. 30 | * @param pattern The regex pattern to compile. 31 | * @return An Error object indicating success or failure of the compilation. 32 | */ 33 | virtual Error compile(const std::string& pattern) override; 34 | 35 | /** 36 | * @brief Find all non-overlapping matches in the input string. 37 | */ 38 | virtual std::vector find_all(const std::string& text) const override; 39 | 40 | private: 41 | std::regex regex_; 42 | }; 43 | 44 | } // namespace tokenizers 45 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/tiktoken.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | // Tiktoken header 11 | // Used by OpenAI, adapted from https://github.com/sewenew/tokenizer 12 | #pragma once 13 | 14 | // Standard 15 | #include 16 | 17 | // Third Party 18 | #include 19 | 20 | // Local 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | namespace tokenizers { 27 | 28 | static constexpr int32_t kSpecialTokensSize = 256; 29 | static constexpr size_t kBOSTokenIndex = 0; 30 | static constexpr size_t kEOSTokenIndex = 1; 31 | 32 | class Tiktoken : public detail::BPETokenizerBase { 33 | public: 34 | explicit Tiktoken( 35 | std::string pattern, 36 | std::unique_ptr> special_tokens, 37 | size_t bos_token_index, 38 | size_t eos_token_index) 39 | : _pattern(std::move(pattern)), 40 | _special_tokens(std::move(special_tokens)), 41 | _bos_token_index(bos_token_index), 42 | _eos_token_index(eos_token_index) { 43 | if (_bos_token_index >= _special_tokens->size() || 44 | _eos_token_index >= _special_tokens->size()) { 45 | abort(); 46 | } 47 | } 48 | 49 | explicit Tiktoken( 50 | std::unique_ptr> special_tokens, 51 | size_t bos_token_index, 52 | size_t eos_token_index) 53 | : Tiktoken( 54 | _get_default_patern(), 55 | std::move(special_tokens), 56 | bos_token_index, 57 | eos_token_index) {} 58 | 59 | explicit Tiktoken() 60 | : _pattern(_get_default_patern()), 61 | _special_tokens(_get_default_special_tokens()), 62 | _bos_token_index(kBOSTokenIndex), 63 | _eos_token_index(kEOSTokenIndex){}; 64 | 65 | Error load(const std::string& tokenizer_path) override; 66 | 67 | private: 68 | static inline std::unique_ptr> 69 | _get_default_special_tokens() { 70 | auto special_tokens = 71 | std::make_unique>(std::vector{ 72 | "<|begin_of_text|>", 73 | "<|end_of_text|>", 74 | "<|reserved_special_token_0|>", 75 | "<|reserved_special_token_1|>", 76 | "<|finetune_right_pad_id|>", 77 | "<|step_id|>", 78 | "<|start_header_id|>", 79 | "<|end_header_id|>", 80 | "<|eom_id|>", 81 | "<|eot_id|>", 82 | "<|python_tag|>"}); 83 | // pad the rest of the special tokens with reserved tokens 84 | ssize_t reserved_special_token_num = 2; 85 | while (special_tokens->size() < kSpecialTokensSize) { 86 | special_tokens->emplace_back( 87 | "<|reserved_special_token_" + 88 | std::to_string(reserved_special_token_num++) + "|>"); 89 | } 90 | return special_tokens; 91 | } 92 | 93 | static inline std::string _get_default_patern() { 94 | // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. 95 | return R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; 96 | } 97 | 98 | Error _encode( 99 | const std::string& input, 100 | std::vector& ret, 101 | uint64_t& last_piece_token_len) const override; 102 | 103 | void _decode(const std::string& input, std::string& ret) const override; 104 | 105 | detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const; 106 | 107 | std::string _pattern; 108 | std::unique_ptr> _special_tokens; 109 | size_t _bos_token_index; 110 | size_t _eos_token_index; 111 | 112 | std::unique_ptr _regex; 113 | }; 114 | 115 | } // namespace tokenizers 116 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/token_decoder.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #pragma once 11 | 12 | // Standard 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Third Party 20 | #include 21 | #include 22 | 23 | namespace tokenizers { 24 | 25 | // -- Base --------------------------------------------------------------------- 26 | 27 | /** 28 | * Base class for all token decoders 29 | */ 30 | class TokenDecoder { 31 | public: 32 | /* -- Types -- */ 33 | 34 | /** Shared pointer type */ 35 | typedef std::shared_ptr Ptr; 36 | 37 | /* -- Virtual Methods -- */ 38 | 39 | /** Decode a sequence of tokens into another sequence of tokens 40 | * 41 | * This is the primary virtual method that all decoders must implement. It may 42 | * change the size/layout of tokens between the input and output vectors. 43 | * 44 | * @param token: The pre-decoding token string 45 | * 46 | * @returns decoded: The decoded token string 47 | */ 48 | virtual std::string decode(const std::string& token) const = 0; 49 | 50 | // virtual destructor 51 | virtual ~TokenDecoder() = default; 52 | 53 | }; // end class TokenDecoder 54 | 55 | // -- Factory ------------------------------------------------------------------ 56 | 57 | /** 58 | * Factory and config class for creating a new TokenDecoder 59 | */ 60 | class TokenDecoderConfig { 61 | public: 62 | /** 63 | * The Type name string matching from decoders 64 | * https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/mod.rs#L55 65 | */ 66 | std::string type; 67 | 68 | /*----------------*/ 69 | /* Public methods */ 70 | /*----------------*/ 71 | 72 | /** 73 | * Construct with the type 74 | */ 75 | explicit TokenDecoderConfig(std::string type = ""); 76 | 77 | /** 78 | * Construct the pre tokenizer instance from the member data 79 | */ 80 | TokenDecoder::Ptr create() const; 81 | 82 | /** 83 | * Populate from a json config file 84 | */ 85 | TokenDecoderConfig& parse_json(const nlohmann::json& json_config); 86 | }; // end class TokenDecoderConfig 87 | 88 | // -- ByteLevel ---------------------------------------------------------------- 89 | // Used by tokenizers 90 | // CITE: 91 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs 92 | 93 | class ByteLevelTokenDecoder : public TokenDecoder { 94 | public: 95 | std::string decode(const std::string& token) const override; 96 | 97 | }; // end class ByteLevelTokenDecoder 98 | 99 | } // namespace tokenizers 100 | -------------------------------------------------------------------------------- /include/pytorch/tokenizers/tokenizer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | /** 10 | * @file 11 | * Tokenizer interface declaration. 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace tokenizers { 22 | 23 | class Tokenizer { 24 | public: 25 | explicit Tokenizer() {} 26 | virtual ~Tokenizer() {} 27 | 28 | virtual Error load(const std::string& tokenizer_path) = 0; 29 | 30 | virtual Result> 31 | encode(const std::string& input, int8_t bos, int8_t eos) const = 0; 32 | 33 | Error decode_verify(uint64_t token) const { 34 | if (!initialized_) { 35 | return Error::Uninitialized; 36 | } 37 | if (token >= vocab_size_) { 38 | return Error::OutOfRange; 39 | } 40 | return Error::Ok; 41 | } 42 | 43 | virtual Result decode(uint64_t prev_token, uint64_t token) 44 | const = 0; 45 | 46 | // getters 47 | int32_t vocab_size() const { 48 | return vocab_size_; 49 | } 50 | 51 | uint64_t bos_tok() const { 52 | return bos_tok_; 53 | } 54 | 55 | uint64_t eos_tok() const { 56 | return eos_tok_; 57 | } 58 | 59 | virtual bool is_loaded() const { 60 | return initialized_; 61 | } 62 | 63 | protected: 64 | bool initialized_ = false; 65 | int32_t vocab_size_ = 0; 66 | uint64_t bos_tok_ = 0, eos_tok_ = 0; 67 | }; 68 | 69 | } // namespace tokenizers 70 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "cmake", # For building binary targets in the wheel. 4 | "pip>=23", # For building the pip package. 5 | "setuptools>=63", # For building the pip package contents. 6 | "wheel", # For building the pip package archive. 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | 10 | [project] 11 | name = "pytorch_tokenizers" 12 | dynamic = [ 13 | # setup.py will set the version. 14 | 'version', 15 | 'readme', 16 | ] 17 | description = "A package with common tokenizers in Python and C++" 18 | authors = [ 19 | {name="PyTorch Team", email="packages@pytorch.org"}, 20 | ] 21 | license = {file = "LICENSE"} 22 | keywords = ["pytorch", "machine learning", "llm"] 23 | # PyPI package information. 24 | classifiers = [ 25 | # How mature is this project? Common values are 26 | # 3 - Alpha 27 | # 4 - Beta 28 | # 5 - Production/Stable 29 | "Development Status :: 4 - Beta", 30 | "Intended Audience :: Developers", 31 | "Intended Audience :: Education", 32 | "Intended Audience :: Science/Research", 33 | "License :: OSI Approved :: BSD License", 34 | "Topic :: Scientific/Engineering", 35 | "Topic :: Scientific/Engineering :: Mathematics", 36 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 37 | "Topic :: Software Development", 38 | "Topic :: Software Development :: Libraries", 39 | "Topic :: Software Development :: Libraries :: Python Modules", 40 | "Programming Language :: C++", 41 | "Programming Language :: Python :: 3", 42 | # Update this as we support more versions of python. 43 | "Programming Language :: Python :: 3.10", 44 | "Programming Language :: Python :: 3.11", 45 | "Programming Language :: Python :: 3.12", 46 | ] 47 | 48 | # Python dependencies required for use. 49 | requires-python = ">=3.10" 50 | dependencies=[ 51 | "tiktoken", 52 | "tokenizers", 53 | "sentencepiece", 54 | ] 55 | 56 | [project.urls] 57 | # The keys are arbitrary but will be visible on PyPI. 58 | Homepage = "https://pytorch.org/executorch/" 59 | Repository = "https://github.com/pytorch/executorch" 60 | Issues = "https://github.com/pytorch/executorch/issues" 61 | Changelog = "https://github.com/pytorch/executorch/releases" 62 | 63 | 64 | [tool.setuptools.exclude-package-data] 65 | "*" = ["*.pyc"] 66 | 67 | [tool.usort] 68 | # Do not try to put "first-party" imports in their own section. 69 | first_party_detection = false 70 | 71 | [tool.black] 72 | # Emit syntax compatible with older versions of python instead of only the range 73 | # specified by `requires-python`. TODO: Remove this once we support these older 74 | # versions of python and can expand the `requires-python` range. 75 | target-version = ["py38", "py39", "py310", "py311", "py312"] 76 | -------------------------------------------------------------------------------- /pytorch_tokenizers/TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain xplat-only targets. 3 | 4 | load("@fbcode_macros//build_defs:python_library.bzl", "python_library") 5 | load(":targets.bzl", "define_common_targets") 6 | 7 | oncall("executorch") 8 | 9 | define_common_targets() 10 | 11 | python_library( 12 | name = "hf_tokenizer", 13 | srcs = ["hf_tokenizer.py"], 14 | labels = ["autodeps2_generated"], 15 | deps = [ 16 | "fbsource//third-party/pypi/tokenizers:tokenizers", 17 | ], 18 | ) 19 | -------------------------------------------------------------------------------- /pytorch_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | 9 | from typing import Optional 10 | 11 | from .hf_tokenizer import HuggingFaceTokenizer 12 | from .llama2c import Llama2cTokenizer 13 | from .tiktoken import TiktokenTokenizer 14 | 15 | __all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"] 16 | 17 | 18 | def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): 19 | if tokenizer_path.endswith(".json"): 20 | tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) 21 | else: 22 | try: 23 | tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path)) 24 | except Exception: 25 | print("Using Tiktokenizer") 26 | tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path)) 27 | return tokenizer 28 | -------------------------------------------------------------------------------- /pytorch_tokenizers/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | CL100K_PAT_STR = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 9 | 10 | LLAMA_BASIC_SPECIAL_TOKENS = [ 11 | "<|begin_of_text|>", 12 | "<|end_of_text|>", 13 | "<|reserved_special_token_0|>", 14 | "<|reserved_special_token_1|>", 15 | "<|finetune_right_pad_id|>", 16 | "<|step_id|>", 17 | "<|start_header_id|>", 18 | "<|end_header_id|>", 19 | "<|eom_id|>", # end of message 20 | "<|eot_id|>", # end of turn 21 | "<|python_tag|>", 22 | "<|image|>", 23 | ] 24 | 25 | LLAMA_NUM_RESERVED_SPECIAL_TOKENS = 256 26 | LLAMA_RESERVED_SPECIAL_TOKENS = [ 27 | f"<|reserved_special_token_{2 + i}|>" 28 | for i in range(LLAMA_NUM_RESERVED_SPECIAL_TOKENS - len(LLAMA_BASIC_SPECIAL_TOKENS)) 29 | ] 30 | 31 | LLAMA_SPECIAL_TOKENS = LLAMA_BASIC_SPECIAL_TOKENS + LLAMA_RESERVED_SPECIAL_TOKENS 32 | -------------------------------------------------------------------------------- /pytorch_tokenizers/hf_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | import json 9 | import os 10 | from typing import List, Optional 11 | 12 | from tokenizers import Tokenizer 13 | 14 | 15 | class HuggingFaceTokenizer: 16 | """ 17 | Tokenizing and encoding/decoding text using the Hugging face tokenizer. 18 | """ 19 | 20 | def __init__(self, model_path: str, config_path: Optional[str] = None): 21 | """ 22 | Initializes the Tokenizer with a tokenizer.json from HuggingFace. 23 | 24 | Args: 25 | model_path (str): The path to the Tiktoken model file. 26 | """ 27 | assert os.path.isfile(model_path), model_path 28 | 29 | self.model = tokenizer = Tokenizer.from_file(model_path) 30 | 31 | self.n_words: int = tokenizer.get_vocab_size() 32 | if config_path: 33 | with open(config_path) as f: 34 | tokenizer_config = json.load(f) 35 | self.bos_id = ( 36 | self.model.token_to_id(tokenizer_config["bos_token"]) 37 | if tokenizer_config["bos_token"] 38 | else None 39 | ) 40 | self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"]) 41 | else: # Fallback guess. 42 | self.bos_id = self.model.token_to_id("<|begin_of_text|>") 43 | self.eos_id = self.model.token_to_id("<|endoftext|>") 44 | 45 | self.stop_tokens = [ 46 | self.eos_id, 47 | ] 48 | 49 | def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]: 50 | assert type(s) is str 51 | return self.model.encode(s).ids 52 | 53 | def decode(self, t: List[int]) -> str: 54 | return self.model.decode(t) 55 | 56 | def decode_token(self, t: int) -> str: 57 | return self.model.decode([t]) 58 | -------------------------------------------------------------------------------- /pytorch_tokenizers/llama2c.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | import logging 9 | import os 10 | import struct 11 | from typing import List 12 | 13 | from sentencepiece import SentencePieceProcessor as SentencePieceProcessor 14 | 15 | 16 | class Llama2cTokenizer: 17 | def __init__(self, model_path: str): 18 | assert os.path.isfile( 19 | model_path 20 | ), f"Need a valid tokenizer model path but got {model_path}" 21 | # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. 22 | self.sp_model = SentencePieceProcessor(model_file=model_path) 23 | self.model_path = model_path 24 | 25 | # BOS / EOS token IDs 26 | self.n_words: int = self.sp_model.vocab_size() 27 | self.bos_id: int = self.sp_model.bos_id() 28 | self.eos_id: int = self.sp_model.eos_id() 29 | logging.info( 30 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 31 | ) 32 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. 33 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 34 | 35 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 36 | assert type(s) is str 37 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 38 | t = self.sp_model.encode(s) 39 | if bos: 40 | t = [self.bos_id] + t 41 | if eos: 42 | t = t + [self.eos_id] 43 | return t 44 | 45 | def decode(self, t: List[int]) -> str: 46 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 47 | return self.sp_model.decode(t) 48 | 49 | def decode_token(self, t: int) -> str: 50 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 51 | return self.sp_model.decode(t) 52 | 53 | def export(self, output_path: str, *, prepend_padding: bool = False) -> None: 54 | """ 55 | Export tokenizer.model to another serialization format. Here we did some lightweight 56 | processing such as supporting prepend padding token, prepend max token length and 57 | replace '_' back to empty space. 58 | 59 | The binary format is: 60 | 1. vocab size: int32 61 | 2. bos token id: int32 62 | 3. eos token id: int32 63 | 4. max token length: int32 64 | 5. score: float32, len of bytes: int32, token bytes: [byte] for each token 65 | 66 | :param output_path: output path of the new binary. 67 | :param prepend_padding: a boolean to control if we want to prepend a padding token. 68 | 69 | :return: None 70 | """ 71 | 72 | # get all the tokens (postprocessed) and their scores as floats 73 | tokens, scores = [], [] 74 | 75 | if prepend_padding: 76 | # Here we use the default padding token and its score. 77 | tokens.append("".encode("utf-8")) 78 | scores.append(-1) 79 | 80 | for i in range(self.n_words): 81 | # decode the token and light postprocessing 82 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. 83 | t = self.sp_model.id_to_piece(i) 84 | # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. 85 | s = self.sp_model.get_score(i) 86 | # sentencepiece use '' as BOS and '' for EOS 87 | if i == self.bos_id: 88 | t = "" 89 | elif i == self.eos_id: 90 | t = "" 91 | t = t.replace("▁", " ") # sentencepiece uses this character as whitespace 92 | b = t.encode("utf-8") # bytes of this token, utf-8 encoded 93 | 94 | tokens.append(b) 95 | scores.append(s) 96 | 97 | # record the max token length 98 | max_token_length = 0 if not tokens else max(len(t) for t in tokens) 99 | 100 | # write to a binary file 101 | with open(output_path, "wb") as f: 102 | # write the vocab size, bos/eos ids and max token length 103 | f.write( 104 | struct.pack( 105 | "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length 106 | ) 107 | ) 108 | for bytes, score in zip(tokens, scores): 109 | f.write(struct.pack("fI", score, len(bytes))) 110 | f.write(bytes) 111 | logging.info(f"Wrote tokenizer to {output_path}") 112 | -------------------------------------------------------------------------------- /pytorch_tokenizers/targets.bzl: -------------------------------------------------------------------------------- 1 | load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") 2 | 3 | def define_common_targets(): 4 | """Defines targets that should be shared between fbcode and xplat. 5 | 6 | The directory containing this targets.bzl file should also contain both 7 | TARGETS and BUCK files that call this function. 8 | """ 9 | runtime.python_library( 10 | name = "tokenizers", 11 | srcs = [ 12 | "__init__.py", 13 | "constants.py", 14 | "llama2c.py", 15 | "tiktoken.py", 16 | "hf_tokenizer.py", 17 | ], 18 | base_module = "pytorch_tokenizers", 19 | visibility = ["PUBLIC"], 20 | _is_external_target = True, 21 | external_deps = [ 22 | "sentencepiece-py", 23 | ], 24 | deps = [ 25 | "fbsource//third-party/pypi/blobfile:blobfile", 26 | "fbsource//third-party/pypi/tiktoken:tiktoken", 27 | "fbsource//third-party/pypi/tokenizers:tokenizers", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /pytorch_tokenizers/tiktoken.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | import os 9 | from logging import getLogger 10 | from pathlib import Path 11 | from typing import ( 12 | AbstractSet, 13 | cast, 14 | Collection, 15 | Dict, 16 | Iterator, 17 | List, 18 | Literal, 19 | Optional, 20 | Sequence, 21 | Union, 22 | ) 23 | 24 | import tiktoken 25 | 26 | from tiktoken.load import load_tiktoken_bpe 27 | 28 | from .constants import CL100K_PAT_STR, LLAMA_SPECIAL_TOKENS 29 | 30 | logger = getLogger(__name__) 31 | 32 | 33 | # The tiktoken tokenizer can handle <=400k chars without 34 | # pyo3_runtime.PanicException. 35 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 36 | 37 | # https://github.com/openai/tiktoken/issues/195 38 | # Here we iterate over subsequences and split if we exceed the limit 39 | # of max consecutive non-whitespace or whitespace characters. 40 | MAX_NO_WHITESPACES_CHARS = 25_000 41 | 42 | 43 | _INSTANCE = None 44 | 45 | 46 | class TiktokenTokenizer: 47 | """ 48 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 49 | WARNING: The regex and special tokens are hardcoded from Llama 3+. 50 | """ 51 | 52 | @classmethod 53 | def get_instance(cls): 54 | global _INSTANCE 55 | 56 | if _INSTANCE is None: 57 | _INSTANCE = TiktokenTokenizer( 58 | os.path.join(os.path.dirname(__file__), "tokenizer.model") 59 | ) 60 | return _INSTANCE 61 | 62 | def __init__( 63 | self, 64 | model_path: str, 65 | pat_str: str = CL100K_PAT_STR, 66 | special_tokens: List[str] = LLAMA_SPECIAL_TOKENS, 67 | ): 68 | """ 69 | Initializes the Tokenizer with a Tiktoken model. 70 | 71 | Args: 72 | model_path (str): The path to the Tiktoken model file. 73 | """ 74 | assert os.path.isfile(model_path), model_path 75 | 76 | mergeable_ranks = load_tiktoken_bpe(model_path) 77 | num_base_tokens = len(mergeable_ranks) 78 | 79 | self.special_tokens = { 80 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 81 | } 82 | self.model = tiktoken.Encoding( 83 | name=Path(model_path).name, 84 | pat_str=pat_str, 85 | mergeable_ranks=mergeable_ranks, 86 | special_tokens=self.special_tokens, 87 | ) 88 | 89 | self.n_words: int = num_base_tokens + len(special_tokens) 90 | # BOS / EOS token IDs 91 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 92 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 93 | 94 | def encode( 95 | self, 96 | s: str, 97 | *, 98 | bos: bool, 99 | eos: bool, 100 | allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, 101 | disallowed_special: Union[Literal["all"], Collection[str]] = (), 102 | ) -> List[int]: 103 | """ 104 | Encodes a string into a list of token IDs. 105 | 106 | Args: 107 | s (str): The input string to be encoded. 108 | bos (bool): Whether to prepend the beginning-of-sequence token. 109 | eos (bool): Whether to append the end-of-sequence token. 110 | allowed_special ("all"|set[str]): allowed special tokens in string 111 | disallowed_special ("all"|set[str]): special tokens that raise an error when in string 112 | 113 | Returns: 114 | list[int]: A list of token IDs. 115 | 116 | By default, setting disallowed_special=() encodes a string by ignoring 117 | special tokens. Specifically: 118 | - Setting `disallowed_special` to () will cause all text corresponding 119 | to special tokens to be encoded as natural text (insteading of raising 120 | an error). 121 | - Setting `allowed_special` to "all" will treat all text corresponding 122 | to special tokens to be encoded as special tokens. 123 | """ 124 | if allowed_special is None: 125 | allowed_special = set() 126 | assert type(s) is str 127 | 128 | substrs = ( 129 | substr 130 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 131 | for substr in self._split_whitespaces_or_nonwhitespaces( 132 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS 133 | ) 134 | ) 135 | t: List[int] = [] 136 | for substr in substrs: 137 | t.extend( 138 | self.model.encode( 139 | substr, 140 | allowed_special=allowed_special, 141 | disallowed_special=disallowed_special, 142 | ) 143 | ) 144 | if bos: 145 | t.insert(0, self.bos_id) 146 | if eos: 147 | t.append(self.eos_id) 148 | return t 149 | 150 | def decode(self, t: Sequence[int]) -> str: 151 | """ 152 | Decodes a list of token IDs into a string. 153 | 154 | Args: 155 | t (List[int]): The list of token IDs to be decoded. 156 | 157 | Returns: 158 | str: The decoded string. 159 | """ 160 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 161 | return self.model.decode(cast(List[int], t)) 162 | 163 | def decode_token(self, t: int) -> str: 164 | """ 165 | Decodes a single token ID into a string. 166 | 167 | Args: 168 | t (int): The token ID to be decoded. 169 | 170 | Returns: 171 | str: The decoded string. 172 | """ 173 | return self.model.decode_single_token_bytes(t).decode("utf-8") 174 | 175 | @staticmethod 176 | def _split_whitespaces_or_nonwhitespaces( 177 | s: str, max_consecutive_slice_len: int 178 | ) -> Iterator[str]: 179 | """ 180 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 181 | consecutive whitespaces or consecutive non-whitespaces. 182 | """ 183 | current_slice_len = 0 184 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False 185 | slice_start = 0 186 | 187 | for i in range(len(s)): 188 | is_now_space = s[i].isspace() 189 | 190 | if current_slice_is_space ^ is_now_space: 191 | current_slice_len = 1 192 | current_slice_is_space = is_now_space 193 | else: 194 | current_slice_len += 1 195 | if current_slice_len > max_consecutive_slice_len: 196 | yield s[slice_start:i] 197 | slice_start = i 198 | current_slice_len = 1 199 | yield s[slice_start:] 200 | -------------------------------------------------------------------------------- /pytorch_tokenizers/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/tokenizers/fc32028858020c4fcafe37aaaeaf5d1b480336a2/pytorch_tokenizers/tools/__init__.py -------------------------------------------------------------------------------- /pytorch_tokenizers/tools/llama2c/TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain xplat-only targets. 3 | 4 | load(":targets.bzl", "define_common_targets") 5 | 6 | oncall("executorch") 7 | 8 | define_common_targets() 9 | -------------------------------------------------------------------------------- /pytorch_tokenizers/tools/llama2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/tokenizers/fc32028858020c4fcafe37aaaeaf5d1b480336a2/pytorch_tokenizers/tools/llama2c/__init__.py -------------------------------------------------------------------------------- /pytorch_tokenizers/tools/llama2c/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | 9 | # Script to rewrite tokenizer model given by sentencepiece to llama2.c format, with lightweight 10 | # postprocessing logic. The output can be consumed by llama2c_tokenizer.cpp. 11 | 12 | import argparse 13 | 14 | from pytorch_tokenizers.llama2c import Llama2cTokenizer 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "-t", 21 | "--tokenizer-model", 22 | type=str, 23 | default="tokenizer.model", 24 | help="path to tokenizer model, given by sentencepiece", 25 | ) 26 | parser.add_argument( 27 | "-o", 28 | "--output-path", 29 | type=str, 30 | default=None, 31 | help="output path of postprocessed tokenizer model", 32 | ) 33 | parser.add_argument( 34 | "-p", 35 | "--prepend-padding", 36 | action="store_true", 37 | help="whether to prepend a padding token to the beginning of the tokenizer", 38 | ) 39 | 40 | args = parser.parse_args() 41 | 42 | t = Llama2cTokenizer(args.tokenizer_model) 43 | 44 | output_path = ( 45 | args.output_path 46 | if args.output_path 47 | else args.tokenizer_model.replace(".model", ".bin") 48 | ) 49 | t.export(output_path, prepend_padding=args.prepend_padding) 50 | -------------------------------------------------------------------------------- /pytorch_tokenizers/tools/llama2c/targets.bzl: -------------------------------------------------------------------------------- 1 | load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") 2 | 3 | def define_common_targets(): 4 | """Defines targets that should be shared between fbcode and xplat. 5 | 6 | The directory containing this targets.bzl file should also contain both 7 | TARGETS and BUCK files that call this function. 8 | """ 9 | runtime.python_library( 10 | name = "convert_lib", 11 | srcs = [ 12 | "__init__.py", 13 | "convert.py", 14 | ], 15 | base_module = "pytorch_tokenizers.tools.llama2c", 16 | visibility = [ 17 | "//executorch/examples/...", 18 | "//executorch/extension/llm/export/...", 19 | "//bento/...", 20 | "//bento_kernels/...", 21 | "@EXECUTORCH_CLIENTS", 22 | ], 23 | _is_external_target = True, 24 | external_deps = [ 25 | "sentencepiece-py", 26 | ], 27 | deps = [ 28 | "//pytorch/tokenizers/pytorch_tokenizers:tokenizers", 29 | ], 30 | ) 31 | 32 | runtime.python_binary( 33 | name = "convert", 34 | main_module = "pytorch_tokenizers.tools.llama2c.convert", 35 | visibility = [ 36 | "//executorch/examples/...", 37 | "fbsource//xplat/executorch/examples/...", 38 | ], 39 | _is_external_target = True, 40 | deps = [ 41 | ":convert_lib", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /requirements-lintrunner.txt: -------------------------------------------------------------------------------- 1 | # Lintrunner itself 2 | lintrunner==0.11.0 3 | lintrunner-adapters==0.11.0 4 | 5 | # Flake 8 and its dependencies 6 | flake8==6.1.0 7 | flake8-breakpoint==1.1.0 8 | flake8-bugbear==23.9.16 9 | flake8-comprehensions==3.14.0 10 | flake8-pyi==23.5.0 11 | mccabe==0.7.0 12 | pycodestyle==2.11.1 13 | torchfix==0.5.0 14 | 15 | # UFMT 16 | black==24.4.2 17 | ufmt==2.6.0 18 | usort==1.0.5 19 | 20 | # Other linters 21 | clang-format==18.1.3 22 | cmakelint==1.4.1 23 | 24 | # MyPy 25 | mypy==1.14.1 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | # type: ignore[syntax] 8 | from setuptools import find_packages, setup 9 | 10 | with open("README.md", "r") as f: 11 | long_description = f.read() 12 | 13 | setup( 14 | version="0.1.0", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | packages=find_packages(), 18 | ) 19 | -------------------------------------------------------------------------------- /src/bpe_tokenizer_base.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #include 11 | 12 | // Standard 13 | #include 14 | #include 15 | 16 | namespace tokenizers { 17 | namespace detail { 18 | 19 | // ---- Helper utils start ----------------------------------------------------- 20 | namespace { 21 | 22 | static uint64_t _max_size() { 23 | return std::numeric_limits::max(); 24 | } 25 | 26 | static std::vector _byte_pair_merge( 27 | const std::string& piece, 28 | const TokenMap& ranks, 29 | std::function func) { 30 | // This is a vector of (start, rank). 31 | // The rank is of the byte pair starting at position start. 32 | // The rank of the last item in the vector is not a valid value. 33 | std::vector> parts; 34 | parts.reserve(piece.size() + 1); 35 | for (auto idx = 0U; idx < piece.size() + 1; ++idx) { 36 | parts.emplace_back(idx, _max_size()); 37 | } 38 | 39 | auto get_rank = [&piece, &ranks]( 40 | const std::vector>& parts, 41 | uint64_t start_idx, 42 | uint64_t skip) -> std::optional { 43 | if (start_idx + skip + 2 < parts.size()) { 44 | auto s = parts[start_idx].first; 45 | auto e = parts[start_idx + skip + 2].first; 46 | auto key = piece.substr(s, e - s); 47 | return ranks.tryGetInteger(key); 48 | } 49 | return std::nullopt; 50 | }; 51 | 52 | // We look up the ranks once in the beginning and iteratively update 53 | // them during each merge, which reduces the number of rank lookups. 54 | for (auto i = 0U; i < parts.size() - 2; ++i) { 55 | auto rank = get_rank(parts, i, 0); 56 | if (rank) { 57 | // usize::MAX is a sentinel value and cannot be a valid rank 58 | if (*rank == _max_size()) { 59 | TK_LOG(Error, "at %" PRIu32 " rank is too large\n", i); 60 | } 61 | parts[i].second = *rank; 62 | } 63 | } 64 | 65 | // If you have n parts and m merges, this does O(mn) work. 66 | // We could do something with a heap and do O(m log n) work. 67 | // It is important to consider that n is often small (<100), and as such 68 | // the cache-locality benefits outweigh the algorithmic complexity downsides 69 | // of the `parts` vector data structure above. 70 | 71 | // Note that we hash bytes, not token pairs. As long as we train BPE the way 72 | // we currently do, this is equivalent. An easy way to break this would be 73 | // to decouple merge priority from token index or to prevent specific token 74 | // merges. 75 | while (true) { 76 | if (parts.size() == 1) { 77 | break; 78 | } 79 | 80 | // usize::MAX is a sentinel rank value allowing us to 81 | // take the min more quickly 82 | auto min_rank = std::make_pair(_max_size(), 0); 83 | for (auto i = 0U; i < parts.size() - 1; ++i) { 84 | auto rank = parts[i].second; 85 | if (rank < min_rank.first) { 86 | min_rank.first = rank; 87 | min_rank.second = i; 88 | } 89 | } 90 | 91 | if (min_rank.first != _max_size()) { 92 | auto i = min_rank.second; 93 | 94 | // NOTE: We are about to remove parts[i + 1]. We do not do it 95 | // yet because there are cache-locality benefits to updating 96 | // parts[i] and parts[i-1] before removing, which could thrash 97 | // the cache. Thus, we update the rank calculation by skipping over 98 | // parts[i + 1], by invoking `get_rank!` with `skip = 1`. 99 | auto rank = get_rank(parts, i, 1); 100 | if (rank) { 101 | parts[i].second = *rank; 102 | } else { 103 | parts[i].second = _max_size(); 104 | } 105 | if (i > 0) { 106 | rank = get_rank(parts, i - 1, 1); 107 | if (rank) { 108 | parts[i - 1].second = *rank; 109 | } else { 110 | parts[i - 1].second = _max_size(); 111 | } 112 | } 113 | 114 | parts.erase(parts.begin() + (i + 1)); 115 | } else { 116 | break; 117 | } 118 | } 119 | std::vector out; 120 | out.reserve(parts.size() - 1); 121 | for (auto i = 0U; i < parts.size() - 1; ++i) { 122 | auto s = parts[i].first; 123 | auto e = parts[i + 1].first; 124 | out.push_back(func(s, e)); 125 | } 126 | return out; 127 | } 128 | 129 | } // namespace 130 | // ---- Helper utils end ------------------------------------------------------- 131 | // ---- protected start -------------------------------------------------------- 132 | 133 | std::pair, std::string> 134 | BPETokenizerBase::split_with_allowed_special_token_( 135 | const std::string& input, 136 | size_t offset, 137 | const TokenMap& allowed_special) const { 138 | if (!special_token_regex_) { 139 | return std::make_pair(std::nullopt, input.substr(offset)); 140 | } 141 | 142 | auto matches = special_token_regex_->find_all(input.substr(offset)); 143 | 144 | for (const auto& m : matches) { 145 | std::string matched_text = input.substr(offset + m.start, m.end - m.start); 146 | if (allowed_special.tryGetInteger(matched_text).has_value()) { 147 | return {matched_text, input.substr(offset, m.start)}; 148 | } 149 | } 150 | 151 | return {std::nullopt, input.substr(offset)}; 152 | } 153 | 154 | Result, uint64_t>> 155 | BPETokenizerBase::encode_with_special_token_( 156 | const std::string& text, 157 | const TokenMap& allowed_special) const { 158 | std::vector tokens; 159 | uint64_t last_piece_token_len = 0; 160 | size_t offset = 0; 161 | 162 | while (offset < text.size()) { 163 | auto [special, sub_input] = 164 | split_with_allowed_special_token_(text, offset, allowed_special); 165 | 166 | TK_CHECK_OK_OR_RETURN_ERROR( 167 | _encode(sub_input, tokens, last_piece_token_len)); 168 | offset += sub_input.size(); 169 | 170 | if (special) { 171 | const auto result = special_token_map_->tryGetInteger(*special); 172 | if (!result) { 173 | TK_LOG(Error, "unknown special token: %s\n", special->c_str()); 174 | return Error::EncodeFailure; 175 | } 176 | 177 | tokens.push_back(*result); 178 | last_piece_token_len = 0; 179 | offset += special->size(); // advance past the matched token 180 | } else { 181 | break; 182 | } 183 | } 184 | 185 | return std::make_pair(tokens, last_piece_token_len); 186 | } 187 | 188 | Result> BPETokenizerBase::byte_pair_encode_( 189 | const std::string& piece, 190 | const TokenMap& token_map) const { 191 | if (piece.size() == 1) { 192 | const auto result = token_map.tryGetInteger(piece); 193 | if (result) { 194 | return std::vector(*result); 195 | } else { 196 | // TODO: is it possible? 197 | return Error::EncodeFailure; 198 | } 199 | } 200 | 201 | return _byte_pair_merge( 202 | piece, token_map, [&piece, &token_map](uint64_t start, uint64_t stop) { 203 | std::string key = piece.substr(start, stop - start); 204 | const auto result = token_map.tryGetInteger(key); 205 | if (result) { 206 | return *result; 207 | } else { 208 | // TODO: what if key does not exist? Should we 209 | // return `unknown`? assert(false); // ?? 210 | return uint64_t(0); 211 | } 212 | }); 213 | } 214 | 215 | // ---- protected end ---------------------------------------------------------- 216 | // ---- public start ----------------------------------------------------------- 217 | 218 | Result> BPETokenizerBase::encode( 219 | const std::string& text, 220 | int8_t bos, 221 | int8_t eos) const { 222 | if (!initialized_) { 223 | return Error::Uninitialized; 224 | } 225 | auto res = 226 | TK_UNWRAP(encode_with_special_token_(text, *special_token_map_)).first; 227 | for (auto i = 0; i < bos; ++i) { 228 | res.insert(res.begin(), bos_tok_); 229 | } 230 | for (auto i = 0; i < eos; ++i) { 231 | res.push_back(eos_tok_); 232 | } 233 | return Result>(std::move(res)); 234 | } 235 | 236 | Result BPETokenizerBase::decode(uint64_t prev, uint64_t cur) 237 | const { 238 | (void)prev; 239 | if (!initialized_) { 240 | return Error::Uninitialized; 241 | } 242 | std::string ret; 243 | 244 | std::string_view token_bytes; 245 | auto result = token_map_->tryGetString(cur); 246 | if (!result) { 247 | result = special_token_map_->tryGetString(cur); 248 | if (!result) { 249 | TK_LOG(Error, "unknown token: %" PRIu64 "\n", cur); 250 | return Error::DecodeFailure; 251 | } else { 252 | token_bytes = *result; 253 | } 254 | } else { 255 | token_bytes = *result; 256 | } 257 | _decode(std::string(token_bytes), ret); 258 | 259 | return ret; 260 | } 261 | 262 | // ---- public end ------------------------------------------------------------- 263 | 264 | } // namespace detail 265 | } // namespace tokenizers 266 | -------------------------------------------------------------------------------- /src/pcre2_regex.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | namespace tokenizers { 15 | 16 | Error Pcre2Regex::compile(const std::string& pattern) { 17 | int error_code; 18 | PCRE2_SIZE error_offset; 19 | 20 | // Compile the pattern 21 | regex_ = pcre2_compile( 22 | reinterpret_cast(pattern.c_str()), 23 | pattern.length(), 24 | PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode 25 | &error_code, 26 | &error_offset, 27 | nullptr); 28 | 29 | if (regex_ == nullptr) { 30 | PCRE2_UCHAR error_buffer[256]; 31 | pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer)); 32 | TK_LOG( 33 | Error, 34 | "PCRE2 compilation failed at offset %" PRId64 ": %s", 35 | static_cast(error_offset), 36 | error_buffer); 37 | return Error::RegexFailure; 38 | } 39 | 40 | // Create match data 41 | match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr); 42 | if (match_data_ == nullptr) { 43 | pcre2_code_free(regex_); 44 | regex_ = nullptr; 45 | TK_LOG(Error, "Failed to create PCRE2 match data"); 46 | return Error::RegexFailure; 47 | } 48 | 49 | return Error::Ok; 50 | } 51 | 52 | Pcre2Regex::~Pcre2Regex() { 53 | if (match_data_) { 54 | pcre2_match_data_free(match_data_); 55 | } 56 | if (regex_) { 57 | pcre2_code_free(regex_); 58 | } 59 | } 60 | 61 | std::vector Pcre2Regex::find_all(const std::string& text) const { 62 | std::vector result; 63 | 64 | if (!regex_ || !match_data_) { 65 | TK_LOG(Error, "Regex is not compiled or invalid, run compile() first"); 66 | return result; 67 | } 68 | 69 | PCRE2_SIZE* ovector; 70 | PCRE2_SPTR subject = reinterpret_cast(text.c_str()); 71 | PCRE2_SIZE subject_length = text.length(); 72 | PCRE2_SIZE offset = 0; 73 | 74 | while (offset < subject_length) { 75 | int rc = pcre2_match( 76 | regex_, 77 | subject, 78 | subject_length, 79 | offset, 80 | 0, // Default options 81 | match_data_, 82 | nullptr); 83 | 84 | if (rc < 0) { 85 | if (rc == PCRE2_ERROR_NOMATCH) { 86 | break; // No more matches 87 | } else { 88 | // Error occurred 89 | PCRE2_UCHAR error_buffer[256]; 90 | pcre2_get_error_message(rc, error_buffer, sizeof(error_buffer)); 91 | std::cerr << "PCRE2 matching error: " << error_buffer << std::endl; 92 | break; 93 | } 94 | } 95 | 96 | ovector = pcre2_get_ovector_pointer(match_data_); 97 | 98 | // Add the match to the result 99 | result.push_back({ovector[0], ovector[1]}); 100 | 101 | // Move to the next position after the match 102 | offset = ovector[1]; 103 | 104 | // If the match was empty, move forward by one character to avoid infinite 105 | // loop 106 | if (ovector[0] == ovector[1]) { 107 | offset++; 108 | } 109 | } 110 | 111 | return result; 112 | } 113 | 114 | } // namespace tokenizers 115 | -------------------------------------------------------------------------------- /src/pre_tokenizer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | // Local 11 | #include 12 | #include 13 | 14 | // Standard 15 | #include 16 | #include 17 | #include 18 | 19 | // Third Party 20 | #include 21 | 22 | using json = nlohmann::json; 23 | 24 | namespace tokenizers { 25 | 26 | // PreTokenizerConfig ////////////////////////////////////////////////////////// 27 | 28 | PreTokenizerConfig::PreTokenizerConfig(std::string type) 29 | : type(std::move(type)) {} 30 | 31 | PreTokenizer::Ptr PreTokenizerConfig::create() const { 32 | // NOTE: These types must line up with the type strings found in the 33 | // tokenizers library 34 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/mod.rs#L73 35 | if (type == "Split") { 36 | if (!pattern) { 37 | throw std::runtime_error( 38 | "Missing pattern for PreTokenizer of type Split"); 39 | } 40 | return PreTokenizer::Ptr(new RegexPreTokenizer(*pattern)); 41 | } 42 | if (type == "Digits") { 43 | if (individual_digits) { 44 | return PreTokenizer::Ptr(new DigitsPreTokenizer(*individual_digits)); 45 | } 46 | return PreTokenizer::Ptr(new DigitsPreTokenizer()); 47 | } 48 | if (type == "ByteLevel") { 49 | if (add_prefix_space && pattern) { 50 | return PreTokenizer::Ptr( 51 | new ByteLevelPreTokenizer(*add_prefix_space, *pattern)); 52 | } 53 | if (add_prefix_space) { 54 | return PreTokenizer::Ptr(new ByteLevelPreTokenizer(*add_prefix_space)); 55 | } 56 | if (pattern) { 57 | return PreTokenizer::Ptr(new ByteLevelPreTokenizer(*pattern)); 58 | } 59 | return PreTokenizer::Ptr(new ByteLevelPreTokenizer()); 60 | } 61 | if (type == "Sequence") { 62 | if (!pretokenizers or pretokenizers->empty()) { 63 | throw std::runtime_error( 64 | "Missing pretokenizers for PreTokenizer of type Sequence"); 65 | } 66 | std::vector pretoks; 67 | std::transform( 68 | pretokenizers->begin(), 69 | pretokenizers->end(), 70 | std::back_inserter(pretoks), 71 | [](const PreTokenizerConfig& cfg) { return cfg.create(); }); 72 | return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks)); 73 | } 74 | throw std::runtime_error("Unsupported PreTokenizer type: " + type); 75 | } 76 | 77 | PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { 78 | type = json_config.at("type"); 79 | if (type == "Split") { 80 | try { 81 | pattern = json_config.at("pattern").at("Regex"); 82 | } catch (json::out_of_range&) { 83 | } 84 | } else if (type == "Digits") { 85 | try { 86 | individual_digits = json_config.at("individual_digits"); 87 | } catch (json::out_of_range&) { 88 | } 89 | } else if (type == "ByteLevel") { 90 | try { 91 | add_prefix_space = json_config.at("add_prefix_space"); 92 | } catch (json::out_of_range&) { 93 | } 94 | // TODO: trim_offsets, use_regex 95 | } else if (type == "Sequence") { 96 | pretokenizers = std::vector(); 97 | for (const auto& entry : json_config.at("pretokenizers")) { 98 | pretokenizers->push_back(PreTokenizerConfig().parse_json(entry)); 99 | } 100 | } else { 101 | throw std::runtime_error("Unsupported PreTokenizer type: " + type); 102 | } 103 | return *this; 104 | } 105 | 106 | // RegexPreTokenizer /////////////////////////////////////////////////////////// 107 | 108 | std::unique_ptr RegexPreTokenizer::create_regex_( 109 | const std::string& pattern) { 110 | assert(!pattern.empty()); 111 | return TK_UNWRAP_THROW(create_regex(pattern)); 112 | } 113 | 114 | std::vector RegexPreTokenizer::pre_tokenize( 115 | const std::string& input) const { 116 | if (!regex_) 117 | return {}; 118 | std::vector results; 119 | for (const auto& match : regex_->find_all(input)) { 120 | results.push_back(input.substr(match.start, match.end - match.start)); 121 | } 122 | return results; 123 | } 124 | 125 | // ByteLevelPreTokenizer /////////////////////////////////////////////////////// 126 | 127 | ////////////////// 128 | // Impl Details // 129 | ////////////////// 130 | namespace { 131 | 132 | // Standard GPT2 regex 133 | // https://github.com/openai/gpt-2/blob/master/src/encoder.py#L53 134 | constexpr char GPT2_EXPR[] = 135 | R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)"; 136 | 137 | } // namespace 138 | 139 | ////////////////// 140 | // Construction // 141 | ////////////////// 142 | 143 | ByteLevelPreTokenizer::ByteLevelPreTokenizer( 144 | bool add_prefix_space, 145 | const std::string& pattern) 146 | : pattern_(pattern.empty() ? GPT2_EXPR : pattern), 147 | add_prefix_space_(add_prefix_space) {} 148 | 149 | std::vector ByteLevelPreTokenizer::pre_tokenize( 150 | const std::string& input) const { 151 | // Add the prefix space if configured to do so. 152 | std::string formatted_input = input; 153 | if (add_prefix_space_ && !formatted_input.empty() && 154 | formatted_input[0] != ' ') { 155 | formatted_input.insert(formatted_input.begin(), ' '); 156 | } 157 | 158 | return unicode_regex_split(formatted_input, {pattern_}); 159 | } 160 | 161 | // SequencePreTokenizer //////////////////////////////////////////////////////// 162 | 163 | SequencePreTokenizer::SequencePreTokenizer( 164 | std::vector pre_tokenizers) 165 | : pre_tokenizers_(std::move(pre_tokenizers)) {} 166 | 167 | std::vector SequencePreTokenizer::pre_tokenize( 168 | const std::string& input) const { 169 | std::vector pieces{std::string(input)}; 170 | for (const auto& pre_tokenizer : pre_tokenizers_) { 171 | std::vector new_pieces; 172 | for (const auto& piece : pieces) { 173 | for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) { 174 | new_pieces.push_back(subpiece); 175 | } 176 | } 177 | pieces = std::move(new_pieces); 178 | } 179 | return pieces; 180 | } 181 | 182 | } // namespace tokenizers 183 | -------------------------------------------------------------------------------- /src/re2_regex.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | namespace tokenizers { 12 | 13 | Error Re2Regex::compile(const std::string& pattern) { 14 | regex_ = std::make_unique(pattern); 15 | // Warmup re2 as it is slow on the first run, void the return value as it's 16 | // not needed Refer to 17 | // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 18 | (void)regex_->ReverseProgramSize(); 19 | if (regex_->ok()) { 20 | return Error::Ok; 21 | } else { 22 | TK_LOG( 23 | Error, 24 | "Failed to compile regex: %s, error: %s", 25 | pattern.c_str(), 26 | regex_->error().c_str()); 27 | return Error::RegexFailure; 28 | } 29 | } 30 | 31 | std::vector Re2Regex::find_all(const std::string& text) const { 32 | if (!regex_ || !regex_->ok()) { 33 | TK_LOG(Error, "Regex is not compiled or invalid, run compile() first"); 34 | return std::vector{}; 35 | } 36 | std::vector result; 37 | re2::StringPiece input(text); 38 | re2::StringPiece piece; 39 | 40 | const char* base = input.data(); 41 | 42 | while (RE2::FindAndConsume(&input, *regex_, &piece)) { 43 | size_t start = piece.data() - base; 44 | result.push_back({start, start + piece.size()}); 45 | } 46 | 47 | return result; 48 | } 49 | 50 | } // namespace tokenizers 51 | -------------------------------------------------------------------------------- /src/regex.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // Default implementation for create_regex, only using RE2 regex library. 9 | // regex_lookahead.cpp has the implementation of create_regex with lookahead 10 | // support, backed by PCRE2 and std::regex. 11 | 12 | #include 13 | #include 14 | 15 | namespace tokenizers { 16 | 17 | // Default implementation that returns failure 18 | static Result> default_create_fallback_regex( 19 | const std::string& pattern) { 20 | (void)pattern; 21 | return tokenizers::Error::RegexFailure; 22 | } 23 | 24 | FallbackRegexFn fallback_regex = default_create_fallback_regex; 25 | 26 | bool register_override_fallback_regex(FallbackRegexFn fn) { 27 | TK_LOG(Info, "Registering override fallback regex"); 28 | fallback_regex = fn; 29 | return true; 30 | } 31 | 32 | FallbackRegexFn get_fallback_regex() { 33 | return fallback_regex; 34 | } 35 | 36 | Result> create_regex(const std::string& pattern) { 37 | // Try RE2 first 38 | auto re2 = std::make_unique(); 39 | auto err = re2->compile("(" + pattern + ")"); 40 | 41 | if (err == Error::Ok) { 42 | return static_cast>(std::move(re2)); 43 | } 44 | 45 | auto res = get_fallback_regex()(pattern); 46 | if (!res.ok()) { 47 | TK_LOG( 48 | Error, 49 | "RE2 doesn't support lookahead patterns. Link with `regex_lookahead` to enable support."); 50 | } else { 51 | return res; 52 | } 53 | 54 | return tokenizers::Error::RegexFailure; 55 | } 56 | } // namespace tokenizers 57 | -------------------------------------------------------------------------------- /src/regex_lookahead.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // This file contains the implementation of create_regex with lookahead support 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | namespace tokenizers { 19 | 20 | /** 21 | * @brief Implementation of the fallback regex function with lookahead support. 22 | * Falls back to PCRE2 if RE2 rejects the pattern due to lookahead. 23 | * Falls back to std::regex if PCRE2 also fails. 24 | */ 25 | Result> create_fallback_regex( 26 | const std::string& pattern) { 27 | TK_LOG(Info, "Creating PCRE2 regex"); 28 | auto pcre2 = std::make_unique(); 29 | auto err = pcre2->compile(pattern); 30 | 31 | if (err == Error::Ok) { 32 | return static_cast>(std::move(pcre2)); 33 | } 34 | 35 | // If PCRE2 also fails, fall back to std::regex 36 | auto std_regex = std::make_unique(); 37 | err = std_regex->compile(pattern); 38 | if (err == Error::Ok) { 39 | TK_LOG( 40 | Info, "PCRE2 failed to compile pattern, falling back to std::regex."); 41 | return static_cast>(std::move(std_regex)); 42 | } 43 | 44 | return tokenizers::Error::RegexFailure; 45 | } 46 | 47 | static bool registered = 48 | register_override_fallback_regex(create_fallback_regex); 49 | 50 | } // namespace tokenizers 51 | -------------------------------------------------------------------------------- /src/sentencepiece.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // A tokenizer that works with sentencepiece. 10 | 11 | #include 12 | #include 13 | #include 14 | #include "third_party/absl/strings/str_replace.h" 15 | namespace tokenizers { 16 | const char kSpaceSymbol[] = "\xe2\x96\x81"; 17 | 18 | SPTokenizer::SPTokenizer() 19 | : Tokenizer(), 20 | _processor(std::make_unique()) {} 21 | 22 | /** 23 | * @brief Load the tokenizer from a file. The tokenizer file contains the 24 | * vocabulary and scores. The format is: the first integer is the maximum 25 | * token length, followed by a list of (word_len, word) pairs. Here we 26 | * are reading all the vocabulary into memory and keep it sorted for fast 27 | * lookup. 28 | * 29 | * @param tokenizer_path The path to the tokenizer file. 30 | * @return Error 31 | */ 32 | Error SPTokenizer::load(const std::string& tokenizer_path) { 33 | if (initialized_) { 34 | fprintf(stderr, "Tokenizer already initialized.\n"); 35 | return Error::Ok; 36 | } 37 | // read in the file 38 | const auto status = _processor->Load(tokenizer_path); 39 | if (!status.ok()) { 40 | fprintf( 41 | stderr, 42 | "couldn't load %s. \nError message: \n%s\n" 43 | "It is likely that the tokenizer artifact is " 44 | "broken or of a different format.", 45 | tokenizer_path.c_str(), 46 | status.error_message()); 47 | return Error::LoadFailure; 48 | } 49 | // load vocab_size, bos_tok, eos_tok 50 | vocab_size_ = _processor->GetPieceSize(); 51 | bos_tok_ = _processor->bos_id(); 52 | eos_tok_ = _processor->eos_id(); 53 | initialized_ = true; 54 | return Error::Ok; 55 | } 56 | 57 | SPTokenizer::~SPTokenizer() {} 58 | 59 | /** 60 | * @brief Decode a token into string. 61 | * 62 | * @param prev_token The previous token. 63 | * @param token The current token. 64 | * @return Result The string representation of the 65 | * token. 66 | */ 67 | Result SPTokenizer::decode(uint64_t prev_token, uint64_t token) 68 | const { 69 | if (!initialized_) { 70 | fprintf(stderr, "Tokenizer not initialized\n"); 71 | return Error::Uninitialized; 72 | } 73 | // get rid of the control ids and 74 | if (_processor->IsControl(token)) { 75 | // NB: returning empty string doesn't work for some reason. It causes 76 | // free(): invalid pointer error. 77 | return std::string(" "); 78 | } 79 | 80 | std::string result = 81 | absl::StrReplaceAll(_processor->IdToPiece(token), {{kSpaceSymbol, " "}}); 82 | 83 | // following BOS token, sentencepiece decoder strips any leading 84 | // whitespace 85 | if (prev_token == bos_tok_ && result[0] == ' ') { 86 | result = result.substr(1); 87 | } 88 | 89 | // handle <0x0A> 90 | result = absl::StrReplaceAll(result, {{"<0x0A>", "\n"}}); 91 | 92 | return result; 93 | } 94 | 95 | /** 96 | * @brief Encode a string into a sequence of tokens. 97 | * 98 | * @param text The string to be encoded. 99 | * @param bos The number of BOS to prepend to the token list. 100 | * @param eos The number of EOS to append to the token list. 101 | * @return Result> 102 | */ 103 | Result> 104 | SPTokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const { 105 | if (!initialized_) { 106 | fprintf(stderr, "Tokenizer not initialized\n"); 107 | return Error::Uninitialized; 108 | } 109 | // workaround a weird issue that text doesn't have correct size() 110 | std::string input(text.c_str()); 111 | // should we reserve memory? 112 | std::vector res; 113 | auto status = _processor->Encode(input, &res); 114 | if (!status.ok()) { 115 | fprintf(stderr, "couldn't encode %s\n", text.c_str()); 116 | return Error::EncodeFailure; 117 | } 118 | 119 | std::vector tokens; 120 | for (auto i = 0; i < bos; ++i) { 121 | tokens.push_back(bos_tok_); 122 | } 123 | 124 | for (auto i = 0; i < res.size(); ++i) { 125 | tokens.push_back(res[i]); 126 | } 127 | 128 | for (auto i = 0; i < eos; ++i) { 129 | tokens.push_back(eos_tok_); 130 | } 131 | return tokens; 132 | } 133 | } // namespace tokenizers 134 | -------------------------------------------------------------------------------- /src/std_regex.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | * @lint-ignore-every LICENSELINT 9 | * @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful 10 | */ 11 | 12 | #include 13 | #include 14 | 15 | namespace tokenizers { 16 | 17 | Error StdRegex::compile(const std::string& pattern) { 18 | try { 19 | regex_ = std::regex(pattern); 20 | return Error::Ok; 21 | } catch (std::regex_error) { 22 | TK_LOG(Error, "Failed to compile regex: %s", pattern.c_str()); 23 | return Error::RegexFailure; 24 | } 25 | } 26 | 27 | std::vector StdRegex::find_all(const std::string& text) const { 28 | std::vector result; 29 | std::sregex_iterator iter(text.begin(), text.end(), regex_); 30 | std::sregex_iterator end; 31 | 32 | for (; iter != end; ++iter) { 33 | const auto& match = *iter; 34 | size_t start = match.position(1); 35 | result.push_back({start, start + match[1].length()}); 36 | } 37 | 38 | return result; 39 | } 40 | 41 | } // namespace tokenizers 42 | -------------------------------------------------------------------------------- /src/tiktoken.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // Adopted from https://github.com/sewenew/tokenizer 10 | 11 | // @lint-ignore-every LICENSELINT 12 | /************************************************************************** 13 | Copyright (c) 2023 sewenew 14 | 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software 22 | distributed under the License is distributed on an "AS IS" BASIS, 23 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | See the License for the specific language governing permissions and 25 | limitations under the License. 26 | *************************************************************************/ 27 | 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | namespace tokenizers { 37 | 38 | using namespace detail; 39 | 40 | // ------------------------------Util start------------------------------------ 41 | namespace { 42 | 43 | static Result> _create_regex( 44 | const std::string& pattern) { 45 | assert(!pattern.empty()); 46 | return create_regex(pattern); 47 | } 48 | 49 | static Result> _parse( 50 | const std::string& line) { 51 | // Tiktoken format 52 | // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 54 | auto pos = line.find(" "); 55 | TK_CHECK_OR_RETURN_ERROR( 56 | pos != std::string::npos, 57 | ParseFailure, 58 | "invalid tiktoken line: %s", 59 | line.c_str()); 60 | 61 | auto token = TK_UNWRAP(base64::decode({line.data(), pos})); 62 | uint64_t rank = 0; 63 | try { 64 | rank = std::stoul(line.substr(pos + 1)); 65 | } catch (const std::exception&) { 66 | TK_CHECK_OR_RETURN_ERROR( 67 | false, EncodeFailure, "invalid encoder rank: %s", line.c_str()); 68 | } 69 | 70 | return std::pair{std::move(token), rank}; 71 | } 72 | 73 | static Result _load_token_map(const std::string& path) { 74 | std::ifstream file(path); 75 | TK_CHECK_OR_RETURN_ERROR( 76 | file, LoadFailure, "failed to open encoder file: %s", path.c_str()); 77 | 78 | // Instead of generating couple of large unordered_maps here to only process 79 | // them linearly in the TokenMap, just place them in a vector of pairs and 80 | // sort them twice, looking for duplicates. It's still O(n log n) but avoids 81 | // the overhead of the unordered_maps. 82 | 83 | std::vector> pairs; 84 | std::string line; 85 | while (std::getline(file, line)) { 86 | auto [token, rank] = TK_UNWRAP(_parse(line)); 87 | pairs.emplace_back(std::move(token), rank); 88 | } 89 | 90 | return buildTokenMap(pairs); 91 | } 92 | 93 | } // namespace 94 | 95 | // ------------------------------Util end------------------------------------ 96 | // -------------------------private method start------------------------------- 97 | 98 | Error Tiktoken::_encode( 99 | const std::string& input, 100 | std::vector& ret, 101 | uint64_t& last_piece_token_len) const { 102 | std::string piece; 103 | assert(_regex); 104 | for (const auto& match : _regex->find_all(input)) { 105 | std::string matched_text = 106 | input.substr(match.start, match.end - match.start); 107 | const auto result = token_map_->tryGetInteger(matched_text); 108 | if (result) { 109 | last_piece_token_len = 1; 110 | ret.push_back(*result); 111 | continue; 112 | } 113 | auto tokens = TK_UNWRAP(byte_pair_encode_(matched_text, *token_map_)); 114 | last_piece_token_len = tokens.size(); 115 | ret.insert(ret.end(), tokens.begin(), tokens.end()); 116 | } 117 | return Error::Ok; 118 | } 119 | 120 | void Tiktoken::_decode(const std::string& input, std::string& ret) const { 121 | ret += input; 122 | } 123 | 124 | // -------------------------private method end------------------------------- 125 | // -------------------------public method start------------------------------- 126 | 127 | Error Tiktoken::load(const std::string& path) { 128 | token_map_.emplace(TK_UNWRAP(_load_token_map(path))); 129 | 130 | std::vector> special_token_map; 131 | for (std::size_t i = 0; i < _special_tokens->size(); ++i) { 132 | special_token_map.emplace_back( 133 | _special_tokens->at(i), token_map_->size() + i); 134 | } 135 | 136 | special_token_map_.emplace(TokenMap(special_token_map)); 137 | 138 | _regex = TK_UNWRAP(_create_regex(_pattern)); 139 | special_token_regex_ = 140 | TK_UNWRAP(detail::build_special_token_regex(TokenMap(special_token_map))); 141 | 142 | // initialize vocab_size, bos_tok, eos_tok 143 | vocab_size_ = token_map_->size() + special_token_map_->size(); 144 | bos_tok_ = 145 | *special_token_map_->tryGetInteger(_special_tokens->at(_bos_token_index)); 146 | eos_tok_ = 147 | *special_token_map_->tryGetInteger(_special_tokens->at(_eos_token_index)); 148 | 149 | initialized_ = true; 150 | return Error::Ok; 151 | } 152 | 153 | // -------------------------public method end------------------------------- 154 | 155 | } // namespace tokenizers 156 | -------------------------------------------------------------------------------- /src/token_decoder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #include 11 | 12 | // Standard 13 | #include 14 | 15 | // Third Party 16 | #include 17 | 18 | // Local 19 | #include 20 | 21 | using json = nlohmann::json; 22 | 23 | namespace tokenizers { 24 | 25 | // TokenDecoderConfig ////////////////////////////////////////////////////////// 26 | 27 | TokenDecoderConfig::TokenDecoderConfig(std::string type) 28 | : type(std::move(type)) {} 29 | 30 | TokenDecoder::Ptr TokenDecoderConfig::create() const { 31 | // NOTE: These types must line up with the type strings found in the 32 | // tokenizers library 33 | // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/mod.rs#L55 34 | if (type == "ByteLevel") { 35 | return TokenDecoder::Ptr(new ByteLevelTokenDecoder()); 36 | } 37 | throw std::runtime_error("Unsupported TokenDecoder type: " + type); 38 | } 39 | 40 | TokenDecoderConfig& TokenDecoderConfig::parse_json(const json& json_config) { 41 | type = json_config.at("type"); 42 | if (type == "ByteLevel") { 43 | // No parameters to parse 44 | } else { 45 | throw std::runtime_error("Unsupported TokenDecoder type: " + type); 46 | } 47 | return *this; 48 | } 49 | 50 | // ByteLevel /////////////////////////////////////////////////////////////////// 51 | 52 | namespace { 53 | 54 | // Copied from llama.cpp 55 | // CITE: 56 | // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L20 57 | static std::string format(const char* fmt, ...) { 58 | va_list ap; 59 | va_list ap2; 60 | va_start(ap, fmt); 61 | va_copy(ap2, ap); 62 | int size = vsnprintf(NULL, 0, fmt, ap); 63 | // GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT 64 | std::vector buf(size + 1); 65 | // int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); 66 | // GGML_ASSERT(size2 == size); 67 | va_end(ap2); 68 | va_end(ap); 69 | return std::string(buf.data(), size); 70 | } 71 | 72 | } // namespace 73 | 74 | std::string ByteLevelTokenDecoder::decode(const std::string& token) const { 75 | // This is borrowed and lightly tweaked from llama.cpp 76 | // CITE: 77 | // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L1755 78 | std::string decoded_text; 79 | // TODO: This could be more efficient since what we really need is a string 80 | // const ref. 81 | const auto cpts = unicode_cpts_from_utf8(token); 82 | for (const auto cpt : cpts) { 83 | const auto utf8 = unicode_cpt_to_utf8(cpt); 84 | try { 85 | decoded_text += unicode_utf8_to_byte(utf8); 86 | } catch (const std::out_of_range& /*e*/) { 87 | decoded_text += "[UNK_BYTE_0x"; 88 | for (const auto c : utf8) { 89 | decoded_text += format("%02x", (uint8_t)c); 90 | } 91 | decoded_text += token + "]"; 92 | } 93 | } 94 | 95 | return decoded_text; 96 | } 97 | 98 | } // end namespace tokenizers 99 | -------------------------------------------------------------------------------- /targets.bzl: -------------------------------------------------------------------------------- 1 | load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_executorch_supported_platforms", "runtime") 2 | load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") 3 | 4 | PLATFORMS = get_executorch_supported_platforms() 5 | 6 | def define_common_targets(): 7 | """Defines targets that should be shared between fbcode and xplat. 8 | 9 | The directory containing this targets.bzl file should also contain both 10 | TARGETS and BUCK files that call this function. 11 | """ 12 | 13 | runtime.cxx_library( 14 | name = "headers", 15 | exported_headers = subdir_glob([ 16 | ("include", "pytorch/tokenizers/*.h"), 17 | ]), 18 | visibility = [ 19 | "@EXECUTORCH_CLIENTS", 20 | "//pytorch/tokenizers/...", 21 | ], 22 | header_namespace = "", 23 | platforms = PLATFORMS, 24 | ) 25 | 26 | runtime.cxx_library( 27 | name = "regex", 28 | srcs = [ 29 | "src/re2_regex.cpp", 30 | "src/regex.cpp", 31 | ], 32 | exported_deps = [ 33 | ":headers", 34 | ], 35 | exported_external_deps = [ 36 | "re2", 37 | ], 38 | visibility = ["//pytorch/tokenizers/..."], 39 | header_namespace = "", 40 | platforms = PLATFORMS, 41 | ) 42 | 43 | runtime.cxx_library( 44 | name = "regex_lookahead", 45 | srcs = [ 46 | "src/pcre2_regex.cpp", 47 | "src/regex_lookahead.cpp", 48 | "src/std_regex.cpp", 49 | ], 50 | exported_deps = [ 51 | ":regex", 52 | ":headers", 53 | ], 54 | compiler_flags = [ 55 | "-Wno-global-constructors", 56 | "-Wno-missing-prototypes", 57 | ], 58 | exported_external_deps = [ 59 | "pcre2", 60 | ], 61 | # Making sure this library is not being stripped by linker. 62 | # @lint-ignore BUCKLINT: Avoid link_whole=True 63 | link_whole = True, 64 | visibility = [ 65 | "@EXECUTORCH_CLIENTS", 66 | "//pytorch/tokenizers/...", 67 | ], 68 | header_namespace = "", 69 | platforms = PLATFORMS, 70 | ) 71 | 72 | runtime.cxx_library( 73 | name = "bpe_tokenizer_base", 74 | srcs = [ 75 | "src/bpe_tokenizer_base.cpp", 76 | ], 77 | exported_deps = [ 78 | ":headers", 79 | ], 80 | exported_external_deps = [ 81 | "re2", 82 | ], 83 | visibility = [ 84 | "//pytorch/tokenizers/...", 85 | ], 86 | platforms = PLATFORMS, 87 | ) 88 | 89 | runtime.cxx_library( 90 | name = "sentencepiece", 91 | srcs = [ 92 | "src/sentencepiece.cpp", 93 | ], 94 | deps = [ 95 | ":regex", 96 | ], 97 | exported_deps = [ 98 | ":headers", 99 | ], 100 | visibility = [ 101 | "@EXECUTORCH_CLIENTS", 102 | "//pytorch/tokenizers/...", 103 | ], 104 | external_deps = [ 105 | "sentencepiece", 106 | "abseil-cpp", 107 | ], 108 | platforms = PLATFORMS, 109 | ) 110 | 111 | runtime.cxx_library( 112 | name = "tiktoken", 113 | srcs = [ 114 | "src/tiktoken.cpp", 115 | ], 116 | deps = [ 117 | ":regex", 118 | ], 119 | exported_deps = [ 120 | ":bpe_tokenizer_base", 121 | ":headers", 122 | ], 123 | exported_external_deps = [ 124 | "re2", 125 | ], 126 | visibility = [ 127 | "@EXECUTORCH_CLIENTS", 128 | "//pytorch/tokenizers/...", 129 | ], 130 | platforms = PLATFORMS, 131 | ) 132 | 133 | runtime.cxx_library( 134 | name = "hf_tokenizer", 135 | srcs = [ 136 | "src/hf_tokenizer.cpp", 137 | "src/pre_tokenizer.cpp", 138 | "src/token_decoder.cpp", 139 | ], 140 | deps = [ 141 | ":regex", 142 | ], 143 | exported_deps = [ 144 | ":bpe_tokenizer_base", 145 | ":headers", 146 | "//pytorch/tokenizers/third-party:unicode", 147 | ], 148 | visibility = [ 149 | "@EXECUTORCH_CLIENTS", 150 | "//pytorch/tokenizers/...", 151 | ], 152 | exported_external_deps = [ 153 | "re2", 154 | "nlohmann_json", 155 | ], 156 | platforms = PLATFORMS, 157 | ) 158 | 159 | runtime.cxx_library( 160 | name = "llama2c_tokenizer", 161 | srcs = [ 162 | "src/llama2c_tokenizer.cpp", 163 | ], 164 | exported_deps = [ 165 | ":headers", 166 | ], 167 | visibility = [ 168 | "@EXECUTORCH_CLIENTS", 169 | "//pytorch/tokenizers/...", 170 | ], 171 | platforms = PLATFORMS, 172 | ) 173 | -------------------------------------------------------------------------------- /test/TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain fbcode-only targets. 3 | 4 | load(":targets.bzl", "define_common_targets") 5 | 6 | oncall("executorch") 7 | 8 | define_common_targets() 9 | -------------------------------------------------------------------------------- /test/fb/TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain fbcode-only targets. 3 | 4 | load("@fbcode_macros//build_defs:cpp_benchmark.bzl", "cpp_benchmark") 5 | load("@fbsource//arvr/tools/build_defs:embed_resources.bzl", "embed_resources") 6 | 7 | oncall("executorch") 8 | 9 | embed_resources( 10 | name = "string_integer_map_bench_resources", 11 | generated_resources = { 12 | "//pytorch/tokenizers/test:test_tiktoken_tokenizer_model": "test_tiktoken_tokenizer_model", 13 | }, 14 | header_path = "pytorch/tokenizers/test/fb", 15 | namespace = "string_integer_map_bench", 16 | visibility = ["PUBLIC"], 17 | ) 18 | 19 | # @nolint 20 | cpp_benchmark( 21 | # @autodeps-skip 22 | name = "string_integer_map_bench", 23 | srcs = ["string_integer_map_bench.cpp"], 24 | preprocessor_flags = ["-DFBCODEBUILD=1"], 25 | deps = [ 26 | "fbsource//third-party/benchmark:benchmark", 27 | "//pytorch/tokenizers/test/fb:string_integer_map_bench_resources", 28 | "//pytorch/tokenizers:headers", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /test/resources/test_bpe_tokenizer.bin: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/resources/test_llama2c_tokenizer.bin: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/resources/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch-labs/tokenizers/fc32028858020c4fcafe37aaaeaf5d1b480336a2/test/resources/test_sentencepiece.model -------------------------------------------------------------------------------- /test/resources/test_tiktoken_invalid_base64.model: -------------------------------------------------------------------------------- 1 | tet 0 2 | -------------------------------------------------------------------------------- /test/resources/test_tiktoken_invalid_rank.model: -------------------------------------------------------------------------------- 1 | ICAgICAgIA== 18446744073709551616 2 | -------------------------------------------------------------------------------- /test/resources/test_tiktoken_no_space.model: -------------------------------------------------------------------------------- 1 | ICAgICAgIA==10 2 | -------------------------------------------------------------------------------- /test/targets.bzl: -------------------------------------------------------------------------------- 1 | load( 2 | "@fbsource//tools/build_defs:default_platform_defs.bzl", 3 | "ANDROID", 4 | "CXX", 5 | ) 6 | load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") 7 | 8 | def define_common_targets(): 9 | """Defines targets that should be shared between fbcode and xplat. 10 | 11 | The directory containing this targets.bzl file should also contain both 12 | TARGETS and BUCK files that call this function. 13 | """ 14 | runtime.cxx_test( 15 | name = "test_base64", 16 | srcs = [ 17 | "test_base64.cpp", 18 | ], 19 | deps = [ 20 | "//pytorch/tokenizers:headers", 21 | ], 22 | ) 23 | 24 | runtime.cxx_test( 25 | name = "test_llama2c_tokenizer", 26 | srcs = [ 27 | "test_llama2c_tokenizer.cpp", 28 | ], 29 | deps = [ 30 | "//pytorch/tokenizers:llama2c_tokenizer", 31 | ], 32 | env = { 33 | "RESOURCES_PATH": "$(location :resources)/resources", 34 | }, 35 | platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. 36 | ) 37 | 38 | runtime.cxx_test( 39 | name = "test_pre_tokenizer", 40 | srcs = [ 41 | "test_pre_tokenizer.cpp", 42 | ], 43 | deps = [ 44 | "//pytorch/tokenizers:headers", 45 | "//pytorch/tokenizers:hf_tokenizer", 46 | ], 47 | ) 48 | 49 | runtime.cxx_test( 50 | name = "test_sentencepiece", 51 | srcs = [ 52 | "test_sentencepiece.cpp", 53 | ], 54 | deps = ["//pytorch/tokenizers:sentencepiece"], 55 | external_deps = [ 56 | "sentencepiece", 57 | "abseil-cpp", 58 | ], 59 | env = { 60 | "RESOURCES_PATH": "$(location :resources)/resources", 61 | }, 62 | platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. 63 | ) 64 | 65 | runtime.cxx_test( 66 | name = "test_string_integer_map", 67 | srcs = [ 68 | "test_string_integer_map.cpp", 69 | ], 70 | deps = [ 71 | "//pytorch/tokenizers:headers", 72 | ], 73 | env = { 74 | "RESOURCES_PATH": "$(location :resources)/resources", 75 | }, 76 | platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. 77 | external_deps = [ 78 | "re2", 79 | ], 80 | ) 81 | 82 | runtime.cxx_test( 83 | name = "test_tiktoken", 84 | srcs = [ 85 | "test_tiktoken.cpp", 86 | ], 87 | deps = [ 88 | "//pytorch/tokenizers:tiktoken", 89 | ], 90 | env = { 91 | "RESOURCES_PATH": "$(location :resources)/resources", 92 | }, 93 | platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. 94 | external_deps = [ 95 | "re2", 96 | ], 97 | ) 98 | 99 | runtime.cxx_test( 100 | name = "test_regex", 101 | srcs = [ 102 | "test_regex.cpp", 103 | ], 104 | deps = [ 105 | "//pytorch/tokenizers:regex_lookahead", 106 | "//pytorch/tokenizers:headers", 107 | ], 108 | ) 109 | 110 | runtime.filegroup( 111 | name = "resources", 112 | srcs = native.glob([ 113 | "resources/**", 114 | ]), 115 | ) 116 | 117 | runtime.export_file( 118 | name = "test_tiktoken_tokenizer_model", 119 | src = "resources/test_tiktoken_tokenizer.model", 120 | visibility = ["@EXECUTORCH_CLIENTS", "//pytorch/tokenizers/..."], 121 | ) 122 | 123 | runtime.python_test( 124 | name = "test_tiktoken_py", 125 | srcs = [ 126 | "test_tiktoken.py", 127 | ], 128 | deps = [ 129 | "//pytorch/tokenizers/pytorch_tokenizers:tokenizers", 130 | ], 131 | resources = { 132 | ":test_tiktoken_tokenizer_model": "test_tiktoken_tokenizer.model", 133 | }, 134 | ) 135 | -------------------------------------------------------------------------------- /test/test_base64.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include "gtest/gtest.h" 11 | 12 | namespace tokenizers { 13 | 14 | TEST(Base64Test, TestDecodeSmoke) { 15 | std::string text = "bGxhbWE="; 16 | auto result = base64::decode(text); 17 | EXPECT_TRUE(result.ok()); 18 | EXPECT_EQ(result.get(), "llama"); 19 | } 20 | 21 | TEST(Base64Test, TestDecodeEmptyStringReturnsError) { 22 | std::string text; 23 | auto result = base64::decode(text); 24 | EXPECT_FALSE(result.ok()); 25 | EXPECT_EQ(result.error(), Error::Base64DecodeFailure); 26 | } 27 | 28 | TEST(Base64Test, TestInvalidStringDecodeReturnsError) { 29 | std::string text = "llama"; 30 | auto result = base64::decode(text); 31 | EXPECT_FALSE(result.ok()); 32 | EXPECT_EQ(result.error(), Error::Base64DecodeFailure); 33 | } 34 | 35 | } // namespace tokenizers 36 | -------------------------------------------------------------------------------- /test/test_llama2c_tokenizer.cpp: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | #ifdef TOKENIZERS_FB_BUCK 4 | #include 5 | #endif 6 | #include 7 | #include 8 | 9 | using namespace ::testing; 10 | 11 | namespace tokenizers { 12 | 13 | namespace { 14 | // Test case based on llama2.c tokenizer 15 | static inline std::string _get_resource_path(const std::string& name) { 16 | #ifdef TOKENIZERS_FB_BUCK 17 | return facebook::xplat::testing::getPathForTestResource( 18 | "test/resources/" + name); 19 | #else 20 | return std::getenv("RESOURCES_PATH") + std::string("/") + name; 21 | #endif 22 | } 23 | 24 | } // namespace 25 | 26 | class Llama2cTokenizerTest : public Test { 27 | public: 28 | void SetUp() override { 29 | tokenizer_ = std::make_unique(); 30 | modelPath_ = _get_resource_path("test_llama2c_tokenizer.bin"); 31 | } 32 | 33 | std::unique_ptr tokenizer_; 34 | std::string modelPath_; 35 | }; 36 | 37 | TEST_F(Llama2cTokenizerTest, EncodeWithoutLoadFails) { 38 | Result> res = tokenizer_->encode("hello world", 0, 0); 39 | EXPECT_EQ(res.error(), Error::Uninitialized); 40 | } 41 | 42 | TEST_F(Llama2cTokenizerTest, DecodeWithoutLoadFails) { 43 | auto result = tokenizer_->decode(0, 0); 44 | EXPECT_EQ(result.error(), Error::Uninitialized); 45 | } 46 | 47 | TEST_F(Llama2cTokenizerTest, DecodeOutOfRangeFails) { 48 | Error res = tokenizer_->load(modelPath_.c_str()); 49 | EXPECT_EQ(res, Error::Ok); 50 | auto result = tokenizer_->decode(0, 64000); 51 | // The vocab size is 32000, and token 64000 is out of vocab range. 52 | EXPECT_EQ(result.error(), Error::OutOfRange); 53 | } 54 | 55 | TEST_F(Llama2cTokenizerTest, TokenizerMetadataIsExpected) { 56 | Error res = tokenizer_->load(modelPath_.c_str()); 57 | EXPECT_EQ(res, Error::Ok); 58 | // test_bpe_tokenizer.bin has vocab_size 0, bos_id 0, eos_id 0 recorded. 59 | EXPECT_EQ(tokenizer_->vocab_size(), 0); 60 | EXPECT_EQ(tokenizer_->bos_tok(), 0); 61 | EXPECT_EQ(tokenizer_->eos_tok(), 0); 62 | } 63 | 64 | TEST_F(Llama2cTokenizerTest, SafeToDestruct) { 65 | // Safe to destruct initialized tokenizer. 66 | tokenizer_->load(modelPath_); 67 | tokenizer_.reset(); 68 | 69 | // Safe to destruct uninitialized tokenizer. 70 | tokenizer_ = std::make_unique(); 71 | tokenizer_.reset(); 72 | } 73 | 74 | } // namespace tokenizers 75 | -------------------------------------------------------------------------------- /test/test_pre_tokenizer.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // Third Party 10 | #include 11 | #include 12 | #include 13 | 14 | // Local 15 | #include 16 | 17 | using json = nlohmann::json; 18 | using namespace tokenizers; 19 | 20 | // Helpers ///////////////////////////////////////////////////////////////////// 21 | 22 | static void assert_split_match( 23 | const PreTokenizer& ptok, 24 | const std::string& prompt, 25 | const std::vector& expected) { 26 | const auto& got = ptok.pre_tokenize(prompt); 27 | EXPECT_EQ(expected.size(), got.size()); 28 | for (auto i = 0; i < got.size(); ++i) { 29 | EXPECT_EQ(expected[i], got[i]); 30 | } 31 | } 32 | 33 | // RegexPreTokenizer /////////////////////////////////////////////////////////// 34 | class RegexPreTokenizerTest : public ::testing::Test {}; 35 | 36 | // Test the basic construction 37 | TEST_F(RegexPreTokenizerTest, Construct) { 38 | RegexPreTokenizer ptok("[0-9]+"); 39 | } 40 | 41 | // Test basic splitting using the expression for Tiktoken 42 | TEST_F(RegexPreTokenizerTest, TiktokenExpr) { 43 | RegexPreTokenizer ptok( 44 | R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"); 45 | assert_split_match( 46 | ptok, "How are you doing?", {"How", " are", " you", " doing", "?"}); 47 | } 48 | 49 | // DigitsPreTokenizer ////////////////////////////////////////////////////////// 50 | class DigitsPreTokenizerTest : public ::testing::Test {}; 51 | 52 | // Test digit splitting with individual digits 53 | TEST_F(DigitsPreTokenizerTest, IndividualDigits) { 54 | DigitsPreTokenizer ptok(true); 55 | assert_split_match( 56 | ptok, 57 | "The number 1 then 234 then 5.", 58 | {"The number ", "1", " then ", "2", "3", "4", " then ", "5", "."}); 59 | } 60 | 61 | // Test digit splitting with contiguous digits 62 | TEST_F(DigitsPreTokenizerTest, ContiguousDigits) { 63 | DigitsPreTokenizer ptok(false); 64 | assert_split_match( 65 | ptok, 66 | "The number 1 then 234 then 5.", 67 | {"The number ", "1", " then ", "234", " then ", "5", "."}); 68 | } 69 | 70 | // ByteLevelPreTokenizer /////////////////////////////////////////////////////// 71 | class ByteLevelPreTokenizerTest : public ::testing::Test {}; 72 | 73 | TEST_F(ByteLevelPreTokenizerTest, PreTokenizeDefault) { 74 | ByteLevelPreTokenizer ptok; 75 | assert_split_match(ptok, "Hello World", {"ĠHello", "ĠWorld"}); 76 | assert_split_match( 77 | ptok, 78 | "The number 1 then 234 then 5.", 79 | {"ĠThe", "Ġnumber", "Ġ1", "Ġthen", "Ġ234", "Ġthen", "Ġ5", "."}); 80 | } 81 | 82 | TEST_F(ByteLevelPreTokenizerTest, PreTokenizeNoPrefix) { 83 | ByteLevelPreTokenizer ptok(false); 84 | assert_split_match(ptok, "Hello World", {"Hello", "ĠWorld"}); 85 | } 86 | 87 | TEST_F(ByteLevelPreTokenizerTest, PreTokenizeCustomRegex) { 88 | ByteLevelPreTokenizer ptok(false, R"(o)"); 89 | assert_split_match(ptok, "Hello World", {"Hell", "o", "ĠW", "o", "rld"}); 90 | } 91 | 92 | // SequencePreTokenizer //////////////////////////////////////////////////////// 93 | class SequencePreTokenizerTest : public ::testing::Test {}; 94 | 95 | TEST_F(SequencePreTokenizerTest, PreTokenizeDigitAndByteLevel) { 96 | PreTokenizer::Ptr dptok(new DigitsPreTokenizer(true)); 97 | PreTokenizer::Ptr bptok(new ByteLevelPreTokenizer(false)); 98 | SequencePreTokenizer ptok({dptok, bptok}); 99 | assert_split_match( 100 | ptok, 101 | "The number 1 then 234 then 5.", 102 | {"The", 103 | "Ġnumber", 104 | "Ġ", 105 | "1", 106 | "Ġthen", 107 | "Ġ", 108 | "2", 109 | "3", 110 | "4", 111 | "Ġthen", 112 | "Ġ", 113 | "5", 114 | "."}); 115 | } 116 | 117 | // PreTokenizerConfig ////////////////////////////////////////////////////////// 118 | // 119 | // NOTE: When adding a new pre-tokenizer or changing arguments, add it to these 120 | // tests! 121 | class PreTokenizerConfigTest : public ::testing::Test {}; 122 | 123 | TEST_F(PreTokenizerConfigTest, AllTypesSuccess) { 124 | // Regex 125 | PreTokenizerConfig("Split").set_pattern(R"(o)").create(); 126 | 127 | // Digits 128 | PreTokenizerConfig("Digits").create(); 129 | PreTokenizerConfig("Digits").set_individual_digits(true).create(); 130 | PreTokenizerConfig("Digits").set_individual_digits(false).create(); 131 | 132 | // ByteLevel 133 | PreTokenizerConfig("ByteLevel").create(); 134 | PreTokenizerConfig("ByteLevel").set_pattern(R"(o)").create(); 135 | PreTokenizerConfig("ByteLevel").set_add_prefix_space(true).create(); 136 | PreTokenizerConfig("ByteLevel") 137 | .set_add_prefix_space(false) 138 | .set_pattern(R"(o)") 139 | .create(); 140 | 141 | // Sequence 142 | PreTokenizerConfig("Sequence") 143 | .set_pretokenizers( 144 | {PreTokenizerConfig("Digits"), PreTokenizerConfig("ByteLevel")}) 145 | .create(); 146 | } 147 | 148 | TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) { 149 | // Regex 150 | EXPECT_THROW(PreTokenizerConfig("Split").create(), std::runtime_error); 151 | 152 | // Sequence 153 | EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error); 154 | EXPECT_THROW( 155 | PreTokenizerConfig("Sequence").set_pretokenizers({}).create(), 156 | std::runtime_error); 157 | EXPECT_THROW( 158 | PreTokenizerConfig("Sequence") 159 | .set_pretokenizers({PreTokenizerConfig("Split")}) 160 | .create(), 161 | std::runtime_error); 162 | 163 | // Unsupported 164 | EXPECT_THROW(PreTokenizerConfig("Unsupported").create(), std::runtime_error); 165 | } 166 | 167 | TEST_F(PreTokenizerConfigTest, ParseJson) { 168 | PreTokenizerConfig config; 169 | const auto ptok = config 170 | .parse_json(json{ 171 | {"type", "Sequence"}, 172 | {"pretokenizers", 173 | json{ 174 | json{ 175 | {"type", "Digits"}, 176 | {"individual_digits", true}, 177 | }, 178 | json{ 179 | {"type", "ByteLevel"}, 180 | {"add_prefix_space", false}, 181 | }, 182 | }}, 183 | }) 184 | .create(); 185 | assert_split_match( 186 | *ptok, 187 | "The number 1 then 234 then 5.", 188 | {"The", 189 | "Ġnumber", 190 | "Ġ", 191 | "1", 192 | "Ġthen", 193 | "Ġ", 194 | "2", 195 | "3", 196 | "4", 197 | "Ġthen", 198 | "Ġ", 199 | "5", 200 | "."}); 201 | } 202 | 203 | TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) { 204 | PreTokenizerConfig config; 205 | const auto ptok = config 206 | .parse_json(json{ 207 | {"type", "Digits"}, 208 | }) 209 | .create(); 210 | assert_split_match( 211 | *ptok, 212 | "The number 1 then 234 then 5.", 213 | {"The number ", "1", " then ", "234", " then ", "5", "."}); 214 | } 215 | 216 | TEST_F(PreTokenizerConfigTest, Split) { 217 | PreTokenizerConfig config; 218 | const auto ptok = 219 | config 220 | .parse_json(json{ 221 | {"type", "Split"}, 222 | {"pattern", 223 | {{"Regex", 224 | R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"}}}, 225 | }) 226 | .create(); 227 | assert_split_match(*ptok, "Hello World", {"Hello", " World"}); 228 | } 229 | -------------------------------------------------------------------------------- /test/test_regex.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | #include "pytorch/tokenizers/pcre2_regex.h" 12 | #include "pytorch/tokenizers/re2_regex.h" 13 | #include "pytorch/tokenizers/regex.h" 14 | 15 | using namespace tokenizers; 16 | 17 | class RegexTest : public ::testing::Test {}; 18 | 19 | // Test basic functionality 20 | TEST_F(RegexTest, BasicMatching) { 21 | auto regex = TK_UNWRAP_THROW(create_regex("\\w+")); 22 | 23 | std::string text = "Hello world"; 24 | auto matches = regex->find_all(text); 25 | ASSERT_EQ(matches.size(), 2); 26 | EXPECT_EQ(matches[0].start, 0); 27 | EXPECT_EQ(matches[0].end, 5); 28 | EXPECT_EQ( 29 | text.substr(matches[0].start, matches[0].end - matches[0].start), 30 | "Hello"); 31 | EXPECT_EQ(matches[1].start, 6); 32 | EXPECT_EQ(matches[1].end, 11); 33 | EXPECT_EQ( 34 | text.substr(matches[1].start, matches[1].end - matches[1].start), 35 | "world"); 36 | } 37 | 38 | // Test pattern that only PCRE2 supports (lookbehind) 39 | TEST_F(RegexTest, Pcre2Specific) { 40 | const std::string pattern = "(?<=@)\\w+"; 41 | 42 | // Verify that the factory function fallsback on a PCRE2 regex 43 | auto regex = TK_UNWRAP_THROW(create_regex(pattern)); 44 | EXPECT_NE(dynamic_cast(regex.get()), nullptr); 45 | 46 | std::string text = "user@example.com"; 47 | auto matches = regex->find_all(text); 48 | ASSERT_EQ(matches.size(), 1); 49 | EXPECT_EQ(matches[0].start, 5); 50 | EXPECT_EQ(matches[0].end, 12); 51 | EXPECT_EQ( 52 | text.substr(matches[0].start, matches[0].end - matches[0].start), 53 | "example"); 54 | } 55 | 56 | // Test complex pattern with negative lookahead that should fall back to PCRE2. 57 | // This specific pattern is from the Qwen2.5 1.5B pretokenizer. 58 | // https://huggingface.co/Qwen/Qwen2.5-1.5B/raw/main/tokenizer.json 59 | TEST_F(RegexTest, ComplexPatternWithNegativeLookahead) { 60 | const std::string complex_pattern = 61 | "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; 62 | 63 | // Now verify that the factory function fallsback on a PCRE2 regex 64 | auto regex = TK_UNWRAP_THROW(create_regex(complex_pattern)); 65 | EXPECT_NE(dynamic_cast(regex.get()), nullptr); 66 | 67 | // Test the pattern with some sample text 68 | std::string text = "Hello's world\n test"; 69 | auto matches = regex->find_all(text); 70 | 71 | // We expect to match: 72 | // 1. "Hello" (word) 73 | // 2. "'s" (contraction) 74 | // 3. " world" (word with leading space) 75 | // 4. "\n" (newline) 76 | // 5. " " (whitespace) 77 | // 6. " test" (word with leading space) 78 | ASSERT_EQ(matches.size(), 6); 79 | 80 | EXPECT_EQ(matches[0].start, 0); 81 | EXPECT_EQ(matches[0].end, 5); 82 | EXPECT_EQ( 83 | text.substr(matches[0].start, matches[0].end - matches[0].start), 84 | "Hello"); 85 | EXPECT_EQ(matches[1].start, 5); 86 | EXPECT_EQ(matches[1].end, 7); 87 | EXPECT_EQ( 88 | text.substr(matches[1].start, matches[1].end - matches[1].start), "'s"); 89 | EXPECT_EQ(matches[2].start, 7); 90 | EXPECT_EQ(matches[2].end, 13); 91 | EXPECT_EQ( 92 | text.substr(matches[2].start, matches[2].end - matches[2].start), 93 | " world"); 94 | EXPECT_EQ(matches[3].start, 13); 95 | EXPECT_EQ(matches[3].end, 14); 96 | EXPECT_EQ( 97 | text.substr(matches[3].start, matches[3].end - matches[3].start), "\n"); 98 | EXPECT_EQ(matches[4].start, 14); 99 | EXPECT_EQ(matches[4].end, 15); 100 | EXPECT_EQ( 101 | text.substr(matches[4].start, matches[4].end - matches[4].start), " "); 102 | EXPECT_EQ(matches[5].start, 15); 103 | EXPECT_EQ(matches[5].end, 20); 104 | EXPECT_EQ( 105 | text.substr(matches[5].start, matches[5].end - matches[5].start), 106 | " test"); 107 | } 108 | -------------------------------------------------------------------------------- /test/test_sentencepiece.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #include 11 | #include 12 | 13 | namespace tokenizers { 14 | 15 | namespace { 16 | static inline std::string _get_resource_path(const std::string& name) { 17 | return std::getenv("RESOURCES_PATH") + std::string("/") + name; 18 | } 19 | } // namespace 20 | 21 | TEST(SPTokenizerTest, TestEncodeWithoutLoad) { 22 | SPTokenizer tokenizer; 23 | std::string text = "Hello world!"; 24 | auto result = tokenizer.encode(text, /*bos*/ 0, /*eos*/ 1); 25 | EXPECT_EQ(result.error(), Error::Uninitialized); 26 | } 27 | 28 | TEST(SPTokenizerTest, TestDecodeWithoutLoad) { 29 | SPTokenizer tokenizer; 30 | auto result = tokenizer.decode(0, 0); 31 | EXPECT_EQ(result.error(), Error::Uninitialized); 32 | } 33 | 34 | TEST(SPTokenizerTest, TestLoad) { 35 | SPTokenizer tokenizer; 36 | auto path = _get_resource_path("test_sentencepiece.model"); 37 | auto error = tokenizer.load(path); 38 | EXPECT_EQ(error, Error::Ok); 39 | } 40 | 41 | TEST(SPTokenizerTest, TestLoadInvalidPath) { 42 | SPTokenizer tokenizer; 43 | auto error = tokenizer.load("invalid_path"); 44 | EXPECT_EQ(error, Error::LoadFailure); 45 | } 46 | 47 | TEST(SPTokenizerTest, TestEncode) { 48 | SPTokenizer tokenizer; 49 | auto path = _get_resource_path("test_sentencepiece.model"); 50 | auto error = tokenizer.load(path); 51 | EXPECT_EQ(error, Error::Ok); 52 | std::string text = "Hello world!"; 53 | auto result = tokenizer.encode(text, /*bos*/ 1, /*eos*/ 0); 54 | EXPECT_TRUE(result.ok()); 55 | EXPECT_EQ(result.get().size(), 4); 56 | EXPECT_EQ(result.get()[0], 1); 57 | EXPECT_EQ(result.get()[1], 15043); 58 | EXPECT_EQ(result.get()[2], 3186); 59 | EXPECT_EQ(result.get()[3], 29991); 60 | } 61 | 62 | TEST(SPTokenizerTest, TestDecode) { 63 | SPTokenizer tokenizer; 64 | auto path = _get_resource_path("test_sentencepiece.model"); 65 | auto error = tokenizer.load(path); 66 | EXPECT_EQ(error, Error::Ok); 67 | std::vector tokens = {1, 15043, 3186, 29991}; 68 | std::vector expected = {"", "Hello", " world", "!"}; 69 | for (auto i = 0; i < 3; ++i) { 70 | auto result = tokenizer.decode(tokens[i], tokens[i + 1]); 71 | EXPECT_TRUE(result.ok()); 72 | EXPECT_EQ(result.get(), expected[i + 1]); 73 | } 74 | } 75 | 76 | } // namespace tokenizers 77 | -------------------------------------------------------------------------------- /test/test_string_integer_map.cpp: -------------------------------------------------------------------------------- 1 | // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #if defined(__APPLE__) || defined(WIN32) || defined(__linux__) 14 | #define TEST_MEMORY_COMPARISON 1 15 | 16 | #if defined(__APPLE__) 17 | #include 18 | #else 19 | #include 20 | #endif 21 | #endif 22 | 23 | namespace { 24 | 25 | using namespace ::testing; 26 | using ::base64::decode; 27 | using ::tokenizers::Error; 28 | using ::tokenizers::Result; 29 | using ::tokenizers::detail::StringIntegerMap; 30 | using ::tokenizers::detail::StringIntegerMapTypeBuilder; 31 | using TokenizerMap = std::unordered_map; 32 | 33 | static inline std::string _get_resource_path(const std::string& name) { 34 | return std::getenv("RESOURCES_PATH") + std::string("/") + name; 35 | } 36 | 37 | } // namespace 38 | 39 | class StringIntegerMapTest : public Test { 40 | public: 41 | void SetUp() override { 42 | modelPath_ = std::getenv("RESOURCES_PATH") + 43 | std::string("/test_tiktoken_tokenizer.model"); 44 | } 45 | 46 | Result loadModel() { 47 | std::ifstream file(modelPath_); 48 | TK_CHECK_OR_RETURN_ERROR( 49 | file, 50 | ParseFailure, 51 | "failed to open encoder file: %s", 52 | modelPath_.c_str()); 53 | 54 | TokenizerMap model; 55 | for (std::string line; std::getline(file, line);) { 56 | if (line.empty()) { 57 | continue; 58 | } 59 | 60 | auto pos = line.find(' '); 61 | auto token = TK_UNWRAP(decode({line.data(), pos})); 62 | uint64_t rank = 0; 63 | try { 64 | rank = std::stoul(line.substr(pos + 1)); 65 | } catch (const std::exception&) { 66 | TK_CHECK_OR_RETURN_ERROR( 67 | false, ParseFailure, "invalid encoder rank: %s", line.c_str()); 68 | } 69 | model[token] = rank; 70 | } 71 | 72 | return model; 73 | } 74 | 75 | std::string modelPath_; 76 | }; 77 | 78 | #if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON 79 | 80 | class TrackingAllocatorBase { 81 | public: 82 | static void reset(); 83 | static std::size_t getSize(); 84 | 85 | protected: 86 | static void* allocate(std::size_t size); 87 | static void deallocate(void* ptr); 88 | 89 | static std::size_t size_; 90 | }; 91 | 92 | void TrackingAllocatorBase::reset() { 93 | size_ = 0; 94 | } 95 | 96 | std::size_t TrackingAllocatorBase::getSize() { 97 | return size_; 98 | } 99 | 100 | void* TrackingAllocatorBase::allocate(std::size_t size) { 101 | void* ptr = malloc(size); 102 | if (!ptr) { 103 | return nullptr; 104 | } 105 | 106 | #if defined(WIN32) 107 | size_ += _msize(ptr); 108 | #elif defined(__APPLE__) 109 | size_ += malloc_size(const_cast(ptr)); 110 | #else 111 | size_ += malloc_usable_size(ptr); 112 | #endif 113 | 114 | return ptr; 115 | } 116 | 117 | void TrackingAllocatorBase::deallocate(void* ptr) { 118 | if (!ptr) { 119 | return; 120 | } 121 | 122 | #if defined(WIN32) 123 | size_ -= _msize(ptr); 124 | #elif defined(__APPLE__) 125 | size_ -= malloc_size(const_cast(ptr)); 126 | #else 127 | size_ -= malloc_usable_size(ptr); 128 | #endif 129 | 130 | free(ptr); 131 | } 132 | 133 | std::size_t TrackingAllocatorBase::size_ = 0; 134 | 135 | template 136 | class TrackingAllocator : public TrackingAllocatorBase { 137 | public: 138 | using value_type = T; 139 | TrackingAllocator() noexcept = default; 140 | template 141 | explicit TrackingAllocator(TrackingAllocator const&) noexcept {} 142 | 143 | value_type* allocate(std::size_t count) { 144 | return static_cast( 145 | TrackingAllocatorBase::allocate(count * sizeof(value_type))); // NOLINT 146 | } 147 | 148 | void deallocate(value_type* ptr, std::size_t /*count*/) noexcept { 149 | TrackingAllocatorBase::deallocate(ptr); 150 | } 151 | }; 152 | 153 | template 154 | bool operator==( 155 | TrackingAllocator const&, 156 | TrackingAllocator const&) noexcept { 157 | return true; 158 | } 159 | 160 | template 161 | bool operator!=( 162 | TrackingAllocator const& lhs, 163 | TrackingAllocator const& rhs) noexcept { 164 | return !(lhs == rhs); 165 | } 166 | 167 | #endif 168 | 169 | TEST_F(StringIntegerMapTest, CreateFromModel) { 170 | const auto res = loadModel(); 171 | ASSERT_EQ(res.ok(), true); 172 | const auto& model = res.get(); 173 | StringIntegerMap map(model); 174 | 175 | for (const auto& [model_key, model_value] : model) { 176 | EXPECT_THAT(map.tryGetInteger(model_key), testing::Optional(model_value)) 177 | << model_key; 178 | EXPECT_THAT(map.tryGetString(model_value), testing::Optional(model_key)) 179 | << model_value; 180 | } 181 | 182 | EXPECT_FALSE(map.tryGetInteger("Ich weiß nicht")); 183 | EXPECT_FALSE(map.tryGetString(999999999)); 184 | 185 | EXPECT_EQ(map.size(), model.size()); 186 | std::unordered_set walked_strings; 187 | std::unordered_set walked_integers; 188 | 189 | for (std::size_t index = 0; index < map.size(); ++index) { 190 | const auto [str, integer] = map.getElement(index); 191 | EXPECT_TRUE(walked_strings.insert(str).second) << "str: " << str; 192 | EXPECT_TRUE(walked_integers.insert(integer).second) 193 | << "integer: " << integer; 194 | } 195 | } 196 | 197 | #if defined(TEST_MEMORY_COMPARISON) && TEST_MEMORY_COMPARISON 198 | 199 | TEST_F(StringIntegerMapTest, MemoryConsumptionComparison) { 200 | TrackingAllocatorBase::reset(); 201 | EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); 202 | 203 | const auto res = loadModel(); 204 | ASSERT_EQ(res.ok(), true); 205 | const auto& model = res.get(); 206 | 207 | std::size_t string_integer_map_size = 0; 208 | { 209 | typename StringIntegerMapTypeBuilder<>::WithAllocator< 210 | TrackingAllocator>::Map map(model); 211 | string_integer_map_size = TrackingAllocatorBase::getSize(); 212 | } 213 | 214 | EXPECT_EQ(TrackingAllocatorBase::getSize(), 0); 215 | 216 | std::size_t unordered_map_size = 0; 217 | { 218 | std::unordered_map< 219 | std::string, 220 | std::uint64_t, 221 | std::hash, 222 | std::equal_to, 223 | TrackingAllocator>> 224 | strings_to_ints; 225 | std::unordered_map< 226 | std::uint64_t, 227 | std::string, 228 | std::hash, 229 | std::equal_to, 230 | TrackingAllocator>> 231 | ints_to_strings; 232 | for (const auto& [k, v] : model) { 233 | strings_to_ints.emplace(k, v); 234 | ints_to_strings.emplace(v, k); 235 | } 236 | 237 | unordered_map_size = TrackingAllocatorBase::getSize(); 238 | } 239 | 240 | EXPECT_LT(string_integer_map_size, unordered_map_size); 241 | 242 | #if 1 243 | std::cout << "string integer map size = " << string_integer_map_size 244 | << std::endl; 245 | std::cout << "unordered map size = " << unordered_map_size << std::endl; 246 | #endif 247 | } 248 | 249 | #endif 250 | 251 | template 252 | struct FixedHash { 253 | std::size_t operator()(const std::string_view& str) const { 254 | if (str.empty()) { 255 | return hash_offset; 256 | } 257 | 258 | return str.size() - 1 + hash_offset; 259 | } 260 | 261 | std::size_t operator()(std::uint64_t value) const { 262 | if (value == 0) { 263 | return hash_offset; 264 | } 265 | 266 | return static_cast(std::log10(value)) + hash_offset; 267 | } 268 | }; 269 | 270 | template 271 | class StringIntegerMapHashTest : public Test { 272 | public: 273 | using Container = typename StringIntegerMapTypeBuilder<>::WithIntegerHash< 274 | THash>::template WithStringHash::Map; 275 | }; 276 | 277 | using StringIntegerMapHashTestTypes = 278 | ::testing::Types, FixedHash<1>, FixedHash<2>, FixedHash<3>>; 279 | TYPED_TEST_SUITE(StringIntegerMapHashTest, StringIntegerMapHashTestTypes); 280 | 281 | TYPED_TEST(StringIntegerMapHashTest, HashCollisions) { 282 | std::unordered_map source = { 283 | {"a", 0}, 284 | {"b", 1}, 285 | {"c", 2}, 286 | {"d", 3}, 287 | }; 288 | 289 | typename TestFixture::Container map(source); 290 | 291 | // 292 | // Check that the strings exist in the map. 293 | // 294 | 295 | EXPECT_THAT(map.tryGetInteger("a"), Optional(0ull)); 296 | EXPECT_THAT(map.tryGetInteger("b"), Optional(1ull)); 297 | EXPECT_THAT(map.tryGetInteger("c"), Optional(2ull)); 298 | EXPECT_THAT(map.tryGetInteger("d"), Optional(3ull)); 299 | 300 | EXPECT_FALSE(map.tryGetInteger("e")); 301 | 302 | // 303 | // Check that the integers exist in the map. 304 | // 305 | 306 | EXPECT_THAT(map.tryGetString(0), Optional(std::string_view("a"))); 307 | EXPECT_THAT(map.tryGetString(1), Optional(std::string_view("b"))); 308 | EXPECT_THAT(map.tryGetString(2), Optional(std::string_view("c"))); 309 | EXPECT_THAT(map.tryGetString(3), Optional(std::string_view("d"))); 310 | 311 | EXPECT_FALSE(map.tryGetString(4)); 312 | 313 | // 314 | // Test a lookup into the next bucket (which should be empty). 315 | // 316 | 317 | EXPECT_FALSE(map.tryGetInteger("aa")); 318 | EXPECT_FALSE(map.tryGetInteger("aaa")); 319 | EXPECT_FALSE(map.tryGetInteger("aaaa")); 320 | 321 | EXPECT_FALSE(map.tryGetString(10)); 322 | EXPECT_FALSE(map.tryGetString(100)); 323 | EXPECT_FALSE(map.tryGetString(1000)); 324 | } 325 | -------------------------------------------------------------------------------- /test/test_tiktoken.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | // @lint-ignore-every LICENSELINT 9 | 10 | #include 11 | #include 12 | 13 | using namespace ::testing; 14 | 15 | namespace tokenizers { 16 | 17 | namespace { 18 | // Test case based on Llama 2 19 | const std::string kPattern = 20 | R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; 21 | static constexpr int32_t kSpecialTokensSize = 256; 22 | static inline std::unique_ptr> _get_special_tokens() { 23 | auto special_tokens = 24 | std::make_unique>(std::vector{ 25 | "<|begin_of_text|>", 26 | "<|end_of_text|>", 27 | "<|reserved_special_token_0|>", 28 | "<|reserved_special_token_1|>", 29 | "<|reserved_special_token_2|>", 30 | "<|reserved_special_token_3|>", 31 | "<|start_header_id|>", 32 | "<|end_header_id|>", 33 | "<|reserved_special_token_4|>", 34 | "<|eot_id|>"}); 35 | 36 | // pad the rest of the special tokens with reserved tokens 37 | ssize_t reserved_special_token_num = 5; 38 | while (special_tokens->size() < kSpecialTokensSize) { 39 | special_tokens->emplace_back( 40 | "<|reserved_special_token_" + 41 | std::to_string(reserved_special_token_num++) + "|>"); 42 | } 43 | return special_tokens; 44 | } 45 | 46 | static inline std::string _get_resource_path(const std::string& name) { 47 | return std::getenv("RESOURCES_PATH") + std::string("/") + name; 48 | } 49 | 50 | } // namespace 51 | 52 | class TiktokenTest : public Test { 53 | public: 54 | void SetUp() override { 55 | tokenizer_ = 56 | std::make_unique(kPattern, _get_special_tokens(), 0, 1); 57 | modelPath_ = _get_resource_path("test_tiktoken_tokenizer.model"); 58 | } 59 | 60 | std::unique_ptr tokenizer_; 61 | std::string modelPath_; 62 | }; 63 | 64 | TEST_F(TiktokenTest, TestEncodeWithoutLoad) { 65 | Tiktoken tokenizer; 66 | std::string text = "Hello world!"; 67 | auto result = tokenizer.encode(text, /*bos*/ 0, /*eos*/ 1); 68 | EXPECT_EQ(result.error(), Error::Uninitialized); 69 | } 70 | 71 | TEST_F(TiktokenTest, TestDecodeWithoutLoad) { 72 | Tiktoken tokenizer; 73 | auto result = tokenizer.decode(0, 0); 74 | EXPECT_EQ(result.error(), Error::Uninitialized); 75 | } 76 | 77 | TEST_F(TiktokenTest, TokenizerVocabSizeIsExpected) { 78 | Error res = tokenizer_->load(modelPath_.c_str()); 79 | EXPECT_EQ(res, Error::Ok); 80 | EXPECT_EQ(tokenizer_->vocab_size(), 128256); 81 | EXPECT_EQ(tokenizer_->bos_tok(), 128000); 82 | EXPECT_EQ(tokenizer_->eos_tok(), 128001); 83 | } 84 | 85 | TEST_F(TiktokenTest, TestEncode) { 86 | Error res = tokenizer_->load(modelPath_.c_str()); 87 | EXPECT_EQ(res, Error::Ok); 88 | Result> out = tokenizer_->encode("hello world", 1, 0); 89 | EXPECT_EQ(out.error(), Error::Ok); 90 | EXPECT_EQ(out.get().size(), 3); 91 | EXPECT_EQ(out.get()[0], 128000); 92 | EXPECT_EQ(out.get()[1], 15339); 93 | EXPECT_EQ(out.get()[2], 1917); 94 | } 95 | 96 | TEST_F(TiktokenTest, TestDecode) { 97 | Error res = tokenizer_->load(modelPath_.c_str()); 98 | EXPECT_EQ(res, Error::Ok); 99 | std::vector expected = {"<|begin_of_text|>", "hello", " world"}; 100 | std::vector tokens = {128000, 15339, 1917}; 101 | for (size_t i = 0; i < tokens.size(); i++) { 102 | Result out = tokenizer_->decode(0, tokens[i]); 103 | EXPECT_EQ(out.error(), Error::Ok); 104 | EXPECT_EQ(out.get(), expected[i]); 105 | } 106 | } 107 | 108 | TEST_F(TiktokenTest, TokenizerDecodeOutOfRangeFails) { 109 | Error res = tokenizer_->load(modelPath_.c_str()); 110 | EXPECT_EQ(res, Error::Ok); 111 | // The vocab size is 128256, addes 256 just so the token is out of vocab 112 | // range. 113 | Result out = tokenizer_->decode(0, 128256 + 256); 114 | EXPECT_EQ(out.error(), Error::DecodeFailure); 115 | } 116 | 117 | TEST_F(TiktokenTest, ConstructionWithInvalidBOSIndex) { 118 | // gtest death test doesn't work on iOS: 119 | // https://github.com/google/googletest/issues/2834 120 | #if !GTEST_OS_IOS 121 | EXPECT_EXIT( 122 | std::make_unique( 123 | std::make_unique>( 124 | std::vector{"<|end_of_text|>"}), 125 | 1, 126 | 0), 127 | ::testing::KilledBySignal(SIGABRT), 128 | ""); 129 | #endif 130 | } 131 | 132 | TEST_F(TiktokenTest, ConstructionWithInvalidEOSIndex) { 133 | // gtest death test doesn't work on iOS: 134 | // https://github.com/google/googletest/issues/2834 135 | #if !GTEST_OS_IOS 136 | EXPECT_EXIT( 137 | std::make_unique( 138 | std::make_unique>( 139 | std::vector{"<|begin_of_text|>"}), 140 | 0, 141 | 1), 142 | ::testing::KilledBySignal(SIGABRT), 143 | ""); 144 | #endif 145 | } 146 | 147 | TEST_F(TiktokenTest, TestLoadInvalidPath) { 148 | Tiktoken tokenizer; 149 | auto error = tokenizer.load("invalid_path"); 150 | EXPECT_EQ(error, Error::LoadFailure); 151 | } 152 | 153 | TEST_F(TiktokenTest, LoadTiktokenFileWithInvalidRank) { 154 | auto invalidModelPath = 155 | _get_resource_path("test_tiktoken_invalid_rank.model"); 156 | Error res = tokenizer_->load(invalidModelPath.c_str()); 157 | 158 | EXPECT_EQ(res, Error::EncodeFailure); 159 | } 160 | 161 | TEST_F(TiktokenTest, LoadTiktokenFileWithInvalidBase64) { 162 | auto invalidModelPath = 163 | _get_resource_path("test_tiktoken_invalid_base64.model"); 164 | Error res = tokenizer_->load(invalidModelPath.c_str()); 165 | 166 | EXPECT_EQ(res, Error::Base64DecodeFailure); 167 | } 168 | 169 | TEST_F(TiktokenTest, LoadTiktokenFileWithNoSpace) { 170 | auto invalidModelPath = _get_resource_path("test_tiktoken_no_space.model"); 171 | Error res = tokenizer_->load(invalidModelPath.c_str()); 172 | 173 | EXPECT_EQ(res, Error::ParseFailure); 174 | } 175 | 176 | TEST_F(TiktokenTest, LoadTiktokenFileWithBPEFile) { 177 | auto invalidModelPath = _get_resource_path("test_bpe_tokenizer.bin"); 178 | Error res = tokenizer_->load(invalidModelPath.c_str()); 179 | 180 | EXPECT_EQ(res, Error::ParseFailure); 181 | } 182 | } // namespace tokenizers 183 | -------------------------------------------------------------------------------- /test/test_tiktoken.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # @lint-ignore-every LICENSELINT 7 | 8 | import unittest 9 | 10 | import pkg_resources 11 | 12 | from pytorch_tokenizers.tiktoken import TiktokenTokenizer 13 | 14 | 15 | class TestTiktokenTokenizer(unittest.TestCase): 16 | def test_default(self): 17 | model_path = pkg_resources.resource_filename( 18 | "pytorch.tokenizers.test", "test_tiktoken_tokenizer.model" 19 | ) 20 | tiktoken = TiktokenTokenizer(model_path) 21 | s = "<|begin_of_text|> hellow world." 22 | self.assertEqual(s, tiktoken.decode(tiktoken.encode(s, bos=False, eos=False))) 23 | 24 | def test_custom_pattern_and_special_tokens(self): 25 | o220k_pattern = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 26 | model_path = pkg_resources.resource_filename( 27 | "pytorch.tokenizers.test", "test_tiktoken_tokenizer.model" 28 | ) 29 | tiktoken = TiktokenTokenizer( 30 | model_path, 31 | pat_str=o220k_pattern, 32 | special_tokens=[ 33 | "<|begin_of_text|>", 34 | "<|end_of_text|>", 35 | "<|custom_token|>", 36 | ], 37 | ) 38 | custom_token_id = tiktoken.special_tokens["<|custom_token|>"] 39 | 40 | s = "<|begin_of_text|> hellow world, this is a custom token: <|custom_token|>." 41 | encoding = tiktoken.encode( 42 | s, 43 | bos=False, 44 | eos=False, 45 | allowed_special="all", 46 | ) 47 | self.assertTrue(custom_token_id in encoding) 48 | self.assertEqual(s, tiktoken.decode(encoding)) 49 | -------------------------------------------------------------------------------- /third-party/TARGETS: -------------------------------------------------------------------------------- 1 | # Any targets that should be shared between fbcode and xplat must be defined in 2 | # targets.bzl. This file can contain fbcode-only targets. 3 | 4 | load(":targets.bzl", "define_common_targets") 5 | 6 | oncall("executorch") 7 | 8 | define_common_targets() 9 | -------------------------------------------------------------------------------- /third-party/llama.cpp-unicode/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | 3 | project(unicode LANGUAGES CXX) 4 | 5 | add_library(unicode STATIC 6 | src/unicode.cpp 7 | src/unicode-data.cpp 8 | ) 9 | 10 | target_include_directories(unicode PUBLIC 11 | ${CMAKE_CURRENT_SOURCE_DIR}/include 12 | ) 13 | 14 | target_compile_features(unicode PUBLIC cxx_std_17) 15 | 16 | install(TARGETS unicode 17 | ARCHIVE DESTINATION lib 18 | ) 19 | 20 | install(DIRECTORY include/ DESTINATION include) 21 | -------------------------------------------------------------------------------- /third-party/llama.cpp-unicode/README.md: -------------------------------------------------------------------------------- 1 | # llama.cpp Unicode 2 | 3 | This is a vendored copy of the `unicode.h` and `unicode-data.h` modules from [llama.cpp](https://github.com/ggerganov/llama.cpp), along with their corresponding source files. The modules are held as vendored source rather than submodules since they are a small subset of the overall `llama.cpp` project. 4 | 5 | ## Latest Update 6 | 7 | llama.cpp - commit 54ef9cfc 8 | https://github.com/ggerganov/llama.cpp -------------------------------------------------------------------------------- /third-party/llama.cpp-unicode/include/unicode-data.h: -------------------------------------------------------------------------------- 1 | /* 2 | llama.cpp - commit 54ef9cfc 3 | https://github.com/ggerganov/llama.cpp 4 | 5 | MIT License 6 | 7 | Copyright (c) 2023-2024 The ggml authors 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | */ 27 | 28 | #pragma once 29 | 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | struct range_nfd { 37 | uint32_t first; 38 | uint32_t last; 39 | uint32_t nfd; 40 | }; 41 | 42 | static const uint32_t MAX_CODEPOINTS = 0x110000; 43 | 44 | extern const std::initializer_list> 45 | unicode_ranges_flags; 46 | 47 | constexpr std::array unicode_set_whitespace = { 48 | 0x000009, 0x00000A, 0x00000B, 0x00000C, 0x00000D, 0x000020, 0x000085, 49 | 0x0000A0, 0x001680, 0x002000, 0x002001, 0x002002, 0x002003, 0x002004, 50 | 0x002005, 0x002006, 0x002007, 0x002008, 0x002009, 0x00200A, 0x002028, 51 | 0x002029, 0x00202F, 0x00205F, 0x003000, 52 | }; 53 | 54 | extern const std::initializer_list> 55 | unicode_map_lowercase; 56 | extern const std::initializer_list> 57 | unicode_map_uppercase; 58 | extern const std::initializer_list unicode_ranges_nfd; 59 | -------------------------------------------------------------------------------- /third-party/llama.cpp-unicode/include/unicode.h: -------------------------------------------------------------------------------- 1 | /* 2 | llama.cpp - commit 54ef9cfc 3 | https://github.com/ggerganov/llama.cpp 4 | 5 | MIT License 6 | 7 | Copyright (c) 2023-2024 The ggml authors 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | */ 27 | 28 | #pragma once 29 | 30 | #include 31 | #include 32 | #include 33 | 34 | // TODO: prefix all symbols with "llama_" 35 | 36 | struct codepoint_flags { 37 | enum { 38 | UNDEFINED = 0x0001, 39 | NUMBER = 0x0002, // regex: \p{N} 40 | LETTER = 0x0004, // regex: \p{L} 41 | SEPARATOR = 0x0008, // regex: \p{Z} 42 | ACCENT_MARK = 0x0010, // regex: \p{M} 43 | PUNCTUATION = 0x0020, // regex: \p{P} 44 | SYMBOL = 0x0040, // regex: \p{S} 45 | CONTROL = 0x0080, // regex: \p{C} 46 | MASK_CATEGORIES = 0x00FF, 47 | }; 48 | 49 | // codepoint type 50 | uint16_t is_undefined : 1; 51 | uint16_t is_number : 1; // regex: \p{N} 52 | uint16_t is_letter : 1; // regex: \p{L} 53 | uint16_t is_separator : 1; // regex: \p{Z} 54 | uint16_t is_accent_mark : 1; // regex: \p{M} 55 | uint16_t is_punctuation : 1; // regex: \p{P} 56 | uint16_t is_symbol : 1; // regex: \p{S} 57 | uint16_t is_control : 1; // regex: \p{C} 58 | // helper flags 59 | uint16_t is_whitespace : 1; // regex: \s 60 | uint16_t is_lowercase : 1; 61 | uint16_t is_uppercase : 1; 62 | uint16_t is_nfd : 1; 63 | 64 | // decode from uint16 65 | inline codepoint_flags(const uint16_t flags = 0) { 66 | *reinterpret_cast(this) = flags; 67 | } 68 | 69 | inline uint16_t as_uint() const { 70 | return *reinterpret_cast(this); 71 | } 72 | 73 | inline uint16_t category_flag() const { 74 | return this->as_uint() & MASK_CATEGORIES; 75 | } 76 | }; 77 | 78 | size_t unicode_len_utf8(char src); 79 | 80 | std::string unicode_cpt_to_utf8(uint32_t cp); 81 | uint32_t unicode_cpt_from_utf8(const std::string &utf8, size_t &offset); 82 | std::vector unicode_cpts_from_utf8(const std::string &utf8); 83 | 84 | std::vector 85 | unicode_cpts_normalize_nfd(const std::vector &cpts); 86 | 87 | codepoint_flags unicode_cpt_flags(const uint32_t cp); 88 | codepoint_flags unicode_cpt_flags(const std::string &utf8); 89 | 90 | std::string unicode_byte_to_utf8(uint8_t byte); 91 | uint8_t unicode_utf8_to_byte(const std::string &utf8); 92 | 93 | uint32_t unicode_tolower(uint32_t cp); 94 | 95 | std::vector 96 | unicode_regex_split(const std::string &text, 97 | const std::vector ®ex_exprs); 98 | -------------------------------------------------------------------------------- /third-party/targets.bzl: -------------------------------------------------------------------------------- 1 | load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") 2 | load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") 3 | 4 | def define_common_targets(): 5 | runtime.cxx_library( 6 | name = "unicode", 7 | srcs = [ 8 | "llama.cpp-unicode/src/unicode.cpp", 9 | "llama.cpp-unicode/src/unicode-data.cpp", 10 | ], 11 | header_namespace = "", 12 | exported_headers = subdir_glob([ 13 | ("llama.cpp-unicode/include", "*.h"), 14 | ]), 15 | visibility = ["@EXECUTORCH_CLIENTS", "//pytorch/tokenizers/..."], 16 | ) 17 | 18 | if runtime.is_oss: 19 | runtime.cxx_library( 20 | name = "abseil", 21 | srcs = glob( 22 | ["abseil-cpp/absl/**/*.cc"], 23 | exclude = [ 24 | "abseil-cpp/absl/**/*test*.cc", 25 | "abseil-cpp/absl/**/*mock*.cc", 26 | "abseil-cpp/absl/**/*matchers*.cc", 27 | "abseil-cpp/absl/**/*benchmark*.cc", 28 | ], 29 | ), 30 | _is_external_target = True, 31 | exported_linker_flags = select( 32 | { 33 | "DEFAULT": [], 34 | "ovr_config//os:macos": ["-Wl,-framework,CoreFoundation"], 35 | }, 36 | ), 37 | public_include_directories = ["abseil-cpp"], 38 | visibility = ["PUBLIC"], 39 | ) 40 | 41 | runtime.cxx_library( 42 | name = "re2", 43 | srcs = glob( 44 | [ 45 | "re2/re2/**/*.cc", 46 | "re2/util/**/*.cc", 47 | ], 48 | exclude = [ 49 | "re2/re2/**/*test*.cc", 50 | "re2/re2/testing/*.cc", 51 | "re2/re2/fuzzing/*.cc", 52 | "re2/re2/**/*benchmark*.cc", 53 | ], 54 | ), 55 | _is_external_target = True, 56 | public_include_directories = ["re2"], 57 | visibility = ["PUBLIC"], 58 | exported_deps = [ 59 | ":abseil", 60 | ], 61 | ) 62 | 63 | runtime.genrule( 64 | name = "config_h_generic", 65 | srcs = ["pcre2/src/config.h.generic"], 66 | cmd = "cp $SRCS $OUT", 67 | out = "pcre2/src/config.h", 68 | ) 69 | runtime.genrule( 70 | name = "pcre2_h_generic", 71 | srcs = ["pcre2/src/pcre2.h.generic"], 72 | cmd = "cp $SRCS $OUT", 73 | out = "pcre2/src/pcre2.h", 74 | ) 75 | runtime.genrule( 76 | name = "pcre2_chartables_c", 77 | srcs = ["pcre2/src/pcre2_chartables.c.dist"], 78 | cmd = "cp $SRCS $OUT", 79 | out = "pcre2/src/pcre2_chartables.c", 80 | ) 81 | runtime.cxx_library( 82 | name = "pcre2", 83 | srcs = [ 84 | "pcre2/src/pcre2_auto_possess.c", 85 | "pcre2/src/pcre2_chkdint.c", 86 | "pcre2/src/pcre2_compile.c", 87 | "pcre2/src/pcre2_compile_cgroup.c", 88 | "pcre2/src/pcre2_compile_class.c", 89 | "pcre2/src/pcre2_config.c", 90 | "pcre2/src/pcre2_context.c", 91 | "pcre2/src/pcre2_convert.c", 92 | "pcre2/src/pcre2_dfa_match.c", 93 | "pcre2/src/pcre2_error.c", 94 | "pcre2/src/pcre2_extuni.c", 95 | "pcre2/src/pcre2_find_bracket.c", 96 | "pcre2/src/pcre2_jit_compile.c", 97 | "pcre2/src/pcre2_maketables.c", 98 | "pcre2/src/pcre2_match.c", 99 | "pcre2/src/pcre2_match_data.c", 100 | "pcre2/src/pcre2_match_next.c", 101 | "pcre2/src/pcre2_newline.c", 102 | "pcre2/src/pcre2_ord2utf.c", 103 | "pcre2/src/pcre2_pattern_info.c", 104 | "pcre2/src/pcre2_script_run.c", 105 | "pcre2/src/pcre2_serialize.c", 106 | "pcre2/src/pcre2_string_utils.c", 107 | "pcre2/src/pcre2_study.c", 108 | "pcre2/src/pcre2_substitute.c", 109 | "pcre2/src/pcre2_substring.c", 110 | "pcre2/src/pcre2_tables.c", 111 | "pcre2/src/pcre2_ucd.c", 112 | "pcre2/src/pcre2_valid_utf.c", 113 | "pcre2/src/pcre2_xclass.c", 114 | ":pcre2_chartables_c", 115 | ], 116 | exported_headers = {"pcre2.h": ":pcre2_h_generic"}, 117 | headers = {"config.h": ":config_h_generic"}, 118 | # Preprocessor flags from https://github.com/PCRE2Project/pcre2/blob/2e03e323339ab692640626f02f8d8d6f95bff9c6/BUILD.bazel#L23. 119 | preprocessor_flags = [ 120 | "-DHAVE_CONFIG_H", 121 | "-DHAVE_MEMMOVE", 122 | "-DHAVE_STRERROR", 123 | "-DPCRE2_CODE_UNIT_WIDTH=8", 124 | "-DPCRE2_STATIC", 125 | "-DSUPPORT_PCRE2_8", 126 | "-DSUPPORT_UNICODE", 127 | ] + select({ 128 | "DEFAULT": ["-DHAVE_UNISTD_H"], 129 | "ovr_config//os:windows": [], 130 | }), 131 | header_namespace = "", 132 | _is_external_target = True, 133 | include_directories = ["pcre2/src"], 134 | visibility = ["PUBLIC"], 135 | ) 136 | --------------------------------------------------------------------------------