├── .github
└── workflows
│ └── ci.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── assets
└── images
│ ├── logo.png
│ ├── logo_v2.png
│ ├── model2vec_logo.png
│ ├── model2vec_model_diagram.png
│ ├── model2vec_model_diagram_transparant_dark.png
│ ├── model2vec_model_diagram_transparant_light.png
│ ├── sentences_per_second_vs_average_score.png
│ ├── speed_vs_accuracy.png
│ ├── speed_vs_accuracy_v2.png
│ ├── speed_vs_accuracy_v3.png
│ ├── speed_vs_accuracy_v4.png
│ ├── speed_vs_mteb_score.png
│ ├── speed_vs_mteb_score_v2.png
│ ├── speed_vs_mteb_score_v3.png
│ ├── training_speed_vs_score.png
│ └── tutorial_ezlo.png
├── docs
├── README.md
├── integrations.md
├── usage.md
└── what_is_model2vec.md
├── model2vec
├── __init__.py
├── distill
│ ├── __init__.py
│ ├── distillation.py
│ ├── inference.py
│ └── utils.py
├── hf_utils.py
├── inference
│ ├── README.md
│ ├── __init__.py
│ └── model.py
├── model.py
├── modelcards
│ ├── classifier_template.md
│ └── model_card_template.md
├── py.typed
├── quantization.py
├── tokenizer
│ ├── __init__.py
│ ├── datamodels.py
│ ├── model.py
│ ├── normalizer.py
│ ├── pretokenizer.py
│ └── tokenizer.py
├── train
│ ├── README.md
│ ├── __init__.py
│ ├── base.py
│ └── classifier.py
├── utils.py
└── version.py
├── pyproject.toml
├── results
├── README.md
└── make_speed_vs_mteb_plot.py
├── scripts
└── export_to_onnx.py
├── tests
├── __init__.py
├── conftest.py
├── data
│ └── test_tokenizer
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer.json
│ │ └── tokenizer_config.json
├── test_distillation.py
├── test_inference.py
├── test_model.py
├── test_quantization.py
├── test_tokenizer.py
├── test_trainable.py
└── test_utils.py
├── tutorials
├── README.md
├── recipe_search.ipynb
├── semantic_chunking.ipynb
└── train_classifier.ipynb
└── uv.lock
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: Run tests and upload coverage
2 |
3 | on:
4 | push
5 |
6 | jobs:
7 | test:
8 | name: Run tests with pytest
9 | runs-on: ${{ matrix.os }}
10 | strategy:
11 | matrix:
12 | os: ["ubuntu-latest", "windows-latest"]
13 | python-version: ["3.9", "3.10", "3.11", "3.12"]
14 | exclude:
15 | - os: windows-latest
16 | python-version: "3.9"
17 | - os: windows-latest
18 | python-version: "3.11"
19 | - os: windows-latest
20 | python-version: "3.12"
21 | fail-fast: false
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 |
26 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }}
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: ${{ matrix.python-version }}
30 | allow-prereleases: true
31 |
32 | # Step for Windows: Create and activate a virtual environment
33 | - name: Create and activate a virtual environment (Windows)
34 | if: ${{ runner.os == 'Windows' }}
35 | run: |
36 | irm https://astral.sh/uv/install.ps1 | iex
37 | $env:Path = "C:\Users\runneradmin\.local\bin;$env:Path"
38 | uv venv .venv
39 | "VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append
40 | "$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append
41 |
42 | # Step for Unix: Create and activate a virtual environment
43 | - name: Create and activate a virtual environment (Unix)
44 | if: ${{ runner.os != 'Windows' }}
45 | run: |
46 | curl -LsSf https://astral.sh/uv/install.sh | sh
47 | uv venv .venv
48 | echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV
49 | echo "$PWD/.venv/bin" >> $GITHUB_PATH
50 |
51 | # Install dependencies using uv pip
52 | - name: Install dependencies
53 | run: make install-no-pre-commit
54 |
55 | # Run tests with coverage
56 | - name: Run tests under coverage
57 | run: |
58 | coverage run -m pytest
59 | coverage report
60 |
61 | # Upload results to Codecov
62 | - name: Upload results to Codecov
63 | uses: codecov/codecov-action@v4
64 | with:
65 | token: ${{ secrets.CODECOV_TOKEN }}
66 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | *.npy
164 | *.torch
165 | .venv
166 | .DS_Store
167 | models
168 | checkpoints/*
169 | features/*
170 | model2vec_models
171 | counts/*
172 | results_old/*
173 | local/*
174 | lightning_logs/*
175 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.4.0
6 | hooks:
7 | - id: check-ast
8 | description: Simply check whether files parse as valid python.
9 | - id: trailing-whitespace
10 | description: Trims trailing whitespace
11 | - id: end-of-file-fixer
12 | description: Makes sure files end in a newline and only a newline.
13 | - id: check-added-large-files
14 | args: ['--maxkb=5000']
15 | description: Prevent giant files from being committed.
16 | - id: check-case-conflict
17 | description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows.
18 | - id: check-yaml
19 | description: Check yaml files for syntax errors.
20 | - repo: https://github.com/jsh9/pydoclint
21 | rev: 0.5.3
22 | hooks:
23 | - id: pydoclint
24 | - repo: https://github.com/astral-sh/ruff-pre-commit
25 | rev: v0.4.10
26 | hooks:
27 | - id: ruff
28 | args: [ --fix ]
29 | - id: ruff-format
30 | - repo: local
31 | hooks:
32 | - id: mypy
33 | name: mypy
34 | entry: mypy
35 | language: python
36 | types: [python]
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Thomas van Dongen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | clean:
2 |
3 |
4 | venv:
5 | uv venv
6 |
7 | install:
8 | uv sync --all-extras
9 | uv run pre-commit install
10 |
11 | install-no-pre-commit:
12 | uv pip install ".[dev,distill,inference,train]"
13 | uv pip install "torch<2.5.0"
14 |
15 | install-base:
16 | uv sync --extra dev
17 |
18 | fix:
19 | uv run pre-commit run --all-files
20 |
21 | test:
22 | uv run pytest --cov=model2vec --cov-report=term-missing
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | 
4 | Fast State-of-the-Art Static Embeddings
5 |
6 |
7 |
8 |
9 |
18 |
19 |
37 |
38 |
39 |
40 |
41 |
42 | Model2Vec is a technique to turn any sentence transformer into a really small static model, reducing model size by a factor up to 50 and making the models up to 500 times faster, with a small drop in performance. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is the most performant static embedding model in the world. See our results [here](results/README.md), or dive in to see how it works.
43 |
44 |
45 |
46 |
47 | [Quickstart](#quickstart) • [Updates & Announcements](#updates--announcements) • [Main Features](#main-features) • [Model List](#model-list)
48 |
49 |
50 |
51 | ## Quickstart
52 |
53 | Install the lightweight base package with:
54 |
55 | ```bash
56 | pip install model2vec
57 | ```
58 |
59 | You can start using Model2Vec by loading one of our [flagship models from the HuggingFace hub](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062). These models are pre-trained and ready to use. The following code snippet shows how to load a model and make embeddings, which you can use for any task, such as text classification, retrieval, clustering, or building a RAG system:
60 | ```python
61 | from model2vec import StaticModel
62 |
63 | # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
64 | model = StaticModel.from_pretrained("minishlab/potion-base-8M")
65 |
66 | # Make embeddings
67 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
68 |
69 | # Make sequences of token embeddings
70 | token_embeddings = model.encode_as_sequence(["It's dangerous to go alone!", "It's a secret to everybody."])
71 | ```
72 |
73 | Instead of using one of our models, you can also distill your own Model2Vec model from a Sentence Transformer model. First, install the `distillation` extras with:
74 |
75 | ```bash
76 | pip install model2vec[distill]
77 | ```
78 |
79 |
80 | Then, you can distill a model in ~30 seconds on a CPU with the following code snippet:
81 |
82 | ```python
83 | from model2vec.distill import distill
84 |
85 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
86 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
87 |
88 | # Save the model
89 | m2v_model.save_pretrained("m2v_model")
90 | ```
91 |
92 | After distillation, you can also fine-tune your own classification models on top of the distilled model, or on a pre-trained model. First, make sure you install the `training` extras with:
93 |
94 | ```bash
95 | pip install model2vec[training]
96 | ```
97 |
98 | Then, you can fine-tune a model as follows:
99 |
100 | ```python
101 | import numpy as np
102 | from datasets import load_dataset
103 | from model2vec.train import StaticModelForClassification
104 |
105 | # Initialize a classifier from a pre-trained model
106 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
107 |
108 | # Load a dataset. Note: both single and multi-label classification datasets are supported
109 | ds = load_dataset("setfit/subj")
110 |
111 | # Train the classifier on text (X) and labels (y)
112 | classifier.fit(ds["train"]["text"], ds["train"]["label"])
113 |
114 | # Evaluate the classifier
115 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
116 | ```
117 |
118 | For advanced usage, please refer to our [usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md).
119 |
120 | ## Updates & Announcements
121 |
122 | - **23/05/2025**: We released [potion-multilingual-128M](https://huggingface.co/minishlab/potion-multilingual-128M), a multilingual model trained on 101 languages. It is the best performing static embedding model for multilingual tasks, and is capable of generating embeddings for any text in any language. The results can be found in our [results](results/README.md#mmteb-results-multilingual) section.
123 |
124 | - **01/05/2025**: We released backend support for `BPE` and `Unigram` tokenizers, along with quantization and dimensionality reduction. New Model2Vec models are now 50% of the original models, and can be quantized to int8 to be 25% of the size, without loss of performance.
125 |
126 | - **12/02/2025**: We released **Model2Vec training**, allowing you to fine-tune your own classification models on top of Model2Vec models. Find out more in our [training documentation](https://github.com/MinishLab/model2vec/blob/main/model2vec/train/README.md) and [results](results/README.md#training-results).
127 |
128 | - **30/01/2025**: We released two new models: [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) and [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M). [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) is our most performant model to date, using a larger vocabulary and higher dimensions. [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) is a finetune of [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) that is optimized for retrieval tasks, and is the best performing static retrieval model currently available.
129 |
130 | - **30/10/2024**: We released three new models: [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M), [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M), and [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M). These models are trained using [Tokenlearn](https://github.com/MinishLab/tokenlearn). Find out more in our [blog post](https://minishlab.github.io/tokenlearn_blogpost/). NOTE: for users of any of our old English M2V models, we recommend switching to these new models as they [perform better on all tasks](https://github.com/MinishLab/model2vec/tree/main/results).
131 |
132 | ## Main Features
133 |
134 | - **State-of-the-Art Performance**: Model2Vec models outperform any other static embeddings (such as GLoVe and BPEmb) by a large margin, as can be seen in our [results](results/README.md).
135 | - **Small**: Model2Vec reduces the size of a Sentence Transformer model by a factor of up to 50. Our [best model](https://huggingface.co/minishlab/potion-base-8M) is just ~30 MB on disk, and our smallest model just ~8 MB (making it the smallest model on [MTEB](https://huggingface.co/spaces/mteb/leaderboard)!).
136 | - **Lightweight Dependencies**: the base package's only major dependency is `numpy`.
137 | - **Lightning-fast Inference**: up to 500 times faster on CPU than the original model.
138 | - **Fast, Dataset-free Distillation**: distill your own model in 30 seconds on a CPU, without a dataset.
139 | - **Fine-tuning**: fine-tune your own classification models on top of Model2Vec models.
140 | - **Integrated in many popular libraries**: Model2Vec is integrated direclty into popular libraries such as [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) and [LangChain](https://github.com/langchain-ai/langchain). For more information, see our [integrations documentation](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md).
141 | - **Tightly integrated with HuggingFace hub**: easily share and load models from the HuggingFace hub, using the familiar `from_pretrained` and `push_to_hub`. Our own models can be found [here](https://huggingface.co/minishlab).
142 |
143 | ## What is Model2Vec?
144 |
145 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Like BPEmb, it can create subword embeddings, but with much better performance. Distillation doesn't need _any_ data, just a vocabulary and a model.
146 |
147 | The core idea is to forward pass a vocabulary through a sentence transformer model, creating static embeddings for the indiviudal tokens. After this, there are a number of post-processing steps we do that results in our best models. For a more extensive deepdive, please refer to the following resources:
148 | - Our initial [Model2Vec blog post](https://huggingface.co/blog/Pringled/model2vec). Note that, while this post gives a good overview of the core idea, we've made a number of substantial improvements since then.
149 | - Our [Tokenlearn blog post](https://minishlab.github.io/tokenlearn_blogpost/). This post describes the Tokenlearn method we used to train our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062).
150 | - Our official [documentation](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md). This document provides a high-level overview of how Model2Vec works.
151 |
152 | ## Documentation
153 |
154 | Our official documentation can be found [here](https://github.com/MinishLab/model2vec/blob/main/docs/README.md). This includes:
155 | - [Usage documentation](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md): provides a technical overview of how to use Model2Vec.
156 | - [Integrations documentation](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md): provides examples of how to use Model2Vec in various downstream libraries.
157 | - [Model2Vec technical documentation](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md): provides a high-level overview of how Model2Vec works.
158 |
159 |
160 | ## Model List
161 |
162 | We provide a number of models that can be used out of the box. These models are available on the [HuggingFace hub](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) and can be loaded using the `from_pretrained` method. The models are listed below.
163 |
164 |
165 |
166 | | Model | Language | Sentence Transformer | Params | Task |
167 | |-----------------------------------------------------------------------|------------|-----------------------------------------------------------------|---------|-----------|
168 | | [potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 32.3M | General |
169 | | [potion-multilingual-128M](https://huggingface.co/minishlab/potion-multilingual-128M) | Multilingual | [bge-m3](https://huggingface.co/BAAI/bge-m3) | 128M | General |
170 | | [potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 32.3M | Retrieval |
171 | | [potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 7.5M | General |
172 | | [potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 3.7M | General |
173 | | [potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) | English | [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 1.8M | General |
174 |
175 |
176 |
177 |
178 | ## Results
179 |
180 | We have performed extensive experiments to evaluate the performance of Model2Vec models. The results are documented in the [results](results/README.md) folder. The results are presented in the following sections:
181 | - [MTEB Results](results/README.md#mteb-results)
182 | - [Training Results](results/README.md#training-results)
183 | - [Ablations](results/README.md#ablations)
184 |
185 | ## License
186 |
187 | MIT
188 |
189 | ## Citing
190 |
191 | If you use Model2Vec in your research, please cite the following:
192 | ```bibtex
193 | @article{minishlab2024model2vec,
194 | author = {Tulkens, Stephan and {van Dongen}, Thomas},
195 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
196 | year = {2024},
197 | url = {https://github.com/MinishLab/model2vec}
198 | }
199 | ```
200 |
--------------------------------------------------------------------------------
/assets/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/logo.png
--------------------------------------------------------------------------------
/assets/images/logo_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/logo_v2.png
--------------------------------------------------------------------------------
/assets/images/model2vec_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_logo.png
--------------------------------------------------------------------------------
/assets/images/model2vec_model_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram.png
--------------------------------------------------------------------------------
/assets/images/model2vec_model_diagram_transparant_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram_transparant_dark.png
--------------------------------------------------------------------------------
/assets/images/model2vec_model_diagram_transparant_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/model2vec_model_diagram_transparant_light.png
--------------------------------------------------------------------------------
/assets/images/sentences_per_second_vs_average_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/sentences_per_second_vs_average_score.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_accuracy_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v2.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_accuracy_v3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v3.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_accuracy_v4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_accuracy_v4.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_mteb_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_mteb_score_v2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score_v2.png
--------------------------------------------------------------------------------
/assets/images/speed_vs_mteb_score_v3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/speed_vs_mteb_score_v3.png
--------------------------------------------------------------------------------
/assets/images/training_speed_vs_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/training_speed_vs_score.png
--------------------------------------------------------------------------------
/assets/images/tutorial_ezlo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/assets/images/tutorial_ezlo.png
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | This directory contains the documentation for Model2Vec. The documentation is formatted in Markdown. The documentation is organized as follows:
4 | - [usage.md](https://github.com/MinishLab/model2vec/blob/main/docs/usage.md): This document provides a technical overview of how to use Model2Vec.
5 | - [integrations.md](https://github.com/MinishLab/model2vec/blob/main/docs/integrations.md): This document provides examples of how to use Model2Vec in various downstream libraries.
6 | - [what_is_model2vec.md](https://github.com/MinishLab/model2vec/blob/main/docs/what_is_model2vec.md): This document provides a high-level overview of how Model2Vec works.
7 |
--------------------------------------------------------------------------------
/docs/integrations.md:
--------------------------------------------------------------------------------
1 |
2 | # Integrations
3 |
4 | Model2Vec can be used in a variety of downstream libraries. This document provides examples of how to use Model2Vec in some of these libraries.
5 |
6 | ## Table of Contents
7 | - [Sentence Transformers](#sentence-transformers)
8 | - [LangChain](#langchain)
9 | - [Txtai](#txtai)
10 | - [Chonkie](#chonkie)
11 | - [Transformers.js](#transformersjs)
12 |
13 | ## Sentence Transformers
14 |
15 | Model2Vec can be used directly in [Sentence Transformers](https://github.com/UKPLab/sentence-transformers):
16 |
17 | The following code snippet shows how to load a Model2Vec model into a Sentence Transformer model:
18 | ```python
19 | from sentence_transformers import SentenceTransformer
20 |
21 | # Load a Model2Vec model from the Hub
22 | model = SentenceTransformer("minishlab/potion-base-8M")
23 | # Make embeddings
24 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
25 | ```
26 |
27 | The following code snippet shows how to distill a model directly into a Sentence Transformer model:
28 |
29 | ```python
30 | from sentence_transformers import SentenceTransformer
31 | from sentence_transformers.models import StaticEmbedding
32 |
33 | static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cpu", pca_dims=256)
34 | model = SentenceTransformer(modules=[static_embedding])
35 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
36 | ```
37 |
38 | For more documentation, please refer to the [Sentence Transformers documentation](https://sbert.net/docs/package_reference/sentence_transformer/models.html#sentence_transformers.models.StaticEmbedding).
39 |
40 |
41 | ## LangChain
42 |
43 | Model2Vec can be used in [LangChain](https://github.com/langchain-ai/langchain) using the `langchain-community` package. For more information, see the [LangChain Model2Vec docs](https://python.langchain.com/docs/integrations/text_embedding/model2vec/). The following code snippet shows how to use Model2Vec in LangChain after installing the `langchain-community` package with `pip install langchain-community`:
44 |
45 | ```python
46 | from langchain_community.embeddings import Model2vecEmbeddings
47 | from langchain_community.vectorstores import FAISS
48 | from langchain.schema import Document
49 |
50 | # Initialize a Model2Vec embedder
51 | embedder = Model2vecEmbeddings("minishlab/potion-base-8M")
52 |
53 | # Create some example texts
54 | texts = [
55 | "Enduring Stew",
56 | "Hearty Elixir",
57 | "Mighty Mushroom Risotto",
58 | "Spicy Meat Skewer",
59 | "Fruit Salad",
60 | ]
61 |
62 | # Embed the texts
63 | embeddings = embedder.embed_documents(texts)
64 |
65 | # Or, create a vector store and query it
66 | documents = [Document(page_content=text) for text in texts]
67 | vector_store = FAISS.from_documents(documents, embedder)
68 | query = "Risotto"
69 | query_vector = embedder.embed_query(query)
70 | retrieved_docs = vector_store.similarity_search_by_vector(query_vector, k=1)
71 | ```
72 |
73 | ## Txtai
74 |
75 | Model2Vec can be used in [txtai](https://github.com/neuml/txtai) for text embeddings, nearest-neighbors search, and any of the other functionalities that txtai offers. The following code snippet shows how to use Model2Vec in txtai after installing the `txtai` package (including the `vectors` dependency) with `pip install txtai[vectors]`:
76 |
77 | ```python
78 | from txtai import Embeddings
79 |
80 | # Load a model2vec model
81 | embeddings = Embeddings(path="minishlab/potion-base-8M", method="model2vec", backend="numpy")
82 |
83 | # Create some example texts
84 | texts = ["Enduring Stew", "Hearty Elixir", "Mighty Mushroom Risotto", "Spicy Meat Skewer", "Chilly Fruit Salad"]
85 |
86 | # Create embeddings for downstream tasks
87 | vectors = embeddings.batchtransform(texts)
88 |
89 | # Or create a nearest-neighbors index and search it
90 | embeddings.index(texts)
91 | result = embeddings.search("Risotto", 1)
92 | ```
93 |
94 | ## Chonkie
95 |
96 | Model2Vec is the default model for semantic chunking in [Chonkie](https://github.com/bhavnicksm/chonkie). To use Model2Vec for semantic chunking in Chonkie, simply install Chonkie with `pip install chonkie[semantic]` and use one of the `potion` models in the `SemanticChunker` class. The following code snippet shows how to use Model2Vec in Chonkie:
97 |
98 | ```python
99 | from chonkie import SDPMChunker
100 |
101 | # Create some example text to chunk
102 | text = "It's dangerous to go alone! Take this."
103 |
104 | # Initialize the SemanticChunker with a potion model
105 | chunker = SDPMChunker(
106 | embedding_model="minishlab/potion-base-8M",
107 | similarity_threshold=0.3
108 | )
109 |
110 | # Chunk the text
111 | chunks = chunker.chunk(text)
112 | ```
113 |
114 | ## Transformers.js
115 |
116 | To use a Model2Vec model in [transformers.js](https://github.com/huggingface/transformers.js), the following code snippet can be used as a starting point:
117 |
118 | ```javascript
119 | import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
120 |
121 | const modelName = 'minishlab/potion-base-8M';
122 |
123 | const modelConfig = {
124 | config: { model_type: 'model2vec' },
125 | dtype: 'fp32',
126 | revision: 'refs/pr/1'
127 | };
128 | const tokenizerConfig = {
129 | revision: 'refs/pr/2'
130 | };
131 |
132 | const model = await AutoModel.from_pretrained(modelName, modelConfig);
133 | const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
134 |
135 | const texts = ['hello', 'hello world'];
136 | const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
137 |
138 | const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
139 | const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
140 |
141 | const flattened_input_ids = input_ids.flat();
142 | const modelInputs = {
143 | input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
144 | offsets: new Tensor('int64', offsets, [offsets.length])
145 | };
146 |
147 | const { embeddings } = await model(modelInputs);
148 | console.log(embeddings.tolist()); // output matches python version
149 | ```
150 |
151 | Note that this requires that the Model2Vec has a `model.onnx` file and several required tokenizers file. To generate these for a model that does not have them yet, the following code snippet can be used:
152 |
153 | ```bash
154 | python scripts/export_to_onnx.py --model_path --save_path ""
155 | ```
156 |
--------------------------------------------------------------------------------
/docs/usage.md:
--------------------------------------------------------------------------------
1 |
2 | # Usage
3 |
4 | This document provides an overview of how to use Model2Vec for inference, distillation, training, and evaluation.
5 |
6 | ## Table of Contents
7 | - [Inference](#inference)
8 | - [Inference with a pretrained model](#inference-with-a-pretrained-model)
9 | - [Inference with the Sentence Transformers library](#inference-with-the-sentence-transformers-library)
10 | - [Distillation](#distillation)
11 | - [Distilling from a Sentence Transformer](#distilling-from-a-sentence-transformer)
12 | - [Distilling from a loaded model](#distilling-from-a-loaded-model)
13 | - [Distilling with the Sentence Transformers library](#distilling-with-the-sentence-transformers-library)
14 | - [Distilling with a custom vocabulary](#distilling-with-a-custom-vocabulary)
15 | - [Training](#training)
16 | - [Training a classifier](#training-a-classifier)
17 | - [Evaluation](#evaluation)
18 | - [Installation](#installation)
19 | - [Evaluation Code](#evaluation-code)
20 |
21 | ## Inference
22 |
23 | ### Inference with a pretrained model
24 |
25 | Inference works as follows. The example shows one of our own models, but you can also just load a local one, or another one from the hub.
26 | ```python
27 | from model2vec import StaticModel
28 |
29 | # Load a model from the Hub. You can optionally pass a token when loading a private model
30 | model = StaticModel.from_pretrained(model_name="minishlab/potion-base-8M", token=None)
31 |
32 | # Make embeddings
33 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
34 |
35 | # Make sequences of token embeddings
36 | token_embeddings = model.encode_as_sequence(["It's dangerous to go alone!", "It's a secret to everybody."])
37 | ```
38 |
39 | ### Inference with the Sentence Transformers library
40 |
41 | The following code snippet shows how to use a Model2Vec model in the [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) library. This is useful if you want to use the model in a Sentence Transformers pipeline.
42 |
43 | ```python
44 | from sentence_transformers import SentenceTransformer
45 |
46 | # Load a Model2Vec model from the Hub
47 | model = SentenceTransformer("minishlab/potion-base-8M")
48 |
49 | # Make embeddings
50 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
51 | ```
52 |
53 | ## Distillation
54 |
55 | ### Distilling from a Sentence Transformer
56 |
57 | The following code can be used to distill a model from a Sentence Transformer. As mentioned above, this leads to really small model that might be less performant.
58 | ```python
59 | from model2vec.distill import distill
60 |
61 | # Distill a Sentence Transformer model
62 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
63 |
64 | # Save the model
65 | m2v_model.save_pretrained("m2v_model")
66 |
67 | ```
68 |
69 | ### Distilling from a loaded model
70 |
71 | If you already have a model loaded, or need to load a model in some special way, we also offer an interface to distill models in memory.
72 |
73 | ```python
74 | from transformers import AutoModel, AutoTokenizer
75 |
76 | from model2vec.distill import distill_from_model
77 |
78 | # Assuming a loaded model and tokenizer
79 | model_name = "baai/bge-base-en-v1.5"
80 | model = AutoModel.from_pretrained(model_name)
81 | tokenizer = AutoTokenizer.from_pretrained(model_name)
82 |
83 | m2v_model = distill_from_model(model=model, tokenizer=tokenizer, pca_dims=256)
84 |
85 | m2v_model.save_pretrained("m2v_model")
86 |
87 | ```
88 |
89 | ### Distilling with the Sentence Transformers library
90 |
91 | The following code snippet shows how to distill a model using the [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) library. This is useful if you want to use the model in a Sentence Transformers pipeline.
92 |
93 | ```python
94 | from sentence_transformers import SentenceTransformer
95 | from sentence_transformers.models import StaticEmbedding
96 |
97 | static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cpu", pca_dims=256)
98 | model = SentenceTransformer(modules=[static_embedding])
99 | embeddings = model.encode(["It's dangerous to go alone!", "It's a secret to everybody."])
100 | ```
101 |
102 | ### Distilling with a custom vocabulary
103 |
104 | If you pass a vocabulary, you get a set of static word embeddings, together with a custom tokenizer for exactly that vocabulary. This is comparable to how you would use GLoVe or traditional word2vec, but doesn't actually require a corpus or data.
105 | ```python
106 | from model2vec.distill import distill
107 |
108 | # Load a vocabulary as a list of strings
109 | vocabulary = ["word1", "word2", "word3"]
110 |
111 | # Distill a Sentence Transformer model with the custom vocabulary
112 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", vocabulary=vocabulary)
113 |
114 | # Save the model
115 | m2v_model.save_pretrained("m2v_model")
116 |
117 | # Or push it to the hub
118 | m2v_model.push_to_hub("my_organization/my_model", token="")
119 | ```
120 |
121 | By default, this will distill a model with a subword tokenizer, combining the models (subword) vocab with the new vocabulary. If you want to get a word-level tokenizer instead (with only the passed vocabulary), the `use_subword` parameter can be set to `False`, e.g.:
122 |
123 | ```python
124 | m2v_model = distill(model_name=model_name, vocabulary=vocabulary, use_subword=False)
125 | ```
126 |
127 | **Important note:** we assume the passed vocabulary is sorted in rank frequency. i.e., we don't care about the actual word frequencies, but do assume that the most frequent word is first, and the least frequent word is last. If you're not sure whether this is case, set `apply_zipf` to `False`. This disables the weighting, but will also make performance a little bit worse.
128 |
129 | ### Quantization
130 |
131 | Models can be quantized to `float16` (default) or `int8` during distillation, or when loading from disk.
132 |
133 | ```python
134 | from model2vec.distill import distill
135 |
136 | # Distill a Sentence Transformer model and quantize is to int8
137 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", quantize_to="int8")
138 |
139 | # Save the model. This model is now 25% of the size of a normal model.
140 | m2v_model.save_pretrained("m2v_model")
141 | ```
142 |
143 | You can also quantize during loading.
144 |
145 | ```python
146 | from model2vec import StaticModel
147 |
148 | model = StaticModel.from_pretrained("minishlab/potion-base-8m", quantize_to="int8")
149 | ```
150 |
151 | ### Dimensionality reduction
152 |
153 | Because almost all Model2Vec models have been distilled using PCA, and because PCA explicitly orders dimensions from most informative to least informative, we can perform dimensionality reduction during loading. This is very similar to how matryoshka embeddings work.
154 |
155 | ```python
156 | from model2vec import StaticModel
157 |
158 | model = StaticModel.from_pretrained("minishlab/potion-base-8m", dimensionality=32)
159 |
160 | print(model.embedding.shape)
161 | # (29528, 32)
162 | ```
163 |
164 | ### Combining quantization and dimensionality reduction
165 |
166 | Combining these tricks can lead to extremely small models. For example, using this, we can reduce the size of `potion-base-8m`, which is now 30MB, to only 1MB:
167 |
168 | ```python
169 | model = StaticModel.from_pretrained("minishlab/potion-base-8m",
170 | dimensionality=32,
171 | quantize_to="int8")
172 | print(model.embedding.nbytes)
173 | # 944896 bytes = 944kb
174 | ```
175 |
176 | This should be enough to satisfy even the strongest hardware constraints.
177 |
178 | ## Training
179 |
180 | ### Training a classifier
181 |
182 | Model2Vec can be used to train a classifier on top of a distilled model. The following code snippet shows how to train a classifier on top of a distilled model. For more advanced usage, as well as results, please refer to the [training documentation](https://github.com/MinishLab/model2vec/blob/main/model2vec/train/README.md).
183 |
184 | ```python
185 | import numpy as np
186 | from datasets import load_dataset
187 | from model2vec.train import StaticModelForClassification
188 |
189 | # Initialize a classifier from a pre-trained model
190 | classifer = StaticModelForClassification.from_pretrained("minishlab/potion-base-8M")
191 |
192 | # Load a dataset
193 | ds = load_dataset("setfit/subj")
194 | train = ds["train"]
195 | test = ds["test"]
196 |
197 | X_train, y_train = train["text"], train["label"]
198 | X_test, y_test = test["text"], test["label"]
199 |
200 | # Train the classifier
201 | classifier.fit(X_train, y_train)
202 |
203 | # Evaluate the classifier
204 | y_hat = classifier.predict(X_test)
205 | accuracy = np.mean(np.array(y_hat) == np.array(y_test)) * 100
206 | ```
207 |
208 | ## Evaluation
209 |
210 | ### Installation
211 |
212 | Our models can be evaluated using our [evaluation package](https://github.com/MinishLab/evaluation). Install the evaluation package with:
213 |
214 | ```bash
215 | pip install git+https://github.com/MinishLab/evaluation.git@main
216 | ```
217 |
218 | ### Evaluation Code
219 |
220 | The following code snippet shows how to evaluate a Model2Vec model:
221 | ```python
222 | from model2vec import StaticModel
223 |
224 | from evaluation import CustomMTEB, get_tasks, parse_mteb_results, make_leaderboard, summarize_results
225 | from mteb import ModelMeta
226 |
227 | # Get all available tasks
228 | tasks = get_tasks()
229 | # Define the CustomMTEB object with the specified tasks
230 | evaluation = CustomMTEB(tasks=tasks)
231 |
232 | # Load the model
233 | model_name = "m2v_model"
234 | model = StaticModel.from_pretrained(model_name)
235 |
236 | # Optionally, add model metadata in MTEB format
237 | model.mteb_model_meta = ModelMeta(
238 | name=model_name, revision="no_revision_available", release_date=None, languages=None
239 | )
240 |
241 | # Run the evaluation
242 | results = evaluation.run(model, eval_splits=["test"], output_folder=f"results")
243 |
244 | # Parse the results and summarize them
245 | parsed_results = parse_mteb_results(mteb_results=results, model_name=model_name)
246 | task_scores = summarize_results(parsed_results)
247 |
248 | # Print the results in a leaderboard format
249 | print(make_leaderboard(task_scores))
250 | ```
251 |
--------------------------------------------------------------------------------
/docs/what_is_model2vec.md:
--------------------------------------------------------------------------------
1 | # What is Model2Vec?
2 |
3 | This document provides a high-level overview of how Model2Vec works.
4 |
5 | The base model2vec technique works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using SIF weighting (previously zipf weighting). During inference, we simply take the mean of all token embeddings occurring in a sentence.
6 |
7 | Our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062) are pre-trained using [tokenlearn](https://github.com/MinishLab/tokenlearn), a technique to pre-train model2vec distillation models. These models are created with the following steps:
8 | - **Distillation**: We distill a Model2Vec model from a Sentence Transformer model, using the method described above.
9 | - **Sentence Transformer inference**: We use the Sentence Transformer model to create mean embeddings for a large number of texts from a corpus.
10 | - **Training**: We train a model to minimize the cosine distance between the mean embeddings generated by the Sentence Transformer model and the mean embeddings generated by the Model2Vec model.
11 | - **Post-training re-regularization**: We re-regularize the trained embeddings by first performing PCA, and then weighting the embeddings using `smooth inverse frequency (SIF)` weighting using the following formula: `w = 1e-3 / (1e-3 + proba)`. Here, `proba` is the probability of the token in the corpus we used for training.
12 |
--------------------------------------------------------------------------------
/model2vec/__init__.py:
--------------------------------------------------------------------------------
1 | from model2vec.model import StaticModel
2 | from model2vec.version import __version__
3 |
4 | __all__ = ["StaticModel", "__version__"]
5 |
--------------------------------------------------------------------------------
/model2vec/distill/__init__.py:
--------------------------------------------------------------------------------
1 | from model2vec.utils import get_package_extras, importable
2 |
3 | _REQUIRED_EXTRA = "distill"
4 |
5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6 | importable(extra_dependency, _REQUIRED_EXTRA)
7 |
8 | from model2vec.distill.distillation import distill, distill_from_model
9 |
10 | __all__ = ["distill", "distill_from_model"]
11 |
--------------------------------------------------------------------------------
/model2vec/distill/distillation.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 | import re
6 | from typing import Optional, cast
7 |
8 | import numpy as np
9 | from huggingface_hub import model_info
10 | from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
11 |
12 | from model2vec.distill.inference import PCADimType, create_embeddings, post_process_embeddings
13 | from model2vec.distill.utils import select_optimal_device
14 | from model2vec.model import StaticModel
15 | from model2vec.quantization import DType, quantize_embeddings
16 | from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def distill_from_model(
22 | model: PreTrainedModel,
23 | tokenizer: PreTrainedTokenizerFast,
24 | vocabulary: list[str] | None = None,
25 | device: str | None = None,
26 | pca_dims: PCADimType = 256,
27 | apply_zipf: bool | None = None,
28 | sif_coefficient: float | None = 1e-4,
29 | token_remove_pattern: str | None = r"\[unused\d+\]",
30 | quantize_to: DType | str = DType.Float16,
31 | use_subword: bool | None = None,
32 | ) -> StaticModel:
33 | """
34 | Distill a staticmodel from a sentence transformer.
35 |
36 | This function creates a set of embeddings from a sentence transformer. It does this by doing either
37 | a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
38 |
39 | If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
40 | If you don't pass a vocabulary, we use the model's tokenizer directly.
41 |
42 | :param model: The model to use.
43 | :param tokenizer: The tokenizer to use.
44 | :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
45 | :param device: The device to use.
46 | :param pca_dims: The number of components to use for PCA.
47 | If this is None, we don't apply PCA.
48 | If this is 'auto', we don't reduce dimensionality, but still apply PCA.
49 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
50 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
51 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
52 | Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
53 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
54 | If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
55 | :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
56 | :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
57 | :return: A StaticModel
58 | :raises: ValueError if the vocabulary is empty after preprocessing.
59 |
60 | """
61 | if use_subword is not None:
62 | logger.warning(
63 | "The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
64 | )
65 | quantize_to = DType(quantize_to)
66 | backend_tokenizer = tokenizer.backend_tokenizer
67 | sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
68 |
69 | if vocabulary is None:
70 | vocabulary = []
71 |
72 | device = select_optimal_device(device)
73 |
74 | n_tokens_before = len(vocabulary)
75 | # Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
76 | all_tokens, backend_tokenizer = clean_and_create_vocabulary(
77 | tokenizer, vocabulary, token_remove_regex=token_remove_regex
78 | )
79 | n_tokens_after = len([token for token in all_tokens if not token.is_internal])
80 | if n_tokens_before:
81 | logger.info(
82 | f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
83 | )
84 |
85 | if not all_tokens:
86 | raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")
87 |
88 | unk_token = cast(Optional[str], tokenizer.special_tokens_map.get("unk_token"))
89 | pad_token = cast(Optional[str], tokenizer.special_tokens_map.get("pad_token"))
90 |
91 | # Weird if to satsify mypy
92 | if pad_token is None:
93 | if unk_token is not None:
94 | pad_token = unk_token
95 | logger.warning(
96 | "The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
97 | )
98 | else:
99 | pad_token = unk_token or all_tokens[0].form
100 | logger.warning(
101 | "The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
102 | )
103 |
104 | # Replace the vocabulary in the tokenizer with the new vocabulary.
105 | backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
106 |
107 | logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
108 | # Convert tokens to IDs
109 | token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
110 |
111 | # Create the embeddings
112 | embeddings = create_embeddings(
113 | tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
114 | )
115 |
116 | # Post process the embeddings by applying PCA and Zipf weighting.
117 | embeddings = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
118 | # Quantize the embeddings.
119 | embeddings = quantize_embeddings(embeddings, quantize_to)
120 |
121 | model_name = getattr(model, "name_or_path", "")
122 |
123 | config = {
124 | "model_type": "model2vec",
125 | "architectures": ["StaticModel"],
126 | "tokenizer_name": model_name,
127 | "apply_pca": pca_dims,
128 | "apply_zipf": apply_zipf,
129 | "sif_coefficient": sif_coefficient,
130 | "hidden_dim": embeddings.shape[1],
131 | "seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
132 | "normalize": True,
133 | }
134 |
135 | if os.path.exists(model_name):
136 | # Using a local model. Get the model name from the path.
137 | model_name = os.path.basename(model_name)
138 | language = None
139 | else:
140 | # Get the language from the model card.
141 | try:
142 | info = model_info(model_name)
143 | language = info.cardData.get("language", None) if info.cardData is not None else None
144 | except Exception as e:
145 | # NOTE: bare except because there's many reasons this can fail.
146 | logger.warning(f"Couldn't get the model info from the Hugging Face Hub: {e}. Setting language to None.")
147 | language = None
148 |
149 | return StaticModel(
150 | vectors=embeddings,
151 | tokenizer=backend_tokenizer,
152 | config=config,
153 | base_model_name=model_name,
154 | language=language,
155 | normalize=True,
156 | )
157 |
158 |
159 | def _validate_parameters(
160 | apply_zipf: bool | None,
161 | sif_coefficient: float | None,
162 | token_remove_pattern: str | None,
163 | ) -> tuple[float | None, re.Pattern | None]:
164 | """
165 | Validate the parameters passed to the distillation function.
166 |
167 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
168 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
169 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
170 | Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
171 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
172 | :return: The SIF coefficient to use.
173 | :raises: ValueError if the regex can't be compiled.
174 |
175 | """
176 | if apply_zipf is not None:
177 | logger.warning(
178 | "The `apply_zipf` parameter is deprecated and will be removed in the next release. "
179 | "Zipf weighting is applied based on the sif_coefficient parameter. If this is set to None, "
180 | "no weighting is applied."
181 | )
182 | if apply_zipf and sif_coefficient is None:
183 | logger.warning("You set apply_zipf to True, but sif_coefficient is None. Setting sif_coefficient to 1e-4.")
184 | sif_coefficient = 1e-4
185 | elif not apply_zipf:
186 | logger.warning("Because you set apply_zipf to False, we ignore the sif_coefficient parameter.")
187 | sif_coefficient = None
188 |
189 | if sif_coefficient is not None:
190 | if not 0 < sif_coefficient < 1.0:
191 | raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
192 |
193 | token_remove_regex: re.Pattern | None = None
194 | if token_remove_pattern is not None:
195 | try:
196 | token_remove_regex = re.compile(token_remove_pattern)
197 | except re.error as e:
198 | raise ValueError(f"Couldn't compile the regex pattern: {e}")
199 |
200 | return sif_coefficient, token_remove_regex
201 |
202 |
203 | def distill(
204 | model_name: str,
205 | vocabulary: list[str] | None = None,
206 | device: str | None = None,
207 | pca_dims: PCADimType = 256,
208 | apply_zipf: bool | None = None,
209 | sif_coefficient: float | None = 1e-4,
210 | token_remove_pattern: str | None = r"\[unused\d+\]",
211 | trust_remote_code: bool = False,
212 | quantize_to: DType | str = DType.Float16,
213 | use_subword: bool | None = None,
214 | ) -> StaticModel:
215 | """
216 | Distill a staticmodel from a sentence transformer.
217 |
218 | This function creates a set of embeddings from a sentence transformer. It does this by doing either
219 | a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
220 |
221 | If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
222 | If you don't pass a vocabulary, we use the model's tokenizer directly.
223 |
224 | :param model_name: The model name to use. Any sentencetransformer compatible model works.
225 | :param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
226 | :param device: The device to use.
227 | :param pca_dims: The number of components to use for PCA.
228 | If this is None, we don't apply PCA.
229 | If this is 'auto', we don't reduce dimenionality, but still apply PCA.
230 | :param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
231 | Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
232 | :param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
233 | Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
234 | :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
235 | :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
236 | :param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
237 | :param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
238 | :return: A StaticModel
239 |
240 | """
241 | model: PreTrainedModel = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code)
242 | tokenizer = cast(
243 | PreTrainedTokenizerFast,
244 | AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True),
245 | )
246 |
247 | return distill_from_model(
248 | model=model,
249 | tokenizer=tokenizer,
250 | vocabulary=vocabulary,
251 | device=device,
252 | pca_dims=pca_dims,
253 | apply_zipf=apply_zipf,
254 | token_remove_pattern=token_remove_pattern,
255 | sif_coefficient=sif_coefficient,
256 | quantize_to=quantize_to,
257 | use_subword=use_subword,
258 | )
259 |
--------------------------------------------------------------------------------
/model2vec/distill/inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import annotations
3 |
4 | import inspect
5 | import logging
6 | from pathlib import Path
7 | from typing import Literal, Protocol, Union
8 |
9 | import numpy as np
10 | import torch
11 | from sklearn.decomposition import PCA
12 | from torch.nn.utils.rnn import pad_sequence
13 | from tqdm import tqdm
14 | from transformers import PreTrainedModel
15 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | PathLike = Union[Path, str]
21 | PCADimType = Union[int, None, float, Literal["auto"]]
22 |
23 |
24 | _DEFAULT_BATCH_SIZE = 256
25 |
26 |
27 | class ModulewithWeights(Protocol):
28 | weight: torch.nn.Parameter
29 |
30 |
31 | def create_embeddings(
32 | model: PreTrainedModel,
33 | tokenized: list[list[int]],
34 | device: str,
35 | pad_token_id: int,
36 | ) -> np.ndarray:
37 | """
38 | Create output embeddings for a bunch of tokens using a pretrained model.
39 |
40 | It does a forward pass for all tokens passed in `tokens`.
41 |
42 | :param model: The model to use.
43 | This should be a transformers model.
44 | :param tokenized: All tokenized tokens.
45 | :param device: The torch device to use.
46 | :param pad_token_id: The pad token id. Used to pad sequences.
47 | :return: The output embeddings.
48 | """
49 | model = model.to(device)
50 |
51 | out_weights: np.ndarray
52 | intermediate_weights: list[np.ndarray] = []
53 |
54 | # Add token_type_ids only if the model supports it
55 | add_token_type_ids = "token_type_ids" in inspect.getfullargspec(model.forward).args
56 |
57 | lengths = np.asarray([len(sequence) for sequence in tokenized])
58 | sort_order = np.argsort(lengths)
59 |
60 | sorted_tokenized = [tokenized[i] for i in sort_order]
61 |
62 | pbar = tqdm(total=len(sorted_tokenized), desc="Encoding tokens", unit=" tokens")
63 |
64 | for batch_idx in range(0, len(sorted_tokenized), _DEFAULT_BATCH_SIZE):
65 | batch = [torch.Tensor(x).long() for x in sorted_tokenized[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE]]
66 |
67 | encoded = {}
68 | encoded["input_ids"] = pad_sequence(batch, batch_first=True, padding_value=pad_token_id)
69 | encoded["attention_mask"] = encoded["input_ids"] != pad_token_id
70 |
71 | if add_token_type_ids:
72 | encoded["token_type_ids"] = torch.zeros_like(encoded["input_ids"])
73 |
74 | out = _encode_mean_using_model(model, encoded)
75 | intermediate_weights.extend(out.numpy())
76 | pbar.update(len(batch))
77 |
78 | # Sort the output back to the original order
79 | intermediate_weights = [intermediate_weights[i] for i in np.argsort(sort_order)]
80 | out_weights = np.stack(intermediate_weights)
81 |
82 | out_weights = np.nan_to_num(out_weights)
83 |
84 | return out_weights
85 |
86 |
87 | @torch.no_grad()
88 | def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.Tensor]) -> torch.Tensor:
89 | """
90 | Encode a batch of tokens using a model.
91 |
92 | Note that if a token in the input batch does not have any embeddings, it will be output as a vector of zeros.
93 | So detection of these is necessary.
94 |
95 | :param model: The model to use.
96 | :param encodings: The encoded tokens to turn into features.
97 | :return: The mean of the output for each token.
98 | """
99 | encodings = {k: v.to(model.device) for k, v in encodings.items()}
100 | encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101 | out: torch.Tensor = encoded.last_hidden_state.cpu()
102 | # NOTE: If the dtype is bfloat 16, we convert to float32,
103 | # because numpy does not suport bfloat16
104 | # See here: https://github.com/numpy/numpy/issues/19808
105 | if out.dtype == torch.bfloat16:
106 | out = out.float()
107 |
108 | # Take the mean by averaging over the attention mask.
109 | mask = encodings["attention_mask"].cpu().float()
110 | mask /= mask.sum(1)[:, None]
111 |
112 | result = torch.bmm(mask[:, None, :].float(), out).squeeze(1)
113 |
114 | return result
115 |
116 |
117 | def post_process_embeddings(
118 | embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
119 | ) -> np.ndarray:
120 | """Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
121 | if pca_dims is not None:
122 | if pca_dims == "auto":
123 | pca_dims = embeddings.shape[1]
124 | if pca_dims > embeddings.shape[1]:
125 | logger.warning(
126 | f"PCA dimension ({pca_dims}) is larger than the number of dimensions in the embeddings ({embeddings.shape[1]}). "
127 | "Applying PCA, but not reducing dimensionality. Is this is not desired, please set `pca_dims` to None. "
128 | "Applying PCA will probably improve performance, so consider just leaving it."
129 | )
130 | pca_dims = embeddings.shape[1]
131 | if pca_dims >= embeddings.shape[0]:
132 | logger.warning(
133 | f"PCA dimension ({pca_dims}) is larger than the number of tokens in the vocabulary ({embeddings.shape[0]}). Not applying PCA."
134 | )
135 | elif pca_dims <= embeddings.shape[1]:
136 | if isinstance(pca_dims, float):
137 | logger.info(f"Applying PCA with {pca_dims} explained variance.")
138 | else:
139 | logger.info(f"Applying PCA with n_components {pca_dims}")
140 |
141 | orig_dims = embeddings.shape[1]
142 | p = PCA(n_components=pca_dims, svd_solver="full")
143 | embeddings = p.fit_transform(embeddings)
144 |
145 | if embeddings.shape[1] < orig_dims:
146 | explained_variance_ratio = np.sum(p.explained_variance_ratio_)
147 | explained_variance = np.sum(p.explained_variance_)
148 | logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
149 | logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
150 | logger.info(f"Explained variance: {explained_variance:.3f}.")
151 |
152 | if sif_coefficient is not None:
153 | logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
154 | inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
155 | proba = inv_rank / np.sum(inv_rank)
156 | embeddings *= (sif_coefficient / (sif_coefficient + proba))[:, None]
157 |
158 | return embeddings
159 |
--------------------------------------------------------------------------------
/model2vec/distill/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from logging import getLogger
4 |
5 | import torch
6 |
7 | logger = getLogger(__name__)
8 |
9 |
10 | def select_optimal_device(device: str | None) -> str:
11 | """
12 | Guess what your optimal device should be based on backend availability.
13 |
14 | If you pass a device, we just pass it through.
15 |
16 | :param device: The device to use. If this is not None you get back what you passed.
17 | :return: The selected device.
18 | """
19 | if device is None:
20 | if torch.cuda.is_available():
21 | device = "cuda"
22 | elif torch.backends.mps.is_available():
23 | device = "mps"
24 | else:
25 | device = "cpu"
26 | logger.info(f"Automatically selected device: {device}")
27 |
28 | return device
29 |
--------------------------------------------------------------------------------
/model2vec/hf_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import logging
5 | from pathlib import Path
6 | from typing import Any, cast
7 |
8 | import huggingface_hub
9 | import numpy as np
10 | import safetensors
11 | from huggingface_hub import ModelCard, ModelCardData
12 | from safetensors.numpy import save_file
13 | from tokenizers import Tokenizer
14 |
15 | from model2vec.utils import SafeOpenProtocol
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def save_pretrained(
21 | folder_path: Path,
22 | embeddings: np.ndarray,
23 | tokenizer: Tokenizer,
24 | config: dict[str, Any],
25 | create_model_card: bool = True,
26 | subfolder: str | None = None,
27 | **kwargs: Any,
28 | ) -> None:
29 | """
30 | Save a model to a folder.
31 |
32 | :param folder_path: The path to the folder.
33 | :param embeddings: The embeddings.
34 | :param tokenizer: The tokenizer.
35 | :param config: A metadata config.
36 | :param create_model_card: Whether to create a model card.
37 | :param subfolder: The subfolder to save the model in.
38 | :param **kwargs: Any additional arguments.
39 | """
40 | folder_path = folder_path / subfolder if subfolder else folder_path
41 | folder_path.mkdir(exist_ok=True, parents=True)
42 | save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
43 | tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
44 | json.dump(config, open(folder_path / "config.json", "w"), indent=4)
45 |
46 | # Create modules.json
47 | modules = [{"idx": 0, "name": "0", "path": ".", "type": "sentence_transformers.models.StaticEmbedding"}]
48 | if config.get("normalize"):
49 | # If normalize=True, add sentence_transformers.models.Normalize
50 | modules.append({"idx": 1, "name": "1", "path": "1_Normalize", "type": "sentence_transformers.models.Normalize"})
51 | json.dump(modules, open(folder_path / "modules.json", "w"), indent=4)
52 |
53 | logger.info(f"Saved model to {folder_path}")
54 |
55 | # Optionally create the model card
56 | if create_model_card:
57 | _create_model_card(folder_path, **kwargs)
58 |
59 |
60 | def _create_model_card(
61 | folder_path: Path,
62 | base_model_name: str = "unknown",
63 | license: str = "mit",
64 | language: list[str] | None = None,
65 | model_name: str | None = None,
66 | template_path: str = "modelcards/model_card_template.md",
67 | **kwargs: Any,
68 | ) -> None:
69 | """
70 | Create a model card and store it in the specified path.
71 |
72 | :param folder_path: The path where the model card will be stored.
73 | :param base_model_name: The name of the base model.
74 | :param license: The license to use.
75 | :param language: The language of the model.
76 | :param model_name: The name of the model to use in the Model Card.
77 | :param template_path: The path to the template.
78 | :param **kwargs: Additional metadata for the model card (e.g., model_name, base_model, etc.).
79 | """
80 | folder_path = Path(folder_path)
81 | model_name = model_name or folder_path.name
82 | full_path = Path(__file__).parent / template_path
83 |
84 | model_card_data = ModelCardData(
85 | model_name=model_name,
86 | base_model=base_model_name,
87 | license=license,
88 | language=language,
89 | tags=["embeddings", "static-embeddings", "sentence-transformers"],
90 | library_name="model2vec",
91 | **kwargs,
92 | )
93 | model_card = ModelCard.from_template(model_card_data, template_path=str(full_path))
94 | model_card.save(folder_path / "README.md")
95 |
96 |
97 | def load_pretrained(
98 | folder_or_repo_path: str | Path,
99 | subfolder: str | None = None,
100 | token: str | None = None,
101 | from_sentence_transformers: bool = False,
102 | ) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
103 | """
104 | Loads a pretrained model from a folder.
105 |
106 | :param folder_or_repo_path: The folder or repo path to load from.
107 | - If this is a local path, we will load from the local path.
108 | - If the local path is not found, we will attempt to load from the huggingface hub.
109 | :param subfolder: The subfolder to load from.
110 | :param token: The huggingface token to use.
111 | :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
112 | :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113 | :return: The embeddings, tokenizer, config, and metadata.
114 |
115 | """
116 | if from_sentence_transformers:
117 | model_file = "0_StaticEmbedding/model.safetensors"
118 | tokenizer_file = "0_StaticEmbedding/tokenizer.json"
119 | config_name = "config_sentence_transformers.json"
120 | else:
121 | model_file = "model.safetensors"
122 | tokenizer_file = "tokenizer.json"
123 | config_name = "config.json"
124 |
125 | folder_or_repo_path = Path(folder_or_repo_path)
126 |
127 | local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128 |
129 | if local_folder.exists():
130 | embeddings_path = local_folder / model_file
131 | if not embeddings_path.exists():
132 | raise FileNotFoundError(f"Embeddings file does not exist in {local_folder}")
133 |
134 | config_path = local_folder / config_name
135 | if not config_path.exists():
136 | raise FileNotFoundError(f"Config file does not exist in {local_folder}")
137 |
138 | tokenizer_path = local_folder / tokenizer_file
139 | if not tokenizer_path.exists():
140 | raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}")
141 |
142 | # README is optional, so this is a bit finicky.
143 | readme_path = local_folder / "README.md"
144 | metadata = _get_metadata_from_readme(readme_path)
145 |
146 | else:
147 | logger.info("Folder does not exist locally, attempting to use huggingface hub.")
148 | embeddings_path = Path(
149 | huggingface_hub.hf_hub_download(
150 | folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
151 | )
152 | )
153 |
154 | try:
155 | readme_path = Path(
156 | huggingface_hub.hf_hub_download(
157 | folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
158 | )
159 | )
160 | metadata = _get_metadata_from_readme(Path(readme_path))
161 | except Exception as e:
162 | # NOTE: we don't want to raise an error here, since the README is optional.
163 | logger.info(f"No README found in the model folder: {e} No model card loaded.")
164 | metadata = {}
165 |
166 | config_path = Path(
167 | huggingface_hub.hf_hub_download(
168 | folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
169 | )
170 | )
171 | tokenizer_path = Path(
172 | huggingface_hub.hf_hub_download(
173 | folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
174 | )
175 | )
176 |
177 | opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
178 | if from_sentence_transformers:
179 | embeddings = opened_tensor_file.get_tensor("embedding.weight")
180 | else:
181 | embeddings = opened_tensor_file.get_tensor("embeddings")
182 |
183 | tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
184 | config = json.load(open(config_path))
185 |
186 | if len(tokenizer.get_vocab()) != len(embeddings):
187 | logger.warning(
188 | f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
189 | )
190 |
191 | return embeddings, tokenizer, config, metadata
192 |
193 |
194 | def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
195 | """Get metadata from a README file."""
196 | if not readme_path.exists():
197 | logger.info(f"README file not found in {readme_path}. No model card loaded.")
198 | return {}
199 | model_card = ModelCard.load(readme_path)
200 | data: dict[str, Any] = model_card.data.to_dict()
201 | if not data:
202 | logger.info("File README.md exists, but was empty. No model card loaded.")
203 | return data
204 |
205 |
206 | def push_folder_to_hub(
207 | folder_path: Path, subfolder: str | None, repo_id: str, private: bool, token: str | None
208 | ) -> None:
209 | """
210 | Push a model folder to the huggingface hub, including model card.
211 |
212 | :param folder_path: The path to the folder.
213 | :param subfolder: The subfolder to push to.
214 | If None, the folder will be pushed to the root of the repo.
215 | :param repo_id: The repo name.
216 | :param private: Whether the repo is private.
217 | :param token: The huggingface token.
218 | """
219 | if not huggingface_hub.repo_exists(repo_id=repo_id, token=token):
220 | huggingface_hub.create_repo(repo_id, token=token, private=private)
221 |
222 | # Push model card and all model files to the Hugging Face hub
223 | huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
224 |
225 | logger.info(f"Pushed model to {repo_id}")
226 |
--------------------------------------------------------------------------------
/model2vec/inference/README.md:
--------------------------------------------------------------------------------
1 | # Inference
2 |
3 | This subpackage mainly contains helper functions for inference with trained models that have been exported to `scikit-learn` compatible pipelines.
4 |
5 | If you're looking for information on how to train a model, see [here](../train/README.md).
6 |
7 | # Usage
8 |
9 | Let's assume you're using our [potion-edu classifier](https://huggingface.co/minishlab/potion-8m-edu-classifier).
10 |
11 | ```python
12 | from model2vec.inference import StaticModelPipeline
13 |
14 | classifier = StaticModelPipeline.from_pretrained("minishlab/potion-8m-edu-classifier")
15 | label = classifier.predict("Attitudes towards cattle in the Alps: a study in letting go.")
16 | ```
17 |
18 | This should just work.
19 |
--------------------------------------------------------------------------------
/model2vec/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from model2vec.utils import get_package_extras, importable
2 |
3 | _REQUIRED_EXTRA = "inference"
4 |
5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6 | importable(extra_dependency, _REQUIRED_EXTRA)
7 |
8 | from model2vec.inference.model import StaticModelPipeline, evaluate_single_or_multi_label
9 |
10 | __all__ = ["StaticModelPipeline", "evaluate_single_or_multi_label"]
11 |
--------------------------------------------------------------------------------
/model2vec/inference/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import re
4 | from pathlib import Path
5 | from tempfile import TemporaryDirectory
6 | from typing import Sequence, TypeVar
7 |
8 | import huggingface_hub
9 | import numpy as np
10 | import skops.io
11 | from sklearn.metrics import classification_report
12 | from sklearn.neural_network import MLPClassifier
13 | from sklearn.pipeline import Pipeline
14 | from sklearn.preprocessing import MultiLabelBinarizer
15 |
16 | from model2vec.hf_utils import _create_model_card
17 | from model2vec.model import PathLike, StaticModel
18 |
19 | _DEFAULT_TRUST_PATTERN = re.compile(r"sklearn\..+")
20 | _DEFAULT_MODEL_FILENAME = "pipeline.skops"
21 |
22 | LabelType = TypeVar("LabelType", list[str], list[list[str]])
23 |
24 |
25 | class StaticModelPipeline:
26 | def __init__(self, model: StaticModel, head: Pipeline) -> None:
27 | """Create a pipeline with a StaticModel encoder."""
28 | self.model = model
29 | self.head = head
30 | classifier = self.head[-1]
31 | # Check if the classifier is a multilabel classifier.
32 | # NOTE: this doesn't look robust, but it is.
33 | # Different classifiers, such as OVR wrappers, support multilabel output natively, so we
34 | # can just use predict.
35 | self.multilabel = False
36 | if isinstance(classifier, MLPClassifier):
37 | if classifier.out_activation_ == "logistic":
38 | self.multilabel = True
39 |
40 | @property
41 | def classes_(self) -> np.ndarray:
42 | """The classes of the classifier."""
43 | return self.head.classes_
44 |
45 | @classmethod
46 | def from_pretrained(
47 | cls: type[StaticModelPipeline], path: PathLike, token: str | None = None, trust_remote_code: bool = False
48 | ) -> StaticModelPipeline:
49 | """
50 | Load a StaticModel from a local path or huggingface hub path.
51 |
52 | NOTE: if you load a private model from the huggingface hub, you need to pass a token.
53 |
54 | :param path: The path to the folder containing the pipeline, or a repository on the Hugging Face Hub
55 | :param token: The token to use to download the pipeline from the hub.
56 | :param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `sklearn`.
57 | :return: The loaded pipeline.
58 | """
59 | model, head = _load_pipeline(path, token, trust_remote_code)
60 | model.embedding = np.nan_to_num(model.embedding)
61 |
62 | return cls(model, head)
63 |
64 | def save_pretrained(self, path: str) -> None:
65 | """Save the model to a folder."""
66 | save_pipeline(self, path)
67 |
68 | def push_to_hub(
69 | self, repo_id: str, subfolder: str | None = None, token: str | None = None, private: bool = False
70 | ) -> None:
71 | """
72 | Save a model to a folder, and then push that folder to the hf hub.
73 |
74 | :param repo_id: The id of the repository to push to.
75 | :param subfolder: The subfolder to push to.
76 | :param token: The token to use to push to the hub.
77 | :param private: Whether the repository should be private.
78 | """
79 | from model2vec.hf_utils import push_folder_to_hub
80 |
81 | with TemporaryDirectory() as temp_dir:
82 | save_pipeline(self, temp_dir)
83 | self.model.save_pretrained(temp_dir)
84 | push_folder_to_hub(Path(temp_dir), subfolder, repo_id, private, token)
85 |
86 | def _encode_and_coerce_to_2d(
87 | self,
88 | X: Sequence[str],
89 | show_progress_bar: bool,
90 | max_length: int | None,
91 | batch_size: int,
92 | use_multiprocessing: bool,
93 | multiprocessing_threshold: int,
94 | ) -> np.ndarray:
95 | """Encode the instances and coerce the output to a matrix."""
96 | encoded = self.model.encode(
97 | X,
98 | show_progress_bar=show_progress_bar,
99 | max_length=max_length,
100 | batch_size=batch_size,
101 | use_multiprocessing=use_multiprocessing,
102 | multiprocessing_threshold=multiprocessing_threshold,
103 | )
104 | if np.ndim(encoded) == 1:
105 | encoded = encoded[None, :]
106 |
107 | return encoded
108 |
109 | def predict(
110 | self,
111 | X: Sequence[str],
112 | show_progress_bar: bool = False,
113 | max_length: int | None = 512,
114 | batch_size: int = 1024,
115 | use_multiprocessing: bool = True,
116 | multiprocessing_threshold: int = 10_000,
117 | threshold: float = 0.5,
118 | ) -> np.ndarray:
119 | """
120 | Predict the labels of the input.
121 |
122 | :param X: The input data to predict. Can be a list of strings or a single string.
123 | :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
124 | :param max_length: The maximum length of the input sequences. Defaults to 512.
125 | :param batch_size: The batch size for prediction. Defaults to 1024.
126 | :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
127 | :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
128 | :param threshold: The threshold for multilabel classification. Defaults to 0.5. Ignored if not multilabel.
129 | :return: The predicted labels or probabilities.
130 | """
131 | encoded = self._encode_and_coerce_to_2d(
132 | X,
133 | show_progress_bar=show_progress_bar,
134 | max_length=max_length,
135 | batch_size=batch_size,
136 | use_multiprocessing=use_multiprocessing,
137 | multiprocessing_threshold=multiprocessing_threshold,
138 | )
139 |
140 | if self.multilabel:
141 | out_labels = []
142 | proba = self.head.predict_proba(encoded)
143 | for vector in proba:
144 | out_labels.append(self.classes_[vector > threshold])
145 | return np.asarray(out_labels, dtype=object)
146 |
147 | return self.head.predict(encoded)
148 |
149 | def predict_proba(
150 | self,
151 | X: Sequence[str],
152 | show_progress_bar: bool = False,
153 | max_length: int | None = 512,
154 | batch_size: int = 1024,
155 | use_multiprocessing: bool = True,
156 | multiprocessing_threshold: int = 10_000,
157 | ) -> np.ndarray:
158 | """
159 | Predict the labels of the input.
160 |
161 | :param X: The input data to predict. Can be a list of strings or a single string.
162 | :param show_progress_bar: Whether to display a progress bar during prediction. Defaults to False.
163 | :param max_length: The maximum length of the input sequences. Defaults to 512.
164 | :param batch_size: The batch size for prediction. Defaults to 1024.
165 | :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
166 | :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
167 | :return: The predicted labels or probabilities.
168 | """
169 | encoded = self._encode_and_coerce_to_2d(
170 | X,
171 | show_progress_bar=show_progress_bar,
172 | max_length=max_length,
173 | batch_size=batch_size,
174 | use_multiprocessing=use_multiprocessing,
175 | multiprocessing_threshold=multiprocessing_threshold,
176 | )
177 |
178 | return self.head.predict_proba(encoded)
179 |
180 | def evaluate(
181 | self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
182 | ) -> str | dict[str, dict[str, float]]:
183 | """
184 | Evaluate the classifier on a given dataset using scikit-learn's classification report.
185 |
186 | :param X: The texts to predict on.
187 | :param y: The ground truth labels.
188 | :param batch_size: The batch size.
189 | :param threshold: The threshold for multilabel classification.
190 | :param output_dict: Whether to output the classification report as a dictionary.
191 | :return: A classification report.
192 | """
193 | predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
194 | report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
195 |
196 | return report
197 |
198 |
199 | def _load_pipeline(
200 | folder_or_repo_path: PathLike, token: str | None = None, trust_remote_code: bool = False
201 | ) -> tuple[StaticModel, Pipeline]:
202 | """
203 | Load a model and an sklearn pipeline.
204 |
205 | This assumes the following files are present in the repo:
206 | - `pipeline.skops`: The head of the pipeline.
207 | - `config.json`: The configuration of the model.
208 | - `model.safetensors`: The weights of the model.
209 | - `tokenizer.json`: The tokenizer of the model.
210 |
211 | :param folder_or_repo_path: The path to the folder containing the pipeline.
212 | :param token: The token to use to download the pipeline from the hub. If this is None, you will only
213 | be able to load the pipeline from a local folder, public repository, or a repository that you have access to
214 | because you are logged in.
215 | :param trust_remote_code: Whether to trust the remote code. If this is False,
216 | we will only load components coming from `sklearn`. If this is True, we will load all components.
217 | If you set this to True, you are responsible for whatever happens.
218 | :return: The encoder model and the loaded head
219 | :raises FileNotFoundError: If the pipeline file does not exist in the folder.
220 | :raises ValueError: If an untrusted type is found in the pipeline, and `trust_remote_code` is False.
221 | """
222 | folder_or_repo_path = Path(folder_or_repo_path)
223 | model_filename = _DEFAULT_MODEL_FILENAME
224 | head_pipeline_path: str | Path
225 | if folder_or_repo_path.exists():
226 | head_pipeline_path = folder_or_repo_path / model_filename
227 | if not head_pipeline_path.exists():
228 | raise FileNotFoundError(f"Pipeline file does not exist in {folder_or_repo_path}")
229 | else:
230 | head_pipeline_path = huggingface_hub.hf_hub_download(
231 | folder_or_repo_path.as_posix(), model_filename, token=token
232 | )
233 |
234 | model = StaticModel.from_pretrained(folder_or_repo_path)
235 |
236 | unknown_types = skops.io.get_untrusted_types(file=head_pipeline_path)
237 | # If the user does not trust remote code, we should check that the unknown types are trusted.
238 | # By default, we trust everything coming from scikit-learn.
239 | if not trust_remote_code:
240 | for t in unknown_types:
241 | if not _DEFAULT_TRUST_PATTERN.match(t):
242 | raise ValueError(f"Untrusted type {t}.")
243 | head = skops.io.load(head_pipeline_path, trusted=unknown_types)
244 |
245 | return model, head
246 |
247 |
248 | def save_pipeline(pipeline: StaticModelPipeline, folder_path: str | Path) -> None:
249 | """
250 | Save a pipeline to a folder.
251 |
252 | :param pipeline: The pipeline to save.
253 | :param folder_path: The path to the folder to save the pipeline to.
254 | """
255 | folder_path = Path(folder_path)
256 | folder_path.mkdir(parents=True, exist_ok=True)
257 | model_filename = _DEFAULT_MODEL_FILENAME
258 | head_pipeline_path = folder_path / model_filename
259 | skops.io.dump(pipeline.head, head_pipeline_path)
260 | pipeline.model.save_pretrained(folder_path)
261 | base_model_name = pipeline.model.base_model_name
262 | if isinstance(base_model_name, list) and base_model_name:
263 | name = base_model_name[0]
264 | elif isinstance(base_model_name, str):
265 | name = base_model_name
266 | else:
267 | name = "unknown"
268 | _create_model_card(
269 | folder_path,
270 | base_model_name=name,
271 | language=pipeline.model.language,
272 | template_path="modelcards/classifier_template.md",
273 | )
274 |
275 |
276 | def _is_multi_label_shaped(y: LabelType) -> bool:
277 | """Check if the labels are in a multi-label shape."""
278 | return isinstance(y, (list, tuple)) and len(y) > 0 and isinstance(y[0], (list, tuple, set))
279 |
280 |
281 | def evaluate_single_or_multi_label(
282 | predictions: np.ndarray,
283 | y: LabelType,
284 | output_dict: bool = False,
285 | ) -> str | dict[str, dict[str, float]]:
286 | """
287 | Evaluate the classifier on a given dataset using scikit-learn's classification report.
288 |
289 | :param predictions: The predictions.
290 | :param y: The ground truth labels.
291 | :param output_dict: Whether to output the classification report as a dictionary.
292 | :return: A classification report.
293 | """
294 | if _is_multi_label_shaped(y):
295 | classes = sorted(set([label for labels in y for label in labels]))
296 | mlb = MultiLabelBinarizer(classes=classes)
297 | y = mlb.fit_transform(y)
298 | predictions = mlb.transform(predictions)
299 | elif isinstance(y[0], (str, int)):
300 | classes = sorted(set(y))
301 |
302 | report = classification_report(
303 | y,
304 | predictions,
305 | output_dict=output_dict,
306 | zero_division=0,
307 | )
308 |
309 | return report
310 |
--------------------------------------------------------------------------------
/model2vec/modelcards/classifier_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 | # {{ model_name }} Model Card
6 |
7 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a fine-tuned version of {% if base_model %}the [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Model2Vec model. It also includes a classifier head on top.
8 |
9 | ## Installation
10 |
11 | Install model2vec using pip:
12 | ```
13 | pip install model2vec[inference]
14 | ```
15 |
16 | ## Usage
17 | Load this model using the `from_pretrained` method:
18 | ```python
19 | from model2vec.inference import StaticModelPipeline
20 |
21 | # Load a pretrained Model2Vec model
22 | model = StaticModelPipeline.from_pretrained("{{ model_name }}")
23 |
24 | # Predict labels
25 | predicted = model.predict(["Example sentence"])
26 | ```
27 |
28 | ## Additional Resources
29 |
30 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec)
31 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
32 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
33 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
34 | - [Website](https://minishlab.github.io/)
35 |
36 | ## Library Authors
37 |
38 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
39 |
40 | ## Citation
41 |
42 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
43 | ```
44 | @article{minishlab2024model2vec,
45 | author = {Tulkens, Stephan and {van Dongen}, Thomas},
46 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
47 | year = {2024},
48 | url = {https://github.com/MinishLab/model2vec}
49 | }
50 | ```
51 |
--------------------------------------------------------------------------------
/model2vec/modelcards/model_card_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 | # {{ model_name }} Model Card
6 |
7 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of {% if base_model %}the {{ base_model }}(https://huggingface.co/{{ base_model }}){% else %}a{% endif %} Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers.
8 |
9 |
10 | ## Installation
11 |
12 | Install model2vec using pip:
13 | ```
14 | pip install model2vec
15 | ```
16 |
17 | ## Usage
18 |
19 | ### Using Model2Vec
20 |
21 | The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models.
22 |
23 | Load this model using the `from_pretrained` method:
24 | ```python
25 | from model2vec import StaticModel
26 |
27 | # Load a pretrained Model2Vec model
28 | model = StaticModel.from_pretrained("{{ model_name }}")
29 |
30 | # Compute text embeddings
31 | embeddings = model.encode(["Example sentence"])
32 | ```
33 |
34 | ### Using Sentence Transformers
35 |
36 | You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model:
37 |
38 | ```python
39 | from sentence_transformers import SentenceTransformer
40 |
41 | # Load a pretrained Sentence Transformer model
42 | model = SentenceTransformer("{{ model_name }}")
43 |
44 | # Compute text embeddings
45 | embeddings = model.encode(["Example sentence"])
46 | ```
47 |
48 | ### Distilling a Model2Vec model
49 |
50 | You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code:
51 |
52 | ```python
53 | from model2vec.distill import distill
54 |
55 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
56 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
57 |
58 | # Save the model
59 | m2v_model.save_pretrained("m2v_model")
60 | ```
61 |
62 | ## How it works
63 |
64 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec.
65 |
66 | It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence.
67 |
68 | ## Additional Resources
69 |
70 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec)
71 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
72 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
73 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials)
74 | - [Website](https://minishlab.github.io/)
75 |
76 |
77 | ## Library Authors
78 |
79 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
80 |
81 | ## Citation
82 |
83 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
84 | ```
85 | @article{minishlab2024model2vec,
86 | author = {Tulkens, Stephan and {van Dongen}, Thomas},
87 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
88 | year = {2024},
89 | url = {https://github.com/MinishLab/model2vec}
90 | }
91 | ```
92 |
--------------------------------------------------------------------------------
/model2vec/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/model2vec/py.typed
--------------------------------------------------------------------------------
/model2vec/quantization.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from enum import Enum
4 |
5 | import numpy as np
6 |
7 |
8 | class DType(str, Enum):
9 | Float16 = "float16"
10 | Float32 = "float32"
11 | Float64 = "float64"
12 | Int8 = "int8"
13 |
14 |
15 | def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
16 | """
17 | Quantize embeddings to a specified data type to reduce memory usage.
18 |
19 | :param embeddings: The embeddings to quantize, as a numpy array.
20 | :param quantize_to: The data type to quantize to.
21 | :return: The quantized embeddings.
22 | :raises ValueError: If the quantization type is not valid.
23 | """
24 | if quantize_to == DType.Float16:
25 | return embeddings.astype(np.float16)
26 | elif quantize_to == DType.Float32:
27 | return embeddings.astype(np.float32)
28 | elif quantize_to == DType.Float64:
29 | return embeddings.astype(np.float64)
30 | elif quantize_to == DType.Int8:
31 | # Normalize to [-128, 127] range for int8
32 | # We normalize to -127 to 127 to keep symmetry.
33 | scale = np.max(np.abs(embeddings)) / 127.0
34 | quantized = np.round(embeddings / scale).astype(np.int8)
35 | return quantized
36 | else:
37 | raise ValueError("Not a valid enum member of DType.")
38 |
39 |
40 | def quantize_and_reduce_dim(
41 | embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None
42 | ) -> np.ndarray:
43 | """
44 | Quantize embeddings to a datatype and reduce dimensionality.
45 |
46 | :param embeddings: The embeddings to quantize and reduce, as a numpy array.
47 | :param quantize_to: The data type to quantize to. If None, no quantization is performed.
48 | :param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
49 | :return: The quantized and reduced embeddings.
50 | :raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
51 | """
52 | if quantize_to is not None:
53 | quantize_to = DType(quantize_to)
54 | embeddings = quantize_embeddings(embeddings, quantize_to)
55 |
56 | if dimensionality is not None:
57 | if dimensionality > embeddings.shape[1]:
58 | raise ValueError(
59 | f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
60 | )
61 | embeddings = embeddings[:, :dimensionality]
62 |
63 | return embeddings
64 |
--------------------------------------------------------------------------------
/model2vec/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from model2vec.utils import importable
2 |
3 | importable("transformers", "tokenizer")
4 |
5 | from model2vec.tokenizer.tokenizer import (
6 | clean_and_create_vocabulary,
7 | create_tokenizer,
8 | replace_vocabulary,
9 | turn_tokens_into_ids,
10 | )
11 |
12 | __all__ = ["clean_and_create_vocabulary", "create_tokenizer", "turn_tokens_into_ids", "replace_vocabulary"]
13 |
--------------------------------------------------------------------------------
/model2vec/tokenizer/datamodels.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class Token:
6 | """A class to represent a token."""
7 |
8 | form: str
9 | # The normalized and pretokenized form of the token
10 | normalized_form: str
11 | # Whether the word is a continuing subword.
12 | is_subword: bool
13 | # Whether the token is internal to the model.
14 | is_internal: bool
15 |
--------------------------------------------------------------------------------
/model2vec/tokenizer/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | import numpy as np
6 |
7 |
8 | def process_tokenizer(
9 | tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
10 | ) -> dict[str, Any]:
11 | """Process the WordPiece tokenizer JSON."""
12 | if tokenizer_json["model"]["type"] == "Unigram":
13 | return _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
14 | tokenizer_json["model"]["type"] = "Unigram"
15 | tokenizer_json["model"]["unk_id"] = pre_tokenized_tokens.index(unk_token) if unk_token else None
16 |
17 | token_weights = np.asarray([_calculate_token_weight_for_unigram(token) for token in pre_tokenized_tokens])
18 | proba = (token_weights / np.sum(token_weights)).tolist()
19 | tokenizer_json["model"]["vocab"] = [(token, np.log(p)) for token, p in zip(pre_tokenized_tokens, proba)]
20 |
21 | return tokenizer_json
22 |
23 |
24 | def _process_unigram(
25 | tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
26 | ) -> dict[str, Any]:
27 | """Process the Unigram tokenizer JSON."""
28 | current_probas = dict(tokenizer_json["model"]["vocab"])
29 | avg_proba = sum(current_probas.values()) / len(current_probas)
30 | new_probas = [[word, current_probas.get(word, avg_proba)] for word in pre_tokenized_tokens]
31 | tokenizer_json["model"]["vocab"] = new_probas
32 |
33 | tokens, _ = zip(*tokenizer_json["model"]["vocab"])
34 | if unk_token is not None:
35 | tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token)
36 |
37 | return tokenizer_json
38 |
39 |
40 | def _calculate_token_weight_for_unigram(token: str) -> float:
41 | """Calculate the token weight for Unigram."""
42 | # Always prefer longer tokens.
43 | return len(token) + token.count("▁") + token.count("Ġ")
44 |
--------------------------------------------------------------------------------
/model2vec/tokenizer/normalizer.py:
--------------------------------------------------------------------------------
1 | from string import punctuation
2 |
3 | from tokenizers import Regex, Tokenizer
4 | from tokenizers.normalizers import Normalizer, Replace, Sequence, Strip
5 |
6 |
7 | def replace_normalizer(
8 | tokenizer: Tokenizer,
9 | ) -> Tokenizer:
10 | """
11 | Replace the normalizer for the tokenizer.
12 |
13 | The new normalizer will replace punctuation with a space before and after the punctuation.
14 | It will also replace multiple spaces with a single space and strip the right side of the string.
15 | If the tokenizer already has a normalizer, it will be added to the new normalizer.
16 | If the tokenizer does not have a normalizer, a new normalizer will be created.
17 |
18 | :param tokenizer: The tokenizer to change.
19 | :return: The tokenizer with a replaced normalizer.
20 | """
21 | normalizer = tokenizer.normalizer
22 | new_normalizers = []
23 | for char in punctuation:
24 | new_normalizers.append(Replace(char, f" {char} "))
25 |
26 | new_normalizers.append(Replace(Regex(r"\s+"), " "))
27 | new_normalizers.append(Strip(right=True))
28 | if normalizer is None:
29 | normalizer = Sequence(new_normalizers) # type: ignore
30 | else:
31 | normalizer = Sequence([normalizer] + new_normalizers) # type: ignore
32 | tokenizer.normalizer = normalizer # type: ignore
33 |
34 | return tokenizer
35 |
--------------------------------------------------------------------------------
/model2vec/tokenizer/pretokenizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from typing import Any
5 |
6 | from tokenizers import Tokenizer
7 |
8 | _FORBIDDEN_PRETOKENIZERS = (
9 | "WhiteSpace",
10 | "WhitespaceSplit",
11 | "BertPreTokenizer",
12 | "CharDelimiterSplit",
13 | "Punctuation",
14 | "Split",
15 | "UnicodeScripts",
16 | )
17 | _BASIC_METASPACE = {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
18 |
19 |
20 | def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] | None:
21 | """Fixes a single pretokenizer to allow multiword units."""
22 | if pre_tokenizer["type"] in _FORBIDDEN_PRETOKENIZERS:
23 | return None
24 | if pre_tokenizer["type"] == "ByteLevel":
25 | pre_tokenizer["add_prefix_space"] = True
26 | pre_tokenizer["use_regex"] = False
27 | if pre_tokenizer["type"] == "Metaspace":
28 | pre_tokenizer["split"] = False
29 | pre_tokenizer["prepend_scheme"] = "always"
30 |
31 | return pre_tokenizer
32 |
33 |
34 | def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer:
35 | """Fixes a single pretokenizer to allow multiword units."""
36 | tokenizer_json = json.loads(tokenizer.to_str())
37 | pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None)
38 |
39 | if pre_tokenizer_json is None:
40 | pre_tokenizer_json = _BASIC_METASPACE
41 |
42 | elif pre_tokenizer_json["type"] == "Sequence":
43 | new_pretokenizers = []
44 | for single_pretokenizer in pre_tokenizer_json["pretokenizers"]:
45 | new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer)
46 | if new_pretokenizer is not None:
47 | new_pretokenizers.append(new_pretokenizer)
48 |
49 | if new_pretokenizers:
50 | pre_tokenizer_json["pretokenizers"] = new_pretokenizers
51 | else:
52 | pre_tokenizer_json = _BASIC_METASPACE
53 |
54 | pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE
55 | tokenizer_json["pre_tokenizer"] = pre_tokenizer_json
56 |
57 | return tokenizer.from_str(json.dumps(tokenizer_json))
58 |
--------------------------------------------------------------------------------
/model2vec/train/README.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
4 |
5 | We support both single and multi-label classification, which work seamlessly based on the labels you provide.
6 |
7 | # Installation
8 |
9 | To train, make sure you install the training extra:
10 |
11 | ```
12 | pip install model2vec[training]
13 | ```
14 |
15 | # Quickstart
16 |
17 | To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows:
18 |
19 | ```python
20 | from model2vec.distill import distill
21 | from model2vec.train import StaticModelForClassification
22 |
23 | # From a distilled model
24 | distilled_model = distill("baai/bge-base-en-v1.5")
25 | classifier = StaticModelForClassification.from_static_model(model=distilled_model)
26 |
27 | # From a pre-trained model: potion is the default
28 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")
29 | ```
30 |
31 | This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data.
32 |
33 | Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed.
34 |
35 | ```python
36 | import numpy as np
37 | from datasets import load_dataset
38 |
39 | # Load the subj dataset
40 | ds = load_dataset("setfit/subj")
41 | train = ds["train"]
42 | test = ds["test"]
43 |
44 | s = perf_counter()
45 | classifier = classifier.fit(train["text"], train["label"])
46 |
47 | print(f"Training took {int(perf_counter() - s)} seconds.")
48 | # Training took 81 seconds
49 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
50 | print(classification_report)
51 | # Achieved 91.0 test accuracy
52 | ```
53 |
54 | As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.
55 |
56 | The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.
57 |
58 | Note that this model is as fast as you're used to from us:
59 |
60 | ```python
61 | from time import perf_counter
62 |
63 | s = perf_counter()
64 | classifier.predict(test["text"])
65 | print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
66 | # Took 67 milliseconds for 2000 instances on CPU.
67 | ```
68 |
69 | ## Multi-label classification
70 |
71 | Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
72 |
73 | ```python
74 | from datasets import load_dataset
75 | from model2vec.train import StaticModelForClassification
76 |
77 | # Initialize a classifier from a pre-trained model
78 | classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
79 |
80 | # Load a multi-label dataset
81 | ds = load_dataset("google-research-datasets/go_emotions")
82 |
83 | # Inspect some of the labels
84 | print(ds["train"]["labels"][40:50])
85 | # [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
86 |
87 | # Train the classifier on text (X) and labels (y)
88 | classifier.fit(ds["train"]["text"], ds["train"]["labels"])
89 | ```
90 |
91 | Then, we can evaluate the classifier:
92 |
93 | ```python
94 | from sklearn import metrics
95 | from sklearn.preprocessing import MultiLabelBinarizer
96 |
97 | classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
98 | print(classification_report)
99 | # Accuracy: 0.410
100 | # Precision: 0.527
101 | # Recall: 0.410
102 | # F1: 0.439
103 | ```
104 |
105 | The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
106 |
107 | # Persistence
108 |
109 | You can turn a classifier into a scikit-learn compatible pipeline, as follows:
110 |
111 | ```python
112 | pipeline = classifier.to_pipeline()
113 | ```
114 |
115 | This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization.
116 |
117 | If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:
118 |
119 | ```python
120 | pipeline.save_pretrained(path)
121 | pipeline.push_to_hub("my_cool/project")
122 | ```
123 |
124 | Later, you can load these as follows:
125 |
126 | ```python
127 | from model2vec.inference import StaticModelPipeline
128 |
129 | pipeline = StaticModelPipeline.from_pretrained("my_cool/project")
130 | ```
131 |
132 | Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk.
133 |
134 |
135 | # Bring your own architecture
136 |
137 | Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc.
138 |
139 | The core functionality of the `StaticModelForClassification` is contained in a couple of functions:
140 |
141 | * `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.
142 | * `train_test_split`: governs the train test split before classification.
143 | * `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training.
144 | * `_encode`: The encoding function used in the model.
145 | * `fit`: contains all the lightning-related fitting logic.
146 |
147 | The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic.
148 |
149 | # Results
150 |
151 | We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation.
152 |
--------------------------------------------------------------------------------
/model2vec/train/__init__.py:
--------------------------------------------------------------------------------
1 | from model2vec.utils import get_package_extras, importable
2 |
3 | _REQUIRED_EXTRA = "train"
4 |
5 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
6 | importable(extra_dependency, _REQUIRED_EXTRA)
7 |
8 | from model2vec.train.classifier import StaticModelForClassification
9 |
10 | __all__ = ["StaticModelForClassification"]
11 |
--------------------------------------------------------------------------------
/model2vec/train/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from typing import Any, TypeVar
5 |
6 | import numpy as np
7 | import torch
8 | from tokenizers import Encoding, Tokenizer
9 | from torch import nn
10 | from torch.nn.utils.rnn import pad_sequence
11 | from torch.utils.data import DataLoader, Dataset
12 |
13 | from model2vec import StaticModel
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class FinetunableStaticModel(nn.Module):
19 | def __init__(self, *, vectors: torch.Tensor, tokenizer: Tokenizer, out_dim: int = 2, pad_id: int = 0) -> None:
20 | """
21 | Initialize a trainable StaticModel from a StaticModel.
22 |
23 | :param vectors: The embeddings of the staticmodel.
24 | :param tokenizer: The tokenizer.
25 | :param out_dim: The output dimension of the head.
26 | :param pad_id: The padding id. This is set to 0 in almost all model2vec models
27 | """
28 | super().__init__()
29 | self.pad_id = pad_id
30 | self.out_dim = out_dim
31 | self.embed_dim = vectors.shape[1]
32 |
33 | self.vectors = vectors
34 | if self.vectors.dtype != torch.float32:
35 | dtype = str(self.vectors.dtype)
36 | logger.warning(
37 | f"Your vectors are {dtype} precision, converting to to torch.float32 to avoid compatibility issues."
38 | )
39 | self.vectors = vectors.float()
40 |
41 | self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
42 | self.head = self.construct_head()
43 | self.w = self.construct_weights()
44 | self.tokenizer = tokenizer
45 |
46 | def construct_weights(self) -> nn.Parameter:
47 | """Construct the weights for the model."""
48 | weights = torch.zeros(len(self.vectors))
49 | weights[self.pad_id] = -10_000
50 | return nn.Parameter(weights)
51 |
52 | def construct_head(self) -> nn.Sequential:
53 | """Method should be overridden for various other classes."""
54 | return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
55 |
56 | @classmethod
57 | def from_pretrained(
58 | cls: type[ModelType], *, out_dim: int = 2, model_name: str = "minishlab/potion-base-32m", **kwargs: Any
59 | ) -> ModelType:
60 | """Load the model from a pretrained model2vec model."""
61 | model = StaticModel.from_pretrained(model_name)
62 | return cls.from_static_model(model=model, out_dim=out_dim, **kwargs)
63 |
64 | @classmethod
65 | def from_static_model(cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, **kwargs: Any) -> ModelType:
66 | """Load the model from a static model."""
67 | model.embedding = np.nan_to_num(model.embedding)
68 | embeddings_converted = torch.from_numpy(model.embedding)
69 | return cls(
70 | vectors=embeddings_converted,
71 | pad_id=model.tokenizer.token_to_id("[PAD]"),
72 | out_dim=out_dim,
73 | tokenizer=model.tokenizer,
74 | **kwargs,
75 | )
76 |
77 | def _encode(self, input_ids: torch.Tensor) -> torch.Tensor:
78 | """
79 | A forward pass and mean pooling.
80 |
81 | This function is analogous to `StaticModel.encode`, but reimplemented to allow gradients
82 | to pass through.
83 |
84 | :param input_ids: A 2D tensor of input ids. All input ids are have to be within bounds.
85 | :return: The mean over the input ids, weighted by token weights.
86 | """
87 | w = self.w[input_ids]
88 | w = torch.sigmoid(w)
89 | zeros = (input_ids != self.pad_id).float()
90 | w = w * zeros
91 | # Add a small epsilon to avoid division by zero
92 | length = zeros.sum(1) + 1e-16
93 | embedded = self.embeddings(input_ids)
94 | # Weigh each token
95 | embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
96 | # Mean pooling by dividing by the length
97 | embedded = embedded / length[:, None]
98 |
99 | return nn.functional.normalize(embedded)
100 |
101 | def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
102 | """Forward pass through the mean, and a classifier layer after."""
103 | encoded = self._encode(input_ids)
104 | return self.head(encoded), encoded
105 |
106 | def tokenize(self, texts: list[str], max_length: int | None = 512) -> torch.Tensor:
107 | """
108 | Tokenize a bunch of strings into a single padded 2D tensor.
109 |
110 | Note that this is not used during training.
111 |
112 | :param texts: The texts to tokenize.
113 | :param max_length: If this is None, the sequence lengths are truncated to 512.
114 | :return: A 2D padded tensor
115 | """
116 | encoded: list[Encoding] = self.tokenizer.encode_batch_fast(texts, add_special_tokens=False)
117 | encoded_ids: list[torch.Tensor] = [torch.Tensor(encoding.ids[:max_length]).long() for encoding in encoded]
118 | return pad_sequence(encoded_ids, batch_first=True, padding_value=self.pad_id)
119 |
120 | @property
121 | def device(self) -> str:
122 | """Get the device of the model."""
123 | return self.embeddings.weight.device
124 |
125 | def to_static_model(self) -> StaticModel:
126 | """Convert the model to a static model."""
127 | emb = self.embeddings.weight.detach().cpu().numpy()
128 | w = torch.sigmoid(self.w).detach().cpu().numpy()
129 |
130 | return StaticModel(emb * w[:, None], self.tokenizer, normalize=True)
131 |
132 |
133 | class TextDataset(Dataset):
134 | def __init__(self, tokenized_texts: list[list[int]], targets: torch.Tensor) -> None:
135 | """
136 | A dataset of texts.
137 |
138 | :param tokenized_texts: The tokenized texts. Each text is a list of token ids.
139 | :param targets: The targets.
140 | :raises ValueError: If the number of labels does not match the number of texts.
141 | """
142 | if len(targets) != len(tokenized_texts):
143 | raise ValueError("Number of labels does not match number of texts.")
144 | self.tokenized_texts = tokenized_texts
145 | self.targets = targets
146 |
147 | def __len__(self) -> int:
148 | """Return the length of the dataset."""
149 | return len(self.tokenized_texts)
150 |
151 | def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
152 | """Gets an item."""
153 | return self.tokenized_texts[index], self.targets[index]
154 |
155 | @staticmethod
156 | def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
157 | """Collate function."""
158 | texts, targets = zip(*batch)
159 |
160 | tensors = [torch.LongTensor(x) for x in texts]
161 | padded = pad_sequence(tensors, batch_first=True, padding_value=0)
162 |
163 | return padded, torch.stack(targets)
164 |
165 | def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
166 | """Convert the dataset to a DataLoader."""
167 | return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
168 |
169 |
170 | ModelType = TypeVar("ModelType", bound=FinetunableStaticModel)
171 |
--------------------------------------------------------------------------------
/model2vec/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import annotations
3 |
4 | import json
5 | import logging
6 | import re
7 | from importlib import import_module
8 | from importlib.metadata import metadata
9 | from pathlib import Path
10 | from typing import Any, Iterator, Protocol, cast
11 |
12 | import numpy as np
13 | import safetensors
14 | from joblib import Parallel
15 | from tokenizers import Tokenizer
16 | from tqdm import tqdm
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class ProgressParallel(Parallel):
22 | """A drop-in replacement for joblib.Parallel that shows a tqdm progress bar."""
23 |
24 | def __init__(self, use_tqdm: bool = True, total: int | None = None, *args: Any, **kwargs: Any) -> None:
25 | """
26 | Initialize the ProgressParallel object.
27 |
28 | :param use_tqdm: Whether to show the progress bar.
29 | :param total: Total number of tasks (batches) you expect to process. If None,
30 | it updates the total dynamically to the number of dispatched tasks.
31 | :param *args: Additional arguments to pass to `Parallel.__init__`.
32 | :param **kwargs: Additional keyword arguments to pass to `Parallel.__init__`.
33 | """
34 | self._use_tqdm = use_tqdm
35 | self._total = total
36 | super().__init__(*args, **kwargs)
37 |
38 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
39 | """Create a tqdm context."""
40 | with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
41 | self._pbar = self._pbar
42 | return super().__call__(*args, **kwargs)
43 |
44 | def print_progress(self) -> None:
45 | """Hook called by joblib as tasks complete. We update the tqdm bar here."""
46 | if self._total is None:
47 | # If no fixed total was given, we dynamically set the total
48 | self._pbar.total = self.n_dispatched_tasks
49 | # Move the bar to the number of completed tasks
50 | self._pbar.n = self.n_completed_tasks
51 | self._pbar.refresh()
52 |
53 |
54 | class SafeOpenProtocol(Protocol):
55 | """Protocol to fix safetensors safe open."""
56 |
57 | def get_tensor(self, key: str) -> np.ndarray:
58 | """Get a tensor."""
59 | ... # pragma: no cover
60 |
61 |
62 | _MODULE_MAP = (("scikit-learn", "sklearn"),)
63 | _DIVIDERS = re.compile(r"[=<>!]+")
64 |
65 |
66 | def get_package_extras(package: str, extra: str) -> Iterator[str]:
67 | """Get the extras of the package."""
68 | try:
69 | message = metadata(package)
70 | except Exception as e:
71 | raise ImportError(f"Could not retrieve metadata for package '{package}': {e}")
72 |
73 | all_packages = message.get_all("Requires-Dist") or []
74 | for package in all_packages:
75 | name, *rest = package.split(";", maxsplit=1)
76 | if rest:
77 | # Extract and clean the extra requirement
78 | found_extra = rest[0].split("==")[-1].strip(" \"'")
79 | if found_extra == extra:
80 | prefix, *_ = _DIVIDERS.split(name)
81 | yield prefix.strip()
82 |
83 |
84 | def importable(module: str, extra: str) -> None:
85 | """Check if a module is importable."""
86 | module = dict(_MODULE_MAP).get(module, module)
87 | try:
88 | import_module(module)
89 | except ImportError:
90 | raise ImportError(
91 | f"`{module}`, is required. Please reinstall model2vec with the `{extra}` extra. `pip install model2vec[{extra}]`"
92 | )
93 |
94 |
95 | def setup_logging() -> None:
96 | """Simple logging setup."""
97 | from rich.logging import RichHandler
98 |
99 | logging.basicConfig(
100 | level="INFO",
101 | format="%(name)s - %(message)s",
102 | datefmt="%Y-%m-%d %H:%M:%S",
103 | handlers=[RichHandler(rich_tracebacks=True)],
104 | )
105 |
106 |
107 | def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
108 | """Load a local model."""
109 | embeddings_path = folder / "model.safetensors"
110 | tokenizer_path = folder / "tokenizer.json"
111 | config_path = folder / "config.json"
112 |
113 | opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
114 | embeddings = opened_tensor_file.get_tensor("embeddings")
115 |
116 | if config_path.exists():
117 | config = json.load(open(config_path))
118 | else:
119 | config = {}
120 |
121 | tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
122 |
123 | if len(tokenizer.get_vocab()) != len(embeddings):
124 | logger.warning(
125 | f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
126 | )
127 |
128 | return embeddings, tokenizer, config
129 |
--------------------------------------------------------------------------------
/model2vec/version.py:
--------------------------------------------------------------------------------
1 | __version_triple__ = (0, 6, 0)
2 | __version__ = ".".join(map(str, __version_triple__))
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "model2vec"
3 | description = "Fast State-of-the-Art Static Embeddings"
4 | readme = { file = "README.md", content-type = "text/markdown" }
5 | license = { file = "LICENSE" }
6 | requires-python = ">=3.9"
7 | authors = [{ name = "Stéphan Tulkens", email = "stephantul@gmail.com"}, {name = "Thomas van Dongen", email = "thomas123@live.nl"}]
8 | dynamic = ["version"]
9 |
10 | classifiers = [
11 | "Development Status :: 4 - Beta",
12 | "Intended Audience :: Developers",
13 | "Intended Audience :: Science/Research",
14 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
15 | "Topic :: Software Development :: Libraries",
16 | "License :: OSI Approved :: MIT License",
17 | "Programming Language :: Python :: 3 :: Only",
18 | "Programming Language :: Python :: 3.9",
19 | "Programming Language :: Python :: 3.10",
20 | "Programming Language :: Python :: 3.11",
21 | "Programming Language :: Python :: 3.12",
22 | "Natural Language :: English",
23 | ]
24 |
25 | dependencies = [
26 | "jinja2",
27 | "joblib",
28 | "numpy",
29 | "rich",
30 | "safetensors",
31 | "setuptools",
32 | "tokenizers>=0.20",
33 | "tqdm",
34 | ]
35 |
36 | [build-system]
37 | requires = ["setuptools>=64", "setuptools_scm>=8"]
38 | build-backend = "setuptools.build_meta"
39 |
40 | [tool.setuptools]
41 | packages = ["model2vec"]
42 | include-package-data = true
43 |
44 | [tool.setuptools.package-data]
45 | model2vec = [
46 | "assets/modelcards/model_card_template.md",
47 | "assets/modelcards/classifier_template.md",
48 | "py.typed"
49 | ]
50 |
51 | [project.optional-dependencies]
52 | dev = [
53 | "black",
54 | "ipython",
55 | "mypy",
56 | "pre-commit",
57 | "pytest",
58 | "pytest-cov",
59 | "ruff",
60 | ]
61 |
62 | distill = ["torch", "transformers<=4.52.1", "scikit-learn"]
63 | onnx = ["onnx", "torch"]
64 | # train also installs inference
65 | train = ["torch", "lightning", "scikit-learn", "skops"]
66 | inference = ["scikit-learn", "skops"]
67 | tokenizer = ["transformers"]
68 |
69 | [project.urls]
70 | "Homepage" = "https://github.com/MinishLab"
71 | "Bug Reports" = "https://github.com/MinishLab/model2vec/issues"
72 | "Source" = "https://github.com/MinishLab/model2vec"
73 |
74 | [tool.ruff]
75 | exclude = [".venv/"]
76 | line-length = 120
77 | target-version = "py310"
78 |
79 | [tool.ruff.lint]
80 | select = [
81 | # Annotations: Enforce type annotations
82 | "ANN",
83 | # Complexity: Enforce a maximum cyclomatic complexity
84 | "C90",
85 | # Pydocstyle: Enforce docstrings
86 | "D",
87 | # Isort: Enforce import order
88 | "I",
89 | # Numpy: Enforce numpy style
90 | "NPY",
91 | # Print: Forbid print statements
92 | "T20",
93 | ]
94 |
95 | ignore = [
96 | # Allow self and cls to be untyped, and allow Any type
97 | "ANN101", "ANN102", "ANN401",
98 | # Pydocstyle ignores
99 | "D100", "D101", "D104", "D203", "D212", "D401",
100 | # Allow use of f-strings in logging
101 | "G004"
102 | ]
103 |
104 | [tool.pydoclint]
105 | style = "sphinx"
106 | exclude = "test_"
107 | allow-init-docstring = true
108 | arg-type-hints-in-docstring = false
109 | check-return-types = false
110 | require-return-section-when-returning-nothing = false
111 |
112 | [tool.mypy]
113 | python_version = "3.10"
114 | warn_unused_configs = true
115 | ignore_missing_imports = true
116 |
117 | [tool.setuptools_scm]
118 | # can be empty if no extra settings are needed, presence enables setuptools_scm
119 |
120 | [tool.setuptools.dynamic]
121 | version = {attr = "model2vec.version.__version__"}
122 |
--------------------------------------------------------------------------------
/results/make_speed_vs_mteb_plot.py:
--------------------------------------------------------------------------------
1 | """Script to benchmark the speed of various text embedding models and generate a plot of the MTEB score vs samples per second."""
2 |
3 | import argparse
4 | import json
5 | import logging
6 | from pathlib import Path
7 | from time import perf_counter
8 | from typing import Any
9 |
10 | import numpy as np
11 | import pandas as pd
12 | from bpemb import BPEmb
13 | from datasets import load_dataset
14 | from plotnine import (
15 | aes,
16 | element_line,
17 | geom_point,
18 | geom_text,
19 | ggplot,
20 | guides,
21 | labs,
22 | scale_size,
23 | scale_y_continuous,
24 | theme,
25 | theme_classic,
26 | xlim,
27 | ylim,
28 | )
29 | from sentence_transformers import SentenceTransformer
30 |
31 | from model2vec import StaticModel
32 |
33 | logging.basicConfig(level=logging.INFO)
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 |
38 | class BPEmbEmbedder:
39 | def __init__(self, vs: int = 50_000, dim: int = 300) -> None:
40 | """Initialize the BPEmbEmbedder."""
41 | self.bpemb_en = BPEmb(lang="en", vs=vs, dim=dim)
42 |
43 | def mean_sentence_embedding(self, sentence: str) -> np.ndarray:
44 | """Encode a sentence to a mean embedding."""
45 | encoded_ids = self.bpemb_en.encode_ids(sentence)
46 | embeddings = self.bpemb_en.vectors[encoded_ids]
47 | if embeddings.size == 0:
48 | return np.zeros(self.bpemb_en.dim) # Return a zero vector if no tokens are found
49 | return embeddings.mean(axis=0)
50 |
51 | def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
52 | """Encode a list of sentences to embeddings."""
53 | return np.array([self.mean_sentence_embedding(sentence.lower()) for sentence in sentences])
54 |
55 |
56 | def make_plot(df: pd.DataFrame) -> ggplot:
57 | """Create a plot of the MTEB score vs samples per second."""
58 | df["label_y"] = (
59 | df["Average score"]
60 | + 0.5 # a constant "base" offset for all bubbles
61 | + 0.08 * np.sqrt(df["Params (Million)"])
62 | )
63 | plot = (
64 | ggplot(df, aes(x="Samples per second", y="Average score"))
65 | + geom_point(aes(size="Params (Million)", color="Model"))
66 | + geom_text(aes(y="label_y", label="Model"), color="black", size=7)
67 | + scale_size(range=(2, 30))
68 | + theme_classic()
69 | + labs(title="Average MTEB Score vs Samples per Second")
70 | + ylim(df["Average score"].min(), df["Average score"].max() + 3)
71 | + scale_y_continuous(breaks=range(30, 70, 5))
72 | + theme(
73 | panel_grid_major=element_line(color="lightgrey", size=0.5),
74 | panel_grid_minor=element_line(color="lightgrey", size=0.25),
75 | figure_size=(10, 6),
76 | )
77 | + xlim(0, df["Samples per second"].max() + 100)
78 | + guides(None)
79 | )
80 | return plot
81 |
82 |
83 | def benchmark_model(name: str, info: list[str], texts: list[str]) -> dict[str, float | str]:
84 | """Benchmark a single model."""
85 | logger.info("Starting", name)
86 | if info[1] == "BPEmb":
87 | model = BPEmbEmbedder(vs=50_000, dim=300) # type: ignore
88 | elif info[1] == "ST":
89 | model = SentenceTransformer(info[0], device="cpu") # type: ignore
90 | else:
91 | model = StaticModel.from_pretrained(info[0]) # type: ignore
92 |
93 | start = perf_counter()
94 | if info[1] == "M2V":
95 | # If the model is a model2vec model, disable multiprocessing for a fair comparison
96 | model.encode(texts, use_multiprocessing=False)
97 | else:
98 | model.encode(texts)
99 |
100 | total_time = perf_counter() - start
101 | docs_per_second = len(texts) / total_time
102 |
103 | logger.info(f"{name}: {docs_per_second} docs per second")
104 | logger.info(f"Total time: {total_time}")
105 |
106 | return {"docs_per_second": docs_per_second, "total_time": total_time}
107 |
108 |
109 | def main(save_path: str, n_texts: int) -> None:
110 | """Benchmark text embedding models and generate a plot."""
111 | # Define the models to benchmark
112 | models: dict[str, list[str]] = {
113 | "BPEmb-50k-300d": ["", "BPEmb"],
114 | "all-MiniLM-L6-v2": ["sentence-transformers/all-MiniLM-L6-v2", "ST"],
115 | "bge-base-en-v1.5": ["BAAI/bge-base-en-v1.5", "ST"],
116 | "GloVe 6B 300d": ["sentence-transformers/average_word_embeddings_glove.6B.300d", "ST"],
117 | "potion-base-8M": ["minishlab/potion-base-8M", "M2V"],
118 | }
119 |
120 | # Load the dataset
121 | ds = load_dataset("wikimedia/wikipedia", data_files="20231101.en/train-00000-of-00041.parquet")["train"]
122 | texts = ds["text"][:n_texts]
123 |
124 | summarized_results = [
125 | {"Model": "potion-base-2M", "Average score": 44.77, "Samples per second": None, "Params (Million)": 1.875},
126 | {"Model": "GloVe 6B 300d", "Average score": 42.36, "Samples per second": None, "Params (Million)": 120.000},
127 | {"Model": "potion-base-4M", "Average score": 48.23, "Samples per second": None, "Params (Million)": 3.750},
128 | {"Model": "all-MiniLM-L6-v2", "Average score": 56.09, "Samples per second": None, "Params (Million)": 23.000},
129 | {"Model": "potion-base-8M", "Average score": 50.03, "Samples per second": None, "Params (Million)": 7.500},
130 | {"Model": "bge-base-en-v1.5", "Average score": 63.56, "Samples per second": None, "Params (Million)": 109.000},
131 | {"Model": "M2V_base_output", "Average score": 45.34, "Samples per second": None, "Params (Million)": 7.500},
132 | {"Model": "BPEmb-50k-300d", "Average score": 37.78, "Samples per second": None, "Params (Million)": 15.000},
133 | {"Model": "potion-base-32M", "Average score": 51.66, "Samples per second": None, "Params (Million)": 32.300},
134 | ]
135 |
136 | timings = {}
137 |
138 | for name, info in models.items():
139 | timing = benchmark_model(name, info, texts)
140 | timings[name] = timing
141 | # Update summarized results
142 | for result in summarized_results:
143 | if result["Model"] == name:
144 | result["Samples per second"] = timing["docs_per_second"]
145 |
146 | # Set potion-base-8M as the reference speed for the other potion models
147 | potion_base_8m_speed = next(
148 | result["Samples per second"] for result in summarized_results if result["Model"] == "potion-base-8M"
149 | )
150 | for model_name in ["M2V_base_output", "potion-base-2M", "potion-base-4M", "potion-base-32M"]:
151 | for result in summarized_results:
152 | if result["Model"] == model_name:
153 | result["Samples per second"] = potion_base_8m_speed
154 |
155 | # Ensure save_path is a directory
156 | save_dir = Path(save_path)
157 | save_dir.mkdir(parents=True, exist_ok=True)
158 |
159 | # Save timings to JSON
160 | json_path = save_dir / "speed_benchmark_results.json"
161 | with open(json_path, "w") as file:
162 | json.dump(timings, file, indent=4)
163 |
164 | # Create and save the plot
165 | df = pd.DataFrame(summarized_results)
166 | plot = make_plot(df)
167 | plot_path = save_dir / "speed_vs_mteb_plot.png"
168 | plot.save(plot_path, width=12, height=10)
169 |
170 | logger.info(f"Timings saved to {json_path}")
171 | logger.info(f"Plot saved to {plot_path}")
172 |
173 |
174 | if __name__ == "__main__":
175 | parser = argparse.ArgumentParser(description="Benchmark text embedding models and generate a plot.")
176 | parser.add_argument(
177 | "--save-path", type=str, required=True, help="Directory to save the benchmark results and plot."
178 | )
179 | parser.add_argument(
180 | "--n-texts", type=int, default=100_000, help="Number of texts to use from the dataset for benchmarking."
181 | )
182 | args = parser.parse_args()
183 |
184 | main(save_path=args.save_path, n_texts=args.n_texts)
185 |
--------------------------------------------------------------------------------
/scripts/export_to_onnx.py:
--------------------------------------------------------------------------------
1 | from model2vec.utils import get_package_extras, importable
2 |
3 | # Define the optional dependency group name
4 | _REQUIRED_EXTRA = "onnx"
5 |
6 | # Check if each dependency for the "onnx" group is importable
7 | for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
8 | importable(extra_dependency, _REQUIRED_EXTRA)
9 |
10 | import argparse
11 | import json
12 | import logging
13 | from pathlib import Path
14 |
15 | import torch
16 | from tokenizers import Tokenizer
17 | from transformers import AutoTokenizer, PreTrainedTokenizerFast
18 |
19 | from model2vec import StaticModel
20 |
21 | logging.basicConfig(level=logging.INFO)
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class TorchStaticModel(torch.nn.Module):
26 | def __init__(self, model: StaticModel) -> None:
27 | """Initialize the TorchStaticModel with a StaticModel instance."""
28 | super().__init__()
29 | # Convert NumPy embeddings to a torch.nn.EmbeddingBag
30 | embeddings = torch.from_numpy(model.embedding)
31 | if embeddings.dtype in {torch.int8, torch.uint8}:
32 | embeddings = embeddings.to(torch.float16)
33 | self.embedding_bag = torch.nn.EmbeddingBag.from_pretrained(embeddings, mode="mean", freeze=True)
34 | self.normalize = model.normalize
35 | # Save tokenizer attributes
36 | self.tokenizer = model.tokenizer
37 | self.unk_token_id = model.unk_token_id
38 | self.median_token_length = model.median_token_length
39 |
40 | def forward(self, input_ids: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
41 | """
42 | Forward pass of the model.
43 |
44 | :param input_ids: The input token ids.
45 | :param offsets: The offsets to compute the mean pooling.
46 | :return: The embeddings.
47 | """
48 | # Perform embedding lookup and mean pooling
49 | embeddings = self.embedding_bag(input_ids, offsets)
50 | # Normalize if required
51 | if self.normalize:
52 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
53 | return embeddings
54 |
55 | def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
56 | """
57 | Tokenize the input sentences.
58 |
59 | :param sentences: The input sentences.
60 | :param max_length: The maximum length of the input_ids.
61 | :return: The input_ids and offsets.
62 | """
63 | # Tokenization logic similar to your StaticModel
64 | if max_length is not None:
65 | m = max_length * self.median_token_length
66 | sentences = [sentence[:m] for sentence in sentences]
67 | encodings = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
68 | encodings_ids = [encoding.ids for encoding in encodings]
69 | if self.unk_token_id is not None:
70 | # Remove unknown tokens
71 | encodings_ids = [
72 | [token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
73 | ]
74 | if max_length is not None:
75 | encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
76 | # Flatten input_ids and compute offsets
77 | offsets = torch.tensor([0] + [len(ids) for ids in encodings_ids[:-1]], dtype=torch.long).cumsum(dim=0)
78 | input_ids = torch.tensor(
79 | [token_id for token_ids in encodings_ids for token_id in token_ids],
80 | dtype=torch.long,
81 | )
82 | return input_ids, offsets
83 |
84 |
85 | def export_model_to_onnx(model_path: str, save_path: Path) -> None:
86 | """
87 | Export the StaticModel to ONNX format and save tokenizer files.
88 |
89 | :param model_path: The path to the pretrained StaticModel.
90 | :param save_path: The directory to save the model and related files.
91 | """
92 | save_path.mkdir(parents=True, exist_ok=True)
93 |
94 | # Load the StaticModel
95 | model = StaticModel.from_pretrained(model_path)
96 | torch_model = TorchStaticModel(model)
97 |
98 | # Save the model using save_pretrained
99 | model.save_pretrained(save_path)
100 |
101 | # Prepare dummy input data
102 | texts = ["hello", "hello world"]
103 | input_ids, offsets = torch_model.tokenize(texts)
104 |
105 | # Export the model to ONNX
106 | onnx_model_path = save_path / "onnx/model.onnx"
107 | onnx_model_path.parent.mkdir(parents=True, exist_ok=True)
108 | torch.onnx.export(
109 | torch_model,
110 | (input_ids, offsets),
111 | str(onnx_model_path),
112 | export_params=True,
113 | opset_version=14,
114 | do_constant_folding=True,
115 | input_names=["input_ids", "offsets"],
116 | output_names=["embeddings"],
117 | dynamic_axes={
118 | "input_ids": {0: "num_tokens"},
119 | "offsets": {0: "batch_size"},
120 | "embeddings": {0: "batch_size"},
121 | },
122 | )
123 |
124 | logger.info(f"Model has been successfully exported to {onnx_model_path}")
125 |
126 | # Save the tokenizer files required for transformers.js
127 | save_tokenizer(model.tokenizer, save_path)
128 | logger.info(f"Tokenizer files have been saved to {save_path}")
129 |
130 |
131 | def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None:
132 | """
133 | Save tokenizer files in a format compatible with Transformers.
134 |
135 | :param tokenizer: The tokenizer from the StaticModel.
136 | :param save_directory: The directory to save the tokenizer files.
137 | :raises FileNotFoundError: If config.json is not found in save_directory.
138 | :raises FileNotFoundError: If tokenizer_config.json is not found in save_directory.
139 | :raises ValueError: If tokenizer_name is not found in config.json.
140 | """
141 | tokenizer_json_path = save_directory / "tokenizer.json"
142 | tokenizer.save(str(tokenizer_json_path))
143 |
144 | # Save vocab.txt
145 | vocab = tokenizer.get_vocab()
146 | vocab_path = save_directory / "vocab.txt"
147 | with open(vocab_path, "w", encoding="utf-8") as vocab_file:
148 | for token in sorted(vocab, key=vocab.get):
149 | vocab_file.write(f"{token}\n")
150 |
151 | # Load config.json to get tokenizer_name
152 | config_path = save_directory / "config.json"
153 | if config_path.exists():
154 | with open(config_path, "r", encoding="utf-8") as f:
155 | config = json.load(f)
156 | else:
157 | raise FileNotFoundError(f"config.json not found in {save_directory}")
158 |
159 | tokenizer_name = config.get("tokenizer_name")
160 | if not tokenizer_name:
161 | raise ValueError("tokenizer_name not found in config.json")
162 |
163 | # Load the original tokenizer
164 | original_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
165 |
166 | # Extract special tokens and tokenizer class
167 | special_tokens = original_tokenizer.special_tokens_map
168 | tokenizer_class = original_tokenizer.__class__.__name__
169 |
170 | # Load the tokenizer using PreTrainedTokenizerFast with special tokens
171 | fast_tokenizer = PreTrainedTokenizerFast(
172 | tokenizer_file=str(tokenizer_json_path),
173 | **special_tokens,
174 | )
175 |
176 | # Save the tokenizer files
177 | fast_tokenizer.save_pretrained(str(save_directory))
178 | # Modify tokenizer_config.json to set the correct tokenizer_class
179 | tokenizer_config_path = save_directory / "tokenizer_config.json"
180 | if tokenizer_config_path.exists():
181 | with open(tokenizer_config_path, "r", encoding="utf-8") as f:
182 | tokenizer_config = json.load(f)
183 | else:
184 | raise FileNotFoundError(f"tokenizer_config.json not found in {save_directory}")
185 |
186 | # Update the tokenizer_class field
187 | tokenizer_config["tokenizer_class"] = tokenizer_class
188 |
189 | # Write the updated tokenizer_config.json back to disk
190 | with open(tokenizer_config_path, "w", encoding="utf-8") as f:
191 | json.dump(tokenizer_config, f, indent=4, sort_keys=True)
192 |
193 |
194 | if __name__ == "__main__":
195 | parser = argparse.ArgumentParser(description="Export StaticModel to ONNX format")
196 | parser.add_argument(
197 | "--model_path",
198 | type=str,
199 | required=True,
200 | help="Path to the pretrained StaticModel",
201 | )
202 | parser.add_argument(
203 | "--save_path",
204 | type=str,
205 | required=True,
206 | help="Directory to save the exported model and files",
207 | )
208 | args = parser.parse_args()
209 |
210 | export_model_to_onnx(args.model_path, Path(args.save_path))
211 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MinishLab/model2vec/a3c42a0bf33d23bab6a0c3fba5d1b96cf5797a8f/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, cast
4 |
5 | import numpy as np
6 | import pytest
7 | import torch
8 | from tokenizers import Tokenizer
9 | from tokenizers.models import BPE, Unigram, WordPiece
10 | from tokenizers.pre_tokenizers import Whitespace
11 | from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
12 |
13 | from model2vec.inference import StaticModelPipeline
14 | from model2vec.train import StaticModelForClassification
15 |
16 | _TOKENIZER_TYPES = ["wordpiece", "bpe", "unigram"]
17 |
18 |
19 | @pytest.fixture(scope="session", params=_TOKENIZER_TYPES, ids=_TOKENIZER_TYPES)
20 | def mock_tokenizer(request: pytest.FixtureRequest) -> Tokenizer:
21 | """Create a mock tokenizer."""
22 | vocab = ["[PAD]", "word1", "word2", "word3", "[UNK]"]
23 | unk_token = "[UNK]"
24 |
25 | tokenizer_type = request.param
26 |
27 | if tokenizer_type == "wordpiece":
28 | model = WordPiece(
29 | vocab={token: idx for idx, token in enumerate(vocab)}, unk_token=unk_token, max_input_chars_per_word=100
30 | )
31 | elif tokenizer_type == "bpe":
32 | model = BPE(
33 | vocab={token: idx for idx, token in enumerate(vocab)},
34 | merges=[],
35 | unk_token=unk_token,
36 | fuse_unk=True,
37 | ignore_merges=True,
38 | )
39 | elif tokenizer_type == "unigram":
40 | model = Unigram(vocab=[(token, 0.0) for token in vocab], unk_id=0, byte_fallback=False)
41 | else:
42 | raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
43 | tokenizer = Tokenizer(model)
44 | tokenizer.pre_tokenizer = Whitespace() # type: ignore # Tokenizer issue
45 |
46 | return tokenizer
47 |
48 |
49 | @pytest.fixture(scope="function")
50 | def mock_berttokenizer() -> PreTrainedTokenizerFast:
51 | """Load the real BertTokenizerFast from the provided tokenizer.json file."""
52 | return cast(PreTrainedTokenizerFast, AutoTokenizer.from_pretrained("tests/data/test_tokenizer"))
53 |
54 |
55 | @pytest.fixture
56 | def mock_transformer() -> AutoModel:
57 | """Create a mock transformer model."""
58 |
59 | class MockPreTrainedModel:
60 | def __init__(self) -> None:
61 | self.device = "cpu"
62 | self.name_or_path = "mock-model"
63 |
64 | def to(self, device: str) -> MockPreTrainedModel:
65 | self.device = device
66 | return self
67 |
68 | def forward(self, *args: Any, **kwargs: Any) -> Any:
69 | # Simulate a last_hidden_state output for a transformer model
70 | batch_size, seq_length = kwargs["input_ids"].shape
71 | # Return a tensor of shape (batch_size, seq_length, 768)
72 | return type(
73 | "BaseModelOutputWithPoolingAndCrossAttentions",
74 | (object,),
75 | {
76 | "last_hidden_state": torch.rand(batch_size, seq_length, 768) # Simulate 768 hidden units
77 | },
78 | )
79 |
80 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
81 | # Simply call the forward method to simulate the same behavior as transformers models
82 | return self.forward(*args, **kwargs)
83 |
84 | return MockPreTrainedModel()
85 |
86 |
87 | @pytest.fixture(scope="session")
88 | def mock_vectors() -> np.ndarray:
89 | """Create mock vectors."""
90 | return np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.0, 0.0], [0.0, 0.0]])
91 |
92 |
93 | @pytest.fixture
94 | def mock_config() -> dict[str, str]:
95 | """Create a mock config."""
96 | return {"some_config": "value"}
97 |
98 |
99 | @pytest.fixture(scope="session")
100 | def mock_inference_pipeline(mock_trained_pipeline: StaticModelForClassification) -> StaticModelPipeline:
101 | """Mock pipeline."""
102 | return mock_trained_pipeline.to_pipeline()
103 |
104 |
105 | @pytest.fixture(
106 | params=[
107 | (False, "single_label", "str"),
108 | (False, "single_label", "int"),
109 | (True, "multilabel", "str"),
110 | (True, "multilabel", "int"),
111 | ],
112 | ids=lambda param: f"{param[1]}_{param[2]}",
113 | scope="session",
114 | )
115 | def mock_trained_pipeline(request: pytest.FixtureRequest) -> StaticModelForClassification:
116 | """Mock StaticModelForClassification with different label formats."""
117 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
118 | torch.random.manual_seed(42)
119 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
120 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
121 |
122 | X = ["dog", "cat"]
123 | is_multilabel, label_type = request.param[0], request.param[2]
124 |
125 | if label_type == "str":
126 | y = [["a", "b"], ["a"]] if is_multilabel else ["a", "b"] # type: ignore
127 | else:
128 | y = [[0, 1], [0]] if is_multilabel else [0, 1] # type: ignore
129 |
130 | model.fit(X, y)
131 |
132 | return model
133 |
--------------------------------------------------------------------------------
/tests/data/test_tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "mask_token": "[MASK]",
4 | "pad_token": "[PAD]",
5 | "sep_token": "[SEP]",
6 | "unk_token": "[UNK]"
7 | }
8 |
--------------------------------------------------------------------------------
/tests/data/test_tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "added_tokens_decoder": {
3 | "0": {
4 | "content": "[PAD]",
5 | "lstrip": false,
6 | "normalized": false,
7 | "rstrip": false,
8 | "single_word": false,
9 | "special": true
10 | },
11 | "100": {
12 | "content": "[UNK]",
13 | "lstrip": false,
14 | "normalized": false,
15 | "rstrip": false,
16 | "single_word": false,
17 | "special": true
18 | },
19 | "101": {
20 | "content": "[CLS]",
21 | "lstrip": false,
22 | "normalized": false,
23 | "rstrip": false,
24 | "single_word": false,
25 | "special": true
26 | },
27 | "102": {
28 | "content": "[SEP]",
29 | "lstrip": false,
30 | "normalized": false,
31 | "rstrip": false,
32 | "single_word": false,
33 | "special": true
34 | },
35 | "103": {
36 | "content": "[MASK]",
37 | "lstrip": false,
38 | "normalized": false,
39 | "rstrip": false,
40 | "single_word": false,
41 | "special": true
42 | }
43 | },
44 | "clean_up_tokenization_spaces": true,
45 | "cls_token": "[CLS]",
46 | "do_basic_tokenize": true,
47 | "do_lower_case": true,
48 | "mask_token": "[MASK]",
49 | "model_max_length": 1000000000000000019884624838656,
50 | "never_split": null,
51 | "pad_token": "[PAD]",
52 | "sep_token": "[SEP]",
53 | "strip_accents": null,
54 | "tokenize_chinese_chars": true,
55 | "tokenizer_class": "BertTokenizer",
56 | "unk_token": "[UNK]"
57 | }
58 |
--------------------------------------------------------------------------------
/tests/test_distillation.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from importlib import import_module
5 | from unittest.mock import MagicMock, patch
6 |
7 | import numpy as np
8 | import pytest
9 | from pytest import LogCaptureFixture
10 | from transformers import AutoModel, BertTokenizerFast
11 |
12 | from model2vec.distill.distillation import (
13 | clean_and_create_vocabulary,
14 | distill,
15 | distill_from_model,
16 | post_process_embeddings,
17 | )
18 | from model2vec.model import StaticModel
19 |
20 | try:
21 | # For huggingface_hub>=0.25.0
22 | from huggingface_hub.errors import RepositoryNotFoundError
23 | except ImportError:
24 | # For huggingface_hub<0.25.0
25 | from huggingface_hub.utils._errors import RepositoryNotFoundError # type: ignore
26 |
27 | rng = np.random.default_rng()
28 |
29 |
30 | @pytest.mark.parametrize(
31 | "vocabulary, pca_dims, apply_zipf",
32 | [
33 | (None, 256, True), # Output vocab with subwords, PCA applied
34 | (["wordA", "wordB"], 4, False), # Custom vocab with subword, PCA applied
35 | (None, "auto", False), # Subword, PCA set to 'auto'
36 | (None, 1024, False), # Subword, PCA set to high number.
37 | (None, None, True), # No PCA applied
38 | (None, 0.9, True), # PCA as float applied
39 | ],
40 | )
41 | @patch.object(import_module("model2vec.distill.distillation"), "model_info")
42 | @patch("transformers.AutoModel.from_pretrained")
43 | def test_distill_from_model(
44 | mock_auto_model: MagicMock,
45 | mock_model_info: MagicMock,
46 | mock_berttokenizer: BertTokenizerFast,
47 | mock_transformer: AutoModel,
48 | vocabulary: list[str] | None,
49 | pca_dims: int | None,
50 | apply_zipf: bool,
51 | ) -> None:
52 | """Test distill function with different parameters."""
53 | # Mock the return value of model_info to avoid calling the Hugging Face API
54 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
55 |
56 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
57 | # mock_auto_tokenizer.return_value = mock_berttokenizer
58 | mock_auto_model.return_value = mock_transformer
59 |
60 | # Call the distill function with the parametrized inputs
61 | static_model = distill_from_model(
62 | model=mock_transformer,
63 | tokenizer=mock_berttokenizer,
64 | vocabulary=vocabulary,
65 | device="cpu",
66 | pca_dims=pca_dims,
67 | apply_zipf=apply_zipf,
68 | token_remove_pattern=None,
69 | )
70 |
71 | static_model2 = distill(
72 | model_name="tests/data/test_tokenizer",
73 | vocabulary=vocabulary,
74 | device="cpu",
75 | pca_dims=pca_dims,
76 | apply_zipf=apply_zipf,
77 | token_remove_pattern=None,
78 | )
79 |
80 | assert static_model.embedding.shape == static_model2.embedding.shape
81 | assert static_model.config == static_model2.config
82 | assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str())
83 | assert static_model.base_model_name == static_model2.base_model_name
84 |
85 |
86 | @patch.object(import_module("model2vec.distill.distillation"), "model_info")
87 | @patch("transformers.AutoModel.from_pretrained")
88 | def test_distill_removal_pattern(
89 | mock_auto_model: MagicMock,
90 | mock_model_info: MagicMock,
91 | mock_berttokenizer: BertTokenizerFast,
92 | mock_transformer: AutoModel,
93 | ) -> None:
94 | """Test the removal pattern."""
95 | # Mock the return value of model_info to avoid calling the Hugging Face API
96 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
97 |
98 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
99 | # mock_auto_tokenizer.return_value = mock_berttokenizer
100 | mock_auto_model.return_value = mock_transformer
101 |
102 | # The vocab size is 30522, but we remove 998 tokens: [CLS], [SEP], and [MASK], and all [unused] tokens.
103 | expected_vocab_size = mock_berttokenizer.vocab_size - 998
104 |
105 | static_model = distill_from_model(
106 | model=mock_transformer,
107 | tokenizer=mock_berttokenizer,
108 | vocabulary=None,
109 | device="cpu",
110 | token_remove_pattern=None,
111 | )
112 |
113 | assert len(static_model.embedding) == expected_vocab_size
114 |
115 | # No tokens removed, nonsensical pattern
116 | static_model = distill_from_model(
117 | model=mock_transformer,
118 | tokenizer=mock_berttokenizer,
119 | vocabulary=None,
120 | device="cpu",
121 | token_remove_pattern="£££££££££££££££££",
122 | )
123 |
124 | assert len(static_model.embedding) == expected_vocab_size
125 |
126 | # Weird pattern.
127 | with pytest.raises(ValueError):
128 | static_model = distill_from_model(
129 | model=mock_transformer,
130 | tokenizer=mock_berttokenizer,
131 | vocabulary=None,
132 | device="cpu",
133 | token_remove_pattern="[...papapa",
134 | )
135 |
136 |
137 | @pytest.mark.parametrize(
138 | "vocabulary, pca_dims, apply_zipf, sif_coefficient, expected_shape",
139 | [
140 | (None, 256, True, None, (29524, 256)), # Output vocab with subwords, PCA applied
141 | (None, "auto", False, None, (29524, 768)), # Subword, PCA set to 'auto'
142 | (None, "auto", True, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
143 | (None, "auto", False, 1e-4, (29524, 768)), # Subword, PCA set to 'auto'
144 | (None, "auto", True, 0, None), # Sif too low
145 | (None, "auto", True, 1, None), # Sif too high
146 | (None, "auto", False, 0, (29524, 768)), # Sif too low, but apply_zipf is False
147 | (None, "auto", False, 1, (29524, 768)), # Sif too high, but apply_zipf is False
148 | (None, 1024, False, None, (29524, 768)), # Subword, PCA set to high number.
149 | (["wordA", "wordB"], 4, False, None, (29526, 4)), # Custom vocab with subword, PCA applied
150 | (None, None, True, None, (29524, 768)), # No PCA applied
151 | ],
152 | )
153 | @patch.object(import_module("model2vec.distill.distillation"), "model_info")
154 | @patch("transformers.AutoModel.from_pretrained")
155 | def test_distill(
156 | mock_auto_model: MagicMock,
157 | mock_model_info: MagicMock,
158 | mock_transformer: AutoModel,
159 | vocabulary: list[str] | None,
160 | pca_dims: int | None,
161 | apply_zipf: bool,
162 | sif_coefficient: float | None,
163 | expected_shape: tuple[int, int],
164 | ) -> None:
165 | """Test distill function with different parameters."""
166 | # Mock the return value of model_info to avoid calling the Hugging Face API
167 | mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
168 |
169 | # Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
170 | mock_auto_model.return_value = mock_transformer
171 |
172 | model_name = "tests/data/test_tokenizer"
173 |
174 | if (
175 | apply_zipf is not None
176 | and apply_zipf
177 | and sif_coefficient is not None
178 | and (sif_coefficient <= 0 or sif_coefficient >= 1)
179 | ):
180 | with pytest.raises(ValueError):
181 | static_model = distill(
182 | model_name=model_name,
183 | vocabulary=vocabulary,
184 | device="cpu",
185 | pca_dims=pca_dims,
186 | apply_zipf=apply_zipf,
187 | sif_coefficient=sif_coefficient,
188 | )
189 |
190 | else:
191 | # Call the distill function with the parametrized inputs
192 | static_model = distill(
193 | model_name=model_name,
194 | vocabulary=vocabulary,
195 | device="cpu",
196 | pca_dims=pca_dims,
197 | apply_zipf=apply_zipf,
198 | sif_coefficient=sif_coefficient,
199 | )
200 |
201 | # Assert the model is correctly generated
202 | assert isinstance(static_model, StaticModel)
203 | assert static_model.embedding.shape == expected_shape
204 | assert "mock-model" in static_model.config["tokenizer_name"]
205 | assert static_model.tokenizer is not None
206 |
207 |
208 | @patch.object(import_module("model2vec.distill.distillation"), "model_info")
209 | def test_missing_modelinfo(
210 | mock_model_info: MagicMock,
211 | mock_transformer: AutoModel,
212 | mock_berttokenizer: BertTokenizerFast,
213 | ) -> None:
214 | """Test that missing model info does not crash."""
215 | mock_model_info.side_effect = RepositoryNotFoundError("Model not found")
216 | static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu")
217 | assert static_model.language is None
218 |
219 |
220 | @pytest.mark.parametrize(
221 | "embeddings, pca_dims, sif_coefficient, expected_shape",
222 | [
223 | (rng.random((1000, 768)), 256, None, (1000, 256)), # PCA applied correctly
224 | (rng.random((1000, 768)), None, None, (1000, 768)), # No PCA applied, dimensions remain unchanged
225 | (rng.random((1000, 768)), 256, 1e-4, (1000, 256)), # PCA and Zipf applied
226 | (rng.random((10, 768)), 256, 1e-4, (10, 768)), # PCA dims higher than vocab size, no PCA applied
227 | ],
228 | )
229 | def test__post_process_embeddings(
230 | embeddings: np.ndarray, pca_dims: int, sif_coefficient: float | None, expected_shape: tuple[int, int]
231 | ) -> None:
232 | """Test the _post_process_embeddings function."""
233 | original_embeddings = embeddings.copy() # Copy embeddings to compare later
234 |
235 | # Test that the function raises an error if the PCA dims are larger than the number of dimensions
236 | if pca_dims and pca_dims > embeddings.shape[1]:
237 | with pytest.raises(ValueError):
238 | post_process_embeddings(embeddings, pca_dims, None)
239 |
240 | processed_embeddings = post_process_embeddings(embeddings, pca_dims, sif_coefficient)
241 |
242 | # Assert the shape is correct
243 | assert processed_embeddings.shape == expected_shape
244 |
245 | # If Zipf weighting is applied compare the original and processed embeddings
246 | # and check the weights are applied correctly
247 | if sif_coefficient and pca_dims is None:
248 | inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
249 | proba = inv_rank / np.sum(inv_rank)
250 | sif_weights = (sif_coefficient / (sif_coefficient + proba))[:, None]
251 |
252 | expected_zipf_embeddings = original_embeddings * sif_weights
253 | assert np.allclose(
254 | processed_embeddings, expected_zipf_embeddings, rtol=1e-5
255 | ), "Zipf weighting not applied correctly"
256 |
257 |
258 | @pytest.mark.parametrize(
259 | "added_tokens, expected_output, expected_warnings",
260 | [
261 | # Case: duplicates ("2010", "government") and an empty token ("")
262 | (["2010", "government", "nerv", ""], ["nerv"], ["Removed", "duplicate", "empty"]),
263 | # Case: No duplicates, no empty tokens
264 | (["worda", "wordb", "wordc"], ["worda", "wordb", "wordc"], []),
265 | # Case: Only empty token (""), should return an empty list
266 | ([""], [], ["Removed", "empty"]),
267 | ],
268 | )
269 | def test_clean_and_create_vocabulary(
270 | mock_berttokenizer: BertTokenizerFast,
271 | added_tokens: list[str],
272 | expected_output: list[str],
273 | expected_warnings: list[str],
274 | caplog: LogCaptureFixture,
275 | ) -> None:
276 | """Test the _clean_vocabulary function."""
277 | with caplog.at_level("WARNING"):
278 | tokens, _ = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None)
279 |
280 | cleaned_vocab = [token.form for token in tokens if not token.is_internal]
281 | # Check the cleaned vocabulary matches the expected output
282 | assert cleaned_vocab == expected_output
283 |
284 | # Check the warnings were logged as expected
285 | logged_warnings = [record.message for record in caplog.records]
286 |
287 | # Ensure the expected warnings contain expected keywords like 'Removed', 'duplicate', or 'empty'
288 | for expected_warning in expected_warnings:
289 | assert any(expected_warning in logged_warning for logged_warning in logged_warnings)
290 |
--------------------------------------------------------------------------------
/tests/test_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from tempfile import TemporaryDirectory
4 | from unittest.mock import patch
5 |
6 | import pytest
7 |
8 | from model2vec.inference import StaticModelPipeline
9 |
10 |
11 | def test_init_predict(mock_inference_pipeline: StaticModelPipeline) -> None:
12 | """Test successful init and predict with StaticModelPipeline."""
13 | target: list[str] | list[list[str]]
14 | if mock_inference_pipeline.multilabel:
15 | if isinstance(mock_inference_pipeline.classes_[0], str):
16 | target = [["a", "b"]]
17 | else:
18 | target = [[0, 1]] # type: ignore
19 | else:
20 | if isinstance(mock_inference_pipeline.classes_[0], str):
21 | target = ["b"]
22 | else:
23 | target = [1] # type: ignore
24 | assert mock_inference_pipeline.predict("dog").tolist() == target
25 | assert mock_inference_pipeline.predict(["dog"]).tolist() == target
26 |
27 |
28 | def test_init_predict_proba(mock_inference_pipeline: StaticModelPipeline) -> None:
29 | """Test successful init and predict_proba with StaticModelPipeline."""
30 | assert mock_inference_pipeline.predict_proba("dog").argmax() == 1
31 | assert mock_inference_pipeline.predict_proba(["dog"]).argmax(1).tolist() == [1]
32 |
33 |
34 | def test_init_evaluate(mock_inference_pipeline: StaticModelPipeline) -> None:
35 | """Test successful init and evaluate with StaticModelPipeline."""
36 | target: list[str] | list[list[str]]
37 | if mock_inference_pipeline.multilabel:
38 | if isinstance(mock_inference_pipeline.classes_[0], str):
39 | target = [["a", "b"]]
40 | else:
41 | target = [[0, 1]] # type: ignore
42 | else:
43 | if isinstance(mock_inference_pipeline.classes_[0], str):
44 | target = ["b"]
45 | else:
46 | target = [1] # type: ignore
47 | mock_inference_pipeline.evaluate("dog", target) # type: ignore
48 |
49 |
50 | def test_roundtrip_save(mock_inference_pipeline: StaticModelPipeline) -> None:
51 | """Test saving and loading the pipeline."""
52 | with TemporaryDirectory() as temp_dir:
53 | mock_inference_pipeline.save_pretrained(temp_dir)
54 | loaded = StaticModelPipeline.from_pretrained(temp_dir)
55 | target: list[str] | list[list[str]]
56 | if mock_inference_pipeline.multilabel:
57 | if isinstance(mock_inference_pipeline.classes_[0], str):
58 | target = [["a", "b"]]
59 | else:
60 | target = [[0, 1]] # type: ignore
61 | else:
62 | if isinstance(mock_inference_pipeline.classes_[0], str):
63 | target = ["b"]
64 | else:
65 | target = [1] # type: ignore
66 | assert loaded.predict("dog").tolist() == target
67 | assert loaded.predict(["dog"]).tolist() == target
68 | assert loaded.predict_proba("dog").argmax() == 1
69 | assert loaded.predict_proba(["dog"]).argmax(1).tolist() == [1]
70 |
71 |
72 | @patch("model2vec.inference.model._DEFAULT_TRUST_PATTERN", re.compile("torch"))
73 | def test_roundtrip_save_mock_trust_pattern(mock_inference_pipeline: StaticModelPipeline) -> None:
74 | """Test saving and loading the pipeline."""
75 | with TemporaryDirectory() as temp_dir:
76 | mock_inference_pipeline.save_pretrained(temp_dir)
77 | with pytest.raises(ValueError):
78 | StaticModelPipeline.from_pretrained(temp_dir)
79 |
80 |
81 | def test_roundtrip_save_file_gone(mock_inference_pipeline: StaticModelPipeline) -> None:
82 | """Test saving and loading the pipeline."""
83 | with TemporaryDirectory() as temp_dir:
84 | mock_inference_pipeline.save_pretrained(temp_dir)
85 | # Rename the file to abc.pipeline, so that it looks like it was downloaded from the hub
86 | os.unlink(os.path.join(temp_dir, "pipeline.skops"))
87 | with pytest.raises(FileNotFoundError):
88 | StaticModelPipeline.from_pretrained(temp_dir)
89 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from tempfile import TemporaryDirectory
3 |
4 | import numpy as np
5 | import pytest
6 | import safetensors
7 | from tokenizers import Tokenizer
8 |
9 | from model2vec import StaticModel
10 |
11 |
12 | def test_initialization(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
13 | """Test successful initialization of StaticModel."""
14 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
15 | assert model.embedding.shape == (5, 2)
16 | assert len(model.tokens) == 5
17 | assert model.tokenizer == mock_tokenizer
18 | assert model.config == mock_config
19 |
20 |
21 | def test_initialization_token_vector_mismatch(mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
22 | """Test if error is raised when number of tokens and vectors don't match."""
23 | mock_vectors = np.array([[0.1, 0.2], [0.2, 0.3]])
24 | with pytest.raises(ValueError):
25 | StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
26 |
27 |
28 | def test_tokenize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
29 | """Test tokenization of a sentence."""
30 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
31 | model._can_encode_fast = True
32 | tokens_fast = model.tokenize(["word1 word2"])
33 | model._can_encode_fast = False
34 | tokens_slow = model.tokenize(["word1 word2"])
35 |
36 | assert tokens_fast == tokens_slow
37 |
38 |
39 | def test_encode_batch_fast(
40 | mock_vectors: np.ndarray, mock_berttokenizer: Tokenizer, mock_config: dict[str, str]
41 | ) -> None:
42 | """Test tokenization of a sentence."""
43 | if hasattr(mock_berttokenizer, "encode_batch_fast"):
44 | del mock_berttokenizer.encode_batch_fast
45 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_berttokenizer, config=mock_config)
46 | assert not model._can_encode_fast
47 |
48 |
49 | def test_encode_single_sentence(
50 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
51 | ) -> None:
52 | """Test encoding of a single sentence."""
53 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
54 | encoded = model.encode("word1 word2")
55 | assert encoded.shape == (2,)
56 |
57 |
58 | def test_encode_single_sentence_empty(
59 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
60 | ) -> None:
61 | """Test encoding of a single empty sentence."""
62 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
63 | model.normalize = True
64 | encoded = model.encode("")
65 | assert not np.isnan(encoded).any()
66 | assert np.all(encoded == 0)
67 |
68 |
69 | def test_encode_multiple_sentences(
70 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
71 | ) -> None:
72 | """Test encoding of multiple sentences."""
73 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
74 | encoded = model.encode(["word1 word2", "word1 word3"])
75 | assert encoded.shape == (2, 2)
76 |
77 |
78 | def test_encode_as_sequence(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
79 | """Test encoding of sentences as tokens."""
80 | sentences = ["word1 word2", "word1 word3"]
81 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
82 | encoded_sequence = model.encode_as_sequence(sentences)
83 | encoded = model.encode(sentences)
84 |
85 | assert len(encoded_sequence) == 2
86 |
87 | means = [np.mean(sequence, axis=0) for sequence in encoded_sequence]
88 | assert np.allclose(means, encoded)
89 |
90 |
91 | def test_encode_multiprocessing(
92 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
93 | ) -> None:
94 | """Test encoding with multiprocessing."""
95 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
96 | # Generate a list of 15k inputs to test multiprocessing
97 | sentences = ["word1 word2"] * 15_000
98 | encoded = model.encode(sentences, use_multiprocessing=True)
99 | assert encoded.shape == (15000, 2)
100 |
101 |
102 | def test_encode_as_sequence_multiprocessing(
103 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
104 | ) -> None:
105 | """Test encoding of sentences as tokens with multiprocessing."""
106 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
107 | # Generate a list of 15k inputs to test multiprocessing
108 | sentences = ["word1 word2"] * 15_000
109 | encoded = model.encode_as_sequence(sentences, use_multiprocessing=True)
110 | assert len(encoded) == 15_000
111 |
112 |
113 | def test_encode_as_tokens_empty(
114 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
115 | ) -> None:
116 | """Test encoding of an empty list of sentences."""
117 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
118 | encoded = model.encode_as_sequence("")
119 | assert np.array_equal(encoded, np.zeros(shape=(0, 2), dtype=model.embedding.dtype))
120 |
121 | encoded = model.encode_as_sequence(["", ""])
122 | out = [np.zeros(shape=(0, 2), dtype=model.embedding.dtype) for _ in range(2)]
123 | assert [np.array_equal(x, y) for x, y in zip(encoded, out)]
124 |
125 |
126 | def test_encode_empty_sentence(
127 | mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
128 | ) -> None:
129 | """Test encoding with an empty sentence."""
130 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
131 | encoded = model.encode("")
132 | assert np.array_equal(encoded, np.zeros((2,)))
133 |
134 |
135 | def test_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
136 | """Test normalization of vectors."""
137 | s = "word1 word2 word3"
138 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=False)
139 | X = model.encode(s)
140 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=True)
141 | normalized = model.encode(s)
142 |
143 | expected = X / np.linalg.norm(X)
144 |
145 | np.testing.assert_almost_equal(normalized, expected)
146 |
147 |
148 | def test_save_pretrained(
149 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
150 | ) -> None:
151 | """Test saving a pretrained model."""
152 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
153 |
154 | # Save the model to the tmp_path
155 | save_path = tmp_path / "saved_model"
156 | model.save_pretrained(save_path)
157 |
158 | # Check that the save_path directory contains the saved files
159 | assert save_path.exists()
160 |
161 | assert (save_path / "model.safetensors").exists()
162 | assert (save_path / "tokenizer.json").exists()
163 | assert (save_path / "config.json").exists()
164 | assert (save_path / "modules.json").exists()
165 |
166 |
167 | def test_load_pretrained(
168 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
169 | ) -> None:
170 | """Test loading a pretrained model after saving it."""
171 | # Save the model to a temporary path
172 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
173 | save_path = tmp_path / "saved_model"
174 | model.save_pretrained(save_path)
175 |
176 | # Load the model back from the same path
177 | loaded_model = StaticModel.from_pretrained(save_path)
178 |
179 | # Assert that the loaded model has the same properties as the original one
180 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
181 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
182 | assert loaded_model.config == mock_config
183 |
184 |
185 | def test_load_pretrained_quantized(
186 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
187 | ) -> None:
188 | """Test loading a pretrained model after saving it."""
189 | # Save the model to a temporary path
190 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
191 | save_path = tmp_path / "saved_model"
192 | model.save_pretrained(save_path)
193 |
194 | # Load the model back from the same path
195 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="int8")
196 |
197 | # Assert that the loaded model has the same properties as the original one
198 | assert loaded_model.embedding.dtype == np.int8
199 | assert loaded_model.embedding.shape == mock_vectors.shape
200 |
201 | # Load the model back from the same path
202 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float16")
203 |
204 | # Assert that the loaded model has the same properties as the original one
205 | assert loaded_model.embedding.dtype == np.float16
206 | assert loaded_model.embedding.shape == mock_vectors.shape
207 |
208 | # Load the model back from the same path
209 | loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32")
210 | # Assert that the loaded model has the same properties as the original one
211 | assert loaded_model.embedding.dtype == np.float32
212 | assert loaded_model.embedding.shape == mock_vectors.shape
213 |
214 |
215 | def test_load_pretrained_dim(
216 | tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
217 | ) -> None:
218 | """Test loading a pretrained model with dimensionality."""
219 | # Save the model to a temporary path
220 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
221 | save_path = tmp_path / "saved_model"
222 | model.save_pretrained(save_path)
223 |
224 | loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2)
225 |
226 | # Assert that the loaded model has the same properties as the original one
227 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2])
228 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
229 | assert loaded_model.config == mock_config
230 |
231 | # Load the model back from the same path
232 | loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None)
233 |
234 | # Assert that the loaded model has the same properties as the original one
235 | np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
236 | assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
237 | assert loaded_model.config == mock_config
238 |
239 | # Load the model back from the same path
240 | with pytest.raises(ValueError):
241 | StaticModel.from_pretrained(save_path, dimensionality=3000)
242 |
243 |
244 | def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
245 | """Tests whether the normalization initialization is correct."""
246 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)
247 | assert not model.normalize
248 |
249 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=False)
250 | assert not model.normalize
251 |
252 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
253 | assert model.normalize
254 |
255 | model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": False}, normalize=True)
256 | assert model.normalize
257 |
258 | model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": True}, normalize=False)
259 | assert not model.normalize
260 |
261 |
262 | def test_set_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
263 | """Tests whether the normalize is set correctly."""
264 | model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
265 | model.normalize = False
266 | assert model.config == {"normalize": False}
267 | model.normalize = True
268 | assert model.config == {"normalize": True}
269 |
270 |
271 | def test_dim(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
272 | """Tests the dimensionality of the model."""
273 | model = StaticModel(mock_vectors, mock_tokenizer, mock_config)
274 | assert model.dim == 2
275 | assert model.dim == model.embedding.shape[1]
276 |
277 |
278 | def test_local_load_from_model(mock_tokenizer: Tokenizer) -> None:
279 | """Test local load from a model."""
280 | x = np.ones((mock_tokenizer.get_vocab_size(), 2))
281 | with TemporaryDirectory() as tempdir:
282 | tempdir_path = Path(tempdir)
283 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
284 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
285 |
286 | model = StaticModel.load_local(tempdir_path)
287 | assert model.embedding.shape == x.shape
288 | assert model.tokenizer.to_str() == mock_tokenizer.to_str()
289 | assert model.config == {"normalize": False}
290 |
291 |
292 | def test_local_load_from_model_no_folder() -> None:
293 | """Test local load from a model with no folder."""
294 | with pytest.raises(ValueError):
295 | StaticModel.load_local("woahbuddy_relax_this_is_just_a_test")
296 |
--------------------------------------------------------------------------------
/tests/test_quantization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from model2vec.quantization import DType, quantize_embeddings
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "input_dtype,target_dtype,expected_dtype",
9 | [
10 | (np.float32, DType.Float16, np.float16),
11 | (np.float16, DType.Float32, np.float32),
12 | (np.float32, DType.Float64, np.float64),
13 | (np.float32, DType.Int8, np.int8),
14 | ],
15 | )
16 | def test_quantize_embeddings(input_dtype: DType, target_dtype: DType, expected_dtype: DType) -> None:
17 | """Test quantization to different dtypes."""
18 | embeddings = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=input_dtype)
19 | # Use negative values for int8 test case
20 | if target_dtype == DType.Int8:
21 | embeddings = np.array([[-1.0, 2.0], [-3.0, 4.0]], dtype=input_dtype)
22 |
23 | quantized = quantize_embeddings(embeddings, target_dtype)
24 | assert quantized.dtype == expected_dtype
25 |
26 | if target_dtype == DType.Int8:
27 | # Check if the values are in the range [-127, 127]
28 | assert np.all(quantized >= -127) and np.all(quantized <= 127)
29 | else:
30 | assert np.allclose(quantized, embeddings.astype(expected_dtype))
31 |
--------------------------------------------------------------------------------
/tests/test_tokenizer.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import pytest
4 | from transformers import PreTrainedTokenizerFast
5 |
6 | from model2vec.tokenizer.model import _calculate_token_weight_for_unigram, _process_unigram, process_tokenizer
7 | from model2vec.tokenizer.normalizer import replace_normalizer
8 | from model2vec.tokenizer.pretokenizer import _FORBIDDEN_PRETOKENIZERS, _fix_single_pretokenizer, replace_pretokenizer
9 | from model2vec.tokenizer.tokenizer import _rename_added_token, create_tokenizer
10 |
11 |
12 | def test_fix_single_pretokenizer() -> None:
13 | """Test the _fix_single_pretokenizer function."""
14 | result = _fix_single_pretokenizer({"type": "ByteLevel", "add_prefix_space": False, "use_regex": True})
15 | assert result == {"type": "ByteLevel", "add_prefix_space": True, "use_regex": False}
16 |
17 | for tokenizer_type in _FORBIDDEN_PRETOKENIZERS:
18 | result = _fix_single_pretokenizer({"type": tokenizer_type})
19 | assert result is None
20 |
21 | result = _fix_single_pretokenizer(
22 | {"type": "Metaspace", "split": True, "prepend_scheme": "never", "replacement": "▁"}
23 | )
24 | assert result == {"type": "Metaspace", "replacement": "▁", "prepend_scheme": "always", "split": False}
25 |
26 |
27 | def test_replace_pretokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
28 | """Test the replace_pretokenizer function."""
29 | tokenizer = replace_pretokenizer(mock_berttokenizer.backend_tokenizer)
30 | assert tokenizer.pre_tokenizer is not None
31 | assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace"
32 | assert tokenizer.pre_tokenizer.replacement == "▁"
33 | assert tokenizer.pre_tokenizer.prepend_scheme == "always"
34 | assert not tokenizer.pre_tokenizer.split
35 |
36 | tokenizer.pre_tokenizer = None # type: ignore
37 | tokenizer = replace_pretokenizer(tokenizer)
38 | assert tokenizer.pre_tokenizer is not None
39 | assert tokenizer.pre_tokenizer.__class__.__name__ == "Metaspace"
40 | assert tokenizer.pre_tokenizer.replacement == "▁"
41 | assert tokenizer.pre_tokenizer.prepend_scheme == "always"
42 | assert tokenizer.pre_tokenizer.split is False
43 |
44 |
45 | def test_replace_normalizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
46 | """Test the replace_normalizer function."""
47 | tokenizer = replace_normalizer(mock_berttokenizer.backend_tokenizer)
48 | assert tokenizer.normalizer is not None
49 | assert tokenizer.normalizer.__class__.__name__ == "Sequence"
50 |
51 | assert tokenizer.normalizer.normalize_str("Hello, World!") == "hello , world !"
52 |
53 | tokenizer.normalizer = None # type: ignore
54 | tokenizer = replace_normalizer(tokenizer)
55 | assert tokenizer.normalizer.normalize_str("Hello, World!") == "Hello , World !"
56 |
57 |
58 | @pytest.mark.parametrize(
59 | "word,weight",
60 | [
61 | ("dog", 3),
62 | ("cat", 3),
63 | ("▁longer▁word", 14),
64 | ("▁word", 6),
65 | ("▁", 2), # Single underscore
66 | ("", 0), # Empty string
67 | ("▁a" * 100, 300), # Long word with underscores
68 | ],
69 | )
70 | def test_calculate_token_weight_for_unigram(word: str, weight: int) -> None:
71 | """Test the _calculate_token_weight_for_unigram function."""
72 | assert _calculate_token_weight_for_unigram(word) == weight
73 |
74 |
75 | def test_process_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
76 | """Test the process_tokenizer function."""
77 | vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"]
78 | tokenizer_json = json.loads(mock_berttokenizer.backend_tokenizer.to_str())
79 | tokenizer_json = process_tokenizer(tokenizer_json=tokenizer_json, pre_tokenized_tokens=vocab, unk_token="[UNK]")
80 |
81 | assert tokenizer_json["model"]["type"] == "Unigram"
82 | assert tokenizer_json["model"]["unk_id"] == 5 # Index of "[UNK]"
83 | assert len(tokenizer_json["model"]["vocab"]) == 6
84 | assert all(isinstance(token, tuple) and len(token) == 2 for token in tokenizer_json["model"]["vocab"])
85 | for (x, _), y in zip(tokenizer_json["model"]["vocab"], vocab):
86 | assert x == y, f"Expected {y}, but got {x}"
87 |
88 |
89 | def test_process_unigram() -> None:
90 | """Test the _process_unigram function."""
91 | vocab = ["dog", "cat", "longer_word", "word", "a" * 100, "[UNK]"]
92 | orig_vocab = [("dog", 0), ("cat", 0)]
93 | model = {"model": {"type": "Unigram", "vocab": orig_vocab}}
94 | processed_model = _process_unigram(model, vocab, "[UNK]")
95 | assert processed_model["model"]["type"] == "Unigram"
96 | assert processed_model["model"]["unk_id"] == 5 # Index of "[UNK]"
97 | assert len(processed_model["model"]["vocab"]) == 6
98 | assert all(isinstance(token, list) and len(token) == 2 for token in processed_model["model"]["vocab"])
99 |
100 | for (x, score), y in zip(processed_model["model"]["vocab"], vocab):
101 | assert x == y, f"Expected {y}, but got {x}"
102 | if x in orig_vocab:
103 | assert score == 0
104 |
105 | assert process_tokenizer(model, vocab, "[UNK]") == processed_model
106 |
107 |
108 | def test_rename_added_token() -> None:
109 | """Test the _rename_added_token function."""
110 | # Invalid input
111 | result = _rename_added_token(None, "a", [{"content": "a", "id": 0}], ["a"])
112 | assert result == [{"content": "a", "id": 0}]
113 |
114 | # Rename 'a' to 'c'
115 | result = _rename_added_token("a", "c", [{"content": "a"}], ["a"])
116 | assert result == [{"content": "c", "id": 0}]
117 |
118 |
119 | def test_create_tokenizer(mock_berttokenizer: PreTrainedTokenizerFast) -> None:
120 | """Test the create_tokenizer function."""
121 | tokenizer = create_tokenizer(tokenizer=mock_berttokenizer, vocabulary=["dog", "catssssss"], token_remove_regex=None)
122 | assert tokenizer.backend_tokenizer.get_vocab_size() == 29525
123 | assert tokenizer.encode("catssssss") == [29524]
124 |
--------------------------------------------------------------------------------
/tests/test_trainable.py:
--------------------------------------------------------------------------------
1 | from tempfile import TemporaryDirectory
2 |
3 | import numpy as np
4 | import pytest
5 | import torch
6 | from tokenizers import Tokenizer
7 | from transformers import AutoTokenizer
8 |
9 | from model2vec.model import StaticModel
10 | from model2vec.train import StaticModelForClassification
11 | from model2vec.train.base import FinetunableStaticModel, TextDataset
12 |
13 |
14 | @pytest.mark.parametrize("n_layers", [0, 1, 2, 3])
15 | def test_init_predict(n_layers: int, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
16 | """Test successful initialization of StaticModelForClassification."""
17 | vectors_torched = torch.from_numpy(mock_vectors)
18 | s = StaticModelForClassification(vectors=vectors_torched, tokenizer=mock_tokenizer, n_layers=n_layers)
19 | assert s.vectors.shape == mock_vectors.shape
20 | assert s.w.shape[0] == mock_vectors.shape[0]
21 | assert list(s.classes) == s.classes_
22 | assert list(s.classes) == ["0", "1"]
23 |
24 | head = s.construct_head()
25 | assert head[0].in_features == mock_vectors.shape[1]
26 | head = s.construct_head()
27 | assert head[0].in_features == mock_vectors.shape[1]
28 | assert head[-1].out_features == 2
29 |
30 |
31 | def test_init_base_class(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
32 | """Test successful initialization of the base class."""
33 | vectors_torched = torch.from_numpy(mock_vectors)
34 | s = FinetunableStaticModel(vectors=vectors_torched, tokenizer=mock_tokenizer)
35 | assert s.vectors.shape == mock_vectors.shape
36 | assert s.w.shape[0] == mock_vectors.shape[0]
37 |
38 | head = s.construct_head()
39 | assert head[0].in_features == mock_vectors.shape[1]
40 |
41 |
42 | def test_init_base_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
43 | """Test initializion from a static model."""
44 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer)
45 | s = FinetunableStaticModel.from_static_model(model=model)
46 | assert s.vectors.shape == mock_vectors.shape
47 | assert s.w.shape[0] == mock_vectors.shape[0]
48 |
49 | with TemporaryDirectory() as temp_dir:
50 | model.save_pretrained(temp_dir)
51 | s = FinetunableStaticModel.from_pretrained(model_name=temp_dir)
52 | assert s.vectors.shape == mock_vectors.shape
53 | assert s.w.shape[0] == mock_vectors.shape[0]
54 |
55 |
56 | def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
57 | """Test initializion from a static model."""
58 | model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer)
59 | s = StaticModelForClassification.from_static_model(model=model)
60 | assert s.vectors.shape == mock_vectors.shape
61 | assert s.w.shape[0] == mock_vectors.shape[0]
62 |
63 | with TemporaryDirectory() as temp_dir:
64 | model.save_pretrained(temp_dir)
65 | s = StaticModelForClassification.from_pretrained(model_name=temp_dir)
66 | assert s.vectors.shape == mock_vectors.shape
67 | assert s.w.shape[0] == mock_vectors.shape[0]
68 |
69 |
70 | def test_encode(mock_trained_pipeline: StaticModelForClassification) -> None:
71 | """Test the encode function."""
72 | result = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long())
73 | assert result.shape == (2, 12)
74 | assert torch.allclose(result[0], result[1])
75 |
76 |
77 | def test_tokenize(mock_trained_pipeline: StaticModelForClassification) -> None:
78 | """Test the encode function."""
79 | result = mock_trained_pipeline.tokenize(["dog dog", "cat"])
80 | assert result.shape == torch.Size([2, 2])
81 | assert result[1, 1] == 0
82 |
83 |
84 | def test_device(mock_trained_pipeline: StaticModelForClassification) -> None:
85 | """Get the device."""
86 | assert mock_trained_pipeline.device == torch.device(type="cpu") # type: ignore # False positive
87 | assert mock_trained_pipeline.device == mock_trained_pipeline.w.device
88 |
89 |
90 | def test_conversion(mock_trained_pipeline: StaticModelForClassification) -> None:
91 | """Test the conversion to numpy."""
92 | staticmodel = mock_trained_pipeline.to_static_model()
93 | with torch.no_grad():
94 | result_1 = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long()).numpy()
95 | result_2 = staticmodel.embedding[[[0, 1], [1, 0]]].mean(0)
96 | result_2 /= np.linalg.norm(result_2, axis=1, keepdims=True)
97 |
98 | assert np.allclose(result_1, result_2)
99 |
100 |
101 | def test_textdataset_init() -> None:
102 | """Test the textdataset init."""
103 | dataset = TextDataset([[0], [1]], torch.arange(2))
104 | assert len(dataset) == 2
105 |
106 |
107 | def test_textdataset_init_incorrect() -> None:
108 | """Test the textdataset init."""
109 | with pytest.raises(ValueError):
110 | TextDataset([[0]], torch.arange(2))
111 |
112 |
113 | def test_predict(mock_trained_pipeline: StaticModelForClassification) -> None:
114 | """Test the predict function."""
115 | result = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist()
116 | if mock_trained_pipeline.multilabel:
117 | if type(mock_trained_pipeline.classes_[0]) == str:
118 | assert result == [["a", "b"], ["a", "b"]]
119 | else:
120 | assert result == [[0, 1], [0, 1]]
121 | else:
122 | if type(mock_trained_pipeline.classes_[0]) == str:
123 | assert result == ["b", "b"]
124 | else:
125 | assert result == [1, 1]
126 |
127 |
128 | def test_predict_proba(mock_trained_pipeline: StaticModelForClassification) -> None:
129 | """Test the predict function."""
130 | result = mock_trained_pipeline.predict_proba(["dog cat", "dog"])
131 | assert result.shape == (2, 2)
132 |
133 |
134 | def test_convert_to_pipeline(mock_trained_pipeline: StaticModelForClassification) -> None:
135 | """Convert a model to a pipeline."""
136 | mock_trained_pipeline.eval()
137 | pipeline = mock_trained_pipeline.to_pipeline()
138 | encoded_pipeline = pipeline.model.encode(["dog cat", "dog"])
139 | encoded_model = mock_trained_pipeline(mock_trained_pipeline.tokenize(["dog cat", "dog"]))[1].detach().numpy()
140 | assert np.allclose(encoded_pipeline, encoded_model)
141 | a = pipeline.predict(["dog cat", "dog"]).tolist()
142 | b = mock_trained_pipeline.predict(["dog cat", "dog"]).tolist()
143 | assert a == b
144 | p1 = pipeline.predict_proba(["dog cat", "dog"])
145 | p2 = mock_trained_pipeline.predict_proba(["dog cat", "dog"])
146 | assert np.allclose(p1, p2)
147 |
148 |
149 | def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -> None:
150 | """Test the train test split function."""
151 | a, b, c, d = mock_trained_pipeline._train_test_split(["0", "1", "2", "3"], ["1", "1", "0", "0"], 0.5)
152 | assert len(a) == 2
153 | assert len(b) == 2
154 | assert len(c) == len(a)
155 | assert len(d) == len(b)
156 |
157 |
158 | def test_y_val_none() -> None:
159 | """Test the y_val function."""
160 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
161 | torch.random.manual_seed(42)
162 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
163 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
164 |
165 | X = ["dog", "cat"]
166 | y = ["0", "1"]
167 |
168 | X_val = ["dog", "cat"]
169 | y_val = ["0", "1"]
170 |
171 | with pytest.raises(ValueError):
172 | model.fit(X, y, X_val=X_val, y_val=None)
173 | with pytest.raises(ValueError):
174 | model.fit(X, y, X_val=None, y_val=y_val)
175 | model.fit(X, y, X_val=None, y_val=None)
176 |
177 |
178 | @pytest.mark.parametrize(
179 | "y_multi,y_val_multi,should_crash",
180 | [[True, True, False], [False, False, False], [True, False, True], [False, True, True]],
181 | )
182 | def test_y_val(y_multi: bool, y_val_multi: bool, should_crash: bool) -> None:
183 | """Test the y_val function."""
184 | tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
185 | torch.random.manual_seed(42)
186 | vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
187 | model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
188 |
189 | X = ["dog", "cat"]
190 | y = [["0", "1"], ["0"]] if y_multi else ["0", "1"] # type: ignore
191 |
192 | X_val = ["dog", "cat"]
193 | y_val = [["0", "1"], ["0"]] if y_val_multi else ["0", "1"] # type: ignore
194 |
195 | if should_crash:
196 | with pytest.raises(ValueError):
197 | model.fit(X, y, X_val=X_val, y_val=y_val)
198 | else:
199 | model.fit(X, y, X_val=X_val, y_val=y_val)
200 |
201 |
202 | def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
203 | """Test the evaluate function."""
204 | if mock_trained_pipeline.multilabel:
205 | if type(mock_trained_pipeline.classes_[0]) == str:
206 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [["a", "b"], ["a"]])
207 | else:
208 | # Ignore the type error since we don't support int labels in our typing, but the code does
209 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [[0, 1], [0]]) # type: ignore
210 | else:
211 | if type(mock_trained_pipeline.classes_[0]) == str:
212 | mock_trained_pipeline.evaluate(["dog cat", "dog"], ["a", "a"])
213 | else:
214 | # Ignore the type error since we don't support int labels in our typing, but the code does
215 | mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore
216 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from pathlib import Path
5 | from tempfile import NamedTemporaryFile, TemporaryDirectory
6 | from typing import Any
7 | from unittest.mock import patch
8 |
9 | import numpy as np
10 | import pytest
11 | import safetensors
12 | import safetensors.numpy
13 | from tokenizers import Tokenizer
14 |
15 | from model2vec.distill.utils import select_optimal_device
16 | from model2vec.hf_utils import _get_metadata_from_readme
17 | from model2vec.utils import get_package_extras, importable, load_local_model
18 |
19 |
20 | def test__get_metadata_from_readme_not_exists() -> None:
21 | """Test getting metadata from a README."""
22 | assert _get_metadata_from_readme(Path("zzz")) == {}
23 |
24 |
25 | def test__get_metadata_from_readme_mocked_file() -> None:
26 | """Test getting metadata from a README."""
27 | with NamedTemporaryFile() as f:
28 | f.write(b"---\nkey: value\n---\n")
29 | f.flush()
30 | assert _get_metadata_from_readme(Path(f.name))["key"] == "value"
31 |
32 |
33 | def test__get_metadata_from_readme_mocked_file_keys() -> None:
34 | """Test getting metadata from a README."""
35 | with NamedTemporaryFile() as f:
36 | f.write(b"")
37 | f.flush()
38 | assert set(_get_metadata_from_readme(Path(f.name))) == set()
39 |
40 |
41 | @pytest.mark.parametrize(
42 | "device, expected, cuda, mps",
43 | [
44 | ("cpu", "cpu", True, True),
45 | ("cpu", "cpu", True, False),
46 | ("cpu", "cpu", False, True),
47 | ("cpu", "cpu", False, False),
48 | ("clown", "clown", False, False),
49 | (None, "cuda", True, True),
50 | (None, "cuda", True, False),
51 | (None, "mps", False, True),
52 | (None, "cpu", False, False),
53 | ],
54 | )
55 | def test_select_optimal_device(device: str | None, expected: str, cuda: bool, mps: bool) -> None:
56 | """Test whether the optimal device is selected."""
57 | with (
58 | patch("torch.cuda.is_available", return_value=cuda),
59 | patch("torch.backends.mps.is_available", return_value=mps),
60 | ):
61 | assert select_optimal_device(device) == expected
62 |
63 |
64 | def test_importable() -> None:
65 | """Test the importable function."""
66 | with pytest.raises(ImportError):
67 | importable("clown", "clown")
68 |
69 | importable("os", "clown")
70 |
71 |
72 | def test_get_package_extras() -> None:
73 | """Test package extras."""
74 | extras = set(get_package_extras("model2vec", "distill"))
75 | assert extras == {"torch", "transformers", "scikit-learn"}
76 |
77 |
78 | def test_get_package_extras_empty() -> None:
79 | """Test package extras with an empty package."""
80 | assert not list(get_package_extras("tqdm", ""))
81 |
82 |
83 | @pytest.mark.parametrize(
84 | "config, expected",
85 | [
86 | ({"dog": "cat"}, {"dog": "cat"}),
87 | ({}, {}),
88 | (None, {}),
89 | ],
90 | )
91 | def test_local_load(mock_tokenizer: Tokenizer, config: dict[str, Any], expected: dict[str, Any]) -> None:
92 | """Test local loading."""
93 | x = np.ones((mock_tokenizer.get_vocab_size(), 2))
94 |
95 | with TemporaryDirectory() as tempdir:
96 | tempdir_path = Path(tempdir)
97 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
98 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
99 | if config is not None:
100 | json.dump(config, open(tempdir_path / "config.json", "w"))
101 | arr, tokenizer, config = load_local_model(tempdir_path)
102 | assert config == expected
103 | assert tokenizer.to_str() == mock_tokenizer.to_str()
104 | assert arr.shape == x.shape
105 |
106 |
107 | def test_local_load_mismatch(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None:
108 | """Test local loading."""
109 | x = np.ones((10, 2))
110 |
111 | with TemporaryDirectory() as tempdir:
112 | tempdir_path = Path(tempdir)
113 | safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
114 | mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
115 |
116 | load_local_model(tempdir_path)
117 | expected = (
118 | f"Number of tokens does not match number of embeddings: `{len(mock_tokenizer.get_vocab())}` vs `{len(x)}`"
119 | )
120 | assert len(caplog.records) == 1
121 | assert caplog.records[0].message == expected
122 |
--------------------------------------------------------------------------------
/tutorials/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | # Tutorials
9 |
10 | This is a list of all our tutorials. They are all self-contained ipython notebooks.
11 |
12 | | | what? | Link |
13 | |--------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------|
14 | | **Recipe search** 🍝 | Learn how to do lightning-fast semantic search by distilling a small model. Compare a really tiny model to a larger with one with a better vocabulary. Learn what Fattoush is (delicious). | [](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/recipe_search.ipynb) |
15 | | **Semantic chunking** 🧩 | Learn how to chunk your text into meaningful segments with [Chonkie](https://github.com/chonkie-inc/chonkie) at lightning-speed. Efficiently query your chunks with [Vicinity](https://github.com/MinishLab/vicinity). | [](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/semantic_chunking.ipynb) |
16 | | **Training a classifier** 🧩 | Learn how to train a classifier using model2vec. Lightning fast, great performance, especially on small datasets | [](https://colab.research.google.com/github/minishlab/model2vec/blob/master/tutorials/train_classifier.ipynb) |
17 |
--------------------------------------------------------------------------------
/tutorials/semantic_chunking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "**Semantic Chunking with Chonkie and Model2Vec**\n",
8 | "\n",
9 | "Semantic chunking is a task of identifying the semantic boundaries of a piece of text. In this tutorial, we will use the [Chonkie](https://github.com/bhavnicksm/chonkie) library to perform semantic chunking on the book War and Peace. Chonkie is a library that provides a lightweight and fast solution to semantic chunking using pre-trained models. It supports our [potion models](https://huggingface.co/collections/minishlab/potion-6721e0abd4ea41881417f062) out of the box, which we will be using in this tutorial.\n",
10 | "\n",
11 | "After chunking our text, we will be using [Vicinity](https://github.com/MinishLab/vicinity), a lightweight nearest neighbors library, to create an index of our chunks and query them."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# Install the necessary libraries\n",
21 | "!pip install -q datasets model2vec numpy tqdm vicinity \"chonkie[semantic]\""
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 1,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# Import the necessary libraries\n",
31 | "import random \n",
32 | "import re\n",
33 | "import requests\n",
34 | "from time import perf_counter\n",
35 | "from chonkie import SDPMChunker\n",
36 | "from model2vec import StaticModel\n",
37 | "from vicinity import Vicinity\n",
38 | "\n",
39 | "random.seed(0)"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "**Loading and pre-processing**\n",
47 | "\n",
48 | "First, we will download War and Peace and apply some basic pre-processing."
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# URL for War and Peace on Project Gutenberg\n",
58 | "url = \"https://www.gutenberg.org/files/2600/2600-0.txt\"\n",
59 | "\n",
60 | "# Download the book\n",
61 | "response = requests.get(url)\n",
62 | "book_text = response.text\n",
63 | "\n",
64 | "def preprocess_text(text: str, min_length: int = 5):\n",
65 | " \"\"\"Basic text preprocessing function.\"\"\"\n",
66 | " text = text.replace(\"\\n\", \" \")\n",
67 | " text = text.replace(\"\\r\", \" \")\n",
68 | " sentences = re.findall(r'[^.!?]*[.!?]', text)\n",
69 | " # Filter out sentences shorter than the specified minimum length\n",
70 | " filtered_sentences = [sentence.strip() for sentence in sentences if len(sentence.split()) >= min_length]\n",
71 | " # Recombine the filtered sentences\n",
72 | " return ' '.join(filtered_sentences)\n",
73 | "\n",
74 | "# Preprocess the text\n",
75 | "book_text = preprocess_text(book_text)"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "**Chunking with Chonkie**\n",
83 | "\n",
84 | "Next, we will use Chonkie to chunk our text into semantic chunks."
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 22,
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "Number of chunks: 4436\n",
97 | "Time taken: 1.6311538339941762\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "# Initialize a SemanticChunker from Chonkie with the potion-base-8M model\n",
103 | "chunker = SDPMChunker(\n",
104 | " embedding_model=\"minishlab/potion-base-32M\",\n",
105 | " chunk_size = 512, \n",
106 | " skip_window=5, \n",
107 | " min_sentences=3\n",
108 | ")\n",
109 | "\n",
110 | "# Chunk the text\n",
111 | "time = perf_counter()\n",
112 | "chunks = chunker.chunk(book_text)\n",
113 | "print(f\"Number of chunks: {len(chunks)}\")\n",
114 | "print(f\"Time taken: {perf_counter() - time}\")"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "And that's it, we chunked the entirety of War and Peace in ~2 seconds. Not bad! Let's look at some example chunks."
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 23,
127 | "metadata": {},
128 | "outputs": [
129 | {
130 | "name": "stdout",
131 | "output_type": "stream",
132 | "text": [
133 | " Wait and we shall see! As if fighting were fun. They are like children from whom one can’t get any sensible account of what has happened because they all want to show how well they can fight. But that’s not what is needed now. “And what ingenious maneuvers they all propose to me! \n",
134 | "\n",
135 | " The first thing he saw on riding up to the space where Túshin’s guns were stationed was an unharnessed horse with a broken leg, that lay screaming piteously beside the harnessed horses. Blood was gushing from its leg as from a spring. Among the limbers lay several dead men. \n",
136 | "\n",
137 | " Out of an army of a hundred thousand we must expect at least twenty thousand wounded, and we haven’t stretchers, or bunks, or dressers, or doctors enough for six thousand. We have ten thousand carts, but we need other things as well—we must manage as best we can! ” The strange thought that of the thousands of men, young and old, who had stared with merry surprise at his hat (perhaps the very men he had noticed), twenty thousand were inevitably doomed to wounds and death amazed Pierre. “They may die tomorrow; why are they thinking of anything but death? ” And by some latent sequence of thought the descent of the Mozháysk hill, the carts with the wounded, the ringing bells, the slanting rays of the sun, and the songs of the cavalrymen vividly recurred to his mind. “The cavalry ride to battle and meet the wounded and do not for a moment think of what awaits them, but pass by, winking at the wounded. Yet from among these men twenty thousand are doomed to die, and they wonder at my hat! ” thought Pierre, continuing his way to Tatárinova. \n",
138 | "\n"
139 | ]
140 | }
141 | ],
142 | "source": [
143 | "# Print a few example chunks\n",
144 | "for _ in range(3):\n",
145 | " chunk = random.choice(chunks)\n",
146 | " print(chunk.text, \"\\n\")"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "Those look good. Next, let's create a vector search index with Vicinity and Model2Vec.\n",
154 | "\n",
155 | "**Creating a vector search index**"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 24,
161 | "metadata": {},
162 | "outputs": [
163 | {
164 | "name": "stdout",
165 | "output_type": "stream",
166 | "text": [
167 | "Time taken: 2.269912125004339\n"
168 | ]
169 | }
170 | ],
171 | "source": [
172 | "# Initialize an embedding model and encode the chunk texts\n",
173 | "time = perf_counter()\n",
174 | "model = StaticModel.from_pretrained(\"minishlab/potion-base-32M\")\n",
175 | "chunk_texts = [chunk.text for chunk in chunks]\n",
176 | "chunk_embeddings = model.encode(chunk_texts)\n",
177 | "\n",
178 | "# Create a Vicinity instance\n",
179 | "vicinity = Vicinity.from_vectors_and_items(vectors=chunk_embeddings, items=chunk_texts)\n",
180 | "print(f\"Time taken: {perf_counter() - time}\")"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "Done! We embedded all our chunks and created an in index in ~1.5 seconds. Now that we have our index, let's query it with some queries.\n",
188 | "\n",
189 | "**Querying the index**"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 25,
195 | "metadata": {},
196 | "outputs": [
197 | {
198 | "name": "stdout",
199 | "output_type": "stream",
200 | "text": [
201 | "Query: Emperor Napoleon\n",
202 | "--------------------------------------------------\n",
203 | " In 1808 the Emperor Alexander went to Erfurt for a fresh interview with the Emperor Napoleon, and in the upper circles of Petersburg there was much talk of the grandeur of this important meeting. CHAPTER XXII In 1809 the intimacy between “the world’s two arbiters,” as Napoleon and Alexander were called, was such that when Napoleon declared war on Austria a Russian corps crossed the frontier to co-operate with our old enemy Bonaparte against our old ally the Emperor of Austria, and in court circles the possibility of marriage between Napoleon and one of Alexander’s sisters was spoken of. But besides considerations of foreign policy, the attention of Russian society was at that time keenly directed on the internal changes that were being undertaken in all the departments of government. Life meanwhile—real life, with its essential interests of health and sickness, toil and rest, and its intellectual interests in thought, science, poetry, music, love, friendship, hatred, and passions—went on as usual, independently of and apart from political friendship or enmity with Napoleon Bonaparte and from all the schemes of reconstruction. BOOK SIX: 1808 - 10 CHAPTER I Prince Andrew had spent two years continuously in the country. All the plans Pierre had attempted on his estates—and constantly changing from one thing to another had never accomplished—were carried out by Prince Andrew without display and without perceptible difficulty. \n",
204 | "\n",
205 | " CHAPTER XXVI On August 25, the eve of the battle of Borodinó, M. de Beausset, prefect of the French Emperor’s palace, arrived at Napoleon’s quarters at Valúevo with Colonel Fabvier, the former from Paris and the latter from Madrid. Donning his court uniform, M. de Beausset ordered a box he had brought for the Emperor to be carried before him and entered the first compartment of Napoleon’s tent, where he began opening the box while conversing with Napoleon’s aides-de-camp who surrounded him. Fabvier, not entering the tent, remained at the entrance talking to some generals of his acquaintance. The Emperor Napoleon had not yet left his bedroom and was finishing his toilet. \n",
206 | "\n",
207 | " In Russia there was an Emperor, Alexander, who decided to restore order in Europe and therefore fought against Napoleon. In 1807 he suddenly made friends with him, but in 1811 they again quarreled and again began killing many people. Napoleon led six hundred thousand men into Russia and captured Moscow; then he suddenly ran away from Moscow, and the Emperor Alexander, helped by the advice of Stein and others, united Europe to arm against the disturber of its peace. All Napoleon’s allies suddenly became his enemies and their forces advanced against the fresh forces he raised. The Allies defeated Napoleon, entered Paris, forced Napoleon to abdicate, and sent him to the island of Elba, not depriving him of the title of Emperor and showing him every respect, though five years before and one year later they all regarded him as an outlaw and a brigand. Then Louis XVIII, who till then had been the laughingstock both of the French and the Allies, began to reign. And Napoleon, shedding tears before his Old Guards, renounced the throne and went into exile. \n",
208 | "\n",
209 | "Query: The battle of Austerlitz\n",
210 | "--------------------------------------------------\n",
211 | " Behave as you did at Austerlitz, Friedland, Vítebsk, and Smolénsk. Let our remotest posterity recall your achievements this day with pride. Let it be said of each of you: “He was in the great battle before Moscow! \n",
212 | "\n",
213 | " By a strange coincidence, this task, which turned out to be a most difficult and important one, was entrusted to Dokhtúrov—that same modest little Dokhtúrov whom no one had described to us as drawing up plans of battles, dashing about in front of regiments, showering crosses on batteries, and so on, and who was thought to be and was spoken of as undecided and undiscerning—but whom we find commanding wherever the position was most difficult all through the Russo-French wars from Austerlitz to the year 1813. At Austerlitz he remained last at the Augezd dam, rallying the regiments, saving what was possible when all were flying and perishing and not a single general was left in the rear guard. Ill with fever he went to Smolénsk with twenty thousand men to defend the town against Napoleon’s whole army. \n",
214 | "\n",
215 | " “Nothing is truer or sadder. These gentlemen ride onto the bridge alone and wave white handkerchiefs; they assure the officer on duty that they, the marshals, are on their way to negotiate with Prince Auersperg. He lets them enter the tête-de-pont. * They spin him a thousand gasconades, saying that the war is over, that the Emperor Francis is arranging a meeting with Bonaparte, that they desire to see Prince Auersperg, and so on. The officer sends for Auersperg; these gentlemen embrace the officers, crack jokes, sit on the cannon, and meanwhile a French battalion gets to the bridge unobserved, flings the bags of incendiary material into the water, and approaches the tête-de-pont. At length appears the lieutenant general, our dear Prince Auersperg von Mautern himself. Flower of the Austrian army, hero of the Turkish wars! Hostilities are ended, we can shake one another’s hand. The Emperor Napoleon burns with impatience to make Prince Auersperg’s acquaintance. \n",
216 | "\n",
217 | "Query: Paris\n",
218 | "--------------------------------------------------\n",
219 | " Paris is Talma, la Duchénois, Potier, the Sorbonne, the boulevards,” and noticing that his conclusion was weaker than what had gone before, he added quickly: “There is only one Paris in the world. You have been to Paris and have remained Russian. Well, I don’t esteem you the less for it. \n",
220 | "\n",
221 | " Look at our youths, look at our ladies! The French are our Gods: Paris is our Kingdom of Heaven. ” He began speaking louder, evidently to be heard by everyone. “French dresses, French ideas, French feelings! \n",
222 | "\n",
223 | " “Oh yes, one sees that plainly. A man who doesn’t know Paris is a savage. You can tell a Parisian two leagues off. \n",
224 | "\n"
225 | ]
226 | }
227 | ],
228 | "source": [
229 | "queries = [\"Emperor Napoleon\", \"The battle of Austerlitz\", \"Paris\"]\n",
230 | "for query in queries:\n",
231 | " print(f\"Query: {query}\\n{'-' * 50}\")\n",
232 | " query_embedding = model.encode(query)\n",
233 | " results = vicinity.query(query_embedding, k=3)[0]\n",
234 | "\n",
235 | " for result in results:\n",
236 | " print(result[0], \"\\n\")"
237 | ]
238 | },
239 | {
240 | "cell_type": "markdown",
241 | "metadata": {},
242 | "source": [
243 | "These indeed look like relevant chunks, nice! That's it for this tutorial. We were able to chunk, index, and query War and Peace in about 3.5 seconds using Chonkie, Vicinity, and Model2Vec. Lightweight and fast, just how we like it."
244 | ]
245 | }
246 | ],
247 | "metadata": {
248 | "kernelspec": {
249 | "display_name": "Python 3 (ipykernel)",
250 | "language": "python",
251 | "name": "python3"
252 | },
253 | "language_info": {
254 | "codemirror_mode": {
255 | "name": "ipython",
256 | "version": 3
257 | },
258 | "file_extension": ".py",
259 | "mimetype": "text/x-python",
260 | "name": "python",
261 | "nbconvert_exporter": "python",
262 | "pygments_lexer": "ipython3",
263 | "version": "3.10.15"
264 | }
265 | },
266 | "nbformat": 4,
267 | "nbformat_minor": 4
268 | }
269 |
--------------------------------------------------------------------------------