├── .python-version ├── CLAUDE.md ├── scripts ├── __init__.py ├── hf_utils │ ├── update_standalone.py │ └── hf_model_process_check.py ├── utils │ └── sync_output_modeling.py ├── context-relevance-datasets │ ├── upload_context_relevance_to_hf.py │ ├── frequency_filter_ds.py │ └── add_reranker_teacher_scores.py └── eval_mldr │ └── ignored_questions.yaml ├── .env.sample ├── open_provence ├── utils │ ├── __init__.py │ ├── modeling_export.py │ └── model_architecture.py ├── models │ ├── __init__.py │ └── open_provence_head.py ├── trainer_cli.py ├── modeling_open_provence_transformers.py ├── __init__.py ├── data_structures.py └── losses.py ├── configs ├── eval_datasets │ ├── en.yaml │ ├── en_nano.yaml │ ├── ja.yaml │ └── ja_nano.yaml ├── open-provence-reranker-v1-gte-modernbert-base.yaml ├── toy-open-provence-reranker-v1-gte-modernbert-base.yaml ├── toy-open-provence-reranker-v1.yaml ├── open-provence-reranker-large-v1.yaml ├── open-provence-reranker-v1.yaml └── open-provence-reranker-xsmall-v1.yaml ├── tox.ini ├── LICENSE ├── .github └── workflows │ └── ci.yaml ├── tests ├── utils │ ├── test_modeling_export.py │ └── test_model_architecture.py ├── test_modeling_default_dtype.py ├── test_checkpoint_resolution.py ├── scripts │ ├── test_generate_ds_from_sentense_transformer.py │ └── test_sync_output_modeling.py ├── test_data_structures.py ├── test_items_sampling.py ├── test_sequential_fragmentize.py ├── test_tokenizer_special_tokens.py ├── test_eval_mldr_official.py └── test_trainer_sampling.py ├── .gitignore ├── pyproject.toml ├── docs ├── eval_dataset.md ├── eval_mldr.md └── train.md └── AGENTS.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | @AGENTS.md 2 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper package for CLI utilities.""" 2 | 3 | from __future__ import annotations 4 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | # Environment variables for Open Provence tooling 2 | # Copy to .env and replace values with your own secrets. 3 | 4 | # Required when using MLDR LLM-based evaluation or the Streamlit WebUI "LLM judge" features. 5 | OPENAI_API_KEY=sk-xxxxxxxxxx 6 | -------------------------------------------------------------------------------- /open_provence/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility modules for OpenProvence. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from .model_architecture import ModelArchitectureUtils 8 | 9 | __all__ = [ 10 | "ModelArchitectureUtils", 11 | ] 12 | -------------------------------------------------------------------------------- /open_provence/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for OpenProvence. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from .open_provence_head import OpenProvenceHead, OpenProvenceHeadConfig 8 | 9 | __all__ = ["OpenProvenceHead", "OpenProvenceHeadConfig"] 10 | -------------------------------------------------------------------------------- /open_provence/trainer_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Entry point for training OpenProvence models. 4 | 5 | This script delegates to the main runner module. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from open_provence.runner import main 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /configs/eval_datasets/en.yaml: -------------------------------------------------------------------------------- 1 | # Evaluation datasets for English freq reranker checkpoints. 2 | split: test 3 | datasets: 4 | - dataset_name: "hotchpotch/msmarco-context-relevance" 5 | subset: "freq2" 6 | - dataset_name: "hotchpotch/natural-questions-context-relevance" 7 | subset: "nodup_freq2" 8 | - dataset_name: "hotchpotch/gooaq-context-relevance-130k" 9 | subset: "default" 10 | -------------------------------------------------------------------------------- /configs/eval_datasets/en_nano.yaml: -------------------------------------------------------------------------------- 1 | split: test 2 | datasets: 3 | - dataset_name: "hotchpotch/msmarco-context-relevance" 4 | subset: "freq2" 5 | n_samples: 100 6 | - dataset_name: "hotchpotch/natural-questions-context-relevance" 7 | subset: "nodup_freq2" 8 | n_samples: 100 9 | - dataset_name: "hotchpotch/gooaq-context-relevance-130k" 10 | subset: "default" 11 | n_samples: 100 12 | -------------------------------------------------------------------------------- /open_provence/utils/modeling_export.py: -------------------------------------------------------------------------------- 1 | """Helpers for exporting modeling_open_provence_standalone scripts.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | 7 | 8 | def write_modeling_open_provence( 9 | source: Path, 10 | destination: Path, 11 | ) -> None: 12 | """Copy modeling_open_provence_standalone.py without mutating its contents.""" 13 | 14 | destination.write_text(source.read_text(encoding="utf-8"), encoding="utf-8") 15 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | requires = tox-uv>=1.11.1 3 | envlist = pytests, lint, format-check, typecheck 4 | isolated_build = false 5 | skip_missing_interpreters = true 6 | 7 | [testenv] 8 | runner = uv-venv-lock-runner 9 | dependency_groups = 10 | dev 11 | cpu 12 | no_default_groups = true 13 | skip_install = true 14 | commands = python -c "raise SystemExit('Specify a concrete environment, e.g. `tox -e lint`')" # guard 15 | 16 | [testenv:pytests] 17 | commands = 18 | pytest --maxfail=1 --durations=5 -n auto --maxprocesses=4 --dist loadscope 19 | 20 | [testenv:lint] 21 | commands = 22 | ruff check open_provence tests scripts 23 | 24 | [testenv:format-check] 25 | commands = 26 | ruff format --check --diff open_provence tests scripts 27 | 28 | [testenv:typecheck] 29 | commands = 30 | pyright --threads 4 open_provence tests scripts 31 | -------------------------------------------------------------------------------- /open_provence/modeling_open_provence_transformers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compatibility shim: legacy imports for OpenProvence Hugging Face helpers. 3 | 4 | All functionality now resides in ``modeling_open_provence_standalone``. This module keeps the 5 | old import path working for downstream tooling that still references 6 | ``open_provence.modeling_open_provence_transformers``. 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | from .modeling_open_provence_standalone import ( 12 | OpenProvenceConfig, 13 | OpenProvenceEncoderConfig, 14 | OpenProvenceEncoderForSequenceClassification, 15 | OpenProvenceEncoderForTokenClassification, 16 | OpenProvenceForSequenceClassification, 17 | OpenProvenceForTokenClassification, 18 | ) 19 | 20 | __all__ = [ 21 | "OpenProvenceConfig", 22 | "OpenProvenceForSequenceClassification", 23 | "OpenProvenceForTokenClassification", 24 | "OpenProvenceEncoderConfig", 25 | "OpenProvenceEncoderForSequenceClassification", 26 | "OpenProvenceEncoderForTokenClassification", 27 | ] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yuichi Tateno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /open_provence/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Query-dependent text pruning and reranking for efficient RAG pipelines. 3 | 4 | This module provides functionality for pruning irrelevant content from documents 5 | based on queries, with optional reranking capabilities. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from .data_collator import OpenProvenceDataCollator 11 | from .data_structures import ( 12 | OpenProvenceConfig, 13 | OpenProvenceOnlyOutput, 14 | OpenProvenceOutput, 15 | RerankingOpenProvenceOutput, 16 | ) 17 | from .encoder import OpenProvenceEncoder 18 | from .losses import OpenProvenceLoss 19 | from .trainer import OpenProvenceTrainer 20 | 21 | # Import runner module at the end to avoid circular imports 22 | # It will be imported after other modules are initialized 23 | 24 | __all__ = [ 25 | "OpenProvenceConfig", 26 | "RerankingOpenProvenceOutput", 27 | "OpenProvenceOutput", 28 | "OpenProvenceOnlyOutput", 29 | "OpenProvenceEncoder", 30 | "OpenProvenceTrainer", 31 | "OpenProvenceLoss", 32 | "OpenProvenceDataCollator", 33 | "runner", 34 | ] 35 | 36 | # Import runner after other modules are initialized 37 | from . import runner 38 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | ci: 15 | runs-on: ubuntu-latest 16 | env: 17 | UV_PYTHON: "3.11" 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Set up uv 23 | id: setup-uv 24 | uses: astral-sh/setup-uv@v6 25 | with: 26 | enable-cache: true 27 | cache-suffix: linux-py311-tox 28 | 29 | - name: Install Python 3.11 30 | run: uv python install 3.11 31 | 32 | - name: Sync dependencies 33 | run: uv sync --locked --no-default-groups --group dev --group cpu 34 | 35 | - name: Download NLTK resources 36 | run: | 37 | # Use the virtualenv interpreter directly so CI never pulls NVIDIA/CUDA extras via `uv run`. 38 | ./.venv/bin/python -c "import nltk; nltk.download('punkt', quiet=True); nltk.download('punkt_tab', quiet=True)" 39 | 40 | - name: Run tox 41 | run: | 42 | # Invoke tox from the synced virtualenv to reuse locked deps; keep `run-parallel` for CI throughput. 43 | ./.venv/bin/tox run-parallel 44 | -------------------------------------------------------------------------------- /tests/utils/test_modeling_export.py: -------------------------------------------------------------------------------- 1 | """Tests for ``open_provence.utils.modeling_export``.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | 7 | from open_provence.utils.modeling_export import write_modeling_open_provence 8 | 9 | 10 | def _make_source(tmp_path: Path, content: str) -> Path: 11 | source = tmp_path / "modeling_open_provence_standalone.py" 12 | source.write_text(content, encoding="utf-8") 13 | return source 14 | 15 | 16 | def test_write_modeling_open_provence_copies_source(tmp_path: Path) -> None: 17 | content = "DEFAULT_SPLITTER_LANGUAGE = \"auto\"\n" 18 | source = _make_source(tmp_path, content) 19 | destination = tmp_path / "out.py" 20 | 21 | write_modeling_open_provence(source, destination) 22 | 23 | assert destination.read_text(encoding="utf-8") == content 24 | 25 | 26 | def test_write_modeling_open_provence_overwrites_existing(tmp_path: Path) -> None: 27 | content = "# latest\nDEFAULT_SPLITTER_LANGUAGE = \"auto\"\n" 28 | source = _make_source(tmp_path, content) 29 | destination = tmp_path / "out.py" 30 | destination.write_text("legacy\n", encoding="utf-8") 31 | 32 | write_modeling_open_provence(source, destination) 33 | 34 | assert destination.read_text(encoding="utf-8") == content 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Distribution / packaging 2 | .Python 3 | build/ 4 | develop-eggs/ 5 | dist/ 6 | downloads/ 7 | eggs/ 8 | .eggs/ 9 | lib/ 10 | lib64/ 11 | parts/ 12 | sdist/ 13 | var/ 14 | wheels/ 15 | share/python-wheels/ 16 | *.egg-info/ 17 | .installed.cfg 18 | *.egg 19 | MANIFEST 20 | 21 | # Docs 22 | /docs/_build/ 23 | /docs/make.bat 24 | 25 | # Editors 26 | .idea 27 | .vscode 28 | 29 | # Coverage 30 | htmlcov 31 | 32 | # Training outputs and temporary files 33 | output/ 34 | outputs/ 35 | tmp/ 36 | *.bin 37 | *.safetensors 38 | *.pt 39 | *.pth 40 | .coverage* 41 | coverage.xml 42 | 43 | # Examples 44 | /examples/**/output/* 45 | /examples/datasets/ 46 | /examples/embeddings/ 47 | /examples/sentence_transformer/training/quora_duplicate_questions/quora-IR-dataset/ 48 | examples/datasets/*/ 49 | 50 | 51 | # Specific files and folders 52 | /pretrained-models/ 53 | /cheatsheet.txt 54 | /testsuite.txt 55 | /TODO.txt 56 | 57 | # Virtual environments 58 | .env 59 | .venv 60 | env/ 61 | venv/ 62 | 63 | # Database 64 | /qdrant_storage 65 | /elastic-start-local 66 | 67 | # Others 68 | *.pyc 69 | *.gz 70 | *.tsv 71 | 72 | 73 | tmp_*.py 74 | nr_*/ 75 | wandb 76 | checkpoints 77 | tmp 78 | .DS_Store 79 | /runs 80 | /output/ 81 | /results/ 82 | /log/ 83 | /logs/ 84 | tmp/ 85 | tmp* 86 | log/ 87 | logs/ 88 | cache/ 89 | 90 | # Log directories 91 | logs/ 92 | scripts/log/ 93 | 94 | .cckiro/ 95 | -------------------------------------------------------------------------------- /configs/eval_datasets/ja.yaml: -------------------------------------------------------------------------------- 1 | # Evaluation datasets for Japanese freq reranker checkpoints. 2 | split: test 3 | datasets: 4 | - dataset_name: "hotchpotch/msmarco-context-relevance" 5 | subset: "freq2" 6 | - dataset_name: "hotchpotch/natural-questions-context-relevance" 7 | subset: "nodup_freq2" 8 | - dataset_name: "hotchpotch/gooaq-context-relevance-130k" 9 | subset: "default" 10 | - dataset_name: "hotchpotch/japanese-context-relevance" 11 | subset: "msmarco-ja-freq2" 12 | - dataset_name: "hotchpotch/japanese-context-relevance" 13 | subset: "auto-wiki-qa-nemotron" 14 | - dataset_name: "hotchpotch/japanese-context-relevance" 15 | subset: "jaquad-freq2" 16 | - dataset_name: "hotchpotch/japanese-context-relevance" 17 | subset: "jqara" 18 | - dataset_name: "hotchpotch/japanese-context-relevance" 19 | subset: "jsquad-freq2" 20 | - dataset_name: "hotchpotch/japanese-context-relevance" 21 | subset: "miracl" 22 | - dataset_name: "hotchpotch/japanese-context-relevance" 23 | subset: "mkqa" 24 | - dataset_name: "hotchpotch/japanese-context-relevance" 25 | subset: "mr-tydi" 26 | - dataset_name: "hotchpotch/japanese-context-relevance" 27 | subset: "quiz-no-mori" 28 | - dataset_name: "hotchpotch/japanese-context-relevance" 29 | subset: "quiz-works" 30 | - dataset_name: "hotchpotch/japanese-context-relevance" 31 | subset: "JFWIR" 32 | -------------------------------------------------------------------------------- /tests/test_modeling_default_dtype.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import platform 4 | 5 | import pytest 6 | import torch 7 | 8 | try: 9 | from open_provence.modeling_open_provence_standalone import _select_default_torch_dtype 10 | except ImportError: # datasets などが未インストールの場合はスキップ 11 | pytest.skip( 12 | "modeling_open_provence_standalone requires optional dependencies", 13 | allow_module_level=True, 14 | ) 15 | 16 | 17 | def test_select_default_dtype_cuda_prefers_bf16(monkeypatch): 18 | monkeypatch.setattr(torch.cuda, "is_available", lambda: True) 19 | monkeypatch.setattr(torch.cuda, "is_bf16_supported", lambda: True) 20 | assert _select_default_torch_dtype("cuda") == torch.bfloat16 21 | 22 | 23 | def test_select_default_dtype_cuda_fallback_float16(monkeypatch): 24 | monkeypatch.setattr(torch.cuda, "is_available", lambda: True) 25 | monkeypatch.setattr(torch.cuda, "is_bf16_supported", lambda: False) 26 | assert _select_default_torch_dtype("cuda") == torch.float16 27 | 28 | 29 | def test_select_default_dtype_cpu_apple(monkeypatch): 30 | monkeypatch.setattr(platform, "system", lambda: "Darwin") 31 | monkeypatch.setattr(platform, "machine", lambda: "arm64") 32 | assert _select_default_torch_dtype("cpu") == "auto" 33 | 34 | 35 | def test_select_default_dtype_mps(monkeypatch): 36 | assert _select_default_torch_dtype("mps") == "auto" 37 | 38 | 39 | def test_select_default_dtype_unknown_device(monkeypatch): 40 | monkeypatch.setattr(platform, "system", lambda: "Linux") 41 | monkeypatch.setattr(platform, "machine", lambda: "x86_64") 42 | assert _select_default_torch_dtype("cpu") is None 43 | -------------------------------------------------------------------------------- /configs/eval_datasets/ja_nano.yaml: -------------------------------------------------------------------------------- 1 | # Nano evaluation slice aligned with freq datasets (first 100 examples per dataset). 2 | split: test 3 | datasets: 4 | - dataset_name: "hotchpotch/msmarco-context-relevance" 5 | subset: "freq2" 6 | n_samples: 100 7 | - dataset_name: "hotchpotch/natural-questions-context-relevance" 8 | subset: "nodup_freq2" 9 | n_samples: 100 10 | - dataset_name: "hotchpotch/gooaq-context-relevance-130k" 11 | subset: "default" 12 | n_samples: 100 13 | - dataset_name: "hotchpotch/japanese-context-relevance" 14 | subset: "msmarco-ja-freq2" 15 | n_samples: 100 16 | - dataset_name: "hotchpotch/japanese-context-relevance" 17 | subset: "auto-wiki-qa-nemotron" 18 | n_samples: 100 19 | - dataset_name: "hotchpotch/japanese-context-relevance" 20 | subset: "jaquad-freq2" 21 | n_samples: 100 22 | - dataset_name: "hotchpotch/japanese-context-relevance" 23 | subset: "jqara" 24 | n_samples: 100 25 | - dataset_name: "hotchpotch/japanese-context-relevance" 26 | subset: "jsquad-freq2" 27 | n_samples: 100 28 | - dataset_name: "hotchpotch/japanese-context-relevance" 29 | subset: "miracl" 30 | n_samples: 100 31 | - dataset_name: "hotchpotch/japanese-context-relevance" 32 | subset: "mkqa" 33 | n_samples: 100 34 | - dataset_name: "hotchpotch/japanese-context-relevance" 35 | subset: "mr-tydi" 36 | n_samples: 100 37 | - dataset_name: "hotchpotch/japanese-context-relevance" 38 | subset: "quiz-no-mori" 39 | n_samples: 100 40 | - dataset_name: "hotchpotch/japanese-context-relevance" 41 | subset: "quiz-works" 42 | n_samples: 100 43 | - dataset_name: "hotchpotch/japanese-context-relevance" 44 | subset: "JFWIR" 45 | n_samples: 100 46 | -------------------------------------------------------------------------------- /configs/open-provence-reranker-v1-gte-modernbert-base.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "Alibaba-NLP/gte-reranker-modernbert-base" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 11 | - 12 | dataset_name: "hotchpotch/natural-questions-context-relevance" 13 | subset: "nodup_freq2" 14 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 15 | items: 6 16 | - 17 | dataset_name: "hotchpotch/gooaq-context-relevance-130k" 18 | subset: "default" 19 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 20 | items: 6 21 | 22 | 23 | training_args: 24 | overwrite_output_dir: true 25 | optimizer: "adafactor" 26 | 27 | # Training parameters 28 | learning_rate: 5.0e-5 29 | per_device_train_batch_size: 4 # If GPU memory is not enough, try reducing this value. 30 | gradient_accumulation_steps: 64 31 | max_grad_norm: 1.0 32 | 33 | # Optimizer and scheduler 34 | weight_decay: 0.01 35 | lr_scheduler_type: "cosine" 36 | warmup_ratio: 0.1 37 | 38 | # Logging and saving 39 | logging_steps: 100 40 | save_steps: 500 41 | save_total_limit: 5 42 | 43 | # Mixed precision 44 | fp16: false 45 | bf16: true 46 | 47 | # Other settings 48 | dataloader_num_workers: 8 49 | load_best_model_at_end: true 50 | num_train_epochs: 1 51 | 52 | # eval 53 | per_device_eval_batch_size: 16 54 | eval_steps: 500 55 | 56 | # Reporting 57 | report_to: ["wandb"] 58 | 59 | eval_datasets: 60 | config: configs/eval_datasets/en.yaml 61 | threshold: 0.1 62 | batch_size: 32 63 | -------------------------------------------------------------------------------- /configs/toy-open-provence-reranker-v1-gte-modernbert-base.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "Alibaba-NLP/gte-reranker-modernbert-base" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 11 | n_samples: 4000 12 | - 13 | dataset_name: "hotchpotch/natural-questions-context-relevance" 14 | subset: "nodup_freq2" 15 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 16 | items: 6 17 | n_samples: 4000 18 | - 19 | dataset_name: "hotchpotch/gooaq-context-relevance-130k" 20 | subset: "default" 21 | teacher_column: "teacher_scores.gte-reranker-modernbert-base" 22 | items: 6 23 | n_samples: 4000 24 | 25 | 26 | training_args: 27 | overwrite_output_dir: true 28 | optimizer: "adafactor" 29 | 30 | # Training parameters 31 | learning_rate: 5.0e-5 32 | per_device_train_batch_size: 4 # If GPU memory is not enough, try reducing this value. 33 | gradient_accumulation_steps: 16 34 | max_grad_norm: 1.0 35 | 36 | # Optimizer and scheduler 37 | weight_decay: 0.01 38 | lr_scheduler_type: "cosine" 39 | warmup_ratio: 0.1 40 | 41 | # Logging and saving 42 | logging_steps: 100 43 | save_steps: 500 44 | save_total_limit: 5 45 | 46 | # Mixed precision 47 | fp16: false 48 | bf16: true 49 | 50 | # Other settings 51 | dataloader_num_workers: 8 52 | load_best_model_at_end: true 53 | num_train_epochs: 1 54 | 55 | # eval 56 | per_device_eval_batch_size: 16 57 | eval_steps: 500 58 | 59 | # Reporting 60 | report_to: ["wandb"] 61 | 62 | eval_datasets: 63 | config: configs/eval_datasets/en_nano.yaml 64 | threshold: 0.1 65 | batch_size: 32 66 | -------------------------------------------------------------------------------- /tests/test_checkpoint_resolution.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | from open_provence.trainer import ResolvedCheckpoint, resolve_resume_checkpoint_path 7 | 8 | 9 | def _make_checkpoint(dir_path: Path) -> None: 10 | dir_path.mkdir(parents=True) 11 | (dir_path / "trainer_state.json").write_text("{}", encoding="utf-8") 12 | 13 | 14 | def test_resolve_explicit_checkpoint_returns_parent(tmp_path: Path) -> None: 15 | checkpoint_dir = tmp_path / "checkpoint-0500" 16 | _make_checkpoint(checkpoint_dir) 17 | 18 | resolved = resolve_resume_checkpoint_path(checkpoint_dir) 19 | 20 | assert isinstance(resolved, ResolvedCheckpoint) 21 | assert resolved.checkpoint_dir == checkpoint_dir.resolve() 22 | assert resolved.run_dir == tmp_path.resolve() 23 | assert resolved.steps == 500 24 | 25 | 26 | def test_resolve_parent_directory_picks_latest_checkpoint(tmp_path: Path) -> None: 27 | run_dir = tmp_path / "run" 28 | older = run_dir / "checkpoint-0100" 29 | newest = run_dir / "checkpoint-0500" 30 | _make_checkpoint(older) 31 | _make_checkpoint(newest) 32 | 33 | resolved = resolve_resume_checkpoint_path(run_dir) 34 | 35 | assert resolved.checkpoint_dir == newest.resolve() 36 | assert resolved.run_dir == run_dir.resolve() 37 | assert resolved.steps == 500 38 | 39 | 40 | def test_resolve_parent_directory_without_checkpoints_errors(tmp_path: Path) -> None: 41 | run_dir = tmp_path / "run" 42 | run_dir.mkdir() 43 | 44 | with pytest.raises(ValueError): 45 | resolve_resume_checkpoint_path(run_dir) 46 | 47 | 48 | def test_resolve_missing_path_errors(tmp_path: Path) -> None: 49 | missing = tmp_path / "missing" 50 | 51 | with pytest.raises(FileNotFoundError): 52 | resolve_resume_checkpoint_path(missing) 53 | -------------------------------------------------------------------------------- /configs/toy-open-provence-reranker-v1.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "hotchpotch/japanese-reranker-base-v2" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 11 | n_samples: 4000 12 | - 13 | dataset_name: "hotchpotch/japanese-context-relevance" 14 | subset: "msmarco-ja-freq2" 15 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 16 | n_samples: 4000 17 | - 18 | dataset_name: "hotchpotch/japanese-context-relevance" 19 | subset: "auto-wiki-qa-nemotron" 20 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 21 | n_samples: 4000 22 | 23 | training_args: 24 | overwrite_output_dir: true 25 | optimizer: "adafactor" 26 | 27 | # Training parameters 28 | learning_rate: 5.0e-5 29 | # The Japanese model produces stable and well-balanced scores with a batch size of 256. 30 | per_device_train_batch_size: 4 # If GPU memory is not enough, try reducing this value. 31 | gradient_accumulation_steps: 16 32 | max_grad_norm: 1.0 33 | 34 | # Optimizer and scheduler 35 | weight_decay: 0.01 36 | lr_scheduler_type: "cosine" 37 | warmup_ratio: 0.1 38 | 39 | # Logging and saving 40 | logging_steps: 100 41 | save_steps: 500 42 | save_total_limit: 5 43 | 44 | # Mixed precision 45 | fp16: false 46 | bf16: true 47 | 48 | # Other settings 49 | dataloader_num_workers: 8 50 | load_best_model_at_end: true 51 | num_train_epochs: 1 52 | 53 | # eval 54 | per_device_eval_batch_size: 16 55 | eval_steps: 500 56 | 57 | # Reporting 58 | report_to: ["wandb"] 59 | 60 | eval_datasets: 61 | config: configs/eval_datasets/ja_nano.yaml 62 | threshold: 0.1 63 | batch_size: 32 64 | -------------------------------------------------------------------------------- /tests/scripts/test_generate_ds_from_sentense_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import subprocess 4 | import sys 5 | from pathlib import Path 6 | 7 | from datasets import Dataset, DatasetDict, load_from_disk 8 | 9 | SCRIPT_PATH = ( 10 | Path(__file__).resolve().parents[2] 11 | / "scripts" 12 | / "context-relevance-datasets" 13 | / "generate_ds_from_sentense_transformer.py" 14 | ) 15 | 16 | 17 | def build_source_dataset(root: Path) -> Path: 18 | rows = 20 19 | data = { 20 | "question": [f"question {i}" for i in range(rows)], 21 | "answer": [f"answer {i}" for i in range(rows)], 22 | "neg1": [f"neg1 {i}" for i in range(rows)], 23 | "neg2": [f"neg2 {i}" for i in range(rows)], 24 | } 25 | dataset = Dataset.from_dict(data) 26 | dataset_dict = DatasetDict({"train": dataset}) 27 | source_path = root / "source_ds" 28 | dataset_dict.save_to_disk(source_path) 29 | return source_path 30 | 31 | 32 | def test_generate_from_local_dataset(tmp_path): 33 | source_path = build_source_dataset(tmp_path) 34 | output_root = tmp_path / "converted" 35 | 36 | cmd = [ 37 | sys.executable, 38 | str(SCRIPT_PATH), 39 | "--dataset", 40 | str(source_path), 41 | "--lang", 42 | "en", 43 | "--output-root", 44 | str(output_root), 45 | "--overwrite", 46 | ] 47 | subprocess.run(cmd, check=True, cwd=Path(__file__).resolve().parents[2]) 48 | 49 | output_dirs = list(output_root.iterdir()) 50 | assert len(output_dirs) == 1 51 | converted = load_from_disk(output_dirs[0]) 52 | assert isinstance(converted, DatasetDict) 53 | assert set(converted.keys()) == {"train", "validation", "test"} 54 | first = converted["train"][0] 55 | assert first["query"].startswith("question") 56 | assert first["texts"][0].startswith("answer") 57 | assert first["labels"][0] == 1 58 | assert all(label in {0, 1} for label in first["labels"]) # sanity check 59 | -------------------------------------------------------------------------------- /tests/test_data_structures.py: -------------------------------------------------------------------------------- 1 | """Tests for ``open_provence.data_structures`` helpers.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from open_provence.data_structures import ( 8 | OpenProvenceOnlyOutput, 9 | OpenProvenceOutput, 10 | RerankingOpenProvenceOutput, 11 | ) 12 | 13 | 14 | def test_open_provence_output_to_dict_serializes_numpy() -> None: 15 | output = OpenProvenceOutput( 16 | ranking_scores=np.array([0.1, 0.2]), 17 | chunk_predictions=np.array([[1, 0], [0, 1]]), 18 | chunk_positions=[[(0, 1)]], 19 | compression_ratio=0.5, 20 | ) 21 | 22 | result = output.to_dict() 23 | 24 | assert result["ranking_scores"] == [0.1, 0.2] 25 | assert result["chunk_predictions"] == [[1, 0], [0, 1]] 26 | assert result["chunk_positions"] == [[(0, 1)]] 27 | assert result["compression_ratio"] == 0.5 28 | assert "token_scores" not in result 29 | 30 | 31 | def test_open_provence_only_output_to_dict_handles_torch() -> None: 32 | logits = torch.tensor([[[0.2, 0.8], [0.7, 0.3]]]) 33 | output = OpenProvenceOnlyOutput( 34 | pruning_logits=logits, 35 | pruning_masks=np.array([[1, 0]]), 36 | num_pruned_tokens=5, 37 | ) 38 | 39 | result = output.to_dict() 40 | 41 | np.testing.assert_allclose( 42 | result["pruning_logits"], 43 | [[[0.2, 0.8], [0.7, 0.3]]], 44 | ) 45 | assert result["pruning_masks"] == [[1, 0]] 46 | assert result["num_pruned_tokens"] == 5 47 | assert "pruning_probs" not in result 48 | 49 | 50 | def test_reranking_output_repr_includes_shapes() -> None: 51 | output = RerankingOpenProvenceOutput( 52 | ranking_scores=np.ones(2), 53 | pruning_masks=np.ones((1, 2)), 54 | compression_ratio=0.75, 55 | pruning_logits=torch.zeros(1, 2, 2), 56 | ) 57 | 58 | result = output.to_dict() 59 | assert result["pruning_logits"] == [[[0.0, 0.0], [0.0, 0.0]]] 60 | 61 | representation = repr(output) 62 | assert "ranking_scores=(2,)" in representation 63 | assert "pruning_masks=(1, 2)" in representation 64 | assert "compression_ratio=0.75" in representation 65 | -------------------------------------------------------------------------------- /tests/test_items_sampling.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datasets import Dataset 4 | from open_provence.trainer import sample_items_by_label_priority 5 | 6 | 7 | def test_items_sampling_keeps_positive_and_samples_negatives(): 8 | dataset = Dataset.from_dict( 9 | { 10 | "labels": [[1, 0, 0, 0]], 11 | "texts": [["pos", "neg-a", "neg-b", "neg-c"]], 12 | "teacher_scores": [[0.9, 0.2, 0.1, 0.05]], 13 | } 14 | ) 15 | 16 | filtered = sample_items_by_label_priority(dataset, 3, seed=123, num_proc=1) 17 | 18 | assert len(filtered) == 1 19 | row = filtered[0] 20 | assert len(row["labels"]) == 3 21 | assert row["labels"][0] == 1 # positive entry is retained 22 | assert row["texts"][0] == "pos" 23 | # The remaining items originate from the original negatives 24 | assert set(row["texts"][1:]).issubset({"neg-a", "neg-b", "neg-c"}) 25 | assert len(row["teacher_scores"]) == 3 26 | 27 | 28 | def test_items_sampling_drops_queries_with_too_few_items(): 29 | dataset = Dataset.from_dict( 30 | { 31 | "id": ["short", "long"], 32 | "labels": [[1, 0], [1, 0, 0]], 33 | "texts": [["p", "n"], ["p", "n1", "n2"]], 34 | } 35 | ) 36 | 37 | filtered = sample_items_by_label_priority(dataset, 3, seed=42, num_proc=1) 38 | 39 | assert len(filtered) == 1 40 | assert filtered[0]["id"] == "long" 41 | assert len(filtered[0]["labels"]) == 3 42 | 43 | 44 | def test_items_sampling_handles_rows_without_positive_labels(): 45 | dataset = Dataset.from_dict( 46 | { 47 | "labels": [[0, 0, 0, 0]], 48 | "texts": [["a", "b", "c", "d"]], 49 | } 50 | ) 51 | 52 | filtered = sample_items_by_label_priority(dataset, 2, seed=7, num_proc=1) 53 | 54 | assert len(filtered) == 1 55 | row = filtered[0] 56 | assert len(row["labels"]) == 2 57 | assert set(row["texts"]).issubset({"a", "b", "c", "d"}) 58 | 59 | 60 | def test_items_sampling_prefers_positive_items_when_exceeding_limit(): 61 | dataset = Dataset.from_dict( 62 | { 63 | "labels": [[1, 1, 0, 0]], 64 | "texts": [["p1", "p2", "n1", "n2"]], 65 | } 66 | ) 67 | 68 | filtered = sample_items_by_label_priority(dataset, 2, seed=5, num_proc=1) 69 | 70 | assert len(filtered) == 1 71 | assert filtered[0]["labels"] == [1, 1] 72 | assert filtered[0]["texts"] == ["p1", "p2"] 73 | -------------------------------------------------------------------------------- /tests/test_sequential_fragmentize.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from open_provence.modeling_open_provence_standalone import ( 6 | OpenProvenceModel, 7 | SentenceSplitter, 8 | ) 9 | 10 | 11 | class _StubTokenizer: 12 | """Minimal tokenizer stub that operates on Unicode codepoints.""" 13 | 14 | def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: 15 | return [ord(ch) for ch in text] 16 | 17 | def __call__( 18 | self, 19 | sentences: list[str], 20 | *, 21 | add_special_tokens: bool = False, 22 | return_attention_mask: bool = False, 23 | ) -> dict[str, Any]: 24 | return {"input_ids": [[ord(ch) for ch in sentence] for sentence in sentences]} 25 | 26 | def batch_decode( 27 | self, 28 | sequences: list[list[int]], 29 | *, 30 | skip_special_tokens: bool = True, 31 | clean_up_tokenization_spaces: bool = False, 32 | ) -> list[str]: 33 | return ["".join(chr(ch) for ch in seq) for seq in sequences] 34 | 35 | def decode( 36 | self, 37 | sequence: list[int], 38 | *, 39 | skip_special_tokens: bool = True, 40 | clean_up_tokenization_spaces: bool = False, 41 | ) -> str: 42 | return "".join(chr(ch) for ch in sequence) 43 | 44 | 45 | def _split_sentences(text: str) -> list[str]: 46 | return [segment for segment in text.split("。") if segment] or [text] 47 | 48 | 49 | def test_run_sequential_fragmentize_produces_fragments() -> None: 50 | model = OpenProvenceModel.__new__(OpenProvenceModel) 51 | model.tokenizer = _StubTokenizer() 52 | 53 | job = { 54 | "query_idx": 0, 55 | "context_idx": 0, 56 | "context_text": "吾輩は猫である。名前はまだない。", 57 | "prefix_sentences": [], 58 | "manual_sentences": None, 59 | "cached_sentences": None, 60 | "cached_token_lists": None, 61 | } 62 | 63 | splitter: SentenceSplitter = _split_sentences 64 | 65 | results = model._run_sequential_fragmentize( 66 | [job], 67 | max_fragment_tokens=16, 68 | splitter=splitter, 69 | show_progress=False, 70 | strip_sentences=True, 71 | respect_sentence_boundaries=False, 72 | ) 73 | 74 | assert len(results) == 1 75 | entry = results[0] 76 | 77 | assert entry["sentences"] == ["吾輩は猫である", "名前はまだない"] 78 | assert entry["fragment_texts"] == ["吾輩は猫である", "名前はまだない"] 79 | assert entry["fragment_token_ids"] 80 | assert entry["timing_sentence_collect"] >= 0.0 81 | assert entry["timing_fragment_decode"] >= 0.0 82 | -------------------------------------------------------------------------------- /tests/test_tokenizer_special_tokens.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from transformers import AutoTokenizer 5 | 6 | ENGLISH_MODEL_NAME = "Alibaba-NLP/gte-reranker-modernbert-base" 7 | JAPANESE_MODEL_NAME = "hotchpotch/japanese-reranker-base-v2" 8 | 9 | 10 | @pytest.mark.parametrize( 11 | ("model_name", "query", "document"), 12 | [ 13 | ( 14 | ENGLISH_MODEL_NAME, 15 | "What is artificial intelligence?", 16 | "Artificial intelligence studies intelligent behaviour in machines.", 17 | ), 18 | ( 19 | JAPANESE_MODEL_NAME, 20 | "AIとは何ですか?", 21 | "AIは人工知能の略称で、人間の知能を機械で再現することを指します。", 22 | ), 23 | ], 24 | ) 25 | def test_encode_plus_inserts_special_tokens(model_name: str, query: str, document: str) -> None: 26 | """Ensure encode_plus inserts special tokens for both English and Japanese checkpoints.""" 27 | 28 | tokenizer = AutoTokenizer.from_pretrained(model_name) 29 | 30 | encoding = tokenizer.encode_plus( 31 | query, 32 | document, 33 | add_special_tokens=True, 34 | return_token_type_ids=True, 35 | ) 36 | 37 | input_ids = encoding["input_ids"] 38 | assert input_ids, "Tokenizer returned empty input ids." 39 | 40 | start_candidates = [ 41 | tokenizer.cls_token_id, 42 | tokenizer.bos_token_id, 43 | tokenizer.special_tokens_map.get("cls_token_id"), 44 | tokenizer.special_tokens_map.get("bos_token_id"), 45 | ] 46 | start_candidates = [tok_id for tok_id in start_candidates if isinstance(tok_id, int)] 47 | assert start_candidates, "Tokenizer has no CLS/BOS token id defined." 48 | assert input_ids[0] in start_candidates, ( 49 | f"Expected one of {start_candidates} at start, but got {input_ids[0]}." 50 | ) 51 | 52 | boundary_candidates = [ 53 | tokenizer.sep_token_id, 54 | tokenizer.eos_token_id, 55 | tokenizer.special_tokens_map.get("sep_token_id"), 56 | tokenizer.special_tokens_map.get("eos_token_id"), 57 | ] 58 | boundary_candidates = [tok_id for tok_id in boundary_candidates if isinstance(tok_id, int)] 59 | assert boundary_candidates, "Tokenizer has no SEP/EOS token id defined." 60 | 61 | boundary_indices = [ 62 | idx for idx, tok in enumerate(input_ids[1:], start=1) if tok in boundary_candidates 63 | ] 64 | assert boundary_indices, ( 65 | "No boundary token found between query and document " 66 | f"(candidates={boundary_candidates}, tokens={input_ids})." 67 | ) 68 | assert boundary_indices[0] < len(input_ids) - 1, ( 69 | "Boundary token should not be the final token." 70 | ) 71 | 72 | # Confirm that removing special tokens changes the sequence start. 73 | encoding_no_special = tokenizer.encode_plus( 74 | query, 75 | document, 76 | add_special_tokens=False, 77 | return_token_type_ids=True, 78 | ) 79 | assert encoding_no_special["input_ids"], "encode_plus without specials returned no tokens." 80 | assert encoding_no_special["input_ids"][0] not in start_candidates, ( 81 | "encode_plus(add_special_tokens=False) unexpectedly kept the start special token; " 82 | "this would invalidate the special-token check." 83 | ) 84 | -------------------------------------------------------------------------------- /tests/utils/test_model_architecture.py: -------------------------------------------------------------------------------- 1 | """Tests for ``open_provence.utils.model_architecture``.""" 2 | 3 | from __future__ import annotations 4 | 5 | from open_provence.utils.model_architecture import ModelArchitectureUtils 6 | 7 | 8 | def test_detect_architecture_modernbert() -> None: 9 | keys = [ 10 | "tok_embeddings.weight", 11 | "layers.0.attn.Wqkv.weight", 12 | "layers.0.mlp_norm.weight", 13 | ] 14 | assert ModelArchitectureUtils.detect_architecture(keys) == "modernbert" 15 | 16 | 17 | def test_detect_architecture_prefers_known_prefixes() -> None: 18 | keys = [ 19 | "bert.embeddings.word_embeddings.weight", 20 | "bert.encoder.layer.0.attention.self.query.weight", 21 | "bert.pooler.dense.weight", 22 | ] 23 | assert ModelArchitectureUtils.detect_architecture(keys) == "bert" 24 | 25 | 26 | def test_detect_architecture_unknown_when_no_patterns() -> None: 27 | keys = ["linear.weight", "classifier.bias"] 28 | assert ModelArchitectureUtils.detect_architecture(keys) == "unknown" 29 | 30 | 31 | def test_needs_prefix_conversion_identifies_flat_modernbert_keys() -> None: 32 | keys = [ 33 | "embeddings.word_embeddings.weight", 34 | "layers.0.attn.Wqkv.weight", 35 | ] 36 | needs_conversion, prefix = ModelArchitectureUtils.needs_prefix_conversion(keys, "modernbert") 37 | assert needs_conversion is True 38 | assert prefix == "model." 39 | 40 | 41 | def test_needs_prefix_conversion_no_action_when_prefixed() -> None: 42 | keys = [ 43 | "model.embeddings.word_embeddings.weight", 44 | "model.layers.0.attn.Wqkv.weight", 45 | ] 46 | needs_conversion, prefix = ModelArchitectureUtils.needs_prefix_conversion(keys, "modernbert") 47 | assert needs_conversion is False 48 | assert prefix is None 49 | 50 | 51 | def test_convert_state_dict_keys_adds_and_skips() -> None: 52 | state_dict = { 53 | "embeddings.word_embeddings.weight": "weights", 54 | "layers.0.attn.Wqkv.weight": "attn", 55 | "pruning_head.linear.weight": "head", 56 | } 57 | 58 | converted = ModelArchitectureUtils.convert_state_dict_keys( 59 | state_dict, 60 | add_prefix="model.", 61 | skip_keys=["pruning_head"], 62 | ) 63 | 64 | assert converted["model.embeddings.word_embeddings.weight"] == "weights" 65 | assert converted["model.layers.0.attn.Wqkv.weight"] == "attn" 66 | assert converted["pruning_head.linear.weight"] == "head" 67 | 68 | 69 | def test_auto_fix_state_dict_adds_model_prefix_for_modernbert() -> None: 70 | state_dict = { 71 | "embeddings.word_embeddings.weight": "weights", 72 | "layers.0.attn.Wqkv.weight": "attn", 73 | } 74 | 75 | fixed = ModelArchitectureUtils.auto_fix_state_dict(state_dict, list(state_dict.keys()), "modernbert") 76 | 77 | assert "model.embeddings.word_embeddings.weight" in fixed 78 | assert "model.layers.0.attn.Wqkv.weight" in fixed 79 | 80 | 81 | def test_normalize_state_dict_for_saving_removes_model_prefix() -> None: 82 | state_dict = { 83 | "model.embeddings.word_embeddings.weight": "weights", 84 | "model.layers.0.attn.Wqkv.weight": "attn", 85 | } 86 | 87 | normalized = ModelArchitectureUtils.normalize_state_dict_for_saving(state_dict, "modernbert") 88 | 89 | assert "embeddings.word_embeddings.weight" in normalized 90 | assert "layers.0.attn.Wqkv.weight" in normalized 91 | -------------------------------------------------------------------------------- /configs/open-provence-reranker-large-v1.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "cl-nagoya/ruri-v3-reranker-310m" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 11 | - 12 | items: 6 13 | dataset_name: "hotchpotch/natural-questions-context-relevance" 14 | subset: "nodup_freq2" 15 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 16 | - 17 | items: 6 18 | dataset_name: "hotchpotch/gooaq-context-relevance-130k" 19 | subset: "default" 20 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 21 | - 22 | dataset_name: "hotchpotch/japanese-context-relevance" 23 | subset: "msmarco-ja-freq2" 24 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 25 | - 26 | dataset_name: "hotchpotch/japanese-context-relevance" 27 | subset: "auto-wiki-qa-nemotron" 28 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 29 | - 30 | dataset_name: "hotchpotch/japanese-context-relevance" 31 | subset: "jaquad-freq2" 32 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 33 | - 34 | dataset_name: "hotchpotch/japanese-context-relevance" 35 | subset: "jqara" 36 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 37 | upsample_factor: 4.0 38 | - 39 | dataset_name: "hotchpotch/japanese-context-relevance" 40 | subset: "jsquad-freq2" 41 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 42 | - 43 | dataset_name: "hotchpotch/japanese-context-relevance" 44 | subset: "miracl" 45 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 46 | upsample_factor: 2.0 47 | - 48 | dataset_name: "hotchpotch/japanese-context-relevance" 49 | subset: "mkqa" 50 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 51 | upsample_factor: 2.0 52 | - 53 | dataset_name: "hotchpotch/japanese-context-relevance" 54 | subset: "mr-tydi" 55 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 56 | upsample_factor: 2.0 57 | - 58 | dataset_name: "hotchpotch/japanese-context-relevance" 59 | subset: "quiz-no-mori" 60 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 61 | - 62 | dataset_name: "hotchpotch/japanese-context-relevance" 63 | subset: "quiz-works" 64 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 65 | - 66 | dataset_name: "hotchpotch/japanese-context-relevance" 67 | subset: "JFWIR" 68 | teacher_column: "teacher_scores.ruri-v3-reranker-310m" 69 | 70 | 71 | training_args: 72 | overwrite_output_dir: true 73 | optimizer: "adafactor" 74 | 75 | # Training parameters 76 | learning_rate: 5.0e-5 77 | # The Japanese model produces stable and well-balanced scores with a batch size of 256. 78 | per_device_train_batch_size: 2 # If GPU memory is not enough, try reducing this value. 79 | gradient_accumulation_steps: 128 80 | max_grad_norm: 1.0 81 | 82 | # Optimizer and scheduler 83 | weight_decay: 0.01 84 | lr_scheduler_type: "cosine" 85 | warmup_ratio: 0.1 86 | 87 | # Logging and saving 88 | logging_steps: 100 89 | save_steps: 500 90 | save_total_limit: 5 91 | 92 | # Mixed precision 93 | fp16: false 94 | bf16: true 95 | 96 | # Other settings 97 | dataloader_num_workers: 8 98 | load_best_model_at_end: true 99 | num_train_epochs: 1 100 | 101 | # eval 102 | per_device_eval_batch_size: 16 103 | eval_steps: 500 104 | 105 | # Reporting 106 | report_to: ["wandb"] 107 | 108 | eval_datasets: 109 | config: configs/eval_datasets/ja.yaml 110 | threshold: 0.1 111 | batch_size: 16 112 | -------------------------------------------------------------------------------- /configs/open-provence-reranker-v1.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "hotchpotch/japanese-reranker-base-v2" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 11 | - 12 | items: 6 13 | dataset_name: "hotchpotch/natural-questions-context-relevance" 14 | subset: "nodup_freq2" 15 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 16 | - 17 | items: 6 18 | dataset_name: "hotchpotch/gooaq-context-relevance-130k" 19 | subset: "default" 20 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 21 | - 22 | dataset_name: "hotchpotch/japanese-context-relevance" 23 | subset: "msmarco-ja-freq2" 24 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 25 | - 26 | dataset_name: "hotchpotch/japanese-context-relevance" 27 | subset: "auto-wiki-qa-nemotron" 28 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 29 | - 30 | dataset_name: "hotchpotch/japanese-context-relevance" 31 | subset: "jaquad-freq2" 32 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 33 | - 34 | dataset_name: "hotchpotch/japanese-context-relevance" 35 | subset: "jqara" 36 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 37 | upsample_factor: 4.0 38 | - 39 | dataset_name: "hotchpotch/japanese-context-relevance" 40 | subset: "jsquad-freq2" 41 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 42 | - 43 | dataset_name: "hotchpotch/japanese-context-relevance" 44 | subset: "miracl" 45 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 46 | upsample_factor: 2.0 47 | - 48 | dataset_name: "hotchpotch/japanese-context-relevance" 49 | subset: "mkqa" 50 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 51 | upsample_factor: 2.0 52 | - 53 | dataset_name: "hotchpotch/japanese-context-relevance" 54 | subset: "mr-tydi" 55 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 56 | upsample_factor: 2.0 57 | - 58 | dataset_name: "hotchpotch/japanese-context-relevance" 59 | subset: "quiz-no-mori" 60 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 61 | - 62 | dataset_name: "hotchpotch/japanese-context-relevance" 63 | subset: "quiz-works" 64 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 65 | - 66 | dataset_name: "hotchpotch/japanese-context-relevance" 67 | subset: "JFWIR" 68 | teacher_column: "teacher_scores.japanese-reranker-base-v2" 69 | 70 | 71 | training_args: 72 | overwrite_output_dir: true 73 | optimizer: "adafactor" 74 | 75 | # Training parameters 76 | learning_rate: 5.0e-5 77 | # The Japanese model produces stable and well-balanced scores with a batch size of 256. 78 | per_device_train_batch_size: 4 # If GPU memory is not enough, try reducing this value. 79 | gradient_accumulation_steps: 64 80 | max_grad_norm: 1.0 81 | 82 | # Optimizer and scheduler 83 | weight_decay: 0.01 84 | lr_scheduler_type: "cosine" 85 | warmup_ratio: 0.1 86 | 87 | # Logging and saving 88 | logging_steps: 100 89 | save_steps: 500 90 | save_total_limit: 5 91 | 92 | # Mixed precision 93 | fp16: false 94 | bf16: true 95 | 96 | # Other settings 97 | dataloader_num_workers: 8 98 | load_best_model_at_end: true 99 | num_train_epochs: 1 100 | 101 | # eval 102 | per_device_eval_batch_size: 16 103 | eval_steps: 500 104 | 105 | # Reporting 106 | report_to: ["wandb"] 107 | 108 | eval_datasets: 109 | config: configs/eval_datasets/ja.yaml 110 | threshold: 0.1 111 | batch_size: 32 112 | -------------------------------------------------------------------------------- /configs/open-provence-reranker-xsmall-v1.yaml: -------------------------------------------------------------------------------- 1 | model_args: 2 | model_name_or_path: "hotchpotch/japanese-reranker-xsmall-v2" 3 | classifier_dropout: 0.0 4 | 5 | data_args: 6 | datasets: 7 | - 8 | dataset_name: "hotchpotch/msmarco-context-relevance" 9 | subset: "freq2" 10 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 11 | - 12 | items: 6 13 | dataset_name: "hotchpotch/natural-questions-context-relevance" 14 | subset: "nodup_freq2" 15 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 16 | - 17 | items: 6 18 | dataset_name: "hotchpotch/gooaq-context-relevance-130k" 19 | subset: "default" 20 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 21 | - 22 | dataset_name: "hotchpotch/japanese-context-relevance" 23 | subset: "msmarco-ja-freq2" 24 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 25 | - 26 | dataset_name: "hotchpotch/japanese-context-relevance" 27 | subset: "auto-wiki-qa-nemotron" 28 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 29 | - 30 | dataset_name: "hotchpotch/japanese-context-relevance" 31 | subset: "jaquad-freq2" 32 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 33 | - 34 | dataset_name: "hotchpotch/japanese-context-relevance" 35 | subset: "jqara" 36 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 37 | upsample_factor: 4.0 38 | - 39 | dataset_name: "hotchpotch/japanese-context-relevance" 40 | subset: "jsquad-freq2" 41 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 42 | - 43 | dataset_name: "hotchpotch/japanese-context-relevance" 44 | subset: "miracl" 45 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 46 | upsample_factor: 2.0 47 | - 48 | dataset_name: "hotchpotch/japanese-context-relevance" 49 | subset: "mkqa" 50 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 51 | upsample_factor: 2.0 52 | - 53 | dataset_name: "hotchpotch/japanese-context-relevance" 54 | subset: "mr-tydi" 55 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 56 | upsample_factor: 2.0 57 | - 58 | dataset_name: "hotchpotch/japanese-context-relevance" 59 | subset: "quiz-no-mori" 60 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 61 | - 62 | dataset_name: "hotchpotch/japanese-context-relevance" 63 | subset: "quiz-works" 64 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 65 | - 66 | dataset_name: "hotchpotch/japanese-context-relevance" 67 | subset: "JFWIR" 68 | teacher_column: "teacher_scores.japanese-reranker-xsmall-v2" 69 | 70 | 71 | training_args: 72 | overwrite_output_dir: true 73 | optimizer: "adafactor" 74 | 75 | # Training parameters 76 | learning_rate: 5.0e-5 77 | # The Japanese model produces stable and well-balanced scores with a batch size of 256. 78 | per_device_train_batch_size: 4 # If GPU memory is not enough, try reducing this value. 79 | gradient_accumulation_steps: 64 80 | max_grad_norm: 1.0 81 | 82 | # Optimizer and scheduler 83 | weight_decay: 0.01 84 | lr_scheduler_type: "cosine" 85 | warmup_ratio: 0.1 86 | 87 | # Logging and saving 88 | logging_steps: 100 89 | save_steps: 500 90 | save_total_limit: 5 91 | 92 | # Mixed precision 93 | fp16: false 94 | bf16: true 95 | 96 | # Other settings 97 | dataloader_num_workers: 8 98 | load_best_model_at_end: true 99 | num_train_epochs: 1 100 | 101 | # eval 102 | per_device_eval_batch_size: 16 103 | eval_steps: 500 104 | 105 | # Reporting 106 | report_to: ["wandb"] 107 | 108 | eval_datasets: 109 | config: configs/eval_datasets/ja.yaml 110 | threshold: 0.1 111 | batch_size: 64 112 | -------------------------------------------------------------------------------- /tests/scripts/test_sync_output_modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib.util 4 | import io 5 | import json 6 | import sys 7 | from pathlib import Path 8 | 9 | 10 | def _repo_root() -> Path: 11 | return Path(__file__).resolve().parents[2] 12 | 13 | 14 | def _load_sync_module(): 15 | module_path = _repo_root() / "scripts" / "utils" / "sync_output_modeling.py" 16 | spec = importlib.util.spec_from_file_location("sync_output_modeling", module_path) 17 | if spec is None or spec.loader is None: 18 | raise RuntimeError("Failed to load sync_output_modeling module") 19 | module = importlib.util.module_from_spec(spec) 20 | sys.modules[spec.name] = module 21 | spec.loader.exec_module(module) # type: ignore[assignment] 22 | return module 23 | 24 | 25 | def test_sync_updates_modeling_and_config(tmp_path: Path) -> None: 26 | repo_root = _repo_root() 27 | base_file = repo_root / "open_provence" / "modeling_open_provence_standalone.py" 28 | sync = _load_sync_module() 29 | 30 | output_dir = tmp_path / "output" 31 | run_dir = output_dir / "toy-open-provence-reranker-japanese-test" 32 | run_dir.mkdir(parents=True, exist_ok=True) 33 | 34 | # Create outdated modeling file 35 | (run_dir / "modeling_open_provence_standalone.py").write_text( 36 | "# legacy content\n", 37 | encoding="utf-8", 38 | ) 39 | 40 | # Config with wrong language and missing legacy field 41 | config_path = run_dir / "config.json" 42 | config_path.write_text( 43 | json.dumps( 44 | { 45 | "model_type": "open_provence", 46 | "splitter_default_language": "en", 47 | "standalone_process_default_language": "en", 48 | "modeling_open_provence_default_language": "en", 49 | }, 50 | indent=2, 51 | ensure_ascii=False, 52 | ) 53 | + "\n", 54 | encoding="utf-8", 55 | ) 56 | 57 | states = sync.plan_sync(base_file, output_dir) 58 | assert len(states) == 1 59 | state = states[0] 60 | assert state.modeling_needs_update is True 61 | assert state.config_needs_update is True 62 | assert set(state.removed_keys) == { 63 | "splitter_default_language", 64 | "standalone_process_default_language", 65 | "modeling_open_provence_default_language", 66 | } 67 | 68 | stream = io.StringIO() 69 | sync.sync_targets(base_file, output_dir, overwrite=True, stream=stream) 70 | 71 | # modeling file should now match base file 72 | assert (run_dir / "modeling_open_provence_standalone.py").read_text( 73 | encoding="utf-8" 74 | ) == base_file.read_text(encoding="utf-8") 75 | 76 | updated_config = json.loads(config_path.read_text(encoding="utf-8")) 77 | for key in ( 78 | "splitter_default_language", 79 | "standalone_process_default_language", 80 | "modeling_open_provence_default_language", 81 | ): 82 | assert key not in updated_config 83 | 84 | output = stream.getvalue() 85 | assert "copied modeling_open_provence_standalone.py" in output 86 | assert "removed deprecated config keys" in output 87 | 88 | 89 | def test_sync_skip_when_up_to_date(tmp_path: Path) -> None: 90 | repo_root = _repo_root() 91 | base_file = repo_root / "open_provence" / "modeling_open_provence_standalone.py" 92 | sync = _load_sync_module() 93 | 94 | output_dir = tmp_path / "output" 95 | run_dir = output_dir / "toy-open-provence-reranker-test" 96 | run_dir.mkdir(parents=True, exist_ok=True) 97 | 98 | # Up-to-date modeling file 99 | run_dir.joinpath("modeling_open_provence_standalone.py").write_text( 100 | base_file.read_text(encoding="utf-8"), 101 | encoding="utf-8", 102 | ) 103 | 104 | config_path = run_dir / "config.json" 105 | config_path.write_text( 106 | json.dumps( 107 | { 108 | "model_type": "open_provence", 109 | "some_other_field": "value", 110 | }, 111 | indent=2, 112 | ensure_ascii=False, 113 | ) 114 | + "\n", 115 | encoding="utf-8", 116 | ) 117 | 118 | stream = io.StringIO() 119 | sync.sync_targets(base_file, output_dir, overwrite=False, stream=stream) 120 | 121 | assert "SKIP (already up to date)" in stream.getvalue() 122 | -------------------------------------------------------------------------------- /tests/test_eval_mldr_official.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | from datasets import Dataset 8 | 9 | ROOT = Path(__file__).resolve().parents[1] 10 | if str(ROOT) not in sys.path: 11 | sys.path.append(str(ROOT)) 12 | 13 | from scripts.eval_mldr import ( # noqa: E402 14 | _should_use_naver_provence_model, 15 | build_records, 16 | parse_args, 17 | ) 18 | 19 | 20 | def _build_dummy_dataset() -> Dataset: 21 | return Dataset.from_list( 22 | [ 23 | { 24 | "query_id": "q1", 25 | "query": "dummy question", 26 | "positive_passages": [ 27 | {"text": "positive text", "docid": "doc1", "title": "Title 1"}, 28 | {"text": "another positive", "docid": "doc2", "title": None}, 29 | ], 30 | "negative_passages": [], 31 | } 32 | ] 33 | ) 34 | 35 | 36 | def test_official_detector_handles_remote_ids() -> None: 37 | assert _should_use_naver_provence_model( 38 | "naver/provence-reranker-debertav3-v1", 39 | is_local=False, 40 | ) 41 | assert _should_use_naver_provence_model( 42 | "NAVER/Provence-multilingual", 43 | is_local=False, 44 | ) 45 | assert _should_use_naver_provence_model( 46 | "naver/xprovence-reranker-bgem3-v1", 47 | is_local=False, 48 | ) 49 | assert not _should_use_naver_provence_model( 50 | "naver/other-model", 51 | is_local=False, 52 | ) 53 | assert not _should_use_naver_provence_model( 54 | "./local/provence", 55 | is_local=True, 56 | ) 57 | 58 | 59 | def test_parse_args_auto_adjusts_for_official(monkeypatch, tmp_path: Path) -> None: 60 | monkeypatch.setattr("scripts.eval_mldr.torch.cuda.is_available", lambda: True) 61 | argv = [ 62 | "prog", 63 | "--model", 64 | "naver/xprovence-reranker-bgem3-v1", 65 | "--lang", 66 | "en", 67 | "--output-dir", 68 | str(tmp_path / "out"), 69 | "--no-eval", 70 | ] 71 | monkeypatch.setattr(sys, "argv", argv) 72 | 73 | args = parse_args() 74 | 75 | assert args.device == "cuda" 76 | assert args.torch_dtype == "bfloat16" 77 | assert args.auto_device_cuda 78 | assert args.auto_torch_dtype 79 | 80 | 81 | def test_build_records_fills_missing_fields_for_official_results() -> None: 82 | dataset = _build_dummy_dataset() 83 | 84 | def process_fn(**_: Any) -> dict[str, Any]: 85 | return { 86 | "pruned_context": [["positive text", "another positive"]], 87 | "reranking_score": [[0.8, 0.6]], 88 | "compression_rate": [[20.0, 30.0]], 89 | } 90 | 91 | records, stats, num_queries = build_records( 92 | process_fn, 93 | dataset, 94 | threshold=0.1, 95 | batch_size=2, 96 | log_timing=False, 97 | use_best_reranker_score=True, 98 | show_progress=False, 99 | ) 100 | 101 | assert num_queries == 1 102 | assert len(records) == 2 103 | for record in records: 104 | assert record["kept_sentences"] == [] 105 | assert record["removed_sentences"] == [] 106 | assert stats["pos_scores"] == [0.8, 0.6] 107 | 108 | 109 | def test_build_records_accepts_scalar_outputs() -> None: 110 | dataset = Dataset.from_list( 111 | [ 112 | { 113 | "query_id": "q1", 114 | "query": "dummy question", 115 | "positive_passages": [ 116 | {"text": "positive text", "docid": "doc1", "title": None}, 117 | ], 118 | "negative_passages": [], 119 | } 120 | ] 121 | ) 122 | 123 | def process_fn(**_: Any) -> dict[str, Any]: 124 | return { 125 | "pruned_context": "positive text", 126 | "reranking_score": 0.9, 127 | "compression_rate": 25.0, 128 | } 129 | 130 | records, stats, num_queries = build_records( 131 | process_fn, 132 | dataset, 133 | threshold=0.1, 134 | batch_size=1, 135 | log_timing=False, 136 | use_best_reranker_score=True, 137 | show_progress=False, 138 | ) 139 | 140 | assert num_queries == 1 141 | assert len(records) == 1 142 | assert records[0]["pruned_text"] == "positive text" 143 | assert records[0]["kept_sentences"] == [] 144 | assert records[0]["removed_sentences"] == [] 145 | assert stats["pos_scores"] == [0.9] 146 | -------------------------------------------------------------------------------- /scripts/hf_utils/update_standalone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Release helper: copy the local `open_provence/modeling_open_provence_standalone.py` 3 | into the four published HF model repos (README list) without touching git-lfs 4 | artifacts, then commit and push. 5 | 6 | Runbook (Nov 22, 2025): 7 | 1) Update the standalone file locally as needed. 8 | 2) Execute `python scripts/hf_utils/update_standalone.py`. 9 | - Clones / pulls into `tmp/release_models/` with 10 | `GIT_LFS_SKIP_SMUDGE=1` to avoid LFS downloads. 11 | - Copies the standalone file, commits with a dated message, and pushes. 12 | 3) Verify: `git -C tmp/release_models/ log -1 --oneline` 13 | should show `chore: update standalone file ()`. 14 | 4) Optional: run `python scripts/hf_utils/hf_model_process_check.py` 15 | to smoke-test the pushed code via AutoModel. 16 | """ 17 | 18 | from __future__ import annotations 19 | 20 | import argparse 21 | import os 22 | import shutil 23 | import subprocess 24 | from collections.abc import Iterable 25 | from datetime import datetime 26 | from pathlib import Path 27 | 28 | DEFAULT_MODELS: tuple[str, ...] = ( 29 | "hotchpotch/open-provence-reranker-v1", 30 | "hotchpotch/open-provence-reranker-xsmall-v1", 31 | "hotchpotch/open-provence-reranker-large-v1", 32 | "hotchpotch/open-provence-reranker-v1-gte-modernbert-base", 33 | ) 34 | 35 | REPO_ROOT = Path(__file__).resolve().parents[2] 36 | STANDALONE_SRC = REPO_ROOT / "open_provence" / "modeling_open_provence_standalone.py" 37 | 38 | 39 | def run(cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None) -> None: 40 | merged_env = os.environ.copy() 41 | merged_env.update(env or {}) 42 | print(f"[cmd] {' '.join(cmd)} (cwd={cwd})") 43 | subprocess.run(cmd, cwd=cwd, env=merged_env, check=True) 44 | 45 | 46 | def ensure_repo(repo_id: str, base_dir: Path, env: dict[str, str]) -> Path: 47 | target_dir = base_dir / repo_id.split("/", maxsplit=1)[1] 48 | if not target_dir.exists(): 49 | base_dir.mkdir(parents=True, exist_ok=True) 50 | run( 51 | ["git", "clone", f"https://huggingface.co/{repo_id}", str(target_dir)], 52 | env=env, 53 | ) 54 | else: 55 | run(["git", "-C", str(target_dir), "pull", "--rebase"], env=env) 56 | return target_dir 57 | 58 | 59 | def copy_standalone(dest_repo: Path) -> None: 60 | dest = dest_repo / "modeling_open_provence_standalone.py" 61 | shutil.copy2(STANDALONE_SRC, dest) 62 | print(f"[copy] {STANDALONE_SRC} -> {dest}") 63 | 64 | 65 | def has_changes(repo_dir: Path) -> bool: 66 | result = subprocess.run( 67 | ["git", "-C", str(repo_dir), "status", "--porcelain"], 68 | check=True, 69 | capture_output=True, 70 | text=True, 71 | ) 72 | return result.stdout.strip() != "" 73 | 74 | 75 | def commit_and_push(repo_dir: Path, message: str, env: dict[str, str]) -> None: 76 | run(["git", "-C", str(repo_dir), "add", "modeling_open_provence_standalone.py"], env=env) 77 | if not has_changes(repo_dir): 78 | print("[skip] No changes to commit.") 79 | return 80 | run(["git", "-C", str(repo_dir), "commit", "-m", message], env=env) 81 | run(["git", "-C", str(repo_dir), "push"], env=env) 82 | 83 | 84 | def update_models(models: Iterable[str], base_dir: Path, commit_message: str) -> None: 85 | git_env = {"GIT_LFS_SKIP_SMUDGE": "1"} 86 | for repo_id in models: 87 | print(f"\n=== Updating {repo_id} ===") 88 | repo_dir = ensure_repo(repo_id, base_dir, git_env) 89 | copy_standalone(repo_dir) 90 | commit_and_push(repo_dir, commit_message, git_env) 91 | 92 | 93 | def parse_args() -> argparse.Namespace: 94 | parser = argparse.ArgumentParser( 95 | description="Sync modeling_open_provence_standalone.py into HF model repos without git-lfs.", 96 | ) 97 | parser.add_argument( 98 | "--models", 99 | nargs="*", 100 | default=DEFAULT_MODELS, 101 | help="Hugging Face model IDs to update (defaults to the four models in README.md).", 102 | ) 103 | parser.add_argument( 104 | "--base-dir", 105 | type=Path, 106 | default=Path("tmp/release_models"), 107 | help="Local directory for cloning HF model repos.", 108 | ) 109 | parser.add_argument( 110 | "--message", 111 | default=f"chore: update standalone file ({datetime.now().date().isoformat()})", 112 | help="Git commit message to use for pushes.", 113 | ) 114 | return parser.parse_args() 115 | 116 | 117 | def main() -> None: 118 | args = parse_args() 119 | update_models(models=args.models, base_dir=args.base_dir, commit_message=args.message) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "open-provence" 3 | version = "0.1.0" 4 | description = "OpenProvence: efficient and robust context pruning for retrieval-augmented generation" 5 | license = { text = "MIT" } 6 | readme = "README.md" 7 | authors = [ 8 | { name = "OpenProvence Contributors", email = "hotchpotch@gmail.com" } 9 | ] 10 | maintainers = [ 11 | { name = "OpenProvence Contributors", email = "hotchpotch@gmail.com" } 12 | ] 13 | requires-python = ">=3.11" 14 | keywords = [ 15 | "Query-dependent pruning", 16 | "Text pruning", 17 | "RAG", 18 | "Retrieval-Augmented Generation", 19 | "Transformer Networks", 20 | "PyTorch", 21 | "NLP", 22 | "deep learning", 23 | ] 24 | classifiers = [ 25 | "Development Status :: 5 - Production/Stable", 26 | "Intended Audience :: Science/Research", 27 | "License :: OSI Approved :: MIT License", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Programming Language :: Python :: 3.13", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | ] 33 | dependencies = [ 34 | "transformers>=4.57.1", 35 | "tqdm", 36 | "torch>=2.8.0,<2.9", 37 | "scikit-learn", 38 | "scipy", 39 | "huggingface-hub>=0.20.0", 40 | "Pillow", 41 | "typing_extensions>=4.5.0", 42 | "datasets==2.20.0", 43 | "sentencepiece>=0.2.0", 44 | "einops>=0.8.1", 45 | "protobuf>=6.31.1", 46 | "bunkai>=1.5.7", 47 | "langdetect>=1.0.9", 48 | "accelerate>=0.26.0", 49 | "wandb>=0.21.0", 50 | "matplotlib>=3.9.4", 51 | "nltk>=3.9.1", 52 | "fast-bunkai>=0.1.0", 53 | ] 54 | 55 | [project.urls] 56 | Homepage = "https://github.com/hotchpotch/open_provence" 57 | Repository = "https://github.com/hotchpotch/open_provence" 58 | Documentation = "https://github.com/hotchpotch/open_provence/tree/main/docs" 59 | Issues = "https://github.com/hotchpotch/open_provence/issues" 60 | 61 | 62 | [project.optional-dependencies] 63 | train = ["datasets", "accelerate>=0.20.3"] 64 | dev = [ 65 | "datasets", 66 | "accelerate>=0.20.3", 67 | "pre-commit", 68 | "pytest", 69 | "pytest-cov", 70 | "pytest-xdist>=3.6.1", 71 | ] 72 | flash-attn = ["flash-attn>=2.7.4.post1"] 73 | 74 | [project.scripts] 75 | open_provence_trainer = "open_provence.trainer_cli:main" 76 | 77 | [build-system] 78 | requires = ["setuptools>=42", "wheel"] 79 | build-backend = "setuptools.build_meta" 80 | 81 | [tool.setuptools.packages.find] 82 | include = ["open_provence*"] 83 | namespaces = false 84 | 85 | [tool.ruff] 86 | target-version = "py311" 87 | line-length = 99 88 | fix = true 89 | src = ["open_provence", "tests", "scripts"] 90 | extend-exclude = [ 91 | "configs", 92 | "debug_output", 93 | "docs", 94 | "htmlcov", 95 | "log", 96 | "output", 97 | "results", 98 | "open_provence.egg-info", 99 | "tmp", 100 | "utils", 101 | "wandb", 102 | "**/.mypy_cache", 103 | "**/.pytest_cache", 104 | ".tox", 105 | "venv", 106 | ".venv", 107 | ] 108 | include = ["**/*.py"] 109 | 110 | [tool.ruff.lint] 111 | select = [ 112 | "E", 113 | "F", 114 | "W", 115 | "I", 116 | "UP", 117 | ] 118 | ignore = [ 119 | "E203", # Whitespace before ':' 120 | "E501", # Line too long (82 > 79 characters) 121 | "D105", # undocumented-magic-method 122 | "D107", # undocumented-public-init 123 | "D205", # blank-line-after-summary 124 | "D415", # ends-in-punctuation 125 | # DoNotAssignLambda 126 | "E731" 127 | ] 128 | 129 | [tool.ruff.lint.per-file-ignores] 130 | "examples/**" = [ 131 | # Ignore `E402` (import violations) in all examples 132 | "E402", 133 | # Ignore missing required imports 134 | "I002" 135 | ] 136 | "docs/**" = [ 137 | # Ignore missing required imports 138 | "I002" 139 | ] 140 | 141 | [tool.ruff.lint.isort] 142 | known-third-party = ["datasets"] 143 | required-imports = ["from __future__ import annotations"] 144 | 145 | [tool.ruff.lint.pydocstyle] 146 | convention = "google" 147 | 148 | [tool.ruff.format] 149 | quote-style = "double" 150 | 151 | [tool.pytest.ini_options] 152 | testpaths = [ 153 | "tests" 154 | ] 155 | addopts = "--strict-markers -m 'not slow and not custom'" 156 | markers = [ 157 | "slow: marks tests as slow", 158 | "custom: marks tests for third-party models with custom modules" 159 | ] 160 | 161 | [tool.pyright] 162 | pythonVersion = "3.11" 163 | pythonPlatform = "Linux" 164 | typeCheckingMode = "standard" 165 | reportMissingImports = "none" # External dependencies may not have stubs 166 | reportUnusedImport = "warning" # Allow unused imports that are re-exported 167 | reportUnusedClass = true 168 | reportUnusedFunction = true 169 | reportUnusedVariable = "warning" # Common in unpacking 170 | reportDuplicateImport = true 171 | reportOptionalSubscript = false 172 | reportOptionalMemberAccess = false 173 | reportOptionalCall = false 174 | reportOptionalIterable = false 175 | reportOptionalContextManager = false 176 | reportOptionalOperand = false 177 | exclude = [ 178 | "configs", 179 | "debug_output", 180 | "docs", 181 | "log", 182 | "output", 183 | "results", 184 | "open_provence.egg-info", 185 | "tmp", 186 | "utils", 187 | "wandb", 188 | "**/__pycache__", 189 | ".venv", 190 | "venv", 191 | ".tox", 192 | ] 193 | include = [ 194 | "open_provence", 195 | "tests", 196 | "scripts", 197 | ] 198 | 199 | [[tool.uv.index]] 200 | name = "torch-cpu" 201 | url = "https://download.pytorch.org/whl/cpu" 202 | explicit = true 203 | 204 | [[tool.uv.index]] 205 | name = "torch-cu128" 206 | url = "https://download.pytorch.org/whl/cu128" 207 | explicit = true 208 | 209 | [tool.uv.sources] 210 | torch = [ 211 | { index = "torch-cpu", group = "cpu" }, 212 | { index = "torch-cu128", group = "cuda", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" }, 213 | ] 214 | 215 | [tool.uv] 216 | default-groups = ["dev", "cuda"] 217 | conflicts = [ 218 | [ 219 | { group = "cpu" }, 220 | { group = "cuda" }, 221 | ], 222 | ] 223 | 224 | [dependency-groups] 225 | cpu = [ 226 | "torch>=2.8.0,<2.9", 227 | ] 228 | 229 | cuda = [ 230 | "torch>=2.8.0,<2.9", 231 | ] 232 | 233 | dev = [ 234 | "litellm>=1.77.7", 235 | "openai>=2.3.0", 236 | "pyright>=1.1.406", 237 | "pytest>=8.4.1", 238 | "pytest-xdist>=3.6.1", 239 | "sentence-transformers>=5.1.1", 240 | "ruff>=0.6.9", 241 | "tox-uv>=1.29.0", 242 | "wandb>=0.21.0", 243 | # "vllm>=0.9.0.1", 244 | "trafilatura>=2.0.0", 245 | "spacy>=3.8.7", 246 | ] 247 | 248 | flash-attn = [ 249 | "flash-attn>=2.7.4.post1", 250 | ] 251 | -------------------------------------------------------------------------------- /tests/test_trainer_sampling.py: -------------------------------------------------------------------------------- 1 | """Tests for dataset sampling logic in ``open_provence.trainer``.""" 2 | 3 | from __future__ import annotations 4 | 5 | import random 6 | from pathlib import Path 7 | from typing import Any, cast 8 | 9 | import pytest 10 | from datasets import Dataset, DatasetDict 11 | from open_provence.trainer import ( 12 | DataArguments, 13 | _sample_dataset_randomly, 14 | prepare_dataset, 15 | sample_items_by_label_priority, 16 | ) 17 | 18 | 19 | def test_sample_dataset_randomly_is_deterministic() -> None: 20 | dataset = Dataset.from_dict({"value": list(range(10))}) 21 | 22 | rnd_first = random.Random(42) 23 | rnd_second = random.Random(42) 24 | 25 | sampled_first = _sample_dataset_randomly(dataset, 3, rnd_first, "test") 26 | sampled_second = _sample_dataset_randomly(dataset, 3, rnd_second, "test") 27 | 28 | assert sampled_first["value"] == sampled_second["value"] 29 | assert len(sampled_first) == 3 30 | 31 | 32 | def test_sample_dataset_randomly_returns_original_if_large_request() -> None: 33 | dataset = Dataset.from_dict({"value": [1, 2, 3]}) 34 | rnd = random.Random(42) 35 | 36 | same_dataset = _sample_dataset_randomly(dataset, 5, rnd, "test") 37 | assert same_dataset is dataset 38 | 39 | 40 | def test_sample_dataset_randomly_rejects_non_positive_sample_size() -> None: 41 | dataset = Dataset.from_dict({"value": [1, 2, 3]}) 42 | rnd = random.Random(42) 43 | 44 | with pytest.raises(ValueError): 45 | _sample_dataset_randomly(dataset, 0, rnd, "test") 46 | 47 | 48 | def _build_dataset(size: int = 10, validation_size: int = 6) -> DatasetDict: 49 | data = { 50 | "query": [f"q{i}" for i in range(size)], 51 | "positive": [f"pos{i}" for i in range(size)], 52 | "negative": [f"neg{i}" for i in range(size)], 53 | "teacher_score": [float(i) for i in range(size)], 54 | } 55 | validation = { 56 | "query": [f"vq{i}" for i in range(validation_size)], 57 | "positive": [f"vpos{i}" for i in range(validation_size)], 58 | "negative": [f"vneg{i}" for i in range(validation_size)], 59 | "teacher_score": [float(i) for i in range(validation_size)], 60 | } 61 | return DatasetDict( 62 | { 63 | "train": Dataset.from_dict(data), 64 | "validation": Dataset.from_dict(validation), 65 | } 66 | ) 67 | 68 | 69 | def test_prepare_dataset_supports_local_paths(tmp_path: Path) -> None: 70 | dataset = _build_dataset() 71 | dataset_path = tmp_path / "local_ds" 72 | dataset.save_to_disk(dataset_path) 73 | 74 | data_args = DataArguments( 75 | dataset_name="unused", 76 | subset="default", 77 | teacher_column="teacher_score", 78 | datasets=[ 79 | { 80 | "dataset_name": str(dataset_path), 81 | "teacher_column": "teacher_score", 82 | } 83 | ], 84 | ) 85 | 86 | train_dataset, eval_dataset = prepare_dataset(data_args, seed=13) 87 | 88 | assert len(train_dataset) == len(dataset["train"]) 89 | assert len(eval_dataset) == len(dataset["validation"]) 90 | 91 | 92 | def test_prepare_dataset_applies_n_samples(monkeypatch: pytest.MonkeyPatch) -> None: 93 | def fake_load_dataset(name: str, subset: str | None = None) -> DatasetDict: 94 | return _build_dataset() 95 | 96 | monkeypatch.setattr("open_provence.trainer.load_dataset", fake_load_dataset) 97 | 98 | data_args = DataArguments( 99 | dataset_name="dummy", 100 | subset="default", 101 | teacher_column="teacher_score", 102 | datasets=[ 103 | { 104 | "dataset_name": "dummy", 105 | "subset": "default", 106 | "teacher_column": "teacher_score", 107 | "n_samples": 5, 108 | } 109 | ], 110 | ) 111 | 112 | train_dataset, eval_dataset = prepare_dataset(data_args, seed=42) 113 | 114 | assert len(train_dataset) == 5 115 | assert len(eval_dataset) == 3 116 | 117 | # Deterministic sampling: rerunning with the same seed yields identical results 118 | train_dataset_again, eval_dataset_again = prepare_dataset(data_args, seed=42) 119 | assert train_dataset_again["query"] == train_dataset["query"] 120 | assert eval_dataset_again["query"] == eval_dataset["query"] 121 | 122 | 123 | def test_prepare_dataset_accepts_fractional_n_samples(monkeypatch: pytest.MonkeyPatch) -> None: 124 | def fake_load_dataset(name: str, subset: str | None = None) -> DatasetDict: 125 | return _build_dataset() 126 | 127 | monkeypatch.setattr("open_provence.trainer.load_dataset", fake_load_dataset) 128 | 129 | data_args = DataArguments( 130 | dataset_name="dummy", 131 | subset="default", 132 | teacher_column="teacher_score", 133 | datasets=[ 134 | { 135 | "dataset_name": "dummy", 136 | "subset": "default", 137 | "teacher_column": "teacher_score", 138 | "n_samples": 0.2, 139 | } 140 | ], 141 | ) 142 | 143 | train_dataset, eval_dataset = prepare_dataset(data_args, seed=42) 144 | 145 | assert len(train_dataset) == 2 # ceil(10 * 0.2) 146 | assert len(eval_dataset) == 2 # ceil(6 * 0.2) 147 | 148 | train_dataset_again, eval_dataset_again = prepare_dataset(data_args, seed=42) 149 | assert train_dataset_again["query"] == train_dataset["query"] 150 | assert eval_dataset_again["query"] == eval_dataset["query"] 151 | 152 | 153 | def test_sample_items_handles_missing_labels() -> None: 154 | dataset = Dataset.from_dict( 155 | { 156 | "texts": [ 157 | ["doc0", "doc1", "doc2", "doc3"], 158 | ["doc4", "doc5", "doc6"], 159 | ], 160 | "teacher_score": [ 161 | [0.9, 0.1, 0.2, 0.3], 162 | [0.8, 0.6, 0.2], 163 | ], 164 | "extra": [ 165 | ["meta0", "meta1", "meta2", "meta3"], 166 | ["meta4", "meta5", "meta6"], 167 | ], 168 | } 169 | ) 170 | 171 | sampled = sample_items_by_label_priority(dataset, max_items=2, seed=7) 172 | 173 | sampled_rows = [cast(dict[str, Any], row) for row in sampled] 174 | 175 | for row in sampled_rows: 176 | assert len(cast(list[Any], row["texts"])) == 2 177 | assert len(cast(list[Any], row["teacher_score"])) == 2 178 | assert len(cast(list[Any], row["extra"])) == 2 179 | 180 | # Deterministic across runs with same seed 181 | sampled_again = [ 182 | cast(dict[str, Any], row) 183 | for row in sample_items_by_label_priority(dataset, max_items=2, seed=7) 184 | ] 185 | assert sampled_again == sampled_rows 186 | -------------------------------------------------------------------------------- /scripts/utils/sync_output_modeling.py: -------------------------------------------------------------------------------- 1 | """Synchronise modeling_open_provence_standalone.py files in output directories.""" 2 | 3 | from __future__ import annotations 4 | 5 | import argparse 6 | import json 7 | import sys 8 | from dataclasses import dataclass 9 | from pathlib import Path 10 | from typing import Iterable, TextIO 11 | 12 | 13 | _DEPRECATED_CONFIG_KEYS: tuple[str, ...] = ( 14 | "splitter_default_language", 15 | "standalone_process_default_language", 16 | "modeling_open_provence_default_language", 17 | ) 18 | 19 | 20 | @dataclass 21 | class TargetState: 22 | modeling_path: Path 23 | config_path: Path | None 24 | modeling_needs_update: bool 25 | config_needs_update: bool 26 | removed_keys: tuple[str, ...] 27 | 28 | def requires_action(self) -> bool: 29 | return self.modeling_needs_update or self.config_needs_update 30 | 31 | 32 | def _load_base_content(base_file: Path) -> str: 33 | if not base_file.exists(): 34 | raise FileNotFoundError(f"Base modeling file not found: {base_file}") 35 | return base_file.read_text(encoding="utf-8") 36 | 37 | 38 | def _evaluate_config(modeling_path: Path) -> tuple[Path | None, bool, tuple[str, ...]]: 39 | config_path = modeling_path.with_name("config.json") 40 | if not config_path.exists(): 41 | return None, False, () 42 | 43 | try: 44 | config = json.loads(config_path.read_text(encoding="utf-8")) 45 | except json.JSONDecodeError: 46 | return config_path, False, () 47 | 48 | if config.get("model_type") != "open_provence": 49 | return config_path, False, () 50 | 51 | removed_keys = tuple(key for key in _DEPRECATED_CONFIG_KEYS if key in config) 52 | return config_path, bool(removed_keys), removed_keys 53 | 54 | 55 | def _gather_target_states(base_content: str, output_dir: Path) -> list[TargetState]: 56 | if not output_dir.exists(): 57 | return [] 58 | 59 | states: list[TargetState] = [] 60 | for modeling_path in sorted(output_dir.rglob("modeling_open_provence_standalone.py")): 61 | current_content = modeling_path.read_text(encoding="utf-8") 62 | modeling_needs_update = current_content != base_content 63 | config_path, config_needs_update, removed_keys = _evaluate_config(modeling_path) 64 | states.append( 65 | TargetState( 66 | modeling_path=modeling_path, 67 | config_path=config_path, 68 | modeling_needs_update=modeling_needs_update, 69 | config_needs_update=config_needs_update, 70 | removed_keys=removed_keys, 71 | ) 72 | ) 73 | return states 74 | 75 | 76 | def parse_args() -> argparse.Namespace: 77 | parser = argparse.ArgumentParser( 78 | description="Copy the latest modeling_open_provence_standalone.py into every output run (dry run by default)." 79 | ) 80 | parser.add_argument( 81 | "--overwrite", 82 | action="store_true", 83 | help="Apply changes; without this flag the script reports pending updates (dry run).", 84 | ) 85 | parser.add_argument( 86 | "--output-dir", 87 | type=Path, 88 | default=Path("output"), 89 | help="Root directory that contains run outputs (default: ./output).", 90 | ) 91 | return parser.parse_args() 92 | 93 | 94 | def plan_sync(base_file: Path, output_dir: Path) -> list[TargetState]: 95 | base_content = _load_base_content(base_file) 96 | return _gather_target_states(base_content, output_dir) 97 | 98 | 99 | def _format_removed_keys(keys: Iterable[str]) -> str: 100 | formatted = ", ".join(sorted(keys)) 101 | return formatted if formatted else "" 102 | 103 | 104 | def sync_targets( 105 | base_file: Path, 106 | output_dir: Path, 107 | overwrite: bool, 108 | *, 109 | stream: TextIO = sys.stdout, 110 | ) -> None: 111 | if not output_dir.exists(): 112 | print(f"No output directory found at {output_dir}", file=stream) 113 | return 114 | 115 | base_content = _load_base_content(base_file) 116 | targets = _gather_target_states(base_content, output_dir) 117 | if not targets: 118 | print("No matching modeling_open_provence_standalone.py files found.", file=stream) 119 | return 120 | 121 | mode = "Applying updates" if overwrite else "Planned updates" 122 | print(f"{mode} for {len(targets)} target(s):", file=stream) 123 | 124 | any_pending = False 125 | 126 | for state in targets: 127 | header = f"- {state.modeling_path}" 128 | if not overwrite: 129 | if not state.requires_action(): 130 | print(f"{header} → SKIP (already up to date)", file=stream) 131 | continue 132 | 133 | any_pending = True 134 | if state.modeling_needs_update: 135 | print( 136 | f"{header} → would copy latest modeling_open_provence_standalone.py", 137 | file=stream, 138 | ) 139 | if state.config_needs_update: 140 | removed = _format_removed_keys(state.removed_keys) 141 | print(f"{header} → would remove deprecated config keys: {removed}", file=stream) 142 | continue 143 | 144 | # overwrite 145 | if state.modeling_needs_update: 146 | state.modeling_path.write_text(base_content, encoding="utf-8") 147 | print(f"{header} → copied modeling_open_provence_standalone.py", file=stream) 148 | else: 149 | print(f"{header} → SKIP (already up to date)", file=stream) 150 | 151 | if state.config_needs_update and state.config_path is not None: 152 | config_path = state.config_path 153 | config = json.loads(config_path.read_text(encoding="utf-8")) 154 | for key in state.removed_keys: 155 | config.pop(key, None) 156 | config_path.write_text( 157 | json.dumps(config, ensure_ascii=False, indent=2) + "\n", 158 | encoding="utf-8", 159 | ) 160 | removed = _format_removed_keys(state.removed_keys) 161 | print(f"{header} → removed deprecated config keys: {removed}", file=stream) 162 | 163 | if not overwrite and any_pending: 164 | print("Re-run with --overwrite to apply these updates.", file=stream) 165 | 166 | 167 | def main() -> None: 168 | args = parse_args() 169 | repo_root = Path(__file__).resolve().parents[2] 170 | base_file = repo_root / "open_provence" / "modeling_open_provence_standalone.py" 171 | output_dir = args.output_dir 172 | if not output_dir.is_absolute(): 173 | output_dir = repo_root / output_dir 174 | 175 | sync_targets(base_file, output_dir, overwrite=args.overwrite) 176 | 177 | 178 | if __name__ == "__main__": 179 | main() 180 | -------------------------------------------------------------------------------- /scripts/hf_utils/hf_model_process_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import argparse 4 | from collections.abc import Iterable, Sequence 5 | from dataclasses import dataclass 6 | 7 | from transformers import AutoModel 8 | 9 | DEFAULT_MODELS: tuple[str, ...] = ( 10 | "hotchpotch/open-provence-reranker-v1", 11 | "hotchpotch/open-provence-reranker-xsmall-v1", 12 | "hotchpotch/open-provence-reranker-large-v1", 13 | "hotchpotch/open-provence-reranker-v1-gte-modernbert-base", 14 | ) 15 | 16 | question: str = "What's your favorite Japanese food?" 17 | context: str = """ 18 | Work deadlines piled up today, and I kept rambling about budget spreadsheets to my roommate. 19 | Next spring I'm planning a trip to Japan so I can wander Kyoto's markets and taste every regional dish I find. 20 | Sushi is honestly my favourite—I want to grab a counter seat and let the chef serve endless nigiri until I'm smiling through soy sauce. 21 | Later I remembered to water the plants and pay the electricity bill before finally getting some sleep. 22 | """ 23 | 24 | 25 | @dataclass 26 | class Case: 27 | name: str 28 | question: str | Sequence[str] 29 | # allow up to 3-level nesting: queries -> docs -> sentences 30 | context: str | Sequence[str] | Sequence[Sequence[str]] | Sequence[Sequence[Sequence[str]]] 31 | 32 | 33 | @dataclass 34 | class SampleResult: 35 | case: str 36 | sample: str 37 | score: float | None 38 | compression: float 39 | pruned: str | None 40 | 41 | 42 | def build_cases() -> list[Case]: 43 | questions = [question, question] 44 | contexts = [context, context] 45 | 46 | context_sentences = [line for line in context.splitlines(True) if line.strip()] 47 | context_sentences_wrapped = [context_sentences] 48 | contexts_nested = [context_sentences_wrapped, context_sentences_wrapped] 49 | 50 | return [ 51 | Case("q=str, c=str", question, context), 52 | Case("q=list[str], c=list[str]", questions, contexts), 53 | Case("q=str, c=list[str] (split sentences)", question, context_sentences), 54 | Case( 55 | "q=str, c=list[list[str]] (split sentences, single doc)", 56 | question, 57 | context_sentences_wrapped, 58 | ), 59 | Case( 60 | "q=list[str], c=list[list[str]] (split sentences per query)", 61 | questions, 62 | contexts_nested, 63 | ), 64 | ] 65 | 66 | 67 | def _iter_samples( 68 | pruned_context, rerank_score, compression_rate 69 | ) -> Iterable[tuple[str, str | None, float | None, float]]: 70 | if not isinstance(pruned_context, list): 71 | yield "", pruned_context, rerank_score, compression_rate 72 | return 73 | 74 | for idx, text in enumerate(pruned_context): 75 | text_str = "\n".join(text) if isinstance(text, list) else text 76 | 77 | score = rerank_score[idx] if isinstance(rerank_score, list) else rerank_score 78 | compression = ( 79 | compression_rate[idx] if isinstance(compression_rate, list) else compression_rate 80 | ) 81 | 82 | if isinstance(score, list): 83 | score = score[0] if score else None 84 | if isinstance(compression, list): 85 | compression = compression[0] if compression else 0.0 86 | 87 | yield f"#{idx}", text_str, score, float(compression) 88 | 89 | 90 | def run_cases(model, threshold: float, verbose: bool) -> list[SampleResult]: 91 | results: list[SampleResult] = [] 92 | for case in build_cases(): 93 | result = model.process( 94 | question=case.question, 95 | context=case.context, 96 | threshold=threshold, 97 | show_progress=verbose, 98 | ) 99 | for sample_tag, pruned, score, compression in _iter_samples( 100 | result["pruned_context"], 101 | result["reranking_score"], 102 | result["compression_rate"], 103 | ): 104 | results.append( 105 | SampleResult( 106 | case=case.name, 107 | sample=sample_tag, 108 | score=None if score is None else float(score), 109 | compression=float(compression), 110 | pruned=pruned if verbose else None, 111 | ) 112 | ) 113 | return results 114 | 115 | 116 | def _format_table(rows: list[SampleResult]) -> str: 117 | headers = ["Case", "Sample", "Rerank score", "Compression"] 118 | data: list[list[str]] = [] 119 | for row in rows: 120 | sample = row.sample or "-" 121 | score = "-" if row.score is None else f"{row.score:.4f}" 122 | compression = f"{row.compression:.2f}" 123 | data.append([row.case, sample, score, compression]) 124 | 125 | col_widths = [max(len(item[i]) for item in ([headers] + data)) for i in range(len(headers))] 126 | 127 | def fmt_row(items: Sequence[str]) -> str: 128 | return " | ".join(item.ljust(col_widths[idx]) for idx, item in enumerate(items)) 129 | 130 | divider = "-+-".join("-" * width for width in col_widths) 131 | lines = [fmt_row(headers), divider] 132 | lines.extend(fmt_row(row) for row in data) 133 | return "\n".join(lines) 134 | 135 | 136 | def parse_args() -> argparse.Namespace: 137 | parser = argparse.ArgumentParser( 138 | description="Smoke-test the four HF models using the run.py sample inputs.", 139 | ) 140 | parser.add_argument( 141 | "--models", 142 | nargs="*", 143 | default=DEFAULT_MODELS, 144 | help="Hugging Face model IDs to load (default: README models).", 145 | ) 146 | parser.add_argument( 147 | "--threshold", 148 | type=float, 149 | default=0.1, 150 | help="Pruning threshold passed to model.process.", 151 | ) 152 | parser.add_argument( 153 | "--verbose", 154 | action="store_true", 155 | help="Print pruned text for each sample in addition to the summary table.", 156 | ) 157 | return parser.parse_args() 158 | 159 | 160 | def main() -> None: 161 | args = parse_args() 162 | for model_id in args.models: 163 | print(f"\n=== {model_id} ===") 164 | model = AutoModel.from_pretrained(model_id, trust_remote_code=True) 165 | model.eval() 166 | rows = run_cases(model, threshold=args.threshold, verbose=args.verbose) 167 | 168 | if args.verbose: 169 | for row in rows: 170 | if row.pruned is None: 171 | continue 172 | print(f"\n-- {row.case} {row.sample or ''}".strip()) 173 | print("Pruned context:\n" + row.pruned) 174 | print(f"Rerank score: {row.score}") 175 | print(f"Compression: {row.compression:.2f}") 176 | 177 | print("\n" + _format_table(rows)) 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /docs/eval_dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset Evaluation Guide 2 | 3 | `scripts/eval_datasets.py` measures how many annotated evidence spans survive pruning across a configuration of context-relevance datasets. This document explains how to run the CLI, what each configuration does, and how to interpret the generated artefacts. It intentionally omits score tables so the instructions remain evergreen. 4 | 5 | ## 1. What the script checks 6 | 7 | For each dataset in a config file, the script: 8 | 9 | 1. Loads the dataset from Hugging Face (e.g., `hotchpotch/msmarco-context-relevance`). 10 | 2. Runs `model.process()` to prune each passage. 11 | 3. Compares the pruned spans against the labelled evidence annotations. 12 | 4. Computes span-level precision, recall, and a β = 2 F2 score (recall-weighted) plus mean compression. 13 | 14 | Dropping relevant spans (false negatives) is more damaging than keeping surplus context, so F2 is the headline metric. 15 | 16 | ## 2. Known gotchas in the datasets 17 | 18 | - Some datasets contain very long passages (>60 k characters). If you hit memory errors, temporarily limit evaluation via `--limit`, use the nano configs, or regenerate the dataset with shorter spans. 19 | - A small number of queries in the multilingual sets are malformed or language-mismatched. The published configs already omit the worst offenders. If you uncover new issues, send a PR to update the source dataset rather than editing the evaluation script. 20 | - Compression percentages are per-dataset averages; for heterogeneous corpora (e.g., GooAQ vs. JA-focused Wikipedia), expect different baseline compression even at the same threshold. 21 | 22 | ## 3. Config files 23 | 24 | All configs live under `configs/eval_datasets/`: 25 | 26 | | File | Purpose | 27 | | --- | --- | 28 | | `ja.yaml`, `en.yaml` | Full evaluation suites (all datasets, full sample counts). | 29 | | `ja_nano.yaml`, `en_nano.yaml` | “Nano” subsets with per-dataset `n_samples` overrides for quick smoke tests. Use these when iterating on code or verifying regressions; they run 10–20× faster. | 30 | 31 | Each entry in a config looks like: 32 | 33 | ```yaml 34 | - dataset_name: hotchpotch/msmarco-context-relevance 35 | subset: default 36 | n_samples: 100 # only in *_nano.yaml 37 | ``` 38 | 39 | The script reads each row sequentially. You can clone a file and add/remove datasets for ad‑hoc scenarios. Optional keys: 40 | 41 | - `split`: override the global split from the YAML header. 42 | - `n_samples`: cap the number of records loaded from that dataset (only present in `*_nano.yaml`). 43 | 44 | ## 4. Core command template 45 | 46 | ```bash 47 | uv run python scripts/eval_datasets.py \ 48 | --config CONFIG_PATH \ 49 | --model MODEL_DIR \ 50 | --threshold 0.1 \ 51 | --batch-size 256 \ 52 | --timing-details \ 53 | --output-json tmp/eval_