├── tests ├── __init__.py ├── languages │ ├── __init__.py │ └── test_chinese.py ├── tokenizers │ ├── __init__.py │ ├── test_bne.py │ └── test_bpe.py ├── test_pretokenizer.py └── test_graph.py ├── complex_tokenization ├── __init__.py ├── examples │ ├── __init__.py │ ├── utils.py │ ├── bpe.py │ └── bne.py ├── graphs │ ├── __init__.py │ ├── settings.py │ ├── units.py │ └── words.py ├── languages │ ├── __init__.py │ └── chinese │ │ ├── __init__.py │ │ ├── .gitignore │ │ ├── create_dictionary.py │ │ ├── ideographic_description_sequences.py │ │ └── frequency.py ├── draw.py ├── trainer.py └── graph.py ├── .gitignore ├── .vscode └── settings.json ├── .github └── workflows │ ├── lint.yaml │ ├── test.yaml │ └── release.yaml ├── pyproject.toml ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/languages/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/languages/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/languages/chinese/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /complex_tokenization/languages/chinese/.gitignore: -------------------------------------------------------------------------------- 1 | cjkvi-ids/ 2 | hfhchan-ids/ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .claude/ 3 | *.egg-info 4 | build/ 5 | dist/ 6 | .env 7 | .DS_Store 8 | **/__pycache__/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "." 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true 7 | } -------------------------------------------------------------------------------- /complex_tokenization/graphs/settings.py: -------------------------------------------------------------------------------- 1 | class GraphSettings: 2 | USE_SINGLETONS = False # speeds up computation but hurts visualization 3 | MAX_MERGE_SIZE = 2 4 | ONLY_MINIMAL_MERGES = True 5 | -------------------------------------------------------------------------------- /complex_tokenization/examples/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | 4 | def text_dataset(max_samples=None, 5 | dataset="Salesforce/wikitext", 6 | dataset_config="wikitext-2-raw-v1"): 7 | dataset = load_dataset(dataset, dataset_config, streaming=True, split="train") 8 | if max_samples is not None: 9 | dataset = dataset.take(max_samples) 10 | return (sample["text"] for sample in dataset) 11 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | 11 | jobs: 12 | test: 13 | name: Lint 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - name: Setup uv 20 | uses: astral-sh/setup-uv@v6.5.0 21 | with: 22 | python-version: "3.12" 23 | enable-cache: true 24 | activate-environment: true 25 | 26 | - name: Install dependencies 27 | run: uv pip install ".[dev]" 28 | 29 | - name: Lint code 30 | run: uv run ruff check . -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | 11 | jobs: 12 | test: 13 | name: Test 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - name: Setup uv 20 | uses: astral-sh/setup-uv@v6.5.0 21 | with: 22 | python-version: "3.12" 23 | enable-cache: true 24 | activate-environment: true 25 | 26 | - name: Install dependencies 27 | run: uv pip install ".[dev]" 28 | 29 | - name: Test Code 30 | run: uv run pytest -n auto --dist loadscope 31 | -------------------------------------------------------------------------------- /tests/tokenizers/test_bne.py: -------------------------------------------------------------------------------- 1 | from complex_tokenization.examples.bne import train_bne_tokenizer 2 | from complex_tokenization.examples.utils import text_dataset 3 | 4 | 5 | class TestBNE: 6 | def test_large_train_bne_tokenizer(self): 7 | """Test training BNE tokenizer with n=4 and expected merges""" 8 | texts = list(text_dataset(max_samples=10)) 9 | merges = train_bne_tokenizer(texts, n=4, num_merges=10) 10 | 11 | expected = [ 12 | (' ', 't', 'h', 'e'), 13 | (' ', 'a'), 14 | ('i', 'o', 'n'), 15 | ('e', 'r'), 16 | (' ', 'g', 'a', 'm'), 17 | (' ', 'V', 'a', 'l'), 18 | ('e', 's'), 19 | ('i', 'n'), 20 | (' Val', 'k', 'y', 'r'), 21 | ('e', 'd') 22 | ] 23 | 24 | assert merges == expected 25 | -------------------------------------------------------------------------------- /complex_tokenization/graphs/units.py: -------------------------------------------------------------------------------- 1 | import regex 2 | 3 | from complex_tokenization.graph import GraphVertex, Node, NodesSequence 4 | 5 | 6 | def characters(s: str) -> GraphVertex: 7 | nodes = [Node(c) for c in s] 8 | 9 | if len(nodes) == 1: 10 | return nodes[0] 11 | return NodesSequence(nodes=tuple(nodes)) 12 | 13 | 14 | def utf8(s: str) -> GraphVertex: 15 | bytes_array = s.encode("utf-8") 16 | nodes = [Node(bytes([b])) for b in bytes_array] 17 | if len(nodes) == 1: 18 | return nodes[0] 19 | return NodesSequence(nodes=tuple(nodes)) 20 | 21 | 22 | def utf8_clusters(s: str) -> GraphVertex: 23 | # Split string into grapheme clusters using regex 24 | # \X matches extended grapheme clusters 25 | clusters = regex.findall(r'\X', s) 26 | nodes = [utf8(cluster) for cluster in clusters] 27 | 28 | if len(nodes) == 1: 29 | return nodes[0] 30 | return NodesSequence(nodes=tuple(nodes)) 31 | -------------------------------------------------------------------------------- /complex_tokenization/draw.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from graphviz import Source 4 | from PIL import Image 5 | 6 | 7 | def draw_dot_content(dot_content: str) -> Image: 8 | dot = """ 9 | digraph G { 10 | graph [compound=true, rankdir=LR, fontsize=16, nodesep=0.6]; 11 | node [shape=circle, fontsize=16]; 12 | edge [fontsize=12, arrowhead=none]; // default: no arrowheads 13 | """ + dot_content + "\n}" 14 | src = Source(dot) 15 | 16 | png_bytes = src.pipe(format="png") 17 | 18 | return Image.open(BytesIO(png_bytes)) 19 | 20 | 21 | def create_gif(frames: list[Image.Image], save=None) -> Image.Image: 22 | target = save if save is not None else BytesIO() 23 | frames[0].save( 24 | target, 25 | format="GIF", 26 | save_all=True, 27 | append_images=frames[1:], # skip the first one (it's already saved) 28 | duration=500, 29 | loop=0, 30 | disposal=2, # <-- clear previous frame 31 | ) 32 | if isinstance(target, BytesIO): 33 | target.seek(0) 34 | return Image.open(target) 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "complex-tokenization" 3 | description = "Tokenizaion for Complex Scripts" 4 | version = "0.0.1" 5 | authors = [ 6 | { name = "Amit Moryossef", email = "amitmoryossef@gmail.com" }, 7 | ] 8 | readme = "README.md" 9 | dependencies = [ 10 | "regex", 11 | ] 12 | 13 | [project.optional-dependencies] 14 | dev = [ 15 | "ruff", 16 | "pytest", 17 | "pytest-xdist", # For parallel test execution 18 | ] 19 | 20 | 21 | [tool.setuptools.packages.find] 22 | where = ["."] 23 | include = ["complex_tokenization*"] 24 | 25 | [tool.ruff] 26 | line-length = 120 27 | 28 | [tool.ruff.lint] 29 | select = [ 30 | "E", # pycodestyle errors 31 | "W", # pycodestyle warnings 32 | "F", # pyflakes 33 | "C90", # mccabe complexity 34 | "I", # isort 35 | "N", # pep8-naming 36 | "UP", # pyupgrade 37 | "B", # flake8-bugbear 38 | "PT", # flake8-pytest-style 39 | "W605", # invalid escape sequence 40 | "BLE", # flake8-blind-except 41 | ] 42 | 43 | [tool.pytest.ini_options] 44 | addopts = "-v" 45 | testpaths = ["complex_tokenization", "tests"] 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 sign 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | on: 3 | release: 4 | types: [ created ] 5 | 6 | jobs: 7 | pypi-publish: 8 | name: Upload release to PyPI 9 | runs-on: ubuntu-latest 10 | environment: 11 | name: pypi 12 | url: https://pypi.org/p/complex-tokenization 13 | permissions: 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v5 17 | 18 | - uses: actions/setup-python@v6 19 | with: 20 | python-version: "3.12" 21 | 22 | - name: Extract release version 23 | id: get_version 24 | run: echo "version=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV 25 | 26 | - name: Update version in pyproject.toml 27 | run: | 28 | sed -i 's/^version = .*/version = "${{ env.version }}"/' pyproject.toml 29 | 30 | - name: Install build dependencies 31 | run: pip install build 32 | 33 | - name: Build a binary wheel dist 34 | run: | 35 | rm -rf dist 36 | python -m build 37 | 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /complex_tokenization/graphs/words.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import regex as re 4 | 5 | from complex_tokenization.graph import GraphVertex, NodesSequence, UnconnectedGraphs 6 | from complex_tokenization.graphs.units import utf8_clusters 7 | 8 | # From openai/gpt-oss-20b 9 | pattern = ( 10 | "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?" 11 | "|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?" 12 | "|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" 13 | ) 14 | 15 | 16 | def pretokenize(text: str) -> Iterable[str]: 17 | return [match.group(0) for match in re.finditer(pattern, text)] 18 | 19 | 20 | def words(text: str, connected=True, units=utf8_clusters) -> GraphVertex: 21 | tokens = pretokenize(text) 22 | nodes = [units(word) for word in tokens] 23 | if len(nodes) == 1: 24 | return nodes[0] 25 | if connected: 26 | return NodesSequence(nodes=tuple(nodes)) 27 | return UnconnectedGraphs(subgraphs=tuple(nodes)) 28 | -------------------------------------------------------------------------------- /complex_tokenization/examples/bpe.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tokenizers import Tokenizer 4 | 5 | from complex_tokenization.examples.bne import train_bne_tokenizer 6 | from complex_tokenization.examples.utils import text_dataset 7 | 8 | 9 | def get_tokenizer_merges(tokenizer: Tokenizer): 10 | backend = tokenizer.backend_tokenizer 11 | data = json.loads(backend.to_str()) 12 | return [tuple(m) for m in data["model"]["merges"]] 13 | 14 | 15 | def train_huggingface_tokenizer(texts: list[str], num_merges: int = 10): 16 | from transformers import AutoTokenizer 17 | 18 | tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") 19 | 20 | new_tokenizer = tokenizer.train_new_from_iterator(texts, 256 + 21 + num_merges) 21 | return get_tokenizer_merges(new_tokenizer) 22 | 23 | 24 | def train_bpe_tokenizer(texts: list[str], num_merges: int = 10): 25 | # BPE can only merge 2 tokens at a time 26 | return train_bne_tokenizer(texts, n=2, num_merges=num_merges) 27 | 28 | 29 | if __name__ == "__main__": 30 | texts = list(text_dataset(max_samples=10)) 31 | print(train_bpe_tokenizer(texts)) 32 | print(train_huggingface_tokenizer(texts)) 33 | -------------------------------------------------------------------------------- /complex_tokenization/examples/bne.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from complex_tokenization.examples.utils import text_dataset 4 | from complex_tokenization.graphs.settings import GraphSettings 5 | from complex_tokenization.graphs.units import utf8_clusters 6 | from complex_tokenization.graphs.words import words 7 | 8 | 9 | def train_bne_tokenizer(texts: list[str], 10 | n=2, 11 | connected=False, 12 | units=utf8_clusters, 13 | num_merges: int = 10): 14 | from complex_tokenization.trainer import Trainer 15 | 16 | GraphSettings.ONLY_MINIMAL_MERGES = True # BNE only merges adjacent tokens 17 | GraphSettings.MAX_MERGE_SIZE = n # Maximum number of tokens to merge at a time 18 | GraphSettings.USE_SINGLETONS = False # for performance 19 | 20 | graphs = tuple([words(text, connected=connected, units=units) for text in texts]) 21 | 22 | trainer = Trainer(graphs=graphs) 23 | trainer.train(num_merges=num_merges) 24 | return trainer.get_merges() 25 | 26 | 27 | if __name__ == "__main__": 28 | texts = list(text_dataset(max_samples=10)) 29 | print(train_bne_tokenizer(texts, n=4)) 30 | -------------------------------------------------------------------------------- /tests/tokenizers/test_bpe.py: -------------------------------------------------------------------------------- 1 | from complex_tokenization.examples.bpe import train_bpe_tokenizer, train_huggingface_tokenizer 2 | from complex_tokenization.examples.utils import text_dataset 3 | 4 | 5 | class TestBPE: 6 | def test_basic_train_huggingface_tokenizer(self): 7 | """Test training HuggingFace tokenizer with expected merges""" 8 | texts = ["the teacher teaches the thick thing"] 9 | # Only 2 merges, to avoid needing a tie-breaker 10 | merges = train_huggingface_tokenizer(texts, num_merges=2) 11 | 12 | expected = [ 13 | ('Ġ', 't'), 14 | ('h', 'e'), 15 | ] 16 | 17 | assert merges == expected 18 | 19 | def test_basic_train_complex_tokenizer(self): 20 | """Test training complex tokenizer with expected merges""" 21 | texts = ["the teacher teaches the thick thing"] 22 | # Only 2 merges, to avoid needing a tie-breaker 23 | merges = train_bpe_tokenizer(texts, num_merges=2) 24 | 25 | expected = [ 26 | (' ', 't'), 27 | ('h', 'e'), 28 | ] 29 | 30 | assert merges == expected 31 | 32 | def test_large_train_huggingface_tokenizer(self): 33 | """Test training HuggingFace tokenizer with expected merges""" 34 | texts = list(text_dataset(max_samples=10)) 35 | merges = train_huggingface_tokenizer(texts, num_merges=10) 36 | 37 | expected = [ 38 | ("Ġ", "t"), 39 | ("Ġ", "a"), 40 | ("o", "n"), 41 | ("h", "e"), 42 | ("e", "s"), 43 | ("e", "r"), 44 | ("i", "n"), 45 | ("Ġt", "he"), 46 | ("e", "d"), 47 | ("a", "l"), 48 | ] 49 | 50 | assert merges == expected 51 | 52 | def test_large_train_complex_tokenizer(self): 53 | """Test training complex tokenizer with expected merges""" 54 | texts = list(text_dataset(max_samples=10)) 55 | merges = train_bpe_tokenizer(texts, num_merges=10) 56 | 57 | expected = [ 58 | (" ", "t"), 59 | (" ", "a"), 60 | ("o", "n"), 61 | ("h", "e"), 62 | ("e", "s"), 63 | ("e", "r"), 64 | ("i", "n"), 65 | (" t", "he"), 66 | ("e", "d"), 67 | ("a", "l"), 68 | ] 69 | 70 | assert merges == expected 71 | -------------------------------------------------------------------------------- /tests/test_pretokenizer.py: -------------------------------------------------------------------------------- 1 | from complex_tokenization.graphs.words import pretokenize 2 | 3 | 4 | class TestPretokenizer: 5 | def test_simple_english_text(self): 6 | """Test pretokenization of simple English text""" 7 | text = "hello world" 8 | result = pretokenize(text) 9 | expected = ["hello", " world"] 10 | assert result == expected 11 | 12 | def test_text_with_punctuation(self): 13 | """Test pretokenization with punctuation""" 14 | text = "hello world!" 15 | result = pretokenize(text) 16 | expected = ["hello", " world", "!"] 17 | assert result == expected 18 | 19 | def test_text_with_numbers(self): 20 | """Test pretokenization with numbers""" 21 | text = "I have 3 apples and 42 oranges" 22 | result = pretokenize(text) 23 | expected = ["I", " have", " ", "3", " apples", " and", " ", "42", " oranges"] 24 | assert result == expected 25 | 26 | def test_text_with_contractions(self): 27 | """Test pretokenization with contractions""" 28 | text = "I'm happy you're here" 29 | result = pretokenize(text) 30 | # Contractions should be preserved 31 | assert any("I'm" in token for token in result) 32 | assert any("you're" in token for token in result) 33 | 34 | def test_empty_string(self): 35 | """Test pretokenization of empty string""" 36 | text = "" 37 | result = pretokenize(text) 38 | assert result == [] 39 | 40 | def test_whitespace_only(self): 41 | """Test pretokenization of whitespace-only text""" 42 | text = " \n\n " 43 | result = pretokenize(text) 44 | # Should tokenize whitespace 45 | assert len(result) > 0 46 | assert all(not token.strip() for token in result) 47 | 48 | def test_mixed_content(self): 49 | """Test pretokenization with mixed content: letters, numbers, punctuation""" 50 | text = "Hello123 world! Test 456." 51 | result = pretokenize(text) 52 | expected = ["Hello", "123", " world", "!", " Test", " ", "456", "."] 53 | assert result == expected 54 | 55 | def test_sentence_with_multiple_punctuation(self): 56 | """Test pretokenization of a sentence with various punctuation marks""" 57 | text = "Hello, world! How are you? I'm fine." 58 | result = pretokenize(text) 59 | expected = ["Hello", ",", " world", "!", " How", " are", " you", "?", " I'm", " fine", "."] 60 | assert result == expected 61 | 62 | def test_capitalized_words(self): 63 | """Test pretokenization with capitalized words""" 64 | text = "Hello World TEST" 65 | result = pretokenize(text) 66 | expected = ["Hello", " World", " TEST"] 67 | assert result == expected 68 | 69 | -------------------------------------------------------------------------------- /complex_tokenization/trainer.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from functools import reduce 3 | 4 | from complex_tokenization.draw import create_gif, draw_dot_content 5 | from complex_tokenization.graph import GraphVertex, Node, Tree, UnconnectedGraphs 6 | from complex_tokenization.graphs.settings import GraphSettings 7 | from complex_tokenization.graphs.units import utf8 8 | 9 | 10 | class Trainer: 11 | def __init__(self, graph: GraphVertex | None = None, graphs: tuple[GraphVertex, ...] | None = None): 12 | if graphs is None and graph is None: 13 | raise ValueError("Must provide either graph or graphs") 14 | if graphs is not None and graph is not None: 15 | raise ValueError("Must provide either graph or graphs, not both") 16 | 17 | if graphs is not None: 18 | graph = UnconnectedGraphs(graphs) 19 | 20 | self.graph = graph 21 | self.merges = [] 22 | 23 | def train(self, num_merges: int = 100, draw=False, verbose=False): 24 | frames = [] 25 | 26 | while True: 27 | if len(self.merges) >= num_merges: 28 | break 29 | 30 | if draw: 31 | dot_content = "\n".join(self.graph.dot()) 32 | image = draw_dot_content(dot_content) 33 | frames.append(image) 34 | 35 | all_merges = self.graph.get_merges() 36 | if GraphSettings.ONLY_MINIMAL_MERGES: 37 | all_merges = (m for m in all_merges if all(isinstance(n, Node) for n in m)) 38 | merges = Counter(all_merges) 39 | merges_compression = Counter({k: (len(k) - 1) * v for k, v in merges.items()}) 40 | 41 | if verbose: 42 | print(merges_compression.most_common(5)) 43 | 44 | if len(merges) == 0: 45 | break 46 | nodes = merges_compression.most_common(1)[0][0] 47 | token = reduce(lambda x, y: x + y, nodes) 48 | 49 | if verbose: 50 | print("Merging", token, "=", nodes) 51 | 52 | self.graph = self.graph.merge(token, nodes) 53 | self.merges.append((token, nodes)) 54 | 55 | if draw: 56 | gif = create_gif(frames, save="example.gif") 57 | gif.show() 58 | 59 | def get_merges(self): 60 | return [tuple(str(node) for node in nodes) for _, nodes in self.merges] 61 | 62 | 63 | if __name__ == "__main__": 64 | example_graph = Tree(root=utf8("⿱"), children=( 65 | utf8("十"), 66 | Tree(root=utf8("⿱"), children=( 67 | utf8("乛"), 68 | utf8("头"), 69 | )), 70 | )) 71 | # example_sentence = "the teacher teaches the thick." 72 | example_sentence = "test test" 73 | # example_graph = sentence_to_graph(example_sentence) 74 | 75 | # other_graph = words(example_sentence) 76 | # example_graph = NodesSequence((example_graph, utf8(" "), other_graph)) 77 | 78 | trainer = Trainer(graph=example_graph) 79 | trainer.train(num_merges=10, draw=True) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tokenization for Complex Scripts 2 | 3 | This repository proposes a generic merge-based tokenization scheme, including concatenative and 4 | non-concatenative language structures. 5 | It therefore allows for more fitting tokenization for complex scripts (such as SignWriting and Chinese) 6 | by decomposing words into smaller units, and representing them in various graph structures. 7 | 8 | ## Usage 9 | 10 | Install: 11 | 12 | ```bash 13 | git clone https://github.com/sign-language-processing/complex-tokenization.git 14 | cd complex-tokenization 15 | pip install ".[dev]" 16 | ``` 17 | 18 | Pretokenize text using a Huggingface Tokenizer implementation: 19 | 20 | ```python 21 | from complex_tokenization.tokenizer import WordsSegmentationTokenizer 22 | 23 | pretokenizer = WordsSegmentationTokenizer(max_bytes=16) 24 | tokens = pretokenizer.tokenize("hello world! 我爱北京天安门 👩‍👩‍👧‍👦") 25 | # ['hello ', 'world! ', '我', '爱', '北京', '天安门', ' ', '👩‍👩‍👧‍👦‍'] 26 | ``` 27 | 28 | ## Pretokenization 29 | 30 | Our tokenizers run on a graph structure, which we can manipulate via pre-tokenization functions. 31 | 32 | ## Units 33 | 34 | Units are the basic blocks we operate on, such as character, or bytes. 35 | We implement three basic blocks: 36 | 37 | ```python 38 | from complex_tokenization.graphs.units import characters, utf8, utf8_clusters 39 | 40 | text = "שלום" 41 | 42 | # Characters Split assigns a single node per character (4 characters) 43 | assert characters(text) == NodesSequence((Node("ש"), Node("ל"), Node("ו"), Node("ם"))) 44 | 45 | # UTF-8 Split assigns a single node per byte (8 bytes) 46 | assert utf8(text) == NodesSequence((Node(value=b'\xd7'), Node(value=b'\xa9'), 47 | Node(value=b'\xd7'), Node(value=b'\x9c'), 48 | Node(value=b'\xd7'), Node(value=b'\x95'), 49 | Node(value=b'\xd7'), Node(value=b'\x9d'))) 50 | 51 | # UTF-8 Clusters Split assigns a single node sequence per cluster, a single node per byte (4 clusters, 2 bytes each) 52 | assert utf8_clusters(text) == NodesSequence(( 53 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\xa9'))), 54 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x9c'))), 55 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x95'))), 56 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x9d'))))) 57 | ``` 58 | 59 | ## Words 60 | 61 | A long text that includes multiple words, can be treated as a single text (without boundaries), 62 | or each word could be considered a single cluster. 63 | 64 | Words can be "connected" to eachother, to allow merging over words, 65 | or "disconnected" to disallow merging over word boundaries. 66 | 67 | ```python 68 | from complex_tokenization.graphs.units import utf8_clusters 69 | from complex_tokenization.graphs.words import words 70 | 71 | text = "a few words" 72 | 73 | # Train tokenization on the entire text 74 | graph = utf8_clusters(text) 75 | 76 | # Treat each word as a cluster, and words are connected 77 | 78 | graph = words(text, units=utf8_clusters, connected=True) 79 | ``` 80 | 81 | ## Tokenizers Implementation 82 | 83 | ### BNE (Byte-Ngram Encoding) 84 | 85 | Byte-Ngram Encoding creates a merge over a sequence of units up to a certain size `N`. 86 | It treats words as disconnected units, and does not allow merges over unmerged clusters. 87 | 88 | ```python 89 | from complex_tokenization.graphs.settings import GraphSettings 90 | from complex_tokenization.graphs.units import utf8_clusters 91 | from complex_tokenization.graphs.words import words 92 | 93 | GraphSettings.ONLY_MINIMAL_MERGES = True # BNE only merges adjacent tokens 94 | GraphSettings.MAX_MERGE_SIZE = N # Maximum number of tokens to merge at a time 95 | 96 | text = "a large text corpus..." 97 | 98 | graph = words(text, units=utf8_clusters, connected=False) 99 | ``` 100 | 101 | ### BPE (Byte-Pair Encoding) 102 | 103 | Same as `BNE`, with a maximum of two tokens merged at a time `GraphSettings.MAX_MERGE_SIZE = 2`. 104 | 105 | ### BoundlessBPE 106 | 107 | ## Cite 108 | 109 | If you use this code in your research, please consider citing the work: 110 | 111 | ```bibtex 112 | @misc{moryossef2025complex, 113 | title={Tokenization for Complex Scripts}, 114 | author={Moryossef, Amit}, 115 | howpublished={\url{https://github.com/sign-language-processing/complex-tokenization}}, 116 | year={2025} 117 | } 118 | ``` -------------------------------------------------------------------------------- /complex_tokenization/languages/chinese/create_dictionary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | from pathlib import Path 4 | 5 | 6 | def clone_repo_if_needed(repo_url, repo_name): 7 | """Clone repository if it doesn't exist.""" 8 | repo_dir = Path(__file__).parent / repo_name 9 | if not repo_dir.exists(): 10 | print(f"Cloning {repo_name} repository to {repo_dir}...") 11 | subprocess.run( 12 | ["git", "clone", repo_url, str(repo_dir)], 13 | check=True 14 | ) 15 | return repo_dir 16 | 17 | 18 | def extract_ids(files): 19 | """Extract ids.txt and ids-ext-cdef.txt from cjkvi-ids.""" 20 | dictionary = {} 21 | 22 | for file_path in files: 23 | if not file_path.exists(): 24 | print(f"Warning: {file_path} not found, skipping...") 25 | continue 26 | 27 | print(f"Processing {file_path.name}...") 28 | with open(file_path, encoding='utf-8') as f: 29 | for line in f: 30 | line = line.strip() 31 | if not line or line.startswith('#'): 32 | continue 33 | parts = line.split('\t') 34 | if len(parts) >= 2: 35 | key = parts[1].strip() 36 | value = parts[2] 37 | if "[" in value: 38 | value = value[:value.index("[")] 39 | value = value.strip() 40 | if key != value: 41 | dictionary[key] = value 42 | 43 | return dictionary 44 | 45 | 46 | def load_canonicalization_rules(canonicalize_path): 47 | """Load canonicalization rules from canonicalize.txt.""" 48 | rules = {} 49 | 50 | if not canonicalize_path.exists(): 51 | print(f"Warning: {canonicalize_path} not found, skipping canonicalization...") 52 | return rules 53 | 54 | with open(canonicalize_path, encoding='utf-8') as f: 55 | for line in f: 56 | line = line.strip() 57 | if not line or line.startswith('#'): 58 | continue 59 | 60 | parts = line.split('\t') 61 | if len(parts) < 3: 62 | continue 63 | 64 | rule_type = parts[0] 65 | source = parts[1] 66 | target = parts[2].replace("~", "") # Remove any '~' characters 67 | 68 | # Only apply certain rule types for normalization 69 | if rule_type in ['identical', 'variant', 'print', 'preferred']: 70 | rules[source] = target 71 | 72 | return rules 73 | 74 | 75 | def canonicalize_dictionary(dictionary, canonicalization_rules): 76 | """Canonicalize dictionary values using canonicalization rules.""" 77 | if not canonicalization_rules: 78 | return dictionary 79 | 80 | print("Canonicalizing dictionary values...") 81 | canonicalized = {} 82 | 83 | for key, value in dictionary.items(): 84 | # Apply canonicalization rules to each character in the value 85 | canonical_value = "" 86 | for char in value: 87 | canonical_value += canonicalization_rules.get(char, char) 88 | if key != canonical_value: 89 | canonicalized[key] = canonical_value 90 | 91 | return canonicalized 92 | 93 | 94 | def expand_ids(value, dictionary): 95 | expanded_value = "" 96 | for char in value: 97 | if char in dictionary: 98 | dictionary[char] = expand_ids(dictionary[char], dictionary) 99 | expanded_value += dictionary[char] 100 | else: 101 | expanded_value += char 102 | return expanded_value 103 | 104 | 105 | def expand_dictionary(dictionary): 106 | """Expand dictionary values by replacing characters that exist as keys.""" 107 | print("Expanding dictionary values...") 108 | expanded = {} 109 | 110 | for key, value in dictionary.items(): 111 | if "{" in value: 112 | continue 113 | 114 | expanded[key] = expand_ids(value, dictionary) 115 | 116 | return expanded 117 | 118 | 119 | def save_dictionary(dictionary): 120 | """Save the dictionary to dictionary.json.""" 121 | output_path = Path(__file__).parent / "dictionary.json" 122 | print(f"Saving dictionary to {output_path}...") 123 | with open(output_path, 'w', encoding='utf-8') as f: 124 | json.dump(dictionary, f, ensure_ascii=False, indent=2) 125 | print(f"Saved {len(dictionary)} entries to {output_path}") 126 | 127 | 128 | def main(): 129 | # Clone both repositories 130 | hfhchan_repo = clone_repo_if_needed("https://github.com/hfhchan/ids.git", "hfhchan-ids") 131 | cjkvi_repo = clone_repo_if_needed("https://github.com/cjkvi/cjkvi-ids.git", "cjkvi-ids") 132 | 133 | # Load canonicalization rules 134 | canonicalization_rules = load_canonicalization_rules(hfhchan_repo / "canonicalize.txt") 135 | 136 | # Extract from hfhchan first, then cjkvi (cjkvi overwrites) 137 | dictionary = extract_ids([ 138 | hfhchan_repo / "release" / "ids-20240112.txt", 139 | cjkvi_repo / "ids.txt", 140 | cjkvi_repo / "ids-ext-cdef.txt" 141 | ]) 142 | 143 | # Canonicalize, then expand 144 | dictionary = canonicalize_dictionary(dictionary, canonicalization_rules) 145 | dictionary = expand_dictionary(dictionary) 146 | save_dictionary(dictionary) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /tests/test_graph.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from complex_tokenization.graph import GraphVertex, Node, NodesSequence 4 | from complex_tokenization.graphs.settings import GraphSettings 5 | from complex_tokenization.graphs.units import characters, utf8, utf8_clusters 6 | from complex_tokenization.graphs.words import words 7 | 8 | 9 | def readable_merges(graph: GraphVertex): 10 | counter = Counter(graph.get_merges()) 11 | byte_merges = {} 12 | for nodes, v in counter.items(): 13 | k = b''.join(bytes(node) for node in nodes) 14 | byte_merges[k] = v 15 | return byte_merges 16 | 17 | 18 | class TestUnitsWord: 19 | def test_characters_split(self): 20 | assert characters("שלום") == NodesSequence((Node("ש"), Node("ל"), Node("ו"), Node("ם"))) 21 | 22 | def test_utf8_split(self): 23 | assert utf8("שלום") == NodesSequence((Node(value=b'\xd7'), Node(value=b'\xa9'), 24 | Node(value=b'\xd7'), Node(value=b'\x9c'), 25 | Node(value=b'\xd7'), Node(value=b'\x95'), 26 | Node(value=b'\xd7'), Node(value=b'\x9d'))) 27 | 28 | def test_utf8_clusters_split(self): 29 | assert utf8_clusters("שלום") == NodesSequence(( 30 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\xa9'))), 31 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x9c'))), 32 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x95'))), 33 | NodesSequence((Node(value=b'\xd7'), Node(value=b'\x9d'))))) 34 | 35 | def test_utf8_ascii_same_as_cluster(self): 36 | assert utf8_clusters('word') == utf8('word') 37 | 38 | def test_utf8_cluster_is_split(self): 39 | graph = utf8_clusters('שלום') 40 | assert isinstance(graph, NodesSequence) 41 | assert len(graph.nodes) == 4 42 | for node in graph.nodes: 43 | assert isinstance(node, NodesSequence) 44 | assert len(node.nodes) == 2 45 | assert isinstance(node.nodes[0], Node) 46 | 47 | def test_utf8_ascii_2_merges(self): 48 | GraphSettings.MAX_MERGE_SIZE = 2 49 | 50 | graph = utf8('lalaland') 51 | merges = readable_merges(graph) 52 | assert len(merges) == 4 53 | 54 | assert merges[b'la'] == 3 55 | assert merges[b'al'] == 2 56 | assert merges[b'an'] == 1 57 | assert merges[b'nd'] == 1 58 | 59 | def test_utf8_ascii_3_merges(self): 60 | GraphSettings.MAX_MERGE_SIZE = 3 61 | 62 | graph = utf8('lalaland') 63 | merges = readable_merges(graph) 64 | assert len(merges) == 8 65 | 66 | assert merges[b'la'] == 3 67 | assert merges[b'lal'] == 2 68 | assert merges[b'al'] == 2 69 | assert merges[b'ala'] == 2 70 | assert merges[b'lan'] == 1 71 | assert merges[b'an'] == 1 72 | assert merges[b'and'] == 1 73 | assert merges[b'nd'] == 1 74 | 75 | def test_utf8_cluster_minimal_merges(self): 76 | GraphSettings.MAX_MERGE_SIZE = 100 77 | GraphSettings.ONLY_MINIMAL_MERGES = True 78 | graph = utf8_clusters('שלום') 79 | merges = readable_merges(graph) 80 | 81 | print(merges) 82 | # Only character sequences should be valid 83 | assert merges['ש'.encode()] == 1 84 | assert merges['ל'.encode()] == 1 85 | assert merges['ו'.encode()] == 1 86 | assert merges['ם'.encode()] == 1 87 | 88 | def test_utf8_cluster_non_minimal_merges(self): 89 | GraphSettings.MAX_MERGE_SIZE = 100 90 | GraphSettings.ONLY_MINIMAL_MERGES = False 91 | graph = utf8_clusters('שלום') 92 | merges = readable_merges(graph) 93 | 94 | # Basically, every subsequence is valid 95 | assert merges['ש'.encode()] == 1 96 | assert merges['ל'.encode()] == 1 97 | assert merges['ו'.encode()] == 1 98 | assert merges['ם'.encode()] == 1 99 | assert merges['של'.encode()] == 1 100 | assert merges['שלו'.encode()] == 1 101 | assert merges['שלום'.encode()] == 1 102 | assert merges['לו'.encode()] == 1 103 | assert merges['לום'.encode()] == 1 104 | assert merges['ום'.encode()] == 1 105 | 106 | 107 | class TestWords: 108 | def test_single_word_same_as_utf8_clusters(self): 109 | # Single word should be identical to utf8_clusters 110 | assert words('word') == utf8_clusters('word') 111 | assert words('שלום') == utf8_clusters('שלום') 112 | 113 | def test_multiple_words_count(self): 114 | # Test that multiple words are properly split and counted 115 | graph = words('hello world test') 116 | assert isinstance(graph, NodesSequence) 117 | assert len(graph.nodes) == 3 118 | 119 | def test_two_words_minimal_merges(self): 120 | GraphSettings.MAX_MERGE_SIZE = 10 121 | GraphSettings.ONLY_MINIMAL_MERGES = True 122 | 123 | graph = words('hi bye') 124 | merges = readable_merges(graph) 125 | assert len(merges) == 7 126 | 127 | assert merges[b'hi'] == 1 128 | assert merges[b' b'] == 1 129 | assert merges[b' by'] == 1 130 | assert merges[b' bye'] == 1 131 | assert merges[b'by'] == 1 132 | assert merges[b'bye'] == 1 133 | assert merges[b'ye'] == 1 134 | 135 | def test_two_words_non_minimal_merge(self): 136 | GraphSettings.MAX_MERGE_SIZE = 10 137 | GraphSettings.ONLY_MINIMAL_MERGES = False 138 | 139 | graph = words('hi bye') 140 | merges = readable_merges(graph) 141 | assert len(merges) == 8 142 | 143 | assert merges[b'hi bye'] == 1 144 | -------------------------------------------------------------------------------- /complex_tokenization/languages/chinese/ideographic_description_sequences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parser for Ideographic Description Sequences (IDS). 3 | 4 | IDS uses Ideographic Description Characters (IDC) to describe the structure of Chinese characters. 5 | IDCs range from U+2FF0 (⿰) to U+2FFB (⿻). 6 | 7 | Binary IDCs (take 2 components): 8 | ⿰ (U+2FF0) - left to right 9 | ⿱ (U+2FF1) - above to below 10 | ⿲ (U+2FF2) - left to middle to right 11 | ⿳ (U+2FF3) - above to middle to below 12 | ⿴ (U+2FF4) - surround from above 13 | ⿵ (U+2FF5) - surround from below 14 | ⿶ (U+2FF6) - surround from left 15 | ⿷ (U+2FF7) - surround from upper left 16 | ⿸ (U+2FF8) - surround from upper right 17 | ⿹ (U+2FF9) - surround from lower left 18 | ⿺ (U+2FFA) - overlaid 19 | 20 | Ternary IDCs (take 3 components): 21 | ⿲ (U+2FF2) - left to middle to right 22 | ⿳ (U+2FF3) - above to middle to below 23 | """ 24 | 25 | import json 26 | from dataclasses import dataclass 27 | from functools import cache 28 | from pathlib import Path 29 | 30 | # IDCs that take 2 components 31 | BINARY_IDCS = set('⿰⿱⿴⿵⿶⿷⿸⿹⿺⿻') 32 | 33 | # IDCs that take 3 components 34 | TERNARY_IDCS = set('⿲⿳') 35 | 36 | # All IDCs 37 | ALL_IDCS = BINARY_IDCS | TERNARY_IDCS 38 | 39 | 40 | @dataclass 41 | class IDSNode: 42 | """Represents a node in the IDS tree.""" 43 | value: str 44 | children: list['IDSNode'] = None 45 | 46 | def __post_init__(self): 47 | if self.children is None: 48 | self.children = [] 49 | 50 | def is_leaf(self) -> bool: 51 | """Check if this node is a leaf (radical).""" 52 | return len(self.children) == 0 53 | 54 | def is_template(self) -> bool: 55 | """Check if this node is a template (IDC).""" 56 | return self.value in ALL_IDCS 57 | 58 | def to_dict(self): 59 | """Convert to dictionary representation.""" 60 | if self.is_leaf(): 61 | return {"type": "radical", "value": self.value} 62 | else: 63 | return { 64 | "type": "template", 65 | "value": self.value, 66 | "children": [child.to_dict() for child in self.children] 67 | } 68 | 69 | 70 | def parse_ideographic_description_sequences(ids: str) -> IDSNode: 71 | """ 72 | Parse an Ideographic Description Sequence into a tree structure. 73 | 74 | Args: 75 | ids: The IDS string to parse (e.g., "⿱⿳𠂊田一⿰⿳𠂊田一⿳𠂊田一") 76 | 77 | Returns: 78 | IDSNode: Root node of the parsed tree 79 | 80 | Example: 81 | >>> tree = parse_ideographic_description_sequences("⿱⿳𠂊田一⿰⿳𠂊田一⿳𠂊田一") 82 | >>> tree.value 83 | '⿱' 84 | >>> len(tree.children) 85 | 2 86 | """ 87 | if not ids: 88 | raise ValueError("Empty IDS string") 89 | 90 | index = [0] # Use list to maintain reference in nested function 91 | 92 | def parse_node() -> IDSNode: 93 | if index[0] >= len(ids): 94 | raise ValueError(f"Unexpected end of IDS string at position {index[0]}") 95 | 96 | char = ids[index[0]] 97 | index[0] += 1 98 | 99 | if char in TERNARY_IDCS: 100 | # Parse 3 children 101 | node = IDSNode(value=char) 102 | node.children = [parse_node() for _ in range(3)] 103 | return node 104 | elif char in BINARY_IDCS: 105 | # Parse 2 children 106 | node = IDSNode(value=char) 107 | node.children = [parse_node() for _ in range(2)] 108 | return node 109 | else: 110 | # Leaf node (radical) 111 | return IDSNode(value=char) 112 | 113 | root = parse_node() 114 | 115 | # Verify we consumed the entire string 116 | if index[0] < len(ids): 117 | raise ValueError(f"Extra characters after parsing: {ids[index[0]:]}") 118 | 119 | return root 120 | 121 | 122 | def ids_tree_to_string(node: IDSNode, prefix: str = "", is_last: bool = True) -> str: 123 | """ 124 | Convert an IDS tree to a readable ASCII tree representation. 125 | 126 | Args: 127 | node: The root node of the tree 128 | prefix: Current line prefix for drawing branches 129 | is_last: Whether this is the last child of its parent 130 | 131 | Returns: 132 | String representation of the tree 133 | """ 134 | # Current node connector 135 | connector = "└── " if is_last else "├── " 136 | 137 | # Node label 138 | if node.is_leaf(): 139 | label = f"Radical: {node.value}" 140 | else: 141 | label = f"Template: {node.value}" 142 | 143 | result = prefix + connector + label + "\n" 144 | 145 | # Process children 146 | if not node.is_leaf(): 147 | # Extension for children's prefix 148 | extension = " " if is_last else "│ " 149 | new_prefix = prefix + extension 150 | 151 | for i, child in enumerate(node.children): 152 | is_last_child = (i == len(node.children) - 1) 153 | result += ids_tree_to_string(child, new_prefix, is_last_child) 154 | 155 | return result 156 | 157 | @cache 158 | def load_characters_dictionary(): 159 | """Load the IDS dictionary from the JSON file.""" 160 | dictionary_path = Path(__file__).parent / "dictionary.json" 161 | with open(dictionary_path, encoding='utf-8') as f: 162 | dictionary = json.load(f) 163 | return dictionary 164 | 165 | @cache 166 | def reversed_characters_dictionary(): 167 | """Load the reversed IDS dictionary from the JSON file.""" 168 | dictionary = load_characters_dictionary() 169 | reversed_dict = {v: k for k, v in dictionary.items()} 170 | return reversed_dict 171 | 172 | 173 | def get_ids_for_character(char: str) -> str: 174 | """Get the IDS for a given character from the dictionary.""" 175 | dictionary = load_characters_dictionary() 176 | return dictionary.get(char) 177 | 178 | def get_character_for_ids(ids: str) -> str | None: 179 | """Get the character for a given IDS from the reversed dictionary.""" 180 | reversed_dict = reversed_characters_dictionary() 181 | return reversed_dict.get(ids, None) 182 | -------------------------------------------------------------------------------- /tests/languages/test_chinese.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from complex_tokenization.languages.chinese.ideographic_description_sequences import ( 7 | ids_tree_to_string, 8 | parse_ideographic_description_sequences, 9 | ) 10 | 11 | 12 | class TestDictionary: 13 | def test_all_values_unique(self): 14 | """Test that all values in dictionary.json are unique""" 15 | dictionary_path = Path(__file__).parent.parent / "complex_tokenization" / "chinese" / "dictionary.json" 16 | 17 | if not dictionary_path.exists(): 18 | pytest.skip("dictionary.json not found") 19 | 20 | with open(dictionary_path) as f: 21 | dictionary = json.load(f) 22 | 23 | values = list(dictionary.values()) 24 | unique_values = set(values) 25 | 26 | if len(values) != len(unique_values): 27 | # Find duplicates 28 | from collections import Counter 29 | value_counts = Counter(values) 30 | duplicates = {v: count for v, count in value_counts.items() if count > 1} 31 | # Sort by count descending, then by value for stable ordering 32 | top_duplicates = sorted(duplicates.items(), key=lambda x: (-x[1], x[0]))[:20] 33 | top_duplicates_str = [(f"{v}: {count} occurrences " 34 | f"({'/'.join([k for k, val in dictionary.items() if val == v])})") 35 | for v, count in top_duplicates] 36 | pytest.fail(f"Found {len(duplicates)} ({len(values) - len(unique_values)}) duplicate values:\n" 37 | + "\n".join(top_duplicates_str)) 38 | 39 | 40 | class TestIdeographicDescriptionSequences: 41 | def test_parse_complex_nested(self): 42 | """Test parsing: ⿱⿳𠂊田一⿰⿳𠂊田一⿳𠂊田一""" 43 | ids = "⿱⿳𠂊田一⿰⿳𠂊田一⿳𠂊田一" 44 | tree = parse_ideographic_description_sequences(ids) 45 | 46 | # Check root 47 | assert tree.value == "⿱" 48 | assert len(tree.children) == 2 49 | 50 | # Check first child (⿳𠂊田一) 51 | first_child = tree.children[0] 52 | assert first_child.value == "⿳" 53 | assert len(first_child.children) == 3 54 | assert first_child.children[0].value == "𠂊" 55 | assert first_child.children[1].value == "田" 56 | assert first_child.children[2].value == "一" 57 | 58 | # Check second child (⿰⿳𠂊田一⿳𠂊田一) 59 | second_child = tree.children[1] 60 | assert second_child.value == "⿰" 61 | assert len(second_child.children) == 2 62 | 63 | expected_tree = """└── Template: ⿱ 64 | ├── Template: ⿳ 65 | │ ├── Radical: 𠂊 66 | │ ├── Radical: 田 67 | │ └── Radical: 一 68 | └── Template: ⿰ 69 | ├── Template: ⿳ 70 | │ ├── Radical: 𠂊 71 | │ ├── Radical: 田 72 | │ └── Radical: 一 73 | └── Template: ⿳ 74 | ├── Radical: 𠂊 75 | ├── Radical: 田 76 | └── Radical: 一 77 | """ 78 | assert ids_tree_to_string(tree) == expected_tree 79 | 80 | def test_parse_with_surround(self): 81 | """Test parsing: ⿰⿳爫龴⿵冂⿱厶又乚""" 82 | ids = "⿰⿳爫龴⿵冂⿱厶又乚" 83 | tree = parse_ideographic_description_sequences(ids) 84 | 85 | # Check root 86 | assert tree.value == "⿰" 87 | assert len(tree.children) == 2 88 | 89 | expected_tree = """└── Template: ⿰ 90 | ├── Template: ⿳ 91 | │ ├── Radical: 爫 92 | │ ├── Radical: 龴 93 | │ └── Template: ⿵ 94 | │ ├── Radical: 冂 95 | │ └── Template: ⿱ 96 | │ ├── Radical: 厶 97 | │ └── Radical: 又 98 | └── Radical: 乚 99 | """ 100 | assert ids_tree_to_string(tree) == expected_tree 101 | 102 | def test_parse_with_overlay(self): 103 | """Test parsing: ⿱⿻⿻コ一丨一""" 104 | ids = "⿱⿻⿻コ一丨一" 105 | tree = parse_ideographic_description_sequences(ids) 106 | 107 | # Check root 108 | assert tree.value == "⿱" 109 | assert len(tree.children) == 2 110 | 111 | expected_tree = """└── Template: ⿱ 112 | ├── Template: ⿻ 113 | │ ├── Template: ⿻ 114 | │ │ ├── Radical: コ 115 | │ │ └── Radical: 一 116 | │ └── Radical: 丨 117 | └── Radical: 一 118 | """ 119 | assert ids_tree_to_string(tree) == expected_tree 120 | 121 | def test_simple_binary(self): 122 | """Test simple binary IDS: ⿰木木""" 123 | ids = "⿰木木" 124 | tree = parse_ideographic_description_sequences(ids) 125 | 126 | assert tree.value == "⿰" 127 | assert len(tree.children) == 2 128 | assert tree.children[0].value == "木" 129 | assert tree.children[1].value == "木" 130 | 131 | expected_tree = """└── Template: ⿰ 132 | ├── Radical: 木 133 | └── Radical: 木 134 | """ 135 | assert ids_tree_to_string(tree) == expected_tree 136 | 137 | def test_empty_string_raises_error(self): 138 | """Test that empty string raises ValueError""" 139 | with pytest.raises(ValueError, match="Empty IDS string"): 140 | parse_ideographic_description_sequences("") 141 | 142 | def test_to_dict(self): 143 | """Test tree to dictionary conversion""" 144 | ids = "⿰木木" 145 | tree = parse_ideographic_description_sequences(ids) 146 | result = tree.to_dict() 147 | 148 | assert result["type"] == "template" 149 | assert result["value"] == "⿰" 150 | assert len(result["children"]) == 2 151 | assert result["children"][0]["type"] == "radical" 152 | assert result["children"][0]["value"] == "木" 153 | 154 | def test_all_dictionary_items_parseable(self): 155 | """Test that all items in dictionary.json are parseable""" 156 | dictionary_path = Path(__file__).parent.parent / "complex_tokenization" / "chinese" / "dictionary.json" 157 | 158 | if not dictionary_path.exists(): 159 | pytest.skip("dictionary.json not found") 160 | 161 | with open(dictionary_path) as f: 162 | dictionary = json.load(f) 163 | 164 | templates = list(dictionary.values()) 165 | 166 | for template in templates: 167 | parse_ideographic_description_sequences(template) 168 | -------------------------------------------------------------------------------- /complex_tokenization/languages/chinese/frequency.py: -------------------------------------------------------------------------------- 1 | """ 2 | Analyze frequency of near-leaf node patterns in Chinese characters from Wikipedia data. 3 | 4 | This script: 5 | 1. Downloads Chinese Wikipedia dataset from HuggingFace 6 | 2. Extracts and counts all Chinese characters 7 | 3. Decomposes characters using Ideographic Description Sequences (IDS) 8 | 4. Identifies near-leaf nodes (nodes whose children are all leaves) 9 | 5. Counts frequency of near-leaf patterns weighted by character frequency 10 | """ 11 | 12 | from collections import Counter 13 | 14 | from datasets import load_dataset 15 | from tqdm import tqdm 16 | 17 | from complex_tokenization.languages.chinese.ideographic_description_sequences import ( 18 | IDSNode, 19 | get_character_for_ids, 20 | get_ids_for_character, 21 | parse_ideographic_description_sequences, 22 | ) 23 | 24 | 25 | def is_chinese_character(char: str) -> bool: 26 | """Check if a character is a Chinese character (CJK Unified Ideographs).""" 27 | code = ord(char) 28 | return ( 29 | 0x4E00 <= code <= 0x9FFF or # CJK Unified Ideographs 30 | 0x3400 <= code <= 0x4DBF or # CJK Unified Ideographs Extension A 31 | 0x20000 <= code <= 0x2A6DF or # CJK Unified Ideographs Extension B 32 | 0x2A700 <= code <= 0x2B73F or # CJK Unified Ideographs Extension C 33 | 0x2B740 <= code <= 0x2B81F or # CJK Unified Ideographs Extension D 34 | 0x2B820 <= code <= 0x2CEAF or # CJK Unified Ideographs Extension E 35 | 0x2CEB0 <= code <= 0x2EBEF or # CJK Unified Ideographs Extension F 36 | 0x30000 <= code <= 0x3134F # CJK Unified Ideographs Extension G 37 | ) 38 | 39 | 40 | def extract_chinese_characters(text: str) -> list[str]: 41 | """Extract all Chinese characters from text.""" 42 | return [char for char in text if is_chinese_character(char)] 43 | 44 | 45 | def linearize_preorder(node: IDSNode) -> tuple[str, ...]: 46 | """ 47 | Linearize a subtree in preorder (node, then children). 48 | Returns a tuple of values. 49 | """ 50 | if node.is_leaf(): 51 | return (node.value,) 52 | 53 | result = [node.value] 54 | for child in node.children: 55 | result.extend(linearize_preorder(child)) 56 | return tuple(result) 57 | 58 | 59 | def find_all_subtree_patterns(node: IDSNode) -> list[tuple[str, ...]]: 60 | """ 61 | Find all non-leaf subtrees in the tree and linearize them in preorder. 62 | Returns a list of tuples representing each subtree's preorder traversal. 63 | """ 64 | patterns = [] 65 | 66 | def traverse(node: IDSNode): 67 | if node.is_leaf(): 68 | return 69 | 70 | # Linearize this subtree 71 | if all(child.is_leaf() for child in node.children): 72 | pattern = linearize_preorder(node) 73 | patterns.append(pattern) 74 | 75 | # Continue traversing to find all subtrees 76 | for child in node.children: 77 | traverse(child) 78 | 79 | traverse(node) 80 | return patterns 81 | 82 | 83 | def main(): 84 | print("Loading Chinese Wikipedia dataset from HuggingFace...") 85 | dataset = load_dataset("Jax-dan/zhwiki-latest", split="train", streaming=True) 86 | dataset = dataset.take(1000) # Limit for testing 87 | 88 | print("Extracting and counting Chinese characters...") 89 | character_counter = Counter() 90 | 91 | # Process dataset 92 | for item in tqdm(dataset, desc="Processing articles"): 93 | text = item.get('text', '') 94 | characters = extract_chinese_characters(text) 95 | character_counter.update(characters) 96 | 97 | print(f"\nTotal unique characters found: {len(character_counter)}") 98 | print(f"Total character occurrences: {sum(character_counter.values())}") 99 | 100 | print("\nDecomposing characters and analyzing all subtree patterns...") 101 | pattern_counter = Counter() 102 | characters_processed = 0 103 | characters_with_ids = 0 104 | 105 | for char, freq in tqdm(character_counter.items(), desc="Analyzing characters"): 106 | characters_processed += 1 107 | 108 | # Get IDS for character 109 | ids = get_ids_for_character(char) 110 | if ids is None: 111 | continue 112 | 113 | characters_with_ids += 1 114 | 115 | try: 116 | # Parse IDS into tree 117 | tree = parse_ideographic_description_sequences(ids) 118 | 119 | # Find all subtree patterns 120 | patterns = find_all_subtree_patterns(tree) 121 | 122 | # Count patterns weighted by character frequency 123 | for pattern in patterns: 124 | pattern_counter[pattern] += freq 125 | except Exception: # noqa: BLE001 126 | # Skip characters that fail to parse (various parsing errors possible from IDS data) 127 | pass 128 | 129 | print(f"\nCharacters processed: {characters_processed}") 130 | print(f"Characters with IDS: {characters_with_ids}") 131 | print(f"Unique subtree patterns found: {len(pattern_counter)}") 132 | 133 | # Print most common patterns 134 | print("\n" + "="*80) 135 | print("MOST COMMON SUBTREE PATTERNS (sorted by compression = price * frequency)") 136 | print("="*80) 137 | print(f"{'Rank':<6} {'Compression':<15} {'Frequency':<15} {'Pattern':<8} {'Character (if Exists)'}") 138 | print("-"*80) 139 | 140 | pricing = Counter({pattern: len(pattern) * count for pattern, count in pattern_counter.items()}) 141 | 142 | for rank, (pattern, price) in enumerate(pricing.most_common(50), 1): 143 | freq = pattern_counter[pattern] 144 | pattern_str = "".join(pattern) 145 | print(f"{rank:<6} {price:<15,} {freq:<15,} {pattern_str:<8} {get_character_for_ids(pattern_str)}") 146 | 147 | # Calculate coverage 148 | total_pattern_frequency = sum(pattern_counter.values()) 149 | total_char_frequency = sum(character_counter.values()) 150 | coverage = (total_pattern_frequency / total_char_frequency) * 100 if total_char_frequency > 0 else 0 151 | 152 | print("\n" + "="*80) 153 | print(f"Total pattern occurrences: {total_pattern_frequency:,}") 154 | print(f"Coverage of all characters: {coverage:.2f}%") 155 | print("="*80) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /complex_tokenization/graph.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable, Iterator 2 | from dataclasses import dataclass 3 | 4 | from complex_tokenization.graphs.settings import GraphSettings 5 | from complex_tokenization.languages.chinese.ideographic_description_sequences import get_character_for_ids 6 | 7 | 8 | def dot_escape(s: str) -> str: 9 | return s \ 10 | .replace("\\", "\\\\") \ 11 | .replace('"', '\\"') \ 12 | .replace("\n", "\\n") 13 | 14 | 15 | 16 | 17 | class GraphVertex: 18 | _instances = {} # Singleton pattern 19 | 20 | def __new__(cls, *args, **kwargs): 21 | if not GraphSettings.USE_SINGLETONS: 22 | return super().__new__(cls) 23 | 24 | key = (args, tuple(sorted(kwargs.items()))) 25 | if key not in cls._instances: 26 | cls._instances[key] = super().__new__(cls) 27 | cls._instances[key].__init__(*args, **kwargs) # optional, if side effects desired 28 | return cls._instances[key] 29 | 30 | def __bytes__(self): 31 | raise NotImplementedError 32 | 33 | def __str__(self): 34 | self_str = bytes(self).decode("utf-8", errors="replace") 35 | token_replacement = get_character_for_ids(self_str) 36 | if token_replacement is not None: 37 | return token_replacement 38 | return self_str 39 | 40 | def __eq__(self, other): 41 | return self is other 42 | 43 | def __hash__(self): 44 | return id(self) 45 | 46 | def dot(self, level=0) -> Iterable[str]: 47 | raise NotImplementedError 48 | 49 | @property 50 | def oid(self) -> str: # object pointer id for Graphviz node id 51 | return f"o{id(self):x}" 52 | 53 | def get_merges(self) -> list[str] | Iterator[tuple[str, ...]]: 54 | return [] 55 | 56 | def merge(self, token, merge) -> "GraphVertex": 57 | raise NotImplementedError 58 | 59 | 60 | @dataclass(frozen=True, slots=True) 61 | class Node(GraphVertex): 62 | value: bytes 63 | 64 | def __bytes__(self): 65 | return self.value 66 | 67 | def dot(self, level=0) -> Iterable[str]: 68 | yield "\t" * level + f'{self.oid} [label="{dot_escape(str(self))}"];' 69 | 70 | def merge(self, token: "Node", merge: tuple): 71 | return self 72 | 73 | def __eq__(self, other): 74 | if not isinstance(other, Node): 75 | return False 76 | return self.value == other.value 77 | 78 | def __add__(self, other): 79 | if isinstance(other, NodesSequence): 80 | return NodesSequence(tuple([self]) + other.nodes) 81 | return Node(value=self.value + other.value) 82 | 83 | def __len__(self): 84 | return len(self.value) 85 | 86 | 87 | @dataclass(frozen=True, slots=True) 88 | class NodesSequence(GraphVertex): 89 | nodes: tuple[GraphVertex, ...] 90 | 91 | def __bytes__(self): 92 | buffer = bytearray() 93 | for node in self.nodes: 94 | buffer += bytes(node) 95 | return bytes(buffer) 96 | 97 | @property 98 | def oid(self) -> str: # object pointer id for Graphviz node id 99 | return self.nodes[0].oid 100 | 101 | def get_merges(self): 102 | num_nodes = len(self.nodes) 103 | for i, node in enumerate(self.nodes): 104 | yield from node.get_merges() 105 | 106 | if GraphSettings.ONLY_MINIMAL_MERGES and not isinstance(node, Node): 107 | continue 108 | 109 | for j in range(i + 2, min(i + GraphSettings.MAX_MERGE_SIZE + 1, num_nodes + 1)): 110 | if GraphSettings.ONLY_MINIMAL_MERGES and j < num_nodes and not isinstance(self.nodes[j], Node): 111 | break 112 | yield tuple(self.nodes[i:j]) 113 | 114 | def merge(self, token: Node, merge: tuple[Node, ...]): 115 | m = len(merge) 116 | i = 0 117 | out: list[Node] = [] 118 | nodes = self.nodes # local alias 119 | 120 | while i <= len(nodes) - m: 121 | if tuple(nodes[i:i + m]) == merge: 122 | out.append(Node(value=token.value)) 123 | i += m # skip the matched span 124 | else: 125 | out.append(nodes[i]) 126 | i += 1 127 | 128 | # append any remaining tail 129 | out.extend(nodes[i:]) 130 | 131 | if len(out) == 1: 132 | return out[0] 133 | 134 | merged_nodes = tuple([n.merge(token, merge) for n in out]) 135 | return NodesSequence(merged_nodes) 136 | 137 | def dot(self, level=0) -> Iterable[str]: 138 | color = "grey" if level % 2 == 1 else "lightgrey" 139 | 140 | # create a subgraph to group nodes 141 | yield f"subgraph cluster_{id(self)} {{" 142 | yield f'\tlabel="{str(self)}";' 143 | yield f'\tstyle=filled; color="{color}";' 144 | yield '\tnode [style=filled, color=white];' 145 | yield '' 146 | yield '\tedge [arrowhead=none];' 147 | yield '' 148 | last_node = None 149 | for node in self.nodes: 150 | yield f'\t{"".join(node.dot(level + 1))}' 151 | if last_node is not None: 152 | yield f'\t{last_node.oid} -> {node.oid};' 153 | last_node = node 154 | yield '' 155 | yield "}" 156 | 157 | def __add__(self, other): 158 | if isinstance(other, NodesSequence): 159 | return NodesSequence(self.nodes + other.nodes) 160 | if isinstance(other, Node): 161 | print("other", type(other)) 162 | print(self.nodes, type(self.nodes)) 163 | return NodesSequence(self.nodes + tuple([other])) 164 | 165 | 166 | @dataclass(frozen=True, slots=True) 167 | class Tree(GraphVertex): 168 | root: GraphVertex 169 | children: tuple[GraphVertex, ...] 170 | 171 | def dot(self, level=0) -> Iterable[str]: 172 | color = "#cce5ff" if level % 2 == 1 else "lightblue" 173 | 174 | yield "subgraph cluster_" + self.oid + " {" 175 | yield f'\tlabel="{dot_escape(str(self))}";' 176 | yield f'\tstyle=filled; color="{color}";' 177 | yield '\tnode [style=filled, color=white];' 178 | yield '\tedge [arrowhead=normal];' 179 | yield '' 180 | yield from self.root.dot(level + 1) 181 | for i, child in enumerate(self.children, start=1): 182 | yield from ("\t" + line for line in child.dot(level + 1)) 183 | yield f'\t{self.oid} -> {child.oid} [label="{i}"];' 184 | yield '' 185 | yield '\tedge [arrowhead=none];' 186 | yield "}" 187 | 188 | @property 189 | def oid(self, level=0) -> str: 190 | return self.root.oid 191 | 192 | def get_merges(self) -> Iterator[tuple]: 193 | yield from self.root.get_merges() 194 | yield (self.root,) + self.children 195 | for child in self.children: 196 | yield from child.get_merges() 197 | 198 | def merge(self, token: Node, nodes: tuple): 199 | if nodes[0] == self.root: 200 | if len(nodes) == len(self.children) + 1: 201 | if all(nodes[i + 1] == child for i, child in enumerate(self.children)): 202 | return Node(value=token.value) 203 | 204 | root = self.root.merge(token, nodes) 205 | children = tuple(child.merge(token, nodes) for child in self.children) 206 | return Tree(root=root, children=children) 207 | 208 | def __bytes__(self): 209 | self_bytes = bytes(self.root) 210 | for child in self.children: 211 | self_bytes += bytes(child) 212 | return self_bytes 213 | 214 | 215 | @dataclass(frozen=True, slots=True) 216 | class UnconnectedGraphs(GraphVertex): 217 | subgraphs: tuple[GraphVertex, ...] 218 | 219 | def __bytes__(self): 220 | raise Exception("Cannot convert UnconnectedGraphs to bytes") 221 | 222 | def merge(self, token: Node, merge: tuple): 223 | subgraphs = tuple(subgraph.merge(token, merge) for subgraph in self.subgraphs) 224 | return UnconnectedGraphs(subgraphs=subgraphs) 225 | 226 | def get_merges(self) -> Iterator[tuple]: 227 | for subgraph in self.subgraphs: 228 | yield from subgraph.get_merges() 229 | 230 | def dot(self, level=0) -> Iterable[str]: 231 | for subgraph in self.subgraphs: 232 | yield from subgraph.dot(level) 233 | 234 | 235 | 236 | 237 | 238 | --------------------------------------------------------------------------------