├── tests ├── __init__.py ├── test_control.py ├── test_chat_template.py ├── test_embeddings.py ├── test_utils.py ├── test_tokenizer.py └── test_logits_processor.py ├── experiments ├── language-modelling │ ├── .gitignore │ ├── clm-byte.sh │ ├── clm-byte-bit.sh │ ├── clm-byte-t5.sh │ ├── clm-bpe-gpt2.sh │ ├── analyze.py │ ├── models │ │ └── utf8-lm-tiny.sh │ ├── README.md │ └── run_clm.py └── benchmark.py ├── .gitignore ├── utf8_tokenizer ├── __init__.py ├── byt5_comparison.py ├── control.py ├── chat_template.jinja ├── utils.py ├── embeddings.py ├── tokenizer.py └── logits_processor.py ├── .github └── workflows │ ├── lint.yaml │ ├── test.yaml │ └── release.yaml ├── LICENSE ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/language-modelling/.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | output-clm-*/ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .claude/ 3 | utf8_tokenizer.egg-info 4 | build/ 5 | dist/ 6 | .env 7 | __pycache__/ 8 | *.pyc 9 | -------------------------------------------------------------------------------- /utf8_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | """UTF-8 Tokenizer package.""" 2 | 3 | from utf8_tokenizer.logits_processor import UTF8ValidationLogitsProcessor 4 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 5 | 6 | __all__ = ["UTF8Tokenizer", "UTF8ValidationLogitsProcessor"] 7 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | 11 | jobs: 12 | lint: 13 | name: Lint 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v6 18 | 19 | - name: Setup uv 20 | uses: astral-sh/setup-uv@v7 21 | with: 22 | python-version: "3.12" 23 | enable-cache: true 24 | activate-environment: true 25 | 26 | - name: Install dependencies 27 | run: uv pip install ".[dev]" 28 | 29 | - name: Lint code 30 | run: uv run ruff check . -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | 11 | jobs: 12 | test: 13 | name: Test 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v6 18 | 19 | - name: Setup uv 20 | uses: astral-sh/setup-uv@v7 21 | with: 22 | python-version: "3.12" 23 | enable-cache: true 24 | activate-environment: true 25 | 26 | - name: Install dependencies 27 | run: uv pip install ".[dev]" 28 | 29 | - name: Test Code 30 | run: uv run pytest -n auto --dist loadscope 31 | -------------------------------------------------------------------------------- /experiments/language-modelling/clm-byte.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="clm-bit-embeddings" 2 | export WANDB_NAME="clm-byte" 3 | export WANDB_RUN_GROUP="byte" 4 | 5 | python run_clm.py \ 6 | --output_dir ./output-clm-byte \ 7 | --dataset_name wikitext \ 8 | --dataset_config_name wikitext-2-raw-v1 \ 9 | --do_train True \ 10 | --do_eval True \ 11 | --eval_strategy epoch \ 12 | --save_strategy epoch \ 13 | --save_total_limit 1 \ 14 | --logging_steps 100 \ 15 | --logging_strategy steps \ 16 | --num_train_epochs 1 \ 17 | --model_name_or_path sbintuitions/tiny-lm \ 18 | --per_device_train_batch_size 8 \ 19 | --per_device_eval_batch_size 8 \ 20 | --block_size 512 \ 21 | --optim adamw_torch_fused \ 22 | --bf16 True \ 23 | --seed 42 \ 24 | --report_to wandb \ 25 | --include_num_input_tokens_seen True 26 | -------------------------------------------------------------------------------- /experiments/language-modelling/clm-byte-bit.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="clm-bit-embeddings" 2 | export WANDB_NAME="clm-byte-bit" 3 | export WANDB_RUN_GROUP="byte-bit" 4 | 5 | python run_clm.py \ 6 | --use_bit_embeddings True \ 7 | --output_dir ./output-clm-byte-bit \ 8 | --dataset_name wikitext \ 9 | --dataset_config_name wikitext-2-raw-v1 \ 10 | --do_train True \ 11 | --do_eval True \ 12 | --eval_strategy epoch \ 13 | --save_strategy epoch \ 14 | --save_total_limit 1 \ 15 | --logging_steps 100 \ 16 | --logging_strategy steps \ 17 | --num_train_epochs 1 \ 18 | --model_name_or_path sbintuitions/tiny-lm \ 19 | --per_device_train_batch_size 8 \ 20 | --per_device_eval_batch_size 8 \ 21 | --block_size 512 \ 22 | --optim adamw_torch_fused \ 23 | --bf16 True \ 24 | --seed 42 \ 25 | --report_to wandb \ 26 | --include_num_input_tokens_seen True 27 | -------------------------------------------------------------------------------- /experiments/language-modelling/clm-byte-t5.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="clm-bit-embeddings" 2 | export WANDB_NAME="clm-byte-t5" 3 | export WANDB_RUN_GROUP="byte-t5" 4 | 5 | python run_clm.py \ 6 | --output_dir ./output-clm-byte-t5 \ 7 | --tokenizer_name "google/byt5-small" \ 8 | --dataset_name wikitext \ 9 | --dataset_config_name wikitext-2-raw-v1 \ 10 | --do_train True \ 11 | --do_eval True \ 12 | --eval_strategy epoch \ 13 | --save_strategy epoch \ 14 | --save_total_limit 1 \ 15 | --logging_steps 100 \ 16 | --logging_strategy steps \ 17 | --num_train_epochs 1 \ 18 | --model_name_or_path sbintuitions/tiny-lm \ 19 | --per_device_train_batch_size 8 \ 20 | --per_device_eval_batch_size 8 \ 21 | --block_size 512 \ 22 | --optim adamw_torch_fused \ 23 | --bf16 True \ 24 | --seed 42 \ 25 | --report_to wandb \ 26 | --include_num_input_tokens_seen True 27 | -------------------------------------------------------------------------------- /experiments/language-modelling/clm-bpe-gpt2.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="clm-bit-embeddings" 2 | export WANDB_NAME="clm-bpe-gpt2" 3 | export WANDB_RUN_GROUP="bpe-gpt2" 4 | 5 | python run_clm.py \ 6 | --output_dir ./output-clm-bpe-gpt2 \ 7 | --tokenizer_name "openai-community/gpt2" \ 8 | --dataset_name wikitext \ 9 | --dataset_config_name wikitext-2-raw-v1 \ 10 | --do_train True \ 11 | --do_eval True \ 12 | --eval_strategy epoch \ 13 | --save_strategy epoch \ 14 | --save_total_limit 1 \ 15 | --logging_steps 100 \ 16 | --logging_strategy steps \ 17 | --num_train_epochs 1 \ 18 | --model_name_or_path sbintuitions/tiny-lm \ 19 | --per_device_train_batch_size 8 \ 20 | --per_device_eval_batch_size 8 \ 21 | --block_size 512 \ 22 | --optim adamw_torch_fused \ 23 | --bf16 True \ 24 | --seed 42 \ 25 | --report_to wandb \ 26 | --include_num_input_tokens_seen True 27 | -------------------------------------------------------------------------------- /experiments/language-modelling/analyze.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import load_state_dict_from_file 2 | from transformers import AutoModelForCausalLM, AutoConfig 3 | import torch 4 | from safetensors.torch import save_file 5 | 6 | from utf8_tokenizer.embeddings import patch_embedding_layers, join_embedding_layers 7 | 8 | MODEL_CHECKPOINT = "./output-tiny-lm-fineweb" 9 | 10 | 11 | # Load model 12 | config = AutoConfig.from_pretrained(MODEL_CHECKPOINT) 13 | model = AutoModelForCausalLM.from_config(config) 14 | patch_embedding_layers(model) 15 | 16 | state_dict = load_state_dict_from_file(f"{MODEL_CHECKPOINT}/model.safetensors") 17 | model.load_state_dict(state_dict) 18 | 19 | # Inspect bit projection weights 20 | embeddings = model.get_input_embeddings() 21 | print(embeddings.bit_proj_w.data) 22 | 23 | # Save weights to file 24 | torch.save(embeddings.bit_proj_w.data, f"{MODEL_CHECKPOINT}/bit_projection_weights.pt") 25 | 26 | # Join embedding layers back 27 | join_embedding_layers(model) 28 | save_file(model.state_dict(), f"{MODEL_CHECKPOINT}/model.safetensors") 29 | 30 | model = AutoModelForCausalLM.from_pretrained(MODEL_CHECKPOINT) 31 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | on: 3 | release: 4 | types: [ created ] 5 | 6 | jobs: 7 | pypi-publish: 8 | name: Upload release to PyPI 9 | runs-on: ubuntu-latest 10 | environment: 11 | name: pypi 12 | url: https://pypi.org/p/utf8-tokenizer 13 | permissions: 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v5 17 | 18 | - uses: actions/setup-python@v6 19 | with: 20 | python-version: "3.12" 21 | 22 | - name: Extract release version 23 | id: get_version 24 | run: echo "version=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV 25 | 26 | - name: Update version in pyproject.toml 27 | run: | 28 | sed -i 's/^version = .*/version = "${{ env.version }}"/' pyproject.toml 29 | 30 | - name: Install build dependencies 31 | run: pip install build 32 | 33 | - name: Build a binary wheel dist 34 | run: | 35 | rm -rf dist 36 | python -m build 37 | 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 sign 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utf8_tokenizer/byt5_comparison.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from transformers import ByT5Tokenizer 4 | 5 | 6 | class ByT5ComparableTokenizer(ByT5Tokenizer): 7 | def __init__(self, *args, **kwargs): 8 | kwargs["unk_token"] = kwargs.get("unk_token", "") 9 | kwargs["bos_token"] = kwargs.get("bos_token", kwargs["unk_token"]) 10 | # Aim for 256 bytes + 3 special tokens + 5 extra ids 11 | kwargs["extra_ids"] = kwargs.get("extra_ids", 5) 12 | super().__init__(*args, **kwargs) 13 | 14 | def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]: 15 | token_ids = super()._add_eos_if_not_present(token_ids) 16 | 17 | # ByT5Tokenizer does not add BOS token by default, so we add it here 18 | if len(token_ids) > 0 and token_ids[0] == self.bos_token_id: 19 | warnings.warn( 20 | f"This sequence already has {self.bos_token_id}. In future versions this behavior may lead to " 21 | f"duplicated bos tokens being added.", 22 | stacklevel=2, 23 | ) 24 | return token_ids 25 | 26 | return [self.bos_token_id] + token_ids 27 | -------------------------------------------------------------------------------- /tests/test_control.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from utf8_tokenizer.control import ControlTokens, visualize_control_tokens 4 | 5 | message = f""" 6 | {ControlTokens.StartOfText}{ControlTokens.StartOfHeading}system 7 | {ControlTokens.ShiftOut}You are a helpful assistant{ControlTokens.ShiftIn}{ControlTokens.EndOfTransmissionBlock} 8 | {ControlTokens.StartOfHeading}user 9 | {ControlTokens.ShiftOut}How much is 1+2?{ControlTokens.ShiftIn}{ControlTokens.EndOfTransmissionBlock} 10 | {ControlTokens.StartOfHeading}assistant 11 | First I'll think about it. 12 | {ControlTokens.Enquiry} The user wants me to calculate, I should call the calculator 13 | {ControlTokens.Substitute}{{``type'': ``calculator'', ``expression'': ``1+2''}} 14 | {ControlTokens.Escape}3{ControlTokens.Acknowledge} 15 | 1 + 2 = 3{ControlTokens.EndOfTransmissionBlock}{ControlTokens.EndOfText} 16 | {ControlTokens.Null}{ControlTokens.Null}{ControlTokens.Null}{ControlTokens.Null} 17 | """.strip() 18 | 19 | visualized_message = """ 20 | ␂␁system 21 | ␎You are a helpful assistant␏␗ 22 | ␁user 23 | ␎How much is 1+2?␏␗ 24 | ␁assistant 25 | First I'll think about it. 26 | ␅ The user wants me to calculate, I should call the calculator 27 | ␚{``type'': ``calculator'', ``expression'': ``1+2''} 28 | ␛3␆ 29 | 1 + 2 = 3␗␃ 30 | ␀␀␀␀ 31 | """.strip() 32 | 33 | 34 | def test_control_visualization(): 35 | result = visualize_control_tokens(message) 36 | print("result", result) 37 | 38 | assert result == visualized_message 39 | 40 | 41 | if __name__ == "__main__": 42 | pytest.main([__file__, "-v"]) 43 | -------------------------------------------------------------------------------- /experiments/language-modelling/models/utf8-lm-tiny.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="utf8" 2 | export WANDB_NAME="tiny-lm-fineweb" 3 | 4 | # tiny-lm using utf8 contains 3m parameters 5 | # Chinchilla scaling law optimal number of tokens is 20x number of parameters 6 | # So for 3.2m parameters, we want to train on 64m tokens 7 | # But! this was calculated for models with standard embeddings, so assuming 16m parameters, 8 | # instead, we train on (128 batch size * 256 block size * 10,000 steps) = 327m tokens 9 | 10 | python run_clm.py \ 11 | --use_bit_embeddings True \ 12 | --output_dir ./output-tiny-lm-fineweb \ 13 | --dataset_name HuggingFaceFW/fineweb \ 14 | --streaming True \ 15 | --dataloader_num_workers 1 \ 16 | --dataloader_prefetch_factor 4 \ 17 | --dataloader_pin_memory True \ 18 | --dataloader_persistent_workers True \ 19 | --do_train True \ 20 | --save_strategy steps \ 21 | --max_steps 20000 \ 22 | --save_steps 1000 \ 23 | --save_total_limit 2 \ 24 | --logging_steps 100 \ 25 | --logging_strategy steps \ 26 | --model_name_or_path sbintuitions/tiny-lm \ 27 | --per_device_train_batch_size 128 \ 28 | --block_size 256 \ 29 | --optim adamw_torch_fused \ 30 | --learning_rate 3e-4 \ 31 | --lr_scheduler_type cosine \ 32 | --warmup_ratio 0.01 \ 33 | --weight_decay 0.1 \ 34 | --adam_beta1 0.9 \ 35 | --adam_beta2 0.95 \ 36 | --max_grad_norm 1.0 \ 37 | --gradient_checkpointing True \ 38 | --bf16 True \ 39 | --seed 42 \ 40 | --report_to wandb \ 41 | --include_num_input_tokens_seen True 42 | 43 | # hf upload sign/utf8-lm-tiny . . -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "utf8-tokenizer" 3 | description = "True UTF-8 tokenizer for byte level models" 4 | version = "0.0.1" 5 | authors = [ 6 | { name = "Amit Moryossef", email = "amit@sign.mt" }, 7 | ] 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | dependencies = [ 11 | "transformers[torch]", 12 | "torch" 13 | ] 14 | 15 | [project.optional-dependencies] 16 | dev = [ 17 | "ruff", 18 | "pytest", 19 | "pytest-xdist", # For parallel test execution 20 | ] 21 | 22 | fast = [ 23 | "numba", # For JIT-compiled padding (faster tokenization) 24 | ] 25 | 26 | train = [ 27 | "datasets", # For dataset loading and processing 28 | "evaluate", # For evaluation metrics 29 | "scikit-learn", # For "accuracy" metric in evaluate 30 | "wandb", # For experiment tracking 31 | ] 32 | 33 | [tool.setuptools] 34 | packages = [ 35 | "utf8_tokenizer", 36 | ] 37 | 38 | [tool.setuptools.package-data] 39 | utf8_tokenizer = ["*.jinja"] 40 | 41 | [tool.ruff] 42 | line-length = 120 43 | extend-exclude = [ 44 | "experiments/*", 45 | ] 46 | 47 | [tool.ruff.lint] 48 | select = [ 49 | "E", # pycodestyle errors 50 | "W", # pycodestyle warnings 51 | "F", # pyflakes 52 | "C90", # mccabe complexity 53 | "I", # isort 54 | "N", # pep8-naming 55 | "UP", # pyupgrade 56 | "B", # flake8-bugbear 57 | "PT", # flake8-pytest-style 58 | "W605", # invalid escape sequence 59 | "BLE", # flake8-blind-except 60 | "TRY", # tryceratops 61 | ] 62 | 63 | [tool.pytest.ini_options] 64 | addopts = "-v" 65 | testpaths = ["utf8_tokenizer", "tests"] 66 | -------------------------------------------------------------------------------- /experiments/language-modelling/README.md: -------------------------------------------------------------------------------- 1 | # Bit-Biased Byte-Level Language Modelling 2 | 3 | > [!TIP] 4 | > This trains on a macOS system with M4 chip in 10 minutes~ 5 | 6 | 7 | We modified the `run_clm.py` script, in this directory. 8 | See the modifications by running a diff with: 9 | https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py 10 | 11 | - We force the use of `UTF8Tokenizer`. 12 | - We add an argument `--use_bit_embeddings` to use the `patch_embedding_layer` 13 | 14 | 15 | ## Setup 16 | 17 | ```shell 18 | # install the library and training dependencies. 19 | pip install ".[train]" 20 | 21 | # Login to Weights & Biases 22 | wandb login 23 | export WANDB_PROJECT="clm-bit-embeddings" 24 | ``` 25 | 26 | ## Training 27 | 28 | We train `sbintuitions/tiny-lm` using: 29 | 30 | ```shell 31 | python run_clm.py \ 32 | --use_bit_embeddings False \ 33 | --output_dir ./output-clm-byte \ 34 | ``` 35 | 36 | Compared to our model setup: 37 | ```shell 38 | python run_clm.py \ 39 | --use_bit_embeddings True \ 40 | --output_dir ./output-clm-byte-bit \ 41 | ``` 42 | 43 | With the following shared arguments: 44 | 45 | ```shell 46 | --dataset_name wikitext \ 47 | --dataset_config_name wikitext-2-raw-v1 \ 48 | --do_train True \ 49 | --do_eval True \ 50 | --eval_strategy epoch \ 51 | --save_strategy epoch \ 52 | --save_total_limit 1 \ 53 | --logging_steps 100 \ 54 | --logging_strategy steps \ 55 | --num_train_epochs 1 \ 56 | --model_name_or_path sbintuitions/tiny-lm \ 57 | --per_device_train_batch_size 8 \ 58 | --per_device_eval_batch_size 8 \ 59 | --block_size 512 \ 60 | --optim adamw_torch_fused \ 61 | --bf16 True \ 62 | --seed 42 \ 63 | --report_to wandb \ 64 | --include_num_input_tokens_seen True 65 | ``` 66 | 67 | If you want to use a different tokenizer, you can specify it with: 68 | ```shell 69 | --tokenizer_name "google/byt5-small" 70 | ``` -------------------------------------------------------------------------------- /experiments/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | from utf8_tokenizer.byt5_comparison import ByT5ComparableTokenizer 5 | from utf8_tokenizer.tokenizer import UTF8Tokenizer, tokenize_ids 6 | 7 | if __name__ == "__main__": 8 | tokenizer = UTF8Tokenizer() 9 | texts = [ 10 | "Hello", "World", "the", "a", "is", "of", "and", "to", "in", "that", 11 | "שלום", "עולם", "את", "של", "על", "עם", "לא", "הוא", "היא", "זה", 12 | "I", "you", "we", "they", "it", "be", "have", "do", "say", "get", 13 | "make", "go", "know", "take", "see", "come", "think", "look", "want", "give", 14 | "use", "find", "tell", "ask", "work", "seem", "feel", "try", "leave", "call", 15 | "", "", "\x0E", "\x0F", # Special tokens 16 | ".", ",", "!", "?", ":", ";", "-", "(", ")", '"', 17 | "hello!", "world.", "test,", "foo-bar", "(test)", '"quoted"', 18 | "אבגדהוזחטיכלמנסעפצקרשת", # Hebrew alphabet 19 | "ABCDEFGHIJKLMNOPQRSTUVWXYZ", # English alphabet 20 | "0123456789", # Numbers 21 | " ", # Space 22 | "emoji🤗", 23 | ] 24 | 25 | num = 100000 26 | 27 | for _ in tqdm(range(num), desc="just tokenizing to ints"): 28 | for text in texts: 29 | tokenize_ids(text) 30 | 31 | for _ in tqdm(range(num), desc="Calling the new tokenizer.torch"): 32 | tokenizer.torch(texts, add_special_tokens=True, padding=True) 33 | 34 | # for _ in tqdm(range(num), desc="Calling the new tokenizer"): 35 | # tokenizer(texts, add_special_tokens=True, padding=True, return_tensors="pt") 36 | # 37 | # for _ in tqdm(range(num), desc="Calling the new tokenizer._original_call"): 38 | # tokenizer._original_call(texts, add_special_tokens=True, padding=True, return_tensors="pt") 39 | # 40 | # tokenizer = ByT5ComparableTokenizer() 41 | # for _ in tqdm(range(num), "Calling the old tokenizer"): 42 | # tokenizer(texts, add_special_tokens=True, padding=True, return_tensors="pt") 43 | -------------------------------------------------------------------------------- /tests/test_chat_template.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from utf8_tokenizer.control import visualize_control_tokens 4 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 5 | 6 | 7 | def get_current_temperature(location: str, unit: str): 8 | """ 9 | Get the current temperature at a location. 10 | 11 | Args: 12 | location: The location to get the temperature for, in the format "City, Country" 13 | unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) 14 | """ 15 | return 22. # A real function should probably actually get the temperature! 16 | 17 | 18 | messages = [ 19 | {"role": "system", 20 | "content": "You are a bot that responds to weather queries. " 21 | "You should reply with the unit used in the queried location."}, 22 | {"role": "user", "content": "Hey, what's the temperature in Paris right now?"}, 23 | {"role": "assistant", "tool_calls": [ 24 | {"type": "function", "function": 25 | {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}}]}, 26 | {"role": "tool", "content": "22"} 27 | ] 28 | 29 | 30 | def test_chat_template(): 31 | tokenizer = UTF8Tokenizer() 32 | 33 | text = tokenizer.apply_chat_template( 34 | messages, 35 | tools=[get_current_temperature], 36 | add_generation_prompt=True, 37 | tokenize=False 38 | ) 39 | 40 | expected = """␑␎get_current_temperature 41 | location: The location to get the temperature for, in the format "City, Country" 42 | unit: The unit to return the temperature in. 43 | ␏␁system 44 | ␎You are a bot that responds to weather queries. You should reply with the unit used in the queried location.␏␗ 45 | ␁user 46 | ␎Hey, what's the temperature in Paris right now?␏␗ 47 | ␁assistant 48 | ␚function get_current_temperature({"location": "Paris, France", "unit": "celsius"})␛␗ 49 | ␁tool 50 | ␎22␏␗ 51 | ␁assistant 52 | """ 53 | 54 | assert expected == visualize_control_tokens(text) 55 | 56 | 57 | if __name__ == "__main__": 58 | pytest.main([__file__, "-v"]) 59 | -------------------------------------------------------------------------------- /utf8_tokenizer/control.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class ControlTokens: 5 | Null = "\x00" 6 | StartOfHeading = "\x01" 7 | StartOfText = "\x02" 8 | EndOfText = "\x03" 9 | EndOfTransmission = "\x04" 10 | Enquiry = "\x05" 11 | Acknowledge = "\x06" 12 | Alert = "\x07" 13 | Backspace = "\x08" 14 | HorizontalTab = "\x09" # Whitespace 15 | LineFeed = "\x0a" # Whitespace 16 | VerticalTab = "\x0b" # Whitespace 17 | FormFeed = "\x0c" # Whitespace 18 | CarriageReturn = "\x0d" # Whitespace 19 | ShiftOut = "\x0e" 20 | ShiftIn = "\x0f" 21 | DataLinkEscape = "\x10" 22 | DeviceControl1 = "\x11" 23 | DeviceControl2 = "\x12" 24 | DeviceControl3 = "\x13" 25 | DeviceControl4 = "\x14" 26 | NegativeAcknowledge = "\x15" 27 | SynchronousIdle = "\x16" 28 | EndOfTransmissionBlock = "\x17" 29 | Cancel = "\x18" 30 | EndOfMedium = "\x19" 31 | Substitute = "\x1a" 32 | Escape = "\x1b" 33 | FileSeparator = "\x1c" 34 | GroupSeparator = "\x1d" 35 | RecordSeparator = "\x1e" 36 | UnitSeparator = "\x1f" 37 | Space = "\x20" # Whitespace 38 | Delete = "\x7f" 39 | 40 | 41 | CONTROl_TOKENS_PATTERN = "\x01-\x08\x0e-\x1f\x7f" 42 | 43 | CONTROL_WHITESPACES = { 44 | ControlTokens.HorizontalTab, 45 | ControlTokens.LineFeed, 46 | ControlTokens.VerticalTab, 47 | ControlTokens.FormFeed, 48 | ControlTokens.CarriageReturn, 49 | ControlTokens.Space, 50 | } 51 | 52 | 53 | def visualize_control_tokens(text: str, include_whitespace=False) -> str: 54 | # Special visual handling for control characters using Control Pictures Unicode block 55 | # Based on https://unicode.org/charts/nameslist/n_2400.html 56 | def control_char_to_symbol(match): 57 | char = match.group(0) 58 | code = ord(char) 59 | if not include_whitespace and char in CONTROL_WHITESPACES: 60 | return char 61 | 62 | if code <= 0x1F: # Control characters 0x00-0x1F map to 0x2400-0x241F 63 | return chr(0x2400 + code) 64 | elif code == 0x7F: # DELETE character maps to 0x2421 65 | return chr(0x2421) 66 | return char 67 | 68 | return re.sub(r"[\x00-\x1F\x7F]", control_char_to_symbol, text) 69 | -------------------------------------------------------------------------------- /utf8_tokenizer/chat_template.jinja: -------------------------------------------------------------------------------- 1 | {# ------------------------------- 2 | Optional: Define available tools 3 | ------------------------------- #} 4 | {%- if tools %} 5 | {%- for tool in tools %} 6 | {{- "\x11" -}} {# Start Of Tool Definition #} 7 | {{- "\x0E" -}} {# Start Of Attention Block #} 8 | {{- tool['function']['name'] + '\n' -}} 9 | {# Loop through all argument names and their descriptions #} 10 | {%- for argument in tool['function']['parameters']['properties'] %} 11 | {{- argument + ': ' + tool['function']['parameters']['properties'][argument]['description'] + '\n' -}} 12 | {%- endfor %} 13 | {{- '\x0F' -}} {# End Of Attention Block #} 14 | {%- endfor %} 15 | {%- endif %} 16 | 17 | {# ------------------------------- 18 | Main conversation message loop 19 | ------------------------------- #} 20 | {% for message in messages %} 21 | {{- "\x01" -}} {# Start Of Text Block #} 22 | 23 | {# Print the role tag, e.g. user or assistant #} 24 | {{- message.role + "\n" -}} 25 | 26 | {# If not the "assistant", we wrap with an attention block #} 27 | {% if message.role != "assistant" %} 28 | {{- "\x0E" -}} 29 | {% endif %} 30 | 31 | {# If the message contains normal content, print it #} 32 | {% if message.content %} 33 | {{- message.content -}} 34 | {% endif %} 35 | 36 | {# If the assistant called any tools, print those tool calls #} 37 | {% if message.tool_calls %} 38 | {% for call in message.tool_calls %} 39 | {{- "\x1A" -}} {# Start Of Tool Call #} 40 | {{- call.type + " " + call.function.name 41 | + "(" + call.function.arguments | tojson + ")" -}} 42 | {{- "\x1B" -}} {# End Of Tool Call #} 43 | {% endfor %} 44 | {% endif %} 45 | 46 | {# If not the "assistant", we wrap with an attention block #} 47 | {% if message.role != "assistant" %} 48 | {{- "\x0F" -}} 49 | {% endif %} 50 | 51 | {{- "\x17" -}} {# End Of Text Block #} 52 | {{- "\n" -}} {# Newline after each message #} 53 | {% endfor %} 54 | 55 | {# ------------------------------- 56 | Add a final assistant prompt marker to continue generation 57 | ------------------------------- #} 58 | {%- if add_generation_prompt %} 59 | {{- "\x01assistant\n" -}} 60 | {%- endif %} -------------------------------------------------------------------------------- /utf8_tokenizer/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | try: 5 | from numba import jit 6 | 7 | @jit(nopython=True) 8 | def _fill_padded(output: np.ndarray, all_values: np.ndarray, lengths: np.ndarray, pad_value: int) -> None: 9 | """Numba JIT-compiled loop to fill values and padding in single pass.""" 10 | batch_size, max_len = output.shape 11 | offset = 0 12 | for i in range(batch_size): 13 | length = lengths[i] 14 | # Fill data 15 | for j in range(length): 16 | output[i, j] = all_values[offset + j] 17 | # Fill padding 18 | for j in range(length, max_len): 19 | output[i, j] = pad_value 20 | offset += length 21 | 22 | except ImportError: 23 | def _fill_padded(output: np.ndarray, all_values: np.ndarray, lengths: np.ndarray, pad_value: int) -> None: 24 | """Fill padded output array using vectorized numpy ops.""" 25 | output.fill(pad_value) 26 | if len(all_values) == 0: 27 | return 28 | batch_size = len(lengths) 29 | row_indices = np.repeat(np.arange(batch_size), lengths) 30 | cumsum = lengths.cumsum() 31 | positions = np.arange(len(all_values)) 32 | groups = np.searchsorted(cumsum, positions, side='right') 33 | prev_cumsum = np.empty(batch_size, dtype=np.uint32) 34 | prev_cumsum[0] = 0 35 | prev_cumsum[1:] = cumsum[:-1] 36 | col_indices = positions - prev_cumsum[groups] 37 | output[row_indices, col_indices] = all_values 38 | 39 | 40 | def pad_bytearrays_to_tensor_loop(bytearrays: list[bytearray], padding_value: int = 0) -> torch.Tensor: 41 | """ 42 | Pad a list of bytearrays into a single tensor using a simple loop. 43 | 44 | This is the reference implementation for testing. 45 | """ 46 | max_len = max(len(b) for b in bytearrays) 47 | output = np.full((len(bytearrays), max_len), padding_value, dtype=np.uint8) 48 | 49 | for i, b in enumerate(bytearrays): 50 | output[i, :len(b)] = np.frombuffer(b, dtype=np.uint8) 51 | 52 | return torch.from_numpy(output) 53 | 54 | 55 | def pad_bytearrays_to_tensor(bytearrays: list[bytes], padding_value: int = 0) -> torch.Tensor: 56 | """ 57 | Pad a list of byte sequences into a single tensor. 58 | 59 | Uses Numba JIT if available, otherwise falls back to vectorized numpy. 60 | """ 61 | lengths = np.fromiter(map(len, bytearrays), dtype=np.uint32, count=len(bytearrays)) 62 | output = np.empty((len(bytearrays), lengths.max()), dtype=np.uint8) 63 | 64 | all_bytes = b''.join(bytearrays) 65 | all_values = np.frombuffer(all_bytes, dtype=np.uint8) 66 | 67 | _fill_padded(output, all_values, lengths, padding_value) 68 | 69 | return torch.from_numpy(output) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Back to Bytes: Revisiting Tokenization Through `UTF-8` 2 | 3 | ![Python](https://img.shields.io/badge/python-3.10-blue) 4 | [![License](https://img.shields.io/badge/license-MIT-green)](./LICENSE) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2510.16987-b31b1b.svg)](https://arxiv.org/abs/2510.16987) 6 | 7 | Full writeup can be found in our paper. 8 | 9 | This module includes a **real** byte level tokenizer for text, which encodes text into a sequence of bytes (0-255). 10 | Unlike `ByT5Tokenizer` for example, `UTF8Tokenizer` is implemented from scratch, and is much more efficient. 11 | 12 | Other "Byte Level" tokenizers usually include various additional "special tokens" (e.g., ``, ``, etc.), 13 | making the encoding and decoding logic more complex, and the token ids larger than 255. 14 | 15 | Instead, we rely on C0 Control characters (0-31) as special tokens, which are not used in normal text. 16 | 17 | ## Usage 18 | 19 | ```shell 20 | pip install utf8-tokenizer 21 | ``` 22 | 23 | Tokenization: 24 | 25 | ```python 26 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 27 | 28 | tokenizer = UTF8Tokenizer() 29 | 30 | texts = ["word", "or multiple"] 31 | print(tokenizer(texts)) 32 | ``` 33 | 34 | Chat Template: 35 | 36 | ```py 37 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 38 | from utf8_tokenizer.control import visualize_control_tokens 39 | 40 | messages = [ 41 | {"role": "system", "content": "You are a helpful assistant."}, 42 | {"role": "user", "content": "Hey, what's 1+1?"}, 43 | {"role": "assistant", "content": "1+1 is 2."}, 44 | ] 45 | 46 | tokenizer = UTF8Tokenizer() 47 | text = tokenizer.apply_chat_template(messages, tokenize=False) 48 | 49 | # Visualize the text with special tokens 50 | print(visualize_control_tokens(text)) 51 | ``` 52 | 53 | Bit-biased byte embeddings: 54 | 55 | ```py 56 | from transformers import AutoModelForCausalLM 57 | 58 | # Load example model 59 | model = AutoModelForCausalLM.from_pretrained("sbintuitions/tiny-lm") 60 | model.resize_token_embeddings(256) 61 | 62 | from utf8_tokenizer.embeddings import patch_embedding_layers, join_embedding_layers 63 | 64 | patch_embedding_layers(model) # Apply bit-bias for training 65 | 66 | # 67 | # Train your model... 68 | # 69 | 70 | join_embedding_layers(model) # Fold to a single embedding layer for inference 71 | ``` 72 | 73 | UTF-8 Validation during Generation: 74 | 75 | ```py 76 | from transformers import AutoModelForCausalLM 77 | from utf8_tokenizer import UTF8Tokenizer, UTF8ValidationLogitsProcessor 78 | 79 | # Load your byte-level model 80 | model = AutoModelForCausalLM.from_pretrained("your-model") 81 | tokenizer = UTF8Tokenizer() 82 | 83 | # Create the UTF-8 validation processor 84 | utf8_processor = UTF8ValidationLogitsProcessor() 85 | 86 | # Generate text with UTF-8 validation 87 | input_text = "Hello" 88 | input_ids = tokenizer(input_text, return_tensors="pt").input_ids 89 | 90 | outputs = model.generate( 91 | input_ids, 92 | logits_processor=[utf8_processor], # Ensures valid UTF-8 sequences 93 | max_new_tokens=100 94 | ) 95 | 96 | # Decode the output 97 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 98 | print(generated_text) 99 | ``` 100 | 101 | The `UTF8ValidationLogitsProcessor` prevents byte-level tokenizers from generating malformed UTF-8 sequences by masking invalid byte continuations during generation. This addresses the issue discussed in [Firestone et al. 2024](https://openreview.net/pdf?id=8ExXncFpf6) where byte-level tokenizers can generate ill-formed UTF-8. 102 | 103 | ## Benchmark 104 | 105 | ### Tokenization Speed 106 | 107 | ```shell 108 | python experiments/benchmark.py 109 | ``` 110 | 111 | On MacBook Pro, with Apple M4 Pro chip, just converting texts of 75 words in different languages to bytes, 112 | without wrapping them in tensors, creating attention masks, or padding, runs at 109.9k/sec. 113 | 114 | Calling the ByT5 tokenizer runs at 0.4k/sec. 115 | When we call our new tokenizer, through the `__call__` path, we get 0.5k/sec, which is a bit faster. 116 | 117 | Our optimized version with zero-copy runs at 66k/sec, where the loss of performance compared to the raw ints is 118 | in padding the input ids into a properly padded tensor. **This is a 164x speedup over the original tokenizer.** 119 | 120 | ### Bit-Biased Byte Embedding 121 | 122 | We [train a small language model](experiments/language-modelling/README.md) with and without bit-bias. 123 | 124 | Our results reveal that bit-bias improves both loss and accuracy, while increasing training time by about 1%. 125 | We hope that our bit-level embeddings module can be further optimized, to minimize the training overhead. 126 | 127 | ## Cite 128 | 129 | If you use this code in your research, please consider citing the work: 130 | 131 | ```bibtex 132 | @misc{moryossef2025utf8, 133 | title = {Back to Bytes: Revisiting Tokenization Through {UTF-8}}, 134 | author = {Amit Moryossef and Clara Meister and Pavel Stepachev and Desmond Elliott}, 135 | howpublished = {\url{https://github.com/sign/utf8-tokenizer}}, 136 | eprint = {2510.16987}, 137 | archivePrefix = {arXiv}, 138 | primaryClass = {cs.CL}, 139 | url = {https://arxiv.org/abs/2510.16987}, 140 | year = {2025} 141 | } 142 | ``` -------------------------------------------------------------------------------- /utf8_tokenizer/embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Embedding 4 | 5 | 6 | def unpack_bits(x: torch.Tensor) -> torch.Tensor: 7 | assert x.dtype == torch.uint8, "Expected bytes tensor input (torch.uint8)" 8 | 9 | # Create shifts by [7, 6, 5, 4, 3, 2, 1, 0] 10 | shifts = torch.arange(7, -1, -1, device=x.device, dtype=torch.uint8) 11 | # Shift the integers (tensor is still integers) 12 | shifted = x.unsqueeze(-1) >> shifts # (B, L, 8) 13 | # Masks off all but the least significant bit 14 | return shifted & 1 15 | 16 | 17 | class PatchedBitEmbeddings(nn.Module): 18 | """ 19 | nn.Embedding + additive 8->D bit projection. 20 | Exposes a *Parameter* `.weight` for HF tying, but learns via: 21 | - base table: self.embeddings.weight (256, D) 22 | - bit proj : self.bit_proj_w (D, 8) 23 | 24 | The façade `self.weight` is refreshed in-place with: 25 | W = base + bits(256,8) @ bit_proj_w.T(8,D) 26 | and all grads arriving on the façade are re-routed to the real params. 27 | """ 28 | 29 | def __init__(self, embeddings: nn.Embedding): 30 | super().__init__() 31 | assert isinstance(embeddings, nn.Embedding) 32 | assert embeddings.weight.shape[0] == 256, "Expected byte-level embedding layer (256 rows)." 33 | 34 | self.embeddings = embeddings 35 | D = embeddings.embedding_dim # noqa: N806 36 | dtype = self.embeddings.weight.dtype 37 | 38 | # Tiny bit projection; use a bare Parameter to avoid Module overhead & extra .t() 39 | self.bit_proj_w = nn.Parameter(torch.zeros(D, 8, dtype=dtype)) # init=0 ⇒ starts identical to base table 40 | 41 | # Tieable façade parameter (what HF ties to the LM head) 42 | self.weight = nn.Parameter(embeddings.weight.detach().clone(), requires_grad=True) 43 | 44 | # Bits table buffer (float32 initially); device/dtype-adjusted copies are cached lazily 45 | all_bytes = torch.arange(256, dtype=torch.uint8) 46 | self.register_buffer("_bits256_base", unpack_bits(all_bytes), persistent=False) 47 | self._bits_cached = None # (256, 8) on current device/dtype 48 | self._bits_device = torch.device("meta") # sentinel (forces first refresh) 49 | self._bits_dtype = dtype 50 | 51 | # Rebuild only when needed (params changed or device/dtype changed) 52 | self._last_base_v = -1 53 | self._last_bit_v = -1 54 | self._last_device = None 55 | self._last_dtype = None 56 | 57 | # Route grads from façade into the true params; block façade updates 58 | def _route_grad(grad: torch.Tensor): 59 | # grad = dL/dW, shape (256, D) 60 | g = grad.detach() 61 | 62 | # base grad: += grad 63 | ew = self.embeddings.weight 64 | if ew.grad is None: 65 | ew.grad = g.clone() 66 | else: 67 | ew.grad.add_(g) 68 | 69 | # bit proj grad: dL/d(D,8) = (dL/d(256,D))^T @ bits(256,8) → (D,8) 70 | # use cached device/dtype-adjusted bits 71 | self._ensure_bits_cached() # cheap no-op after first call 72 | gb = g.t().mm(self._bits_cached) # (D,8) 73 | if self.bit_proj_w.grad is None: 74 | self.bit_proj_w.grad = gb.clone() 75 | else: 76 | self.bit_proj_w.grad.add_(gb) 77 | 78 | # façade param should not be optimized directly 79 | return torch.zeros_like(grad) 80 | 81 | self.weight.register_hook(_route_grad) 82 | 83 | # Keep façade fresh even if call order is quirky (runs before forward) 84 | self.register_forward_pre_hook(lambda m, inp: m._maybe_refresh_weight_()) 85 | 86 | # ---- internals ---- 87 | 88 | def _ensure_bits_cached(self): 89 | """Cache bits in current device/dtype (no work in the hot path after first time).""" 90 | w = self.embeddings.weight 91 | if (self._bits_device is not w.device) or (self._bits_dtype != w.dtype): 92 | self._bits_cached = self._bits256_base.to(device=w.device, dtype=w.dtype, non_blocking=True).contiguous() 93 | self._bits_device, self._bits_dtype = w.device, w.dtype 94 | 95 | def _needs_refresh(self) -> bool: 96 | w = self.embeddings.weight 97 | bw = self.bit_proj_w 98 | return ( 99 | w._version != self._last_base_v 100 | or bw._version != self._last_bit_v 101 | or w.device is not self._last_device 102 | or w.dtype != self._last_dtype 103 | ) 104 | 105 | def _mark_refreshed(self): 106 | w = self.embeddings.weight 107 | self._last_base_v = w._version 108 | self._last_bit_v = self.bit_proj_w._version 109 | self._last_device = w.device 110 | self._last_dtype = w.dtype 111 | 112 | @torch.no_grad() 113 | def _refresh_weight_(self): 114 | """façade = base + bits @ bit_proj^T (fused with addmm_ for speed).""" 115 | self._ensure_bits_cached() 116 | # Copy base 117 | self.weight.data.copy_(self.embeddings.weight) 118 | # Fused GEMM + add: (256,8) @ (8,D) → (256,D) 119 | self.weight.data.addmm_(self._bits_cached, self.bit_proj_w.t()) 120 | 121 | @torch.no_grad() 122 | def _maybe_refresh_weight_(self): 123 | if self._needs_refresh(): 124 | self._refresh_weight_() 125 | self._mark_refreshed() 126 | 127 | # ---- public API ---- 128 | 129 | def forward(self, input_ids: torch.Tensor) -> torch.Tensor: 130 | # Façade already refreshed by pre-hook; forward is just a normal embedding lookup 131 | input_ids = input_ids.to(dtype=torch.long) 132 | # TODO: ideally, we should use 133 | # return nn.functional.embedding(input_ids, self.weight) 134 | # https://github.com/pytorch/pytorch/issues/162918 135 | return self.weight[input_ids] 136 | 137 | 138 | def patch_embedding_layers(model): 139 | embeddings: Embedding = model.get_input_embeddings() 140 | assert isinstance(embeddings, Embedding), "Expected nn.Embedding layer" 141 | assert len(embeddings.weight) == 256, "Expected byte-level embedding layer" 142 | patched_embeddings = PatchedBitEmbeddings(embeddings) 143 | 144 | model.set_input_embeddings(patched_embeddings) 145 | model.tie_embeddings_and_encoder_decoder() 146 | 147 | 148 | def join_embedding_layers(model): 149 | embeddings: PatchedBitEmbeddings = model.get_input_embeddings() 150 | assert isinstance(embeddings, PatchedBitEmbeddings), "Expected patched embedding layer" 151 | 152 | # Reuse the original embedding to preserve weight tying 153 | original_embedding = embeddings.embeddings 154 | original_embedding.weight.data = embeddings.weight.data 155 | 156 | model.set_input_embeddings(original_embedding) 157 | model.tie_embeddings_and_encoder_decoder() 158 | -------------------------------------------------------------------------------- /tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from transformers import AutoModelForMaskedLM 5 | 6 | from utf8_tokenizer.embeddings import ( 7 | PatchedBitEmbeddings, 8 | join_embedding_layers, 9 | patch_embedding_layers, 10 | unpack_bits, 11 | ) 12 | 13 | 14 | class TestEmbeddings: 15 | @pytest.fixture 16 | def model(self): 17 | model = AutoModelForMaskedLM.from_pretrained("prajjwal1/bert-tiny") 18 | model.resize_token_embeddings(256) 19 | return model 20 | 21 | @pytest.fixture 22 | def sample_input(self): 23 | B, L = 2, 16 # noqa: N806 24 | return torch.randint(0, 256, (B, L), dtype=torch.uint8) 25 | 26 | @pytest.fixture 27 | def attention_mask(self): 28 | B, L = 2, 16 # noqa: N806 29 | return torch.ones(B, L, dtype=torch.long) 30 | 31 | def test_unpack_bits(self): 32 | x = torch.tensor([255, 0, 128], dtype=torch.uint8) 33 | bits = unpack_bits(x) 34 | 35 | assert bits.shape == (3, 8) 36 | assert torch.allclose(bits[0], torch.tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)) 37 | assert torch.allclose(bits[1], torch.tensor([0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)) 38 | assert torch.allclose(bits[2], torch.tensor([1, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)) 39 | 40 | def test_embedding_setup(self, model): 41 | original_embeddings = model.get_input_embeddings() 42 | assert isinstance(original_embeddings, nn.Embedding) 43 | assert original_embeddings.num_embeddings == 256 44 | 45 | patch_embedding_layers(model) 46 | patched_embeddings = model.get_input_embeddings() 47 | assert isinstance(patched_embeddings, PatchedBitEmbeddings) 48 | assert isinstance(patched_embeddings.embeddings, nn.Embedding) 49 | assert patched_embeddings.embeddings.num_embeddings == 256 50 | 51 | def test_model_callable_after_patching(self, model, sample_input, attention_mask): 52 | patch_embedding_layers(model) 53 | 54 | with torch.inference_mode(): 55 | output = model(input_ids=sample_input, attention_mask=attention_mask) 56 | 57 | assert hasattr(output, "logits") 58 | assert output.logits.shape == (2, 16, 256) 59 | 60 | def test_parameter_count_preservation(self, model): 61 | original_param_count = sum(p.numel() for p in model.parameters()) 62 | original_embedding_params = model.get_input_embeddings().weight.numel() 63 | 64 | patch_embedding_layers(model) 65 | patched_embeddings = model.get_input_embeddings() 66 | patched_param_count = sum(p.numel() for p in model.parameters()) 67 | 68 | join_embedding_layers(model) 69 | final_param_count = sum(p.numel() for p in model.parameters()) 70 | final_embedding_params = model.get_input_embeddings().weight.numel() 71 | 72 | # The embedding layer parameter count should be preserved 73 | assert original_embedding_params == final_embedding_params 74 | # The patched version should have additional parameters (bit projection) 75 | assert patched_param_count > original_param_count 76 | # The total parameter difference includes bit projection + combined weight 77 | embedding_dim = patched_embeddings.embeddings.embedding_dim 78 | expected_extra_params = 8 * embedding_dim + 256 * embedding_dim # bit_proj + combined weight 79 | assert patched_param_count - original_param_count == expected_extra_params 80 | # After joining, we should be back to the original count 81 | assert final_param_count == original_param_count 82 | 83 | def test_embedding_weight_preservation(self, model, sample_input): 84 | sample_input_int = sample_input.to(dtype=torch.int) 85 | 86 | original_embeddings = model.get_input_embeddings() 87 | original_weight = original_embeddings.weight.clone() 88 | 89 | with torch.inference_mode(): 90 | original_output = original_embeddings(sample_input_int) 91 | 92 | patch_embedding_layers(model) 93 | join_embedding_layers(model) 94 | 95 | final_embeddings = model.get_input_embeddings() 96 | final_weight = final_embeddings.weight 97 | 98 | with torch.inference_mode(): 99 | final_output = final_embeddings(sample_input_int) 100 | 101 | assert torch.allclose(original_weight, final_weight, atol=1e-6) 102 | assert torch.allclose(original_output, final_output, atol=1e-6) 103 | 104 | def test_patched_bit_embeddings_properties(self, model): 105 | patch_embedding_layers(model) 106 | patched_embeddings = model.get_input_embeddings() 107 | 108 | assert hasattr(patched_embeddings, "weight") 109 | assert hasattr(patched_embeddings, "embeddings") 110 | assert hasattr(patched_embeddings, "bit_proj_w") 111 | 112 | weight = patched_embeddings.weight 113 | assert weight.shape == (256, patched_embeddings.embeddings.embedding_dim) 114 | 115 | def test_bit_projection_initialization(self, model): 116 | patch_embedding_layers(model) 117 | patched_embeddings = model.get_input_embeddings() 118 | 119 | weight = patched_embeddings.bit_proj_w.data 120 | assert torch.allclose(weight, torch.zeros_like(weight)) 121 | 122 | def test_forward_pass_consistency(self, model, sample_input): 123 | sample_input_int = sample_input.to(dtype=torch.int) 124 | 125 | original_embeddings = model.get_input_embeddings() 126 | 127 | with torch.inference_mode(): 128 | original_embedded = original_embeddings(sample_input_int) 129 | 130 | patch_embedding_layers(model) 131 | patched_embeddings = model.get_input_embeddings() 132 | 133 | with torch.inference_mode(): 134 | patched_embedded = patched_embeddings(sample_input) 135 | 136 | assert torch.allclose(original_embedded, patched_embedded, atol=1e-6) 137 | 138 | def test_embedding_weight_consistency_after_patch(self, model): 139 | original_embeddings = model.get_input_embeddings() 140 | original_weight = original_embeddings.weight.clone() 141 | 142 | patch_embedding_layers(model) 143 | patched_embeddings = model.get_input_embeddings() 144 | 145 | assert torch.allclose(original_weight, patched_embeddings.weight, atol=1e-6) 146 | 147 | def test_embedding_dtype_is_casted(self, model, sample_input): 148 | patch_embedding_layers(model) 149 | model = model.to(torch.float16) 150 | 151 | patched_embeddings = model.get_input_embeddings() 152 | assert patched_embeddings(sample_input).dtype == torch.float16 153 | 154 | def test_bit_projection_training(self, model, sample_input): 155 | patch_embedding_layers(model) 156 | patched_embeddings = model.get_input_embeddings() 157 | 158 | initial_embedding_weight = patched_embeddings.embeddings.weight.data.clone() 159 | 160 | # Verify bit_proj.weight starts as zeros 161 | initial_bit_weight = patched_embeddings.bit_proj_w.data.clone() 162 | assert torch.allclose(initial_bit_weight, torch.zeros_like(initial_bit_weight)) 163 | 164 | # Set up a simple training objective: make embeddings closer to 1 165 | optimizer = torch.optim.SGD(patched_embeddings.parameters(), lr=0.1) 166 | target = torch.ones_like(patched_embeddings(sample_input)) 167 | 168 | # One training step 169 | optimizer.zero_grad() 170 | embedded = patched_embeddings(sample_input) 171 | loss = torch.nn.functional.mse_loss(embedded, target) 172 | loss.backward() 173 | optimizer.step() 174 | 175 | # Verify bit_proj.weight is no longer zeros 176 | final_bit_weight = patched_embeddings.bit_proj_w.data 177 | assert not torch.allclose(final_bit_weight, torch.zeros_like(final_bit_weight), atol=1e-6) 178 | 179 | # Verify base embeddings have changed too 180 | final_embedding_weight = patched_embeddings.embeddings.weight.data 181 | assert not torch.allclose(final_embedding_weight, initial_embedding_weight, atol=1e-6) 182 | -------------------------------------------------------------------------------- /utf8_tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | from pathlib import Path 4 | 5 | import torch 6 | from transformers import AutoTokenizer, PreTrainedTokenizer 7 | from transformers.tokenization_utils_base import BatchEncoding, TextInput 8 | 9 | from utf8_tokenizer.control import ControlTokens 10 | from utf8_tokenizer.utils import pad_bytearrays_to_tensor 11 | 12 | 13 | def tokenize_ids(text: str, errors="strict"): 14 | return list(text.encode("utf-8", errors=errors)) 15 | 16 | 17 | TokenizerResult = namedtuple("TokenizerResult", ["input_ids", "attention_mask"]) 18 | 19 | PAD_TOKEN = ControlTokens.Null 20 | PAD_TOKEN_ID = ord(PAD_TOKEN) 21 | 22 | BOS_TOKEN = ControlTokens.StartOfText 23 | BOS_TOKEN_ID = ord(BOS_TOKEN) 24 | 25 | EOS_TOKEN = ControlTokens.EndOfText 26 | EOS_TOKEN_ID = ord(EOS_TOKEN) 27 | EOS_BYTE = bytes([EOS_TOKEN_ID]) 28 | 29 | 30 | class UTF8Tokenizer(PreTrainedTokenizer): 31 | """ 32 | Custom UTF8 Byte Level Tokenizer implementation, 33 | extending PreTrainedTokenizer for basic Hugging Face ecosystem support. 34 | 35 | This tokenizer only supports exactly 256 tokens, with no support for "special tokens". 36 | See README.md to learn more how this works with control tokens instead. 37 | 38 | Additionally, exposes a `.torch` method, which fuses and skips unnecessary ops, 39 | to achieve a ~8x speedup over `__call__` for training purposes. 40 | """ 41 | 42 | def __init__(self, **kwargs): 43 | kwargs["pad_token"] = PAD_TOKEN 44 | kwargs["pad_token_id"] = PAD_TOKEN_ID 45 | kwargs["bos_token"] = BOS_TOKEN 46 | kwargs["bos_token_id"] = BOS_TOKEN_ID 47 | kwargs["eos_token"] = EOS_TOKEN 48 | kwargs["eos_token_id"] = EOS_TOKEN_ID 49 | super().__init__(**kwargs) 50 | 51 | # Chat template for instruction-following models 52 | with open(Path(__file__).parent / "chat_template.jinja") as f: 53 | self.chat_template = "" 54 | for line in f: 55 | self.chat_template += line.strip() 56 | 57 | @property 58 | def vocab_size(self) -> int: 59 | return 2 ** 8 60 | 61 | def add_tokens(self, *args, **kwargs): 62 | raise NotImplementedError("UTF8Tokenizer does not support adding tokens") 63 | 64 | def get_vocab(self): 65 | return {chr(i): i for i in range(self.vocab_size)} 66 | 67 | def _tokenize(self, text: TextInput, errors="strict", **kwargs): 68 | return [chr(c) for c in tokenize_ids(text, errors=errors)] 69 | 70 | def tokenize(self, text: TextInput, errors="strict", **kwargs): 71 | return self._tokenize(text, errors=errors, **kwargs) 72 | 73 | def _encode_plus(self, text: TextInput, **kwargs): 74 | return self.prepare_for_model(tokenize_ids(text), **kwargs) 75 | 76 | def _convert_token_to_id(self, token: str): 77 | return ord(token) 78 | 79 | def _convert_id_to_token(self, index: int): 80 | return chr(index) 81 | 82 | def convert_tokens_to_string(self, tokens: list[str]): 83 | """Converts a sequence of tokens (string) in a single string.""" 84 | _bytes = bytes(ord(token) for token in tokens) 85 | return _bytes.decode("utf-8", errors="ignore") 86 | 87 | def build_inputs_with_special_tokens( 88 | self, token_ids_0: list[int] | bytearray, token_ids_1: list[int] | None = None 89 | ) -> list[int] | bytearray: 90 | assert token_ids_1 is None, "UTF8Tokenizer only supports single sequence" 91 | 92 | # Experimentally, the fastest way to add BOS/EOS 93 | token_ids_0.append(EOS_TOKEN_ID) # EOS 94 | token_ids_0.insert(0, BOS_TOKEN_ID) # BOS 95 | return token_ids_0 96 | 97 | def _encode(self, texts: list[TextInput], add_special_tokens: bool) -> list[bytes]: 98 | """Fast path: encode texts with optional special tokens using string concatenation.""" 99 | if add_special_tokens: 100 | texts = [BOS_TOKEN + text + EOS_TOKEN for text in texts] 101 | return [text.encode("utf-8") for text in texts] 102 | 103 | def _encode_and_truncate( 104 | self, texts: list[TextInput], max_length: int, add_special_tokens: bool 105 | ) -> list[bytes]: 106 | """Encode and truncate texts. Uses bytes slicing + concat (faster than bytearray).""" 107 | if add_special_tokens: 108 | # Prepend BOS via string concat, truncate, concat EOS byte 109 | return [(BOS_TOKEN + text).encode("utf-8")[:max_length + 1] + EOS_BYTE for text in texts] 110 | return [text.encode("utf-8")[:max_length] for text in texts] 111 | 112 | def _original_call(self, *args, **kwargs) -> BatchEncoding: 113 | return super().__call__(*args, **kwargs) 114 | 115 | def __call__(self, *args, **kwargs) -> BatchEncoding: 116 | return_tensors = kwargs.pop("return_tensors", "pt") 117 | if return_tensors != "pt": 118 | return self._original_call(*args, return_tensors=return_tensors, **kwargs) 119 | result = self.torch(*args, **kwargs) 120 | return result._asdict() 121 | 122 | def torch( 123 | self, 124 | texts: list[TextInput], 125 | add_special_tokens: bool = True, 126 | padding: bool = False, 127 | truncation: bool = False, 128 | max_length: int | None = None, 129 | device: torch.device | None = None, 130 | ) -> TokenizerResult: 131 | if truncation: 132 | if max_length is None: 133 | warnings.warn( 134 | "Asking to truncate to max_length but no maximum length is provided and the model has " 135 | "no predefined maximum length. Default to no truncation.", 136 | stacklevel=2, 137 | ) 138 | truncation = False 139 | else: 140 | max_length = max_length - 2 if add_special_tokens else max_length 141 | if max_length < 0: 142 | warnings.warn( 143 | "We need to remove more tokens than exist. Default to no truncation.", 144 | stacklevel=2) 145 | truncation = False 146 | 147 | # Encode 148 | if truncation: 149 | input_bytes = self._encode_and_truncate(texts, max_length, add_special_tokens) 150 | else: 151 | input_bytes = self._encode(texts, add_special_tokens) 152 | 153 | if padding: 154 | # Fast path: pre-allocate and fill directly 155 | input_ids = pad_bytearrays_to_tensor(input_bytes, padding_value=PAD_TOKEN_ID) 156 | attention_mask = input_ids.ne(PAD_TOKEN_ID) 157 | else: 158 | # Slow path - no padding means we need to return a list of tensors 159 | # bytearray() needed because bytes are immutable -> read-only tensor warning 160 | input_ids = [torch.frombuffer(bytearray(ids), dtype=torch.uint8) for ids in input_bytes] 161 | attention_mask = [torch.ones(len(ids), dtype=torch.bool) for ids in input_ids] 162 | 163 | # # IDs should be long tensors, to prevent issues with some models 164 | # if isinstance(input_ids, list): 165 | # input_ids = [ids.long() for ids in input_ids] 166 | # else: 167 | # input_ids = input_ids.long() 168 | 169 | if device is not None: 170 | if isinstance(input_ids, list): 171 | input_ids = [ids.to(device, non_blocking=True) for ids in input_ids] 172 | attention_mask = [mask.to(device, non_blocking=True) for mask in attention_mask] 173 | else: 174 | input_ids = input_ids.to(device, non_blocking=True) 175 | attention_mask = attention_mask.to(device, non_blocking=True) 176 | 177 | return TokenizerResult(input_ids=input_ids, attention_mask=attention_mask) 178 | 179 | def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None): 180 | return () 181 | 182 | def to_dict(self): 183 | return {} 184 | 185 | 186 | AutoTokenizer.register(UTF8Tokenizer, slow_tokenizer_class=UTF8Tokenizer) 187 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | from utf8_tokenizer.utils import pad_bytearrays_to_tensor, pad_bytearrays_to_tensor_loop 6 | 7 | 8 | def reference_pad_sequence(bytearrays: list[bytearray], padding_value: int = 0) -> torch.Tensor: 9 | """Reference implementation using torch.nn.utils.rnn.pad_sequence.""" 10 | tensors = [torch.frombuffer(b, dtype=torch.uint8) for b in bytearrays] 11 | return pad_sequence(tensors, batch_first=True, padding_value=padding_value) 12 | 13 | 14 | class TestPadBytearraysToTensor: 15 | """Tests for pad_bytearrays_to_tensor implementations.""" 16 | 17 | @pytest.fixture 18 | def sample_bytearrays(self): 19 | return [ 20 | bytearray(b"hello"), 21 | bytearray(b"hi"), 22 | bytearray(b"world!"), 23 | ] 24 | 25 | @pytest.fixture 26 | def unicode_bytearrays(self): 27 | return [ 28 | bytearray("שלום".encode()), 29 | bytearray(b"hello"), 30 | bytearray("世界".encode()), 31 | ] 32 | 33 | @pytest.fixture 34 | def single_bytearray(self): 35 | return [bytearray(b"test")] 36 | 37 | @pytest.fixture 38 | def empty_bytearrays(self): 39 | return [bytearray(b""), bytearray(b"a"), bytearray(b"")] 40 | 41 | # Test vectorized implementation matches reference 42 | @pytest.mark.parametrize("padding_value", [0, 255, 42]) 43 | def test_vectorized_matches_reference(self, sample_bytearrays, padding_value): 44 | expected = reference_pad_sequence(sample_bytearrays, padding_value) 45 | result = pad_bytearrays_to_tensor(sample_bytearrays, padding_value) 46 | assert torch.equal(result, expected) 47 | 48 | def test_vectorized_matches_reference_unicode(self, unicode_bytearrays): 49 | expected = reference_pad_sequence(unicode_bytearrays) 50 | result = pad_bytearrays_to_tensor(unicode_bytearrays) 51 | assert torch.equal(result, expected) 52 | 53 | def test_vectorized_matches_reference_single(self, single_bytearray): 54 | expected = reference_pad_sequence(single_bytearray) 55 | result = pad_bytearrays_to_tensor(single_bytearray) 56 | assert torch.equal(result, expected) 57 | 58 | def test_vectorized_matches_loop_with_empty(self, empty_bytearrays): 59 | # Note: torch.frombuffer can't handle empty buffers, so we compare implementations 60 | loop_result = pad_bytearrays_to_tensor_loop(empty_bytearrays) 61 | vectorized_result = pad_bytearrays_to_tensor(empty_bytearrays) 62 | assert torch.equal(loop_result, vectorized_result) 63 | # Check shape is correct 64 | assert vectorized_result.shape == (3, 1) # 3 items, max_len=1 65 | 66 | # Test loop implementation matches reference 67 | @pytest.mark.parametrize("padding_value", [0, 255, 42]) 68 | def test_loop_matches_reference(self, sample_bytearrays, padding_value): 69 | expected = reference_pad_sequence(sample_bytearrays, padding_value) 70 | result = pad_bytearrays_to_tensor_loop(sample_bytearrays, padding_value) 71 | assert torch.equal(result, expected) 72 | 73 | def test_loop_matches_reference_unicode(self, unicode_bytearrays): 74 | expected = reference_pad_sequence(unicode_bytearrays) 75 | result = pad_bytearrays_to_tensor_loop(unicode_bytearrays) 76 | assert torch.equal(result, expected) 77 | 78 | # Test vectorized matches loop implementation 79 | def test_vectorized_matches_loop(self, sample_bytearrays): 80 | loop_result = pad_bytearrays_to_tensor_loop(sample_bytearrays) 81 | vectorized_result = pad_bytearrays_to_tensor(sample_bytearrays) 82 | assert torch.equal(loop_result, vectorized_result) 83 | 84 | def test_vectorized_matches_loop_unicode(self, unicode_bytearrays): 85 | loop_result = pad_bytearrays_to_tensor_loop(unicode_bytearrays) 86 | vectorized_result = pad_bytearrays_to_tensor(unicode_bytearrays) 87 | assert torch.equal(loop_result, vectorized_result) 88 | 89 | # Test output properties 90 | def test_output_shape(self, sample_bytearrays): 91 | result = pad_bytearrays_to_tensor(sample_bytearrays) 92 | max_len = max(len(b) for b in sample_bytearrays) 93 | assert result.shape == (len(sample_bytearrays), max_len) 94 | 95 | def test_output_dtype(self, sample_bytearrays): 96 | result = pad_bytearrays_to_tensor(sample_bytearrays) 97 | assert result.dtype == torch.uint8 98 | 99 | def test_padding_applied_correctly(self): 100 | bytearrays = [bytearray(b"ab"), bytearray(b"abcd")] 101 | result = pad_bytearrays_to_tensor(bytearrays, padding_value=0) 102 | 103 | # First row: "ab" + padding 104 | assert result[0].tolist() == [ord("a"), ord("b"), 0, 0] 105 | # Second row: "abcd" no padding 106 | assert result[1].tolist() == [ord("a"), ord("b"), ord("c"), ord("d")] 107 | 108 | def test_content_preserved(self, sample_bytearrays): 109 | result = pad_bytearrays_to_tensor(sample_bytearrays) 110 | 111 | for i, b in enumerate(sample_bytearrays): 112 | # Check that non-padded content matches original 113 | assert result[i, : len(b)].tolist() == list(b) 114 | 115 | # Edge cases 116 | def test_all_same_length(self): 117 | bytearrays = [bytearray(b"abc"), bytearray(b"def"), bytearray(b"ghi")] 118 | expected = reference_pad_sequence(bytearrays) 119 | result = pad_bytearrays_to_tensor(bytearrays) 120 | assert torch.equal(result, expected) 121 | assert result.shape == (3, 3) # No padding needed 122 | 123 | def test_large_batch(self): 124 | bytearrays = [bytearray(f"text{i}".encode()) for i in range(1000)] 125 | expected = reference_pad_sequence(bytearrays) 126 | result = pad_bytearrays_to_tensor(bytearrays) 127 | assert torch.equal(result, expected) 128 | 129 | def test_varying_lengths(self): 130 | bytearrays = [ 131 | bytearray(b"a"), 132 | bytearray(b"ab"), 133 | bytearray(b"abc"), 134 | bytearray(b"abcd"), 135 | bytearray(b"abcde"), 136 | ] 137 | expected = reference_pad_sequence(bytearrays) 138 | result = pad_bytearrays_to_tensor(bytearrays) 139 | assert torch.equal(result, expected) 140 | 141 | def test_all_empty_bytearrays(self): 142 | """Test with all empty bytearrays - edge case that produces (n, 0) tensor.""" 143 | bytearrays = [bytearray(b""), bytearray(b""), bytearray(b"")] 144 | loop_result = pad_bytearrays_to_tensor_loop(bytearrays) 145 | vectorized_result = pad_bytearrays_to_tensor(bytearrays) 146 | assert torch.equal(loop_result, vectorized_result) 147 | assert vectorized_result.shape == (3, 0) # All empty 148 | 149 | def test_single_empty_bytearray(self): 150 | """Test with a single empty bytearray.""" 151 | bytearrays = [bytearray(b"")] 152 | loop_result = pad_bytearrays_to_tensor_loop(bytearrays) 153 | vectorized_result = pad_bytearrays_to_tensor(bytearrays) 154 | assert torch.equal(loop_result, vectorized_result) 155 | assert vectorized_result.shape == (1, 0) 156 | 157 | def test_bytes_input(self): 158 | """Test that bytes (not bytearray) work correctly - matches type annotation.""" 159 | # The function signature says list[bytes], so test with bytes 160 | byte_list = [b"hello", b"hi", b"world!"] 161 | result = pad_bytearrays_to_tensor(byte_list) 162 | # Compare with bytearray version 163 | bytearray_list = [bytearray(b) for b in byte_list] 164 | expected = pad_bytearrays_to_tensor(bytearray_list) 165 | assert torch.equal(result, expected) 166 | 167 | def test_bytes_input_with_unicode(self): 168 | """Test bytes input with unicode content.""" 169 | byte_list = ["שלום".encode(), b"hello", "世界".encode()] 170 | result = pad_bytearrays_to_tensor(byte_list) 171 | bytearray_list = [bytearray(b) for b in byte_list] 172 | expected = pad_bytearrays_to_tensor(bytearray_list) 173 | assert torch.equal(result, expected) 174 | 175 | def test_mixed_empty_positions(self): 176 | """Test empty bytearrays at different positions.""" 177 | bytearrays = [bytearray(b""), bytearray(b"abc"), bytearray(b""), bytearray(b"de")] 178 | loop_result = pad_bytearrays_to_tensor_loop(bytearrays) 179 | vectorized_result = pad_bytearrays_to_tensor(bytearrays) 180 | assert torch.equal(loop_result, vectorized_result) 181 | assert vectorized_result.shape == (4, 3) 182 | 183 | def test_very_long_sequences(self): 184 | """Test with very long byte sequences.""" 185 | bytearrays = [ 186 | bytearray(b"a" * 10000), 187 | bytearray(b"b" * 5000), 188 | bytearray(b"c" * 15000), 189 | ] 190 | loop_result = pad_bytearrays_to_tensor_loop(bytearrays) 191 | vectorized_result = pad_bytearrays_to_tensor(bytearrays) 192 | assert torch.equal(loop_result, vectorized_result) 193 | assert vectorized_result.shape == (3, 15000) 194 | 195 | def test_null_bytes(self): 196 | """Test bytearrays containing null bytes.""" 197 | bytearrays = [bytearray(b"\x00\x00\x00"), bytearray(b"a\x00b"), bytearray(b"\x00")] 198 | loop_result = pad_bytearrays_to_tensor_loop(bytearrays) 199 | vectorized_result = pad_bytearrays_to_tensor(bytearrays) 200 | assert torch.equal(loop_result, vectorized_result) 201 | # Verify null bytes are preserved 202 | assert vectorized_result[0, 0].item() == 0 203 | assert vectorized_result[1, 1].item() == 0 204 | 205 | 206 | if __name__ == "__main__": 207 | pytest.main([__file__, "-v"]) 208 | -------------------------------------------------------------------------------- /utf8_tokenizer/logits_processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | UTF-8 Validation Logits Processor for HuggingFace transformers. 3 | 4 | This module provides a LogitsProcessor that ensures generated sequences 5 | form valid UTF-8 byte sequences by masking out invalid continuations. 6 | """ 7 | 8 | import torch 9 | from transformers import LogitsProcessor 10 | 11 | 12 | class UTF8ValidationLogitsProcessor(LogitsProcessor): 13 | """ 14 | LogitsProcessor that enforces valid UTF-8 byte sequences during generation. 15 | 16 | This processor examines the previously generated bytes and masks out invalid 17 | next bytes according to UTF-8 encoding rules. This prevents the model from 18 | generating malformed UTF-8 sequences. 19 | 20 | UTF-8 encoding rules (from Unicode Technical Committee, 2025, §3.9.3): 21 | - 1-byte: 00-7F (ASCII) 22 | - 2-byte: C2-DF followed by 80-BF 23 | - 3-byte: 24 | - E0 followed by A0-BF, then 80-BF 25 | - E1-EC followed by 80-BF, then 80-BF 26 | - ED followed by 80-9F, then 80-BF 27 | - EE-EF followed by 80-BF, then 80-BF 28 | - 4-byte: 29 | - F0 followed by 90-BF, then 80-BF, then 80-BF 30 | - F1-F3 followed by 80-BF, then 80-BF, then 80-BF 31 | - F4 followed by 80-8F, then 80-BF, then 80-BF 32 | 33 | Usage: 34 | from transformers import AutoModelForCausalLM 35 | from utf8_tokenizer.logits_processor import UTF8ValidationLogitsProcessor 36 | 37 | model = AutoModelForCausalLM.from_pretrained("your-model") 38 | processor = UTF8ValidationLogitsProcessor() 39 | 40 | outputs = model.generate( 41 | input_ids, 42 | logits_processor=[processor], 43 | max_new_tokens=100 44 | ) 45 | """ 46 | 47 | def __init__(self): 48 | """Initialize the UTF-8 validation logits processor.""" 49 | super().__init__() 50 | 51 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 52 | """ 53 | Process logits to enforce valid UTF-8 sequences. 54 | 55 | Args: 56 | input_ids: Previously generated token IDs of shape (batch_size, seq_len) 57 | scores: Logits for next token of shape (batch_size, vocab_size) 58 | 59 | Returns: 60 | Modified scores with invalid UTF-8 continuations masked to -inf 61 | """ 62 | batch_size = input_ids.shape[0] 63 | vocab_size = scores.shape[1] 64 | 65 | # Create a mask initialized to False (all tokens allowed by default) 66 | mask = torch.zeros((batch_size, vocab_size), dtype=torch.bool, device=scores.device) 67 | 68 | for batch_idx in range(batch_size): 69 | # Get the UTF-8 state by examining the last few bytes 70 | seq = input_ids[batch_idx] 71 | allowed_bytes = self._get_allowed_next_bytes(seq) 72 | 73 | # Mask out all bytes that are not allowed 74 | for byte_val in range(256): 75 | if byte_val not in allowed_bytes: 76 | mask[batch_idx, byte_val] = True 77 | 78 | # Set masked positions to -inf 79 | scores = scores.masked_fill(mask, float("-inf")) 80 | 81 | return scores 82 | 83 | def _get_allowed_next_bytes(self, sequence: torch.Tensor) -> set[int]: 84 | """ 85 | Determine which bytes are valid as the next byte based on the sequence. 86 | 87 | Args: 88 | sequence: Tensor of previously generated byte IDs 89 | 90 | Returns: 91 | Set of allowed byte values (0-255) 92 | """ 93 | if len(sequence) == 0: 94 | # At the start, we can generate any valid UTF-8 start byte 95 | return self._valid_start_bytes() 96 | 97 | # Convert to list for easier manipulation 98 | seq_list = sequence.tolist() 99 | 100 | # Find the start of the current incomplete UTF-8 sequence 101 | # by looking backwards from the end 102 | utf8_state = self._get_utf8_state(seq_list) 103 | 104 | if utf8_state["complete"]: 105 | # Current sequence is complete, can start a new character 106 | return self._valid_start_bytes() 107 | else: 108 | # In the middle of a multi-byte sequence 109 | return self._valid_continuation_bytes(utf8_state) 110 | 111 | def _valid_start_bytes(self) -> set[int]: 112 | """Return all valid UTF-8 starting bytes.""" 113 | valid = set() 114 | 115 | # 1-byte sequences: 00-7F (ASCII) 116 | valid.update(range(0x00, 0x80)) 117 | 118 | # 2-byte sequences: C2-DF 119 | valid.update(range(0xC2, 0xE0)) 120 | 121 | # 3-byte sequences: E0-EF 122 | valid.update(range(0xE0, 0xF0)) 123 | 124 | # 4-byte sequences: F0-F4 125 | valid.update(range(0xF0, 0xF5)) 126 | 127 | return valid 128 | 129 | def _get_utf8_state(self, seq_list: list[int]) -> dict: # noqa: C901 130 | """ 131 | Analyze the sequence to determine the current UTF-8 state. 132 | 133 | Args: 134 | seq_list: List of byte values 135 | 136 | Returns: 137 | Dictionary with: 138 | - complete: bool indicating if current character is complete 139 | - first_byte: the first byte of the current sequence (if incomplete) 140 | - position: position within the multi-byte sequence (1, 2, or 3 for continuation) 141 | 142 | Note: Complexity is inherent to UTF-8 specification with multiple byte sequence types. 143 | """ 144 | if not seq_list: 145 | return {"complete": True} 146 | 147 | # Look backwards to find the start of the current UTF-8 character 148 | for i in range(len(seq_list) - 1, -1, -1): 149 | byte_val = seq_list[i] 150 | 151 | # Check if this is a UTF-8 start byte 152 | if byte_val < 0x80: 153 | # ASCII - complete 154 | if i == len(seq_list) - 1: 155 | return {"complete": True} 156 | else: 157 | # We found an ASCII byte before the end, so the last bytes must be invalid 158 | # This shouldn't happen with proper validation, but treat as complete 159 | return {"complete": True} 160 | 161 | elif 0xC2 <= byte_val < 0xE0: 162 | # 2-byte sequence start 163 | bytes_after = len(seq_list) - i - 1 164 | if bytes_after >= 1: 165 | return {"complete": True} 166 | else: 167 | return {"complete": False, "first_byte": byte_val, "position": 1} 168 | 169 | elif 0xE0 <= byte_val < 0xF0: 170 | # 3-byte sequence start 171 | bytes_after = len(seq_list) - i - 1 172 | if bytes_after >= 2: 173 | return {"complete": True} 174 | else: 175 | return {"complete": False, "first_byte": byte_val, "position": bytes_after + 1} 176 | 177 | elif 0xF0 <= byte_val < 0xF5: 178 | # 4-byte sequence start 179 | bytes_after = len(seq_list) - i - 1 180 | if bytes_after >= 3: 181 | return {"complete": True} 182 | else: 183 | return {"complete": False, "first_byte": byte_val, "position": bytes_after + 1} 184 | 185 | elif 0x80 <= byte_val < 0xC0: 186 | # Continuation byte - keep looking backwards 187 | continue 188 | 189 | else: 190 | # Invalid UTF-8 byte (C0, C1, F5-FF) 191 | # Treat as complete to allow recovery 192 | if i == len(seq_list) - 1: 193 | return {"complete": True} 194 | 195 | # If we've gone through the whole sequence and only found continuation bytes, 196 | # treat as complete (malformed, but allow recovery) 197 | return {"complete": True} 198 | 199 | def _valid_continuation_bytes(self, state: dict) -> set[int]: # noqa: C901 200 | """ 201 | Get valid continuation bytes based on the current UTF-8 state. 202 | 203 | Args: 204 | state: UTF-8 state dictionary from _get_utf8_state 205 | 206 | Returns: 207 | Set of allowed byte values 208 | 209 | Note: Complexity is inherent to UTF-8 specification with multiple byte sequence types. 210 | """ 211 | first_byte = state["first_byte"] 212 | position = state["position"] 213 | 214 | # 2-byte sequences: C2-DF 215 | if 0xC2 <= first_byte < 0xE0: 216 | if position == 1: 217 | # Second byte must be 80-BF 218 | return set(range(0x80, 0xC0)) 219 | 220 | # 3-byte sequences 221 | elif 0xE0 <= first_byte < 0xF0: 222 | if first_byte == 0xE0: 223 | # E0: second byte A0-BF, third byte 80-BF 224 | if position == 1: 225 | return set(range(0xA0, 0xC0)) 226 | elif position == 2: 227 | return set(range(0x80, 0xC0)) 228 | 229 | elif first_byte == 0xED: 230 | # ED: second byte 80-9F, third byte 80-BF 231 | if position == 1: 232 | return set(range(0x80, 0xA0)) 233 | elif position == 2: 234 | return set(range(0x80, 0xC0)) 235 | 236 | else: 237 | # E1-EC, EE-EF: second and third bytes 80-BF 238 | if position in (1, 2): 239 | return set(range(0x80, 0xC0)) 240 | 241 | # 4-byte sequences 242 | elif 0xF0 <= first_byte < 0xF5: 243 | if first_byte == 0xF0: 244 | # F0: second byte 90-BF, third and fourth bytes 80-BF 245 | if position == 1: 246 | return set(range(0x90, 0xC0)) 247 | elif position in (2, 3): 248 | return set(range(0x80, 0xC0)) 249 | 250 | elif first_byte == 0xF4: 251 | # F4: second byte 80-8F, third and fourth bytes 80-BF 252 | if position == 1: 253 | return set(range(0x80, 0x90)) 254 | elif position in (2, 3): 255 | return set(range(0x80, 0xC0)) 256 | 257 | else: 258 | # F1-F3: second, third, and fourth bytes 80-BF 259 | if position in (1, 2, 3): 260 | return set(range(0x80, 0xC0)) 261 | 262 | # If we get here, something went wrong - allow all start bytes to recover 263 | return self._valid_start_bytes() 264 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | import torch 5 | 6 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 7 | 8 | 9 | @pytest.fixture 10 | def tokenizer(): 11 | return UTF8Tokenizer() 12 | 13 | 14 | def test_basic_tokenization(tokenizer): 15 | """Test basic tokenization without special tokens.""" 16 | text = "hello" 17 | result = tokenizer._original_call(text, add_special_tokens=False) 18 | 19 | # Expected: [104, 101, 108, 108, 111] (UTF-8 bytes for "hello") 20 | expected_ids = [104, 101, 108, 108, 111] 21 | assert result.input_ids == expected_ids 22 | assert result.attention_mask == [1] * len(expected_ids) 23 | 24 | 25 | def test_unicode_string(tokenizer): 26 | """Test tokenization of a unicode string.""" 27 | text = "héllo עמית" 28 | result = tokenizer._original_call(text, add_special_tokens=False) 29 | 30 | decoded = tokenizer.decode(result.input_ids) 31 | assert decoded == text 32 | 33 | 34 | def test_tokenization_with_special_tokens(tokenizer): 35 | """Test tokenization with BOS and EOS tokens.""" 36 | text = "hello" 37 | result = tokenizer._original_call(text, add_special_tokens=True) 38 | 39 | # Expected: [2, 104, 101, 108, 108, 111, 3] (BOS + "hello" + EOS) 40 | expected_ids = [2, 104, 101, 108, 108, 111, 3] 41 | assert result.input_ids == expected_ids 42 | assert result.attention_mask == [1] * len(expected_ids) 43 | 44 | 45 | def test_torch_method_basic(tokenizer): 46 | """Test torch method with basic functionality.""" 47 | texts = ["hello", "world"] 48 | result = tokenizer.torch(texts, add_special_tokens=False) 49 | 50 | # Check shapes and types 51 | assert isinstance(result.input_ids, list) 52 | assert isinstance(result.attention_mask, list) 53 | assert len(result.input_ids) == 2 54 | assert len(result.attention_mask) == 2 55 | 56 | # Check individual tensors 57 | for i, text in enumerate(texts): 58 | expected_ids = [ord(c) for c in text] 59 | assert result.input_ids[i].tolist() == expected_ids 60 | assert result.input_ids[i].dtype == torch.uint8 61 | assert result.attention_mask[i].dtype == torch.bool 62 | assert result.attention_mask[i].tolist() == [1] * len(expected_ids) 63 | 64 | 65 | def test_torch_method_with_special_tokens(tokenizer): 66 | """Test torch method with special tokens.""" 67 | texts = ["hello", "world"] 68 | result = tokenizer.torch(texts, add_special_tokens=True) 69 | 70 | # Check that BOS and EOS are added 71 | for i, text in enumerate(texts): 72 | expected_ids = [2] + [ord(c) for c in text] + [3] # BOS + text + EOS 73 | assert result.input_ids[i].tolist() == expected_ids 74 | assert result.attention_mask[i].tolist() == [1] * len(expected_ids) 75 | 76 | 77 | def test_torch_method_with_padding(tokenizer): 78 | """Test torch method with padding enabled.""" 79 | texts = ["hi", "hello world"] # Different lengths 80 | result = tokenizer.torch(texts, add_special_tokens=False, padding=True) 81 | 82 | # Check that tensors are padded to same length 83 | assert isinstance(result.input_ids, torch.Tensor) 84 | assert isinstance(result.attention_mask, torch.Tensor) 85 | 86 | max_len = max(len(text) for text in texts) 87 | assert result.input_ids.shape == (2, max_len) 88 | assert result.attention_mask.shape == (2, max_len) 89 | 90 | # Check padding values 91 | assert result.input_ids[0, -1] == 0 # Pad token 92 | assert result.attention_mask[0, -1] == 0 # Padded attention 93 | 94 | 95 | def test_torch_method_with_truncation(tokenizer): 96 | """Test torch method with truncation.""" 97 | texts = ["hello world this is a long text"] 98 | max_length = 5 99 | result = tokenizer.torch(texts, add_special_tokens=False, truncation=True, max_length=max_length) 100 | 101 | # Check truncation 102 | assert len(result.input_ids[0]) == max_length 103 | assert len(result.attention_mask[0]) == max_length 104 | 105 | # Verify truncated content matches first 5 characters 106 | expected_ids = [ord(c) for c in texts[0][:max_length]] 107 | assert result.input_ids[0].tolist() == expected_ids 108 | 109 | 110 | def test_torch_method_special_tokens_with_padding(tokenizer): 111 | """Test torch method with both special tokens and padding.""" 112 | texts = ["hi", "hello"] 113 | result = tokenizer.torch(texts, add_special_tokens=True, padding=True) 114 | 115 | # Check dimensions 116 | assert isinstance(result.input_ids, torch.Tensor) 117 | assert isinstance(result.attention_mask, torch.Tensor) 118 | 119 | # Both sequences should have BOS + text + EOS, padded to same length 120 | max_len = max(len(text) + 2 for text in texts) # +2 for BOS/EOS 121 | assert result.input_ids.shape == (2, max_len) 122 | assert result.attention_mask.shape == (2, max_len) 123 | 124 | # Check first sequence: BOS + "hi" + EOS + padding 125 | expected_first = [2, ord("h"), ord("i"), 3, 0, 0, 0] 126 | assert result.input_ids[0].tolist() == expected_first 127 | assert result.attention_mask[0].tolist() == [1, 1, 1, 1, 0, 0, 0] 128 | 129 | 130 | @pytest.mark.parametrize("add_special_tokens", [True, False]) 131 | def test_comparison_tokenizer_vs_torch_method(add_special_tokens, tokenizer): 132 | """Test that tokenizer._original_call() and torch() methods produce compatible results.""" 133 | texts = ["test string"] 134 | 135 | result1 = tokenizer._original_call(texts, add_special_tokens=add_special_tokens, return_tensors="pt") 136 | result2 = tokenizer.torch(texts, add_special_tokens=add_special_tokens, padding=False) 137 | 138 | assert torch.equal(result1.input_ids[0], result2.input_ids[0]) 139 | assert torch.equal(result1.attention_mask[0], result2.attention_mask[0]) 140 | 141 | 142 | @pytest.mark.parametrize("add_special_tokens", [True, False]) 143 | def test_comparison_tokenizer_vs_torch_method_max_length(add_special_tokens, tokenizer): 144 | texts = ["test string"] 145 | 146 | params = dict(add_special_tokens=add_special_tokens, max_length=1, truncation=True) 147 | result1 = tokenizer._original_call(texts, return_tensors="pt", **params) 148 | result2 = tokenizer.torch(texts, **params) 149 | 150 | assert torch.equal(result1.input_ids[0], result2.input_ids[0]) 151 | assert torch.equal(result1.attention_mask[0], result2.attention_mask[0]) 152 | 153 | 154 | def test_comparison_tokenizer_vs_torch_method_multiple_strings(tokenizer): 155 | texts = ["test string", "shorter"] 156 | 157 | result1 = tokenizer._original_call(texts, padding=True, return_tensors="pt") 158 | result2 = tokenizer.torch(texts, padding=True) 159 | 160 | assert torch.equal(result1.input_ids[0], result2.input_ids[0]) 161 | assert torch.equal(result1.attention_mask[0], result2.attention_mask[0]) 162 | 163 | 164 | def test_empty_string(tokenizer): 165 | """Test tokenization of empty string.""" 166 | result = tokenizer._original_call("", add_special_tokens=False) 167 | assert result.input_ids == [] 168 | assert result.attention_mask == [] 169 | 170 | result_with_special = tokenizer._original_call("", add_special_tokens=True) 171 | assert result_with_special.input_ids == [2, 3] # BOS + EOS 172 | assert result_with_special.attention_mask == [1, 1] 173 | 174 | 175 | def test_special_characters(tokenizer): 176 | """Test tokenization with special characters and unicode.""" 177 | text = "hello\nworld\t!" 178 | result = tokenizer._original_call(text, add_special_tokens=False) 179 | 180 | expected_ids = [ord(c) for c in text] 181 | assert result.input_ids == expected_ids 182 | assert result.attention_mask == [1] * len(expected_ids) 183 | 184 | 185 | def test_dtype_consistency(tokenizer): 186 | """Test that dtypes are consistent across different methods.""" 187 | texts = ["hello", "world"] 188 | result = tokenizer.torch(texts, padding=True) 189 | 190 | assert result.input_ids.dtype == torch.uint8 191 | assert result.attention_mask.dtype == torch.bool 192 | 193 | 194 | def test_batch_consistency(tokenizer): 195 | """Test that batch processing produces consistent results.""" 196 | texts = ["hello", "world", "test"] 197 | 198 | # Process as batch 199 | batch_result = tokenizer.torch(texts, add_special_tokens=True, padding=True) 200 | 201 | # Process individually 202 | individual_results = [] 203 | for text in texts: 204 | result = tokenizer.torch([text], add_special_tokens=True, padding=False) 205 | individual_results.append(result.input_ids[0]) 206 | 207 | # Check that unpadded portions match 208 | for i, individual_ids in enumerate(individual_results): 209 | batch_ids = batch_result.input_ids[i][: len(individual_ids)] 210 | assert batch_ids.tolist() == individual_ids.tolist() 211 | 212 | 213 | def test_vocab_properties(tokenizer): 214 | """Test tokenizer vocabulary properties.""" 215 | assert tokenizer.vocab_size == 256 # 2^8 216 | assert tokenizer.pad_token_id == 0 217 | assert tokenizer.bos_token_id == 2 218 | assert tokenizer.eos_token_id == 3 219 | 220 | vocab = tokenizer.get_vocab() 221 | assert len(vocab) == 256 222 | assert vocab[chr(0)] == 0 223 | assert vocab[chr(2)] == 2 224 | assert vocab[chr(3)] == 3 225 | 226 | 227 | def test_tokenizer_save_and_load(): 228 | """Test saving and loading the tokenizer.""" 229 | original_tokenizer = UTF8Tokenizer() 230 | 231 | with tempfile.TemporaryDirectory() as temp_dir: 232 | # Save the tokenizer 233 | original_tokenizer.save_pretrained(temp_dir) 234 | 235 | # Load the tokenizer 236 | UTF8Tokenizer.from_pretrained(temp_dir) 237 | 238 | 239 | class TestTorchMethodEdgeCases: 240 | """Additional edge case tests for the .torch() method.""" 241 | 242 | @pytest.fixture 243 | def tokenizer(self): 244 | return UTF8Tokenizer() 245 | 246 | def test_single_element(self, tokenizer): 247 | """Test torch method with a single element.""" 248 | result = tokenizer.torch(["hello"], padding=True) 249 | assert isinstance(result.input_ids, torch.Tensor) 250 | assert result.input_ids.shape == (1, 7) # BOS + hello + EOS 251 | 252 | def test_all_empty_strings_no_special_tokens(self, tokenizer): 253 | """Test with all empty strings and no special tokens - extreme edge case.""" 254 | result = tokenizer.torch(["", ""], padding=True, add_special_tokens=False) 255 | assert isinstance(result.input_ids, torch.Tensor) 256 | assert result.input_ids.shape == (2, 0) # Empty tensor 257 | assert result.attention_mask.shape == (2, 0) 258 | 259 | def test_all_empty_strings_with_special_tokens(self, tokenizer): 260 | """Test with all empty strings but with special tokens.""" 261 | result = tokenizer.torch(["", ""], padding=True, add_special_tokens=True) 262 | assert result.input_ids.shape == (2, 2) # BOS + EOS only 263 | assert result.input_ids[0].tolist() == [2, 3] 264 | assert result.input_ids[1].tolist() == [2, 3] 265 | 266 | def test_device_parameter_cpu(self, tokenizer): 267 | """Test torch method with device parameter (CPU).""" 268 | result = tokenizer.torch(["hello", "hi"], padding=True, device="cpu") 269 | assert result.input_ids.device == torch.device("cpu") 270 | assert result.attention_mask.device == torch.device("cpu") 271 | 272 | def test_device_parameter_no_padding(self, tokenizer): 273 | """Test torch method with device parameter without padding.""" 274 | result = tokenizer.torch(["hello", "hi"], padding=False, device="cpu") 275 | assert isinstance(result.input_ids, list) 276 | assert result.input_ids[0].device == torch.device("cpu") 277 | assert result.attention_mask[0].device == torch.device("cpu") 278 | 279 | def test_unicode_strings_with_padding(self, tokenizer): 280 | """Test unicode strings with padding.""" 281 | texts = ["שלום", "hello", "世界"] # Hebrew, English, Chinese 282 | result = tokenizer.torch(texts, padding=True, add_special_tokens=True) 283 | assert isinstance(result.input_ids, torch.Tensor) 284 | # Hebrew שלום = 8 bytes, Chinese 世界 = 6 bytes, so Hebrew + BOS/EOS = 10 285 | max_len = max(len(t.encode("utf-8")) for t in texts) + 2 286 | assert result.input_ids.shape == (3, max_len) 287 | 288 | def test_truncation_with_padding(self, tokenizer): 289 | """Test truncation combined with padding.""" 290 | texts = ["hello world this is long", "hi"] 291 | result = tokenizer.torch(texts, padding=True, truncation=True, max_length=10, add_special_tokens=True) 292 | assert result.input_ids.shape == (2, 10) # Both truncated/padded to max_length 293 | # First should be truncated, second should be padded 294 | assert result.attention_mask[0].sum().item() == 10 # All active (truncated) 295 | assert result.attention_mask[1].sum().item() == 4 # BOS + hi + EOS = 4 296 | 297 | def test_truncation_without_special_tokens(self, tokenizer): 298 | """Test truncation without special tokens.""" 299 | texts = ["hello world"] 300 | result = tokenizer.torch(texts, truncation=True, max_length=5, add_special_tokens=False) 301 | assert len(result.input_ids[0]) == 5 302 | assert result.input_ids[0].tolist() == [ord(c) for c in "hello"] 303 | 304 | def test_string_with_null_bytes(self, tokenizer): 305 | """Test string containing null bytes (which is pad token).""" 306 | text = "a\x00b" 307 | result = tokenizer.torch([text], padding=False, add_special_tokens=False) 308 | # Null byte should be preserved 309 | assert result.input_ids[0].tolist() == [ord("a"), 0, ord("b")] 310 | 311 | def test_string_with_null_bytes_padding(self, tokenizer): 312 | """Test strings with null bytes and padding - null bytes ARE padding by design.""" 313 | texts = ["a\x00b", "xy"] 314 | result = tokenizer.torch(texts, padding=True, add_special_tokens=False) 315 | # First string: a, \x00, b (length 3) 316 | # Second string: x, y, pad (length 3) 317 | assert result.input_ids.shape == (2, 3) 318 | # Null bytes (0x00) are the pad token, so they're marked as not attended 319 | assert result.attention_mask[0].tolist() == [True, False, True] # a, pad, b 320 | assert result.attention_mask[1].tolist() == [True, True, False] # xy + pad 321 | 322 | def test_very_long_strings(self, tokenizer): 323 | """Test with very long strings.""" 324 | long_text = "a" * 10000 325 | result = tokenizer.torch([long_text], padding=False, add_special_tokens=True) 326 | assert len(result.input_ids[0]) == 10002 # BOS + 10000 + EOS 327 | 328 | def test_mixed_lengths_attention_mask(self, tokenizer): 329 | """Test that attention mask correctly identifies padded positions.""" 330 | texts = ["abc", "a", "abcde"] 331 | result = tokenizer.torch(texts, padding=True, add_special_tokens=False) 332 | # Lengths: 3, 1, 5 -> padded to 5 333 | assert result.attention_mask[0].tolist() == [True, True, True, False, False] 334 | assert result.attention_mask[1].tolist() == [True, False, False, False, False] 335 | assert result.attention_mask[2].tolist() == [True, True, True, True, True] 336 | 337 | def test_return_tensors_pt_via_call(self, tokenizer): 338 | """Test that __call__ with return_tensors='pt' uses torch method.""" 339 | texts = ["hello", "world"] 340 | result = tokenizer(texts, padding=True, return_tensors="pt") 341 | assert "input_ids" in result 342 | assert "attention_mask" in result 343 | assert isinstance(result["input_ids"], torch.Tensor) 344 | 345 | def test_return_tensors_other_via_call(self, tokenizer): 346 | """Test that __call__ with other return_tensors falls back to original.""" 347 | texts = ["hello"] 348 | result = tokenizer(texts, return_tensors=None, add_special_tokens=False) 349 | # Without return_tensors, should use original call 350 | assert result.input_ids == [[104, 101, 108, 108, 111]] 351 | 352 | def test_truncation_warning_no_max_length(self, tokenizer): 353 | """Test that truncation without max_length issues warning.""" 354 | import warnings 355 | texts = ["hello world"] 356 | with warnings.catch_warnings(record=True) as w: 357 | warnings.simplefilter("always") 358 | result = tokenizer.torch(texts, truncation=True, max_length=None, add_special_tokens=False) 359 | assert len(w) == 1 360 | assert "no maximum length is provided" in str(w[0].message) 361 | # Should return untouched since truncation was disabled 362 | assert len(result.input_ids[0]) == 11 363 | 364 | 365 | if __name__ == "__main__": 366 | pytest.main([__file__, "-v"]) 367 | -------------------------------------------------------------------------------- /tests/test_logits_processor.py: -------------------------------------------------------------------------------- 1 | """Tests for UTF8ValidationLogitsProcessor.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from utf8_tokenizer.logits_processor import UTF8ValidationLogitsProcessor 7 | 8 | 9 | @pytest.fixture 10 | def processor(): 11 | """Create a UTF8ValidationLogitsProcessor instance.""" 12 | return UTF8ValidationLogitsProcessor() 13 | 14 | 15 | class TestUTF8State: 16 | """Test the internal UTF-8 state detection.""" 17 | 18 | def test_empty_sequence(self, processor): 19 | """Test that empty sequence is considered complete.""" 20 | state = processor._get_utf8_state([]) 21 | assert state["complete"] is True 22 | 23 | def test_ascii_complete(self, processor): 24 | """Test that ASCII characters are complete.""" 25 | # Single ASCII character 26 | state = processor._get_utf8_state([0x41]) # 'A' 27 | assert state["complete"] is True 28 | 29 | # Multiple ASCII characters 30 | state = processor._get_utf8_state([0x48, 0x65, 0x6C, 0x6C, 0x6F]) # "Hello" 31 | assert state["complete"] is True 32 | 33 | def test_two_byte_incomplete(self, processor): 34 | """Test incomplete 2-byte sequence.""" 35 | state = processor._get_utf8_state([0xC2]) # Start of 2-byte sequence 36 | assert state["complete"] is False 37 | assert state["first_byte"] == 0xC2 38 | assert state["position"] == 1 39 | 40 | def test_two_byte_complete(self, processor): 41 | """Test complete 2-byte sequence.""" 42 | state = processor._get_utf8_state([0xC2, 0xA9]) # © symbol 43 | assert state["complete"] is True 44 | 45 | def test_three_byte_incomplete_position_1(self, processor): 46 | """Test incomplete 3-byte sequence at position 1.""" 47 | state = processor._get_utf8_state([0xE2]) # Start of 3-byte sequence 48 | assert state["complete"] is False 49 | assert state["first_byte"] == 0xE2 50 | assert state["position"] == 1 51 | 52 | def test_three_byte_incomplete_position_2(self, processor): 53 | """Test incomplete 3-byte sequence at position 2.""" 54 | state = processor._get_utf8_state([0xE2, 0x82]) # Partial 3-byte sequence 55 | assert state["complete"] is False 56 | assert state["first_byte"] == 0xE2 57 | assert state["position"] == 2 58 | 59 | def test_three_byte_complete(self, processor): 60 | """Test complete 3-byte sequence.""" 61 | state = processor._get_utf8_state([0xE2, 0x82, 0xAC]) # € symbol 62 | assert state["complete"] is True 63 | 64 | def test_four_byte_incomplete_position_1(self, processor): 65 | """Test incomplete 4-byte sequence at position 1.""" 66 | state = processor._get_utf8_state([0xF0]) # Start of 4-byte sequence 67 | assert state["complete"] is False 68 | assert state["first_byte"] == 0xF0 69 | assert state["position"] == 1 70 | 71 | def test_four_byte_incomplete_position_2(self, processor): 72 | """Test incomplete 4-byte sequence at position 2.""" 73 | state = processor._get_utf8_state([0xF0, 0x9F]) 74 | assert state["complete"] is False 75 | assert state["first_byte"] == 0xF0 76 | assert state["position"] == 2 77 | 78 | def test_four_byte_incomplete_position_3(self, processor): 79 | """Test incomplete 4-byte sequence at position 3.""" 80 | state = processor._get_utf8_state([0xF0, 0x9F, 0x98]) 81 | assert state["complete"] is False 82 | assert state["first_byte"] == 0xF0 83 | assert state["position"] == 3 84 | 85 | def test_four_byte_complete(self, processor): 86 | """Test complete 4-byte sequence.""" 87 | state = processor._get_utf8_state([0xF0, 0x9F, 0x98, 0x80]) # 😀 emoji 88 | assert state["complete"] is True 89 | 90 | 91 | class TestValidStartBytes: 92 | """Test valid UTF-8 start bytes.""" 93 | 94 | def test_valid_start_bytes(self, processor): 95 | """Test that all valid start bytes are included.""" 96 | valid = processor._valid_start_bytes() 97 | 98 | # ASCII: 0x00-0x7F 99 | for i in range(0x00, 0x80): 100 | assert i in valid 101 | 102 | # 2-byte: 0xC2-0xDF 103 | for i in range(0xC2, 0xE0): 104 | assert i in valid 105 | 106 | # 3-byte: 0xE0-0xEF 107 | for i in range(0xE0, 0xF0): 108 | assert i in valid 109 | 110 | # 4-byte: 0xF0-0xF4 111 | for i in range(0xF0, 0xF5): 112 | assert i in valid 113 | 114 | def test_invalid_start_bytes_excluded(self, processor): 115 | """Test that invalid start bytes are not included.""" 116 | valid = processor._valid_start_bytes() 117 | 118 | # Continuation bytes: 0x80-0xBF 119 | for i in range(0x80, 0xC0): 120 | assert i not in valid 121 | 122 | # Invalid: 0xC0-0xC1 123 | assert 0xC0 not in valid 124 | assert 0xC1 not in valid 125 | 126 | # Invalid: 0xF5-0xFF 127 | for i in range(0xF5, 0x100): 128 | assert i not in valid 129 | 130 | 131 | class TestTwoByteSequences: 132 | """Test 2-byte UTF-8 sequences (C2-DF).""" 133 | 134 | def test_two_byte_continuation(self, processor): 135 | """Test that 2-byte sequences require 80-BF as second byte.""" 136 | # Test with C2 137 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xC2])) 138 | expected = set(range(0x80, 0xC0)) 139 | assert allowed == expected 140 | 141 | # Test with DF 142 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xDF])) 143 | expected = set(range(0x80, 0xC0)) 144 | assert allowed == expected 145 | 146 | 147 | class TestThreeByteSequences: 148 | """Test 3-byte UTF-8 sequences (E0-EF).""" 149 | 150 | def test_e0_second_byte(self, processor): 151 | """Test E0 requires A0-BF as second byte.""" 152 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xE0])) 153 | expected = set(range(0xA0, 0xC0)) 154 | assert allowed == expected 155 | 156 | def test_e0_third_byte(self, processor): 157 | """Test E0 requires 80-BF as third byte.""" 158 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xE0, 0xA0])) 159 | expected = set(range(0x80, 0xC0)) 160 | assert allowed == expected 161 | 162 | def test_ed_second_byte(self, processor): 163 | """Test ED requires 80-9F as second byte.""" 164 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xED])) 165 | expected = set(range(0x80, 0xA0)) 166 | assert allowed == expected 167 | 168 | def test_ed_third_byte(self, processor): 169 | """Test ED requires 80-BF as third byte.""" 170 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xED, 0x80])) 171 | expected = set(range(0x80, 0xC0)) 172 | assert allowed == expected 173 | 174 | def test_e1_ec_second_byte(self, processor): 175 | """Test E1-EC require 80-BF as second byte.""" 176 | for first_byte in [0xE1, 0xE5, 0xEC]: 177 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte])) 178 | expected = set(range(0x80, 0xC0)) 179 | assert allowed == expected 180 | 181 | def test_e1_ec_third_byte(self, processor): 182 | """Test E1-EC require 80-BF as third byte.""" 183 | for first_byte in [0xE1, 0xE5, 0xEC]: 184 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte, 0x80])) 185 | expected = set(range(0x80, 0xC0)) 186 | assert allowed == expected 187 | 188 | def test_ee_ef_second_byte(self, processor): 189 | """Test EE-EF require 80-BF as second byte.""" 190 | for first_byte in [0xEE, 0xEF]: 191 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte])) 192 | expected = set(range(0x80, 0xC0)) 193 | assert allowed == expected 194 | 195 | def test_ee_ef_third_byte(self, processor): 196 | """Test EE-EF require 80-BF as third byte.""" 197 | for first_byte in [0xEE, 0xEF]: 198 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte, 0x80])) 199 | expected = set(range(0x80, 0xC0)) 200 | assert allowed == expected 201 | 202 | 203 | class TestFourByteSequences: 204 | """Test 4-byte UTF-8 sequences (F0-F4).""" 205 | 206 | def test_f0_second_byte(self, processor): 207 | """Test F0 requires 90-BF as second byte.""" 208 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF0])) 209 | expected = set(range(0x90, 0xC0)) 210 | assert allowed == expected 211 | 212 | def test_f0_third_byte(self, processor): 213 | """Test F0 requires 80-BF as third byte.""" 214 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF0, 0x90])) 215 | expected = set(range(0x80, 0xC0)) 216 | assert allowed == expected 217 | 218 | def test_f0_fourth_byte(self, processor): 219 | """Test F0 requires 80-BF as fourth byte.""" 220 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF0, 0x90, 0x80])) 221 | expected = set(range(0x80, 0xC0)) 222 | assert allowed == expected 223 | 224 | def test_f4_second_byte(self, processor): 225 | """Test F4 requires 80-8F as second byte.""" 226 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF4])) 227 | expected = set(range(0x80, 0x90)) 228 | assert allowed == expected 229 | 230 | def test_f4_third_byte(self, processor): 231 | """Test F4 requires 80-BF as third byte.""" 232 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF4, 0x80])) 233 | expected = set(range(0x80, 0xC0)) 234 | assert allowed == expected 235 | 236 | def test_f4_fourth_byte(self, processor): 237 | """Test F4 requires 80-BF as fourth byte.""" 238 | allowed = processor._get_allowed_next_bytes(torch.tensor([0xF4, 0x80, 0x80])) 239 | expected = set(range(0x80, 0xC0)) 240 | assert allowed == expected 241 | 242 | def test_f1_f3_second_byte(self, processor): 243 | """Test F1-F3 require 80-BF as second byte.""" 244 | for first_byte in [0xF1, 0xF2, 0xF3]: 245 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte])) 246 | expected = set(range(0x80, 0xC0)) 247 | assert allowed == expected 248 | 249 | def test_f1_f3_third_byte(self, processor): 250 | """Test F1-F3 require 80-BF as third byte.""" 251 | for first_byte in [0xF1, 0xF2, 0xF3]: 252 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte, 0x80])) 253 | expected = set(range(0x80, 0xC0)) 254 | assert allowed == expected 255 | 256 | def test_f1_f3_fourth_byte(self, processor): 257 | """Test F1-F3 require 80-BF as fourth byte.""" 258 | for first_byte in [0xF1, 0xF2, 0xF3]: 259 | allowed = processor._get_allowed_next_bytes(torch.tensor([first_byte, 0x80, 0x80])) 260 | expected = set(range(0x80, 0xC0)) 261 | assert allowed == expected 262 | 263 | 264 | class TestLogitsProcessing: 265 | """Test the main logits processing functionality.""" 266 | 267 | def test_call_masks_invalid_bytes(self, processor): 268 | """Test that __call__ properly masks invalid bytes.""" 269 | # Start with a 2-byte sequence starter (C2) 270 | input_ids = torch.tensor([[0xC2]]) 271 | scores = torch.zeros((1, 256)) 272 | 273 | processed_scores = processor(input_ids, scores) 274 | 275 | # Only bytes 80-BF should be allowed (not -inf) 276 | for i in range(256): 277 | if 0x80 <= i < 0xC0: 278 | assert processed_scores[0, i] != float("-inf") 279 | else: 280 | assert processed_scores[0, i] == float("-inf") 281 | 282 | def test_call_batch_processing(self, processor): 283 | """Test batch processing with different sequences.""" 284 | # Batch of 3: one ASCII, one 2-byte start, one 3-byte start 285 | input_ids = torch.tensor([[0x41], [0xC2], [0xE2]]) 286 | scores = torch.zeros((3, 256)) 287 | 288 | processed_scores = processor(input_ids, scores) 289 | 290 | # First sequence (ASCII complete) - should allow all start bytes 291 | valid_start = processor._valid_start_bytes() 292 | for i in range(256): 293 | if i in valid_start: 294 | assert processed_scores[0, i] != float("-inf") 295 | else: 296 | assert processed_scores[0, i] == float("-inf") 297 | 298 | # Second sequence (2-byte incomplete) - should allow 80-BF 299 | for i in range(256): 300 | if 0x80 <= i < 0xC0: 301 | assert processed_scores[1, i] != float("-inf") 302 | else: 303 | assert processed_scores[1, i] == float("-inf") 304 | 305 | # Third sequence (3-byte incomplete) - should allow 80-BF 306 | for i in range(256): 307 | if 0x80 <= i < 0xC0: 308 | assert processed_scores[2, i] != float("-inf") 309 | else: 310 | assert processed_scores[2, i] == float("-inf") 311 | 312 | def test_empty_sequence_allows_all_start_bytes(self, processor): 313 | """Test that empty sequence allows all valid start bytes.""" 314 | input_ids = torch.tensor([[]]) 315 | scores = torch.zeros((1, 256)) 316 | 317 | processed_scores = processor(input_ids, scores) 318 | 319 | valid_start = processor._valid_start_bytes() 320 | for i in range(256): 321 | if i in valid_start: 322 | assert processed_scores[0, i] != float("-inf") 323 | else: 324 | assert processed_scores[0, i] == float("-inf") 325 | 326 | 327 | class TestRealWorldSequences: 328 | """Test with real-world UTF-8 sequences.""" 329 | 330 | def test_copyright_symbol(self, processor): 331 | """Test © symbol (U+00A9): C2 A9.""" 332 | # After C2, should allow 80-BF 333 | input_ids = torch.tensor([[0xC2]]) 334 | scores = torch.zeros((1, 256)) 335 | processed = processor(input_ids, scores) 336 | 337 | # A9 should be allowed 338 | assert processed[0, 0xA9] != float("-inf") 339 | # 7F should not be allowed (not in 80-BF range) 340 | assert processed[0, 0x7F] == float("-inf") 341 | 342 | def test_euro_symbol(self, processor): 343 | """Test € symbol (U+20AC): E2 82 AC.""" 344 | # After E2, should allow 80-BF 345 | input_ids = torch.tensor([[0xE2]]) 346 | scores = torch.zeros((1, 256)) 347 | processed = processor(input_ids, scores) 348 | assert processed[0, 0x82] != float("-inf") 349 | 350 | # After E2 82, should allow 80-BF 351 | input_ids = torch.tensor([[0xE2, 0x82]]) 352 | scores = torch.zeros((1, 256)) 353 | processed = processor(input_ids, scores) 354 | assert processed[0, 0xAC] != float("-inf") 355 | 356 | def test_emoji(self, processor): 357 | """Test 😀 emoji (U+1F600): F0 9F 98 80.""" 358 | # After F0, should allow 90-BF 359 | input_ids = torch.tensor([[0xF0]]) 360 | scores = torch.zeros((1, 256)) 361 | processed = processor(input_ids, scores) 362 | assert processed[0, 0x9F] != float("-inf") 363 | assert processed[0, 0x8F] == float("-inf") # Below 90 364 | 365 | # After F0 9F, should allow 80-BF 366 | input_ids = torch.tensor([[0xF0, 0x9F]]) 367 | scores = torch.zeros((1, 256)) 368 | processed = processor(input_ids, scores) 369 | assert processed[0, 0x98] != float("-inf") 370 | 371 | # After F0 9F 98, should allow 80-BF 372 | input_ids = torch.tensor([[0xF0, 0x9F, 0x98]]) 373 | scores = torch.zeros((1, 256)) 374 | processed = processor(input_ids, scores) 375 | assert processed[0, 0x80] != float("-inf") 376 | 377 | def test_mixed_sequence(self, processor): 378 | """Test a mixed sequence: Hello© (48 65 6C 6C 6F C2 A9).""" 379 | # After complete "Hello©", should allow all start bytes 380 | input_ids = torch.tensor([[0x48, 0x65, 0x6C, 0x6C, 0x6F, 0xC2, 0xA9]]) 381 | scores = torch.zeros((1, 256)) 382 | processed = processor(input_ids, scores) 383 | 384 | valid_start = processor._valid_start_bytes() 385 | for i in range(256): 386 | if i in valid_start: 387 | assert processed[0, i] != float("-inf") 388 | else: 389 | assert processed[0, i] == float("-inf") 390 | 391 | 392 | class TestWithActualModel: 393 | """Test UTF8ValidationLogitsProcessor with an actual language model.""" 394 | 395 | @pytest.mark.parametrize( 396 | ("prefix_bytes", "expected_range_start", "expected_range_end"), 397 | [ 398 | # 2-byte sequences 399 | ([0xC2], 0x80, 0xBF), # C2 -> 80-BF 400 | ([0xDF], 0x80, 0xBF), # DF -> 80-BF 401 | # 3-byte sequences 402 | ([0xE0], 0xA0, 0xBF), # E0 -> A0-BF (special) 403 | ([0xE1], 0x80, 0xBF), # E1 -> 80-BF 404 | ([0xED], 0x80, 0x9F), # ED -> 80-9F (special) 405 | ([0xEE], 0x80, 0xBF), # EE -> 80-BF 406 | # 4-byte sequences 407 | ([0xF0], 0x90, 0xBF), # F0 -> 90-BF (special) 408 | ([0xF1], 0x80, 0xBF), # F1 -> 80-BF 409 | ([0xF4], 0x80, 0x8F), # F4 -> 80-8F (special) 410 | # Continuation bytes 411 | ([0xE2, 0x82], 0x80, 0xBF), # E2 82 -> 80-BF 412 | ([0xF0, 0x9F], 0x80, 0xBF), # F0 9F -> 80-BF 413 | ([0xF0, 0x9F, 0x98], 0x80, 0xBF), # F0 9F 98 -> 80-BF 414 | ], 415 | ) 416 | def test_model_generation_with_processor(self, prefix_bytes, expected_range_start, expected_range_end): 417 | """ 418 | Test that the processor correctly constrains generation with an actual language model. 419 | 420 | This test loads a small language model, resizes it to 256 tokens, and verifies that 421 | after generating with specific UTF-8 prefixes, the next byte falls within the expected range. 422 | """ 423 | from transformers import AutoModelForCausalLM 424 | 425 | # Load the model 426 | model = AutoModelForCausalLM.from_pretrained("sbintuitions/tiny-lm", torch_dtype="auto") 427 | 428 | # Resize model to 256 tokens for byte-level generation 429 | model.resize_token_embeddings(256) 430 | model.eval() 431 | 432 | # Create processor 433 | processor = UTF8ValidationLogitsProcessor() 434 | 435 | # Create input with the prefix bytes 436 | input_ids = torch.tensor([prefix_bytes]) 437 | 438 | # Generate a single token with the processor 439 | with torch.no_grad(): 440 | outputs = model.generate( 441 | input_ids, 442 | max_new_tokens=1, 443 | do_sample=False, # Use greedy decoding 444 | logits_processor=[processor], 445 | pad_token_id=0, 446 | ) 447 | 448 | # Get the generated byte (last token in the output) 449 | generated_byte = outputs[0, -1].item() 450 | 451 | # Verify the generated byte is in the expected range 452 | assert expected_range_start <= generated_byte <= expected_range_end, ( 453 | f"Generated byte {generated_byte:02X} not in expected range " 454 | f"{expected_range_start:02X}-{expected_range_end:02X}" 455 | ) 456 | 457 | 458 | if __name__ == "__main__": 459 | pytest.main([__file__, "-v"]) 460 | -------------------------------------------------------------------------------- /experiments/language-modelling/run_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # /// script 17 | # dependencies = [ 18 | # "transformers @ git+https://github.com/huggingface/transformers.git", 19 | # "albumentations >= 1.4.16", 20 | # "accelerate >= 0.12.0", 21 | # "torch >= 1.3", 22 | # "datasets >= 2.14.0", 23 | # "sentencepiece != 0.1.92", 24 | # "protobuf", 25 | # "evaluate", 26 | # "scikit-learn", 27 | # ] 28 | # /// 29 | 30 | """ 31 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 32 | 33 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 34 | https://huggingface.co/models?filter=text-generation 35 | """ 36 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 37 | 38 | import logging 39 | import math 40 | import os 41 | import sys 42 | from dataclasses import dataclass, field 43 | from itertools import chain 44 | from typing import Optional 45 | 46 | import datasets 47 | import evaluate 48 | import torch 49 | from datasets import IterableDataset, IterableDatasetDict, load_dataset 50 | 51 | import transformers 52 | from transformers import ( 53 | CONFIG_MAPPING, 54 | MODEL_FOR_CAUSAL_LM_MAPPING, 55 | AutoConfig, 56 | AutoModelForCausalLM, 57 | AutoTokenizer, 58 | HfArgumentParser, 59 | Trainer, 60 | TrainingArguments, 61 | default_data_collator, 62 | is_torch_xla_available, 63 | set_seed, 64 | ) 65 | from transformers.testing_utils import CaptureLogger 66 | from transformers.utils.versions import require_version 67 | 68 | from utf8_tokenizer.byt5_comparison import ByT5ComparableTokenizer 69 | from utf8_tokenizer.embeddings import patch_embedding_layers 70 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 71 | 72 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 73 | # check_min_version("4.57.0.dev0") 74 | 75 | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 76 | 77 | logger = logging.getLogger(__name__) 78 | 79 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 80 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 81 | 82 | 83 | @dataclass 84 | class ModelArguments: 85 | """ 86 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 87 | """ 88 | 89 | model_name_or_path: Optional[str] = field( 90 | default=None, 91 | metadata={ 92 | "help": ( 93 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 94 | ) 95 | }, 96 | ) 97 | model_type: Optional[str] = field( 98 | default=None, 99 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 100 | ) 101 | config_overrides: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": ( 105 | "Override some existing default config settings when a model is trained from scratch. Example: " 106 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 107 | ) 108 | }, 109 | ) 110 | config_name: Optional[str] = field( 111 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 112 | ) 113 | tokenizer_name: Optional[str] = field( 114 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 115 | ) 116 | cache_dir: Optional[str] = field( 117 | default=None, 118 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 119 | ) 120 | use_fast_tokenizer: bool = field( 121 | default=True, 122 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 123 | ) 124 | use_bit_embeddings: bool = field( 125 | default=False, 126 | metadata={"help": "Whether to use bit embeddings."}, 127 | ) 128 | model_revision: str = field( 129 | default="main", 130 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 131 | ) 132 | token: str = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " 137 | "generated when running `hf auth login` (stored in `~/.huggingface`)." 138 | ) 139 | }, 140 | ) 141 | trust_remote_code: bool = field( 142 | default=False, 143 | metadata={ 144 | "help": ( 145 | "Whether to trust the execution of code from datasets/models defined on the Hub." 146 | " This option should only be set to `True` for repositories you trust and in which you have read the" 147 | " code, as it will execute code present on the Hub on your local machine." 148 | ) 149 | }, 150 | ) 151 | dtype: Optional[str] = field( 152 | default=None, 153 | metadata={ 154 | "help": ( 155 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 156 | "dtype will be automatically derived from the model's weights." 157 | ), 158 | "choices": ["auto", "bfloat16", "float16", "float32"], 159 | }, 160 | ) 161 | 162 | def __post_init__(self): 163 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 164 | raise ValueError( 165 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 166 | ) 167 | 168 | 169 | @dataclass 170 | class DataTrainingArguments: 171 | """ 172 | Arguments pertaining to what data we are going to input our model for training and eval. 173 | """ 174 | 175 | dataset_name: Optional[str] = field( 176 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 177 | ) 178 | dataset_config_name: Optional[str] = field( 179 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 180 | ) 181 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 182 | validation_file: Optional[str] = field( 183 | default=None, 184 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 185 | ) 186 | max_train_samples: Optional[int] = field( 187 | default=None, 188 | metadata={ 189 | "help": ( 190 | "For debugging purposes or quicker training, truncate the number of training examples to this " 191 | "value if set." 192 | ) 193 | }, 194 | ) 195 | max_eval_samples: Optional[int] = field( 196 | default=None, 197 | metadata={ 198 | "help": ( 199 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 200 | "value if set." 201 | ) 202 | }, 203 | ) 204 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 205 | block_size: Optional[int] = field( 206 | default=None, 207 | metadata={ 208 | "help": ( 209 | "Optional input sequence length after tokenization. " 210 | "The training dataset will be truncated in block of this size for training. " 211 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 212 | ) 213 | }, 214 | ) 215 | overwrite_cache: bool = field( 216 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 217 | ) 218 | validation_split_percentage: Optional[int] = field( 219 | default=5, 220 | metadata={ 221 | "help": "The percentage of the train set used as validation set in case there's no validation split" 222 | }, 223 | ) 224 | preprocessing_num_workers: Optional[int] = field( 225 | default=None, 226 | metadata={"help": "The number of processes to use for the preprocessing."}, 227 | ) 228 | keep_linebreaks: bool = field( 229 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 230 | ) 231 | 232 | def __post_init__(self): 233 | if self.streaming: 234 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 235 | 236 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 237 | raise ValueError("Need either a dataset name or a training/validation file.") 238 | else: 239 | if self.train_file is not None: 240 | extension = self.train_file.split(".")[-1] 241 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 242 | if self.validation_file is not None: 243 | extension = self.validation_file.split(".")[-1] 244 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 245 | 246 | 247 | def split_streaming_dataset( 248 | full_streaming_dataset, 249 | validation_percentage: int = 5, 250 | ) -> IterableDatasetDict: 251 | """ 252 | Splits a streaming dataset into 253 | training and validation IterableDatasets, and supports methods like .map(), .filter(), 254 | .take() and properties like .features on the resulting streams. 255 | 256 | Args: 257 | full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb"). 258 | validation_percentage (int): The proportion of the dataset to be used for validation split. 259 | 260 | Returns: 261 | IterableDatasetDict: An IterableDatasetDict containing two IterableDataset objects: (train_stream, validation_stream). 262 | """ 263 | if not (0 < validation_percentage < 100): 264 | raise ValueError( 265 | f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}" 266 | ) 267 | 268 | def split_generator(is_train: bool): 269 | for i, example in enumerate(full_streaming_dataset): 270 | if is_train: 271 | if i % 100 > validation_percentage: 272 | yield example 273 | else: 274 | if i % 100 < validation_percentage: 275 | yield example 276 | 277 | features = full_streaming_dataset.features 278 | train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features) 279 | validation_stream = IterableDataset.from_generator( 280 | split_generator, gen_kwargs={"is_train": False}, features=features 281 | ) 282 | 283 | return IterableDatasetDict({"train": train_stream, "validation": validation_stream}) 284 | 285 | 286 | def main(): 287 | # See all possible arguments in src/transformers/training_args.py 288 | # or by passing the --help flag to this script. 289 | # We now keep distinct sets of args, for a cleaner separation of concerns. 290 | 291 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 292 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 293 | # If we pass only one argument to the script and it's the path to a json file, 294 | # let's parse it to get our arguments. 295 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 296 | else: 297 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 298 | 299 | # Setup logging 300 | logging.basicConfig( 301 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 302 | datefmt="%m/%d/%Y %H:%M:%S", 303 | handlers=[logging.StreamHandler(sys.stdout)], 304 | ) 305 | 306 | if training_args.should_log: 307 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 308 | transformers.utils.logging.set_verbosity_info() 309 | 310 | log_level = training_args.get_process_log_level() 311 | logger.setLevel(log_level) 312 | datasets.utils.logging.set_verbosity(log_level) 313 | transformers.utils.logging.set_verbosity(log_level) 314 | transformers.utils.logging.enable_default_handler() 315 | transformers.utils.logging.enable_explicit_format() 316 | 317 | # Log on each process the small summary: 318 | logger.warning( 319 | f"Process rank: {training_args.local_process_index}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 320 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 321 | ) 322 | logger.info(f"Training/evaluation parameters {training_args}") 323 | 324 | # Set seed before initializing model. 325 | set_seed(training_args.seed) 326 | 327 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 328 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 329 | # (the dataset will be downloaded automatically from the datasets Hub). 330 | # 331 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 332 | # 'text' is found. You can easily tweak this behavior (see below). 333 | # 334 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 335 | # download the dataset. 336 | if data_args.dataset_name is not None: 337 | # Downloading and loading a dataset from the hub. 338 | raw_datasets = load_dataset( 339 | data_args.dataset_name, 340 | data_args.dataset_config_name, 341 | cache_dir=model_args.cache_dir, 342 | token=model_args.token, 343 | streaming=data_args.streaming, 344 | trust_remote_code=model_args.trust_remote_code, 345 | ) 346 | if "validation" not in raw_datasets: 347 | if data_args.streaming: 348 | dataset_stream = load_dataset( 349 | data_args.dataset_name, 350 | data_args.dataset_config_name, 351 | split="train", 352 | cache_dir=model_args.cache_dir, 353 | token=model_args.token, 354 | streaming=data_args.streaming, 355 | trust_remote_code=model_args.trust_remote_code, 356 | ) 357 | raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) 358 | else: 359 | raw_datasets["validation"] = load_dataset( 360 | data_args.dataset_name, 361 | data_args.dataset_config_name, 362 | split=f"train[:{data_args.validation_split_percentage}%]", 363 | cache_dir=model_args.cache_dir, 364 | token=model_args.token, 365 | streaming=data_args.streaming, 366 | trust_remote_code=model_args.trust_remote_code, 367 | ) 368 | raw_datasets["train"] = load_dataset( 369 | data_args.dataset_name, 370 | data_args.dataset_config_name, 371 | split=f"train[{data_args.validation_split_percentage}%:]", 372 | cache_dir=model_args.cache_dir, 373 | token=model_args.token, 374 | streaming=data_args.streaming, 375 | trust_remote_code=model_args.trust_remote_code, 376 | ) 377 | else: 378 | data_files = {} 379 | dataset_args = {} 380 | if data_args.train_file is not None: 381 | data_files["train"] = data_args.train_file 382 | if data_args.validation_file is not None: 383 | data_files["validation"] = data_args.validation_file 384 | extension = ( 385 | data_args.train_file.split(".")[-1] 386 | if data_args.train_file is not None 387 | else data_args.validation_file.split(".")[-1] 388 | ) 389 | if extension == "txt": 390 | extension = "text" 391 | dataset_args["keep_linebreaks"] = data_args.keep_linebreaks 392 | raw_datasets = load_dataset( 393 | extension, 394 | data_files=data_files, 395 | cache_dir=model_args.cache_dir, 396 | token=model_args.token, 397 | **dataset_args, 398 | ) 399 | # If no validation data is there, validation_split_percentage will be used to divide the dataset. 400 | if "validation" not in raw_datasets: 401 | if data_args.streaming: 402 | dataset_stream = load_dataset( 403 | extension, 404 | data_files=data_files, 405 | split="train", 406 | cache_dir=model_args.cache_dir, 407 | token=model_args.token, 408 | **dataset_args, 409 | ) 410 | raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) 411 | else: 412 | raw_datasets["validation"] = load_dataset( 413 | extension, 414 | data_files=data_files, 415 | split=f"train[:{data_args.validation_split_percentage}%]", 416 | cache_dir=model_args.cache_dir, 417 | token=model_args.token, 418 | **dataset_args, 419 | ) 420 | 421 | raw_datasets["train"] = load_dataset( 422 | extension, 423 | data_files=data_files, 424 | split=f"train[{data_args.validation_split_percentage}%:]", 425 | cache_dir=model_args.cache_dir, 426 | token=model_args.token, 427 | **dataset_args, 428 | ) 429 | 430 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 431 | # https://huggingface.co/docs/datasets/loading_datasets. 432 | 433 | # Load pretrained model and tokenizer 434 | # 435 | # Distributed training: 436 | # The .from_pretrained methods guarantee that only one local process can concurrently 437 | # download model & vocab. 438 | 439 | config_kwargs = { 440 | "cache_dir": model_args.cache_dir, 441 | "revision": model_args.model_revision, 442 | "token": model_args.token, 443 | "trust_remote_code": model_args.trust_remote_code, 444 | } 445 | if model_args.config_name: 446 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 447 | elif model_args.model_name_or_path: 448 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 449 | else: 450 | config = CONFIG_MAPPING[model_args.model_type]() 451 | logger.warning("You are instantiating a new config instance from scratch.") 452 | if model_args.config_overrides is not None: 453 | logger.info(f"Overriding config: {model_args.config_overrides}") 454 | config.update_from_string(model_args.config_overrides) 455 | logger.info(f"New config: {config}") 456 | 457 | if model_args.tokenizer_name: 458 | if model_args.tokenizer_name.startswith("google/byt5"): 459 | tokenizer = ByT5ComparableTokenizer() 460 | else: 461 | tokenizer_kwargs = { 462 | "cache_dir": model_args.cache_dir, 463 | "use_fast": model_args.use_fast_tokenizer, 464 | "revision": model_args.model_revision, 465 | "token": model_args.token, 466 | "trust_remote_code": model_args.trust_remote_code, 467 | } 468 | if model_args.tokenizer_name: 469 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 470 | elif model_args.model_name_or_path: 471 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 472 | else: 473 | raise ValueError( 474 | "You are instantiating a new tokenizer from scratch. This is not supported by this script. " 475 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 476 | ) 477 | else: 478 | tokenizer = UTF8Tokenizer() 479 | 480 | if model_args.model_name_or_path: 481 | dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) 482 | model = AutoModelForCausalLM.from_pretrained( 483 | model_args.model_name_or_path, 484 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 485 | config=config, 486 | cache_dir=model_args.cache_dir, 487 | revision=model_args.model_revision, 488 | token=model_args.token, 489 | trust_remote_code=model_args.trust_remote_code, 490 | dtype=dtype, 491 | ) 492 | else: 493 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code) 494 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 495 | logger.info(f"Training new model from scratch - Total size={n_params / 2 ** 20:.2f}M params") 496 | 497 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 498 | # on a small vocab and want a smaller embedding size, remove this test. 499 | embedding_size = model.get_input_embeddings().weight.shape[0] 500 | # if len(tokenizer) > embedding_size: 501 | model.resize_token_embeddings(len(tokenizer)) 502 | 503 | if model_args.use_bit_embeddings: 504 | patch_embedding_layers(model) 505 | 506 | # Preprocessing the datasets. 507 | # First we tokenize all the texts. 508 | if training_args.do_train: 509 | column_names = list(raw_datasets["train"].features) 510 | else: 511 | column_names = list(raw_datasets["validation"].features) 512 | text_column_name = "text" if "text" in column_names else column_names[0] 513 | 514 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 515 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 516 | 517 | def tokenize_function(examples): 518 | with CaptureLogger(tok_logger) as cl: 519 | output = tokenizer(examples[text_column_name]) 520 | # clm input could be much much longer than block_size 521 | if "Token indices sequence length is longer than the" in cl.out: 522 | tok_logger.warning( 523 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" 524 | " before being passed to the model." 525 | ) 526 | return output 527 | 528 | with training_args.main_process_first(desc="dataset map tokenization"): 529 | if not data_args.streaming: 530 | tokenized_datasets = raw_datasets.map( 531 | tokenize_function, 532 | batched=True, 533 | num_proc=data_args.preprocessing_num_workers, 534 | remove_columns=column_names, 535 | load_from_cache_file=not data_args.overwrite_cache, 536 | desc="Running tokenizer on dataset", 537 | ) 538 | else: 539 | tokenized_datasets = raw_datasets.map( 540 | tokenize_function, 541 | batched=True, 542 | remove_columns=column_names, 543 | ) 544 | if hasattr(config, "max_position_embeddings"): 545 | max_pos_embeddings = config.max_position_embeddings 546 | else: 547 | # Define a default value if the attribute is missing in the config. 548 | max_pos_embeddings = 1024 549 | 550 | if data_args.block_size is None: 551 | block_size = tokenizer.model_max_length 552 | if block_size > max_pos_embeddings: 553 | logger.warning( 554 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 555 | f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx." 556 | ) 557 | if max_pos_embeddings > 0: 558 | block_size = min(1024, max_pos_embeddings) 559 | else: 560 | block_size = 1024 561 | else: 562 | if data_args.block_size > tokenizer.model_max_length: 563 | logger.warning( 564 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model " 565 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 566 | ) 567 | block_size = min(data_args.block_size, tokenizer.model_max_length) 568 | 569 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 570 | def group_texts(examples): 571 | # Concatenate all texts. 572 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples} 573 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 574 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 575 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 576 | total_length = (total_length // block_size) * block_size 577 | # Split by chunks of max_len. 578 | result = { 579 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)] 580 | for k, t in concatenated_examples.items() 581 | } 582 | result["labels"] = result["input_ids"].copy() 583 | return result 584 | 585 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 586 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 587 | # to preprocess. 588 | # 589 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 590 | # https://huggingface.co/docs/datasets/process#map 591 | 592 | with training_args.main_process_first(desc="grouping texts together"): 593 | if not data_args.streaming: 594 | lm_datasets = tokenized_datasets.map( 595 | group_texts, 596 | batched=True, 597 | num_proc=data_args.preprocessing_num_workers, 598 | load_from_cache_file=not data_args.overwrite_cache, 599 | desc=f"Grouping texts in chunks of {block_size}", 600 | ) 601 | else: 602 | lm_datasets = tokenized_datasets.map( 603 | group_texts, 604 | batched=True, 605 | ) 606 | 607 | if training_args.do_train: 608 | if "train" not in tokenized_datasets: 609 | raise ValueError("--do_train requires a train dataset") 610 | train_dataset = lm_datasets["train"] 611 | if data_args.max_train_samples is not None: 612 | if data_args.streaming: 613 | train_dataset = train_dataset.take(data_args.max_train_samples) 614 | else: 615 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 616 | train_dataset = train_dataset.select(range(max_train_samples)) 617 | 618 | if training_args.do_eval: 619 | if "validation" not in tokenized_datasets: 620 | raise ValueError("--do_eval requires a validation dataset") 621 | eval_dataset = lm_datasets["validation"] 622 | if data_args.max_eval_samples is not None: 623 | if data_args.streaming: 624 | eval_dataset = eval_dataset.take(data_args.max_eval_samples) 625 | else: 626 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 627 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 628 | 629 | def preprocess_logits_for_metrics(logits, labels): 630 | if isinstance(logits, tuple): 631 | # Depending on the model and config, logits may contain extra tensors, 632 | # like past_key_values, but logits always come first 633 | logits = logits[0] 634 | return logits.argmax(dim=-1) 635 | 636 | metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) 637 | 638 | def compute_metrics(eval_preds): 639 | preds, labels = eval_preds 640 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 641 | # by preprocess_logits_for_metrics but we need to shift the labels 642 | labels = labels[:, 1:].reshape(-1) 643 | preds = preds[:, :-1].reshape(-1) 644 | return metric.compute(predictions=preds, references=labels) 645 | 646 | # Initialize our Trainer 647 | trainer = Trainer( 648 | model=model, 649 | args=training_args, 650 | train_dataset=train_dataset if training_args.do_train else None, 651 | eval_dataset=eval_dataset if training_args.do_eval else None, 652 | processing_class=tokenizer, 653 | # Data collator will default to DataCollatorWithPadding, so we change it. 654 | data_collator=default_data_collator, 655 | compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None, 656 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 657 | if training_args.do_eval and not is_torch_xla_available() 658 | else None, 659 | ) 660 | 661 | # Training 662 | if training_args.do_train: 663 | checkpoint = None 664 | if training_args.resume_from_checkpoint is not None: 665 | checkpoint = training_args.resume_from_checkpoint 666 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 667 | trainer.save_model() # Saves the tokenizer too for easy upload 668 | 669 | metrics = train_result.metrics 670 | 671 | if data_args.max_train_samples is not None: 672 | max_train_samples = data_args.max_train_samples 673 | elif data_args.streaming: 674 | max_train_samples = 0 # TODO: figure out a better way to get the length of streaming dataset 675 | else: 676 | max_train_samples = len(train_dataset) 677 | 678 | if data_args.streaming: 679 | metrics["train_samples"] = max_train_samples 680 | else: 681 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 682 | 683 | trainer.log_metrics("train", metrics) 684 | trainer.save_metrics("train", metrics) 685 | trainer.save_state() 686 | 687 | # Evaluation 688 | if training_args.do_eval: 689 | logger.info("*** Evaluate ***") 690 | 691 | metrics = trainer.evaluate() 692 | 693 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 694 | if data_args.streaming: 695 | metrics["eval_samples"] = max_eval_samples 696 | else: 697 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 698 | 699 | try: 700 | perplexity = math.exp(metrics["eval_loss"]) 701 | except OverflowError: 702 | perplexity = float("inf") 703 | metrics["perplexity"] = perplexity 704 | 705 | trainer.log_metrics("eval", metrics) 706 | trainer.save_metrics("eval", metrics) 707 | 708 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 709 | if data_args.dataset_name is not None: 710 | kwargs["dataset_tags"] = data_args.dataset_name 711 | if data_args.dataset_config_name is not None: 712 | kwargs["dataset_args"] = data_args.dataset_config_name 713 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 714 | else: 715 | kwargs["dataset"] = data_args.dataset_name 716 | 717 | if training_args.push_to_hub: 718 | trainer.push_to_hub(**kwargs) 719 | else: 720 | trainer.create_model_card(**kwargs) 721 | 722 | 723 | def _mp_fn(index): 724 | # For xla_spawn (TPUs) 725 | main() 726 | 727 | 728 | if __name__ == "__main__": 729 | main() 730 | --------------------------------------------------------------------------------