├── tests ├── __init__.py ├── test_noop.py ├── test_collator.py ├── test_config.py ├── test_model_deterministic.py ├── test_attention.py ├── test_model_overfitting.py ├── test_model_generation.py ├── test_batch_image_encoder.py ├── test_train.py ├── test_model.py └── test_processor.py ├── welt ├── __init__.py ├── vision │ ├── __init__.py │ ├── test_navit.py │ ├── navit.py │ ├── README.md │ ├── batch_image_encoder.py │ └── vision_utils.py ├── noop.py ├── config.py ├── collator.py ├── attention.py ├── model_utils.py └── processor.py ├── examples ├── __init__.py ├── benchmark_processor.py └── benchmark_image_encoder.py ├── training ├── __init__.py ├── experiments │ ├── easy-tasks │ │ ├── ocr.yaml │ │ ├── README.md │ │ └── string-repetition.yaml │ ├── bpe-pretokenizer │ │ ├── README.md │ │ └── welt-bpe-14m.yaml │ ├── machine-translation │ │ ├── machine-translation.yaml │ │ ├── machine-translation-signed-spoken.yaml │ │ └── README.md │ └── chat │ │ └── README.md ├── args_trainer.py ├── freeze_callback.py ├── sample.py ├── README.md ├── args_model.py ├── extendable_yaml.py ├── args_data.py └── train.py ├── assets ├── phenomena.png ├── architecture.png └── architecture.excalidraw.zip ├── .github ├── workflows │ ├── lint.yaml │ ├── test.yaml │ ├── docker-build.yaml │ └── publish-docker.yaml └── actions │ └── setup-environment │ └── action.yaml ├── Dockerfile ├── LICENSE ├── pyproject.toml ├── MOTIVATION.md ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /welt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /welt/vision/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/phenomena.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign/WeLT/main/assets/phenomena.png -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign/WeLT/main/assets/architecture.png -------------------------------------------------------------------------------- /assets/architecture.excalidraw.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign/WeLT/main/assets/architecture.excalidraw.zip -------------------------------------------------------------------------------- /training/experiments/easy-tasks/ocr.yaml: -------------------------------------------------------------------------------- 1 | $extends: ./string-repetition.yaml 2 | 3 | # Model Setup 4 | image_encoder_model_name_or_path: "NaViT-tiny" 5 | bytes_encoder_model_name_or_path: null 6 | 7 | output_dir: ./output/ocr-tiny -------------------------------------------------------------------------------- /training/experiments/bpe-pretokenizer/README.md: -------------------------------------------------------------------------------- 1 | # BPE Pretokenizer 2 | 3 | In the paper we discuss that the pretokenizer could be anything, really. 4 | 5 | 6 | ### Train Models 7 | 8 | ```bash 9 | python -m training.train training/experiments/bpe-pretokenizer/welt-bpe-14m.yaml 10 | ``` -------------------------------------------------------------------------------- /training/experiments/easy-tasks/README.md: -------------------------------------------------------------------------------- 1 | # Easy Tasks 2 | 3 | We define easy task to verify the model is able to perform basic computation. 4 | 5 | ### String Repetition 6 | 7 | ```bash 8 | export WANDB_PROJECT="string-repetition" 9 | python -m training.train training/experiments/easy-tasks/string-repetition.yaml 10 | ``` 11 | 12 | ### OCR 13 | 14 | ```bash 15 | export WANDB_PROJECT="ocr" 16 | python -m training.train training/experiments/easy-tasks/ocr.yaml 17 | ``` -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | # act --container-architecture linux/amd64 -j lint 2 | 3 | name: Lint 4 | 5 | on: 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v5 17 | 18 | - name: Setup environment 19 | uses: ./.github/actions/setup-environment 20 | 21 | - name: Lint code 22 | run: uv run ruff check . -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | # act --container-architecture linux/amd64 -j test 2 | 3 | name: Test 4 | 5 | on: 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | jobs: 12 | test: 13 | name: Test 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v5 18 | 19 | - name: Setup environment 20 | uses: ./.github/actions/setup-environment 21 | 22 | - name: Test code 23 | env: 24 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 25 | run: uv run pytest -n auto --dist loadscope 26 | 27 | -------------------------------------------------------------------------------- /tests/test_noop.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | from transformers import AutoImageProcessor 5 | 6 | from welt.noop import NoopImageProcessor 7 | 8 | 9 | def test_image_processor_save_load(): 10 | processor = NoopImageProcessor() 11 | assert isinstance(processor, NoopImageProcessor) 12 | 13 | with tempfile.TemporaryDirectory() as temp_dir: 14 | processor.save_pretrained(save_directory=temp_dir, push_to_hub=False) 15 | new_processor = AutoImageProcessor.from_pretrained(temp_dir) 16 | assert isinstance(new_processor, NoopImageProcessor) 17 | 18 | 19 | if __name__ == "__main__": 20 | pytest.main([__file__, "-v"]) 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:25.11-py3 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive \ 4 | PYTHONUNBUFFERED=1 \ 5 | PIP_NO_CACHE_DIR=1 6 | 7 | # System deps (git for installs; build-essential for compiling kernels; tidy apt cache) 8 | # Rendering system deps (pango, cairo...) 9 | RUN apt-get update && \ 10 | apt-get install -y --no-install-recommends build-essential pkg-config \ 11 | libgirepository-1.0-1 libcairo2 gir1.2-pango-1.0 libcairo2-dev libgirepository1.0-dev && \ 12 | rm -rf /var/lib/apt/lists/* 13 | 14 | # Install package dependencies 15 | RUN mkdir -p /app/welt/vision && \ 16 | touch /app/README.md 17 | WORKDIR /app 18 | COPY pyproject.toml /app/pyproject.toml 19 | RUN pip install ".[train]" 20 | 21 | COPY welt /app/welt 22 | COPY training /app/training 23 | -------------------------------------------------------------------------------- /tests/test_collator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from welt.collator import stack_pad_tensors_fast, stack_pad_tensors_slow 5 | 6 | 7 | def test_stack_pad_tensors_consistency(): 8 | """Test that stack_pad_tensors_fast and stack_pad_tensors_slow produce the same output.""" 9 | a = torch.randn(2, 3, 4) 10 | b = torch.randn(3, 5, 4) 11 | c = torch.randn(1, 2, 3) 12 | d = torch.randn(2, 3, 1) 13 | 14 | tensors = [a, b, c, d] 15 | pad_value = -1 16 | 17 | result_fast = stack_pad_tensors_fast(tensors, pad_value=pad_value) 18 | result_slow = stack_pad_tensors_slow(tensors, pad_value=pad_value) 19 | 20 | assert torch.equal(result_fast, result_slow), "Fast and slow implementations should produce identical results" 21 | assert result_fast.shape == result_slow.shape, "Output shapes should be identical" 22 | 23 | 24 | if __name__ == "__main__": 25 | pytest.main([__file__, "-v"]) 26 | -------------------------------------------------------------------------------- /.github/actions/setup-environment/action.yaml: -------------------------------------------------------------------------------- 1 | name: 'Setup Environment' 2 | description: 'Install rendering dependencies, setup uv, and install Python dependencies' 3 | 4 | inputs: 5 | python-version: 6 | description: 'Python version to use' 7 | required: false 8 | default: '3.12' 9 | extra-dependencies: 10 | description: 'Extra dependencies to install (e.g., "[dev]")' 11 | required: false 12 | default: '[dev]' 13 | 14 | runs: 15 | using: 'composite' 16 | steps: 17 | - name: Set up Pango and Cairo 18 | uses: sign/pixel-renderer/.github/actions/setup-pango-cairo@main 19 | 20 | - name: Setup uv 21 | uses: astral-sh/setup-uv@v7 22 | with: 23 | python-version: ${{ inputs.python-version }} 24 | enable-cache: true 25 | activate-environment: true 26 | 27 | - name: Install dependencies 28 | shell: bash 29 | run: uv pip install ".${{ inputs.extra-dependencies }}" 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 sign 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /welt/noop.py: -------------------------------------------------------------------------------- 1 | """ 2 | This model creates NoOp classes for frequently used Huggingface Transformers components, 3 | to allow for specifying we do not want to use a component. 4 | """ 5 | 6 | from transformers import ( 7 | AutoConfig, 8 | AutoImageProcessor, 9 | AutoModel, 10 | ImageProcessingMixin, 11 | PretrainedConfig, 12 | PreTrainedModel, 13 | ) 14 | 15 | 16 | class NoopImageProcessor(ImageProcessingMixin): 17 | name = "noop-image-processor" 18 | 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | 22 | def __call__(self, **unused_kwargs): 23 | raise NotImplementedError() 24 | 25 | 26 | class NoopConfig(PretrainedConfig): 27 | model_type = "noop_model" 28 | 29 | def __init__(self, **kwargs): 30 | super().__init__(**kwargs) 31 | 32 | self.hidden_size = 0 33 | 34 | 35 | class NoopModel(PreTrainedModel): 36 | config_class = NoopConfig 37 | 38 | def __init__(self, config: NoopConfig): 39 | super().__init__(config=config) 40 | 41 | 42 | AutoImageProcessor.register(NoopConfig, NoopImageProcessor) 43 | AutoConfig.register(NoopConfig.model_type, NoopConfig) 44 | AutoModel.register(NoopConfig, NoopModel) 45 | -------------------------------------------------------------------------------- /examples/benchmark_processor.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torch.utils.data import DataLoader 3 | from tqdm import tqdm 4 | from trl import pack_dataset 5 | 6 | from tests.test_model import setup_tiny_model 7 | 8 | 9 | def collate_fn(batch): 10 | """Custom collate function to handle batches with variable-sized tensors.""" 11 | # Don't collate, just return the list of examples 12 | return batch 13 | 14 | if __name__ == '__main__': 15 | dataset = load_dataset("Helsinki-NLP/opus-100", "en-he", split="train") 16 | 17 | model, processor, collator = setup_tiny_model() 18 | 19 | # Convert dataset to text 20 | dataset = dataset.map( 21 | lambda batch: {"text": f"\x0E{batch['translation']['en']}\x0F {batch['translation']['he']}"}, 22 | remove_columns=["translation"]) 23 | dataset = processor.pretokenize_dataset(dataset) 24 | 25 | dataset = pack_dataset(dataset, seq_length=128) 26 | dataset = dataset.with_transform(processor) 27 | 28 | dataloader = DataLoader(dataset, 29 | batch_size=2, 30 | num_workers=2, 31 | collate_fn=collator) 32 | 33 | for _ in tqdm(dataloader): 34 | pass 35 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import tempfile 3 | 4 | import pytest 5 | 6 | from tests.test_model import setup_tiny_model 7 | from welt.config import WordLatentTransformerConfig 8 | 9 | 10 | def test_config_save_and_load_equal(): 11 | model, processor, collator = setup_tiny_model() 12 | 13 | with tempfile.TemporaryDirectory() as temp_dir: 14 | model.config.save_pretrained(save_directory=temp_dir, push_to_hub=False) 15 | 16 | new_config = WordLatentTransformerConfig.from_pretrained(temp_dir) 17 | 18 | old_config_text = model.config.to_json_string() 19 | new_config_text = new_config.to_json_string() 20 | 21 | if old_config_text != new_config_text: 22 | diff = "\n".join( 23 | difflib.unified_diff( 24 | old_config_text.splitlines(), 25 | new_config_text.splitlines(), 26 | fromfile="model.config (before save/load)", 27 | tofile="new_config (after save/load)", 28 | lineterm="", 29 | ) 30 | ) 31 | pytest.fail(f"Config changed after save/load:\n{diff}") 32 | 33 | 34 | if __name__ == "__main__": 35 | pytest.main([__file__, "-v"]) 36 | -------------------------------------------------------------------------------- /training/args_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training arguments for WeLT Trainer. 3 | 4 | Extends Seq2SeqTrainingArguments with additional parameters for generation-based evaluation. 5 | """ 6 | from dataclasses import dataclass, field 7 | 8 | from transformers import Seq2SeqTrainingArguments 9 | 10 | 11 | @dataclass 12 | class WeLTTrainingArguments(Seq2SeqTrainingArguments): 13 | """ 14 | Training arguments for WeLT Trainer. 15 | 16 | Extends Seq2SeqTrainingArguments with parameters specific to WeLT's 17 | generation-based evaluation. 18 | """ 19 | 20 | eval_metrics: list[str] | None = field( 21 | default=None, 22 | metadata={ 23 | "help": ( 24 | "List of evaluation metrics to compute during generation-based evaluation. " 25 | "Examples: ['sacrebleu', 'chrf', 'bleu', 'rouge']. " 26 | "If None, only eval_loss, byte_accuracy, word_accuracy, and perplexity will be computed." 27 | ) 28 | }, 29 | ) 30 | 31 | log_samples: int = field( 32 | default=3, 33 | metadata={ 34 | "help": ( 35 | "Number of sample predictions to log during evaluation. " 36 | "Set to 0 to disable sample logging." 37 | ) 38 | }, 39 | ) 40 | -------------------------------------------------------------------------------- /.github/workflows/docker-build.yaml: -------------------------------------------------------------------------------- 1 | name: Dockerfile build test 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'Dockerfile' 7 | - 'pyproject.toml' 8 | - '.github/workflows/docker-build.yaml' 9 | pull_request: 10 | paths: 11 | - 'Dockerfile' 12 | 13 | jobs: 14 | docker-build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout source 18 | uses: actions/checkout@v5 19 | 20 | # Reclaim ~50–80 GB on the GitHub runner 21 | - name: Free disk space 22 | uses: jlumbroso/free-disk-space@main 23 | with: 24 | tool-cache: true 25 | android: true 26 | dotnet: true 27 | haskell: true 28 | large-packages: true 29 | docker-images: true 30 | swap-storage: true 31 | 32 | - name: Set up Docker Buildx 33 | uses: docker/setup-buildx-action@v3 34 | 35 | # Build with BuildKit, don’t load/push; keep only cache 36 | - name: Build (cache-only, no image load) 37 | uses: docker/build-push-action@v6 38 | with: 39 | context: . 40 | file: Dockerfile 41 | pull: true 42 | push: false 43 | load: false 44 | cache-from: type=gha 45 | cache-to: type=gha,mode=max 46 | outputs: type=cacheonly 47 | -------------------------------------------------------------------------------- /training/experiments/machine-translation/machine-translation.yaml: -------------------------------------------------------------------------------- 1 | # Model Setup 2 | image_encoder_model_name_or_path: WinKawaks/vit-tiny-patch16-224 3 | bytes_encoder_model_name_or_path: prajjwal1/bert-tiny 4 | latent_transformer_model_name_or_path: EleutherAI/pythia-31m 5 | bytes_decoder_model_name_or_path: sbintuitions/tiny-lm 6 | 7 | # Dataset setup 8 | dataset_name: Helsinki-NLP/opus-100 9 | dataset_config_name: en-he 10 | dataset_text_template: "\x0E{translation[en]}\x0F {translation[he]}" 11 | max_sequence_length: 128 12 | max_word_length: 16 13 | 14 | # Data Loader 15 | dataloader_num_workers: 8 16 | dataloader_prefetch_factor: 4 17 | dataloader_pin_memory: true 18 | dataloader_persistent_workers: true 19 | 20 | # Training setup 21 | remove_unused_columns: false # Necessary 22 | per_device_train_batch_size: 2 23 | per_device_eval_batch_size: 2 24 | auto_find_batch_size: true 25 | output_dir: ./output/en-he 26 | do_train: true 27 | max_steps: 10000 28 | learning_rate: 3.0e-4 29 | optim: adamw_torch_fused 30 | 31 | # Evaluation 32 | do_eval: false 33 | eval_on_start: true 34 | eval_strategy: steps 35 | eval_steps: 500 36 | metric_for_best_model: accuracy 37 | max_eval_samples: 32 38 | 39 | # Logging 40 | logging_steps: 10 41 | logging_strategy: steps 42 | include_tokens_per_second: true 43 | include_num_input_tokens_seen: true 44 | report_to: wandb 45 | 46 | # Dtype 47 | bf16: true 48 | dtype: bfloat16 49 | -------------------------------------------------------------------------------- /training/experiments/machine-translation/machine-translation-signed-spoken.yaml: -------------------------------------------------------------------------------- 1 | # Model Setup 2 | image_encoder_model_name_or_path: null 3 | bytes_encoder_model_name_or_path: prajjwal1/bert-tiny 4 | latent_transformer_model_name_or_path: EleutherAI/pythia-31m 5 | bytes_decoder_model_name_or_path: sbintuitions/tiny-lm 6 | 7 | # Dataset setup 8 | dataset_name: sign/signbank-plus 9 | dataset_config_name: cleaned 10 | dataset_text_template: 11 | - "<{sign_language}>\x0E{sign_text}\x0F<{spoken_language}> " 12 | - "{spoken_text}" 13 | max_sequence_length: 64 14 | max_word_length: 128 # SignWriting has lots of bytes per sign 15 | 16 | # Data Loader 17 | dataloader_num_workers: 8 18 | dataloader_prefetch_factor: 4 19 | dataloader_pin_memory: true 20 | dataloader_persistent_workers: true 21 | 22 | # Training setup 23 | remove_unused_columns: false # Necessary 24 | per_device_train_batch_size: 96 25 | per_device_eval_batch_size: 96 26 | auto_find_batch_size: true 27 | output_dir: ./output/signed-to-spoken 28 | do_train: true 29 | max_steps: 10000 30 | learning_rate: 3.0e-4 31 | optim: adamw_torch_fused 32 | 33 | # Evaluation 34 | do_eval: false 35 | eval_on_start: true 36 | eval_strategy: steps 37 | eval_steps: 500 38 | metric_for_best_model: accuracy 39 | 40 | # Logging 41 | logging_steps: 10 42 | logging_strategy: steps 43 | include_num_input_tokens_seen: true 44 | report_to: wandb 45 | 46 | # Dtype 47 | bf16: true 48 | dtype: bfloat16 49 | -------------------------------------------------------------------------------- /training/experiments/bpe-pretokenizer/welt-bpe-14m.yaml: -------------------------------------------------------------------------------- 1 | # Model Setup 2 | image_encoder_model_name_or_path: null 3 | bytes_encoder_model_name_or_path: prajjwal1/bert-tiny 4 | latent_transformer_model_name_or_path: EleutherAI/pythia-14m 5 | bytes_decoder_model_name_or_path: sbintuitions/tiny-lm 6 | pretokenizer_name: EleutherAI/pythia-14m 7 | 8 | # Dataset setup 9 | dataset_name: Helsinki-NLP/opus-100 10 | dataset_config_name: en-he 11 | dataset_text_template: "\x0E{translation[en]}\x0F {translation[en]}" 12 | max_sequence_length: 128 13 | max_word_length: 16 14 | 15 | max_eval_samples: 32 16 | 17 | # Data Loader 18 | dataloader_num_workers: 8 19 | dataloader_prefetch_factor: 4 20 | dataloader_pin_memory: true 21 | dataloader_persistent_workers: true 22 | 23 | # Training setup 24 | remove_unused_columns: false # Necessary 25 | per_device_train_batch_size: 32 26 | per_device_eval_batch_size: 32 27 | auto_find_batch_size: true 28 | output_dir: ./output/welt-bpe-14m 29 | do_train: true 30 | max_steps: 1000 31 | learning_rate: 3.0e-4 32 | optim: adamw_torch_fused 33 | 34 | # Evaluation 35 | do_eval: false 36 | eval_on_start: true 37 | eval_strategy: steps 38 | eval_steps: 100 39 | metric_for_best_model: accuracy 40 | 41 | # Logging 42 | logging_steps: 10 43 | logging_strategy: steps 44 | include_tokens_per_second: true 45 | include_num_input_tokens_seen: true 46 | report_to: none 47 | 48 | # Dtype 49 | bf16: true 50 | dtype: bfloat16 -------------------------------------------------------------------------------- /welt/vision/test_navit.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | import torch 5 | 6 | from tests.test_model import num_model_params 7 | from welt.vision.navit import NaViTConfig, NaViTModel 8 | 9 | 10 | @pytest.fixture(scope="module") 11 | def config(): 12 | return NaViTConfig( 13 | image_size=256, 14 | patch_size=16, 15 | hidden_size=512, 16 | dim=256, 17 | depth=6, 18 | heads=8, 19 | mlp_dim=1024, 20 | dropout=0.0, 21 | emb_dropout=0.0, 22 | token_dropout_prob=0.1, 23 | ) 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def model(config): 28 | return NaViTModel(config) 29 | 30 | 31 | def test_model_forward(model): 32 | images = [ 33 | torch.randn(3, 256, 256), 34 | torch.randn(3, 128, 128), 35 | torch.randn(3, 128, 256), 36 | torch.randn(3, 256, 128), 37 | torch.randn(3, 64, 256) 38 | ] 39 | 40 | out = model(images=images) 41 | assert out.pooler_output.shape == (5, 512) 42 | 43 | 44 | def test_model_from_pretrained_works(model, config): 45 | with tempfile.TemporaryDirectory() as temp_dir: 46 | model.save_pretrained(temp_dir) 47 | config.save_pretrained(temp_dir) 48 | original_num_parameters = num_model_params(model) 49 | 50 | new_model = NaViTModel.from_pretrained(temp_dir) 51 | loaded_num_parameters = num_model_params(new_model) 52 | assert original_num_parameters == loaded_num_parameters, \ 53 | f"Number of parameters mismatch: {original_num_parameters:,} vs {loaded_num_parameters:,}" 54 | 55 | 56 | if __name__ == "__main__": 57 | pytest.main([__file__, "-v"]) 58 | -------------------------------------------------------------------------------- /training/experiments/easy-tasks/string-repetition.yaml: -------------------------------------------------------------------------------- 1 | # Model Setup 2 | image_encoder_model_name_or_path: null 3 | bytes_encoder_model_name_or_path: prajjwal1/bert-tiny 4 | latent_transformer_model_name_or_path: sbintuitions/tiny-lm 5 | bytes_decoder_model_name_or_path: sign/utf8-lm-tiny 6 | load_pretrained: true 7 | 8 | # Dataset setup 9 | dataset_name: Helsinki-NLP/opus-100 10 | dataset_config_name: en-he 11 | dataset_text_template: 12 | - "\x0E{translation[en]}\x0F " 13 | - "{translation[en]}" 14 | max_sequence_length: 128 15 | max_word_length: 16 16 | 17 | max_eval_samples: 32 18 | 19 | # Data Loader 20 | dataloader_num_workers: 8 21 | dataloader_prefetch_factor: 4 22 | dataloader_pin_memory: true 23 | dataloader_persistent_workers: true 24 | 25 | # Training setup 26 | remove_unused_columns: false # Necessary 27 | per_device_train_batch_size: 32 28 | per_device_eval_batch_size: 32 29 | auto_find_batch_size: true 30 | output_dir: ./output/string-repetition-tiny 31 | overwrite_output_dir: true 32 | do_train: true 33 | max_steps: 10000 34 | learning_rate: 3.0e-4 35 | optim: adamw_torch_fused 36 | 37 | # Evaluation 38 | do_eval: true 39 | eval_on_start: true 40 | eval_strategy: steps 41 | eval_steps: 100 42 | metric_for_best_model: chrf # Using generation-based metric 43 | eval_metrics: [sacrebleu, chrf] # Generation-based evaluation metrics 44 | predict_with_generate: true 45 | generation_max_length: 50 # Max tokens/words to generate during evaluation 46 | log_samples: 5 # Number of sample predictions to log 47 | 48 | # Logging 49 | logging_steps: 10 50 | logging_strategy: steps 51 | include_tokens_per_second: true 52 | include_num_input_tokens_seen: true 53 | report_to: wandb 54 | 55 | # Dtype 56 | bf16: true 57 | dtype: bfloat16 -------------------------------------------------------------------------------- /training/freeze_callback.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments 2 | 3 | from welt.model import WordLatentTransformer 4 | 5 | 6 | class FreezeWarmupCallback(TrainerCallback): 7 | """ 8 | If steps==0: no-op. 9 | Else: on train begin -> model.freeze_pretrained_models() 10 | after `steps` -> model.unfreeze() (once). 11 | Safe with DDP/Deepspeed since toggling requires_grad is fine mid-training. 12 | """ 13 | 14 | def __init__(self, model: WordLatentTransformer, steps: int = 0): 15 | self.model = model 16 | 17 | self.steps = steps 18 | self.enabled = self.steps > 0 19 | 20 | print("FreezeWarmupCallback initialized with steps =", self.steps) 21 | 22 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 23 | if not self.enabled: 24 | return control 25 | 26 | # If resuming past threshold, don't (re)freeze 27 | if state.global_step >= self.steps: 28 | self.enabled = False 29 | return control 30 | 31 | self.model.freeze_pretrained_models() 32 | print("✓ Freezing pretrained model parameters for first", self.steps, "steps.") 33 | 34 | control.should_log = True 35 | return control 36 | 37 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 38 | if not self.enabled: 39 | return control 40 | 41 | if state.global_step >= self.steps: 42 | self.model.unfreeze() 43 | print("✓ Unfreezing all model parameters after", state.global_step, "steps.") 44 | self.enabled = False 45 | 46 | # Force a log/save right after unfreeze for visibility/checkpoints 47 | control.should_log = True 48 | control.should_save = True 49 | return control 50 | -------------------------------------------------------------------------------- /training/experiments/chat/README.md: -------------------------------------------------------------------------------- 1 | # Chat 2 | 3 | We would like to create a "ChatGPT Clone" using our new architecture, following 4 | [nanochat](https://github.com/karpathy/nanochat). 5 | 6 | Pretraining dataset: `karpathy/fineweb-edu-100b-shuffle`. 7 | Midtraining dataset: `HuggingFaceTB/smoltalk2` (`Mid`). 8 | SFT dataset: `HuggingFaceTB/smoltalk2` (`SFT`). 9 | RL dataset: `HuggingFaceTB/smoltalk2` (`Preference`). 10 | 11 | ## Base model (pretraining) 12 | 13 | We train a model on `karpathy/fineweb-edu-100b-shuffle` - 14 | Chinchilla says #tokens = 20X #params, so we need 10B/20 = 500m parameters. 15 | 16 | ```shell 17 | torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN 18 | # evaluate the model on a larger chunk of train/val data and draw some samples 19 | torchrun --standalone --nproc_per_node=8 -m scripts.base_loss 20 | # evaluate the model on CORE tasks 21 | torchrun --standalone --nproc_per_node=8 -m scripts.base_eval 22 | ``` 23 | 24 | ## Midtraining 25 | 26 | Teach the model conversation special tokens, tool use, multiple choice. 27 | 28 | ```shell 29 | torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN 30 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid 31 | ``` 32 | 33 | ## Supervised Finetuning 34 | 35 | Domain adaptation to each sequence all by itself per row (no packing) 36 | train sft and re-eval right away (should see a small bump) 37 | 38 | ```shell 39 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN 40 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft 41 | ``` 42 | 43 | ## Reinforcement Learning 44 | 45 | ```shell 46 | # run reinforcement learning 47 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN 48 | # eval the RL model only on GSM8K 49 | torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K 50 | ``` 51 | 52 | ## Chat Interface 53 | 54 | https://github.com/karpathy/nanochat/blob/master/scripts/chat_web.py -------------------------------------------------------------------------------- /welt/config.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import CONFIG_MAPPING, AutoConfig, PretrainedConfig 3 | 4 | from welt.noop import NoopConfig 5 | 6 | 7 | class WordLatentTransformerConfig(PretrainedConfig): 8 | model_type = "welt" 9 | 10 | sub_configs = { 11 | "image_encoder": AutoConfig, 12 | "bytes_encoder": AutoConfig, 13 | "latent_transformer": AutoConfig, 14 | "bytes_decoder": AutoConfig, 15 | } 16 | 17 | def __init__(self, 18 | image_encoder: AutoConfig | dict | None = None, 19 | bytes_encoder: AutoConfig | dict | None = None, 20 | latent_transformer: AutoConfig | dict = None, 21 | bytes_decoder: AutoConfig | dict = None, 22 | modality_dropout: float = 0.15, 23 | num_tokens: int = 256, 24 | **kwargs): 25 | # Configuration defaults 26 | kwargs["is_decoder"] = kwargs.get("is_decoder", True) 27 | super().__init__(**kwargs) 28 | 29 | self.image_encoder = image_encoder 30 | self.bytes_encoder = bytes_encoder 31 | self.latent_transformer = latent_transformer 32 | self.bytes_decoder = bytes_decoder 33 | 34 | for name in self.sub_configs.keys(): 35 | self.init_sub_config(name) 36 | 37 | self.modality_dropout = modality_dropout 38 | self.num_tokens = num_tokens 39 | 40 | def init_sub_config(self, name: str): 41 | config = getattr(self, name, None) 42 | if isinstance(config, dict): 43 | model_type = config.get("model_type", None) 44 | config_cls = CONFIG_MAPPING[model_type] if model_type else PretrainedConfig 45 | config = config_cls(**config) 46 | setattr(self, name, config) 47 | 48 | if config is None: 49 | # For optional encoders, use NoopConfig instead of PretrainedConfig 50 | if name in ["image_encoder", "bytes_encoder"]: 51 | setattr(self, name, NoopConfig()) 52 | else: 53 | setattr(self, name, PretrainedConfig()) 54 | -------------------------------------------------------------------------------- /training/sample.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from transformers import GenerationConfig 5 | from transformers.trainer_utils import get_last_checkpoint 6 | 7 | from welt.model import WordLatentTransformerForCausalLM 8 | from welt.processor import TextImageProcessor 9 | 10 | 11 | @torch.no_grad() 12 | @torch.autocast(device_type="cuda", dtype=torch.bfloat16) 13 | def sample(model_path: Path): 14 | last_checkpoint = get_last_checkpoint(model_path) 15 | model: WordLatentTransformerForCausalLM = \ 16 | WordLatentTransformerForCausalLM.from_pretrained(last_checkpoint, attn_implementation="flash_attention_2") 17 | processor = TextImageProcessor.from_pretrained(model_path) 18 | 19 | model.eval() 20 | 21 | texts = [ 22 | # Texts from validation set 23 | "\x0EWhat's wrong?\x0F ", 24 | "\x0EYOu dOn't know the half Of it.\x0F ", # look! mixed case 25 | "\x0E- No, just said that you were acting... aggressive.\x0F ", 26 | "\x0E-l'm a deputy.\x0F ", # look! "l" instead of "I", no space. 27 | "\x0EWell, the good news is Joe wasn't cheating on you.\x0F ", 28 | "\x0E- Mm-hmm. No wonder women won't flirt with me.\x0F ", 29 | "\x0EYou understand you'll be working with your father.\x0F ", 30 | "\x0EYou wanted freedom, you got it.\x0F ", 31 | # Not from validation set, just wanted to try a long named entity 32 | "\x0EAlexander Hamilton\x0F ", 33 | ] 34 | 35 | inputs = processor(texts, collated=True, packed=False) 36 | 37 | outputs = model.generate( 38 | **inputs, 39 | processor=processor, 40 | max_generated_words=32, 41 | bytes_generation_config=GenerationConfig(num_beams=2) # Sample with beam search, for example 42 | ) 43 | for text, output in zip(texts, outputs, strict=False): 44 | print(f"Generated for '{text}': {output}") 45 | 46 | 47 | if __name__ == "__main__": 48 | model_path = Path(__file__).parent.parent / "output" / "signed-to-spoken" 49 | sample(model_path) 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "WeLT" 3 | description = "WeLT: Word embedding Latent Transformer - a language model designed to operate on \"words\"." 4 | version = "0.0.1" 5 | authors = [ 6 | { name = "Amit Moryossef", email = "amit@sign.mt" }, 7 | ] 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | dependencies = [ 11 | "transformers[torch]", 12 | "trl", # For huggingface implementation of pack_dataset 13 | "datasets", 14 | "Pillow", # For AutoImageProcessor 15 | "torchvision", # For use_fast=True 16 | "cachetools", # For LRU cache in processor 17 | "vit-pytorch", # For NaViT model implementation 18 | # Our packages 19 | "words-segmentation", # For pre-tokenization into words 20 | "utf8-tokenizer[fast]", # For tokenization of bytes 21 | "pixel-renderer[pangocairo]", # For text rendering into images (thread safe) 22 | ] 23 | 24 | [project.optional-dependencies] 25 | dev = [ 26 | "ruff", 27 | "pytest", 28 | "pytest-xdist", # For parallel test execution 29 | ] 30 | 31 | train = [ 32 | "hf_transfer", # For HF_HUB_ENABLE_HF_TRANSFER 33 | "accelerate", # For distributed training 34 | "datasets", # For dataset loading and processing 35 | "evaluate", # For evaluation metrics 36 | "scikit-learn", # For "accuracy" metric in evaluate 37 | "sacrebleu", # For usual bleu/chrF metrics 38 | "wandb", # For experiment tracking 39 | ] 40 | 41 | [tool.setuptools] 42 | packages = [ 43 | "welt", 44 | "welt.vision", 45 | ] 46 | 47 | [tool.ruff] 48 | line-length = 120 49 | extend-exclude = [ 50 | "training/experiments/*", 51 | ] 52 | 53 | [tool.ruff.lint] 54 | select = [ 55 | "E", # pycodestyle errors 56 | "W", # pycodestyle warnings 57 | "F", # pyflakes 58 | "C90", # mccabe complexity 59 | "I", # isort 60 | "N", # pep8-naming 61 | "UP", # pyupgrade 62 | "B", # flake8-bugbear 63 | "PT", # flake8-pytest-style 64 | "W605", # invalid escape sequence 65 | "BLE", # flake8-blind-except 66 | "TRY", # tryceratops 67 | ] 68 | 69 | [tool.pytest.ini_options] 70 | addopts = "-v" 71 | testpaths = [ 72 | "welt", 73 | "tests", 74 | ] 75 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | > [!CAUTION] 4 | > Read the [vision README](../welt/vision/README.md) to understand how to select 5 | > an image encoder model for fast training. 6 | 7 | Setup with: 8 | 9 | ```bash 10 | pip install ".[train]" 11 | ``` 12 | 13 | Run: 14 | [//]: # (TODO: Unclear why `remove_unused_columns=False` is needed, but it is required to avoid errors during training.) 15 | 16 | ```bash 17 | python -m training.train \ 18 | --image_encoder_model_name_or_path "WinKawaks/vit-tiny-patch16-224" \ 19 | --bytes_encoder_model_name_or_path "prajjwal1/bert-tiny" \ 20 | --latent_transformer_model_name_or_path "sbintuitions/tiny-lm" \ 21 | --bytes_decoder_model_name_or_path "sbintuitions/tiny-lm" \ 22 | --load_pretrained True \ 23 | --dataset_name Helsinki-NLP/opus-100 \ 24 | --dataset_config_name en-he \ 25 | --dataset_text_template " {translation[en]} {translation[he]}" \ 26 | --remove_unused_columns False \ 27 | --per_device_train_batch_size 1 \ 28 | --per_device_eval_batch_size 1 \ 29 | --do_train \ 30 | --save_steps 10 \ 31 | --output_dir output \ 32 | --overwrite_output_dir \ 33 | --logging_steps 1 \ 34 | --logging_strategy steps \ 35 | --max_steps 50 \ 36 | --max_sequence_length 32 \ 37 | --max_word_length 8 \ 38 | --dataloader_num_workers 4 \ 39 | --include_tokens_per_second True \ 40 | --include_num_input_tokens_seen True \ 41 | --max_train_samples 16 \ 42 | --warmup_freeze_steps 10 43 | ``` 44 | 45 | Use `warmup_freeze_steps=N` to freeze the pretrained modules for the first N steps 46 | ([#7](https://github.com/sign/WeLT/issues/7)). 47 | 48 | ### Training Quirks 49 | 50 | `num_input_tokens_seen` and `train_tokens_per_second` are calculated based on the number of bytes the model decodes. 51 | That means that in practice, if `max_word_length=32`, a rough estimate of 52 | the real number of **words** the model sees should be divided by 32. 53 | 54 | ### Performance Optimization 55 | 56 | To speed up the processor's image rendering and preprocessing, you can 57 | increase `processor.cache_size` to cache more preprocessed images in memory. 58 | (`cache_size=500_000` can take 25GB of RAM per process). 59 | Ideally, we make the renderer so fast it doesn't need caching at all. -------------------------------------------------------------------------------- /tests/test_model_deterministic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tests.test_model import predict_dataset, setup_tiny_model 4 | from tests.test_model_overfitting import train_model 5 | 6 | 7 | def test_model_setup_is_deterministic(): 8 | """Test that model setup is deterministic - creates identical models when called twice.""" 9 | model1, processor1, collator1 = setup_tiny_model() 10 | model2, processor2, collator2 = setup_tiny_model() 11 | 12 | # Compare all parameters between the two models 13 | model1_state_dict = model1.state_dict() 14 | model2_state_dict = model2.state_dict() 15 | 16 | # Check that both models have the same parameter names 17 | assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()), \ 18 | "Models have different parameter names" 19 | 20 | # Check that all parameters are identical 21 | for param_name in model1_state_dict: 22 | param1 = model1_state_dict[param_name] 23 | param2 = model2_state_dict[param_name] 24 | 25 | # Check shapes match 26 | assert param1.shape == param2.shape, \ 27 | f"Parameter {param_name} has different shapes: {param1.shape} vs {param2.shape}" 28 | 29 | # Check values are identical 30 | assert torch.allclose(param1, param2, rtol=1e-7, atol=1e-7), \ 31 | f"Parameter {param_name} has different values. Max diff: {(param1 - param2).abs().max().item()}" 32 | 33 | print(f"✓ Model setup is deterministic: {len(model1_state_dict)} parameters verified") 34 | 35 | 36 | def test_train_model_is_deterministic(): 37 | print("Setting up models for training determinism test...") 38 | models = [train_model(setup_tiny_model, num_epochs=50) for _ in range(2)] 39 | 40 | print("Predicting losses for test texts using both models...") 41 | test_texts = ["a b", "b a", "a cat", "a dog"] 42 | losses = [predict_dataset(test_texts, model, processor, collator)[0] 43 | for model, processor, collator in models] 44 | 45 | # Compare losses - they should be identical 46 | tolerance = 1e-4 47 | for text in test_texts: 48 | assert abs(losses[0][text] - losses[1][text]) < tolerance, \ 49 | f"Loss mismatch for '{text}': {losses[0][text]:.6f} vs {losses[1][text]:.6f}" 50 | 51 | print("✅ Training determinism test passed - both models produced identical losses!") 52 | -------------------------------------------------------------------------------- /training/experiments/machine-translation/README.md: -------------------------------------------------------------------------------- 1 | # Machine Translation 2 | 3 | We compare HuggingFace's example training for a causal language model to our setup. 4 | 5 | We modified the `run_clm.py` script, in this directory. 6 | See the modifications by running a diff with: 7 | https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py 8 | 9 | We train `EleutherAI/pythia-70m` from scratch, using: 10 | 11 | ```shell 12 | python run_clm.py \ 13 | --model_name_or_path EleutherAI/pythia-70m \ 14 | --per_device_train_batch_size 8 \ 15 | --per_device_eval_batch_size 8 \ 16 | --block_size 128 \ 17 | --output_dir ./output/clm \ 18 | ``` 19 | 20 | Compared to our model setup: 21 | ```shell 22 | python -m training.train \ 23 | --image_encoder_model_name_or_path WinKawaks/vit-tiny-patch16-224 \ 24 | --bytes_encoder_model_name_or_path prajjwal1/bert-tiny \ 25 | --latent_transformer_model_name_or_path EleutherAI/pythia-70m \ 26 | --bytes_decoder_model_name_or_path sbintuitions/tiny-lm \ 27 | --remove_unused_columns False \ 28 | --per_device_train_batch_size 128 \ 29 | --per_device_eval_batch_size 128 \ 30 | --max_sequence_length 128 \ 31 | --max_word_length 20 \ 32 | --output_dir ./output/welt \ 33 | --dtype bfloat16 \ 34 | ``` 35 | 36 | With the following shared arguments: 37 | 38 | ```shell 39 | --dataset_name Helsinki-NLP/opus-100 \ 40 | --dataset_config_name en-he \ 41 | --dataset_text_template " {translation[en]} {translation[he]}" \ 42 | --do_train True \ 43 | --do_eval True \ 44 | --metric_for_best_model accuracy \ 45 | --eval_on_start True \ 46 | --eval_strategy epoch \ 47 | --logging_steps 10 \ 48 | --logging_strategy steps \ 49 | --max_steps 100000 \ 50 | --dataloader_num_workers 8 \ 51 | --dataloader_prefetch_factor 4 \ 52 | --dataloader_pin_memory True \ 53 | --dataloader_persistent_workers True \ 54 | --auto_find_batch_size True \ 55 | --include_tokens_per_second True \ 56 | --include_num_input_tokens_seen True \ 57 | --learning_rate 3e-4 \ 58 | --optim adamw_torch_fused \ 59 | --bf16 True \ 60 | --report_to wandb 61 | ``` 62 | 63 | ### Using configs 64 | 65 | ```bash 66 | python -m training.train training/experiments/machine-translation/machine-translation.yaml 67 | python -m training.train training/experiments/machine-translation/machine-translation-signed-spoken.yaml 68 | ``` -------------------------------------------------------------------------------- /welt/vision/navit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoModel 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.modeling_outputs import BaseModelOutputWithPooling 5 | from transformers.modeling_utils import PreTrainedModel 6 | from vit_pytorch.na_vit import NaViT 7 | 8 | 9 | class NaViTConfig(PretrainedConfig): 10 | model_type = "navit" 11 | 12 | def __init__( 13 | self, 14 | image_size: int = 256, # only used to set a default; NaViT handles var-size 15 | patch_size: int = 16, 16 | hidden_size: int = 512, 17 | dim: int = 512, 18 | depth: int = 6, 19 | heads: int = 8, 20 | mlp_dim: int = 1024, 21 | dropout: float = 0.0, 22 | emb_dropout: float = 0.0, 23 | token_dropout_prob: float = 0.1, 24 | **kwargs, 25 | ): 26 | super().__init__(**kwargs) 27 | self.image_size = image_size 28 | self.patch_size = patch_size 29 | self.hidden_size = hidden_size 30 | self.dim = dim 31 | self.depth = depth 32 | self.heads = heads 33 | self.mlp_dim = mlp_dim 34 | self.dropout = dropout 35 | self.emb_dropout = emb_dropout 36 | self.token_dropout_prob = token_dropout_prob 37 | 38 | 39 | class NaViTModel(PreTrainedModel): 40 | config_class = NaViTConfig 41 | 42 | def __init__(self, config: NaViTConfig): 43 | super().__init__(config) 44 | 45 | navit = NaViT( 46 | image_size=config.image_size, 47 | patch_size=config.patch_size, 48 | num_classes=config.hidden_size, 49 | dim=config.dim, 50 | depth=config.depth, 51 | heads=config.heads, 52 | mlp_dim=config.mlp_dim, 53 | dropout=config.dropout, 54 | emb_dropout=config.emb_dropout, 55 | token_dropout_prob=config.token_dropout_prob, 56 | ) 57 | # Explicitly register as a submodule to ensure device changes are tracked 58 | self.add_module('navit', navit) 59 | 60 | # Initialize weights via HF utility (won't disturb vit-pytorch-initialized weights unless uninitialized) 61 | self.post_init() 62 | 63 | def forward(self, images: list[torch.Tensor], **kwargs): 64 | # vit-pytorch NaViT returns shape (B, num_classes) 65 | logits = self.navit(images) 66 | 67 | return BaseModelOutputWithPooling( 68 | last_hidden_state=None, 69 | pooler_output=logits 70 | ) 71 | 72 | AutoConfig.register(NaViTConfig.model_type, NaViTConfig) 73 | AutoModel.register(NaViTConfig, NaViTModel) 74 | -------------------------------------------------------------------------------- /welt/vision/README.md: -------------------------------------------------------------------------------- 1 | # Vision 2 | 3 | Working with Image encoder models is less standardized than working with text models. 4 | 5 | In this repository, it is recommended to use, in order of preference: 6 | 7 | 1. [NaViT](https://github.com/lucidrains/vit-pytorch#navit) models, since they natively support variable-sized inputs. 8 | 2. Models that implement 9 | [ViTModel](https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py), 10 | since we have a patch to handle variable-sized inputs. 11 | 3. Any other image encoder model from `transformers`, in which we have utilities to group images by size and batch them, 12 | avoiding padding artifacts, with a lot of compute overhead. 13 | 14 | ### Vision Utils [./vision_utils.py](./vision_utils.py) 15 | 16 | This utility is under a PR [#40587](https://github.com/huggingface/transformers/pull/40587) to `transformers`. 17 | 18 | This utility addresses the lack of standardization across vision transformer models and image encoders. 19 | Unlike text models, where core attributes like `hidden_size` or `vocab_size` are consistently exposed, 20 | vision models vary widely in how they represent hidden dimensions, pooling strategies, and forward-pass arguments. 21 | Some models store the hidden size in `config.hidden_size`, others in `vision_config.hidden_size`, 22 | `neck_hidden_sizes`, or `hidden_sizes`. 23 | Similarly, handling positional encoding interpolation or pooling hidden states differs across architectures. 24 | Without a unified interface, downstream code must implement ad-hoc checks and case-by-case logic. 25 | This module centralizes those edge cases into robust, cached utilities that reliably determine encoder dimensions, 26 | construct forward arguments, and pool features, providing a consistent abstraction for working with 27 | heterogeneous vision backbones. 28 | 29 | ### Batch Image Encoder [./batch_image_encoder.py](./batch_image_encoder.py) 30 | 31 | This utility tackles the challenge that most vision transformer encoders can only process batches of 32 | equally sized images, which makes handling variable-sized inputs difficult and error-prone. 33 | Padding smaller images to match the largest one in a batch leads to wasted computation and, worse, changes the 34 | encoder’s output in subtle ways since positional embeddings and convolutions can be affected by padded regions. 35 | 36 | To resolve this, the utility reorganizes inputs by cropping them to their true dimensions, 37 | grouping images of the same shape, and then running those groups through the encoder in efficient batches. 38 | After encoding, the results are reordered and reassembled back into the original nested batch structure, 39 | ensuring consistency and eliminating the artifacts introduced by padding. 40 | This provides a reliable and scalable way to extract embeddings from heterogeneous image sets while still 41 | leveraging batch parallelism where possible. 42 | -------------------------------------------------------------------------------- /training/args_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from transformers import ( 4 | HfArgumentParser, 5 | ) 6 | from transformers.models.auto.modeling_auto import ( 7 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, 8 | MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, 9 | MODEL_FOR_MASKED_LM_MAPPING_NAMES, 10 | ) 11 | 12 | 13 | def listed_model(model_names): 14 | models_list = ", ".join(model_names.keys()) 15 | return field( 16 | default=None, 17 | metadata={"help": "Any model implementing the architectures from the list: " + models_list}, 18 | ) 19 | 20 | 21 | @dataclass 22 | class ModelArguments: 23 | model_name_or_path: str | None = field( 24 | default=None, 25 | metadata={ 26 | "help": ( 27 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 28 | ) 29 | }, 30 | ) 31 | 32 | image_encoder_model_name_or_path: str | None = listed_model(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES) 33 | bytes_encoder_model_name_or_path: str | None = listed_model(MODEL_FOR_MASKED_LM_MAPPING_NAMES) 34 | latent_transformer_model_name_or_path: str | None = listed_model(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) 35 | bytes_decoder_model_name_or_path: str | None = listed_model(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) 36 | 37 | load_pretrained: bool = field(default=False, metadata={ 38 | "help": "Whether to load the pretrained weights of the models specified in *_model_name_or_path." 39 | }) 40 | 41 | warmup_freeze_steps: int = field(default=0, metadata={ 42 | "help": "Steps to keep most modules frozen at start." 43 | }) 44 | 45 | pretokenizer_name: str | None = field(default=None, metadata={ 46 | "help": "Pretokenizer to use, defaults to https://github.com/sign/words-segmentation." 47 | }) 48 | 49 | trust_remote_code: bool = field( 50 | default=False, 51 | metadata={ 52 | "help": ( 53 | "Whether to trust the execution of code from datasets/models defined on the Hub." 54 | " This option should only be set to `True` for repositories you trust and in which you have read the" 55 | " code, as it will execute code present on the Hub on your local machine." 56 | ) 57 | }, 58 | ) 59 | 60 | dtype: str | None = field( 61 | default=None, 62 | metadata={ 63 | "help": ( 64 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 65 | "dtype will be automatically derived from the model's weights." 66 | ), 67 | "choices": ["auto", "bfloat16", "float16", "float32"], 68 | }, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = HfArgumentParser(ModelArguments) 74 | parser.parse_args_into_dataclasses() 75 | -------------------------------------------------------------------------------- /welt/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def stack_pad_tensors_fast(tensors: list[torch.Tensor], pad_value=0, dtype=None, device=None): 5 | # single fused op using NestedTensor (PyTorch ≥2.1) 6 | nt = torch.nested.nested_tensor(tensors, dtype=dtype, device=device) 7 | return nt.to_padded_tensor(pad_value) # shape: [B, *max_shape] 8 | 9 | 10 | def stack_pad_tensors_slow(tensors: list[torch.Tensor], pad_value=0): 11 | # Check if all tensors have the same shape 12 | shapes = [tensor.shape for tensor in tensors] 13 | if len(set(shapes)) == 1: 14 | # All shapes are the same, just stack 15 | return torch.stack(tensors) 16 | 17 | # Find maximum size for each dimension 18 | max_shape = [] 19 | ndim = max(len(shape) for shape in shapes) 20 | 21 | for dim in range(ndim): 22 | max_size = max(shape[dim] if dim < len(shape) else 1 23 | for shape in shapes) 24 | max_shape.append(max_size) 25 | 26 | # Pad each tensor to max_shape 27 | padded_tensors = [] 28 | for tensor in tensors: 29 | # Calculate padding needed for each dimension 30 | padding = [] 31 | for dim in reversed(range(ndim)): # padding goes from last dim to first 32 | if dim < len(tensor.shape): 33 | pad_size = max_shape[dim] - tensor.shape[dim] 34 | padding.extend([0, pad_size]) 35 | else: 36 | padding.extend([0, max_shape[dim]]) 37 | 38 | if padding: 39 | padded_tensor = torch.nn.functional.pad(tensor, padding, value=pad_value) 40 | else: 41 | padded_tensor = tensor 42 | 43 | # If tensor has fewer dimensions than max, add singleton dimensions 44 | while len(padded_tensor.shape) < ndim: 45 | padded_tensor = padded_tensor.unsqueeze(0) 46 | 47 | padded_tensors.append(padded_tensor) 48 | 49 | return torch.stack(padded_tensors) 50 | 51 | 52 | def stack_pad_tensors(tensors: list[torch.Tensor], pad_value=0): # noqa: C901 53 | """ 54 | Vibe coded: 55 | Generic collate function that automatically pads mismatched dimensions. 56 | For each tensor field in the batch, finds dimensions that don't match 57 | and pads them with zeros to the maximum size. 58 | """ 59 | # Early return if empty 60 | if len(tensors) == 0: 61 | return torch.empty(0) 62 | 63 | # Early return if not tensors 64 | if not hasattr(tensors[0], 'shape'): 65 | return tensors 66 | 67 | # Early return if single tensor 68 | if len(tensors) == 1: 69 | return tensors[0].unsqueeze(0) 70 | 71 | device = tensors[0].device 72 | if device.type == 'mps': 73 | return stack_pad_tensors_slow(tensors, pad_value=pad_value) 74 | 75 | dtype = tensors[0].dtype 76 | return stack_pad_tensors_fast(tensors, pad_value=pad_value, dtype=dtype, device=device) 77 | 78 | 79 | def stack_pad_tensors_list(tensors_list: list[list[torch.Tensor]], pad_value=0): 80 | return stack_pad_tensors([stack_pad_tensors(tensors, pad_value=pad_value) for tensors in tensors_list]) 81 | 82 | def collate_fn(batch: list, pad_value=0): 83 | if not batch: 84 | return batch 85 | 86 | # Get all keys from the first item 87 | keys = batch[0].keys() 88 | collated = {} 89 | 90 | for key in keys: 91 | tensors = [item[key] for item in batch] 92 | collated[key] = stack_pad_tensors(tensors, pad_value=pad_value) 93 | 94 | return collated 95 | -------------------------------------------------------------------------------- /examples/benchmark_image_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from font_download import FontConfig 3 | from font_download.example_fonts.noto_sans import FONTS_NOTO_SANS 4 | from pixel_renderer import PixelRendererProcessor 5 | from tqdm import tqdm 6 | from transformers import AutoImageProcessor, ViTModel 7 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 8 | from words_segmentation.tokenizer import WordsSegmentationTokenizer 9 | 10 | from welt.processor import TextImageProcessor 11 | from welt.vision.batch_image_encoder import encode_images, encode_padded_images 12 | from welt.vision.navit import NaViTConfig, NaViTModel 13 | 14 | text = """ 15 | Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are 16 | newly initialized: ['pooler.dense.bias', 'pooler.dense.weight'] 17 | You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. 18 | [INFO|trainer.py:2523] 2025-09-08 20:09:21,135 >> ***** Running training ***** 19 | [INFO|trainer.py:2524] 2025-09-08 20:09:21,138 >> Num examples = 138,167 20 | [INFO|trainer.py:2525] 2025-09-08 20:09:21,138 >> Num Epochs = 39 21 | [INFO|trainer.py:2526] 2025-09-08 20:09:21,138 >> Instantaneous batch size per device = 176 22 | [INFO|trainer.py:2529] 2025-09-08 20:09:21,138 >> Total train batch size (w. parallel, 23 | distributed & accumulation) = 176 24 | [INFO|trainer.py:2530] 2025-09-08 20:09:21,138 >> Gradient Accumulation steps = 1 25 | [INFO|trainer.py:2531] 2025-09-08 20:09:21,138 >> Total optimization steps = 30,000 26 | [INFO|trainer.py:2532] 2025-09-08 20:09:21,139 >> Number of trainable parameters = 30,286,208""" 27 | 28 | words = text.strip().split() # 117~ words 29 | print(len(words)) 30 | 31 | font_config = FontConfig(sources=FONTS_NOTO_SANS) 32 | processor = TextImageProcessor( 33 | pretokenizer=WordsSegmentationTokenizer(), 34 | tokenizer=UTF8Tokenizer(), 35 | renderer=PixelRendererProcessor(font=font_config), 36 | image_processor=AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224", use_fast=True), 37 | ) 38 | 39 | renders, dimensions = processor.render_texts(words) 40 | 41 | vit_model = ViTModel.from_pretrained("WinKawaks/vit-tiny-patch16-224") 42 | 43 | # Similar model~ 44 | navit_model = NaViTModel(NaViTConfig( 45 | image_size=512, # Max size for pos embeddings 46 | patch_size=vit_model.config.patch_size, 47 | hidden_size=vit_model.config.hidden_size, 48 | dim=vit_model.config.intermediate_size, 49 | depth=vit_model.config.num_hidden_layers // 2, 50 | heads=vit_model.config.num_attention_heads, 51 | mlp_dim=128, 52 | )) 53 | 54 | num_steps = 10 55 | batch_size = 10 56 | batch = torch.stack([renders] * batch_size) 57 | dimensions_batch = torch.stack([dimensions] * batch_size) 58 | 59 | device = "cpu" 60 | if torch.cuda.is_available(): 61 | device = "cuda" 62 | # elif torch.backends.mps.is_available(): 63 | # device = "mps" 64 | 65 | vit_model = vit_model.to(device) 66 | navit_model = navit_model.to(device) 67 | batch = batch.to(device) 68 | dimensions_batch = dimensions_batch.to(device) 69 | 70 | for _ in tqdm(range(num_steps), desc="ViT Model grouped"): 71 | encode_images(image_encoder=vit_model, input_images=batch, input_images_dimensions=dimensions_batch) 72 | 73 | for _ in tqdm(range(num_steps), desc="ViT Model with padding"): 74 | encode_padded_images(image_encoder=vit_model, input_images=batch) 75 | 76 | for _ in tqdm(range(num_steps), desc="NaViT Model direct"): 77 | encode_images(image_encoder=navit_model, input_images=batch, input_images_dimensions=dimensions_batch) 78 | -------------------------------------------------------------------------------- /.github/workflows/publish-docker.yaml: -------------------------------------------------------------------------------- 1 | # Adapted from https://docs.github.com/en/actions/tutorials/publishing-packages/publishing-docker-images 2 | name: Publish a Docker image 3 | 4 | # Configures this workflow to run every time a new release is created in the repository. 5 | on: 6 | release: 7 | types: [ created ] 8 | 9 | # Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds. 10 | env: 11 | REGISTRY: ghcr.io 12 | IMAGE_NAME: ${{ github.repository }} 13 | 14 | # There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu. 15 | jobs: 16 | build-and-push-image: 17 | runs-on: ubuntu-latest 18 | # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. 19 | permissions: 20 | contents: read 21 | packages: write 22 | attestations: write 23 | id-token: write 24 | 25 | steps: 26 | - name: Checkout repository 27 | uses: actions/checkout@v4 28 | 29 | # Reclaim ~50–80 GB on the GitHub runner 30 | - name: Free disk space 31 | uses: jlumbroso/free-disk-space@main 32 | with: 33 | tool-cache: true 34 | android: true 35 | dotnet: true 36 | haskell: true 37 | large-packages: true 38 | docker-images: true 39 | swap-storage: true 40 | 41 | # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. 42 | - name: Log in to the Container registry 43 | uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 44 | with: 45 | registry: ${{ env.REGISTRY }} 46 | username: ${{ github.actor }} 47 | password: ${{ secrets.GITHUB_TOKEN }} 48 | # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. 49 | - name: Extract metadata (tags, labels) for Docker 50 | id: meta 51 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 52 | with: 53 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 54 | # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. 55 | # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see [Usage](https://github.com/docker/build-push-action#usage) in the README of the `docker/build-push-action` repository. 56 | # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. 57 | - name: Build and push Docker image 58 | id: push 59 | uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 60 | with: 61 | context: . 62 | push: true 63 | tags: ${{ steps.meta.outputs.tags }} 64 | labels: ${{ steps.meta.outputs.labels }} 65 | 66 | # This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see [Using artifact attestations to establish provenance for builds](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds). 67 | - name: Generate artifact attestation 68 | uses: actions/attest-build-provenance@v2 69 | with: 70 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} 71 | subject-digest: ${{ steps.push.outputs.digest }} 72 | push-to-registry: true 73 | 74 | -------------------------------------------------------------------------------- /welt/vision/batch_image_encoder.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import torch 4 | from transformers import ViTForImageClassification, ViTModel 5 | from transformers.image_transforms import group_images_by_shape, reorder_images 6 | 7 | from welt.collator import stack_pad_tensors 8 | from welt.vision.navit import NaViTModel 9 | from welt.vision.vision_utils import encode_images as utils_encode_images 10 | 11 | # Define a type alias for image encoders, for type inference and clarity 12 | # However, every image encoder should be supported 13 | ImageEncoder = ViTModel | ViTForImageClassification | NaViTModel 14 | 15 | 16 | def encode_padded_images(image_encoder: ImageEncoder, 17 | input_images: torch.Tensor) -> torch.Tensor: 18 | """Image encoder should accept variable size images and return consistent embeddings.""" 19 | B, L, *_ = input_images.shape # noqa: N806 20 | 21 | linear_images = input_images.view(B * L, *input_images.shape[2:]) 22 | embeds = encode_images_batch(image_encoder=image_encoder, images=linear_images) 23 | embeds = embeds.view(B, L, -1) 24 | return embeds 25 | 26 | 27 | def encode_images(image_encoder: ImageEncoder, 28 | input_images: torch.Tensor, 29 | input_images_dimensions: torch.Tensor) -> torch.Tensor: 30 | """Image encoder should accept variable size images and return consistent embeddings.""" 31 | B, L, *_ = input_images.shape # noqa: N806 32 | 33 | # Recreate as list of lists if input is a nested tensor, cropping padding from each image 34 | nested_images = [ 35 | [img[:, :h, :w] for img, (h, w) in zip(images, dims, strict=False) if h > 0 and w > 0] 36 | for images, dims in zip(input_images, input_images_dimensions, strict=False) 37 | ] 38 | 39 | # Flatten images 40 | all_images = list(chain.from_iterable(nested_images)) 41 | 42 | # Re-arrange the embeddings to match the original batch structure 43 | embeddings = encode_images_group(image_encoder=image_encoder, images=all_images) 44 | 45 | # Restructure the flat embeddings back into the nested format 46 | embeds = embeddings.new_zeros((B, L, embeddings.size(-1))) 47 | 48 | lengths = torch.tensor([len(inner) for inner in nested_images], device=embeddings.device) 49 | row_idx = torch.repeat_interleave(torch.arange(B, device=embeddings.device), lengths) 50 | col_idx = torch.cat([torch.arange(n, device=embeddings.device) for n in lengths]) 51 | 52 | # One fused write instead of a Python loop 53 | embeds.index_put_((row_idx, col_idx), embeddings, accumulate=False) 54 | 55 | return embeds 56 | 57 | 58 | def encode_images_batch(image_encoder: ImageEncoder, 59 | images: list[torch.Tensor] | torch.Tensor) -> torch.Tensor: 60 | if isinstance(images, list): 61 | images = stack_pad_tensors(images) 62 | 63 | # Encode images using the image encoder 64 | return utils_encode_images(image_encoder, images) 65 | 66 | 67 | def encode_images_sequentially(image_encoder: ImageEncoder, 68 | images: list[torch.Tensor]) -> torch.Tensor: 69 | encoded_images = [encode_images_batch(image_encoder, image.unsqueeze(0)) for image in images] 70 | return torch.cat(encoded_images, dim=0) 71 | 72 | 73 | def encode_images_group(image_encoder: ImageEncoder, 74 | images: list[torch.Tensor]) -> torch.Tensor: 75 | if isinstance(image_encoder, NaViTModel): 76 | # NaViT can handle variable size images natively! 77 | return utils_encode_images(image_encoder, images) 78 | 79 | grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=False) 80 | 81 | # Encode each group separately 82 | encoded_groups = {size: encode_images_batch(image_encoder=image_encoder, images=group) 83 | for size, group in grouped_images.items()} 84 | 85 | # Re-arrange the encoded images to match the original order 86 | rearranged_images = reorder_images(encoded_groups, grouped_images_index) 87 | 88 | return torch.stack(rearranged_images) 89 | -------------------------------------------------------------------------------- /welt/attention.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from utf8_tokenizer.control import ControlTokens 5 | 6 | # Module-level caches that grow as needed 7 | _tril_cache: torch.Tensor | None = None 8 | _arange_cache: torch.Tensor | None = None 9 | 10 | 11 | def _get_tril(size: int) -> torch.Tensor: 12 | """Get a lower triangular matrix of at least the given size, using cached version if possible.""" 13 | global _tril_cache 14 | if _tril_cache is None or len(_tril_cache) < size: 15 | _tril_cache = torch.tril(torch.ones((size, size), dtype=torch.bool)) 16 | return _tril_cache 17 | 18 | 19 | def _get_arange(size: int) -> torch.Tensor: 20 | """Get an arange tensor of at least the given size, using cached version if possible.""" 21 | global _arange_cache 22 | if _arange_cache is None or len(_arange_cache) < size: 23 | _arange_cache = torch.arange(size, dtype=torch.long) 24 | return _arange_cache 25 | 26 | 27 | def get_shift_blocks(words: list[str]): 28 | """ 29 | Find shift blocks in a sequence of words. 30 | 31 | Yields tuples (start, end) where start is the index of ShiftOut 32 | and end is the index of ShiftIn (inclusive). Handles warnings for invalid blocks. 33 | 34 | Args: 35 | words: List of word strings 36 | 37 | Yields: 38 | Tuples of (start_idx, end_idx) for each valid shift block 39 | """ 40 | shift_out_idx = None 41 | 42 | for i, word in enumerate(words): 43 | if word == ControlTokens.ShiftOut: 44 | if shift_out_idx is not None: 45 | warnings.warn( 46 | "Shift Out (SO) detected after another Shift Out (SO) without Shift In (SI). " 47 | "Nested shift blocks are not allowed.", 48 | stacklevel=2) 49 | shift_out_idx = i 50 | if word == ControlTokens.ShiftIn: 51 | if shift_out_idx is None: 52 | warnings.warn( 53 | "Shift In (SI) detected, without first seeing Shift Out (SO). " 54 | "Skipping self-attention block.", 55 | stacklevel=2) 56 | else: 57 | yield shift_out_idx, i 58 | shift_out_idx = None 59 | 60 | if shift_out_idx is not None: 61 | warnings.warn( 62 | "Unclosed Shift Out (SO) block detected at end of sequence. " 63 | "Missing corresponding Shift In (SI).", 64 | stacklevel=2) 65 | 66 | 67 | def add_self_attention_blocks(mask: torch.Tensor, words: list[str]) -> None: 68 | # Attention blocks (PrefixLM / MAS) are surrounded by and tokens (`\xOE` ... `\x0F`). 69 | for start, end in get_shift_blocks(words): 70 | mask[0, start:end + 1, start:end + 1] = 1 71 | 72 | 73 | def get_attention_mask_for_packed_sequence(seq_lengths: list[int], words: list[str] = None) -> torch.Tensor: 74 | """ 75 | Returns a 3D attention mask for a packed sequence. (1, seq_len, seq_len) 76 | The first dimension represents the head dimension, which is set to 1 for broadcasting. 77 | """ 78 | total_length = sum(seq_lengths) 79 | 80 | mask = torch.zeros((1, total_length, total_length), dtype=torch.bool) 81 | 82 | # Use module-level cached tril matrix 83 | max_len = max(seq_lengths) 84 | tril = _get_tril(max_len) 85 | 86 | start_position = 0 87 | for length in seq_lengths: 88 | end_position = start_position + length 89 | mask[0, start_position:end_position, start_position:end_position] = tril[:length, :length] 90 | start_position = end_position 91 | 92 | if words is not None: 93 | add_self_attention_blocks(mask, words) 94 | 95 | return mask 96 | 97 | 98 | def get_position_ids_for_packed_sequence(seq_lengths: list[int]) -> torch.Tensor: 99 | # Use module-level cached arange and slice 100 | max_len = max(seq_lengths) 101 | arange = _get_arange(max_len) 102 | return torch.cat([arange[:length] for length in seq_lengths]) 103 | -------------------------------------------------------------------------------- /MOTIVATION.md: -------------------------------------------------------------------------------- 1 | # Why Text-as-Images Changes Everything 2 | 3 | **What if the root cause of most LLM failures is how we tokenize text?** 4 | 5 | Visually encoding text mimics how humans read, and creates equivalence 6 | between what the human sees and how the computer processes the text. 7 | 8 | Pre-tokenization into meaningful units (e.g. words) allows the model to encode information across languages 9 | more equally, and reduces the impact of tokenization artifacts. 10 | 11 | ## Tokenization 12 | 13 | In his [lecture](https://www.youtube.com/watch?v=zduSFxRajkE), Andrej Karpathy discusses weird behaviors in models 14 | that trace back to tokenization. 15 | 16 | - Why can't LLM spell words? **Tokenization**. 17 | - Why can't LLM do super simple string processing tasks like reversing a string? **Tokenization**. 18 | - Why is LLM worse at non-English languages (e.g. Japanese)? **Tokenization**. 19 | - Why is LLM bad at simple arithmetic? **Tokenization**. 20 | - Why did GPT-2 have more than necessary trouble coding in Python? **Tokenization**. 21 | - Why did my LLM abruptly halt when it sees the string ""? **Tokenization**. 22 | - What is this weird warning I get about a "trailing whitespace"? **Tokenization**. 23 | - Why the LLM break if I ask it about "SolidGoldMagikarp"? **Tokenization**. 24 | - Why should I prefer to use YAML over JSON with LLMs? **Tokenization**. 25 | - Why is LLM not actually end-to-end language modeling? **Tokenization**. 26 | - What is the real root of suffering? **Tokenization**. 27 | 28 | What if we encoded text as images of pre-tokenized words (alongside bytes)? 29 | 30 | - ✅ LLMs should be able to spell words, they see the characters. 31 | - ✅ LLMs should be able to do string processing tasks, they see the characters. 32 | - ✅ LLMs should be equally good at all languages, with equitable pre-tokenization. 33 | - ❌ Unclear if LLMs will be better at arithmetic, but they should be able to see the numbers. 34 | - ❌ Unclear if LLMs will be better at coding. 35 | - ✅ LLMs would not abruptly halt on special tokens, as the special tokens are different images. 36 | - ✅ No more warnings about trailing whitespace, as whitespace is part of the token. 37 | - ✅ LLMs should not break on weird words, as the tokenizer is not trained separately. 38 | - ✅ JSON and YAML should be equally easy, as the quote-marks are part of the token. 39 | - ☑️ LLMs should be more end-to-end language modeling *except for pre-tokenization*. 40 | - ❓ The real root of suffering is still unclear. 41 | 42 | ## Robust Open Vocabulary Translation 43 | 44 | [Salesky et al. (2021)](https://arxiv.org/pdf/2104.08211) claim that: 45 | 46 | > Machine translation models have discrete vocabularies and commonly use subword segmentation techniques 47 | > to achieve an ‘open vocabulary.’ This approach relies on consistent and correct underlying unicode sequences, 48 | > and makes models susceptible to degradation from common types of noise and variation. 49 | 50 | Examples of common behavior which cause divergent representations for subword models 51 | 52 | ### Diacritics 53 | 54 | For latin scripts, such as German, we may use diacritics such as Umlauts (e.g. `ä`, `ö`, `ü`). 55 | We can write them down either as a single character (e.g. Unicode Normalization Form C `ü` = [2448]), 56 | or as a combination of two characters (e.g. Unicode Normalization Form D `u` + `¨` = [84, 136, 230]). 57 | 58 | The paper gives the example of Arabic `كتاب` (with 3 tokens in GPT-4) which fully vowelized is `كِتَابٌ` (7 tokens). 59 | Another would be Hebrew `ספר` (with 5 tokens in GPT-4) which diacritized is `סֵפֶר` (9 tokens). 60 | 61 | ### Misspelling 62 | 63 | The paper gives an example of the words `language` and `langauge` which are tokenized as 1 and 2 tokens 64 | respectively in GPT-4, giving a very different representation to what is likely intended as the same meaning. 65 | The problems may only increase in non-latin scripts. 66 | 67 | ### Visually Similar / Identical Characters 68 | 69 | People often obfuscate text using visually similar or identical characters, 70 | with homograph attacks --- using characters that look the same from different scripts, 71 | to LeetSpeak --- using characters that look similar. 72 | 73 | For example, the Latin character `a` (U+0061) looks very similar to the Cyrillic character `а` (U+0430). 74 | Given the word `man` (1 token in GPT-4), if we replace the `a` with the Cyrillic `а`, we get `mаn` (3 tokens). 75 | 76 | The paper gives the LeetSpeak example for `really` vs `rea11y` (1 token vs 3 tokens in GPT-4). 77 | -------------------------------------------------------------------------------- /training/extendable_yaml.py: -------------------------------------------------------------------------------- 1 | """ 2 | YAML file extension utility that supports $extends directive for configuration inheritance. 3 | 4 | This module allows YAML configuration files to extend from other YAML files using 5 | the $extends directive, enabling configuration reuse and minimal config specifications. 6 | 7 | Example: 8 | base.yaml: 9 | model: gpt-3 10 | temperature: 0.7 11 | max_tokens: 100 12 | 13 | extended.yaml: 14 | $extends: ./base.yaml 15 | temperature: 0.9 # Override temperature 16 | # model and max_tokens are inherited 17 | """ 18 | 19 | import os 20 | import tempfile 21 | from pathlib import Path 22 | from typing import Any 23 | 24 | import yaml 25 | 26 | 27 | def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: 28 | """ 29 | Deep merge two dictionaries, with override values taking precedence. 30 | 31 | Args: 32 | base: The base dictionary 33 | override: The dictionary with override values 34 | 35 | Returns: 36 | A new merged dictionary 37 | """ 38 | result = base.copy() 39 | 40 | for key, value in override.items(): 41 | if key in result and isinstance(result[key], dict) and isinstance(value, dict): 42 | # Recursively merge nested dictionaries 43 | result[key] = deep_merge(result[key], value) 44 | else: 45 | # Override the value 46 | result[key] = value 47 | 48 | return result 49 | 50 | 51 | def load_yaml_with_extends(yaml_path: str | Path) -> dict[str, Any]: 52 | """ 53 | Load a YAML file and recursively resolve $extends directives. 54 | 55 | Args: 56 | yaml_path: Path to the YAML file 57 | 58 | Returns: 59 | The merged configuration dictionary 60 | 61 | Raises: 62 | FileNotFoundError: If the YAML file or any extended file is not found 63 | ValueError: If a circular dependency is detected 64 | """ 65 | yaml_path = Path(yaml_path).resolve() 66 | 67 | def _load_recursive(path: Path, visited: set[Path]) -> dict[str, Any]: 68 | """Recursively load and merge YAML files.""" 69 | if path in visited: 70 | raise ValueError(f"Circular dependency detected: {path}") 71 | 72 | if not path.exists(): 73 | raise FileNotFoundError(f"YAML file not found: {path}") 74 | 75 | visited.add(path) 76 | 77 | with open(path) as f: 78 | config = yaml.safe_load(f) or {} 79 | 80 | # Check if this file extends another 81 | if "$extends" in config: 82 | extends_path = config.pop("$extends") 83 | 84 | # Resolve the parent path relative to the current file 85 | if not os.path.isabs(extends_path): 86 | extends_path = (path.parent / extends_path).resolve() 87 | else: 88 | extends_path = Path(extends_path).resolve() 89 | 90 | # Load the parent configuration 91 | parent_config = _load_recursive(extends_path, visited.copy()) 92 | 93 | # Merge: parent as base, current config overrides 94 | config = deep_merge(parent_config, config) 95 | 96 | return config 97 | 98 | return _load_recursive(yaml_path, set()) 99 | 100 | 101 | def resolve_yaml_file(yaml_path: str | Path) -> str: 102 | """ 103 | Resolve a YAML file that may contain $extends directives. 104 | 105 | If the YAML file contains $extends, this function creates a temporary 106 | merged YAML file and returns its path. Otherwise, returns the original path. 107 | 108 | Args: 109 | yaml_path: Path to the YAML file 110 | 111 | Returns: 112 | Path to the resolved YAML file (either original or temporary merged file) 113 | """ 114 | yaml_path = Path(yaml_path) 115 | 116 | # Quick check: does the file contain $extends? 117 | with open(yaml_path) as f: 118 | first_line = f.readline().strip() 119 | if not first_line.startswith("$extends:"): 120 | # No extension, return original path 121 | return str(yaml_path.resolve()) 122 | 123 | # Load and merge configurations 124 | merged_config = load_yaml_with_extends(yaml_path) 125 | 126 | # Create a temporary YAML file with the merged configuration 127 | temp_file = tempfile.NamedTemporaryFile( 128 | mode='w', 129 | suffix='.yaml', 130 | prefix='merged_config_', 131 | delete=False 132 | ) 133 | 134 | try: 135 | yaml.dump(merged_config, temp_file, default_flow_style=False, sort_keys=False) 136 | temp_file.flush() 137 | return temp_file.name 138 | finally: 139 | temp_file.close() 140 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | .idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | 209 | 210 | .DS_Store 211 | .claude/ 212 | build 213 | output/ 214 | wandb/ -------------------------------------------------------------------------------- /training/args_data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from transformers.utils.versions import require_version 4 | 5 | 6 | @dataclass 7 | class DataTrainingArguments: 8 | """ 9 | Arguments pertaining to what data we are going to input our model for training and eval. 10 | """ 11 | 12 | dataset_name: str | None = field( 13 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 14 | ) 15 | dataset_config_name: str | None = field( 16 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 17 | ) 18 | dataset_text_template: str | None | list[str] = field( 19 | default=None, 20 | metadata={ 21 | "help": ( 22 | "Template to format dataset text using Python format strings with dataset column names. " 23 | "Single string: concatenated text for training/eval (e.g., 'Translate: {source} to {target}'). " 24 | "List of 2 strings: [prefix, completion] for generation-based evaluation. " 25 | " - During training: prefix + completion are concatenated. " 26 | " - During evaluation/testing: prefix used for generation, completion used as reference. " 27 | "Example: ['<{sign_language}> {sign_text} <{spoken_language}> ', '{spoken_text}']" 28 | ) 29 | }, 30 | ) 31 | train_file: str | None = field(default=None, metadata={"help": "The input training data file (a text file)."}) 32 | validation_file: str | None = field( 33 | default=None, 34 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 35 | ) 36 | max_train_samples: int | None = field( 37 | default=None, 38 | metadata={ 39 | "help": ( 40 | "For debugging purposes or quicker training, truncate the number of training examples to this " 41 | "value if set." 42 | ) 43 | }, 44 | ) 45 | max_eval_samples: int | None = field( 46 | default=None, 47 | metadata={ 48 | "help": ( 49 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 50 | "value if set." 51 | ) 52 | }, 53 | ) 54 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 55 | block_size: int | None = field( 56 | default=None, 57 | metadata={ 58 | "help": ( 59 | "Optional input sequence length after tokenization. " 60 | "The training dataset will be truncated in block of this size for training. " 61 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 62 | ) 63 | }, 64 | ) 65 | max_sequence_length: int | None = field( 66 | default=None, 67 | metadata={ 68 | "help": ( 69 | "Maximum sequence length for the model. " 70 | "Sequences will be truncated to this length if they are longer." 71 | ) 72 | }, 73 | ) 74 | max_word_length: int | None = field( 75 | default=128, 76 | metadata={ 77 | "help": ( 78 | "Maximum word length for the model. " 79 | "Words will be truncated to this length if they are longer." 80 | ) 81 | }, 82 | ) 83 | overwrite_cache: bool = field( 84 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 85 | ) 86 | validation_split_percentage: int | None = field( 87 | default=5, 88 | metadata={ 89 | "help": "The percentage of the train set used as validation set in case there's no validation split" 90 | }, 91 | ) 92 | preprocessing_num_workers: int | None = field( 93 | default=None, 94 | metadata={"help": "The number of processes to use for the preprocessing."}, 95 | ) 96 | keep_linebreaks: bool = field( 97 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 98 | ) 99 | 100 | def __post_init__(self): 101 | if self.streaming: 102 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 103 | 104 | # Validate dataset_text_template 105 | if self.dataset_text_template is not None: 106 | if isinstance(self.dataset_text_template, list): 107 | if len(self.dataset_text_template) != 2: 108 | msg = ( 109 | f"dataset_text_template must be either a string or a list of size 2. " 110 | f"Got a list of size {len(self.dataset_text_template)}." 111 | ) 112 | raise ValueError(msg) # noqa: TRY003 113 | elif not isinstance(self.dataset_text_template, str): 114 | msg = ( 115 | f"dataset_text_template must be either a string or a list of size 2. " 116 | f"Got {type(self.dataset_text_template).__name__}." 117 | ) 118 | raise ValueError(msg) # noqa: TRY003 119 | 120 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 121 | raise ValueError("Need either a dataset name or a training/validation file.") # noqa: TRY003 122 | else: 123 | if self.train_file is not None: 124 | extension = self.train_file.split(".")[-1] 125 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 126 | if self.validation_file is not None: 127 | extension = self.validation_file.split(".")[-1] 128 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 129 | -------------------------------------------------------------------------------- /welt/model_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | from functools import partial 3 | 4 | import torch 5 | from font_download import FontConfig 6 | from font_download.example_fonts.noto_sans import FONTS_NOTO_SANS 7 | from pixel_renderer import PixelRendererProcessor 8 | from transformers import ( 9 | AutoConfig, 10 | AutoImageProcessor, 11 | AutoTokenizer, 12 | PretrainedConfig, 13 | enable_full_determinism, 14 | set_seed, 15 | ) 16 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 17 | from words_segmentation.tokenizer import WordsSegmentationTokenizer 18 | 19 | from welt.collator import collate_fn 20 | from welt.config import WordLatentTransformerConfig 21 | from welt.model import WordLatentTransformerForCausalLM, logger 22 | from welt.noop import NoopConfig, NoopImageProcessor 23 | from welt.processor import TextImageProcessor 24 | from welt.vision.navit import NaViTConfig 25 | 26 | 27 | def print_model_summary(name: str, model): 28 | """Print a summary of the model's architecture.""" 29 | if model is None: 30 | print(name, "is None") 31 | return 32 | total_params = sum(p.numel() for p in model.parameters()) 33 | print(name, f"Total parameters: {total_params:,}") 34 | 35 | 36 | def get_attn_implementation(): 37 | if importlib.util.find_spec("flash_attn") is None: 38 | logger.warning("Flash Attention not available, using default attention") 39 | return None 40 | return "flash_attention_2" 41 | 42 | 43 | CUSTOM_MODELS: dict[str, PretrainedConfig] = { 44 | "NaViT-tiny": NaViTConfig( 45 | patch_size=16, 46 | hidden_size=128, 47 | dim=64, 48 | depth=3, 49 | heads=4, 50 | mlp_dim=128, 51 | dropout=0.0, 52 | emb_dropout=0.0, 53 | token_dropout_prob=0.1, 54 | ), 55 | "NaViT-small": NaViTConfig( 56 | patch_size=16, 57 | hidden_size=512, 58 | dim=256, 59 | depth=6, 60 | heads=8, 61 | mlp_dim=1024, 62 | dropout=0.0, 63 | emb_dropout=0.0, 64 | token_dropout_prob=0.1, 65 | ) 66 | } 67 | CUSTOM_PROCESSORS_ALIAS: dict[str, str] = { 68 | "NaViT-tiny": "WinKawaks/vit-tiny-patch16-224", 69 | "NaViT-small": "WinKawaks/vit-tiny-patch16-224", 70 | } 71 | 72 | 73 | def get_model_config(model_name): 74 | if model_name is None: 75 | return NoopConfig() 76 | 77 | if model_name in CUSTOM_MODELS: 78 | return CUSTOM_MODELS[model_name] 79 | return AutoConfig.from_pretrained(model_name) 80 | 81 | 82 | def setup_model( 83 | image_encoder_name="WinKawaks/vit-tiny-patch16-224", 84 | bytes_encoder_name="prajjwal1/bert-tiny", 85 | latent_transformer_name="EleutherAI/pythia-70m", 86 | bytes_decoder_name="sign/utf8-lm-tiny", 87 | pretokenizer_name: str | None = None, 88 | trust_remote_code=False, 89 | modality_dropout=0.15, 90 | dtype=torch.float32, 91 | seed=42, 92 | load_pretrained=True, 93 | max_word_length=None, 94 | ): 95 | set_seed(seed, deterministic=True) 96 | enable_full_determinism(seed=seed, warn_only=True) 97 | 98 | if image_encoder_name is not None: 99 | image_processor_name = CUSTOM_PROCESSORS_ALIAS.get(image_encoder_name, image_encoder_name) 100 | image_processor = AutoImageProcessor.from_pretrained(image_processor_name, use_fast=True) 101 | else: 102 | image_processor = NoopImageProcessor() 103 | 104 | tokenizer = UTF8Tokenizer() 105 | 106 | config = WordLatentTransformerConfig( 107 | # All sub-configs are loaded from the respective model names 108 | image_encoder=get_model_config(image_encoder_name), 109 | bytes_encoder=get_model_config(bytes_encoder_name), 110 | latent_transformer=get_model_config(latent_transformer_name), 111 | bytes_decoder=get_model_config(bytes_decoder_name), 112 | # Other configuration parameters 113 | modality_dropout=modality_dropout, 114 | tokenizer_class=tokenizer.__class__.__name__, 115 | num_tokens=len(tokenizer), 116 | bos_token_id=tokenizer.bos_token_id, 117 | pad_token_id=tokenizer.pad_token_id, 118 | eos_token_id=tokenizer.eos_token_id, 119 | sep_token_id=tokenizer.sep_token_id, 120 | trust_remote_code=trust_remote_code, 121 | dtype=dtype, 122 | ) 123 | 124 | # Combine the models 125 | model = WordLatentTransformerForCausalLM(config, 126 | load_pretrained=load_pretrained, 127 | attn_implementation=get_attn_implementation()) 128 | print_model_summary("Image Encoder", model.image_encoder) 129 | print_model_summary("Bytes Encoder", model.bytes_encoder) 130 | print_model_summary("Latent Transformer", model.latent_transformer) 131 | print_model_summary("Bytes Decoder", model.bytes_decoder) 132 | print_model_summary("Final Model", model) 133 | 134 | max_seq_length = getattr(model.latent_transformer.config, "max_position_embeddings", 1024) 135 | if max_word_length is None: 136 | max_word_length = getattr(model.bytes_decoder.config, "max_position_embeddings", 128) 137 | 138 | max_bytes = max_word_length - 2 # Reserve space for BOS and EOS tokens 139 | if pretokenizer_name is not None: 140 | print(f"Using pretokenizer: {pretokenizer_name}") 141 | pretokenizer = AutoTokenizer.from_pretrained(pretokenizer_name, 142 | use_fast=True, 143 | trust_remote_code=trust_remote_code) 144 | else: 145 | print("Using pretokenizer: WordsSegmentationTokenizer") 146 | pretokenizer = WordsSegmentationTokenizer(max_bytes=max_bytes) 147 | 148 | font_config = FontConfig(sources=FONTS_NOTO_SANS) 149 | renderer = PixelRendererProcessor(font=font_config) 150 | 151 | processor = TextImageProcessor( 152 | pretokenizer=pretokenizer, 153 | tokenizer=tokenizer, 154 | renderer=renderer, 155 | image_processor=image_processor, 156 | max_seq_length=max_seq_length, 157 | max_word_length=max_word_length, 158 | ) 159 | 160 | collator = partial(collate_fn, pad_value=tokenizer.pad_token_type_id) 161 | 162 | return model, processor, collator 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌎 WeLT: Word Embedding Latent Transformer 2 | 3 | ![Python](https://img.shields.io/badge/python-3.12-blue) 4 | [![License](https://img.shields.io/badge/license-MIT-green)](./LICENSE) 5 | 6 | [See Motivational Examples](./MOTIVATION.md) 7 | 8 | We present a language model that replaces conventional subword tokenization with a dual representation of text as 9 | both bytes and rendered images of "words", allowing the model to directly process the visual and symbolic 10 | structure of language. It segments text into whitespace-delimited units, encodes each as images and byte sequences, 11 | and passes them through an encoder–decoder pipeline: 12 | a bytes encoder, an image encoder, a large latent transformer, and a bytes decoder. 13 | At inference, predicted bytes are rendered back into images, closing the loop for the next prediction. 14 | This design could make learning and inference cheaper and faster for non-English languages, 15 | since the heavy latent transformer only predicts high-level token representations while the actual byte 16 | sequences are generated by a much smaller decoder. 17 | 18 | ![Model Architecture](./assets/architecture.png) 19 | 20 | ## Quick Start 21 | 22 | Clone and setup: 23 | 24 | ```shell 25 | git clone https://github.com/sign/WeLT.git 26 | cd WeLT 27 | ``` 28 | 29 | Install dependencies: 30 | 31 | ```shell 32 | conda create -n welt python=3.12 -y 33 | conda activate welt 34 | pip install ".[dev]" 35 | ``` 36 | 37 | Or using docker: 38 | 39 | ```shell 40 | docker build -t welt . 41 | 42 | # Run an interactive shell inside the container 43 | docker run -it --rm --gpus all \ 44 | --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ 45 | -v "$(pwd)/welt:/app/welt" \ 46 | -v "$(pwd)/training:/app/training" \ 47 | welt /bin/bash 48 | 49 | # Run a training job 50 | docker run -it --rm --gpus all \ 51 | --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ 52 | -v "$(pwd)/welt:/app/welt" \ 53 | -v "$(pwd)/training:/app/training" \ 54 | -v /shared/.cache/huggingface:/root/.cache/huggingface \ 55 | -v ~/.netrc:/root/.netrc:ro \ 56 | -e WANDB_PROJECT="ocr" \ 57 | welt python -m training.train training/experiments/easy-tasks/ocr.yaml 58 | ``` 59 | 60 | > [!TIP] 61 | > Run tests using `pytest` to ensure everything is working correctly. 62 | 63 | ## Model Setup 64 | 65 | - **Bytes Encoder** - You can use any language model as the bytes encoder (causal or masked). 66 | - **Image Encoder** - You can use any image encoder. 67 | - **Latent Transformer** - You can use any causal LM (recommended: large). 68 | - **Bytes Decoder** - You can use any causal LM (recommended: small). 69 | 70 | For language models, the parameter count is lower than reported, due to removing the embedding layers. 71 | 72 | Our implementation allows for any mix-and-match. Some example setups are: 73 | 74 | | Name | Bytes Encoder | Image Encoder | Latent Transformer | Bytes Decoder | Total Parameters | 75 | |--------|------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------|------------------| 76 | | tiny | [bert-tiny](https://huggingface.co/prajjwal1/bert-tiny) (0.5m) | [vit-tiny-patch16-224](https://huggingface.co/WinKawaks/vit-tiny-patch16-224) (5m) | [pythia-70m](https://huggingface.co/EleutherAI/pythia-70m) (19m) | [tiny-lm](sbintuitions/tiny-lm) (3m) | 28m | 77 | | small | [ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) (111m) | [swinv2-tiny-patch4-window16-256](https://huggingface.co/microsoft/swinv2-tiny-patch4-window16-256) (27m) | [gemma-3-270m](https://huggingface.co/google/gemma-3-270m) (100m) | [SmolLM2-135M](https://huggingface.co/HuggingFaceTB/SmolLM2-135M) (106m) | 346m | 78 | | medium | [deberta-v3-large](https://huggingface.co/microsoft/deberta-v3-large) (303m) | [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) (87m) | [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) (973m) | [gpt2-medium](https://huggingface.co/openai-community/gpt2-medium) (304m) | 1,674m | 79 | 80 | To turn off bytes encoding, set `bytes_encoder=False`, and similarly for images, set `image_encoder=False`. 81 | You can also turn off a specific encoder after training has completed, for testing purposes. 82 | 83 | > [!WARNING] 84 | > In implementation of the bytes decoder, we concatenate the embeddings of the bytes of the current token with 85 | > all the embeddings of the previous tokens (on the word level). This is done since not all causal LMs support 86 | > cross-attention, and so we want to avoid using it, and rely on the self-attention mechanism instead. 87 | 88 | ## Training 89 | 90 | Training instructions are available in the [training/README.md](./training/README.md). 91 | There, you can select the model architectures you want to use for each component, and the dataset you want to train on. 92 | 93 | ## Inference 94 | 95 | Since we have two decoders, the autoregressive prediction logic is a bit more complex than the usual, 96 | and supporting decoding algorithms like beam-search is not trivial. 97 | 98 | Thus, on the latent-transformer level, 99 | [we only support greedy decoding](https://github.com/sign/WeLT/issues/5) for now. 100 | On the bytes decoder level, we support all classical decoding algorithms supported by HuggingFace Transformers. 101 | 102 | ## Contributing 103 | 104 | See [open issues](https://github.com/search?q=repo%3Asign%2FWeLT+%22%2Fissues%2F%22&type=code) 105 | and [TODOs](https://github.com/search?q=repo%3Asign%2FWeLT%20TODO&type=code) in the codebase. 106 | 107 | During the creation of this repository, we created several others to support it: 108 | - [`sign/words-segmentation`](https://github.com/sign/words-segmentation) as a universal word level pretokenizer. 109 | - [`sign/utf8-tokenizer`](https://github.com/sign/utf8-tokenizer) as a robust byte-level tokenizer. 110 | - [`sign/pixel-renderer`](https://github.com/sign/pixel-renderer) as a reproducible text-to-image renderer. 111 | 112 | > [!WARNING] 113 | > Training runs are experimental until core issues are resolved. 114 | 115 | ## Cite 116 | 117 | If you use this code in your research, please consider citing the work: 118 | 119 | ```bibtex 120 | @misc{moryossef2025welt, 121 | title={{WeLT}: Word Embedding Latent Transformer for Equitable Modeling of the Languages of the World}, 122 | author={Moryossef, Amit}, 123 | howpublished={\url{https://github.com/sign/WeLT}}, 124 | year={2025} 125 | } 126 | ``` -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from utf8_tokenizer.control import ControlTokens 4 | 5 | from welt.attention import ( 6 | add_self_attention_blocks, 7 | get_attention_mask_for_packed_sequence, 8 | get_position_ids_for_packed_sequence, 9 | get_shift_blocks, 10 | ) 11 | 12 | 13 | def test_get_attention_mask_for_packed_sequence_single_sequence(): 14 | seq_lengths = [3] 15 | mask = get_attention_mask_for_packed_sequence(seq_lengths) 16 | 17 | expected = torch.tensor([[ 18 | [True, False, False], 19 | [True, True, False], 20 | [True, True, True] 21 | ]]) 22 | 23 | assert torch.equal(mask, expected) 24 | assert mask.shape == (1, 3, 3) 25 | 26 | 27 | def test_get_attention_mask_for_packed_sequence_two_sequences(): 28 | seq_lengths = [2, 2] 29 | mask = get_attention_mask_for_packed_sequence(seq_lengths) 30 | 31 | expected = torch.tensor([[ 32 | [True, False, False, False], 33 | [True, True, False, False], 34 | [False, False, True, False], 35 | [False, False, True, True] 36 | ]]) 37 | 38 | assert torch.equal(mask, expected) 39 | assert mask.shape == (1, 4, 4) 40 | 41 | 42 | def test_get_position_ids_for_packed_sequence_single_sequence(): 43 | seq_lengths = [3] 44 | position_ids = get_position_ids_for_packed_sequence(seq_lengths) 45 | 46 | expected = torch.tensor([0, 1, 2]) 47 | 48 | assert torch.equal(position_ids, expected) 49 | assert position_ids.shape == (3,) 50 | 51 | 52 | def test_get_position_ids_for_packed_sequence_two_sequences(): 53 | seq_lengths = [2, 2] 54 | position_ids = get_position_ids_for_packed_sequence(seq_lengths) 55 | 56 | expected = torch.tensor([0, 1, 0, 1]) 57 | 58 | assert torch.equal(position_ids, expected) 59 | assert position_ids.shape == (4,) 60 | 61 | 62 | def test_add_self_attention_blocks_basic_shift_block(): 63 | mask = torch.zeros((1, 5, 5), dtype=torch.bool) 64 | words = ["hello", ControlTokens.ShiftOut, "world", ControlTokens.ShiftIn, "end"] 65 | 66 | add_self_attention_blocks(mask, words) 67 | 68 | expected = torch.zeros((1, 5, 5), dtype=torch.bool) 69 | expected[0, 1:4, 1:4] = True 70 | 71 | assert torch.equal(mask, expected) 72 | 73 | 74 | def test_add_self_attention_blocks_multiple_shift_blocks(): 75 | mask = torch.zeros((1, 8, 8), dtype=torch.bool) 76 | words = [ 77 | "start", 78 | ControlTokens.ShiftOut, "first", "block", ControlTokens.ShiftIn, 79 | "middle", 80 | ControlTokens.ShiftOut, "second", ControlTokens.ShiftIn 81 | ] 82 | 83 | add_self_attention_blocks(mask, words) 84 | 85 | expected = torch.zeros((1, 8, 8), dtype=torch.bool) 86 | expected[0, 1:5, 1:5] = True 87 | expected[0, 6:8, 6:8] = True 88 | 89 | assert torch.equal(mask, expected) 90 | 91 | 92 | def test_add_self_attention_blocks_no_shift_tokens(): 93 | mask = torch.zeros((1, 3, 3), dtype=torch.bool) 94 | words = ["hello", "world", "test"] 95 | 96 | original_mask = mask.clone() 97 | add_self_attention_blocks(mask, words) 98 | 99 | assert torch.equal(mask, original_mask) 100 | 101 | 102 | def test_add_self_attention_blocks_only_shift_out(): 103 | mask = torch.zeros((1, 3, 3), dtype=torch.bool) 104 | words = ["hello", ControlTokens.ShiftOut, "world"] 105 | 106 | original_mask = mask.clone() 107 | with pytest.warns(UserWarning, match="Missing corresponding Shift In"): 108 | add_self_attention_blocks(mask, words) 109 | 110 | assert torch.equal(mask, original_mask) 111 | 112 | 113 | def test_add_self_attention_blocks_shift_in_without_out(): 114 | mask = torch.zeros((1, 3, 3), dtype=torch.bool) 115 | words = ["hello", ControlTokens.ShiftIn, "world"] 116 | 117 | original_mask = mask.clone() 118 | 119 | with pytest.warns(UserWarning, match="Skipping self-attention block."): 120 | add_self_attention_blocks(mask, words) 121 | 122 | assert torch.equal(mask, original_mask) 123 | 124 | 125 | def test_add_self_attention_blocks_single_token_block(): 126 | mask = torch.zeros((1, 3, 3), dtype=torch.bool) 127 | words = [ControlTokens.ShiftOut, ControlTokens.ShiftIn, "end"] 128 | 129 | add_self_attention_blocks(mask, words) 130 | 131 | expected = torch.zeros((1, 3, 3), dtype=torch.bool) 132 | expected[0, 0:2, 0:2] = True 133 | 134 | assert torch.equal(mask, expected) 135 | 136 | 137 | def test_add_self_attention_blocks_nested_shift_out(): 138 | mask = torch.zeros((1, 6, 6), dtype=torch.bool) 139 | words = [ 140 | "start", 141 | ControlTokens.ShiftOut, "outer", 142 | ControlTokens.ShiftOut, "inner", 143 | ControlTokens.ShiftIn 144 | ] 145 | 146 | with pytest.warns(UserWarning, match="Nested shift blocks are not allowed."): 147 | add_self_attention_blocks(mask, words) 148 | 149 | expected = torch.zeros((1, 6, 6), dtype=torch.bool) 150 | expected[0, 3:6, 3:6] = True 151 | 152 | assert torch.equal(mask, expected) 153 | 154 | 155 | def test_get_attention_mask_for_packed_sequence_with_shift_blocks(): 156 | seq_lengths = [7] 157 | words = [ 158 | "hello", 159 | ControlTokens.ShiftOut, "prefix", "block", ControlTokens.ShiftIn, 160 | "world", "end" 161 | ] 162 | 163 | mask = get_attention_mask_for_packed_sequence(seq_lengths, words) 164 | 165 | expected = torch.tensor([[ 166 | [True, False, False, False, False, False, False], 167 | [True, True, True, True, True, False, False], 168 | [True, True, True, True, True, False, False], 169 | [True, True, True, True, True, False, False], 170 | [True, True, True, True, True, False, False], 171 | [True, True, True, True, True, True, False], 172 | [True, True, True, True, True, True, True] 173 | ]]) 174 | 175 | assert torch.equal(mask, expected) 176 | 177 | 178 | def test_get_shift_blocks_single_block(): 179 | """Test get_shift_blocks returns correct indexes for a single shift block.""" 180 | words = ["hello", ControlTokens.ShiftOut, "world", "test", ControlTokens.ShiftIn, "end"] 181 | blocks = list(get_shift_blocks(words)) 182 | 183 | assert len(blocks) == 1 184 | assert blocks[0] == (1, 4) # ShiftOut at index 1, ShiftIn at index 4 185 | 186 | 187 | def test_get_shift_blocks_multiple_blocks(): 188 | """Test get_shift_blocks returns correct indexes for multiple shift blocks.""" 189 | words = [ 190 | "start", 191 | ControlTokens.ShiftOut, "first", ControlTokens.ShiftIn, 192 | "middle", 193 | ControlTokens.ShiftOut, "second", ControlTokens.ShiftIn, 194 | "end" 195 | ] 196 | blocks = list(get_shift_blocks(words)) 197 | 198 | assert len(blocks) == 2 199 | assert blocks[0] == (1, 3) # First block: ShiftOut at 1, ShiftIn at 3 200 | assert blocks[1] == (5, 7) # Second block: ShiftOut at 5, ShiftIn at 7 201 | 202 | 203 | def test_get_shift_blocks_no_blocks(): 204 | """Test get_shift_blocks returns empty when no shift blocks present.""" 205 | words = ["hello", "world", "test"] 206 | blocks = list(get_shift_blocks(words)) 207 | 208 | assert len(blocks) == 0 209 | 210 | 211 | if __name__ == "__main__": 212 | pytest.main([__file__, "-v"]) 213 | -------------------------------------------------------------------------------- /tests/test_model_overfitting.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | from transformers import Trainer, TrainingArguments 5 | from trl import pack_dataset 6 | 7 | from tests.test_model import make_dataset, predict_dataset, setup_tiny_model 8 | 9 | 10 | # TODO: this training is flaky due to https://github.com/huggingface/transformers/issues/40219 11 | def train_model(setup_function, 12 | num_epochs=10, 13 | train_texts=None, 14 | packing=False): 15 | model, processor, collator = setup_function() 16 | 17 | if train_texts is None: 18 | train_texts = ["a b", "b a", "a cat", "a dog"] 19 | 20 | train_dataset = make_dataset(train_texts) 21 | 22 | if packing: 23 | train_dataset = processor.pretokenize_dataset(train_dataset) 24 | train_dataset = pack_dataset(train_dataset, seq_length=7) 25 | 26 | train_dataset = train_dataset.with_transform(processor) 27 | 28 | # Setup training arguments with more epochs for overfitting 29 | training_args = TrainingArguments( 30 | output_dir=tempfile.mktemp(), 31 | num_train_epochs=num_epochs, 32 | per_device_train_batch_size=len(train_texts), 33 | logging_steps=1, 34 | logging_strategy="steps", 35 | save_strategy="no", # Disable saving to avoid shared tensor issues 36 | remove_unused_columns=False, 37 | dataloader_drop_last=False, 38 | warmup_steps=0, # No warmup for immediate learning 39 | weight_decay=0.0, # No regularization for overfitting 40 | learning_rate=1e-4, 41 | lr_scheduler_type="constant", # Keep learning rate constant 42 | ) 43 | 44 | # Initialize trainer 45 | trainer = Trainer( 46 | model=model, 47 | args=training_args, 48 | processing_class=processor, 49 | train_dataset=train_dataset, 50 | data_collator=collator, 51 | ) 52 | 53 | # Train the model 54 | print("Training model on texts:", train_texts) 55 | trainer.train() 56 | 57 | # Set to eval mode 58 | model.eval() 59 | 60 | return model, processor, collator 61 | 62 | 63 | @pytest.fixture(scope="module") 64 | def trained_models(): 65 | """Train the model once and reuse for all tests.""" 66 | num_epochs = 200 67 | return { 68 | "packed": train_model(setup_tiny_model, num_epochs=num_epochs, packing=True), 69 | "unpacked": train_model(setup_tiny_model, num_epochs=num_epochs, packing=False) 70 | } 71 | 72 | 73 | @pytest.fixture 74 | def model_configuration(request, trained_models): 75 | """Configure model based on the test parameter.""" 76 | model_type, config_name = request.param 77 | model, processor, collator = trained_models[model_type] 78 | 79 | # Store original encoders 80 | original_bytes_encoder = model.bytes_encoder 81 | original_image_encoder = model.image_encoder 82 | 83 | # Apply configuration 84 | if config_name == "full_model": 85 | print("\n[Configuration: Full model with all encoders]") 86 | # No changes needed 87 | elif config_name == "no_bytes_encoder": 88 | print("\n[Configuration: Model without bytes encoder]") 89 | model.bytes_encoder = None 90 | elif config_name == "no_image_encoder": 91 | print("\n[Configuration: Model without image encoder]") 92 | model.image_encoder = None 93 | 94 | # Yield the configured model 95 | yield model, processor, collator 96 | 97 | # Restore original encoders after test 98 | model.bytes_encoder = original_bytes_encoder 99 | model.image_encoder = original_image_encoder 100 | 101 | 102 | MODEL_CONFIGURATIONS = [ 103 | # Packed setups 104 | ("packed", "full_model"), 105 | ("packed", "no_bytes_encoder"), 106 | ("packed", "no_image_encoder"), 107 | 108 | # Unpacked setups 109 | ("unpacked", "full_model"), 110 | ("unpacked", "no_bytes_encoder"), 111 | ("unpacked", "no_image_encoder"), 112 | ] 113 | 114 | MODEL_IDs = [f"{config} / {model_type}" for config, model_type in MODEL_CONFIGURATIONS] 115 | parameterization = pytest.mark.parametrize("model_configuration", MODEL_CONFIGURATIONS, indirect=True, ids=MODEL_IDs) 116 | 117 | 118 | @parameterization 119 | def test_character_level_conditioning(model_configuration): 120 | """Test 1: Character-level conditioning (a b vs a a, b a vs b b)""" 121 | model, processor, collator = model_configuration 122 | 123 | print("\n=== Test 1: Character-level conditioning ===") 124 | 125 | test_texts_char = ["a b", "b a", "a a", "b b"] 126 | losses, predictions = predict_dataset(test_texts_char, model, processor, collator) 127 | 128 | # Check conditioning: trained sequences should have lower loss 129 | assert losses['a b'] < losses['a a'], \ 130 | f"'a b' should have lower loss than 'a a': {losses['a b']:.4f} vs {losses['a a']:.4f}" 131 | 132 | assert losses['b a'] < losses['b b'], \ 133 | f"'b a' should have lower loss than 'b b': {losses['b a']:.4f} vs {losses['b b']:.4f}" 134 | 135 | print("✅ Character-level conditioning test passed!") 136 | 137 | 138 | @parameterization 139 | def test_word_level_conditioning(model_configuration): 140 | """Test 2: Word-level conditioning (a cat vs a dat, a dog vs a cog)""" 141 | model, processor, collator = model_configuration 142 | 143 | print("\n=== Test 2: Word-level conditioning ===") 144 | 145 | test_texts_word = ["a cat", "a dog", "a dat", "a cog", "a bat", "a fog"] 146 | losses, predictions = predict_dataset(test_texts_word, model, processor, collator) 147 | 148 | # Check conditioning: trained sequences should have lower loss 149 | assert losses['a cat'] < losses['a dat'], \ 150 | f"'a cat' should have lower loss than 'a dat': {losses['a cat']:.4f} vs {losses['a dat']:.4f}" 151 | 152 | assert losses['a dog'] < losses['a cog'], \ 153 | f"'a dog' should have lower loss than 'a cog': {losses['a dog']:.4f} vs {losses['a cog']:.4f}" 154 | 155 | print("✅ Word-level conditioning test passed!") 156 | 157 | 158 | @parameterization 159 | def test_byte_level_conditioning(model_configuration): 160 | """Test 3: Byte-level conditioning within words""" 161 | model, processor, collator = model_configuration 162 | 163 | print("\n=== Test 3: Byte-level conditioning within words ===") 164 | 165 | # For "a cat" and "a dog", after seeing "a c" or "a d", the model should be confident about the rest 166 | # This tests that the byte decoder is properly conditioned on previous bytes 167 | 168 | # Create a special test to check conditional probabilities 169 | test_conditional = ["a cat", "a cog", "a dog", "a dat"] 170 | losses, predictions = predict_dataset(test_conditional, model, processor, collator) 171 | 172 | # After 'a c', 'cat' should be more likely than 'cog' 173 | assert losses['a cat'] < losses['a cog'], \ 174 | (f"'a cat' should have lower loss than 'a cog': " 175 | f"{losses['a cat']:.4f} vs {losses['a cog']:.4f}") 176 | 177 | # After 'a d', 'dog' should be more likely than 'dat' 178 | assert losses['a dog'] < losses['a dat'], \ 179 | (f"'a dog' should have lower loss than 'a dat': " 180 | f"{losses['a dog']:.4f} vs {losses['a dat']:.4f}") 181 | 182 | print("✅ Byte-level conditioning test passed!") 183 | 184 | 185 | if __name__ == "__main__": 186 | pytest.main([__file__, "-v"]) 187 | -------------------------------------------------------------------------------- /tests/test_model_generation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import AutoModelForCausalLM 4 | 5 | from tests.test_model import dataset_to_batch, make_dataset, setup_tiny_model 6 | 7 | 8 | @pytest.fixture(scope="module") 9 | def generation_model_setup(): 10 | """Setup the generation model with trained weights.""" 11 | 12 | # Setup the base model 13 | model, processor, collator = setup_tiny_model() 14 | 15 | # Set to eval mode 16 | model.eval() 17 | 18 | return model, processor, collator 19 | 20 | 21 | def predict_texts(texts: list[str], model, processor, collator): 22 | """Helper function to predict texts using the generation model.""" 23 | 24 | print("-" * 30) 25 | dataset = make_dataset(texts) 26 | batch = dataset_to_batch(model, processor, collator, dataset) 27 | 28 | processor.max_word_length = 5 29 | 30 | with torch.no_grad(): 31 | outputs = model.generate( 32 | input_ids=batch["input_ids"], 33 | input_attention_mask=batch["input_attention_mask"], 34 | input_images=batch["input_images"], 35 | input_images_dimensions=batch["input_images_dimensions"], 36 | attention_mask=batch["attention_mask"], 37 | processor=processor, 38 | max_generated_words=5 39 | ) 40 | 41 | for text, output in zip(texts, outputs, strict=False): 42 | print(f"Generated for '{text}': {output}") 43 | return outputs 44 | 45 | 46 | def test_batch_interference(generation_model_setup): 47 | """Test that generation of a batch does not interfere between texts.""" 48 | model, processor, collator = generation_model_setup 49 | 50 | print("\n=== Testing batch interference ===") 51 | batches = [ 52 | ["a"], 53 | ["a", "two words", ""], 54 | ["a", "even three words"], 55 | ["a", "b", "a_long_word"], 56 | ["a", "a"] 57 | ] 58 | outputs = [predict_texts(batch, model, processor, collator) for batch in batches] 59 | 60 | single = outputs[0][0] # Single result for "a" 61 | print(f"Single result for 'a': '{outputs[0][0]}'") 62 | print(f"Batch 1 result for 'a': '{outputs[1][0]}'") 63 | print(f"Batch 2 result for 'a': '{outputs[2][0]}'") 64 | print(f"Batch 3 result for 'a': '{outputs[3][0]}'") 65 | print(f"Batch 4 result for 'a': '{outputs[4][0]}'") 66 | 67 | # All results for 'a' should be the same 68 | all_a_results = [outputs[0][0], outputs[1][0], outputs[2][0], outputs[3][0], outputs[4][0], outputs[4][1]] 69 | print(f"\nAll results for 'a': {all_a_results}") 70 | 71 | # Check that all occurrences of "a" generate the same output 72 | assert all(result == single for result in all_a_results), \ 73 | f"Not all 'a' generations are equal. Expected all to be '{single}', but got: {all_a_results}" 74 | 75 | # Check that different inputs produce different outputs (model responds to input) 76 | # From batch 3: ["a", "b", "a_long_word"] - "a" and "b" should differ 77 | result_a = outputs[3][0] 78 | result_b = outputs[3][1] 79 | assert result_a != result_b, \ 80 | f"Different inputs 'a' and 'b' produced identical outputs: '{result_a}'. Model may not be responding to input." 81 | 82 | print("✅ Generation test passed - no batch interference detected") 83 | 84 | 85 | def test_same_text_in_batch(generation_model_setup): 86 | """Test that same text in batch returns same output.""" 87 | model, processor, collator = generation_model_setup 88 | 89 | print("\n=== Testing same text in batch returns same output ===") 90 | 91 | # Test with 4 times "a" 92 | batch_a = predict_texts(["a", "a", "a", "a"], model, processor, collator) 93 | print(f"\nBatch with 4x 'a': {batch_a}") 94 | assert all( 95 | result == batch_a[0] for result in batch_a), f"Same text 'a' in batch produced different outputs: {batch_a}" 96 | print("✅ All 'a' inputs in batch produced same output") 97 | 98 | # Test with 4 times "b" 99 | batch_b = predict_texts(["b", "b", "b", "b"], model, processor, collator) 100 | print(f"\nBatch with 4x 'b': {batch_b}") 101 | assert all( 102 | result == batch_b[0] for result in batch_b), f"Same text 'b' in batch produced different outputs: {batch_b}" 103 | print("✅ All 'b' inputs in batch produced same output") 104 | 105 | # Check that different inputs produce different outputs (model responds to input) 106 | assert batch_a[0] != batch_b[0], \ 107 | (f"Different inputs 'a' and 'b' produced identical outputs: '{batch_a[0]}'. " 108 | f"Model may not be responding to input.") 109 | print("✅ 'a' and 'b' produce different outputs - model responds to input") 110 | 111 | print("✅ Test passed!") 112 | 113 | def test_batch_vs_individual_different_lengths(): 114 | """ 115 | BUG from: https://github.com/sign/WeLT/issues/49 116 | 117 | Test that batch generation matches individual generation for texts of different lengths. 118 | 119 | This test uses the "sign/WeLT-string-repetition" model from HuggingFace and tests 120 | three texts of different lengths: "A B C D" (4 words), "E F G" (3 words), "H I" (2 words). 121 | 122 | The test should fail if there's a bug where batch processing doesn't properly handle 123 | texts of different lengths. 124 | """ 125 | print("\n=== Testing batch vs individual generation for different lengths ===") 126 | 127 | # Load the trained model from HuggingFace 128 | print("Loading model from HuggingFace: sign/WeLT-string-repetition") 129 | model = AutoModelForCausalLM.from_pretrained("sign/WeLT-string-repetition", trust_remote_code=True) 130 | model.eval() 131 | 132 | # Get processor from setup_tiny_model with no image encoder 133 | print("Setting up processor") 134 | _, processor, collator = setup_tiny_model(image_encoder_name=None) 135 | 136 | # Test texts of different lengths using the string-repetition task format 137 | # Format: \x0E\x0F 138 | texts = [ 139 | "\x0EA B C D\x0F ", # 4 words (longest) 140 | "\x0EE F G\x0F ", # 3 words (medium) 141 | "\x0EH I\x0F ", # 2 words (shortest) 142 | ] 143 | 144 | print(f"\nTest texts: {texts}") 145 | 146 | def generate_batch(batch): 147 | """Helper function to generate outputs for a batch.""" 148 | with torch.no_grad(): 149 | return model.generate( 150 | **batch, 151 | processor=processor, 152 | max_generated_words=10, 153 | ) 154 | 155 | # Prepare batches ahead of time 156 | individual_batches = [dataset_to_batch(model, processor, collator, make_dataset([text])) for text in texts] 157 | combined_batch = dataset_to_batch(model, processor, collator, make_dataset(texts)) 158 | 159 | # Generate each text individually 160 | print("\n--- Individual generation ---") 161 | individual_outputs = [] 162 | for text, batch in zip(texts, individual_batches, strict=False): 163 | outputs = generate_batch(batch) 164 | individual_outputs.append(outputs[0]) 165 | print(f" '{text}' -> '{outputs[0]}'") 166 | 167 | # Generate all texts as a batch 168 | print("\n--- Batch generation ---") 169 | batch_outputs = generate_batch(combined_batch) 170 | for text, output in zip(texts, batch_outputs, strict=False): 171 | print(f" '{text}' -> '{output}'") 172 | 173 | # Check that batch outputs match individual outputs 174 | print("\n--- Checking consistency ---") 175 | all_match = True 176 | for i, (text, individual, batch) in enumerate(zip(texts, individual_outputs, batch_outputs, strict=False)): 177 | match = individual == batch 178 | status = "✅" if match else "❌" 179 | print(f"{status} Text {i} ('{text}'):") 180 | print(f" Individual: '{individual}'") 181 | print(f" Batch: '{batch}'") 182 | print(f" Match: {match}") 183 | 184 | if not match: 185 | all_match = False 186 | 187 | # This assertion should fail if there's a bug 188 | assert all_match, ( 189 | "Batch generation does not match individual generation for texts of different lengths!\n" 190 | "Individual outputs: " + str(individual_outputs) + "\n" 191 | "Batch outputs: " + str(batch_outputs) 192 | ) 193 | 194 | print("\n✅ All outputs match - batch and individual generation are consistent!") 195 | 196 | 197 | if __name__ == "__main__": 198 | pytest.main([__file__, "-v"]) 199 | -------------------------------------------------------------------------------- /tests/test_batch_image_encoder.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | import pytest 4 | import torch 5 | from transformers import AutoConfig, AutoModel 6 | 7 | from welt.collator import stack_pad_tensors_list 8 | from welt.vision.batch_image_encoder import ( 9 | encode_images, 10 | encode_images_batch, 11 | encode_images_group, 12 | encode_images_sequentially, 13 | ) 14 | from welt.vision.navit import NaViTConfig 15 | from welt.vision.vision_utils import image_encoder_size 16 | 17 | MODELS = { 18 | "custom-navit": 512, 19 | "WinKawaks/vit-tiny-patch16-224": 192, 20 | "microsoft/swinv2-tiny-patch4-window16-256": 768, 21 | "google/vit-base-patch16-224": 768, 22 | "microsoft/resnet-18": 512, 23 | "apple/mobilevit-xx-small": 320, 24 | "facebook/dinov3-vits16-pretrain-lvd1689m": 384, 25 | "facebook/dinov3-convnext-tiny-pretrain-lvd1689m": 768, 26 | } 27 | 28 | MODEL_NAMES = list(MODELS.keys()) 29 | 30 | 31 | def images_dimensions(images: list[list[torch.Tensor]]) -> torch.Tensor: 32 | tensors = [ 33 | [torch.tensor([img.shape[-2], img.shape[-1]], dtype=torch.long) for img in batch] 34 | for batch in images 35 | ] 36 | return stack_pad_tensors_list(tensors) 37 | 38 | 39 | @cache 40 | def image_encoder(model_name): 41 | if model_name == "custom-navit": 42 | config = NaViTConfig() 43 | else: 44 | config = AutoConfig.from_pretrained(model_name) 45 | 46 | model = AutoModel.from_config(config) 47 | model.eval() 48 | return model 49 | 50 | 51 | def create_random_image(height, width, channels=3): 52 | """Create a random image tensor.""" 53 | return torch.randn(channels, height, width) 54 | 55 | 56 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 57 | def test_image_encoder_size(model_name): 58 | """Test that image_encoder_size returns the expected hidden size for each model.""" 59 | model = image_encoder(model_name) 60 | expected_size = MODELS[model_name] 61 | actual_size = image_encoder_size(model) 62 | assert actual_size == expected_size, f"Expected {expected_size}, got {actual_size} for {model_name}" 63 | 64 | 65 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 66 | def test_encode_images_single_image(model_name): 67 | """Test encode_images with a single image.""" 68 | model = image_encoder(model_name) 69 | 70 | # Single batch with single image 71 | images = [[create_random_image(64, 64)]] 72 | dimensions = images_dimensions(images) 73 | images = stack_pad_tensors_list(images) 74 | 75 | embeddings = encode_images(model, images, dimensions) 76 | 77 | assert embeddings.shape == (1, 1, embeddings.shape[2]) 78 | 79 | 80 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 81 | def test_encode_images_deterministic(model_name): 82 | model = image_encoder(model_name) 83 | 84 | images = [[create_random_image(64, 64)]] 85 | dimensions = images_dimensions(images) 86 | images = stack_pad_tensors_list(images) 87 | 88 | embeddings1 = encode_images(model, images, dimensions) 89 | embeddings2 = encode_images(model, images, dimensions) 90 | 91 | # Results should be identical 92 | assert torch.equal(embeddings1, embeddings2) 93 | 94 | 95 | def test_encode_non_equal_lists(): 96 | model = image_encoder("WinKawaks/vit-tiny-patch16-224") 97 | 98 | img = create_random_image(32, 32) 99 | images = [ 100 | [img], 101 | [img, img] 102 | ] 103 | dimensions = images_dimensions(images) 104 | images = stack_pad_tensors_list(images) 105 | 106 | # Test encoding 107 | embedding = encode_images(model, images, dimensions) 108 | assert embedding.shape == (2, 2, 192) 109 | 110 | assert torch.equal(embedding[0][0], embedding[1][0]) 111 | assert torch.equal(embedding[1][0], embedding[1][1]) 112 | assert torch.equal(embedding[0][1], torch.zeros_like(embedding[0][1])) 113 | 114 | 115 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 116 | def test_encode_images_batched_or_sequential(model_name): 117 | """Make sure the batch implementation and the sequential implementation return the same result""" 118 | model = image_encoder(model_name) 119 | 120 | images = [create_random_image(64, 64), create_random_image(64, 64)] 121 | 122 | embeddings1 = encode_images_batch(model, images) 123 | embeddings2 = encode_images_sequentially(model, images) 124 | 125 | assert embeddings1.shape == embeddings2.shape 126 | assert torch.allclose(embeddings1, embeddings2, atol=1e-5) 127 | 128 | 129 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 130 | def test_encode_images_group_or_sequential(model_name): 131 | """Make sure the batch implementation and the sequential implementation return the same result""" 132 | model = image_encoder(model_name) 133 | 134 | images = [ 135 | create_random_image(64, 64), 136 | create_random_image(64, 32), 137 | create_random_image(64, 64) 138 | ] 139 | 140 | embeddings1 = encode_images_sequentially(model, images) 141 | embeddings2 = encode_images_group(model, images) 142 | 143 | assert embeddings1.shape == embeddings2.shape 144 | assert torch.allclose(embeddings1, embeddings2, atol=1e-5) 145 | 146 | 147 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 148 | def test_encode_images_basic(model_name): 149 | """Test basic functionality of encode_images with different image sizes.""" 150 | model = image_encoder(model_name) 151 | 152 | # Create images of different sizes 153 | images = [ 154 | [create_random_image(32, 32), create_random_image(64, 64)], 155 | [create_random_image(32, 64), create_random_image(128, 32)] 156 | ] 157 | dimensions = images_dimensions(images) 158 | images = stack_pad_tensors_list(images) 159 | 160 | # Test encoding 161 | embeddings = encode_images(model, images, dimensions) 162 | 163 | # Check output shape 164 | assert embeddings.shape[0] == 2 # batch size 165 | assert embeddings.shape[1] == 2 # sequence length (number of images per batch) 166 | assert embeddings.shape[2] > 0 # embedding dimension 167 | 168 | 169 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 170 | @pytest.mark.parametrize("image_size", [(32, 32), (32, 64), (32, 128), (64, 128)]) 171 | def test_encode_images_different_models_and_sizes(model_name, image_size): 172 | """Test encode_images with different model types and image sizes.""" 173 | model = image_encoder(model_name) 174 | height, width = image_size 175 | 176 | # Create test image of specified size 177 | images = [[create_random_image(height, width)]] 178 | dimensions = images_dimensions(images) 179 | images = stack_pad_tensors_list(images) 180 | 181 | embeddings = encode_images(model, images, dimensions) 182 | 183 | assert embeddings.shape[0] == 1 # batch size 184 | assert embeddings.shape[1] == 1 # sequence length 185 | assert embeddings.shape[2] > 0 # embedding dimension 186 | 187 | 188 | @pytest.mark.parametrize("model_name", MODEL_NAMES) 189 | def test_encode_images_batch_vs_individual(model_name): 190 | """Test that batch processing gives same results as individual processing for different models.""" 191 | model = image_encoder(model_name) 192 | 193 | # Create test images of different sizes 194 | img1 = create_random_image(32, 32) 195 | img2 = create_random_image(64, 64) 196 | img3 = create_random_image(32, 128) 197 | img4 = create_random_image(64, 128) 198 | 199 | # Batch processing - all images at once 200 | batch_images = [[img1, img2, img3, img4]] 201 | dimensions = images_dimensions(batch_images) 202 | batch_images = stack_pad_tensors_list(batch_images) 203 | 204 | batch_embeddings = encode_images(model, batch_images, dimensions) 205 | 206 | # Individual processing - each image separately 207 | individual_embeddings = [] 208 | for img in [img1, img2, img3, img4]: 209 | sub_batch = [[img]] 210 | 211 | img_embeddings = encode_images(model, 212 | stack_pad_tensors_list(sub_batch), 213 | images_dimensions(sub_batch)) 214 | individual_embeddings.append(img_embeddings.squeeze(0)) 215 | 216 | pairs = zip(individual_embeddings, batch_embeddings[0], strict=False) 217 | for i, (individual_embedding, batch_embedding) in enumerate(pairs): 218 | print(individual_embedding.shape, batch_embedding.shape) 219 | # Compare results (allowing for small numerical differences) 220 | assert torch.allclose(batch_embedding, individual_embedding, atol=1e-3), \ 221 | f"Batch vs individual results ({i}) differ for model {model_name}" 222 | 223 | 224 | if __name__ == "__main__": 225 | pytest.main([__file__, "-v"]) 226 | -------------------------------------------------------------------------------- /welt/vision/vision_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vision utilities for transformers models. 3 | 4 | This module provides utilities for working with image encoders and vision models, 5 | including functions to determine encoder dimensions and handle configuration edge cases. 6 | """ 7 | import inspect 8 | from functools import cache 9 | 10 | import torch 11 | from transformers import AutoModelForImageClassification 12 | 13 | 14 | class UnknownImageEncoderError(ValueError): 15 | """ 16 | Exception raised when an image encoder's hidden size cannot be determined. 17 | 18 | This error is raised when the image encoder model doesn't have any of the 19 | expected configuration attributes for determining the hidden size 20 | """ 21 | 22 | def __init__(self): 23 | super().__init__("Image encoder does not have a known hidden size configuration.") 24 | 25 | 26 | @cache 27 | def image_encoder_size(image_encoder: AutoModelForImageClassification) -> int: 28 | """ 29 | Determine the hidden size of an image encoder model. 30 | 31 | This function extracts the hidden size dimension from various types of image encoder 32 | models by checking different configuration attributes in a prioritized order. 33 | 34 | Args: 35 | image_encoder: An AutoModelForImageClassification instance. 36 | 37 | Returns: 38 | int: The hidden size of the image encoder. 39 | 40 | Raises: 41 | UnknownImageEncoderError: If the image encoder doesn't have any of the 42 | expected configuration attributes for hidden size. 43 | 44 | Note: 45 | The function checks for configuration attributes in the following order: 46 | 1. config.vision_config.hidden_size (for CLIP-like models) 47 | 2. config.hidden_size (standard hidden size attribute) 48 | 3. config.neck_hidden_sizes (for MobileViT models, with expand_output handling) 49 | 4. config.hidden_sizes (fallback to last hidden size in the list) 50 | """ 51 | # Extract the model configuration, defaulting to empty dict if not found 52 | config = getattr(image_encoder, 'config', {}) 53 | 54 | # For multi-modal models like CLIP, the vision encoder config is nested 55 | if hasattr(config, 'vision_config'): 56 | config = config.vision_config 57 | 58 | # Most standard vision models have a direct hidden_size attribute 59 | if hasattr(config, 'hidden_size'): 60 | return config.hidden_size 61 | 62 | # Handle MobileViT models which use neck_hidden_sizes instead of hidden_size 63 | # Reference: https://huggingface.co/docs/transformers/model_doc/mobilevit#transformers.MobileViTModel 64 | if hasattr(config, 'neck_hidden_sizes'): 65 | # When expand_output is True, MobileViT applies an additional 1x1 convolution 66 | # to expand output channels from neck_hidden_sizes[5] to neck_hidden_sizes[6] 67 | if getattr(image_encoder, 'expand_output', False): 68 | return config.neck_hidden_sizes[-1] # Use the expanded output size 69 | return config.neck_hidden_sizes[-2] # Use the pre-expansion size 70 | 71 | # Fallback for models that store multiple layer sizes in a list (e.g., some ViT variants) 72 | if hasattr(config, 'hidden_sizes'): 73 | return config.hidden_sizes[-1] # Use the final layer's hidden size 74 | 75 | # No recognized hidden size configuration found 76 | raise UnknownImageEncoderError() 77 | 78 | 79 | @cache 80 | def model_args_dict(model: AutoModelForImageClassification) -> dict: 81 | """ 82 | Generate model arguments dictionary for image encoder forward pass. 83 | 84 | This function creates a dictionary of arguments optimized for feature extraction 85 | from image encoder models, including conditional parameters based on model capabilities. 86 | 87 | Args: 88 | model: An AutoModelForImageClassification instance to generate arguments for. 89 | 90 | Returns: 91 | dict: Dictionary of arguments to pass to the model's forward method. 92 | Always includes 'output_hidden_states': True. 93 | May include 'interpolate_pos_encoding': True if supported by the model. 94 | 95 | Note: 96 | The function is cached to avoid repeated signature inspection for the same model. 97 | Positional encoding interpolation is enabled for models that support it, 98 | allowing better handling of images with different sizes than training data. 99 | """ 100 | # Configure model arguments to output hidden states for feature extraction 101 | args = {"output_hidden_states": True} 102 | 103 | # Enable positional encoding interpolation if the model supports it 104 | # This is useful for handling images of different sizes than training 105 | if accepts(model.forward, 'interpolate_pos_encoding'): 106 | args['interpolate_pos_encoding'] = True 107 | 108 | return args 109 | 110 | 111 | @cache 112 | def accepts(func, param_name: str) -> bool: 113 | """ 114 | Check if a function accepts a specific parameter. 115 | 116 | This function inspects the signature of a given function to determine whether 117 | it accepts a specific parameter either as a named parameter or through **kwargs. 118 | 119 | Args: 120 | func: The function to inspect. 121 | param_name: The name of the parameter to check for. 122 | 123 | Returns: 124 | bool: True if the function accepts the parameter, False otherwise. 125 | 126 | Note: 127 | Returns True if either: 128 | 1. The parameter name is explicitly defined in the function signature 129 | 2. The function accepts **kwargs (VAR_KEYWORD parameters) 130 | """ 131 | sig = inspect.signature(func) 132 | return ( 133 | param_name in sig.parameters 134 | or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) 135 | ) 136 | 137 | 138 | def pool_hidden_dim(tensor: torch.Tensor, hidden_size: int) -> torch.Tensor: 139 | """ 140 | Pool a tensor across all dimensions except batch and hidden dimensions. 141 | 142 | This function performs mean pooling across spatial or patch dimensions while 143 | preserving the batch and hidden dimensions. It works with various tensor layouts 144 | from different vision model architectures. 145 | 146 | Args: 147 | tensor: Input tensor to pool. Can have various shapes depending on the model: 148 | - ViT-like: `(batch_size, num_patches, hidden_size)` 149 | - ConvNet-like: `(batch_size, height, width, channels)` or 150 | `(batch_size, channels, height, width)` 151 | hidden_size: The size of the hidden/feature dimension to preserve. 152 | 153 | Returns: 154 | torch.Tensor: Pooled tensor with shape `(batch_size, hidden_size)`. 155 | 156 | Raises: 157 | StopIteration: If no dimension matches the specified hidden_size (excluding batch dim). 158 | 159 | Note: 160 | The function identifies the hidden dimension by finding the dimension that 161 | matches hidden_size (excluding the batch dimension at index 0), then pools 162 | across all other non-batch, non-hidden dimensions. 163 | """ 164 | # Find the dimension index that matches our hidden size (skip batch dim at index 0) 165 | hidden_dim = next(i for i, s in enumerate(tensor.shape) if s == hidden_size and i != 0) 166 | 167 | # Identify all dimensions to pool over (everything except batch and hidden dims) 168 | non_hidden_dims = tuple(i for i in range(len(tensor.shape)) if i != hidden_dim and i != 0) 169 | 170 | # Perform mean pooling across spatial/patch dimensions 171 | return tensor.mean(dim=non_hidden_dims) 172 | 173 | 174 | def encode_images(image_encoder: AutoModelForImageClassification, 175 | images: list[torch.Tensor] | torch.Tensor) -> torch.Tensor: 176 | """ 177 | Encode a batch of images using the provided image encoder model. 178 | 179 | This function runs images through the encoder and extracts the final hidden states, 180 | with optional support for positional encoding interpolation when available. 181 | 182 | Args: 183 | image_encoder: An AutoModelForImageClassification instance used for encoding. 184 | images: A tensor of shape `(batch_size, channels, height, width)` containing 185 | the preprocessed images to encode. 186 | 187 | Returns: 188 | torch.Tensor: The encoded image features with shape `(batch_size, hidden_size)`. 189 | Features are pooled across spatial/patch dimensions. 190 | 191 | Note: 192 | The function automatically enables output_hidden_states to access intermediate 193 | representations and conditionally enables interpolate_pos_encoding for models 194 | that support dynamic positional encoding based on input image size. 195 | """ 196 | # Configure model arguments to output hidden states for feature extraction 197 | model_args = model_args_dict(image_encoder) 198 | 199 | # Run the forward pass through the image encoder 200 | encoded_images = image_encoder(images, **model_args) 201 | 202 | # Default to using pooler_output if available (shape [batch_size, hidden_size]) 203 | if hasattr(encoded_images, "pooler_output"): 204 | pooled_output = encoded_images.pooler_output 205 | # ResNet outputs include extra dimensions (batch_size, hidden_size, 1, 1) 206 | pooled_output = pooled_output.squeeze() 207 | if pooled_output.dim() == 1: 208 | pooled_output = pooled_output.unsqueeze(0) 209 | return pooled_output 210 | 211 | # Extract the final layer's hidden states (shape varies by model architecture) 212 | if hasattr(encoded_images, "last_hidden_state"): 213 | last_hidden_states = encoded_images.last_hidden_state 214 | else: 215 | last_hidden_states = encoded_images.hidden_states[-1] 216 | 217 | # Get the hidden size dimension for this encoder model 218 | hidden_size = image_encoder_size(image_encoder) 219 | 220 | # Pool across spatial/patch dimensions to get [batch_size, hidden_size] output 221 | return pool_hidden_dim(last_hidden_states, hidden_size) 222 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | End-to-end test for the training script. 3 | 4 | This module tests the complete training pipeline by running a minimal training 5 | run that completes in ~10 seconds, verifying the entire train.py workflow. 6 | """ 7 | import shutil 8 | import tempfile 9 | from pathlib import Path 10 | 11 | import pytest 12 | import yaml 13 | 14 | from training.train import train 15 | 16 | 17 | @pytest.fixture 18 | def temp_output_dir(): 19 | """Create a temporary directory for training output.""" 20 | temp_dir = tempfile.mkdtemp(prefix="test_train_") 21 | yield temp_dir 22 | # Cleanup after test 23 | shutil.rmtree(temp_dir, ignore_errors=True) 24 | 25 | 26 | def test_basic_training_with_eval_chrf(temp_output_dir): 27 | """ 28 | Test that a basic training run works end-to-end and reports eval_chrf. 29 | 30 | This test: 31 | 1. Loads the known-good string-repetition config 32 | 2. Modifies it for fast testing (10 steps, small dataset) 33 | 3. Trains for 10 steps (very fast) 34 | 4. Runs evaluation 35 | 5. Verifies that eval_chrf metric is reported 36 | """ 37 | # Load the base config from the experiment 38 | base_config_path = Path(__file__).parent.parent / "training/experiments/easy-tasks/string-repetition.yaml" 39 | with open(base_config_path) as f: 40 | config = yaml.safe_load(f) 41 | 42 | # Modify config for fast testing 43 | config.update({ 44 | # Output to temp directory 45 | "output_dir": temp_output_dir, 46 | 47 | # Minimal training for speed 48 | "max_steps": 10, 49 | "max_train_samples": 10, 50 | "max_eval_samples": 5, 51 | 52 | # Disable reporting 53 | "report_to": [], 54 | 55 | # Disable checkpointing 56 | "save_strategy": "no", 57 | 58 | # Evaluate after training (not at start) 59 | "eval_on_start": False, 60 | "eval_steps": 10, 61 | 62 | # Test specifically with chrf metric 63 | "eval_metrics": ["chrf"], 64 | "metric_for_best_model": None, # Disable since we're not saving checkpoints 65 | 66 | # Minimal dataloader for speed 67 | "dataloader_num_workers": 0, 68 | "dataloader_prefetch_factor": None, 69 | "dataloader_pin_memory": False, 70 | "dataloader_persistent_workers": False, 71 | 72 | # Reduce logging 73 | "logging_steps": 5, 74 | "log_samples": 1, 75 | 76 | # Disable bf16 for compatibility 77 | "bf16": False, 78 | }) 79 | 80 | # Write modified config to temp file 81 | config_path = Path(temp_output_dir) / "test_config.yaml" 82 | with open(config_path, "w") as f: 83 | yaml.dump(config, f) 84 | 85 | # Run training 86 | train(args=str(config_path)) 87 | 88 | # Verify that training completed and metrics were saved 89 | output_dir = Path(temp_output_dir) 90 | 91 | # Check that eval metrics file was created 92 | eval_results_path = output_dir / "eval_results.json" 93 | assert eval_results_path.exists(), "eval_results.json should be created" 94 | 95 | # Load and verify metrics 96 | import json 97 | with open(eval_results_path) as f: 98 | eval_metrics = json.load(f) 99 | 100 | # Verify eval_chrf is present 101 | assert "eval_chrf" in eval_metrics, \ 102 | f"eval_chrf should be in metrics. Found: {list(eval_metrics.keys())}" 103 | 104 | # Verify eval_chrf is a valid number 105 | chrf_score = eval_metrics["eval_chrf"] 106 | assert isinstance(chrf_score, int | float), \ 107 | f"eval_chrf should be numeric, got {type(chrf_score)}" 108 | assert 0 <= chrf_score <= 100, \ 109 | f"eval_chrf should be between 0 and 100, got {chrf_score}" 110 | 111 | # Verify other expected metrics are present 112 | assert "eval_loss" in eval_metrics, "eval_loss should be present" 113 | assert "eval_samples" in eval_metrics, "eval_samples should be present" 114 | assert "perplexity" in eval_metrics, "perplexity should be present" 115 | 116 | print("\n✓ Training completed successfully!") 117 | print(f"✓ eval_chrf = {chrf_score:.2f}") 118 | print(f"✓ eval_loss = {eval_metrics['eval_loss']:.4f}") 119 | print(f"✓ eval_samples = {eval_metrics['eval_samples']}") 120 | print(f"✓ All metrics: {list(eval_metrics.keys())}") 121 | 122 | 123 | def test_training_without_generation_metrics(temp_output_dir): 124 | """Test that training works without generation-based metrics (backward compatibility).""" 125 | base_config_path = Path(__file__).parent.parent / "training/experiments/easy-tasks/string-repetition.yaml" 126 | with open(base_config_path) as f: 127 | config = yaml.safe_load(f) 128 | 129 | # Modify config to disable generation metrics 130 | config.update({ 131 | "output_dir": temp_output_dir, 132 | "max_steps": 5, 133 | "max_train_samples": 10, 134 | "max_eval_samples": 5, 135 | "report_to": [], 136 | "save_strategy": "no", 137 | "eval_on_start": False, 138 | "eval_steps": 5, 139 | "eval_metrics": None, # No generation metrics 140 | "metric_for_best_model": None, # Disable since no generation metrics 141 | "dataloader_num_workers": 0, 142 | "dataloader_prefetch_factor": None, 143 | "dataloader_pin_memory": False, 144 | "dataloader_persistent_workers": False, 145 | "logging_steps": 5, 146 | "log_samples": 0, 147 | "bf16": False, 148 | }) 149 | 150 | config_path = Path(temp_output_dir) / "test_config_no_metrics.yaml" 151 | with open(config_path, "w") as f: 152 | yaml.dump(config, f) 153 | 154 | # Run training 155 | train(args=str(config_path)) 156 | 157 | # Verify training completed 158 | output_dir = Path(temp_output_dir) 159 | eval_results_path = output_dir / "eval_results.json" 160 | assert eval_results_path.exists() 161 | 162 | import json 163 | with open(eval_results_path) as f: 164 | eval_metrics = json.load(f) 165 | 166 | # Should have loss and perplexity but no generation metrics 167 | assert "eval_loss" in eval_metrics 168 | assert "perplexity" in eval_metrics 169 | assert "eval_samples" in eval_metrics 170 | # Should not have generation metrics 171 | assert "eval_chrf" not in eval_metrics 172 | assert "eval_sacrebleu" not in eval_metrics 173 | 174 | 175 | def test_training_with_sacrebleu(temp_output_dir): 176 | """Test training with sacrebleu metric specifically.""" 177 | base_config_path = Path(__file__).parent.parent / "training/experiments/easy-tasks/string-repetition.yaml" 178 | with open(base_config_path) as f: 179 | config = yaml.safe_load(f) 180 | 181 | config.update({ 182 | "output_dir": temp_output_dir, 183 | "max_steps": 5, 184 | "max_train_samples": 10, 185 | "max_eval_samples": 5, 186 | "report_to": [], 187 | "save_strategy": "no", 188 | "eval_on_start": False, 189 | "eval_steps": 5, 190 | "eval_metrics": ["sacrebleu"], # Only sacrebleu 191 | "metric_for_best_model": None, 192 | "dataloader_num_workers": 0, 193 | "dataloader_prefetch_factor": None, 194 | "dataloader_pin_memory": False, 195 | "dataloader_persistent_workers": False, 196 | "logging_steps": 5, 197 | "log_samples": 1, 198 | "bf16": False, 199 | }) 200 | 201 | config_path = Path(temp_output_dir) / "test_config_sacrebleu.yaml" 202 | with open(config_path, "w") as f: 203 | yaml.dump(config, f) 204 | 205 | # Run training 206 | train(args=str(config_path)) 207 | 208 | # Verify training completed 209 | output_dir = Path(temp_output_dir) 210 | eval_results_path = output_dir / "eval_results.json" 211 | assert eval_results_path.exists() 212 | 213 | import json 214 | with open(eval_results_path) as f: 215 | eval_metrics = json.load(f) 216 | 217 | # Should have sacrebleu 218 | assert "eval_sacrebleu" in eval_metrics 219 | assert 0 <= eval_metrics["eval_sacrebleu"] <= 100 220 | assert "eval_loss" in eval_metrics 221 | assert "perplexity" in eval_metrics 222 | 223 | 224 | def test_training_determinism(temp_output_dir): 225 | """Test that training with same seed produces similar results.""" 226 | base_config_path = Path(__file__).parent.parent / "training/experiments/easy-tasks/string-repetition.yaml" 227 | with open(base_config_path) as f: 228 | config = yaml.safe_load(f) 229 | 230 | config.update({ 231 | "output_dir": temp_output_dir, 232 | "max_steps": 3, 233 | "max_train_samples": 5, 234 | "max_eval_samples": 3, 235 | "report_to": [], 236 | "save_strategy": "no", 237 | "eval_on_start": False, 238 | "eval_steps": 3, 239 | "eval_metrics": ["bleu"], 240 | "metric_for_best_model": None, 241 | "dataloader_num_workers": 0, 242 | "dataloader_prefetch_factor": None, 243 | "dataloader_pin_memory": False, 244 | "dataloader_persistent_workers": False, 245 | "logging_steps": 3, 246 | "log_samples": 0, 247 | "bf16": False, 248 | "seed": 12345, # Fixed seed for determinism 249 | }) 250 | 251 | # Run 1 252 | config_path1 = Path(temp_output_dir) / "test_config_run1.yaml" 253 | with open(config_path1, "w") as f: 254 | yaml.dump(config, f) 255 | 256 | train(args=str(config_path1)) 257 | 258 | import json 259 | with open(Path(temp_output_dir) / "eval_results.json") as f: 260 | metrics1 = json.load(f) 261 | 262 | # Clean up for run 2 263 | shutil.rmtree(temp_output_dir, ignore_errors=True) 264 | Path(temp_output_dir).mkdir(parents=True, exist_ok=True) 265 | 266 | # Run 2 with same config 267 | config_path2 = Path(temp_output_dir) / "test_config_run2.yaml" 268 | with open(config_path2, "w") as f: 269 | yaml.dump(config, f) 270 | 271 | train(args=str(config_path2)) 272 | 273 | with open(Path(temp_output_dir) / "eval_results.json") as f: 274 | metrics2 = json.load(f) 275 | 276 | # Results should be similar (within tolerance for floating point differences) 277 | assert abs(metrics1["eval_loss"] - metrics2["eval_loss"]) < 0.5, \ 278 | f"Losses should be similar: {metrics1['eval_loss']} vs {metrics2['eval_loss']}" 279 | 280 | 281 | if __name__ == "__main__": 282 | pytest.main([__file__, "-v", "-s"]) 283 | -------------------------------------------------------------------------------- /welt/processor.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | 4 | import torch 5 | from cachetools import LRUCache 6 | from datasets import Dataset 7 | from pixel_renderer import PixelRendererProcessor 8 | from transformers import ImageProcessingMixin, PreTrainedTokenizer, ProcessorMixin 9 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 10 | from words_segmentation.tokenizer import WordsSegmentationTokenizer # noqa: F401 - for registering AutoTokenizer 11 | 12 | from welt.attention import ( 13 | get_attention_mask_for_packed_sequence, 14 | get_position_ids_for_packed_sequence, 15 | get_shift_blocks, 16 | ) 17 | from welt.collator import collate_fn, stack_pad_tensors 18 | from welt.noop import NoopImageProcessor 19 | 20 | 21 | class TextImageProcessor(ProcessorMixin): 22 | name = "text-image-processor" 23 | 24 | attributes = [ 25 | "pretokenizer", 26 | "tokenizer", 27 | "renderer", 28 | "image_processor" 29 | ] 30 | pretokenizer_class = "AutoTokenizer" 31 | tokenizer_class = "AutoTokenizer" 32 | renderer_class = "PixelRendererProcessor" 33 | image_processor_class = "AutoImageProcessor" 34 | 35 | def __init__(self, 36 | pretokenizer: PreTrainedTokenizer, 37 | tokenizer: UTF8Tokenizer, 38 | renderer: PixelRendererProcessor, 39 | image_processor: ImageProcessingMixin, 40 | max_seq_length: int = 128, 41 | max_word_length: int = 32, 42 | cache_size: int = 10000): 43 | super().__init__(pretokenizer=pretokenizer, 44 | tokenizer=tokenizer, 45 | renderer=renderer, 46 | image_processor=image_processor) 47 | 48 | assert tokenizer.bos_token_id is not None, "Tokenizer must have a BOS token" 49 | assert tokenizer.eos_token_id is not None, "Tokenizer must have an EOS token" 50 | 51 | self.pretokenizer = pretokenizer 52 | self.tokenizer = tokenizer 53 | self.renderer = renderer 54 | self.image_processor = image_processor 55 | 56 | self.max_word_length = max_word_length 57 | self.max_seq_length = max_seq_length 58 | self.cache_size = cache_size 59 | 60 | self.images_cache = LRUCache(maxsize=self.cache_size) 61 | 62 | def render_texts(self, texts: list[str]) -> tuple[torch.Tensor, torch.Tensor]: 63 | if isinstance(self.image_processor, NoopImageProcessor): 64 | return torch.empty(1,), torch.empty(1,) 65 | 66 | images = [self.images_cache.get(text, None) for text in texts] 67 | 68 | # Render all missing texts and group by size for efficient batching 69 | render_groups = defaultdict(list) 70 | index_groups = defaultdict(list) 71 | for i, v in enumerate(images): 72 | if v is None: 73 | render = self.renderer.render_text(texts[i]) 74 | index_groups[render.shape].append(i) 75 | render_groups[render.shape].append(render) 76 | 77 | # Process each shape group and update cache 78 | for shape, renders in render_groups.items(): 79 | processed = self.image_processor(renders, return_tensors="pt", do_center_crop=False, do_resize=False) 80 | pixel_values = processed.pixel_values.to(torch.bfloat16) # TODO : make dtype configurable 81 | for i, pixel_value in zip(index_groups[shape], pixel_values, strict=True): 82 | self.images_cache[texts[i]] = pixel_value 83 | images[i] = pixel_value 84 | 85 | image_dimensions = torch.tensor([img.shape[-2:] for img in images], dtype=torch.long) 86 | return stack_pad_tensors(images), image_dimensions 87 | 88 | def pretokenize(self, text: str) -> list[str]: 89 | # Add BOS token at the start 90 | text = self.tokenizer.bos_token + text.strip() 91 | 92 | # TODO: Ensure all texts end with a space. this is a model quirk and needs to be handled generally 93 | # if the text does not end with a space, the model should continue generating the last word directly 94 | # https://github.com/sign/WeLT/issues/2 95 | text += " " 96 | 97 | return self.pretokenizer.tokenize(text) 98 | 99 | def pretokenize_dataset(self, dataset: Dataset) -> Dataset: 100 | """Pretokenize a dataset in place, adding a 'words' column.""" 101 | 102 | def tokenize_example(example): 103 | example["words"] = self.pretokenize(example["text"]) 104 | return example 105 | 106 | return dataset.map(tokenize_example, 107 | batched=False, 108 | remove_columns=["text"], 109 | desc="Pretokenizing texts into 'words'") 110 | 111 | def get_sequence_labels(self, words: list[str], seq_lengths: list[int] = None, pack=True) -> list[str]: 112 | """ 113 | Generate labels for word-level sequences. 114 | 115 | Tokens inside shift blocks (between ShiftOut and ShiftIn control tokens) are masked 116 | with empty labels to prevent training on "known" tokens that are already visible via 117 | self-attention. The ShiftIn token itself keeps its label to predict the next word. 118 | 119 | Args: 120 | words: List of word strings to generate labels for 121 | seq_lengths: Optional list of sequence lengths for packed sequences 122 | pack: If True, use packed mode (longer context labels), else unpacked (next word only) 123 | 124 | Returns: 125 | List of label strings corresponding to each word 126 | """ 127 | if seq_lengths is None: 128 | seq_lengths = [len(words)] 129 | 130 | labels = [] 131 | 132 | # Process each sequence separately, to support efficient packing 133 | offset = 0 134 | for length in seq_lengths: 135 | if pack: 136 | # Next several characters as label, last word has no label 137 | segment_words = words[offset:offset + length] 138 | text = "".join(segment_words) 139 | label_idx = 0 140 | 141 | for word in segment_words: 142 | label_idx += len(word) 143 | 144 | # For efficiency, we don't just use the next word as label, but a longer token string 145 | # max_word_length characters, not bytes, will be trimmed by the tokenizer later 146 | label = text[label_idx:label_idx + self.max_word_length] 147 | # TODO: remove once https://github.com/sign/WeLT/issues/2 is solved 148 | label = label.rstrip() # Remove trailing spaces to avoid generating them 149 | 150 | labels.append(label) 151 | else: 152 | # Next word as label, last word has no label 153 | raw_labels = words[offset + 1:offset + length] + [""] 154 | # Truncate labels to max_word_length (characters, not bytes) 155 | labels += [label[:self.max_word_length] for label in raw_labels] 156 | 157 | # TODO: remove once https://github.com/sign/WeLT/issues/2 is solved 158 | labels[-2] = labels[-2].rstrip() # Remove last trailing space to avoid generating it 159 | 160 | offset += length 161 | 162 | # Mask labels inside shift blocks (except for ShiftIn token) 163 | for start, end in get_shift_blocks(words): 164 | for i in range(start, end): # Excludes end (ShiftIn token) 165 | labels[i] = "" 166 | 167 | return labels 168 | 169 | def tokenize_words(self, words: list[str], device=None): 170 | return self.tokenizer.torch( 171 | words, 172 | padding=True, 173 | add_special_tokens=True, 174 | device=device, 175 | # Truncation happens in pre-tokenization. 176 | # This is just for additional safety: 177 | max_length=self.max_word_length, 178 | truncation=True, 179 | ) 180 | 181 | def process_single_example(self, words: list[str], seq_lengths: list[int], pack=True): 182 | labels = self.get_sequence_labels(words, seq_lengths, pack=pack) 183 | 184 | # Tokenize words with BOS and EOS tokens 185 | tokenized = self.tokenize_words(words) # Tokenized inputs 186 | tokenized_labels = self.tokenize_words(labels) # Tokenized outputs 187 | 188 | # Render images 189 | input_images, input_images_dimensions = self.render_texts(words) 190 | 191 | return { 192 | "input_ids": tokenized.input_ids, 193 | "input_attention_mask": tokenized.attention_mask, # Attention within each word 194 | # Attention across words 195 | "attention_mask": get_attention_mask_for_packed_sequence(seq_lengths, words=words), 196 | "position_ids": get_position_ids_for_packed_sequence(seq_lengths), 197 | "input_images": input_images, 198 | "input_images_dimensions": input_images_dimensions, 199 | "labels_input": tokenized_labels.input_ids[:, :-1], # Remove EOS token from input labels 200 | "labels_attention_mask": tokenized_labels.attention_mask[:, :-1], # Remove EOS token from attention mask 201 | "labels_output": tokenized_labels.input_ids[:, 1:] # Remove BOS token from output labels 202 | } 203 | 204 | def __call__(self, 205 | batch: dict[str, list[str]] | str | list[str], 206 | collated=False, 207 | packed=False) -> dict[str, torch.Tensor]: 208 | if isinstance(batch, str): 209 | batch = {"text": [batch]} 210 | 211 | if isinstance(batch, list): 212 | batch = {"text": batch} 213 | 214 | # Copy batch before modifying to avoid mutating the input 215 | if "text" in batch and "words" not in batch: 216 | batch = batch.copy() 217 | words = [self.pretokenize(t) for t in batch["text"]] 218 | batch["words"] = words 219 | batch["seq_lengths"] = [[len(w)] for w in words] 220 | 221 | dicts = [self.process_single_example(words=words, seq_lengths=seq_lengths, pack=packed) 222 | for words, seq_lengths in zip(batch["words"], batch["seq_lengths"], strict=False)] 223 | 224 | if collated: 225 | return collate_fn(dicts) 226 | 227 | new_batch = {} 228 | for key in dicts[0].keys(): 229 | new_batch[key] = [d[key] for d in dicts] 230 | 231 | # Preserve extra fields from the original batch (e.g., "prefix", "completion") 232 | for key in batch: 233 | if key not in new_batch and key not in {"text", "words", "seq_lengths"}: 234 | new_batch[key] = batch[key] 235 | 236 | return new_batch 237 | 238 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | import torch 5 | from datasets import Dataset 6 | from safetensors.torch import load_model, save_model 7 | from transformers.modeling_outputs import CausalLMOutput 8 | 9 | from welt.model import WordLatentTransformer 10 | from welt.model_utils import setup_model 11 | 12 | 13 | def setup_tiny_model( 14 | image_encoder_name="WinKawaks/vit-tiny-patch16-224", 15 | bytes_encoder_name="prajjwal1/bert-tiny", 16 | latent_transformer_name="sbintuitions/tiny-lm", 17 | bytes_decoder_name="sign/utf8-lm-tiny", 18 | **kwargs): 19 | """Set up a tiny version of the WordLatentTransformer model for testing, the tinyer the better.""" 20 | return setup_model( 21 | image_encoder_name=image_encoder_name, 22 | bytes_encoder_name=bytes_encoder_name, 23 | latent_transformer_name=latent_transformer_name, 24 | bytes_decoder_name=bytes_decoder_name, 25 | load_pretrained=False, 26 | **kwargs 27 | ) 28 | 29 | 30 | def make_dataset(texts: list[str]): 31 | """Create a dataset from a list of texts.""" 32 | return Dataset.from_dict({"text": texts}) 33 | 34 | 35 | def dataset_to_batch(model, processor, collator, dataset): 36 | # Compute losses for each sequence - process entire batch at once 37 | device = next(model.parameters()).device 38 | dataset = dataset.with_transform(processor) 39 | batch = collator([dataset[i] for i in range(len(dataset))]) 40 | # Move batch to the same device as the model 41 | return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} 42 | 43 | 44 | def predict_dataset(texts: list[str], model, processor, collator): 45 | """Predict a dataset and return the logits.""" 46 | dataset = make_dataset(texts) 47 | batch = dataset_to_batch(model, processor, collator, dataset) 48 | 49 | with torch.no_grad(): 50 | outputs = model(**batch) 51 | 52 | logits = outputs.logits 53 | labels = batch['labels_output'] 54 | 55 | output_per_text = {} 56 | losses = {} 57 | for i, text in enumerate(texts): 58 | # Compute cross entropy loss for this item 59 | item_loss = torch.nn.functional.cross_entropy(input=logits[i].reshape(-1, logits.size(-1)), 60 | target=labels[i].reshape(-1), 61 | ignore_index=processor.tokenizer.pad_token_type_id, 62 | reduction='none') 63 | # Reduce loss based on ignore_index 64 | losses[text] = item_loss.sum().item() / (labels[i] != processor.tokenizer.pad_token_type_id).sum().item() 65 | print(f"Loss for '{text}': {losses[text]:.4f}") 66 | 67 | output_per_text[text] = CausalLMOutput( 68 | loss=item_loss, 69 | logits=logits[i], 70 | hidden_states=(outputs.hidden_states[0][i],), 71 | attentions=None 72 | ) 73 | 74 | return losses, output_per_text 75 | 76 | 77 | def test_attention_no_look_ahead(): 78 | """Test that attention does not look ahead - causal masking is working correctly.""" 79 | model, processor, collator = setup_model() 80 | model.eval() 81 | 82 | # Test sequences that share prefixes 83 | texts = ["a b c x y z", "a b d m"] 84 | 85 | # Force every word to predict a single byte (and EOS) 86 | # "a , b , c , " and "a , b , d , " 87 | processor.max_word_length = 1 88 | 89 | _, outputs = predict_dataset(texts, model, processor, collator) 90 | for text in texts: 91 | print(f"Loss for '{text}':", outputs[text].loss.cpu().numpy()) 92 | 93 | # Check that the first 4 tokens have identical losses 94 | for i in range(4): 95 | assert abs(outputs[texts[0]].loss[i] - outputs[texts[1]].loss[i]) < 1e-4, \ 96 | f"Loss at position {i} should be identical: {outputs[texts[0]].loss[i]} vs {outputs[texts[1]].loss[i]}" 97 | 98 | 99 | def test_attention_does_look_back(): 100 | """Test that attention does look back - model uses previous context.""" 101 | model, processor, collator = setup_model() 102 | model.eval() 103 | 104 | # Test sequences with shared suffix but different prefix 105 | texts = ["c b a", "d b a"] 106 | 107 | # Force every word to predict a single byte and special tokens 108 | processor.max_word_length = 3 # # (BOS + 1 byte + EOS) 109 | 110 | _, outputs = predict_dataset(texts, model, processor, collator) 111 | for text in texts: 112 | print(f"Loss for '{text}':", outputs[text].loss.cpu().numpy()) 113 | 114 | # Check that ALL positions have different losses due to different context 115 | for i in range(7): # Check all 7 positions (excluding padding) 116 | loss_diff = abs(outputs[texts[0]].loss[i] - outputs[texts[1]].loss[i]) 117 | assert loss_diff > 1e-4, \ 118 | (f"Loss at position {i} should be different due to context: " 119 | f"{outputs[texts[0]].loss[i]} vs {outputs[texts[1]].loss[i]} (diff: {loss_diff})") 120 | 121 | 122 | DEVICES = ["cpu"] 123 | if torch.cuda.is_available(): 124 | DEVICES.append("cuda") 125 | if torch.backends.mps.is_available(): 126 | DEVICES.append("mps") 127 | 128 | 129 | @pytest.mark.parametrize("device", DEVICES) 130 | def test_multiple_texts_batch_not_nan(device): 131 | """Test that attention does look back - model uses previous context.""" 132 | model, processor, collator = setup_model() 133 | # model.eval() 134 | 135 | # Move model to specified device 136 | model = model.to(torch.device(device)) 137 | 138 | # Test sequences with shared suffix but different prefix 139 | texts = ["1", "1 2 3"] 140 | 141 | dataset = make_dataset(texts) 142 | batch = dataset_to_batch(model, processor, collator, dataset) 143 | 144 | # TODO: this fails on mps device, because of the attention mask 145 | # ONLY when no_grad is used https://github.com/pytorch/pytorch/issues/167515 146 | with torch.no_grad(): 147 | outputs = model(**batch) 148 | assert not torch.isnan(outputs.loss).any(), "Loss contains NaN values" 149 | 150 | 151 | def test_loss_is_independent_of_batch(): 152 | """Test that loss at first position is identical regardless of other items in batch.""" 153 | model, processor, collator = setup_model() 154 | model.eval() 155 | 156 | batches = [ 157 | # Run first batch with just "a" 158 | ["a"], 159 | # Run second batch with "a" and additional text 160 | ["a", "2 w"], 161 | # Run third batch with "a" and additional longer text 162 | ["a", "two words"], 163 | ] 164 | outputs = [predict_dataset(batch, model, processor, collator)[1] for batch in batches] 165 | 166 | # Get the loss for "a" from both batches 167 | losses = [outputs[i]["a"].loss[0].item() for i in range(len(outputs))] 168 | max_loss = max(losses) 169 | losses = [loss / max_loss for loss in losses] # Normalize losses for comparison, across different models 170 | 171 | # Check that the loss at the first position (first token) is nearly identical 172 | # Note: losses[0] and losses[1] should be the same 173 | # since they're both predicting the same token with the same context 174 | # Small numerical differences are acceptable due to batching implementation details 175 | assert abs(losses[0] - losses[1]) < 1e-3, \ 176 | f"Loss at first position should be nearly identical: {losses[0]} vs {losses[1]}" 177 | 178 | assert abs(losses[0] - losses[2]) < 1e-3, \ 179 | f"Loss at first position should be nearly identical: {losses[0]} vs {losses[2]}" 180 | 181 | print(f"✓ Loss at first position is batch-independent: {losses[0]:.4f}") 182 | 183 | 184 | def num_model_params(model): 185 | return sum(p.numel() for p in model.parameters()) 186 | 187 | 188 | def test_model_save_and_load_works(): 189 | """Test that the model can be saved and loaded without issues.""" 190 | model, processor, collator = setup_tiny_model() 191 | 192 | with tempfile.NamedTemporaryFile(suffix=".safetensors") as temp_file: 193 | original_num_parameters = num_model_params(model) 194 | save_model(model, temp_file.name) 195 | load_model(model, temp_file.name) 196 | loaded_num_parameters = num_model_params(model) 197 | assert original_num_parameters == loaded_num_parameters, \ 198 | f"Number of parameters mismatch: {original_num_parameters:,} vs {loaded_num_parameters:,}" 199 | 200 | 201 | def test_model_from_pretrained_works(): 202 | """Test that the model can be saved and loaded without issues.""" 203 | model, processor, collator = setup_tiny_model() 204 | 205 | with tempfile.TemporaryDirectory() as temp_dir: 206 | original_num_parameters = num_model_params(model) 207 | model.save_pretrained(save_directory=temp_dir, push_to_hub=False) 208 | 209 | new_model = WordLatentTransformer.from_pretrained(temp_dir) 210 | loaded_num_parameters = num_model_params(new_model) 211 | 212 | assert original_num_parameters == loaded_num_parameters, \ 213 | f"Number of parameters mismatch: {original_num_parameters:,} vs {loaded_num_parameters:,}" 214 | 215 | 216 | def test_freeze_unfreeze_model_works(): 217 | """Test that freezing the model works correctly.""" 218 | model, processor, collator = setup_tiny_model() 219 | 220 | model.freeze_pretrained_models() 221 | 222 | for name, param in model.latent_transformer.named_parameters(): 223 | assert not param.requires_grad, f"Parameter {name} should be frozen but is unfrozen." 224 | 225 | for layer in [model.encoder_mapping, model.decoder_mapping]: 226 | for name, param in layer.named_parameters(): 227 | assert param.requires_grad, f"Parameter {name} should be unfrozen but is frozen." 228 | 229 | model.unfreeze() 230 | 231 | for name, param in model.named_parameters(): 232 | assert param.requires_grad, f"Parameter {name} should be unfrozen but is frozen." 233 | 234 | 235 | def test_model_from_pretrained_works_without_image_encoder(): 236 | """Test that the model can be saved and loaded without issues.""" 237 | model, processor, collator = setup_model(image_encoder_name=None) 238 | 239 | with tempfile.TemporaryDirectory() as temp_dir: 240 | original_num_parameters = num_model_params(model) 241 | model.save_pretrained(save_directory=temp_dir, push_to_hub=False) 242 | 243 | new_model = WordLatentTransformer.from_pretrained(temp_dir) 244 | loaded_num_parameters = num_model_params(new_model) 245 | 246 | assert original_num_parameters == loaded_num_parameters, \ 247 | f"Number of parameters mismatch: {original_num_parameters:,} vs {loaded_num_parameters:,}" 248 | 249 | 250 | def test_model_from_pretrained_works_without_bytes_encoder(): 251 | """Test that the model can be saved and loaded without bytes encoder.""" 252 | model, processor, collator = setup_model(bytes_encoder_name=None) 253 | 254 | with tempfile.TemporaryDirectory() as temp_dir: 255 | original_num_parameters = num_model_params(model) 256 | model.save_pretrained(save_directory=temp_dir, push_to_hub=False) 257 | 258 | new_model = WordLatentTransformer.from_pretrained(temp_dir) 259 | loaded_num_parameters = num_model_params(new_model) 260 | 261 | assert original_num_parameters == loaded_num_parameters, \ 262 | f"Number of parameters mismatch: {original_num_parameters:,} vs {loaded_num_parameters:,}" 263 | 264 | 265 | if __name__ == "__main__": 266 | pytest.main([__file__, "-v"]) 267 | -------------------------------------------------------------------------------- /tests/test_processor.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tempfile 3 | 4 | import pytest 5 | import torch 6 | from datasets import Dataset 7 | from font_download import FontConfig 8 | from font_download.example_fonts.noto_sans import FONTS_NOTO_SANS 9 | from pixel_renderer import PixelRendererProcessor 10 | from trl.data_utils import pack_dataset 11 | from utf8_tokenizer.control import ControlTokens 12 | from utf8_tokenizer.tokenizer import UTF8Tokenizer 13 | from words_segmentation.tokenizer import WordsSegmentationTokenizer 14 | 15 | from tests.test_model import setup_tiny_model 16 | from welt.noop import NoopImageProcessor 17 | from welt.processor import TextImageProcessor 18 | 19 | 20 | @pytest.fixture(scope="module") 21 | def processor(): 22 | model, processor, collator = setup_tiny_model() 23 | return processor 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def renderer(): 28 | font_config = FontConfig(sources=FONTS_NOTO_SANS) 29 | return PixelRendererProcessor(font=font_config) 30 | 31 | 32 | expected_tensor_keys = ["input_ids", "input_attention_mask", "attention_mask", "position_ids", 33 | "labels_input", "labels_attention_mask", "labels_output", 34 | "input_images", "input_images_dimensions"] 35 | expected_keys = expected_tensor_keys 36 | 37 | 38 | def test_processor_multiprocessing_pickle(processor): 39 | # Processor should be pickleable for multiprocessing 40 | pickle.dumps(processor) 41 | 42 | 43 | def test_processor_single_text_collated(processor): 44 | text = "example text for testing" 45 | inputs = processor(text, collated=True) 46 | 47 | assert all(key in inputs for key in expected_keys) 48 | assert all(isinstance(inputs[key], torch.Tensor) for key in expected_tensor_keys) 49 | 50 | 51 | def test_processor_single_text_not_collated(processor): 52 | text = "example text for testing" 53 | inputs = processor(text) 54 | assert all(key in inputs for key in expected_keys) 55 | assert all(isinstance(inputs[key], list) and len(inputs[key]) == 1 for key in expected_tensor_keys) 56 | 57 | 58 | def test_processor_single_text_value(processor): 59 | text = "a b" 60 | inputs = processor(text, packed=True) 61 | assert torch.equal(inputs["input_ids"][0], torch.tensor([[2, 2, 3, 0], [2, 97, 32, 3], [2, 98, 32, 3]])) 62 | assert inputs["input_attention_mask"][0].shape == (3, 4) 63 | assert inputs["attention_mask"][0].shape == (1, 3, 3) 64 | assert torch.equal(inputs["position_ids"][0], torch.tensor([0, 1, 2])) 65 | assert torch.equal(inputs["labels_input"][0], torch.tensor([[2, 97, 32, 98], [2, 98, 3, 0], [2, 3, 0, 0]])) 66 | assert torch.equal(inputs["labels_output"][0], torch.tensor([[97, 32, 98, 3], [98, 3, 0, 0], [3, 0, 0, 0]])) 67 | 68 | 69 | def test_processor_list_format_collated(processor): 70 | text = "example text for testing" 71 | inputs = processor([text], collated=True) 72 | assert all(key in inputs for key in expected_keys) 73 | assert all(isinstance(inputs[key], torch.Tensor) for key in expected_tensor_keys) 74 | 75 | 76 | def test_processor_object_format_collated(processor): 77 | text = "example text for testing" 78 | inputs = processor({"text": text}, collated=True) 79 | assert all(key in inputs for key in expected_keys) 80 | assert all(isinstance(inputs[key], torch.Tensor) for key in expected_tensor_keys) 81 | 82 | 83 | def test_processor_multiple_strings_collated_attention_mask(processor): 84 | texts = ["one", "two words", "three word test"] 85 | inputs = processor(texts, collated=True) 86 | assert all(key in inputs for key in expected_keys) 87 | assert all(isinstance(inputs[key], torch.Tensor) for key in expected_tensor_keys) 88 | 89 | assert inputs["attention_mask"].shape == (3, 1, 4, 4) 90 | 91 | expected = [ 92 | torch.tensor([ 93 | [True, False, False, False], 94 | [True, True, False, False], 95 | [False, False, False, False], 96 | [False, False, False, False] 97 | ]), 98 | torch.tensor([ 99 | [True, False, False, False], 100 | [True, True, False, False], 101 | [True, True, True, False], 102 | [False, False, False, False] 103 | ]), 104 | torch.tensor([ 105 | [True, False, False, False], 106 | [True, True, False, False], 107 | [True, True, True, False], 108 | [True, True, True, True] 109 | ]) 110 | ] 111 | 112 | for mask, expected_mask in zip(inputs["attention_mask"], expected, strict=False): 113 | assert torch.equal(mask[0], expected_mask) 114 | 115 | 116 | def test_processor_packed_vs_unpacked_labels(processor): 117 | text = "hello world test" 118 | 119 | # Test packed=True (default) 120 | inputs_packed = processor(text, collated=True, packed=True) 121 | 122 | # Test packed=False 123 | inputs_unpacked = processor(text, collated=True, packed=False) 124 | 125 | # Both should have same structure 126 | assert all(key in inputs_packed for key in expected_keys) 127 | assert all(key in inputs_unpacked for key in expected_keys) 128 | 129 | # Input tokens should be the same 130 | assert torch.equal(inputs_packed["input_ids"], inputs_unpacked["input_ids"]) 131 | assert torch.equal(inputs_packed["attention_mask"], inputs_unpacked["attention_mask"]) 132 | 133 | # Labels should be different due to different packing strategies 134 | assert not torch.equal(inputs_packed["labels_input"], inputs_unpacked["labels_input"]) 135 | assert not torch.equal(inputs_packed["labels_output"], inputs_unpacked["labels_output"]) 136 | 137 | 138 | def test_processor_packed_false_default_behavior(processor): 139 | text = "example text for testing" 140 | 141 | # Default should be packed=True 142 | inputs_default = processor(text, collated=True) 143 | inputs_explicit_packed = processor(text, collated=True, packed=False) 144 | 145 | # Should be identical 146 | assert torch.equal(inputs_default["labels_input"], inputs_explicit_packed["labels_input"]) 147 | assert torch.equal(inputs_default["labels_output"], inputs_explicit_packed["labels_output"]) 148 | 149 | 150 | def test_get_words_and_labels_packed_vs_unpacked(processor): 151 | text = "hello world test" 152 | words = processor.pretokenize(text) 153 | 154 | # Test packed=True 155 | labels_packed = processor.get_sequence_labels(words, pack=True) 156 | 157 | # Test packed=False 158 | labels_unpacked = processor.get_sequence_labels(words, pack=False) 159 | 160 | # Labels should be different 161 | assert labels_packed != labels_unpacked 162 | 163 | assert labels_packed == ['hello world test', 'world test', 'test', ''] 164 | assert labels_unpacked == ['hello ', 'world ', 'test', ''] 165 | 166 | 167 | def test_render_images_shape(processor): 168 | texts = ["short", "a bit longer text"] 169 | renders, dimensions = processor.render_texts(texts) 170 | 171 | assert renders.shape == (2, 3, 16, 112) 172 | assert torch.equal(dimensions, torch.tensor([[16, 48], [16, 112]])) 173 | 174 | 175 | def test_pretokenize_splits_control_tokens(processor): 176 | text = (f"{ControlTokens.ShiftOut}test{ControlTokens.ShiftIn}" 177 | f"{ControlTokens.StartOfHeading}hello {ControlTokens.EndOfText}") 178 | words = processor.pretokenize(text) 179 | assert words == [ 180 | ControlTokens.StartOfText, # BOS is added by pretokenize 181 | ControlTokens.ShiftOut, 'test', ControlTokens.ShiftIn, 182 | ControlTokens.StartOfHeading, "hello ", ControlTokens.EndOfText, 183 | " " # Space is added by pretokenize 184 | ] 185 | 186 | 187 | def test_pretokenize_multiple_whitespace(processor): 188 | text = """ 189 | def foo(): 190 | return "bar" 191 | """.strip() 192 | words = processor.pretokenize(text) 193 | assert words == [ControlTokens.StartOfText, "def ", "foo():\n", " " * 8, 'return ', '"bar" '] 194 | 195 | 196 | def test_get_words_and_labels_packed_vs_unpacked_respect_max_word_length(processor, renderer): 197 | text = "this is a long-test" 198 | words = processor.pretokenize(text) 199 | 200 | new_processor = TextImageProcessor( 201 | pretokenizer=WordsSegmentationTokenizer(), 202 | tokenizer=processor.tokenizer, 203 | renderer=renderer, 204 | image_processor=processor.image_processor, 205 | max_word_length=3 206 | ) 207 | 208 | # Test packed=True 209 | labels_packed = new_processor.get_sequence_labels(words, pack=True) 210 | 211 | # Test packed=False 212 | labels_unpacked = new_processor.get_sequence_labels(words, pack=False) 213 | 214 | assert labels_packed == ['thi', 'is', 'a l', 'lon', ''] 215 | assert labels_unpacked == ['thi', 'is ', 'a ', 'lon', ''] 216 | 217 | 218 | def test_pretokenize_dataset(processor): 219 | texts = [ 220 | "hi!", 221 | "hello world", 222 | ] 223 | dataset = Dataset.from_dict({"text": texts}) 224 | dataset = processor.pretokenize_dataset(dataset) 225 | 226 | assert dataset[:] == { 227 | 'words': [ 228 | [ControlTokens.StartOfText, 'hi! '], 229 | [ControlTokens.StartOfText, 'hello ', 'world '], 230 | ], 231 | } 232 | 233 | 234 | def test_packed_dataset(processor): 235 | texts = [ 236 | "hi!", 237 | "hello world", 238 | "yes.", 239 | "a b c" 240 | ] 241 | dataset = Dataset.from_dict({"text": texts}) 242 | dataset = processor.pretokenize_dataset(dataset) 243 | packed_dataset = pack_dataset(dataset, seq_length=7) 244 | 245 | assert packed_dataset[:] == { 246 | 'seq_lengths': [ 247 | [4, 3], 248 | [2, 2], 249 | ], 250 | 'words': [ 251 | [ 252 | ControlTokens.StartOfText, 'a ', 'b ', 'c ', 253 | ControlTokens.StartOfText, 'hello ', 'world ', 254 | ], 255 | [ 256 | ControlTokens.StartOfText, 'hi! ', 257 | ControlTokens.StartOfText, 'yes. ', 258 | ], 259 | ], 260 | } 261 | 262 | 263 | def test_packed_dataset_labels_independent(processor): 264 | texts = [ 265 | "a b", 266 | "c d", 267 | ] 268 | dataset = Dataset.from_dict({"text": texts}) 269 | dataset = processor.pretokenize_dataset(dataset) 270 | packed_dataset = pack_dataset(dataset, seq_length=8) 271 | 272 | datum = next(iter(packed_dataset)) 273 | labels = processor.get_sequence_labels(datum["words"], datum["seq_lengths"], pack=True) 274 | 275 | assert labels == [ 276 | 'a b', 'b', '', 277 | 'c d', 'd', '' 278 | ] 279 | 280 | 281 | def test_processor_works_on_packed_sequence(processor): 282 | texts = [ 283 | "hi!", 284 | "hello world", 285 | "yes.", 286 | "a b c" 287 | ] 288 | dataset = Dataset.from_dict({"text": texts}) 289 | dataset = processor.pretokenize_dataset(dataset) 290 | packed_dataset = pack_dataset(dataset, seq_length=8) 291 | 292 | transformed_dataset = packed_dataset.with_transform(processor) 293 | for inputs in transformed_dataset: 294 | assert all(key in inputs for key in expected_keys) 295 | assert all(isinstance(inputs[key], torch.Tensor) for key in expected_tensor_keys) 296 | 297 | 298 | def test_processor_save_and_load_works(processor): 299 | with tempfile.TemporaryDirectory() as temp_dir: 300 | processor.save_pretrained(save_directory=temp_dir, push_to_hub=False) 301 | new_processor = TextImageProcessor.from_pretrained(temp_dir) 302 | 303 | for attr in processor.attributes: 304 | assert getattr(new_processor, attr) is not None 305 | assert getattr(new_processor, attr).__class__.__name__ == getattr(processor, attr).__class__.__name__ 306 | 307 | 308 | def test_processor_save_and_load_works_without_image_processor(renderer): 309 | processor = TextImageProcessor( 310 | pretokenizer=WordsSegmentationTokenizer(), 311 | tokenizer=UTF8Tokenizer(), 312 | renderer=renderer, 313 | image_processor=NoopImageProcessor()) 314 | 315 | with tempfile.TemporaryDirectory(delete=False) as temp_dir: 316 | print(temp_dir) 317 | processor.save_pretrained(save_directory=temp_dir, push_to_hub=False) 318 | new_processor = TextImageProcessor.from_pretrained(temp_dir) 319 | assert isinstance(new_processor.image_processor, NoopImageProcessor) 320 | 321 | 322 | def test_labels_masked_in_shift_blocks_packed(processor): 323 | """Test that labels are empty for tokens inside shift blocks (except ShiftIn itself).""" 324 | # Use f-string template and let processor segment into words 325 | text = f"{ControlTokens.ShiftOut}hello{ControlTokens.ShiftIn} שלום" 326 | words = processor.pretokenize(text) 327 | 328 | labels = processor.get_sequence_labels(words, pack=False) 329 | 330 | # Expected: BOS, "", SO, "hello", SI, " ", "שלום " 331 | # Labels should be empty for tokens inside shift blocks (SO and "hello") 332 | # ShiftIn keeps its label to predict next word 333 | 334 | # Check exact label content 335 | assert labels[0] == "" # BOS -> "" 336 | assert labels[1] == ControlTokens.ShiftOut # "" -> SO 337 | assert labels[2] == "" # SO -> "hello" (inside block, masked) 338 | assert labels[3] == "" # "hello" -> SI (inside block, masked) 339 | assert labels[4] == " " # SI -> " " (exits block, has label) 340 | assert labels[5] == "שלום" # " " -> "שלום " (rstripped) 341 | assert labels[6] == "" # Last token always empty 342 | 343 | 344 | def test_labels_masked_in_shift_blocks_unpacked(processor): 345 | """Test that labels are empty for tokens inside shift blocks in unpacked mode.""" 346 | words = [ 347 | ControlTokens.StartOfText, 348 | "", ControlTokens.ShiftOut, "hello", ControlTokens.ShiftIn, 349 | "", "שלום" 350 | ] 351 | 352 | labels = processor.get_sequence_labels(words, pack=False) 353 | 354 | # Expected behavior (unpacked mode): 355 | # - Each token predicts the next token 356 | # - Inside shift blocks, labels should be empty except for ShiftIn 357 | 358 | assert labels[0] # BOS -> "" 359 | assert labels[1] # "" -> ShiftOut 360 | assert labels[2] == "" # ShiftOut -> "hello" (inside block, no label) 361 | assert labels[3] == "" # "hello" -> ShiftIn (inside block, no label) 362 | assert labels[4] # ShiftIn -> "" (exits block, should have label) 363 | assert labels[5] # "" -> "שלום" 364 | assert labels[6] == "" # "שלום" -> (second-to-last, rstripped) 365 | 366 | 367 | def test_multiple_shift_blocks(processor): 368 | """Test handling of multiple shift blocks in a sequence.""" 369 | words = [ 370 | ControlTokens.StartOfText, 371 | ControlTokens.ShiftOut, "first", "block", ControlTokens.ShiftIn, 372 | "middle", "token", 373 | ControlTokens.ShiftOut, "second", "block", ControlTokens.ShiftIn, 374 | "end" 375 | ] 376 | 377 | labels = processor.get_sequence_labels(words, pack=True) 378 | 379 | # ShiftOut and content inside blocks should have empty labels 380 | # ShiftIn tokens should have labels (to predict next word) 381 | assert labels[0] # BOS 382 | assert labels[1] == "" # ShiftOut (first block) 383 | assert labels[2] == "" # "first" (inside block) 384 | assert labels[3] == "" # "block" (inside block) 385 | assert labels[4] # ShiftIn (exits first block) 386 | assert labels[5] # "middle" (normal token) 387 | assert labels[6] # "token" (normal token) 388 | assert labels[7] == "" # ShiftOut (second block) 389 | assert labels[8] == "" # "second" (inside block) 390 | assert labels[9] == "" # "block" (inside block) 391 | assert labels[10] # ShiftIn (exits second block) 392 | assert labels[11] == "" # "end" (second-to-last, rstripped) 393 | 394 | 395 | if __name__ == "__main__": 396 | pytest.main([__file__, "-v"]) 397 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | # Heavily adapted from 2 | # https://github.com/huggingface/transformers/edit/main/examples/pytorch/language-modeling/run_clm.py 3 | import logging 4 | import math 5 | import os 6 | import sys 7 | 8 | import datasets 9 | import torch 10 | import transformers 11 | from datasets import IterableDataset, IterableDatasetDict, load_dataset 12 | from safetensors.torch import load_model 13 | from transformers import ( 14 | HfArgumentParser, 15 | TrainingArguments, 16 | set_seed, 17 | ) 18 | from transformers.trainer_utils import get_last_checkpoint 19 | from trl import pack_dataset 20 | 21 | from training.args_data import DataTrainingArguments 22 | from training.args_model import ModelArguments 23 | from training.args_trainer import WeLTTrainingArguments 24 | from training.extendable_yaml import resolve_yaml_file 25 | from training.freeze_callback import FreezeWarmupCallback 26 | from training.trainer import WeLTTrainer 27 | from welt.model_utils import setup_model 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | def enable_optimizations(): 33 | torch.backends.cudnn.benchmark = True 34 | torch.backends.cuda.enable_flash_sdp(True) 35 | torch.backends.cuda.enable_mem_efficient_sdp(True) 36 | torch.backends.cuda.enable_math_sdp(False) 37 | 38 | # For debugging purposes only: 39 | # torch.autograd.set_detect_anomaly(True) 40 | 41 | # TODO --use_cuda_graphs true ? 42 | 43 | # TODO pin_memory_device="cuda" (PyTorch ≥2.3) 44 | 45 | # TODO 46 | # torch_compile=True, 47 | # torch_compile_backend="inductor", 48 | # torch_compile_mode="default", 49 | 50 | # TODO use accelerate launch 51 | 52 | 53 | def split_streaming_dataset( 54 | full_streaming_dataset, 55 | validation_percentage: int = 5, 56 | ) -> IterableDatasetDict: 57 | """ 58 | Splits a streaming dataset into 59 | training and validation IterableDatasets, and supports methods like .map(), .filter(), 60 | .take() and properties like .features on the resulting streams. 61 | 62 | Args: 63 | full_streaming_dataset (Dataset): The name of the dataset to load (e.g., "HuggingFaceFW/fineweb"). 64 | validation_percentage (int): The proportion of the dataset to be used for validation split. 65 | 66 | Returns: 67 | IterableDatasetDict: An IterableDatasetDict containing 68 | two IterableDataset objects: (train_stream, validation_stream). 69 | """ 70 | if not (0 < validation_percentage < 100): 71 | raise ValueError( # noqa: TRY003 72 | f"validation_percentage must be between 0 and 100 (exclusive). Passed: {validation_percentage}" 73 | ) 74 | 75 | def split_generator(is_train: bool): 76 | for i, example in enumerate(full_streaming_dataset): 77 | if is_train: 78 | if i % 100 > validation_percentage: 79 | yield example 80 | else: 81 | if i % 100 < validation_percentage: 82 | yield example 83 | 84 | features = full_streaming_dataset.features 85 | train_stream = IterableDataset.from_generator(split_generator, gen_kwargs={"is_train": True}, features=features) 86 | validation_stream = IterableDataset.from_generator( 87 | split_generator, gen_kwargs={"is_train": False}, features=features 88 | ) 89 | 90 | return IterableDatasetDict({"train": train_stream, "validation": validation_stream}) 91 | 92 | 93 | def parse_args_into_dataclasses(args: list[str] | None | str = None): 94 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, WeLTTrainingArguments)) 95 | # If we pass only one argument to the script and it's the path to a json or yaml file, 96 | # let's parse it to get our arguments. 97 | if isinstance(args, str): 98 | resolved_path = resolve_yaml_file(os.path.abspath(args)) 99 | return parser.parse_yaml_file(yaml_file=resolved_path) 100 | 101 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 102 | return parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 103 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 104 | resolved_path = resolve_yaml_file(os.path.abspath(sys.argv[1])) 105 | return parser.parse_yaml_file(yaml_file=resolved_path) 106 | else: 107 | return parser.parse_args_into_dataclasses(args=args) 108 | 109 | 110 | def init_logging(training_args: TrainingArguments): 111 | logging.basicConfig( 112 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 113 | datefmt="%m/%d/%Y %H:%M:%S", 114 | handlers=[logging.StreamHandler(sys.stdout)], 115 | ) 116 | 117 | if training_args.should_log: 118 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 119 | transformers.utils.logging.set_verbosity_info() 120 | 121 | log_level = training_args.get_process_log_level() 122 | logger.setLevel(log_level) 123 | datasets.utils.logging.set_verbosity(log_level) 124 | transformers.utils.logging.set_verbosity(log_level) 125 | transformers.utils.logging.enable_default_handler() 126 | transformers.utils.logging.enable_explicit_format() 127 | 128 | # Log on each process the small summary: 129 | logger.warning( 130 | f"Process rank: {training_args.local_rank}, " + 131 | f"device: {training_args.device}, " + 132 | f"n_gpu: {training_args.n_gpu}, " + 133 | f"distributed training: {training_args.parallel_mode.value == 'distributed'}, " + 134 | f"16-bits training: {training_args.fp16}" 135 | ) 136 | logger.info(f"Training/evaluation parameters {training_args}") 137 | 138 | 139 | def init_model(model_args: ModelArguments, data_args: DataTrainingArguments, seed: int): 140 | # Set seed before initializing model. 141 | set_seed(seed) 142 | 143 | # Initialize the model 144 | model, processor, collator = setup_model( 145 | image_encoder_name=model_args.image_encoder_model_name_or_path, 146 | bytes_encoder_name=model_args.bytes_encoder_model_name_or_path, 147 | latent_transformer_name=model_args.latent_transformer_model_name_or_path, 148 | bytes_decoder_name=model_args.bytes_decoder_model_name_or_path, 149 | trust_remote_code=model_args.trust_remote_code, 150 | dtype=model_args.dtype, 151 | seed=seed, 152 | load_pretrained=model_args.load_pretrained, 153 | max_word_length=data_args.max_word_length, 154 | pretokenizer_name=model_args.pretokenizer_name, 155 | ) 156 | 157 | # Load the model from a local path if provided 158 | if model_args.model_name_or_path: 159 | load_model(model, model_args.model_name_or_path) 160 | 161 | return model, processor, collator 162 | 163 | 164 | def detect_last_checkpoint(training_args: TrainingArguments): 165 | last_checkpoint = None 166 | if os.path.isdir(training_args.output_dir) and training_args.do_train: 167 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 168 | 169 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 170 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") 171 | 172 | return last_checkpoint 173 | 174 | 175 | def init_datasets(data_args: DataTrainingArguments, # noqa: C901 176 | trust_remote_code: bool, 177 | do_train: bool = True, 178 | cache_dir: str = None): 179 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 180 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 181 | # (the dataset will be downloaded automatically from the datasets Hub). 182 | # 183 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 184 | # 'text' is found. You can easily tweak this behavior (see below). 185 | # 186 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 187 | # download the dataset. 188 | 189 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 190 | # https://huggingface.co/docs/datasets/loading_datasets. 191 | 192 | if data_args.dataset_name is not None: 193 | # Downloading and loading a dataset from the hub. 194 | raw_datasets = load_dataset( 195 | data_args.dataset_name, 196 | data_args.dataset_config_name, 197 | cache_dir=cache_dir, 198 | streaming=data_args.streaming, 199 | trust_remote_code=trust_remote_code, 200 | ) 201 | if "validation" not in raw_datasets: 202 | if data_args.streaming: 203 | dataset_stream = load_dataset( 204 | data_args.dataset_name, 205 | data_args.dataset_config_name, 206 | split="train", 207 | cache_dir=cache_dir, 208 | streaming=data_args.streaming, 209 | trust_remote_code=trust_remote_code, 210 | ) 211 | raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) 212 | else: 213 | raw_datasets["validation"] = load_dataset( 214 | data_args.dataset_name, 215 | data_args.dataset_config_name, 216 | split=f"train[:{data_args.validation_split_percentage}%]", 217 | cache_dir=cache_dir, 218 | streaming=data_args.streaming, 219 | trust_remote_code=trust_remote_code, 220 | ) 221 | raw_datasets["train"] = load_dataset( 222 | data_args.dataset_name, 223 | data_args.dataset_config_name, 224 | split=f"train[{data_args.validation_split_percentage}%:]", 225 | cache_dir=cache_dir, 226 | streaming=data_args.streaming, 227 | trust_remote_code=trust_remote_code, 228 | ) 229 | else: 230 | data_files = {} 231 | dataset_args = {} 232 | if data_args.train_file is not None: 233 | data_files["train"] = data_args.train_file 234 | if data_args.validation_file is not None: 235 | data_files["validation"] = data_args.validation_file 236 | extension = ( 237 | data_args.train_file.split(".")[-1] 238 | if data_args.train_file is not None 239 | else data_args.validation_file.split(".")[-1] 240 | ) 241 | if extension == "txt": 242 | extension = "text" 243 | dataset_args["keep_linebreaks"] = data_args.keep_linebreaks 244 | raw_datasets = load_dataset( 245 | extension, 246 | data_files=data_files, 247 | cache_dir=cache_dir, 248 | **dataset_args, 249 | ) 250 | # If no validation data is there, validation_split_percentage will be used to divide the dataset. 251 | if "validation" not in raw_datasets: 252 | if data_args.streaming: 253 | dataset_stream = load_dataset( 254 | extension, 255 | data_files=data_files, 256 | split="train", 257 | cache_dir=cache_dir, 258 | **dataset_args, 259 | ) 260 | raw_datasets = split_streaming_dataset(dataset_stream, data_args.validation_split_percentage) 261 | else: 262 | raw_datasets["validation"] = load_dataset( 263 | extension, 264 | data_files=data_files, 265 | split=f"train[:{data_args.validation_split_percentage}%]", 266 | cache_dir=cache_dir, 267 | **dataset_args, 268 | ) 269 | 270 | raw_datasets["train"] = load_dataset( 271 | extension, 272 | data_files=data_files, 273 | split=f"train[{data_args.validation_split_percentage}%:]", 274 | cache_dir=cache_dir, 275 | **dataset_args, 276 | ) 277 | 278 | if do_train: 279 | column_names = list(raw_datasets["train"].features) 280 | else: 281 | column_names = list(raw_datasets["validation"].features) 282 | text_column_name = "text" if "text" in column_names else column_names[0] 283 | 284 | def process_split(dataset, split_name: str): 285 | """Apply mapping and filtering to a dataset split.""" 286 | template = data_args.dataset_text_template 287 | if template is None: 288 | def mapping_fn(example): 289 | return {"text": example[text_column_name]} 290 | else: 291 | is_single_text_template = isinstance(template, str) 292 | single_text_template = template \ 293 | if is_single_text_template else "".join(template) 294 | 295 | def mapping_fn(example): 296 | if is_single_text_template or split_name == "train": 297 | return {"text": single_text_template.format(**example)} 298 | 299 | prefix = template[0].format(**example) 300 | completion = template[1].format(**example) 301 | return { 302 | "text": f"{prefix}{completion}", # Full text for training loss calculation 303 | "prefix": prefix, # For generation 304 | "completion": completion, # Reference for metrics 305 | } 306 | 307 | map_args = {} 308 | if not data_args.streaming: 309 | map_args = { 310 | "num_proc": data_args.preprocessing_num_workers, 311 | "load_from_cache_file": not data_args.overwrite_cache, 312 | } 313 | 314 | dataset = dataset.map( 315 | mapping_fn, 316 | remove_columns=column_names, 317 | desc=f"Formatting {split_name} split", 318 | **map_args 319 | ) 320 | dataset = dataset.filter( 321 | lambda x: len(x["text"]) > 0, 322 | desc=f"Filtering empty examples from {split_name}", 323 | **map_args 324 | ) 325 | return dataset 326 | 327 | return {split: process_split(raw_datasets[split], split) for split in raw_datasets} 328 | 329 | 330 | def limit_dataset_size(dataset, max_samples: int | None = None, streaming: bool = False): 331 | if max_samples is not None: 332 | if streaming: 333 | dataset = dataset.take(max_samples) 334 | elif max_samples < len(dataset): 335 | dataset = dataset.select(range(max_samples)) 336 | 337 | return dataset 338 | 339 | 340 | 341 | 342 | def train(args: list[str] | None | str = None): # noqa: C901 343 | cache_dir = None # Use the default cache directory / Environment variable 344 | 345 | enable_optimizations() 346 | 347 | model_args, data_args, training_args = parse_args_into_dataclasses(args) 348 | 349 | init_logging(training_args) 350 | 351 | # Detecting last checkpoint. 352 | last_checkpoint = detect_last_checkpoint(training_args) 353 | 354 | # Initialize the model 355 | model, processor, collator = init_model(model_args, data_args, seed=training_args.seed) 356 | 357 | if data_args.max_sequence_length is not None: 358 | processor.max_seq_length = data_args.max_sequence_length 359 | 360 | # Save the processor to the output directory 361 | processor.save_pretrained(save_directory=training_args.output_dir, push_to_hub=False) 362 | 363 | # Load the datasets 364 | text_datasets = init_datasets(data_args, 365 | cache_dir=cache_dir, 366 | trust_remote_code=model_args.trust_remote_code, 367 | do_train=training_args.do_train) 368 | 369 | train_dataset = None 370 | if training_args.do_train: 371 | if "train" not in text_datasets: 372 | raise ValueError("--do_train requires a train dataset") # noqa: TRY003 373 | train_dataset = limit_dataset_size(text_datasets["train"], 374 | max_samples=data_args.max_train_samples, 375 | streaming=data_args.streaming) 376 | 377 | eval_dataset = None 378 | if training_args.do_eval: 379 | if "validation" not in text_datasets: 380 | raise ValueError("--do_eval requires a validation dataset") # noqa: TRY003 381 | eval_dataset = limit_dataset_size(text_datasets["validation"], 382 | max_samples=data_args.max_eval_samples, 383 | streaming=data_args.streaming) 384 | 385 | # Sequence packing 386 | if train_dataset: 387 | block_size = min(data_args.block_size or math.inf, processor.max_seq_length) 388 | train_dataset = processor.pretokenize_dataset(train_dataset) 389 | train_dataset = pack_dataset(train_dataset, seq_length=block_size) 390 | 391 | # Transform the datasets to the format expected by the model 392 | if train_dataset: 393 | train_dataset = train_dataset.with_transform(processor) 394 | if eval_dataset: 395 | eval_dataset = eval_dataset.with_transform(processor) 396 | 397 | # Initialize our Trainer 398 | # Note: WeLTTrainer computes accuracy and generation-based metrics internally 399 | trainer = WeLTTrainer( 400 | model=model, 401 | args=training_args, 402 | processor=processor, 403 | train_dataset=train_dataset, 404 | eval_dataset=eval_dataset, 405 | data_collator=collator, 406 | # Generation-based evaluation settings from training args 407 | eval_metrics=training_args.eval_metrics if training_args.do_eval else None, 408 | max_generated_words=training_args.generation_max_length or 50, 409 | log_samples=training_args.log_samples, 410 | ) 411 | 412 | # Freeze the pretrained models for some steps 413 | trainer.add_callback(FreezeWarmupCallback(steps=model_args.warmup_freeze_steps, model=model)) 414 | 415 | # Training 416 | if training_args.do_train: 417 | checkpoint = None 418 | if training_args.resume_from_checkpoint is not None: 419 | checkpoint = training_args.resume_from_checkpoint 420 | elif last_checkpoint is not None: 421 | checkpoint = last_checkpoint 422 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 423 | trainer.save_model() 424 | 425 | metrics = train_result.metrics 426 | 427 | max_train_samples = ( 428 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 429 | ) 430 | if data_args.streaming: 431 | metrics["train_samples"] = max_train_samples 432 | else: 433 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 434 | 435 | trainer.log_metrics("train", metrics) 436 | trainer.save_metrics("train", metrics) 437 | trainer.save_state() 438 | 439 | # Evaluation 440 | if training_args.do_eval: 441 | logger.info("*** Evaluate ***") 442 | 443 | metrics = trainer.evaluate() 444 | 445 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 446 | if data_args.streaming: 447 | metrics["eval_samples"] = max_eval_samples 448 | else: 449 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 450 | 451 | try: 452 | perplexity = math.exp(metrics["eval_loss"]) 453 | except OverflowError: 454 | perplexity = float("inf") 455 | metrics["perplexity"] = perplexity 456 | 457 | trainer.log_metrics("eval", metrics) 458 | trainer.save_metrics("eval", metrics) 459 | 460 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 461 | if data_args.dataset_name is not None: 462 | kwargs["dataset_tags"] = data_args.dataset_name 463 | if data_args.dataset_config_name is not None: 464 | kwargs["dataset_args"] = data_args.dataset_config_name 465 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 466 | else: 467 | kwargs["dataset"] = data_args.dataset_name 468 | 469 | if training_args.push_to_hub: 470 | trainer.push_to_hub(**kwargs) 471 | else: 472 | trainer.create_model_card(**kwargs) 473 | 474 | 475 | if __name__ == "__main__": 476 | train() 477 | --------------------------------------------------------------------------------