├── .github └── workflows │ ├── CI.yml │ └── test.yml ├── .gitignore ├── CITATION.cff ├── Cargo.toml ├── LICENSE ├── README.md ├── benchmarks ├── README.md ├── data │ └── c4.jsonl └── train.py ├── bpeasy ├── __init__.py ├── bpeasy.pyi ├── convert.py └── tokenizer.py ├── pyproject.toml ├── requirements.txt ├── src └── lib.rs └── tests ├── test_convert.py ├── test_tokenizer.py ├── test_train_bpe.py └── test_utils_bpe.py /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | test: 12 | name: Pre-test for release 13 | strategy: 14 | matrix: 15 | os: ["ubuntu"] 16 | python-version: ["3.10", "3.13"] 17 | runs-on: ${{ matrix.os }}-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | cache: "pip" 24 | - uses: dtolnay/rust-toolchain@stable 25 | - name: Setup virtual environment 26 | run: | 27 | python -m venv venv 28 | source venv/bin/activate 29 | pip install -r requirements.txt 30 | - name: Run tests 31 | run: | 32 | source venv/bin/activate 33 | cargo test 34 | maturin develop --release 35 | pytest tests 36 | 37 | linux: 38 | needs: [test] 39 | runs-on: ubuntu-latest 40 | strategy: 41 | matrix: 42 | target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] 43 | python-version: ["3.10", "3.13"] 44 | steps: 45 | - uses: actions/checkout@v4 46 | - uses: actions/setup-python@v5 47 | with: 48 | python-version: ${{ matrix.python-version }} 49 | - name: Build wheels 50 | uses: PyO3/maturin-action@v1 51 | with: 52 | target: ${{ matrix.target }} 53 | args: --release --out dist --find-interpreter 54 | sccache: "true" 55 | manylinux: auto 56 | - name: Upload wheels 57 | uses: actions/upload-artifact@v4 58 | with: 59 | name: dist-linux-${{ matrix.target }}-py${{ matrix.python-version }} 60 | path: dist 61 | 62 | windows: 63 | needs: [test] 64 | runs-on: windows-latest 65 | strategy: 66 | matrix: 67 | target: [x64, x86] 68 | python-version: ["3.10", "3.13"] 69 | steps: 70 | - uses: actions/checkout@v4 71 | - uses: actions/setup-python@v5 72 | with: 73 | python-version: ${{ matrix.python-version }} 74 | architecture: ${{ matrix.target }} 75 | - name: Build wheels 76 | uses: PyO3/maturin-action@v1 77 | with: 78 | target: ${{ matrix.target }} 79 | args: --release --out dist --find-interpreter 80 | sccache: "true" 81 | - name: Upload wheels 82 | uses: actions/upload-artifact@v4 83 | with: 84 | name: dist-windows-${{ matrix.target }}-py${{ matrix.python-version }} 85 | path: dist 86 | 87 | macos: 88 | needs: [test] 89 | runs-on: macos-latest 90 | strategy: 91 | matrix: 92 | target: [x86_64, aarch64] 93 | python-version: ["3.10", "3.13"] 94 | steps: 95 | - uses: actions/checkout@v4 96 | - uses: actions/setup-python@v5 97 | with: 98 | python-version: ${{ matrix.python-version }} 99 | - name: Build wheels 100 | uses: PyO3/maturin-action@v1 101 | with: 102 | target: ${{ matrix.target }} 103 | args: --release --out dist --find-interpreter 104 | sccache: "true" 105 | - name: Upload wheels 106 | uses: actions/upload-artifact@v4 107 | with: 108 | name: dist-macos-${{ matrix.target }}-py${{ matrix.python-version }} 109 | path: dist 110 | 111 | sdist: 112 | needs: [test] 113 | runs-on: ubuntu-latest 114 | steps: 115 | - uses: actions/checkout@v4 116 | - name: Build sdist 117 | uses: PyO3/maturin-action@v1 118 | with: 119 | command: sdist 120 | args: --out dist 121 | - name: Upload sdist 122 | uses: actions/upload-artifact@v4 123 | with: 124 | name: dist-sdist 125 | path: dist 126 | 127 | release: 128 | name: Release 129 | runs-on: ubuntu-latest 130 | needs: [linux, windows, macos, sdist] 131 | if: "startsWith(github.ref, 'refs/tags/')" 132 | permissions: 133 | id-token: write 134 | steps: 135 | - uses: actions/download-artifact@v4 136 | with: 137 | pattern: dist-* 138 | merge-multiple: true 139 | path: dist 140 | - name: Publish to PyPI 141 | uses: PyO3/maturin-action@v1 142 | with: 143 | command: upload 144 | args: --non-interactive --skip-existing dist/* 145 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | coverage: 15 | name: Coverage for ${{ matrix.os }} with Python ${{ matrix.python-version }} 16 | strategy: 17 | matrix: 18 | os: ["ubuntu"] 19 | python-version: ["3.10", "3.13"] 20 | runs-on: ${{ matrix.os }}-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | cache: "pip" 27 | - uses: dtolnay/rust-toolchain@stable 28 | - name: Install cargo-llvm-cov 29 | uses: taiki-e/install-action@cargo-llvm-cov 30 | - uses: Swatinem/rust-cache@v2.7.0 31 | with: 32 | key: coverage-cargo-${{ matrix.os }}-py${{ matrix.python-version }} 33 | continue-on-error: true 34 | - name: Setup virtual environment 35 | run: | 36 | python -m venv venv 37 | source venv/bin/activate 38 | pip install -r requirements.txt 39 | - name: Run coverage 40 | run: | 41 | source venv/bin/activate 42 | source <(cargo llvm-cov show-env --export-prefix) 43 | export CARGO_TARGET_DIR=$CARGO_LLVM_COV_TARGET_DIR 44 | export CARGO_INCREMENTAL=1 45 | cargo llvm-cov clean --workspace 46 | cargo test 47 | maturin develop 48 | pytest tests --cov=bpeasy --cov-report xml 49 | cargo llvm-cov report --lcov --output-path coverage.lcov 50 | - uses: codecov/codecov-action@v3 51 | env: 52 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 53 | with: 54 | files: coverage.lcov,coverage.xml 55 | name: ${{ matrix.os }}-py${{ matrix.python-version }} 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | out/ 2 | .nox/ 3 | .benchmarks/ 4 | 5 | # Generated by Cargo 6 | # will have compiled files and executables 7 | debug/ 8 | target/ 9 | 10 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 11 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 12 | Cargo.lock 13 | 14 | # These are backup files generated by rustfmt 15 | **/*.rs.bk 16 | 17 | # MSVC Windows builds of rustc generate these, which store debugging information 18 | *.pdb 19 | 20 | # Added by cargo 21 | /target 22 | 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | .pytest_cache/ 27 | *.py[cod] 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | .venv/ 35 | env/ 36 | bin/ 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | include/ 47 | man/ 48 | venv/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | pip-selfcheck.json 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .coverage 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | 66 | # Translations 67 | *.mo 68 | 69 | # Mr Developer 70 | .mr.developer.cfg 71 | .project 72 | .pydevproject 73 | 74 | # Rope 75 | .ropeproject 76 | 77 | # Django stuff: 78 | *.log 79 | *.pot 80 | 81 | .DS_Store 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyCharm 87 | .idea/ 88 | 89 | # VSCode 90 | .vscode/ 91 | 92 | # Pyenv 93 | .python-version 94 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: bpeasy 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Gautier 9 | family-names: Dagan 10 | email: gautier.dagan@ed.ac.uk 11 | affiliation: University of Edinburgh 12 | orcid: 'https://orcid.org/0000-0002-1867-4201' 13 | repository-code: 'https://github.com/gautierdag/bpeasy' 14 | url: 'https://github.com/gautierdag/bpeasy' 15 | license: MIT 16 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bpeasy" 3 | version = "0.1.5" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "bpeasy" 8 | crate-type = ["cdylib"] 9 | 10 | [profile.release] 11 | opt-level = 3 12 | lto = "fat" 13 | codegen-units = 1 14 | 15 | [dependencies] 16 | fancy-regex = "0.12.0" 17 | fxhash = "0.2.1" 18 | pyo3 = { version = "0.19.0", features = ["extension-module"] } 19 | rayon = "1.8.0" 20 | regex = "1.5.4" 21 | serde_json = "1.0.108" 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gautier Dagan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bpeasy 2 | 3 | [![codecov](https://codecov.io/gh/gautierdag/bpeasy/branch/main/graph/badge.svg?token=NWHDJ22L8I)](https://codecov.io/gh/gautierdag/bpeasy) [![tests](https://github.com/gautierdag/bpeasy/actions/workflows/test.yml/badge.svg)](https://github.com/gautierdag/bpeasy/actions/workflows/test.yml) [![image](https://img.shields.io/pypi/l/bpeasy.svg)](https://pypi.python.org/pypi/bpeasy) [![image](https://img.shields.io/pypi/pyversions/bpeasy.svg)](https://pypi.python.org/pypi/bpeasy) [![PyPI version](https://badge.fury.io/py/bpeasy.svg)](https://badge.fury.io/py/bpeasy) 4 | 5 | ## Overview 6 | 7 | `bpeasy` is a Python package that provides a tokenizer trainer, implementing in 400 lines of rust an efficient version of Byte Pair Encoding (BPE). The implementation largely follows the huggingface `tokenizers` library, but makes opinionated decisions to simplify the tokenizer training specifically to: 8 | 9 | 1. Treat text data at the byte-level first --- all text is converted to bytes before training rather than using a character-level approach (like in Huggingface). 10 | 2. Always use a regex-based split pre-tokenizer. This is a customisable regex that is applied to the text before training. This regex decides where to split the text and limits what kind of tokens are possible. This is technically possible in Huggingface but is not well documented. We also use the `fancy-regex` crate which supports a richer set of regex features than the `regex` crate used in Huggingface. 11 | 3. Use `int64` types for counting to allow for training on much larger datasets without the risk of overflow. 12 | 13 | **You can think of `bpeasy` as the `tiktoken` training code that never was.** 14 | 15 | See the [benchmarks](/benchmarks/README.md) section for a comparison with the Huggingface library. 16 | 17 | ## Installation 18 | 19 | Simply install the package using pip: 20 | 21 | ```bash 22 | pip install bpeasy 23 | ``` 24 | 25 | ## Training 26 | 27 | The training function is designed to be bare-bones and returns the trained tokenizer vocab as a dictionary of bytes to integers. This is to allow for maximum flexibility in how you want to use the tokenizer. For example, you can use then port these to tiktoken or Huggingface tokenizers (see below). 28 | 29 | ```python 30 | # should be an iterator over str 31 | iterator = jsonl_content_iterator(args) 32 | # example regex from GPT-4 33 | regex_pattern = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 34 | 35 | # returns the vocab (dict[bytes, int]) 36 | vocab = bpeasy.train_bpe( 37 | iterator, 38 | regex_pattern, 39 | args.max_sentencepiece_length, # max length of tokens 40 | args.vocab_size, # max size of vocab 41 | ) 42 | ``` 43 | 44 | Alternatively, you can also train using the basic tokenizer class provided: 45 | 46 | ```python 47 | from bpeasy.tokenizer import BPEasyTokenizer 48 | 49 | tokenizer = BPEasyTokenizer.train( 50 | iterator, # iterator over str 51 | vocab_size=vocab_size, 52 | max_token_length=max_token_length, 53 | regex_pattern=regex_pattern, 54 | special_tokens=["", "", ""], 55 | fill_to_nearest_multiple_of_eight=True, 56 | name="bpeasy", 57 | ) 58 | ``` 59 | 60 | ### Encoding/Decoding 61 | 62 | To test your tokenizer you can use the `BPEasyTokenizer` class, which is a wrapper around the `tiktoken.Encoding` module, simplifying the handling of vocabularies, special tokens, and regex patterns for tokenization. 63 | 64 | ```python 65 | from bpeasy.tokenizer import BPEasyTokenizer 66 | 67 | your_special_tokens = ["", "", ""] 68 | 69 | tokenizer = BPEasyTokenizer( 70 | vocab=vocab, 71 | regex_pattern=regex_pattern, 72 | special_tokens=your_special_tokens, 73 | fill_to_nearest_multiple_of_eight=True, # pad vocab to multiple of 8 74 | name="bpeasy" # optional name for the tokenizer 75 | ) 76 | 77 | test = "hello_world" 78 | 79 | # encode and decode uses the tiktoken functions 80 | encoded = tokenizer.encode(test) 81 | decoded = tokenizer.decode(encoded) 82 | > "hello_world" 83 | ``` 84 | 85 | You can also use `tiktoken` directly, but you would need to handle the special tokens and regex pattern yourself: 86 | 87 | ```python 88 | import tiktoken 89 | 90 | vocab = bpeasy.train_bpe(...) 91 | special_tokens = ["", "", ""] 92 | 93 | # Sort the vocab by rank 94 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1]) 95 | 96 | # add special tokens 97 | special_token_ranks = {} 98 | for special_token in special_tokens: 99 | special_token_ranks[special_token] = len(sorted_vocab) 100 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab))) 101 | 102 | full_vocab = dict(sorted_vocab) 103 | 104 | encoder = tiktoken.Encoding( 105 | name=name, 106 | pat_str=regex_pattern, 107 | mergeable_ranks=full_vocab, 108 | special_tokens=special_token_ranks, 109 | ) 110 | ``` 111 | 112 | ### Save/Load tokenizer from file 113 | 114 | We provide basic utility functions to save and load the tokenizer from a json file. 115 | 116 | ```python 117 | tokenizer.save("path_to_file.json") 118 | 119 | tokenizer = BPEasyTokenizer.from_file("path_to_file.json") 120 | ``` 121 | 122 | ### Export to HuggingFace format 123 | 124 | We also support exporting the tokenizer to the HuggingFace format, which can then be used directly with the HuggingFace `transformers` library. 125 | 126 | ```python 127 | from bpeasy.tokenizer import BPEasyTokenizer 128 | from trans 129 | tokenizer = BPEasyTokenizer( 130 | ... 131 | ) 132 | 133 | tokenizer.export_to_huggingface_format("hf_tokenizer.json") 134 | 135 | from transformers import PreTrainedTokenizerFast 136 | 137 | hf_tokenizer = PreTrainedTokenizerFast(tokenizer_file="hf_tokenizer.json") 138 | ``` 139 | 140 | ### Export vocab to `tiktoken` txt format 141 | 142 | ```python 143 | from bpeasy import 144 | vocab = bpeasy.train_bpe(...) 145 | 146 | # saves the vocab to a tiktoken txt file format 147 | save_vocab_to_tiktoken(vocab, "vocab.txt", special_tokens=["", "", ""]) 148 | 149 | ``` 150 | 151 | If you want to use the `tiktoken` txt format, you will still need to handle the regex and special tokens yourself, as shown above, 152 | 153 | ## Contributing 154 | 155 | Contributions are welcome! Please open an issue if you have any suggestions or improvements. 156 | 157 | ## License 158 | 159 | This project is licensed under the MIT License. 160 | 161 | ## Citation 162 | 163 | If you use `bpeasy` in your research, please cite the following paper: 164 | 165 | ```bash 166 | @software{bpeasy, 167 | author = {Gautier Dagan}, 168 | title = {bpeasy}, 169 | year = {2024}, 170 | url = {https://github.com/gautierdag/bpeasy}, 171 | repository = {https://github.com/gautierdag/bpeasy}, 172 | author-email = {gautier.dagan@ed.ac.uk}, 173 | affiliation = {University of Edinburgh}, 174 | orcid = {https://orcid.org/0000-0002-1867-4201} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks on the c4 dataset 2 | 3 | Using varying vocab sizes from (5k:100k) 4 | 5 | | Library/Operation | Time (seconds) | Standard Deviation | 6 | |----------------------------|---------------------------------|--------------------------------| 7 | | HuggingFace Train | 0.8165 | ±0.62 | 8 | | `bpeasy` Train | 0.68815 | ±0.41 | 9 | | HuggingFace Encode | 0.6247 | ±0.051 | 10 | | `bpeasy` Encode (uses `tiktoken`) | 0.2679 | ±0.035 | 11 | 12 | | | Bytes per Token (normalised against HF) | Standard Deviation | 13 | |----------------------------|---------------------------------|--------------------------------| 14 | | `bpeasy` | 1.0008992687171223 | ±5.542696043278318e-05 | 15 | 16 | We can see that BPEasy is faster than HuggingFace for training and encoding. Though the difference is not massive for training and will heavily depend on the dataset/compute, it is comparable. 17 | 18 | We also gain a tiny bit of compression (more bytes per token) because `bpeasy` works at the byte level and is slightly more efficient in its allocation of basic tokens. 19 | 20 | ## Reproducing the benchmarks 21 | 22 | ```bash 23 | pip install tokenizers 24 | pip install bpeasy 25 | 26 | python benchmarks/train.py 27 | ``` 28 | -------------------------------------------------------------------------------- /benchmarks/train.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import glob 3 | import json 4 | import logging 5 | import sys 6 | import time 7 | from pathlib import Path 8 | 9 | import tokenizers 10 | from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers 11 | from tokenizers.models import BPE 12 | from tokenizers.trainers import BpeTrainer 13 | from tqdm import tqdm 14 | 15 | import bpeasy 16 | from bpeasy.tokenizer import BPEasyTokenizer 17 | 18 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 19 | 20 | 21 | @dataclasses.dataclass 22 | class TrainBPETokenizerArgs: 23 | dataset: str = "./benchmarks/data" 24 | vocab_size: int = 32_000 25 | max_sentencepiece_length: int = 128 26 | regex_pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 27 | 28 | def __post_init__(self): 29 | checkpoint_dir = Path(self.dataset) 30 | assert checkpoint_dir.is_dir(), checkpoint_dir 31 | 32 | 33 | def jsonl_content_iterator( 34 | args: TrainBPETokenizerArgs, 35 | ): 36 | """ 37 | Iterates over a jsonl file and yields the content of each line 38 | Tracks the number of characters yielded and stops when the limit is reached 39 | This is ripe for optimisation if you want to mess with more fine-grained 40 | character limits (eg. more Python than Java) 41 | """ 42 | file_path = args.dataset 43 | chunk_num, character_count = 0, 0 44 | chunks = glob.glob(f"{file_path}/*.jsonl") 45 | 46 | while chunk_num < len(chunks): 47 | file_name = chunks[chunk_num] 48 | with open(file_name, "r", encoding="utf-8") as f: 49 | for line in f: 50 | obj = json.loads(line) 51 | text = obj["text"] 52 | text_character_count = len(text) 53 | character_count += text_character_count 54 | yield text 55 | chunk_num += 1 56 | 57 | 58 | def train_huggingface(args: TrainBPETokenizerArgs): 59 | # should be at least 0.14.0 to train with char limit 60 | assert tokenizers.__version__ >= "0.14.0" 61 | tokenizer = Tokenizer(BPE(byte_fallback=True)) 62 | trainer = BpeTrainer( 63 | vocab_size=args.vocab_size, 64 | special_tokens=[f"<0x{i:02X}>" for i in range(256)], # seed sm vocab 65 | max_token_length=args.max_sentencepiece_length, 66 | show_progress=False, 67 | ) 68 | gpt_regex = Regex(args.regex_pattern) 69 | 70 | split_pre_tokenizer = pre_tokenizers.Split( 71 | gpt_regex, behavior="isolated", invert=False 72 | ) 73 | byte_pre_tokenizer = pre_tokenizers.ByteLevel( 74 | add_prefix_space=False, use_regex=False 75 | ) 76 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence( 77 | [split_pre_tokenizer, byte_pre_tokenizer] 78 | ) 79 | # Use ByteLevel Decoder 80 | tokenizer.decoder = decoders.Sequence( 81 | [decoders.ByteLevel(), decoders.ByteFallback()] 82 | ) 83 | iterator = jsonl_content_iterator(args) 84 | # training the tokenizer 85 | tokenizer.train_from_iterator(iterator, trainer) 86 | 87 | return tokenizer 88 | 89 | 90 | def train_bpeasy(args: TrainBPETokenizerArgs): 91 | # Use ByteLevel Decoder 92 | iterator = jsonl_content_iterator(args) 93 | # training the tokenizer 94 | vocab = bpeasy.train_bpe( 95 | iterator, 96 | args.regex_pattern, 97 | args.max_sentencepiece_length, 98 | args.vocab_size, 99 | ) 100 | 101 | return BPEasyTokenizer( 102 | vocab, 103 | args.regex_pattern, 104 | special_tokens=[], 105 | fill_to_nearest_multiple_of_eight=False, 106 | ) 107 | 108 | 109 | def encode(tokenizer, args) -> float: 110 | iterator = jsonl_content_iterator(args) 111 | lengths = [] 112 | num_bytes = 0 113 | for text in iterator: 114 | num_bytes += len(text.encode("utf-8")) 115 | encoded = tokenizer.encode(text) 116 | lengths.append(len(encoded)) 117 | return num_bytes / sum(lengths) 118 | 119 | 120 | def get_mean_std_dev(times: list[float]) -> tuple[float, float]: 121 | avg_time = sum(times) / len(times) 122 | std_dev = sum([(t - avg_time) ** 2 for t in times]) 123 | return avg_time, std_dev 124 | 125 | 126 | if __name__ == "__main__": 127 | args = TrainBPETokenizerArgs() 128 | 129 | times_train_huggingface = [] 130 | times_encode_huggingface = [] 131 | times_train_bpeasy = [] 132 | times_encode_bpeasy = [] 133 | byte_per_token_bpeasy = [] 134 | 135 | for v in tqdm(range(5000, 100_000, 5000)): 136 | args.vocab_size = v 137 | 138 | time_now = time.time() 139 | tokenizer = train_huggingface(args) 140 | times_train_huggingface.append(time.time() - time_now) 141 | 142 | time_now = time.time() 143 | byte_per_token_hf = encode(tokenizer, args) 144 | times_encode_huggingface.append(time.time() - time_now) 145 | 146 | time_now = time.time() 147 | tokenizer = train_bpeasy(args) 148 | times_train_bpeasy.append(time.time() - time_now) 149 | 150 | time_now = time.time() 151 | byte_per_token_bpeasy.append(encode(tokenizer, args) / byte_per_token_hf) 152 | times_encode_bpeasy.append(time.time() - time_now) 153 | 154 | m_hf, std_hf = get_mean_std_dev(times_train_huggingface) 155 | m_bpeasy, std_bpeasy = get_mean_std_dev(times_train_bpeasy) 156 | 157 | print(f"huggingface train time {m_hf} +/- {std_hf}") 158 | print(f"bpeasy train time {m_bpeasy} +/- {std_bpeasy}") 159 | 160 | m_hf, std_hf = get_mean_std_dev(times_encode_huggingface) 161 | m_bpeasy, std_bpeasy = get_mean_std_dev(times_encode_bpeasy) 162 | 163 | print(f"huggingface encode time {m_hf} +/- {std_hf}") 164 | print(f"bpeasy encode time {m_bpeasy} +/- {std_bpeasy}") 165 | 166 | m_bpeasy, std_bpeasy = get_mean_std_dev(byte_per_token_bpeasy) 167 | print(f"bpeasy bytes/token vs hf: {m_bpeasy} +/- {std_bpeasy}") 168 | -------------------------------------------------------------------------------- /bpeasy/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | from .bpeasy import train_bpe 4 | 5 | __version__ = version("bpeasy") 6 | 7 | 8 | __all__ = [ 9 | "save_vocab_to_tiktoken", 10 | "train_bpe", 11 | "__version__", 12 | ] 13 | 14 | 15 | def save_vocab_to_tiktoken( 16 | vocab: dict[bytes, int], 17 | out_path: str, 18 | special_tokens: list[str] = [], 19 | fill_to_nearest_multiple_of_eight: bool = False, 20 | ) -> None: 21 | """ 22 | Export vocab to tiktoken txt format - use this if you want to use tiktoken library directly 23 | Note: you will need to handle special tokens and regex yourself 24 | """ 25 | import base64 26 | 27 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1]) 28 | for special_token in special_tokens: 29 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab))) 30 | 31 | if fill_to_nearest_multiple_of_eight: 32 | while len(sorted_vocab) % 8 != 0: 33 | sorted_vocab.append( 34 | (f"<|special-{len(sorted_vocab)}|>".encode("utf-8"), len(sorted_vocab)) 35 | ) 36 | 37 | with open(out_path, "wb") as f: 38 | for token, rank in sorted_vocab: 39 | # encode token to base64 and write to file with rank separated by a space 40 | f.write(base64.b64encode(token) + b" " + str(rank).encode("utf-8") + b"\n") 41 | -------------------------------------------------------------------------------- /bpeasy/bpeasy.pyi: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | def train_bpe( 4 | iterator: Iterator[str], 5 | python_regex: str, 6 | max_token_length: int, 7 | vocab_size: int, 8 | ) -> dict[bytes, int]: ... 9 | -------------------------------------------------------------------------------- /bpeasy/convert.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | 3 | from typing import Optional 4 | from functools import lru_cache 5 | import json 6 | 7 | 8 | # Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960 9 | def bpe( 10 | mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None 11 | ) -> list[bytes]: 12 | parts = [bytes([b]) for b in token] 13 | while True: 14 | min_idx = None 15 | min_rank = None 16 | for i, pair in enumerate(zip(parts[:-1], parts[1:])): 17 | rank = mergeable_ranks.get(pair[0] + pair[1]) 18 | if rank is not None and (min_rank is None or rank < min_rank): 19 | min_idx = i 20 | min_rank = rank 21 | if min_rank is None or (max_rank is not None and min_rank >= max_rank): 22 | break 23 | assert min_idx is not None 24 | parts = ( 25 | parts[:min_idx] 26 | + [parts[min_idx] + parts[min_idx + 1]] 27 | + parts[min_idx + 2 :] 28 | ) 29 | return parts 30 | 31 | 32 | # Source taken from https://github.com/huggingface/transformers/blob/73de5108e172112bc620cfc0ceebfd27730dba11/src/transformers/models/gpt2/tokenization_gpt2.py#L63 33 | @lru_cache() 34 | def bytes_to_unicode(): 35 | """ 36 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control 37 | characters the bpe code barfs on. 38 | 39 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab 40 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for 41 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup 42 | tables between utf-8 bytes and unicode strings. 43 | """ 44 | bs = ( 45 | list(range(ord("!"), ord("~") + 1)) 46 | + list(range(ord("¡"), ord("¬") + 1)) 47 | + list(range(ord("®"), ord("ÿ") + 1)) 48 | ) 49 | cs = bs[:] 50 | n = 0 51 | for b in range(2**8): 52 | if b not in bs: 53 | bs.append(b) 54 | cs.append(2**8 + n) 55 | n += 1 56 | cs = [chr(n) for n in cs] 57 | return dict(zip(bs, cs)) 58 | 59 | 60 | def convert_tiktoken_to_huggingface( 61 | encoder: tiktoken.Encoding, 62 | out_path: str, 63 | regex_pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+""", 64 | ): 65 | byte_encoder = bytes_to_unicode() 66 | 67 | def token_bytes_to_string(b): 68 | return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) 69 | 70 | def generate_vocab_and_merges(encoder): 71 | mergeable_ranks = encoder._mergeable_ranks 72 | 73 | merges = [] 74 | vocab = {} 75 | i = 0 76 | for token, rank in mergeable_ranks.items(): 77 | # Skip special tokens as they are added separately 78 | if token_bytes_to_string(token) in encoder._special_tokens: 79 | continue 80 | 81 | vocab[token_bytes_to_string(token)] = rank 82 | 83 | i += 1 84 | if len(token) == 1: 85 | continue 86 | merged = tuple(bpe(mergeable_ranks, token, max_rank=rank)) 87 | assert len(merged) == 2 88 | merges.append(" ".join(map(token_bytes_to_string, merged))) 89 | 90 | # Also add special tokens 91 | vocab.update(encoder._special_tokens) 92 | 93 | return vocab, merges 94 | 95 | vocab, merges = generate_vocab_and_merges(encoder) 96 | 97 | added_tokens = [ 98 | { 99 | "id": id, 100 | "content": content, 101 | "single_word": False, 102 | "lstrip": False, 103 | "rstrip": False, 104 | "normalized": False, 105 | "special": True, 106 | } 107 | for content, id in encoder._special_tokens.items() 108 | ] 109 | 110 | tokenizer_template = { 111 | "version": "1.0", 112 | "truncation": None, 113 | "padding": None, 114 | "added_tokens": added_tokens, 115 | "normalizer": None, 116 | "pre_tokenizer": { 117 | "type": "Sequence", 118 | "pretokenizers": [ 119 | { 120 | "type": "Split", 121 | "pattern": {"Regex": regex_pattern}, 122 | "behavior": "Removed", 123 | "invert": True, 124 | }, 125 | { 126 | "type": "ByteLevel", 127 | "add_prefix_space": False, 128 | "trim_offsets": True, 129 | "use_regex": False, 130 | }, 131 | ], 132 | }, 133 | "post_processor": None, 134 | "decoder": { 135 | "type": "ByteLevel", 136 | "add_prefix_space": True, 137 | "trim_offsets": True, 138 | "use_regex": True, 139 | }, 140 | "model": { 141 | "type": "BPE", 142 | "dropout": None, 143 | "unk_token": None, 144 | "continuing_subword_prefix": "", 145 | "end_of_word_suffix": "", 146 | "fuse_unk": False, 147 | "byte_fallback": False, 148 | "vocab": vocab, 149 | "merges": merges, 150 | }, 151 | } 152 | 153 | with open( 154 | out_path, 155 | "w", 156 | encoding="utf-8", 157 | ) as fp: 158 | json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False) 159 | -------------------------------------------------------------------------------- /bpeasy/tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import base64 3 | from typing import Iterator 4 | 5 | import tiktoken 6 | 7 | from .bpeasy import train_bpe 8 | from .convert import convert_tiktoken_to_huggingface 9 | 10 | 11 | _DEFAULT_REGEX_PATTERN = r"""[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 12 | 13 | 14 | class BPEasyTokenizer: 15 | def __init__( 16 | self, 17 | vocab: dict[bytes, int], 18 | regex_pattern: str = _DEFAULT_REGEX_PATTERN, 19 | special_tokens: list[str] = [], 20 | fill_to_nearest_multiple_of_eight=False, 21 | name="bpeasy", 22 | ): 23 | """ 24 | Wrapper around tiktoken.Encoding 25 | Handles the loading/saving of vocab/special_tokens/regex 26 | """ 27 | 28 | self.name = name 29 | self.regex_pattern = regex_pattern 30 | self.special_tokens = special_tokens 31 | self.vocab = vocab 32 | 33 | # Sort the vocab by rank 34 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1]) 35 | 36 | # add special tokens 37 | special_token_ranks = {} 38 | for special_token in special_tokens: 39 | special_token_ranks[special_token] = len(sorted_vocab) 40 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab))) 41 | 42 | full_vocab = dict(sorted_vocab) 43 | 44 | # fill to nearest multiple of 8 45 | if fill_to_nearest_multiple_of_eight: 46 | while len(sorted_vocab) % 8 != 0: 47 | sorted_vocab.append( 48 | ( 49 | f"<|special-{len(sorted_vocab)}|>".encode("utf-8"), 50 | len(sorted_vocab), 51 | ) 52 | ) 53 | 54 | self._encoder = tiktoken.Encoding( 55 | name=name, 56 | pat_str=self.regex_pattern, 57 | mergeable_ranks=full_vocab, 58 | special_tokens=special_token_ranks, 59 | ) 60 | 61 | def encode(self, text: str, **kwargs) -> list[int]: 62 | return self._encoder.encode(text, **kwargs) 63 | 64 | def decode(self, tokens: list[int], **kwargs) -> str: 65 | return self._encoder.decode(tokens, **kwargs) 66 | 67 | @classmethod 68 | def from_file(cls, file_path: str) -> "BPEasyTokenizer": 69 | with open(file_path, "r") as file: 70 | data = json.load(file) 71 | bytes_vocab = { 72 | base64.b64decode(key): value for key, value in data["vocab"].items() 73 | } 74 | instance = cls( 75 | name=data["name"], 76 | vocab=bytes_vocab, 77 | regex_pattern=data["regex_pattern"], 78 | special_tokens=data["special_tokens"], 79 | ) 80 | return instance 81 | 82 | def save(self, file_path: str) -> None: 83 | with open(file_path, "w") as file: 84 | json.dump( 85 | { 86 | "name": self.name, 87 | "regex_pattern": self.regex_pattern, 88 | "special_tokens": self.special_tokens, 89 | "vocab": { 90 | base64.b64encode(key).decode("utf-8"): value 91 | for key, value in self.vocab.items() 92 | }, 93 | }, 94 | file, 95 | ) 96 | 97 | def export_to_huggingface_format(self, out_path: str) -> None: 98 | convert_tiktoken_to_huggingface(self._encoder, out_path, self.regex_pattern) 99 | 100 | def __len__(self) -> int: 101 | return len(self.vocab) 102 | 103 | @classmethod 104 | def train( 105 | cls, 106 | iterator: Iterator[str], 107 | vocab_size: int = 32_000, 108 | max_token_length=128, 109 | regex_pattern: str = _DEFAULT_REGEX_PATTERN, 110 | special_tokens: list[str] = [], 111 | fill_to_nearest_multiple_of_eight=False, 112 | name="bpeasy", 113 | ) -> "BPEasyTokenizer": 114 | bytes_vocab = train_bpe(iterator, regex_pattern, max_token_length, vocab_size) 115 | return cls( 116 | name=name, 117 | vocab=bytes_vocab, 118 | regex_pattern=regex_pattern, 119 | special_tokens=special_tokens, 120 | fill_to_nearest_multiple_of_eight=fill_to_nearest_multiple_of_eight, 121 | ) 122 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.3,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "bpeasy" 7 | requires-python = ">=3.8" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | "Programming Language :: Python :: 3.8", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3.13", 18 | "License :: OSI Approved :: MIT License", 19 | ] 20 | dynamic = ["version"] 21 | description = "Fast bare-bones BPE for modern tokenizer training" 22 | authors = [{name = "Gautier Dagan", email = ""}] 23 | license = "MIT" 24 | readme = "README.md" 25 | homepage = "https://github.com/gautierdag/bpeasy" 26 | repository = "https://github.com/gautierdag/bpeasy" 27 | include = [ 28 | "LICENSE", 29 | ] 30 | keywords = ["tokenizer", "tokenization", "bpe"] 31 | dependencies = [ 32 | "tiktoken>=0.4.0", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | dev = ["pytest", "pytest-cov", "black", "tokenizers", "tqdm"] 37 | 38 | [tool.maturin] 39 | features = ["pyo3/extension-module"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest>=7.1.2 2 | pytest-cov>=3.0.0 3 | maturin>=0.12.14 4 | tiktoken -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use fancy_regex::Regex; 2 | use fxhash::FxHashMap as HashMap; 3 | use fxhash::FxHashSet as HashSet; 4 | use pyo3::exceptions; 5 | use pyo3::prelude::*; 6 | use pyo3::types::{PyBytes, PyDict, PyIterator, PyString}; 7 | use rayon::prelude::*; 8 | use std::cmp::Ordering; 9 | use std::collections::BinaryHeap; 10 | 11 | type Pair = (u32, u32); 12 | 13 | #[derive(Debug, Eq)] 14 | struct Merge { 15 | pair: Pair, 16 | count: i64, 17 | pos: HashSet, 18 | } 19 | impl PartialEq for Merge { 20 | fn eq(&self, other: &Self) -> bool { 21 | self.count == other.count && self.pair == other.pair 22 | } 23 | } 24 | impl PartialOrd for Merge { 25 | fn partial_cmp(&self, other: &Self) -> Option { 26 | Some(self.cmp(other)) 27 | } 28 | } 29 | impl Ord for Merge { 30 | fn cmp(&self, other: &Self) -> Ordering { 31 | if self.count != other.count { 32 | self.count.cmp(&other.count) 33 | } else { 34 | // Here we want ascending order 35 | other.pair.cmp(&self.pair) 36 | } 37 | } 38 | } 39 | 40 | #[derive(Debug, Clone, Copy)] 41 | struct Symbol { 42 | c: u32, 43 | prev: isize, 44 | next: isize, 45 | len: usize, 46 | } 47 | 48 | #[derive(Debug)] 49 | struct Sentence { 50 | symbols: Vec, 51 | } 52 | 53 | impl Sentence { 54 | fn new() -> Self { 55 | Sentence { symbols: vec![] } 56 | } 57 | 58 | fn add(&mut self, c: u32, byte_len: usize) { 59 | let (prev, next) = { 60 | let len: isize = self.symbols.len() as isize; 61 | if let Some(last) = self.symbols.last_mut() { 62 | // Update `next` on the previous one 63 | last.next = len; 64 | (len - 1, -1) 65 | } else { 66 | (-1, -1) 67 | } 68 | }; 69 | self.symbols.push(Symbol { 70 | c, 71 | prev, 72 | next, 73 | len: byte_len, 74 | }); 75 | } 76 | 77 | fn merge(&mut self, c1: u32, c2: u32, replacement: u32, max_length: usize) -> Vec<(Pair, i64)> { 78 | let mut changes: Vec<(Pair, i64)> = vec![]; 79 | let mut i = 0; 80 | loop { 81 | if i >= self.symbols.len() { 82 | break; 83 | } 84 | 85 | // Found a pair 86 | if self.symbols[i].c == c1 && i + 1 < self.symbols.len() && self.symbols[i + 1].c == c2 87 | { 88 | let first = self.symbols[i]; 89 | let second = self.symbols[i + 1]; 90 | 91 | // Remove in place 92 | let new_s = Symbol { 93 | c: replacement, 94 | prev: first.prev, 95 | next: second.next, 96 | len: first.len + second.len, 97 | }; 98 | 99 | // If there are other characters before the pair 100 | if i > 0 { 101 | changes.push(((self.symbols[i - 1].c, first.c), -1)); 102 | if self.symbols[i - 1].len + new_s.len < max_length { 103 | changes.push(((self.symbols[i - 1].c, replacement), 1)); 104 | } 105 | } 106 | 107 | self.symbols.insert(i, new_s); // Insert replacement before first char of pair 108 | self.symbols.remove(i + 1); // Remove first char of pair 109 | self.symbols.remove(i + 1); // And then the second 110 | 111 | // If there are other characters after the pair 112 | if i < self.symbols.len() - 1 { 113 | changes.push(((second.c, self.symbols[i + 1].c), -1)); 114 | if self.symbols[i + 1].len + new_s.len < max_length { 115 | changes.push(((replacement, self.symbols[i + 1].c), 1)); 116 | } 117 | } 118 | } 119 | i += 1; 120 | } 121 | changes 122 | } 123 | 124 | fn get_symbols(&self) -> Vec { 125 | self.symbols.iter().map(|s| s.c).collect() 126 | } 127 | 128 | fn from_str(s: &str) -> Self { 129 | let mut sentence = Sentence::new(); 130 | for byte in s.bytes() { 131 | sentence.add(byte as u32, 1); 132 | } 133 | sentence 134 | } 135 | } 136 | 137 | fn pretokenize<'a>(text: &'a str, regex: &Regex) -> Vec<&'a str> { 138 | regex 139 | .find_iter(text) 140 | .filter_map(|mat| match mat { 141 | Ok(m) => Some(m.as_str()), 142 | Err(_) => None, 143 | }) 144 | .collect() 145 | } 146 | 147 | fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec) { 148 | let regex: Regex = Regex::new(pattern).expect("Invalid regex pattern"); 149 | // Tokenize strings in parallel 150 | let (tokens, counts): (Vec<&str>, Vec) = strings 151 | .par_iter() 152 | .flat_map(|&text| pretokenize(text, ®ex)) 153 | .fold( 154 | || HashMap::<&str, u64>::default(), 155 | |mut acc, token| { 156 | *acc.entry(token).or_insert(0) += 1; 157 | acc 158 | }, 159 | ) 160 | .reduce( 161 | || HashMap::<&str, u64>::default(), 162 | |mut a, b| { 163 | for (token, count) in b { 164 | *a.entry(token).or_insert(0) += count; 165 | } 166 | a 167 | }, 168 | ) 169 | .into_iter() 170 | .unzip(); 171 | 172 | // Convert tokens to sentences and filter sentences and counts to remove single byte sentences 173 | let (filtered_sentences, filtered_counts): (Vec, Vec) = tokens 174 | .into_iter() 175 | .map(Sentence::from_str) 176 | .zip(counts.into_iter()) 177 | .filter(|(sentence, _)| sentence.symbols.len() > 1) 178 | .unzip(); 179 | 180 | (filtered_sentences, filtered_counts) 181 | } 182 | 183 | fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec>) { 184 | let mut word_to_id: HashMap, u32> = HashMap::default(); 185 | let mut id_to_word: Vec> = Vec::with_capacity(vocab_size); 186 | for i in 0..=255 { 187 | word_to_id.insert(vec![i], i as u32); 188 | id_to_word.push(vec![i]); 189 | } 190 | return (word_to_id, id_to_word); 191 | } 192 | 193 | fn get_most_frequent_pair( 194 | tokenized_sentences: &[Sentence], 195 | base_counts: &[u64], 196 | ) -> (HashMap, HashMap>) { 197 | // Calculate frequencies for each pair of bytes in all sentences and words 198 | tokenized_sentences 199 | .par_iter() 200 | .enumerate() 201 | .map(|(i, sentence)| { 202 | let mut local_pair_counts = HashMap::::default(); 203 | let mut local_pair_positions: HashMap> = HashMap::default(); 204 | 205 | for window in sentence.get_symbols().windows(2) { 206 | let current_pair: Pair = (window[0], window[1]); 207 | // First update counts 208 | local_pair_counts 209 | .entry(current_pair) 210 | .and_modify(|c| *c += base_counts[i] as i64) 211 | .or_insert(base_counts[i] as i64); 212 | 213 | // Then update position 214 | local_pair_positions 215 | .entry(current_pair) 216 | .and_modify(|h: &mut HashSet| { 217 | h.insert(i); 218 | }) 219 | .or_insert_with(|| { 220 | let mut h = HashSet::::default(); 221 | h.insert(i); 222 | h 223 | }); 224 | } 225 | (local_pair_counts, local_pair_positions) 226 | }) 227 | .reduce( 228 | || { 229 | ( 230 | HashMap::::default(), 231 | HashMap::>::default(), 232 | ) 233 | }, 234 | |(mut global_pair_counts, mut global_pair_positions), (pc, wtu)| { 235 | // Merge the pair counts and positions from all sentences 236 | for (k, v) in pc { 237 | global_pair_counts 238 | .entry(k) 239 | .and_modify(|c| *c += v) 240 | .or_insert(v); 241 | } 242 | for (k, v) in wtu { 243 | global_pair_positions 244 | .entry(k) 245 | .and_modify(|set| *set = set.union(&v).copied().collect()) 246 | .or_insert(v); 247 | } 248 | (global_pair_counts, global_pair_positions) 249 | }, 250 | ) 251 | } 252 | 253 | // Build vocab from most frequent pairs 254 | fn build_bpe_vocab( 255 | tokenized_sentences: Vec, 256 | base_counts: &[u64], 257 | max_token_length: usize, 258 | vocab_size: usize, 259 | ) -> HashMap, u32> { 260 | let (mut word_to_id, mut id_to_word) = initialize_vocab_bytes(vocab_size); 261 | 262 | // get most frequent pair 263 | let (mut global_pair_counts, mut global_pair_positions) = 264 | get_most_frequent_pair(&tokenized_sentences, &base_counts); 265 | 266 | // build Priority Queue from counts and positions 267 | let mut queue: BinaryHeap = BinaryHeap::new(); 268 | global_pair_positions.drain().for_each(|(pair, pos)| { 269 | let count: i64 = global_pair_counts[&pair]; 270 | if count > 0 { 271 | queue.push(Merge { pair, count, pos }); 272 | } 273 | }); 274 | 275 | while word_to_id.len() < vocab_size { 276 | // check if queue is empty 277 | if queue.is_empty() { 278 | break; 279 | } 280 | 281 | let mut top = queue.pop().unwrap(); 282 | // check if count has changed 283 | if top.count != global_pair_counts[&top.pair] { 284 | top.count = global_pair_counts[&top.pair]; 285 | queue.push(top); 286 | continue; 287 | } 288 | 289 | // exit count is 0 290 | if top.count < 1 { 291 | break; 292 | } 293 | 294 | // add to vocab 295 | let (left, right) = top.pair; 296 | let merged_id = word_to_id.len() as u32; 297 | 298 | let mut word = id_to_word[left as usize].clone(); 299 | let right_word = id_to_word[right as usize].clone(); 300 | word.extend(right_word.iter()); 301 | word_to_id.insert(word.clone(), merged_id); 302 | id_to_word.push(word); 303 | 304 | // update counts and positions for each sentence 305 | let changes = top 306 | .pos 307 | .par_iter() 308 | .flat_map(|&i| { 309 | let sentence = &tokenized_sentences[i] as *const _ as *mut Sentence; 310 | // We can merge each of these sentences in parallel here because each position 311 | // can be there only once (HashSet). So this is safe. 312 | unsafe { 313 | (*sentence) 314 | .merge(top.pair.0, top.pair.1, merged_id, max_token_length) 315 | .into_iter() 316 | .map(|c| (c, i)) 317 | .collect::>() 318 | } 319 | }) 320 | .collect::>(); 321 | 322 | for ((pair, change), iw) in changes { 323 | // adjust count to reflect sentence level count 324 | let count = change * base_counts[iw] as i64; 325 | global_pair_counts 326 | .entry(pair) 327 | .and_modify(|c| *c += count) 328 | .or_insert(count); 329 | if count > 0 { 330 | global_pair_positions 331 | .entry(pair) 332 | .and_modify(|h| { 333 | h.insert(iw); 334 | }) 335 | .or_insert_with(|| { 336 | let mut h = HashSet::::default(); 337 | h.insert(iw); 338 | h 339 | }); 340 | } 341 | } 342 | 343 | // update queue 344 | global_pair_positions.drain().for_each(|(pair, pos)| { 345 | let count = global_pair_counts[&pair]; 346 | if count > 0 { 347 | queue.push(Merge { pair, count, pos }); 348 | } 349 | }); 350 | } 351 | word_to_id 352 | } 353 | 354 | // Train BPE from Iterator 355 | #[pyfunction] 356 | fn train_bpe( 357 | py: Python, 358 | iterator: &PyIterator, 359 | python_regex: &PyString, 360 | max_token_length: usize, 361 | vocab_size: usize, 362 | ) -> PyResult { 363 | let regex = python_regex.to_str()?; 364 | 365 | // validate inputs 366 | if max_token_length < 2 { 367 | return Err(exceptions::PyValueError::new_err( 368 | "max_token_length must be greater than 1", 369 | )); 370 | } 371 | if vocab_size < 256 { 372 | return Err(exceptions::PyValueError::new_err( 373 | "vocab_size must be greater than 256", 374 | )); 375 | } 376 | if regex.is_empty() { 377 | return Err(exceptions::PyValueError::new_err("regex cannot be empty")); 378 | } 379 | 380 | // Extract strings from Python iterator and store them in a Rust Vec for parallel processing 381 | let strings: Vec<&str> = iterator 382 | .filter_map(|item_result| { 383 | item_result.ok().and_then(|item| { 384 | item.extract::<&PyString>() 385 | .ok() 386 | .and_then(|py_string| py_string.to_str().ok()) 387 | }) 388 | }) 389 | .filter(|text| !text.is_empty()) 390 | .collect(); 391 | 392 | let (pretokenized_sentences, counts): (Vec, Vec) = 393 | pretokenize_strings(strings, regex); 394 | 395 | let bpe_vocab = build_bpe_vocab( 396 | pretokenized_sentences, 397 | &counts, 398 | max_token_length, 399 | vocab_size, 400 | ); 401 | let python_dict_out = PyDict::new(py); 402 | // convert bpe_vocab to python dict 403 | for (key, value) in bpe_vocab { 404 | let py_key = PyBytes::new(py, &key); 405 | python_dict_out.set_item(py_key, value)?; 406 | } 407 | Ok(python_dict_out.into()) 408 | } 409 | 410 | /// bpeasy is a bare-bones implementation of byte-pair encoding (BPE) in Rust. 411 | /// It is designed to be used as a Python module and returns a byte-pair vocabulary 412 | /// as a Python dictionary. 413 | #[pymodule] 414 | fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> { 415 | m.add_function(wrap_pyfunction!(train_bpe, m)?)?; 416 | Ok(()) 417 | } 418 | 419 | #[cfg(test)] 420 | mod tests { 421 | #[test] 422 | fn test_all() { 423 | let text: &str = "\tYou hear a £ £ £ here"; 424 | let pattern = r"([^\s]+)|(\s+)"; 425 | let compiled_regex: fancy_regex::Regex = 426 | fancy_regex::Regex::new(pattern).expect("Invalid regex pattern"); 427 | let pretokenized_sentences = crate::pretokenize(text, &compiled_regex); 428 | assert_eq!( 429 | pretokenized_sentences, 430 | vec!["\t", "You", " ", "hear", " ", "a", " ", "£", " ", "£", " ", "£", " ", "here"] 431 | ); 432 | 433 | let text_2: &str = "You hear £ £ £ here"; 434 | 435 | let (pretokenized_sentences, _counts) = 436 | crate::pretokenize_strings(vec![text, text_2], pattern); 437 | 438 | let vocab_size = 300; 439 | let max_token_length = 128; 440 | crate::build_bpe_vocab( 441 | pretokenized_sentences, 442 | &_counts, 443 | max_token_length, 444 | vocab_size, 445 | ); 446 | } 447 | 448 | #[test] 449 | fn test_initialize_vocab_bytes() { 450 | let vocab = crate::initialize_vocab_bytes(400); 451 | assert_eq!(vocab.0.len(), 256); 452 | } 453 | } 454 | -------------------------------------------------------------------------------- /tests/test_convert.py: -------------------------------------------------------------------------------- 1 | from bpeasy.convert import bpe 2 | 3 | 4 | def test_bpe_function(): 5 | mergeable_ranks = {b"ab": 0, b"bc": 1, b"cd": 2} 6 | token = b"abcd" 7 | result = bpe(mergeable_ranks, token) 8 | assert result == [ 9 | b"ab", 10 | b"cd", 11 | ], "The bpe function did not split the token correctly" 12 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from unittest import mock 4 | from bpeasy.tokenizer import BPEasyTokenizer 5 | 6 | 7 | def test_initialization(): 8 | vocab = {b"hello": 1, b"world": 2} 9 | tokenizer = BPEasyTokenizer(vocab=vocab) 10 | assert tokenizer.vocab == vocab 11 | assert tokenizer.name == "bpeasy" 12 | assert len(tokenizer.special_tokens) == 0 13 | assert len(tokenizer) == 2 14 | 15 | 16 | def test_encode_decode(): 17 | vocab = {b"hello": 1, b" world": 2} 18 | tokenizer = BPEasyTokenizer(vocab=vocab) 19 | encoded = tokenizer.encode("hello world", allowed_special="all") 20 | assert encoded == [1, 2] 21 | decoded = tokenizer.decode(encoded) 22 | assert decoded == "hello world" 23 | 24 | 25 | def test_save_and_load(): 26 | vocab = {b"hello": 1, b" world": 2} 27 | tokenizer = BPEasyTokenizer(vocab=vocab) 28 | 29 | # Test saving 30 | with mock.patch("builtins.open", mock.mock_open()) as mock_file: 31 | tokenizer.save("dummy_path.json") 32 | mock_file.assert_called_once_with("dummy_path.json", "w") 33 | 34 | # Prepare dummy file content for loading 35 | dummy_file_content = json.dumps( 36 | { 37 | "name": "bpeasy", 38 | "vocab": { 39 | base64.b64encode(key).decode("utf-8"): value 40 | for key, value in vocab.items() 41 | }, 42 | "regex_pattern": tokenizer.regex_pattern, 43 | "special_tokens": tokenizer.special_tokens, 44 | } 45 | ) 46 | 47 | # Test loading 48 | with mock.patch( 49 | "builtins.open", mock.mock_open(read_data=dummy_file_content) 50 | ) as mock_file: 51 | loaded_tokenizer = BPEasyTokenizer.from_file("dummy_path.json") 52 | assert loaded_tokenizer.vocab == vocab 53 | 54 | 55 | @mock.patch("builtins.open", new_callable=mock.mock_open) 56 | @mock.patch("json.dump") 57 | def test_conversion_to_huggingface(mock_json_dump, mock_open): 58 | vocab = { 59 | b"h": 0, 60 | b"e": 1, 61 | b"l": 2, 62 | b"o": 3, 63 | b" ": 4, 64 | b"w": 5, 65 | b"r": 6, 66 | b"d": 7, 67 | b"he": 8, 68 | b"ll": 9, 69 | b"llo": 10, 70 | b"hello": 11, 71 | b"wo": 12, 72 | b"wor": 13, 73 | b"ld": 14, 74 | b"world": 15, 75 | b" world": 16, 76 | } 77 | tokenizer = BPEasyTokenizer(vocab=vocab) 78 | tokenizer.export_to_huggingface_format("dummy_path.json") 79 | mock_open.assert_called_once_with("dummy_path.json", "w", encoding="utf-8") 80 | mock_json_dump.assert_called_once() 81 | args, _ = mock_json_dump.call_args 82 | assert args[0]["model"]["type"] == "BPE" 83 | 84 | 85 | @mock.patch("builtins.open", new_callable=mock.mock_open) 86 | @mock.patch("json.dump") 87 | def test_conversion_to_huggingface_with_special_tokens(mock_json_dump, mock_open): 88 | vocab = { 89 | b"h": 0, 90 | b"e": 1, 91 | b"l": 2, 92 | b"o": 3, 93 | b" ": 4, 94 | b"w": 5, 95 | b"r": 6, 96 | b"d": 7, 97 | b"he": 8, 98 | b"ll": 9, 99 | b"llo": 10, 100 | b"hello": 11, 101 | b"wo": 12, 102 | b"wor": 13, 103 | b"ld": 14, 104 | b"world": 15, 105 | b" world": 16, 106 | } 107 | tokenizer = BPEasyTokenizer(vocab=vocab, special_tokens=["<|special-0|>", ""]) 108 | tokenizer.export_to_huggingface_format("dummy_path.json") 109 | mock_open.assert_called_once_with("dummy_path.json", "w", encoding="utf-8") 110 | mock_json_dump.assert_called_once() 111 | args, _ = mock_json_dump.call_args 112 | assert args[0]["model"]["type"] == "BPE" 113 | -------------------------------------------------------------------------------- /tests/test_train_bpe.py: -------------------------------------------------------------------------------- 1 | from bpeasy import train_bpe 2 | 3 | 4 | def test_train_bpe_vocab_size(): 5 | vocab_size = 300 6 | max_token_length = 4 7 | regex = r"([^\s]+)|(\s+)" 8 | vocab = train_bpe( 9 | iter(["This is a test", "this is another test", "good tests"]), 10 | regex, 11 | max_token_length, 12 | vocab_size, 13 | ) 14 | assert len(vocab) == 267 15 | 16 | 17 | def test_train_bpe_max_token_length(): 18 | vocab_size = 300 19 | max_token_length = 2 20 | regex = r"([^\s]+)|(\s+)" 21 | vocab = train_bpe( 22 | iter(["This is a test", "this is another test", "good tests"]), 23 | regex, 24 | max_token_length, 25 | vocab_size, 26 | ) 27 | for token in vocab: 28 | assert len(token) <= max_token_length 29 | max_token_length = 3 30 | vocab = train_bpe( 31 | iter(["This is a test", "this is another test", "good tests"]), 32 | regex, 33 | max_token_length, 34 | vocab_size, 35 | ) 36 | for token in vocab: 37 | assert len(token) <= max_token_length 38 | 39 | 40 | def test_train_bpe_gpt_regex(): 41 | vocab_size = 300 42 | max_token_length = 128 43 | regex = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 44 | vocab = train_bpe( 45 | iter(["We've got a test", "We've got good test", "this is a good tests"]), 46 | regex, 47 | max_token_length, 48 | vocab_size, 49 | ) 50 | for token in vocab: 51 | assert len(token) <= max_token_length 52 | 53 | print(vocab) 54 | assert b" go" in vocab.keys() 55 | assert b"'ve" in vocab.keys() 56 | -------------------------------------------------------------------------------- /tests/test_utils_bpe.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from unittest.mock import mock_open, patch, call 3 | 4 | from bpeasy import save_vocab_to_tiktoken 5 | 6 | 7 | def test_basic_functionality(): 8 | vocab = {b"hello": 1, b"world": 2} 9 | with patch("builtins.open", mock_open()) as mock_file: 10 | save_vocab_to_tiktoken(vocab, "path/to/file.txt") 11 | mock_file.assert_called_once_with("path/to/file.txt", "wb") 12 | # Check if the sorted vocab is written correctly 13 | expected_content = [ 14 | base64.b64encode(b"hello") + b" 1\n", 15 | base64.b64encode(b"world") + b" 2\n", 16 | ] 17 | mock_file().write.assert_has_calls( 18 | [call(content) for content in expected_content] 19 | ) 20 | 21 | 22 | def test_special_tokens_addition(): 23 | vocab = {b"token": 0} 24 | special_tokens = ["special1", "special2"] 25 | with patch("builtins.open", mock_open()) as mock_file: 26 | save_vocab_to_tiktoken(vocab, "path/to/file.txt", special_tokens) 27 | # Check if special tokens are added correctly 28 | expected_content = [ 29 | base64.b64encode(b"token") + b" 0\n", 30 | base64.b64encode("special1".encode("utf-8")) + b" 1\n", 31 | base64.b64encode("special2".encode("utf-8")) + b" 2\n", 32 | ] 33 | mock_file().write.assert_has_calls( 34 | [call(content) for content in expected_content] 35 | ) 36 | 37 | 38 | def test_fill_to_nearest_multiple_of_eight(): 39 | vocab = {b"token": 0} 40 | with patch("builtins.open", mock_open()) as mock_file: 41 | save_vocab_to_tiktoken( 42 | vocab, "path/to/file.txt", fill_to_nearest_multiple_of_eight=True 43 | ) 44 | # Verify that additional tokens are added to make the count a multiple of eight 45 | mock_file().write.assert_called() 46 | # Check the exact content based on your logic of filling to nearest multiple of eight 47 | assert mock_file().write.call_count == 8 48 | --------------------------------------------------------------------------------