├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── assets └── images │ ├── logo.png │ ├── logo_v2.png │ ├── model2vec_logo.png │ ├── model2vec_model_diagram.png │ ├── model2vec_model_diagram_transparant_dark.png │ ├── model2vec_model_diagram_transparant_light.png │ ├── sentences_per_second_vs_average_score.png │ ├── speed_vs_accuracy.png │ ├── speed_vs_accuracy_v2.png │ ├── speed_vs_accuracy_v3.png │ ├── speed_vs_accuracy_v4.png │ ├── speed_vs_mteb_score.png │ ├── speed_vs_mteb_score_v2.png │ ├── speed_vs_mteb_score_v3.png │ ├── training_speed_vs_score.png │ └── tutorial_ezlo.png ├── docs ├── README.md ├── integrations.md ├── usage.md └── what_is_model2vec.md ├── model2vec ├── __init__.py ├── distill │ ├── __init__.py │ ├── distillation.py │ ├── inference.py │ └── utils.py ├── hf_utils.py ├── inference │ ├── README.md │ ├── __init__.py │ └── model.py ├── model.py ├── modelcards │ ├── classifier_template.md │ └── model_card_template.md ├── py.typed ├── quantization.py ├── tokenizer │ ├── __init__.py │ ├── datamodels.py │ ├── model.py │ ├── normalizer.py │ ├── pretokenizer.py │ └── tokenizer.py ├── train │ ├── README.md │ ├── __init__.py │ ├── base.py │ └── classifier.py ├── utils.py └── version.py ├── pyproject.toml ├── results ├── README.md └── make_speed_vs_mteb_plot.py ├── scripts └── export_to_onnx.py ├── tests ├── __init__.py ├── conftest.py ├── data │ └── test_tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.json │ │ └── tokenizer_config.json ├── test_distillation.py ├── test_inference.py ├── test_model.py ├── test_quantization.py ├── test_tokenizer.py ├── test_trainable.py └── test_utils.py ├── tutorials ├── README.md ├── recipe_search.ipynb ├── semantic_chunking.ipynb └── train_classifier.ipynb └── uv.lock /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests and upload coverage 2 | 3 | on: 4 | push 5 | 6 | jobs: 7 | test: 8 | name: Run tests with pytest 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: ["ubuntu-latest", "windows-latest"] 13 | python-version: ["3.9", "3.10", "3.11", "3.12"] 14 | exclude: 15 | - os: windows-latest 16 | python-version: "3.9" 17 | - os: windows-latest 18 | python-version: "3.11" 19 | - os: windows-latest 20 | python-version: "3.12" 21 | fail-fast: false 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | allow-prereleases: true 31 | 32 | # Step for Windows: Create and activate a virtual environment 33 | - name: Create and activate a virtual environment (Windows) 34 | if: ${{ runner.os == 'Windows' }} 35 | run: | 36 | irm https://astral.sh/uv/install.ps1 | iex 37 | $env:Path = "C:\Users\runneradmin\.local\bin;$env:Path" 38 | uv venv .venv 39 | "VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append 40 | "$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append 41 | 42 | # Step for Unix: Create and activate a virtual environment 43 | - name: Create and activate a virtual environment (Unix) 44 | if: ${{ runner.os != 'Windows' }} 45 | run: | 46 | curl -LsSf https://astral.sh/uv/install.sh | sh 47 | uv venv .venv 48 | echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV 49 | echo "$PWD/.venv/bin" >> $GITHUB_PATH 50 | 51 | # Install dependencies using uv pip 52 | - name: Install dependencies 53 | run: make install-no-pre-commit 54 | 55 | # Run tests with coverage 56 | - name: Run tests under coverage 57 | run: | 58 | coverage run -m pytest 59 | coverage report 60 | 61 | # Upload results to Codecov 62 | - name: Upload results to Codecov 63 | uses: codecov/codecov-action@v4 64 | with: 65 | token: ${{ secrets.CODECOV_TOKEN }} 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | *.npy 164 | *.torch 165 | .venv 166 | .DS_Store 167 | models 168 | checkpoints/* 169 | features/* 170 | model2vec_models 171 | counts/* 172 | results_old/* 173 | local/* 174 | lightning_logs/* 175 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-ast 8 | description: Simply check whether files parse as valid python. 9 | - id: trailing-whitespace 10 | description: Trims trailing whitespace 11 | - id: end-of-file-fixer 12 | description: Makes sure files end in a newline and only a newline. 13 | - id: check-added-large-files 14 | args: ['--maxkb=5000'] 15 | description: Prevent giant files from being committed. 16 | - id: check-case-conflict 17 | description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows. 18 | - id: check-yaml 19 | description: Check yaml files for syntax errors. 20 | - repo: https://github.com/jsh9/pydoclint 21 | rev: 0.5.3 22 | hooks: 23 | - id: pydoclint 24 | - repo: https://github.com/astral-sh/ruff-pre-commit 25 | rev: v0.4.10 26 | hooks: 27 | - id: ruff 28 | args: [ --fix ] 29 | - id: ruff-format 30 | - repo: local 31 | hooks: 32 | - id: mypy 33 | name: mypy 34 | entry: mypy 35 | language: python 36 | types: [python] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Thomas van Dongen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | 3 | 4 | venv: 5 | uv venv 6 | 7 | install: 8 | uv sync --all-extras 9 | uv run pre-commit install 10 | 11 | install-no-pre-commit: 12 | uv pip install ".[dev,distill,inference,train]" 13 | uv pip install "torch<2.5.0" 14 | 15 | install-base: 16 | uv sync --extra dev 17 | 18 | fix: 19 | uv run pre-commit run --all-files 20 | 21 | test: 22 | uv run pytest --cov=model2vec --cov-report=term-missing 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | Model2Vec logo
4 | Fast State-of-the-Art Static Embeddings 5 |

6 | 7 | 8 | 9 |
10 |

11 | 🤗 Models | 12 | 📚 Tutorials | 13 | 🌐 Blog | 14 | 🏆 Results | 15 | 📖 Docs 16 |

17 |
18 | 19 |
20 |

21 | Package version 22 | Supported Python versions 23 | 24 | Downloads 25 | 26 | 27 | Codecov 28 | 29 | 30 | Join Discord 31 | 32 | 33 | License - MIT 34 | 35 |

36 |
37 | 38 | 39 | 40 | 41 | 42 | Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by a factor up to 50 and making the models up to 500 times faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works. 43 | 44 |
45 |

46 | 47 | [Quickstart](#quickstart) • [Updates & Announcements](#updates--announcements) • [Main Features](#main-features) • [Model List](#model-list) 48 |

49 |
50 | 51 | ## Quickstart 52 | 53 | Install the lightweight base package with: 54 | 55 | ```bash 56 | pip install model2vec 57 | ``` 58 | 59 | You can start using Model2Vec by loading one of our [flagship models from the HuggingFace hub](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062). These models are pre-trained and ready to use. The following code snippet shows how to load a model and make embeddings, which you can use for any task, such as text classification, retrieval, clustering, or building a RAG system: 60 | ```python 61 | from model2vec import StaticModel 62 | 63 | # Load a model from the HuggingFace hub (in this case the potion-base-8M model) 64 | model = StaticModel.from_pretrained("minishlab/potion-base-8M") 65 | 66 | # Make embeddings 67 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 68 | 69 | # Make sequences of token embeddings 70 | token_embeddings = model.encode_as_sequence(["It's dangerous to go alone!", "It's a secret to everybody."]) 71 | ``` 72 | 73 | Instead of using one of our models, you can also distill your own Model2Vec model from a Sentence Transformer model. First, install the `distillation` extras with: 74 | 75 | ```bash 76 | pip install model2vec[distill] 77 | ``` 78 | 79 | 80 | Then, you can distill a model in ~30 seconds on a CPU with the following code snippet: 81 | 82 | ```python 83 | from model2vec.distill import distill 84 | 85 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model 86 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256) 87 | 88 | # Save the model 89 | m2v_model.save_pretrained("m2v_model") 90 | ``` 91 | 92 | After distillation, you can also fine-tune your own classification models on top of the distilled model, or on a pre-trained model. First, make sure you install the `training` extras with: 93 | 94 | ```bash 95 | pip install model2vec[training] 96 | ``` 97 | 98 | Then, you can fine-tune a model as follows: 99 | 100 | ```python 101 | import numpy as np 102 | from datasets import load_dataset 103 | from model2vec.train import StaticModelForClassification 104 | 105 | # Initialize a classifier from a pre-trained model 106 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") 107 | 108 | # Load a dataset. Note: both single and multi-label classification datasets are supported 109 | ds = load_dataset("setfit/subj") 110 | 111 | # Train the classifier on text (X) and labels (y) 112 | classifier.fit(ds["train"]["text"], ds["train"]["label"]) 113 | 114 | # Evaluate the classifier 115 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"]) 116 | ``` 117 | 118 | For advanced usage, please refer to our [usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md). 119 | 120 | ## Updates & Announcements 121 | 122 | - **23/05/2025**: We released [potion-multilingual-128M](https://huggingface.co/minishlab/potion-multilingual-128M), a multilingual model trained on 101 languages. It is the best performing static embedding model for multilingual tasks, and is capable of generating embeddings for any text in any language. The results can be found in our [results](results/README.md#mmteb-results-multilingual) section. 123 | 124 | - **01/05/2025**: We released backend support for `BPE` and `Unigram` tokenizers, along with quantization and dimensionality reduction. New Model2Vec models are now 50% of the original models, and can be quantized to int8 to be 25% of the size, without loss of performance. 125 | 126 | - **12/02/2025**: We released **Model2Vec training**, allowing you to fine-tune your own classification models on top of Model2Vec models. Find out more in our [training documentation](https://github.com/MinishLab/model2vec/blob/main/model2vec/train/README.md) and [results](results/README.md#training-results). 127 | 128 | - **30/01/2025**: We released two new models: [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) and [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M). [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is our most performant model to date, using a larger vocabulary and higher dimensions. [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) is a finetune of [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) that is optimized for retrieval tasks, and is the best performing static retrieval model currently available. 129 | 130 | - **30/10/2024**: We released three new models: [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M), [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M), and [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M). These models are trained using [Tokenlearn](https://github.com/MinishLab/tokenlearn). Find out more in our [blog post](https://minishlab.github.io/tokenlearn_blogpost/). NOTE: for users of any of our old English M2V models, we recommend switching to these new models as they [perform better on all tasks](https://github.com/MinishLab/model2vec/tree/main/results). 131 | 132 | ## Main Features 133 | 134 | - **State-of-the-Art Performance**: Model2Vec models outperform any other static embeddings (such as GLoVe and BPEmb) by a large margin, as can be seen in our [results](results/README.md). 135 | - **Small**: Model2Vec reduces the size of a Sentence Transformer model by a factor of up to 50. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is just ~30 MB on disk, and our smallest model just ~8 MB (making it the smallest model on [MTEB](https://huggingface.co/spaces/mteb/leaderboard)!). 136 | - **Lightweight Dependencies**: the base package's only major dependency is `numpy`. 137 | - **Lightning-fast Inference**: up to 500 times faster on CPU than the original model. 138 | - **Fast, Dataset-free Distillation**: distill your own model in 30 seconds on a CPU, without a dataset. 139 | - **Fine-tuning**: fine-tune your own classification models on top of Model2Vec models. 140 | - **Integrated in many popular libraries**: Model2Vec is integrated direclty into popular libraries such as [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) and [LangChain](https://github.com/langchain-ai/langchain). For more information, see our [integrations documentation](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md). 141 | - **Tightly integrated with HuggingFace hub**: easily share and load models from the HuggingFace hub, using the familiar `from_pretrained` and `push_to_hub`. Our own models can be found [here](https://huggingface.co/minishlab). 142 | 143 | ## What is Model2Vec? 144 | 145 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Like BPEmb, it can create subword embeddings, but with much better performance. Distillation doesn't need _any_ data, just a vocabulary and a model. 146 | 147 | The core idea is to forward pass a vocabulary through a sentence transformer model, creating static embeddings for the indiviudal tokens. After this, there are a number of post-processing steps we do that results in our best models. For a more extensive deepdive, please refer to the following resources: 148 | - Our initial [Model2Vec blog post](https://huggingface.co/blog/Pringled/model2vec). Note that, while this post gives a good overview of the core idea, we've made a number of substantial improvements since then. 149 | - Our [Tokenlearn blog post](https://minishlab.github.io/tokenlearn_blogpost/). This post describes the Tokenlearn method we used to train our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062). 150 | - Our official [documentation](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md). This document provides a high-level overview of how Model2Vec works. 151 | 152 | ## Documentation 153 | 154 | Our official documentation can be found [here](https://github.com/MinishLab/model2vec/blob/main/docs/README.md). This includes: 155 | - [Usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md): provides a technical overview of how to use Model2Vec. 156 | - [Integrations documentation](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md): provides examples of how to use Model2Vec in various downstream libraries. 157 | - [Model2Vec technical documentation](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md): provides a high-level overview of how Model2Vec works. 158 | 159 | 160 | ## Model List 161 | 162 | We provide a number of models that can be used out of the box. These models are available on the [HuggingFace hub](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) and can be loaded using the `from_pretrained` method. The models are listed below. 163 | 164 | 165 | 166 | | Model | Language | Sentence Transformer | Params | Task | 167 | |-----------------------------------------------------------------------|------------|-----------------------------------------------------------------|---------|-----------| 168 | | [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 32.3M | General | 169 | | [potion-multilingual-128M](https://huggingface.co/minishlab/potion-multilingual-128M) | Multilingual | [bge-m3](https://huggingface.co/BAAI/bge-m3) | 128M | General | 170 | | [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 32.3M | Retrieval | 171 | | [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 7.5M | General | 172 | | [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 3.7M | General | 173 | | [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 1.8M | General | 174 | 175 | 176 | 177 | 178 | ## Results 179 | 180 | We have performed extensive experiments to evaluate the performance of Model2Vec models. The results are documented in the [results](results/README.md) folder. The results are presented in the following sections: 181 | - [MTEB Results](results/README.md#mteb-results) 182 | - [Training Results](results/README.md#training-results) 183 | - [Ablations](results/README.md#ablations) 184 | 185 | ## License 186 | 187 | MIT 188 | 189 | ## Citing 190 | 191 | If you use Model2Vec in your research, please cite the following: 192 | ```bibtex 193 | @article{minishlab2024model2vec, 194 | author = {Tulkens, Stephan and {van Dongen}, Thomas}, 195 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, 196 | year = {2024}, 197 | url = {https://github.com/MinishLab/model2vec} 198 | } 199 | ``` 200 | -------------------------------------------------------------------------------- /assets/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/logo.png -------------------------------------------------------------------------------- /assets/images/logo_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/logo_v2.png -------------------------------------------------------------------------------- /assets/images/model2vec_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_logo.png -------------------------------------------------------------------------------- /assets/images/model2vec_model_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram.png -------------------------------------------------------------------------------- /assets/images/model2vec_model_diagram_transparant_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram_transparant_dark.png -------------------------------------------------------------------------------- /assets/images/model2vec_model_diagram_transparant_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram_transparant_light.png -------------------------------------------------------------------------------- /assets/images/sentences_per_second_vs_average_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/sentences_per_second_vs_average_score.png -------------------------------------------------------------------------------- /assets/images/speed_vs_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy.png -------------------------------------------------------------------------------- /assets/images/speed_vs_accuracy_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v2.png -------------------------------------------------------------------------------- /assets/images/speed_vs_accuracy_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v3.png -------------------------------------------------------------------------------- /assets/images/speed_vs_accuracy_v4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v4.png -------------------------------------------------------------------------------- /assets/images/speed_vs_mteb_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score.png -------------------------------------------------------------------------------- /assets/images/speed_vs_mteb_score_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score_v2.png -------------------------------------------------------------------------------- /assets/images/speed_vs_mteb_score_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score_v3.png -------------------------------------------------------------------------------- /assets/images/training_speed_vs_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/training_speed_vs_score.png -------------------------------------------------------------------------------- /assets/images/tutorial_ezlo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/tutorial_ezlo.png -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | This directory contains the documentation for Model2Vec. The documentation is formatted in Markdown. The documentation is organized as follows: 4 | - [usage.md](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md): This document provides a technical overview of how to use Model2Vec. 5 | - [integrations.md](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md): This document provides examples of how to use Model2Vec in various downstream libraries. 6 | - [what_is_model2vec.md](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md): This document provides a high-level overview of how Model2Vec works. 7 | -------------------------------------------------------------------------------- /docs/integrations.md: -------------------------------------------------------------------------------- 1 | 2 | # Integrations 3 | 4 | Model2Vec can be used in a variety of downstream libraries. This document provides examples of how to use Model2Vec in some of these libraries. 5 | 6 | ## Table of Contents 7 | - [Sentence Transformers](#sentence-transformers) 8 | - [LangChain](#langchain) 9 | - [Txtai](#txtai) 10 | - [Chonkie](#chonkie) 11 | - [Transformers.js](#transformersjs) 12 | 13 | ## Sentence Transformers 14 | 15 | Model2Vec can be used directly in [Sentence Transformers](https://github.com/UKPLab/sentence-transformers): 16 | 17 | The following code snippet shows how to load a Model2Vec model into a Sentence Transformer model: 18 | ```python 19 | from sentence_transformers import SentenceTransformer 20 | 21 | # Load a Model2Vec model from the Hub 22 | model = SentenceTransformer("minishlab/potion-base-8M") 23 | # Make embeddings 24 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 25 | ``` 26 | 27 | The following code snippet shows how to distill a model directly into a Sentence Transformer model: 28 | 29 | ```python 30 | from sentence_transformers import SentenceTransformer 31 | from sentence_transformers.models import StaticEmbedding 32 | 33 | static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cpu", pca_dims=256) 34 | model = SentenceTransformer(modules=[static_embedding]) 35 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 36 | ``` 37 | 38 | For more documentation, please refer to the [Sentence Transformers documentation](https://sbert.net/docs/package_reference/sentence_transformer/models.html#sentence_transformers.models.StaticEmbedding). 39 | 40 | 41 | ## LangChain 42 | 43 | Model2Vec can be used in [LangChain](https://github.com/langchain-ai/langchain) using the `langchain-community` package. For more information, see the [LangChain Model2Vec docs](https://python.langchain.com/docs/integrations/text_embedding/model2vec/). The following code snippet shows how to use Model2Vec in LangChain after installing the `langchain-community` package with `pip install langchain-community`: 44 | 45 | ```python 46 | from langchain_community.embeddings import Model2vecEmbeddings 47 | from langchain_community.vectorstores import FAISS 48 | from langchain.schema import Document 49 | 50 | # Initialize a Model2Vec embedder 51 | embedder = Model2vecEmbeddings("minishlab/potion-base-8M") 52 | 53 | # Create some example texts 54 | texts = [ 55 | "Enduring Stew", 56 | "Hearty Elixir", 57 | "Mighty Mushroom Risotto", 58 | "Spicy Meat Skewer", 59 | "Fruit Salad", 60 | ] 61 | 62 | # Embed the texts 63 | embeddings = embedder.embed_documents(texts) 64 | 65 | # Or, create a vector store and query it 66 | documents = [Document(page_content=text) for text in texts] 67 | vector_store = FAISS.from_documents(documents, embedder) 68 | query = "Risotto" 69 | query_vector = embedder.embed_query(query) 70 | retrieved_docs = vector_store.similarity_search_by_vector(query_vector, k=1) 71 | ``` 72 | 73 | ## Txtai 74 | 75 | Model2Vec can be used in [txtai](https://github.com/neuml/txtai) for text embeddings, nearest-neighbors search, and any of the other functionalities that txtai offers. The following code snippet shows how to use Model2Vec in txtai after installing the `txtai` package (including the `vectors` dependency) with `pip install txtai[vectors]`: 76 | 77 | ```python 78 | from txtai import Embeddings 79 | 80 | # Load a model2vec model 81 | embeddings = Embeddings(path="minishlab/potion-base-8M", method="model2vec", backend="numpy") 82 | 83 | # Create some example texts 84 | texts = ["Enduring Stew", "Hearty Elixir", "Mighty Mushroom Risotto", "Spicy Meat Skewer", "Chilly Fruit Salad"] 85 | 86 | # Create embeddings for downstream tasks 87 | vectors = embeddings.batchtransform(texts) 88 | 89 | # Or create a nearest-neighbors index and search it 90 | embeddings.index(texts) 91 | result = embeddings.search("Risotto", 1) 92 | ``` 93 | 94 | ## Chonkie 95 | 96 | Model2Vec is the default model for semantic chunking in [Chonkie](https://github.com/bhavnicksm/chonkie). To use Model2Vec for semantic chunking in Chonkie, simply install Chonkie with `pip install chonkie[semantic]` and use one of the `potion` models in the `SemanticChunker` class. The following code snippet shows how to use Model2Vec in Chonkie: 97 | 98 | ```python 99 | from chonkie import SDPMChunker 100 | 101 | # Create some example text to chunk 102 | text = "It's dangerous to go alone! Take this." 103 | 104 | # Initialize the SemanticChunker with a potion model 105 | chunker = SDPMChunker( 106 | embedding_model="minishlab/potion-base-8M", 107 | similarity_threshold=0.3 108 | ) 109 | 110 | # Chunk the text 111 | chunks = chunker.chunk(text) 112 | ``` 113 | 114 | ## Transformers.js 115 | 116 | To use a Model2Vec model in [transformers.js](https://github.com/huggingface/transformers.js), the following code snippet can be used as a starting point: 117 | 118 | ```javascript 119 | import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers'; 120 | 121 | const modelName = 'minishlab/potion-base-8M'; 122 | 123 | const modelConfig = { 124 | config: { model_type: 'model2vec' }, 125 | dtype: 'fp32', 126 | revision: 'refs/pr/1' 127 | }; 128 | const tokenizerConfig = { 129 | revision: 'refs/pr/2' 130 | }; 131 | 132 | const model = await AutoModel.from_pretrained(modelName, modelConfig); 133 | const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig); 134 | 135 | const texts = ['hello', 'hello world']; 136 | const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false }); 137 | 138 | const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []); 139 | const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))]; 140 | 141 | const flattened_input_ids = input_ids.flat(); 142 | const modelInputs = { 143 | input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]), 144 | offsets: new Tensor('int64', offsets, [offsets.length]) 145 | }; 146 | 147 | const { embeddings } = await model(modelInputs); 148 | console.log(embeddings.tolist()); // output matches python version 149 | ``` 150 | 151 | Note that this requires that the Model2Vec has a `model.onnx` file and several required tokenizers file. To generate these for a model that does not have them yet, the following code snippet can be used: 152 | 153 | ```bash 154 | python scripts/export_to_onnx.py --model_path --save_path "" 155 | ``` 156 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | 2 | # Usage 3 | 4 | This document provides an overview of how to use Model2Vec for inference, distillation, training, and evaluation. 5 | 6 | ## Table of Contents 7 | - [Inference](#inference) 8 | - [Inference with a pretrained model](#inference-with-a-pretrained-model) 9 | - [Inference with the Sentence Transformers library](#inference-with-the-sentence-transformers-library) 10 | - [Distillation](#distillation) 11 | - [Distilling from a Sentence Transformer](#distilling-from-a-sentence-transformer) 12 | - [Distilling from a loaded model](#distilling-from-a-loaded-model) 13 | - [Distilling with the Sentence Transformers library](#distilling-with-the-sentence-transformers-library) 14 | - [Distilling with a custom vocabulary](#distilling-with-a-custom-vocabulary) 15 | - [Training](#training) 16 | - [Training a classifier](#training-a-classifier) 17 | - [Evaluation](#evaluation) 18 | - [Installation](#installation) 19 | - [Evaluation Code](#evaluation-code) 20 | 21 | ## Inference 22 | 23 | ### Inference with a pretrained model 24 | 25 | Inference works as follows. The example shows one of our own models, but you can also just load a local one, or another one from the hub. 26 | ```python 27 | from model2vec import StaticModel 28 | 29 | # Load a model from the Hub. You can optionally pass a token when loading a private model 30 | model = StaticModel.from_pretrained(model_name="minishlab/potion-base-8M", token=None) 31 | 32 | # Make embeddings 33 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 34 | 35 | # Make sequences of token embeddings 36 | token_embeddings = model.encode_as_sequence(["It's dangerous to go alone!", "It's a secret to everybody."]) 37 | ``` 38 | 39 | ### Inference with the Sentence Transformers library 40 | 41 | The following code snippet shows how to use a Model2Vec model in the [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) library. This is useful if you want to use the model in a Sentence Transformers pipeline. 42 | 43 | ```python 44 | from sentence_transformers import SentenceTransformer 45 | 46 | # Load a Model2Vec model from the Hub 47 | model = SentenceTransformer("minishlab/potion-base-8M") 48 | 49 | # Make embeddings 50 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 51 | ``` 52 | 53 | ## Distillation 54 | 55 | ### Distilling from a Sentence Transformer 56 | 57 | The following code can be used to distill a model from a Sentence Transformer. As mentioned above, this leads to really small model that might be less performant. 58 | ```python 59 | from model2vec.distill import distill 60 | 61 | # Distill a Sentence Transformer model 62 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256) 63 | 64 | # Save the model 65 | m2v_model.save_pretrained("m2v_model") 66 | 67 | ``` 68 | 69 | ### Distilling from a loaded model 70 | 71 | If you already have a model loaded, or need to load a model in some special way, we also offer an interface to distill models in memory. 72 | 73 | ```python 74 | from transformers import AutoModel, AutoTokenizer 75 | 76 | from model2vec.distill import distill_from_model 77 | 78 | # Assuming a loaded model and tokenizer 79 | model_name = "baai/bge-base-en-v1.5" 80 | model = AutoModel.from_pretrained(model_name) 81 | tokenizer = AutoTokenizer.from_pretrained(model_name) 82 | 83 | m2v_model = distill_from_model(model=model, tokenizer=tokenizer, pca_dims=256) 84 | 85 | m2v_model.save_pretrained("m2v_model") 86 | 87 | ``` 88 | 89 | ### Distilling with the Sentence Transformers library 90 | 91 | The following code snippet shows how to distill a model using the [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) library. This is useful if you want to use the model in a Sentence Transformers pipeline. 92 | 93 | ```python 94 | from sentence_transformers import SentenceTransformer 95 | from sentence_transformers.models import StaticEmbedding 96 | 97 | static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cpu", pca_dims=256) 98 | model = SentenceTransformer(modules=[static_embedding]) 99 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."]) 100 | ``` 101 | 102 | ### Distilling with a custom vocabulary 103 | 104 | If you pass a vocabulary, you get a set of static word embeddings, together with a custom tokenizer for exactly that vocabulary. This is comparable to how you would use GLoVe or traditional word2vec, but doesn't actually require a corpus or data. 105 | ```python 106 | from model2vec.distill import distill 107 | 108 | # Load a vocabulary as a list of strings 109 | vocabulary = ["word1", "word2", "word3"] 110 | 111 | # Distill a Sentence Transformer model with the custom vocabulary 112 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", vocabulary=vocabulary) 113 | 114 | # Save the model 115 | m2v_model.save_pretrained("m2v_model") 116 | 117 | # Or push it to the hub 118 | m2v_model.push_to_hub("my_organization/my_model", token="") 119 | ``` 120 | 121 | By default, this will distill a model with a subword tokenizer, combining the models (subword) vocab with the new vocabulary. If you want to get a word-level tokenizer instead (with only the passed vocabulary), the `use_subword` parameter can be set to `False`, e.g.: 122 | 123 | ```python 124 | m2v_model = distill(model_name=model_name, vocabulary=vocabulary, use_subword=False) 125 | ``` 126 | 127 | **Important note:** we assume the passed vocabulary is sorted in rank frequency. i.e., we don't care about the actual word frequencies, but do assume that the most frequent word is first, and the least frequent word is last. If you're not sure whether this is case, set `apply_zipf` to `False`. This disables the weighting, but will also make performance a little bit worse. 128 | 129 | ### Quantization 130 | 131 | Models can be quantized to `float16` (default) or `int8` during distillation, or when loading from disk. 132 | 133 | ```python 134 | from model2vec.distill import distill 135 | 136 | # Distill a Sentence Transformer model and quantize is to int8 137 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", quantize_to="int8") 138 | 139 | # Save the model. This model is now 25% of the size of a normal model. 140 | m2v_model.save_pretrained("m2v_model") 141 | ``` 142 | 143 | You can also quantize during loading. 144 | 145 | ```python 146 | from model2vec import StaticModel 147 | 148 | model = StaticModel.from_pretrained("minishlab/potion-base-8m", quantize_to="int8") 149 | ``` 150 | 151 | ### Dimensionality reduction 152 | 153 | Because almost all Model2Vec models have been distilled using PCA, and because PCA explicitly orders dimensions from most informative to least informative, we can perform dimensionality reduction during loading. This is very similar to how matryoshka embeddings work. 154 | 155 | ```python 156 | from model2vec import StaticModel 157 | 158 | model = StaticModel.from_pretrained("minishlab/potion-base-8m", dimensionality=32) 159 | 160 | print(model.embedding.shape) 161 | # (29528, 32) 162 | ``` 163 | 164 | ### Combining quantization and dimensionality reduction 165 | 166 | Combining these tricks can lead to extremely small models. For example, using this, we can reduce the size of `potion-base-8m`, which is now 30MB, to only 1MB: 167 | 168 | ```python 169 | model = StaticModel.from_pretrained("minishlab/potion-base-8m", 170 | dimensionality=32, 171 | quantize_to="int8") 172 | print(model.embedding.nbytes) 173 | # 944896 bytes = 944kb 174 | ``` 175 | 176 | This should be enough to satisfy even the strongest hardware constraints. 177 | 178 | ## Training 179 | 180 | ### Training a classifier 181 | 182 | Model2Vec can be used to train a classifier on top of a distilled model. The following code snippet shows how to train a classifier on top of a distilled model. For more advanced usage, as well as results, please refer to the [training documentation](https://github.com/MinishLab/model2vec/blob/main/model2vec/train/README.md). 183 | 184 | ```python 185 | import numpy as np 186 | from datasets import load_dataset 187 | from model2vec.train import StaticModelForClassification 188 | 189 | # Initialize a classifier from a pre-trained model 190 | classifer = StaticModelForClassification.from_pretrained("minishlab/potion-base-8M") 191 | 192 | # Load a dataset 193 | ds = load_dataset("setfit/subj") 194 | train = ds["train"] 195 | test = ds["test"] 196 | 197 | X_train, y_train = train["text"], train["label"] 198 | X_test, y_test = test["text"], test["label"] 199 | 200 | # Train the classifier 201 | classifier.fit(X_train, y_train) 202 | 203 | # Evaluate the classifier 204 | y_hat = classifier.predict(X_test) 205 | accuracy = np.mean(np.array(y_hat) == np.array(y_test)) * 100 206 | ``` 207 | 208 | ## Evaluation 209 | 210 | ### Installation 211 | 212 | Our models can be evaluated using our [evaluation package](https://github.com/MinishLab/evaluation). Install the evaluation package with: 213 | 214 | ```bash 215 | pip install git+https://github.com/MinishLab/evaluation.git@main 216 | ``` 217 | 218 | ### Evaluation Code 219 | 220 | The following code snippet shows how to evaluate a Model2Vec model: 221 | ```python 222 | from model2vec import StaticModel 223 | 224 | from evaluation import CustomMTEB, get_tasks, parse_mteb_results, make_leaderboard, summarize_results 225 | from mteb import ModelMeta 226 | 227 | # Get all available tasks 228 | tasks = get_tasks() 229 | # Define the CustomMTEB object with the specified tasks 230 | evaluation = CustomMTEB(tasks=tasks) 231 | 232 | # Load the model 233 | model_name = "m2v_model" 234 | model = StaticModel.from_pretrained(model_name) 235 | 236 | # Optionally, add model metadata in MTEB format 237 | model.mteb_model_meta = ModelMeta( 238 | name=model_name, revision="no_revision_available", release_date=None, languages=None 239 | ) 240 | 241 | # Run the evaluation 242 | results = evaluation.run(model, eval_splits=["test"], output_folder=f"results") 243 | 244 | # Parse the results and summarize them 245 | parsed_results = parse_mteb_results(mteb_results=results, model_name=model_name) 246 | task_scores = summarize_results(parsed_results) 247 | 248 | # Print the results in a leaderboard format 249 | print(make_leaderboard(task_scores)) 250 | ``` 251 | -------------------------------------------------------------------------------- /docs/what_is_model2vec.md: -------------------------------------------------------------------------------- 1 | # What is Model2Vec? 2 | 3 | This document provides a high-level overview of how Model2Vec works. 4 | 5 | The base model2vec technique works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using SIF weighting (previously zipf weighting). During inference, we simply take the mean of all token embeddings occurring in a sentence. 6 | 7 | Our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062) are pre-trained using [tokenlearn](https://github.com/MinishLab/tokenlearn), a technique to pre-train model2vec distillation models. These models are created with the following steps: 8 | - **Distillation**: We distill a Model2Vec model from a Sentence Transformer model, using the method described above. 9 | - **Sentence Transformer inference**: We use the Sentence Transformer model to create mean embeddings for a large number of texts from a corpus. 10 | - **Training**: We train a model to minimize the cosine distance between the mean embeddings generated by the Sentence Transformer model and the mean embeddings generated by the Model2Vec model. 11 | - **Post-training re-regularization**: We re-regularize the trained embeddings by first performing PCA, and then weighting the embeddings using `smooth inverse frequency (SIF)` weighting using the following formula: `w = 1e-3 / (1e-3 + proba)`. Here, `proba` is the probability of the token in the corpus we used for training. 12 | -------------------------------------------------------------------------------- /model2vec/__init__.py: -------------------------------------------------------------------------------- 1 | from model2vec.model import StaticModel 2 | from model2vec.version import __version__ 3 | 4 | __all__ = ["StaticModel", "__version__"] 5 | -------------------------------------------------------------------------------- /model2vec/distill/__init__.py: -------------------------------------------------------------------------------- 1 | from model2vec.utils import get_package_extras, importable 2 | 3 | _REQUIRED_EXTRA = "distill" 4 | 5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): 6 | importable(extra_dependency, _REQUIRED_EXTRA) 7 | 8 | from model2vec.distill.distillation import distill, distill_from_model 9 | 10 | __all__ = ["distill", "distill_from_model"] 11 | -------------------------------------------------------------------------------- /model2vec/distill/distillation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import re 6 | from typing import Optional, cast 7 | 8 | import numpy as np 9 | from huggingface_hub import model_info 10 | from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast 11 | 12 | from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings 13 | from model2vec.distill.utils import select_optimal_device 14 | from model2vec.model import StaticModel 15 | from model2vec.quantization import DType, quantize_embeddings 16 | from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def distill_from_model( 22 | model: PreTrainedModel, 23 | tokenizer: PreTrainedTokenizerFast, 24 | vocabulary: list[str] | None = None, 25 | device: str | None = None, 26 | pca_dims: PCADimType = 256, 27 | apply_zipf: bool | None = None, 28 | sif_coefficient: float | None = 1e-4, 29 | token_remove_pattern: str | None = r"\[unused\d+\]", 30 | quantize_to: DType | str = DType.Float16, 31 | use_subword: bool | None = None, 32 | ) -> StaticModel: 33 | """ 34 | Distill a staticmodel from a sentence transformer. 35 | 36 | This function creates a set of embeddings from a sentence transformer. It does this by doing either 37 | a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary. 38 | 39 | If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary. 40 | If you don't pass a vocabulary, we use the model's tokenizer directly. 41 | 42 | :param model: The model to use. 43 | :param tokenizer: The tokenizer to use. 44 | :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary. 45 | :param device: The device to use. 46 | :param pca_dims: The number of components to use for PCA. 47 | If this is None, we don't apply PCA. 48 | If this is 'auto', we don't reduce dimensionality, but still apply PCA. 49 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied. 50 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied. 51 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied. 52 | Should be a value > 0 and < 1.0. A value of 1e-4 is a good default. 53 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary. 54 | If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error. 55 | :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents. 56 | :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything. 57 | :return: A StaticModel 58 | :raises: ValueError if the vocabulary is empty after preprocessing. 59 | 60 | """ 61 | if use_subword is not None: 62 | logger.warning( 63 | "The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything." 64 | ) 65 | quantize_to = DType(quantize_to) 66 | backend_tokenizer = tokenizer.backend_tokenizer 67 | sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern) 68 | 69 | if vocabulary is None: 70 | vocabulary = [] 71 | 72 | device = select_optimal_device(device) 73 | 74 | n_tokens_before = len(vocabulary) 75 | # Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary. 76 | all_tokens, backend_tokenizer = clean_and_create_vocabulary( 77 | tokenizer, vocabulary, token_remove_regex=token_remove_regex 78 | ) 79 | n_tokens_after = len([token for token in all_tokens if not token.is_internal]) 80 | if n_tokens_before: 81 | logger.info( 82 | f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing." 83 | ) 84 | 85 | if not all_tokens: 86 | raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.") 87 | 88 | unk_token = cast(Optional[str], tokenizer.special_tokens_map.get("unk_token")) 89 | pad_token = cast(Optional[str], tokenizer.special_tokens_map.get("pad_token")) 90 | 91 | # Weird if to satsify mypy 92 | if pad_token is None: 93 | if unk_token is not None: 94 | pad_token = unk_token 95 | logger.warning( 96 | "The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token." 97 | ) 98 | else: 99 | pad_token = unk_token or all_tokens[0].form 100 | logger.warning( 101 | "The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token." 102 | ) 103 | 104 | # Replace the vocabulary in the tokenizer with the new vocabulary. 105 | backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token) 106 | 107 | logger.info(f"Creating embeddings for {len(all_tokens)} tokens") 108 | # Convert tokens to IDs 109 | token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token) 110 | 111 | # Create the embeddings 112 | embeddings = create_embeddings( 113 | tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token] 114 | ) 115 | 116 | # Post process the embeddings by applying PCA and Zipf weighting. 117 | embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient) 118 | # Quantize the embeddings. 119 | embeddings = quantize_embeddings(embeddings, quantize_to) 120 | 121 | model_name = getattr(model, "name_or_path", "") 122 | 123 | config = { 124 | "model_type": "model2vec", 125 | "architectures": ["StaticModel"], 126 | "tokenizer_name": model_name, 127 | "apply_pca": pca_dims, 128 | "apply_zipf": apply_zipf, 129 | "sif_coefficient": sif_coefficient, 130 | "hidden_dim": embeddings.shape[1], 131 | "seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit. 132 | "normalize": True, 133 | } 134 | 135 | if os.path.exists(model_name): 136 | # Using a local model. Get the model name from the path. 137 | model_name = os.path.basename(model_name) 138 | language = None 139 | else: 140 | # Get the language from the model card. 141 | try: 142 | info = model_info(model_name) 143 | language = info.cardData.get("language", None) if info.cardData is not None else None 144 | except Exception as e: 145 | # NOTE: bare except because there's many reasons this can fail. 146 | logger.warning(f"Couldn't get the model info from the Hugging Face Hub: {e}. Setting language to None.") 147 | language = None 148 | 149 | return StaticModel( 150 | vectors=embeddings, 151 | tokenizer=backend_tokenizer, 152 | config=config, 153 | base_model_name=model_name, 154 | language=language, 155 | normalize=True, 156 | ) 157 | 158 | 159 | def _validate_parameters( 160 | apply_zipf: bool | None, 161 | sif_coefficient: float | None, 162 | token_remove_pattern: str | None, 163 | ) -> tuple[float | None, re.Pattern | None]: 164 | """ 165 | Validate the parameters passed to the distillation function. 166 | 167 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied. 168 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied. 169 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied. 170 | Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default. 171 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary. 172 | :return: The SIF coefficient to use. 173 | :raises: ValueError if the regex can't be compiled. 174 | 175 | """ 176 | if apply_zipf is not None: 177 | logger.warning( 178 | "The `apply_zipf` parameter is deprecated and will be removed in the next release. " 179 | "Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, " 180 | "no weighting is applied." 181 | ) 182 | if apply_zipf and sif_coefficient is None: 183 | logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.") 184 | sif_coefficient = 1e-4 185 | elif not apply_zipf: 186 | logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.") 187 | sif_coefficient = None 188 | 189 | if sif_coefficient is not None: 190 | if not 0 < sif_coefficient < 1.0: 191 | raise ValueError("SIF coefficient must be a value > 0 and < 1.0.") 192 | 193 | token_remove_regex: re.Pattern | None = None 194 | if token_remove_pattern is not None: 195 | try: 196 | token_remove_regex = re.compile(token_remove_pattern) 197 | except re.error as e: 198 | raise ValueError(f"Couldn't compile the regex pattern: {e}") 199 | 200 | return sif_coefficient, token_remove_regex 201 | 202 | 203 | def distill( 204 | model_name: str, 205 | vocabulary: list[str] | None = None, 206 | device: str | None = None, 207 | pca_dims: PCADimType = 256, 208 | apply_zipf: bool | None = None, 209 | sif_coefficient: float | None = 1e-4, 210 | token_remove_pattern: str | None = r"\[unused\d+\]", 211 | trust_remote_code: bool = False, 212 | quantize_to: DType | str = DType.Float16, 213 | use_subword: bool | None = None, 214 | ) -> StaticModel: 215 | """ 216 | Distill a staticmodel from a sentence transformer. 217 | 218 | This function creates a set of embeddings from a sentence transformer. It does this by doing either 219 | a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary. 220 | 221 | If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary. 222 | If you don't pass a vocabulary, we use the model's tokenizer directly. 223 | 224 | :param model_name: The model name to use. Any sentencetransformer compatible model works. 225 | :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary. 226 | :param device: The device to use. 227 | :param pca_dims: The number of components to use for PCA. 228 | If this is None, we don't apply PCA. 229 | If this is 'auto', we don't reduce dimenionality, but still apply PCA. 230 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied. 231 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied. 232 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied. 233 | Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default. 234 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary. 235 | :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components. 236 | :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents. 237 | :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything. 238 | :return: A StaticModel 239 | 240 | """ 241 | model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code) 242 | tokenizer = cast( 243 | PreTrainedTokenizerFast, 244 | AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True), 245 | ) 246 | 247 | return distill_from_model( 248 | model=model, 249 | tokenizer=tokenizer, 250 | vocabulary=vocabulary, 251 | device=device, 252 | pca_dims=pca_dims, 253 | apply_zipf=apply_zipf, 254 | token_remove_pattern=token_remove_pattern, 255 | sif_coefficient=sif_coefficient, 256 | quantize_to=quantize_to, 257 | use_subword=use_subword, 258 | ) 259 | -------------------------------------------------------------------------------- /model2vec/distill/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import inspect 5 | import logging 6 | from pathlib import Path 7 | from typing import Literal, Protocol, Union 8 | 9 | import numpy as np 10 | import torch 11 | from sklearn.decomposition import PCA 12 | from torch.nn.utils.rnn import pad_sequence 13 | from tqdm import tqdm 14 | from transformers import PreTrainedModel 15 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | PathLike = Union[Path, str] 21 | PCADimType = Union[int, None, float, Literal["auto"]] 22 | 23 | 24 | _DEFAULT_BATCH_SIZE = 256 25 | 26 | 27 | class ModulewithWeights(Protocol): 28 | weight: torch.nn.Parameter 29 | 30 | 31 | def create_embeddings( 32 | model: PreTrainedModel, 33 | tokenized: list[list[int]], 34 | device: str, 35 | pad_token_id: int, 36 | ) -> np.ndarray: 37 | """ 38 | Create output embeddings for a bunch of tokens using a pretrained model. 39 | 40 | It does a forward pass for all tokens passed in `tokens`. 41 | 42 | :param model: The model to use. 43 | This should be a transformers model. 44 | :param tokenized: All tokenized tokens. 45 | :param device: The torch device to use. 46 | :param pad_token_id: The pad token id. Used to pad sequences. 47 | :return: The output embeddings. 48 | """ 49 | model = model.to(device) 50 | 51 | out_weights: np.ndarray 52 | intermediate_weights: list[np.ndarray] = [] 53 | 54 | # Add token_type_ids only if the model supports it 55 | add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args 56 | 57 | lengths = np.asarray([len(sequence) for sequence in tokenized]) 58 | sort_order = np.argsort(lengths) 59 | 60 | sorted_tokenized = [tokenized[i] for i in sort_order] 61 | 62 | pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens") 63 | 64 | for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE): 65 | batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]] 66 | 67 | encoded = {} 68 | encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id) 69 | encoded["attention_mask"] = encoded["input_ids"] != pad_token_id 70 | 71 | if add_token_type_ids: 72 | encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"]) 73 | 74 | out = _encode_mean_using_model(model, encoded) 75 | intermediate_weights.extend(out.numpy()) 76 | pbar.update(len(batch)) 77 | 78 | # Sort the output back to the original order 79 | intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)] 80 | out_weights = np.stack(intermediate_weights) 81 | 82 | out_weights = np.nan_to_num(out_weights) 83 | 84 | return out_weights 85 | 86 | 87 | @torch.no_grad() 88 | def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor: 89 | """ 90 | Encode a batch of tokens using a model. 91 | 92 | Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros. 93 | So detection of these is necessary. 94 | 95 | :param model: The model to use. 96 | :param encodings: The encoded tokens to turn into features. 97 | :return: The mean of the output for each token. 98 | """ 99 | encodings = {k: v.to(model.device) for k, v in encodings.items()} 100 | encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings) 101 | out: torch.Tensor = encoded.last_hidden_state.cpu() 102 | # NOTE: If the dtype is bfloat 16, we convert to float32, 103 | # because numpy does not suport bfloat16 104 | # See here: https://github.com/numpy/numpy/issues/19808 105 | if out.dtype == torch.bfloat16: 106 | out = out.float() 107 | 108 | # Take the mean by averaging over the attention mask. 109 | mask = encodings["attention_mask"].cpu().float() 110 | mask /= mask.sum(1)[:, None] 111 | 112 | result = torch.bmm(mask[:, None, :].float(), out).squeeze(1) 113 | 114 | return result 115 | 116 | 117 | def post_process_embeddings( 118 | embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4 119 | ) -> np.ndarray: 120 | """Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law.""" 121 | if pca_dims is not None: 122 | if pca_dims == "auto": 123 | pca_dims = embeddings.shape[1] 124 | if pca_dims > embeddings.shape[1]: 125 | logger.warning( 126 | f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). " 127 | "Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. " 128 | "Applying PCA will probably improve performance, so consider just leaving it." 129 | ) 130 | pca_dims = embeddings.shape[1] 131 | if pca_dims >= embeddings.shape[0]: 132 | logger.warning( 133 | f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA." 134 | ) 135 | elif pca_dims <= embeddings.shape[1]: 136 | if isinstance(pca_dims, float): 137 | logger.info(f"Applying PCA with {pca_dims} explained variance.") 138 | else: 139 | logger.info(f"Applying PCA with n_components {pca_dims}") 140 | 141 | orig_dims = embeddings.shape[1] 142 | p = PCA(n_components=pca_dims, svd_solver="full") 143 | embeddings = p.fit_transform(embeddings) 144 | 145 | if embeddings.shape[1] < orig_dims: 146 | explained_variance_ratio = np.sum(p.explained_variance_ratio_) 147 | explained_variance = np.sum(p.explained_variance_) 148 | logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.") 149 | logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.") 150 | logger.info(f"Explained variance: {explained_variance:.3f}.") 151 | 152 | if sif_coefficient is not None: 153 | logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.") 154 | inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2)) 155 | proba = inv_rank / np.sum(inv_rank) 156 | embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None] 157 | 158 | return embeddings 159 | -------------------------------------------------------------------------------- /model2vec/distill/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from logging import getLogger 4 | 5 | import torch 6 | 7 | logger = getLogger(__name__) 8 | 9 | 10 | def select_optimal_device(device: str | None) -> str: 11 | """ 12 | Guess what your optimal device should be based on backend availability. 13 | 14 | If you pass a device, we just pass it through. 15 | 16 | :param device: The device to use. If this is not None you get back what you passed. 17 | :return: The selected device. 18 | """ 19 | if device is None: 20 | if torch.cuda.is_available(): 21 | device = "cuda" 22 | elif torch.backends.mps.is_available(): 23 | device = "mps" 24 | else: 25 | device = "cpu" 26 | logger.info(f"Automatically selected device: {device}") 27 | 28 | return device 29 | -------------------------------------------------------------------------------- /model2vec/hf_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | from pathlib import Path 6 | from typing import Any, cast 7 | 8 | import huggingface_hub 9 | import numpy as np 10 | import safetensors 11 | from huggingface_hub import ModelCard, ModelCardData 12 | from safetensors.numpy import save_file 13 | from tokenizers import Tokenizer 14 | 15 | from model2vec.utils import SafeOpenProtocol 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def save_pretrained( 21 | folder_path: Path, 22 | embeddings: np.ndarray, 23 | tokenizer: Tokenizer, 24 | config: dict[str, Any], 25 | create_model_card: bool = True, 26 | subfolder: str | None = None, 27 | **kwargs: Any, 28 | ) -> None: 29 | """ 30 | Save a model to a folder. 31 | 32 | :param folder_path: The path to the folder. 33 | :param embeddings: The embeddings. 34 | :param tokenizer: The tokenizer. 35 | :param config: A metadata config. 36 | :param create_model_card: Whether to create a model card. 37 | :param subfolder: The subfolder to save the model in. 38 | :param **kwargs: Any additional arguments. 39 | """ 40 | folder_path = folder_path / subfolder if subfolder else folder_path 41 | folder_path.mkdir(exist_ok=True, parents=True) 42 | save_file({"embeddings": embeddings}, folder_path / "model.safetensors") 43 | tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False) 44 | json.dump(config, open(folder_path / "config.json", "w"), indent=4) 45 | 46 | # Create modules.json 47 | modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}] 48 | if config.get("normalize"): 49 | # If normalize=True, add sentence_transformers.models.Normalize 50 | modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"}) 51 | json.dump(modules, open(folder_path / "modules.json", "w"), indent=4) 52 | 53 | logger.info(f"Saved model to {folder_path}") 54 | 55 | # Optionally create the model card 56 | if create_model_card: 57 | _create_model_card(folder_path, **kwargs) 58 | 59 | 60 | def _create_model_card( 61 | folder_path: Path, 62 | base_model_name: str = "unknown", 63 | license: str = "mit", 64 | language: list[str] | None = None, 65 | model_name: str | None = None, 66 | template_path: str = "modelcards/model_card_template.md", 67 | **kwargs: Any, 68 | ) -> None: 69 | """ 70 | Create a model card and store it in the specified path. 71 | 72 | :param folder_path: The path where the model card will be stored. 73 | :param base_model_name: The name of the base model. 74 | :param license: The license to use. 75 | :param language: The language of the model. 76 | :param model_name: The name of the model to use in the Model Card. 77 | :param template_path: The path to the template. 78 | :param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.). 79 | """ 80 | folder_path = Path(folder_path) 81 | model_name = model_name or folder_path.name 82 | full_path = Path(__file__).parent / template_path 83 | 84 | model_card_data = ModelCardData( 85 | model_name=model_name, 86 | base_model=base_model_name, 87 | license=license, 88 | language=language, 89 | tags=["embeddings", "static-embeddings", "sentence-transformers"], 90 | library_name="model2vec", 91 | **kwargs, 92 | ) 93 | model_card = ModelCard.from_template(model_card_data, template_path=str(full_path)) 94 | model_card.save(folder_path / "README.md") 95 | 96 | 97 | def load_pretrained( 98 | folder_or_repo_path: str | Path, 99 | subfolder: str | None = None, 100 | token: str | None = None, 101 | from_sentence_transformers: bool = False, 102 | ) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]: 103 | """ 104 | Loads a pretrained model from a folder. 105 | 106 | :param folder_or_repo_path: The folder or repo path to load from. 107 | - If this is a local path, we will load from the local path. 108 | - If the local path is not found, we will attempt to load from the huggingface hub. 109 | :param subfolder: The subfolder to load from. 110 | :param token: The huggingface token to use. 111 | :param from_sentence_transformers: Whether to load the model from a sentence transformers model. 112 | :raises: FileNotFoundError if the folder exists, but the file does not exist locally. 113 | :return: The embeddings, tokenizer, config, and metadata. 114 | 115 | """ 116 | if from_sentence_transformers: 117 | model_file = "0_StaticEmbedding/model.safetensors" 118 | tokenizer_file = "0_StaticEmbedding/tokenizer.json" 119 | config_name = "config_sentence_transformers.json" 120 | else: 121 | model_file = "model.safetensors" 122 | tokenizer_file = "tokenizer.json" 123 | config_name = "config.json" 124 | 125 | folder_or_repo_path = Path(folder_or_repo_path) 126 | 127 | local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path 128 | 129 | if local_folder.exists(): 130 | embeddings_path = local_folder / model_file 131 | if not embeddings_path.exists(): 132 | raise FileNotFoundError(f"Embeddings file does not exist in {local_folder}") 133 | 134 | config_path = local_folder / config_name 135 | if not config_path.exists(): 136 | raise FileNotFoundError(f"Config file does not exist in {local_folder}") 137 | 138 | tokenizer_path = local_folder / tokenizer_file 139 | if not tokenizer_path.exists(): 140 | raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}") 141 | 142 | # README is optional, so this is a bit finicky. 143 | readme_path = local_folder / "README.md" 144 | metadata = _get_metadata_from_readme(readme_path) 145 | 146 | else: 147 | logger.info("Folder does not exist locally, attempting to use huggingface hub.") 148 | embeddings_path = Path( 149 | huggingface_hub.hf_hub_download( 150 | folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder 151 | ) 152 | ) 153 | 154 | try: 155 | readme_path = Path( 156 | huggingface_hub.hf_hub_download( 157 | folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder 158 | ) 159 | ) 160 | metadata = _get_metadata_from_readme(Path(readme_path)) 161 | except Exception as e: 162 | # NOTE: we don't want to raise an error here, since the README is optional. 163 | logger.info(f"No README found in the model folder: {e} No model card loaded.") 164 | metadata = {} 165 | 166 | config_path = Path( 167 | huggingface_hub.hf_hub_download( 168 | folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder 169 | ) 170 | ) 171 | tokenizer_path = Path( 172 | huggingface_hub.hf_hub_download( 173 | folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder 174 | ) 175 | ) 176 | 177 | opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy")) 178 | if from_sentence_transformers: 179 | embeddings = opened_tensor_file.get_tensor("embedding.weight") 180 | else: 181 | embeddings = opened_tensor_file.get_tensor("embeddings") 182 | 183 | tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path)) 184 | config = json.load(open(config_path)) 185 | 186 | if len(tokenizer.get_vocab()) != len(embeddings): 187 | logger.warning( 188 | f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`" 189 | ) 190 | 191 | return embeddings, tokenizer, config, metadata 192 | 193 | 194 | def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]: 195 | """Get metadata from a README file.""" 196 | if not readme_path.exists(): 197 | logger.info(f"README file not found in {readme_path}. No model card loaded.") 198 | return {} 199 | model_card = ModelCard.load(readme_path) 200 | data: dict[str, Any] = model_card.data.to_dict() 201 | if not data: 202 | logger.info("File README.md exists, but was empty. No model card loaded.") 203 | return data 204 | 205 | 206 | def push_folder_to_hub( 207 | folder_path: Path, subfolder: str | None, repo_id: str, private: bool, token: str | None 208 | ) -> None: 209 | """ 210 | Push a model folder to the huggingface hub, including model card. 211 | 212 | :param folder_path: The path to the folder. 213 | :param subfolder: The subfolder to push to. 214 | If None, the folder will be pushed to the root of the repo. 215 | :param repo_id: The repo name. 216 | :param private: Whether the repo is private. 217 | :param token: The huggingface token. 218 | """ 219 | if not huggingface_hub.repo_exists(repo_id=repo_id, token=token): 220 | huggingface_hub.create_repo(repo_id, token=token, private=private) 221 | 222 | # Push model card and all model files to the Hugging Face hub 223 | huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder) 224 | 225 | logger.info(f"Pushed model to {repo_id}") 226 | -------------------------------------------------------------------------------- /model2vec/inference/README.md: -------------------------------------------------------------------------------- 1 | # Inference 2 | 3 | This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines. 4 | 5 | If you're looking for information on how to train a model, see [here](../train/README.md). 6 | 7 | # Usage 8 | 9 | Let's assume you're using our [potion-edu classifier](https://huggingface.co/minishlab/potion-8m-edu-classifier). 10 | 11 | ```python 12 | from model2vec.inference import StaticModelPipeline 13 | 14 | classifier = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier") 15 | label = classifier.predict("Attitudes towards cattle in the Alps: a study in letting go.") 16 | ``` 17 | 18 | This should just work. 19 | -------------------------------------------------------------------------------- /model2vec/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from model2vec.utils import get_package_extras, importable 2 | 3 | _REQUIRED_EXTRA = "inference" 4 | 5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): 6 | importable(extra_dependency, _REQUIRED_EXTRA) 7 | 8 | from model2vec.inference.model import StaticModelPipeline, evaluate_single_or_multi_label 9 | 10 | __all__ = ["StaticModelPipeline", "evaluate_single_or_multi_label"] 11 | -------------------------------------------------------------------------------- /model2vec/inference/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from pathlib import Path 5 | from tempfile import TemporaryDirectory 6 | from typing import Sequence, TypeVar 7 | 8 | import huggingface_hub 9 | import numpy as np 10 | import skops.io 11 | from sklearn.metrics import classification_report 12 | from sklearn.neural_network import MLPClassifier 13 | from sklearn.pipeline import Pipeline 14 | from sklearn.preprocessing import MultiLabelBinarizer 15 | 16 | from model2vec.hf_utils import _create_model_card 17 | from model2vec.model import PathLike, StaticModel 18 | 19 | _DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+") 20 | _DEFAULT_MODEL_FILENAME = "pipeline.skops" 21 | 22 | LabelType = TypeVar("LabelType", list[str], list[list[str]]) 23 | 24 | 25 | class StaticModelPipeline: 26 | def __init__(self, model: StaticModel, head: Pipeline) -> None: 27 | """Create a pipeline with a StaticModel encoder.""" 28 | self.model = model 29 | self.head = head 30 | classifier = self.head[-1] 31 | # Check if the classifier is a multilabel classifier. 32 | # NOTE: this doesn't look robust, but it is. 33 | # Different classifiers, such as OVR wrappers, support multilabel output natively, so we 34 | # can just use predict. 35 | self.multilabel = False 36 | if isinstance(classifier, MLPClassifier): 37 | if classifier.out_activation_ == "logistic": 38 | self.multilabel = True 39 | 40 | @property 41 | def classes_(self) -> np.ndarray: 42 | """The classes of the classifier.""" 43 | return self.head.classes_ 44 | 45 | @classmethod 46 | def from_pretrained( 47 | cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False 48 | ) -> StaticModelPipeline: 49 | """ 50 | Load a StaticModel from a local path or huggingface hub path. 51 | 52 | NOTE: if you load a private model from the huggingface hub, you need to pass a token. 53 | 54 | :param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub 55 | :param token: The token to use to download the pipeline from the hub. 56 | :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `sklearn`. 57 | :return: The loaded pipeline. 58 | """ 59 | model, head = _load_pipeline(path, token, trust_remote_code) 60 | model.embedding = np.nan_to_num(model.embedding) 61 | 62 | return cls(model, head) 63 | 64 | def save_pretrained(self, path: str) -> None: 65 | """Save the model to a folder.""" 66 | save_pipeline(self, path) 67 | 68 | def push_to_hub( 69 | self, repo_id: str, subfolder: str | None = None, token: str | None = None, private: bool = False 70 | ) -> None: 71 | """ 72 | Save a model to a folder, and then push that folder to the hf hub. 73 | 74 | :param repo_id: The id of the repository to push to. 75 | :param subfolder: The subfolder to push to. 76 | :param token: The token to use to push to the hub. 77 | :param private: Whether the repository should be private. 78 | """ 79 | from model2vec.hf_utils import push_folder_to_hub 80 | 81 | with TemporaryDirectory() as temp_dir: 82 | save_pipeline(self, temp_dir) 83 | self.model.save_pretrained(temp_dir) 84 | push_folder_to_hub(Path(temp_dir), subfolder, repo_id, private, token) 85 | 86 | def _encode_and_coerce_to_2d( 87 | self, 88 | X: Sequence[str], 89 | show_progress_bar: bool, 90 | max_length: int | None, 91 | batch_size: int, 92 | use_multiprocessing: bool, 93 | multiprocessing_threshold: int, 94 | ) -> np.ndarray: 95 | """Encode the instances and coerce the output to a matrix.""" 96 | encoded = self.model.encode( 97 | X, 98 | show_progress_bar=show_progress_bar, 99 | max_length=max_length, 100 | batch_size=batch_size, 101 | use_multiprocessing=use_multiprocessing, 102 | multiprocessing_threshold=multiprocessing_threshold, 103 | ) 104 | if np.ndim(encoded) == 1: 105 | encoded = encoded[None, :] 106 | 107 | return encoded 108 | 109 | def predict( 110 | self, 111 | X: Sequence[str], 112 | show_progress_bar: bool = False, 113 | max_length: int | None = 512, 114 | batch_size: int = 1024, 115 | use_multiprocessing: bool = True, 116 | multiprocessing_threshold: int = 10_000, 117 | threshold: float = 0.5, 118 | ) -> np.ndarray: 119 | """ 120 | Predict the labels of the input. 121 | 122 | :param X: The input data to predict. Can be a list of strings or a single string. 123 | :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. 124 | :param max_length: The maximum length of the input sequences. Defaults to 512. 125 | :param batch_size: The batch size for prediction. Defaults to 1024. 126 | :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. 127 | :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. 128 | :param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel. 129 | :return: The predicted labels or probabilities. 130 | """ 131 | encoded = self._encode_and_coerce_to_2d( 132 | X, 133 | show_progress_bar=show_progress_bar, 134 | max_length=max_length, 135 | batch_size=batch_size, 136 | use_multiprocessing=use_multiprocessing, 137 | multiprocessing_threshold=multiprocessing_threshold, 138 | ) 139 | 140 | if self.multilabel: 141 | out_labels = [] 142 | proba = self.head.predict_proba(encoded) 143 | for vector in proba: 144 | out_labels.append(self.classes_[vector > threshold]) 145 | return np.asarray(out_labels, dtype=object) 146 | 147 | return self.head.predict(encoded) 148 | 149 | def predict_proba( 150 | self, 151 | X: Sequence[str], 152 | show_progress_bar: bool = False, 153 | max_length: int | None = 512, 154 | batch_size: int = 1024, 155 | use_multiprocessing: bool = True, 156 | multiprocessing_threshold: int = 10_000, 157 | ) -> np.ndarray: 158 | """ 159 | Predict the labels of the input. 160 | 161 | :param X: The input data to predict. Can be a list of strings or a single string. 162 | :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False. 163 | :param max_length: The maximum length of the input sequences. Defaults to 512. 164 | :param batch_size: The batch size for prediction. Defaults to 1024. 165 | :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True. 166 | :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000. 167 | :return: The predicted labels or probabilities. 168 | """ 169 | encoded = self._encode_and_coerce_to_2d( 170 | X, 171 | show_progress_bar=show_progress_bar, 172 | max_length=max_length, 173 | batch_size=batch_size, 174 | use_multiprocessing=use_multiprocessing, 175 | multiprocessing_threshold=multiprocessing_threshold, 176 | ) 177 | 178 | return self.head.predict_proba(encoded) 179 | 180 | def evaluate( 181 | self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False 182 | ) -> str | dict[str, dict[str, float]]: 183 | """ 184 | Evaluate the classifier on a given dataset using scikit-learn's classification report. 185 | 186 | :param X: The texts to predict on. 187 | :param y: The ground truth labels. 188 | :param batch_size: The batch size. 189 | :param threshold: The threshold for multilabel classification. 190 | :param output_dict: Whether to output the classification report as a dictionary. 191 | :return: A classification report. 192 | """ 193 | predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold) 194 | report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict) 195 | 196 | return report 197 | 198 | 199 | def _load_pipeline( 200 | folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False 201 | ) -> tuple[StaticModel, Pipeline]: 202 | """ 203 | Load a model and an sklearn pipeline. 204 | 205 | This assumes the following files are present in the repo: 206 | - `pipeline.skops`: The head of the pipeline. 207 | - `config.json`: The configuration of the model. 208 | - `model.safetensors`: The weights of the model. 209 | - `tokenizer.json`: The tokenizer of the model. 210 | 211 | :param folder_or_repo_path: The path to the folder containing the pipeline. 212 | :param token: The token to use to download the pipeline from the hub. If this is None, you will only 213 | be able to load the pipeline from a local folder, public repository, or a repository that you have access to 214 | because you are logged in. 215 | :param trust_remote_code: Whether to trust the remote code. If this is False, 216 | we will only load components coming from `sklearn`. If this is True, we will load all components. 217 | If you set this to True, you are responsible for whatever happens. 218 | :return: The encoder model and the loaded head 219 | :raises FileNotFoundError: If the pipeline file does not exist in the folder. 220 | :raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False. 221 | """ 222 | folder_or_repo_path = Path(folder_or_repo_path) 223 | model_filename = _DEFAULT_MODEL_FILENAME 224 | head_pipeline_path: str | Path 225 | if folder_or_repo_path.exists(): 226 | head_pipeline_path = folder_or_repo_path / model_filename 227 | if not head_pipeline_path.exists(): 228 | raise FileNotFoundError(f"Pipeline file does not exist in {folder_or_repo_path}") 229 | else: 230 | head_pipeline_path = huggingface_hub.hf_hub_download( 231 | folder_or_repo_path.as_posix(), model_filename, token=token 232 | ) 233 | 234 | model = StaticModel.from_pretrained(folder_or_repo_path) 235 | 236 | unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path) 237 | # If the user does not trust remote code, we should check that the unknown types are trusted. 238 | # By default, we trust everything coming from scikit-learn. 239 | if not trust_remote_code: 240 | for t in unknown_types: 241 | if not _DEFAULT_TRUST_PATTERN.match(t): 242 | raise ValueError(f"Untrusted type {t}.") 243 | head = skops.io.load(head_pipeline_path, trusted=unknown_types) 244 | 245 | return model, head 246 | 247 | 248 | def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None: 249 | """ 250 | Save a pipeline to a folder. 251 | 252 | :param pipeline: The pipeline to save. 253 | :param folder_path: The path to the folder to save the pipeline to. 254 | """ 255 | folder_path = Path(folder_path) 256 | folder_path.mkdir(parents=True, exist_ok=True) 257 | model_filename = _DEFAULT_MODEL_FILENAME 258 | head_pipeline_path = folder_path / model_filename 259 | skops.io.dump(pipeline.head, head_pipeline_path) 260 | pipeline.model.save_pretrained(folder_path) 261 | base_model_name = pipeline.model.base_model_name 262 | if isinstance(base_model_name, list) and base_model_name: 263 | name = base_model_name[0] 264 | elif isinstance(base_model_name, str): 265 | name = base_model_name 266 | else: 267 | name = "unknown" 268 | _create_model_card( 269 | folder_path, 270 | base_model_name=name, 271 | language=pipeline.model.language, 272 | template_path="modelcards/classifier_template.md", 273 | ) 274 | 275 | 276 | def _is_multi_label_shaped(y: LabelType) -> bool: 277 | """Check if the labels are in a multi-label shape.""" 278 | return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set)) 279 | 280 | 281 | def evaluate_single_or_multi_label( 282 | predictions: np.ndarray, 283 | y: LabelType, 284 | output_dict: bool = False, 285 | ) -> str | dict[str, dict[str, float]]: 286 | """ 287 | Evaluate the classifier on a given dataset using scikit-learn's classification report. 288 | 289 | :param predictions: The predictions. 290 | :param y: The ground truth labels. 291 | :param output_dict: Whether to output the classification report as a dictionary. 292 | :return: A classification report. 293 | """ 294 | if _is_multi_label_shaped(y): 295 | classes = sorted(set([label for labels in y for label in labels])) 296 | mlb = MultiLabelBinarizer(classes=classes) 297 | y = mlb.fit_transform(y) 298 | predictions = mlb.transform(predictions) 299 | elif isinstance(y[0], (str, int)): 300 | classes = sorted(set(y)) 301 | 302 | report = classification_report( 303 | y, 304 | predictions, 305 | output_dict=output_dict, 306 | zero_division=0, 307 | ) 308 | 309 | return report 310 | -------------------------------------------------------------------------------- /model2vec/modelcards/classifier_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # {{ model_name }} Model Card 6 | 7 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a fine-tuned version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Model2Vec model. It also includes a classifier head on top. 8 | 9 | ## Installation 10 | 11 | Install model2vec using pip: 12 | ``` 13 | pip install model2vec[inference] 14 | ``` 15 | 16 | ## Usage 17 | Load this model using the `from_pretrained` method: 18 | ```python 19 | from model2vec.inference import StaticModelPipeline 20 | 21 | # Load a pretrained Model2Vec model 22 | model = StaticModelPipeline.from_pretrained("{{ model_name }}") 23 | 24 | # Predict labels 25 | predicted = model.predict(["Example sentence"]) 26 | ``` 27 | 28 | ## Additional Resources 29 | 30 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec) 31 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) 32 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results) 33 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials) 34 | - [Website](https://minishlab.github.io/) 35 | 36 | ## Library Authors 37 | 38 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). 39 | 40 | ## Citation 41 | 42 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work. 43 | ``` 44 | @article{minishlab2024model2vec, 45 | author = {Tulkens, Stephan and {van Dongen}, Thomas}, 46 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, 47 | year = {2024}, 48 | url = {https://github.com/MinishLab/model2vec} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /model2vec/modelcards/model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # {{ model_name }} Model Card 6 | 7 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the {{ base_model }}(https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers. 8 | 9 | 10 | ## Installation 11 | 12 | Install model2vec using pip: 13 | ``` 14 | pip install model2vec 15 | ``` 16 | 17 | ## Usage 18 | 19 | ### Using Model2Vec 20 | 21 | The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models. 22 | 23 | Load this model using the `from_pretrained` method: 24 | ```python 25 | from model2vec import StaticModel 26 | 27 | # Load a pretrained Model2Vec model 28 | model = StaticModel.from_pretrained("{{ model_name }}") 29 | 30 | # Compute text embeddings 31 | embeddings = model.encode(["Example sentence"]) 32 | ``` 33 | 34 | ### Using Sentence Transformers 35 | 36 | You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model: 37 | 38 | ```python 39 | from sentence_transformers import SentenceTransformer 40 | 41 | # Load a pretrained Sentence Transformer model 42 | model = SentenceTransformer("{{ model_name }}") 43 | 44 | # Compute text embeddings 45 | embeddings = model.encode(["Example sentence"]) 46 | ``` 47 | 48 | ### Distilling a Model2Vec model 49 | 50 | You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code: 51 | 52 | ```python 53 | from model2vec.distill import distill 54 | 55 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model 56 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256) 57 | 58 | # Save the model 59 | m2v_model.save_pretrained("m2v_model") 60 | ``` 61 | 62 | ## How it works 63 | 64 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec. 65 | 66 | It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence. 67 | 68 | ## Additional Resources 69 | 70 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec) 71 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) 72 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results) 73 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials) 74 | - [Website](https://minishlab.github.io/) 75 | 76 | 77 | ## Library Authors 78 | 79 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). 80 | 81 | ## Citation 82 | 83 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work. 84 | ``` 85 | @article{minishlab2024model2vec, 86 | author = {Tulkens, Stephan and {van Dongen}, Thomas}, 87 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, 88 | year = {2024}, 89 | url = {https://github.com/MinishLab/model2vec} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /model2vec/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/model2vec/py.typed -------------------------------------------------------------------------------- /model2vec/quantization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | 5 | import numpy as np 6 | 7 | 8 | class DType(str, Enum): 9 | Float16 = "float16" 10 | Float32 = "float32" 11 | Float64 = "float64" 12 | Int8 = "int8" 13 | 14 | 15 | def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray: 16 | """ 17 | Quantize embeddings to a specified data type to reduce memory usage. 18 | 19 | :param embeddings: The embeddings to quantize, as a numpy array. 20 | :param quantize_to: The data type to quantize to. 21 | :return: The quantized embeddings. 22 | :raises ValueError: If the quantization type is not valid. 23 | """ 24 | if quantize_to == DType.Float16: 25 | return embeddings.astype(np.float16) 26 | elif quantize_to == DType.Float32: 27 | return embeddings.astype(np.float32) 28 | elif quantize_to == DType.Float64: 29 | return embeddings.astype(np.float64) 30 | elif quantize_to == DType.Int8: 31 | # Normalize to [-128, 127] range for int8 32 | # We normalize to -127 to 127 to keep symmetry. 33 | scale = np.max(np.abs(embeddings)) / 127.0 34 | quantized = np.round(embeddings / scale).astype(np.int8) 35 | return quantized 36 | else: 37 | raise ValueError("Not a valid enum member of DType.") 38 | 39 | 40 | def quantize_and_reduce_dim( 41 | embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None 42 | ) -> np.ndarray: 43 | """ 44 | Quantize embeddings to a datatype and reduce dimensionality. 45 | 46 | :param embeddings: The embeddings to quantize and reduce, as a numpy array. 47 | :param quantize_to: The data type to quantize to. If None, no quantization is performed. 48 | :param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed. 49 | :return: The quantized and reduced embeddings. 50 | :raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality. 51 | """ 52 | if quantize_to is not None: 53 | quantize_to = DType(quantize_to) 54 | embeddings = quantize_embeddings(embeddings, quantize_to) 55 | 56 | if dimensionality is not None: 57 | if dimensionality > embeddings.shape[1]: 58 | raise ValueError( 59 | f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}" 60 | ) 61 | embeddings = embeddings[:, :dimensionality] 62 | 63 | return embeddings 64 | -------------------------------------------------------------------------------- /model2vec/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from model2vec.utils import importable 2 | 3 | importable("transformers", "tokenizer") 4 | 5 | from model2vec.tokenizer.tokenizer import ( 6 | clean_and_create_vocabulary, 7 | create_tokenizer, 8 | replace_vocabulary, 9 | turn_tokens_into_ids, 10 | ) 11 | 12 | __all__ = ["clean_and_create_vocabulary", "create_tokenizer", "turn_tokens_into_ids", "replace_vocabulary"] 13 | -------------------------------------------------------------------------------- /model2vec/tokenizer/datamodels.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Token: 6 | """A class to represent a token.""" 7 | 8 | form: str 9 | # The normalized and pretokenized form of the token 10 | normalized_form: str 11 | # Whether the word is a continuing subword. 12 | is_subword: bool 13 | # Whether the token is internal to the model. 14 | is_internal: bool 15 | -------------------------------------------------------------------------------- /model2vec/tokenizer/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | 7 | 8 | def process_tokenizer( 9 | tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None 10 | ) -> dict[str, Any]: 11 | """Process the WordPiece tokenizer JSON.""" 12 | if tokenizer_json["model"]["type"] == "Unigram": 13 | return _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token) 14 | tokenizer_json["model"]["type"] = "Unigram" 15 | tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None 16 | 17 | token_weights = np.asarray([_calculate_token_weight_for_unigram(token) for token in pre_tokenized_tokens]) 18 | proba = (token_weights / np.sum(token_weights)).tolist() 19 | tokenizer_json["model"]["vocab"] = [(token, np.log(p)) for token, p in zip(pre_tokenized_tokens, proba)] 20 | 21 | return tokenizer_json 22 | 23 | 24 | def _process_unigram( 25 | tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None 26 | ) -> dict[str, Any]: 27 | """Process the Unigram tokenizer JSON.""" 28 | current_probas = dict(tokenizer_json["model"]["vocab"]) 29 | avg_proba = sum(current_probas.values()) / len(current_probas) 30 | new_probas = [[word, current_probas.get(word, avg_proba)] for word in pre_tokenized_tokens] 31 | tokenizer_json["model"]["vocab"] = new_probas 32 | 33 | tokens, _ = zip(*tokenizer_json["model"]["vocab"]) 34 | if unk_token is not None: 35 | tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) 36 | 37 | return tokenizer_json 38 | 39 | 40 | def _calculate_token_weight_for_unigram(token: str) -> float: 41 | """Calculate the token weight for Unigram.""" 42 | # Always prefer longer tokens. 43 | return len(token) + token.count("▁") + token.count("Ġ") 44 | -------------------------------------------------------------------------------- /model2vec/tokenizer/normalizer.py: -------------------------------------------------------------------------------- 1 | from string import punctuation 2 | 3 | from tokenizers import Regex, Tokenizer 4 | from tokenizers.normalizers import Normalizer, Replace, Sequence, Strip 5 | 6 | 7 | def replace_normalizer( 8 | tokenizer: Tokenizer, 9 | ) -> Tokenizer: 10 | """ 11 | Replace the normalizer for the tokenizer. 12 | 13 | The new normalizer will replace punctuation with a space before and after the punctuation. 14 | It will also replace multiple spaces with a single space and strip the right side of the string. 15 | If the tokenizer already has a normalizer, it will be added to the new normalizer. 16 | If the tokenizer does not have a normalizer, a new normalizer will be created. 17 | 18 | :param tokenizer: The tokenizer to change. 19 | :return: The tokenizer with a replaced normalizer. 20 | """ 21 | normalizer = tokenizer.normalizer 22 | new_normalizers = [] 23 | for char in punctuation: 24 | new_normalizers.append(Replace(char, f" {char} ")) 25 | 26 | new_normalizers.append(Replace(Regex(r"\s+"), " ")) 27 | new_normalizers.append(Strip(right=True)) 28 | if normalizer is None: 29 | normalizer = Sequence(new_normalizers) # type: ignore 30 | else: 31 | normalizer = Sequence([normalizer] + new_normalizers) # type: ignore 32 | tokenizer.normalizer = normalizer # type: ignore 33 | 34 | return tokenizer 35 | -------------------------------------------------------------------------------- /model2vec/tokenizer/pretokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from typing import Any 5 | 6 | from tokenizers import Tokenizer 7 | 8 | _FORBIDDEN_PRETOKENIZERS = ( 9 | "WhiteSpace", 10 | "WhitespaceSplit", 11 | "BertPreTokenizer", 12 | "CharDelimiterSplit", 13 | "Punctuation", 14 | "Split", 15 | "UnicodeScripts", 16 | ) 17 | _BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False} 18 | 19 | 20 | def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None: 21 | """Fixes a single pretokenizer to allow multiword units.""" 22 | if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS: 23 | return None 24 | if pre_tokenizer["type"] == "ByteLevel": 25 | pre_tokenizer["add_prefix_space"] = True 26 | pre_tokenizer["use_regex"] = False 27 | if pre_tokenizer["type"] == "Metaspace": 28 | pre_tokenizer["split"] = False 29 | pre_tokenizer["prepend_scheme"] = "always" 30 | 31 | return pre_tokenizer 32 | 33 | 34 | def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer: 35 | """Fixes a single pretokenizer to allow multiword units.""" 36 | tokenizer_json = json.loads(tokenizer.to_str()) 37 | pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None) 38 | 39 | if pre_tokenizer_json is None: 40 | pre_tokenizer_json = _BASIC_METASPACE 41 | 42 | elif pre_tokenizer_json["type"] == "Sequence": 43 | new_pretokenizers = [] 44 | for single_pretokenizer in pre_tokenizer_json["pretokenizers"]: 45 | new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer) 46 | if new_pretokenizer is not None: 47 | new_pretokenizers.append(new_pretokenizer) 48 | 49 | if new_pretokenizers: 50 | pre_tokenizer_json["pretokenizers"] = new_pretokenizers 51 | else: 52 | pre_tokenizer_json = _BASIC_METASPACE 53 | 54 | pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE 55 | tokenizer_json["pre_tokenizer"] = pre_tokenizer_json 56 | 57 | return tokenizer.from_str(json.dumps(tokenizer_json)) 58 | -------------------------------------------------------------------------------- /model2vec/train/README.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). 4 | 5 | We support both single and multi-label classification, which work seamlessly based on the labels you provide. 6 | 7 | # Installation 8 | 9 | To train, make sure you install the training extra: 10 | 11 | ``` 12 | pip install model2vec[training] 13 | ``` 14 | 15 | # Quickstart 16 | 17 | To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows: 18 | 19 | ```python 20 | from model2vec.distill import distill 21 | from model2vec.train import StaticModelForClassification 22 | 23 | # From a distilled model 24 | distilled_model = distill("baai/bge-base-en-v1.5") 25 | classifier = StaticModelForClassification.from_static_model(model=distilled_model) 26 | 27 | # From a pre-trained model: potion is the default 28 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m") 29 | ``` 30 | 31 | This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data. 32 | 33 | Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed. 34 | 35 | ```python 36 | import numpy as np 37 | from datasets import load_dataset 38 | 39 | # Load the subj dataset 40 | ds = load_dataset("setfit/subj") 41 | train = ds["train"] 42 | test = ds["test"] 43 | 44 | s = perf_counter() 45 | classifier = classifier.fit(train["text"], train["label"]) 46 | 47 | print(f"Training took {int(perf_counter() - s)} seconds.") 48 | # Training took 81 seconds 49 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"]) 50 | print(classification_report) 51 | # Achieved 91.0 test accuracy 52 | ``` 53 | 54 | As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training. 55 | 56 | The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5. 57 | 58 | Note that this model is as fast as you're used to from us: 59 | 60 | ```python 61 | from time import perf_counter 62 | 63 | s = perf_counter() 64 | classifier.predict(test["text"]) 65 | print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.") 66 | # Took 67 milliseconds for 2000 instances on CPU. 67 | ``` 68 | 69 | ## Multi-label classification 70 | 71 | Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: 72 | 73 | ```python 74 | from datasets import load_dataset 75 | from model2vec.train import StaticModelForClassification 76 | 77 | # Initialize a classifier from a pre-trained model 78 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") 79 | 80 | # Load a multi-label dataset 81 | ds = load_dataset("google-research-datasets/go_emotions") 82 | 83 | # Inspect some of the labels 84 | print(ds["train"]["labels"][40:50]) 85 | # [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]] 86 | 87 | # Train the classifier on text (X) and labels (y) 88 | classifier.fit(ds["train"]["text"], ds["train"]["labels"]) 89 | ``` 90 | 91 | Then, we can evaluate the classifier: 92 | 93 | ```python 94 | from sklearn import metrics 95 | from sklearn.preprocessing import MultiLabelBinarizer 96 | 97 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3) 98 | print(classification_report) 99 | # Accuracy: 0.410 100 | # Precision: 0.527 101 | # Recall: 0.410 102 | # F1: 0.439 103 | ``` 104 | 105 | The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. 106 | 107 | # Persistence 108 | 109 | You can turn a classifier into a scikit-learn compatible pipeline, as follows: 110 | 111 | ```python 112 | pipeline = classifier.to_pipeline() 113 | ``` 114 | 115 | This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization. 116 | 117 | If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions: 118 | 119 | ```python 120 | pipeline.save_pretrained(path) 121 | pipeline.push_to_hub("my_cool/project") 122 | ``` 123 | 124 | Later, you can load these as follows: 125 | 126 | ```python 127 | from model2vec.inference import StaticModelPipeline 128 | 129 | pipeline = StaticModelPipeline.from_pretrained("my_cool/project") 130 | ``` 131 | 132 | Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk. 133 | 134 | 135 | # Bring your own architecture 136 | 137 | Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc. 138 | 139 | The core functionality of the `StaticModelForClassification` is contained in a couple of functions: 140 | 141 | * `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior. 142 | * `train_test_split`: governs the train test split before classification. 143 | * `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training. 144 | * `_encode`: The encoding function used in the model. 145 | * `fit`: contains all the lightning-related fitting logic. 146 | 147 | The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. 148 | 149 | # Results 150 | 151 | We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation. 152 | -------------------------------------------------------------------------------- /model2vec/train/__init__.py: -------------------------------------------------------------------------------- 1 | from model2vec.utils import get_package_extras, importable 2 | 3 | _REQUIRED_EXTRA = "train" 4 | 5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): 6 | importable(extra_dependency, _REQUIRED_EXTRA) 7 | 8 | from model2vec.train.classifier import StaticModelForClassification 9 | 10 | __all__ = ["StaticModelForClassification"] 11 | -------------------------------------------------------------------------------- /model2vec/train/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Any, TypeVar 5 | 6 | import numpy as np 7 | import torch 8 | from tokenizers import Encoding, Tokenizer 9 | from torch import nn 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import DataLoader, Dataset 12 | 13 | from model2vec import StaticModel 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class FinetunableStaticModel(nn.Module): 19 | def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None: 20 | """ 21 | Initialize a trainable StaticModel from a StaticModel. 22 | 23 | :param vectors: The embeddings of the staticmodel. 24 | :param tokenizer: The tokenizer. 25 | :param out_dim: The output dimension of the head. 26 | :param pad_id: The padding id. This is set to 0 in almost all model2vec models 27 | """ 28 | super().__init__() 29 | self.pad_id = pad_id 30 | self.out_dim = out_dim 31 | self.embed_dim = vectors.shape[1] 32 | 33 | self.vectors = vectors 34 | if self.vectors.dtype != torch.float32: 35 | dtype = str(self.vectors.dtype) 36 | logger.warning( 37 | f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues." 38 | ) 39 | self.vectors = vectors.float() 40 | 41 | self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id) 42 | self.head = self.construct_head() 43 | self.w = self.construct_weights() 44 | self.tokenizer = tokenizer 45 | 46 | def construct_weights(self) -> nn.Parameter: 47 | """Construct the weights for the model.""" 48 | weights = torch.zeros(len(self.vectors)) 49 | weights[self.pad_id] = -10_000 50 | return nn.Parameter(weights) 51 | 52 | def construct_head(self) -> nn.Sequential: 53 | """Method should be overridden for various other classes.""" 54 | return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim)) 55 | 56 | @classmethod 57 | def from_pretrained( 58 | cls: type[ModelType], *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any 59 | ) -> ModelType: 60 | """Load the model from a pretrained model2vec model.""" 61 | model = StaticModel.from_pretrained(model_name) 62 | return cls.from_static_model(model=model, out_dim=out_dim, **kwargs) 63 | 64 | @classmethod 65 | def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType: 66 | """Load the model from a static model.""" 67 | model.embedding = np.nan_to_num(model.embedding) 68 | embeddings_converted = torch.from_numpy(model.embedding) 69 | return cls( 70 | vectors=embeddings_converted, 71 | pad_id=model.tokenizer.token_to_id("[PAD]"), 72 | out_dim=out_dim, 73 | tokenizer=model.tokenizer, 74 | **kwargs, 75 | ) 76 | 77 | def _encode(self, input_ids: torch.Tensor) -> torch.Tensor: 78 | """ 79 | A forward pass and mean pooling. 80 | 81 | This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients 82 | to pass through. 83 | 84 | :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds. 85 | :return: The mean over the input ids, weighted by token weights. 86 | """ 87 | w = self.w[input_ids] 88 | w = torch.sigmoid(w) 89 | zeros = (input_ids != self.pad_id).float() 90 | w = w * zeros 91 | # Add a small epsilon to avoid division by zero 92 | length = zeros.sum(1) + 1e-16 93 | embedded = self.embeddings(input_ids) 94 | # Weigh each token 95 | embedded = torch.bmm(w[:, None, :], embedded).squeeze(1) 96 | # Mean pooling by dividing by the length 97 | embedded = embedded / length[:, None] 98 | 99 | return nn.functional.normalize(embedded) 100 | 101 | def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 102 | """Forward pass through the mean, and a classifier layer after.""" 103 | encoded = self._encode(input_ids) 104 | return self.head(encoded), encoded 105 | 106 | def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor: 107 | """ 108 | Tokenize a bunch of strings into a single padded 2D tensor. 109 | 110 | Note that this is not used during training. 111 | 112 | :param texts: The texts to tokenize. 113 | :param max_length: If this is None, the sequence lengths are truncated to 512. 114 | :return: A 2D padded tensor 115 | """ 116 | encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False) 117 | encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded] 118 | return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id) 119 | 120 | @property 121 | def device(self) -> str: 122 | """Get the device of the model.""" 123 | return self.embeddings.weight.device 124 | 125 | def to_static_model(self) -> StaticModel: 126 | """Convert the model to a static model.""" 127 | emb = self.embeddings.weight.detach().cpu().numpy() 128 | w = torch.sigmoid(self.w).detach().cpu().numpy() 129 | 130 | return StaticModel(emb * w[:, None], self.tokenizer, normalize=True) 131 | 132 | 133 | class TextDataset(Dataset): 134 | def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None: 135 | """ 136 | A dataset of texts. 137 | 138 | :param tokenized_texts: The tokenized texts. Each text is a list of token ids. 139 | :param targets: The targets. 140 | :raises ValueError: If the number of labels does not match the number of texts. 141 | """ 142 | if len(targets) != len(tokenized_texts): 143 | raise ValueError("Number of labels does not match number of texts.") 144 | self.tokenized_texts = tokenized_texts 145 | self.targets = targets 146 | 147 | def __len__(self) -> int: 148 | """Return the length of the dataset.""" 149 | return len(self.tokenized_texts) 150 | 151 | def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]: 152 | """Gets an item.""" 153 | return self.tokenized_texts[index], self.targets[index] 154 | 155 | @staticmethod 156 | def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]: 157 | """Collate function.""" 158 | texts, targets = zip(*batch) 159 | 160 | tensors = [torch.LongTensor(x) for x in texts] 161 | padded = pad_sequence(tensors, batch_first=True, padding_value=0) 162 | 163 | return padded, torch.stack(targets) 164 | 165 | def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader: 166 | """Convert the dataset to a DataLoader.""" 167 | return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size) 168 | 169 | 170 | ModelType = TypeVar("ModelType", bound=FinetunableStaticModel) 171 | -------------------------------------------------------------------------------- /model2vec/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import annotations 3 | 4 | import json 5 | import logging 6 | import re 7 | from importlib import import_module 8 | from importlib.metadata import metadata 9 | from pathlib import Path 10 | from typing import Any, Iterator, Protocol, cast 11 | 12 | import numpy as np 13 | import safetensors 14 | from joblib import Parallel 15 | from tokenizers import Tokenizer 16 | from tqdm import tqdm 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class ProgressParallel(Parallel): 22 | """A drop-in replacement for joblib.Parallel that shows a tqdm progress bar.""" 23 | 24 | def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None: 25 | """ 26 | Initialize the ProgressParallel object. 27 | 28 | :param use_tqdm: Whether to show the progress bar. 29 | :param total: Total number of tasks (batches) you expect to process. If None, 30 | it updates the total dynamically to the number of dispatched tasks. 31 | :param *args: Additional arguments to pass to `Parallel.__init__`. 32 | :param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`. 33 | """ 34 | self._use_tqdm = use_tqdm 35 | self._total = total 36 | super().__init__(*args, **kwargs) 37 | 38 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 39 | """Create a tqdm context.""" 40 | with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: 41 | self._pbar = self._pbar 42 | return super().__call__(*args, **kwargs) 43 | 44 | def print_progress(self) -> None: 45 | """Hook called by joblib as tasks complete. We update the tqdm bar here.""" 46 | if self._total is None: 47 | # If no fixed total was given, we dynamically set the total 48 | self._pbar.total = self.n_dispatched_tasks 49 | # Move the bar to the number of completed tasks 50 | self._pbar.n = self.n_completed_tasks 51 | self._pbar.refresh() 52 | 53 | 54 | class SafeOpenProtocol(Protocol): 55 | """Protocol to fix safetensors safe open.""" 56 | 57 | def get_tensor(self, key: str) -> np.ndarray: 58 | """Get a tensor.""" 59 | ... # pragma: no cover 60 | 61 | 62 | _MODULE_MAP = (("scikit-learn", "sklearn"),) 63 | _DIVIDERS = re.compile(r"[=<>!]+") 64 | 65 | 66 | def get_package_extras(package: str, extra: str) -> Iterator[str]: 67 | """Get the extras of the package.""" 68 | try: 69 | message = metadata(package) 70 | except Exception as e: 71 | raise ImportError(f"Could not retrieve metadata for package '{package}': {e}") 72 | 73 | all_packages = message.get_all("Requires-Dist") or [] 74 | for package in all_packages: 75 | name, *rest = package.split(";", maxsplit=1) 76 | if rest: 77 | # Extract and clean the extra requirement 78 | found_extra = rest[0].split("==")[-1].strip(" \"'") 79 | if found_extra == extra: 80 | prefix, *_ = _DIVIDERS.split(name) 81 | yield prefix.strip() 82 | 83 | 84 | def importable(module: str, extra: str) -> None: 85 | """Check if a module is importable.""" 86 | module = dict(_MODULE_MAP).get(module, module) 87 | try: 88 | import_module(module) 89 | except ImportError: 90 | raise ImportError( 91 | f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`" 92 | ) 93 | 94 | 95 | def setup_logging() -> None: 96 | """Simple logging setup.""" 97 | from rich.logging import RichHandler 98 | 99 | logging.basicConfig( 100 | level="INFO", 101 | format="%(name)s - %(message)s", 102 | datefmt="%Y-%m-%d %H:%M:%S", 103 | handlers=[RichHandler(rich_tracebacks=True)], 104 | ) 105 | 106 | 107 | def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]: 108 | """Load a local model.""" 109 | embeddings_path = folder / "model.safetensors" 110 | tokenizer_path = folder / "tokenizer.json" 111 | config_path = folder / "config.json" 112 | 113 | opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy")) 114 | embeddings = opened_tensor_file.get_tensor("embeddings") 115 | 116 | if config_path.exists(): 117 | config = json.load(open(config_path)) 118 | else: 119 | config = {} 120 | 121 | tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path)) 122 | 123 | if len(tokenizer.get_vocab()) != len(embeddings): 124 | logger.warning( 125 | f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`" 126 | ) 127 | 128 | return embeddings, tokenizer, config 129 | -------------------------------------------------------------------------------- /model2vec/version.py: -------------------------------------------------------------------------------- 1 | __version_triple__ = (0, 6, 0) 2 | __version__ = ".".join(map(str, __version_triple__)) 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "model2vec" 3 | description = "Fast State-of-the-Art Static Embeddings" 4 | readme = { file = "README.md", content-type = "text/markdown" } 5 | license = { file = "LICENSE" } 6 | requires-python = ">=3.9" 7 | authors = [{ name = "Stéphan Tulkens", email = "stephantul@gmail.com"}, {name = "Thomas van Dongen", email = "thomas123@live.nl"}] 8 | dynamic = ["version"] 9 | 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | "Topic :: Software Development :: Libraries", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Natural Language :: English", 23 | ] 24 | 25 | dependencies = [ 26 | "jinja2", 27 | "joblib", 28 | "numpy", 29 | "rich", 30 | "safetensors", 31 | "setuptools", 32 | "tokenizers>=0.20", 33 | "tqdm", 34 | ] 35 | 36 | [build-system] 37 | requires = ["setuptools>=64", "setuptools_scm>=8"] 38 | build-backend = "setuptools.build_meta" 39 | 40 | [tool.setuptools] 41 | packages = ["model2vec"] 42 | include-package-data = true 43 | 44 | [tool.setuptools.package-data] 45 | model2vec = [ 46 | "assets/modelcards/model_card_template.md", 47 | "assets/modelcards/classifier_template.md", 48 | "py.typed" 49 | ] 50 | 51 | [project.optional-dependencies] 52 | dev = [ 53 | "black", 54 | "ipython", 55 | "mypy", 56 | "pre-commit", 57 | "pytest", 58 | "pytest-cov", 59 | "ruff", 60 | ] 61 | 62 | distill = ["torch", "transformers<=4.52.1", "scikit-learn"] 63 | onnx = ["onnx", "torch"] 64 | # train also installs inference 65 | train = ["torch", "lightning", "scikit-learn", "skops"] 66 | inference = ["scikit-learn", "skops"] 67 | tokenizer = ["transformers"] 68 | 69 | [project.urls] 70 | "Homepage" = "https://github.com/MinishLab" 71 | "Bug Reports" = "https://github.com/MinishLab/model2vec/issues" 72 | "Source" = "https://github.com/MinishLab/model2vec" 73 | 74 | [tool.ruff] 75 | exclude = [".venv/"] 76 | line-length = 120 77 | target-version = "py310" 78 | 79 | [tool.ruff.lint] 80 | select = [ 81 | # Annotations: Enforce type annotations 82 | "ANN", 83 | # Complexity: Enforce a maximum cyclomatic complexity 84 | "C90", 85 | # Pydocstyle: Enforce docstrings 86 | "D", 87 | # Isort: Enforce import order 88 | "I", 89 | # Numpy: Enforce numpy style 90 | "NPY", 91 | # Print: Forbid print statements 92 | "T20", 93 | ] 94 | 95 | ignore = [ 96 | # Allow self and cls to be untyped, and allow Any type 97 | "ANN101", "ANN102", "ANN401", 98 | # Pydocstyle ignores 99 | "D100", "D101", "D104", "D203", "D212", "D401", 100 | # Allow use of f-strings in logging 101 | "G004" 102 | ] 103 | 104 | [tool.pydoclint] 105 | style = "sphinx" 106 | exclude = "test_" 107 | allow-init-docstring = true 108 | arg-type-hints-in-docstring = false 109 | check-return-types = false 110 | require-return-section-when-returning-nothing = false 111 | 112 | [tool.mypy] 113 | python_version = "3.10" 114 | warn_unused_configs = true 115 | ignore_missing_imports = true 116 | 117 | [tool.setuptools_scm] 118 | # can be empty if no extra settings are needed, presence enables setuptools_scm 119 | 120 | [tool.setuptools.dynamic] 121 | version = {attr = "model2vec.version.__version__"} 122 | -------------------------------------------------------------------------------- /results/make_speed_vs_mteb_plot.py: -------------------------------------------------------------------------------- 1 | """Script to benchmark the speed of various text embedding models and generate a plot of the MTEB score vs samples per second.""" 2 | 3 | import argparse 4 | import json 5 | import logging 6 | from pathlib import Path 7 | from time import perf_counter 8 | from typing import Any 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from bpemb import BPEmb 13 | from datasets import load_dataset 14 | from plotnine import ( 15 | aes, 16 | element_line, 17 | geom_point, 18 | geom_text, 19 | ggplot, 20 | guides, 21 | labs, 22 | scale_size, 23 | scale_y_continuous, 24 | theme, 25 | theme_classic, 26 | xlim, 27 | ylim, 28 | ) 29 | from sentence_transformers import SentenceTransformer 30 | 31 | from model2vec import StaticModel 32 | 33 | logging.basicConfig(level=logging.INFO) 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class BPEmbEmbedder: 39 | def __init__(self, vs: int = 50_000, dim: int = 300) -> None: 40 | """Initialize the BPEmbEmbedder.""" 41 | self.bpemb_en = BPEmb(lang="en", vs=vs, dim=dim) 42 | 43 | def mean_sentence_embedding(self, sentence: str) -> np.ndarray: 44 | """Encode a sentence to a mean embedding.""" 45 | encoded_ids = self.bpemb_en.encode_ids(sentence) 46 | embeddings = self.bpemb_en.vectors[encoded_ids] 47 | if embeddings.size == 0: 48 | return np.zeros(self.bpemb_en.dim) # Return a zero vector if no tokens are found 49 | return embeddings.mean(axis=0) 50 | 51 | def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray: 52 | """Encode a list of sentences to embeddings.""" 53 | return np.array([self.mean_sentence_embedding(sentence.lower()) for sentence in sentences]) 54 | 55 | 56 | def make_plot(df: pd.DataFrame) -> ggplot: 57 | """Create a plot of the MTEB score vs samples per second.""" 58 | df["label_y"] = ( 59 | df["Average score"] 60 | + 0.5 # a constant "base" offset for all bubbles 61 | + 0.08 * np.sqrt(df["Params (Million)"]) 62 | ) 63 | plot = ( 64 | ggplot(df, aes(x="Samples per second", y="Average score")) 65 | + geom_point(aes(size="Params (Million)", color="Model")) 66 | + geom_text(aes(y="label_y", label="Model"), color="black", size=7) 67 | + scale_size(range=(2, 30)) 68 | + theme_classic() 69 | + labs(title="Average MTEB Score vs Samples per Second") 70 | + ylim(df["Average score"].min(), df["Average score"].max() + 3) 71 | + scale_y_continuous(breaks=range(30, 70, 5)) 72 | + theme( 73 | panel_grid_major=element_line(color="lightgrey", size=0.5), 74 | panel_grid_minor=element_line(color="lightgrey", size=0.25), 75 | figure_size=(10, 6), 76 | ) 77 | + xlim(0, df["Samples per second"].max() + 100) 78 | + guides(None) 79 | ) 80 | return plot 81 | 82 | 83 | def benchmark_model(name: str, info: list[str], texts: list[str]) -> dict[str, float | str]: 84 | """Benchmark a single model.""" 85 | logger.info("Starting", name) 86 | if info[1] == "BPEmb": 87 | model = BPEmbEmbedder(vs=50_000, dim=300) # type: ignore 88 | elif info[1] == "ST": 89 | model = SentenceTransformer(info[0], device="cpu") # type: ignore 90 | else: 91 | model = StaticModel.from_pretrained(info[0]) # type: ignore 92 | 93 | start = perf_counter() 94 | if info[1] == "M2V": 95 | # If the model is a model2vec model, disable multiprocessing for a fair comparison 96 | model.encode(texts, use_multiprocessing=False) 97 | else: 98 | model.encode(texts) 99 | 100 | total_time = perf_counter() - start 101 | docs_per_second = len(texts) / total_time 102 | 103 | logger.info(f"{name}: {docs_per_second} docs per second") 104 | logger.info(f"Total time: {total_time}") 105 | 106 | return {"docs_per_second": docs_per_second, "total_time": total_time} 107 | 108 | 109 | def main(save_path: str, n_texts: int) -> None: 110 | """Benchmark text embedding models and generate a plot.""" 111 | # Define the models to benchmark 112 | models: dict[str, list[str]] = { 113 | "BPEmb-50k-300d": ["", "BPEmb"], 114 | "all-MiniLM-L6-v2": ["sentence-transformers/all-MiniLM-L6-v2", "ST"], 115 | "bge-base-en-v1.5": ["BAAI/bge-base-en-v1.5", "ST"], 116 | "GloVe 6B 300d": ["sentence-transformers/average_word_embeddings_glove.6B.300d", "ST"], 117 | "potion-base-8M": ["minishlab/potion-base-8M", "M2V"], 118 | } 119 | 120 | # Load the dataset 121 | ds = load_dataset("wikimedia/wikipedia", data_files="20231101.en/train-00000-of-00041.parquet")["train"] 122 | texts = ds["text"][:n_texts] 123 | 124 | summarized_results = [ 125 | {"Model": "potion-base-2M", "Average score": 44.77, "Samples per second": None, "Params (Million)": 1.875}, 126 | {"Model": "GloVe 6B 300d", "Average score": 42.36, "Samples per second": None, "Params (Million)": 120.000}, 127 | {"Model": "potion-base-4M", "Average score": 48.23, "Samples per second": None, "Params (Million)": 3.750}, 128 | {"Model": "all-MiniLM-L6-v2", "Average score": 56.09, "Samples per second": None, "Params (Million)": 23.000}, 129 | {"Model": "potion-base-8M", "Average score": 50.03, "Samples per second": None, "Params (Million)": 7.500}, 130 | {"Model": "bge-base-en-v1.5", "Average score": 63.56, "Samples per second": None, "Params (Million)": 109.000}, 131 | {"Model": "M2V_base_output", "Average score": 45.34, "Samples per second": None, "Params (Million)": 7.500}, 132 | {"Model": "BPEmb-50k-300d", "Average score": 37.78, "Samples per second": None, "Params (Million)": 15.000}, 133 | {"Model": "potion-base-32M", "Average score": 51.66, "Samples per second": None, "Params (Million)": 32.300}, 134 | ] 135 | 136 | timings = {} 137 | 138 | for name, info in models.items(): 139 | timing = benchmark_model(name, info, texts) 140 | timings[name] = timing 141 | # Update summarized results 142 | for result in summarized_results: 143 | if result["Model"] == name: 144 | result["Samples per second"] = timing["docs_per_second"] 145 | 146 | # Set potion-base-8M as the reference speed for the other potion models 147 | potion_base_8m_speed = next( 148 | result["Samples per second"] for result in summarized_results if result["Model"] == "potion-base-8M" 149 | ) 150 | for model_name in ["M2V_base_output", "potion-base-2M", "potion-base-4M", "potion-base-32M"]: 151 | for result in summarized_results: 152 | if result["Model"] == model_name: 153 | result["Samples per second"] = potion_base_8m_speed 154 | 155 | # Ensure save_path is a directory 156 | save_dir = Path(save_path) 157 | save_dir.mkdir(parents=True, exist_ok=True) 158 | 159 | # Save timings to JSON 160 | json_path = save_dir / "speed_benchmark_results.json" 161 | with open(json_path, "w") as file: 162 | json.dump(timings, file, indent=4) 163 | 164 | # Create and save the plot 165 | df = pd.DataFrame(summarized_results) 166 | plot = make_plot(df) 167 | plot_path = save_dir / "speed_vs_mteb_plot.png" 168 | plot.save(plot_path, width=12, height=10) 169 | 170 | logger.info(f"Timings saved to {json_path}") 171 | logger.info(f"Plot saved to {plot_path}") 172 | 173 | 174 | if __name__ == "__main__": 175 | parser = argparse.ArgumentParser(description="Benchmark text embedding models and generate a plot.") 176 | parser.add_argument( 177 | "--save-path", type=str, required=True, help="Directory to save the benchmark results and plot." 178 | ) 179 | parser.add_argument( 180 | "--n-texts", type=int, default=100_000, help="Number of texts to use from the dataset for benchmarking." 181 | ) 182 | args = parser.parse_args() 183 | 184 | main(save_path=args.save_path, n_texts=args.n_texts) 185 | -------------------------------------------------------------------------------- /scripts/export_to_onnx.py: -------------------------------------------------------------------------------- 1 | from model2vec.utils import get_package_extras, importable 2 | 3 | # Define the optional dependency group name 4 | _REQUIRED_EXTRA = "onnx" 5 | 6 | # Check if each dependency for the "onnx" group is importable 7 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA): 8 | importable(extra_dependency, _REQUIRED_EXTRA) 9 | 10 | import argparse 11 | import json 12 | import logging 13 | from pathlib import Path 14 | 15 | import torch 16 | from tokenizers import Tokenizer 17 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 18 | 19 | from model2vec import StaticModel 20 | 21 | logging.basicConfig(level=logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class TorchStaticModel(torch.nn.Module): 26 | def __init__(self, model: StaticModel) -> None: 27 | """Initialize the TorchStaticModel with a StaticModel instance.""" 28 | super().__init__() 29 | # Convert NumPy embeddings to a torch.nn.EmbeddingBag 30 | embeddings = torch.from_numpy(model.embedding) 31 | if embeddings.dtype in {torch.int8, torch.uint8}: 32 | embeddings = embeddings.to(torch.float16) 33 | self.embedding_bag = torch.nn.EmbeddingBag.from_pretrained(embeddings, mode="mean", freeze=True) 34 | self.normalize = model.normalize 35 | # Save tokenizer attributes 36 | self.tokenizer = model.tokenizer 37 | self.unk_token_id = model.unk_token_id 38 | self.median_token_length = model.median_token_length 39 | 40 | def forward(self, input_ids: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: 41 | """ 42 | Forward pass of the model. 43 | 44 | :param input_ids: The input token ids. 45 | :param offsets: The offsets to compute the mean pooling. 46 | :return: The embeddings. 47 | """ 48 | # Perform embedding lookup and mean pooling 49 | embeddings = self.embedding_bag(input_ids, offsets) 50 | # Normalize if required 51 | if self.normalize: 52 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 53 | return embeddings 54 | 55 | def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: 56 | """ 57 | Tokenize the input sentences. 58 | 59 | :param sentences: The input sentences. 60 | :param max_length: The maximum length of the input_ids. 61 | :return: The input_ids and offsets. 62 | """ 63 | # Tokenization logic similar to your StaticModel 64 | if max_length is not None: 65 | m = max_length * self.median_token_length 66 | sentences = [sentence[:m] for sentence in sentences] 67 | encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False) 68 | encodings_ids = [encoding.ids for encoding in encodings] 69 | if self.unk_token_id is not None: 70 | # Remove unknown tokens 71 | encodings_ids = [ 72 | [token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids 73 | ] 74 | if max_length is not None: 75 | encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids] 76 | # Flatten input_ids and compute offsets 77 | offsets = torch.tensor([0] + [len(ids) for ids in encodings_ids[:-1]], dtype=torch.long).cumsum(dim=0) 78 | input_ids = torch.tensor( 79 | [token_id for token_ids in encodings_ids for token_id in token_ids], 80 | dtype=torch.long, 81 | ) 82 | return input_ids, offsets 83 | 84 | 85 | def export_model_to_onnx(model_path: str, save_path: Path) -> None: 86 | """ 87 | Export the StaticModel to ONNX format and save tokenizer files. 88 | 89 | :param model_path: The path to the pretrained StaticModel. 90 | :param save_path: The directory to save the model and related files. 91 | """ 92 | save_path.mkdir(parents=True, exist_ok=True) 93 | 94 | # Load the StaticModel 95 | model = StaticModel.from_pretrained(model_path) 96 | torch_model = TorchStaticModel(model) 97 | 98 | # Save the model using save_pretrained 99 | model.save_pretrained(save_path) 100 | 101 | # Prepare dummy input data 102 | texts = ["hello", "hello world"] 103 | input_ids, offsets = torch_model.tokenize(texts) 104 | 105 | # Export the model to ONNX 106 | onnx_model_path = save_path / "onnx/model.onnx" 107 | onnx_model_path.parent.mkdir(parents=True, exist_ok=True) 108 | torch.onnx.export( 109 | torch_model, 110 | (input_ids, offsets), 111 | str(onnx_model_path), 112 | export_params=True, 113 | opset_version=14, 114 | do_constant_folding=True, 115 | input_names=["input_ids", "offsets"], 116 | output_names=["embeddings"], 117 | dynamic_axes={ 118 | "input_ids": {0: "num_tokens"}, 119 | "offsets": {0: "batch_size"}, 120 | "embeddings": {0: "batch_size"}, 121 | }, 122 | ) 123 | 124 | logger.info(f"Model has been successfully exported to {onnx_model_path}") 125 | 126 | # Save the tokenizer files required for transformers.js 127 | save_tokenizer(model.tokenizer, save_path) 128 | logger.info(f"Tokenizer files have been saved to {save_path}") 129 | 130 | 131 | def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None: 132 | """ 133 | Save tokenizer files in a format compatible with Transformers. 134 | 135 | :param tokenizer: The tokenizer from the StaticModel. 136 | :param save_directory: The directory to save the tokenizer files. 137 | :raises FileNotFoundError: If config.json is not found in save_directory. 138 | :raises FileNotFoundError: If tokenizer_config.json is not found in save_directory. 139 | :raises ValueError: If tokenizer_name is not found in config.json. 140 | """ 141 | tokenizer_json_path = save_directory / "tokenizer.json" 142 | tokenizer.save(str(tokenizer_json_path)) 143 | 144 | # Save vocab.txt 145 | vocab = tokenizer.get_vocab() 146 | vocab_path = save_directory / "vocab.txt" 147 | with open(vocab_path, "w", encoding="utf-8") as vocab_file: 148 | for token in sorted(vocab, key=vocab.get): 149 | vocab_file.write(f"{token}\n") 150 | 151 | # Load config.json to get tokenizer_name 152 | config_path = save_directory / "config.json" 153 | if config_path.exists(): 154 | with open(config_path, "r", encoding="utf-8") as f: 155 | config = json.load(f) 156 | else: 157 | raise FileNotFoundError(f"config.json not found in {save_directory}") 158 | 159 | tokenizer_name = config.get("tokenizer_name") 160 | if not tokenizer_name: 161 | raise ValueError("tokenizer_name not found in config.json") 162 | 163 | # Load the original tokenizer 164 | original_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 165 | 166 | # Extract special tokens and tokenizer class 167 | special_tokens = original_tokenizer.special_tokens_map 168 | tokenizer_class = original_tokenizer.__class__.__name__ 169 | 170 | # Load the tokenizer using PreTrainedTokenizerFast with special tokens 171 | fast_tokenizer = PreTrainedTokenizerFast( 172 | tokenizer_file=str(tokenizer_json_path), 173 | **special_tokens, 174 | ) 175 | 176 | # Save the tokenizer files 177 | fast_tokenizer.save_pretrained(str(save_directory)) 178 | # Modify tokenizer_config.json to set the correct tokenizer_class 179 | tokenizer_config_path = save_directory / "tokenizer_config.json" 180 | if tokenizer_config_path.exists(): 181 | with open(tokenizer_config_path, "r", encoding="utf-8") as f: 182 | tokenizer_config = json.load(f) 183 | else: 184 | raise FileNotFoundError(f"tokenizer_config.json not found in {save_directory}") 185 | 186 | # Update the tokenizer_class field 187 | tokenizer_config["tokenizer_class"] = tokenizer_class 188 | 189 | # Write the updated tokenizer_config.json back to disk 190 | with open(tokenizer_config_path, "w", encoding="utf-8") as f: 191 | json.dump(tokenizer_config, f, indent=4, sort_keys=True) 192 | 193 | 194 | if __name__ == "__main__": 195 | parser = argparse.ArgumentParser(description="Export StaticModel to ONNX format") 196 | parser.add_argument( 197 | "--model_path", 198 | type=str, 199 | required=True, 200 | help="Path to the pretrained StaticModel", 201 | ) 202 | parser.add_argument( 203 | "--save_path", 204 | type=str, 205 | required=True, 206 | help="Directory to save the exported model and files", 207 | ) 208 | args = parser.parse_args() 209 | 210 | export_model_to_onnx(args.model_path, Path(args.save_path)) 211 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, cast 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | from tokenizers import Tokenizer 9 | from tokenizers.models import BPE, Unigram, WordPiece 10 | from tokenizers.pre_tokenizers import Whitespace 11 | from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast 12 | 13 | from model2vec.inference import StaticModelPipeline 14 | from model2vec.train import StaticModelForClassification 15 | 16 | _TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"] 17 | 18 | 19 | @pytest.fixture(scope="session", params=_TOKENIZER_TYPES, ids=_TOKENIZER_TYPES) 20 | def mock_tokenizer(request: pytest.FixtureRequest) -> Tokenizer: 21 | """Create a mock tokenizer.""" 22 | vocab = ["[PAD]", "word1", "word2", "word3", "[UNK]"] 23 | unk_token = "[UNK]" 24 | 25 | tokenizer_type = request.param 26 | 27 | if tokenizer_type == "wordpiece": 28 | model = WordPiece( 29 | vocab={token: idx for idx, token in enumerate(vocab)}, unk_token=unk_token, max_input_chars_per_word=100 30 | ) 31 | elif tokenizer_type == "bpe": 32 | model = BPE( 33 | vocab={token: idx for idx, token in enumerate(vocab)}, 34 | merges=[], 35 | unk_token=unk_token, 36 | fuse_unk=True, 37 | ignore_merges=True, 38 | ) 39 | elif tokenizer_type == "unigram": 40 | model = Unigram(vocab=[(token, 0.0) for token in vocab], unk_id=0, byte_fallback=False) 41 | else: 42 | raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") 43 | tokenizer = Tokenizer(model) 44 | tokenizer.pre_tokenizer = Whitespace() # type: ignore # Tokenizer issue 45 | 46 | return tokenizer 47 | 48 | 49 | @pytest.fixture(scope="function") 50 | def mock_berttokenizer() -> PreTrainedTokenizerFast: 51 | """Load the real BertTokenizerFast from the provided tokenizer.json file.""" 52 | return cast(PreTrainedTokenizerFast, AutoTokenizer.from_pretrained("tests/data/test_tokenizer")) 53 | 54 | 55 | @pytest.fixture 56 | def mock_transformer() -> AutoModel: 57 | """Create a mock transformer model.""" 58 | 59 | class MockPreTrainedModel: 60 | def __init__(self) -> None: 61 | self.device = "cpu" 62 | self.name_or_path = "mock-model" 63 | 64 | def to(self, device: str) -> MockPreTrainedModel: 65 | self.device = device 66 | return self 67 | 68 | def forward(self, *args: Any, **kwargs: Any) -> Any: 69 | # Simulate a last_hidden_state output for a transformer model 70 | batch_size, seq_length = kwargs["input_ids"].shape 71 | # Return a tensor of shape (batch_size, seq_length, 768) 72 | return type( 73 | "BaseModelOutputWithPoolingAndCrossAttentions", 74 | (object,), 75 | { 76 | "last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units 77 | }, 78 | ) 79 | 80 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 81 | # Simply call the forward method to simulate the same behavior as transformers models 82 | return self.forward(*args, **kwargs) 83 | 84 | return MockPreTrainedModel() 85 | 86 | 87 | @pytest.fixture(scope="session") 88 | def mock_vectors() -> np.ndarray: 89 | """Create mock vectors.""" 90 | return np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.0, 0.0], [0.0, 0.0]]) 91 | 92 | 93 | @pytest.fixture 94 | def mock_config() -> dict[str, str]: 95 | """Create a mock config.""" 96 | return {"some_config": "value"} 97 | 98 | 99 | @pytest.fixture(scope="session") 100 | def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification) -> StaticModelPipeline: 101 | """Mock pipeline.""" 102 | return mock_trained_pipeline.to_pipeline() 103 | 104 | 105 | @pytest.fixture( 106 | params=[ 107 | (False, "single_label", "str"), 108 | (False, "single_label", "int"), 109 | (True, "multilabel", "str"), 110 | (True, "multilabel", "int"), 111 | ], 112 | ids=lambda param: f"{param[1]}_{param[2]}", 113 | scope="session", 114 | ) 115 | def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification: 116 | """Mock StaticModelForClassification with different label formats.""" 117 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer 118 | torch.random.manual_seed(42) 119 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) 120 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") 121 | 122 | X = ["dog", "cat"] 123 | is_multilabel, label_type = request.param[0], request.param[2] 124 | 125 | if label_type == "str": 126 | y = [["a", "b"], ["a"]] if is_multilabel else ["a", "b"] # type: ignore 127 | else: 128 | y = [[0, 1], [0]] if is_multilabel else [0, 1] # type: ignore 129 | 130 | model.fit(X, y) 131 | 132 | return model 133 | -------------------------------------------------------------------------------- /tests/data/test_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /tests/data/test_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "added_tokens_decoder": { 3 | "0": { 4 | "content": "[PAD]", 5 | "lstrip": false, 6 | "normalized": false, 7 | "rstrip": false, 8 | "single_word": false, 9 | "special": true 10 | }, 11 | "100": { 12 | "content": "[UNK]", 13 | "lstrip": false, 14 | "normalized": false, 15 | "rstrip": false, 16 | "single_word": false, 17 | "special": true 18 | }, 19 | "101": { 20 | "content": "[CLS]", 21 | "lstrip": false, 22 | "normalized": false, 23 | "rstrip": false, 24 | "single_word": false, 25 | "special": true 26 | }, 27 | "102": { 28 | "content": "[SEP]", 29 | "lstrip": false, 30 | "normalized": false, 31 | "rstrip": false, 32 | "single_word": false, 33 | "special": true 34 | }, 35 | "103": { 36 | "content": "[MASK]", 37 | "lstrip": false, 38 | "normalized": false, 39 | "rstrip": false, 40 | "single_word": false, 41 | "special": true 42 | } 43 | }, 44 | "clean_up_tokenization_spaces": true, 45 | "cls_token": "[CLS]", 46 | "do_basic_tokenize": true, 47 | "do_lower_case": true, 48 | "mask_token": "[MASK]", 49 | "model_max_length": 1000000000000000019884624838656, 50 | "never_split": null, 51 | "pad_token": "[PAD]", 52 | "sep_token": "[SEP]", 53 | "strip_accents": null, 54 | "tokenize_chinese_chars": true, 55 | "tokenizer_class": "BertTokenizer", 56 | "unk_token": "[UNK]" 57 | } 58 | -------------------------------------------------------------------------------- /tests/test_distillation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from importlib import import_module 5 | from unittest.mock import MagicMock, patch 6 | 7 | import numpy as np 8 | import pytest 9 | from pytest import LogCaptureFixture 10 | from transformers import AutoModel, BertTokenizerFast 11 | 12 | from model2vec.distill.distillation import ( 13 | clean_and_create_vocabulary, 14 | distill, 15 | distill_from_model, 16 | post_process_embeddings, 17 | ) 18 | from model2vec.model import StaticModel 19 | 20 | try: 21 | # For huggingface_hub>=0.25.0 22 | from huggingface_hub.errors import RepositoryNotFoundError 23 | except ImportError: 24 | # For huggingface_hub<0.25.0 25 | from huggingface_hub.utils._errors import RepositoryNotFoundError # type: ignore 26 | 27 | rng = np.random.default_rng() 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "vocabulary, pca_dims, apply_zipf", 32 | [ 33 | (None, 256, True), # Output vocab with subwords, PCA applied 34 | (["wordA", "wordB"], 4, False), # Custom vocab with subword, PCA applied 35 | (None, "auto", False), # Subword, PCA set to 'auto' 36 | (None, 1024, False), # Subword, PCA set to high number. 37 | (None, None, True), # No PCA applied 38 | (None, 0.9, True), # PCA as float applied 39 | ], 40 | ) 41 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") 42 | @patch("transformers.AutoModel.from_pretrained") 43 | def test_distill_from_model( 44 | mock_auto_model: MagicMock, 45 | mock_model_info: MagicMock, 46 | mock_berttokenizer: BertTokenizerFast, 47 | mock_transformer: AutoModel, 48 | vocabulary: list[str] | None, 49 | pca_dims: int | None, 50 | apply_zipf: bool, 51 | ) -> None: 52 | """Test distill function with different parameters.""" 53 | # Mock the return value of model_info to avoid calling the Hugging Face API 54 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) 55 | 56 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances 57 | # mock_auto_tokenizer.return_value = mock_berttokenizer 58 | mock_auto_model.return_value = mock_transformer 59 | 60 | # Call the distill function with the parametrized inputs 61 | static_model = distill_from_model( 62 | model=mock_transformer, 63 | tokenizer=mock_berttokenizer, 64 | vocabulary=vocabulary, 65 | device="cpu", 66 | pca_dims=pca_dims, 67 | apply_zipf=apply_zipf, 68 | token_remove_pattern=None, 69 | ) 70 | 71 | static_model2 = distill( 72 | model_name="tests/data/test_tokenizer", 73 | vocabulary=vocabulary, 74 | device="cpu", 75 | pca_dims=pca_dims, 76 | apply_zipf=apply_zipf, 77 | token_remove_pattern=None, 78 | ) 79 | 80 | assert static_model.embedding.shape == static_model2.embedding.shape 81 | assert static_model.config == static_model2.config 82 | assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str()) 83 | assert static_model.base_model_name == static_model2.base_model_name 84 | 85 | 86 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") 87 | @patch("transformers.AutoModel.from_pretrained") 88 | def test_distill_removal_pattern( 89 | mock_auto_model: MagicMock, 90 | mock_model_info: MagicMock, 91 | mock_berttokenizer: BertTokenizerFast, 92 | mock_transformer: AutoModel, 93 | ) -> None: 94 | """Test the removal pattern.""" 95 | # Mock the return value of model_info to avoid calling the Hugging Face API 96 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) 97 | 98 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances 99 | # mock_auto_tokenizer.return_value = mock_berttokenizer 100 | mock_auto_model.return_value = mock_transformer 101 | 102 | # The vocab size is 30522, but we remove 998 tokens: [CLS], [SEP], and [MASK], and all [unused] tokens. 103 | expected_vocab_size = mock_berttokenizer.vocab_size - 998 104 | 105 | static_model = distill_from_model( 106 | model=mock_transformer, 107 | tokenizer=mock_berttokenizer, 108 | vocabulary=None, 109 | device="cpu", 110 | token_remove_pattern=None, 111 | ) 112 | 113 | assert len(static_model.embedding) == expected_vocab_size 114 | 115 | # No tokens removed, nonsensical pattern 116 | static_model = distill_from_model( 117 | model=mock_transformer, 118 | tokenizer=mock_berttokenizer, 119 | vocabulary=None, 120 | device="cpu", 121 | token_remove_pattern="£££££££££££££££££", 122 | ) 123 | 124 | assert len(static_model.embedding) == expected_vocab_size 125 | 126 | # Weird pattern. 127 | with pytest.raises(ValueError): 128 | static_model = distill_from_model( 129 | model=mock_transformer, 130 | tokenizer=mock_berttokenizer, 131 | vocabulary=None, 132 | device="cpu", 133 | token_remove_pattern="[...papapa", 134 | ) 135 | 136 | 137 | @pytest.mark.parametrize( 138 | "vocabulary, pca_dims, apply_zipf, sif_coefficient, expected_shape", 139 | [ 140 | (None, 256, True, None, (29524, 256)), # Output vocab with subwords, PCA applied 141 | (None, "auto", False, None, (29524, 768)), # Subword, PCA set to 'auto' 142 | (None, "auto", True, 1e-4, (29524, 768)), # Subword, PCA set to 'auto' 143 | (None, "auto", False, 1e-4, (29524, 768)), # Subword, PCA set to 'auto' 144 | (None, "auto", True, 0, None), # Sif too low 145 | (None, "auto", True, 1, None), # Sif too high 146 | (None, "auto", False, 0, (29524, 768)), # Sif too low, but apply_zipf is False 147 | (None, "auto", False, 1, (29524, 768)), # Sif too high, but apply_zipf is False 148 | (None, 1024, False, None, (29524, 768)), # Subword, PCA set to high number. 149 | (["wordA", "wordB"], 4, False, None, (29526, 4)), # Custom vocab with subword, PCA applied 150 | (None, None, True, None, (29524, 768)), # No PCA applied 151 | ], 152 | ) 153 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") 154 | @patch("transformers.AutoModel.from_pretrained") 155 | def test_distill( 156 | mock_auto_model: MagicMock, 157 | mock_model_info: MagicMock, 158 | mock_transformer: AutoModel, 159 | vocabulary: list[str] | None, 160 | pca_dims: int | None, 161 | apply_zipf: bool, 162 | sif_coefficient: float | None, 163 | expected_shape: tuple[int, int], 164 | ) -> None: 165 | """Test distill function with different parameters.""" 166 | # Mock the return value of model_info to avoid calling the Hugging Face API 167 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) 168 | 169 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances 170 | mock_auto_model.return_value = mock_transformer 171 | 172 | model_name = "tests/data/test_tokenizer" 173 | 174 | if ( 175 | apply_zipf is not None 176 | and apply_zipf 177 | and sif_coefficient is not None 178 | and (sif_coefficient <= 0 or sif_coefficient >= 1) 179 | ): 180 | with pytest.raises(ValueError): 181 | static_model = distill( 182 | model_name=model_name, 183 | vocabulary=vocabulary, 184 | device="cpu", 185 | pca_dims=pca_dims, 186 | apply_zipf=apply_zipf, 187 | sif_coefficient=sif_coefficient, 188 | ) 189 | 190 | else: 191 | # Call the distill function with the parametrized inputs 192 | static_model = distill( 193 | model_name=model_name, 194 | vocabulary=vocabulary, 195 | device="cpu", 196 | pca_dims=pca_dims, 197 | apply_zipf=apply_zipf, 198 | sif_coefficient=sif_coefficient, 199 | ) 200 | 201 | # Assert the model is correctly generated 202 | assert isinstance(static_model, StaticModel) 203 | assert static_model.embedding.shape == expected_shape 204 | assert "mock-model" in static_model.config["tokenizer_name"] 205 | assert static_model.tokenizer is not None 206 | 207 | 208 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") 209 | def test_missing_modelinfo( 210 | mock_model_info: MagicMock, 211 | mock_transformer: AutoModel, 212 | mock_berttokenizer: BertTokenizerFast, 213 | ) -> None: 214 | """Test that missing model info does not crash.""" 215 | mock_model_info.side_effect = RepositoryNotFoundError("Model not found") 216 | static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu") 217 | assert static_model.language is None 218 | 219 | 220 | @pytest.mark.parametrize( 221 | "embeddings, pca_dims, sif_coefficient, expected_shape", 222 | [ 223 | (rng.random((1000, 768)), 256, None, (1000, 256)), # PCA applied correctly 224 | (rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions remain unchanged 225 | (rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and Zipf applied 226 | (rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims higher than vocab size, no PCA applied 227 | ], 228 | ) 229 | def test__post_process_embeddings( 230 | embeddings: np.ndarray, pca_dims: int, sif_coefficient: float | None, expected_shape: tuple[int, int] 231 | ) -> None: 232 | """Test the _post_process_embeddings function.""" 233 | original_embeddings = embeddings.copy() # Copy embeddings to compare later 234 | 235 | # Test that the function raises an error if the PCA dims are larger than the number of dimensions 236 | if pca_dims and pca_dims > embeddings.shape[1]: 237 | with pytest.raises(ValueError): 238 | post_process_embeddings(embeddings, pca_dims, None) 239 | 240 | processed_embeddings = post_process_embeddings(embeddings, pca_dims, sif_coefficient) 241 | 242 | # Assert the shape is correct 243 | assert processed_embeddings.shape == expected_shape 244 | 245 | # If Zipf weighting is applied compare the original and processed embeddings 246 | # and check the weights are applied correctly 247 | if sif_coefficient and pca_dims is None: 248 | inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2)) 249 | proba = inv_rank / np.sum(inv_rank) 250 | sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None] 251 | 252 | expected_zipf_embeddings = original_embeddings * sif_weights 253 | assert np.allclose( 254 | processed_embeddings, expected_zipf_embeddings, rtol=1e-5 255 | ), "Zipf weighting not applied correctly" 256 | 257 | 258 | @pytest.mark.parametrize( 259 | "added_tokens, expected_output, expected_warnings", 260 | [ 261 | # Case: duplicates ("2010", "government") and an empty token ("") 262 | (["2010", "government", "nerv", ""], ["nerv"], ["Removed", "duplicate", "empty"]), 263 | # Case: No duplicates, no empty tokens 264 | (["worda", "wordb", "wordc"], ["worda", "wordb", "wordc"], []), 265 | # Case: Only empty token (""), should return an empty list 266 | ([""], [], ["Removed", "empty"]), 267 | ], 268 | ) 269 | def test_clean_and_create_vocabulary( 270 | mock_berttokenizer: BertTokenizerFast, 271 | added_tokens: list[str], 272 | expected_output: list[str], 273 | expected_warnings: list[str], 274 | caplog: LogCaptureFixture, 275 | ) -> None: 276 | """Test the _clean_vocabulary function.""" 277 | with caplog.at_level("WARNING"): 278 | tokens, _ = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None) 279 | 280 | cleaned_vocab = [token.form for token in tokens if not token.is_internal] 281 | # Check the cleaned vocabulary matches the expected output 282 | assert cleaned_vocab == expected_output 283 | 284 | # Check the warnings were logged as expected 285 | logged_warnings = [record.message for record in caplog.records] 286 | 287 | # Ensure the expected warnings contain expected keywords like 'Removed', 'duplicate', or 'empty' 288 | for expected_warning in expected_warnings: 289 | assert any(expected_warning in logged_warning for logged_warning in logged_warnings) 290 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from tempfile import TemporaryDirectory 4 | from unittest.mock import patch 5 | 6 | import pytest 7 | 8 | from model2vec.inference import StaticModelPipeline 9 | 10 | 11 | def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None: 12 | """Test successful init and predict with StaticModelPipeline.""" 13 | target: list[str] | list[list[str]] 14 | if mock_inference_pipeline.multilabel: 15 | if isinstance(mock_inference_pipeline.classes_[0], str): 16 | target = [["a", "b"]] 17 | else: 18 | target = [[0, 1]] # type: ignore 19 | else: 20 | if isinstance(mock_inference_pipeline.classes_[0], str): 21 | target = ["b"] 22 | else: 23 | target = [1] # type: ignore 24 | assert mock_inference_pipeline.predict("dog").tolist() == target 25 | assert mock_inference_pipeline.predict(["dog"]).tolist() == target 26 | 27 | 28 | def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None: 29 | """Test successful init and predict_proba with StaticModelPipeline.""" 30 | assert mock_inference_pipeline.predict_proba("dog").argmax() == 1 31 | assert mock_inference_pipeline.predict_proba(["dog"]).argmax(1).tolist() == [1] 32 | 33 | 34 | def test_init_evaluate(mock_inference_pipeline: StaticModelPipeline) -> None: 35 | """Test successful init and evaluate with StaticModelPipeline.""" 36 | target: list[str] | list[list[str]] 37 | if mock_inference_pipeline.multilabel: 38 | if isinstance(mock_inference_pipeline.classes_[0], str): 39 | target = [["a", "b"]] 40 | else: 41 | target = [[0, 1]] # type: ignore 42 | else: 43 | if isinstance(mock_inference_pipeline.classes_[0], str): 44 | target = ["b"] 45 | else: 46 | target = [1] # type: ignore 47 | mock_inference_pipeline.evaluate("dog", target) # type: ignore 48 | 49 | 50 | def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None: 51 | """Test saving and loading the pipeline.""" 52 | with TemporaryDirectory() as temp_dir: 53 | mock_inference_pipeline.save_pretrained(temp_dir) 54 | loaded = StaticModelPipeline.from_pretrained(temp_dir) 55 | target: list[str] | list[list[str]] 56 | if mock_inference_pipeline.multilabel: 57 | if isinstance(mock_inference_pipeline.classes_[0], str): 58 | target = [["a", "b"]] 59 | else: 60 | target = [[0, 1]] # type: ignore 61 | else: 62 | if isinstance(mock_inference_pipeline.classes_[0], str): 63 | target = ["b"] 64 | else: 65 | target = [1] # type: ignore 66 | assert loaded.predict("dog").tolist() == target 67 | assert loaded.predict(["dog"]).tolist() == target 68 | assert loaded.predict_proba("dog").argmax() == 1 69 | assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1] 70 | 71 | 72 | @patch("model2vec.inference.model._DEFAULT_TRUST_PATTERN", re.compile("torch")) 73 | def test_roundtrip_save_mock_trust_pattern(mock_inference_pipeline: StaticModelPipeline) -> None: 74 | """Test saving and loading the pipeline.""" 75 | with TemporaryDirectory() as temp_dir: 76 | mock_inference_pipeline.save_pretrained(temp_dir) 77 | with pytest.raises(ValueError): 78 | StaticModelPipeline.from_pretrained(temp_dir) 79 | 80 | 81 | def test_roundtrip_save_file_gone(mock_inference_pipeline: StaticModelPipeline) -> None: 82 | """Test saving and loading the pipeline.""" 83 | with TemporaryDirectory() as temp_dir: 84 | mock_inference_pipeline.save_pretrained(temp_dir) 85 | # Rename the file to abc.pipeline, so that it looks like it was downloaded from the hub 86 | os.unlink(os.path.join(temp_dir, "pipeline.skops")) 87 | with pytest.raises(FileNotFoundError): 88 | StaticModelPipeline.from_pretrained(temp_dir) 89 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tempfile import TemporaryDirectory 3 | 4 | import numpy as np 5 | import pytest 6 | import safetensors 7 | from tokenizers import Tokenizer 8 | 9 | from model2vec import StaticModel 10 | 11 | 12 | def test_initialization(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 13 | """Test successful initialization of StaticModel.""" 14 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 15 | assert model.embedding.shape == (5, 2) 16 | assert len(model.tokens) == 5 17 | assert model.tokenizer == mock_tokenizer 18 | assert model.config == mock_config 19 | 20 | 21 | def test_initialization_token_vector_mismatch(mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 22 | """Test if error is raised when number of tokens and vectors don't match.""" 23 | mock_vectors = np.array([[0.1, 0.2], [0.2, 0.3]]) 24 | with pytest.raises(ValueError): 25 | StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 26 | 27 | 28 | def test_tokenize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 29 | """Test tokenization of a sentence.""" 30 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 31 | model._can_encode_fast = True 32 | tokens_fast = model.tokenize(["word1 word2"]) 33 | model._can_encode_fast = False 34 | tokens_slow = model.tokenize(["word1 word2"]) 35 | 36 | assert tokens_fast == tokens_slow 37 | 38 | 39 | def test_encode_batch_fast( 40 | mock_vectors: np.ndarray, mock_berttokenizer: Tokenizer, mock_config: dict[str, str] 41 | ) -> None: 42 | """Test tokenization of a sentence.""" 43 | if hasattr(mock_berttokenizer, "encode_batch_fast"): 44 | del mock_berttokenizer.encode_batch_fast 45 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_berttokenizer, config=mock_config) 46 | assert not model._can_encode_fast 47 | 48 | 49 | def test_encode_single_sentence( 50 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 51 | ) -> None: 52 | """Test encoding of a single sentence.""" 53 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 54 | encoded = model.encode("word1 word2") 55 | assert encoded.shape == (2,) 56 | 57 | 58 | def test_encode_single_sentence_empty( 59 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 60 | ) -> None: 61 | """Test encoding of a single empty sentence.""" 62 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 63 | model.normalize = True 64 | encoded = model.encode("") 65 | assert not np.isnan(encoded).any() 66 | assert np.all(encoded == 0) 67 | 68 | 69 | def test_encode_multiple_sentences( 70 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 71 | ) -> None: 72 | """Test encoding of multiple sentences.""" 73 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 74 | encoded = model.encode(["word1 word2", "word1 word3"]) 75 | assert encoded.shape == (2, 2) 76 | 77 | 78 | def test_encode_as_sequence(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 79 | """Test encoding of sentences as tokens.""" 80 | sentences = ["word1 word2", "word1 word3"] 81 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 82 | encoded_sequence = model.encode_as_sequence(sentences) 83 | encoded = model.encode(sentences) 84 | 85 | assert len(encoded_sequence) == 2 86 | 87 | means = [np.mean(sequence, axis=0) for sequence in encoded_sequence] 88 | assert np.allclose(means, encoded) 89 | 90 | 91 | def test_encode_multiprocessing( 92 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 93 | ) -> None: 94 | """Test encoding with multiprocessing.""" 95 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 96 | # Generate a list of 15k inputs to test multiprocessing 97 | sentences = ["word1 word2"] * 15_000 98 | encoded = model.encode(sentences, use_multiprocessing=True) 99 | assert encoded.shape == (15000, 2) 100 | 101 | 102 | def test_encode_as_sequence_multiprocessing( 103 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 104 | ) -> None: 105 | """Test encoding of sentences as tokens with multiprocessing.""" 106 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 107 | # Generate a list of 15k inputs to test multiprocessing 108 | sentences = ["word1 word2"] * 15_000 109 | encoded = model.encode_as_sequence(sentences, use_multiprocessing=True) 110 | assert len(encoded) == 15_000 111 | 112 | 113 | def test_encode_as_tokens_empty( 114 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 115 | ) -> None: 116 | """Test encoding of an empty list of sentences.""" 117 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 118 | encoded = model.encode_as_sequence("") 119 | assert np.array_equal(encoded, np.zeros(shape=(0, 2), dtype=model.embedding.dtype)) 120 | 121 | encoded = model.encode_as_sequence(["", ""]) 122 | out = [np.zeros(shape=(0, 2), dtype=model.embedding.dtype) for _ in range(2)] 123 | assert [np.array_equal(x, y) for x, y in zip(encoded, out)] 124 | 125 | 126 | def test_encode_empty_sentence( 127 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 128 | ) -> None: 129 | """Test encoding with an empty sentence.""" 130 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 131 | encoded = model.encode("") 132 | assert np.array_equal(encoded, np.zeros((2,))) 133 | 134 | 135 | def test_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 136 | """Test normalization of vectors.""" 137 | s = "word1 word2 word3" 138 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=False) 139 | X = model.encode(s) 140 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=True) 141 | normalized = model.encode(s) 142 | 143 | expected = X / np.linalg.norm(X) 144 | 145 | np.testing.assert_almost_equal(normalized, expected) 146 | 147 | 148 | def test_save_pretrained( 149 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 150 | ) -> None: 151 | """Test saving a pretrained model.""" 152 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 153 | 154 | # Save the model to the tmp_path 155 | save_path = tmp_path / "saved_model" 156 | model.save_pretrained(save_path) 157 | 158 | # Check that the save_path directory contains the saved files 159 | assert save_path.exists() 160 | 161 | assert (save_path / "model.safetensors").exists() 162 | assert (save_path / "tokenizer.json").exists() 163 | assert (save_path / "config.json").exists() 164 | assert (save_path / "modules.json").exists() 165 | 166 | 167 | def test_load_pretrained( 168 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 169 | ) -> None: 170 | """Test loading a pretrained model after saving it.""" 171 | # Save the model to a temporary path 172 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 173 | save_path = tmp_path / "saved_model" 174 | model.save_pretrained(save_path) 175 | 176 | # Load the model back from the same path 177 | loaded_model = StaticModel.from_pretrained(save_path) 178 | 179 | # Assert that the loaded model has the same properties as the original one 180 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors) 181 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab() 182 | assert loaded_model.config == mock_config 183 | 184 | 185 | def test_load_pretrained_quantized( 186 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 187 | ) -> None: 188 | """Test loading a pretrained model after saving it.""" 189 | # Save the model to a temporary path 190 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 191 | save_path = tmp_path / "saved_model" 192 | model.save_pretrained(save_path) 193 | 194 | # Load the model back from the same path 195 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="int8") 196 | 197 | # Assert that the loaded model has the same properties as the original one 198 | assert loaded_model.embedding.dtype == np.int8 199 | assert loaded_model.embedding.shape == mock_vectors.shape 200 | 201 | # Load the model back from the same path 202 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float16") 203 | 204 | # Assert that the loaded model has the same properties as the original one 205 | assert loaded_model.embedding.dtype == np.float16 206 | assert loaded_model.embedding.shape == mock_vectors.shape 207 | 208 | # Load the model back from the same path 209 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32") 210 | # Assert that the loaded model has the same properties as the original one 211 | assert loaded_model.embedding.dtype == np.float32 212 | assert loaded_model.embedding.shape == mock_vectors.shape 213 | 214 | 215 | def test_load_pretrained_dim( 216 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str] 217 | ) -> None: 218 | """Test loading a pretrained model with dimensionality.""" 219 | # Save the model to a temporary path 220 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config) 221 | save_path = tmp_path / "saved_model" 222 | model.save_pretrained(save_path) 223 | 224 | loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2) 225 | 226 | # Assert that the loaded model has the same properties as the original one 227 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2]) 228 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab() 229 | assert loaded_model.config == mock_config 230 | 231 | # Load the model back from the same path 232 | loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None) 233 | 234 | # Assert that the loaded model has the same properties as the original one 235 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors) 236 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab() 237 | assert loaded_model.config == mock_config 238 | 239 | # Load the model back from the same path 240 | with pytest.raises(ValueError): 241 | StaticModel.from_pretrained(save_path, dimensionality=3000) 242 | 243 | 244 | def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 245 | """Tests whether the normalization initialization is correct.""" 246 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None) 247 | assert not model.normalize 248 | 249 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=False) 250 | assert not model.normalize 251 | 252 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True) 253 | assert model.normalize 254 | 255 | model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": False}, normalize=True) 256 | assert model.normalize 257 | 258 | model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": True}, normalize=False) 259 | assert not model.normalize 260 | 261 | 262 | def test_set_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 263 | """Tests whether the normalize is set correctly.""" 264 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True) 265 | model.normalize = False 266 | assert model.config == {"normalize": False} 267 | model.normalize = True 268 | assert model.config == {"normalize": True} 269 | 270 | 271 | def test_dim(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None: 272 | """Tests the dimensionality of the model.""" 273 | model = StaticModel(mock_vectors, mock_tokenizer, mock_config) 274 | assert model.dim == 2 275 | assert model.dim == model.embedding.shape[1] 276 | 277 | 278 | def test_local_load_from_model(mock_tokenizer: Tokenizer) -> None: 279 | """Test local load from a model.""" 280 | x = np.ones((mock_tokenizer.get_vocab_size(), 2)) 281 | with TemporaryDirectory() as tempdir: 282 | tempdir_path = Path(tempdir) 283 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors") 284 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json")) 285 | 286 | model = StaticModel.load_local(tempdir_path) 287 | assert model.embedding.shape == x.shape 288 | assert model.tokenizer.to_str() == mock_tokenizer.to_str() 289 | assert model.config == {"normalize": False} 290 | 291 | 292 | def test_local_load_from_model_no_folder() -> None: 293 | """Test local load from a model with no folder.""" 294 | with pytest.raises(ValueError): 295 | StaticModel.load_local("woahbuddy_relax_this_is_just_a_test") 296 | -------------------------------------------------------------------------------- /tests/test_quantization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from model2vec.quantization import DType, quantize_embeddings 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "input_dtype,target_dtype,expected_dtype", 9 | [ 10 | (np.float32, DType.Float16, np.float16), 11 | (np.float16, DType.Float32, np.float32), 12 | (np.float32, DType.Float64, np.float64), 13 | (np.float32, DType.Int8, np.int8), 14 | ], 15 | ) 16 | def test_quantize_embeddings(input_dtype: DType, target_dtype: DType, expected_dtype: DType) -> None: 17 | """Test quantization to different dtypes.""" 18 | embeddings = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=input_dtype) 19 | # Use negative values for int8 test case 20 | if target_dtype == DType.Int8: 21 | embeddings = np.array([[-1.0, 2.0], [-3.0, 4.0]], dtype=input_dtype) 22 | 23 | quantized = quantize_embeddings(embeddings, target_dtype) 24 | assert quantized.dtype == expected_dtype 25 | 26 | if target_dtype == DType.Int8: 27 | # Check if the values are in the range [-127, 127] 28 | assert np.all(quantized >= -127) and np.all(quantized <= 127) 29 | else: 30 | assert np.allclose(quantized, embeddings.astype(expected_dtype)) 31 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from transformers import PreTrainedTokenizerFast 5 | 6 | from model2vec.tokenizer.model import _calculate_token_weight_for_unigram, _process_unigram, process_tokenizer 7 | from model2vec.tokenizer.normalizer import replace_normalizer 8 | from model2vec.tokenizer.pretokenizer import _FORBIDDEN_PRETOKENIZERS, _fix_single_pretokenizer, replace_pretokenizer 9 | from model2vec.tokenizer.tokenizer import _rename_added_token, create_tokenizer 10 | 11 | 12 | def test_fix_single_pretokenizer() -> None: 13 | """Test the _fix_single_pretokenizer function.""" 14 | result = _fix_single_pretokenizer({"type": "ByteLevel", "add_prefix_space": False, "use_regex": True}) 15 | assert result == {"type": "ByteLevel", "add_prefix_space": True, "use_regex": False} 16 | 17 | for tokenizer_type in _FORBIDDEN_PRETOKENIZERS: 18 | result = _fix_single_pretokenizer({"type": tokenizer_type}) 19 | assert result is None 20 | 21 | result = _fix_single_pretokenizer( 22 | {"type": "Metaspace", "split": True, "prepend_scheme": "never", "replacement": "▁"} 23 | ) 24 | assert result == {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False} 25 | 26 | 27 | def test_replace_pretokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: 28 | """Test the replace_pretokenizer function.""" 29 | tokenizer = replace_pretokenizer(mock_berttokenizer.backend_tokenizer) 30 | assert tokenizer.pre_tokenizer is not None 31 | assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace" 32 | assert tokenizer.pre_tokenizer.replacement == "▁" 33 | assert tokenizer.pre_tokenizer.prepend_scheme == "always" 34 | assert not tokenizer.pre_tokenizer.split 35 | 36 | tokenizer.pre_tokenizer = None # type: ignore 37 | tokenizer = replace_pretokenizer(tokenizer) 38 | assert tokenizer.pre_tokenizer is not None 39 | assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace" 40 | assert tokenizer.pre_tokenizer.replacement == "▁" 41 | assert tokenizer.pre_tokenizer.prepend_scheme == "always" 42 | assert tokenizer.pre_tokenizer.split is False 43 | 44 | 45 | def test_replace_normalizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: 46 | """Test the replace_normalizer function.""" 47 | tokenizer = replace_normalizer(mock_berttokenizer.backend_tokenizer) 48 | assert tokenizer.normalizer is not None 49 | assert tokenizer.normalizer.__class__.__name__ == "Sequence" 50 | 51 | assert tokenizer.normalizer.normalize_str("Hello, World!") == "hello , world !" 52 | 53 | tokenizer.normalizer = None # type: ignore 54 | tokenizer = replace_normalizer(tokenizer) 55 | assert tokenizer.normalizer.normalize_str("Hello, World!") == "Hello , World !" 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "word,weight", 60 | [ 61 | ("dog", 3), 62 | ("cat", 3), 63 | ("▁longer▁word", 14), 64 | ("▁word", 6), 65 | ("▁", 2), # Single underscore 66 | ("", 0), # Empty string 67 | ("▁a" * 100, 300), # Long word with underscores 68 | ], 69 | ) 70 | def test_calculate_token_weight_for_unigram(word: str, weight: int) -> None: 71 | """Test the _calculate_token_weight_for_unigram function.""" 72 | assert _calculate_token_weight_for_unigram(word) == weight 73 | 74 | 75 | def test_process_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: 76 | """Test the process_tokenizer function.""" 77 | vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"] 78 | tokenizer_json = json.loads(mock_berttokenizer.backend_tokenizer.to_str()) 79 | tokenizer_json = process_tokenizer(tokenizer_json=tokenizer_json, pre_tokenized_tokens=vocab, unk_token="[UNK]") 80 | 81 | assert tokenizer_json["model"]["type"] == "Unigram" 82 | assert tokenizer_json["model"]["unk_id"] == 5 # Index of "[UNK]" 83 | assert len(tokenizer_json["model"]["vocab"]) == 6 84 | assert all(isinstance(token, tuple) and len(token) == 2 for token in tokenizer_json["model"]["vocab"]) 85 | for (x, _), y in zip(tokenizer_json["model"]["vocab"], vocab): 86 | assert x == y, f"Expected {y}, but got {x}" 87 | 88 | 89 | def test_process_unigram() -> None: 90 | """Test the _process_unigram function.""" 91 | vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"] 92 | orig_vocab = [("dog", 0), ("cat", 0)] 93 | model = {"model": {"type": "Unigram", "vocab": orig_vocab}} 94 | processed_model = _process_unigram(model, vocab, "[UNK]") 95 | assert processed_model["model"]["type"] == "Unigram" 96 | assert processed_model["model"]["unk_id"] == 5 # Index of "[UNK]" 97 | assert len(processed_model["model"]["vocab"]) == 6 98 | assert all(isinstance(token, list) and len(token) == 2 for token in processed_model["model"]["vocab"]) 99 | 100 | for (x, score), y in zip(processed_model["model"]["vocab"], vocab): 101 | assert x == y, f"Expected {y}, but got {x}" 102 | if x in orig_vocab: 103 | assert score == 0 104 | 105 | assert process_tokenizer(model, vocab, "[UNK]") == processed_model 106 | 107 | 108 | def test_rename_added_token() -> None: 109 | """Test the _rename_added_token function.""" 110 | # Invalid input 111 | result = _rename_added_token(None, "a", [{"content": "a", "id": 0}], ["a"]) 112 | assert result == [{"content": "a", "id": 0}] 113 | 114 | # Rename 'a' to 'c' 115 | result = _rename_added_token("a", "c", [{"content": "a"}], ["a"]) 116 | assert result == [{"content": "c", "id": 0}] 117 | 118 | 119 | def test_create_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None: 120 | """Test the create_tokenizer function.""" 121 | tokenizer = create_tokenizer(tokenizer=mock_berttokenizer, vocabulary=["dog", "catssssss"], token_remove_regex=None) 122 | assert tokenizer.backend_tokenizer.get_vocab_size() == 29525 123 | assert tokenizer.encode("catssssss") == [29524] 124 | -------------------------------------------------------------------------------- /tests/test_trainable.py: -------------------------------------------------------------------------------- 1 | from tempfile import TemporaryDirectory 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from tokenizers import Tokenizer 7 | from transformers import AutoTokenizer 8 | 9 | from model2vec.model import StaticModel 10 | from model2vec.train import StaticModelForClassification 11 | from model2vec.train.base import FinetunableStaticModel, TextDataset 12 | 13 | 14 | @pytest.mark.parametrize("n_layers", [0, 1, 2, 3]) 15 | def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 16 | """Test successful initialization of StaticModelForClassification.""" 17 | vectors_torched = torch.from_numpy(mock_vectors) 18 | s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers) 19 | assert s.vectors.shape == mock_vectors.shape 20 | assert s.w.shape[0] == mock_vectors.shape[0] 21 | assert list(s.classes) == s.classes_ 22 | assert list(s.classes) == ["0", "1"] 23 | 24 | head = s.construct_head() 25 | assert head[0].in_features == mock_vectors.shape[1] 26 | head = s.construct_head() 27 | assert head[0].in_features == mock_vectors.shape[1] 28 | assert head[-1].out_features == 2 29 | 30 | 31 | def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 32 | """Test successful initialization of the base class.""" 33 | vectors_torched = torch.from_numpy(mock_vectors) 34 | s = FinetunableStaticModel(vectors=vectors_torched, tokenizer=mock_tokenizer) 35 | assert s.vectors.shape == mock_vectors.shape 36 | assert s.w.shape[0] == mock_vectors.shape[0] 37 | 38 | head = s.construct_head() 39 | assert head[0].in_features == mock_vectors.shape[1] 40 | 41 | 42 | def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 43 | """Test initializion from a static model.""" 44 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) 45 | s = FinetunableStaticModel.from_static_model(model=model) 46 | assert s.vectors.shape == mock_vectors.shape 47 | assert s.w.shape[0] == mock_vectors.shape[0] 48 | 49 | with TemporaryDirectory() as temp_dir: 50 | model.save_pretrained(temp_dir) 51 | s = FinetunableStaticModel.from_pretrained(model_name=temp_dir) 52 | assert s.vectors.shape == mock_vectors.shape 53 | assert s.w.shape[0] == mock_vectors.shape[0] 54 | 55 | 56 | def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: 57 | """Test initializion from a static model.""" 58 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer) 59 | s = StaticModelForClassification.from_static_model(model=model) 60 | assert s.vectors.shape == mock_vectors.shape 61 | assert s.w.shape[0] == mock_vectors.shape[0] 62 | 63 | with TemporaryDirectory() as temp_dir: 64 | model.save_pretrained(temp_dir) 65 | s = StaticModelForClassification.from_pretrained(model_name=temp_dir) 66 | assert s.vectors.shape == mock_vectors.shape 67 | assert s.w.shape[0] == mock_vectors.shape[0] 68 | 69 | 70 | def test_encode(mock_trained_pipeline: StaticModelForClassification) -> None: 71 | """Test the encode function.""" 72 | result = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()) 73 | assert result.shape == (2, 12) 74 | assert torch.allclose(result[0], result[1]) 75 | 76 | 77 | def test_tokenize(mock_trained_pipeline: StaticModelForClassification) -> None: 78 | """Test the encode function.""" 79 | result = mock_trained_pipeline.tokenize(["dog dog", "cat"]) 80 | assert result.shape == torch.Size([2, 2]) 81 | assert result[1, 1] == 0 82 | 83 | 84 | def test_device(mock_trained_pipeline: StaticModelForClassification) -> None: 85 | """Get the device.""" 86 | assert mock_trained_pipeline.device == torch.device(type="cpu") # type: ignore # False positive 87 | assert mock_trained_pipeline.device == mock_trained_pipeline.w.device 88 | 89 | 90 | def test_conversion(mock_trained_pipeline: StaticModelForClassification) -> None: 91 | """Test the conversion to numpy.""" 92 | staticmodel = mock_trained_pipeline.to_static_model() 93 | with torch.no_grad(): 94 | result_1 = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()).numpy() 95 | result_2 = staticmodel.embedding[[[0, 1], [1, 0]]].mean(0) 96 | result_2 /= np.linalg.norm(result_2, axis=1, keepdims=True) 97 | 98 | assert np.allclose(result_1, result_2) 99 | 100 | 101 | def test_textdataset_init() -> None: 102 | """Test the textdataset init.""" 103 | dataset = TextDataset([[0], [1]], torch.arange(2)) 104 | assert len(dataset) == 2 105 | 106 | 107 | def test_textdataset_init_incorrect() -> None: 108 | """Test the textdataset init.""" 109 | with pytest.raises(ValueError): 110 | TextDataset([[0]], torch.arange(2)) 111 | 112 | 113 | def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None: 114 | """Test the predict function.""" 115 | result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() 116 | if mock_trained_pipeline.multilabel: 117 | if type(mock_trained_pipeline.classes_[0]) == str: 118 | assert result == [["a", "b"], ["a", "b"]] 119 | else: 120 | assert result == [[0, 1], [0, 1]] 121 | else: 122 | if type(mock_trained_pipeline.classes_[0]) == str: 123 | assert result == ["b", "b"] 124 | else: 125 | assert result == [1, 1] 126 | 127 | 128 | def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None: 129 | """Test the predict function.""" 130 | result = mock_trained_pipeline.predict_proba(["dog cat", "dog"]) 131 | assert result.shape == (2, 2) 132 | 133 | 134 | def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification) -> None: 135 | """Convert a model to a pipeline.""" 136 | mock_trained_pipeline.eval() 137 | pipeline = mock_trained_pipeline.to_pipeline() 138 | encoded_pipeline = pipeline.model.encode(["dog cat", "dog"]) 139 | encoded_model = mock_trained_pipeline(mock_trained_pipeline.tokenize(["dog cat", "dog"]))[1].detach().numpy() 140 | assert np.allclose(encoded_pipeline, encoded_model) 141 | a = pipeline.predict(["dog cat", "dog"]).tolist() 142 | b = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist() 143 | assert a == b 144 | p1 = pipeline.predict_proba(["dog cat", "dog"]) 145 | p2 = mock_trained_pipeline.predict_proba(["dog cat", "dog"]) 146 | assert np.allclose(p1, p2) 147 | 148 | 149 | def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -> None: 150 | """Test the train test split function.""" 151 | a, b, c, d = mock_trained_pipeline._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5) 152 | assert len(a) == 2 153 | assert len(b) == 2 154 | assert len(c) == len(a) 155 | assert len(d) == len(b) 156 | 157 | 158 | def test_y_val_none() -> None: 159 | """Test the y_val function.""" 160 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer 161 | torch.random.manual_seed(42) 162 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) 163 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") 164 | 165 | X = ["dog", "cat"] 166 | y = ["0", "1"] 167 | 168 | X_val = ["dog", "cat"] 169 | y_val = ["0", "1"] 170 | 171 | with pytest.raises(ValueError): 172 | model.fit(X, y, X_val=X_val, y_val=None) 173 | with pytest.raises(ValueError): 174 | model.fit(X, y, X_val=None, y_val=y_val) 175 | model.fit(X, y, X_val=None, y_val=None) 176 | 177 | 178 | @pytest.mark.parametrize( 179 | "y_multi,y_val_multi,should_crash", 180 | [[True, True, False], [False, False, False], [True, False, True], [False, True, True]], 181 | ) 182 | def test_y_val(y_multi: bool, y_val_multi: bool, should_crash: bool) -> None: 183 | """Test the y_val function.""" 184 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer 185 | torch.random.manual_seed(42) 186 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) 187 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") 188 | 189 | X = ["dog", "cat"] 190 | y = [["0", "1"], ["0"]] if y_multi else ["0", "1"] # type: ignore 191 | 192 | X_val = ["dog", "cat"] 193 | y_val = [["0", "1"], ["0"]] if y_val_multi else ["0", "1"] # type: ignore 194 | 195 | if should_crash: 196 | with pytest.raises(ValueError): 197 | model.fit(X, y, X_val=X_val, y_val=y_val) 198 | else: 199 | model.fit(X, y, X_val=X_val, y_val=y_val) 200 | 201 | 202 | def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None: 203 | """Test the evaluate function.""" 204 | if mock_trained_pipeline.multilabel: 205 | if type(mock_trained_pipeline.classes_[0]) == str: 206 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [["a", "b"], ["a"]]) 207 | else: 208 | # Ignore the type error since we don't support int labels in our typing, but the code does 209 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [[0, 1], [0]]) # type: ignore 210 | else: 211 | if type(mock_trained_pipeline.classes_[0]) == str: 212 | mock_trained_pipeline.evaluate(["dog cat", "dog"], ["a", "a"]) 213 | else: 214 | # Ignore the type error since we don't support int labels in our typing, but the code does 215 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore 216 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | from pathlib import Path 5 | from tempfile import NamedTemporaryFile, TemporaryDirectory 6 | from typing import Any 7 | from unittest.mock import patch 8 | 9 | import numpy as np 10 | import pytest 11 | import safetensors 12 | import safetensors.numpy 13 | from tokenizers import Tokenizer 14 | 15 | from model2vec.distill.utils import select_optimal_device 16 | from model2vec.hf_utils import _get_metadata_from_readme 17 | from model2vec.utils import get_package_extras, importable, load_local_model 18 | 19 | 20 | def test__get_metadata_from_readme_not_exists() -> None: 21 | """Test getting metadata from a README.""" 22 | assert _get_metadata_from_readme(Path("zzz")) == {} 23 | 24 | 25 | def test__get_metadata_from_readme_mocked_file() -> None: 26 | """Test getting metadata from a README.""" 27 | with NamedTemporaryFile() as f: 28 | f.write(b"---\nkey: value\n---\n") 29 | f.flush() 30 | assert _get_metadata_from_readme(Path(f.name))["key"] == "value" 31 | 32 | 33 | def test__get_metadata_from_readme_mocked_file_keys() -> None: 34 | """Test getting metadata from a README.""" 35 | with NamedTemporaryFile() as f: 36 | f.write(b"") 37 | f.flush() 38 | assert set(_get_metadata_from_readme(Path(f.name))) == set() 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "device, expected, cuda, mps", 43 | [ 44 | ("cpu", "cpu", True, True), 45 | ("cpu", "cpu", True, False), 46 | ("cpu", "cpu", False, True), 47 | ("cpu", "cpu", False, False), 48 | ("clown", "clown", False, False), 49 | (None, "cuda", True, True), 50 | (None, "cuda", True, False), 51 | (None, "mps", False, True), 52 | (None, "cpu", False, False), 53 | ], 54 | ) 55 | def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None: 56 | """Test whether the optimal device is selected.""" 57 | with ( 58 | patch("torch.cuda.is_available", return_value=cuda), 59 | patch("torch.backends.mps.is_available", return_value=mps), 60 | ): 61 | assert select_optimal_device(device) == expected 62 | 63 | 64 | def test_importable() -> None: 65 | """Test the importable function.""" 66 | with pytest.raises(ImportError): 67 | importable("clown", "clown") 68 | 69 | importable("os", "clown") 70 | 71 | 72 | def test_get_package_extras() -> None: 73 | """Test package extras.""" 74 | extras = set(get_package_extras("model2vec", "distill")) 75 | assert extras == {"torch", "transformers", "scikit-learn"} 76 | 77 | 78 | def test_get_package_extras_empty() -> None: 79 | """Test package extras with an empty package.""" 80 | assert not list(get_package_extras("tqdm", "")) 81 | 82 | 83 | @pytest.mark.parametrize( 84 | "config, expected", 85 | [ 86 | ({"dog": "cat"}, {"dog": "cat"}), 87 | ({}, {}), 88 | (None, {}), 89 | ], 90 | ) 91 | def test_local_load(mock_tokenizer: Tokenizer, config: dict[str, Any], expected: dict[str, Any]) -> None: 92 | """Test local loading.""" 93 | x = np.ones((mock_tokenizer.get_vocab_size(), 2)) 94 | 95 | with TemporaryDirectory() as tempdir: 96 | tempdir_path = Path(tempdir) 97 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors") 98 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json")) 99 | if config is not None: 100 | json.dump(config, open(tempdir_path / "config.json", "w")) 101 | arr, tokenizer, config = load_local_model(tempdir_path) 102 | assert config == expected 103 | assert tokenizer.to_str() == mock_tokenizer.to_str() 104 | assert arr.shape == x.shape 105 | 106 | 107 | def test_local_load_mismatch(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None: 108 | """Test local loading.""" 109 | x = np.ones((10, 2)) 110 | 111 | with TemporaryDirectory() as tempdir: 112 | tempdir_path = Path(tempdir) 113 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors") 114 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json")) 115 | 116 | load_local_model(tempdir_path) 117 | expected = ( 118 | f"Number of tokens does not match number of embeddings: `{len(mock_tokenizer.get_vocab())}` vs `{len(x)}`" 119 | ) 120 | assert len(caplog.records) == 1 121 | assert caplog.records[0].message == expected 122 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | Tutorials 4 | 5 | 6 |
7 | 8 | # Tutorials 9 | 10 | This is a list of all our tutorials. They are all self-contained ipython notebooks. 11 | 12 | | | what? | Link | 13 | |--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------| 14 | | **Recipe search** 🍝 | Learn how to do lightning-fast semantic search by distilling a small model. Compare a really tiny model to a larger with one with a better vocabulary. Learn what Fattoush is (delicious). | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/recipe_search.ipynb) | 15 | | **Semantic chunking** 🧩 | Learn how to chunk your text into meaningful segments with [Chonkie](https://github.com/chonkie-inc/chonkie) at lightning-speed. Efficiently query your chunks with [Vicinity](https://github.com/MinishLab/vicinity). | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/semantic_chunking.ipynb) | 16 | | **Training a classifier** 🧩 | Learn how to train a classifier using model2vec. Lightning fast, great performance, especially on small datasets | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/train_classifier.ipynb) | 17 | -------------------------------------------------------------------------------- /tutorials/semantic_chunking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Semantic Chunking with Chonkie and Model2Vec**\n", 8 | "\n", 9 | "Semantic chunking is a task of identifying the semantic boundaries of a piece of text. In this tutorial, we will use the [Chonkie](https://github.com/bhavnicksm/chonkie) library to perform semantic chunking on the book War and Peace. Chonkie is a library that provides a lightweight and fast solution to semantic chunking using pre-trained models. It supports our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062) out of the box, which we will be using in this tutorial.\n", 10 | "\n", 11 | "After chunking our text, we will be using [Vicinity](https://github.com/MinishLab/vicinity), a lightweight nearest neighbors library, to create an index of our chunks and query them." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Install the necessary libraries\n", 21 | "!pip install -q datasets model2vec numpy tqdm vicinity \"chonkie[semantic]\"" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 1, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# Import the necessary libraries\n", 31 | "import random \n", 32 | "import re\n", 33 | "import requests\n", 34 | "from time import perf_counter\n", 35 | "from chonkie import SDPMChunker\n", 36 | "from model2vec import StaticModel\n", 37 | "from vicinity import Vicinity\n", 38 | "\n", 39 | "random.seed(0)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "**Loading and pre-processing**\n", 47 | "\n", 48 | "First, we will download War and Peace and apply some basic pre-processing." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# URL for War and Peace on Project Gutenberg\n", 58 | "url = \"https://www.gutenberg.org/files/2600/2600-0.txt\"\n", 59 | "\n", 60 | "# Download the book\n", 61 | "response = requests.get(url)\n", 62 | "book_text = response.text\n", 63 | "\n", 64 | "def preprocess_text(text: str, min_length: int = 5):\n", 65 | " \"\"\"Basic text preprocessing function.\"\"\"\n", 66 | " text = text.replace(\"\\n\", \" \")\n", 67 | " text = text.replace(\"\\r\", \" \")\n", 68 | " sentences = re.findall(r'[^.!?]*[.!?]', text)\n", 69 | " # Filter out sentences shorter than the specified minimum length\n", 70 | " filtered_sentences = [sentence.strip() for sentence in sentences if len(sentence.split()) >= min_length]\n", 71 | " # Recombine the filtered sentences\n", 72 | " return ' '.join(filtered_sentences)\n", 73 | "\n", 74 | "# Preprocess the text\n", 75 | "book_text = preprocess_text(book_text)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "**Chunking with Chonkie**\n", 83 | "\n", 84 | "Next, we will use Chonkie to chunk our text into semantic chunks." 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 22, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Number of chunks: 4436\n", 97 | "Time taken: 1.6311538339941762\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "# Initialize a SemanticChunker from Chonkie with the potion-base-8M model\n", 103 | "chunker = SDPMChunker(\n", 104 | " embedding_model=\"minishlab/potion-base-32M\",\n", 105 | " chunk_size = 512, \n", 106 | " skip_window=5, \n", 107 | " min_sentences=3\n", 108 | ")\n", 109 | "\n", 110 | "# Chunk the text\n", 111 | "time = perf_counter()\n", 112 | "chunks = chunker.chunk(book_text)\n", 113 | "print(f\"Number of chunks: {len(chunks)}\")\n", 114 | "print(f\"Time taken: {perf_counter() - time}\")" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "And that's it, we chunked the entirety of War and Peace in ~2 seconds. Not bad! Let's look at some example chunks." 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 23, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | " Wait and we shall see! As if fighting were fun. They are like children from whom one can’t get any sensible account of what has happened because they all want to show how well they can fight. But that’s not what is needed now. “And what ingenious maneuvers they all propose to me! \n", 134 | "\n", 135 | " The first thing he saw on riding up to the space where Túshin’s guns were stationed was an unharnessed horse with a broken leg, that lay screaming piteously beside the harnessed horses. Blood was gushing from its leg as from a spring. Among the limbers lay several dead men. \n", 136 | "\n", 137 | " Out of an army of a hundred thousand we must expect at least twenty thousand wounded, and we haven’t stretchers, or bunks, or dressers, or doctors enough for six thousand. We have ten thousand carts, but we need other things as well—we must manage as best we can! ” The strange thought that of the thousands of men, young and old, who had stared with merry surprise at his hat (perhaps the very men he had noticed), twenty thousand were inevitably doomed to wounds and death amazed Pierre. “They may die tomorrow; why are they thinking of anything but death? ” And by some latent sequence of thought the descent of the Mozháysk hill, the carts with the wounded, the ringing bells, the slanting rays of the sun, and the songs of the cavalrymen vividly recurred to his mind. “The cavalry ride to battle and meet the wounded and do not for a moment think of what awaits them, but pass by, winking at the wounded. Yet from among these men twenty thousand are doomed to die, and they wonder at my hat! ” thought Pierre, continuing his way to Tatárinova. \n", 138 | "\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "# Print a few example chunks\n", 144 | "for _ in range(3):\n", 145 | " chunk = random.choice(chunks)\n", 146 | " print(chunk.text, \"\\n\")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "Those look good. Next, let's create a vector search index with Vicinity and Model2Vec.\n", 154 | "\n", 155 | "**Creating a vector search index**" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 24, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "Time taken: 2.269912125004339\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "# Initialize an embedding model and encode the chunk texts\n", 173 | "time = perf_counter()\n", 174 | "model = StaticModel.from_pretrained(\"minishlab/potion-base-32M\")\n", 175 | "chunk_texts = [chunk.text for chunk in chunks]\n", 176 | "chunk_embeddings = model.encode(chunk_texts)\n", 177 | "\n", 178 | "# Create a Vicinity instance\n", 179 | "vicinity = Vicinity.from_vectors_and_items(vectors=chunk_embeddings, items=chunk_texts)\n", 180 | "print(f\"Time taken: {perf_counter() - time}\")" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "Done! We embedded all our chunks and created an in index in ~1.5 seconds. Now that we have our index, let's query it with some queries.\n", 188 | "\n", 189 | "**Querying the index**" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 25, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "Query: Emperor Napoleon\n", 202 | "--------------------------------------------------\n", 203 | " In 1808 the Emperor Alexander went to Erfurt for a fresh interview with the Emperor Napoleon, and in the upper circles of Petersburg there was much talk of the grandeur of this important meeting. CHAPTER XXII In 1809 the intimacy between “the world’s two arbiters,” as Napoleon and Alexander were called, was such that when Napoleon declared war on Austria a Russian corps crossed the frontier to co-operate with our old enemy Bonaparte against our old ally the Emperor of Austria, and in court circles the possibility of marriage between Napoleon and one of Alexander’s sisters was spoken of. But besides considerations of foreign policy, the attention of Russian society was at that time keenly directed on the internal changes that were being undertaken in all the departments of government. Life meanwhile—real life, with its essential interests of health and sickness, toil and rest, and its intellectual interests in thought, science, poetry, music, love, friendship, hatred, and passions—went on as usual, independently of and apart from political friendship or enmity with Napoleon Bonaparte and from all the schemes of reconstruction. BOOK SIX: 1808 - 10 CHAPTER I Prince Andrew had spent two years continuously in the country. All the plans Pierre had attempted on his estates—and constantly changing from one thing to another had never accomplished—were carried out by Prince Andrew without display and without perceptible difficulty. \n", 204 | "\n", 205 | " CHAPTER XXVI On August 25, the eve of the battle of Borodinó, M. de Beausset, prefect of the French Emperor’s palace, arrived at Napoleon’s quarters at Valúevo with Colonel Fabvier, the former from Paris and the latter from Madrid. Donning his court uniform, M. de Beausset ordered a box he had brought for the Emperor to be carried before him and entered the first compartment of Napoleon’s tent, where he began opening the box while conversing with Napoleon’s aides-de-camp who surrounded him. Fabvier, not entering the tent, remained at the entrance talking to some generals of his acquaintance. The Emperor Napoleon had not yet left his bedroom and was finishing his toilet. \n", 206 | "\n", 207 | " In Russia there was an Emperor, Alexander, who decided to restore order in Europe and therefore fought against Napoleon. In 1807 he suddenly made friends with him, but in 1811 they again quarreled and again began killing many people. Napoleon led six hundred thousand men into Russia and captured Moscow; then he suddenly ran away from Moscow, and the Emperor Alexander, helped by the advice of Stein and others, united Europe to arm against the disturber of its peace. All Napoleon’s allies suddenly became his enemies and their forces advanced against the fresh forces he raised. The Allies defeated Napoleon, entered Paris, forced Napoleon to abdicate, and sent him to the island of Elba, not depriving him of the title of Emperor and showing him every respect, though five years before and one year later they all regarded him as an outlaw and a brigand. Then Louis XVIII, who till then had been the laughingstock both of the French and the Allies, began to reign. And Napoleon, shedding tears before his Old Guards, renounced the throne and went into exile. \n", 208 | "\n", 209 | "Query: The battle of Austerlitz\n", 210 | "--------------------------------------------------\n", 211 | " Behave as you did at Austerlitz, Friedland, Vítebsk, and Smolénsk. Let our remotest posterity recall your achievements this day with pride. Let it be said of each of you: “He was in the great battle before Moscow! \n", 212 | "\n", 213 | " By a strange coincidence, this task, which turned out to be a most difficult and important one, was entrusted to Dokhtúrov—that same modest little Dokhtúrov whom no one had described to us as drawing up plans of battles, dashing about in front of regiments, showering crosses on batteries, and so on, and who was thought to be and was spoken of as undecided and undiscerning—but whom we find commanding wherever the position was most difficult all through the Russo-French wars from Austerlitz to the year 1813. At Austerlitz he remained last at the Augezd dam, rallying the regiments, saving what was possible when all were flying and perishing and not a single general was left in the rear guard. Ill with fever he went to Smolénsk with twenty thousand men to defend the town against Napoleon’s whole army. \n", 214 | "\n", 215 | " “Nothing is truer or sadder. These gentlemen ride onto the bridge alone and wave white handkerchiefs; they assure the officer on duty that they, the marshals, are on their way to negotiate with Prince Auersperg. He lets them enter the tête-de-pont. * They spin him a thousand gasconades, saying that the war is over, that the Emperor Francis is arranging a meeting with Bonaparte, that they desire to see Prince Auersperg, and so on. The officer sends for Auersperg; these gentlemen embrace the officers, crack jokes, sit on the cannon, and meanwhile a French battalion gets to the bridge unobserved, flings the bags of incendiary material into the water, and approaches the tête-de-pont. At length appears the lieutenant general, our dear Prince Auersperg von Mautern himself. Flower of the Austrian army, hero of the Turkish wars! Hostilities are ended, we can shake one another’s hand. The Emperor Napoleon burns with impatience to make Prince Auersperg’s acquaintance. \n", 216 | "\n", 217 | "Query: Paris\n", 218 | "--------------------------------------------------\n", 219 | " Paris is Talma, la Duchénois, Potier, the Sorbonne, the boulevards,” and noticing that his conclusion was weaker than what had gone before, he added quickly: “There is only one Paris in the world. You have been to Paris and have remained Russian. Well, I don’t esteem you the less for it. \n", 220 | "\n", 221 | " Look at our youths, look at our ladies! The French are our Gods: Paris is our Kingdom of Heaven. ” He began speaking louder, evidently to be heard by everyone. “French dresses, French ideas, French feelings! \n", 222 | "\n", 223 | " “Oh yes, one sees that plainly. A man who doesn’t know Paris is a savage. You can tell a Parisian two leagues off. \n", 224 | "\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "queries = [\"Emperor Napoleon\", \"The battle of Austerlitz\", \"Paris\"]\n", 230 | "for query in queries:\n", 231 | " print(f\"Query: {query}\\n{'-' * 50}\")\n", 232 | " query_embedding = model.encode(query)\n", 233 | " results = vicinity.query(query_embedding, k=3)[0]\n", 234 | "\n", 235 | " for result in results:\n", 236 | " print(result[0], \"\\n\")" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "These indeed look like relevant chunks, nice! That's it for this tutorial. We were able to chunk, index, and query War and Peace in about 3.5 seconds using Chonkie, Vicinity, and Model2Vec. Lightweight and fast, just how we like it." 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 3 (ipykernel)", 250 | "language": "python", 251 | "name": "python3" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.10.15" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 4 268 | } 269 | --------------------------------------------------------------------------------