├── .formatter.exs ├── .github └── workflows │ ├── ci.yml │ ├── gh-pages.yml │ ├── release.yml │ └── rust-ci.yml ├── .gitignore ├── CHANGELOG.md ├── CODEOWNERS.md ├── LICENSE ├── README.md ├── RELEASE.md ├── flake.lock ├── flake.nix ├── lib ├── tokenizers.ex └── tokenizers │ ├── added_token.ex │ ├── decode_stream.ex │ ├── decoder.ex │ ├── encoding.ex │ ├── encoding │ └── transformation.ex │ ├── http_client.ex │ ├── model.ex │ ├── model │ ├── bpe.ex │ ├── unigram.ex │ ├── wordlevel.ex │ └── wordpiece.ex │ ├── native.ex │ ├── normalizer.ex │ ├── post_processor.ex │ ├── pre_tokenizer.ex │ ├── shared.ex │ ├── tokenizer.ex │ └── trainer.ex ├── mix.exs ├── mix.lock ├── native └── ex_tokenizers │ ├── .cargo │ └── config.toml │ ├── .gitignore │ ├── Cargo.lock │ ├── Cargo.toml │ ├── Cross.toml │ ├── README.md │ └── src │ ├── added_token.rs │ ├── decode_stream.rs │ ├── decoders.rs │ ├── encoding.rs │ ├── error.rs │ ├── lib.rs │ ├── models.rs │ ├── normalizers.rs │ ├── post_processors.rs │ ├── pre_tokenizers.rs │ ├── tokenizer.rs │ ├── trainers.rs │ └── util.rs ├── notebooks ├── pretrained.livemd └── training.livemd └── test ├── fixtures ├── bert-base-cased.json ├── merges.txt ├── vocab.json └── vocab.txt ├── test_helper.exs └── tokenizers ├── added_token_test.exs ├── decode_stream_test.exs ├── decoder_test.exs ├── model ├── bpe_test.exs ├── unigram.exs ├── wordlevel_test.exs └── wordpiece_test.exs ├── model_test.exs ├── normalizer_test.exs ├── post_processor_test.exs ├── pre_tokenizer_test.exs ├── tokenizer_test.exs └── trainer_test.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | workflow_dispatch: 8 | env: 9 | MIX_ENV: test 10 | TOKENIZERS_BUILD: "true" 11 | jobs: 12 | main: 13 | runs-on: ubuntu-latest 14 | name: "Test (${{ matrix.elixir_version }}, ${{ matrix.otp_version }})" 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | include: 19 | - elixir_version: 1.18.3 20 | otp_version: 27.3.3 21 | lint: true 22 | - elixir_version: 1.13.4 23 | otp_version: 24.3.4 24 | steps: 25 | - uses: actions/checkout@v4 26 | - uses: actions/cache@v4 27 | with: 28 | path: | 29 | deps 30 | _build 31 | key: ${{ runner.os }}-mix-${{ matrix.elixir_version }}-${{matrix.otp_version}}-${{ hashFiles('**/mix.lock') }} 32 | restore-keys: | 33 | ${{ runner.os }}-mix-${{ matrix.elixir_version }}-${{matrix.otp_version}}- 34 | - uses: actions/cache@v4 35 | with: 36 | path: | 37 | ~/.cargo/bin/ 38 | ~/.cargo/registry/index/ 39 | ~/.cargo/registry/cache/ 40 | ~/.cargo/git/db/ 41 | native/ex_tokenizers/target/ 42 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 43 | 44 | - name: Install Rust toolchain 45 | uses: dtolnay/rust-toolchain@stable 46 | 47 | - uses: erlef/setup-beam@v1 48 | with: 49 | otp-version: ${{matrix.otp_version}} 50 | elixir-version: ${{matrix.elixir_version}} 51 | - run: mix deps.get 52 | - run: mix format --check-formatted 53 | if: ${{ matrix.lint }} 54 | - run: mix deps.unlock --check-unused 55 | if: ${{ matrix.lint }} 56 | - run: mix deps.compile 57 | - run: mix compile --warnings-as-errors 58 | if: ${{ matrix.lint }} 59 | - run: mix test 60 | -------------------------------------------------------------------------------- /.github/workflows/gh-pages.yml: -------------------------------------------------------------------------------- 1 | name: Docs on GitHub pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | concurrency: 9 | group: ${{ github.ref }}-docs 10 | cancel-in-progress: true 11 | 12 | env: 13 | ELIXIR_VERSION: 1.18.3 14 | OTP_VERSION: 27.3.3 15 | TOKENIZERS_BUILD: "true" 16 | 17 | jobs: 18 | deploy: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - uses: actions/cache@v4 24 | with: 25 | path: | 26 | deps 27 | _build 28 | key: ${{ runner.os }}-mix-${{ hashFiles('**/mix.lock') }} 29 | restore-keys: | 30 | ${{ runner.os }}-mix- 31 | 32 | - uses: actions/cache@v4 33 | with: 34 | path: | 35 | ~/.cargo/bin/ 36 | ~/.cargo/registry/index/ 37 | ~/.cargo/registry/cache/ 38 | ~/.cargo/git/db/ 39 | native/ex_tokenizers/target/ 40 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 41 | 42 | - name: Install Rust toolchain 43 | uses: dtolnay/rust-toolchain@stable 44 | 45 | - uses: erlef/setup-beam@v1 46 | with: 47 | otp-version: "${{ env.OTP_VERSION }}" 48 | elixir-version: "${{ env.ELIXIR_VERSION }}" 49 | - run: mix deps.get 50 | - run: mix docs 51 | - name: Deploy 🚀 52 | uses: JamesIves/github-pages-deploy-action@4.1.1 53 | with: 54 | branch: gh-pages 55 | folder: doc 56 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build precompiled NIFs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | # Just run on main branch if "native" path changed. Tags will always run. 9 | - "native/**" 10 | # Also run in case there is any changes to this file. 11 | - ".github/workflows/release.yml" 12 | tags: 13 | - "*" 14 | 15 | jobs: 16 | build_release: 17 | name: NIF ${{ matrix.nif }} - ${{ matrix.job.target }} (${{ matrix.job.os }}) 18 | runs-on: ${{ matrix.job.os }} 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | nif: ["2.15", "2.16"] 23 | job: 24 | - { target: aarch64-apple-darwin, os: macos-13 } 25 | - { target: aarch64-unknown-linux-gnu, os: ubuntu-22.04, use-cross: true } 26 | - { target: aarch64-unknown-linux-musl, os: ubuntu-22.04, use-cross: true } 27 | - { target: arm-unknown-linux-gnueabihf, os: ubuntu-22.04, use-cross: true } 28 | - { target: riscv64gc-unknown-linux-gnu, os: ubuntu-22.04, use-cross: true } 29 | - { target: x86_64-apple-darwin, os: macos-13 } 30 | - { target: x86_64-pc-windows-gnu, os: windows-2022 } 31 | - { target: x86_64-pc-windows-msvc, os: windows-2022 } 32 | - { target: x86_64-unknown-linux-gnu, os: ubuntu-22.04 } 33 | - { target: x86_64-unknown-linux-musl, os: ubuntu-22.04, use-cross: true } 34 | 35 | steps: 36 | - name: Checkout source code 37 | uses: actions/checkout@v3 38 | 39 | - name: Extract crate information 40 | shell: bash 41 | run: | 42 | # Get the project version from mix.exs 43 | echo "PROJECT_VERSION=$(sed -n 's/^ @version "\(.*\)"/\1/p' mix.exs | head -n1)" >> $GITHUB_ENV 44 | 45 | - name: Install Rust toolchain 46 | uses: dtolnay/rust-toolchain@stable 47 | with: 48 | target: ${{ matrix.job.target }} 49 | 50 | - name: Build the project 51 | id: build-crate 52 | uses: philss/rustler-precompiled-action@v1.0.1 53 | with: 54 | nif-version: ${{ matrix.nif }} 55 | project-dir: "native/ex_tokenizers" 56 | project-name: ex_tokenizers 57 | project-version: ${{ env.PROJECT_VERSION }} 58 | target: ${{ matrix.job.target }} 59 | use-cross: ${{ matrix.job.use-cross }} 60 | 61 | - name: Artifact upload 62 | uses: actions/upload-artifact@v4 63 | with: 64 | name: ${{ steps.build-crate.outputs.file-name }} 65 | path: ${{ steps.build-crate.outputs.file-path }} 66 | 67 | - name: Publish archives and packages 68 | uses: softprops/action-gh-release@v1 69 | with: 70 | files: | 71 | ${{ steps.build-crate.outputs.file-path }} 72 | if: startsWith(github.ref, 'refs/tags/') 73 | -------------------------------------------------------------------------------- /.github/workflows/rust-ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | paths: 6 | - 'native/**' 7 | pull_request: 8 | paths: 9 | - 'native/**' 10 | workflow_dispatch: 11 | 12 | jobs: 13 | lint-rust: 14 | name: Lint Rust 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | manifest: 19 | - native/ex_tokenizers/Cargo.toml 20 | 21 | steps: 22 | - uses: actions/checkout@v4 23 | 24 | - uses: actions/cache@v4 25 | with: 26 | path: | 27 | ~/.cargo/bin/ 28 | ~/.cargo/registry/index/ 29 | ~/.cargo/registry/cache/ 30 | ~/.cargo/git/db/ 31 | native/ex_tokenizers/target/ 32 | priv/native/ 33 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 34 | restore-keys: | 35 | ${{ runner.os }}-cargo- 36 | 37 | - name: Install Rust toolchain 38 | uses: dtolnay/rust-toolchain@stable 39 | with: 40 | components: rustfmt, clippy 41 | 42 | - name: run rustfmt 43 | run: cargo fmt --manifest-path=${{ matrix.manifest }} --all -- --check 44 | 45 | - name: run clippy 46 | run: cargo clippy --manifest-path=${{ matrix.manifest }} -- -Dwarnings 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | tokenizers-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | 28 | # Ignore ElixirLS temp files. 29 | .elixir_ls/ 30 | 31 | # Shared objects build by Rust. 32 | *.so 33 | 34 | .nix-* 35 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [v0.5.1] - 2024-10-02 9 | 10 | ### Added 11 | 12 | - Add new ByteLevel normalizer (`Tokenizers.Normalizer.byte_level/0`). 13 | 14 | ### Changed 15 | 16 | - Reduce memory copies when encoding. 17 | - Bump Rust tokenizers to v0.20.0. 18 | 19 | ## [v0.5.0] - 2024-04-24 20 | 21 | ### Added 22 | 23 | - Support for regular expressions to replace normalizer. See 24 | `Tokenizers.Normalizer.replace_regex/2`. 25 | - Support for regular expressions to split pre-tokenizer. See 26 | `Tokenizers.PreTokenizer.split_regex/3`. 27 | 28 | ### Removed 29 | 30 | - **(Breaking)** Removed `:add_prefix_space` option in favour of `:prepend_scheme` 31 | for metaspace decoder and pre-tokenizer 32 | 33 | ## [v0.4.0] - 2023-08-09 34 | 35 | ### Added 36 | 37 | - Support for training a tokenizer from scratch. See `Tokenizers.Tokenizer.train_from_files/3` 38 | and `Tokenizers.Model` for available models. 39 | 40 | - Support for changing tokenizer configuration, such as `Tokenizers.Tokenizer.set_padding/2` 41 | and `Tokenizers.Tokenizer.set_truncation/2`. See the "Configuration" functions group in 42 | `Tokenizers.Tokenizer`. 43 | 44 | - Support for apply multiple encoding transformations without additional data copies, 45 | see `Tokenizers.Encoding.Transformation`. Transformations can be passed to 46 | `Tokenizers.Tokenizer.encode/3` via `:encoding_transformations` or applied via 47 | `Tokenizers.Encoding.transform/2`. 48 | 49 | ### Changed 50 | 51 | - **(Breaking)** `Tokenizers.Tokenizer.encode/3` no longer accepts a batch of inputs, 52 | to encode a batch use `Tokenizers.Tokenizer.encode_batch/3` instead 53 | 54 | - **(Breaking)** `Tokenizers.Tokenizer.decode/3` no longer accepts a batch of inputs, 55 | to encode a batch use `Tokenizers.Tokenizer.decode_batch/3` instead 56 | 57 | ## [v0.3.2] - 2023-04-19 58 | 59 | ### Changed 60 | 61 | - Bump [tokenizers](https://crates.io/crates/tokenizers) to v0.13.3 in the 62 | crate's dependencies. 63 | 64 | ## [v0.3.1] - 2023-04-06 65 | 66 | ### Added 67 | 68 | - Add binary variants for accessing encoding data. This way we can convert encoding 69 | data to tensors without additional allocations. The following functions were added: 70 | 71 | - `get_u32_ids/1` 72 | - `get_u32_attention_mask/1` 73 | - `get_u32_type_ids/1` 74 | - `get_u32_special_tokens_mask/1` 75 | 76 | ## [v0.3.0] - 2023-03-04 77 | 78 | ### Added 79 | 80 | - Add option to use cache when downloading pretrained files. We check the ETAG of 81 | the file before trying to download it. This introduces the `:use_cache` and `:cache_dir` 82 | options to the `Tokenizers.from_pretrained/2` function. 83 | 84 | - Support adding special tokens when creating a tokenizer. This allows a pretrained 85 | tokenizer to be loaded with additional special tokens. 86 | 87 | This change adds the `:additional_special_tokens` option to the `Tokenizers.from_pretrained/2` 88 | function. 89 | 90 | - Add support for the `riscv64gc-unknown-linux-gnu` target, which is useful for Nerves 91 | projects running on 64 bits RISC-V computers. 92 | This means that we are precompiling the project to run on those machines. 93 | 94 | ### Changed 95 | 96 | - Change minimum required version of Rustler Precompiled to `~> 0.6`. With this, we have 97 | the `aarch64-unknown-linux-musl` and `riscv64gc-unknown-linux-gnu` as default targets. 98 | But we also drop support for the NIF version 2.14. 99 | 100 | ## [v0.2.0] - 2022-12-07 101 | 102 | ### Added 103 | 104 | - Add a minimal http server to avoid problems with openssl 105 | - Expose `Encoding.get_special_tokens_mask/1` and `Encoding.get_offsets/1` for NER 106 | 107 | ## [v0.1.0] - 2022-08-25 108 | 109 | First release. 110 | 111 | [v0.5.1]: https://github.com/elixir-nx/tokenizers/compare/v0.5.0...v0.5.1 112 | [v0.5.0]: https://github.com/elixir-nx/tokenizers/compare/v0.4.0...v0.5.0 113 | [v0.4.0]: https://github.com/elixir-nx/tokenizers/compare/v0.3.2...v0.4.0 114 | [v0.3.2]: https://github.com/elixir-nx/tokenizers/compare/v0.3.1...v0.3.2 115 | [v0.3.1]: https://github.com/elixir-nx/tokenizers/compare/v0.3.0...v0.3.1 116 | [v0.3.0]: https://github.com/elixir-nx/tokenizers/compare/v0.2.0...v0.3.0 117 | [v0.2.0]: https://github.com/elixir-nx/tokenizers/compare/v0.1.0...v0.2.0 118 | [v0.1.0]: https://github.com/elixir-nx/tokenizers/releases/tag/v0.1.0 119 | -------------------------------------------------------------------------------- /CODEOWNERS.md: -------------------------------------------------------------------------------- 1 | - @cigrainger 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tokenizers 2 | 3 | ![CI](https://github.com/elixir-nx/explorer/actions/workflows/ci.yml/badge.svg) 4 | 5 | Elixir bindings for [Hugging Face Tokenizers](https://github.com/huggingface/tokenizers). 6 | 7 | ## Installation 8 | 9 | You can add `:tokenizers` as dependency in your `mix.exs`: 10 | 11 | ```elixir 12 | def deps do 13 | [ 14 | {:tokenizers, "~> 0.3.0"}, 15 | ] 16 | end 17 | ``` 18 | 19 | If you are using Livebook or IEx, you can instead run: 20 | 21 | ```elixir 22 | Mix.install([ 23 | {:tokenizers, "~> 0.3.0"}, 24 | ]) 25 | ``` 26 | 27 | ## Example 28 | 29 | You can use any pre-trained tokenizer from any model repo on Hugging Face Hub, such as [bert-base-cased](https://huggingface.co/bert-base-cased). 30 | 31 | ```elixir 32 | {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-cased") 33 | {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, "Hello there!") 34 | Tokenizers.Encoding.get_tokens(encoding) 35 | #=> ["Hello", "there", "!"] 36 | Tokenizers.Encoding.get_ids(encoding) 37 | #=> [8667, 1175, 106] 38 | ``` 39 | 40 | The [notebooks](./notebooks) directory has [an introductory Livebook](./notebooks/pretrained.livemd) to give you a feel for the API. 41 | 42 | ## Contributing 43 | 44 | Tokenizers uses Rust to call functionality from the Hugging Face Tokenizers library. While 45 | Rust is not necessary to use Tokenizers as a package, you need Rust tooling installed on 46 | your machine if you want to compile from source, which is the case when contributing to 47 | Tokenizers. In particular, you will need Rust Stable, which can be installed with 48 | [Rustup](https://rust-lang.github.io/rustup/installation/index.html). 49 | 50 | ## License 51 | 52 | Copyright (c) 2022 Christopher Grainger 53 | 54 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 55 | 56 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 57 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # How to release 2 | 3 | Because we use 4 | [`RustlerPrecompiled`](https://hexdocs.pm/rustler_precompiled/RustlerPrecompiled.html), releasing 5 | is a bit more involved than it would be otherwise. 6 | 7 | 1. Open a PR with any changes needed for the release. 8 | 9 | - This must include at least updating the `version` in `mix.exs` and any other files that 10 | reference it, like `README.md`. It must also include updating `CHANGELOG.md` to reflect the 11 | release. 12 | 13 | 2. Once the PR is merged, cut a GitHub release with information from the changelog and tag the 14 | commit with the version number. 15 | 3. This will kick off the "Build precompiled NIFs" GitHub Action. Wait for this to complete. It 16 | usually takes around 40-60 minutes. 17 | 4. While the NIFs are compiling, ensure you have the latest version of `main` and don't have any 18 | intermediate builds by running `rm -rf native/ex_tokenizers/target`. 19 | 5. Once the NIFs are built, use `mix rustler_precompiled.download Tokenizers.Native --all --print` to download generate the checksum file. 20 | 6. Run `mix hex.publish`. 21 | 7. Bump the version in the `mix.exs` and add the `-dev` flag to it. 22 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "fenix": { 4 | "inputs": { 5 | "nixpkgs": [ 6 | "nixpkgs" 7 | ], 8 | "rust-analyzer-src": "rust-analyzer-src" 9 | }, 10 | "locked": { 11 | "lastModified": 1702448575, 12 | "narHash": "sha256-Gm8lI5vumDEryeUI+bT8w0AIvbolZIGh0F/E0mQSLcw=", 13 | "owner": "nix-community", 14 | "repo": "fenix", 15 | "rev": "dcf3ca909bd069e6a5737461b64c8d894c6dee85", 16 | "type": "github" 17 | }, 18 | "original": { 19 | "owner": "nix-community", 20 | "repo": "fenix", 21 | "type": "github" 22 | } 23 | }, 24 | "flake-utils": { 25 | "inputs": { 26 | "systems": "systems" 27 | }, 28 | "locked": { 29 | "lastModified": 1701680307, 30 | "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", 31 | "owner": "numtide", 32 | "repo": "flake-utils", 33 | "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", 34 | "type": "github" 35 | }, 36 | "original": { 37 | "owner": "numtide", 38 | "repo": "flake-utils", 39 | "type": "github" 40 | } 41 | }, 42 | "nixpkgs": { 43 | "locked": { 44 | "lastModified": 1702312524, 45 | "narHash": "sha256-gkZJRDBUCpTPBvQk25G0B7vfbpEYM5s5OZqghkjZsnE=", 46 | "owner": "NixOS", 47 | "repo": "nixpkgs", 48 | "rev": "a9bf124c46ef298113270b1f84a164865987a91c", 49 | "type": "github" 50 | }, 51 | "original": { 52 | "id": "nixpkgs", 53 | "ref": "nixos-unstable", 54 | "type": "indirect" 55 | } 56 | }, 57 | "root": { 58 | "inputs": { 59 | "fenix": "fenix", 60 | "flake-utils": "flake-utils", 61 | "nixpkgs": "nixpkgs" 62 | } 63 | }, 64 | "rust-analyzer-src": { 65 | "flake": false, 66 | "locked": { 67 | "lastModified": 1702418101, 68 | "narHash": "sha256-XyrXFAiMS5r9Kl4lPpmkTTclPKGwJBxln6enERe5nvk=", 69 | "owner": "rust-lang", 70 | "repo": "rust-analyzer", 71 | "rev": "b3af1916ccfb85233571ce9ecb45a3a7c74ba0fb", 72 | "type": "github" 73 | }, 74 | "original": { 75 | "owner": "rust-lang", 76 | "ref": "nightly", 77 | "repo": "rust-analyzer", 78 | "type": "github" 79 | } 80 | }, 81 | "systems": { 82 | "locked": { 83 | "lastModified": 1681028828, 84 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 85 | "owner": "nix-systems", 86 | "repo": "default", 87 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 88 | "type": "github" 89 | }, 90 | "original": { 91 | "owner": "nix-systems", 92 | "repo": "default", 93 | "type": "github" 94 | } 95 | } 96 | }, 97 | "root": "root", 98 | "version": 7 99 | } 100 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Tokenizers"; 3 | 4 | inputs = { 5 | fenix = { 6 | url = "github:nix-community/fenix"; 7 | inputs.nixpkgs.follows = "nixpkgs"; 8 | }; 9 | nixpkgs.url = "nixpkgs/nixos-unstable"; 10 | flake-utils.url = "github:numtide/flake-utils"; 11 | }; 12 | 13 | outputs = { 14 | self, 15 | nixpkgs, 16 | flake-utils, 17 | fenix, 18 | }: 19 | flake-utils.lib.eachSystem [ 20 | flake-utils.lib.system.x86_64-linux 21 | flake-utils.lib.system.x86_64-darwin 22 | flake-utils.lib.system.aarch64-darwin 23 | flake-utils.lib.system.aarch64-linux 24 | ] 25 | (system: let 26 | pkgs = import nixpkgs {inherit system;}; 27 | in { 28 | devShell = pkgs.mkShell { 29 | buildInputs = with pkgs; 30 | [ 31 | act 32 | binutils 33 | clang 34 | elixir_1_15 35 | (fenix.packages."${system}".complete.withComponents [ 36 | "cargo" 37 | "clippy" 38 | "rust-src" 39 | "rustc" 40 | "rustfmt" 41 | ]) 42 | gcc 43 | libiconv 44 | openssl 45 | pkg-config 46 | ] 47 | ++ lib.optionals stdenv.isDarwin [ 48 | darwin.apple_sdk.frameworks.Foundation 49 | darwin.apple_sdk.frameworks.Carbon 50 | darwin.apple_sdk.frameworks.AppKit 51 | ]; 52 | shellHook = '' 53 | mkdir -p .nix-mix 54 | mkdir -p .nix-hex 55 | export MIX_HOME=$PWD/.nix-mix 56 | export HEX_HOME=$PWD/.nix-hex 57 | export PATH=$MIX_HOME/bin:$PATH 58 | export PATH=$HEX_HOME/bin:$PATH 59 | export PATH=$MIX_HOME/escripts:$PATH 60 | export ERL_AFLAGS="-kernel shell_history enabled" 61 | ''; 62 | }; 63 | }); 64 | } 65 | -------------------------------------------------------------------------------- /lib/tokenizers.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers do 2 | @moduledoc """ 3 | Elixir bindings to [Hugging Face Tokenizers](https://github.com/huggingface/tokenizers). 4 | 5 | Hugging Face describes the Tokenizers library as: 6 | 7 | > Fast State-of-the-art tokenizers, optimized for both research and 8 | > production 9 | > 10 | > 🤗 Tokenizers provides an implementation of today’s most used 11 | > tokenizers, with a focus on performance and versatility. These 12 | > tokenizers are also used in 🤗 Transformers. 13 | 14 | A tokenizer is effectively a pipeline of transformations that take 15 | a text input and return an encoded version of that text (`t:Tokenizers.Encoding.t/0`). 16 | 17 | The main entrypoint to this library is the `Tokenizers.Tokenizer` 18 | module, which defines the `t:Tokenizers.Tokenizer.t/0` struct, a 19 | container holding the constituent parts of the pipeline. Most 20 | functionality is in that module. 21 | """ 22 | end 23 | -------------------------------------------------------------------------------- /lib/tokenizers/added_token.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.AddedToken do 2 | @moduledoc """ 3 | This struct represents a token added to tokenizer vocabulary. 4 | """ 5 | 6 | @type t() :: %__MODULE__{resource: reference()} 7 | defstruct [:resource] 8 | 9 | @doc """ 10 | Builds a new added token. 11 | 12 | ## Options 13 | 14 | * `:special` - defines whether this token is a special token. 15 | Defaults to `false` 16 | 17 | * `:single_word` - defines whether this token should only match 18 | single words. If `true`, this token will never match inside of a 19 | word. For example the token `ing` would match on `tokenizing` if 20 | this option is `false`. The notion of ”inside of a word” is 21 | defined by the word boundaries pattern in regular expressions 22 | (i.e. the token should start and end with word boundaries). 23 | Defaults to `false` 24 | 25 | * `:lstrip` - defines whether this token should strip all potential 26 | whitespace on its left side. If `true`, this token will greedily 27 | match any whitespace on its left. For example if we try to match 28 | the token `[MASK]` with `lstrip=true`, in the text `"I saw a [MASK]"`, 29 | we would match on `" [MASK]"`. (Note the space on the left). 30 | Defaults to `false` 31 | 32 | * `:rstrip` - defines whether this token should strip all potential 33 | whitespaces on its right side. If `true`, this token will greedily 34 | match any whitespace on its right. It works just like `:lstrip`, 35 | but on the right. Defaults to `false` 36 | 37 | * `:normalized` - defines whether this token should match against 38 | the normalized version of the input text. For example, with the 39 | added token `"yesterday"`, and a normalizer in charge of 40 | lowercasing the text, the token could be extract from the input 41 | `"I saw a lion Yesterday"`. If `true`, the token will be extracted 42 | from the normalized input `"i saw a lion yesterday"`. If `false`, 43 | the token will be extracted from the original input 44 | `"I saw a lion Yesterday"`. Defaults to `false` for special tokens 45 | and `true` otherwise 46 | 47 | """ 48 | @spec new(token :: String.t(), keyword()) :: t() 49 | defdelegate new(token, opts \\ []), to: Tokenizers.Native, as: :added_token_new 50 | 51 | @doc """ 52 | Retrieves information about added token. 53 | """ 54 | @spec info(added_token :: t()) :: map() 55 | defdelegate info(model), to: Tokenizers.Native, as: :added_token_info 56 | end 57 | 58 | defimpl Inspect, for: Tokenizers.AddedToken do 59 | import Inspect.Algebra 60 | 61 | @spec inspect(Tokenizers.AddedToken.t(), Inspect.Opts.t()) :: Inspect.Algebra.t() 62 | def inspect(decoder, opts) do 63 | attrs = 64 | decoder 65 | |> Tokenizers.Native.added_token_info() 66 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 67 | 68 | concat(["#Tokenizers.AddedToken<", to_doc(attrs, opts), ">"]) 69 | end 70 | end 71 | -------------------------------------------------------------------------------- /lib/tokenizers/decode_stream.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.DecodeStream do 2 | @moduledoc """ 3 | Implements streaming decoding functionality for tokenizers. 4 | """ 5 | 6 | @enforce_keys [:resource] 7 | defstruct [:resource] 8 | 9 | @type t :: %__MODULE__{ 10 | resource: reference() 11 | } 12 | 13 | @doc """ 14 | Creates a new decode stream. 15 | 16 | ## Options 17 | 18 | * `:skip_special_tokens` - determines whether special tokens should be 19 | skipped during decoding. By default, it is set to `false`. 20 | 21 | """ 22 | @spec new(keyword()) :: t() 23 | def new(opts \\ []) when is_list(opts) do 24 | opts = Keyword.validate!(opts, skip_special_tokens: false) 25 | Tokenizers.Native.decoder_stream_new(opts[:skip_special_tokens]) 26 | end 27 | 28 | @doc """ 29 | Steps through the decode stream with the given tokenizer and token ID. 30 | 31 | Returns `{:ok, String.t()}` if there's a decoded string, or `{:ok, :out_ofr_range}` if the token ID is out of range. 32 | Returns `{:error, reason}` if an error occurs during decoding. 33 | """ 34 | def step(%__MODULE__{} = decode_stream, tokenizer, id) when is_integer(id) do 35 | case Tokenizers.Native.decoder_stream_step(decode_stream, tokenizer, id) do 36 | {:ok, decoded} when is_binary(decoded) -> 37 | {:ok, decoded} 38 | 39 | {:ok, nil} -> 40 | {:ok, :out_of_range} 41 | 42 | {:error, reason} -> 43 | {:error, reason} 44 | end 45 | end 46 | 47 | @doc """ 48 | Returns information about the decode stream state. 49 | """ 50 | defdelegate info(decode_stream), to: Tokenizers.Native, as: :decoder_stream_info 51 | 52 | defimpl Inspect do 53 | import Inspect.Algebra 54 | alias Tokenizers.DecodeStream 55 | 56 | def inspect(decode_stream, opts) do 57 | "#Tokenizers.DecodeStream<#{to_doc(DecodeStream.info(decode_stream), opts)}>" 58 | end 59 | end 60 | end 61 | -------------------------------------------------------------------------------- /lib/tokenizers/decoder.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Decoder do 2 | @moduledoc """ 3 | Decoders and decoding functions. 4 | 5 | Decoder transforms a sequence of token ids back to a readable piece 6 | of text. 7 | 8 | Some normalizers and pre-tokenizers use special characters or 9 | identifiers that need special logic to be reverted. 10 | """ 11 | 12 | defstruct [:resource] 13 | 14 | @type t() :: %__MODULE__{resource: reference()} 15 | 16 | @doc """ 17 | Decodes tokens into string with provided decoder. 18 | """ 19 | @spec decode(t(), [String.t()]) :: {:ok, String.t()} | {:error, any()} 20 | defdelegate decode(decoder, tokens), to: Tokenizers.Native, as: :decoders_decode 21 | 22 | @doc """ 23 | Creates a BPE decoder. 24 | 25 | ## Options 26 | 27 | * `:suffix` - the suffix to add to the end of each word. Defaults 28 | to `` 29 | 30 | """ 31 | @spec bpe(keyword()) :: t() 32 | defdelegate bpe(opts \\ []), to: Tokenizers.Native, as: :decoders_bpe 33 | 34 | @doc """ 35 | Creates a ByteFallback decoder. 36 | """ 37 | @spec byte_fallback() :: t() 38 | defdelegate byte_fallback(), to: Tokenizers.Native, as: :decoders_byte_fallback 39 | 40 | @doc """ 41 | Creates a ByteLevel decoder. 42 | """ 43 | @spec byte_level() :: t() 44 | defdelegate byte_level(), to: Tokenizers.Native, as: :decoders_byte_level 45 | 46 | @doc """ 47 | Creates a CTC decoder. 48 | 49 | ## Options 50 | 51 | * `:pad_token` - the token used for padding. Defaults to `` 52 | 53 | * `:word_delimiter_token` - the token used for word delimiter. 54 | Defaults to `|` 55 | 56 | * `:cleanup` - whether to cleanup tokenization artifacts, defaults 57 | to `true` 58 | 59 | """ 60 | @spec ctc(keyword()) :: t() 61 | defdelegate ctc(opts \\ []), to: Tokenizers.Native, as: :decoders_ctc 62 | 63 | @doc """ 64 | Creates a Fuse decoder. 65 | """ 66 | @spec fuse :: t() 67 | defdelegate fuse(), to: Tokenizers.Native, as: :decoders_fuse 68 | 69 | @doc """ 70 | Creates a Metaspace decoder. 71 | 72 | ## Options 73 | 74 | * `:replacement` - the replacement character. Defaults to `▁` 75 | (as char) 76 | 77 | * `:prepend_scheme` - whether to add a space to the first word if there 78 | isn't already one. This lets us treat "hello" exactly like "say hello". 79 | Either of `:always`, `:never`, `:first`. `:first` means the space is 80 | only added on the first token (relevant when special tokens are used 81 | or other pre_tokenizer are used). Defaults to `:always` 82 | 83 | """ 84 | @spec metaspace(keyword()) :: t() 85 | defdelegate metaspace(opts \\ []), 86 | to: Tokenizers.Native, 87 | as: :decoders_metaspace 88 | 89 | @doc """ 90 | Creates a Replace decoder. 91 | """ 92 | @spec replace(String.t(), String.t()) :: t() 93 | defdelegate replace(pattern, content), to: Tokenizers.Native, as: :decoders_replace 94 | 95 | @doc """ 96 | Combines a list of decoders into a single sequential decoder. 97 | """ 98 | @spec sequence(decoders :: [t()]) :: t() 99 | defdelegate sequence(decoders), to: Tokenizers.Native, as: :decoders_sequence 100 | 101 | @doc """ 102 | Creates a Strip decoder. 103 | 104 | It expects a character and the number of times to strip the 105 | character on `left` and `right` sides. 106 | """ 107 | @spec strip(char(), non_neg_integer(), non_neg_integer()) :: t() 108 | defdelegate strip(content, left, right), to: Tokenizers.Native, as: :decoders_strip 109 | 110 | @doc """ 111 | Creates a WordPiece decoder. 112 | 113 | ## Options 114 | 115 | * `:prefix` - The prefix to use for subwords. Defaults to `##` 116 | 117 | * `:cleanup` - Whether to cleanup tokenization artifacts. Defaults 118 | to `true` 119 | 120 | """ 121 | @spec word_piece(keyword()) :: t() 122 | defdelegate word_piece(opts \\ []), 123 | to: Tokenizers.Native, 124 | as: :decoders_wordpiece 125 | end 126 | 127 | defimpl Inspect, for: Tokenizers.Decoder do 128 | import Inspect.Algebra 129 | 130 | @spec inspect(Tokenizers.Decoder.t(), Inspect.Opts.t()) :: Inspect.Algebra.t() 131 | def inspect(decoder, opts) do 132 | attrs = 133 | decoder 134 | |> Tokenizers.Native.decoders_info() 135 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 136 | 137 | concat(["#Tokenizers.Decoder<", to_doc(attrs, opts), ">"]) 138 | end 139 | end 140 | -------------------------------------------------------------------------------- /lib/tokenizers/encoding.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Encoding do 2 | @moduledoc """ 3 | Encoding is the result of passing a text through tokenization pipeline. 4 | 5 | This function defines a struct and a number of functions to retrieve 6 | information about the encoded text. 7 | 8 | For further machine learning processing you most likely want to 9 | access the encoded token ids via `get_ids/1`. If you want to convert 10 | the ids to a tensor, use `get_u32_ids/1` to get a zero-copy binary. 11 | """ 12 | 13 | defstruct resource: nil 14 | 15 | @type t :: %__MODULE__{resource: reference()} 16 | 17 | @doc """ 18 | Returns the number of tokens in `encoding`. 19 | """ 20 | @spec get_length(t()) :: non_neg_integer() 21 | defdelegate get_length(encoding), to: Tokenizers.Native, as: :encoding_get_length 22 | 23 | @doc """ 24 | Returns the number of sequences combined in `encoding`. 25 | """ 26 | @spec get_n_sequences(t()) :: non_neg_integer() 27 | defdelegate get_n_sequences(encoding), to: Tokenizers.Native, as: :encoding_get_n_sequences 28 | 29 | @doc """ 30 | Sets the given sequence id for all tokens contained in `encoding`. 31 | """ 32 | @spec set_sequence_id(t(), non_neg_integer()) :: t() 33 | defdelegate set_sequence_id(encoding, id), to: Tokenizers.Native, as: :encoding_set_sequence_id 34 | 35 | @doc """ 36 | Returns the ids from `encoding`. 37 | """ 38 | @spec get_ids(t()) :: [integer()] 39 | defdelegate get_ids(encoding), to: Tokenizers.Native, as: :encoding_get_ids 40 | 41 | @doc """ 42 | Same as `get_ids/1`, but returns binary with u32 values. 43 | """ 44 | @spec get_u32_ids(t()) :: binary() 45 | defdelegate get_u32_ids(encoding), to: Tokenizers.Native, as: :encoding_get_u32_ids 46 | 47 | @doc """ 48 | Returns token type ids from `encoding`. 49 | """ 50 | @spec get_type_ids(t()) :: [integer()] 51 | defdelegate get_type_ids(encoding), to: Tokenizers.Native, as: :encoding_get_type_ids 52 | 53 | @doc """ 54 | Same as `get_type_ids/1`, but returns binary with u32 values. 55 | """ 56 | @spec get_u32_type_ids(t()) :: binary() 57 | defdelegate get_u32_type_ids(encoding), to: Tokenizers.Native, as: :encoding_get_u32_type_ids 58 | 59 | @doc """ 60 | Returns the attention mask from `encoding`. 61 | """ 62 | @spec get_attention_mask(t()) :: [integer()] 63 | defdelegate get_attention_mask(encoding), 64 | to: Tokenizers.Native, 65 | as: :encoding_get_attention_mask 66 | 67 | @doc """ 68 | Same as `get_attention_mask/1`, but returns binary with u32 values. 69 | """ 70 | @spec get_u32_attention_mask(t()) :: binary() 71 | defdelegate get_u32_attention_mask(encoding), 72 | to: Tokenizers.Native, 73 | as: :encoding_get_u32_attention_mask 74 | 75 | @doc """ 76 | Returns the special tokens mask from `encoding`. 77 | """ 78 | @spec get_special_tokens_mask(t()) :: [integer()] 79 | defdelegate get_special_tokens_mask(encoding), 80 | to: Tokenizers.Native, 81 | as: :encoding_get_special_tokens_mask 82 | 83 | @doc """ 84 | Same as `get_special_tokens_mask/1`, but returns binary with u32 values. 85 | """ 86 | @spec get_u32_special_tokens_mask(t()) :: binary() 87 | defdelegate get_u32_special_tokens_mask(encoding), 88 | to: Tokenizers.Native, 89 | as: :encoding_get_u32_special_tokens_mask 90 | 91 | @doc """ 92 | Returns the tokens from `encoding`. 93 | """ 94 | @spec get_tokens(t()) :: [binary()] 95 | defdelegate get_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_tokens 96 | 97 | @doc """ 98 | Returns word ids from `encoding`. 99 | """ 100 | @spec get_word_ids(t()) :: [non_neg_integer() | nil] 101 | defdelegate get_word_ids(encoding), to: Tokenizers.Native, as: :encoding_get_word_ids 102 | 103 | @doc """ 104 | Returns sequence ids from `encoding`. 105 | """ 106 | @spec get_sequence_ids(t()) :: [non_neg_integer() | nil] 107 | defdelegate get_sequence_ids(encoding), to: Tokenizers.Native, as: :encoding_get_sequence_ids 108 | 109 | @doc """ 110 | Returns offsets from `encoding`. 111 | 112 | The offsets are expressed in terms of UTF-8 bytes. 113 | """ 114 | @spec get_offsets(t()) :: [{integer(), integer()}] 115 | defdelegate get_offsets(encoding), to: Tokenizers.Native, as: :encoding_get_offsets 116 | 117 | @doc """ 118 | Returns the overflow from `encoding`. 119 | """ 120 | @spec get_overflowing(t()) :: [t()] 121 | defdelegate get_overflowing(encoding), to: Tokenizers.Native, as: :encoding_get_overflowing 122 | 123 | @doc """ 124 | Returns the encoded tokens corresponding to the word at the given 125 | index in the input sequence, with the form `{start_token, end_token + 1}`. 126 | """ 127 | @spec word_to_tokens(t(), non_neg_integer(), non_neg_integer()) :: 128 | {non_neg_integer(), non_neg_integer()} | nil 129 | defdelegate word_to_tokens(encoding, word, seq_id), 130 | to: Tokenizers.Native, 131 | as: :encoding_word_to_tokens 132 | 133 | @doc """ 134 | Returns the offsets of the word at the given index in the input 135 | sequence. 136 | """ 137 | @spec word_to_chars(t(), non_neg_integer(), non_neg_integer()) :: 138 | {non_neg_integer(), non_neg_integer()} | nil 139 | defdelegate word_to_chars(encoding, word, seq_id), 140 | to: Tokenizers.Native, 141 | as: :encoding_word_to_chars 142 | 143 | @doc """ 144 | Returns the index of the sequence containing the given token. 145 | """ 146 | @spec token_to_sequence(t(), non_neg_integer()) :: non_neg_integer() | nil 147 | defdelegate token_to_sequence(encoding, token), 148 | to: Tokenizers.Native, 149 | as: :encoding_token_to_sequence 150 | 151 | @doc """ 152 | Returns the offsets of the token at the given index. 153 | """ 154 | @spec token_to_chars(t(), non_neg_integer()) :: 155 | {non_neg_integer(), {non_neg_integer(), non_neg_integer()}} | nil 156 | defdelegate token_to_chars(encoding, token), to: Tokenizers.Native, as: :encoding_token_to_chars 157 | 158 | @doc """ 159 | Returns the word that contains the token at the given index. 160 | """ 161 | @spec token_to_word(t(), non_neg_integer()) :: 162 | {non_neg_integer(), non_neg_integer()} | nil 163 | defdelegate token_to_word(encoding, token), to: Tokenizers.Native, as: :encoding_token_to_word 164 | 165 | @doc """ 166 | Returns the token that contains the given char. 167 | """ 168 | @spec char_to_token(t(), non_neg_integer(), non_neg_integer()) :: 169 | non_neg_integer() | nil 170 | defdelegate char_to_token(encoding, position, seq_id), 171 | to: Tokenizers.Native, 172 | as: :encoding_char_to_token 173 | 174 | @doc """ 175 | Returns the word that contains the given char. 176 | """ 177 | @spec char_to_word(t(), non_neg_integer(), non_neg_integer()) :: 178 | non_neg_integer() | nil 179 | defdelegate char_to_word(encoding, position, seq_id), 180 | to: Tokenizers.Native, 181 | as: :encoding_char_to_word 182 | 183 | @typedoc """ 184 | Padding configuration. 185 | 186 | * `:direction` - the padding direction. Defaults to `:right` 187 | 188 | * `:pad_id` - the id corresponding to the padding token. Defaults 189 | to `0` 190 | 191 | * `:pad_type_id` - the type ID corresponding to the padding token. 192 | Defaults to `0` 193 | 194 | * `:pad_token` - the padding token to use. Defaults to `"[PAD]"` 195 | 196 | """ 197 | @type padding_opts :: [ 198 | pad_id: non_neg_integer(), 199 | pad_type_id: non_neg_integer(), 200 | pad_token: String.t(), 201 | direction: :left | :right 202 | ] 203 | 204 | @doc """ 205 | Pad the encoding to the given length. 206 | 207 | For available options see `t:padding_opts/0`. 208 | """ 209 | @spec pad(t(), non_neg_integer(), opts :: padding_opts()) :: t() 210 | defdelegate pad(encoding, target_length, opts \\ []), 211 | to: Tokenizers.Native, 212 | as: :encoding_pad 213 | 214 | @typedoc """ 215 | Truncation configuration. 216 | 217 | * `:stride` - the length of previous content to be included in each 218 | overflowing piece. Defaults to `0` 219 | 220 | * `:direction` - the truncation direction. Defaults to `:right` 221 | 222 | """ 223 | @type truncation_opts :: [stride: non_neg_integer(), direction: :left | :right] 224 | 225 | @doc """ 226 | Truncate the encoding to the given length. 227 | 228 | For available options see `t:truncation_opts/0`. 229 | """ 230 | @spec truncate(t(), non_neg_integer(), opts :: truncation_opts()) :: t() 231 | defdelegate truncate(encoding, max_length, opts \\ []), 232 | to: Tokenizers.Native, 233 | as: :encoding_truncate 234 | 235 | @doc """ 236 | Returns the number of tokens in `encoding`. 237 | """ 238 | @spec n_tokens(encoding :: t()) :: non_neg_integer() 239 | defdelegate n_tokens(encoding), to: Tokenizers.Native, as: :encoding_get_length 240 | 241 | @doc """ 242 | Performs set of transformations to given encoding, creating a new one. 243 | Transformations are applied in order they are given. 244 | 245 | While all these transformations can be done one by one, this function 246 | is more efficient as it avoids multiple allocations and Garbage Collection 247 | for intermediate encodings. 248 | 249 | Check the module `Tokenizers.Encoding.Transformation` for handy functions, 250 | that can be used to build the transformations list. 251 | Also, you can build this list manually, as long as it follows the format. 252 | """ 253 | defdelegate transform(encoding, transformations), to: Tokenizers.Native, as: :encoding_transform 254 | end 255 | 256 | defimpl Inspect, for: Tokenizers.Encoding do 257 | import Inspect.Algebra 258 | 259 | alias Tokenizers.Encoding 260 | 261 | def inspect(encoding, opts) do 262 | attrs = [ 263 | length: Encoding.get_length(encoding), 264 | ids: Encoding.get_ids(encoding) 265 | ] 266 | 267 | concat(["#Tokenizers.Encoding<", to_doc(attrs, opts), ">"]) 268 | end 269 | end 270 | -------------------------------------------------------------------------------- /lib/tokenizers/encoding/transformation.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Encoding.Transformation do 2 | @moduledoc """ 3 | Module containing handy functions to build the transformations list. 4 | 5 | This list is applied to an encoding using `Tokenizers.Encoding.transform/2`. 6 | """ 7 | 8 | @type t :: [ 9 | {:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}}, 10 | {:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}}, 11 | {:set_sequence_id, non_neg_integer()} 12 | ] 13 | 14 | @doc """ 15 | Generates the padding transformation. 16 | 17 | Check `Tokenizers.Encoding.pad/3` for more information. 18 | """ 19 | @spec pad(non_neg_integer(), Tokenizers.Encoding.padding_opts()) :: 20 | {:pad, {non_neg_integer(), Tokenizers.Encoding.padding_opts()}} 21 | def pad(target_length, opts \\ []) do 22 | {:pad, {target_length, opts}} 23 | end 24 | 25 | @doc """ 26 | Generates the truncation transformation. 27 | 28 | Check `Tokenizers.Encoding.truncate/3` for more information. 29 | """ 30 | @spec truncate(non_neg_integer(), Tokenizers.Encoding.truncation_opts()) :: 31 | {:truncate, {non_neg_integer(), Tokenizers.Encoding.truncation_opts()}} 32 | def truncate(max_length, opts \\ []) do 33 | {:truncate, {max_length, opts}} 34 | end 35 | 36 | @doc """ 37 | Generates the set_sequence_id transformation. 38 | 39 | Check `Tokenizers.Encoding.set_sequence_id/2` for more information. 40 | """ 41 | @spec set_sequence_id(non_neg_integer()) :: 42 | {:set_sequence_id, non_neg_integer()} 43 | def set_sequence_id(id) do 44 | {:set_sequence_id, id} 45 | end 46 | end 47 | -------------------------------------------------------------------------------- /lib/tokenizers/http_client.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.HTTPClient do 2 | @moduledoc """ 3 | A simple implementation of an HTTP client. 4 | 5 | This is using the built-in `:httpc` module, configured to use SSL. 6 | The `request/1` function is similar to `Req.request/1`. 7 | """ 8 | 9 | @base_url "https://huggingface.io" 10 | 11 | @doc """ 12 | Make an HTTP(s) requests. 13 | 14 | ## Options 15 | 16 | * `:method` - An HTTP method. By default it uses the `:get` method. 17 | 18 | * `:base_url` - The base URL to make requests. By default is #{inspect(@base_url)}. 19 | 20 | * `:url` - A path to a resource. By default is "". 21 | 22 | * `:headers` - A list of tuples representing HTTP headers. By default it's empty. 23 | 24 | """ 25 | def request(opts) when is_list(opts) do 26 | opts = Keyword.validate!(opts, base_url: @base_url, headers: [], method: :get, url: "") 27 | 28 | url = Path.join([opts[:base_url], opts[:url]]) |> String.to_charlist() 29 | headers = Enum.map(opts[:headers], fn {key, value} -> {String.to_charlist(key), value} end) 30 | 31 | {:ok, _} = Application.ensure_all_started(:inets) 32 | {:ok, _} = Application.ensure_all_started(:ssl) 33 | 34 | if proxy = System.get_env("HTTP_PROXY") || System.get_env("http_proxy") do 35 | %{host: host, port: port} = URI.parse(proxy) 36 | 37 | :httpc.set_options([{:proxy, {{String.to_charlist(host), port}, []}}]) 38 | end 39 | 40 | proxy = System.get_env("HTTPS_PROXY") || System.get_env("https_proxy") 41 | 42 | with true <- is_binary(proxy), 43 | %{host: host, port: port} when is_binary(host) and is_integer(port) <- URI.parse(proxy) do 44 | :httpc.set_options([{:https_proxy, {{String.to_charlist(host), port}, []}}]) 45 | end 46 | 47 | # https://erlef.github.io/security-wg/secure_coding_and_deployment_hardening/inets 48 | cacertfile = CAStore.file_path() |> String.to_charlist() 49 | 50 | http_options = [ 51 | ssl: [ 52 | verify: :verify_peer, 53 | cacertfile: cacertfile, 54 | depth: 3, 55 | customize_hostname_check: [ 56 | match_fun: :public_key.pkix_verify_hostname_match_fun(:https) 57 | ] 58 | ] 59 | ] 60 | 61 | options = [body_format: :binary] 62 | 63 | case :httpc.request(opts[:method], {url, headers}, http_options, options) do 64 | {:ok, {{_, status, _}, headers, body}} -> 65 | {:ok, %{status: status, headers: normalize_headers(headers), body: body}} 66 | 67 | {:ok, {status, body}} -> 68 | {:ok, %{status: status, body: body, headers: []}} 69 | 70 | {:error, reason} -> 71 | {:error, "could not make request #{url}: #{inspect(reason)}"} 72 | end 73 | end 74 | 75 | defp normalize_headers(headers) do 76 | for {key, value} <- headers do 77 | {List.to_string(key), List.to_string(value)} 78 | end 79 | end 80 | end 81 | -------------------------------------------------------------------------------- /lib/tokenizers/model.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model do 2 | @moduledoc """ 3 | The struct and associated functions for the tokenizer model. 4 | """ 5 | 6 | defstruct [:resource] 7 | 8 | @typedoc """ 9 | Represents different kind of models that can be used across the library. 10 | """ 11 | @type t() :: %__MODULE__{resource: reference()} 12 | 13 | @doc """ 14 | Retrieves information about the model. 15 | 16 | Information retrieved differs per model but all include `model_type`. 17 | """ 18 | @spec info(t()) :: map() 19 | defdelegate info(model), to: Tokenizers.Native, as: :models_info 20 | 21 | @doc """ 22 | Saves the given model in the given directory. 23 | 24 | This function generates a couple files with predefined names, you 25 | can specify `:prefix` to scope them. Existing files with the same 26 | names in this directory will be overridden. 27 | 28 | ## Options 29 | 30 | * `:prefix` - the prefix to use for all the files that will get 31 | created. Defaults to `""` 32 | 33 | """ 34 | @spec save(t(), String.t(), keyword()) :: {:ok, file_paths :: [String.t()]} | {:error, any()} 35 | defdelegate save(model, directory, opts \\ []), to: Tokenizers.Native, as: :models_save 36 | end 37 | 38 | defimpl Inspect, for: Tokenizers.Model do 39 | import Inspect.Algebra 40 | 41 | alias Tokenizers.Model 42 | 43 | @spec inspect(Tokenizers.Model.t(), Inspect.Opts.t()) :: Inspect.Algebra.t() 44 | def inspect(model, opts) do 45 | attrs = 46 | model 47 | |> Model.info() 48 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 49 | 50 | concat(["#Tokenizers.Model<", to_doc(attrs, opts), ">"]) 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /lib/tokenizers/model/bpe.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.BPE do 2 | @typedoc """ 3 | Options for model initialisation. 4 | 5 | * `:byte_fallback`- whether to use the byte fallback trick 6 | 7 | * `:cache_capacity` - the number of words that the BPE cache can 8 | contain. The cache allows to speed-up the process by keeping 9 | the result of the merge operations for a number of words. 10 | Defaults to `10_000` 11 | 12 | * `:dropout` - The BPE dropout to use. Must be a float between 13 | 0 and 1 14 | 15 | * `:unk_token` - The unknown token to be used by the model 16 | 17 | * `:continuing_subword_prefix` - The prefix to attach to subword 18 | units that don't represent a beginning of word 19 | 20 | * `:end_of_word_suffix` - The suffix to attach to subword units 21 | that represent an end of word 22 | 23 | """ 24 | @type options() :: [ 25 | cache_capacity: number(), 26 | dropout: float(), 27 | unk_token: String.t(), 28 | continuing_subword_prefix: String.t(), 29 | end_of_word_suffix: String.t(), 30 | fuse_unk: boolean(), 31 | byte_fallback: boolean() 32 | ] 33 | 34 | @doc """ 35 | Instantiate a BPE model from the given vocab and merges. 36 | """ 37 | @spec init( 38 | %{String.t() => integer()}, 39 | [{String.t(), String.t()}], 40 | options() 41 | ) :: {:ok, Tokenizers.Model.t()} 42 | defdelegate init(vocab, merges, options \\ []), to: Tokenizers.Native, as: :models_bpe_init 43 | 44 | @doc """ 45 | Instantiate an empty BPE model. 46 | """ 47 | @spec empty() :: {:ok, Tokenizers.Model.t()} 48 | defdelegate empty(), to: Tokenizers.Native, as: :models_bpe_empty 49 | 50 | @doc """ 51 | Instantiate a BPE model from the given vocab and merges files. 52 | """ 53 | @spec from_file(String.t(), String.t(), options()) :: {:ok, Tokenizers.Model.t()} 54 | defdelegate from_file(vocab_path, merges_path, options \\ []), 55 | to: Tokenizers.Native, 56 | as: :models_bpe_from_file 57 | end 58 | -------------------------------------------------------------------------------- /lib/tokenizers/model/unigram.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.Unigram do 2 | @typedoc """ 3 | Options for model initialisation. 4 | 5 | * `:byte_fallback`- whether to use the byte fallback trick 6 | * `:unk_id`- the unknown token id to be used by the model 7 | 8 | """ 9 | @type options() :: [ 10 | byte_fallback: boolean(), 11 | unk_id: integer() 12 | ] 13 | 14 | @doc """ 15 | Instantiate a Unigram model from the given vocab. 16 | """ 17 | @spec init([{String.t(), number()}], options()) :: {:ok, Tokenizers.Model.t()} 18 | defdelegate init(vocab, options \\ []), 19 | to: Tokenizers.Native, 20 | as: :models_unigram_init 21 | 22 | @doc """ 23 | Instantiate an empty Unigram model 24 | """ 25 | @spec empty() :: {:ok, Tokenizers.Model.t()} 26 | defdelegate empty(), to: Tokenizers.Native, as: :models_unigram_empty 27 | end 28 | -------------------------------------------------------------------------------- /lib/tokenizers/model/wordlevel.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.WordLevel do 2 | @typedoc """ 3 | Options for model initialisation. 4 | 5 | * `:unk_token` - the unknown token to be used by the model. Defaults 6 | to "[UNK]" 7 | 8 | """ 9 | @type options() :: [ 10 | unk_token: String.t() 11 | ] 12 | 13 | @doc """ 14 | Instantiate a WordLevel model from the given vocab. 15 | """ 16 | @spec init( 17 | vocab :: %{String.t() => integer()}, 18 | options :: options() 19 | ) :: {:ok, Tokenizers.Model.t()} 20 | defdelegate init(vocab, options \\ []), 21 | to: Tokenizers.Native, 22 | as: :models_wordlevel_init 23 | 24 | @doc """ 25 | Instantiate an empty WordLevel model. 26 | """ 27 | @spec empty() :: {:ok, Tokenizers.Model.t()} 28 | defdelegate empty(), to: Tokenizers.Native, as: :models_wordlevel_empty 29 | 30 | @doc """ 31 | Instantiate a WordLevel model from the given vocab file. 32 | """ 33 | @spec from_file(String.t(), options()) :: {:ok, Tokenizers.Model.t()} 34 | defdelegate from_file(vocab_path, options \\ []), 35 | to: Tokenizers.Native, 36 | as: :models_wordlevel_from_file 37 | end 38 | -------------------------------------------------------------------------------- /lib/tokenizers/model/wordpiece.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.WordPiece do 2 | @typedoc """ 3 | Options for model initialisation. 4 | 5 | * `:unk_token` - the unknown token to be used by the model. 6 | Defaults to `"[UNK]"` 7 | 8 | * `:max_input_chars_per_word` - the maximum number of characters 9 | to allow in a single word. Defaults to `100` 10 | 11 | * `:continuing_subword_prefix` - the prefix to attach to subword 12 | units that don't represent a beginning of word. Defaults to `"##"`. 13 | 14 | """ 15 | @type options() :: [ 16 | unk_token: String.t(), 17 | max_input_chars_per_word: number(), 18 | continuing_subword_prefix: String.t() 19 | ] 20 | 21 | @doc """ 22 | Instantiate a WordPiece model from the given vocab. 23 | """ 24 | @spec init(%{String.t() => integer()}, options()) :: {:ok, Tokenizers.Model.t()} 25 | defdelegate init(vocab, options \\ []), 26 | to: Tokenizers.Native, 27 | as: :models_wordpiece_init 28 | 29 | @doc """ 30 | Instantiate an empty WordPiece model. 31 | """ 32 | @spec empty() :: {:ok, Tokenizers.Model.t()} 33 | defdelegate empty(), to: Tokenizers.Native, as: :models_wordpiece_empty 34 | 35 | @doc """ 36 | Instantiate a WordPiece model from the given vocab file. 37 | """ 38 | @spec from_file(String.t(), options()) :: {:ok, Tokenizers.Model.t()} 39 | defdelegate from_file(vocab_path, options \\ []), 40 | to: Tokenizers.Native, 41 | as: :models_wordpiece_from_file 42 | end 43 | -------------------------------------------------------------------------------- /lib/tokenizers/native.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Native do 2 | @moduledoc false 3 | 4 | mix_config = Mix.Project.config() 5 | version = mix_config[:version] 6 | github_url = mix_config[:package][:links]["GitHub"] 7 | 8 | use RustlerPrecompiled, 9 | otp_app: :tokenizers, 10 | crate: "ex_tokenizers", 11 | version: version, 12 | base_url: "#{github_url}/releases/download/v#{version}", 13 | force_build: System.get_env("TOKENIZERS_BUILD") in ["1", "true"] 14 | 15 | # Added tokens 16 | def added_token_new(_token, _opts), do: err() 17 | # 18 | def added_token_info(_added_token), do: err() 19 | 20 | # Decoders 21 | def decoders_decode(_decoder, _tokens), do: err() 22 | # 23 | def decoders_info(_decoder), do: err() 24 | # 25 | def decoders_byte_level(), do: err() 26 | def decoders_replace(_pattern, _content), do: err() 27 | def decoders_wordpiece(_options), do: err() 28 | def decoders_byte_fallback(), do: err() 29 | def decoders_fuse(), do: err() 30 | def decoders_strip(_content, _left, _right), do: err() 31 | def decoders_metaspace(_options), do: err() 32 | def decoders_bpe(_options), do: err() 33 | def decoders_ctc(_options), do: err() 34 | def decoders_sequence(_decoders), do: err() 35 | 36 | # DecoderStream 37 | def decoder_stream_step(_decoder_stream, _tokenizer, _id), do: err() 38 | # 39 | def decoder_stream_info(_decoder_stream), do: err() 40 | # 41 | def decoder_stream_new(_skip_special_tokens), do: err() 42 | 43 | # Encoding 44 | def encoding_get_length(_encoding), do: err() 45 | def encoding_get_n_sequences(_encoding), do: err() 46 | def encoding_set_sequence_id(_encoding, _seq_id), do: err() 47 | def encoding_get_ids(_encoding), do: err() 48 | def encoding_get_u32_ids(_encoding), do: err() 49 | def encoding_get_type_ids(_encoding), do: err() 50 | def encoding_get_u32_type_ids(_encoding), do: err() 51 | def encoding_get_attention_mask(_encoding), do: err() 52 | def encoding_get_u32_attention_mask(_encoding), do: err() 53 | def encoding_get_special_tokens_mask(_encoding), do: err() 54 | def encoding_get_u32_special_tokens_mask(_encoding), do: err() 55 | def encoding_get_tokens(_encoding), do: err() 56 | def encoding_get_word_ids(_encoding), do: err() 57 | def encoding_get_sequence_ids(_encoding), do: err() 58 | def encoding_get_offsets(_encoding), do: err() 59 | def encoding_get_overflowing(_encoding), do: err() 60 | def encoding_word_to_tokens(_encoding, _word, _seq_id), do: err() 61 | def encoding_word_to_chars(_encoding, _word, _seq_id), do: err() 62 | def encoding_token_to_sequence(_encoding, _token), do: err() 63 | def encoding_token_to_chars(_encoding, _token), do: err() 64 | def encoding_token_to_word(_encoding, _token), do: err() 65 | def encoding_char_to_token(_encoding, _position, _seq_id), do: err() 66 | def encoding_char_to_word(_encoding, _position, _seq_id), do: err() 67 | def encoding_pad(_encoding, _target_length, _opts), do: err() 68 | def encoding_truncate(_encoding, _max_length, _opts), do: err() 69 | # 70 | def encoding_transform(_encoding, _transformers), do: err() 71 | 72 | # Models 73 | def models_save(_model, _folder, _opts), do: err() 74 | # 75 | def models_info(_model), do: err() 76 | # 77 | def models_bpe_init(_vocab, _merges, _options), do: err() 78 | def models_bpe_empty(), do: err() 79 | def models_bpe_from_file(_vocab, _merges, _options), do: err() 80 | # 81 | def models_wordpiece_init(_vocab, _options), do: err() 82 | def models_wordpiece_empty(), do: err() 83 | def models_wordpiece_from_file(_vocab, _options), do: err() 84 | # 85 | def models_wordlevel_init(_vocab, _options), do: err() 86 | def models_wordlevel_empty(), do: err() 87 | def models_wordlevel_from_file(_vocab, _options), do: err() 88 | # 89 | def models_unigram_init(_vocab, _options), do: err() 90 | def models_unigram_empty(), do: err() 91 | 92 | # Normalizers 93 | def normalizers_normalize(_normalizer, _input), do: err() 94 | # 95 | def normalizers_info(_normalizer), do: err() 96 | # 97 | def normalizers_bert_normalizer(_opts), do: err() 98 | def normalizers_nfd(), do: err() 99 | def normalizers_nfkd(), do: err() 100 | def normalizers_nfc(), do: err() 101 | def normalizers_nfkc(), do: err() 102 | def normalizers_strip(_opts), do: err() 103 | def normalizers_prepend(_prepend), do: err() 104 | def normalizers_strip_accents(), do: err() 105 | def normalizers_sequence(_normalizers), do: err() 106 | def normalizers_lowercase(), do: err() 107 | def normalizers_replace(_pattern, _content), do: err() 108 | def normalizers_nmt(), do: err() 109 | def normalizers_precompiled(_data), do: err() 110 | def normalizers_byte_level(), do: err() 111 | def normalizers_byte_level_alphabet(), do: err() 112 | 113 | # PreTokenizers 114 | def pre_tokenizers_pre_tokenize(_pre_tokenizer, _input), do: err() 115 | # 116 | def pre_tokenizers_info(_pre_tokenizer), do: err() 117 | # 118 | def pre_tokenizers_byte_level(_opts), do: err() 119 | def pre_tokenizers_byte_level_alphabet(), do: err() 120 | def pre_tokenizers_whitespace(), do: err() 121 | def pre_tokenizers_whitespace_split(), do: err() 122 | def pre_tokenizers_bert(), do: err() 123 | def pre_tokenizers_metaspace(_opts), do: err() 124 | def pre_tokenizers_char_delimiter_split(_delimiter), do: err() 125 | def pre_tokenizers_split(_pattern, _behavior, _options), do: err() 126 | def pre_tokenizers_punctuation(_behavior), do: err() 127 | def pre_tokenizers_sequence(_pre_tokenizers), do: err() 128 | def pre_tokenizers_digits(_options), do: err() 129 | 130 | # PostProcessors 131 | def post_processors_info(_post_processor), do: err() 132 | # 133 | def post_processors_bert(_sep, _cls), do: err() 134 | def post_processors_roberta(_sep, _cls, _opts), do: err() 135 | def post_processors_byte_level(_opts), do: err() 136 | def post_processors_template(_opts), do: err() 137 | def post_processors_sequence(_post_processors), do: err() 138 | 139 | # Trainers 140 | def trainers_info(_trainer), do: err() 141 | # 142 | def trainers_bpe_trainer(_options), do: err() 143 | def trainers_wordpiece_trainer(_options), do: err() 144 | def trainers_wordlevel_trainer(_options), do: err() 145 | def trainers_unigram_trainer(_options), do: err() 146 | 147 | # Tokenizer 148 | def tokenizer_init(_model), do: err() 149 | def tokenizer_from_file(_path, _options), do: err() 150 | def tokenizer_from_buffer(_buffer, _options), do: err() 151 | def tokenizer_save(_tokenizer, _folder, _options), do: err() 152 | # 153 | def tokenizer_get_model(_tokenizer), do: err() 154 | def tokenizer_set_model(_tokenizer, _model), do: err() 155 | def tokenizer_get_normalizer(_tokenizer), do: err() 156 | def tokenizer_set_normalizer(_tokenizer, _normalizer), do: err() 157 | def tokenizer_get_pre_tokenizer(_tokenizer), do: err() 158 | def tokenizer_set_pre_tokenizer(_tokenizer, _pre_tokenizer), do: err() 159 | def tokenizer_get_post_processor(_tokenizer), do: err() 160 | def tokenizer_set_post_processor(_tokenizer, _post_processor), do: err() 161 | def tokenizer_get_decoder(_tokenizer), do: err() 162 | def tokenizer_set_decoder(_tokenizer, _decoder), do: err() 163 | def tokenizer_get_vocab(_tokenizer, _with_added_tokens), do: err() 164 | def tokenizer_get_vocab_size(_tokenizer, _with_added_tokens), do: err() 165 | def tokenizer_add_tokens(_tokenizer, _tokens), do: err() 166 | def tokenizer_add_special_tokens(_tokenizer, _tokens), do: err() 167 | def tokenizer_set_truncation(_tokenizer, _opts), do: err() 168 | def tokenizer_disable_truncation(_tokenizer), do: err() 169 | def tokenizer_set_padding(_tokenizer, _opts), do: err() 170 | def tokenizer_disable_padding(_tokenizer), do: err() 171 | # 172 | def tokenizer_encode(_tokenizer, _input, _options), do: err() 173 | def tokenizer_encode_batch(_tokenizer, _inputs, _options), do: err() 174 | def tokenizer_decode(_tokenizer, _ids, _options), do: err() 175 | def tokenizer_decode_batch(_tokenizer, _ids, _options), do: err() 176 | def tokenizer_token_to_id(_tokenizer, _token), do: err() 177 | def tokenizer_id_to_token(_tokenizer, _id), do: err() 178 | def tokenizer_post_processing(_tokenizer, _encoding, _pair, _add_special_tokens), do: err() 179 | # 180 | def tokenizer_train_from_files(_tokenizer, _files, _trainer), do: err() 181 | 182 | defp err(), do: :erlang.nif_error(:nif_not_loaded) 183 | end 184 | -------------------------------------------------------------------------------- /lib/tokenizers/normalizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Normalizer do 2 | @moduledoc """ 3 | Normalizers and normalization functions. 4 | 5 | A normalizer is in charge of pre-processing the input string in 6 | order to normalize it as relevant for the given use case. 7 | 8 | Some common examples of normalization are the Unicode normalization 9 | algorithms (NFD, NFKD, NFC & NFKC) or lowercasing. The specificity 10 | of tokenizers is that we keep track of the alignment while 11 | normalizing. This is essential to allow mapping from the generated 12 | tokens back to the input text. 13 | """ 14 | 15 | defstruct [:resource] 16 | 17 | @type t() :: %__MODULE__{resource: reference()} 18 | 19 | @doc """ 20 | Normalizes the given text input. 21 | """ 22 | @spec normalize(t(), String.t()) :: {:ok, String.t()} 23 | defdelegate normalize(normalizer, input), to: Tokenizers.Native, as: :normalizers_normalize 24 | 25 | # Normalizer entities. Following the order in https://docs.rs/tokenizers/0.20.0/src/tokenizers/normalizers/mod.rs.html#24 26 | 27 | @doc """ 28 | Takes care of normalizing raw text before giving it to a BERT model. 29 | 30 | This includes cleaning the text, handling accents, Chinese chars and 31 | lowercasing. 32 | 33 | ## Options 34 | 35 | * `:clean_text` - whether to clean the text, by removing any 36 | control characters and replacing all whitespaces by the classic 37 | one. Defaults to `true` 38 | 39 | * `:handle_chinese_chars` - whether to handle chinese chars by 40 | putting spaces around them. Default `true` 41 | 42 | * `:strip_accents` - whether to strip all accents. If this option 43 | is not specified, then it will be determined by the value for 44 | lowercase (as in the original Bert) 45 | 46 | * `:lowercase` - whether to lowercase. Default `true` 47 | 48 | """ 49 | @spec bert_normalizer(keyword()) :: t() 50 | defdelegate bert_normalizer(opts \\ []), 51 | to: Tokenizers.Native, 52 | as: :normalizers_bert_normalizer 53 | 54 | @doc """ 55 | Creates a Strip normalizer. 56 | 57 | Removes all whitespace characters on the specified sides (left, 58 | right or both) of the input 59 | 60 | ## Options 61 | 62 | * `:left` - whether to strip left side. Defaults to `true` 63 | 64 | * `:right` - whether to strip right side. Defaults to `true` 65 | 66 | """ 67 | @spec strip(keyword()) :: t() 68 | defdelegate strip(opts \\ []), to: Tokenizers.Native, as: :normalizers_strip 69 | 70 | @doc """ 71 | Creates a Strip Accent normalizer. 72 | 73 | Removes all accent symbols in unicode (to be used with NFD for 74 | consistency). 75 | """ 76 | @spec strip_accents :: t() 77 | defdelegate strip_accents(), to: Tokenizers.Native, as: :normalizers_strip_accents 78 | 79 | @doc """ 80 | Creates a NFC Unicode normalizer. 81 | """ 82 | @spec nfc :: t() 83 | defdelegate nfc(), to: Tokenizers.Native, as: :normalizers_nfc 84 | 85 | @doc """ 86 | Creates a NFD Unicode normalizer. 87 | """ 88 | @spec nfd :: t() 89 | defdelegate nfd(), to: Tokenizers.Native, as: :normalizers_nfd 90 | 91 | @doc """ 92 | Creates a NFKC Unicode normalizer. 93 | """ 94 | @spec nfkc :: t() 95 | defdelegate nfkc(), to: Tokenizers.Native, as: :normalizers_nfkc 96 | 97 | @doc """ 98 | Creates a NFKD Unicode normalizer. 99 | """ 100 | @spec nfkd :: t() 101 | defdelegate nfkd(), to: Tokenizers.Native, as: :normalizers_nfkd 102 | 103 | @doc """ 104 | Composes multiple normalizers that will run in the provided order. 105 | """ 106 | @spec sequence([t()]) :: t() 107 | defdelegate sequence(normalizers), to: Tokenizers.Native, as: :normalizers_sequence 108 | 109 | @doc """ 110 | Replaces all uppercase to lowercase 111 | """ 112 | @spec lowercase :: t() 113 | defdelegate lowercase(), to: Tokenizers.Native, as: :normalizers_lowercase 114 | 115 | @doc """ 116 | Creates a Nmt normalizer. 117 | """ 118 | @spec nmt :: t() 119 | defdelegate nmt(), to: Tokenizers.Native, as: :normalizers_nmt 120 | 121 | @doc """ 122 | Precompiled normalizer. 123 | 124 | Don’t use manually it is used for compatibility with SentencePiece. 125 | """ 126 | @spec precompiled(binary()) :: {:ok, t()} | {:error, any()} 127 | defdelegate precompiled(data), to: Tokenizers.Native, as: :normalizers_precompiled 128 | 129 | @doc """ 130 | Replaces a custom `search` string with the given `content`. 131 | """ 132 | @spec replace(String.t(), String.t()) :: t() 133 | def replace(search, content) do 134 | Tokenizers.Native.normalizers_replace({:string, search}, content) 135 | end 136 | 137 | @doc """ 138 | Replaces occurrences of a custom regexp `pattern` with the given `content`. 139 | 140 | The `pattern` should be a string representing a regular expression 141 | according to the [Oniguruma Regex Engine](https://github.com/kkos/oniguruma). 142 | """ 143 | @spec replace_regex(String.t(), String.t()) :: t() 144 | def replace_regex(pattern, content) do 145 | Tokenizers.Native.normalizers_replace({:regex, pattern}, content) 146 | end 147 | 148 | @doc """ 149 | Creates a Prepend normalizer. 150 | """ 151 | @spec prepend(prepend :: String.t()) :: t() 152 | defdelegate prepend(prepend), to: Tokenizers.Native, as: :normalizers_prepend 153 | 154 | @doc """ 155 | Created ByteLevel normalizer. 156 | """ 157 | @spec byte_level :: t() 158 | defdelegate byte_level(), to: Tokenizers.Native, as: :normalizers_byte_level 159 | 160 | @doc """ 161 | Gets ByteLevel normalizer's alphabet. 162 | """ 163 | defdelegate byte_level_alphabet(), to: Tokenizers.Native, as: :normalizers_byte_level_alphabet 164 | end 165 | 166 | defimpl Inspect, for: Tokenizers.Normalizer do 167 | import Inspect.Algebra 168 | 169 | def inspect(decoder, opts) do 170 | attrs = 171 | decoder 172 | |> Tokenizers.Native.normalizers_info() 173 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 174 | 175 | concat(["#Tokenizers.Normalizer<", to_doc(attrs, opts), ">"]) 176 | end 177 | end 178 | -------------------------------------------------------------------------------- /lib/tokenizers/post_processor.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.PostProcessor do 2 | @moduledoc """ 3 | Post-processors. 4 | 5 | After the whole pipeline, we sometimes want to insert some special 6 | tokens before we feed the encoded text into a model like 7 | ”[CLS] My horse is amazing [SEP]”, we can do that with a post-processor. 8 | """ 9 | 10 | defstruct [:resource] 11 | 12 | @type t() :: %__MODULE__{resource: reference()} 13 | 14 | @doc """ 15 | Creates a Bert post-processor with the given tokens. 16 | """ 17 | @spec bert({String.t(), integer()}, {String.t(), integer()}) :: t() 18 | defdelegate bert(sep, cls), to: Tokenizers.Native, as: :post_processors_bert 19 | 20 | @doc """ 21 | Creates a Roberta post-processor. 22 | 23 | ## Options 24 | 25 | * `:trim_offsets` - whether to trim the whitespaces in the produced 26 | offsets. Defaults to `true` 27 | 28 | * `:add_prefix_space` - whether add_prefix_space was ON during the 29 | pre-tokenization. Defaults to `true` 30 | 31 | """ 32 | @spec roberta({String.t(), integer()}, {String.t(), integer()}, keyword()) :: t() 33 | defdelegate roberta(sep, cls, opts \\ []), to: Tokenizers.Native, as: :post_processors_roberta 34 | 35 | @doc """ 36 | Creates a ByteLevel post-processor. 37 | 38 | ## Options 39 | 40 | * `:trim_offsets` - whether to trim the whitespaces in the produced 41 | offsets. Defaults to `true` 42 | 43 | """ 44 | @spec byte_level(keyword()) :: t() 45 | defdelegate byte_level(opts \\ []), to: Tokenizers.Native, as: :post_processors_byte_level 46 | 47 | @doc """ 48 | Creates a Template post-processor. 49 | 50 | Lets you easily template the post processing, adding special tokens 51 | and specifying the type id for each sequence/special token. The 52 | template is given two strings representing the single sequence and 53 | the pair of sequences, as well as a set of special tokens to use. 54 | 55 | For example, when specifying a template with these values: 56 | 57 | * single: `"[CLS] $A [SEP]"` 58 | * pair: `"[CLS] $A [SEP] $B [SEP]"` 59 | * special tokens: 60 | * `"[CLS]"` 61 | * `"[SEP]"` 62 | 63 | > Input: `("I like this", "but not this")` 64 | > Output: `"[CLS] I like this [SEP] but not this [SEP]"` 65 | 66 | ## Options 67 | 68 | * `:single` - a string describing the template for a single 69 | sequence 70 | 71 | * `:pair` - a string describing the template for a pair of 72 | sequences 73 | 74 | * `:special_tokens` - a list of special tokens to use in the 75 | template. Must be a list of `{token, token_id}` tuples 76 | 77 | """ 78 | @spec template(keyword()) :: t() 79 | defdelegate template(opts \\ []), to: Tokenizers.Native, as: :post_processors_template 80 | 81 | @doc """ 82 | Instantiate a new Sequence post-processor 83 | """ 84 | @spec sequence(post_processors :: [t()]) :: t() 85 | defdelegate sequence(post_processors), to: Tokenizers.Native, as: :post_processors_sequence 86 | end 87 | 88 | defimpl Inspect, for: Tokenizers.PostProcessor do 89 | import Inspect.Algebra 90 | 91 | def inspect(decoder, opts) do 92 | attrs = 93 | decoder 94 | |> Tokenizers.Native.post_processors_info() 95 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 96 | 97 | concat(["#Tokenizers.PostProcessor<", to_doc(attrs, opts), ">"]) 98 | end 99 | end 100 | -------------------------------------------------------------------------------- /lib/tokenizers/pre_tokenizer.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.PreTokenizer do 2 | @moduledoc """ 3 | Pre-tokenizers. 4 | 5 | A pre-tokenizer takes care of splitting the input according to a set 6 | of rules. This pre-processing lets you ensure that the underlying 7 | model does not build tokens across multiple “splits”. For example 8 | if you don’t want to have whitespaces inside a token, then you can 9 | have a pre-tokenizer that splits on these whitespaces. 10 | 11 | You can easily combine multiple pre-tokenizers together using 12 | `sequence/1`. 13 | 14 | A pre-tokenizer is also allowed to modify the string, just like a 15 | normalizer does. This is necessary to allow some complicated 16 | algorithms that require to split before normalizing (e.g. ByteLevel). 17 | """ 18 | 19 | defstruct [:resource] 20 | 21 | @type t() :: %__MODULE__{resource: reference()} 22 | 23 | @doc """ 24 | Converts a string into a sequence of pre-tokens. 25 | """ 26 | @spec pre_tokenize(t(), String.t()) :: {:ok, [{String.t(), {integer(), integer()}}]} 27 | defdelegate pre_tokenize(pre_tokenizer, input), 28 | to: Tokenizers.Native, 29 | as: :pre_tokenizers_pre_tokenize 30 | 31 | @doc """ 32 | Creates a ByteLevel pre-tokenizer. 33 | 34 | Splits on whitespaces while remapping all the bytes to a set of 35 | visible characters. This technique has been introduced by OpenAI 36 | with GPT-2 and has some more or less nice properties: 37 | 38 | * Since it maps on bytes, a tokenizer using this only requires 39 | 256 characters as initial alphabet (the number of values a byte 40 | can have), as opposed to the 130,000+ Unicode characters. 41 | 42 | * A consequence of the previous point is that it is absolutely 43 | unnecessary to have an unknown token using this since we can 44 | represent anything with 256 tokens (Youhou!! 🎉🎉) 45 | 46 | * For non ascii characters, it gets completely unreadable, but it 47 | works nonetheless! 48 | 49 | ## Options 50 | 51 | * `:add_prefix_space` - whether to add a space to the first word 52 | if there isn’t already one. This lets us treat hello exactly 53 | like say hello. Defaults to `true` 54 | 55 | * `:use_regex` - set this to `false` to prevent this pre-tokenizer 56 | from using the GPT2 specific regexp for splitting on whitespace. 57 | Defaults to `true` 58 | 59 | """ 60 | @spec byte_level(keyword()) :: t() 61 | defdelegate byte_level(opts \\ []), to: Tokenizers.Native, as: :pre_tokenizers_byte_level 62 | 63 | @doc """ 64 | Gets ByteLevel pre-tokenizer's alphabet. 65 | """ 66 | @spec byte_level_alphabet() :: charlist() 67 | defdelegate byte_level_alphabet(), 68 | to: Tokenizers.Native, 69 | as: :pre_tokenizers_byte_level_alphabet 70 | 71 | @doc """ 72 | Creates a Whitespace pre-tokenizer. 73 | 74 | Splits on word boundaries. Uses the following regular expression: 75 | `\w+|[^\w\s]+`. 76 | """ 77 | @spec whitespace() :: t() 78 | defdelegate whitespace(), to: Tokenizers.Native, as: :pre_tokenizers_whitespace 79 | 80 | @doc """ 81 | Creates a WhitespaceSplit pre-tokenizer. 82 | 83 | Splits on any whitespace character. 84 | """ 85 | @spec whitespace_split() :: t() 86 | defdelegate whitespace_split(), to: Tokenizers.Native, as: :pre_tokenizers_whitespace_split 87 | 88 | @doc """ 89 | Creates a BertPreTokenizer pre-tokenizer. 90 | 91 | Splits for use in BERT models. 92 | """ 93 | @spec bert_pre_tokenizer() :: t() 94 | defdelegate bert_pre_tokenizer(), to: Tokenizers.Native, as: :pre_tokenizers_bert 95 | 96 | @doc """ 97 | Creates Metaspace pre-tokenizer. 98 | 99 | Splits on whitespaces and replaces them with a special char “▁” 100 | (U+2581). 101 | 102 | ## Options 103 | 104 | * `:replacement` - the replacement character to use. Defaults to `"▁"` 105 | 106 | * `:prepend_scheme` - whether to add a space to the first word if there 107 | isn't already one. This lets us treat "hello" exactly like "say hello". 108 | Either of `:always`, `:never`, `:first`. `:first` means the space is 109 | only added on the first token (relevant when special tokens are used 110 | or other pre_tokenizer are used). Defaults to `:always` 111 | 112 | """ 113 | @spec metaspace(keyword()) :: t() 114 | defdelegate metaspace(opts \\ []), to: Tokenizers.Native, as: :pre_tokenizers_metaspace 115 | 116 | @doc """ 117 | Creates a CharDelimiterSplit pre-tokenizer. 118 | 119 | This pre-tokenizer simply splits on the provided delimiter. Works 120 | almost like simple split function, except that it accounts for 121 | multiple consecutive spaces. 122 | """ 123 | @spec char_delimiter_split(char()) :: t() 124 | defdelegate char_delimiter_split(delimiter), 125 | to: Tokenizers.Native, 126 | as: :pre_tokenizers_char_delimiter_split 127 | 128 | @typedoc """ 129 | Specifies how delimiter should behave for several pretokenizers. 130 | """ 131 | @type split_delimiter_behaviour() :: 132 | :removed 133 | | :isolated 134 | | :merged_with_previous 135 | | :merged_with_next 136 | | :contiguous 137 | 138 | @doc """ 139 | Creates a Split pre-tokenizer using a string as split pattern. 140 | 141 | Versatile pre-tokenizer that splits on provided pattern and according 142 | to provided behavior. 143 | 144 | ## Options 145 | 146 | * `:invert` - whether to invert the split or not. Defaults to `false` 147 | 148 | """ 149 | @spec split(String.t(), split_delimiter_behaviour(), keyword()) :: t() 150 | def split(pattern, behavior, opts \\ []) when is_binary(pattern) do 151 | Tokenizers.Native.pre_tokenizers_split({:string, pattern}, behavior, opts) 152 | end 153 | 154 | @doc ~S""" 155 | Creates a Split pre-tokenizer using a regular expression as split pattern. 156 | 157 | Versatile pre-tokenizer that splits on provided regex pattern and according 158 | to provided behavior. 159 | 160 | The `pattern` should be a string representing a regular expression 161 | according to the [Oniguruma Regex Engine](https://github.com/kkos/oniguruma). 162 | 163 | ## Options 164 | 165 | * `:invert` - whether to invert the split or not. Defaults to `false` 166 | 167 | ## Example 168 | 169 | iex> Tokenizers.PreTokenizer.split_regex(~S(\?\d{2}\?), :removed) 170 | #Tokenizers.PreTokenizer<[pre_tokenizer_type: "Split"]> 171 | 172 | """ 173 | @spec split_regex(String.t(), split_delimiter_behaviour(), keyword()) :: t() 174 | def split_regex(pattern, behavior, opts \\ []) when is_binary(pattern) do 175 | Tokenizers.Native.pre_tokenizers_split({:regex, pattern}, behavior, opts) 176 | end 177 | 178 | @doc """ 179 | Creates a Punctuation pre-tokenizer. 180 | 181 | Will isolate all punctuation characters. 182 | """ 183 | @spec punctuation(split_delimiter_behaviour()) :: t() 184 | defdelegate punctuation(behaviour), to: Tokenizers.Native, as: :pre_tokenizers_punctuation 185 | 186 | @doc """ 187 | Creates a Sequence pre-tokenizer. 188 | 189 | Lets you compose multiple pre-tokenizers that will be run in the 190 | given order. 191 | """ 192 | @spec sequence([t()]) :: t() 193 | defdelegate sequence(pre_tokenizers), to: Tokenizers.Native, as: :pre_tokenizers_sequence 194 | 195 | @doc """ 196 | Creates a Digits pre-tokenizer. 197 | 198 | Splits the numbers from any other characters. 199 | 200 | ## Options 201 | 202 | * `:individual_digits` - whether to split individual digits or not. 203 | Defaults to `false` 204 | 205 | """ 206 | @spec digits(keyword()) :: t() 207 | defdelegate digits(opts \\ []), 208 | to: Tokenizers.Native, 209 | as: :pre_tokenizers_digits 210 | end 211 | 212 | defimpl Inspect, for: Tokenizers.PreTokenizer do 213 | import Inspect.Algebra 214 | 215 | def inspect(decoder, opts) do 216 | attrs = 217 | decoder 218 | |> Tokenizers.Native.pre_tokenizers_info() 219 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 220 | 221 | concat(["#Tokenizers.PreTokenizer<", to_doc(attrs, opts), ">"]) 222 | end 223 | end 224 | -------------------------------------------------------------------------------- /lib/tokenizers/shared.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Shared do 2 | @moduledoc false 3 | 4 | def unwrap({:ok, value}), do: value 5 | def unwrap({:error, reason}), do: raise(reason) 6 | end 7 | -------------------------------------------------------------------------------- /lib/tokenizers/trainer.ex: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Trainer do 2 | @moduledoc """ 3 | A Trainer has the responsibility to train a model. 4 | We feed it with lines/sentences and then it can train the given Model. 5 | """ 6 | defstruct [:resource] 7 | @type t() :: %__MODULE__{resource: reference()} 8 | 9 | @doc """ 10 | Get trainer info 11 | """ 12 | @spec info(t()) :: map() 13 | defdelegate info(trainer), to: Tokenizers.Native, as: :trainers_info 14 | 15 | @typedoc """ 16 | Options for BPE trainer initialisation. All options can be ommited. 17 | """ 18 | @type bpe_options() :: [ 19 | vocab_size: non_neg_integer(), 20 | min_frequency: non_neg_integer(), 21 | special_tokens: [String.t()], 22 | limit_alphabet: non_neg_integer(), 23 | initial_alphabet: [char()], 24 | show_progress: boolean(), 25 | continuing_subword_prefix: String.t(), 26 | end_of_word_suffix: String.t() 27 | ] 28 | 29 | @doc """ 30 | Creates a new BPE Trainer. 31 | """ 32 | @spec bpe(bpe_options()) :: {:ok, t()} | {:error, any()} 33 | defdelegate bpe(options \\ []), to: Tokenizers.Native, as: :trainers_bpe_trainer 34 | 35 | @typedoc """ 36 | Options for WordPiece trainer initialisation. All options can be ommited. 37 | """ 38 | @type wordpiece_options() :: [ 39 | vocab_size: non_neg_integer(), 40 | min_frequency: non_neg_integer(), 41 | special_tokens: [String.t()], 42 | limit_alphabet: non_neg_integer(), 43 | initial_alphabet: [char()], 44 | show_progress: boolean(), 45 | continuing_subword_prefix: String.t(), 46 | end_of_word_suffix: String.t() 47 | ] 48 | 49 | @doc """ 50 | Creates a new WordPiece Trainer. 51 | """ 52 | @spec wordpiece(wordpiece_options()) :: {:ok, t()} | {:error, any()} 53 | defdelegate wordpiece(options \\ []), to: Tokenizers.Native, as: :trainers_wordpiece_trainer 54 | 55 | @typedoc """ 56 | Options for WordLevel trainer initialisation. All options can be ommited. 57 | """ 58 | @type wordlevel_options() :: [ 59 | vocab_size: non_neg_integer(), 60 | min_frequency: non_neg_integer(), 61 | special_tokens: [String.t()], 62 | show_progress: boolean() 63 | ] 64 | 65 | @doc """ 66 | Creates a new WordLevel Trainer. 67 | """ 68 | @spec wordlevel(wordlevel_options()) :: {:ok, t()} | {:error, any()} 69 | defdelegate wordlevel(options \\ []), to: Tokenizers.Native, as: :trainers_wordlevel_trainer 70 | 71 | @typedoc """ 72 | Options for Unigram trainer initialisation. All options can be ommited. 73 | """ 74 | @type unigram_options() :: [ 75 | vocab_size: non_neg_integer(), 76 | n_sub_iterations: non_neg_integer(), 77 | shrinking_factor: float(), 78 | special_tokens: [String.t()], 79 | initial_alphabet: [char()], 80 | uni_token: String.t(), 81 | max_piece_length: non_neg_integer(), 82 | seed_size: non_neg_integer(), 83 | show_progress: boolean() 84 | ] 85 | 86 | @doc """ 87 | Creates a new Unigram Trainer. 88 | """ 89 | @spec unigram(unigram_options()) :: {:ok, t()} | {:error, any()} 90 | defdelegate unigram(options \\ []), to: Tokenizers.Native, as: :trainers_unigram_trainer 91 | end 92 | 93 | defimpl Inspect, for: Tokenizers.Trainer do 94 | import Inspect.Algebra 95 | 96 | @spec inspect(Tokenizers.Trainer.t(), Inspect.Opts.t()) :: Inspect.Algebra.t() 97 | def inspect(trainer, opts) do 98 | attrs = 99 | trainer 100 | |> Tokenizers.Trainer.info() 101 | |> Keyword.new(fn {k, v} -> {String.to_atom(k), v} end) 102 | 103 | concat(["#Tokenizers.Trainer<", to_doc(attrs, opts), ">"]) 104 | end 105 | end 106 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.MixProject do 2 | use Mix.Project 3 | 4 | @source_url "https://github.com/elixir-nx/tokenizers" 5 | @version "0.6.0-dev" 6 | 7 | def project do 8 | [ 9 | app: :tokenizers, 10 | name: "Tokenizers", 11 | description: "Bindings to Hugging Face Tokenizers for Elixir", 12 | version: @version, 13 | elixir: "~> 1.13", 14 | package: package(), 15 | deps: deps(), 16 | docs: docs(), 17 | preferred_cli_env: [ 18 | docs: :docs, 19 | "hex.publish": :docs 20 | ] 21 | ] 22 | end 23 | 24 | def application do 25 | [ 26 | extra_applications: [:logger, :inets, :public_key] 27 | ] 28 | end 29 | 30 | defp deps do 31 | [ 32 | {:castore, "~> 0.1 or ~> 1.0"}, 33 | {:ex_doc, "~> 0.28", only: :docs, runtime: false}, 34 | {:rustler, ">= 0.0.0", optional: true}, 35 | {:rustler_precompiled, "~> 0.6"} 36 | ] 37 | end 38 | 39 | defp docs do 40 | [ 41 | main: "Tokenizers", 42 | source_ref: "v#{@version}", 43 | source_url: @source_url, 44 | extras: ["notebooks/pretrained.livemd", "notebooks/training.livemd", "LICENSE"], 45 | groups_for_modules: [ 46 | Tokenization: [ 47 | Tokenizers.Tokenizer, 48 | Tokenizers.Encoding, 49 | Tokenizers.Encoding.Transformation, 50 | Tokenizers.Decoder 51 | ], 52 | Pipeline: [ 53 | Tokenizers.Normalizer, 54 | Tokenizers.PreTokenizer, 55 | Tokenizers.PostProcessor 56 | ], 57 | Training: [ 58 | Tokenizers.Model, 59 | Tokenizers.Model.BPE, 60 | Tokenizers.Model.Unigram, 61 | Tokenizers.Model.WordLevel, 62 | Tokenizers.Model.WordPiece, 63 | Tokenizers.Trainer, 64 | Tokenizers.AddedToken 65 | ], 66 | Other: [ 67 | Tokenizers.HTTPClient 68 | ] 69 | ], 70 | groups_for_functions: [ 71 | # Tokenizers.Tokenizer 72 | Loading: &(&1[:type] == :loading), 73 | Inference: &(&1[:type] == :inference), 74 | Configuration: &(&1[:type] == :configuration), 75 | Training: &(&1[:type] == :training) 76 | ] 77 | ] 78 | end 79 | 80 | defp package do 81 | [ 82 | files: [ 83 | "lib", 84 | "native", 85 | "checksum-*.exs", 86 | "mix.exs", 87 | "LICENSE" 88 | ], 89 | licenses: ["Apache-2.0"], 90 | links: %{"GitHub" => @source_url}, 91 | maintainers: ["Christopher Grainger"] 92 | ] 93 | end 94 | end 95 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "castore": {:hex, :castore, "1.0.9", "5cc77474afadf02c7c017823f460a17daa7908e991b0cc917febc90e466a375c", [:mix], [], "hexpm", "5ea956504f1ba6f2b4eb707061d8e17870de2bee95fb59d512872c2ef06925e7"}, 3 | "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, 4 | "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, 5 | "finch": {:hex, :finch, "0.19.0", "c644641491ea854fc5c1bbaef36bfc764e3f08e7185e1f084e35e0672241b76d", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.6.2 or ~> 1.7", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.1", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "fc5324ce209125d1e2fa0fcd2634601c52a787aff1cd33ee833664a5af4ea2b6"}, 6 | "hpax": {:hex, :hpax, "1.0.0", "28dcf54509fe2152a3d040e4e3df5b265dcb6cb532029ecbacf4ce52caea3fd2", [:mix], [], "hexpm", "7f1314731d711e2ca5fdc7fd361296593fc2542570b3105595bb0bc6d0fad601"}, 7 | "jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"}, 8 | "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, 9 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, 10 | "makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"}, 11 | "mime": {:hex, :mime, "2.0.6", "8f18486773d9b15f95f4f4f1e39b710045fa1de891fada4516559967276e4dc2", [:mix], [], "hexpm", "c9945363a6b26d747389aac3643f8e0e09d30499a138ad64fe8fd1d13d9b153e"}, 12 | "mint": {:hex, :mint, "1.6.2", "af6d97a4051eee4f05b5500671d47c3a67dac7386045d87a904126fd4bbcea2e", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1 or ~> 0.2.0 or ~> 1.0", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "5ee441dffc1892f1ae59127f74afe8fd82fda6587794278d924e4d90ea3d63f9"}, 13 | "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, 14 | "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, 15 | "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, 16 | "req": {:hex, :req, "0.5.6", "8fe1eead4a085510fe3d51ad854ca8f20a622aae46e97b302f499dfb84f726ac", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.17", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 2.0.6 or ~> 2.1", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "cfaa8e720945d46654853de39d368f40362c2641c4b2153c886418914b372185"}, 17 | "rustler": {:hex, :rustler, "0.34.0", "e9a73ee419fc296a10e49b415a2eb87a88c9217aa0275ec9f383d37eed290c1c", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:req, "~> 0.5", [hex: :req, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "1d0c7449482b459513003230c0e2422b0252245776fe6fd6e41cb2b11bd8e628"}, 18 | "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.2", "5f25cbe220a8fac3e7ad62e6f950fcdca5a5a5f8501835d2823e8c74bf4268d5", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "63d1bd5f8e23096d1ff851839923162096364bac8656a4a3c00d1fff8e83ee0a"}, 19 | "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, 20 | "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, 21 | } 22 | -------------------------------------------------------------------------------- /native/ex_tokenizers/.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.x86_64-apple-darwin] 2 | rustflags = [ 3 | "-C", "link-arg=-undefined", 4 | "-C", "link-arg=dynamic_lookup", 5 | ] 6 | 7 | [target.aarch64-apple-darwin] 8 | rustflags = [ 9 | "-C", "link-arg=-undefined", 10 | "-C", "link-arg=dynamic_lookup", 11 | ] 12 | 13 | [target.x86_64-unknown-linux-musl] 14 | rustflags = [ 15 | "-C", "target-feature=-crt-static" 16 | ] 17 | 18 | [target.aarch64-unknown-linux-musl] 19 | rustflags = [ 20 | "-C", "target-feature=-crt-static" 21 | ] -------------------------------------------------------------------------------- /native/ex_tokenizers/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /native/ex_tokenizers/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ex_tokenizers" 3 | version = "0.1.0" 4 | authors = [] 5 | edition = "2021" 6 | 7 | [lib] 8 | name = "ex_tokenizers" 9 | path = "src/lib.rs" 10 | crate-type = ["cdylib"] 11 | 12 | [dependencies] 13 | anyhow = "1" 14 | rustler = "0.36.1" 15 | thiserror = "2" 16 | tokenizers = { version = "0.21.1", default-features = false, features = ["onig", "esaxx_fast"]} 17 | serde = { version = "1.0", features = [ "rc", "derive" ] } 18 | -------------------------------------------------------------------------------- /native/ex_tokenizers/Cross.toml: -------------------------------------------------------------------------------- 1 | [build.env] 2 | passthrough = [ 3 | "RUSTLER_NIF_VERSION" 4 | ] 5 | -------------------------------------------------------------------------------- /native/ex_tokenizers/README.md: -------------------------------------------------------------------------------- 1 | # NIF for Tokenizers.Native 2 | 3 | ## To build the NIF module: 4 | 5 | - Your NIF will now build along with your project. 6 | 7 | ## To load the NIF: 8 | 9 | ```elixir 10 | defmodule Native do 11 | use Rustler, otp_app: :tokenizers, crate: "ex_tokenizers" 12 | 13 | # When your NIF is loaded, it will override this function. 14 | def add(_a, _b), do: :erlang.nif_error(:nif_not_loaded) 15 | end 16 | ``` 17 | 18 | ## Examples 19 | 20 | [This](https://github.com/hansihe/NifIo) is a complete example of a NIF written in Rust. 21 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/added_token.rs: -------------------------------------------------------------------------------- 1 | use crate::{new_info, util::Info}; 2 | use rustler::{NifTaggedEnum, NifUntaggedEnum, Resource}; 3 | use serde::{Deserialize, Serialize}; 4 | use tokenizers::AddedToken; 5 | 6 | pub struct ExTokenizersAddedTokenRef(pub AddedToken); 7 | 8 | #[rustler::resource_impl] 9 | impl Resource for ExTokenizersAddedTokenRef {} 10 | 11 | #[derive(rustler::NifStruct)] 12 | #[module = "Tokenizers.AddedToken"] 13 | pub struct ExTokenizersAddedToken { 14 | pub resource: rustler::ResourceArc, 15 | } 16 | 17 | impl Serialize for ExTokenizersAddedToken { 18 | fn serialize(&self, serializer: S) -> Result 19 | where 20 | S: serde::Serializer, 21 | { 22 | self.resource.0.serialize(serializer) 23 | } 24 | } 25 | 26 | impl<'de> Deserialize<'de> for ExTokenizersAddedToken { 27 | fn deserialize(deserializer: D) -> Result 28 | where 29 | D: serde::Deserializer<'de>, 30 | { 31 | Ok(ExTokenizersAddedToken::new(AddedToken::deserialize( 32 | deserializer, 33 | )?)) 34 | } 35 | } 36 | 37 | impl ExTokenizersAddedTokenRef { 38 | pub fn new(data: T) -> Self 39 | where 40 | T: Into, 41 | { 42 | Self(data.into()) 43 | } 44 | } 45 | 46 | impl ExTokenizersAddedToken { 47 | pub fn new(data: T) -> Self 48 | where 49 | T: Into, 50 | { 51 | Self { 52 | resource: rustler::ResourceArc::new(ExTokenizersAddedTokenRef::new(data)), 53 | } 54 | } 55 | } 56 | 57 | #[derive(NifUntaggedEnum)] 58 | pub enum AddedTokenInput { 59 | AddedToken(ExTokenizersAddedToken), 60 | String(String), 61 | } 62 | 63 | #[derive(NifUntaggedEnum)] 64 | pub enum AddedSpecialTokenInput { 65 | AddedToken(ExTokenizersAddedToken), 66 | String(String), 67 | } 68 | 69 | impl From<&AddedTokenInput> for AddedToken { 70 | fn from(input: &AddedTokenInput) -> Self { 71 | match input { 72 | AddedTokenInput::AddedToken(token) => token.resource.0.clone(), 73 | AddedTokenInput::String(string) => AddedToken::from(string, false), 74 | } 75 | } 76 | } 77 | 78 | impl From<&AddedSpecialTokenInput> for AddedToken { 79 | fn from(input: &AddedSpecialTokenInput) -> Self { 80 | match input { 81 | AddedSpecialTokenInput::AddedToken(token) => token.resource.0.clone(), 82 | AddedSpecialTokenInput::String(string) => AddedToken::from(string, true), 83 | } 84 | } 85 | } 86 | 87 | /////////////////////////////////////////////////////////////////////////////// 88 | /// Inspection 89 | /////////////////////////////////////////////////////////////////////////////// 90 | 91 | #[rustler::nif] 92 | fn added_token_info(added_token: ExTokenizersAddedToken) -> Info { 93 | let added_token: &AddedToken = &added_token.resource.0; 94 | new_info!( 95 | content: added_token.content.clone(), 96 | single_word: added_token.single_word, 97 | lstrip: added_token.lstrip, 98 | rstrip: added_token.rstrip, 99 | normalized: added_token.normalized, 100 | special: added_token.special 101 | ) 102 | } 103 | 104 | #[derive(NifTaggedEnum)] 105 | pub enum AddedTokenOption { 106 | Special(bool), 107 | SingleWord(bool), 108 | Lstrip(bool), 109 | Rstrip(bool), 110 | Normalized(bool), 111 | } 112 | 113 | #[rustler::nif] 114 | fn added_token_new(token: String, options: Vec) -> ExTokenizersAddedToken { 115 | struct Opts { 116 | special: bool, 117 | single_word: bool, 118 | lstrip: bool, 119 | rstrip: bool, 120 | normalized: Option, 121 | } 122 | let mut opts = Opts { 123 | special: false, 124 | single_word: false, 125 | lstrip: false, 126 | rstrip: false, 127 | normalized: None, 128 | }; 129 | 130 | for option in options { 131 | match option { 132 | AddedTokenOption::Special(value) => opts.special = value, 133 | AddedTokenOption::SingleWord(value) => opts.single_word = value, 134 | AddedTokenOption::Lstrip(value) => opts.lstrip = value, 135 | AddedTokenOption::Rstrip(value) => opts.rstrip = value, 136 | AddedTokenOption::Normalized(value) => opts.normalized = Some(value), 137 | } 138 | } 139 | 140 | let mut token = AddedToken::from(token, opts.special); 141 | token = token.single_word(opts.single_word); 142 | token = token.lstrip(opts.lstrip); 143 | token = token.rstrip(opts.rstrip); 144 | if let Some(normalized) = opts.normalized { 145 | token = token.normalized(normalized); 146 | } 147 | 148 | ExTokenizersAddedToken::new(token) 149 | } 150 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/decode_stream.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | use crate::{new_info, tokenizer::ExTokenizersTokenizer, util::Info, ExTokenizersError}; 4 | 5 | #[derive(Serialize, Deserialize, Clone, Debug)] 6 | pub struct ExTokenizersDecodeStreamRef { 7 | skip_special_tokens: bool, 8 | ids: Vec, 9 | prefix: String, 10 | prefix_index: usize, 11 | read_index: usize, 12 | } 13 | 14 | impl ExTokenizersDecodeStreamRef { 15 | pub fn step( 16 | &mut self, 17 | tokenizer: ExTokenizersTokenizer, 18 | id: u32, 19 | ) -> tokenizers::tokenizer::Result> { 20 | tokenizers::step_decode_stream( 21 | &tokenizer.resource.0, 22 | id, 23 | self.skip_special_tokens, 24 | &mut self.ids, 25 | &mut self.prefix, 26 | &mut self.prefix_index, 27 | ) 28 | } 29 | } 30 | 31 | pub struct ExTokenizerDecodeStreamLock { 32 | pub inner: std::sync::RwLock, 33 | } 34 | 35 | #[rustler::resource_impl] 36 | impl rustler::Resource for ExTokenizerDecodeStreamLock {} 37 | 38 | #[derive(rustler::NifStruct)] 39 | #[module = "Tokenizers.DecodeStream"] 40 | pub struct ExTokenizersDecodeStream { 41 | pub resource: rustler::ResourceArc, 42 | } 43 | 44 | impl Serialize for ExTokenizersDecodeStream { 45 | fn serialize(&self, serializer: S) -> Result 46 | where 47 | S: serde::Serializer, 48 | { 49 | self.resource.inner.serialize(serializer) 50 | } 51 | } 52 | 53 | impl<'de> Deserialize<'de> for ExTokenizersDecodeStream { 54 | fn deserialize(deserializer: D) -> Result 55 | where 56 | D: serde::Deserializer<'de>, 57 | { 58 | Ok(ExTokenizersDecodeStream::new( 59 | ExTokenizersDecodeStreamRef::deserialize(deserializer)?, 60 | )) 61 | } 62 | } 63 | 64 | impl Clone for ExTokenizersDecodeStream { 65 | fn clone(&self) -> Self { 66 | Self { 67 | resource: rustler::ResourceArc::new(ExTokenizerDecodeStreamLock { 68 | inner: std::sync::RwLock::new(self.resource.inner.read().unwrap().clone()), 69 | }), 70 | } 71 | } 72 | } 73 | 74 | impl ExTokenizersDecodeStream { 75 | pub fn new(data: ExTokenizersDecodeStreamRef) -> Self { 76 | Self { 77 | resource: rustler::ResourceArc::new(ExTokenizerDecodeStreamLock { 78 | inner: std::sync::RwLock::new(data), 79 | }), 80 | } 81 | } 82 | } 83 | 84 | #[rustler::nif(schedule = "DirtyCpu")] 85 | fn decoder_stream_step( 86 | decode_stream: ExTokenizersDecodeStream, 87 | tokenizer: ExTokenizersTokenizer, 88 | id: u32, 89 | ) -> Result, ExTokenizersError> { 90 | decode_stream 91 | .resource 92 | .inner 93 | .write() 94 | .unwrap() 95 | .step(tokenizer, id) 96 | .map_err(ExTokenizersError::Tokenizer) 97 | } 98 | 99 | #[rustler::nif] 100 | fn decoder_stream_new(skip_special_tokens: bool) -> ExTokenizersDecodeStream { 101 | let ds = ExTokenizersDecodeStreamRef { 102 | skip_special_tokens, 103 | ids: vec![], 104 | prefix: "".to_string(), 105 | prefix_index: 0, 106 | read_index: 0, 107 | }; 108 | 109 | ExTokenizersDecodeStream::new(ds) 110 | } 111 | 112 | /////////////////////////////////////////////////////////////////////////////// 113 | /// Inspection 114 | /////////////////////////////////////////////////////////////////////////////// 115 | 116 | #[rustler::nif] 117 | fn decoder_stream_info(decode_stream: ExTokenizersDecodeStream) -> Info { 118 | let ds = decode_stream.resource.inner.read().unwrap(); 119 | 120 | new_info! { 121 | skip_special_tokens: ds.skip_special_tokens 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/decoders.rs: -------------------------------------------------------------------------------- 1 | use rustler::NifTaggedEnum; 2 | use serde::{Deserialize, Serialize}; 3 | use tokenizers::{Decoder, DecoderWrapper}; 4 | 5 | use crate::{new_info, util::Info, ExTokenizersError}; 6 | 7 | pub struct ExTokenizersDecoderRef(pub DecoderWrapper); 8 | 9 | #[rustler::resource_impl] 10 | impl rustler::Resource for ExTokenizersDecoderRef {} 11 | 12 | #[derive(rustler::NifStruct)] 13 | #[module = "Tokenizers.Decoder"] 14 | pub struct ExTokenizersDecoder { 15 | pub resource: rustler::ResourceArc, 16 | } 17 | 18 | impl Serialize for ExTokenizersDecoder { 19 | fn serialize(&self, serializer: S) -> Result 20 | where 21 | S: serde::Serializer, 22 | { 23 | self.resource.0.serialize(serializer) 24 | } 25 | } 26 | 27 | impl<'de> Deserialize<'de> for ExTokenizersDecoder { 28 | fn deserialize(deserializer: D) -> Result 29 | where 30 | D: serde::Deserializer<'de>, 31 | { 32 | Ok(ExTokenizersDecoder::new(DecoderWrapper::deserialize( 33 | deserializer, 34 | )?)) 35 | } 36 | } 37 | 38 | impl Clone for ExTokenizersDecoder { 39 | fn clone(&self) -> Self { 40 | Self { 41 | resource: self.resource.clone(), 42 | } 43 | } 44 | } 45 | 46 | impl ExTokenizersDecoderRef { 47 | pub fn new(data: T) -> Self 48 | where 49 | T: Into, 50 | { 51 | Self(data.into()) 52 | } 53 | } 54 | 55 | impl ExTokenizersDecoder { 56 | pub fn new(data: T) -> Self 57 | where 58 | T: Into, 59 | { 60 | Self { 61 | resource: rustler::ResourceArc::new(ExTokenizersDecoderRef::new(data)), 62 | } 63 | } 64 | } 65 | 66 | impl tokenizers::Decoder for ExTokenizersDecoder { 67 | fn decode_chain(&self, tokens: Vec) -> tokenizers::Result> { 68 | self.resource.0.decode_chain(tokens) 69 | } 70 | } 71 | 72 | #[rustler::nif(schedule = "DirtyCpu")] 73 | fn decoders_decode( 74 | decoder: ExTokenizersDecoder, 75 | tokens: Vec, 76 | ) -> Result { 77 | decoder 78 | .resource 79 | .0 80 | .decode(tokens) 81 | .map_err(ExTokenizersError::Tokenizer) 82 | } 83 | 84 | /////////////////////////////////////////////////////////////////////////////// 85 | /// Inspection 86 | /////////////////////////////////////////////////////////////////////////////// 87 | 88 | #[rustler::nif] 89 | fn decoders_info(decoder: ExTokenizersDecoder) -> Info { 90 | match &decoder.resource.0 { 91 | tokenizers::DecoderWrapper::BPE(decoder) => new_info! { 92 | decoder_type: "BPE", 93 | suffix: decoder.suffix.clone() 94 | }, 95 | tokenizers::DecoderWrapper::ByteLevel(decoder) => new_info! { 96 | decoder_type: "ByteLevel", 97 | add_prefix_space: decoder.add_prefix_space, 98 | trim_offsets: decoder.trim_offsets, 99 | use_regex: decoder.use_regex 100 | }, 101 | tokenizers::DecoderWrapper::WordPiece(decoder) => new_info! { 102 | decoder_type: "WordPiece", 103 | prefix: decoder.prefix.clone(), 104 | cleanup: decoder.cleanup 105 | }, 106 | tokenizers::DecoderWrapper::Metaspace(decoder) => new_info! { 107 | decoder_type: "Metaspace", 108 | prepend_scheme: match decoder.prepend_scheme { 109 | tokenizers::pre_tokenizers::metaspace::PrependScheme::First => "first", 110 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Never => "never", 111 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Always => "always", 112 | } 113 | }, 114 | tokenizers::DecoderWrapper::CTC(decoder) => new_info! { 115 | decoder_type: "CTC", 116 | pad_token: decoder.pad_token.clone(), 117 | word_delimiter_token: decoder.word_delimiter_token.clone(), 118 | cleanup: decoder.cleanup 119 | }, 120 | tokenizers::DecoderWrapper::Sequence(_) => new_info! { 121 | decoder_type: "Sequence" 122 | }, 123 | DecoderWrapper::Replace(_) => new_info! { 124 | decoder_type: "Replace" 125 | }, 126 | DecoderWrapper::Fuse(_) => new_info! { 127 | decoder_type: "Fuse" 128 | }, 129 | DecoderWrapper::Strip(decoder) => new_info! { 130 | decoder_type: "Strip", 131 | content: decoder.content as u32, 132 | start: decoder.start, 133 | stop: decoder.stop 134 | }, 135 | DecoderWrapper::ByteFallback(_) => new_info! { 136 | decoder_type: "ByteFallback" 137 | }, 138 | } 139 | } 140 | 141 | /////////////////////////////////////////////////////////////////////////////// 142 | /// Builders 143 | /////////////////////////////////////////////////////////////////////////////// 144 | 145 | #[rustler::nif] 146 | fn decoders_byte_level() -> ExTokenizersDecoder { 147 | ExTokenizersDecoder::new(tokenizers::decoders::byte_level::ByteLevel::default()) 148 | } 149 | 150 | #[rustler::nif] 151 | fn decoders_replace( 152 | pattern: String, 153 | content: String, 154 | ) -> Result { 155 | Ok(ExTokenizersDecoder::new( 156 | tokenizers::normalizers::Replace::new(pattern, content) 157 | .map_err(|_| rustler::Error::BadArg)?, 158 | )) 159 | } 160 | 161 | #[derive(NifTaggedEnum)] 162 | pub enum WordpieceOption { 163 | Prefix(String), 164 | Cleanup(bool), 165 | } 166 | 167 | #[rustler::nif] 168 | fn decoders_wordpiece(options: Vec) -> ExTokenizersDecoder { 169 | struct Opts { 170 | prefix: String, 171 | cleanup: bool, 172 | } 173 | let mut opts = Opts { 174 | prefix: "##".into(), 175 | cleanup: true, 176 | }; 177 | for opt in options { 178 | match opt { 179 | WordpieceOption::Prefix(prefix) => opts.prefix = prefix, 180 | WordpieceOption::Cleanup(cleanup) => opts.cleanup = cleanup, 181 | } 182 | } 183 | ExTokenizersDecoder::new(tokenizers::decoders::wordpiece::WordPiece::new( 184 | opts.prefix, 185 | opts.cleanup, 186 | )) 187 | } 188 | 189 | #[rustler::nif] 190 | fn decoders_byte_fallback() -> ExTokenizersDecoder { 191 | ExTokenizersDecoder::new(tokenizers::decoders::byte_fallback::ByteFallback::new()) 192 | } 193 | 194 | #[rustler::nif] 195 | fn decoders_fuse() -> ExTokenizersDecoder { 196 | ExTokenizersDecoder::new(tokenizers::decoders::fuse::Fuse::new()) 197 | } 198 | 199 | #[rustler::nif] 200 | fn decoders_strip( 201 | content: u32, 202 | left: usize, 203 | right: usize, 204 | ) -> Result { 205 | let content = std::char::from_u32(content).ok_or(rustler::Error::BadArg)?; 206 | Ok(ExTokenizersDecoder::new( 207 | tokenizers::decoders::strip::Strip::new(content, left, right), 208 | )) 209 | } 210 | 211 | #[derive(NifTaggedEnum)] 212 | pub enum MetaspaceOption { 213 | Replacement(u32), 214 | PrependScheme(PrependScheme), 215 | } 216 | 217 | #[derive(NifTaggedEnum)] 218 | pub enum PrependScheme { 219 | First, 220 | Never, 221 | Always, 222 | } 223 | 224 | #[rustler::nif] 225 | fn decoders_metaspace( 226 | options: Vec, 227 | ) -> Result { 228 | struct Opts { 229 | replacement: char, 230 | prepend_scheme: tokenizers::decoders::metaspace::PrependScheme, 231 | } 232 | let mut opts = Opts { 233 | replacement: '▁', 234 | prepend_scheme: tokenizers::decoders::metaspace::PrependScheme::Always, 235 | }; 236 | for opt in options { 237 | match opt { 238 | MetaspaceOption::Replacement(replacement) => { 239 | opts.replacement = std::char::from_u32(replacement).ok_or(rustler::Error::BadArg)? 240 | } 241 | MetaspaceOption::PrependScheme(prepend_scheme) => { 242 | opts.prepend_scheme = match prepend_scheme { 243 | PrependScheme::First => { 244 | tokenizers::pre_tokenizers::metaspace::PrependScheme::First 245 | } 246 | PrependScheme::Never => { 247 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Never 248 | } 249 | PrependScheme::Always => { 250 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Always 251 | } 252 | } 253 | } 254 | } 255 | } 256 | Ok(ExTokenizersDecoder::new( 257 | tokenizers::decoders::metaspace::Metaspace::new( 258 | opts.replacement, 259 | opts.prepend_scheme, 260 | true, 261 | ), 262 | )) 263 | } 264 | 265 | #[derive(NifTaggedEnum)] 266 | pub enum BpeOption { 267 | Suffix(String), 268 | } 269 | 270 | #[rustler::nif] 271 | fn decoders_bpe(options: Vec) -> ExTokenizersDecoder { 272 | struct Opts { 273 | suffix: String, 274 | } 275 | let mut opts = Opts { 276 | suffix: "".into(), 277 | }; 278 | for opt in options { 279 | match opt { 280 | BpeOption::Suffix(suffix) => opts.suffix = suffix, 281 | } 282 | } 283 | 284 | ExTokenizersDecoder::new(tokenizers::decoders::bpe::BPEDecoder::new(opts.suffix)) 285 | } 286 | 287 | #[derive(NifTaggedEnum)] 288 | pub enum CTCOption { 289 | PadToken(String), 290 | WordDelimiterToken(String), 291 | Cleanup(bool), 292 | } 293 | 294 | #[rustler::nif] 295 | fn decoders_ctc(options: Vec) -> ExTokenizersDecoder { 296 | struct Opts { 297 | pad_token: String, 298 | word_delimiter_token: String, 299 | cleanup: bool, 300 | } 301 | let mut opts = Opts { 302 | pad_token: "".into(), 303 | word_delimiter_token: "|".into(), 304 | cleanup: true, 305 | }; 306 | 307 | for opt in options { 308 | match opt { 309 | CTCOption::PadToken(pad_token) => opts.pad_token = pad_token, 310 | CTCOption::WordDelimiterToken(word_delimiter_token) => { 311 | opts.word_delimiter_token = word_delimiter_token 312 | } 313 | CTCOption::Cleanup(cleanup) => opts.cleanup = cleanup, 314 | } 315 | } 316 | 317 | ExTokenizersDecoder::new(tokenizers::decoders::ctc::CTC::new( 318 | opts.pad_token, 319 | opts.word_delimiter_token, 320 | opts.cleanup, 321 | )) 322 | } 323 | 324 | #[rustler::nif] 325 | fn decoders_sequence(decoders: Vec) -> ExTokenizersDecoder { 326 | let sequence = decoders 327 | .iter() 328 | .map(|decoder| decoder.resource.clone()) 329 | .fold(Vec::with_capacity(decoders.len()), |mut acc, decoder| { 330 | acc.push(decoder.0.clone()); 331 | acc 332 | }); 333 | 334 | ExTokenizersDecoder::new(tokenizers::decoders::sequence::Sequence::new(sequence)) 335 | } 336 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/encoding.rs: -------------------------------------------------------------------------------- 1 | use rustler::{Binary, Env, NifTaggedEnum, ResourceArc}; 2 | use tokenizers::Encoding; 3 | 4 | use crate::util::Direction; 5 | 6 | pub struct ExTokenizersEncodingRef(pub Encoding); 7 | 8 | #[rustler::resource_impl] 9 | impl rustler::Resource for ExTokenizersEncodingRef {} 10 | 11 | #[derive(rustler::NifStruct)] 12 | #[module = "Tokenizers.Encoding"] 13 | pub struct ExTokenizersEncoding { 14 | pub resource: ResourceArc, 15 | } 16 | 17 | impl From for ExTokenizersEncoding { 18 | fn from(encoding: Encoding) -> Self { 19 | Self { 20 | resource: ResourceArc::new(ExTokenizersEncodingRef(encoding)), 21 | } 22 | } 23 | } 24 | 25 | /////////////////////////////////////////////////////////////////////////////// 26 | /// Implementation 27 | /////////////////////////////////////////////////////////////////////////////// 28 | 29 | #[rustler::nif] 30 | pub fn encoding_get_length(encoding: ExTokenizersEncoding) -> usize { 31 | encoding.resource.0.len() 32 | } 33 | 34 | #[rustler::nif] 35 | pub fn encoding_get_n_sequences(encoding: ExTokenizersEncoding) -> usize { 36 | encoding.resource.0.n_sequences() 37 | } 38 | 39 | #[rustler::nif] 40 | pub fn encoding_set_sequence_id( 41 | encoding: ExTokenizersEncoding, 42 | seq_id: usize, 43 | ) -> ExTokenizersEncoding { 44 | let mut encoding = encoding.resource.0.clone(); 45 | encoding.set_sequence_id(seq_id); 46 | encoding.into() 47 | } 48 | 49 | #[rustler::nif] 50 | pub fn encoding_get_ids(encoding: ExTokenizersEncoding) -> Vec { 51 | encoding.resource.0.get_ids().to_vec() 52 | } 53 | 54 | #[rustler::nif] 55 | pub fn encoding_get_u32_ids(env: Env, encoding: ExTokenizersEncoding) -> Binary { 56 | encoding 57 | .resource 58 | .make_binary(env, |r| slice_u32_to_u8(r.0.get_ids())) 59 | } 60 | 61 | #[rustler::nif] 62 | pub fn encoding_get_type_ids(encoding: ExTokenizersEncoding) -> Vec { 63 | encoding.resource.0.get_type_ids().to_vec() 64 | } 65 | 66 | #[rustler::nif] 67 | pub fn encoding_get_u32_type_ids(env: Env, encoding: ExTokenizersEncoding) -> Binary { 68 | encoding 69 | .resource 70 | .make_binary(env, |r| slice_u32_to_u8(r.0.get_type_ids())) 71 | } 72 | 73 | #[rustler::nif] 74 | pub fn encoding_get_attention_mask(encoding: ExTokenizersEncoding) -> Vec { 75 | encoding.resource.0.get_attention_mask().to_vec() 76 | } 77 | 78 | #[rustler::nif] 79 | pub fn encoding_get_u32_attention_mask(env: Env, encoding: ExTokenizersEncoding) -> Binary { 80 | encoding 81 | .resource 82 | .make_binary(env, |r| slice_u32_to_u8(r.0.get_attention_mask())) 83 | } 84 | 85 | #[rustler::nif] 86 | pub fn encoding_get_special_tokens_mask(encoding: ExTokenizersEncoding) -> Vec { 87 | encoding.resource.0.get_special_tokens_mask().to_vec() 88 | } 89 | 90 | #[rustler::nif] 91 | pub fn encoding_get_u32_special_tokens_mask(env: Env, encoding: ExTokenizersEncoding) -> Binary { 92 | encoding 93 | .resource 94 | .make_binary(env, |r| slice_u32_to_u8(r.0.get_special_tokens_mask())) 95 | } 96 | 97 | #[rustler::nif] 98 | pub fn encoding_get_tokens(encoding: ExTokenizersEncoding) -> Vec { 99 | encoding.resource.0.get_tokens().to_vec() 100 | } 101 | 102 | #[rustler::nif] 103 | pub fn encoding_get_word_ids(encoding: ExTokenizersEncoding) -> Vec> { 104 | encoding.resource.0.get_word_ids().to_vec() 105 | } 106 | 107 | #[rustler::nif] 108 | pub fn encoding_get_sequence_ids(encoding: ExTokenizersEncoding) -> Vec> { 109 | encoding.resource.0.get_sequence_ids().to_vec() 110 | } 111 | 112 | #[rustler::nif] 113 | pub fn encoding_get_offsets(encoding: ExTokenizersEncoding) -> Vec<(usize, usize)> { 114 | encoding.resource.0.get_offsets().to_vec() 115 | } 116 | 117 | #[rustler::nif] 118 | pub fn encoding_get_overflowing(encoding: ExTokenizersEncoding) -> Vec { 119 | encoding 120 | .resource 121 | .0 122 | .get_overflowing() 123 | .iter() 124 | .map(|encoding| encoding.clone().into()) 125 | .collect::>() 126 | } 127 | 128 | #[rustler::nif] 129 | pub fn encoding_word_to_tokens( 130 | encoding: ExTokenizersEncoding, 131 | word: u32, 132 | seq_id: usize, 133 | ) -> Option<(usize, usize)> { 134 | encoding.resource.0.word_to_tokens(word, seq_id) 135 | } 136 | 137 | #[rustler::nif] 138 | pub fn encoding_word_to_chars( 139 | encoding: ExTokenizersEncoding, 140 | word: u32, 141 | seq_id: usize, 142 | ) -> Option<(usize, usize)> { 143 | encoding.resource.0.word_to_chars(word, seq_id) 144 | } 145 | 146 | #[rustler::nif] 147 | pub fn encoding_token_to_sequence(encoding: ExTokenizersEncoding, token: usize) -> Option { 148 | encoding.resource.0.token_to_sequence(token) 149 | } 150 | 151 | #[rustler::nif] 152 | pub fn encoding_token_to_chars( 153 | encoding: ExTokenizersEncoding, 154 | token: usize, 155 | ) -> Option<(usize, (usize, usize))> { 156 | encoding.resource.0.token_to_chars(token) 157 | } 158 | 159 | #[rustler::nif] 160 | pub fn encoding_token_to_word( 161 | encoding: ExTokenizersEncoding, 162 | token: usize, 163 | ) -> Option<(usize, u32)> { 164 | encoding.resource.0.token_to_word(token) 165 | } 166 | 167 | #[rustler::nif] 168 | pub fn encoding_char_to_token( 169 | encoding: ExTokenizersEncoding, 170 | position: usize, 171 | seq_id: usize, 172 | ) -> Option { 173 | encoding.resource.0.char_to_token(position, seq_id) 174 | } 175 | 176 | #[rustler::nif] 177 | pub fn encoding_char_to_word( 178 | encoding: ExTokenizersEncoding, 179 | position: usize, 180 | seq_id: usize, 181 | ) -> Option { 182 | encoding.resource.0.char_to_word(position, seq_id) 183 | } 184 | 185 | #[derive(NifTaggedEnum)] 186 | pub enum PadOption { 187 | PadId(u32), 188 | PadTypeId(u32), 189 | PadToken(String), 190 | Direction(Direction), 191 | } 192 | 193 | struct Padding { 194 | pad_id: u32, 195 | pad_type_id: u32, 196 | pad_token: String, 197 | direction: Direction, 198 | } 199 | 200 | fn parse_pad_options(opts: &Vec) -> Padding { 201 | let mut default = Padding { 202 | pad_id: 0, 203 | pad_type_id: 0, 204 | pad_token: "[PAD]".to_string(), 205 | direction: Direction::Right, 206 | }; 207 | for opt in opts { 208 | match opt { 209 | PadOption::PadId(id) => default.pad_id = *id, 210 | PadOption::PadTypeId(id) => default.pad_type_id = *id, 211 | PadOption::PadToken(token) => default.pad_token = token.clone(), 212 | PadOption::Direction(direction) => default.direction = direction.clone(), 213 | } 214 | } 215 | default 216 | } 217 | 218 | #[rustler::nif] 219 | pub fn encoding_pad( 220 | encoding: ExTokenizersEncoding, 221 | target_length: usize, 222 | opts: Vec, 223 | ) -> ExTokenizersEncoding { 224 | let default = parse_pad_options(&opts); 225 | 226 | let mut encoding = encoding.resource.0.clone(); 227 | encoding.pad( 228 | target_length, 229 | default.pad_id, 230 | default.pad_type_id, 231 | &default.pad_token, 232 | default.direction.into(), 233 | ); 234 | encoding.into() 235 | } 236 | 237 | #[derive(NifTaggedEnum)] 238 | pub enum TruncationOption { 239 | Stride(usize), 240 | Direction(Direction), 241 | } 242 | 243 | struct Truncation { 244 | stride: usize, 245 | direction: Direction, 246 | } 247 | 248 | fn parse_truncation_options(opts: &Vec) -> Truncation { 249 | let mut default = Truncation { 250 | stride: 0, 251 | direction: Direction::Right, 252 | }; 253 | 254 | for opt in opts { 255 | match opt { 256 | TruncationOption::Stride(stride) => default.stride = *stride, 257 | TruncationOption::Direction(direction) => default.direction = direction.clone(), 258 | } 259 | } 260 | default 261 | } 262 | 263 | #[rustler::nif] 264 | pub fn encoding_truncate( 265 | encoding: ExTokenizersEncoding, 266 | max_len: usize, 267 | opts: Vec, 268 | ) -> ExTokenizersEncoding { 269 | let default = parse_truncation_options(&opts); 270 | 271 | let mut encoding = encoding.resource.0.clone(); 272 | 273 | encoding.truncate(max_len, default.stride, default.direction.into()); 274 | encoding.into() 275 | } 276 | 277 | fn slice_u32_to_u8(slice: &[u32]) -> &[u8] { 278 | unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len() * 4) } 279 | } 280 | 281 | /////////////////////////////////////////////////////////////////////////////// 282 | /// Encoding transformations 283 | /////////////////////////////////////////////////////////////////////////////// 284 | 285 | #[derive(NifTaggedEnum)] 286 | pub enum TransformationElement { 287 | Pad((usize, Vec)), // {:pad, {target_length, opts}} 288 | Truncate((usize, Vec)), // {:truncate, {max_len, opts}} 289 | SetSequenceId(usize), // {:set_sequence_id, seq_id} 290 | } 291 | 292 | #[rustler::nif] 293 | pub fn encoding_transform( 294 | encoding: ExTokenizersEncoding, 295 | transformations: Vec, 296 | ) -> ExTokenizersEncoding { 297 | let mut encoding = encoding.resource.0.clone(); 298 | apply_transformations(&mut encoding, &transformations); 299 | encoding.into() 300 | } 301 | 302 | pub fn apply_transformations( 303 | encoding: &mut Encoding, 304 | transformations: &Vec, 305 | ) { 306 | for transformation in transformations { 307 | match transformation { 308 | TransformationElement::Pad((target_length, opts)) => { 309 | let default = parse_pad_options(opts); 310 | 311 | encoding.pad( 312 | *target_length, 313 | default.pad_id, 314 | default.pad_type_id, 315 | &default.pad_token, 316 | default.direction.into(), 317 | ) 318 | } 319 | TransformationElement::Truncate((max_len, opts)) => { 320 | let default = parse_truncation_options(opts); 321 | encoding.truncate(*max_len, default.stride, default.direction.into()) 322 | } 323 | TransformationElement::SetSequenceId(seq_id) => encoding.set_sequence_id(*seq_id), 324 | } 325 | } 326 | } 327 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/error.rs: -------------------------------------------------------------------------------- 1 | use rustler::{Encoder, Env, Term}; 2 | use std::{io, panic::RefUnwindSafe}; 3 | use thiserror::Error; 4 | 5 | rustler::atoms! { 6 | ok, 7 | error 8 | } 9 | 10 | #[derive(Error, Debug)] 11 | pub enum ExTokenizersError { 12 | #[error("Invalid Char")] 13 | InvalidChar, 14 | #[error("Tokenizer Error")] 15 | Tokenizer(#[from] tokenizers::Error), 16 | #[error("IO Error")] 17 | Io(#[from] io::Error), 18 | #[error("Internal Error: {0}")] 19 | Internal(String), 20 | #[error("Other error: {0}")] 21 | Other(String), 22 | #[error(transparent)] 23 | Unknown(#[from] anyhow::Error), 24 | } 25 | 26 | impl Encoder for ExTokenizersError { 27 | fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { 28 | format!("{self:?}").encode(env) 29 | } 30 | } 31 | 32 | impl RefUnwindSafe for ExTokenizersError {} 33 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod added_token; 2 | mod decode_stream; 3 | mod decoders; 4 | mod encoding; 5 | mod error; 6 | mod models; 7 | mod normalizers; 8 | mod post_processors; 9 | mod pre_tokenizers; 10 | mod tokenizer; 11 | mod trainers; 12 | mod util; 13 | 14 | use rustler::{Env, Term}; 15 | 16 | pub use error::ExTokenizersError; 17 | 18 | fn on_load(_env: Env, _info: Term) -> bool { 19 | true 20 | } 21 | 22 | rustler::init!("Elixir.Tokenizers.Native", load = on_load); 23 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/normalizers.rs: -------------------------------------------------------------------------------- 1 | use crate::{new_info, util::Info, ExTokenizersError}; 2 | use rustler::NifTaggedEnum; 3 | use serde::{Deserialize, Serialize}; 4 | use tokenizers::{ 5 | normalizers::{replace::ReplacePattern, ByteLevel}, 6 | NormalizedString, Normalizer, NormalizerWrapper, 7 | }; 8 | 9 | pub struct ExTokenizersNormalizerRef(pub NormalizerWrapper); 10 | 11 | #[rustler::resource_impl] 12 | impl rustler::Resource for ExTokenizersNormalizerRef {} 13 | 14 | #[derive(rustler::NifStruct)] 15 | #[module = "Tokenizers.Normalizer"] 16 | pub struct ExTokenizersNormalizer { 17 | pub resource: rustler::ResourceArc, 18 | } 19 | 20 | impl Serialize for ExTokenizersNormalizer { 21 | fn serialize(&self, serializer: S) -> Result 22 | where 23 | S: serde::Serializer, 24 | { 25 | self.resource.0.serialize(serializer) 26 | } 27 | } 28 | 29 | impl<'de> Deserialize<'de> for ExTokenizersNormalizer { 30 | fn deserialize(deserializer: D) -> Result 31 | where 32 | D: serde::Deserializer<'de>, 33 | { 34 | Ok(ExTokenizersNormalizer::new(NormalizerWrapper::deserialize( 35 | deserializer, 36 | )?)) 37 | } 38 | } 39 | 40 | impl ExTokenizersNormalizerRef { 41 | pub fn new(data: T) -> Self 42 | where 43 | T: Into, 44 | { 45 | Self(data.into()) 46 | } 47 | } 48 | 49 | impl Clone for ExTokenizersNormalizer { 50 | fn clone(&self) -> Self { 51 | Self { 52 | resource: self.resource.clone(), 53 | } 54 | } 55 | } 56 | 57 | impl ExTokenizersNormalizer { 58 | pub fn new(data: T) -> Self 59 | where 60 | T: Into, 61 | { 62 | Self { 63 | resource: rustler::ResourceArc::new(ExTokenizersNormalizerRef::new(data)), 64 | } 65 | } 66 | } 67 | 68 | impl tokenizers::Normalizer for ExTokenizersNormalizer { 69 | fn normalize(&self, normalized: &mut NormalizedString) -> tokenizers::Result<()> { 70 | self.resource.0.normalize(normalized) 71 | } 72 | } 73 | 74 | #[rustler::nif(schedule = "DirtyCpu")] 75 | fn normalizers_normalize( 76 | normalizer: ExTokenizersNormalizer, 77 | input: String, 78 | ) -> Result { 79 | let mut normalized = NormalizedString::from(input); 80 | normalizer.resource.0.normalize(&mut normalized)?; 81 | Ok(normalized.get().to_owned()) 82 | } 83 | 84 | // ///////////////////////////////////////////////////////////////////////////// 85 | // / Inspection 86 | // ///////////////////////////////////////////////////////////////////////////// 87 | 88 | #[rustler::nif] 89 | fn normalizers_info(normalizer: ExTokenizersNormalizer) -> Info { 90 | match normalizer.resource.0 { 91 | NormalizerWrapper::BertNormalizer(_) => new_info!( 92 | normalizer_type: "BertNormalizer" 93 | ), 94 | NormalizerWrapper::StripNormalizer(_) => new_info!( 95 | normalizer_type: "StripNormalizer" 96 | ), 97 | NormalizerWrapper::StripAccents(_) => new_info!( 98 | normalizer_type: "StripAccents" 99 | ), 100 | NormalizerWrapper::NFC(_) => new_info!( 101 | normalizer_type: "NFC" 102 | ), 103 | NormalizerWrapper::NFD(_) => new_info!( 104 | normalizer_type: "NFD" 105 | ), 106 | NormalizerWrapper::NFKC(_) => new_info!( 107 | normalizer_type: "NFKC" 108 | ), 109 | NormalizerWrapper::NFKD(_) => new_info!( 110 | normalizer_type: "NFKD" 111 | ), 112 | NormalizerWrapper::Sequence(_) => new_info!( 113 | normalizer_type: "Sequence" 114 | ), 115 | NormalizerWrapper::Lowercase(_) => new_info!( 116 | normalizer_type: "Lowercase" 117 | ), 118 | NormalizerWrapper::Nmt(_) => new_info!( 119 | normalizer_type: "Nmt" 120 | ), 121 | NormalizerWrapper::Precompiled(_) => new_info!( 122 | normalizer_type: "Precompiled" 123 | ), 124 | NormalizerWrapper::Replace(_) => new_info!( 125 | normalizer_type: "Replace" 126 | ), 127 | NormalizerWrapper::Prepend(_) => new_info!( 128 | normalizer_type: "Prepend" 129 | ), 130 | NormalizerWrapper::ByteLevel(_) => new_info!( 131 | normalizer_type: "ByteLevel" 132 | ), 133 | } 134 | } 135 | 136 | // ///////////////////////////////////////////////////////////////////////////// 137 | // / Implementation 138 | // ///////////////////////////////////////////////////////////////////////////// 139 | 140 | #[derive(NifTaggedEnum)] 141 | pub enum BertOption { 142 | CleanText(bool), 143 | HandleChineseChars(bool), 144 | StripAccents(bool), 145 | Lowercase(bool), 146 | } 147 | 148 | #[rustler::nif] 149 | pub fn normalizers_bert_normalizer(options: Vec) -> ExTokenizersNormalizer { 150 | struct Opts { 151 | clean_text: bool, 152 | handle_chinese_chars: bool, 153 | strip_accents: Option, 154 | lowercase: bool, 155 | } 156 | 157 | // Default values 158 | let mut opts = Opts { 159 | clean_text: true, 160 | handle_chinese_chars: true, 161 | strip_accents: None, 162 | lowercase: true, 163 | }; 164 | options.iter().for_each(|option| match option { 165 | BertOption::CleanText(val) => opts.clean_text = *val, 166 | BertOption::HandleChineseChars(val) => opts.handle_chinese_chars = *val, 167 | BertOption::StripAccents(val) => opts.strip_accents = Some(*val), 168 | BertOption::Lowercase(val) => opts.lowercase = *val, 169 | }); 170 | 171 | ExTokenizersNormalizer::new(tokenizers::normalizers::BertNormalizer::new( 172 | opts.clean_text, 173 | opts.handle_chinese_chars, 174 | opts.strip_accents, 175 | opts.lowercase, 176 | )) 177 | } 178 | 179 | #[rustler::nif] 180 | pub fn normalizers_nfd() -> ExTokenizersNormalizer { 181 | ExTokenizersNormalizer::new(tokenizers::normalizers::unicode::NFD) 182 | } 183 | 184 | #[rustler::nif] 185 | pub fn normalizers_nfkd() -> ExTokenizersNormalizer { 186 | ExTokenizersNormalizer::new(tokenizers::normalizers::unicode::NFKD) 187 | } 188 | 189 | #[rustler::nif] 190 | pub fn normalizers_nfc() -> ExTokenizersNormalizer { 191 | ExTokenizersNormalizer::new(tokenizers::normalizers::unicode::NFC) 192 | } 193 | 194 | #[rustler::nif] 195 | pub fn normalizers_nfkc() -> ExTokenizersNormalizer { 196 | ExTokenizersNormalizer::new(tokenizers::normalizers::unicode::NFKC) 197 | } 198 | 199 | #[derive(NifTaggedEnum)] 200 | pub enum StripOption { 201 | Left(bool), 202 | Right(bool), 203 | } 204 | 205 | #[rustler::nif] 206 | pub fn normalizers_strip(options: Vec) -> ExTokenizersNormalizer { 207 | struct Opts { 208 | left: bool, 209 | right: bool, 210 | } 211 | 212 | // Default values 213 | let mut opts = Opts { 214 | left: true, 215 | right: true, 216 | }; 217 | options.iter().for_each(|option| match option { 218 | StripOption::Left(val) => opts.left = *val, 219 | StripOption::Right(val) => opts.right = *val, 220 | }); 221 | 222 | ExTokenizersNormalizer::new(tokenizers::normalizers::strip::Strip::new( 223 | opts.left, opts.right, 224 | )) 225 | } 226 | 227 | #[rustler::nif] 228 | pub fn normalizers_prepend(prepend: String) -> ExTokenizersNormalizer { 229 | ExTokenizersNormalizer::new(tokenizers::normalizers::prepend::Prepend::new(prepend)) 230 | } 231 | 232 | #[rustler::nif] 233 | pub fn normalizers_strip_accents() -> ExTokenizersNormalizer { 234 | ExTokenizersNormalizer::new(tokenizers::normalizers::strip::StripAccents) 235 | } 236 | 237 | #[rustler::nif] 238 | pub fn normalizers_sequence(normalizers: Vec) -> ExTokenizersNormalizer { 239 | // Fairly saying, normalizer is immutable, but we are still using `arc` 240 | // to point already created normalizer instead of clonning it. 241 | let seq: Vec = normalizers 242 | .iter() 243 | .map(|normalizer| normalizer.resource.0.clone()) 244 | .collect(); 245 | ExTokenizersNormalizer::new(tokenizers::normalizers::Sequence::new(seq)) 246 | } 247 | 248 | #[rustler::nif] 249 | pub fn normalizers_lowercase() -> ExTokenizersNormalizer { 250 | ExTokenizersNormalizer::new(tokenizers::normalizers::utils::Lowercase) 251 | } 252 | 253 | #[derive(NifTaggedEnum)] 254 | pub enum LocalReplacePattern { 255 | String(String), 256 | Regex(String), 257 | } 258 | 259 | #[rustler::nif] 260 | pub fn normalizers_replace( 261 | pattern: LocalReplacePattern, 262 | content: String, 263 | ) -> Result { 264 | let final_pattern = match pattern { 265 | LocalReplacePattern::String(pattern) => ReplacePattern::String(pattern), 266 | LocalReplacePattern::Regex(pattern) => ReplacePattern::Regex(pattern), 267 | }; 268 | 269 | Ok(ExTokenizersNormalizer::new( 270 | tokenizers::normalizers::replace::Replace::new(final_pattern, content) 271 | .map_err(|_| rustler::Error::BadArg)?, 272 | )) 273 | } 274 | 275 | #[rustler::nif] 276 | pub fn normalizers_nmt() -> ExTokenizersNormalizer { 277 | ExTokenizersNormalizer::new(tokenizers::normalizers::unicode::Nmt) 278 | } 279 | 280 | #[rustler::nif] 281 | pub fn normalizers_precompiled(data: Vec) -> Result { 282 | Ok(ExTokenizersNormalizer::new( 283 | tokenizers::normalizers::precompiled::Precompiled::from(&data) 284 | .map_err(anyhow::Error::from)?, 285 | )) 286 | } 287 | 288 | // ByteLevel part 289 | 290 | #[rustler::nif] 291 | pub fn normalizers_byte_level() -> ExTokenizersNormalizer { 292 | ExTokenizersNormalizer::new(tokenizers::normalizers::byte_level::ByteLevel) 293 | } 294 | 295 | #[rustler::nif] 296 | pub fn normalizers_byte_level_alphabet() -> Vec { 297 | ByteLevel::alphabet() 298 | .iter() 299 | .map(|c| String::from(*c)) 300 | .collect() 301 | } 302 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/post_processors.rs: -------------------------------------------------------------------------------- 1 | use rustler::NifTaggedEnum; 2 | use serde::{Deserialize, Serialize}; 3 | use tokenizers::{Encoding, PostProcessorWrapper}; 4 | 5 | use crate::{new_info, util::Info}; 6 | 7 | pub struct ExTokenizersPostProcessorRef(pub PostProcessorWrapper); 8 | 9 | #[rustler::resource_impl] 10 | impl rustler::Resource for ExTokenizersPostProcessorRef {} 11 | 12 | #[derive(rustler::NifStruct)] 13 | #[module = "Tokenizers.PostProcessor"] 14 | pub struct ExTokenizersPostProcessor { 15 | pub resource: rustler::ResourceArc, 16 | } 17 | 18 | impl ExTokenizersPostProcessorRef { 19 | pub fn new(data: T) -> Self 20 | where 21 | T: Into, 22 | { 23 | Self(data.into()) 24 | } 25 | } 26 | 27 | impl ExTokenizersPostProcessor { 28 | pub fn new(data: T) -> Self 29 | where 30 | T: Into, 31 | { 32 | Self { 33 | resource: rustler::ResourceArc::new(ExTokenizersPostProcessorRef::new(data)), 34 | } 35 | } 36 | } 37 | 38 | impl tokenizers::PostProcessor for ExTokenizersPostProcessor { 39 | fn added_tokens(&self, is_pair: bool) -> usize { 40 | self.resource.0.added_tokens(is_pair) 41 | } 42 | 43 | fn process_encodings( 44 | &self, 45 | encodings: Vec, 46 | add_special_tokens: bool, 47 | ) -> tokenizers::Result> { 48 | self.resource 49 | .0 50 | .process_encodings(encodings, add_special_tokens) 51 | } 52 | } 53 | 54 | impl Serialize for ExTokenizersPostProcessor { 55 | fn serialize(&self, serializer: S) -> Result 56 | where 57 | S: serde::Serializer, 58 | { 59 | self.resource.0.serialize(serializer) 60 | } 61 | } 62 | 63 | impl<'de> Deserialize<'de> for ExTokenizersPostProcessor { 64 | fn deserialize(deserializer: D) -> Result 65 | where 66 | D: serde::Deserializer<'de>, 67 | { 68 | Ok(ExTokenizersPostProcessor::new( 69 | PostProcessorWrapper::deserialize(deserializer)?, 70 | )) 71 | } 72 | } 73 | 74 | impl Clone for ExTokenizersPostProcessor { 75 | fn clone(&self) -> Self { 76 | Self { 77 | resource: self.resource.clone(), 78 | } 79 | } 80 | } 81 | 82 | type ProcessorPair = (String, u32); 83 | 84 | // ///////////////////////////////////////////////////////////////////////////// 85 | // / Inspection 86 | // ///////////////////////////////////////////////////////////////////////////// 87 | #[rustler::nif] 88 | fn post_processors_info(post_processor: ExTokenizersPostProcessor) -> Info { 89 | match &post_processor.resource.0 { 90 | PostProcessorWrapper::Roberta(_) => new_info![post_processor_type: "roberta"], 91 | PostProcessorWrapper::Bert(_) => new_info![post_processor_type: "bert"], 92 | PostProcessorWrapper::ByteLevel(_) => new_info![post_processor_type: "byte_level"], 93 | PostProcessorWrapper::Template(_) => new_info![post_processor_type: "template"], 94 | PostProcessorWrapper::Sequence(_) => new_info![post_processor_type: "sequence"], 95 | } 96 | } 97 | 98 | // ///////////////////////////////////////////////////////////////////////////// 99 | // / Implementation 100 | // ///////////////////////////////////////////////////////////////////////////// 101 | #[rustler::nif] 102 | pub fn post_processors_bert(sep: ProcessorPair, cls: ProcessorPair) -> ExTokenizersPostProcessor { 103 | ExTokenizersPostProcessor::new(tokenizers::processors::bert::BertProcessing::new(sep, cls)) 104 | } 105 | 106 | #[derive(NifTaggedEnum)] 107 | pub enum RobertaOption { 108 | TrimOffsets(bool), 109 | AddPrefixSpace(bool), 110 | } 111 | 112 | #[rustler::nif] 113 | pub fn post_processors_roberta( 114 | sep: ProcessorPair, 115 | cls: ProcessorPair, 116 | opts: Vec, 117 | ) -> ExTokenizersPostProcessor { 118 | let mut proc = tokenizers::processors::roberta::RobertaProcessing::new(sep, cls); 119 | for opt in opts { 120 | match opt { 121 | RobertaOption::TrimOffsets(v) => proc = proc.trim_offsets(v), 122 | RobertaOption::AddPrefixSpace(v) => proc = proc.add_prefix_space(v), 123 | } 124 | } 125 | ExTokenizersPostProcessor::new(proc) 126 | } 127 | 128 | #[derive(NifTaggedEnum)] 129 | pub enum ByteLevelOption { 130 | TrimOffsets(bool), 131 | } 132 | 133 | #[rustler::nif] 134 | pub fn post_processors_byte_level(opts: Vec) -> ExTokenizersPostProcessor { 135 | let mut proc = tokenizers::processors::byte_level::ByteLevel::default(); 136 | for opt in opts { 137 | match opt { 138 | ByteLevelOption::TrimOffsets(v) => proc = proc.trim_offsets(v), 139 | } 140 | } 141 | ExTokenizersPostProcessor::new(proc) 142 | } 143 | 144 | #[derive(NifTaggedEnum)] 145 | pub enum TemplateOption { 146 | Single(String), 147 | Pair(String), 148 | SpecialTokens(Vec<(String, u32)>), 149 | } 150 | 151 | #[rustler::nif] 152 | pub fn post_processors_template( 153 | opts: Vec, 154 | ) -> Result { 155 | let mut builder = tokenizers::processors::template::TemplateProcessing::builder(); 156 | for opt in opts { 157 | match opt { 158 | TemplateOption::Single(v) => { 159 | builder.try_single(v).map_err(|_| rustler::Error::BadArg)? 160 | } 161 | TemplateOption::Pair(v) => builder.try_pair(v).map_err(|_| rustler::Error::BadArg)?, 162 | TemplateOption::SpecialTokens(v) => builder.special_tokens(v), 163 | }; 164 | } 165 | Ok(ExTokenizersPostProcessor::new( 166 | builder.build().map_err(|_| rustler::Error::BadArg)?, 167 | )) 168 | } 169 | 170 | #[rustler::nif] 171 | pub fn post_processors_sequence( 172 | post_processors: Vec, 173 | ) -> ExTokenizersPostProcessor { 174 | ExTokenizersPostProcessor::new(tokenizers::processors::sequence::Sequence::new( 175 | post_processors 176 | .iter() 177 | .map(|pp| pp.resource.0.clone()) 178 | .collect(), 179 | )) 180 | } 181 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/pre_tokenizers.rs: -------------------------------------------------------------------------------- 1 | use crate::util::Info; 2 | use crate::{new_info, ExTokenizersError}; 3 | use rustler::NifTaggedEnum; 4 | use serde::{Deserialize, Serialize}; 5 | use tokenizers::pre_tokenizers::split::SplitPattern; 6 | use tokenizers::PreTokenizer; 7 | use tokenizers::{processors::byte_level::ByteLevel, PreTokenizedString, PreTokenizerWrapper}; 8 | 9 | pub struct ExTokenizersPreTokenizerRef(pub PreTokenizerWrapper); 10 | 11 | #[rustler::resource_impl] 12 | impl rustler::Resource for ExTokenizersPreTokenizerRef {} 13 | 14 | #[derive(rustler::NifStruct)] 15 | #[module = "Tokenizers.PreTokenizer"] 16 | pub struct ExTokenizersPreTokenizer { 17 | pub resource: rustler::ResourceArc, 18 | } 19 | 20 | impl Serialize for ExTokenizersPreTokenizer { 21 | fn serialize(&self, serializer: S) -> Result 22 | where 23 | S: serde::Serializer, 24 | { 25 | self.resource.0.serialize(serializer) 26 | } 27 | } 28 | 29 | impl<'de> Deserialize<'de> for ExTokenizersPreTokenizer { 30 | fn deserialize(deserializer: D) -> Result 31 | where 32 | D: serde::Deserializer<'de>, 33 | { 34 | Ok(ExTokenizersPreTokenizer::new( 35 | PreTokenizerWrapper::deserialize(deserializer)?, 36 | )) 37 | } 38 | } 39 | 40 | impl tokenizers::PreTokenizer for ExTokenizersPreTokenizer { 41 | fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tokenizers::Result<()> { 42 | self.resource.0.pre_tokenize(pretokenized) 43 | } 44 | } 45 | 46 | impl Clone for ExTokenizersPreTokenizer { 47 | fn clone(&self) -> Self { 48 | Self { 49 | resource: self.resource.clone(), 50 | } 51 | } 52 | } 53 | 54 | impl ExTokenizersPreTokenizerRef { 55 | pub fn new(data: T) -> Self 56 | where 57 | T: Into, 58 | { 59 | Self(data.into()) 60 | } 61 | } 62 | 63 | impl ExTokenizersPreTokenizer { 64 | pub fn new(data: T) -> Self 65 | where 66 | T: Into, 67 | { 68 | Self { 69 | resource: rustler::ResourceArc::new(ExTokenizersPreTokenizerRef::new(data)), 70 | } 71 | } 72 | } 73 | 74 | #[rustler::nif] 75 | #[allow(clippy::type_complexity)] 76 | pub fn pre_tokenizers_pre_tokenize( 77 | pre_tokenizer: ExTokenizersPreTokenizer, 78 | sequence: String, 79 | ) -> Result, ExTokenizersError> { 80 | let mut pretokenized = PreTokenizedString::from(sequence); 81 | 82 | pre_tokenizer.pre_tokenize(&mut pretokenized)?; 83 | let splits: Vec<(String, (usize, usize))> = pretokenized 84 | .get_splits( 85 | tokenizers::OffsetReferential::Original, 86 | tokenizers::OffsetType::Char, 87 | ) 88 | .into_iter() 89 | .map(|(s, o, _)| (s.to_owned(), o)) 90 | .collect(); 91 | Ok(splits) 92 | } 93 | 94 | // ///////////////////////////////////////////////////////////////////////////// 95 | // / Inspection 96 | // ///////////////////////////////////////////////////////////////////////////// 97 | 98 | #[rustler::nif] 99 | fn pre_tokenizers_info(pre_tokenizer: ExTokenizersPreTokenizer) -> Info { 100 | match pre_tokenizer.resource.0 { 101 | PreTokenizerWrapper::BertPreTokenizer(_) => new_info!( 102 | pre_tokenizer_type: "BertPreTokenizer" 103 | ), 104 | PreTokenizerWrapper::ByteLevel(_) => new_info!( 105 | pre_tokenizer_type: "ByteLevel" 106 | ), 107 | PreTokenizerWrapper::Delimiter(_) => new_info!( 108 | pre_tokenizer_type: "Delimiter" 109 | ), 110 | PreTokenizerWrapper::Metaspace(_) => new_info!( 111 | pre_tokenizer_type: "Metaspace" 112 | ), 113 | PreTokenizerWrapper::Whitespace(_) => new_info!( 114 | pre_tokenizer_type: "Whitespace" 115 | ), 116 | PreTokenizerWrapper::Sequence(_) => new_info!( 117 | pre_tokenizer_type: "Sequence" 118 | ), 119 | PreTokenizerWrapper::Split(_) => new_info!( 120 | pre_tokenizer_type: "Split" 121 | ), 122 | PreTokenizerWrapper::Punctuation(_) => new_info!( 123 | pre_tokenizer_type: "Punctuation" 124 | ), 125 | PreTokenizerWrapper::WhitespaceSplit(_) => new_info!( 126 | pre_tokenizer_type: "WhitespaceSplit" 127 | ), 128 | PreTokenizerWrapper::Digits(_) => new_info!( 129 | pre_tokenizer_type: "Digits" 130 | ), 131 | PreTokenizerWrapper::UnicodeScripts(_) => new_info!( 132 | pre_tokenizer_type: "UnicodeScripts" 133 | ), 134 | } 135 | } 136 | 137 | // ///////////////////////////////////////////////////////////////////////////// 138 | // / Implementation 139 | // ///////////////////////////////////////////////////////////////////////////// 140 | 141 | #[derive(NifTaggedEnum)] 142 | pub enum ByteLevelOption { 143 | AddPrefixSpace(bool), 144 | UseRegex(bool), 145 | } 146 | 147 | #[rustler::nif] 148 | pub fn pre_tokenizers_byte_level(options: Vec) -> ExTokenizersPreTokenizer { 149 | let mut byte_level: ByteLevel = tokenizers::pre_tokenizers::byte_level::ByteLevel::default(); 150 | for option in options { 151 | match option { 152 | ByteLevelOption::AddPrefixSpace(add_prefix_space) => { 153 | byte_level = byte_level.add_prefix_space(add_prefix_space) 154 | } 155 | ByteLevelOption::UseRegex(use_regex) => byte_level = byte_level.use_regex(use_regex), 156 | }; 157 | } 158 | ExTokenizersPreTokenizer::new(byte_level) 159 | } 160 | 161 | #[rustler::nif] 162 | pub fn pre_tokenizers_byte_level_alphabet() -> Vec { 163 | tokenizers::pre_tokenizers::byte_level::ByteLevel::alphabet() 164 | .into_iter() 165 | .map(|c| c as u32) 166 | .collect::>() 167 | } 168 | 169 | #[rustler::nif] 170 | pub fn pre_tokenizers_whitespace() -> ExTokenizersPreTokenizer { 171 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::whitespace::Whitespace) 172 | } 173 | 174 | #[rustler::nif] 175 | pub fn pre_tokenizers_whitespace_split() -> ExTokenizersPreTokenizer { 176 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::whitespace::WhitespaceSplit) 177 | } 178 | 179 | #[rustler::nif] 180 | pub fn pre_tokenizers_bert() -> ExTokenizersPreTokenizer { 181 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::bert::BertPreTokenizer) 182 | } 183 | 184 | #[derive(NifTaggedEnum)] 185 | pub enum MetaspaceOption { 186 | Replacement(u32), 187 | PrependScheme(PrependScheme), 188 | } 189 | 190 | #[derive(NifTaggedEnum)] 191 | pub enum PrependScheme { 192 | First, 193 | Never, 194 | Always, 195 | } 196 | 197 | #[rustler::nif] 198 | pub fn pre_tokenizers_metaspace( 199 | options: Vec, 200 | ) -> Result { 201 | let mut metaspace = tokenizers::pre_tokenizers::metaspace::Metaspace::default(); 202 | for option in options { 203 | match option { 204 | MetaspaceOption::Replacement(replacement) => metaspace 205 | .set_replacement(std::char::from_u32(replacement).ok_or(rustler::Error::BadArg)?), 206 | MetaspaceOption::PrependScheme(prepend_scheme) => { 207 | metaspace.prepend_scheme = match prepend_scheme { 208 | PrependScheme::First => { 209 | tokenizers::pre_tokenizers::metaspace::PrependScheme::First 210 | } 211 | PrependScheme::Never => { 212 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Never 213 | } 214 | PrependScheme::Always => { 215 | tokenizers::pre_tokenizers::metaspace::PrependScheme::Always 216 | } 217 | } 218 | } 219 | } 220 | } 221 | Ok(ExTokenizersPreTokenizer::new(metaspace)) 222 | } 223 | 224 | #[rustler::nif] 225 | pub fn pre_tokenizers_char_delimiter_split( 226 | delimiter: u32, 227 | ) -> Result { 228 | Ok(ExTokenizersPreTokenizer::new( 229 | tokenizers::pre_tokenizers::delimiter::CharDelimiterSplit::new( 230 | std::char::from_u32(delimiter).ok_or(rustler::Error::BadArg)?, 231 | ), 232 | )) 233 | } 234 | 235 | #[derive(rustler::NifUnitEnum)] 236 | pub enum SplitDelimiterBehavior { 237 | Removed, 238 | Isolated, 239 | MergedWithPrevious, 240 | MergedWithNext, 241 | Contiguous, 242 | } 243 | 244 | impl From for tokenizers::SplitDelimiterBehavior { 245 | fn from(value: SplitDelimiterBehavior) -> Self { 246 | match value { 247 | SplitDelimiterBehavior::Removed => tokenizers::SplitDelimiterBehavior::Removed, 248 | SplitDelimiterBehavior::Isolated => tokenizers::SplitDelimiterBehavior::Isolated, 249 | SplitDelimiterBehavior::MergedWithPrevious => { 250 | tokenizers::SplitDelimiterBehavior::MergedWithPrevious 251 | } 252 | SplitDelimiterBehavior::MergedWithNext => { 253 | tokenizers::SplitDelimiterBehavior::MergedWithNext 254 | } 255 | SplitDelimiterBehavior::Contiguous => tokenizers::SplitDelimiterBehavior::Contiguous, 256 | } 257 | } 258 | } 259 | 260 | #[derive(NifTaggedEnum)] 261 | pub enum SplitOption { 262 | Invert(bool), 263 | } 264 | 265 | #[derive(NifTaggedEnum)] 266 | pub enum LocalSplitPattern { 267 | String(String), 268 | Regex(String), 269 | } 270 | 271 | #[rustler::nif] 272 | pub fn pre_tokenizers_split( 273 | pattern: LocalSplitPattern, 274 | behavior: SplitDelimiterBehavior, 275 | options: Vec, 276 | ) -> Result { 277 | struct Opts { 278 | invert: bool, 279 | } 280 | let mut opts = Opts { invert: false }; 281 | let final_pattern = match pattern { 282 | LocalSplitPattern::String(pattern) => SplitPattern::String(pattern), 283 | LocalSplitPattern::Regex(pattern) => SplitPattern::Regex(pattern), 284 | }; 285 | 286 | for option in options { 287 | match option { 288 | SplitOption::Invert(invert) => opts.invert = invert, 289 | } 290 | } 291 | 292 | Ok(ExTokenizersPreTokenizer::new( 293 | tokenizers::pre_tokenizers::split::Split::new(final_pattern, behavior.into(), opts.invert) 294 | .map_err(|_| rustler::Error::BadArg)?, 295 | )) 296 | } 297 | 298 | #[rustler::nif] 299 | pub fn pre_tokenizers_punctuation(behavior: SplitDelimiterBehavior) -> ExTokenizersPreTokenizer { 300 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::punctuation::Punctuation::new( 301 | behavior.into(), 302 | )) 303 | } 304 | 305 | #[rustler::nif] 306 | pub fn pre_tokenizers_sequence( 307 | pretokenizers: Vec, 308 | ) -> ExTokenizersPreTokenizer { 309 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::sequence::Sequence::new( 310 | pretokenizers 311 | .iter() 312 | .map(|pretokenizer| pretokenizer.resource.0.clone()) 313 | .collect(), 314 | )) 315 | } 316 | 317 | #[derive(NifTaggedEnum)] 318 | pub enum DigitsOption { 319 | IndividualDigits(bool), 320 | } 321 | 322 | #[rustler::nif] 323 | pub fn pre_tokenizers_digits(options: Vec) -> ExTokenizersPreTokenizer { 324 | struct Opts { 325 | individual_digits: bool, 326 | } 327 | let mut opts = Opts { 328 | individual_digits: false, 329 | }; 330 | 331 | for option in options { 332 | match option { 333 | DigitsOption::IndividualDigits(individual_digits) => { 334 | opts.individual_digits = individual_digits 335 | } 336 | }; 337 | } 338 | 339 | ExTokenizersPreTokenizer::new(tokenizers::pre_tokenizers::digits::Digits::new( 340 | opts.individual_digits, 341 | )) 342 | } 343 | -------------------------------------------------------------------------------- /native/ex_tokenizers/src/util.rs: -------------------------------------------------------------------------------- 1 | use std::panic::RefUnwindSafe; 2 | 3 | use rustler::Encoder; 4 | use tokenizers::{PaddingDirection, TruncationDirection}; 5 | 6 | #[macro_export] 7 | macro_rules! new_info { 8 | [$($a:ident : $b:expr),*] => {{ 9 | let vec: Vec<(Box, Box)> = vec![$((Box::new(stringify!($a)), Box::new($b)),)*]; 10 | Info(vec) 11 | }} 12 | } 13 | 14 | pub struct Info(pub Vec<(Box, Box)>); 15 | impl RefUnwindSafe for Info {} 16 | 17 | impl rustler::Encoder for Info { 18 | fn encode<'a>(&self, env: rustler::Env<'a>) -> rustler::Term<'a> { 19 | rustler::Term::map_from_pairs( 20 | env, 21 | &self 22 | .0 23 | .iter() 24 | .map(|(k, v)| (k.encode(env), v.encode(env))) 25 | .collect::>(), 26 | ) 27 | .unwrap() 28 | } 29 | } 30 | 31 | #[derive(rustler::NifUnitEnum, Clone)] 32 | pub enum Direction { 33 | Left, 34 | Right, 35 | } 36 | 37 | impl From for PaddingDirection { 38 | fn from(val: Direction) -> Self { 39 | match val { 40 | Direction::Left => PaddingDirection::Left, 41 | Direction::Right => PaddingDirection::Right, 42 | } 43 | } 44 | } 45 | 46 | impl From<&Direction> for PaddingDirection { 47 | fn from(val: &Direction) -> Self { 48 | match val { 49 | Direction::Left => PaddingDirection::Left, 50 | Direction::Right => PaddingDirection::Right, 51 | } 52 | } 53 | } 54 | 55 | impl From for TruncationDirection { 56 | fn from(val: Direction) -> Self { 57 | match val { 58 | Direction::Left => TruncationDirection::Left, 59 | Direction::Right => TruncationDirection::Right, 60 | } 61 | } 62 | } 63 | 64 | impl From<&Direction> for TruncationDirection { 65 | fn from(val: &Direction) -> Self { 66 | match val { 67 | Direction::Left => TruncationDirection::Left, 68 | Direction::Right => TruncationDirection::Right, 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /notebooks/pretrained.livemd: -------------------------------------------------------------------------------- 1 | # Pretrained tokenizers 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:kino, "~> 0.10.0"}, 6 | {:scidata, "~> 0.1.5"}, 7 | {:tokenizers, "~> 0.4.0"}, 8 | {:nx, "~> 0.5"} 9 | ]) 10 | ``` 11 | 12 | ## Setup 13 | 14 | This Livebook will demonstrate how to use `Tokenizers` with pretrained tokenizers available on the [Hugging Face Hub](https://huggingface.co/models). 15 | 16 | We'll install `Kino` for user input and `SciData` for real data to tokenize. 17 | 18 | Check **_Notebook dependencies and setup_** section at the beginning of this notebook 19 | 20 | 21 | 22 | We'll alias modules in `Tokenizers` for readability. For now, the two main entry points into `Tokenizers` are the `Tokenizer` and `Encoding` modules. 23 | 24 | ```elixir 25 | alias Tokenizers.Tokenizer 26 | alias Tokenizers.Encoding 27 | ``` 28 | 29 | ## Get a tokenizer 30 | 31 | The first thing to do is get a tokenizer from the hub. I've chosen `bert-base-cased` here as it's commonly used in Hugging Face examples. This call will download the tokenizer from the hub and load it into memory. 32 | 33 | ```elixir 34 | {:ok, tokenizer} = Tokenizer.from_pretrained("bert-base-cased") 35 | ``` 36 | 37 | 38 | 39 | ## Save and load 40 | 41 | You can save and load models. That means you can load in tokenizers you may have trained locally! 42 | 43 | You can choose the path with the Kino input below. 44 | 45 | ```elixir 46 | input = Kino.Input.text("Path") 47 | ``` 48 | 49 | ```elixir 50 | path = Kino.Input.read(input) 51 | Tokenizer.save(tokenizer, path) 52 | ``` 53 | 54 | ```elixir 55 | {:ok, tokenizer} = Tokenizer.from_file(path) 56 | ``` 57 | 58 | ## Check the tokenizer 59 | 60 | Let's see what we can do with the tokenizer. First, let's have a look at the vocab. It's represented as a map of tokens to ids. 61 | 62 | ```elixir 63 | vocab = Tokenizer.get_vocab(tokenizer) 64 | ``` 65 | 66 | We can access an id using the vocab, but we don't need to extract the vocab. `Tokenizer.token_to_id/2` does the job for us. 67 | 68 | ```elixir 69 | vocab["Jaguar"] 70 | ``` 71 | 72 | ```elixir 73 | Tokenizer.token_to_id(tokenizer, "Jaguar") 74 | ``` 75 | 76 | And if we want to go back the other way... 77 | 78 | ```elixir 79 | Tokenizer.id_to_token(tokenizer, 21694) 80 | ``` 81 | 82 | We can also see the vocab size. 83 | 84 | ```elixir 85 | Tokenizer.get_vocab_size(tokenizer) 86 | ``` 87 | 88 | ## Encode and decode 89 | 90 | When you tokenize some text you get an encoding. This is represented as `Tokenizers.Encoding.t()`. Because `Tokenizers` relies on Rust bindings, the encoding itself appears opaque. 91 | 92 | ```elixir 93 | {:ok, encoding} = Tokenizer.encode(tokenizer, "Hello there!") 94 | ``` 95 | 96 | However, we can get the ids for the encoding as an Elixir list. 97 | 98 | ```elixir 99 | ids = Encoding.get_ids(encoding) 100 | ``` 101 | 102 | And we can decode those back into tokens. 103 | 104 | ```elixir 105 | Tokenizer.decode(tokenizer, ids) 106 | ``` 107 | 108 | Passing a batch of text as a list of strings returns a batch of encodings. 109 | 110 | ```elixir 111 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, ["Hello there!", "This is a test."]) 112 | ``` 113 | 114 | And we can see the list of ids and decode them again. 115 | 116 | ```elixir 117 | list_of_ids = Enum.map(encodings, &Encoding.get_ids/1) 118 | ``` 119 | 120 | ```elixir 121 | Tokenizer.decode_batch(tokenizer, list_of_ids) 122 | ``` 123 | 124 | ## Get a tensor 125 | 126 | Typically the reason we're tokenizing text is to use it as an input in a machine learning model. For that, we'll need tensors. 127 | 128 | In order to get a tensor, we need sequences that are all of the same length. We'll get some data from `Scidata` and use `Tokenizers.Encoding.pad/3` and `Tokenizers.Encoding.truncate/3` to yield a tensor. 129 | 130 | ```elixir 131 | %{review: reviews} = Scidata.YelpPolarityReviews.download_test() 132 | ``` 133 | 134 | ```elixir 135 | {:ok, encoding_batch} = 136 | reviews 137 | |> Enum.take(10) 138 | |> then(&Tokenizer.encode_batch(tokenizer, &1)) 139 | 140 | tensor = 141 | encoding_batch 142 | |> Enum.map(fn encoding -> 143 | encoding 144 | |> Encoding.pad(200) 145 | |> Encoding.truncate(200) 146 | |> Encoding.get_ids() 147 | end) 148 | |> Nx.tensor() 149 | ``` 150 | 151 | And we can reverse the operation to see our data. Note the `[PAD]` tokens. 152 | 153 | ```elixir 154 | tensor 155 | |> Nx.to_batched(1) 156 | |> Enum.map(&Nx.to_flat_list/1) 157 | |> then(&Tokenizer.decode_batch(tokenizer, &1)) 158 | ``` 159 | -------------------------------------------------------------------------------- /notebooks/training.livemd: -------------------------------------------------------------------------------- 1 | # Training custom tokenizer 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:tokenizers, "~> 0.4.0"}, 6 | {:req, "~> 0.3.8"} 7 | ]) 8 | ``` 9 | 10 | ## Intro 11 | 12 | Let’s have a quick look at the 🤗 Tokenizers library features. The library provides an implementation of today’s most used tokenizers that is both easy to use and blazing fast. 13 | 14 | 15 | 16 | ## Downloading the data 17 | 18 | To illustrate how fast the 🤗 Tokenizers library is, let’s train a new tokenizer on wikitext-103 (516M of text) in just a few seconds. First things first, you will need to download this dataset and unzip it with: 19 | 20 | ```bash 21 | wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip 22 | unzip wikitext-103-raw-v1.zip 23 | ``` 24 | 25 | 26 | 27 | Alternatively you can run this code: 28 | 29 | ```elixir 30 | Req.get!("https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip").body 31 | |> Enum.each(fn {filename, data} -> 32 | filename = to_string(filename) 33 | path = Path.join(__DIR__, filename) 34 | IO.puts("Writing #{filename} to path #{path}") 35 | 36 | :ok = File.mkdir_p!(Path.dirname(path)) 37 | File.write!(path, data, [:write]) 38 | end) 39 | ``` 40 | 41 | ## Training the tokenizer from scratch 42 | 43 | ```elixir 44 | alias Tokenizers.Tokenizer 45 | alias Tokenizers.Trainer 46 | alias Tokenizers.PostProcessor 47 | alias Tokenizers.PreTokenizer 48 | alias Tokenizers.Model 49 | alias Tokenizers.Encoding 50 | ``` 51 | 52 | In this tour, we will build and train a Byte-Pair Encoding (BPE) tokenizer. For more information about the different type of tokenizers, check out this guide in the 🤗 Transformers documentation. Here, training the tokenizer means it will learn merge rules by: 53 | 54 | * Start with all the characters present in the training corpus as tokens. 55 | * Identify the most common pair of tokens and merge it into one token. 56 | * Repeat until the vocabulary (e.g., the number of tokens) has reached the size we want. 57 | 58 | The main API of the library is the class Tokenizer, here is how we instantiate one with a BPE model: 59 | 60 | ```elixir 61 | {:ok, model} = Model.BPE.init(%{}, [], unk_token: "[UNK]") 62 | {:ok, tokenizer} = Tokenizer.init(model) 63 | ``` 64 | 65 | To train our tokenizer on the wikitext files, we will need to instantiate a **trainer**, in this case a BPE trainer: 66 | 67 | ```elixir 68 | {:ok, trainer} = Trainer.bpe(special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]) 69 | ``` 70 | 71 | We can set the training arguments like `vocab_size` or `min_frequency` (here left at their default values of `30,000` and `0`), but the most important part is to give the `special_tokens` we plan to use later on (they are not used at all during training) so that they get inserted in the vocabulary. 72 | 73 | > The order in which you write the special tokens list matters: here `"[UNK]"` will get the ID `0`, `"[CLS]"` will get the ID `1` and so forth. 74 | 75 | We could train our tokenizer right now, but it wouldn't be optimal. Without a pre-tokenizer that will split our inputs into words, we might get tokens that overlap several words: for instance we could get an "it is" token since those two words often appear next to each other. Using a pre-tokenizer will ensure no token is bigger than a word returned by the pre-tokenizer. Here we want to train a subword BPE tokenizer, and we will use the easiest pre-tokenizer possible by splitting on whitespace. 76 | 77 | ```elixir 78 | tokenizer = Tokenizer.set_pre_tokenizer(tokenizer, PreTokenizer.whitespace()) 79 | ``` 80 | 81 | Now, we can just call the `Tokenizer.train_from_files/3` function with the list of files we want to train on: 82 | 83 | ```elixir 84 | {:ok, tokenizer} = 85 | [ 86 | "wikitext-103-raw/wiki.test.raw", 87 | "wikitext-103-raw/wiki.train.raw", 88 | "wikitext-103-raw/wiki.valid.raw" 89 | ] 90 | |> Enum.map(&Path.join(__DIR__, &1)) 91 | |> then(&Tokenizer.train_from_files(tokenizer, &1, trainer: trainer)) 92 | ``` 93 | 94 | This should only take a few seconds to train our tokenizer on the full wikitext dataset! To save the tokenizer in one file that contains all its configuration and vocabulary, just use the `Tokenizer.save/2` function: 95 | 96 | ```elixir 97 | Tokenizer.save(tokenizer, Path.join(__DIR__, "tokenizer-wiki.json")) 98 | ``` 99 | 100 | and you can reload your tokenizer from that file with the `Tokenizer.from_file/1` function: 101 | 102 | ```elixir 103 | {:ok, tokenizer} = Tokenizer.from_file(Path.join(__DIR__, "tokenizer-wiki.json")) 104 | ``` 105 | 106 | ## Using the tokenizer 107 | 108 | Now that we have trained a tokenizer, we can use it on any text we want with the `Tokenizer.encode/1` function: 109 | 110 | ```elixir 111 | {:ok, encoding} = Tokenizer.encode(tokenizer, "Hello, y'all! How are you 😁 ?") 112 | ``` 113 | 114 | This applied the full pipeline of the tokenizer on the text, returning an `encoding`. To learn more about this pipeline, and how to apply (or customize) parts of it, check out [this page](https://huggingface.co/docs/tokenizers/pipeline). 115 | 116 | This `encoding` then has all the attributes you need for your deep learning model (or other). The tokens attribute contains the segmentation of your text in tokens: 117 | 118 | ```elixir 119 | Encoding.get_tokens(encoding) 120 | ``` 121 | 122 | Similarly, the ids attribute will contain the index of each of those tokens in the tokenizer’s vocabulary: 123 | 124 | ```elixir 125 | Encoding.get_ids(encoding) 126 | ``` 127 | 128 | An important feature of the 🤗 Tokenizers library is that it comes with full alignment tracking, meaning you can always get the part of your original sentence that corresponds to a given token. Those are stored in the offsets attribute of our Encoding object. For instance, let’s assume we would want to find back what caused the "[UNK]" token to appear, which is the token at index 9 in the list, we can just ask for the offset at the index: 129 | 130 | ```elixir 131 | {emoji_offset_start, emoji_offset_end} = Encoding.get_offsets(encoding) |> Enum.at(9) 132 | ``` 133 | 134 | and those are the indices that correspond to the emoji in the original sentence: 135 | 136 | ```elixir 137 | :binary.part( 138 | "Hello, y'all! How are you 😁 ?", 139 | emoji_offset_start, 140 | # Length 141 | emoji_offset_end - emoji_offset_start 142 | ) 143 | ``` 144 | 145 | ## Post-processing 146 | 147 | We might want our tokenizer to automatically add special tokens, like `[CLS]` or `[SEP]`. To do this, we use a post-processor. Template post-processing is the most commonly used, you just have to specify a template for the processing of single sentences and pairs of sentences, along with the special tokens and their IDs. 148 | 149 | When we built our tokenizer, we set `[CLS]` and `[SEP]` in positions 1 and 2 of our list of special tokens, so this should be their IDs. To double-check, we can use the `Tokenizer.token_to_id/2` function: 150 | 151 | ```elixir 152 | Tokenizer.token_to_id(tokenizer, "[SEP]") 153 | ``` 154 | 155 | Here is how we can set the post-processing to give us the traditional BERT inputs: 156 | 157 | ```elixir 158 | tokenizer = 159 | Tokenizer.set_post_processor( 160 | tokenizer, 161 | PostProcessor.template( 162 | single: "[CLS] $A [SEP]", 163 | pair: "[CLS] $A [SEP] $B:1 [SEP]:1", 164 | special_tokens: [ 165 | {"[CLS]", Tokenizer.token_to_id(tokenizer, "[CLS]")}, 166 | {"[SEP]", Tokenizer.token_to_id(tokenizer, "[SEP]")} 167 | ] 168 | ) 169 | ) 170 | ``` 171 | 172 | Let's go over this snippet of code in more details. First we specify the template for single sentences: those should have the form `"[CLS] $A [SEP]"` where `$A` represents our sentence. 173 | 174 | Then, we specify the template for sentence pairs, which should have the form `"[CLS] $A [SEP] $B [SEP]"` where `$A` represents the first sentence and `$B` the second one. The `:1` added in the template represent the type IDs we want for each part of our input: it defaults to `0` for everything (which is why we don't have `$A:0`) and here we set it to 1 for the tokens of the second sentence and the last `"[SEP]"` token. 175 | 176 | Lastly, we specify the special tokens we used and their IDs in our tokenizer's vocabulary. 177 | 178 | To check out this worked properly, let's try to encode the same sentence as before: 179 | 180 | ```elixir 181 | {:ok, encoding} = Tokenizer.encode(tokenizer, "Hello, y'all! How are you 😁 ?") 182 | Encoding.get_tokens(encoding) 183 | ``` 184 | 185 | To check the results on a pair of sentences, we just pass the two sentences to `Tokenizer.encode/2`: 186 | 187 | ```elixir 188 | {:ok, encoding} = Tokenizer.encode(tokenizer, {"Hello, y'all!", "How are you 😁 ?"}) 189 | Encoding.get_tokens(encoding) 190 | ``` 191 | 192 | You can then check the type IDs attributed to each token is correct with 193 | 194 | ```elixir 195 | Encoding.get_type_ids(encoding) 196 | ``` 197 | 198 | If you save your tokenizer with `Tokenizer.save/2`, the post-processor will be saved along. 199 | 200 | ## Encoding multiple sentences in a batch 201 | 202 | To get the full speed of the 🤗 Tokenizers library, it's best to process your texts by batches by using the `Tokenizer.encode_batch/2` function: 203 | 204 | ```elixir 205 | {:ok, encoding} = Tokenizer.encode_batch(tokenizer, ["Hello, y'all!", "How are you 😁 ?"]) 206 | ``` 207 | 208 | The output is then a list of `encoding`s like the ones we saw before. You can process together as many texts as you like, as long as it fits in memory. 209 | 210 | To process a batch of sentence pairs, pass a list of tuples to the `Tokenizer.encode_batch/2` function: 211 | 212 | ```elixir 213 | {:ok, encoding} = 214 | Tokenizer.encode_batch(tokenizer, [ 215 | {"Hello, y'all!", "How are you 😁 ?"}, 216 | { 217 | "Hello to you too!", 218 | "I'm fine, thank you!" 219 | } 220 | ]) 221 | ``` 222 | 223 | When encoding multiple sentences, you can automatically pad the outputs to the longest sentence present by using `Tokenizer.set_padding/2`, with the `pad_token` and its ID (which we can double-check the id for the padding token with `Tokenizer.token_to_id/2` like before): 224 | 225 | ```elixir 226 | tokenizer = Tokenizer.set_padding(tokenizer, pad_id: 3, pad_token: "[PAD]") 227 | ``` 228 | 229 | We can set the direction of the padding (defaults to the right) or a given length if we want to pad every sample to that specific number (here we leave it unset to pad to the size of the longest text). 230 | 231 | ```elixir 232 | {:ok, encoding} = Tokenizer.encode_batch(tokenizer, ["Hello, y'all!", "How are you 😁 ?"]) 233 | 234 | encoding 235 | |> Enum.at(1) 236 | |> Encoding.get_tokens() 237 | ``` 238 | 239 | In this case, the attention mask generated by the tokenizer takes the padding into account: 240 | 241 | ```elixir 242 | encoding 243 | |> Enum.at(1) 244 | |> Encoding.get_attention_mask() 245 | ``` 246 | -------------------------------------------------------------------------------- /test/fixtures/merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elixir-nx/tokenizers/d3081c4973bb1f868fcdaf71c8e655d188dfa836/test/fixtures/merges.txt -------------------------------------------------------------------------------- /test/fixtures/vocab.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /test/fixtures/vocab.txt: -------------------------------------------------------------------------------- 1 | my 2 | name 3 | is 4 | jo 5 | ##hn 6 | what 7 | yours 8 | pair 9 | [UNK] 10 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | -------------------------------------------------------------------------------- /test/tokenizers/added_token_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.AddedTokenTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.AddedToken 4 | 5 | describe "Added token" do 6 | test "successfully initializes with empty params" do 7 | assert token = %Tokenizers.AddedToken{} = Tokenizers.AddedToken.new("[MASK]") 8 | 9 | assert %{ 10 | "content" => "[MASK]", 11 | "lstrip" => false, 12 | "normalized" => true, 13 | "rstrip" => false, 14 | "single_word" => false, 15 | "special" => false 16 | } = Tokenizers.AddedToken.info(token) 17 | end 18 | 19 | test "successfully initializes with params" do 20 | assert token = 21 | %Tokenizers.AddedToken{} = 22 | Tokenizers.AddedToken.new( 23 | "[MASK]", 24 | lstrip: true, 25 | rstrip: true, 26 | single_word: true, 27 | normalized: false, 28 | special: true 29 | ) 30 | 31 | assert %{ 32 | "content" => "[MASK]", 33 | "lstrip" => true, 34 | "normalized" => false, 35 | "rstrip" => true, 36 | "single_word" => true, 37 | "special" => true 38 | } = Tokenizers.AddedToken.info(token) 39 | end 40 | end 41 | end 42 | -------------------------------------------------------------------------------- /test/tokenizers/decode_stream_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.DecodeStreamTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Decoder 4 | 5 | describe "Minimal tokenizer" do 6 | test "Decodes with stream" do 7 | {:ok, bpe} = Tokenizers.Model.BPE.empty() 8 | {:ok, tk} = Tokenizers.Tokenizer.init(bpe) 9 | 10 | tk = 11 | tk 12 | |> Tokenizers.Tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) 13 | 14 | ds = Tokenizers.DecodeStream.new() 15 | 16 | {:ok, "my"} = Tokenizers.DecodeStream.step(ds, tk, 0) 17 | {:ok, " name"} = Tokenizers.DecodeStream.step(ds, tk, 1) 18 | {:ok, " is"} = Tokenizers.DecodeStream.step(ds, tk, 2) 19 | {:ok, " john"} = Tokenizers.DecodeStream.step(ds, tk, 3) 20 | {:ok, " pair"} = Tokenizers.DecodeStream.step(ds, tk, 4) 21 | end 22 | end 23 | 24 | describe "Byte fallback decode stream" do 25 | test "handles byte fallback decoding" do 26 | vocab = [ 27 | {"", 0.0}, 28 | {"<0x20>", -0.1}, 29 | {"<0xC3>", -0.2}, 30 | {"<0xA9>", -0.3} 31 | ] 32 | 33 | {:ok, model} = Tokenizers.Model.Unigram.init(vocab, byte_fallback: true, unk_id: 0) 34 | 35 | {:ok, tk} = Tokenizers.Tokenizer.init(model) 36 | 37 | tk = 38 | tk 39 | |> Tokenizers.Tokenizer.set_decoder(Tokenizers.Decoder.byte_fallback()) 40 | 41 | ds = Tokenizers.DecodeStream.new() 42 | 43 | {:ok, " "} = Tokenizers.DecodeStream.step(ds, tk, 1) 44 | {:ok, :out_of_range} = Tokenizers.DecodeStream.step(ds, tk, 2) 45 | {:ok, "é"} = Tokenizers.DecodeStream.step(ds, tk, 3) 46 | end 47 | 48 | test "handles metaspace decoding" do 49 | vocab = [ 50 | {"", 0.0}, 51 | {"▁This", -0.1} 52 | ] 53 | 54 | {:ok, model} = Tokenizers.Model.Unigram.init(vocab, byte_fallback: false, unk_id: 0) 55 | {:ok, tk} = Tokenizers.Tokenizer.init(model) 56 | 57 | tk = 58 | tk 59 | |> Tokenizers.Tokenizer.set_decoder(Tokenizers.Decoder.metaspace()) 60 | 61 | ds = Tokenizers.DecodeStream.new() 62 | 63 | {:ok, "This"} = Tokenizers.DecodeStream.step(ds, tk, 1) 64 | {:ok, " This"} = Tokenizers.DecodeStream.step(ds, tk, 1) 65 | end 66 | end 67 | 68 | describe "DecodeStream info" do 69 | test "skip_special_tokens false" do 70 | assert Tokenizers.DecodeStream.info(Tokenizers.DecodeStream.new()) == %{ 71 | "skip_special_tokens" => false 72 | } 73 | end 74 | 75 | test "skip_special_tokens true" do 76 | assert Tokenizers.DecodeStream.info(Tokenizers.DecodeStream.new(skip_special_tokens: true)) == 77 | %{ 78 | "skip_special_tokens" => true 79 | } 80 | end 81 | 82 | test "default DecodeStream" do 83 | assert Tokenizers.DecodeStream.info(Tokenizers.DecodeStream.new()) == %{ 84 | "skip_special_tokens" => false 85 | } 86 | end 87 | end 88 | end 89 | -------------------------------------------------------------------------------- /test/tokenizers/decoder_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.DecoderTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Decoder 4 | 5 | describe "WordPiece Decoder" do 6 | test "accepts no parameters" do 7 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.word_piece() 8 | end 9 | 10 | test "accepts all params" do 11 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.word_piece(prefix: "test", cleanup: false) 12 | end 13 | 14 | test "can decode array of strings" do 15 | assert Tokenizers.Decoder.word_piece() 16 | |> Tokenizers.Decoder.decode(["Hel", "##lo", "there", "my", "fr", "##iend"]) == 17 | {:ok, "Hello there my friend"} 18 | end 19 | end 20 | 21 | describe "ByteFallback Decoder" do 22 | test "accepts no parameters" do 23 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.byte_fallback() 24 | end 25 | 26 | test "can decode array of strings" do 27 | [ 28 | {["Hel", "lo"], "Hello"}, 29 | {["<0x61>"], "a"}, 30 | {["<0x61>"], "a"}, 31 | {["My", " na", "me"], "My name"}, 32 | {["<0x61>"], "a"}, 33 | {["<0xE5>"], "�"}, 34 | {["<0xE5>", "<0x8f>"], "��"}, 35 | {["<0xE5>", "<0x8f>", "<0xab>"], "叫"}, 36 | {["<0xE5>", "<0x8f>", "a"], "��a"}, 37 | {["<0xE5>", "<0x8f>", "<0xab>", "a"], "叫a"} 38 | ] 39 | |> Enum.each(fn {tokens, result} -> 40 | assert Tokenizers.Decoder.decode(Tokenizers.Decoder.byte_fallback(), tokens) == 41 | {:ok, result} 42 | end) 43 | end 44 | end 45 | 46 | describe "Replace Decoder" do 47 | test "can decode array of strings" do 48 | assert Tokenizers.Decoder.decode(Tokenizers.Decoder.replace("_", " "), ["Hello", "_Hello"]) == 49 | {:ok, "Hello Hello"} 50 | end 51 | end 52 | 53 | describe "Fuse Decoder" do 54 | test "accepts no parameters" do 55 | %Tokenizers.Decoder{} = Tokenizers.Decoder.fuse() 56 | end 57 | 58 | test "can decode array of strings" do 59 | assert Tokenizers.Decoder.fuse() 60 | |> Tokenizers.Decoder.decode(["Hel", "lo"]) == 61 | {:ok, "Hello"} 62 | end 63 | end 64 | 65 | describe "Strip Decoder" do 66 | test "can be initialized" do 67 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.strip(?_, 0, 0) 68 | end 69 | 70 | test "can't be initialized with invalid char" do 71 | assert_raise ArgumentError, fn -> 72 | Tokenizers.Decoder.strip(61_126_999, 0, 0) 73 | end 74 | end 75 | 76 | test "can decode array of strings" do 77 | assert Tokenizers.Decoder.strip(?_, 1, 0) 78 | |> Tokenizers.Decoder.decode(["_Hel", "lo", "__there"]) == 79 | {:ok, "Hello_there"} 80 | end 81 | end 82 | 83 | describe "Metaspace Decoder" do 84 | test "accepts no parameters" do 85 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.metaspace() 86 | end 87 | 88 | test "accepts all params" do 89 | assert %Tokenizers.Decoder{} = 90 | Tokenizers.Decoder.metaspace(replacement: ?t, prepend_scheme: :always) 91 | end 92 | end 93 | 94 | describe "BPE Decoder" do 95 | test "accepts no parameters" do 96 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.bpe() 97 | end 98 | end 99 | 100 | describe "CTC Decoder" do 101 | test "accepts no parameters" do 102 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.ctc() 103 | end 104 | 105 | test "accepts all parameters" do 106 | assert %Tokenizers.Decoder{} = 107 | Tokenizers.Decoder.ctc( 108 | pad_token: "", 109 | word_delimiter_token: "!!", 110 | cleanup: false 111 | ) 112 | end 113 | 114 | test "can decode array of strings" do 115 | assert Tokenizers.Decoder.ctc() 116 | |> Tokenizers.Decoder.decode([ 117 | "", 118 | "h", 119 | "h", 120 | "e", 121 | "e", 122 | "l", 123 | "l", 124 | "", 125 | "l", 126 | "l", 127 | "o" 128 | ]) == 129 | {:ok, "hello"} 130 | end 131 | end 132 | 133 | describe "Sequence Decoder" do 134 | test "accepts empty list as parameter" do 135 | assert %Tokenizers.Decoder{} = Tokenizers.Decoder.sequence([]) 136 | end 137 | 138 | test "can decode array of strings correctly" do 139 | assert Tokenizers.Decoder.sequence([ 140 | Tokenizers.Decoder.ctc(), 141 | Tokenizers.Decoder.metaspace() 142 | ]) 143 | |> Tokenizers.Decoder.decode(["▁", "▁", "H", "H", "i", "i", "▁", "y", "o", "u"]) == 144 | {:ok, "Hi you"} 145 | end 146 | end 147 | end 148 | -------------------------------------------------------------------------------- /test/tokenizers/model/bpe_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.BPETest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Model.BPE 4 | 5 | describe "initialized from memory" do 6 | test "returns loaded model" do 7 | assert {:ok, %Tokenizers.Model{}} = 8 | Tokenizers.Model.BPE.init(%{"a" => 0, "b" => 1, "ab" => 2}, [{"a", "b"}]) 9 | end 10 | 11 | test "accepts keyword params" do 12 | assert {:ok, %Tokenizers.Model{}} = 13 | Tokenizers.Model.BPE.init(%{"a" => 0, "b" => 1, "ab" => 2}, [{"a", "b"}], 14 | dropout: 0.3 15 | ) 16 | end 17 | 18 | test "rejects bad keyword params" do 19 | assert_raise ErlangError, fn -> 20 | Tokenizers.Model.BPE.init(%{"a" => 0, "b" => 1, "ab" => 2}, [{"a", "b"}], 21 | weird_value: :something 22 | ) 23 | end 24 | end 25 | end 26 | 27 | describe "loaded from file" do 28 | test "Good initialization with valid paths" do 29 | assert {:ok, %Tokenizers.Model{}} = 30 | Tokenizers.Model.BPE.from_file( 31 | "test/fixtures/vocab.json", 32 | "test/fixtures/merges.txt" 33 | ) 34 | end 35 | 36 | test "bad initialization with invalid paths" do 37 | assert {:error, _} = 38 | Tokenizers.Model.BPE.from_file( 39 | "test/fixtures/not_found_vocab.json", 40 | "test/fixtures/merges.txt" 41 | ) 42 | end 43 | 44 | test "bad initialization with good paths but invalid data" do 45 | assert {:error, _} = 46 | Tokenizers.Model.BPE.from_file( 47 | "test/fixtures/vocab.txt", 48 | "test/fixtures/merges.txt" 49 | ) 50 | end 51 | end 52 | end 53 | -------------------------------------------------------------------------------- /test/tokenizers/model/unigram.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.UnigramTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Model.Unigram 4 | 5 | describe "initialized from memory" do 6 | test "returns loaded model" do 7 | assert {:ok, %Tokenizers.Model{}} = 8 | Tokenizers.Model.Unigram.init([{"", 0}, {"Hello", -1}, {"there", -2}], 9 | unk_id: 0 10 | ) 11 | end 12 | end 13 | end 14 | -------------------------------------------------------------------------------- /test/tokenizers/model/wordlevel_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.WordLevelTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Model.WordLevel 4 | 5 | describe "initialized from memory" do 6 | test "returns loaded model" do 7 | assert {:ok, %Tokenizers.Model{}} = 8 | Tokenizers.Model.WordLevel.init(%{"a" => 0, "b" => 1, "ab" => 2}) 9 | end 10 | 11 | test "accepts keyword params" do 12 | assert {:ok, %Tokenizers.Model{}} = 13 | Tokenizers.Model.WordLevel.init(%{"a" => 0, "b" => 1, "ab" => 2}, 14 | unk_token: "asdf" 15 | ) 16 | end 17 | 18 | test "rejects bad keyword params" do 19 | assert_raise ErlangError, fn -> 20 | Tokenizers.Model.WordLevel.init(%{"a" => 0, "b" => 1, "ab" => 2}, 21 | weird_value: :something 22 | ) 23 | end 24 | end 25 | end 26 | 27 | describe "loaded from file" do 28 | test "good initialization with valid paths" do 29 | assert {:ok, %Tokenizers.Model{}} = 30 | Tokenizers.Model.WordLevel.from_file("test/fixtures/vocab.json") 31 | end 32 | 33 | test "bad initialization with invalid paths" do 34 | assert {:error, _} = 35 | Tokenizers.Model.WordLevel.from_file("test/fixtures/not_found_vocab.json") 36 | end 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /test/tokenizers/model/wordpiece_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.Model.WordPieceTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Model.WordPiece 4 | 5 | describe "initialized from memory" do 6 | test "returns loaded model" do 7 | assert {:ok, %Tokenizers.Model{}} = 8 | Tokenizers.Model.WordPiece.init(%{"a" => 0, "b" => 1, "ab" => 2}) 9 | end 10 | 11 | test "accepts keyword params" do 12 | assert {:ok, %Tokenizers.Model{}} = 13 | Tokenizers.Model.WordPiece.init(%{"a" => 0, "b" => 1, "ab" => 2}, 14 | max_input_chars_per_word: 50 15 | ) 16 | end 17 | 18 | test "rejects bad keyword params" do 19 | assert_raise ErlangError, fn -> 20 | Tokenizers.Model.WordPiece.init(%{"a" => 0, "b" => 1, "ab" => 2}, 21 | weird_value: :something 22 | ) 23 | end 24 | end 25 | end 26 | 27 | describe "loaded from file" do 28 | test "good initialization with valid paths" do 29 | assert {:ok, %Tokenizers.Model{}} = 30 | Tokenizers.Model.WordPiece.from_file("test/fixtures/vocab.txt") 31 | end 32 | 33 | test "bad initialization with invalid paths" do 34 | assert {:error, _} = 35 | Tokenizers.Model.WordPiece.from_file("test/fixtures/not_found_vocab.json") 36 | end 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /test/tokenizers/model_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.ModelTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Model 4 | 5 | describe "model common functionality" do 6 | test "successfully saves the model" do 7 | {:ok, original_model} = Tokenizers.Model.BPE.empty() 8 | assert {:ok, [vocab, merges]} = Tokenizers.Model.save(original_model, System.tmp_dir!()) 9 | assert File.exists?(vocab) 10 | assert File.exists?(merges) 11 | 12 | assert {:ok, loaded_model} = Tokenizers.Model.BPE.from_file(vocab, merges) 13 | assert Tokenizers.Model.info(original_model) == Tokenizers.Model.info(loaded_model) 14 | end 15 | 16 | test "successfully saves the model to a directory with prefix" do 17 | {:ok, original_model} = Tokenizers.Model.BPE.empty() 18 | 19 | assert {:ok, [vocab, merges]} = 20 | Tokenizers.Model.save(original_model, System.tmp_dir!(), prefix: "MODEL_PREFIX") 21 | 22 | assert File.exists?(vocab) 23 | assert File.exists?(merges) 24 | 25 | assert String.contains?(vocab, "MODEL_PREFIX") 26 | assert String.contains?(merges, "MODEL_PREFIX") 27 | 28 | assert {:ok, loaded_model} = Tokenizers.Model.BPE.from_file(vocab, merges) 29 | assert Tokenizers.Model.info(original_model) == Tokenizers.Model.info(loaded_model) 30 | end 31 | end 32 | end 33 | -------------------------------------------------------------------------------- /test/tokenizers/normalizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.NormalizerTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Normalizer 4 | 5 | describe "Bert" do 6 | test "accepts no parameters" do 7 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.bert_normalizer() 8 | end 9 | 10 | test "accepts options" do 11 | assert %Tokenizers.Normalizer{} = 12 | Tokenizers.Normalizer.bert_normalizer( 13 | clean_text: true, 14 | handle_chinese_chars: true, 15 | strip_accents: true, 16 | lowercase: true 17 | ) 18 | end 19 | 20 | test "works well with strip accents" do 21 | assert Tokenizers.Normalizer.bert_normalizer(strip_accents: true, lowercase: false) 22 | |> Tokenizers.Normalizer.normalize("Héllò") == 23 | {:ok, "Hello"} 24 | end 25 | 26 | test "handles chinese chars well" do 27 | assert Tokenizers.Normalizer.bert_normalizer(handle_chinese_chars: true) 28 | |> Tokenizers.Normalizer.normalize("你好") == 29 | {:ok, " 你 好 "} 30 | end 31 | 32 | test "handles clean text well" do 33 | assert Tokenizers.Normalizer.bert_normalizer(clean_text: true, lowercase: false) 34 | |> Tokenizers.Normalizer.normalize("\ufeffHello") == 35 | {:ok, "Hello"} 36 | end 37 | 38 | test "handle lowercase well" do 39 | assert Tokenizers.Normalizer.bert_normalizer(lowercase: true) 40 | |> Tokenizers.Normalizer.normalize("Hello") == 41 | {:ok, "hello"} 42 | end 43 | end 44 | 45 | describe "Sequence" do 46 | test "can be instantiated" do 47 | assert Tokenizers.Normalizer.sequence([ 48 | Tokenizers.Normalizer.lowercase(), 49 | Tokenizers.Normalizer.strip() 50 | ]) 51 | |> Tokenizers.Normalizer.normalize("HELLO ") == {:ok, "hello"} 52 | end 53 | end 54 | 55 | describe "Lowercase" do 56 | test "accepts no parameters" do 57 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.lowercase() 58 | end 59 | 60 | test "can normalize strings" do 61 | assert Tokenizers.Normalizer.lowercase() 62 | |> Tokenizers.Normalizer.normalize("HELLO") == {:ok, "hello"} 63 | end 64 | end 65 | 66 | describe "Strip" do 67 | test "accepts no parameters" do 68 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.strip() 69 | end 70 | 71 | test "accepts options" do 72 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.strip(left: true, right: true) 73 | end 74 | 75 | test "can normalizer strings" do 76 | assert Tokenizers.Normalizer.strip() 77 | |> Tokenizers.Normalizer.normalize(" Hello there ") == 78 | {:ok, "Hello there"} 79 | end 80 | end 81 | 82 | describe "Prepend" do 83 | test "can be initialized" do 84 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.prepend("▁") 85 | end 86 | 87 | test "can normalize strings" do 88 | assert Tokenizers.Normalizer.prepend("▁") 89 | |> Tokenizers.Normalizer.normalize("Hello") == 90 | {:ok, "▁Hello"} 91 | end 92 | end 93 | 94 | describe "Replace" do 95 | test "can be initialized" do 96 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.replace("find", "replace") 97 | end 98 | 99 | test "can normalize strings" do 100 | assert Tokenizers.Normalizer.replace("Hello", "World") 101 | |> Tokenizers.Normalizer.normalize("Hello") == 102 | {:ok, "World"} 103 | end 104 | end 105 | 106 | describe "Replace Regex" do 107 | test "can be initialized" do 108 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.replace_regex("\\d*", "") 109 | end 110 | 111 | test "can normalize strings" do 112 | assert Tokenizers.Normalizer.replace_regex("\\d*", "") 113 | |> Tokenizers.Normalizer.normalize("1Hel2lo3") == 114 | {:ok, "Hello"} 115 | end 116 | end 117 | 118 | describe "ByteLevel" do 119 | test "can be initialized" do 120 | assert %Tokenizers.Normalizer{} = Tokenizers.Normalizer.byte_level() 121 | end 122 | 123 | test "can normalize strings" do 124 | # Test is taken directly from original Rust implementation 125 | assert Tokenizers.Normalizer.byte_level() 126 | |> Tokenizers.Normalizer.normalize("Hello 我今天能为你做什么") == 127 | {:ok, "HelloĠæĪijä»Ĭ天èĥ½ä¸ºä½łåģļä»Ģä¹Ī"} 128 | end 129 | 130 | test "returns alphabet" do 131 | assert length(Tokenizers.Normalizer.byte_level_alphabet()) != 0 132 | end 133 | end 134 | end 135 | -------------------------------------------------------------------------------- /test/tokenizers/post_processor_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.PostProcessorTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.PostProcessor 4 | 5 | describe "bertProcessing" do 6 | test "instantiates correctly with only two parameters" do 7 | assert %Tokenizers.PostProcessor{} = 8 | Tokenizers.PostProcessor.bert({"[SEP]", 0}, {"[CLS]", 1}) 9 | end 10 | 11 | test "successfully processes data" do 12 | {:ok, tokenizer} = Tokenizers.Tokenizer.init(Tokenizers.Model.BPE.empty() |> elem(1)) 13 | 14 | tokenizer = 15 | tokenizer 16 | |> Tokenizers.Tokenizer.add_special_tokens(["[SEP]", "[CLS]"]) 17 | |> Tokenizers.Tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) 18 | |> Tokenizers.Tokenizer.set_post_processor( 19 | Tokenizers.PostProcessor.bert({"[SEP]", 0}, {"[CLS]", 1}) 20 | ) 21 | 22 | {:ok, output} = Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"}) 23 | 24 | assert Tokenizers.Encoding.get_tokens(output) == [ 25 | "[CLS]", 26 | "my", 27 | "name", 28 | "[SEP]", 29 | "pair", 30 | "[SEP]" 31 | ] 32 | 33 | assert Tokenizers.Encoding.get_ids(output) == [1, 2, 3, 0, 6, 0] 34 | end 35 | end 36 | 37 | describe "robertaProcessing" do 38 | test "instantiates correctly with only two parameters" do 39 | assert %Tokenizers.PostProcessor{} = 40 | Tokenizers.PostProcessor.roberta({"", 0}, {"", 1}) 41 | end 42 | 43 | test "successfully processes data" do 44 | {:ok, tokenizer} = Tokenizers.Tokenizer.init(Tokenizers.Model.BPE.empty() |> elem(1)) 45 | 46 | tokenizer = 47 | tokenizer 48 | |> Tokenizers.Tokenizer.add_special_tokens(["", ""]) 49 | |> Tokenizers.Tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) 50 | |> Tokenizers.Tokenizer.set_post_processor( 51 | Tokenizers.PostProcessor.roberta({"", 1}, {"", 0}) 52 | ) 53 | 54 | {:ok, output} = Tokenizers.Tokenizer.encode(tokenizer, {"my name", "pair"}) 55 | 56 | assert Tokenizers.Encoding.get_tokens(output) == [ 57 | "", 58 | "my", 59 | "name", 60 | "", 61 | "", 62 | "pair", 63 | "" 64 | ] 65 | 66 | assert Tokenizers.Encoding.get_ids(output) == [0, 2, 3, 1, 1, 6, 1] 67 | end 68 | end 69 | 70 | describe "byteLevelProcessing" do 71 | test "instantiates correctly with only two parameters" do 72 | assert %Tokenizers.PostProcessor{} = 73 | Tokenizers.PostProcessor.byte_level(trim_offsets: false) 74 | 75 | assert %Tokenizers.PostProcessor{} = Tokenizers.PostProcessor.byte_level() 76 | end 77 | end 78 | end 79 | -------------------------------------------------------------------------------- /test/tokenizers/pre_tokenizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.PreTokenizerTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.PreTokenizer 4 | 5 | describe "Byte Level pretokenizer" do 6 | test "accepts no parameters" do 7 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.byte_level() 8 | end 9 | 10 | test "accepts options" do 11 | assert %Tokenizers.PreTokenizer{} = 12 | Tokenizers.PreTokenizer.byte_level(add_prefix_space: false) 13 | end 14 | end 15 | 16 | describe "Split pretokenizer" do 17 | test "accepts no parameters" do 18 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.split(" ", :removed) 19 | end 20 | 21 | test "accepts options" do 22 | assert %Tokenizers.PreTokenizer{} = 23 | Tokenizers.PreTokenizer.split(" ", :removed, invert: true) 24 | end 25 | end 26 | 27 | describe "Regex split pretokenizer" do 28 | test "accepts regular expressions" do 29 | assert %Tokenizers.PreTokenizer{} = 30 | Tokenizers.PreTokenizer.split_regex(".*", :removed) 31 | end 32 | 33 | test "accepts options" do 34 | assert %Tokenizers.PreTokenizer{} = 35 | Tokenizers.PreTokenizer.split_regex(".*", :removed, invert: true) 36 | end 37 | end 38 | 39 | describe "WhitespaceSplit pretokenizer" do 40 | test "accepts no parameters" do 41 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.whitespace_split() 42 | end 43 | end 44 | 45 | describe "BertPreTokenizer pretokenizer" do 46 | test "accepts no parameters" do 47 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.bert_pre_tokenizer() 48 | end 49 | end 50 | 51 | describe "Metaspace pretokenizer" do 52 | test "accepts no parameters" do 53 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.metaspace() 54 | end 55 | 56 | test "accepts options" do 57 | assert %Tokenizers.PreTokenizer{} = 58 | Tokenizers.PreTokenizer.metaspace(replacement: ?_, prepend_scheme: :never) 59 | end 60 | end 61 | 62 | describe "CharDelimiterSplit pretokenizer" do 63 | test "accepts no parameters" do 64 | assert %Tokenizers.PreTokenizer{} = Tokenizers.PreTokenizer.char_delimiter_split(?_) 65 | end 66 | end 67 | 68 | describe "Sequence pretokenizer" do 69 | test "accepts no parameters but chain of tokenizers" do 70 | assert %Tokenizers.PreTokenizer{} = 71 | Tokenizers.PreTokenizer.sequence([ 72 | Tokenizers.PreTokenizer.whitespace_split(), 73 | Tokenizers.PreTokenizer.bert_pre_tokenizer() 74 | ]) 75 | end 76 | end 77 | end 78 | -------------------------------------------------------------------------------- /test/tokenizers/tokenizer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.TokenizerTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Tokenizer 4 | 5 | alias Tokenizers.Encoding 6 | alias Tokenizers.Tokenizer 7 | 8 | setup do 9 | {:ok, tokenizer} = Tokenizer.from_file("test/fixtures/bert-base-cased.json") 10 | {:ok, tokenizer: tokenizer} 11 | end 12 | 13 | describe "IO" do 14 | test "can read from file" do 15 | {:ok, tokenizer} = Tokenizer.from_file("test/fixtures/bert-base-cased.json") 16 | assert Tokenizer.get_vocab_size(tokenizer) == 28996 17 | end 18 | 19 | @tag :tmp_dir 20 | test "can write to file", config do 21 | {:ok, tokenizer} = Tokenizer.from_file("test/fixtures/bert-base-cased.json") 22 | {:ok, path} = Tokenizer.save(tokenizer, config.tmp_dir <> "test.json") 23 | {:ok, tokenizer} = Tokenizer.from_file(path) 24 | assert Tokenizer.get_vocab_size(tokenizer) == 28996 25 | end 26 | end 27 | 28 | describe "modify tokenizer" do 29 | test "can add special tokens" do 30 | special_tokens = ["<|test|>"] 31 | 32 | {:ok, tokenizer} = Tokenizer.from_file("test/fixtures/bert-base-cased.json") 33 | tokenizer = Tokenizer.add_special_tokens(tokenizer, special_tokens) 34 | 35 | assert Tokenizer.get_vocab_size(tokenizer) == 28997 36 | end 37 | 38 | test "can decode special tokens" do 39 | text = ["This <|test|>is a test<|also|>", "<|test|>And so<|also|> is this<|test|>"] 40 | special_tokens = ["<|test|>", "<|also|>"] 41 | 42 | {:ok, tokenizer} = Tokenizer.from_file("test/fixtures/bert-base-cased.json") 43 | tokenizer = Tokenizer.add_special_tokens(tokenizer, special_tokens) 44 | 45 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 46 | 47 | {:ok, decodings} = 48 | Tokenizer.decode_batch(tokenizer, Enum.map(encodings, &Encoding.get_ids/1), 49 | skip_special_tokens: true 50 | ) 51 | 52 | assert ["This is a test", "And so is this"] == decodings 53 | end 54 | end 55 | 56 | describe "from_pretrained/2" do 57 | defmodule SuccessHTTPClient do 58 | def request(opts) do 59 | send(self(), {:request, opts}) 60 | 61 | body = 62 | case opts[:method] do 63 | :get -> 64 | File.read!("test/fixtures/bert-base-cased.json") 65 | 66 | :head -> 67 | "" 68 | end 69 | 70 | {:ok, 71 | %{ 72 | body: body, 73 | headers: [{"etag", "test-etag"}], 74 | status: opts[:test_status] 75 | }} 76 | end 77 | end 78 | 79 | defmodule ErrorHTTPClient do 80 | def request(opts) do 81 | send(self(), {:request, opts}) 82 | {:error, "internal error"} 83 | end 84 | end 85 | 86 | @tag :tmp_dir 87 | test "load from pretrained successfully", %{tmp_dir: tmp_dir} do 88 | {:ok, tokenizer} = 89 | Tokenizer.from_pretrained("bert-base-cased", 90 | use_cache: false, 91 | cache_dir: tmp_dir, 92 | http_client: {SuccessHTTPClient, [test_status: 200, headers: [{"test-header", "42"}]]} 93 | ) 94 | 95 | assert Tokenizer.get_vocab_size(tokenizer) == 28996 96 | 97 | assert_received {:request, opts} 98 | 99 | assert opts[:method] == :get 100 | assert opts[:base_url] == "https://huggingface.co" 101 | assert opts[:url] == "/bert-base-cased/resolve/main/tokenizer.json" 102 | 103 | assert [{"test-header", "42"}, {"user-agent", "tokenizers-elixir/" <> _app_version}] = 104 | opts[:headers] 105 | 106 | {:ok, tokenizer} = 107 | Tokenizer.from_pretrained("bert-base-cased", 108 | use_cache: true, 109 | cache_dir: tmp_dir, 110 | http_client: {SuccessHTTPClient, [test_status: 200]} 111 | ) 112 | 113 | assert Tokenizer.get_vocab_size(tokenizer) == 28996 114 | 115 | assert_received {:request, opts} 116 | assert opts[:method] == :head 117 | end 118 | 119 | @tag :tmp_dir 120 | test "returns error when status is not found", %{tmp_dir: tmp_dir} do 121 | assert {:error, :not_found} = 122 | Tokenizer.from_pretrained("bert-base-cased", 123 | use_cache: false, 124 | cache_dir: tmp_dir, 125 | http_client: {SuccessHTTPClient, [test_status: 404]} 126 | ) 127 | end 128 | 129 | @tag :tmp_dir 130 | test "returns error when request is not successful", %{tmp_dir: tmp_dir} do 131 | assert {:error, error} = 132 | Tokenizer.from_pretrained("bert-base-cased", 133 | use_cache: false, 134 | cache_dir: tmp_dir, 135 | http_client: {ErrorHTTPClient, []} 136 | ) 137 | 138 | assert error == "internal error" 139 | end 140 | end 141 | 142 | describe "encode/decode" do 143 | test "can encode a single string", %{tokenizer: tokenizer} do 144 | assert {:ok, %Tokenizers.Encoding{}} = Tokenizer.encode(tokenizer, "This is a test") 145 | end 146 | 147 | test "can apply transformations to encoding", %{tokenizer: tokenizer} do 148 | assert {:ok, %Tokenizers.Encoding{}} = 149 | Tokenizer.encode(tokenizer, "This is a test", 150 | encoding_transformations: [ 151 | Encoding.Transformation.pad(2), 152 | Encoding.Transformation.truncate(4), 153 | Encoding.Transformation.set_sequence_id(1234) 154 | ] 155 | ) 156 | end 157 | 158 | test "can encode a single string with special characters", %{tokenizer: tokenizer} do 159 | seq = "This is a test" 160 | {:ok, encoding_clean} = Tokenizer.encode(tokenizer, seq, add_special_tokens: false) 161 | {:ok, encoding_special} = Tokenizer.encode(tokenizer, seq) 162 | 163 | refute Encoding.get_length(encoding_clean) == Encoding.get_length(encoding_special) 164 | end 165 | 166 | test "can encode a pair of strings", %{tokenizer: tokenizer} do 167 | assert {:ok, %Tokenizers.Encoding{}} = Tokenizer.encode(tokenizer, {"Question?", "Answer"}) 168 | end 169 | 170 | test "can encode a batch of strings", %{tokenizer: tokenizer} do 171 | assert {:ok, [%Tokenizers.Encoding{}, %Tokenizers.Encoding{}]} = 172 | Tokenizer.encode_batch(tokenizer, ["This is a test", "And so is this"]) 173 | end 174 | 175 | test "can encode a batch of strings and pairs", %{tokenizer: tokenizer} do 176 | assert {:ok, [%Tokenizers.Encoding{}, %Tokenizers.Encoding{}]} = 177 | Tokenizer.encode_batch(tokenizer, ["This is a test", {"Question?", "Answer"}]) 178 | end 179 | 180 | test "can apply transformations to batch of encodings", %{tokenizer: tokenizer} do 181 | assert {:ok, [%Tokenizers.Encoding{}, %Tokenizers.Encoding{}]} = 182 | Tokenizer.encode_batch(tokenizer, ["This is a test", "And so is this"], 183 | encoding_transformations: [ 184 | Encoding.Transformation.pad(2), 185 | Encoding.Transformation.truncate(4), 186 | Encoding.Transformation.set_sequence_id(1234) 187 | ] 188 | ) 189 | end 190 | 191 | test "can decode a single encoding", %{tokenizer: tokenizer} do 192 | text = "This is a test" 193 | {:ok, encoding} = Tokenizer.encode(tokenizer, text) 194 | ids = Encoding.get_ids(encoding) 195 | {:ok, decoded} = Tokenizer.decode(tokenizer, ids) 196 | assert decoded == text 197 | end 198 | 199 | test "can decode a single encoding skipping special characters", %{tokenizer: tokenizer} do 200 | seq = "This is a test" 201 | {:ok, encoding} = Tokenizer.encode(tokenizer, seq) 202 | ids = Encoding.get_ids(encoding) 203 | 204 | {:ok, seq_clean} = Tokenizer.decode(tokenizer, ids) 205 | {:ok, seq_special} = Tokenizer.decode(tokenizer, ids, skip_special_tokens: false) 206 | 207 | refute seq_special == seq 208 | assert seq_clean == seq 209 | end 210 | 211 | test "can decode a batch of encodings", %{tokenizer: tokenizer} do 212 | text = ["This is a test", "And so is this"] 213 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 214 | ids = Enum.map(encodings, &Encoding.get_ids/1) 215 | {:ok, decoded} = Tokenizer.decode_batch(tokenizer, ids) 216 | assert decoded == text 217 | 218 | assert Enum.map(ids, &list_to_u32/1) == Enum.map(encodings, &Encoding.get_u32_ids/1) 219 | end 220 | end 221 | 222 | describe "encode metadata" do 223 | test "can return attention mask", %{tokenizer: tokenizer} do 224 | text = ["Hello world", "Yes sir hello indeed"] 225 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 226 | 227 | attention_mask = Enum.map(encodings, &Encoding.get_attention_mask/1) 228 | assert [[1, 1, 1, 1], [1, 1, 1, 1, 1, 1]] == attention_mask 229 | 230 | assert Enum.map(attention_mask, &list_to_u32/1) == 231 | Enum.map(encodings, &Encoding.get_u32_attention_mask/1) 232 | end 233 | 234 | test "can return type ids", %{tokenizer: tokenizer} do 235 | text = [{"Hello", "world"}, {"Yes sir", "hello indeed"}] 236 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 237 | 238 | type_ids = Enum.map(encodings, &Encoding.get_type_ids/1) 239 | assert [[0, 0, 0, 1, 1], [0, 0, 0, 0, 1, 1, 1]] == type_ids 240 | 241 | assert Enum.map(type_ids, &list_to_u32/1) == 242 | Enum.map(encodings, &Encoding.get_u32_type_ids/1) 243 | end 244 | 245 | test "can return special tokens mask", %{tokenizer: tokenizer} do 246 | text = ["This is a test", "And so is this"] 247 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 248 | 249 | special_tokens_mask = Enum.map(encodings, &Encoding.get_special_tokens_mask/1) 250 | assert [[1, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 1]] == special_tokens_mask 251 | 252 | assert Enum.map(special_tokens_mask, &list_to_u32/1) == 253 | Enum.map(encodings, &Encoding.get_u32_special_tokens_mask/1) 254 | end 255 | 256 | test "can return offsets", %{tokenizer: tokenizer} do 257 | text = ["This is a test", "And so is this"] 258 | {:ok, encodings} = Tokenizer.encode_batch(tokenizer, text) 259 | offsets = Enum.map(encodings, &Encoding.get_offsets/1) 260 | 261 | assert [ 262 | [{0, 0}, {0, 4}, {5, 7}, {8, 9}, {10, 14}, {0, 0}], 263 | [{0, 0}, {0, 3}, {4, 6}, {7, 9}, {10, 14}, {0, 0}] 264 | ] == offsets 265 | end 266 | end 267 | 268 | defp list_to_u32(list) do 269 | for x <- list, into: <<>>, do: <> 270 | end 271 | end 272 | -------------------------------------------------------------------------------- /test/tokenizers/trainer_test.exs: -------------------------------------------------------------------------------- 1 | defmodule Tokenizers.TrainerTest do 2 | use ExUnit.Case, async: true 3 | doctest Tokenizers.Trainer 4 | 5 | describe "BPE trainer" do 6 | test "successfully initializes with empty params" do 7 | assert {:ok, %Tokenizers.Trainer{}} = Tokenizers.Trainer.bpe() 8 | end 9 | 10 | test "successfully initializes with params" do 11 | assert {:ok, %Tokenizers.Trainer{} = trainer} = 12 | Tokenizers.Trainer.bpe( 13 | vocab_size: 1000, 14 | min_frequency: 2, 15 | special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 16 | limit_alphabet: 1000, 17 | initial_alphabet: [?a, ?b, ?c], 18 | show_progress: true, 19 | continuing_subword_prefix: "##", 20 | end_of_word_suffix: "##" 21 | ) 22 | 23 | assert %{ 24 | "continuing_subword_prefix" => "##", 25 | "end_of_word_suffix" => "##", 26 | "initial_alphabet" => 3, 27 | "limit_alphabet" => 1000, 28 | "min_frequency" => 2, 29 | "show_progress" => true, 30 | "special_tokens" => 5, 31 | "trainer_type" => "bpe", 32 | "vocab_size" => 1000 33 | } == Tokenizers.Trainer.info(trainer) 34 | end 35 | 36 | test "fails to initialize with invalid params" do 37 | assert {:error, _} = 38 | Tokenizers.Trainer.bpe( 39 | vocab_size: 1000, 40 | min_frequency: 2, 41 | special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 42 | limit_alphabet: 1000, 43 | initial_alphabet: [1_234_123_451, ?b, ?c], 44 | show_progress: true, 45 | continuing_subword_prefix: "##", 46 | end_of_word_suffix: "##" 47 | ) 48 | end 49 | 50 | test "accepts added tokens as special tokens" do 51 | assert {:ok, %Tokenizers.Trainer{}} = 52 | Tokenizers.Trainer.bpe( 53 | special_tokens: [ 54 | Tokenizers.AddedToken.new("[UNK]", special: true), 55 | Tokenizers.AddedToken.new("[CLS]", special: true), 56 | Tokenizers.AddedToken.new("[SEP]", special: true), 57 | Tokenizers.AddedToken.new("[PAD]", special: true), 58 | Tokenizers.AddedToken.new("[MASK]", special: true) 59 | ] 60 | ) 61 | end 62 | end 63 | 64 | describe "WordPiece trainer" do 65 | test "successfully initializes with empty params" do 66 | assert {:ok, %Tokenizers.Trainer{}} = Tokenizers.Trainer.wordpiece() 67 | end 68 | 69 | test "successfully initializes with params" do 70 | assert {:ok, %Tokenizers.Trainer{}} = 71 | Tokenizers.Trainer.wordpiece( 72 | vocab_size: 1000, 73 | min_frequency: 2, 74 | special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 75 | limit_alphabet: 1000, 76 | initial_alphabet: [?a, ?b, ?c], 77 | show_progress: true, 78 | continuing_subword_prefix: "##", 79 | end_of_word_suffix: "##" 80 | ) 81 | end 82 | end 83 | 84 | describe "WordLevel trainer" do 85 | test "successfully initializes with empty params" do 86 | assert {:ok, %Tokenizers.Trainer{}} = Tokenizers.Trainer.wordlevel() 87 | end 88 | 89 | test "successfully initializes with params" do 90 | assert {:ok, %Tokenizers.Trainer{}} = 91 | Tokenizers.Trainer.wordlevel( 92 | vocab_size: 1000, 93 | min_frequency: 2, 94 | special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 95 | show_progress: true 96 | ) 97 | end 98 | end 99 | 100 | describe "Unigram trainer" do 101 | test "successfully initializes with empty params" do 102 | assert {:ok, %Tokenizers.Trainer{}} = Tokenizers.Trainer.unigram() 103 | end 104 | 105 | test "successfully initializes with params" do 106 | assert {:ok, %Tokenizers.Trainer{}} = 107 | Tokenizers.Trainer.unigram( 108 | vocab_size: 1000, 109 | n_sub_iterations: 2, 110 | shrinking_factor: 0.75, 111 | special_tokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], 112 | initial_alphabet: [?a, ?b, ?c], 113 | uni_token: "##", 114 | max_piece_length: 4, 115 | seed_size: 100, 116 | show_progress: true 117 | ) 118 | end 119 | end 120 | end 121 | --------------------------------------------------------------------------------