├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── pyproject.toml ├── tokenlearn ├── __init__.py ├── featurize.py ├── pretrain.py ├── train.py ├── utils.py └── version.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | local/ 166 | results/ 167 | local_models/ 168 | data/ 169 | lightning_logs/ 170 | models/ 171 | wandb/ 172 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-ast 8 | description: Simply check whether files parse as valid python. 9 | - id: trailing-whitespace 10 | description: Trims trailing whitespace 11 | - id: end-of-file-fixer 12 | description: Makes sure files end in a newline and only a newline. 13 | - id: check-added-large-files 14 | args: ['--maxkb=5000'] 15 | description: Prevent giant files from being committed. 16 | - id: check-case-conflict 17 | description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows. 18 | - repo: https://github.com/jsh9/pydoclint 19 | rev: 0.5.3 20 | hooks: 21 | - id: pydoclint 22 | - repo: https://github.com/astral-sh/ruff-pre-commit 23 | rev: v0.4.10 24 | hooks: 25 | - id: ruff 26 | args: [ --fix ] 27 | - id: ruff-format 28 | - repo: local 29 | hooks: 30 | - id: mypy 31 | name: mypy 32 | entry: mypy 33 | language: python 34 | types: [python] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 The Minish Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | venv: 2 | uv venv 3 | 4 | install: 5 | uv sync --all-extras 6 | uv run pre-commit install 7 | 8 | fix: 9 | uv run pre-commit run --all-files 10 | 11 | test: 12 | uv run pytest --cov=nanofit --cov-report=term-missing 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tokenlearn 2 | Tokenlearn is a method to pre-train [Model2Vec](https://github.com/MinishLab/model2vec). 3 | 4 | The method is described in detail in our [Tokenlearn blogpost](https://minishlab.github.io/tokenlearn_blogpost/). 5 | 6 | ## Quickstart 7 | 8 | Install the package with: 9 | 10 | ```bash 11 | pip install tokenlearn 12 | ``` 13 | 14 | The basic usage of Tokenlearn consists of two CLI scripts: `featurize` and `train`. 15 | 16 | Tokenlearn is trained using means from a sentence transformer. To create means, the `tokenlearn-featurize` CLI can be used: 17 | 18 | ```bash 19 | python3 -m tokenlearn.featurize --model-name "baai/bge-base-en-v1.5" --output-dir "data/c4_features" 20 | ``` 21 | 22 | NOTE: the default model is trained on the C4 dataset. If you want to use a different dataset, the following code can be used: 23 | 24 | ```bash 25 | python3 -m tokenlearn.featurize \ 26 | --model-name "baai/bge-base-en-v1.5" \ 27 | --output-dir "data/c4_features" \ 28 | --dataset-path "allenai/c4" \ 29 | --dataset-name "en" \ 30 | --dataset-split "train" 31 | ``` 32 | 33 | To train a model on the featurized data, the `tokenlearn-train` CLI can be used: 34 | ```bash 35 | python3 -m tokenlearn.train --model-name "baai/bge-base-en-v1.5" --data-path "data/c4_features" --save-path "" 36 | ``` 37 | 38 | Training will create two models: 39 | - The base trained model. 40 | - The base model with weighting applied. This is the model that should be used for downstream tasks. 41 | 42 | NOTE: the code assumes that the padding token ID in your tokenizer is 0. If this is not the case, you will need to modify the code. 43 | 44 | ### Evaluation 45 | 46 | To evaluate a model, you can use the following command after installing the optional evaluation dependencies: 47 | 48 | ```bash 49 | pip install evaluation@git+https://github.com/MinishLab/evaluation@main 50 | 51 | ``` 52 | 53 | ```python 54 | from model2vec import StaticModel 55 | 56 | from evaluation import CustomMTEB, get_tasks, parse_mteb_results, make_leaderboard, summarize_results 57 | from mteb import ModelMeta 58 | 59 | # Get all available tasks 60 | tasks = get_tasks() 61 | # Define the CustomMTEB object with the specified tasks 62 | evaluation = CustomMTEB(tasks=tasks) 63 | 64 | # Load a trained model 65 | model_name = "tokenlearn_model" 66 | model = StaticModel.from_pretrained(model_name) 67 | 68 | # Optionally, add model metadata in MTEB format 69 | model.mteb_model_meta = ModelMeta( 70 | name=model_name, revision="no_revision_available", release_date=None, languages=None 71 | ) 72 | 73 | # Run the evaluation 74 | results = evaluation.run(model, eval_splits=["test"], output_folder=f"results") 75 | 76 | # Parse the results and summarize them 77 | parsed_results = parse_mteb_results(mteb_results=results, model_name=model_name) 78 | task_scores = summarize_results(parsed_results) 79 | 80 | # Print the results in a leaderboard format 81 | print(make_leaderboard(task_scores)) 82 | ``` 83 | 84 | ## License 85 | 86 | MIT 87 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tokenlearn" 3 | description = "Pre-train Static Embedders" 4 | readme = "README.md" 5 | license = { file = "LICENSE" } 6 | requires-python = ">=3.9" 7 | authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}] 8 | dynamic = ["version"] 9 | 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | "Topic :: Software Development :: Libraries", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Natural Language :: English", 23 | ] 24 | 25 | dependencies = [ 26 | "model2vec[distill]>=0.5.0", 27 | "sentence-transformers", 28 | "torch", 29 | "datasets", 30 | "more-itertools>=10.5.0", 31 | ] 32 | 33 | [build-system] 34 | requires = ["setuptools>=64", "setuptools_scm>=8"] 35 | build-backend = "setuptools.build_meta" 36 | 37 | [project.optional-dependencies] 38 | dev = [ 39 | "black", 40 | "ipython", 41 | "mypy", 42 | "pre-commit", 43 | "pytest", 44 | "pytest-coverage", 45 | "ruff", 46 | ] 47 | 48 | [project.urls] 49 | "Homepage" = "https://github.com/MinishLab" 50 | "Bug Reports" = "https://github.com/MinishLab/tokenlearn/issues" 51 | "Source" = "https://github.com/MinishLab/tokenlearn" 52 | 53 | [tool.ruff] 54 | exclude = [".venv/"] 55 | line-length = 120 56 | target-version = "py310" 57 | 58 | [tool.ruff.lint] 59 | select = [ 60 | # Annotations: Enforce type annotations 61 | "ANN", 62 | # Complexity: Enforce a maximum cyclomatic complexity 63 | "C90", 64 | # Pydocstyle: Enforce docstrings 65 | "D", 66 | # Isort: Enforce import order 67 | "I", 68 | # Numpy: Enforce numpy style 69 | "NPY", 70 | # Print: Forbid print statements 71 | "T20", 72 | ] 73 | ignore = [ 74 | # Allow self and cls to be untyped, and allow Any type 75 | "ANN101", "ANN102", "ANN401", 76 | # Pydocstyle ignores 77 | "D100", "D101", "D104", "D203", "D212", "D401", 78 | # Allow use of f-strings in logging 79 | "G004" 80 | ] 81 | 82 | [tool.pydoclint] 83 | style = "sphinx" 84 | exclude = "test_" 85 | allow-init-docstring = true 86 | arg-type-hints-in-docstring = false 87 | check-return-types = false 88 | require-return-section-when-returning-nothing = false 89 | 90 | [tool.mypy] 91 | python_version = "3.10" 92 | warn_unused_configs = true 93 | ignore_missing_imports = true 94 | 95 | [tool.setuptools] 96 | packages = ["tokenlearn"] 97 | 98 | [tool.setuptools_scm] 99 | # can be empty if no extra settings are needed, presence enables setuptools_scm 100 | 101 | [tool.setuptools.dynamic] 102 | version = {attr = "tokenlearn.version.__version__"} 103 | -------------------------------------------------------------------------------- /tokenlearn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/tokenlearn/cafc6d6aae39f99e52bdeb7433d179259ca1e9d5/tokenlearn/__init__.py -------------------------------------------------------------------------------- /tokenlearn/featurize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from typing import Iterator 6 | 7 | import numpy as np 8 | from datasets import load_dataset 9 | from more_itertools import batched 10 | from sentence_transformers import SentenceTransformer 11 | from tqdm import tqdm 12 | from transformers.tokenization_utils import PreTrainedTokenizer 13 | 14 | _SAVE_EVERY = 32 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def featurize( 21 | dataset: Iterator[dict[str, str]], 22 | model: SentenceTransformer, 23 | output_dir: str, 24 | max_means: int, 25 | batch_size: int, 26 | text_key: str, 27 | ) -> None: 28 | """Make a directory and dump all kinds of data in it.""" 29 | output_dir_path = Path(output_dir) 30 | output_dir_path.mkdir(parents=True, exist_ok=True) 31 | 32 | # Ugly hack 33 | largest_batch = max([int(x.stem.split("_")[1]) for x in list(output_dir_path.glob("*.json"))], default=0) 34 | if largest_batch: 35 | logger.info(f"Resuming from batch {largest_batch}, skipping previous batches.") 36 | 37 | texts = [] 38 | embeddings = [] 39 | dim = model.get_sentence_embedding_dimension() 40 | if dim is None: 41 | raise ValueError("Model has no sentence embedding dimension.") 42 | 43 | tokenizer: PreTrainedTokenizer = model.tokenizer 44 | # Binding i in case the dataset is empty. 45 | i = 0 46 | for i, batch in tqdm(enumerate(batched(dataset, n=batch_size))): 47 | if i * batch_size >= max_means: 48 | logger.info(f"Reached maximum number of means: {max_means}") 49 | break 50 | if largest_batch and i <= largest_batch: 51 | continue 52 | batch = [x[text_key] for x in batch] 53 | 54 | if not all(isinstance(x, str) for x in batch): 55 | raise ValueError(f"Detected non-string at batch: {i}") 56 | 57 | batch_embeddings = model.encode(batch, output_value="token_embeddings") # type: ignore # Annoying 58 | for text, embedding in zip(batch, batch_embeddings): 59 | texts.append(_truncate_text(tokenizer, text)) 60 | embeddings.append(embedding[1:-1].float().mean(axis=0).cpu().numpy()) 61 | if i and i % _SAVE_EVERY == 0: 62 | json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) 63 | np.save(output_dir_path / f"feature_{i}.npy", embeddings) 64 | texts = [] 65 | embeddings = [] 66 | if texts: 67 | json.dump(texts, open(output_dir_path / f"feature_{i}.json", "w"), indent=4) 68 | np.save(output_dir_path / f"feature_{i}.npy", embeddings) 69 | 70 | 71 | def _truncate_text(tokenizer: PreTrainedTokenizer, text: str) -> str: 72 | """Truncate text to fit the tokenizer's maximum length.""" 73 | tokens = tokenizer.encode( 74 | text, 75 | truncation=True, 76 | max_length=tokenizer.model_max_length, 77 | ) 78 | return tokenizer.decode(tokens, skip_special_tokens=True) 79 | 80 | 81 | def main() -> None: 82 | """Main function to featurize texts using a sentence transformer.""" 83 | parser = argparse.ArgumentParser(description="Featurize texts using a sentence transformer.") 84 | parser.add_argument( 85 | "--model-name", 86 | type=str, 87 | default="baai/bge-base-en-v1.5", 88 | help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').", 89 | ) 90 | parser.add_argument( 91 | "--output-dir", 92 | type=str, 93 | default=None, 94 | help="Directory to save the featurized texts.", 95 | ) 96 | parser.add_argument( 97 | "--dataset-path", 98 | type=str, 99 | default="allenai/c4", 100 | help="The dataset path or name (e.g. 'allenai/c4').", 101 | ) 102 | parser.add_argument( 103 | "--dataset-name", 104 | type=str, 105 | default="en", 106 | help="The dataset configuration name (e.g., 'en' for C4).", 107 | ) 108 | parser.add_argument( 109 | "--dataset-split", 110 | type=str, 111 | default="train", 112 | help="The dataset split (e.g., 'train', 'validation').", 113 | ) 114 | parser.add_argument( 115 | "--no-streaming", 116 | action="store_false", 117 | help="Disable streaming mode when loading the dataset.", 118 | ) 119 | parser.add_argument( 120 | "--max-means", 121 | type=int, 122 | default=1000000, 123 | help="The maximum number of mean embeddings to generate.", 124 | ) 125 | parser.add_argument( 126 | "--key", 127 | type=str, 128 | default="text", 129 | help="The key of the text field in the dataset to featurize (default: 'text').", 130 | ) 131 | parser.add_argument( 132 | "--batch-size", 133 | type=int, 134 | default=32, 135 | help="Batch size to use for encoding the texts.", 136 | ) 137 | 138 | args = parser.parse_args() 139 | 140 | if args.output_dir is None: 141 | model_name = args.model_name.replace("/", "_") 142 | dataset_path = args.dataset_path.replace("/", "_") 143 | output_dir = f"{model_name}_{dataset_path}_featurized" 144 | else: 145 | output_dir = args.output_dir 146 | 147 | model = SentenceTransformer(args.model_name) 148 | dataset = load_dataset( 149 | args.dataset_path, 150 | name=args.dataset_name, 151 | split=args.dataset_split, 152 | streaming=args.no_streaming, 153 | ) 154 | 155 | featurize(iter(dataset), model, output_dir, args.max_means, args.batch_size, args.key) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /tokenlearn/pretrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from model2vec import StaticModel 8 | from model2vec.distill.utils import select_optimal_device 9 | from tokenizers import Tokenizer 10 | from torch import nn 11 | from torch.nn.utils.rnn import pad_sequence 12 | from torch.utils.data import DataLoader, Dataset 13 | from tqdm import tqdm 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class StaticModelFineTuner(nn.Module): 19 | def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None: 20 | """ 21 | Initialize from a model. 22 | 23 | :param vectors: The vectors to use. 24 | :param out_dim: The output dimension. 25 | :param pad_id: The padding id. 26 | """ 27 | super().__init__() 28 | self.pad_id = pad_id 29 | norms = vectors.norm(dim=1) 30 | # Normalize the vectors 31 | vectors = vectors / norms[:, None] 32 | self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id) 33 | self.n_out = out_dim 34 | self.out_layer = nn.Linear(vectors.shape[1], self.n_out) 35 | weights = torch.Tensor(norms) 36 | weights[pad_id] = 0 37 | self.w = nn.Parameter(weights) 38 | 39 | def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor: 40 | """Forward pass through the mean.""" 41 | w = self.w[input_ids] 42 | zeros = (input_ids != self.pad_id).float() 43 | w = w * zeros 44 | # Add a small epsilon to avoid division by zero 45 | length = zeros.sum(1) + 1e-16 46 | embedded = self.embeddings(input_ids) 47 | # Zero out the padding 48 | embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) 49 | # Simulate actual mean 50 | embedded = embedded / length[:, None] 51 | 52 | return embedded 53 | 54 | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 55 | """Forward pass through the mean, and a classifier layer after.""" 56 | embedded = self.sub_forward(x) 57 | return self.out_layer(embedded), embedded 58 | 59 | @property 60 | def device(self) -> torch.device: 61 | """Get the device of the model.""" 62 | return self.embeddings.weight.device 63 | 64 | 65 | class TextDataset(Dataset): 66 | def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None: 67 | """ 68 | Initialize the dataset. 69 | 70 | :param texts: The texts to tokenize. 71 | :param targets: The targets. 72 | :param tokenizer: The tokenizer to use. 73 | :raises ValueError: If the number of labels does not match the number of texts. 74 | """ 75 | if len(targets) != len(texts): 76 | raise ValueError("Number of labels does not match number of texts.") 77 | self.texts = [x[:20_000] for x in texts] 78 | self.tokenized_texts: list[list[int]] = [ 79 | encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False) 80 | ] 81 | self.targets = targets 82 | self.tokenizer = tokenizer 83 | 84 | def __len__(self) -> int: 85 | """Return the length of the dataset.""" 86 | return len(self.tokenized_texts) 87 | 88 | def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: 89 | """Gets an item.""" 90 | return self.tokenized_texts[index], self.targets[index] 91 | 92 | @staticmethod 93 | def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: 94 | """Collate function.""" 95 | texts, targets = zip(*batch) 96 | 97 | tensors = [torch.LongTensor(x).int() for x in texts] 98 | padded = pad_sequence(tensors, batch_first=True, padding_value=0) 99 | 100 | return padded, torch.stack(targets) 101 | 102 | def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: 103 | """Convert the dataset to a DataLoader.""" 104 | return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) 105 | 106 | 107 | def train_supervised( # noqa: C901 108 | train_dataset: TextDataset, 109 | validation_dataset: TextDataset, 110 | model: StaticModel, 111 | patience: int | None = 5, 112 | device: str | None = None, 113 | batch_size: int = 256, 114 | lr: float = 1e-3, 115 | ) -> StaticModel: 116 | """ 117 | Train a tokenlearn model. 118 | 119 | :param train_dataset: The training dataset. 120 | :param validation_dataset: The validation dataset. 121 | :param model: The model to train. 122 | :param patience: The number of epochs to wait before early stopping. 123 | :param device: The device to train on. 124 | :param batch_size: The batch size. 125 | :param lr: The learning rate. 126 | :return: The trained model. 127 | """ 128 | device = select_optimal_device(device) 129 | train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size) 130 | 131 | # Initialize the model 132 | trainable_model = StaticModelFineTuner( 133 | torch.from_numpy(model.embedding), 134 | out_dim=train_dataset.targets.shape[1], 135 | pad_id=model.tokenizer.token_to_id("[PAD]"), 136 | ) 137 | trainable_model.to(device) 138 | 139 | # Separate parameters for model and linear layer 140 | model_params = ( 141 | list(trainable_model.embeddings.parameters()) 142 | + [trainable_model.w] 143 | + list(trainable_model.out_layer.parameters()) 144 | ) 145 | 146 | # Create optimizer with separate parameter groups 147 | optimizer = torch.optim.AdamW(params=model_params, lr=lr) 148 | 149 | lowest_loss = float("inf") 150 | param_dict = trainable_model.state_dict() 151 | curr_patience = patience 152 | stop = False 153 | 154 | criterion = nn.MSELoss() 155 | 156 | try: 157 | for epoch in range(100_000): 158 | logger.info(f"Epoch {epoch}") 159 | trainable_model.train() 160 | 161 | # Track train loss separately 162 | train_losses = [] 163 | barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]") 164 | 165 | for idx, (x, y) in enumerate(barred_train): 166 | optimizer.zero_grad() 167 | x = x.to(trainable_model.device) 168 | y_hat, _ = trainable_model(x) 169 | # Separate loss components 170 | train_loss = criterion(y_hat, y.to(trainable_model.device)).mean() 171 | 172 | # Apply weights 173 | train_loss.backward() 174 | 175 | optimizer.step() 176 | train_losses.append(train_loss.item()) 177 | 178 | barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}") 179 | 180 | # Evaluate every 1000 steps and at the end of the epoch 181 | if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1: 182 | trainable_model.eval() 183 | with torch.no_grad(): 184 | validation_losses = [] 185 | barred_val = tqdm( 186 | validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation" 187 | ) 188 | for x_val, y_val in barred_val: 189 | x_val = x_val.to(trainable_model.device) 190 | y_hat_val, _ = trainable_model(x_val) 191 | val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean() 192 | validation_losses.append(val_loss.item()) 193 | barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}") 194 | 195 | validation_loss = np.mean(validation_losses) 196 | # Early stopping logic based on validation loss 197 | if patience is not None and curr_patience is not None: 198 | if (lowest_loss - validation_loss) > 1e-4: 199 | param_dict = trainable_model.state_dict() # Save best model state based on training loss 200 | curr_patience = patience 201 | lowest_loss = validation_loss 202 | else: 203 | curr_patience -= 1 204 | if curr_patience == 0: 205 | stop = True 206 | break 207 | logger.info(f"Patience level: {patience - curr_patience}") 208 | logger.info(f"Validation loss: {validation_loss:.3f}") 209 | logger.info(f"Lowest loss: {lowest_loss:.3f}") 210 | 211 | trainable_model.train() 212 | 213 | if stop: 214 | logger.info("Early stopping") 215 | break 216 | 217 | except KeyboardInterrupt: 218 | logger.info("Training interrupted") 219 | 220 | trainable_model.eval() 221 | # Load best model based on training loss 222 | trainable_model.load_state_dict(param_dict) 223 | 224 | # Move the embeddings to the device (GPU) 225 | embeddings_weight = trainable_model.embeddings.weight.to(device) 226 | 227 | # Perform the forward pass on GPU 228 | with torch.no_grad(): 229 | vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy() 230 | 231 | new_model = StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config) 232 | 233 | return new_model 234 | -------------------------------------------------------------------------------- /tokenlearn/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import numpy as np 7 | import torch 8 | from model2vec import StaticModel 9 | from model2vec.distill import distill 10 | from sklearn.decomposition import PCA 11 | 12 | from tokenlearn.pretrain import TextDataset, train_supervised 13 | from tokenlearn.utils import collect_means_and_texts, create_vocab 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _MAX_N_VAL_SAMPLES = 10_000 20 | 21 | 22 | def train_model( 23 | model: StaticModel, 24 | train_txt: list[str], 25 | train_vec: np.ndarray, 26 | device: str = "cpu", 27 | pca_dims: int = 256, 28 | ) -> StaticModel: 29 | """ 30 | Train a tokenlearn model. 31 | 32 | :param model: The static model to distill further. 33 | :param train_txt: List of texts to train on. 34 | :param train_vec: List of vectors to train on. 35 | :param device: Device to run the training on. 36 | :param pca_dims: Number of dimensions to reduce the target embeddings to using PCA. 37 | :return: The trained model. 38 | """ 39 | pca_for_targets = PCA(n_components=pca_dims) 40 | train_vec = pca_for_targets.fit_transform(train_vec) 41 | var = np.cumsum(pca_for_targets.explained_variance_ratio_)[-1] 42 | logger.info(f"Explained variance of target embeddings: {var:.2f}") 43 | 44 | # Split the data into training and validation sets 45 | # We use a max of 10k samples as validation data 46 | val_samples = min(_MAX_N_VAL_SAMPLES, len(train_txt) // 10) 47 | train_txt, train_vec, val_txt, val_vec = ( 48 | train_txt[:-val_samples], 49 | train_vec[:-val_samples], 50 | train_txt[-val_samples:], 51 | train_vec[-val_samples:], 52 | ) 53 | 54 | train_data = TextDataset(train_txt, torch.from_numpy(train_vec), model.tokenizer) 55 | val_data = TextDataset(val_txt, torch.from_numpy(val_vec), model.tokenizer) 56 | 57 | # Train the model 58 | model = train_supervised(train_dataset=train_data, validation_dataset=val_data, model=model, device=device) 59 | return model 60 | 61 | 62 | def main() -> None: 63 | """Main function to train and save a Model2Vec model using tokenlearn.""" 64 | parser = argparse.ArgumentParser(description="Train a Model2Vec using tokenlearn.") 65 | parser.add_argument( 66 | "--model-name", 67 | type=str, 68 | default="baai/bge-base-en-v1.5", 69 | help="The model name for distillation (e.g., 'baai/bge-base-en-v1.5').", 70 | ) 71 | parser.add_argument( 72 | "--data-path", 73 | type=str, 74 | default="data/fineweb_bgebase", 75 | help="Path to the directory containing the dataset.", 76 | ) 77 | parser.add_argument( 78 | "--save-path", 79 | type=str, 80 | required=True, 81 | help="Path to save the trained model.", 82 | ) 83 | parser.add_argument( 84 | "--device", 85 | type=str, 86 | default="cpu", 87 | help="Device to run the training on (e.g., 'cpu', 'cuda').", 88 | ) 89 | parser.add_argument( 90 | "--vocab-size", 91 | type=int, 92 | default=56000, 93 | help="The vocabulary size to use for training.", 94 | ) 95 | parser.add_argument( 96 | "--trust-remote-code", 97 | action="store_true", 98 | help="Trust remote code when loading the model.", 99 | ) 100 | parser.add_argument( 101 | "--pca-dims", 102 | type=int, 103 | default=256, 104 | help="Number of dimensions to reduce the target embeddings to using PCA.", 105 | ) 106 | args = parser.parse_args() 107 | 108 | # Collect paths for training data 109 | paths = sorted(Path(args.data_path).glob("*.json")) 110 | train_txt, train_vec = collect_means_and_texts(paths) 111 | 112 | pca_dims = args.pca_dims 113 | 114 | vocab_size = args.vocab_size 115 | if vocab_size: 116 | # Create a vocabulary if a vocab size is specified 117 | vocab = create_vocab(texts=train_txt, vocab_size=vocab_size) 118 | logger.info(f"Vocabulary created with {len(vocab)} tokens.") 119 | else: 120 | vocab = None 121 | model = distill( 122 | model_name=args.model_name, quantize_to="float32", vocabulary=vocab, pca_dims=pca_dims, trust_remote_code=True 123 | ) 124 | 125 | # Train the model 126 | model = train_model(model, train_txt, train_vec, device=args.device, pca_dims=pca_dims) 127 | model.save_pretrained(args.save_path) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /tokenlearn/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from collections import Counter 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import regex 8 | from tqdm import tqdm 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def create_vocab(texts: list[str], vocab_size: int = 56_000) -> list[str]: 14 | """ 15 | Create a vocabulary from a list of texts. 16 | 17 | :param texts: The list of texts to create the vocabulary from. 18 | :param vocab_size: The size of the vocabulary. Defaults to 56,000, which is the vocab_size used for our 32M models. 19 | :return: The vocabulary. 20 | """ 21 | tokenizer_regex = regex.compile(r"\w+|[^\w\s]+") 22 | 23 | # Tokenize all texts 24 | tokens = [] 25 | for text in tqdm(texts, desc="Tokenizing texts"): 26 | tokens.extend(tokenizer_regex.findall(text.lower())) 27 | 28 | # Count the tokens 29 | token_counts = Counter(tokens) 30 | 31 | # Get the most common tokens as the vocabulary 32 | vocab = [word for word, _ in token_counts.most_common(vocab_size)] 33 | return vocab 34 | 35 | 36 | def collect_means_and_texts(paths: list[Path]) -> tuple[list[str], np.ndarray]: 37 | """Collect means and texts from a list of paths.""" 38 | txts = [] 39 | vectors_list = [] 40 | for items_path in tqdm(paths, desc="Collecting means and texts"): 41 | if not items_path.name.endswith(".json"): 42 | continue 43 | base_path = items_path.with_name(items_path.stem.replace("", "")) 44 | vectors_path = items_path.with_name(base_path.name.replace(".json", "") + ".npy") 45 | try: 46 | with open(items_path, "r") as f: 47 | items = json.load(f) 48 | vectors = np.load(vectors_path, allow_pickle=False) 49 | vectors = vectors.astype(np.float32) 50 | except (KeyError, FileNotFoundError, ValueError) as e: 51 | logger.info(f"Error loading data from {base_path}: {e}") 52 | continue 53 | 54 | # Filter out any NaN vectors before appending 55 | vectors = np.stack(vectors) 56 | items = np.array(items) 57 | non_nan_indices = ~np.isnan(vectors).any(axis=1) 58 | valid_vectors = vectors[non_nan_indices] 59 | valid_items = items[non_nan_indices] 60 | txts.extend(valid_items.tolist()) 61 | vectors_list.append(valid_vectors) 62 | 63 | if vectors_list: 64 | all_vectors = np.concatenate(vectors_list, axis=0) 65 | else: 66 | all_vectors = np.array([]) 67 | return txts, all_vectors 68 | -------------------------------------------------------------------------------- /tokenlearn/version.py: -------------------------------------------------------------------------------- 1 | __version_triple__ = (0, 2, 1) 2 | __version__ = ".".join(map(str, __version_triple__)) 3 | --------------------------------------------------------------------------------