├── tests └── conftest.py ├── yast ├── modeling │ ├── __init__.py │ ├── splade.py │ └── splade_subword.py ├── utils.py ├── log_metrics.py ├── custom_dataset │ ├── mmarco.py │ ├── japanese_splade_hn_v1.py │ └── hpprc_emb_scores.py ├── run.py ├── arguments.py ├── trainer.py ├── regularizers.py ├── losses.py └── data.py ├── examples ├── japanese-splade │ ├── README.md │ ├── toy.yaml │ ├── japanese-splade-base-v1-mmarco-only.yaml │ ├── japanese-splade-base-v1.yaml │ ├── japanese-splade-base-v1-with-toy.yaml │ ├── japanese_splade_base_v2.yaml │ └── japanese_splade_base_v1_5.yaml └── toy_datasets │ └── japanese │ └── toy_dataset_japanese.jsonl ├── LICENSE ├── pyproject.toml ├── CLAUDE.md ├── README.md ├── .gitignore └── utils └── JMTEB_L0.py /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.insert(0, str(Path(__file__).resolve().parents[1])) 5 | -------------------------------------------------------------------------------- /yast/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .splade import Splade 2 | from .splade_subword import SpladeSubword 3 | 4 | __all__ = ["Splade", "SpladeSubword"] 5 | -------------------------------------------------------------------------------- /yast/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def seed_everything(seed: int = 42): 9 | random.seed(seed) 10 | os.environ["PYTHONHASHSEED"] = str(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.backends.cudnn.deterministic = True # type: ignore 15 | -------------------------------------------------------------------------------- /examples/japanese-splade/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 日本語 SPLADE モデルの学習方法 4 | 5 | [japanese-splade-base-v1](https://huggingface.co/hotchpotch/japanese-splade-base-v1) 等の学習方法です。他のバージョン(v2など)の学習用 yaml も同梱されています。 6 | 7 | ## ライブラリのセットアップ 8 | 9 | yast のルートディレクトリで実行します。 10 | 11 | ``` 12 | poetry install 13 | ``` 14 | 15 | ## toy データセットの学習 16 | 17 | サンプルデータセットの学習です。データセットが小さすぎて、きちんとしたモデルは作れませんが、データのサンプルとして。 18 | 19 | ``` 20 | poetry run python -m yast.run ./examples/japanese-splade-v1/toy.yaml 21 | ``` 22 | 23 | ## japanese-splade-base-v1 の学習 24 | 25 | ``` 26 | poetry run python -m yast.run ./examples/japanese-splade-v1/japanese-splade-base-v1.yaml 27 | ``` 28 | 29 | ## japanese-splade-base-v1-mmarco-only の学習 30 | 31 | ``` 32 | poetry run python -m yast.run ./examples/japanese-splade-v1/japanese-splade-base-v1-mmarco-only.yaml 33 | ``` 34 | 35 | ## japanese-splade-base-v1 と toy データセットの学習 36 | 37 | データセットを作った場合、データセットを混ぜての学習が可能です。 38 | 39 | ``` 40 | poetry run python -m yast.run ./examples/japanese-splade-v1/japanese-splade-base-v1-with-toy.yaml 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yuichi Tateno (@hotchpotch) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "yast" 3 | version = "0.1.0" 4 | description = "YAST - Yet Another SPLADE or Sparse Trainer" 5 | authors = [ 6 | {name = "Yuichi Tateno", email = "hotchpotch@gmail.com"} 7 | ] 8 | license = {text = "MIT"} 9 | readme = "README.md" 10 | requires-python = ">=3.11" 11 | dependencies = [ 12 | "transformers>=4.45.0", 13 | "datasets>=3.0.0", 14 | "torch>=2.7.0", 15 | "torchvision", 16 | "joblib>=1.1.0", 17 | "wandb>=0.16.0", 18 | "accelerate>=1.0.0", 19 | "einops>=0.8.1", 20 | ] 21 | 22 | [project.optional-dependencies] 23 | dev = [ 24 | "ruff>=0.7.0", 25 | "yasem>=0.3.1", 26 | "fugashi>=1.3.2", 27 | "unidic-lite>=1.0.8", 28 | ] 29 | 30 | [build-system] 31 | requires = ["hatchling"] 32 | build-backend = "hatchling.build" 33 | 34 | [tool.hatch.build.targets.wheel] 35 | packages = ["yast"] 36 | 37 | # PyTorch CUDA 12.8 configuration for uv 38 | [[tool.uv.index]] 39 | name = "pytorch-cu128" 40 | url = "https://download.pytorch.org/whl/cu128" 41 | explicit = true 42 | 43 | [tool.uv.sources] 44 | torch = [ 45 | { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 46 | ] 47 | torchvision = [ 48 | { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 49 | ] 50 | -------------------------------------------------------------------------------- /yast/log_metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class LogMetrics: 8 | @staticmethod 9 | def L0(batch_tensor: torch.Tensor) -> torch.Tensor: 10 | return torch.count_nonzero(batch_tensor, dim=-1).float().mean() 11 | 12 | def __init__(self): 13 | self._init() 14 | 15 | def _init(self): 16 | self.metrics = defaultdict(list) 17 | 18 | def clear(self): 19 | self._init() 20 | 21 | def add(self, key: str, value: float | torch.Tensor): 22 | if isinstance(value, torch.Tensor): 23 | value = value.cpu().detach().item() 24 | self.metrics[key].append(float(value)) 25 | 26 | def add_dict(self, metrics: dict[str, float | torch.Tensor]): 27 | for key, value in metrics.items(): 28 | self.add(key, value) 29 | 30 | def _process(self, np_func): 31 | return {key: float(np_func(values)) for key, values in self.metrics.items()} 32 | 33 | def mean(self): 34 | return self._process(np.mean) 35 | 36 | def max(self): 37 | return self._process(np.max) 38 | 39 | def min(self): 40 | return self._process(np.min) 41 | 42 | def std(self): 43 | return self._process(np.std) 44 | 45 | def median(self): 46 | return self._process(np.median) 47 | -------------------------------------------------------------------------------- /examples/japanese-splade/toy.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | dataloader_drop_last: true 3 | dataloader_num_workers: 12 4 | gradient_accumulation_steps: 1 5 | per_device_train_batch_size: 2 6 | learning_rate: 2.0e-05 7 | logging_steps: 1 8 | lr_scheduler_type: cosine 9 | max_grad_norm: 1.0 10 | max_length: 512 11 | model_name_or_path: tohoku-nlp/bert-base-japanese-v3 12 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 13 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 14 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 15 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 16 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 17 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 18 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 19 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 20 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' 21 | num_train_epochs: 12 22 | output_dir: ./output/toy_japanese 23 | overwrite_output_dir: true 24 | regularizer_doc: L1 25 | regularizer_query: L1 26 | remove_checkpoints: true 27 | run_name: toy_japanese 28 | save_steps: 5000 29 | save_total_limit: 2 30 | seed: 42 31 | sparsity_warmup_steps_doc: 0.1 32 | sparsity_warmup_steps_query: 0.1 33 | sparsity_weight_doc: 0.01 34 | sparsity_weight_query: 0.025 35 | train_data: 36 | - ./examples/toy_datasets/japanese/ 37 | train_group_size: 8 38 | training_losses: cross_entropy 39 | warmup_ratio: 0.05 40 | weight_decay: 0 41 | -------------------------------------------------------------------------------- /examples/japanese-splade/japanese-splade-base-v1-mmarco-only.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | dataloader_drop_last: true 3 | dataloader_num_workers: 12 4 | gradient_accumulation_steps: 32 5 | per_device_train_batch_size: 4 6 | learning_rate: 2.0e-05 7 | logging_steps: 200 8 | lr_scheduler_type: cosine 9 | max_grad_norm: 1.0 10 | max_length: 512 11 | model_name_or_path: tohoku-nlp/bert-base-japanese-v3 12 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 13 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 14 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 15 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 16 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 17 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 18 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 19 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 20 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' 21 | num_train_epochs: 12 22 | output_dir: ./output/japanese-splade-base-v1-mmarco-only 23 | overwrite_output_dir: true 24 | regularizer_doc: L1 25 | regularizer_query: L1 26 | remove_checkpoints: true 27 | run_name: japanese-splade-base-v1-mmarco-only 28 | save_steps: 5000 29 | save_total_limit: 2 30 | seed: 42 31 | sparsity_warmup_steps_doc: 0.1 32 | sparsity_warmup_steps_query: 0.1 33 | sparsity_weight_doc: 0.01 34 | sparsity_weight_query: 0.025 35 | train_data: 36 | - dataset_class: yast.custom_dataset.mmarco.MMarcoHardNegatives 37 | train_data: 38 | lang: japanese 39 | reranker: bge-reranker-v2-m3 40 | train_group_size: 8 41 | training_losses: cross_entropy 42 | warmup_ratio: 0.05 43 | weight_decay: 0 44 | -------------------------------------------------------------------------------- /examples/japanese-splade/japanese-splade-base-v1.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | dataloader_drop_last: true 3 | dataloader_num_workers: 12 4 | gradient_accumulation_steps: 8 5 | per_device_train_batch_size: 4 6 | learning_rate: 2.0e-05 7 | logging_steps: 200 8 | lr_scheduler_type: cosine 9 | max_grad_norm: 1.0 10 | max_length: 512 11 | model_name_or_path: tohoku-nlp/bert-base-japanese-v3 12 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 13 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 14 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 15 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 16 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 17 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 18 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 19 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 20 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' 21 | num_train_epochs: 2 22 | output_dir: ./output/japanese-splade-base-v1 23 | overwrite_output_dir: true 24 | regularizer_doc: L1 25 | regularizer_query: L1 26 | remove_checkpoints: true 27 | run_name: japanese-splade-base-v1 28 | save_steps: 5000 29 | save_total_limit: 2 30 | seed: 42 31 | sparsity_warmup_steps_doc: 0.1 32 | sparsity_warmup_steps_query: 0.1 33 | sparsity_weight_doc: 0.001 34 | sparsity_weight_query: 0.0025 35 | train_data: 36 | - dataset_class: yast.custom_dataset.mmarco.MMarcoHardNegatives 37 | train_data: 38 | lang: english 39 | reranker: bge-reranker-v2-m3 40 | - dataset_class: yast.custom_dataset.hpprc_emb_scores.HpprcEmbScoresDataset 41 | train_data: 42 | - subset: auto-wiki-qa 43 | - subset: mmarco 44 | - subset: jsquad 45 | - subset: jaquad 46 | - subset: auto-wiki-qa-nemotron 47 | - subset: quiz-works 48 | - subset: quiz-no-mori 49 | - aug_factor: 5 50 | subset: miracl 51 | - aug_factor: 8 52 | subset: jqara 53 | - aug_factor: 5 54 | subset: mr-tydi 55 | - aug_factor: 3 56 | subset: baobab-wiki-retrieval 57 | - subset: mkqa 58 | train_group_size: 8 59 | training_losses: cross_entropy 60 | warmup_ratio: 0.05 61 | weight_decay: 0 62 | -------------------------------------------------------------------------------- /examples/japanese-splade/japanese-splade-base-v1-with-toy.yaml: -------------------------------------------------------------------------------- 1 | bf16: true 2 | dataloader_drop_last: true 3 | dataloader_num_workers: 12 4 | gradient_accumulation_steps: 8 5 | per_device_train_batch_size: 4 6 | learning_rate: 2.0e-05 7 | logging_steps: 200 8 | lr_scheduler_type: cosine 9 | max_grad_norm: 1.0 10 | max_length: 512 11 | model_name_or_path: tohoku-nlp/bert-base-japanese-v3 12 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 13 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 14 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 15 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 16 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 17 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 18 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 19 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 20 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' 21 | num_train_epochs: 2 22 | output_dir: ./output/japanese-splade-base-v1 23 | overwrite_output_dir: true 24 | regularizer_doc: L1 25 | regularizer_query: L1 26 | remove_checkpoints: true 27 | run_name: japanese-splade-base-v1 28 | save_steps: 5000 29 | save_total_limit: 2 30 | seed: 42 31 | sparsity_warmup_steps_doc: 0.1 32 | sparsity_warmup_steps_query: 0.1 33 | sparsity_weight_doc: 0.001 34 | sparsity_weight_query: 0.0025 35 | train_data: 36 | - ./examples/toy_datasets/japanese/ # toy dataset 37 | - dataset_class: yast.custom_dataset.mmarco.MMarcoHardNegatives 38 | train_data: 39 | lang: english 40 | reranker: bge-reranker-v2-m3 41 | - dataset_class: yast.custom_dataset.hpprc_emb_scores.HpprcEmbScoresDataset 42 | train_data: 43 | - subset: auto-wiki-qa 44 | - subset: mmarco 45 | - subset: jsquad 46 | - subset: jaquad 47 | - subset: auto-wiki-qa-nemotron 48 | - subset: quiz-works 49 | - subset: quiz-no-mori 50 | - aug_factor: 5 51 | subset: miracl 52 | - aug_factor: 8 53 | subset: jqara 54 | - aug_factor: 5 55 | subset: mr-tydi 56 | - aug_factor: 3 57 | subset: baobab-wiki-retrieval 58 | - subset: mkqa 59 | train_group_size: 8 60 | training_losses: cross_entropy 61 | warmup_ratio: 0.05 62 | weight_decay: 0 63 | -------------------------------------------------------------------------------- /examples/japanese-splade/japanese_splade_base_v2.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: hotchpotch/japanese-splade-base-v1_5 2 | train_data: 3 | - 4 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 5 | dataset_options: 6 | dataset_name: msmarco-ja 7 | hard_positives: true 8 | target_model_name: "japanese-splade-base-v1_5" 9 | - 10 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 11 | dataset_options: 12 | hard_positives: true 13 | target_model_name: "japanese-splade-base-v1_5" 14 | - 15 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 16 | dataset_options: 17 | dataset_name: mqa 18 | hard_positives: false 19 | target_model_name: "japanese-splade-base-v1_5" 20 | - 21 | dataset_class: yast.custom_dataset.mmarco.MMarcoHardNegatives 22 | train_data: 23 | reranker: "bge-reranker-v2-m3" 24 | lang: "english" 25 | max_length: 512 26 | output_dir: AUTO 27 | # learning_rate: 4.0e-5 28 | optim: "adafactor" 29 | num_train_epochs: 3 30 | per_device_train_batch_size: 4 31 | gradient_accumulation_steps: 32 32 | warmup_ratio: 0.05 33 | lr_scheduler_type: cosine 34 | bf16: true 35 | dataloader_drop_last: true 36 | logging_steps: 25 37 | max_grad_norm: 1.0 38 | dataloader_num_workers: 12 # 12 39 | overwrite_output_dir: true 40 | save_total_limit: 2 41 | save_steps: 5000 42 | training_losses: 43 | cross_entropy: 44 | weight: 1.0 45 | kl_div: 46 | loss_kwargs: 47 | temperature: 0.5 48 | weight: 3.5 49 | weight_decay: 0 50 | train_group_size: 8 51 | sparsity_weight_doc: 0.35 52 | sparsity_weight_query: 0.15 53 | sparsity_warmup_steps_doc: 0.1 54 | sparsity_warmup_steps_query: 0.1 55 | regularizer_doc: flops 56 | regularizer_query: L1 57 | seed: 42 58 | remove_checkpoints: true 59 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 60 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 61 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 62 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 63 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 64 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 65 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 66 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 67 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' -------------------------------------------------------------------------------- /examples/japanese-splade/japanese_splade_base_v1_5.yaml: -------------------------------------------------------------------------------- 1 | model_name_or_path: hotchpotch/ruri-pt-base-retromae 2 | train_data: 3 | - 4 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 5 | dataset_options: 6 | dataset_name: msmarco-ja 7 | hard_positives: true 8 | - 9 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 10 | dataset_options: 11 | hard_positives: true 12 | - 13 | dataset_class: yast.custom_dataset.japanese_splade_hn_v1.JapaneseSpladeHardNegativesV1 14 | dataset_options: 15 | dataset_name: mqa 16 | hard_positives: true 17 | - 18 | dataset_class: yast.custom_dataset.mmarco.MMarcoHardNegatives 19 | train_data: 20 | reranker: "bge-reranker-v2-m3" 21 | lang: "english" 22 | - 23 | dataset_class: yast.custom_dataset.hpprc_emb_scores.HpprcEmbScoresDataset 24 | train_data: 25 | - 26 | subset: auto-wiki-qa 27 | n: 10000 28 | target_score_keys: ["ruri-reranker-large"] 29 | - subset: jsquad 30 | n: 10000 31 | target_score_keys: ["ruri-reranker-large"] 32 | - subset: jaquad 33 | n: 10000 34 | target_score_keys: ["ruri-reranker-large"] 35 | - subset: auto-wiki-qa-nemotron 36 | n: 40000 37 | target_score_keys: ["ruri-reranker-large"] 38 | - subset: quiz-works 39 | target_score_keys: ["ruri-reranker-large"] 40 | - subset: quiz-no-mori 41 | target_score_keys: ["ruri-reranker-large"] 42 | - subset: baobab-wiki-retrieval 43 | aug_factor: 3 44 | target_score_keys: ["ruri-reranker-large"] 45 | - subset: mkqa 46 | target_score_keys: ["ruri-reranker-large"] 47 | max_length: 512 48 | output_dir: AUTO 49 | # learning_rate: 4.0e-5 50 | optim: "adafactor" 51 | num_train_epochs: 2 52 | per_device_train_batch_size: 4 53 | gradient_accumulation_steps: 32 54 | warmup_ratio: 0.05 55 | lr_scheduler_type: cosine 56 | bf16: true 57 | dataloader_drop_last: true 58 | logging_steps: 25 59 | max_grad_norm: 1.0 60 | dataloader_num_workers: 12 # 12 61 | overwrite_output_dir: true 62 | save_total_limit: 2 63 | save_steps: 1000 64 | training_losses: cross_entropy 65 | weight_decay: 0 66 | train_group_size: 8 67 | sparsity_weight_doc: 0.1 68 | sparsity_weight_query: 0.03 69 | sparsity_warmup_steps_doc: 0.1 70 | sparsity_warmup_steps_query: 0.1 71 | regularizer_doc: flops 72 | regularizer_query: L1 73 | seed: 42 74 | remove_checkpoints: true 75 | noise_tokens: '" 〠 ! # $ % & '' ( ) * + , - . / : ; < = > ? @ [ \ ] ^ _ ` { | } ~ 76 | ¡ ¢ £ ¤ ¥ ¦ § © « ¬ ® ° ± ¶ · » ¿ Å × ÷ ħ Щ щ ъ א ิ ლ ‐ – — ― ‖ † ‡ • ′ ※ 77 | ‿ ⁂ ⁑ € ℧ ← ↑ → ↓ ↔ ↖ ↗ ↘ ↙ ⇄ ⇒ ⇔ ⇦ ⇧ ⇨ ⇩ ∀ ∂ ∃ ∅ ∇ ∈ ∉ ∋ − ∓ √ ∝ ∞ ∟ ∠ ∥ ∦ ∧ ∨ 78 | ∩ ∪ ∫ ∮ ∴ ∵ ∽ ≃ ≅ ≈ ≒ ≠ ≡ ≢ ≦ ≧ ≪ ≫ ≶ ≷ ⊂ ⊃ ⊄ ⊅ ⊆ ⊇ ⊊ ⊋ ⊕ ⊖ ⊗ ⊥ ⊿ ⋚ ⋛ ⌅ ⌆ ⌒ ⌘ ⎾ 79 | ⎿ ⏀ ⏁ ⏂ ⏃ ⏄ ⏅ ⏆ ⏇ ⏈ ⏉ ⏊ ⏋ ⏌ ⏎ ⓫ ⓬ ⓭ ⓮ ⓯ ⓰ ⓱ ⓲ ⓳ ⓴ ⓵ ⓶ ⓷ ⓸ ⓹ ⓺ ⓻ ⓼ ⓽ ⓾ ─ ━ ┌ ┐ ┘ 80 | ├ ╹ ■ □ ▱ ▲ △ ▶ ▷ ▼ ▽ ◀ ◁ ◆ ◇ ◉ ○ ◎ ● ◐ ◑ ◒ ◓ ◡ ◦ ◯ ☀ ☁ ☂ ☃ ★ ☆ ☎ ☖ ☗ ☞ ♀ ♂ ♠ ♡ 81 | ♢ ♣ ♤ ♥ ♦ ♧ ♨ ♩ ♪ ♫ ♬ ♭ ♮ ♯ ✓ ❖ ❶ ❷ ❸ ❹ ❺ ❻ ❼ ❽ ❾ ❿ ⤴ ⤵ ⦅ ⦆ ⦿ ⧺ ⧻ 、 。 〃 々 〇 〈 〉 82 | 《 》 「 」 『 』 【 】 〒 〓 〔 〕 〖 〗 〘 〙 〜 〝 〟 〠 〳 〴 〵 〻 〽 ぁ ぃ ぅ ぇ ぉ っ ゝ ゞ ゠ ァ ゥ ェ ォ ッ ・ 83 | ー ヽ ヾ 丿 仝 屮 彡 ﹅ ﹆ ]、' -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | YAST (Yet Another SPLADE or Sparse Trainer) is an experimental implementation of a SPLADE trainer that works with Huggingface's Trainer API. The project focuses on training sparse neural information retrieval models, particularly for Japanese language applications. 8 | 9 | **Important**: This is an experimental repository with frequent breaking changes. Code should be treated accordingly. 10 | 11 | ## Development Environment 12 | 13 | This project uses **uv** for dependency management and requires **Python 3.11**. The project was migrated from Poetry to uv for faster dependency resolution and improved development experience. 14 | 15 | ## Architecture 16 | 17 | ### Core Components 18 | 19 | - **Training Entry Point**: `yast/run.py` - Main training script that accepts YAML/JSON configuration files 20 | - **Model Architecture**: `yast/modeling/` - Contains SPLADE model implementations: 21 | - `splade.py` - Base SPLADE model 22 | - `splade_subword.py` - Subword-aware SPLADE variant 23 | - **Training Logic**: `yast/trainer.py` - Custom trainer extending HuggingFace Trainer 24 | - **Data Pipeline**: `yast/data.py` - Dataset creation and collation logic 25 | - **Custom Datasets**: `yast/custom_dataset/` - Domain-specific dataset implementations 26 | 27 | ### Configuration System 28 | 29 | Training is driven by YAML configuration files that specify: 30 | - Model parameters (ModelArguments) 31 | - Data parameters (DataArguments) 32 | - Training parameters (SpladeTrainingArguments extending HuggingFace TrainingArguments) 33 | - Run parameters (RunArguments) 34 | 35 | Examples are in `examples/japanese-splade/` directory. 36 | 37 | ## Common Commands 38 | 39 | ### Environment Setup 40 | ```bash 41 | # Initial setup 42 | uv venv --python 3.11 .venv # Create virtual environment 43 | uv sync --extra dev # Install dependencies with dev extras 44 | 45 | # Daily development 46 | source .venv/bin/activate # Activate virtual environment (optional) 47 | # OR use uv run for direct command execution without activation 48 | ``` 49 | 50 | ### Training 51 | ```bash 52 | # Train with YAML config 53 | uv run python -m yast.run path/to/config.yaml 54 | 55 | # Train with JSON config 56 | uv run python -m yast.run path/to/config.json 57 | ``` 58 | 59 | ### Code Quality 60 | ```bash 61 | uv run ruff check . # Lint code 62 | uv run ruff format . # Format code 63 | ``` 64 | 65 | ### Package Management 66 | ```bash 67 | uv add package_name # Add new dependency 68 | uv add --dev package_name # Add development dependency 69 | uv remove package_name # Remove dependency 70 | uv sync # Sync dependencies with lockfile 71 | uv lock # Update lockfile 72 | ``` 73 | 74 | ### Development Dependencies 75 | - `yasem>=0.3.1` - Related SPLADE embedder project 76 | - `fugashi` + `unidic-lite` - Japanese text processing 77 | 78 | ## Key Configuration Parameters 79 | 80 | - `sparsity_weight_doc/query`: Controls sparsity regularization 81 | - `regularizer_doc/query`: Type of regularization (L1, L2, flops, etc.) 82 | - `training_losses`: Can be single loss or dict with multiple losses and weights 83 | - `subword_pooling`: Enables subword-aware model variant 84 | - `trust_remote_code`: Required for some model loading scenarios 85 | 86 | ## Testing 87 | 88 | No formal test suite is currently implemented. The project uses example configurations for validation. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YAST - Yet Another SPLADE or Sparse Trainer 🚀 2 | 3 | Welcome to YAST! This open-source project provides a powerful and flexible SPLADE (Sparse Lexical and Expansion) trainer. Built to integrate seamlessly with Huggingface's Trainer API, YAST allows you to leverage cutting-edge sparse retrieval techniques based on various SPLADE-related research papers. Our goal is to offer an accessible tool for training these models. YAST is licensed under the permissive MIT License. 4 | 5 | ## ⚠️ Important Notice 6 | 7 | Please note that YAST is currently an **experimental** project. This means you might encounter **breaking changes** introduced from time to time. To ensure a stable experience, we highly recommend **forking** this repository and working with a specific **revision (commit hash)**. 8 | 9 | ## Development Setup 10 | 11 | This project uses [uv](https://docs.astral.sh/uv/) for dependency management and requires Python 3.11. 12 | 13 | ### Prerequisites 14 | 15 | - Python 3.11+ 16 | - [uv](https://docs.astral.sh/uv/getting-started/installation/) package manager 17 | 18 | ### Quick Start 19 | 20 | ```bash 21 | # Clone the repository 22 | git clone https://github.com/hotchpotch/yast.git 23 | cd yast 24 | 25 | # Create virtual environment and install dependencies 26 | uv venv --python 3.11 .venv 27 | uv sync --extra dev 28 | 29 | # Activate virtual environment (optional - you can use uv run instead) 30 | source .venv/bin/activate 31 | 32 | # Run training example 33 | uv run python -m yast.run examples/japanese-splade/toy.yaml 34 | ``` 35 | 36 | ### Optional: Flash Attention 2 for Performance 37 | 38 | For improved training speed, install Flash Attention 2: 39 | 40 | ```bash 41 | uv pip install --no-deps flash-attn --no-build-isolation 42 | uv pip install einops 43 | ``` 44 | 45 | **Note**: Requires a compatible CUDA GPU and may take time to compile. 46 | 47 | ## Training a Japanese SPLADE Model 48 | 49 | For details on training a Japanese SPLADE model, please see the [Japanese SPLADE example](./examples/japanese-splade/README.md). This document is written in Japanese (日本語で書かれています). If you don't read Japanese, online translation tools can be helpful for understanding the content. 50 | 51 | ### Related Blog Posts (Content in Japanese) 52 | 53 | Here are some blog posts related to this project, written in Japanese: 54 | - [高性能な日本語SPLADE(スパース検索)モデルを公開しました](https://secon.dev/entry/2024/10/07/100000/) 55 | - [SPLADE モデルの作り方・日本語SPLADEテクニカルレポート](https://secon.dev/entry/2024/10/23/080000-japanese-splade-tech-report/) 56 | - [情報検索モデルで最高性能(512トークン以下)・日本語版SPLADE v2をリリース](https://secon.dev/entry/2024/12/19/100000-japanese-splade-v2-release/) 57 | 58 | 59 | ## 💡 Related Work 60 | 61 | Another project, [YASEM (Yet Another Splade | Sparse Embedder)](https://github.com/hotchpotch/yasem), offers a more user-friendly implementation for working with SPLADE models. 62 | 63 | ## 🙏 Acknowledgments 64 | 65 | We thank the researchers behind the original SPLADE papers for their outstanding contributions to this field. 66 | 67 | ## References 68 | 69 | - [SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720) 70 | - [SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval](https://arxiv.org/abs/2109.10086) 71 | - [From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective](http://arxiv.org/abs/2205.04733) 72 | - [An Efficiency Study for SPLADE Models](https://dl.acm.org/doi/10.1145/3477495.3531833) 73 | - [A Static Pruning Study on Sparse Neural Retrievers](https://arxiv.org/abs/2304.12702) 74 | - [SPLADE-v3: New baselines for SPLADE](https://arxiv.org/abs/2403.06789) 75 | - [Minimizing FLOPs to Learn Efficient Sparse Representations](https://arxiv.org/abs/2004.05665) 76 | 77 | ## License 78 | 79 | This project is licensed under the MIT License. See the LICENSE file for full license details. 80 | Copyright (c) 2024 Yuichi Tateno (@hotchpotch) 81 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | tmp/ 165 | outputs/ 166 | output/ 167 | wandb/ 168 | -------------------------------------------------------------------------------- /yast/modeling/splade.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import AutoModelForMaskedLM, PreTrainedModel 7 | 8 | from ..arguments import ModelArguments 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class SpladeMaxPooling(nn.Module): 14 | """ 15 | SPLADE Max pooling implementation based on: 16 | 17 | "SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval" 18 | Formal et al., 2021 (https://arxiv.org/abs/2109.10086) 19 | 20 | This implements the max pooling variant introduced in SPLADE v2, which replaced 21 | the original sum pooling with max pooling over the sequence length dimension: 22 | w_j = max_{i in t} log(1 + ReLU(w_ij)) 23 | 24 | The pooling operation consists of: 25 | 1. Applying ReLU activation 26 | 2. Adding 1 and taking log: log(1 + ReLU(x)) 27 | 3. Max pooling over sequence length dimension 28 | """ 29 | 30 | def __init__(self): 31 | super().__init__() 32 | self.relu = nn.ReLU() 33 | 34 | def forward(self, output, attention_mask): 35 | """ 36 | Forward pass of SPLADE Max. 37 | 38 | Args: 39 | output (torch.Tensor): Input tensor of shape (batch_size, sequence_length, vocab_size) 40 | attention_mask (torch.Tensor): Attention mask of shape (batch_size, sequence_length) 41 | 42 | Returns: 43 | torch.Tensor: Output tensor of shape (batch_size, vocab_size) 44 | """ 45 | if output.dim() != 3 or attention_mask.dim() != 2: 46 | raise ValueError("Invalid input dimensions") 47 | 48 | if output.size(0) != attention_mask.size(0) or output.size( 49 | 1 50 | ) != attention_mask.size(1): 51 | raise ValueError("Mismatched batch size or sequence length") 52 | 53 | activated = torch.log(1 + self.relu(output)) 54 | masked = activated * attention_mask.unsqueeze(-1) 55 | values, _ = torch.max(masked, dim=1) 56 | 57 | return values 58 | 59 | 60 | class Splade(nn.Module): 61 | def __init__( 62 | self, 63 | hf_model: PreTrainedModel, 64 | model_args: ModelArguments, 65 | ): 66 | super().__init__() 67 | self.hf_model = hf_model 68 | self.model_args = model_args 69 | self.splade_max = SpladeMaxPooling() 70 | 71 | self.config = self.hf_model.config 72 | self._keys_to_ignore_on_save = getattr( 73 | self.hf_model, "_keys_to_ignore_on_save", None 74 | ) 75 | 76 | def gradient_checkpointing_enable(self, **kwargs): 77 | self.hf_model.gradient_checkpointing_enable(**kwargs) 78 | 79 | def _logit_to_query_docs(self, logits: torch.Tensor, batch_size: int): 80 | logits_shape = logits.shape 81 | assert len(logits_shape) == 2 82 | query_with_docs_size = int((logits_shape[0] / batch_size)) 83 | vocab_size = logits.shape[-1] 84 | 85 | state = logits.view( 86 | batch_size, 87 | query_with_docs_size, 88 | vocab_size, 89 | ) 90 | queries = state[:, :1, :] 91 | docs = state[:, 1:, :] 92 | return queries, docs 93 | 94 | def forward(self, batch_inputs: dict, batch_size: int): 95 | output = self.hf_model(**batch_inputs, return_dict=True).logits 96 | attention_mask = batch_inputs["attention_mask"] 97 | logits = self.splade_max(output, attention_mask) 98 | queries, docs = self._logit_to_query_docs(logits, batch_size) 99 | 100 | return queries, docs 101 | 102 | @classmethod 103 | def from_pretrained(cls, model_args: ModelArguments, *args, **kwargs): 104 | hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs) 105 | # for resume training 106 | model_args_path = os.path.join( 107 | kwargs.get("pretrained_model_name_or_path", ""), "model_args.bin" 108 | ) 109 | if os.path.exists(model_args_path): 110 | model_args = torch.load(model_args_path) 111 | splade = cls(hf_model, model_args) 112 | return splade 113 | 114 | def save_pretrained(self, output_dir: str): 115 | state_dict = self.hf_model.state_dict() 116 | state_dict = type(state_dict)( 117 | {k: v.clone().cpu() for k, v in state_dict.items()} 118 | ) 119 | self.hf_model.save_pretrained(output_dir, state_dict=state_dict) 120 | # for resume training 121 | torch.save(self.model_args, os.path.join(output_dir, "model_args.bin")) 122 | -------------------------------------------------------------------------------- /examples/toy_datasets/japanese/toy_dataset_japanese.jsonl: -------------------------------------------------------------------------------- 1 | {"query": "最近の携帯電話の機能について教えてください", "positives": ["スマートフォンには通話機能の他、カメラ、GPS、インターネット接続など多彩な機能が搭載されています"], "negatives": ["携帯電話の料金プランには、通話やデータ使用量に応じて多種多様な選択肢が提供されています", "携帯電話の歴史は、1980年代の大きな端末から始まり、現在のスマートフォンのような小型で高性能な機器へと進化しました", "スマートフォンの充電はUSB-Cなどの規格に統一されつつありますが、急速充電の有無も重要な要素です", "携帯電話のリサイクルは、端末の分解と部品の再利用を通じて、電子ゴミの削減に貢献しています", "携帯電話の電波は、基地局との間で通信を行い、都市部では特に多くの基地局が配置されています", "スマートフォンの選び方は、用途や予算に応じた機能やブランドの比較がポイントとなります", "携帯電話の画面には、液晶ディスプレイや有機EL(OLED)など、多様な種類があり、用途に応じて使い分けが可能です"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 2 | {"query": "東京の人気観光スポットを教えて", "positives": ["東京スカイツリーは東京を代表する観光名所で、展望台からは都市の絶景を楽しめます"], "negatives": ["東京の気候は四季折々に変化し、春には桜、秋には紅葉が見どころです", "東京の人口は約1400万人で、世界でも有数の大都市です", "東京の不動産市場は価格が高く、特に中心部では物件価格が急騰しています", "東京の交通機関は電車やバスが網羅されており、移動の利便性が非常に高いです", "東京は23の区と多くの市で構成され、各地域ごとに独自の特色を持っています", "東京の教育制度は公立・私立の学校が共存し、国際的な学校も多く存在しています", "東京の産業は金融、IT、観光、製造業など、多岐にわたる経済活動が展開されています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 3 | {"query": "効果的な英語学習方法について", "positives": ["英語学習では、毎日少しずつ継続的に学習することが上達の鍵となります"], "negatives": ["英語の歴史は中世の英語から始まり、現在の国際共通語としての地位を確立するまでに多くの変遷を経ました", "英語圏の文化には、アメリカ、イギリス、オーストラリアなどの地域ごとに異なる特徴が見られます", "英語の試験制度にはTOEIC、IELTS、TOEFLなど、目的に応じた多様な種類があります", "英語教師になるための資格には、TEFLやTESOLなど、さまざまな認定プログラムがあります", "英語の発音記号を学ぶことで、正しい発音を理解しやすくなります", "英語の文法書はレベルや用途に応じて多岐にわたる選択肢があります", "英語辞書にはオンライン辞書や単語学習アプリなど、多くの種類が提供されています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 4 | {"query": "健康的な食生活のアドバイス", "positives": ["バランスの良い食事には、主食、主菜、副菜をバランスよく取り入れることが大切です"], "negatives": ["食品の保存方法は冷蔵、冷凍、乾燥など、食品の種類に応じて適切な手段を選ぶ必要があります", "食品添加物は保存料や着色料など、食品の品質を保つために使用されますが、過剰摂取は避けるべきです", "食品の値段は季節や輸入状況に応じて変動することが多いです", "食品表示を正しく読むことで、栄養素やアレルギー情報を把握することができます", "食品衛生法は、安全で安心な食品の供給を確保するための法律です", "食品アレルギーは特定の食品に対して免疫系が過剰反応する状態を指します", "食品の賞味期限と消費期限を理解し、食品ロスを減らすことが重要です"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 5 | {"query": "プログラミング入門について", "positives": ["プログラミングを始めるには、まず基本的な概念と簡単な命令の使い方を理解することが重要です"], "negatives": ["プログラミング言語の歴史は、初期のアセンブリ言語から始まり、現在の高水準言語に至るまで多様な進化を遂げました", "プログラマーの給与は業界やスキルに応じて大きく異なり、フリーランスの需要も高まっています", "プログラミングの資格は、特定の言語や分野に特化した認定を受けることで、キャリアアップにつながります", "プログラミングスクールでは、短期間で実践的なスキルを学べるカリキュラムが提供されています", "プログラミング関連の求人は、IT企業から金融業界まで幅広い分野で募集されています", "プログラミングコンテストでは、参加者が問題解決能力を競い合い、優秀者は企業から注目されます", "プログラミング関連の書籍には、初心者向けから上級者向けまで、多種多様な内容が含まれています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 6 | {"query": "効果的な運動方法について", "positives": ["有酸素運動と筋力トレーニングを組み合わせることで、より効果的な運動効果が得られます"], "negatives": ["運動靴は用途やスポーツに応じて適切なタイプを選ぶ必要があります", "運動器具の価格は、家庭用とジム用で大きな違いが見られます", "運動施設の場所は、アクセスの良さや設備の充実度が選択のポイントとなります", "運動会の企画は、参加者の年齢層や目的に合わせたプログラムが重要です", "運動療法は、特定の病状に対するリハビリとして医師や専門家の指導のもと行われます", "運動部の歴史は、学校や地域の発展とともに進化してきました", "運動用品メーカーは、品質と価格のバランスを取った製品を提供しています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 7 | {"query": "睡眠の質を改善する方法", "positives": ["快適な睡眠のためには、就寝時間を規則正しく保ち、寝室の環境を整えることが重要です"], "negatives": ["睡眠薬は医師の処方が必要であり、適切な使用が重要です", "睡眠時間の統計データは、国や年齢によって平均値が異なります", "睡眠障害の診断には、専門医の評価と睡眠検査が必要です", "睡眠に関する研究論文では、最新の科学的知見が紹介されています", "睡眠グッズは、枕やマットレスなど、快適な睡眠をサポートするためのアイテムが多岐にわたります", "睡眠クリニックは、慢性的な睡眠障害を診断・治療する専門施設です", "睡眠に関する迷信は、科学的根拠に基づかない俗説が多く存在します"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 8 | {"query": "効率的な時間管理の方法", "positives": ["時間管理の基本は優先順位をつけ、重要なタスクから計画的に取り組むことです"], "negatives": ["時間管理アプリにはタスク管理や時間計測ができるものが多く、用途に合わせて選ぶことが重要です", "時間管理の歴史は産業革命時代のタイムマネジメント手法にさかのぼることができます", "時間管理セミナーでは、ビジネスや個人生活における効果的な方法を学べます", "時間管理に関する研究データは、仕事のパフォーマンス向上に関連する要因として注目されています", "時間管理ツールの価格は、無料から高額なものまで幅広く、機能によって異なります", "時間管理の専門家は、コンサルタントやビジネスコーチとして活躍しています", "時間管理の書籍には、個人の体験談や具体的な戦略を紹介したものが多く出版されています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 9 | {"query": "家庭菜園の始め方", "positives": ["家庭菜園を始めるには、日当たりの良い場所を選び、季節に合った野菜を選ぶことが大切です"], "negatives": ["家庭菜園に使用する道具は、シャベル、ジョウロ、プランターなど基本的なアイテムが必要です", "家庭菜園の害虫対策には、無農薬の方法や生物農薬の使用が推奨されます", "家庭菜園の肥料には、化学肥料と有機肥料があり、育てる植物に応じて選ぶことが重要です", "家庭菜園に関する書籍は、初心者向けからプロ向けまで幅広く出版されています", "家庭菜園での失敗例は、土壌の選択ミスや過剰な水やりなど、学ぶべきポイントが多くあります", "家庭菜園のブログでは、実体験を通じた栽培のコツや失敗談がシェアされています", "家庭菜園に関するコンテストでは、独創的なアイデアやデザインが評価されます"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} 10 | {"query": "写真撮影のコツについて", "positives": ["良い写真を撮るためには、構図、光の使い方、シャッタースピードの基本を理解することが重要です"], "negatives": ["カメラの価格は、初心者向けからプロ用機材まで幅広く、性能に応じて異なります", "写真展では、さまざまなテーマの作品が展示され、プロとアマチュアの作品を比較する機会が得られます", "写真用品を購入する際は、実店舗だけでなくオンラインショップも利用できます", "写真家の経歴は、個人の作品や活動によって異なり、独自のスタイルを持つことが重要です", "写真の歴史は、ダゲレオタイプの発明から始まり、デジタル写真の普及へと続いています", "写真コンテストでは、技術だけでなく、創造性やテーマの解釈が評価されます", "写真の印刷サービスには、高品質なプリントやフォトブックの作成が含まれています"], "positives_score": [1.0], "negatives_score": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]} -------------------------------------------------------------------------------- /utils/JMTEB_L0.py: -------------------------------------------------------------------------------- 1 | """ 2 | JMTEB retrieval データセットから、スパース性を測るためクエリとドキュメントのL0を出力 3 | """ 4 | 5 | import argparse 6 | from typing import Dict, List, cast 7 | 8 | import datasets 9 | import numpy as np 10 | from yasem import SpladeEmbedder 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser( 15 | description="Measure L0 sparsity for multiple models on JMTEB dataset." 16 | ) 17 | parser.add_argument( 18 | "-m", 19 | "--model_names", 20 | type=str, 21 | nargs="+", 22 | default=["hotchpotch/japanese-splade-base-v1"], 23 | help="List of model names to evaluate.", 24 | ) 25 | return parser.parse_args() 26 | 27 | 28 | def map_corpus_text(example): 29 | title = example.get("title", "") 30 | text = example.get("text", "") 31 | if title: 32 | text = title + " " + text 33 | return {"text": text} 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | model_names: List[str] = args.model_names 39 | 40 | JMTEB_TARGETS = [ 41 | "jaqket", 42 | "mrtydi", 43 | "jagovfaqs_22k", 44 | "nlp_journal_title_abs", 45 | "nlp_journal_title_intro", 46 | "nlp_journal_abs_intro", 47 | ] 48 | JMTEB_QUERY_SPLIT_TARGET = "test" 49 | QUERY_MAX_SAMPLE_SIZE = 1000 50 | CORPUS_MAX_SAMPLE_SIZE = 1000 51 | 52 | # Initialize result dictionary with keys as 'target-query' and 'target-docs' 53 | result_dict: Dict[str, Dict[str, float]] = {} 54 | for target in JMTEB_TARGETS: 55 | result_dict[f"{target}-query"] = {} 56 | result_dict[f"{target}-docs"] = {} 57 | 58 | for model_name in model_names: 59 | print(f"Processing Model: {model_name}") 60 | embedder = SpladeEmbedder(model_name) 61 | 62 | for target in JMTEB_TARGETS: 63 | print(f" Processing Target: {target}") 64 | # Load query dataset 65 | target_query_ds = datasets.load_dataset( 66 | "sbintuitions/JMTEB", 67 | name=f"{target}-query", 68 | split=JMTEB_QUERY_SPLIT_TARGET, 69 | trust_remote_code=True, 70 | ) 71 | target_query_ds = cast(datasets.Dataset, target_query_ds) 72 | 73 | target_query_ds = target_query_ds.select( 74 | range(min(QUERY_MAX_SAMPLE_SIZE, len(target_query_ds))) 75 | ) 76 | 77 | # Load corpus dataset 78 | target_corpus_ds = datasets.load_dataset( 79 | "sbintuitions/JMTEB", 80 | name=f"{target}-corpus", 81 | split="corpus", 82 | trust_remote_code=True, 83 | ) 84 | target_corpus_ds = cast(datasets.Dataset, target_corpus_ds) 85 | target_corpus_ds = target_corpus_ds.select( 86 | range(min(CORPUS_MAX_SAMPLE_SIZE, len(target_corpus_ds))) 87 | ) 88 | target_corpus_ds = target_corpus_ds.map(map_corpus_text, num_proc=4) 89 | 90 | print(f" Query size: {len(target_query_ds)}") 91 | print(f" Docs size: {len(target_corpus_ds)}") 92 | 93 | # Encode queries and documents 94 | query_vectors = embedder.encode( 95 | target_query_ds["query"], 96 | convert_to_csr_matrix=True, 97 | show_progress_bar=True, 98 | ) 99 | corpus_vectors = embedder.encode( 100 | target_corpus_ds["text"], 101 | convert_to_csr_matrix=True, 102 | show_progress_bar=True, 103 | ) 104 | 105 | # Calculate L0 106 | query_L0 = np.mean(np.diff(query_vectors.indptr)) # type:ignore 107 | docs_L0 = np.mean(np.diff(corpus_vectors.indptr)) # type:ignore 108 | 109 | print(f" {target} Queries L0: {query_L0}") 110 | print(f" {target} Docs L0: {docs_L0}") 111 | 112 | # Store results 113 | result_dict[f"{target}-query"][model_name] = query_L0 # type:ignore 114 | result_dict[f"{target}-docs"][model_name] = docs_L0 # type:ignore 115 | 116 | # Generate Markdown Table 117 | print("\n## L0 Sparsity\n") 118 | header = ["Target"] + model_names 119 | print("| " + " | ".join(header) + " |") 120 | print("| " + " | ".join(["---"] * len(header)) + " |") 121 | for target_key in sorted(result_dict.keys()): 122 | row = [target_key] 123 | for model in model_names: 124 | l0_value = result_dict[target_key].get(model, 0.0) 125 | row.append(f"{l0_value:.1f}") 126 | print("| " + " | ".join(row) + " |") 127 | 128 | # Generate CSV Output 129 | print("\n## CSV Output\n") 130 | print("Target," + ",".join(model_names)) 131 | for target_key in sorted(result_dict.keys()): 132 | row = [target_key] + [ 133 | f"{result_dict[target_key].get(model, 0.0):.1f}" for model in model_names 134 | ] 135 | print(",".join(row)) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /yast/custom_dataset/mmarco.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import joblib 4 | from datasets import load_dataset 5 | from huggingface_hub import hf_hub_download 6 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 7 | 8 | from ..arguments import DataArguments 9 | from ..data import DatasetForSpladeTraining 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | MMARCO_DATASET = "unicamp-dl/mmarco" 14 | HADR_NEGATIVE_SCORE_DS = "hotchpotch/mmarco-hard-negatives-reranker-score" 15 | 16 | NEG_SCORE_TH = 0.3 17 | POS_SCORE_TH = 0.7 18 | NEG_FILTER_COUNT = 8 19 | 20 | 21 | def _map_filter_score(example, neg_score_th: float, pos_score_th: float): 22 | neg_score: list[float] = example["neg.score"] 23 | neg_score_filtered_index = [ 24 | i for i, score in enumerate(neg_score) if score < neg_score_th 25 | ] 26 | # same pos_score 27 | pos_score = example["pos.score"] 28 | pos_score_filtered_index = [ 29 | i for i, score in enumerate(pos_score) if score > pos_score_th 30 | ] 31 | return { 32 | **example, 33 | "neg.score": [neg_score[i] for i in neg_score_filtered_index], 34 | "neg": [example["neg"][i] for i in neg_score_filtered_index], 35 | "pos.score": [pos_score[i] for i in pos_score_filtered_index], 36 | "pos": [example["pos"][i] for i in pos_score_filtered_index], 37 | } 38 | 39 | 40 | def _filter_score(example, net_filter_count: int): 41 | # neg のカウントがN以上で、pos のカウントが1以上のものを返す 42 | return len(example["neg"]) >= net_filter_count and len(example["pos"]) >= 1 43 | 44 | 45 | class MMarcoHardNegatives(DatasetForSpladeTraining): 46 | def __init__( 47 | self, 48 | args: DataArguments, 49 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 50 | ): 51 | logger.info("Initializing MMarcoHardNegatives dataset") 52 | 53 | self.query_max_len = args.dataset_options.get("query_max_len", 256) 54 | self.doc_max_len = args.dataset_options.get("doc_max_len", 1024) 55 | train_data = args.train_data 56 | 57 | dataset_options = args.dataset_options 58 | self.binarize_label: bool = dataset_options.get("binarize_label", False) 59 | 60 | lang = train_data["lang"] 61 | reranker_name = train_data["reranker"] 62 | neg_score_th = train_data.get("neg_score_th", NEG_SCORE_TH) 63 | pos_score_th = train_data.get("pos_score_th", POS_SCORE_TH) 64 | net_filter_count = train_data.get("net_filter_count", NEG_FILTER_COUNT) 65 | subset = f"{lang}_{reranker_name}" 66 | 67 | mapping = f"mappings/{lang}_joblib.pkl.gz" 68 | 69 | logger.info(f"Downloading mapping file from Hugging Face Hub: {mapping}") 70 | mapping_file = hf_hub_download( 71 | repo_type="dataset", repo_id=HADR_NEGATIVE_SCORE_DS, filename=mapping 72 | ) 73 | 74 | logger.info(f"Loading mapping file: {mapping_file}") 75 | index_mapping_dict = joblib.load(mapping_file) 76 | 77 | self.query_id_dict = index_mapping_dict["query_id_dict"] 78 | self.collection_id_dict = index_mapping_dict["collection_id_dict"] 79 | 80 | logger.info(f"Loading queries dataset for language: {lang}") 81 | self.queries_ds = load_dataset( 82 | MMARCO_DATASET, 83 | "queries-" + lang, 84 | split="train", 85 | trust_remote_code=True, 86 | ) 87 | logger.info(f"Loading collection dataset for language: {lang}") 88 | self.collection_ds = load_dataset( 89 | MMARCO_DATASET, 90 | "collection-" + lang, 91 | split="collection", 92 | trust_remote_code=True, 93 | ) 94 | logger.info(f"Loading hard negatives dataset subset: {subset}") 95 | ds = load_dataset(HADR_NEGATIVE_SCORE_DS, subset, split="train") 96 | ds = ds.map( 97 | _map_filter_score, 98 | num_proc=11, # type: ignore 99 | fn_kwargs={"neg_score_th": neg_score_th, "pos_score_th": pos_score_th}, 100 | ) # type: ignore 101 | ds = ds.filter( 102 | _filter_score, num_proc=11, fn_kwargs={"net_filter_count": net_filter_count} 103 | ) # type: ignore 104 | logger.info(f"Filtered dataset size: {len(ds)}") 105 | 106 | super().__init__(args, tokenizer, ds) # type: ignore 107 | 108 | def get_query_text(self, query_id: int) -> str: 109 | idx = self.query_id_dict[query_id] 110 | return self.queries_ds[idx]["text"][0 : self.query_max_len] # type: ignore 111 | 112 | def get_collection_text(self, doc_id: int) -> str: 113 | idx = self.collection_id_dict[doc_id] 114 | return self.collection_ds[idx]["text"][0 : self.doc_max_len] # type: ignore 115 | 116 | def __getitem__(self, item) -> list[dict]: 117 | qid = self.dataset[item]["qid"] 118 | query = self.get_query_text(qid) 119 | pos_ids = self.dataset[item]["pos"] 120 | pos_ids_score = self.dataset[item]["pos.score"] 121 | neg_ids = self.dataset[item]["neg"] 122 | neg_ids_score = self.dataset[item]["neg.score"] 123 | 124 | pos_texts = [self.get_collection_text(pos_id) for pos_id in pos_ids] 125 | neg_texts = [self.get_collection_text(neg_id) for neg_id in neg_ids] 126 | 127 | if self.binarize_label: 128 | pos_ids_score = [1.0] * len(pos_ids_score) 129 | neg_ids_score = [0.0] * len(neg_ids_score) 130 | 131 | return self.create_batch_inputs( 132 | query, 133 | pos_texts, 134 | neg_texts, 135 | pos_ids_score, 136 | neg_ids_score, 137 | ) 138 | -------------------------------------------------------------------------------- /yast/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/reranker/run.py 3 | License: MIT License 4 | """ 5 | 6 | import logging 7 | import os 8 | import shutil 9 | import sys 10 | from pathlib import Path 11 | 12 | from transformers import ( 13 | AutoTokenizer, 14 | HfArgumentParser, 15 | set_seed, 16 | ) 17 | from transformers.trainer_utils import get_last_checkpoint 18 | 19 | from yast.modeling.splade_subword import SpladeSubword 20 | 21 | from .arguments import ( 22 | DataArguments, 23 | ModelArguments, 24 | RunArguments, 25 | SpladeTrainingArguments, 26 | ) 27 | from .data import ( 28 | GroupCollator, 29 | create_dateset_from_args, 30 | ) 31 | from .modeling import Splade 32 | from .trainer import SpladeTrainer 33 | from .utils import seed_everything 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def _setup_wandb(): 39 | if "WANDB_PROJECT" not in os.environ: 40 | os.environ["WANDB_PROJECT"] = "splade" 41 | 42 | 43 | def splade_model_factory(model_args: ModelArguments): 44 | if model_args.subword_pooling: 45 | logger.info(f"Use subword splade model: {model_args.subword_pooling}") 46 | model = SpladeSubword.from_pretrained( 47 | model_args, 48 | model_args.model_name_or_path, 49 | trust_remote_code=model_args.trust_remote_code, 50 | ) 51 | else: 52 | model = Splade.from_pretrained( 53 | model_args, 54 | model_args.model_name_or_path, 55 | trust_remote_code=model_args.trust_remote_code, 56 | ) 57 | return model 58 | 59 | 60 | def main(): 61 | _setup_wandb() 62 | 63 | parser = HfArgumentParser( 64 | (ModelArguments, DataArguments, SpladeTrainingArguments, RunArguments) # type: ignore 65 | ) 66 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 67 | model_args, data_args, training_args, run_args = parser.parse_json_file( 68 | json_file=os.path.abspath(sys.argv[1]) 69 | ) 70 | elif len(sys.argv) == 2 and ( 71 | sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml") 72 | ): 73 | model_args, data_args, training_args, run_args = parser.parse_yaml_file( 74 | yaml_file=os.path.abspath(sys.argv[1]) 75 | ) 76 | else: 77 | model_args, data_args, training_args, run_args = ( 78 | parser.parse_args_into_dataclasses() 79 | ) 80 | 81 | seed_everything(training_args.seed) 82 | 83 | training_args.remove_unused_columns = False # override 84 | 85 | if ( 86 | os.path.exists(training_args.output_dir) 87 | and os.listdir(training_args.output_dir) 88 | and training_args.do_train 89 | and not training_args.overwrite_output_dir 90 | ): 91 | raise ValueError( 92 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 93 | ) 94 | 95 | # Setup logging 96 | logging.basicConfig( 97 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 98 | datefmt="%m/%d/%Y %H:%M:%S", 99 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 100 | ) 101 | 102 | logger.warning( 103 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 104 | training_args.local_rank, 105 | training_args.device, 106 | training_args.n_gpu, 107 | bool(training_args.local_rank != -1), 108 | training_args.fp16, 109 | ) 110 | 111 | # override 112 | if model_args.subword_pooling and not data_args.create_subword_indices: 113 | logging.info("[override] Set create_subword_indices to True") 114 | data_args.create_subword_indices = True 115 | if ( 116 | training_args.noise_tokens is not None 117 | and data_args.noise_tokens_for_subword is None 118 | ): 119 | logger.info("[override] Set noise_tokens_for_subwords") 120 | data_args.noise_tokens_for_subword = training_args.noise_tokens 121 | logger.info("Training/evaluation parameters %s", training_args) 122 | logger.info("Model parameters %s", model_args) 123 | logger.info("Data parameters %s", data_args) 124 | 125 | set_seed(training_args.seed) 126 | 127 | tokenizer = AutoTokenizer.from_pretrained( 128 | ( 129 | model_args.tokenizer_name 130 | if model_args.tokenizer_name 131 | else model_args.model_name_or_path 132 | ), 133 | use_fast=False, 134 | ) 135 | 136 | model = splade_model_factory(model_args) 137 | 138 | train_dataset = create_dateset_from_args(data_args, tokenizer) 139 | trainer = SpladeTrainer( 140 | args=training_args, 141 | model=model, 142 | train_dataset=train_dataset, 143 | data_collator=GroupCollator(tokenizer), 144 | tokenizer=tokenizer, 145 | ) 146 | 147 | Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) 148 | last_checkpoint = None 149 | if training_args.resume_from_checkpoint or os.environ.get( 150 | "RESUME_FROM_CHECKPOINT", False 151 | ): 152 | training_args.resume_from_checkpoint = True 153 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 154 | logger.info("[RESUME] last_checkpoint: %s", last_checkpoint) 155 | 156 | trainer.train(resume_from_checkpoint=last_checkpoint) 157 | trainer.save_model() 158 | 159 | if run_args.remove_checkpoints: 160 | logger.info("Remove checkpoints") 161 | # remove checkpoints 162 | for dir in Path(training_args.output_dir).glob("checkpoint-*"): 163 | if dir.is_dir(): 164 | shutil.rmtree(dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /yast/arguments.py: -------------------------------------------------------------------------------- 1 | """ 2 | code base from 3 | - https://github.com/FlagOpen/FlagEmbedding/ (MIT license) 4 | """ 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Any, Literal, Optional 8 | 9 | from transformers import TrainingArguments 10 | 11 | 12 | @dataclass 13 | class ModelArguments: 14 | """ 15 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 16 | """ 17 | 18 | model_name_or_path: str = field( 19 | default=None, # type: ignore 20 | metadata={ 21 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 22 | }, 23 | ) 24 | config_name: Optional[str] = field( 25 | default=None, 26 | metadata={ 27 | "help": "Pretrained config name or path if not the same as model_name" 28 | }, 29 | ) 30 | tokenizer_name: Optional[str] = field( 31 | default=None, 32 | metadata={ 33 | "help": "Pretrained tokenizer name or path if not the same as model_name" 34 | }, 35 | ) 36 | subword_pooling: Optional[str] = field( 37 | default=None, 38 | metadata={"help": "Pooling type for subword, max or mean(default: None)"}, 39 | ) 40 | trust_remote_code: Optional[bool] = field( 41 | default=False, 42 | metadata={"help": "Trust remote code for model loading"}, 43 | ) 44 | 45 | 46 | @dataclass 47 | class SpladeTrainingArguments(TrainingArguments): 48 | sparsity_weight_doc: float = field( 49 | default=5e-3, 50 | metadata={ 51 | "help": "Regularization coefficient for document representations. " 52 | "Controls the sparsity and FLOPS of document embeddings. " 53 | "Higher values lead to sparser representations." 54 | }, 55 | ) 56 | sparsity_weight_query: float = field( 57 | default=5e-3, 58 | metadata={ 59 | "help": "Regularization coefficient for query representations. " 60 | "Controls the sparsity and FLOPS of query embeddings. " 61 | "Higher values lead to sparser representations." 62 | }, 63 | ) 64 | 65 | sparsity_warmup_steps_query: float = field( 66 | default=0.1, 67 | metadata={ 68 | "help": "Query lambda warmup steps. If 0 < value < 1, treated as ratio of total steps. Otherwise, absolute step count." 69 | }, 70 | ) 71 | sparsity_warmup_steps_doc: float = field( 72 | default=0.1, 73 | metadata={ 74 | "help": "Document lambda warmup steps. If 0 < value < 1, treated as ratio of total steps. Otherwise, absolute step count." 75 | }, 76 | ) 77 | regularizer_query: Literal[ 78 | "mean_squared", 79 | "flops", 80 | "L1", 81 | "L2", 82 | "flops_l1_weighted", 83 | "dynamic_sparsity", 84 | "magnitude_threshold", 85 | "entropy_balanced", 86 | "dynamic_sparsity", 87 | "grouped_magnitude", 88 | "topk_entropy", 89 | "adaptive_threshold", 90 | ] = field( 91 | default="L1", 92 | metadata={}, 93 | ) 94 | regularizer_doc: Literal[ 95 | "mean_squared", 96 | "flops", 97 | "L1", 98 | "L2", 99 | "flops_l1_weighted", 100 | "dynamic_sparsity", 101 | "magnitude_threshold", 102 | "entropy_balanced", 103 | "dynamic_sparsity", 104 | "grouped_magnitude", 105 | "topk_entropy", 106 | "adaptive_threshold", 107 | ] = field( 108 | default="flops", 109 | metadata={}, 110 | ) 111 | 112 | training_losses: Any = field( 113 | default="cross_entropy", 114 | metadata={ 115 | "help": """Specify single or multiple training losses with optional weights and parameters. 116 | 117 | Supported formats: 118 | 1. Single loss (string): 119 | "cross_entropy" 120 | 121 | 2. Multiple losses with weights and parameters (dict): 122 | { 123 | "cross_entropy": { 124 | "weight": 1.0, 125 | "loss_kwargs": {} 126 | }, 127 | "mse": { 128 | "weight": 1.0, 129 | "loss_kwargs": {} 130 | } 131 | } 132 | 133 | Available loss functions: [cross_entropy, mse, contrastive, ...] 134 | 135 | Example configurations: 136 | - "cross_entropy" # Single loss 137 | - {"cross_entropy": {"weight": 1.0}, "mse": {"weight": 0.5}} # Multiple losses with weights 138 | - {"cross_entropy": {"weight": 1.0, "loss_kwargs": {"reduction": "mean"}}} # With parameters 139 | """ 140 | }, 141 | ) 142 | noise_tokens: None | str | list[str] = field( 143 | default=None, 144 | metadata={"help": "Noise tokens for training"}, 145 | ) 146 | noise_tokens_weight: float = field( 147 | default=1.0, 148 | metadata={"help": "Noise tokens loss weight"}, 149 | ) 150 | use_subword: bool = field( 151 | default=False, 152 | metadata={"help": "Use subword for splade training"}, 153 | ) 154 | 155 | 156 | @dataclass 157 | class DataArguments: 158 | train_data: Any = field( 159 | default=None, metadata={"help": "Path or hf dataset to corpus"} 160 | ) # type: ignore 161 | train_group_size: int = field(default=8) 162 | train_max_positive_size: int = field(default=1) 163 | max_length: int = field( 164 | default=512, 165 | metadata={ 166 | "help": "The maximum total input document length after tokenization for input text. " 167 | }, 168 | ) 169 | max_query_length: int = field( 170 | default=64, 171 | metadata={ 172 | "help": "The maximum total input query length after tokenization for input text." 173 | }, 174 | ) 175 | dataset_options: dict = field( 176 | default_factory=dict, metadata={"help": "Additional options for the dataset"} 177 | ) 178 | create_subword_indices: bool = field( 179 | default=False, 180 | metadata={"help": "Create subword indices for splade model"}, 181 | ) 182 | noise_tokens_for_subword: None | str | list[str] = field( 183 | default=None, 184 | metadata={"help": "Noise tokens for subwords"}, 185 | ) 186 | 187 | def __post_init__(self): 188 | # validation 189 | pass 190 | 191 | 192 | @dataclass 193 | class RunArguments: 194 | remove_checkpoints: bool = field( 195 | default=False, metadata={"help": "Remove checkpoints after training"} 196 | ) 197 | -------------------------------------------------------------------------------- /yast/modeling/splade_subword.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from transformers import PreTrainedModel 5 | 6 | from ..arguments import ModelArguments 7 | from .splade import Splade 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class SpladeSubword(Splade): 13 | POOLING_TYPES = ["max", "mean"] # Available pooling types 14 | SUBWORD_MASK_ID = -100 # ID to ignore subword positions 15 | 16 | def __init__(self, hf_model: PreTrainedModel, model_args: ModelArguments): 17 | super().__init__(hf_model, model_args) 18 | subword_pooling = model_args.subword_pooling 19 | if subword_pooling is not None and subword_pooling not in self.POOLING_TYPES: 20 | raise ValueError( 21 | f"Invalid pooling type: {subword_pooling}. Please choose from {self.POOLING_TYPES}" 22 | ) 23 | self.pooling_type = subword_pooling 24 | 25 | def _aggregate_subwords( 26 | self, 27 | logits: torch.Tensor, 28 | input_ids: torch.Tensor, 29 | subword_indices: torch.Tensor, 30 | ) -> torch.Tensor: 31 | """ 32 | Optimized aggregation of subword logits using PyTorch's scatter operations. 33 | 34 | Args: 35 | logits (torch.Tensor): [batch_size, vocab_size] logits tensor after splade_max. 36 | input_ids (torch.Tensor): [batch_size, sequence_length] input token IDs. 37 | subword_indices (torch.Tensor): [batch_size, sequence_length] subword group indices. 38 | 39 | Returns: 40 | torch.Tensor: Aggregated logits tensor. 41 | """ 42 | batch_size, vocab_size = logits.size() 43 | device = logits.device 44 | dtype = logits.dtype # Ensure consistent dtype 45 | 46 | # Create mask to identify valid subword positions 47 | mask = subword_indices != self.SUBWORD_MASK_ID # [B, S] 48 | 49 | # Get indices of valid positions 50 | valid_positions = mask.nonzero( 51 | as_tuple=False 52 | ) # [N, 2] each row is (batch_idx, seq_idx) 53 | if valid_positions.numel() == 0: 54 | return logits # No valid subwords to aggregate 55 | 56 | batch_indices = valid_positions[:, 0] # [N] 57 | seq_indices = valid_positions[:, 1] # [N] 58 | 59 | # Extract corresponding group indices and token IDs 60 | group_indices = subword_indices[batch_indices, seq_indices] # [N] 61 | token_ids = input_ids[batch_indices, seq_indices] # [N] 62 | 63 | # Get logits corresponding to selected token IDs 64 | selected_logits = logits[batch_indices, token_ids] # [N] 65 | 66 | # Calculate unique group identifiers by offsetting group IDs for each batch 67 | max_group_tensor = group_indices.max() 68 | max_group = ( 69 | max_group_tensor.item() 70 | if torch.is_tensor(max_group_tensor) 71 | else max_group_tensor 72 | ) 73 | if max_group < 0: 74 | max_group = 0 # Handle case where all group_indices are -100, though unlikely due to mask 75 | 76 | # Add batch offset to make group identifiers unique 77 | group_offset = batch_indices * (max_group + 1) # [N] 78 | unique_group_ids = group_offset + group_indices # [N] 79 | 80 | # Calculate number of unique groups 81 | num_unique_groups = batch_size * (max_group + 1) # Removed .item() 82 | 83 | if self.pooling_type == "max": 84 | # Initialize with -inf and set the same dtype as logits 85 | pooled_values = torch.full( 86 | (num_unique_groups,), # type: ignore 87 | -float("inf"), 88 | device=device, 89 | dtype=dtype, # type: ignore 90 | ) 91 | # Calculate maximum value for each group using scatter_reduce 92 | pooled_values = pooled_values.scatter_reduce( 93 | dim=0, 94 | index=unique_group_ids, 95 | src=selected_logits, 96 | reduce="amax", 97 | include_self=True, 98 | ) 99 | elif self.pooling_type == "mean": 100 | # Initialize sum and count tensors with the same dtype as logits 101 | sum_pooled = torch.zeros(num_unique_groups, device=device, dtype=dtype) # type: ignore 102 | count_pooled = torch.zeros(num_unique_groups, device=device, dtype=dtype) # type: ignore 103 | # Calculate sum for each group 104 | sum_pooled = sum_pooled.scatter_add(0, unique_group_ids, selected_logits) 105 | # Calculate count for each group 106 | count_pooled = count_pooled.scatter_add( 107 | 0, unique_group_ids, torch.ones_like(selected_logits) 108 | ) 109 | # Calculate average (prevent division by zero) 110 | pooled_values = sum_pooled / torch.clamp(count_pooled, min=1) 111 | 112 | # Map pooled values to each token 113 | pooled_values_per_token = pooled_values[unique_group_ids] # [N] # type: ignore 114 | 115 | # Calculate global token IDs to handle multiple tokens in batch 116 | global_token_ids = batch_indices * vocab_size + token_ids # [N] 117 | 118 | # Initialize tensor to hold maximum pooled values for each token (initialize with -inf) 119 | final_pooled = torch.full( 120 | (batch_size * vocab_size,), -float("inf"), device=device, dtype=dtype 121 | ) 122 | 123 | if self.pooling_type == "max": 124 | # For MaxPooling, use maximum values 125 | final_pooled = final_pooled.scatter_reduce( 126 | dim=0, 127 | index=global_token_ids, 128 | src=pooled_values_per_token, 129 | reduce="amax", 130 | include_self=True, 131 | ) 132 | final_pooled = final_pooled.view(batch_size, vocab_size) 133 | # Take maximum between original logits and pooled values 134 | new_logits = torch.maximum(logits, final_pooled) 135 | else: # mean 136 | # For MeanPooling, use average values directly 137 | temp_sum = torch.zeros_like(final_pooled) 138 | temp_count = torch.zeros_like(final_pooled) 139 | 140 | # Aggregate sum of values and count 141 | temp_sum.scatter_add_(0, global_token_ids, pooled_values_per_token) 142 | temp_count.scatter_add_( 143 | 0, global_token_ids, torch.ones_like(pooled_values_per_token) 144 | ) 145 | 146 | # Calculate final average 147 | final_pooled = (temp_sum / temp_count.clamp(min=1.0)).view( 148 | batch_size, vocab_size 149 | ) 150 | 151 | # Update only positions with subwords using average values 152 | subword_mask = (temp_count > 0).view(batch_size, vocab_size) 153 | new_logits = torch.where(subword_mask, final_pooled, logits) 154 | 155 | return new_logits 156 | 157 | def forward(self, batch_inputs: dict, batch_size: int): 158 | """ 159 | Forward pass of the model. 160 | 161 | Args: 162 | batch_inputs: Input batch containing 'input_ids', 'attention_mask', and 'subword_indices'. 163 | batch_size: Size of the batch. 164 | 165 | Returns: 166 | Tuple of query and document representations. 167 | """ 168 | subword_indices = batch_inputs.pop("subword_indices") 169 | input_ids = batch_inputs["input_ids"] 170 | 171 | output = self.hf_model(**batch_inputs, return_dict=True).logits 172 | attention_mask = batch_inputs["attention_mask"] 173 | 174 | logits = self.splade_max(output, attention_mask) 175 | 176 | if self.pooling_type is not None: 177 | logits = self._aggregate_subwords(logits, input_ids, subword_indices) 178 | 179 | # Get query and document representations 180 | query, docs = self._logit_to_query_docs(logits, batch_size) 181 | return query, docs 182 | -------------------------------------------------------------------------------- /yast/custom_dataset/japanese_splade_hn_v1.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | from datasets import concatenate_datasets, load_dataset 5 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 6 | 7 | from ..arguments import DataArguments 8 | from ..data import DatasetForSpladeTraining 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | HADR_NEGATIVE_SCORE_DS = "hotchpotch/japanese-splade-v1-hard-negatives" 13 | DS_SPIT = "train" 14 | # QUERY_DS = "mmarco-dataset" 15 | # COLLECTION_DS = "mmarco-collection" 16 | 17 | NEG_SCORE_TH = 0.3 18 | POS_SCORE_TH = 0.7 19 | NEG_FILTER_COUNT = 7 20 | NEG_POS_SCORE_TH = 0.95 21 | 22 | TOP_100_SAMPLING_COUNT = 4 # top100 から hard negative としてサンプリングする数 23 | 24 | 25 | def _map_score_with_hard_positives( 26 | example, 27 | neg_score_th: float, 28 | pos_score_th: float, 29 | neg_pos_score_th: float, 30 | hard_positives: bool, 31 | target_model_name: str = "japanese-splade-base-v1-mmarco-only", 32 | ): 33 | neg_score_top100 = example[ 34 | f"score.bge-reranker-v2-m3.neg_ids.{target_model_name}.top100" 35 | ] 36 | neg_score_other100 = example[ 37 | f"score.bge-reranker-v2-m3.neg_ids.{target_model_name}.other100" 38 | ] 39 | pos_score = example["score.bge-reranker-v2-m3.pos_ids"] 40 | neg_score_top100_filtered_index = [ 41 | i for i, score in enumerate(neg_score_top100) if score < neg_score_th 42 | ] 43 | neg_score_other100_filtered_index = [ 44 | i for i, score in enumerate(neg_score_other100) if score < neg_score_th 45 | ] 46 | pos_score_filtered_index = [ 47 | i for i, score in enumerate(pos_score) if score > pos_score_th 48 | ] 49 | 50 | # hard positives はまずは、neg.other100 から取得する 51 | hard_positives_ids = example[f"neg_ids.{target_model_name}.other100"] 52 | hard_positives_scores = neg_score_other100 53 | hard_positives_score_filtered_index = [ 54 | i for i, score in enumerate(hard_positives_scores) if score > neg_pos_score_th 55 | ] 56 | 57 | # top100 では hard_positives としては弱いので、使わない 58 | # if len(hard_positives_score_filtered_index) == 0: 59 | # # neg.other100 に hard positives に該当するスコアがない場合、top100 から取得する 60 | # hard_positives_ids = example[ 61 | # f"neg_ids.{target_model_name}.top100" 62 | # ] 63 | # hard_positives_scores = neg_score_top100 64 | # hard_positives_score_filtered_index = [ 65 | # i 66 | # for i, score in enumerate(hard_positives_scores) 67 | # if score > neg_pos_score_th 68 | # ] 69 | 70 | data = { 71 | **example, 72 | "neg.score.top100": [ 73 | neg_score_top100[i] for i in neg_score_top100_filtered_index 74 | ], 75 | "neg.top100": [ 76 | example[f"neg_ids.{target_model_name}.top100"][i] 77 | for i in neg_score_top100_filtered_index 78 | ], 79 | "neg.score.other100": [ 80 | neg_score_other100[i] for i in neg_score_other100_filtered_index 81 | ], 82 | "neg.other100": [ 83 | example[f"neg_ids.{target_model_name}.other100"][i] 84 | for i in neg_score_other100_filtered_index 85 | ], 86 | "pos.score": [pos_score[i] for i in pos_score_filtered_index], 87 | "pos": [example["pos_ids"][i] for i in pos_score_filtered_index], 88 | } 89 | if hard_positives and len(hard_positives_score_filtered_index) > 0: 90 | # hard_positives flag がある 91 | # かつ hard_positives としてふさわしいスコアがある場合、pos.score, pos を neg に置き換える 92 | data["pos.score"] = [ 93 | hard_positives_scores[i] for i in hard_positives_score_filtered_index 94 | ] 95 | data["pos"] = [ 96 | hard_positives_ids[i] for i in hard_positives_score_filtered_index 97 | ] 98 | elif len(pos_score_filtered_index) == 0: 99 | # neg_score_top100 の最大値と、その index を取得 100 | max_score = max(neg_score_top100) 101 | max_score_index = neg_score_top100.index(max_score) 102 | if max_score >= neg_pos_score_th: 103 | # pos が閾値以上のものがなく、かつ十分なスコアが neg にある場合は、それを pos とする 104 | data["pos"] = [ 105 | example[f"neg_ids.{target_model_name}.top100"][max_score_index] 106 | ] 107 | data["pos.score"] = [max_score] 108 | elif len(hard_positives_score_filtered_index) > 0: 109 | # neg_score_top100 にも hard_positives にも該当するスコアがない場合、 110 | # hard_positives_score_filtered_index から pos を1つランダムに追加する 111 | hard_positve_index = random.choice(hard_positives_score_filtered_index) 112 | max_score = hard_positives_scores[hard_positve_index] 113 | data["pos.score"] = [max_score] 114 | data["pos"] = [hard_positives_ids[hard_positve_index]] 115 | return data 116 | 117 | 118 | def _filter_score(example, net_filter_count: int): 119 | # neg のカウントがN以上で、pos のカウントが1以上のものを返す 120 | return ( 121 | len(example["neg.other100"] + example["neg.top100"]) >= net_filter_count 122 | and len(example["pos"]) >= 1 123 | ) 124 | 125 | 126 | class JapaneseSpladeHardNegativesV1(DatasetForSpladeTraining): 127 | def __init__( 128 | self, 129 | args: DataArguments, 130 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 131 | ): 132 | self.query_max_len = args.dataset_options.get("query_max_len", 256) 133 | self.doc_max_len = args.dataset_options.get("doc_max_len", 1024) 134 | 135 | dataset_options = args.dataset_options 136 | self.binarize_label: bool = dataset_options.get("binarize_label", False) 137 | self.hard_positives: bool = dataset_options.get("hard_positives", False) 138 | self.target_model_name: str = dataset_options.get( 139 | "target_model_name", "japanese-splade-base-v1-mmarco-only" 140 | ) 141 | self.query_column_name: str = dataset_options.get("query_column_name", "anc") 142 | self.doc_column_name: str = dataset_options.get("doc_column_name", "text") 143 | dataset_name = dataset_options.get("dataset_name", "mmarco") 144 | logger.info(f"Initializing {dataset_name} hard_negative dataset") 145 | logger.info(f"binarize_label: {self.binarize_label}") 146 | logger.info(f"hard_positives: {self.hard_positives}") 147 | logger.info(f"target_model_name: {self.target_model_name}") 148 | logger.info(f"query_column_name: {self.query_column_name}") 149 | logger.info(f"doc_column_name: {self.doc_column_name}") 150 | 151 | query_ds_name = f"{dataset_name}-dataset" 152 | collection_ds_name = f"{dataset_name}-collection" 153 | 154 | neg_score_th = dataset_options.get("neg_score_th", NEG_SCORE_TH) 155 | pos_score_th = dataset_options.get("pos_score_th", POS_SCORE_TH) 156 | neg_pos_thcore_th = dataset_options.get("neg_pos_thcore_th", NEG_POS_SCORE_TH) 157 | net_filter_count = dataset_options.get("net_filter_count", NEG_FILTER_COUNT) 158 | 159 | self.top_100_sampling_count = dataset_options.get( 160 | "top_100_sampling_count", TOP_100_SAMPLING_COUNT 161 | ) 162 | 163 | ds = load_dataset(HADR_NEGATIVE_SCORE_DS, query_ds_name, split=DS_SPIT) 164 | self.collection_ds = load_dataset( 165 | HADR_NEGATIVE_SCORE_DS, collection_ds_name, split=DS_SPIT 166 | ) 167 | ds = ds.map( 168 | _map_score_with_hard_positives, 169 | num_proc=11, # type: ignore 170 | fn_kwargs={ 171 | "neg_score_th": neg_score_th, 172 | "pos_score_th": pos_score_th, 173 | "neg_pos_score_th": neg_pos_thcore_th, 174 | "hard_positives": self.hard_positives, 175 | "target_model_name": self.target_model_name, 176 | }, 177 | ) # type: ignore 178 | ds = ds.filter( 179 | _filter_score, num_proc=11, fn_kwargs={"net_filter_count": net_filter_count} 180 | ) # type: ignore 181 | logger.info(f"Filtered dataset size: {len(ds)}") 182 | 183 | aug_factor = dataset_options.get("aug_factor", 1.0) 184 | n = int(dataset_options.get("n", 0)) 185 | if aug_factor != 1.0: 186 | n = int(len(ds) * (aug_factor)) 187 | logging.info( 188 | f"Augmenting dataset with factor aug_factor={aug_factor}, n={n}" 189 | ) 190 | if n > len(ds): 191 | logger.info(f"Expanding dataset from {len(ds)} to {n}") 192 | ds_expand = [] 193 | c = n // len(ds) 194 | r = n % len(ds) 195 | for _ in range(c): 196 | ds_expand.append(ds.shuffle(seed=42)) 197 | ds_expand.append(ds.shuffle(seed=42).select(range(r))) # type: ignore 198 | ds = concatenate_datasets(ds_expand) 199 | assert len(ds) == n 200 | elif n > 0: 201 | logger.info(f"Shuffling and selecting first {n} samples from dataset") 202 | ds = ds.shuffle(seed=42).select(range(n)) # type: ignore 203 | 204 | super().__init__(args, tokenizer, ds) # type: ignore 205 | 206 | def get_collection_text(self, doc_id: int) -> str: 207 | text = self.collection_ds[doc_id][self.doc_column_name] # type: ignore 208 | return text # type: ignore 209 | 210 | def __getitem__(self, item) -> list[dict]: 211 | group_size = self.args.train_group_size 212 | query = self.dataset[item][self.query_column_name] 213 | 214 | pos_ids = self.dataset[item]["pos"] 215 | pos_ids_score = self.dataset[item]["pos.score"] 216 | 217 | neg_ids_top100 = self.dataset[item]["neg.top100"] 218 | neg_ids_score_top100 = self.dataset[item]["neg.score.top100"] 219 | 220 | # N をneg_ids_top100からrandom sampling 221 | top_100_count = self.top_100_sampling_count 222 | if len(neg_ids_top100) < top_100_count: 223 | top_100_count = len(neg_ids_top100) 224 | neg_ids_top100_sampled = random.sample(neg_ids_top100, top_100_count) 225 | neg_ids_score_top100_sampled = [ 226 | neg_ids_score_top100[neg_ids_top100.index(id_)] 227 | for id_ in neg_ids_top100_sampled 228 | ] 229 | 230 | other_100_count = group_size - top_100_count - 1 231 | neg_ids_other100 = self.dataset[item]["neg.other100"] 232 | neg_ids_score_other100 = self.dataset[item]["neg.score.other100"] 233 | if len(neg_ids_other100) < other_100_count: 234 | other_100_count = len(neg_ids_other100) 235 | neg_ids_other100_sampled = random.sample(neg_ids_other100, other_100_count) 236 | neg_ids_score_other100_sampled = [ 237 | neg_ids_score_other100[neg_ids_other100.index(id_)] 238 | for id_ in neg_ids_other100_sampled 239 | ] 240 | 241 | neg_ids = neg_ids_top100_sampled + neg_ids_other100_sampled 242 | neg_ids_score = neg_ids_score_top100_sampled + neg_ids_score_other100_sampled 243 | 244 | pos_texts = [self.get_collection_text(pos_id) for pos_id in pos_ids] 245 | neg_texts = [self.get_collection_text(neg_id) for neg_id in neg_ids] 246 | 247 | if self.binarize_label: 248 | pos_ids_score = [1.0] * len(pos_ids_score) 249 | neg_ids_score = [0.0] * len(neg_ids_score) 250 | 251 | return self.create_batch_inputs( 252 | query, 253 | pos_texts, 254 | neg_texts, 255 | pos_ids_score, 256 | neg_ids_score, 257 | ) 258 | -------------------------------------------------------------------------------- /yast/custom_dataset/hpprc_emb_scores.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import cache 3 | from typing import cast 4 | 5 | import numpy as np 6 | from datasets import Dataset, concatenate_datasets, load_dataset 7 | from huggingface_hub import HfApi 8 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 9 | 10 | from ..arguments import DataArguments 11 | from ..data import DatasetForSpladeTraining 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | HPPRC_EMB_DS = "hpprc/emb" 16 | SCORE_DS = "hotchpotch/hpprc_emb-scores" 17 | 18 | 19 | @cache 20 | def get_dataset_subsets(dataset_name): 21 | logger.info(f"Fetching dataset subsets for dataset: {dataset_name}") 22 | api = HfApi() 23 | dataset_info = api.dataset_info(dataset_name) 24 | 25 | if card_data := dataset_info.card_data: 26 | if "configs" in card_data: 27 | subsets = [config["config_name"] for config in card_data["configs"]] 28 | logger.info(f"Found configs: {subsets}") 29 | return subsets 30 | elif "dataset_info" in card_data: 31 | subsets = [info["config_name"] for info in card_data["dataset_info"]] 32 | logger.info(f"Found dataset_info configs: {subsets}") 33 | return subsets 34 | else: 35 | logger.warning(f"No subsets found for dataset: {dataset_name}") 36 | return [] 37 | 38 | 39 | def get_datasets(target_name: str): 40 | logger.info(f"Retrieving datasets for target name: {target_name}") 41 | subsets = get_dataset_subsets(SCORE_DS) 42 | target_subsets = [subset for subset in subsets if subset.startswith(target_name)] # type: ignore 43 | if not target_subsets: 44 | logger.error(f"Subset not found: {target_name}") 45 | raise ValueError(f"Subset not found: {target_name}") 46 | target_subset = target_subsets[0] 47 | target_base_name, revision = target_subset.rsplit("-dataset__", 1) 48 | logger.info( 49 | f"Loading score dataset: {SCORE_DS}, subset: {target_subset}, revision: {revision}" 50 | ) 51 | score_ds = load_dataset( 52 | SCORE_DS, 53 | name=target_subset, 54 | split="train", 55 | ) 56 | if target_name.startswith("quiz-") or target_name.startswith("mkqa"): 57 | collection_name = "qa-collection" 58 | else: 59 | collection_name = f"{target_base_name}-collection" 60 | logger.info( 61 | f"Loading embedding dataset: {HPPRC_EMB_DS}, collection: {collection_name}, revision: {revision}" 62 | ) 63 | hpprc_emb_ds = load_dataset( 64 | HPPRC_EMB_DS, 65 | name=collection_name, 66 | split="train", 67 | revision=revision, 68 | ) 69 | return hpprc_emb_ds, score_ds 70 | 71 | 72 | TARGET_POS_ID_KEYS = [ 73 | "pos_ids", 74 | "pos_ids.original", 75 | "pos_ids.me5-large", 76 | "pos_ids.bm25", 77 | ] 78 | TARGET_NEG_ID_KEYS = [ 79 | "neg_ids", 80 | "neg_ids.original", 81 | "neg_ids.me5-large", 82 | "neg_ids.bm25", 83 | ] 84 | 85 | 86 | def map_data( 87 | example, 88 | target_score_keys: list[str] = ["ruri-reranker-large", "bge-reranker-v2-m3"], 89 | pos_id_score_threshold: float = 0.7, 90 | neg_id_score_threshold: float = 0.3, 91 | ): 92 | target_id_keys = TARGET_POS_ID_KEYS + TARGET_NEG_ID_KEYS 93 | # 対象の target_score の平均値を取得 94 | target_score_dict = {} 95 | for id_key in target_id_keys: 96 | scores_key_values = [] 97 | for score_key in target_score_keys: 98 | full_score_key = f"score.{score_key}.{id_key}" 99 | scores = example.get(full_score_key, []) 100 | if len(scores) > 0: 101 | scores_key_values.append(np.array(scores)) 102 | if len(scores_key_values) > 0: 103 | if len(scores_key_values) != len(target_score_keys): 104 | logger.error( 105 | f"len(scores_key_values) != len(target_score_keys): {len(scores_key_values)} != {len(target_score_keys)}" 106 | ) 107 | raise ValueError( 108 | f"len(scores_key_values) != len(target_score_keys): {len(scores_key_values)} != {len(target_score_keys)}" 109 | ) 110 | mean_scores = np.array(scores_key_values).T.mean(axis=1) 111 | target_score_dict[id_key] = mean_scores.tolist() 112 | 113 | filtered_target_ids_dict = {} 114 | # 閾値でフィルタリング 115 | for id_key in target_score_dict.keys(): 116 | target_score_ids = example[id_key] 117 | target_scores = target_score_dict[id_key] 118 | if "pos_ids" in id_key: 119 | filtered_target_scores_indexes = [ 120 | i 121 | for i, score in enumerate(target_scores) 122 | if score >= pos_id_score_threshold 123 | ] 124 | else: 125 | filtered_target_scores_indexes = [ 126 | i 127 | for i, score in enumerate(target_scores) 128 | if score <= neg_id_score_threshold 129 | ] 130 | filtered_target_ids = [ 131 | target_score_ids[i] for i in filtered_target_scores_indexes 132 | ] 133 | filtered_target_ids_dict[id_key] = filtered_target_ids 134 | target_score_dict[id_key] = [ 135 | target_scores[i] for i in filtered_target_scores_indexes 136 | ] 137 | result_pos_ids = [] 138 | result_pos_ids_score = [] 139 | result_neg_ids = [] 140 | result_neg_ids_score = [] 141 | for id_key in target_score_dict.keys(): 142 | # pos_ids, neg_ids ともに重複IDがあるので、その場合は追加しない 143 | if "pos_ids" in id_key and id_key not in result_pos_ids: 144 | result_pos_ids += filtered_target_ids_dict[id_key] 145 | result_pos_ids_score += target_score_dict[id_key] 146 | elif "neg_ids" in id_key and id_key not in result_neg_ids and id_key: 147 | result_neg_ids += filtered_target_ids_dict[id_key] 148 | result_neg_ids_score += target_score_dict[id_key] 149 | # ログメッセージに置き換え 150 | logger.debug(f"pos_ids: {result_pos_ids}") 151 | logger.debug(f"neg_ids: {result_neg_ids}") 152 | logger.debug(f"result_pos_ids_score: {result_pos_ids_score}") 153 | logger.debug(f"result_neg_ids_score: {result_neg_ids_score}") 154 | return { 155 | "anc": example["anc"], 156 | "pos_ids": result_pos_ids, 157 | "pos_ids.score": result_pos_ids_score, 158 | "neg_ids": result_neg_ids, 159 | "neg_ids.score": result_neg_ids_score, 160 | } 161 | 162 | 163 | def filter_data(example): 164 | neg_ids = example["neg_ids"] 165 | if len(neg_ids) < 8: 166 | return False 167 | pos_ids = example["pos_ids"] 168 | if len(pos_ids) == 0: 169 | return False 170 | return True 171 | 172 | 173 | class HpprcEmbScoresDataset(DatasetForSpladeTraining): 174 | def __init__( 175 | self, 176 | args: DataArguments, 177 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 178 | seed: int = 42, 179 | ): 180 | logger.info("Initializing HpprcEmbScoresDataset") 181 | train_data = args.train_data 182 | # train data は list 183 | if not isinstance(train_data, list): 184 | logger.error("train_data must be a list") 185 | raise ValueError("train_data must be a list") 186 | dataset_options = args.dataset_options 187 | self.binarize_label: bool = dataset_options.get("binarize_label", False) 188 | all_ds = [] 189 | target_emb_ds = {} 190 | for target in train_data: 191 | if not isinstance(target, dict): 192 | logger.error("train_data must be a list of dictionaries") 193 | raise ValueError("train_data must be a list of dictionaries") 194 | subset = target["subset"] 195 | logger.info(f"Processing subset: {subset}") 196 | n = target.get("n", None) 197 | emb_ds, score_ds = get_datasets(subset) 198 | score_ds = cast(Dataset, score_ds) 199 | logger.info(f"Mapping data for subset: {subset}") 200 | score_ds = score_ds.map( 201 | map_data, 202 | num_proc=11, 203 | remove_columns=score_ds.column_names, 204 | fn_kwargs={ 205 | "target_score_keys": target.get( 206 | "target_score_keys", 207 | ["ruri-reranker-large", "bge-reranker-v2-m3"], 208 | ) 209 | }, 210 | ) # type: ignore 211 | target_emb_ds[subset] = emb_ds 212 | aug_factor = target.get("aug_factor", 1.0) 213 | if aug_factor != 1.0: 214 | if n is not None: 215 | logger.warning( 216 | f"aug_factor is ignored because n is specified, skipping aug_factor args for subset: {subset}" 217 | ) 218 | else: 219 | n = int(len(score_ds) * aug_factor) 220 | logger.info( 221 | f"Augmenting dataset: {subset} with aug_factor: {aug_factor}" 222 | ) 223 | if n is not None: 224 | if n > len(score_ds): 225 | logger.info( 226 | f"Expanding dataset: {subset} from {len(score_ds)} to {n}" 227 | ) 228 | score_ds_expand = [] 229 | c = n // len(score_ds) 230 | r = n % len(score_ds) 231 | for _ in range(c): 232 | score_ds_expand.append(score_ds.shuffle(seed=seed)) 233 | score_ds_expand.append(score_ds.shuffle(seed=seed).select(range(r))) 234 | score_ds = concatenate_datasets(score_ds_expand) 235 | assert len(score_ds) == n 236 | else: 237 | logger.info( 238 | f"Shuffling and selecting first {n} samples from dataset: {subset}" 239 | ) 240 | score_ds = score_ds.shuffle(seed=seed).select(range(n)) # type: ignore 241 | before_filter_len = len(score_ds) 242 | logger.info( 243 | f"Filtering dataset: {subset}, original size: {before_filter_len}" 244 | ) 245 | score_ds = score_ds.filter(filter_data, num_proc=11) 246 | after_filter_len = len(score_ds) 247 | logger.info( 248 | f"Filtered dataset size: {subset}, before: {before_filter_len}, after: {after_filter_len}, ratio: {after_filter_len / before_filter_len:.2f}" 249 | ) 250 | subsets_column = [subset] * len(score_ds) 251 | score_ds = score_ds.add_column("subset", subsets_column) # type: ignore 252 | all_ds.append(score_ds) 253 | logger.info(f"Loaded subset: {subset}, size: {len(score_ds)}") 254 | self.target_emb_ds = target_emb_ds 255 | ds = concatenate_datasets(all_ds) 256 | logger.info(f"Total concatenated dataset size: {len(ds)}") 257 | super().__init__(args, tokenizer, ds) 258 | 259 | def get_text_by_subset(self, subset: str, idx: int, max_len: int = 1024) -> str: 260 | emb_ds = self.target_emb_ds[subset] 261 | row = emb_ds[idx] 262 | text = row["text"] 263 | title = row.get("title", None) 264 | if title: 265 | text = title + " " + text 266 | return text[0:max_len] 267 | 268 | def __getitem__(self, item) -> list[dict]: 269 | subset = self.dataset[item]["subset"] 270 | query = self.dataset[item]["anc"] 271 | pos_ids = self.dataset[item]["pos_ids"] 272 | pos_ids_score = self.dataset[item]["pos_ids.score"] 273 | neg_ids = self.dataset[item]["neg_ids"] 274 | neg_ids_score = self.dataset[item]["neg_ids.score"] 275 | pos_texts = [self.get_text_by_subset(subset, pos_id) for pos_id in pos_ids] 276 | neg_texts = [self.get_text_by_subset(subset, neg_id) for neg_id in neg_ids] 277 | 278 | if self.binarize_label: 279 | pos_ids_score = [1.0] * len(pos_ids_score) 280 | neg_ids_score = [0.0] * len(neg_ids_score) 281 | 282 | return self.create_batch_inputs( 283 | query, 284 | pos_texts, 285 | neg_texts, 286 | pos_ids_score, 287 | neg_ids_score, 288 | ) 289 | -------------------------------------------------------------------------------- /yast/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from typing import Any, Union 5 | 6 | import torch 7 | from transformers import Trainer 8 | 9 | from .arguments import SpladeTrainingArguments 10 | from .log_metrics import LogMetrics 11 | from .losses import LossWithWeight, losses 12 | from .modeling import Splade 13 | from .regularizers import regularizers 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SpladeTrainer(Trainer): 19 | def __init__(self, args: SpladeTrainingArguments, **kwargs: Any): 20 | super().__init__(args=args, **kwargs) 21 | self.args: SpladeTrainingArguments = args 22 | self.batch_size: int = self.args.per_device_train_batch_size 23 | 24 | self.total_steps: int = int( 25 | self.get_train_dataloader().__len__() * args.num_train_epochs 26 | ) 27 | 28 | self.warmup_steps_doc: int = self._calculate_warmup_steps( 29 | args.sparsity_warmup_steps_doc 30 | ) 31 | self.warmup_steps_query: int = self._calculate_warmup_steps( 32 | args.sparsity_warmup_steps_query 33 | ) 34 | 35 | self._set_training_losses(args.training_losses) 36 | self._set_noise_token_ids(args.noise_tokens) 37 | 38 | self.regularizer_query_fn = regularizers[args.regularizer_query] 39 | self.regularizer_doc_fn = regularizers[args.regularizer_doc] 40 | 41 | self.log_metrics: LogMetrics = LogMetrics() 42 | 43 | def _set_noise_token_ids(self, noise_tokens: None | str | list[str]) -> None: 44 | if isinstance(noise_tokens, str): 45 | noise_tokens = re.split(r"\s+", noise_tokens) 46 | elif noise_tokens is None: 47 | noise_tokens = [] 48 | if len(noise_tokens) == 0: 49 | self.noise_token_ids = [] 50 | else: 51 | noise_tokens = list(set(noise_tokens)) 52 | tokenizer = self.tokenizer 53 | token_ids: list[int] = tokenizer.convert_tokens_to_ids(noise_tokens) # type: ignore 54 | if len(token_ids) != len(noise_tokens): 55 | missing_tokens = set(noise_tokens) - set( 56 | tokenizer.convert_ids_to_tokens(token_ids) # type: ignore 57 | ) 58 | raise ValueError( 59 | f"Token(s) {missing_tokens} are not in the tokenizer's vocabulary." 60 | ) 61 | logger.info( 62 | f"target noise tokens ({len(token_ids)}): {' '.join(noise_tokens)}" 63 | ) 64 | self.noise_token_ids = token_ids 65 | 66 | if self.tokenizer: 67 | # CLS, SEP, UNK token 68 | if self.tokenizer.cls_token_id is not None: 69 | self.noise_token_ids.append(self.tokenizer.cls_token_id) 70 | if self.tokenizer.sep_token_id is not None: 71 | self.noise_token_ids.append(self.tokenizer.sep_token_id) 72 | if self.tokenizer.unk_token_id is not None: 73 | self.noise_token_ids.append(self.tokenizer.unk_token_id) 74 | 75 | def _set_training_losses(self, training_loss: Any) -> None: 76 | """ 77 | 'cross_entropy', [['loss': 'cross_entropy', 'weight': 1.0], ['loss': 'mse', 'weight': 1.0]]" 78 | training_loss types 79 | 1) str only 80 | ex) 'cross_entropy' 81 | 2) dict 82 | ex) {"cross_entropy": {"weight": 1.0", loss_kwargs: {}}, "mse": {"weight": 1.0, loss_kwargs: {}}} 83 | """ 84 | 85 | training_losses: dict[str, LossWithWeight] = {} 86 | if isinstance(training_loss, str): 87 | if loss_klass := losses.get(training_loss): 88 | loss_fn = loss_klass() 89 | loss_with_args: LossWithWeight = { 90 | "loss_fn": loss_fn, 91 | "weight": 1.0, 92 | } 93 | training_losses[training_loss] = loss_with_args 94 | else: 95 | raise ValueError( 96 | f"Training loss type {training_loss} is not supported. Choose from {list(losses.keys())}" 97 | ) 98 | elif isinstance(training_loss, dict): 99 | for loss_name, loss_values in training_loss.items(): 100 | if loss_klass := losses.get(loss_name): 101 | loss_kwargs = loss_values.get("loss_kwargs", {}) 102 | if len(loss_kwargs) > 0: 103 | logger.info(f"loss_kwargs for {loss_name}: {loss_kwargs}") 104 | loss_fn = loss_klass(**loss_kwargs) 105 | loss_with_args: LossWithWeight = { 106 | "loss_fn": loss_fn, 107 | "weight": loss_values.get("weight", 1.0), 108 | } 109 | training_losses[loss_name] = loss_with_args 110 | else: 111 | raise ValueError( 112 | f"Training loss type {loss_name} is not supported. Choose from {list(losses.keys())}" 113 | ) 114 | else: 115 | raise ValueError( 116 | f"Training loss type {training_loss} is not supported. Choose from {list(losses.keys())}" 117 | ) 118 | self.training_losses = training_losses 119 | self.training_loss_is_contrastive = ( 120 | "contrastive" in training_losses and len(training_losses) == 1 121 | ) 122 | 123 | def _calculate_warmup_steps(self, steps: Union[float, int]) -> int: 124 | if 0.0 < steps < 1.0: 125 | return int(self.total_steps * steps) 126 | return int(steps) 127 | 128 | def _calculate_current_warmup_weight( 129 | self, max_weight: float, warmup_steps: int 130 | ) -> float: 131 | step = self.state.global_step 132 | current_weight = max_weight * ((step) / (warmup_steps + 1)) ** 2 133 | return min(max_weight, current_weight) 134 | 135 | @property 136 | def current_sparsity_weight_doc(self) -> float: 137 | return self._calculate_current_warmup_weight( 138 | self.args.sparsity_weight_doc, self.warmup_steps_doc 139 | ) 140 | 141 | @property 142 | def current_sparsity_weight_query(self) -> float: 143 | return self._calculate_current_warmup_weight( 144 | self.args.sparsity_weight_query, self.warmup_steps_query 145 | ) 146 | 147 | def compute_loss( # type: ignore[override] 148 | self, 149 | model: Splade, 150 | inputs: dict[str, torch.Tensor], 151 | return_outputs: bool = False, 152 | **kwargs, # for transformers v4.46.0 153 | ) -> torch.Tensor | tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: 154 | if "labels" in inputs: 155 | labels = inputs.pop("labels") 156 | labels = labels[~torch.isnan(labels)] # remove nan values 157 | labels = labels.view(self.args.per_device_train_batch_size, -1) 158 | else: 159 | raise ValueError("labels is not in batch_inputs") 160 | 161 | queries, docs = model(inputs, self.batch_size) 162 | vocab_size = docs.size(2) 163 | 164 | scores = self.compute_scores(queries, docs) 165 | losses = {} 166 | for loss_name, loss_with_args in self.training_losses.items(): 167 | loss_fn = loss_with_args["loss_fn"] 168 | weight = loss_with_args["weight"] 169 | loss = loss_fn(scores, labels) 170 | losses[loss_name + "_loss"] = weight * loss 171 | 172 | loss = sum(losses.values()) 173 | 174 | docs_matrix = docs.reshape(-1, vocab_size) 175 | queries_matrix = queries.reshape(-1, vocab_size) 176 | 177 | regularizer_doc_loss = ( 178 | self.regularizer_doc_fn(docs_matrix) * self.current_sparsity_weight_doc 179 | ) 180 | regularizer_query_loss = ( 181 | self.regularizer_query_fn(queries_matrix) 182 | * self.current_sparsity_weight_query 183 | ) 184 | regularizer_loss = regularizer_doc_loss + regularizer_query_loss 185 | noise_token_loss = self.compute_noise_token_loss(queries_matrix, docs_matrix) 186 | 187 | losses: dict[str, float | torch.Tensor] = { 188 | **losses, 189 | "L0_doc": LogMetrics.L0(docs_matrix), 190 | "L0_query": LogMetrics.L0(queries_matrix), 191 | "regularizer_loss": regularizer_loss, 192 | "regularizer_doc_loss": regularizer_doc_loss, 193 | "regularizer_query_loss": regularizer_query_loss, 194 | "noise_token_loss": noise_token_loss, 195 | } 196 | self.log_metrics.add_dict(losses) 197 | 198 | total_loss = loss + regularizer_loss + noise_token_loss 199 | 200 | if not return_outputs: 201 | return total_loss 202 | return total_loss, [(queries, docs)] 203 | 204 | def compute_noise_token_loss( 205 | self, queries_matrix: torch.Tensor, docs_matrix: torch.Tensor 206 | ) -> torch.Tensor: 207 | if self.noise_token_ids and self.args.noise_tokens_weight > 0: 208 | # XXX: 毎回GPUに載せているが、本来一回で良い 209 | noise_token_ids_tensor = torch.tensor( 210 | self.noise_token_ids, device=queries_matrix.device 211 | ) 212 | noise_token_ids_tensor = noise_token_ids_tensor.view(-1) 213 | 214 | noise_scores_queries = queries_matrix[:, noise_token_ids_tensor] 215 | noise_scores_docs = docs_matrix[:, noise_token_ids_tensor] 216 | 217 | noise_loss_queries = noise_scores_queries.sum() 218 | noise_loss_docs = noise_scores_docs.sum() 219 | noise_token_loss = noise_loss_queries + noise_loss_docs 220 | 221 | # warmup 222 | current_weight = self._calculate_current_warmup_weight( 223 | self.args.noise_tokens_weight, self.warmup_steps_query 224 | ) 225 | noise_token_loss *= current_weight 226 | else: 227 | noise_token_loss = torch.tensor(0.0, device=queries_matrix.device) 228 | return noise_token_loss 229 | 230 | def compute_scores(self, queries, docs) -> torch.Tensor: 231 | if self.training_loss_is_contrastive: 232 | scores = self.compute_contrastive_score(queries, docs) 233 | else: 234 | scores = self.compute_similarity_scores(queries, docs) 235 | return scores 236 | 237 | def compute_contrastive_score( 238 | self, queries: torch.Tensor, docs: torch.Tensor 239 | ) -> torch.Tensor: 240 | """ 241 | for contrastive loss, 242 | return shape is (batch_size, 1 + neg_size) 243 | """ 244 | scores = torch.bmm(queries, torch.permute(docs, [0, 2, 1])).squeeze(1) 245 | scores_positive = scores[:, :1] 246 | negatives = docs[:, 1:, :].reshape(-1, docs.size(2)).T 247 | scores_negative = torch.matmul(queries.squeeze(1), negatives) 248 | return torch.cat([scores_positive, scores_negative], dim=1) 249 | 250 | def compute_similarity_scores( 251 | self, queries: torch.Tensor, docs: torch.Tensor 252 | ) -> torch.Tensor: 253 | return torch.bmm(queries, docs.transpose(1, 2)).squeeze(1) 254 | 255 | def log(self, logs: dict[str, float], start_time=None, **kwargs) -> None: 256 | logs["step"] = self.state.global_step 257 | if self.state.epoch is not None: 258 | logs["epoch"] = round(self.state.epoch, 2) 259 | 260 | current_metrics = self.log_metrics.mean() 261 | self.log_metrics.clear() 262 | logs.update(current_metrics) 263 | 264 | output = {**logs, "step": self.state.global_step} 265 | self.state.log_history.append(output) 266 | self.control = self.callback_handler.on_log( 267 | self.args, self.state, self.control, logs 268 | ) 269 | 270 | def _save( 271 | self, output_dir: str | None = None, state_dict: dict[str, Any] | None = None 272 | ) -> None: 273 | output_dir = output_dir or self.args.output_dir 274 | os.makedirs(output_dir, exist_ok=True) 275 | logger.info("Saving model checkpoint to %s", output_dir) 276 | 277 | if not hasattr(self.model, "save_pretrained"): 278 | raise NotImplementedError( 279 | f"MODEL {self.model.__class__.__name__} does not support save_pretrained interface" 280 | ) 281 | 282 | self.model.save_pretrained(output_dir) 283 | if self.tokenizer is not None and self.is_world_process_zero(): 284 | self.tokenizer.save_pretrained(output_dir) 285 | 286 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 287 | -------------------------------------------------------------------------------- /yast/regularizers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Concatenate, Literal 2 | 3 | import torch 4 | 5 | 6 | def regularize_flops(batch_tensor: torch.Tensor) -> torch.Tensor: 7 | """ 8 | FLOPs regularization as described in "Minimizing FLOPs to Learn Efficient Sparse Representations". 9 | https://arxiv.org/abs/2004.05665 10 | 11 | Merits for SPLADE: 12 | - Directly optimizes for search efficiency by minimizing FLOPs 13 | - Promotes even distribution of non-zero elements across dimensions 14 | - Theoretically grounded approach for sparse representation learning 15 | 16 | Demerits for SPLADE: 17 | - May require careful tuning of regularization strength 18 | - Might lead to over-sparsification if not balanced with the main loss 19 | 20 | Args: 21 | batch_tensor (torch.Tensor): Input tensor of shape (batch_size, dim) 22 | 23 | Returns: 24 | torch.Tensor: FLOPs regularization term 25 | """ 26 | mean_abs = torch.abs(batch_tensor).mean(dim=0) 27 | flops_reg = torch.sum(torch.square(mean_abs)) 28 | return flops_reg 29 | 30 | 31 | def regularize_mean_squared(batch_tensor: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Regularization that computes the mean of absolute values, squares them, and sums the result. 34 | 35 | Merits for SPLADE: 36 | - Similar effect to FLOPs regularization, promoting sparsity 37 | - May be more numerically stable in some cases 38 | 39 | Demerits for SPLADE: 40 | - Less direct theoretical connection to FLOPs minimization 41 | - Might not distribute non-zero elements as evenly as FLOPs regularization 42 | 43 | Args: 44 | batch_tensor (torch.Tensor): Input tensor of shape (batch_size, dim) 45 | 46 | Returns: 47 | torch.Tensor: Mean squared regularization term 48 | """ 49 | return torch.pow(torch.abs(batch_tensor).mean(dim=0), 2).sum() 50 | 51 | 52 | def regularize_L1(batch_tensor: torch.Tensor) -> torch.Tensor: 53 | """ 54 | L1 regularization: computes the mean L1 norm across the last dimension. 55 | 56 | Merits for SPLADE: 57 | - Promotes general sparsity in the embeddings 58 | - Well-understood and widely used in machine learning 59 | 60 | Demerits for SPLADE: 61 | - Doesn't specifically optimize for search efficiency 62 | - May not distribute non-zero elements evenly across dimensions 63 | 64 | Args: 65 | batch_tensor (torch.Tensor): Input tensor of shape (batch_size, dim) 66 | 67 | Returns: 68 | torch.Tensor: L1 regularization term 69 | """ 70 | if batch_tensor.size(1) == 1: 71 | batch_tensor = batch_tensor.squeeze(1) 72 | return torch.norm(batch_tensor, p=1, dim=-1).mean() 73 | 74 | 75 | def regularize_L2(batch_tensor: torch.Tensor) -> torch.Tensor: 76 | """ 77 | L2 regularization: computes the mean L2 norm across the last dimension. 78 | 79 | Merits for SPLADE: 80 | - Prevents embeddings from growing too large 81 | - Can improve generalization in some cases 82 | 83 | Demerits for SPLADE: 84 | - Doesn't promote sparsity, which is crucial for SPLADE 85 | - May not be suitable as the primary regularization for sparse retrieval models 86 | 87 | Args: 88 | batch_tensor (torch.Tensor): Input tensor of shape (batch_size, dim) 89 | 90 | Returns: 91 | torch.Tensor: L2 regularization term 92 | """ 93 | if batch_tensor.size(1) == 1: 94 | batch_tensor = batch_tensor.squeeze(1) 95 | l2_norm = torch.norm(batch_tensor, p=2, dim=-1) 96 | return l2_norm.mean() 97 | 98 | 99 | def regularize_flops_l1_weighted( 100 | batch_tensor: torch.Tensor, flops_weight: float = 0.7, l1_weight: float = 0.3 101 | ) -> torch.Tensor: 102 | """ 103 | Combines FLOPs and L1 regularization with adjustable weights. 104 | Balances between FLOPs' distribution properties and L1's sparsification. 105 | 106 | Merits: 107 | - Balances between FLOPs' distribution properties and L1's sparsification 108 | - Adjustable trade-off via flops_weight and l1_weight parameters 109 | - More flexible than single regularization approaches 110 | 111 | Demerits: 112 | - Requires tuning of two hyperparameters 113 | - May be more computationally expensive 114 | 115 | Args: 116 | batch_tensor: Input tensor of shape (batch_size, dim) 117 | flops_weight: Weight for FLOPs regularization 118 | l1_weight: Weight for L1 regularization 119 | 120 | Returns: 121 | torch.Tensor: Combined regularization term 122 | """ 123 | flops_term = regularize_flops(batch_tensor) 124 | l1_term = regularize_L1(batch_tensor) 125 | return flops_weight * flops_term + l1_weight * l1_term 126 | 127 | 128 | def regularize_dynamic_sparsity( 129 | batch_tensor: torch.Tensor, target_sparsity: float = 0.95, smoothing: float = 0.01 130 | ) -> torch.Tensor: 131 | """ 132 | Dynamically adjusts regularization strength based on current sparsity level. 133 | 134 | Merits: 135 | - Automatically adjusts to maintain desired sparsity level 136 | - Prevents over-sparsification 137 | - More stable training dynamics 138 | 139 | Demerits: 140 | - Additional computational overhead 141 | - May take longer to converge 142 | 143 | Args: 144 | batch_tensor: Input tensor of shape (batch_size, dim) 145 | target_sparsity: Desired sparsity level (0 to 1) 146 | smoothing: Smoothing factor for sparsity calculation 147 | 148 | Returns: 149 | torch.Tensor: Adaptive regularization term 150 | """ 151 | current_sparsity = (batch_tensor.abs() < 1e-6).float().mean() 152 | sparsity_error = target_sparsity - current_sparsity 153 | 154 | # Adjust regularization strength based on sparsity error 155 | strength = torch.sigmoid(sparsity_error / smoothing) 156 | 157 | # Use FLOPs regularization with adaptive strength 158 | return strength * regularize_flops(batch_tensor) 159 | 160 | 161 | def regularize_magnitude_threshold( 162 | batch_tensor: torch.Tensor, threshold: float = 0.1, power: float = 2.0 163 | ) -> torch.Tensor: 164 | """ 165 | Applies progressive penalty based on value magnitudes relative to threshold. 166 | 167 | Merits: 168 | - More granular control over sparsification 169 | - Can preserve important large values while suppressing small ones 170 | - Helps achieve desired sparsity pattern 171 | 172 | Demerits: 173 | - Sensitive to threshold parameter 174 | - May require careful tuning 175 | 176 | Args: 177 | batch_tensor: Input tensor of shape (batch_size, dim) 178 | threshold: Threshold value for penalty application 179 | power: Power factor for penalty scaling 180 | 181 | Returns: 182 | torch.Tensor: Threshold regularization term 183 | """ 184 | abs_values = torch.abs(batch_tensor) 185 | penalty = torch.where( 186 | abs_values < threshold, 187 | torch.pow(abs_values / threshold, power), 188 | torch.ones_like(abs_values), 189 | ) 190 | return penalty.mean() 191 | 192 | 193 | def regularize_entropy_balanced( 194 | batch_tensor: torch.Tensor, target_entropy: float = 2.0 195 | ) -> torch.Tensor: 196 | """ 197 | Promotes balanced distribution of non-zero elements using entropy. 198 | 199 | Merits: 200 | - Encourages more balanced distribution of non-zero elements 201 | - Helps prevent concentration of activations 202 | - Can maintain semantic diversity 203 | 204 | Demerits: 205 | - More complex computation 206 | - May not always converge to optimal sparsity level 207 | 208 | Args: 209 | batch_tensor: Input tensor of shape (batch_size, dim) 210 | target_entropy: Target entropy value for distribution 211 | 212 | Returns: 213 | torch.Tensor: Distributional regularization term 214 | """ 215 | # Calculate normalized magnitude distribution 216 | magnitudes = torch.abs(batch_tensor) 217 | probs = magnitudes / (magnitudes.sum(dim=-1, keepdim=True) + 1e-6) 218 | 219 | # Calculate entropy of the distribution 220 | entropy = -(probs * torch.log(probs + 1e-6)).sum(dim=-1).mean() 221 | 222 | # Penalize deviation from target entropy 223 | return torch.abs(entropy - target_entropy) 224 | 225 | 226 | def regularize_grouped_magnitude( 227 | batch_tensor: torch.Tensor, group_size: int = 8, threshold: float = 0.1 228 | ) -> torch.Tensor: 229 | """ 230 | Group-wise magnitude regularization that promotes structured sparsity. 231 | Similar to structured pruning in neural networks. 232 | 233 | Merits: 234 | - Promotes structured sparsity within groups 235 | - More efficient for actual computation 236 | - Better preserves semantic relationships 237 | 238 | Demerits: 239 | - Group size needs to be tuned 240 | - May not work well with very small dimensions 241 | 242 | Args: 243 | batch_tensor: Input tensor of shape (batch_size, dim) 244 | group_size: Size of groups for structured sparsity 245 | threshold: Threshold for magnitude comparison 246 | 247 | Returns: 248 | torch.Tensor: Group magnitude regularization term 249 | """ 250 | batch_size, dim = batch_tensor.shape 251 | num_groups = dim // group_size 252 | 253 | # Reshape into groups 254 | grouped_tensor = batch_tensor.view(batch_size, num_groups, group_size) 255 | 256 | # Calculate group-wise magnitudes 257 | group_magnitudes = torch.norm(grouped_tensor, p=2, dim=2) 258 | 259 | # Apply soft thresholding to groups 260 | threshold_penalty = torch.relu(group_magnitudes - threshold) 261 | 262 | return threshold_penalty.mean() 263 | 264 | 265 | def regularize_topk_entropy( 266 | batch_tensor: torch.Tensor, k: int = 256, temperature: float = 1.0 267 | ) -> torch.Tensor: 268 | """ 269 | Top-k sparse entropy regularization that maintains semantic diversity. 270 | Combines benefits of top-k sparsity with entropy-based distribution control. 271 | 272 | Merits: 273 | - Controls exact number of non-zero elements 274 | - Maintains semantic diversity through entropy 275 | - More predictable sparsity patterns 276 | 277 | Demerits: 278 | - k needs to be chosen carefully 279 | - Computationally more expensive due to sorting 280 | 281 | Args: 282 | batch_tensor: Input tensor of shape (batch_size, dim) 283 | k: Number of top elements to consider 284 | temperature: Temperature for softmax 285 | 286 | Returns: 287 | torch.Tensor: Top-k entropy regularization term 288 | """ 289 | magnitudes = torch.abs(batch_tensor) 290 | 291 | # Get top-k values and compute soft distribution 292 | top_k_values, _ = torch.topk(magnitudes, k=k, dim=1) 293 | soft_distribution = torch.softmax(top_k_values / temperature, dim=1) 294 | 295 | # Compute entropy of top-k distribution 296 | entropy = -(soft_distribution * torch.log(soft_distribution + 1e-10)).sum(dim=1) 297 | 298 | return -entropy.mean() # Minimize negative entropy 299 | 300 | 301 | def regularize_adaptive_threshold( 302 | batch_tensor: torch.Tensor, 303 | init_threshold: float = 0.1, 304 | target_density: float = 0.05, 305 | momentum: float = 0.9, 306 | ) -> torch.Tensor: 307 | """ 308 | Adaptive threshold regularization that maintains target density. 309 | Automatically adjusts threshold to maintain desired sparsity level. 310 | 311 | Merits: 312 | - Automatically maintains target sparsity 313 | - Smooth adaptation of threshold 314 | - More stable than fixed threshold approaches 315 | 316 | Demerits: 317 | - Requires careful tuning of momentum 318 | - May take time to stabilize 319 | 320 | Args: 321 | batch_tensor: Input tensor of shape (batch_size, dim) 322 | init_threshold: Initial threshold value 323 | target_density: Target density (1 - sparsity) 324 | momentum: Momentum for threshold adaptation 325 | 326 | Returns: 327 | torch.Tensor: Adaptive threshold regularization term 328 | """ 329 | magnitudes = torch.abs(batch_tensor) 330 | current_density = (magnitudes > init_threshold).float().mean() 331 | 332 | # Compute threshold adjustment 333 | density_error = current_density - target_density 334 | threshold_adjustment = torch.sign(density_error) * torch.abs(density_error).sqrt() 335 | 336 | # Apply soft thresholding with current threshold 337 | penalty = torch.relu(magnitudes - init_threshold) 338 | 339 | return penalty.mean() * (1.0 + threshold_adjustment) 340 | 341 | 342 | regularizers: dict[ 343 | Literal[ 344 | "mean_squared", 345 | "flops", 346 | "L1", 347 | "L2", 348 | "flops_l1_weighted", 349 | "dynamic_sparsity", 350 | "magnitude_threshold", 351 | "entropy_balanced", 352 | "dynamic_sparsity", 353 | "grouped_magnitude", 354 | "topk_entropy", 355 | "adaptive_threshold", 356 | ], 357 | Callable[Concatenate[torch.Tensor, ...], torch.Tensor], 358 | ] = { 359 | "mean_squared": regularize_mean_squared, 360 | "flops": regularize_flops, 361 | "L1": regularize_L1, 362 | "L2": regularize_L2, 363 | "flops_l1_weighted": regularize_flops_l1_weighted, 364 | "dynamic_sparsity": regularize_dynamic_sparsity, 365 | "magnitude_threshold": regularize_magnitude_threshold, 366 | "entropy_balanced": regularize_entropy_balanced, 367 | "grouped_magnitude": regularize_grouped_magnitude, 368 | "topk_entropy": regularize_topk_entropy, 369 | "adaptive_threshold": regularize_adaptive_threshold, 370 | } 371 | -------------------------------------------------------------------------------- /yast/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Type, TypedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | """ 8 | Loss implementations for SPLADE training 9 | """ 10 | 11 | 12 | class KLDivLoss(nn.Module): 13 | def __init__( 14 | self, 15 | reduction: Literal["batchmean", "sum", "none"] = "batchmean", 16 | temperature: float = 1.0, 17 | ) -> None: 18 | super().__init__() 19 | 20 | if temperature <= 0: 21 | raise ValueError(f"Temperature must be positive, got {temperature}") 22 | 23 | self.temperature = temperature 24 | self.kl_div = nn.KLDivLoss(reduction=reduction, log_target=False) 25 | 26 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 27 | if scores.shape != labels.shape: 28 | raise ValueError( 29 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 30 | ) 31 | 32 | if not torch.isfinite(scores).all(): 33 | raise ValueError("scores contains inf or nan") 34 | 35 | if not torch.isfinite(labels).all(): 36 | raise ValueError("labels contains inf or nan") 37 | 38 | log_probs = F.log_softmax(scores / self.temperature, dim=1) 39 | loss = self.kl_div(log_probs, labels) * (self.temperature**2) 40 | 41 | return loss 42 | 43 | 44 | class WeightedBCELoss(nn.Module): 45 | def __init__( 46 | self, 47 | reduction: Literal["mean", "sum", "none"] = "mean", 48 | temperature: float = 1.0, 49 | scaling_factor: float = 25.0, 50 | pos_weight: float = 8.0, 51 | ) -> None: 52 | super().__init__() 53 | 54 | if temperature <= 0: 55 | raise ValueError(f"Temperature must be positive, got {temperature}") 56 | if scaling_factor <= 0: 57 | raise ValueError(f"Scaling factor must be positive, got {scaling_factor}") 58 | 59 | self.temperature = temperature 60 | self.scaling_factor = scaling_factor 61 | # Initialize BCEWithLogitsLoss with pos_weight 62 | self.bce = nn.BCEWithLogitsLoss( 63 | reduction=reduction, pos_weight=torch.tensor([pos_weight]) 64 | ) 65 | 66 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 67 | if scores.shape != labels.shape: 68 | raise ValueError( 69 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 70 | ) 71 | 72 | if not torch.isfinite(scores).all(): 73 | raise ValueError("scores contains inf or nan") 74 | 75 | if not torch.isfinite(labels).all(): 76 | raise ValueError("labels contains inf or nan") 77 | 78 | scaled_scores = (scores / self.scaling_factor) / self.temperature 79 | loss = self.bce(scaled_scores, labels) * (self.temperature**2) 80 | 81 | return loss 82 | 83 | 84 | class MarginMSELoss(nn.Module): 85 | def __init__(self, margin: float = 0.05): 86 | super(MarginMSELoss, self).__init__() 87 | self.margin = margin 88 | 89 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 90 | if scores.shape != labels.shape: 91 | raise ValueError( 92 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 93 | ) 94 | 95 | mse_loss = F.mse_loss(scores, labels, reduction="none") 96 | margin_loss = F.relu(mse_loss - self.margin) 97 | 98 | loss = margin_loss.mean(dim=1).mean() 99 | 100 | return loss 101 | 102 | 103 | class MarginCrossEntropyLoss(nn.Module): 104 | def __init__(self, margin: float = 0.05): 105 | super(MarginCrossEntropyLoss, self).__init__() 106 | self.margin = margin 107 | 108 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 109 | if scores.shape != labels.shape: 110 | raise ValueError( 111 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 112 | ) 113 | 114 | log_probs = F.log_softmax(scores, dim=1) 115 | soft_ce_loss = -(labels * log_probs).sum(dim=1).mean() 116 | 117 | mse_loss = F.mse_loss(scores, labels, reduction="none").mean(dim=1) 118 | margin_loss = F.relu(mse_loss - self.margin).mean() 119 | 120 | loss = soft_ce_loss + margin_loss 121 | return loss 122 | 123 | 124 | class TeacherGuidedMarginLoss(nn.Module): 125 | def __init__( 126 | self, 127 | temperature: float = 0.35, 128 | margin: float = 2.0, 129 | soft_ce_weight: float = 1.0, 130 | margin_weight: float = 1.0, 131 | ): 132 | """ 133 | Args: 134 | temperature (float): Temperature parameter controlling score scaling 135 | margin (float): Minimum margin between positive/negative pairs 136 | soft_ce_weight (float): Weight for soft cross entropy loss 137 | margin_weight (float): Weight for margin loss 138 | """ 139 | super(TeacherGuidedMarginLoss, self).__init__() 140 | self.temperature = temperature 141 | self.margin = margin 142 | self.soft_ce_weight = soft_ce_weight 143 | self.margin_weight = margin_weight 144 | 145 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 146 | """ 147 | Args: 148 | scores (torch.Tensor): Inner product values (batch_size, num_candidates) 149 | labels (torch.Tensor): Teacher scores (batch_size, num_candidates) 150 | Returns: 151 | torch.Tensor: Scalar loss value 152 | """ 153 | if scores.shape != labels.shape: 154 | raise ValueError( 155 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 156 | ) 157 | 158 | # Scale scores by temperature 159 | scaled_scores = scores / self.temperature 160 | 161 | # Soft Cross Entropy Loss 162 | log_probs = F.log_softmax(scaled_scores, dim=1) 163 | teacher_probs = F.softmax(labels / self.temperature, dim=1) 164 | soft_ce_loss = -(teacher_probs * log_probs).sum(dim=1).mean() 165 | 166 | # Weighted Margin Loss 167 | positives = scores[:, 0].unsqueeze(1) # (batch_size, 1) 168 | negatives = scores[:, 1:] # (batch_size, num_negatives) 169 | teacher_weights = labels[:, 1:] # Teacher scores for negatives 170 | weighted_margin = self.margin * ( 171 | 1.0 - teacher_weights 172 | ) # (batch_size, num_negatives) 173 | margin_diffs = ( 174 | negatives - positives + weighted_margin 175 | ) # (batch_size, num_negatives) 176 | margin_loss = F.relu(margin_diffs).mean() 177 | 178 | # Final loss (weighted sum) 179 | loss = self.soft_ce_weight * soft_ce_loss + self.margin_weight * margin_loss 180 | # TODO: Enable returning as dict 181 | return loss 182 | 183 | 184 | class SoftCrossEntropyLoss(nn.Module): 185 | def __init__(self, temperature: float = 1.5): 186 | """ 187 | Args: 188 | temperature (float): Temperature parameter controlling score scaling 189 | """ 190 | super(SoftCrossEntropyLoss, self).__init__() 191 | self.temperature = temperature 192 | 193 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 194 | """ 195 | Args: 196 | scores (torch.Tensor): Inner product values (batch_size, num_candidates) 197 | labels (torch.Tensor): Teacher scores (batch_size, num_candidates) 198 | Returns: 199 | torch.Tensor: Scalar loss value 200 | """ 201 | if scores.shape != labels.shape: 202 | raise ValueError( 203 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 204 | ) 205 | 206 | scaled_scores = scores / self.temperature 207 | log_probs = F.log_softmax(scaled_scores, dim=1) 208 | teacher_probs = F.softmax(labels / self.temperature, dim=1) 209 | loss = -(teacher_probs * log_probs).sum(dim=1).mean() 210 | return loss 211 | 212 | 213 | class WeightedMarginLoss(nn.Module): 214 | def __init__( 215 | self, 216 | margin: float = 4.0, 217 | max_negative_teacher: float = 0.3, 218 | min_positive_teacher: float = 0.7, 219 | ): 220 | """ 221 | Args: 222 | margin (float): Base margin between positive/negative pairs 223 | max_negative_teacher (float): Maximum expected value for negative teacher scores 224 | min_positive_teacher (float): Minimum expected value for positive teacher scores 225 | """ 226 | super(WeightedMarginLoss, self).__init__() 227 | self.base_margin = margin 228 | self.max_negative_teacher = max_negative_teacher 229 | self.min_positive_teacher = min_positive_teacher 230 | 231 | def get_margin(self, teacher_scores: torch.Tensor) -> torch.Tensor: 232 | """ 233 | Calculate margin based on negative teacher scores 234 | Lower scoring examples (easy negatives) get larger margins 235 | 236 | Args: 237 | teacher_scores: Teacher scores for negative examples (batch_size, num_negatives) 238 | Returns: 239 | torch.Tensor: Margin values 240 | """ 241 | # Clip teacher scores to [0, max_negative_teacher] range 242 | clipped_scores = torch.clamp(teacher_scores, 0.0, self.max_negative_teacher) 243 | 244 | # Lower scores get larger margins (range 0.9-1.1) 245 | margin_scale = 1.1 - (clipped_scores / self.max_negative_teacher) * 0.2 246 | return self.base_margin * margin_scale 247 | 248 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 249 | """ 250 | Args: 251 | scores (torch.Tensor): Inner product values (batch_size, num_candidates) 252 | labels (torch.Tensor): Teacher scores (batch_size, num_candidates) 253 | First column is positive, rest are negative 254 | Returns: 255 | torch.Tensor: Scalar loss value 256 | """ 257 | if scores.shape != labels.shape: 258 | raise ValueError( 259 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 260 | ) 261 | 262 | # Check teacher score ranges (warning only) 263 | negatives_mask = labels[:, 1:] > self.max_negative_teacher 264 | if torch.any(negatives_mask): 265 | count = torch.sum(negatives_mask).item() 266 | print( 267 | f"Warning: {count} negative teacher scores exceed expected maximum value of {self.max_negative_teacher}" 268 | ) 269 | 270 | positives_mask = labels[:, 0] < self.min_positive_teacher 271 | if torch.any(positives_mask): 272 | count = torch.sum(positives_mask).item() 273 | print( 274 | f"Warning: {count} positive teacher scores are below minimum value of {self.min_positive_teacher}" 275 | ) 276 | 277 | positives = scores[:, 0].unsqueeze(1) # (batch_size, 1) 278 | negatives = scores[:, 1:] # (batch_size, num_negatives) 279 | teacher_weights = labels[:, 1:] # Teacher scores for negatives 280 | 281 | # Calculate margins 282 | weighted_margin = self.get_margin(teacher_weights) 283 | 284 | # Calculate margin loss 285 | margin_diffs = negatives - positives + weighted_margin 286 | loss = F.relu(margin_diffs).mean() 287 | return loss 288 | 289 | 290 | class WeightedMarginLossWithLog(nn.Module): 291 | def __init__( 292 | self, 293 | margin: float = 0.5, 294 | # max_negative_teacher: float = 0.3, 295 | # min_positive_teacher: float = 0.7, 296 | eps: float = 1e-6, 297 | ): 298 | super().__init__() 299 | """ 300 | margin = 0.5 # This is appropriate for a log-space difference of 0.35 301 | Statistics for positive inner products: 302 | mean: 10.03 303 | log(10.03) = 2.31 304 | std: 1.566 305 | 306 | Statistics for negative inner products: 307 | mean: 7.11 308 | log(7.11) = 1.96 309 | std: 1.26 310 | """ 311 | self.base_margin = margin 312 | # self.max_negative_teacher = max_negative_teacher 313 | # self.min_positive_teacher = min_positive_teacher 314 | self.eps = eps 315 | 316 | def get_margin( 317 | self, teacher_pos_scores: torch.Tensor, teacher_neg_scores: torch.Tensor 318 | ) -> torch.Tensor: 319 | """ 320 | Adjust margins for each batch element (row) based on both positive and negative teacher scores 321 | 322 | Args: 323 | teacher_pos_scores: Teacher scores for positive examples (batch_size, 1) 324 | teacher_neg_scores: Teacher scores for negative examples (batch_size, num_negatives) 325 | Returns: 326 | torch.Tensor: Adjusted margins (batch_size, num_negatives) 327 | """ 328 | # Calculate statistics for each row 329 | neg_mean_per_row = teacher_neg_scores.mean( 330 | dim=1, keepdim=True 331 | ) # (batch_size, 1) 332 | neg_std_per_row = teacher_neg_scores.std(dim=1, keepdim=True) # (batch_size, 1) 333 | 334 | # Calculate relative position of negative scores in each row 335 | # How far from the mean (in standard deviations) 336 | neg_relative_scores = (teacher_neg_scores - neg_mean_per_row) / ( 337 | neg_std_per_row + self.eps 338 | ) 339 | 340 | # Relative strength of positive scores per row 341 | # Confidence in positive examples for that query 342 | pos_relative_strength = (teacher_pos_scores - self.min_positive_teacher) / ( 343 | 1.0 - self.min_positive_teacher 344 | ) # (batch_size, 1) 345 | 346 | # Higher positive confidence leads to larger margins (range 1.0-1.2) 347 | pos_scale = 1.0 + (pos_relative_strength * 0.2) # (batch_size, 1) 348 | 349 | # Lower negative scores relative to mean get larger margins 350 | # Map -2~2σ range to 0-1 using sigmoid 351 | neg_confidence = torch.sigmoid(neg_relative_scores) 352 | neg_scale = 1.1 - (neg_confidence * 0.2) # (batch_size, num_negatives) 353 | 354 | # Combine both scales 355 | margin_scale = pos_scale * neg_scale # (batch_size, num_negatives) 356 | 357 | return self.base_margin * margin_scale 358 | 359 | def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 360 | if scores.shape != labels.shape: 361 | raise ValueError( 362 | f"Shape mismatch: scores {scores.shape} != labels {labels.shape}" 363 | ) 364 | 365 | log_scores = torch.log(scores + self.eps) 366 | 367 | positives = log_scores[:, 0].unsqueeze(1) 368 | negatives = log_scores[:, 1:] 369 | 370 | # Separate teacher scores 371 | teacher_pos = labels[:, 0].unsqueeze(1) # Teacher scores for positives 372 | teacher_neg = labels[:, 1:] # Teacher scores for negatives 373 | 374 | # Calculate margins using both teacher scores 375 | weighted_margin = self.get_margin(teacher_pos, teacher_neg) 376 | 377 | margin_diffs = negatives - positives + weighted_margin 378 | loss = F.relu(margin_diffs).mean() 379 | 380 | return loss 381 | 382 | 383 | class LossWithWeight(TypedDict): 384 | loss_fn: nn.Module 385 | weight: float 386 | 387 | 388 | losses: dict[str, Type[nn.Module]] = { 389 | "cross_entropy": nn.CrossEntropyLoss, 390 | "mse": nn.MSELoss, 391 | "kl_div": KLDivLoss, 392 | "margin_mse": MarginMSELoss, 393 | "margin_ce": MarginCrossEntropyLoss, 394 | "teacher_guided_margin": TeacherGuidedMarginLoss, 395 | "soft_ce": SoftCrossEntropyLoss, 396 | "weighted_margin": WeightedMarginLoss, 397 | "weighted_margin_log": WeightedMarginLossWithLog, 398 | "weighted_bce": WeightedBCELoss, 399 | } 400 | -------------------------------------------------------------------------------- /yast/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from: https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/reranker/data.py 3 | License: MIT License 4 | """ 5 | 6 | import functools 7 | import importlib 8 | import logging 9 | import os 10 | import random 11 | import re 12 | from copy import deepcopy 13 | from dataclasses import dataclass 14 | from typing import List, Type, cast 15 | 16 | import torch 17 | import torch.utils.data 18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk 19 | from transformers import ( 20 | BatchEncoding, 21 | DataCollatorWithPadding, 22 | PreTrainedTokenizer, 23 | PreTrainedTokenizerFast, 24 | ) 25 | 26 | from .arguments import DataArguments 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | POSITIVE_KEYS = ["pos", "positive", "positives"] 32 | NEGATIVE_KEYS = ["neg", "negative", "negatives"] 33 | QUERY_KEYS = ["query", "qry", "question", "q"] 34 | POSITIVE_SCORE_KEYS = [ 35 | "pos_score", 36 | "positive_score", 37 | "positive_scores", 38 | "positives_score", 39 | ] 40 | NEGATIVE_SCORE_KEYS = [ 41 | "neg_score", 42 | "negative_score", 43 | "negative_scores", 44 | "negatives_score", 45 | ] 46 | 47 | 48 | class DatasetForSpladeTraining(torch.utils.data.Dataset): 49 | def __init__( 50 | self, 51 | args: DataArguments, 52 | tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, 53 | dataset: Dataset | None = None, 54 | ): 55 | if not dataset: 56 | train_data = args.train_data # list or str 57 | if isinstance(train_data, list): 58 | datasets = [] 59 | for target in train_data: 60 | logger.info(f"Loading {target}") 61 | datasets.append(self.load_dataset(target)) 62 | self.dataset = concatenate_datasets(datasets) 63 | else: 64 | logger.info(f"Loading {train_data}") 65 | self.dataset = self.load_dataset(train_data) 66 | else: 67 | self.dataset = dataset 68 | 69 | self.dataset: Dataset = cast(Dataset, self.dataset) 70 | 71 | self.tokenizer = tokenizer 72 | self.args = args 73 | self.total_len = len(self.dataset) 74 | 75 | self.subword_token_ids = set() 76 | 77 | if args.create_subword_indices: 78 | subword_prefix = "##" # bert subword prefix 79 | for token in tokenizer.get_vocab(): 80 | if token.startswith(subword_prefix): 81 | self.subword_token_ids.add(tokenizer.convert_tokens_to_ids(token)) 82 | logger.info(f"Subword token count: {len(self.subword_token_ids)}") 83 | self._set_noise_token_ids(args.noise_tokens_for_subword) 84 | 85 | def load_dataset(self, target_name: str) -> Dataset: 86 | if target_name.endswith(".jsonl") or target_name.endswith(".json"): 87 | logger.info(f"Loading JSON dataset from {target_name}") 88 | return load_dataset("json", data_files=target_name)["train"] # type: ignore 89 | elif os.path.isdir(target_name): 90 | datasets = [] 91 | target_files = os.listdir(target_name) 92 | if any([f.endswith(".arrow") for f in target_files]): 93 | # has arrow files 94 | logger.info(f"Loading dataset from directory {target_name}") 95 | target_ds = load_from_disk(target_name) 96 | logger.info(f"Loaded {target_name}: {len(target_ds)} examples") 97 | datasets.append(target_ds) 98 | else: 99 | for target in target_files: 100 | full_path = os.path.join(target_name, target) 101 | logger.info(f"Loading {full_path}") 102 | target_ds = self.load_dataset(full_path) 103 | logger.info(f"Loaded {full_path}: {len(target_ds)} examples") 104 | datasets.append(target_ds) 105 | return concatenate_datasets(datasets) 106 | else: 107 | logger.info(f"Loading dataset {target_name} with split='train'") 108 | return load_dataset(target_name, split="train") # type: ignore 109 | 110 | @property 111 | @functools.lru_cache(maxsize=None) 112 | def query_key(self): 113 | for key in QUERY_KEYS: 114 | if key in self.dataset.column_names: 115 | return key 116 | raise ValueError("Query key not found") 117 | 118 | @property 119 | @functools.lru_cache(maxsize=None) 120 | def positive_key(self): 121 | for key in POSITIVE_KEYS: 122 | if key in self.dataset.column_names: 123 | return key 124 | raise ValueError("Positive key not found") 125 | 126 | @property 127 | @functools.lru_cache(maxsize=None) 128 | def negative_key(self): 129 | for key in NEGATIVE_KEYS: 130 | if key in self.dataset.column_names: 131 | return key 132 | raise ValueError("Negative key not found") 133 | 134 | @property 135 | @functools.lru_cache(maxsize=None) 136 | def positive_score_key(self): 137 | for key in POSITIVE_SCORE_KEYS: 138 | if key in self.dataset.column_names: 139 | return key 140 | return None 141 | 142 | @property 143 | @functools.lru_cache(maxsize=None) 144 | def negative_score_key(self): 145 | for key in NEGATIVE_SCORE_KEYS: 146 | if key in self.dataset.column_names: 147 | return key 148 | return None 149 | 150 | def create_one_example( 151 | self, encoding: str, max_length: int | None = None 152 | ) -> BatchEncoding: 153 | if not max_length: 154 | max_length = self.args.max_length 155 | item = self.tokenizer.encode_plus( 156 | encoding, 157 | truncation=True, 158 | max_length=max_length, 159 | padding=False, 160 | ) 161 | if len(self.subword_token_ids) > 0: 162 | item["subword_indices"] = create_subword_indices( 163 | torch.tensor(item["input_ids"]).unsqueeze(0), 164 | self.subword_token_ids, 165 | self.noise_token_ids, 166 | self.tokenizer, 167 | ).squeeze(0) 168 | 169 | return item 170 | 171 | def create_batch_inputs( 172 | self, 173 | query: str, 174 | pos_texts: list[str], 175 | neg_texts: list[str], 176 | pos_ids_score: list[float], 177 | neg_ids_score: list[float], 178 | ): 179 | pos_size = min(self.args.train_max_positive_size, len(pos_texts)) 180 | neg_size = self.args.train_group_size - pos_size 181 | 182 | pos_targets = list(zip(pos_texts, pos_ids_score)) 183 | if len(pos_targets) >= pos_size: 184 | pos_positions = random.sample(range(len(pos_targets)), pos_size) 185 | else: 186 | pos_positions = random.choices(range(len(pos_targets)), k=pos_size) 187 | pos_texts = [pos_targets[i][0] for i in pos_positions] 188 | pos_ids_score = [pos_targets[i][1] for i in pos_positions] 189 | 190 | neg_targets = list(zip(neg_texts, neg_ids_score)) 191 | if len(neg_targets) >= neg_size: 192 | neg_positions = random.sample(range(len(neg_targets)), neg_size) 193 | else: 194 | neg_positions = random.choices(range(len(neg_targets)), k=neg_size) 195 | 196 | neg_texts = [neg_targets[i][0] for i in neg_positions] 197 | neg_ids_score = [neg_targets[i][1] for i in neg_positions] 198 | 199 | labels = [float("nan")] + pos_ids_score + neg_ids_score 200 | 201 | batch_inputs = [] 202 | batch_inputs.append(self.create_one_example(query, self.args.max_query_length)) 203 | for text in pos_texts + neg_texts: 204 | batch_inputs.append(self.create_one_example(text, self.args.max_length)) 205 | for label, batch_input in zip(labels, batch_inputs): 206 | batch_input["label"] = label 207 | return batch_inputs 208 | 209 | def __len__(self): 210 | return self.total_len 211 | 212 | def __getitem__(self, item) -> List[dict]: 213 | query = self.dataset[item][self.query_key] 214 | pos_texts = self.dataset[item][self.positive_key] 215 | if not isinstance(pos_texts, list): 216 | pos_texts = [pos_texts] 217 | pos_scores = self.dataset[item].get(self.positive_score_key, []) 218 | if not isinstance(pos_scores, list): 219 | pos_scores = [pos_scores] 220 | neg_texts = self.dataset[item][self.negative_key] 221 | if not isinstance(neg_texts, list): 222 | neg_texts = [neg_texts] 223 | neg_scores = self.dataset[item].get(self.negative_score_key, []) 224 | if not isinstance(neg_scores, list): 225 | neg_scores = [neg_scores] 226 | 227 | return self.create_batch_inputs( 228 | query, pos_texts, neg_texts, pos_scores, neg_scores 229 | ) 230 | 231 | def _set_noise_token_ids(self, noise_tokens: None | str | list[str]) -> None: 232 | # FIXME: 実装を trainer から copy していてよくない... 233 | if isinstance(noise_tokens, str): 234 | noise_tokens = re.split(r"\s+", noise_tokens) 235 | elif noise_tokens is None: 236 | noise_tokens = [] 237 | if len(noise_tokens) > 0: 238 | noise_tokens = list(set(noise_tokens)) 239 | tokenizer = self.tokenizer 240 | token_ids: list[int] = tokenizer.convert_tokens_to_ids(noise_tokens) # type: ignore 241 | if len(token_ids) != len(noise_tokens): 242 | missing_tokens = set(noise_tokens) - set( 243 | tokenizer.convert_ids_to_tokens(token_ids) # type: ignore 244 | ) 245 | raise ValueError( 246 | f"Token(s) {missing_tokens} are not in the tokenizer's vocabulary." 247 | ) 248 | logger.info( 249 | f"target noise tokens ({len(token_ids)}): {' '.join(noise_tokens)}" 250 | ) 251 | self.noise_token_ids = set(token_ids) 252 | else: 253 | self.noise_token_ids = set() 254 | 255 | 256 | @dataclass 257 | class GroupCollator(DataCollatorWithPadding): 258 | def __call__(self, features): 259 | if isinstance(features[0], list): 260 | features = sum(features, []) # type: ignore 261 | 262 | # subword_indices がある場合、 263 | # サブワードインデックスのパディング処理を追加 264 | if "subword_indices" in features[0]: 265 | max_length = max(len(f["input_ids"]) for f in features) 266 | for feature in features: 267 | padding_length = max_length - len(feature["input_ids"]) 268 | if padding_length > 0: 269 | feature["subword_indices"] = ( 270 | feature["subword_indices"].tolist() + [-100] * padding_length 271 | ) 272 | 273 | return super().__call__(features) 274 | 275 | 276 | def detect_dataset_klass(dataset_path: str) -> Type[DatasetForSpladeTraining]: 277 | module_path, class_name = dataset_path.rsplit(".", 1) 278 | module = importlib.import_module(module_path) 279 | dataset_class = getattr(module, class_name) 280 | return dataset_class 281 | 282 | 283 | def create_dateset_from_args( 284 | args: DataArguments, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast 285 | ) -> DatasetForSpladeTraining: 286 | train_data = args.train_data 287 | if isinstance(train_data, str): 288 | target_ds = DatasetForSpladeTraining(args, tokenizer) 289 | elif isinstance(train_data, list): 290 | target_ds_list = [] 291 | for target_train_data in train_data: 292 | dataset_class_args = deepcopy(args) 293 | if isinstance(target_train_data, str): 294 | dataset_class_args.train_data = target_train_data 295 | target_ds_list.append( 296 | DatasetForSpladeTraining(dataset_class_args, tokenizer) 297 | ) 298 | elif isinstance(target_train_data, dict): 299 | dataset_class_name = target_train_data.get("dataset_class") 300 | dataset_options = target_train_data.get("dataset_options", {}) 301 | # merge dataset_options 302 | dataset_class_args.dataset_options.update(dataset_options) 303 | 304 | if not dataset_class_name: 305 | raise ValueError(f"dataset_class is required, {target_train_data}") 306 | dataset_class_train_data = target_train_data.get("train_data") 307 | if dataset_class_train_data: 308 | dataset_class_args.train_data = dataset_class_train_data 309 | dataset_klass = detect_dataset_klass(dataset_class_name) 310 | target_ds_list.append(dataset_klass(dataset_class_args, tokenizer)) 311 | else: 312 | raise ValueError(f"Invalid type {target_train_data}") 313 | target_ds = torch.utils.data.ConcatDataset(target_ds_list) 314 | return target_ds # type: ignore 315 | 316 | 317 | def create_subword_indices( 318 | token_ids: torch.Tensor, subword_token_ids: set, noise_token_ids: set, tokenizer 319 | ) -> torch.Tensor: 320 | """ 321 | トークンIDからサブワードインデックスを生成する 322 | サブワードを含む単語は同じインデックスでグループ化し、 323 | 単独トークンは-100として扱う 324 | 325 | Args: 326 | token_ids (torch.Tensor): トークンID (batch_size, seq_len) 327 | subword_token_ids (set): サブワードとして扱うトークンIDのset 328 | noise_token_ids (set): ノイズトークンとして扱うトークンIDのset 329 | 330 | Returns: 331 | torch.Tensor: サブワードインデックス (batch_size, seq_len) 332 | -100: 単独トークン(サブワードを含まない単語)やパディング 333 | 0以上: サブワードを含む単語のグループインデックス 334 | """ 335 | batch_size, seq_len = token_ids.shape 336 | subword_indices = torch.full_like( 337 | token_ids, 338 | -100, # PADDINGのマスク値 339 | ) 340 | 341 | current_subword_group_idx = -1 342 | start_word_token_id = -1 343 | for b in range(batch_size): 344 | word_start_pos = -1 345 | in_subword_sequence = False 346 | 347 | for i in range(seq_len): 348 | token_id = token_ids[b, i].item() 349 | 350 | # パディングやマスクされたトークンはスキップ 351 | if token_id == -100: 352 | continue 353 | 354 | is_subword = token_id in subword_token_ids 355 | 356 | # 新しい単語の開始 357 | if not is_subword and not in_subword_sequence: 358 | # 前の単語の処理 359 | if word_start_pos != -1: 360 | # 単独トークンの場合 361 | if not in_subword_sequence: 362 | subword_indices[b, word_start_pos] = -100 363 | 364 | word_start_pos = i 365 | in_subword_sequence = False 366 | start_word_token_id = token_id 367 | 368 | # サブワードシーケンスの開始 369 | elif is_subword and not in_subword_sequence: 370 | current_subword_group_idx += 1 371 | in_subword_sequence = True 372 | if word_start_pos != -1: 373 | if start_word_token_id not in noise_token_ids: 374 | # start_word_token_id がノイズトークンでない場合 375 | # start_token もサブワードグループに含める 376 | subword_indices[b, word_start_pos : i + 1] = ( 377 | current_subword_group_idx 378 | ) 379 | else: 380 | # start_token_id がノイズトークンの場合、その start_token は無視する 381 | subword_indices[b, word_start_pos] = -100 382 | subword_indices[b, i] = current_subword_group_idx 383 | 384 | # サブワードシーケンスの途中 385 | elif is_subword and in_subword_sequence: 386 | subword_indices[b, i] = current_subword_group_idx 387 | 388 | # サブワードシーケンスの終了 389 | if not is_subword and in_subword_sequence: 390 | word_start_pos = i 391 | in_subword_sequence = False 392 | 393 | # 最後の単語の処理 394 | if word_start_pos != -1 and not in_subword_sequence: 395 | subword_indices[b, word_start_pos] = -100 396 | 397 | return subword_indices 398 | --------------------------------------------------------------------------------