├── .github
└── workflows
│ ├── CI.yml
│ └── test.yml
├── .gitignore
├── CITATION.cff
├── Cargo.toml
├── LICENSE
├── README.md
├── benchmarks
├── README.md
├── data
│ └── c4.jsonl
└── train.py
├── bpeasy
├── __init__.py
├── bpeasy.pyi
├── convert.py
└── tokenizer.py
├── pyproject.toml
├── requirements.txt
├── src
└── lib.rs
└── tests
├── test_convert.py
├── test_tokenizer.py
├── test_train_bpe.py
└── test_utils_bpe.py
/.github/workflows/CI.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | test:
12 | name: Pre-test for release
13 | strategy:
14 | matrix:
15 | os: ["ubuntu"]
16 | python-version: ["3.10", "3.13"]
17 | runs-on: ${{ matrix.os }}-latest
18 | steps:
19 | - uses: actions/checkout@v4
20 | - uses: actions/setup-python@v5
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | cache: "pip"
24 | - uses: dtolnay/rust-toolchain@stable
25 | - name: Setup virtual environment
26 | run: |
27 | python -m venv venv
28 | source venv/bin/activate
29 | pip install -r requirements.txt
30 | - name: Run tests
31 | run: |
32 | source venv/bin/activate
33 | cargo test
34 | maturin develop --release
35 | pytest tests
36 |
37 | linux:
38 | needs: [test]
39 | runs-on: ubuntu-latest
40 | strategy:
41 | matrix:
42 | target: [x86_64, x86, aarch64, armv7, s390x, ppc64le]
43 | python-version: ["3.10", "3.13"]
44 | steps:
45 | - uses: actions/checkout@v4
46 | - uses: actions/setup-python@v5
47 | with:
48 | python-version: ${{ matrix.python-version }}
49 | - name: Build wheels
50 | uses: PyO3/maturin-action@v1
51 | with:
52 | target: ${{ matrix.target }}
53 | args: --release --out dist --find-interpreter
54 | sccache: "true"
55 | manylinux: auto
56 | - name: Upload wheels
57 | uses: actions/upload-artifact@v4
58 | with:
59 | name: dist-linux-${{ matrix.target }}-py${{ matrix.python-version }}
60 | path: dist
61 |
62 | windows:
63 | needs: [test]
64 | runs-on: windows-latest
65 | strategy:
66 | matrix:
67 | target: [x64, x86]
68 | python-version: ["3.10", "3.13"]
69 | steps:
70 | - uses: actions/checkout@v4
71 | - uses: actions/setup-python@v5
72 | with:
73 | python-version: ${{ matrix.python-version }}
74 | architecture: ${{ matrix.target }}
75 | - name: Build wheels
76 | uses: PyO3/maturin-action@v1
77 | with:
78 | target: ${{ matrix.target }}
79 | args: --release --out dist --find-interpreter
80 | sccache: "true"
81 | - name: Upload wheels
82 | uses: actions/upload-artifact@v4
83 | with:
84 | name: dist-windows-${{ matrix.target }}-py${{ matrix.python-version }}
85 | path: dist
86 |
87 | macos:
88 | needs: [test]
89 | runs-on: macos-latest
90 | strategy:
91 | matrix:
92 | target: [x86_64, aarch64]
93 | python-version: ["3.10", "3.13"]
94 | steps:
95 | - uses: actions/checkout@v4
96 | - uses: actions/setup-python@v5
97 | with:
98 | python-version: ${{ matrix.python-version }}
99 | - name: Build wheels
100 | uses: PyO3/maturin-action@v1
101 | with:
102 | target: ${{ matrix.target }}
103 | args: --release --out dist --find-interpreter
104 | sccache: "true"
105 | - name: Upload wheels
106 | uses: actions/upload-artifact@v4
107 | with:
108 | name: dist-macos-${{ matrix.target }}-py${{ matrix.python-version }}
109 | path: dist
110 |
111 | sdist:
112 | needs: [test]
113 | runs-on: ubuntu-latest
114 | steps:
115 | - uses: actions/checkout@v4
116 | - name: Build sdist
117 | uses: PyO3/maturin-action@v1
118 | with:
119 | command: sdist
120 | args: --out dist
121 | - name: Upload sdist
122 | uses: actions/upload-artifact@v4
123 | with:
124 | name: dist-sdist
125 | path: dist
126 |
127 | release:
128 | name: Release
129 | runs-on: ubuntu-latest
130 | needs: [linux, windows, macos, sdist]
131 | if: "startsWith(github.ref, 'refs/tags/')"
132 | permissions:
133 | id-token: write
134 | steps:
135 | - uses: actions/download-artifact@v4
136 | with:
137 | pattern: dist-*
138 | merge-multiple: true
139 | path: dist
140 | - name: Publish to PyPI
141 | uses: PyO3/maturin-action@v1
142 | with:
143 | command: upload
144 | args: --non-interactive --skip-existing dist/*
145 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | workflow_dispatch:
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | coverage:
15 | name: Coverage for ${{ matrix.os }} with Python ${{ matrix.python-version }}
16 | strategy:
17 | matrix:
18 | os: ["ubuntu"]
19 | python-version: ["3.10", "3.13"]
20 | runs-on: ${{ matrix.os }}-latest
21 | steps:
22 | - uses: actions/checkout@v4
23 | - uses: actions/setup-python@v5
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | cache: "pip"
27 | - uses: dtolnay/rust-toolchain@stable
28 | - name: Install cargo-llvm-cov
29 | uses: taiki-e/install-action@cargo-llvm-cov
30 | - uses: Swatinem/rust-cache@v2.7.0
31 | with:
32 | key: coverage-cargo-${{ matrix.os }}-py${{ matrix.python-version }}
33 | continue-on-error: true
34 | - name: Setup virtual environment
35 | run: |
36 | python -m venv venv
37 | source venv/bin/activate
38 | pip install -r requirements.txt
39 | - name: Run coverage
40 | run: |
41 | source venv/bin/activate
42 | source <(cargo llvm-cov show-env --export-prefix)
43 | export CARGO_TARGET_DIR=$CARGO_LLVM_COV_TARGET_DIR
44 | export CARGO_INCREMENTAL=1
45 | cargo llvm-cov clean --workspace
46 | cargo test
47 | maturin develop
48 | pytest tests --cov=bpeasy --cov-report xml
49 | cargo llvm-cov report --lcov --output-path coverage.lcov
50 | - uses: codecov/codecov-action@v3
51 | env:
52 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
53 | with:
54 | files: coverage.lcov,coverage.xml
55 | name: ${{ matrix.os }}-py${{ matrix.python-version }}
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | out/
2 | .nox/
3 | .benchmarks/
4 |
5 | # Generated by Cargo
6 | # will have compiled files and executables
7 | debug/
8 | target/
9 |
10 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
11 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
12 | Cargo.lock
13 |
14 | # These are backup files generated by rustfmt
15 | **/*.rs.bk
16 |
17 | # MSVC Windows builds of rustc generate these, which store debugging information
18 | *.pdb
19 |
20 | # Added by cargo
21 | /target
22 |
23 |
24 | # Byte-compiled / optimized / DLL files
25 | __pycache__/
26 | .pytest_cache/
27 | *.py[cod]
28 |
29 | # C extensions
30 | *.so
31 |
32 | # Distribution / packaging
33 | .Python
34 | .venv/
35 | env/
36 | bin/
37 | build/
38 | develop-eggs/
39 | dist/
40 | eggs/
41 | lib/
42 | lib64/
43 | parts/
44 | sdist/
45 | var/
46 | include/
47 | man/
48 | venv/
49 | *.egg-info/
50 | .installed.cfg
51 | *.egg
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 | pip-selfcheck.json
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .coverage
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 |
66 | # Translations
67 | *.mo
68 |
69 | # Mr Developer
70 | .mr.developer.cfg
71 | .project
72 | .pydevproject
73 |
74 | # Rope
75 | .ropeproject
76 |
77 | # Django stuff:
78 | *.log
79 | *.pot
80 |
81 | .DS_Store
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyCharm
87 | .idea/
88 |
89 | # VSCode
90 | .vscode/
91 |
92 | # Pyenv
93 | .python-version
94 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: bpeasy
3 | message: >-
4 | If you use this software, please cite it using the
5 | metadata from this file.
6 | type: software
7 | authors:
8 | - given-names: Gautier
9 | family-names: Dagan
10 | email: gautier.dagan@ed.ac.uk
11 | affiliation: University of Edinburgh
12 | orcid: 'https://orcid.org/0000-0002-1867-4201'
13 | repository-code: 'https://github.com/gautierdag/bpeasy'
14 | url: 'https://github.com/gautierdag/bpeasy'
15 | license: MIT
16 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "bpeasy"
3 | version = "0.1.5"
4 | edition = "2021"
5 |
6 | [lib]
7 | name = "bpeasy"
8 | crate-type = ["cdylib"]
9 |
10 | [profile.release]
11 | opt-level = 3
12 | lto = "fat"
13 | codegen-units = 1
14 |
15 | [dependencies]
16 | fancy-regex = "0.12.0"
17 | fxhash = "0.2.1"
18 | pyo3 = { version = "0.19.0", features = ["extension-module"] }
19 | rayon = "1.8.0"
20 | regex = "1.5.4"
21 | serde_json = "1.0.108"
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Gautier Dagan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bpeasy
2 |
3 | [](https://codecov.io/gh/gautierdag/bpeasy) [](https://github.com/gautierdag/bpeasy/actions/workflows/test.yml) [](https://pypi.python.org/pypi/bpeasy) [](https://pypi.python.org/pypi/bpeasy) [](https://badge.fury.io/py/bpeasy)
4 |
5 | ## Overview
6 |
7 | `bpeasy` is a Python package that provides a tokenizer trainer, implementing in 400 lines of rust an efficient version of Byte Pair Encoding (BPE). The implementation largely follows the huggingface `tokenizers` library, but makes opinionated decisions to simplify the tokenizer training specifically to:
8 |
9 | 1. Treat text data at the byte-level first --- all text is converted to bytes before training rather than using a character-level approach (like in Huggingface).
10 | 2. Always use a regex-based split pre-tokenizer. This is a customisable regex that is applied to the text before training. This regex decides where to split the text and limits what kind of tokens are possible. This is technically possible in Huggingface but is not well documented. We also use the `fancy-regex` crate which supports a richer set of regex features than the `regex` crate used in Huggingface.
11 | 3. Use `int64` types for counting to allow for training on much larger datasets without the risk of overflow.
12 |
13 | **You can think of `bpeasy` as the `tiktoken` training code that never was.**
14 |
15 | See the [benchmarks](/benchmarks/README.md) section for a comparison with the Huggingface library.
16 |
17 | ## Installation
18 |
19 | Simply install the package using pip:
20 |
21 | ```bash
22 | pip install bpeasy
23 | ```
24 |
25 | ## Training
26 |
27 | The training function is designed to be bare-bones and returns the trained tokenizer vocab as a dictionary of bytes to integers. This is to allow for maximum flexibility in how you want to use the tokenizer. For example, you can use then port these to tiktoken or Huggingface tokenizers (see below).
28 |
29 | ```python
30 | # should be an iterator over str
31 | iterator = jsonl_content_iterator(args)
32 | # example regex from GPT-4
33 | regex_pattern = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
34 |
35 | # returns the vocab (dict[bytes, int])
36 | vocab = bpeasy.train_bpe(
37 | iterator,
38 | regex_pattern,
39 | args.max_sentencepiece_length, # max length of tokens
40 | args.vocab_size, # max size of vocab
41 | )
42 | ```
43 |
44 | Alternatively, you can also train using the basic tokenizer class provided:
45 |
46 | ```python
47 | from bpeasy.tokenizer import BPEasyTokenizer
48 |
49 | tokenizer = BPEasyTokenizer.train(
50 | iterator, # iterator over str
51 | vocab_size=vocab_size,
52 | max_token_length=max_token_length,
53 | regex_pattern=regex_pattern,
54 | special_tokens=["", "", ""],
55 | fill_to_nearest_multiple_of_eight=True,
56 | name="bpeasy",
57 | )
58 | ```
59 |
60 | ### Encoding/Decoding
61 |
62 | To test your tokenizer you can use the `BPEasyTokenizer` class, which is a wrapper around the `tiktoken.Encoding` module, simplifying the handling of vocabularies, special tokens, and regex patterns for tokenization.
63 |
64 | ```python
65 | from bpeasy.tokenizer import BPEasyTokenizer
66 |
67 | your_special_tokens = ["", "", ""]
68 |
69 | tokenizer = BPEasyTokenizer(
70 | vocab=vocab,
71 | regex_pattern=regex_pattern,
72 | special_tokens=your_special_tokens,
73 | fill_to_nearest_multiple_of_eight=True, # pad vocab to multiple of 8
74 | name="bpeasy" # optional name for the tokenizer
75 | )
76 |
77 | test = "hello_world"
78 |
79 | # encode and decode uses the tiktoken functions
80 | encoded = tokenizer.encode(test)
81 | decoded = tokenizer.decode(encoded)
82 | > "hello_world"
83 | ```
84 |
85 | You can also use `tiktoken` directly, but you would need to handle the special tokens and regex pattern yourself:
86 |
87 | ```python
88 | import tiktoken
89 |
90 | vocab = bpeasy.train_bpe(...)
91 | special_tokens = ["", "", ""]
92 |
93 | # Sort the vocab by rank
94 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1])
95 |
96 | # add special tokens
97 | special_token_ranks = {}
98 | for special_token in special_tokens:
99 | special_token_ranks[special_token] = len(sorted_vocab)
100 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab)))
101 |
102 | full_vocab = dict(sorted_vocab)
103 |
104 | encoder = tiktoken.Encoding(
105 | name=name,
106 | pat_str=regex_pattern,
107 | mergeable_ranks=full_vocab,
108 | special_tokens=special_token_ranks,
109 | )
110 | ```
111 |
112 | ### Save/Load tokenizer from file
113 |
114 | We provide basic utility functions to save and load the tokenizer from a json file.
115 |
116 | ```python
117 | tokenizer.save("path_to_file.json")
118 |
119 | tokenizer = BPEasyTokenizer.from_file("path_to_file.json")
120 | ```
121 |
122 | ### Export to HuggingFace format
123 |
124 | We also support exporting the tokenizer to the HuggingFace format, which can then be used directly with the HuggingFace `transformers` library.
125 |
126 | ```python
127 | from bpeasy.tokenizer import BPEasyTokenizer
128 | from trans
129 | tokenizer = BPEasyTokenizer(
130 | ...
131 | )
132 |
133 | tokenizer.export_to_huggingface_format("hf_tokenizer.json")
134 |
135 | from transformers import PreTrainedTokenizerFast
136 |
137 | hf_tokenizer = PreTrainedTokenizerFast(tokenizer_file="hf_tokenizer.json")
138 | ```
139 |
140 | ### Export vocab to `tiktoken` txt format
141 |
142 | ```python
143 | from bpeasy import
144 | vocab = bpeasy.train_bpe(...)
145 |
146 | # saves the vocab to a tiktoken txt file format
147 | save_vocab_to_tiktoken(vocab, "vocab.txt", special_tokens=["", "", ""])
148 |
149 | ```
150 |
151 | If you want to use the `tiktoken` txt format, you will still need to handle the regex and special tokens yourself, as shown above,
152 |
153 | ## Contributing
154 |
155 | Contributions are welcome! Please open an issue if you have any suggestions or improvements.
156 |
157 | ## License
158 |
159 | This project is licensed under the MIT License.
160 |
161 | ## Citation
162 |
163 | If you use `bpeasy` in your research, please cite the following paper:
164 |
165 | ```bash
166 | @software{bpeasy,
167 | author = {Gautier Dagan},
168 | title = {bpeasy},
169 | year = {2024},
170 | url = {https://github.com/gautierdag/bpeasy},
171 | repository = {https://github.com/gautierdag/bpeasy},
172 | author-email = {gautier.dagan@ed.ac.uk},
173 | affiliation = {University of Edinburgh},
174 | orcid = {https://orcid.org/0000-0002-1867-4201}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarks on the c4 dataset
2 |
3 | Using varying vocab sizes from (5k:100k)
4 |
5 | | Library/Operation | Time (seconds) | Standard Deviation |
6 | |----------------------------|---------------------------------|--------------------------------|
7 | | HuggingFace Train | 0.8165 | ±0.62 |
8 | | `bpeasy` Train | 0.68815 | ±0.41 |
9 | | HuggingFace Encode | 0.6247 | ±0.051 |
10 | | `bpeasy` Encode (uses `tiktoken`) | 0.2679 | ±0.035 |
11 |
12 | | | Bytes per Token (normalised against HF) | Standard Deviation |
13 | |----------------------------|---------------------------------|--------------------------------|
14 | | `bpeasy` | 1.0008992687171223 | ±5.542696043278318e-05 |
15 |
16 | We can see that BPEasy is faster than HuggingFace for training and encoding. Though the difference is not massive for training and will heavily depend on the dataset/compute, it is comparable.
17 |
18 | We also gain a tiny bit of compression (more bytes per token) because `bpeasy` works at the byte level and is slightly more efficient in its allocation of basic tokens.
19 |
20 | ## Reproducing the benchmarks
21 |
22 | ```bash
23 | pip install tokenizers
24 | pip install bpeasy
25 |
26 | python benchmarks/train.py
27 | ```
28 |
--------------------------------------------------------------------------------
/benchmarks/train.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import glob
3 | import json
4 | import logging
5 | import sys
6 | import time
7 | from pathlib import Path
8 |
9 | import tokenizers
10 | from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers
11 | from tokenizers.models import BPE
12 | from tokenizers.trainers import BpeTrainer
13 | from tqdm import tqdm
14 |
15 | import bpeasy
16 | from bpeasy.tokenizer import BPEasyTokenizer
17 |
18 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
19 |
20 |
21 | @dataclasses.dataclass
22 | class TrainBPETokenizerArgs:
23 | dataset: str = "./benchmarks/data"
24 | vocab_size: int = 32_000
25 | max_sentencepiece_length: int = 128
26 | regex_pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
27 |
28 | def __post_init__(self):
29 | checkpoint_dir = Path(self.dataset)
30 | assert checkpoint_dir.is_dir(), checkpoint_dir
31 |
32 |
33 | def jsonl_content_iterator(
34 | args: TrainBPETokenizerArgs,
35 | ):
36 | """
37 | Iterates over a jsonl file and yields the content of each line
38 | Tracks the number of characters yielded and stops when the limit is reached
39 | This is ripe for optimisation if you want to mess with more fine-grained
40 | character limits (eg. more Python than Java)
41 | """
42 | file_path = args.dataset
43 | chunk_num, character_count = 0, 0
44 | chunks = glob.glob(f"{file_path}/*.jsonl")
45 |
46 | while chunk_num < len(chunks):
47 | file_name = chunks[chunk_num]
48 | with open(file_name, "r", encoding="utf-8") as f:
49 | for line in f:
50 | obj = json.loads(line)
51 | text = obj["text"]
52 | text_character_count = len(text)
53 | character_count += text_character_count
54 | yield text
55 | chunk_num += 1
56 |
57 |
58 | def train_huggingface(args: TrainBPETokenizerArgs):
59 | # should be at least 0.14.0 to train with char limit
60 | assert tokenizers.__version__ >= "0.14.0"
61 | tokenizer = Tokenizer(BPE(byte_fallback=True))
62 | trainer = BpeTrainer(
63 | vocab_size=args.vocab_size,
64 | special_tokens=[f"<0x{i:02X}>" for i in range(256)], # seed sm vocab
65 | max_token_length=args.max_sentencepiece_length,
66 | show_progress=False,
67 | )
68 | gpt_regex = Regex(args.regex_pattern)
69 |
70 | split_pre_tokenizer = pre_tokenizers.Split(
71 | gpt_regex, behavior="isolated", invert=False
72 | )
73 | byte_pre_tokenizer = pre_tokenizers.ByteLevel(
74 | add_prefix_space=False, use_regex=False
75 | )
76 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
77 | [split_pre_tokenizer, byte_pre_tokenizer]
78 | )
79 | # Use ByteLevel Decoder
80 | tokenizer.decoder = decoders.Sequence(
81 | [decoders.ByteLevel(), decoders.ByteFallback()]
82 | )
83 | iterator = jsonl_content_iterator(args)
84 | # training the tokenizer
85 | tokenizer.train_from_iterator(iterator, trainer)
86 |
87 | return tokenizer
88 |
89 |
90 | def train_bpeasy(args: TrainBPETokenizerArgs):
91 | # Use ByteLevel Decoder
92 | iterator = jsonl_content_iterator(args)
93 | # training the tokenizer
94 | vocab = bpeasy.train_bpe(
95 | iterator,
96 | args.regex_pattern,
97 | args.max_sentencepiece_length,
98 | args.vocab_size,
99 | )
100 |
101 | return BPEasyTokenizer(
102 | vocab,
103 | args.regex_pattern,
104 | special_tokens=[],
105 | fill_to_nearest_multiple_of_eight=False,
106 | )
107 |
108 |
109 | def encode(tokenizer, args) -> float:
110 | iterator = jsonl_content_iterator(args)
111 | lengths = []
112 | num_bytes = 0
113 | for text in iterator:
114 | num_bytes += len(text.encode("utf-8"))
115 | encoded = tokenizer.encode(text)
116 | lengths.append(len(encoded))
117 | return num_bytes / sum(lengths)
118 |
119 |
120 | def get_mean_std_dev(times: list[float]) -> tuple[float, float]:
121 | avg_time = sum(times) / len(times)
122 | std_dev = sum([(t - avg_time) ** 2 for t in times])
123 | return avg_time, std_dev
124 |
125 |
126 | if __name__ == "__main__":
127 | args = TrainBPETokenizerArgs()
128 |
129 | times_train_huggingface = []
130 | times_encode_huggingface = []
131 | times_train_bpeasy = []
132 | times_encode_bpeasy = []
133 | byte_per_token_bpeasy = []
134 |
135 | for v in tqdm(range(5000, 100_000, 5000)):
136 | args.vocab_size = v
137 |
138 | time_now = time.time()
139 | tokenizer = train_huggingface(args)
140 | times_train_huggingface.append(time.time() - time_now)
141 |
142 | time_now = time.time()
143 | byte_per_token_hf = encode(tokenizer, args)
144 | times_encode_huggingface.append(time.time() - time_now)
145 |
146 | time_now = time.time()
147 | tokenizer = train_bpeasy(args)
148 | times_train_bpeasy.append(time.time() - time_now)
149 |
150 | time_now = time.time()
151 | byte_per_token_bpeasy.append(encode(tokenizer, args) / byte_per_token_hf)
152 | times_encode_bpeasy.append(time.time() - time_now)
153 |
154 | m_hf, std_hf = get_mean_std_dev(times_train_huggingface)
155 | m_bpeasy, std_bpeasy = get_mean_std_dev(times_train_bpeasy)
156 |
157 | print(f"huggingface train time {m_hf} +/- {std_hf}")
158 | print(f"bpeasy train time {m_bpeasy} +/- {std_bpeasy}")
159 |
160 | m_hf, std_hf = get_mean_std_dev(times_encode_huggingface)
161 | m_bpeasy, std_bpeasy = get_mean_std_dev(times_encode_bpeasy)
162 |
163 | print(f"huggingface encode time {m_hf} +/- {std_hf}")
164 | print(f"bpeasy encode time {m_bpeasy} +/- {std_bpeasy}")
165 |
166 | m_bpeasy, std_bpeasy = get_mean_std_dev(byte_per_token_bpeasy)
167 | print(f"bpeasy bytes/token vs hf: {m_bpeasy} +/- {std_bpeasy}")
168 |
--------------------------------------------------------------------------------
/bpeasy/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 |
3 | from .bpeasy import train_bpe
4 |
5 | __version__ = version("bpeasy")
6 |
7 |
8 | __all__ = [
9 | "save_vocab_to_tiktoken",
10 | "train_bpe",
11 | "__version__",
12 | ]
13 |
14 |
15 | def save_vocab_to_tiktoken(
16 | vocab: dict[bytes, int],
17 | out_path: str,
18 | special_tokens: list[str] = [],
19 | fill_to_nearest_multiple_of_eight: bool = False,
20 | ) -> None:
21 | """
22 | Export vocab to tiktoken txt format - use this if you want to use tiktoken library directly
23 | Note: you will need to handle special tokens and regex yourself
24 | """
25 | import base64
26 |
27 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1])
28 | for special_token in special_tokens:
29 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab)))
30 |
31 | if fill_to_nearest_multiple_of_eight:
32 | while len(sorted_vocab) % 8 != 0:
33 | sorted_vocab.append(
34 | (f"<|special-{len(sorted_vocab)}|>".encode("utf-8"), len(sorted_vocab))
35 | )
36 |
37 | with open(out_path, "wb") as f:
38 | for token, rank in sorted_vocab:
39 | # encode token to base64 and write to file with rank separated by a space
40 | f.write(base64.b64encode(token) + b" " + str(rank).encode("utf-8") + b"\n")
41 |
--------------------------------------------------------------------------------
/bpeasy/bpeasy.pyi:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | def train_bpe(
4 | iterator: Iterator[str],
5 | python_regex: str,
6 | max_token_length: int,
7 | vocab_size: int,
8 | ) -> dict[bytes, int]: ...
9 |
--------------------------------------------------------------------------------
/bpeasy/convert.py:
--------------------------------------------------------------------------------
1 | import tiktoken
2 |
3 | from typing import Optional
4 | from functools import lru_cache
5 | import json
6 |
7 |
8 | # Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
9 | def bpe(
10 | mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None
11 | ) -> list[bytes]:
12 | parts = [bytes([b]) for b in token]
13 | while True:
14 | min_idx = None
15 | min_rank = None
16 | for i, pair in enumerate(zip(parts[:-1], parts[1:])):
17 | rank = mergeable_ranks.get(pair[0] + pair[1])
18 | if rank is not None and (min_rank is None or rank < min_rank):
19 | min_idx = i
20 | min_rank = rank
21 | if min_rank is None or (max_rank is not None and min_rank >= max_rank):
22 | break
23 | assert min_idx is not None
24 | parts = (
25 | parts[:min_idx]
26 | + [parts[min_idx] + parts[min_idx + 1]]
27 | + parts[min_idx + 2 :]
28 | )
29 | return parts
30 |
31 |
32 | # Source taken from https://github.com/huggingface/transformers/blob/73de5108e172112bc620cfc0ceebfd27730dba11/src/transformers/models/gpt2/tokenization_gpt2.py#L63
33 | @lru_cache()
34 | def bytes_to_unicode():
35 | """
36 | Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
37 | characters the bpe code barfs on.
38 |
39 | The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
40 | if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
41 | decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
42 | tables between utf-8 bytes and unicode strings.
43 | """
44 | bs = (
45 | list(range(ord("!"), ord("~") + 1))
46 | + list(range(ord("¡"), ord("¬") + 1))
47 | + list(range(ord("®"), ord("ÿ") + 1))
48 | )
49 | cs = bs[:]
50 | n = 0
51 | for b in range(2**8):
52 | if b not in bs:
53 | bs.append(b)
54 | cs.append(2**8 + n)
55 | n += 1
56 | cs = [chr(n) for n in cs]
57 | return dict(zip(bs, cs))
58 |
59 |
60 | def convert_tiktoken_to_huggingface(
61 | encoder: tiktoken.Encoding,
62 | out_path: str,
63 | regex_pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+""",
64 | ):
65 | byte_encoder = bytes_to_unicode()
66 |
67 | def token_bytes_to_string(b):
68 | return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
69 |
70 | def generate_vocab_and_merges(encoder):
71 | mergeable_ranks = encoder._mergeable_ranks
72 |
73 | merges = []
74 | vocab = {}
75 | i = 0
76 | for token, rank in mergeable_ranks.items():
77 | # Skip special tokens as they are added separately
78 | if token_bytes_to_string(token) in encoder._special_tokens:
79 | continue
80 |
81 | vocab[token_bytes_to_string(token)] = rank
82 |
83 | i += 1
84 | if len(token) == 1:
85 | continue
86 | merged = tuple(bpe(mergeable_ranks, token, max_rank=rank))
87 | assert len(merged) == 2
88 | merges.append(" ".join(map(token_bytes_to_string, merged)))
89 |
90 | # Also add special tokens
91 | vocab.update(encoder._special_tokens)
92 |
93 | return vocab, merges
94 |
95 | vocab, merges = generate_vocab_and_merges(encoder)
96 |
97 | added_tokens = [
98 | {
99 | "id": id,
100 | "content": content,
101 | "single_word": False,
102 | "lstrip": False,
103 | "rstrip": False,
104 | "normalized": False,
105 | "special": True,
106 | }
107 | for content, id in encoder._special_tokens.items()
108 | ]
109 |
110 | tokenizer_template = {
111 | "version": "1.0",
112 | "truncation": None,
113 | "padding": None,
114 | "added_tokens": added_tokens,
115 | "normalizer": None,
116 | "pre_tokenizer": {
117 | "type": "Sequence",
118 | "pretokenizers": [
119 | {
120 | "type": "Split",
121 | "pattern": {"Regex": regex_pattern},
122 | "behavior": "Removed",
123 | "invert": True,
124 | },
125 | {
126 | "type": "ByteLevel",
127 | "add_prefix_space": False,
128 | "trim_offsets": True,
129 | "use_regex": False,
130 | },
131 | ],
132 | },
133 | "post_processor": None,
134 | "decoder": {
135 | "type": "ByteLevel",
136 | "add_prefix_space": True,
137 | "trim_offsets": True,
138 | "use_regex": True,
139 | },
140 | "model": {
141 | "type": "BPE",
142 | "dropout": None,
143 | "unk_token": None,
144 | "continuing_subword_prefix": "",
145 | "end_of_word_suffix": "",
146 | "fuse_unk": False,
147 | "byte_fallback": False,
148 | "vocab": vocab,
149 | "merges": merges,
150 | },
151 | }
152 |
153 | with open(
154 | out_path,
155 | "w",
156 | encoding="utf-8",
157 | ) as fp:
158 | json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False)
159 |
--------------------------------------------------------------------------------
/bpeasy/tokenizer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import base64
3 | from typing import Iterator
4 |
5 | import tiktoken
6 |
7 | from .bpeasy import train_bpe
8 | from .convert import convert_tiktoken_to_huggingface
9 |
10 |
11 | _DEFAULT_REGEX_PATTERN = r"""[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
12 |
13 |
14 | class BPEasyTokenizer:
15 | def __init__(
16 | self,
17 | vocab: dict[bytes, int],
18 | regex_pattern: str = _DEFAULT_REGEX_PATTERN,
19 | special_tokens: list[str] = [],
20 | fill_to_nearest_multiple_of_eight=False,
21 | name="bpeasy",
22 | ):
23 | """
24 | Wrapper around tiktoken.Encoding
25 | Handles the loading/saving of vocab/special_tokens/regex
26 | """
27 |
28 | self.name = name
29 | self.regex_pattern = regex_pattern
30 | self.special_tokens = special_tokens
31 | self.vocab = vocab
32 |
33 | # Sort the vocab by rank
34 | sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1])
35 |
36 | # add special tokens
37 | special_token_ranks = {}
38 | for special_token in special_tokens:
39 | special_token_ranks[special_token] = len(sorted_vocab)
40 | sorted_vocab.append((special_token.encode("utf-8"), len(sorted_vocab)))
41 |
42 | full_vocab = dict(sorted_vocab)
43 |
44 | # fill to nearest multiple of 8
45 | if fill_to_nearest_multiple_of_eight:
46 | while len(sorted_vocab) % 8 != 0:
47 | sorted_vocab.append(
48 | (
49 | f"<|special-{len(sorted_vocab)}|>".encode("utf-8"),
50 | len(sorted_vocab),
51 | )
52 | )
53 |
54 | self._encoder = tiktoken.Encoding(
55 | name=name,
56 | pat_str=self.regex_pattern,
57 | mergeable_ranks=full_vocab,
58 | special_tokens=special_token_ranks,
59 | )
60 |
61 | def encode(self, text: str, **kwargs) -> list[int]:
62 | return self._encoder.encode(text, **kwargs)
63 |
64 | def decode(self, tokens: list[int], **kwargs) -> str:
65 | return self._encoder.decode(tokens, **kwargs)
66 |
67 | @classmethod
68 | def from_file(cls, file_path: str) -> "BPEasyTokenizer":
69 | with open(file_path, "r") as file:
70 | data = json.load(file)
71 | bytes_vocab = {
72 | base64.b64decode(key): value for key, value in data["vocab"].items()
73 | }
74 | instance = cls(
75 | name=data["name"],
76 | vocab=bytes_vocab,
77 | regex_pattern=data["regex_pattern"],
78 | special_tokens=data["special_tokens"],
79 | )
80 | return instance
81 |
82 | def save(self, file_path: str) -> None:
83 | with open(file_path, "w") as file:
84 | json.dump(
85 | {
86 | "name": self.name,
87 | "regex_pattern": self.regex_pattern,
88 | "special_tokens": self.special_tokens,
89 | "vocab": {
90 | base64.b64encode(key).decode("utf-8"): value
91 | for key, value in self.vocab.items()
92 | },
93 | },
94 | file,
95 | )
96 |
97 | def export_to_huggingface_format(self, out_path: str) -> None:
98 | convert_tiktoken_to_huggingface(self._encoder, out_path, self.regex_pattern)
99 |
100 | def __len__(self) -> int:
101 | return len(self.vocab)
102 |
103 | @classmethod
104 | def train(
105 | cls,
106 | iterator: Iterator[str],
107 | vocab_size: int = 32_000,
108 | max_token_length=128,
109 | regex_pattern: str = _DEFAULT_REGEX_PATTERN,
110 | special_tokens: list[str] = [],
111 | fill_to_nearest_multiple_of_eight=False,
112 | name="bpeasy",
113 | ) -> "BPEasyTokenizer":
114 | bytes_vocab = train_bpe(iterator, regex_pattern, max_token_length, vocab_size)
115 | return cls(
116 | name=name,
117 | vocab=bytes_vocab,
118 | regex_pattern=regex_pattern,
119 | special_tokens=special_tokens,
120 | fill_to_nearest_multiple_of_eight=fill_to_nearest_multiple_of_eight,
121 | )
122 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["maturin>=1.3,<2.0"]
3 | build-backend = "maturin"
4 |
5 | [project]
6 | name = "bpeasy"
7 | requires-python = ">=3.8"
8 | classifiers = [
9 | "Programming Language :: Rust",
10 | "Programming Language :: Python :: Implementation :: CPython",
11 | "Programming Language :: Python :: Implementation :: PyPy",
12 | "Programming Language :: Python :: 3.8",
13 | "Programming Language :: Python :: 3.9",
14 | "Programming Language :: Python :: 3.10",
15 | "Programming Language :: Python :: 3.11",
16 | "Programming Language :: Python :: 3.12",
17 | "Programming Language :: Python :: 3.13",
18 | "License :: OSI Approved :: MIT License",
19 | ]
20 | dynamic = ["version"]
21 | description = "Fast bare-bones BPE for modern tokenizer training"
22 | authors = [{name = "Gautier Dagan", email = ""}]
23 | license = "MIT"
24 | readme = "README.md"
25 | homepage = "https://github.com/gautierdag/bpeasy"
26 | repository = "https://github.com/gautierdag/bpeasy"
27 | include = [
28 | "LICENSE",
29 | ]
30 | keywords = ["tokenizer", "tokenization", "bpe"]
31 | dependencies = [
32 | "tiktoken>=0.4.0",
33 | ]
34 |
35 | [project.optional-dependencies]
36 | dev = ["pytest", "pytest-cov", "black", "tokenizers", "tqdm"]
37 |
38 | [tool.maturin]
39 | features = ["pyo3/extension-module"]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest>=7.1.2
2 | pytest-cov>=3.0.0
3 | maturin>=0.12.14
4 | tiktoken
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | use fancy_regex::Regex;
2 | use fxhash::FxHashMap as HashMap;
3 | use fxhash::FxHashSet as HashSet;
4 | use pyo3::exceptions;
5 | use pyo3::prelude::*;
6 | use pyo3::types::{PyBytes, PyDict, PyIterator, PyString};
7 | use rayon::prelude::*;
8 | use std::cmp::Ordering;
9 | use std::collections::BinaryHeap;
10 |
11 | type Pair = (u32, u32);
12 |
13 | #[derive(Debug, Eq)]
14 | struct Merge {
15 | pair: Pair,
16 | count: i64,
17 | pos: HashSet,
18 | }
19 | impl PartialEq for Merge {
20 | fn eq(&self, other: &Self) -> bool {
21 | self.count == other.count && self.pair == other.pair
22 | }
23 | }
24 | impl PartialOrd for Merge {
25 | fn partial_cmp(&self, other: &Self) -> Option {
26 | Some(self.cmp(other))
27 | }
28 | }
29 | impl Ord for Merge {
30 | fn cmp(&self, other: &Self) -> Ordering {
31 | if self.count != other.count {
32 | self.count.cmp(&other.count)
33 | } else {
34 | // Here we want ascending order
35 | other.pair.cmp(&self.pair)
36 | }
37 | }
38 | }
39 |
40 | #[derive(Debug, Clone, Copy)]
41 | struct Symbol {
42 | c: u32,
43 | prev: isize,
44 | next: isize,
45 | len: usize,
46 | }
47 |
48 | #[derive(Debug)]
49 | struct Sentence {
50 | symbols: Vec,
51 | }
52 |
53 | impl Sentence {
54 | fn new() -> Self {
55 | Sentence { symbols: vec![] }
56 | }
57 |
58 | fn add(&mut self, c: u32, byte_len: usize) {
59 | let (prev, next) = {
60 | let len: isize = self.symbols.len() as isize;
61 | if let Some(last) = self.symbols.last_mut() {
62 | // Update `next` on the previous one
63 | last.next = len;
64 | (len - 1, -1)
65 | } else {
66 | (-1, -1)
67 | }
68 | };
69 | self.symbols.push(Symbol {
70 | c,
71 | prev,
72 | next,
73 | len: byte_len,
74 | });
75 | }
76 |
77 | fn merge(&mut self, c1: u32, c2: u32, replacement: u32, max_length: usize) -> Vec<(Pair, i64)> {
78 | let mut changes: Vec<(Pair, i64)> = vec![];
79 | let mut i = 0;
80 | loop {
81 | if i >= self.symbols.len() {
82 | break;
83 | }
84 |
85 | // Found a pair
86 | if self.symbols[i].c == c1 && i + 1 < self.symbols.len() && self.symbols[i + 1].c == c2
87 | {
88 | let first = self.symbols[i];
89 | let second = self.symbols[i + 1];
90 |
91 | // Remove in place
92 | let new_s = Symbol {
93 | c: replacement,
94 | prev: first.prev,
95 | next: second.next,
96 | len: first.len + second.len,
97 | };
98 |
99 | // If there are other characters before the pair
100 | if i > 0 {
101 | changes.push(((self.symbols[i - 1].c, first.c), -1));
102 | if self.symbols[i - 1].len + new_s.len < max_length {
103 | changes.push(((self.symbols[i - 1].c, replacement), 1));
104 | }
105 | }
106 |
107 | self.symbols.insert(i, new_s); // Insert replacement before first char of pair
108 | self.symbols.remove(i + 1); // Remove first char of pair
109 | self.symbols.remove(i + 1); // And then the second
110 |
111 | // If there are other characters after the pair
112 | if i < self.symbols.len() - 1 {
113 | changes.push(((second.c, self.symbols[i + 1].c), -1));
114 | if self.symbols[i + 1].len + new_s.len < max_length {
115 | changes.push(((replacement, self.symbols[i + 1].c), 1));
116 | }
117 | }
118 | }
119 | i += 1;
120 | }
121 | changes
122 | }
123 |
124 | fn get_symbols(&self) -> Vec {
125 | self.symbols.iter().map(|s| s.c).collect()
126 | }
127 |
128 | fn from_str(s: &str) -> Self {
129 | let mut sentence = Sentence::new();
130 | for byte in s.bytes() {
131 | sentence.add(byte as u32, 1);
132 | }
133 | sentence
134 | }
135 | }
136 |
137 | fn pretokenize<'a>(text: &'a str, regex: &Regex) -> Vec<&'a str> {
138 | regex
139 | .find_iter(text)
140 | .filter_map(|mat| match mat {
141 | Ok(m) => Some(m.as_str()),
142 | Err(_) => None,
143 | })
144 | .collect()
145 | }
146 |
147 | fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec) {
148 | let regex: Regex = Regex::new(pattern).expect("Invalid regex pattern");
149 | // Tokenize strings in parallel
150 | let (tokens, counts): (Vec<&str>, Vec) = strings
151 | .par_iter()
152 | .flat_map(|&text| pretokenize(text, ®ex))
153 | .fold(
154 | || HashMap::<&str, u64>::default(),
155 | |mut acc, token| {
156 | *acc.entry(token).or_insert(0) += 1;
157 | acc
158 | },
159 | )
160 | .reduce(
161 | || HashMap::<&str, u64>::default(),
162 | |mut a, b| {
163 | for (token, count) in b {
164 | *a.entry(token).or_insert(0) += count;
165 | }
166 | a
167 | },
168 | )
169 | .into_iter()
170 | .unzip();
171 |
172 | // Convert tokens to sentences and filter sentences and counts to remove single byte sentences
173 | let (filtered_sentences, filtered_counts): (Vec, Vec) = tokens
174 | .into_iter()
175 | .map(Sentence::from_str)
176 | .zip(counts.into_iter())
177 | .filter(|(sentence, _)| sentence.symbols.len() > 1)
178 | .unzip();
179 |
180 | (filtered_sentences, filtered_counts)
181 | }
182 |
183 | fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec>) {
184 | let mut word_to_id: HashMap, u32> = HashMap::default();
185 | let mut id_to_word: Vec> = Vec::with_capacity(vocab_size);
186 | for i in 0..=255 {
187 | word_to_id.insert(vec![i], i as u32);
188 | id_to_word.push(vec![i]);
189 | }
190 | return (word_to_id, id_to_word);
191 | }
192 |
193 | fn get_most_frequent_pair(
194 | tokenized_sentences: &[Sentence],
195 | base_counts: &[u64],
196 | ) -> (HashMap, HashMap>) {
197 | // Calculate frequencies for each pair of bytes in all sentences and words
198 | tokenized_sentences
199 | .par_iter()
200 | .enumerate()
201 | .map(|(i, sentence)| {
202 | let mut local_pair_counts = HashMap::::default();
203 | let mut local_pair_positions: HashMap> = HashMap::default();
204 |
205 | for window in sentence.get_symbols().windows(2) {
206 | let current_pair: Pair = (window[0], window[1]);
207 | // First update counts
208 | local_pair_counts
209 | .entry(current_pair)
210 | .and_modify(|c| *c += base_counts[i] as i64)
211 | .or_insert(base_counts[i] as i64);
212 |
213 | // Then update position
214 | local_pair_positions
215 | .entry(current_pair)
216 | .and_modify(|h: &mut HashSet| {
217 | h.insert(i);
218 | })
219 | .or_insert_with(|| {
220 | let mut h = HashSet::::default();
221 | h.insert(i);
222 | h
223 | });
224 | }
225 | (local_pair_counts, local_pair_positions)
226 | })
227 | .reduce(
228 | || {
229 | (
230 | HashMap::::default(),
231 | HashMap::>::default(),
232 | )
233 | },
234 | |(mut global_pair_counts, mut global_pair_positions), (pc, wtu)| {
235 | // Merge the pair counts and positions from all sentences
236 | for (k, v) in pc {
237 | global_pair_counts
238 | .entry(k)
239 | .and_modify(|c| *c += v)
240 | .or_insert(v);
241 | }
242 | for (k, v) in wtu {
243 | global_pair_positions
244 | .entry(k)
245 | .and_modify(|set| *set = set.union(&v).copied().collect())
246 | .or_insert(v);
247 | }
248 | (global_pair_counts, global_pair_positions)
249 | },
250 | )
251 | }
252 |
253 | // Build vocab from most frequent pairs
254 | fn build_bpe_vocab(
255 | tokenized_sentences: Vec,
256 | base_counts: &[u64],
257 | max_token_length: usize,
258 | vocab_size: usize,
259 | ) -> HashMap, u32> {
260 | let (mut word_to_id, mut id_to_word) = initialize_vocab_bytes(vocab_size);
261 |
262 | // get most frequent pair
263 | let (mut global_pair_counts, mut global_pair_positions) =
264 | get_most_frequent_pair(&tokenized_sentences, &base_counts);
265 |
266 | // build Priority Queue from counts and positions
267 | let mut queue: BinaryHeap = BinaryHeap::new();
268 | global_pair_positions.drain().for_each(|(pair, pos)| {
269 | let count: i64 = global_pair_counts[&pair];
270 | if count > 0 {
271 | queue.push(Merge { pair, count, pos });
272 | }
273 | });
274 |
275 | while word_to_id.len() < vocab_size {
276 | // check if queue is empty
277 | if queue.is_empty() {
278 | break;
279 | }
280 |
281 | let mut top = queue.pop().unwrap();
282 | // check if count has changed
283 | if top.count != global_pair_counts[&top.pair] {
284 | top.count = global_pair_counts[&top.pair];
285 | queue.push(top);
286 | continue;
287 | }
288 |
289 | // exit count is 0
290 | if top.count < 1 {
291 | break;
292 | }
293 |
294 | // add to vocab
295 | let (left, right) = top.pair;
296 | let merged_id = word_to_id.len() as u32;
297 |
298 | let mut word = id_to_word[left as usize].clone();
299 | let right_word = id_to_word[right as usize].clone();
300 | word.extend(right_word.iter());
301 | word_to_id.insert(word.clone(), merged_id);
302 | id_to_word.push(word);
303 |
304 | // update counts and positions for each sentence
305 | let changes = top
306 | .pos
307 | .par_iter()
308 | .flat_map(|&i| {
309 | let sentence = &tokenized_sentences[i] as *const _ as *mut Sentence;
310 | // We can merge each of these sentences in parallel here because each position
311 | // can be there only once (HashSet). So this is safe.
312 | unsafe {
313 | (*sentence)
314 | .merge(top.pair.0, top.pair.1, merged_id, max_token_length)
315 | .into_iter()
316 | .map(|c| (c, i))
317 | .collect::>()
318 | }
319 | })
320 | .collect::>();
321 |
322 | for ((pair, change), iw) in changes {
323 | // adjust count to reflect sentence level count
324 | let count = change * base_counts[iw] as i64;
325 | global_pair_counts
326 | .entry(pair)
327 | .and_modify(|c| *c += count)
328 | .or_insert(count);
329 | if count > 0 {
330 | global_pair_positions
331 | .entry(pair)
332 | .and_modify(|h| {
333 | h.insert(iw);
334 | })
335 | .or_insert_with(|| {
336 | let mut h = HashSet::::default();
337 | h.insert(iw);
338 | h
339 | });
340 | }
341 | }
342 |
343 | // update queue
344 | global_pair_positions.drain().for_each(|(pair, pos)| {
345 | let count = global_pair_counts[&pair];
346 | if count > 0 {
347 | queue.push(Merge { pair, count, pos });
348 | }
349 | });
350 | }
351 | word_to_id
352 | }
353 |
354 | // Train BPE from Iterator
355 | #[pyfunction]
356 | fn train_bpe(
357 | py: Python,
358 | iterator: &PyIterator,
359 | python_regex: &PyString,
360 | max_token_length: usize,
361 | vocab_size: usize,
362 | ) -> PyResult {
363 | let regex = python_regex.to_str()?;
364 |
365 | // validate inputs
366 | if max_token_length < 2 {
367 | return Err(exceptions::PyValueError::new_err(
368 | "max_token_length must be greater than 1",
369 | ));
370 | }
371 | if vocab_size < 256 {
372 | return Err(exceptions::PyValueError::new_err(
373 | "vocab_size must be greater than 256",
374 | ));
375 | }
376 | if regex.is_empty() {
377 | return Err(exceptions::PyValueError::new_err("regex cannot be empty"));
378 | }
379 |
380 | // Extract strings from Python iterator and store them in a Rust Vec for parallel processing
381 | let strings: Vec<&str> = iterator
382 | .filter_map(|item_result| {
383 | item_result.ok().and_then(|item| {
384 | item.extract::<&PyString>()
385 | .ok()
386 | .and_then(|py_string| py_string.to_str().ok())
387 | })
388 | })
389 | .filter(|text| !text.is_empty())
390 | .collect();
391 |
392 | let (pretokenized_sentences, counts): (Vec, Vec) =
393 | pretokenize_strings(strings, regex);
394 |
395 | let bpe_vocab = build_bpe_vocab(
396 | pretokenized_sentences,
397 | &counts,
398 | max_token_length,
399 | vocab_size,
400 | );
401 | let python_dict_out = PyDict::new(py);
402 | // convert bpe_vocab to python dict
403 | for (key, value) in bpe_vocab {
404 | let py_key = PyBytes::new(py, &key);
405 | python_dict_out.set_item(py_key, value)?;
406 | }
407 | Ok(python_dict_out.into())
408 | }
409 |
410 | /// bpeasy is a bare-bones implementation of byte-pair encoding (BPE) in Rust.
411 | /// It is designed to be used as a Python module and returns a byte-pair vocabulary
412 | /// as a Python dictionary.
413 | #[pymodule]
414 | fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
415 | m.add_function(wrap_pyfunction!(train_bpe, m)?)?;
416 | Ok(())
417 | }
418 |
419 | #[cfg(test)]
420 | mod tests {
421 | #[test]
422 | fn test_all() {
423 | let text: &str = "\tYou hear a £ £ £ here";
424 | let pattern = r"([^\s]+)|(\s+)";
425 | let compiled_regex: fancy_regex::Regex =
426 | fancy_regex::Regex::new(pattern).expect("Invalid regex pattern");
427 | let pretokenized_sentences = crate::pretokenize(text, &compiled_regex);
428 | assert_eq!(
429 | pretokenized_sentences,
430 | vec!["\t", "You", " ", "hear", " ", "a", " ", "£", " ", "£", " ", "£", " ", "here"]
431 | );
432 |
433 | let text_2: &str = "You hear £ £ £ here";
434 |
435 | let (pretokenized_sentences, _counts) =
436 | crate::pretokenize_strings(vec![text, text_2], pattern);
437 |
438 | let vocab_size = 300;
439 | let max_token_length = 128;
440 | crate::build_bpe_vocab(
441 | pretokenized_sentences,
442 | &_counts,
443 | max_token_length,
444 | vocab_size,
445 | );
446 | }
447 |
448 | #[test]
449 | fn test_initialize_vocab_bytes() {
450 | let vocab = crate::initialize_vocab_bytes(400);
451 | assert_eq!(vocab.0.len(), 256);
452 | }
453 | }
454 |
--------------------------------------------------------------------------------
/tests/test_convert.py:
--------------------------------------------------------------------------------
1 | from bpeasy.convert import bpe
2 |
3 |
4 | def test_bpe_function():
5 | mergeable_ranks = {b"ab": 0, b"bc": 1, b"cd": 2}
6 | token = b"abcd"
7 | result = bpe(mergeable_ranks, token)
8 | assert result == [
9 | b"ab",
10 | b"cd",
11 | ], "The bpe function did not split the token correctly"
12 |
--------------------------------------------------------------------------------
/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | from unittest import mock
4 | from bpeasy.tokenizer import BPEasyTokenizer
5 |
6 |
7 | def test_initialization():
8 | vocab = {b"hello": 1, b"world": 2}
9 | tokenizer = BPEasyTokenizer(vocab=vocab)
10 | assert tokenizer.vocab == vocab
11 | assert tokenizer.name == "bpeasy"
12 | assert len(tokenizer.special_tokens) == 0
13 | assert len(tokenizer) == 2
14 |
15 |
16 | def test_encode_decode():
17 | vocab = {b"hello": 1, b" world": 2}
18 | tokenizer = BPEasyTokenizer(vocab=vocab)
19 | encoded = tokenizer.encode("hello world", allowed_special="all")
20 | assert encoded == [1, 2]
21 | decoded = tokenizer.decode(encoded)
22 | assert decoded == "hello world"
23 |
24 |
25 | def test_save_and_load():
26 | vocab = {b"hello": 1, b" world": 2}
27 | tokenizer = BPEasyTokenizer(vocab=vocab)
28 |
29 | # Test saving
30 | with mock.patch("builtins.open", mock.mock_open()) as mock_file:
31 | tokenizer.save("dummy_path.json")
32 | mock_file.assert_called_once_with("dummy_path.json", "w")
33 |
34 | # Prepare dummy file content for loading
35 | dummy_file_content = json.dumps(
36 | {
37 | "name": "bpeasy",
38 | "vocab": {
39 | base64.b64encode(key).decode("utf-8"): value
40 | for key, value in vocab.items()
41 | },
42 | "regex_pattern": tokenizer.regex_pattern,
43 | "special_tokens": tokenizer.special_tokens,
44 | }
45 | )
46 |
47 | # Test loading
48 | with mock.patch(
49 | "builtins.open", mock.mock_open(read_data=dummy_file_content)
50 | ) as mock_file:
51 | loaded_tokenizer = BPEasyTokenizer.from_file("dummy_path.json")
52 | assert loaded_tokenizer.vocab == vocab
53 |
54 |
55 | @mock.patch("builtins.open", new_callable=mock.mock_open)
56 | @mock.patch("json.dump")
57 | def test_conversion_to_huggingface(mock_json_dump, mock_open):
58 | vocab = {
59 | b"h": 0,
60 | b"e": 1,
61 | b"l": 2,
62 | b"o": 3,
63 | b" ": 4,
64 | b"w": 5,
65 | b"r": 6,
66 | b"d": 7,
67 | b"he": 8,
68 | b"ll": 9,
69 | b"llo": 10,
70 | b"hello": 11,
71 | b"wo": 12,
72 | b"wor": 13,
73 | b"ld": 14,
74 | b"world": 15,
75 | b" world": 16,
76 | }
77 | tokenizer = BPEasyTokenizer(vocab=vocab)
78 | tokenizer.export_to_huggingface_format("dummy_path.json")
79 | mock_open.assert_called_once_with("dummy_path.json", "w", encoding="utf-8")
80 | mock_json_dump.assert_called_once()
81 | args, _ = mock_json_dump.call_args
82 | assert args[0]["model"]["type"] == "BPE"
83 |
84 |
85 | @mock.patch("builtins.open", new_callable=mock.mock_open)
86 | @mock.patch("json.dump")
87 | def test_conversion_to_huggingface_with_special_tokens(mock_json_dump, mock_open):
88 | vocab = {
89 | b"h": 0,
90 | b"e": 1,
91 | b"l": 2,
92 | b"o": 3,
93 | b" ": 4,
94 | b"w": 5,
95 | b"r": 6,
96 | b"d": 7,
97 | b"he": 8,
98 | b"ll": 9,
99 | b"llo": 10,
100 | b"hello": 11,
101 | b"wo": 12,
102 | b"wor": 13,
103 | b"ld": 14,
104 | b"world": 15,
105 | b" world": 16,
106 | }
107 | tokenizer = BPEasyTokenizer(vocab=vocab, special_tokens=["<|special-0|>", ""])
108 | tokenizer.export_to_huggingface_format("dummy_path.json")
109 | mock_open.assert_called_once_with("dummy_path.json", "w", encoding="utf-8")
110 | mock_json_dump.assert_called_once()
111 | args, _ = mock_json_dump.call_args
112 | assert args[0]["model"]["type"] == "BPE"
113 |
--------------------------------------------------------------------------------
/tests/test_train_bpe.py:
--------------------------------------------------------------------------------
1 | from bpeasy import train_bpe
2 |
3 |
4 | def test_train_bpe_vocab_size():
5 | vocab_size = 300
6 | max_token_length = 4
7 | regex = r"([^\s]+)|(\s+)"
8 | vocab = train_bpe(
9 | iter(["This is a test", "this is another test", "good tests"]),
10 | regex,
11 | max_token_length,
12 | vocab_size,
13 | )
14 | assert len(vocab) == 267
15 |
16 |
17 | def test_train_bpe_max_token_length():
18 | vocab_size = 300
19 | max_token_length = 2
20 | regex = r"([^\s]+)|(\s+)"
21 | vocab = train_bpe(
22 | iter(["This is a test", "this is another test", "good tests"]),
23 | regex,
24 | max_token_length,
25 | vocab_size,
26 | )
27 | for token in vocab:
28 | assert len(token) <= max_token_length
29 | max_token_length = 3
30 | vocab = train_bpe(
31 | iter(["This is a test", "this is another test", "good tests"]),
32 | regex,
33 | max_token_length,
34 | vocab_size,
35 | )
36 | for token in vocab:
37 | assert len(token) <= max_token_length
38 |
39 |
40 | def test_train_bpe_gpt_regex():
41 | vocab_size = 300
42 | max_token_length = 128
43 | regex = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44 | vocab = train_bpe(
45 | iter(["We've got a test", "We've got good test", "this is a good tests"]),
46 | regex,
47 | max_token_length,
48 | vocab_size,
49 | )
50 | for token in vocab:
51 | assert len(token) <= max_token_length
52 |
53 | print(vocab)
54 | assert b" go" in vocab.keys()
55 | assert b"'ve" in vocab.keys()
56 |
--------------------------------------------------------------------------------
/tests/test_utils_bpe.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from unittest.mock import mock_open, patch, call
3 |
4 | from bpeasy import save_vocab_to_tiktoken
5 |
6 |
7 | def test_basic_functionality():
8 | vocab = {b"hello": 1, b"world": 2}
9 | with patch("builtins.open", mock_open()) as mock_file:
10 | save_vocab_to_tiktoken(vocab, "path/to/file.txt")
11 | mock_file.assert_called_once_with("path/to/file.txt", "wb")
12 | # Check if the sorted vocab is written correctly
13 | expected_content = [
14 | base64.b64encode(b"hello") + b" 1\n",
15 | base64.b64encode(b"world") + b" 2\n",
16 | ]
17 | mock_file().write.assert_has_calls(
18 | [call(content) for content in expected_content]
19 | )
20 |
21 |
22 | def test_special_tokens_addition():
23 | vocab = {b"token": 0}
24 | special_tokens = ["special1", "special2"]
25 | with patch("builtins.open", mock_open()) as mock_file:
26 | save_vocab_to_tiktoken(vocab, "path/to/file.txt", special_tokens)
27 | # Check if special tokens are added correctly
28 | expected_content = [
29 | base64.b64encode(b"token") + b" 0\n",
30 | base64.b64encode("special1".encode("utf-8")) + b" 1\n",
31 | base64.b64encode("special2".encode("utf-8")) + b" 2\n",
32 | ]
33 | mock_file().write.assert_has_calls(
34 | [call(content) for content in expected_content]
35 | )
36 |
37 |
38 | def test_fill_to_nearest_multiple_of_eight():
39 | vocab = {b"token": 0}
40 | with patch("builtins.open", mock_open()) as mock_file:
41 | save_vocab_to_tiktoken(
42 | vocab, "path/to/file.txt", fill_to_nearest_multiple_of_eight=True
43 | )
44 | # Verify that additional tokens are added to make the count a multiple of eight
45 | mock_file().write.assert_called()
46 | # Check the exact content based on your logic of filling to nearest multiple of eight
47 | assert mock_file().write.call_count == 8
48 |
--------------------------------------------------------------------------------