├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yaml │ ├── feature-request.yml │ └── model-request.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── ci.yml │ ├── python-publish.yml │ ├── python-tests.yml │ └── type-checkers.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── RELEASE.md ├── docs ├── Getting Started.ipynb ├── assets │ └── favicon.png ├── examples │ ├── ColBERT_with_FastEmbed.ipynb │ ├── FastEmbed_GPU.ipynb │ ├── FastEmbed_Multi_GPU.ipynb │ ├── FastEmbed_vs_HF_Comparison.ipynb │ ├── Hindi_Tamil_RAG_with_Navarasa7B.ipynb │ ├── Hybrid_Search.ipynb │ ├── Image_Embedding.ipynb │ ├── SPLADE_with_FastEmbed.ipynb │ └── Supported_Models.ipynb ├── experimental │ ├── Accuracy_vs_SamplingRate.png │ └── Binary Quantization from Scratch.ipynb ├── index.md ├── overrides │ └── main.html └── qdrant │ ├── Binary_Quantization_with_Qdrant.ipynb │ ├── Retrieval_with_FastEmbed.ipynb │ └── Usage_With_Qdrant.ipynb ├── experiments ├── 01_ONNX_Port.ipynb ├── 02_SPLADE_to_ONNX.ipynb ├── Example. Convert Resnet50 to ONNX.ipynb ├── Throughput_Across_Models.ipynb ├── attention_export.py └── try_attention_export.py ├── fastembed ├── __init__.py ├── common │ ├── __init__.py │ ├── model_description.py │ ├── model_management.py │ ├── onnx_model.py │ ├── preprocessor_utils.py │ ├── types.py │ └── utils.py ├── embedding.py ├── image │ ├── __init__.py │ ├── image_embedding.py │ ├── image_embedding_base.py │ ├── onnx_embedding.py │ ├── onnx_image_model.py │ └── transform │ │ ├── functional.py │ │ └── operators.py ├── late_interaction │ ├── __init__.py │ ├── colbert.py │ ├── jina_colbert.py │ ├── late_interaction_embedding_base.py │ ├── late_interaction_text_embedding.py │ └── token_embeddings.py ├── late_interaction_multimodal │ ├── __init__.py │ ├── colpali.py │ ├── late_interaction_multimodal_embedding.py │ ├── late_interaction_multimodal_embedding_base.py │ └── onnx_multimodal_model.py ├── parallel_processor.py ├── py.typed ├── rerank │ └── cross_encoder │ │ ├── __init__.py │ │ ├── custom_text_cross_encoder.py │ │ ├── onnx_text_cross_encoder.py │ │ ├── onnx_text_model.py │ │ ├── text_cross_encoder.py │ │ └── text_cross_encoder_base.py ├── sparse │ ├── __init__.py │ ├── bm25.py │ ├── bm42.py │ ├── minicoil.py │ ├── sparse_embedding_base.py │ ├── sparse_text_embedding.py │ ├── splade_pp.py │ └── utils │ │ ├── minicoil_encoder.py │ │ ├── sparse_vectors_converter.py │ │ ├── tokenizer.py │ │ └── vocab_resolver.py └── text │ ├── __init__.py │ ├── clip_embedding.py │ ├── custom_text_embedding.py │ ├── multitask_embedding.py │ ├── onnx_embedding.py │ ├── onnx_text_model.py │ ├── pooled_embedding.py │ ├── pooled_normalized_embedding.py │ ├── text_embedding.py │ └── text_embedding_base.py ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── config.py ├── misc ├── image.jpeg └── small_image.jpeg ├── profiling.py ├── test_attention_embeddings.py ├── test_common.py ├── test_custom_models.py ├── test_image_onnx_embeddings.py ├── test_late_interaction_embeddings.py ├── test_late_interaction_multimodal.py ├── test_multi_gpu.py ├── test_sparse_embeddings.py ├── test_text_cross_encoder.py ├── test_text_multitask_embeddings.py ├── test_text_onnx_embeddings.py ├── type_stub.py └── utils.py /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: Bug 2 | description: File a bug report 3 | title: "[Bug]: " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this bug report! 9 | - type: textarea 10 | id: what-happened 11 | attributes: 12 | label: What happened? 13 | description: Describe the error you encountered. 14 | placeholder: 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: expected 19 | attributes: 20 | label: What is the expected behaviour? 21 | description: Describe the way you expected the code to behave. 22 | placeholder: 23 | - type: textarea 24 | id: code-snippet 25 | attributes: 26 | label: A minimal reproducible example 27 | description: It would really help us to fix the problem if you could provide a code snippet that reproduces the issue. 28 | placeholder: 29 | - type: textarea 30 | id: python-version 31 | attributes: 32 | label: What Python version are you on? e.g. python --version 33 | description: Also tell us, what package manager are you using e.g. conda, pip, poetry? 34 | placeholder: Python3.10 35 | validations: 36 | required: true 37 | - type: textarea 38 | id: version 39 | attributes: 40 | label: FastEmbed version 41 | description: What version of FastEmbed are you running? python -c "import fastembed; print(fastembed.__version__)". If you're not on the latest, please upgrade and see if the problem persists. 42 | placeholder: v0.5.1 43 | validations: 44 | required: true 45 | - type: dropdown 46 | id: os 47 | attributes: 48 | label: What os are you seeing the problem on? 49 | multiple: true 50 | options: 51 | - Linux 52 | - MacOS 53 | - Windows 54 | - type: textarea 55 | id: logs 56 | attributes: 57 | label: Relevant stack traces and/or logs 58 | description: Please copy and paste any relevant raised exceptions. This will be automatically formatted into code, so no need for backticks. 59 | render: shell 60 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yaml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: GitHub Community Support 4 | url: https://github.com/qdrant/fastembed/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: Feature 2 | description: New functionality request 3 | title: "[Feature]: " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this report! 9 | - type: textarea 10 | id: feature-description 11 | attributes: 12 | label: What feature would you like to request? 13 | description: Please provide the description of the feature you would like to request. 14 | placeholder: 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: additional-info 19 | attributes: 20 | label: Is there any additional information you would like to provide? 21 | description: Please provide any additional information that you think might be useful. 22 | placeholder: 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/model-request.yml: -------------------------------------------------------------------------------- 1 | name: Model 2 | description: Request a new model 3 | title: "[Model]: " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this report! 9 | - type: textarea 10 | id: model-name 11 | attributes: 12 | label: Which model would you like to support? 13 | description: Please provide the name of the model you would like to see supported. 14 | placeholder: Link to the model (e.g. on HuggingFace) 15 | validations: 16 | required: true 17 | - type: textarea 18 | id: motivation 19 | attributes: 20 | label: What are the main advantages of this model? 21 | description: Please describe the main advantages of this model comparing to the existing ones and provide links to benchmarks if there are any. 22 | placeholder: 23 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### All Submissions: 2 | 3 | * [ ] Have you followed the guidelines in our Contributing document? 4 | * [ ] Have you checked to ensure there aren't other open [Pull Requests](../../../pulls) for the same update/change? 5 | 6 | 7 | 8 | ### New Feature Submissions: 9 | 10 | * [ ] Does your submission pass the existing tests? 11 | * [ ] Have you added tests for your feature? 12 | * [ ] Have you installed `pre-commit` with `pip3 install pre-commit` and set up hooks with `pre-commit install`? 13 | 14 | ### New models submission: 15 | 16 | * [ ] Have you added an explanation of why it's important to include this model? 17 | * [ ] Have you added tests for the new model? Were canonical values for tests computed via the original model? 18 | * [ ] Have you added the code snippet for how canonical values were computed? 19 | * [ ] Have you successfully ran tests with your changes locally? 20 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | permissions: 8 | contents: write 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.x 17 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 18 | - uses: actions/cache@v3 19 | with: 20 | key: mkdocs-material-${{ env.cache_id }} 21 | path: .cache 22 | restore-keys: | 23 | mkdocs-material- 24 | - run: pip install mkdocs-material mkdocstrings==0.27.0 pillow cairosvg mknotebooks 25 | - run: mkdocs gh-deploy --force 26 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | workflow_dispatch: 13 | push: 14 | # Pattern matched against refs/tags 15 | tags: 16 | - 'v*' # Push events to every version tag 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: '3.9.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install poetry 32 | poetry install 33 | - name: Build package 34 | run: poetry build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | pull_request: 5 | branches: [ master, main, gpu ] 6 | workflow_dispatch: 7 | 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | test: 14 | 15 | strategy: 16 | matrix: 17 | python-version: 18 | - '3.9.x' 19 | - '3.10.x' 20 | - '3.11.x' 21 | - '3.12.x' 22 | - '3.13.x' 23 | os: 24 | - ubuntu-latest 25 | - macos-latest 26 | - windows-latest 27 | 28 | runs-on: ${{ matrix.os }} 29 | 30 | name: Python ${{ matrix.python-version }} on ${{ matrix.os }} test 31 | 32 | steps: 33 | - uses: actions/checkout@v3 34 | - name: Set up Python 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install poetry 41 | poetry config virtualenvs.create false 42 | poetry install --no-interaction --no-ansi --without dev,docs 43 | 44 | - name: Run pytest 45 | env: 46 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 47 | run: | 48 | poetry run pytest -------------------------------------------------------------------------------- /.github/workflows/type-checkers.yml: -------------------------------------------------------------------------------- 1 | name: type-checkers 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | fail-fast: true 10 | matrix: 11 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 12 | os: [ubuntu-latest] 13 | 14 | name: Python ${{ matrix.python-version }} test 15 | 16 | steps: 17 | - uses: actions/checkout@v1 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip poetry 27 | poetry install --no-interaction --no-ansi --without dev,docs,test 28 | 29 | - name: mypy 30 | run: | 31 | poetry run mypy fastembed \ 32 | --disallow-incomplete-defs \ 33 | --disallow-untyped-defs \ 34 | --disable-error-code=import-untyped 35 | 36 | - name: pyright 37 | run: | 38 | poetry run pyright tests/type_stub.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | .python-version 89 | 90 | .pdm.toml 91 | 92 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 93 | __pypackages__/ 94 | 95 | # Celery stuff 96 | celerybeat-schedule 97 | celerybeat.pid 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # pytype static type analyzer 130 | .pytype/ 131 | 132 | # Cython debug symbols 133 | cython_debug/ 134 | 135 | .idea/ 136 | .DS_Store 137 | *.tar.gz 138 | **/local_cache/ 139 | docs/experimental/*.parquet 140 | docs/experimental/*.bin 141 | qdrant_storage/* 142 | experiments/models/* 143 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.3.4 4 | hooks: 5 | - id: ruff 6 | types_or: [ python, pyi, jupyter ] 7 | args: [ --fix ] 8 | - id: ruff-format 9 | types_or: [ python, pyi, jupyter ] 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FastEmbed! 2 | 3 | :+1::tada: First off, thanks for taking the time to contribute! :tada::+1: 4 | 5 | The following is a set of guidelines for contributing to FastEmbed. These are mostly guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. 6 | 7 | ## Table Of Contents 8 | 9 | [I don't want to read this whole thing, I just have a question!!!](#i-dont-want-to-read-this-whole-thing-i-just-have-a-question) 10 | 11 | [How Can I Contribute?](#how-can-i-contribute) 12 | * [Your First Code Contribution](#your-first-code-contribution) 13 | * [Adding New Models](#adding-new-models) 14 | 15 | [Styleguides](#styleguides) 16 | * [Code Lint](#code-lint) 17 | * [Pre-Commit Hooks](#pre-commit-hooks) 18 | 19 | ## I don't want to read this whole thing I just have a question!!! 20 | 21 | > **Note:** Please don't file an issue to ask a question. You'll get faster results by using the resources below: 22 | 23 | * [FastEmbed Docs](https://qdrant.github.io/fastembed/) 24 | * [Qdrant Discord](https://discord.gg/Qy6HCJK9Dc) 25 | 26 | ## How Can I Contribute? 27 | 28 | ## How Do I Submit A (Good) Bug Report? 29 | 30 | Bugs are tracked as [GitHub issues](https://guides.github.com/features/issues/). 31 | 32 | Explain the problem and include additional details to help maintainers reproduce the problem: 33 | 34 | * **Use a clear and descriptive title** for the issue to identify the problem. 35 | * **Describe the exact steps which reproduce the problem** in as many details as possible. For example, start by explaining how you are using FastEmbed, e.g. with Langchain, Qdrant Client, Llama Index and which command exactly you used. When listing steps, **don't just say what you did, but explain how you did it**. 36 | * **Provide specific examples to demonstrate the steps**. Include links to files or GitHub projects, or copy/pasteable snippets, which you use in those examples. If you're providing snippets in the issue, use [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines). 37 | * **Describe the behavior you observed after following the steps** and point out what exactly is the problem with that behavior. 38 | * **Explain which behavior you expected to see instead and why.** 39 | * **If the problem is related to performance or memory**, include a [call stack profile capture](https://github.com/joerick/pyinstrument) and your observations. 40 | 41 | Include details about your configuration and environment: 42 | 43 | * **Which version of FastEmbed are you using?** You can get the exact version by running `python -c "import fastembed; print(fastembed.__version__)"`. 44 | * **What's the name and version of the OS you're using**? 45 | * **Which packages do you have installed?** You can get that list by running `pip freeze` 46 | 47 | ### Your First Code Contribution 48 | 49 | Unsure where to begin contributing to FastEmbed? You can start by looking through these `good-first-issue`issues: 50 | 51 | * [Good First Issue](https://github.com/qdrant/fastembed/labels/good%20first%20issue) - issues which should only require a few lines of code, and a test or two. These are a great way to get started with FastEmbed. This includes adding new models which are already tested and ready on Huggingface Hub. 52 | 53 | ## Pull Requests 54 | 55 | The best way to learn about the mechanics of FastEmbed is to start working on it. 56 | 57 | ### Your First Code Contribution 58 | Your first code contribution can be small bug fixes: 59 | 1. This PR adds a small bug fix for a single input: https://github.com/qdrant/fastembed/pull/148 60 | 2. This PR adds a check for the right file location and extension, specific to an OS: https://github.com/qdrant/fastembed/pull/128 61 | 62 | Even documentation improvements and tests are most welcome: 63 | 1. This PR fixes a README link: https://github.com/qdrant/fastembed/pull/143 64 | 65 | ### Adding New Models 66 | 1. Open Requests for New Models are [here](https://github.com/qdrant/fastembed/labels/model%20request). 67 | 2. There are quite a few pull requests that were merged for this purpose and you can use them as a reference. Here is an example: https://github.com/qdrant/fastembed/pull/129 68 | 3. Make sure to add tests for the new model 69 | - The CANONICAL_VECTOR values must come from a reference implementation usually from Huggingface Transformers or Sentence Transformers 70 | - Here is a reference [Colab Notebook](https://colab.research.google.com/drive/1tNdV3DsiwsJzu2AXnUnoeF5av1Hp8HF1?usp=sharing) for how we will evaluate whether your VECTOR values in the test are correct or not. 71 | 72 | ## Styleguides 73 | 74 | ### Code Lint 75 | We use ruff for code linting. It should be installed with poetry since it's a dev dependency. 76 | 77 | ### Pre-Commit Hooks 78 | We use pre-commit hooks to ensure that the code is linted before it's committed. You can install pre-commit hooks by running `pre-commit install` in the root directory of the project. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Qdrant 2 | 3 | This product includes software developed by Qdrant 4 | 5 | This distribution includes the following Jina AI models, each with its respective license: 6 | - jinaai/jina-colbert-v2 7 | - License: cc-by-nc-4.0 8 | - jinaai/jina-reranker-v2-base-multilingual 9 | - License: cc-by-nc-4.0 10 | - jinaai/jina-embeddings-v3 11 | - License: cc-by-nc-4.0 12 | 13 | These models are developed by Jina (https://jina.ai/) and are subject to Jina AI's licensing terms. 14 | 15 | This distribution includes the following Google models, each with its respective license: 16 | - vidore/colpali-v1.3 17 | - License: gemma 18 | 19 | Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms 20 | 21 | Additional Notes: 22 | This project also includes third-party libraries with their respective licenses. Please refer to the documentation of each library for details regarding its usage and licensing terms. 23 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Releasing FastEmbed 2 | 3 | This is a guide how to release `fastembed` and `fastembed-gpu` packages. 4 | 5 | ## How to 6 | 7 | 1. Accumulate changes in the `main` branch. 8 | 2. Bump the version in `pyproject.toml` 9 | 10 | 3. Rebase the `gpu` branch on `main` and resolve conflicts if occurred: 11 | 12 | ```bash 13 | git checkout gpu 14 | git rebase main 15 | git push -f origin gpu 16 | ``` 17 | 18 | 4. Draft release notes 19 | 5. Checkout to `main` and create a tag, e.g.: 20 | 21 | ```bash 22 | git checkout main 23 | git tag -a v0.1.0 -m "Release v0.1.0" 24 | ``` 25 | 26 | 6. Checkout `gpu` and create a tag, e.g.: 27 | 28 | ```bash 29 | git checkout gpu 30 | git tag -a v0.1.0-gpu -m "Release v0.1.0" 31 | ``` 32 | 33 | 7. Push tags: 34 | 35 | ```bash 36 | git push --tags 37 | ``` 38 | 39 | 8. Verify that both packages have been published successfully on PyPI. Try installing them and verify imports. 40 | 9. Create a release on GitHub with the written release notes. 41 | 42 | -------------------------------------------------------------------------------- /docs/assets/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/fastembed/a260022ae6059f4e7568cbd57cd6191cdaab8f33/docs/assets/favicon.png -------------------------------------------------------------------------------- /docs/examples/FastEmbed_Multi_GPU.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Fastembed Multi-GPU Tutorial\n", 8 | "This tutorial demonstrates how to leverage multi-GPU support in Fastembed. Fastembed supports embedding text and images utilizing modern GPUs for acceleration. Let's explore how to use Fastembed with multiple GPUs step by step." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "#### Prerequisites\n", 16 | "To get started, ensure you have the following installed:\n", 17 | "- Python 3.9 or later\n", 18 | "- Fastembed (`pip install fastembed-gpu`)\n", 19 | "- Refer to [this](https://github.com/qdrant/fastembed/blob/main/docs/examples/FastEmbed_GPU.ipynb) tutorial if you have issues with GPU dependencies\n", 20 | "- Access to a multi-GPU server" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "### Multi-GPU using cuda argument with TextEmbedding Model" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from fastembed import TextEmbedding\n", 37 | "\n", 38 | "# define the documents to embed\n", 39 | "docs = [\"hello world\", \"flag embedding\"] * 100\n", 40 | "\n", 41 | "# define gpu ids\n", 42 | "device_ids = [0, 1]\n", 43 | "\n", 44 | "if __name__ == \"__main__\":\n", 45 | " # initialize a TextEmbedding model using CUDA\n", 46 | " text_model = TextEmbedding(\n", 47 | " model_name=\"sentence-transformers/all-MiniLM-L6-v2\",\n", 48 | " cuda=True,\n", 49 | " device_ids=device_ids,\n", 50 | " lazy_load=True,\n", 51 | " )\n", 52 | "\n", 53 | " # generate embeddings\n", 54 | " text_embeddings = list(text_model.embed(docs, batch_size=2, parallel=len(device_ids)))\n", 55 | " print(text_embeddings)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "In this snippet:\n", 63 | "- `cuda=True` enables GPU acceleration.\n", 64 | "- `device_ids=[0, 1]` specifies GPUs to use. Replace `[0, 1]` with available GPU IDs.\n", 65 | "- `lazy_load=True`\n", 66 | "\n", 67 | "**NOTE**: When using multi-GPU settings, it is important to configure `parallel` and `lazy_load` properly to avoid inefficiencies:\n", 68 | "\n", 69 | "`parallel`: This parameter enables multi-GPU support by spawning child processes for each GPU specified in device_ids. To ensure proper utilization, the value of `parallel` must match the number of GPUs in device_ids. If using a single GPU, this parameter is not necessary.\n", 70 | "\n", 71 | "`lazy_load`: Enabling `lazy_load` prevents redundant memory usage. Without `lazy_load`, the model is initially loaded into the memory of the first GPU by the main process. When child processes are spawned for each GPU, the model is reloaded on the first GPU, causing redundant memory consumption and inefficiencies." 72 | ] 73 | } 74 | ], 75 | "metadata": { 76 | "kernelspec": { 77 | "display_name": ".venv", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "name": "python", 83 | "version": "3.10.15" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 2 88 | } 89 | -------------------------------------------------------------------------------- /docs/examples/Image_Embedding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "aa0a86859809102", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "source": [ 10 | "# Image Embedding\n", 11 | "As of version 0.3.0 fastembed supports computation of image embeddings.\n", 12 | "\n", 13 | "The process is as easy and straightforward as with text embeddings. Let's see how it works." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 5, 19 | "id": "cea8fd5c019571fe", 20 | "metadata": { 21 | "ExecuteTime": { 22 | "end_time": "2024-06-02T11:35:40.126023Z", 23 | "start_time": "2024-06-02T11:35:39.864701Z" 24 | }, 25 | "collapsed": false 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 47482.69it/s]\n" 33 | ] 34 | }, 35 | { 36 | "data": { 37 | "text/plain": "[array([0. , 0. , 0. , ..., 0. , 0.01139933,\n 0. ], dtype=float32),\n array([0.02169187, 0. , 0. , ..., 0. , 0.00848291,\n 0. ], dtype=float32)]" 38 | }, 39 | "execution_count": 5, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "from fastembed import ImageEmbedding\n", 46 | "\n", 47 | "model = ImageEmbedding(\"Qdrant/resnet50-onnx\")\n", 48 | "\n", 49 | "embeddings_generator = model.embed(\n", 50 | " [\"../../tests/misc/image.jpeg\", \"../../tests/misc/small_image.jpeg\"]\n", 51 | ")\n", 52 | "embeddings_list = list(embeddings_generator)\n", 53 | "embeddings_list" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "3f838f18523ad1e0", 59 | "metadata": { 60 | "collapsed": false 61 | }, 62 | "source": [ 63 | "## Preprocessing\n", 64 | "\n", 65 | "Preprocessing is encapsulated in the ImageEmbedding class, applied operations are identical to the ones provided by [Hugging Face Transformers](https://huggingface.co/docs/transformers/en/index).\n", 66 | "You don't need to think about batching, opening/closing files, resizing images, etc., Fastembed will take care of it." 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "894b33ff9b385d72", 72 | "metadata": { 73 | "collapsed": false 74 | }, 75 | "source": [ 76 | "## Supported models\n", 77 | "\n", 78 | "List of supported image embedding models can either be found [here](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-image-embedding-models) or by calling the `ImageEmbedding.list_supported_models()` method." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "id": "6d6a4cbbd2200d14", 85 | "metadata": { 86 | "ExecuteTime": { 87 | "end_time": "2024-06-02T11:40:19.313226Z", 88 | "start_time": "2024-06-02T11:40:19.309845Z" 89 | }, 90 | "collapsed": false 91 | }, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": "[{'model': 'Qdrant/clip-ViT-B-32-vision',\n 'dim': 512,\n 'description': 'CLIP vision encoder based on ViT-B/32',\n 'size_in_GB': 0.34,\n 'sources': {'hf': 'Qdrant/clip-ViT-B-32-vision'},\n 'model_file': 'model.onnx'},\n {'model': 'Qdrant/resnet50-onnx',\n 'dim': 2048,\n 'description': 'ResNet-50 from `Deep Residual Learning for Image Recognition `__.',\n 'size_in_GB': 0.1,\n 'sources': {'hf': 'Qdrant/resnet50-onnx'},\n 'model_file': 'model.onnx'}]" 96 | }, 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "ImageEmbedding.list_supported_models()" 104 | ] 105 | } 106 | ], 107 | "metadata": { 108 | "kernelspec": { 109 | "display_name": "Python 3", 110 | "language": "python", 111 | "name": "python3" 112 | }, 113 | "language_info": { 114 | "codemirror_mode": { 115 | "name": "ipython", 116 | "version": 2 117 | }, 118 | "file_extension": ".py", 119 | "mimetype": "text/x-python", 120 | "name": "python", 121 | "nbconvert_exporter": "python", 122 | "pygments_lexer": "ipython2", 123 | "version": "2.7.6" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 5 128 | } 129 | -------------------------------------------------------------------------------- /docs/experimental/Accuracy_vs_SamplingRate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/fastembed/a260022ae6059f4e7568cbd57cd6191cdaab8f33/docs/experimental/Accuracy_vs_SamplingRate.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # ⚡️ What is FastEmbed? 2 | 3 | FastEmbed is a lightweight, fast, Python library built for embedding generation. We [support popular text models](https://qdrant.github.io/fastembed/examples/Supported_Models/). Please [open a Github issue](https://github.com/qdrant/fastembed/issues/new) if you want us to add a new model. 4 | 5 | 1. Light & Fast 6 | - Quantized model weights 7 | - ONNX Runtime for inference 8 | 9 | 2. Accuracy/Recall 10 | - Better than OpenAI Ada-002 11 | - Default is Flag Embedding, which has shown good results on the [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard 12 | - List of [supported models](https://qdrant.github.io/fastembed/examples/Supported_Models/) - including multilingual models 13 | 14 | Here is an example for [Retrieval Embedding Generation](https://qdrant.github.io/fastembed/examples/Retrieval%20with%20FastEmbed/) and how to use [FastEmbed with Qdrant](https://qdrant.github.io/fastembed/examples/Usage_With_Qdrant/). 15 | 16 | ## 🚀 Installation 17 | 18 | To install the FastEmbed library, pip works: 19 | 20 | ```bash 21 | pip install fastembed 22 | ``` 23 | 24 | ## 📖 Usage 25 | 26 | ```python 27 | from fastembed import TextEmbedding 28 | 29 | documents: list[str] = [ 30 | "passage: Hello, World!", 31 | "query: Hello, World!", 32 | "passage: This is an example passage.", 33 | "fastembed is supported by and maintained by Qdrant." 34 | ] 35 | embedding_model = TextEmbedding() 36 | embeddings: list[np.ndarray] = embedding_model.embed(documents) 37 | ``` 38 | 39 | ## Usage with Qdrant 40 | 41 | Installation with Qdrant Client in Python: 42 | 43 | ```bash 44 | pip install qdrant-client[fastembed] 45 | ``` 46 | 47 | Might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh. 48 | 49 | ```python 50 | from qdrant_client import QdrantClient 51 | 52 | # Initialize the client 53 | client = QdrantClient(":memory:") # Using an in-process Qdrant 54 | 55 | # Prepare your documents, metadata, and IDs 56 | docs = ["Qdrant has Langchain integrations", "Qdrant also has Llama Index integrations"] 57 | metadata = [ 58 | {"source": "Langchain-docs"}, 59 | {"source": "Llama-index-docs"}, 60 | ] 61 | ids = [42, 2] 62 | 63 | client.add( 64 | collection_name="demo_collection", 65 | documents=docs, 66 | metadata=metadata, 67 | ids=ids 68 | ) 69 | 70 | search_result = client.query( 71 | collection_name="demo_collection", 72 | query_text="This is a query document" 73 | ) 74 | print(search_result) 75 | ``` 76 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | {% if page.nb_url %} 5 | 8 | {% endif %} 9 | 10 | {{ super() }} 11 | 12 | {% endblock content %} 13 | 14 | {% block announce %} 15 |
16 | If you're using FastEmbed from Qdrant, join the 17 | 18 | 19 | {% include ".icons/fontawesome/brands/discord.svg" %} 20 | 21 | Qdrant Discord server 22 | 23 | to get help and share your work! Or check out Qdrant Cloud to 25 | get started with vector search! 26 |
27 | {% endblock %} 28 | -------------------------------------------------------------------------------- /docs/qdrant/Retrieval_with_FastEmbed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ⚓️ Retrieval with FastEmbed\n", 8 | "\n", 9 | "This notebook demonstrates how to use FastEmbed to perform vector search and retrieval. It consists of the following sections:\n", 10 | "\n", 11 | "1. Setup: Installing the necessary packages.\n", 12 | "2. Importing Libraries: Importing FastEmbed and other libraries.\n", 13 | "3. Data Preparation: Example data and embedding generation.\n", 14 | "4. Querying: Defining a function to search documents based on a query.\n", 15 | "5. Running Queries: Running example queries.\n", 16 | "\n", 17 | "## Setup\n", 18 | "\n", 19 | "First, we need to install the dependencies. `fastembed` to create embeddings and perform retrieval." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# !pip install fastembed --quiet --upgrade" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Importing the necessary libraries:" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import numpy as np\n", 45 | "from fastembed import TextEmbedding" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Data Preparation\n", 53 | "We initialize the embedding model and generate embeddings for the documents.\n", 54 | "\n", 55 | "### 💡 Tip: Prefer using `query_embed` for queries and `passage_embed` for documents." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "(384,) 10\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# Example list of documents\n", 73 | "documents: list[str] = [\n", 74 | " \"Maharana Pratap was a Rajput warrior king from Mewar\",\n", 75 | " \"He fought against the Mughal Empire led by Akbar\",\n", 76 | " \"The Battle of Haldighati in 1576 was his most famous battle\",\n", 77 | " \"He refused to submit to Akbar and continued guerrilla warfare\",\n", 78 | " \"His capital was Chittorgarh, which he lost to the Mughals\",\n", 79 | " \"He died in 1597 at the age of 57\",\n", 80 | " \"Maharana Pratap is considered a symbol of Rajput resistance against foreign rule\",\n", 81 | " \"His legacy is celebrated in Rajasthan through festivals and monuments\",\n", 82 | " \"He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar\",\n", 83 | " \"His life has been depicted in various films, TV shows, and books\",\n", 84 | "]\n", 85 | "# Initialize the DefaultEmbedding class with the desired parameters\n", 86 | "embedding_model = TextEmbedding(model_name=\"BAAI/bge-small-en\")\n", 87 | "\n", 88 | "# We'll use the passage_embed method to get the embeddings for the documents\n", 89 | "embeddings: list[np.ndarray] = list(\n", 90 | " embedding_model.passage_embed(documents)\n", 91 | ") # notice that we are casting the generator to a list\n", 92 | "\n", 93 | "print(embeddings[0].shape, len(embeddings))" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## Querying\n", 101 | "\n", 102 | "We'll define a function to print the top k documents based on a query, and prepare a sample query." 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 4, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "query = \"Who was Maharana Pratap?\"\n", 112 | "query_embedding = list(embedding_model.query_embed(query))[0]\n", 113 | "plain_query_embedding = list(embedding_model.embed(query))[0]\n", 114 | "\n", 115 | "\n", 116 | "def print_top_k(query_embedding, embeddings, documents, k=5):\n", 117 | " # use numpy to calculate the cosine similarity between the query and the documents\n", 118 | " scores = np.dot(embeddings, query_embedding)\n", 119 | " # sort the scores in descending order\n", 120 | " sorted_scores = np.argsort(scores)[::-1]\n", 121 | " # print the top 5\n", 122 | " for i in range(k):\n", 123 | " print(f\"Rank {i+1}: {documents[sorted_scores[i]]}\")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/plain": [ 134 | "(array([-0.06002192, 0.04322132, -0.00545516, -0.04419701, -0.00542277],\n", 135 | " dtype=float32),\n", 136 | " array([-0.06002192, 0.04322132, -0.00545516, -0.04419701, -0.00542277],\n", 137 | " dtype=float32))" 138 | ] 139 | }, 140 | "execution_count": 5, 141 | "metadata": {}, 142 | "output_type": "execute_result" 143 | } 144 | ], 145 | "source": [ 146 | "query_embedding[:5], plain_query_embedding[:5]" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "The `query_embed` is specifically designed for queries, leading to more relevant and context-aware results. The retrieved documents tend to align closely with the query's intent.\n", 154 | "\n", 155 | "In contrast, `embed` is a more general-purpose representation that might not capture the nuances of the query as effectively. The retrieved documents using plain embeddings might be less relevant or ordered differently compared to the results obtained using query embeddings.\n", 156 | "\n", 157 | "Conclusion: Using query and passage embeddings leads to more relevant and context-aware results." 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "fst", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.10.13" 178 | }, 179 | "orig_nbformat": 4 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 2 183 | } 184 | -------------------------------------------------------------------------------- /experiments/Example. Convert Resnet50 to ONNX.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4bdb2a91-fa2a-4cee-ad5a-176cc957394d", 7 | "metadata": { 8 | "ExecuteTime": { 9 | "end_time": "2024-05-23T12:15:28.171586Z", 10 | "start_time": "2024-05-23T12:15:28.076314Z" 11 | } 12 | }, 13 | "outputs": [ 14 | { 15 | "ename": "ModuleNotFoundError", 16 | "evalue": "No module named 'torch'", 17 | "output_type": "error", 18 | "traceback": [ 19 | "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", 20 | "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", 21 | "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[1;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorch\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01monnx\u001B[39;00m\n\u001B[1;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtorchvision\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m\n", 22 | "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'torch'" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import torch\n", 28 | "import torch.onnx\n", 29 | "import torchvision.models as models\n", 30 | "import torchvision.transforms as transforms\n", 31 | "from PIL import Image\n", 32 | "import numpy as np\n", 33 | "from tests.config import TEST_MISC_DIR\n", 34 | "\n", 35 | "# Load pre-trained ResNet-50 model\n", 36 | "resnet = models.resnet50(pretrained=True)\n", 37 | "resnet = torch.nn.Sequential(*(list(resnet.children())[:-1])) # Remove the last fully connected layer\n", 38 | "resnet.eval()\n", 39 | "\n", 40 | "# Define preprocessing transform\n", 41 | "preprocess = transforms.Compose([\n", 42 | " transforms.Resize(256),\n", 43 | " transforms.CenterCrop(224),\n", 44 | " transforms.ToTensor(),\n", 45 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 46 | "])\n", 47 | "\n", 48 | "# Load and preprocess the image\n", 49 | "def preprocess_image(image_path):\n", 50 | " input_image = Image.open(image_path)\n", 51 | " input_tensor = preprocess(input_image)\n", 52 | " input_batch = input_tensor.unsqueeze(0) # Add batch dimension\n", 53 | " return input_batch\n", 54 | "\n", 55 | "# Example input for exporting\n", 56 | "input_image = preprocess_image('example.jpg')\n", 57 | "\n", 58 | "# Export the model to ONNX with dynamic axes\n", 59 | "torch.onnx.export(\n", 60 | " resnet, \n", 61 | " input_image, \n", 62 | " \"model.onnx\", \n", 63 | " export_params=True, \n", 64 | " opset_version=9, \n", 65 | " input_names=['input'], \n", 66 | " output_names=['output'],\n", 67 | " dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}\n", 68 | ")\n", 69 | "\n", 70 | "# Load ONNX model\n", 71 | "import onnx\n", 72 | "import onnxruntime as ort\n", 73 | "\n", 74 | "onnx_model = onnx.load(\"model.onnx\")\n", 75 | "ort_session = ort.InferenceSession(\"model.onnx\")\n", 76 | "\n", 77 | "# Run inference and extract feature vectors\n", 78 | "def extract_feature_vectors(image_paths):\n", 79 | " input_images = [preprocess_image(image_path) for image_path in image_paths]\n", 80 | " input_batch = torch.cat(input_images, dim=0) # Combine images into a single batch\n", 81 | " ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}\n", 82 | " ort_outs = ort_session.run(None, ort_inputs)\n", 83 | " return ort_outs[0]\n", 84 | "\n", 85 | "# Example usage\n", 86 | "images = [TEST_MISC_DIR / \"image.jpeg\", str(TEST_MISC_DIR / \"small_image.jpeg\")] # Replace with your image paths\n", 87 | "feature_vectors = extract_feature_vectors(images)\n", 88 | "print(\"Feature vector shape:\", feature_vectors.shape)\n" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "outputs": [], 94 | "source": [], 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "id": "baa650c4cb3e0e6d" 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "Python 3 (ipykernel)", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.12.2" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 5 122 | } 123 | -------------------------------------------------------------------------------- /experiments/attention_export.py: -------------------------------------------------------------------------------- 1 | from optimum.exporters.onnx import main_export 2 | from transformers import AutoTokenizer 3 | 4 | model_id = "sentence-transformers/paraphrase-MiniLM-L6-v2" 5 | output_dir = f"models/{model_id.replace('/', '_')}" 6 | model_kwargs = {"output_attentions": True, "return_dict": True} 7 | tokenizer = AutoTokenizer.from_pretrained(model_id) 8 | 9 | # export if the output model does not exist 10 | # try: 11 | # sess = onnxruntime.InferenceSession(f"{output_dir}/model.onnx") 12 | # print("Model already exported") 13 | # except FileNotFoundError: 14 | print(f"Exporting model to {output_dir}") 15 | main_export( 16 | model_id, output=output_dir, no_post_process=True, model_kwargs=model_kwargs 17 | ) 18 | -------------------------------------------------------------------------------- /experiments/try_attention_export.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import onnxruntime 4 | from transformers import AutoTokenizer 5 | 6 | model_id = "sentence-transformers/paraphrase-MiniLM-L6-v2" 7 | output_dir = f"models/{model_id.replace('/', '_')}" 8 | model_kwargs = {"output_attentions": True, "return_dict": True} 9 | tokenizer = AutoTokenizer.from_pretrained(model_id) 10 | 11 | model_path = f"{output_dir}/model.onnx" 12 | onnx_model = onnx.load(model_path) 13 | ort_session = onnxruntime.InferenceSession(model_path) 14 | text = "This is a test sentence" 15 | tokenizer_output = tokenizer(text, return_tensors="np") 16 | input_ids = tokenizer_output["input_ids"] 17 | attention_mask = tokenizer_output["attention_mask"] 18 | print(attention_mask) 19 | # Prepare the input 20 | input_ids = np.array(input_ids).astype( 21 | np.int64 22 | ) # Replace your_input_ids with actual input data 23 | 24 | # Run the ONNX model 25 | outputs = ort_session.run( 26 | None, {"input_ids": input_ids, "attention_mask": attention_mask} 27 | ) 28 | 29 | # Get the attention weights 30 | attentions = outputs[-1] 31 | 32 | # Print the attention weights for the first layer and first head 33 | print(attentions[0][0]) 34 | -------------------------------------------------------------------------------- /fastembed/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from fastembed.image import ImageEmbedding 4 | from fastembed.late_interaction import LateInteractionTextEmbedding 5 | from fastembed.late_interaction_multimodal import LateInteractionMultimodalEmbedding 6 | from fastembed.sparse import SparseEmbedding, SparseTextEmbedding 7 | from fastembed.text import TextEmbedding 8 | 9 | try: 10 | version = importlib.metadata.version("fastembed") 11 | except importlib.metadata.PackageNotFoundError as _: 12 | version = importlib.metadata.version("fastembed-gpu") 13 | 14 | __version__ = version 15 | __all__ = [ 16 | "TextEmbedding", 17 | "SparseTextEmbedding", 18 | "SparseEmbedding", 19 | "ImageEmbedding", 20 | "LateInteractionTextEmbedding", 21 | "LateInteractionMultimodalEmbedding", 22 | ] 23 | -------------------------------------------------------------------------------- /fastembed/common/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.common.types import ImageInput, OnnxProvider, PathInput 2 | 3 | __all__ = ["OnnxProvider", "ImageInput", "PathInput"] 4 | -------------------------------------------------------------------------------- /fastembed/common/model_description.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | from typing import Optional, Any 4 | 5 | 6 | @dataclass(frozen=True) 7 | class ModelSource: 8 | hf: Optional[str] = None 9 | url: Optional[str] = None 10 | _deprecated_tar_struct: bool = False 11 | 12 | @property 13 | def deprecated_tar_struct(self) -> bool: 14 | return self._deprecated_tar_struct 15 | 16 | def __post_init__(self) -> None: 17 | if self.hf is None and self.url is None: 18 | raise ValueError( 19 | f"At least one source should be set, current sources: hf={self.hf}, url={self.url}" 20 | ) 21 | 22 | 23 | @dataclass(frozen=True) 24 | class BaseModelDescription: 25 | model: str 26 | sources: ModelSource 27 | model_file: str 28 | description: str 29 | license: str 30 | size_in_GB: float 31 | additional_files: list[str] = field(default_factory=list) 32 | 33 | 34 | @dataclass(frozen=True) 35 | class DenseModelDescription(BaseModelDescription): 36 | dim: Optional[int] = None 37 | tasks: Optional[dict[str, Any]] = field(default_factory=dict) 38 | 39 | def __post_init__(self) -> None: 40 | assert self.dim is not None, "dim is required for dense model description" 41 | 42 | 43 | @dataclass(frozen=True) 44 | class SparseModelDescription(BaseModelDescription): 45 | requires_idf: Optional[bool] = None 46 | vocab_size: Optional[int] = None 47 | 48 | 49 | class PoolingType(str, Enum): 50 | CLS = "CLS" 51 | MEAN = "MEAN" 52 | DISABLED = "DISABLED" 53 | -------------------------------------------------------------------------------- /fastembed/common/onnx_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Any, Generic, Iterable, Optional, Sequence, Type, TypeVar 5 | 6 | import numpy as np 7 | import onnxruntime as ort 8 | 9 | from numpy.typing import NDArray 10 | from tokenizers import Tokenizer 11 | 12 | from fastembed.common.types import OnnxProvider, NumpyArray 13 | from fastembed.parallel_processor import Worker 14 | 15 | # Holds type of the embedding result 16 | T = TypeVar("T") 17 | 18 | 19 | @dataclass 20 | class OnnxOutputContext: 21 | model_output: NumpyArray 22 | attention_mask: Optional[NDArray[np.int64]] = None 23 | input_ids: Optional[NDArray[np.int64]] = None 24 | 25 | 26 | class OnnxModel(Generic[T]): 27 | @classmethod 28 | def _get_worker_class(cls) -> Type["EmbeddingWorker[T]"]: 29 | raise NotImplementedError("Subclasses must implement this method") 30 | 31 | def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]: 32 | """Post-process the ONNX model output to convert it into a usable format. 33 | 34 | Args: 35 | output (OnnxOutputContext): The raw output from the ONNX model. 36 | **kwargs: Additional keyword arguments that may be needed by specific implementations. 37 | 38 | Returns: 39 | Iterable[T]: Post-processed output as an iterable of type T. 40 | """ 41 | raise NotImplementedError("Subclasses must implement this method") 42 | 43 | def __init__(self) -> None: 44 | self.model: Optional[ort.InferenceSession] = None 45 | self.tokenizer: Optional[Tokenizer] = None 46 | 47 | def _preprocess_onnx_input( 48 | self, onnx_input: dict[str, NumpyArray], **kwargs: Any 49 | ) -> dict[str, NumpyArray]: 50 | """ 51 | Preprocess the onnx input. 52 | """ 53 | return onnx_input 54 | 55 | def _load_onnx_model( 56 | self, 57 | model_dir: Path, 58 | model_file: str, 59 | threads: Optional[int], 60 | providers: Optional[Sequence[OnnxProvider]] = None, 61 | cuda: bool = False, 62 | device_id: Optional[int] = None, 63 | ) -> None: 64 | model_path = model_dir / model_file 65 | # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers 66 | 67 | if cuda and providers is not None: 68 | warnings.warn( 69 | f"`cuda` and `providers` are mutually exclusive parameters, cuda: {cuda}, providers: {providers}", 70 | category=UserWarning, 71 | stacklevel=6, 72 | ) 73 | 74 | if providers is not None: 75 | onnx_providers = list(providers) 76 | elif cuda: 77 | if device_id is None: 78 | onnx_providers = ["CUDAExecutionProvider"] 79 | else: 80 | onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})] 81 | else: 82 | onnx_providers = ["CPUExecutionProvider"] 83 | 84 | available_providers = ort.get_available_providers() 85 | requested_provider_names: list[str] = [] 86 | for provider in onnx_providers: 87 | # check providers available 88 | provider_name = provider if isinstance(provider, str) else provider[0] 89 | requested_provider_names.append(provider_name) 90 | if provider_name not in available_providers: 91 | raise ValueError( 92 | f"Provider {provider_name} is not available. Available providers: {available_providers}" 93 | ) 94 | 95 | so = ort.SessionOptions() 96 | so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL 97 | 98 | if threads is not None: 99 | so.intra_op_num_threads = threads 100 | so.inter_op_num_threads = threads 101 | 102 | self.model = ort.InferenceSession( 103 | str(model_path), providers=onnx_providers, sess_options=so 104 | ) 105 | if "CUDAExecutionProvider" in requested_provider_names: 106 | assert self.model is not None 107 | current_providers = self.model.get_providers() 108 | if "CUDAExecutionProvider" not in current_providers: 109 | warnings.warn( 110 | f"Attempt to set CUDAExecutionProvider failed. Current providers: {current_providers}." 111 | "If you are using CUDA 12.x, install onnxruntime-gpu via " 112 | "`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`", 113 | RuntimeWarning, 114 | ) 115 | 116 | def load_onnx_model(self) -> None: 117 | raise NotImplementedError("Subclasses must implement this method") 118 | 119 | def onnx_embed(self, *args: Any, **kwargs: Any) -> OnnxOutputContext: 120 | raise NotImplementedError("Subclasses must implement this method") 121 | 122 | 123 | class EmbeddingWorker(Worker, Generic[T]): 124 | def init_embedding( 125 | self, 126 | model_name: str, 127 | cache_dir: str, 128 | **kwargs: Any, 129 | ) -> OnnxModel[T]: 130 | raise NotImplementedError() 131 | 132 | def __init__( 133 | self, 134 | model_name: str, 135 | cache_dir: str, 136 | **kwargs: Any, 137 | ): 138 | self.model = self.init_embedding(model_name, cache_dir, **kwargs) 139 | 140 | @classmethod 141 | def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker[T]": 142 | return cls(model_name=model_name, cache_dir=cache_dir, **kwargs) 143 | 144 | def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: 145 | raise NotImplementedError("Subclasses must implement this method") 146 | -------------------------------------------------------------------------------- /fastembed/common/preprocessor_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | from pathlib import Path 4 | 5 | from tokenizers import AddedToken, Tokenizer 6 | 7 | from fastembed.image.transform.operators import Compose 8 | 9 | 10 | def load_special_tokens(model_dir: Path) -> dict[str, Any]: 11 | tokens_map_path = model_dir / "special_tokens_map.json" 12 | if not tokens_map_path.exists(): 13 | raise ValueError(f"Could not find special_tokens_map.json in {model_dir}") 14 | 15 | with open(str(tokens_map_path)) as tokens_map_file: 16 | tokens_map = json.load(tokens_map_file) 17 | 18 | return tokens_map 19 | 20 | 21 | def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]: 22 | config_path = model_dir / "config.json" 23 | if not config_path.exists(): 24 | raise ValueError(f"Could not find config.json in {model_dir}") 25 | 26 | tokenizer_path = model_dir / "tokenizer.json" 27 | if not tokenizer_path.exists(): 28 | raise ValueError(f"Could not find tokenizer.json in {model_dir}") 29 | 30 | tokenizer_config_path = model_dir / "tokenizer_config.json" 31 | if not tokenizer_config_path.exists(): 32 | raise ValueError(f"Could not find tokenizer_config.json in {model_dir}") 33 | 34 | with open(str(config_path)) as config_file: 35 | config = json.load(config_file) 36 | 37 | with open(str(tokenizer_config_path)) as tokenizer_config_file: 38 | tokenizer_config = json.load(tokenizer_config_file) 39 | assert "model_max_length" in tokenizer_config or "max_length" in tokenizer_config, ( 40 | "Models without model_max_length or max_length are not supported." 41 | ) 42 | if "model_max_length" not in tokenizer_config: 43 | max_context = tokenizer_config["max_length"] 44 | elif "max_length" not in tokenizer_config: 45 | max_context = tokenizer_config["model_max_length"] 46 | else: 47 | max_context = min(tokenizer_config["model_max_length"], tokenizer_config["max_length"]) 48 | 49 | tokens_map = load_special_tokens(model_dir) 50 | 51 | tokenizer = Tokenizer.from_file(str(tokenizer_path)) 52 | tokenizer.enable_truncation(max_length=max_context) 53 | tokenizer.enable_padding( 54 | pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"] 55 | ) 56 | 57 | for token in tokens_map.values(): 58 | if isinstance(token, str): 59 | tokenizer.add_special_tokens([token]) 60 | elif isinstance(token, dict): 61 | tokenizer.add_special_tokens([AddedToken(**token)]) 62 | 63 | special_token_to_id: dict[str, int] = {} 64 | 65 | for token in tokens_map.values(): 66 | if isinstance(token, str): 67 | special_token_to_id[token] = tokenizer.token_to_id(token) 68 | elif isinstance(token, dict): 69 | token_str = token.get("content", "") 70 | special_token_to_id[token_str] = tokenizer.token_to_id(token_str) 71 | 72 | return tokenizer, special_token_to_id 73 | 74 | 75 | def load_preprocessor(model_dir: Path) -> Compose: 76 | preprocessor_config_path = model_dir / "preprocessor_config.json" 77 | if not preprocessor_config_path.exists(): 78 | raise ValueError(f"Could not find preprocessor_config.json in {model_dir}") 79 | 80 | with open(str(preprocessor_config_path)) as preprocessor_config_file: 81 | preprocessor_config = json.load(preprocessor_config_file) 82 | transforms = Compose.from_config(preprocessor_config) 83 | return transforms 84 | -------------------------------------------------------------------------------- /fastembed/common/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | from PIL import Image 4 | from typing import Any, Union 5 | import numpy as np 6 | from numpy.typing import NDArray 7 | 8 | if sys.version_info >= (3, 10): 9 | from typing import TypeAlias 10 | else: 11 | from typing_extensions import TypeAlias 12 | 13 | 14 | PathInput: TypeAlias = Union[str, Path] 15 | ImageInput: TypeAlias = Union[PathInput, Image.Image] 16 | 17 | OnnxProvider: TypeAlias = Union[str, tuple[str, dict[Any, Any]]] 18 | NumpyArray = Union[ 19 | NDArray[np.float64], 20 | NDArray[np.float32], 21 | NDArray[np.float16], 22 | NDArray[np.int8], 23 | NDArray[np.int64], 24 | NDArray[np.int32], 25 | ] 26 | -------------------------------------------------------------------------------- /fastembed/common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import tempfile 5 | import unicodedata 6 | from pathlib import Path 7 | from itertools import islice 8 | from typing import Iterable, Optional, TypeVar 9 | 10 | import numpy as np 11 | from numpy.typing import NDArray 12 | 13 | from fastembed.common.types import NumpyArray 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray: 19 | # Calculate the Lp norm along the specified dimension 20 | norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True) 21 | norm = np.maximum(norm, eps) # Avoid division by zero 22 | normalized_array = input_array / norm 23 | return normalized_array 24 | 25 | 26 | def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray: 27 | input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64) 28 | input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1])) 29 | sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1) 30 | sum_mask = np.sum(input_mask_expanded, axis=1) 31 | pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9) 32 | return pooled_embeddings 33 | 34 | 35 | def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]: 36 | """ 37 | >>> list(iter_batch([1,2,3,4,5], 3)) 38 | [[1, 2, 3], [4, 5]] 39 | """ 40 | source_iter = iter(iterable) 41 | while source_iter: 42 | b = list(islice(source_iter, size)) 43 | if len(b) == 0: 44 | break 45 | yield b 46 | 47 | 48 | def define_cache_dir(cache_dir: Optional[str] = None) -> Path: 49 | """ 50 | Define the cache directory for fastembed 51 | """ 52 | if cache_dir is None: 53 | default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") 54 | cache_path = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) 55 | else: 56 | cache_path = Path(cache_dir) 57 | cache_path.mkdir(parents=True, exist_ok=True) 58 | 59 | return cache_path 60 | 61 | 62 | def get_all_punctuation() -> set[str]: 63 | return set( 64 | chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") 65 | ) 66 | 67 | 68 | def remove_non_alphanumeric(text: str) -> str: 69 | return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE) 70 | -------------------------------------------------------------------------------- /fastembed/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | 3 | from loguru import logger 4 | 5 | from fastembed import TextEmbedding 6 | 7 | logger.warning( 8 | "DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated." 9 | "Use from fastembed import TextEmbedding instead." 10 | ) 11 | 12 | DefaultEmbedding = TextEmbedding 13 | FlagEmbedding = TextEmbedding 14 | 15 | 16 | class JinaEmbedding(TextEmbedding): 17 | def __init__( 18 | self, 19 | model_name: str = "jinaai/jina-embeddings-v2-base-en", 20 | cache_dir: Optional[str] = None, 21 | threads: Optional[int] = None, 22 | **kwargs: Any, 23 | ): 24 | super().__init__(model_name, cache_dir, threads, **kwargs) 25 | -------------------------------------------------------------------------------- /fastembed/image/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.image.image_embedding import ImageEmbedding 2 | 3 | __all__ = ["ImageEmbedding"] 4 | -------------------------------------------------------------------------------- /fastembed/image/image_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, Sequence, Type, Union 2 | from dataclasses import asdict 3 | 4 | from fastembed.common.types import NumpyArray 5 | from fastembed.common import ImageInput, OnnxProvider 6 | from fastembed.image.image_embedding_base import ImageEmbeddingBase 7 | from fastembed.image.onnx_embedding import OnnxImageEmbedding 8 | from fastembed.common.model_description import DenseModelDescription 9 | 10 | 11 | class ImageEmbedding(ImageEmbeddingBase): 12 | EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding] 13 | 14 | @classmethod 15 | def list_supported_models(cls) -> list[dict[str, Any]]: 16 | """ 17 | Lists the supported models. 18 | 19 | Returns: 20 | list[dict[str, Any]]: A list of dictionaries containing the model information. 21 | 22 | Example: 23 | ``` 24 | [ 25 | { 26 | "model": "Qdrant/clip-ViT-B-32-vision", 27 | "dim": 512, 28 | "description": "CLIP vision encoder based on ViT-B/32", 29 | "license": "mit", 30 | "size_in_GB": 0.33, 31 | "sources": { 32 | "hf": "Qdrant/clip-ViT-B-32-vision", 33 | }, 34 | "model_file": "model.onnx", 35 | } 36 | ] 37 | ``` 38 | """ 39 | return [asdict(model) for model in cls._list_supported_models()] 40 | 41 | @classmethod 42 | def _list_supported_models(cls) -> list[DenseModelDescription]: 43 | result: list[DenseModelDescription] = [] 44 | for embedding in cls.EMBEDDINGS_REGISTRY: 45 | result.extend(embedding._list_supported_models()) 46 | return result 47 | 48 | def __init__( 49 | self, 50 | model_name: str, 51 | cache_dir: Optional[str] = None, 52 | threads: Optional[int] = None, 53 | providers: Optional[Sequence[OnnxProvider]] = None, 54 | cuda: bool = False, 55 | device_ids: Optional[list[int]] = None, 56 | lazy_load: bool = False, 57 | **kwargs: Any, 58 | ): 59 | super().__init__(model_name, cache_dir, threads, **kwargs) 60 | for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: 61 | supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() 62 | if any(model_name.lower() == model.model.lower() for model in supported_models): 63 | self.model = EMBEDDING_MODEL_TYPE( 64 | model_name, 65 | cache_dir, 66 | threads=threads, 67 | providers=providers, 68 | cuda=cuda, 69 | device_ids=device_ids, 70 | lazy_load=lazy_load, 71 | **kwargs, 72 | ) 73 | return 74 | 75 | raise ValueError( 76 | f"Model {model_name} is not supported in ImageEmbedding." 77 | "Please check the supported models using `ImageEmbedding.list_supported_models()`" 78 | ) 79 | 80 | @property 81 | def embedding_size(self) -> int: 82 | """Get the embedding size of the current model""" 83 | if self._embedding_size is None: 84 | self._embedding_size = self.get_embedding_size(self.model_name) 85 | return self._embedding_size 86 | 87 | @classmethod 88 | def get_embedding_size(cls, model_name: str) -> int: 89 | """Get the embedding size of the passed model 90 | 91 | Args: 92 | model_name (str): The name of the model to get embedding size for. 93 | 94 | Returns: 95 | int: The size of the embedding. 96 | 97 | Raises: 98 | ValueError: If the model name is not found in the supported models. 99 | """ 100 | descriptions = cls._list_supported_models() 101 | embedding_size: Optional[int] = None 102 | for description in descriptions: 103 | if description.model.lower() == model_name.lower(): 104 | embedding_size = description.dim 105 | break 106 | if embedding_size is None: 107 | model_names = [description.model for description in descriptions] 108 | raise ValueError( 109 | f"Embedding size for model {model_name} was None. " 110 | f"Available model names: {model_names}" 111 | ) 112 | return embedding_size 113 | 114 | def embed( 115 | self, 116 | images: Union[ImageInput, Iterable[ImageInput]], 117 | batch_size: int = 16, 118 | parallel: Optional[int] = None, 119 | **kwargs: Any, 120 | ) -> Iterable[NumpyArray]: 121 | """ 122 | Encode a list of images into list of embeddings. 123 | 124 | Args: 125 | images: Iterator of image paths or single image path to embed 126 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 127 | parallel: 128 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 129 | If 0, use all available cores. 130 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 131 | 132 | Returns: 133 | List of embeddings, one per document 134 | """ 135 | yield from self.model.embed(images, batch_size, parallel, **kwargs) 136 | -------------------------------------------------------------------------------- /fastembed/image/image_embedding_base.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Any, Union 2 | 3 | from fastembed.common.model_description import DenseModelDescription 4 | from fastembed.common.types import NumpyArray 5 | from fastembed.common.model_management import ModelManagement 6 | from fastembed.common.types import ImageInput 7 | 8 | 9 | class ImageEmbeddingBase(ModelManagement[DenseModelDescription]): 10 | def __init__( 11 | self, 12 | model_name: str, 13 | cache_dir: Optional[str] = None, 14 | threads: Optional[int] = None, 15 | **kwargs: Any, 16 | ): 17 | self.model_name = model_name 18 | self.cache_dir = cache_dir 19 | self.threads = threads 20 | self._local_files_only = kwargs.pop("local_files_only", False) 21 | self._embedding_size: Optional[int] = None 22 | 23 | def embed( 24 | self, 25 | images: Union[ImageInput, Iterable[ImageInput]], 26 | batch_size: int = 16, 27 | parallel: Optional[int] = None, 28 | **kwargs: Any, 29 | ) -> Iterable[NumpyArray]: 30 | """ 31 | Embeds a list of images into a list of embeddings. 32 | 33 | Args: 34 | images: The list of image paths to preprocess and embed. 35 | batch_size: Batch size for encoding 36 | parallel: 37 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 38 | If 0, use all available cores. 39 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 40 | **kwargs: Additional keyword argument to pass to the embed method. 41 | 42 | Yields: 43 | Iterable[NdArray]: The embeddings. 44 | """ 45 | raise NotImplementedError() 46 | 47 | @classmethod 48 | def get_embedding_size(cls, model_name: str) -> int: 49 | """Returns embedding size of the chosen model.""" 50 | raise NotImplementedError("Subclasses must implement this method") 51 | 52 | @property 53 | def embedding_size(self) -> int: 54 | """Returns embedding size for the current model""" 55 | raise NotImplementedError("Subclasses must implement this method") 56 | -------------------------------------------------------------------------------- /fastembed/image/onnx_image_model.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from multiprocessing import get_all_start_methods 4 | from pathlib import Path 5 | from typing import Any, Iterable, Optional, Sequence, Type, Union 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from fastembed.image.transform.operators import Compose 11 | from fastembed.common.types import NumpyArray 12 | from fastembed.common import ImageInput, OnnxProvider 13 | from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T 14 | from fastembed.common.preprocessor_utils import load_preprocessor 15 | from fastembed.common.utils import iter_batch 16 | from fastembed.parallel_processor import ParallelWorkerPool 17 | 18 | # Holds type of the embedding result 19 | 20 | 21 | class OnnxImageModel(OnnxModel[T]): 22 | @classmethod 23 | def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]: 24 | raise NotImplementedError("Subclasses must implement this method") 25 | 26 | def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]: 27 | """Post-process the ONNX model output to convert it into a usable format. 28 | 29 | Args: 30 | output (OnnxOutputContext): The raw output from the ONNX model. 31 | **kwargs: Additional keyword arguments that may be needed by specific implementations. 32 | 33 | Returns: 34 | Iterable[T]: Post-processed output as an iterable of type T. 35 | """ 36 | raise NotImplementedError("Subclasses must implement this method") 37 | 38 | def __init__(self) -> None: 39 | super().__init__() 40 | self.processor: Optional[Compose] = None 41 | 42 | def _preprocess_onnx_input( 43 | self, onnx_input: dict[str, NumpyArray], **kwargs: Any 44 | ) -> dict[str, NumpyArray]: 45 | """ 46 | Preprocess the onnx input. 47 | """ 48 | return onnx_input 49 | 50 | def _load_onnx_model( 51 | self, 52 | model_dir: Path, 53 | model_file: str, 54 | threads: Optional[int], 55 | providers: Optional[Sequence[OnnxProvider]] = None, 56 | cuda: bool = False, 57 | device_id: Optional[int] = None, 58 | ) -> None: 59 | super()._load_onnx_model( 60 | model_dir=model_dir, 61 | model_file=model_file, 62 | threads=threads, 63 | providers=providers, 64 | cuda=cuda, 65 | device_id=device_id, 66 | ) 67 | self.processor = load_preprocessor(model_dir=model_dir) 68 | 69 | def load_onnx_model(self) -> None: 70 | raise NotImplementedError("Subclasses must implement this method") 71 | 72 | def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]: 73 | input_name = self.model.get_inputs()[0].name # type: ignore[union-attr] 74 | return {input_name: encoded} 75 | 76 | def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext: 77 | with contextlib.ExitStack(): 78 | image_files = [ 79 | Image.open(image) if not isinstance(image, Image.Image) else image 80 | for image in images 81 | ] 82 | assert self.processor is not None, "Processor is not initialized" 83 | encoded = np.array(self.processor(image_files)) 84 | onnx_input = self._build_onnx_input(encoded) 85 | onnx_input = self._preprocess_onnx_input(onnx_input) 86 | model_output = self.model.run(None, onnx_input) # type: ignore[union-attr] 87 | embeddings = model_output[0].reshape(len(images), -1) 88 | return OnnxOutputContext(model_output=embeddings) 89 | 90 | def _embed_images( 91 | self, 92 | model_name: str, 93 | cache_dir: str, 94 | images: Union[ImageInput, Iterable[ImageInput]], 95 | batch_size: int = 256, 96 | parallel: Optional[int] = None, 97 | providers: Optional[Sequence[OnnxProvider]] = None, 98 | cuda: bool = False, 99 | device_ids: Optional[list[int]] = None, 100 | local_files_only: bool = False, 101 | specific_model_path: Optional[str] = None, 102 | **kwargs: Any, 103 | ) -> Iterable[T]: 104 | is_small = False 105 | 106 | if isinstance(images, (str, Path, Image.Image)): 107 | images = [images] 108 | is_small = True 109 | 110 | if isinstance(images, list) and len(images) < batch_size: 111 | is_small = True 112 | 113 | if parallel is None or is_small: 114 | if not hasattr(self, "model") or self.model is None: 115 | self.load_onnx_model() 116 | 117 | for batch in iter_batch(images, batch_size): 118 | yield from self._post_process_onnx_output(self.onnx_embed(batch), **kwargs) 119 | else: 120 | if parallel == 0: 121 | parallel = os.cpu_count() 122 | 123 | start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" 124 | params = { 125 | "model_name": model_name, 126 | "cache_dir": cache_dir, 127 | "providers": providers, 128 | "local_files_only": local_files_only, 129 | "specific_model_path": specific_model_path, 130 | **kwargs, 131 | } 132 | 133 | pool = ParallelWorkerPool( 134 | num_workers=parallel or 1, 135 | worker=self._get_worker_class(), 136 | cuda=cuda, 137 | device_ids=device_ids, 138 | start_method=start_method, 139 | ) 140 | for batch in pool.ordered_map(iter_batch(images, batch_size), **params): 141 | yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore 142 | 143 | 144 | class ImageEmbeddingWorker(EmbeddingWorker[T]): 145 | def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: 146 | for idx, batch in items: 147 | embeddings = self.model.onnx_embed(batch) 148 | yield idx, embeddings 149 | -------------------------------------------------------------------------------- /fastembed/image/transform/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from fastembed.common.types import NumpyArray 7 | 8 | 9 | def convert_to_rgb(image: Image.Image) -> Image.Image: 10 | if image.mode == "RGB": 11 | return image 12 | 13 | image = image.convert("RGB") 14 | return image 15 | 16 | 17 | def center_crop( 18 | image: Union[Image.Image, NumpyArray], 19 | size: tuple[int, int], 20 | ) -> NumpyArray: 21 | if isinstance(image, np.ndarray): 22 | _, orig_height, orig_width = image.shape 23 | else: 24 | orig_height, orig_width = image.height, image.width 25 | # (H, W, C) -> (C, H, W) 26 | image = np.array(image).transpose((2, 0, 1)) 27 | 28 | crop_height, crop_width = size 29 | 30 | # left upper corner (0, 0) 31 | top = (orig_height - crop_height) // 2 32 | bottom = top + crop_height 33 | left = (orig_width - crop_width) // 2 34 | right = left + crop_width 35 | 36 | # Check if cropped area is within image boundaries 37 | if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: 38 | image = image[..., top:bottom, left:right] 39 | return image 40 | 41 | # Padding with zeros 42 | new_height = max(crop_height, orig_height) 43 | new_width = max(crop_width, orig_width) 44 | new_shape = image.shape[:-2] + (new_height, new_width) 45 | new_image = np.zeros_like(image, shape=new_shape, dtype=np.float32) 46 | 47 | top_pad = (new_height - orig_height) // 2 48 | bottom_pad = top_pad + orig_height 49 | left_pad = (new_width - orig_width) // 2 50 | right_pad = left_pad + orig_width 51 | new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image 52 | 53 | top += top_pad 54 | bottom += top_pad 55 | left += left_pad 56 | right += left_pad 57 | 58 | new_image = new_image[ 59 | ..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right) 60 | ] 61 | 62 | return new_image 63 | 64 | 65 | def normalize( 66 | image: NumpyArray, 67 | mean: Union[float, list[float]], 68 | std: Union[float, list[float]], 69 | ) -> NumpyArray: 70 | num_channels = image.shape[1] if len(image.shape) == 4 else image.shape[0] 71 | 72 | if not np.issubdtype(image.dtype, np.floating): 73 | image = image.astype(np.float32) 74 | 75 | mean_list = mean if isinstance(mean, list) else [mean] * num_channels 76 | 77 | if len(mean_list) != num_channels: 78 | raise ValueError( 79 | f"mean must have the same number of channels as the image, image has {num_channels} channels, got " 80 | f"{len(mean_list)}" 81 | ) 82 | 83 | mean_arr = np.array(mean_list, dtype=np.float32) 84 | 85 | std_list = std if isinstance(std, list) else [std] * num_channels 86 | if len(std_list) != num_channels: 87 | raise ValueError( 88 | f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std_list)}" 89 | ) 90 | 91 | std_arr = np.array(std_list, dtype=np.float32) 92 | 93 | image_upd = ((image.T - mean_arr) / std_arr).T 94 | return image_upd 95 | 96 | 97 | def resize( 98 | image: Image.Image, 99 | size: Union[int, tuple[int, int]], 100 | resample: Union[int, Image.Resampling] = Image.Resampling.BILINEAR, 101 | ) -> Image.Image: 102 | if isinstance(size, tuple): 103 | return image.resize(size, resample) 104 | 105 | height, width = image.height, image.width 106 | short, long = (width, height) if width <= height else (height, width) 107 | 108 | new_short, new_long = size, int(size * long / short) 109 | if width <= height: 110 | new_size = (new_short, new_long) 111 | else: 112 | new_size = (new_long, new_short) 113 | return image.resize(new_size, resample) 114 | 115 | 116 | def rescale(image: NumpyArray, scale: float, dtype: type = np.float32) -> NumpyArray: 117 | return (image * scale).astype(dtype) 118 | 119 | 120 | def pil2ndarray(image: Union[Image.Image, NumpyArray]) -> NumpyArray: 121 | if isinstance(image, Image.Image): 122 | return np.asarray(image).transpose((2, 0, 1)) 123 | return image 124 | 125 | 126 | def pad2square( 127 | image: Image.Image, 128 | size: int, 129 | fill_color: Union[str, int, tuple[int, ...]] = 0, 130 | ) -> Image.Image: 131 | height, width = image.height, image.width 132 | 133 | left, right = 0, width 134 | top, bottom = 0, height 135 | 136 | crop_required = False 137 | if width > size: 138 | left = (width - size) // 2 139 | right = left + size 140 | crop_required = True 141 | 142 | if height > size: 143 | top = (height - size) // 2 144 | bottom = top + size 145 | crop_required = True 146 | 147 | new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) 148 | new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image) 149 | return new_image 150 | -------------------------------------------------------------------------------- /fastembed/late_interaction/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.late_interaction.late_interaction_text_embedding import ( 2 | LateInteractionTextEmbedding, 3 | ) 4 | 5 | __all__ = ["LateInteractionTextEmbedding"] 6 | -------------------------------------------------------------------------------- /fastembed/late_interaction/jina_colbert.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type 2 | 3 | from fastembed.common.types import NumpyArray 4 | from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker 5 | from fastembed.common.model_description import DenseModelDescription, ModelSource 6 | 7 | supported_jina_colbert_models: list[DenseModelDescription] = [ 8 | DenseModelDescription( 9 | model="jinaai/jina-colbert-v2", 10 | dim=128, 11 | description="New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year", 12 | license="cc-by-nc-4.0", 13 | size_in_GB=2.24, 14 | sources=ModelSource(hf="jinaai/jina-colbert-v2"), 15 | model_file="onnx/model.onnx", 16 | additional_files=["onnx/model.onnx_data"], 17 | ) 18 | ] 19 | 20 | 21 | class JinaColbert(Colbert): 22 | QUERY_MARKER_TOKEN_ID = 250002 23 | DOCUMENT_MARKER_TOKEN_ID = 250003 24 | MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning 25 | MASK_TOKEN = "" 26 | 27 | @classmethod 28 | def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]: 29 | return JinaColbertEmbeddingWorker 30 | 31 | @classmethod 32 | def _list_supported_models(cls) -> list[DenseModelDescription]: 33 | """Lists the supported models. 34 | 35 | Returns: 36 | list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. 37 | """ 38 | return supported_jina_colbert_models 39 | 40 | def _preprocess_onnx_input( 41 | self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any 42 | ) -> dict[str, NumpyArray]: 43 | onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc) 44 | 45 | # the attention mask for jina-colbert-v2 is always 1 in queries 46 | if not is_doc: 47 | onnx_input["attention_mask"][:] = 1 48 | return onnx_input 49 | 50 | 51 | class JinaColbertEmbeddingWorker(ColbertEmbeddingWorker): 52 | def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> JinaColbert: 53 | return JinaColbert( 54 | model_name=model_name, 55 | cache_dir=cache_dir, 56 | threads=1, 57 | **kwargs, 58 | ) 59 | -------------------------------------------------------------------------------- /fastembed/late_interaction/late_interaction_embedding_base.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Union, Any 2 | 3 | from fastembed.common.model_description import DenseModelDescription 4 | from fastembed.common.types import NumpyArray 5 | from fastembed.common.model_management import ModelManagement 6 | 7 | 8 | class LateInteractionTextEmbeddingBase(ModelManagement[DenseModelDescription]): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | cache_dir: Optional[str] = None, 13 | threads: Optional[int] = None, 14 | **kwargs: Any, 15 | ): 16 | self.model_name = model_name 17 | self.cache_dir = cache_dir 18 | self.threads = threads 19 | self._local_files_only = kwargs.pop("local_files_only", False) 20 | self._embedding_size: Optional[int] = None 21 | 22 | def embed( 23 | self, 24 | documents: Union[str, Iterable[str]], 25 | batch_size: int = 256, 26 | parallel: Optional[int] = None, 27 | **kwargs: Any, 28 | ) -> Iterable[NumpyArray]: 29 | raise NotImplementedError() 30 | 31 | def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: 32 | """ 33 | Embeds a list of text passages into a list of embeddings. 34 | 35 | Args: 36 | texts (Iterable[str]): The list of texts to embed. 37 | **kwargs: Additional keyword argument to pass to the embed method. 38 | 39 | Yields: 40 | Iterable[NdArray]: The embeddings. 41 | """ 42 | 43 | # This is model-specific, so that different models can have specialized implementations 44 | yield from self.embed(texts, **kwargs) 45 | 46 | def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: 47 | """ 48 | Embeds queries 49 | 50 | Args: 51 | query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. 52 | 53 | Returns: 54 | Iterable[NdArray]: The embeddings. 55 | """ 56 | 57 | # This is model-specific, so that different models can have specialized implementations 58 | if isinstance(query, str): 59 | yield from self.embed([query], **kwargs) 60 | else: 61 | yield from self.embed(query, **kwargs) 62 | 63 | @classmethod 64 | def get_embedding_size(cls, model_name: str) -> int: 65 | """Returns embedding size of the chosen model.""" 66 | raise NotImplementedError("Subclasses must implement this method") 67 | 68 | @property 69 | def embedding_size(self) -> int: 70 | """Returns embedding size for the current model""" 71 | raise NotImplementedError("Subclasses must implement this method") 72 | -------------------------------------------------------------------------------- /fastembed/late_interaction/late_interaction_text_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, Sequence, Type, Union 2 | from dataclasses import asdict 3 | 4 | from fastembed.common.model_description import DenseModelDescription 5 | from fastembed.common.types import NumpyArray 6 | from fastembed.common import OnnxProvider 7 | from fastembed.late_interaction.colbert import Colbert 8 | from fastembed.late_interaction.jina_colbert import JinaColbert 9 | from fastembed.late_interaction.late_interaction_embedding_base import ( 10 | LateInteractionTextEmbeddingBase, 11 | ) 12 | 13 | 14 | class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase): 15 | EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert] 16 | 17 | @classmethod 18 | def list_supported_models(cls) -> list[dict[str, Any]]: 19 | """ 20 | Lists the supported models. 21 | 22 | Returns: 23 | list[dict[str, Any]]: A list of dictionaries containing the model information. 24 | 25 | Example: 26 | ``` 27 | [ 28 | { 29 | "model": "colbert-ir/colbertv2.0", 30 | "dim": 128, 31 | "description": "Late interaction model", 32 | "license": "mit", 33 | "size_in_GB": 0.44, 34 | "sources": { 35 | "hf": "colbert-ir/colbertv2.0", 36 | }, 37 | "model_file": "model.onnx", 38 | }, 39 | ] 40 | ``` 41 | """ 42 | return [asdict(model) for model in cls._list_supported_models()] 43 | 44 | @classmethod 45 | def _list_supported_models(cls) -> list[DenseModelDescription]: 46 | result: list[DenseModelDescription] = [] 47 | for embedding in cls.EMBEDDINGS_REGISTRY: 48 | result.extend(embedding._list_supported_models()) 49 | return result 50 | 51 | def __init__( 52 | self, 53 | model_name: str, 54 | cache_dir: Optional[str] = None, 55 | threads: Optional[int] = None, 56 | providers: Optional[Sequence[OnnxProvider]] = None, 57 | cuda: bool = False, 58 | device_ids: Optional[list[int]] = None, 59 | lazy_load: bool = False, 60 | **kwargs: Any, 61 | ): 62 | super().__init__(model_name, cache_dir, threads, **kwargs) 63 | for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: 64 | supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() 65 | if any(model_name.lower() == model.model.lower() for model in supported_models): 66 | self.model = EMBEDDING_MODEL_TYPE( 67 | model_name, 68 | cache_dir, 69 | threads=threads, 70 | providers=providers, 71 | cuda=cuda, 72 | device_ids=device_ids, 73 | lazy_load=lazy_load, 74 | **kwargs, 75 | ) 76 | return 77 | 78 | raise ValueError( 79 | f"Model {model_name} is not supported in LateInteractionTextEmbedding." 80 | "Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`" 81 | ) 82 | 83 | @property 84 | def embedding_size(self) -> int: 85 | """Get the embedding size of the current model""" 86 | if self._embedding_size is None: 87 | self._embedding_size = self.get_embedding_size(self.model_name) 88 | return self._embedding_size 89 | 90 | @classmethod 91 | def get_embedding_size(cls, model_name: str) -> int: 92 | """Get the embedding size of the passed model 93 | 94 | Args: 95 | model_name (str): The name of the model to get embedding size for. 96 | 97 | Returns: 98 | int: The size of the embedding. 99 | 100 | Raises: 101 | ValueError: If the model name is not found in the supported models. 102 | """ 103 | descriptions = cls._list_supported_models() 104 | embedding_size: Optional[int] = None 105 | for description in descriptions: 106 | if description.model.lower() == model_name.lower(): 107 | embedding_size = description.dim 108 | break 109 | if embedding_size is None: 110 | model_names = [description.model for description in descriptions] 111 | raise ValueError( 112 | f"Embedding size for model {model_name} was None. " 113 | f"Available model names: {model_names}" 114 | ) 115 | return embedding_size 116 | 117 | def embed( 118 | self, 119 | documents: Union[str, Iterable[str]], 120 | batch_size: int = 256, 121 | parallel: Optional[int] = None, 122 | **kwargs: Any, 123 | ) -> Iterable[NumpyArray]: 124 | """ 125 | Encode a list of documents into list of embeddings. 126 | We use mean pooling with attention so that the model can handle variable-length inputs. 127 | 128 | Args: 129 | documents: Iterator of documents or single document to embed 130 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 131 | parallel: 132 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 133 | If 0, use all available cores. 134 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 135 | 136 | Returns: 137 | List of embeddings, one per document 138 | """ 139 | yield from self.model.embed(documents, batch_size, parallel, **kwargs) 140 | 141 | def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: 142 | """ 143 | Embeds queries 144 | 145 | Args: 146 | query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. 147 | 148 | Returns: 149 | Iterable[NdArray]: The embeddings. 150 | """ 151 | 152 | # This is model-specific, so that different models can have specialized implementations 153 | yield from self.model.query_embed(query, **kwargs) 154 | -------------------------------------------------------------------------------- /fastembed/late_interaction/token_embeddings.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Union, Iterable, Optional, Any, Type 3 | 4 | from fastembed.common.model_description import DenseModelDescription, ModelSource 5 | from fastembed.common.onnx_model import OnnxOutputContext 6 | from fastembed.common.types import NumpyArray 7 | from fastembed.late_interaction.late_interaction_embedding_base import ( 8 | LateInteractionTextEmbeddingBase, 9 | ) 10 | from fastembed.text.onnx_embedding import OnnxTextEmbedding 11 | from fastembed.text.onnx_text_model import TextEmbeddingWorker 12 | 13 | 14 | supported_token_embeddings_models = [ 15 | DenseModelDescription( 16 | model="jinaai/jina-embeddings-v2-small-en-tokens", 17 | dim=512, 18 | description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation," 19 | " Prefixes for queries/documents: not necessary, 2023 year.", 20 | license="apache-2.0", 21 | size_in_GB=0.12, 22 | sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"), 23 | model_file="onnx/model.onnx", 24 | ), 25 | ] 26 | 27 | 28 | class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): 29 | @classmethod 30 | def _list_supported_models(cls) -> list[DenseModelDescription]: 31 | """Lists the supported models. 32 | 33 | Returns: 34 | list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. 35 | """ 36 | return supported_token_embeddings_models 37 | 38 | @classmethod 39 | def list_supported_models(cls) -> list[dict[str, Any]]: 40 | """Lists the supported models. 41 | 42 | Returns: 43 | list[dict[str, Any]]: A list of dictionaries containing the model information. 44 | """ 45 | return [asdict(model) for model in cls._list_supported_models()] 46 | 47 | @classmethod 48 | def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]: 49 | return TokensEmbeddingWorker 50 | 51 | def _post_process_onnx_output( 52 | self, output: OnnxOutputContext, **kwargs: Any 53 | ) -> Iterable[NumpyArray]: 54 | # Size: (batch_size, sequence_length, hidden_size) 55 | embeddings = output.model_output 56 | # Size: (batch_size, sequence_length) 57 | assert output.attention_mask is not None 58 | masks = output.attention_mask 59 | 60 | # For each document we only select those embeddings that are not masked out 61 | for i in range(embeddings.shape[0]): 62 | yield embeddings[i, masks[i] == 1] 63 | 64 | def embed( 65 | self, 66 | documents: Union[str, Iterable[str]], 67 | batch_size: int = 256, 68 | parallel: Optional[int] = None, 69 | **kwargs: Any, 70 | ) -> Iterable[NumpyArray]: 71 | yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs) 72 | 73 | 74 | class TokensEmbeddingWorker(TextEmbeddingWorker[NumpyArray]): 75 | def init_embedding( 76 | self, model_name: str, cache_dir: str, **kwargs: Any 77 | ) -> TokenEmbeddingsModel: 78 | return TokenEmbeddingsModel( 79 | model_name=model_name, 80 | cache_dir=cache_dir, 81 | threads=1, 82 | **kwargs, 83 | ) 84 | -------------------------------------------------------------------------------- /fastembed/late_interaction_multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import ( 2 | LateInteractionMultimodalEmbedding, 3 | ) 4 | 5 | __all__ = ["LateInteractionMultimodalEmbedding"] 6 | -------------------------------------------------------------------------------- /fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, Sequence, Type, Union 2 | from dataclasses import asdict 3 | 4 | from fastembed.common import OnnxProvider, ImageInput 5 | from fastembed.common.types import NumpyArray 6 | from fastembed.late_interaction_multimodal.colpali import ColPali 7 | 8 | from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import ( 9 | LateInteractionMultimodalEmbeddingBase, 10 | ) 11 | from fastembed.common.model_description import DenseModelDescription 12 | 13 | 14 | class LateInteractionMultimodalEmbedding(LateInteractionMultimodalEmbeddingBase): 15 | EMBEDDINGS_REGISTRY: list[Type[LateInteractionMultimodalEmbeddingBase]] = [ColPali] 16 | 17 | @classmethod 18 | def list_supported_models(cls) -> list[dict[str, Any]]: 19 | """ 20 | Lists the supported models. 21 | 22 | Returns: 23 | list[dict[str, Any]]: A list of dictionaries containing the model information. 24 | 25 | Example: 26 | ``` 27 | [ 28 | { 29 | "model": "Qdrant/colpali-v1.3-fp16", 30 | "dim": 128, 31 | "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", 32 | "license": "mit", 33 | "size_in_GB": 6.06, 34 | "sources": { 35 | "hf": "Qdrant/colpali-v1.3-fp16", 36 | }, 37 | "additional_files": [ 38 | "model.onnx_data", 39 | ], 40 | "model_file": "model.onnx", 41 | }, 42 | ] 43 | ``` 44 | """ 45 | return [asdict(model) for model in cls._list_supported_models()] 46 | 47 | @classmethod 48 | def _list_supported_models(cls) -> list[DenseModelDescription]: 49 | result: list[DenseModelDescription] = [] 50 | for embedding in cls.EMBEDDINGS_REGISTRY: 51 | result.extend(embedding._list_supported_models()) 52 | return result 53 | 54 | def __init__( 55 | self, 56 | model_name: str, 57 | cache_dir: Optional[str] = None, 58 | threads: Optional[int] = None, 59 | providers: Optional[Sequence[OnnxProvider]] = None, 60 | cuda: bool = False, 61 | device_ids: Optional[list[int]] = None, 62 | lazy_load: bool = False, 63 | **kwargs: Any, 64 | ): 65 | super().__init__(model_name, cache_dir, threads, **kwargs) 66 | for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: 67 | supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() 68 | if any(model_name.lower() == model.model.lower() for model in supported_models): 69 | self.model = EMBEDDING_MODEL_TYPE( 70 | model_name, 71 | cache_dir, 72 | threads=threads, 73 | providers=providers, 74 | cuda=cuda, 75 | device_ids=device_ids, 76 | lazy_load=lazy_load, 77 | **kwargs, 78 | ) 79 | return 80 | 81 | raise ValueError( 82 | f"Model {model_name} is not supported in LateInteractionMultimodalEmbedding." 83 | "Please check the supported models using `LateInteractionMultimodalEmbedding.list_supported_models()`" 84 | ) 85 | 86 | @property 87 | def embedding_size(self) -> int: 88 | """Get the embedding size of the current model""" 89 | if self._embedding_size is None: 90 | self._embedding_size = self.get_embedding_size(self.model_name) 91 | return self._embedding_size 92 | 93 | @classmethod 94 | def get_embedding_size(cls, model_name: str) -> int: 95 | """Get the embedding size of the passed model 96 | 97 | Args: 98 | model_name (str): The name of the model to get embedding size for. 99 | 100 | Returns: 101 | int: The size of the embedding. 102 | 103 | Raises: 104 | ValueError: If the model name is not found in the supported models. 105 | """ 106 | descriptions = cls._list_supported_models() 107 | embedding_size: Optional[int] = None 108 | for description in descriptions: 109 | if description.model.lower() == model_name.lower(): 110 | embedding_size = description.dim 111 | break 112 | if embedding_size is None: 113 | model_names = [description.model for description in descriptions] 114 | raise ValueError( 115 | f"Embedding size for model {model_name} was None. " 116 | f"Available model names: {model_names}" 117 | ) 118 | return embedding_size 119 | 120 | def embed_text( 121 | self, 122 | documents: Union[str, Iterable[str]], 123 | batch_size: int = 256, 124 | parallel: Optional[int] = None, 125 | **kwargs: Any, 126 | ) -> Iterable[NumpyArray]: 127 | """ 128 | Encode a list of documents into list of embeddings. 129 | 130 | Args: 131 | documents: Iterator of documents or single document to embed 132 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 133 | parallel: 134 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 135 | If 0, use all available cores. 136 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 137 | 138 | Returns: 139 | List of embeddings, one per document 140 | """ 141 | yield from self.model.embed_text(documents, batch_size, parallel, **kwargs) 142 | 143 | def embed_image( 144 | self, 145 | images: Union[ImageInput, Iterable[ImageInput]], 146 | batch_size: int = 16, 147 | parallel: Optional[int] = None, 148 | **kwargs: Any, 149 | ) -> Iterable[NumpyArray]: 150 | """ 151 | Encode a list of images into list of embeddings. 152 | 153 | Args: 154 | images: Iterator of image paths or single image path to embed 155 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 156 | parallel: 157 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 158 | If 0, use all available cores. 159 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 160 | 161 | Returns: 162 | List of embeddings, one per image 163 | """ 164 | yield from self.model.embed_image(images, batch_size, parallel, **kwargs) 165 | -------------------------------------------------------------------------------- /fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Union, Any 2 | 3 | 4 | from fastembed.common import ImageInput 5 | from fastembed.common.model_description import DenseModelDescription 6 | from fastembed.common.model_management import ModelManagement 7 | from fastembed.common.types import NumpyArray 8 | 9 | 10 | class LateInteractionMultimodalEmbeddingBase(ModelManagement[DenseModelDescription]): 11 | def __init__( 12 | self, 13 | model_name: str, 14 | cache_dir: Optional[str] = None, 15 | threads: Optional[int] = None, 16 | **kwargs: Any, 17 | ): 18 | self.model_name = model_name 19 | self.cache_dir = cache_dir 20 | self.threads = threads 21 | self._local_files_only = kwargs.pop("local_files_only", False) 22 | self._embedding_size: Optional[int] = None 23 | 24 | def embed_text( 25 | self, 26 | documents: Union[str, Iterable[str]], 27 | batch_size: int = 256, 28 | parallel: Optional[int] = None, 29 | **kwargs: Any, 30 | ) -> Iterable[NumpyArray]: 31 | """ 32 | Embeds a list of documents into a list of embeddings. 33 | 34 | Args: 35 | documents (Iterable[str]): The list of texts to embed. 36 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 37 | parallel: 38 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 39 | If 0, use all available cores. 40 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 41 | **kwargs: Additional keyword argument to pass to the embed method. 42 | 43 | Yields: 44 | Iterable[NumpyArray]: The embeddings. 45 | """ 46 | raise NotImplementedError() 47 | 48 | def embed_image( 49 | self, 50 | images: Union[ImageInput, Iterable[ImageInput]], 51 | batch_size: int = 16, 52 | parallel: Optional[int] = None, 53 | **kwargs: Any, 54 | ) -> Iterable[NumpyArray]: 55 | """ 56 | Encode a list of images into list of embeddings. 57 | Args: 58 | images: Iterator of image paths or single image path to embed 59 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 60 | parallel: 61 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 62 | If 0, use all available cores. 63 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 64 | 65 | Returns: 66 | List of embeddings, one per image 67 | """ 68 | raise NotImplementedError() 69 | 70 | @classmethod 71 | def get_embedding_size(cls, model_name: str) -> int: 72 | """Returns embedding size of the chosen model.""" 73 | raise NotImplementedError("Subclasses must implement this method") 74 | 75 | @property 76 | def embedding_size(self) -> int: 77 | """Returns embedding size for the current model""" 78 | raise NotImplementedError("Subclasses must implement this method") 79 | -------------------------------------------------------------------------------- /fastembed/py.typed: -------------------------------------------------------------------------------- 1 | partial -------------------------------------------------------------------------------- /fastembed/rerank/cross_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.rerank.cross_encoder.text_cross_encoder import TextCrossEncoder 2 | 3 | __all__ = ["TextCrossEncoder"] 4 | -------------------------------------------------------------------------------- /fastembed/rerank/cross_encoder/custom_text_cross_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Any 2 | 3 | from fastembed.common import OnnxProvider 4 | from fastembed.common.model_description import BaseModelDescription 5 | from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder 6 | 7 | 8 | class CustomTextCrossEncoder(OnnxTextCrossEncoder): 9 | SUPPORTED_MODELS: list[BaseModelDescription] = [] 10 | 11 | def __init__( 12 | self, 13 | model_name: str, 14 | cache_dir: Optional[str] = None, 15 | threads: Optional[int] = None, 16 | providers: Optional[Sequence[OnnxProvider]] = None, 17 | cuda: bool = False, 18 | device_ids: Optional[list[int]] = None, 19 | lazy_load: bool = False, 20 | device_id: Optional[int] = None, 21 | specific_model_path: Optional[str] = None, 22 | **kwargs: Any, 23 | ): 24 | super().__init__( 25 | model_name=model_name, 26 | cache_dir=cache_dir, 27 | threads=threads, 28 | providers=providers, 29 | cuda=cuda, 30 | device_ids=device_ids, 31 | lazy_load=lazy_load, 32 | device_id=device_id, 33 | specific_model_path=specific_model_path, 34 | **kwargs, 35 | ) 36 | 37 | @classmethod 38 | def _list_supported_models(cls) -> list[BaseModelDescription]: 39 | return cls.SUPPORTED_MODELS 40 | 41 | @classmethod 42 | def add_model( 43 | cls, 44 | model_description: BaseModelDescription, 45 | ) -> None: 46 | cls.SUPPORTED_MODELS.append(model_description) 47 | -------------------------------------------------------------------------------- /fastembed/rerank/cross_encoder/onnx_text_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import get_all_start_methods 3 | from pathlib import Path 4 | from typing import Any, Iterable, Optional, Sequence, Type 5 | 6 | import numpy as np 7 | from tokenizers import Encoding 8 | 9 | from fastembed.common.onnx_model import ( 10 | EmbeddingWorker, 11 | OnnxModel, 12 | OnnxOutputContext, 13 | OnnxProvider, 14 | ) 15 | from fastembed.common.types import NumpyArray 16 | from fastembed.common.preprocessor_utils import load_tokenizer 17 | from fastembed.common.utils import iter_batch 18 | from fastembed.parallel_processor import ParallelWorkerPool 19 | 20 | 21 | class OnnxCrossEncoderModel(OnnxModel[float]): 22 | ONNX_OUTPUT_NAMES: Optional[list[str]] = None 23 | 24 | @classmethod 25 | def _get_worker_class(cls) -> Type["TextRerankerWorker"]: 26 | raise NotImplementedError("Subclasses must implement this method") 27 | 28 | def _load_onnx_model( 29 | self, 30 | model_dir: Path, 31 | model_file: str, 32 | threads: Optional[int], 33 | providers: Optional[Sequence[OnnxProvider]] = None, 34 | cuda: bool = False, 35 | device_id: Optional[int] = None, 36 | ) -> None: 37 | super()._load_onnx_model( 38 | model_dir=model_dir, 39 | model_file=model_file, 40 | threads=threads, 41 | providers=providers, 42 | cuda=cuda, 43 | device_id=device_id, 44 | ) 45 | self.tokenizer, _ = load_tokenizer(model_dir=model_dir) 46 | assert self.tokenizer is not None 47 | 48 | def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]: 49 | return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr] 50 | 51 | def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]: 52 | input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] 53 | inputs: dict[str, NumpyArray] = { 54 | "input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64), 55 | } 56 | if "token_type_ids" in input_names: 57 | inputs["token_type_ids"] = np.array( 58 | [enc.type_ids for enc in tokenized_input], dtype=np.int64 59 | ) 60 | if "attention_mask" in input_names: 61 | inputs["attention_mask"] = np.array( 62 | [enc.attention_mask for enc in tokenized_input], dtype=np.int64 63 | ) 64 | return inputs 65 | 66 | def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext: 67 | pairs = [(query, doc) for doc in documents] 68 | return self.onnx_embed_pairs(pairs, **kwargs) 69 | 70 | def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext: 71 | tokenized_input = self.tokenize(pairs, **kwargs) 72 | inputs = self._build_onnx_input(tokenized_input) 73 | onnx_input = self._preprocess_onnx_input(inputs, **kwargs) 74 | outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] 75 | relevant_output = outputs[0] 76 | scores: NumpyArray = relevant_output[:, 0] 77 | return OnnxOutputContext(model_output=scores) 78 | 79 | def _rerank_documents( 80 | self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any 81 | ) -> Iterable[float]: 82 | if not hasattr(self, "model") or self.model is None: 83 | self.load_onnx_model() 84 | for batch in iter_batch(documents, batch_size): 85 | yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs)) 86 | 87 | def _rerank_pairs( 88 | self, 89 | model_name: str, 90 | cache_dir: str, 91 | pairs: Iterable[tuple[str, str]], 92 | batch_size: int, 93 | parallel: Optional[int] = None, 94 | providers: Optional[Sequence[OnnxProvider]] = None, 95 | cuda: bool = False, 96 | device_ids: Optional[list[int]] = None, 97 | local_files_only: bool = False, 98 | specific_model_path: Optional[str] = None, 99 | **kwargs: Any, 100 | ) -> Iterable[float]: 101 | is_small = False 102 | 103 | if isinstance(pairs, tuple): 104 | pairs = [pairs] 105 | is_small = True 106 | 107 | if isinstance(pairs, list): 108 | if len(pairs) < batch_size: 109 | is_small = True 110 | 111 | if parallel is None or is_small: 112 | if not hasattr(self, "model") or self.model is None: 113 | self.load_onnx_model() 114 | for batch in iter_batch(pairs, batch_size): 115 | yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs)) 116 | else: 117 | if parallel == 0: 118 | parallel = os.cpu_count() 119 | 120 | start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" 121 | params = { 122 | "model_name": model_name, 123 | "cache_dir": cache_dir, 124 | "providers": providers, 125 | "local_files_only": local_files_only, 126 | "specific_model_path": specific_model_path, 127 | **kwargs, 128 | } 129 | 130 | pool = ParallelWorkerPool( 131 | num_workers=parallel or 1, 132 | worker=self._get_worker_class(), 133 | cuda=cuda, 134 | device_ids=device_ids, 135 | start_method=start_method, 136 | ) 137 | for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params): 138 | yield from self._post_process_onnx_output(batch) # type: ignore 139 | 140 | def _post_process_onnx_output( 141 | self, output: OnnxOutputContext, **kwargs: Any 142 | ) -> Iterable[float]: 143 | """Post-process the ONNX model output to convert it into a usable format. 144 | 145 | Args: 146 | output (OnnxOutputContext): The raw output from the ONNX model. 147 | **kwargs: Additional keyword arguments that may be needed by specific implementations. 148 | 149 | Returns: 150 | Iterable[float]: Post-processed output as an iterable of float values. 151 | """ 152 | raise NotImplementedError("Subclasses must implement this method") 153 | 154 | def _preprocess_onnx_input( 155 | self, onnx_input: dict[str, NumpyArray], **kwargs: Any 156 | ) -> dict[str, NumpyArray]: 157 | """ 158 | Preprocess the onnx input. 159 | """ 160 | return onnx_input 161 | 162 | 163 | class TextRerankerWorker(EmbeddingWorker[float]): 164 | def __init__( 165 | self, 166 | model_name: str, 167 | cache_dir: str, 168 | **kwargs: Any, 169 | ): 170 | self.model: OnnxCrossEncoderModel 171 | super().__init__(model_name, cache_dir, **kwargs) 172 | 173 | def init_embedding( 174 | self, 175 | model_name: str, 176 | cache_dir: str, 177 | **kwargs: Any, 178 | ) -> OnnxCrossEncoderModel: 179 | raise NotImplementedError() 180 | 181 | def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]: 182 | for idx, batch in items: 183 | onnx_output = self.model.onnx_embed_pairs(batch) 184 | yield idx, onnx_output 185 | -------------------------------------------------------------------------------- /fastembed/rerank/cross_encoder/text_cross_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, Sequence, Type 2 | from dataclasses import asdict 3 | 4 | from fastembed.common import OnnxProvider 5 | from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder 6 | from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder 7 | 8 | from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase 9 | from fastembed.common.model_description import ( 10 | ModelSource, 11 | BaseModelDescription, 12 | ) 13 | 14 | 15 | class TextCrossEncoder(TextCrossEncoderBase): 16 | CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [ 17 | OnnxTextCrossEncoder, 18 | CustomTextCrossEncoder, 19 | ] 20 | 21 | @classmethod 22 | def list_supported_models(cls) -> list[dict[str, Any]]: 23 | """Lists the supported models. 24 | 25 | Returns: 26 | list[BaseModelDescription]: A list of dictionaries containing the model information. 27 | 28 | Example: 29 | ``` 30 | [ 31 | { 32 | "model": "Xenova/ms-marco-MiniLM-L-6-v2", 33 | "size_in_GB": 0.08, 34 | "sources": { 35 | "hf": "Xenova/ms-marco-MiniLM-L-6-v2", 36 | }, 37 | "model_file": "onnx/model.onnx", 38 | "description": "MiniLM-L-6-v2 model optimized for re-ranking tasks.", 39 | "license": "apache-2.0", 40 | } 41 | ] 42 | ``` 43 | """ 44 | return [asdict(model) for model in cls._list_supported_models()] 45 | 46 | @classmethod 47 | def _list_supported_models(cls) -> list[BaseModelDescription]: 48 | result: list[BaseModelDescription] = [] 49 | for encoder in cls.CROSS_ENCODER_REGISTRY: 50 | result.extend(encoder._list_supported_models()) 51 | return result 52 | 53 | def __init__( 54 | self, 55 | model_name: str, 56 | cache_dir: Optional[str] = None, 57 | threads: Optional[int] = None, 58 | providers: Optional[Sequence[OnnxProvider]] = None, 59 | cuda: bool = False, 60 | device_ids: Optional[list[int]] = None, 61 | lazy_load: bool = False, 62 | **kwargs: Any, 63 | ): 64 | super().__init__(model_name, cache_dir, threads, **kwargs) 65 | 66 | for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY: 67 | supported_models = CROSS_ENCODER_TYPE._list_supported_models() 68 | if any(model_name.lower() == model.model.lower() for model in supported_models): 69 | self.model = CROSS_ENCODER_TYPE( 70 | model_name=model_name, 71 | cache_dir=cache_dir, 72 | threads=threads, 73 | providers=providers, 74 | cuda=cuda, 75 | device_ids=device_ids, 76 | lazy_load=lazy_load, 77 | **kwargs, 78 | ) 79 | return 80 | 81 | raise ValueError( 82 | f"Model {model_name} is not supported in TextCrossEncoder." 83 | "Please check the supported models using `TextCrossEncoder.list_supported_models()`" 84 | ) 85 | 86 | def rerank( 87 | self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any 88 | ) -> Iterable[float]: 89 | """Rerank a list of documents based on a query. 90 | 91 | Args: 92 | query: Query to rerank the documents against 93 | documents: Iterator of documents to rerank 94 | batch_size: Batch size for reranking 95 | 96 | Returns: 97 | Iterable of scores for each document 98 | """ 99 | yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs) 100 | 101 | def rerank_pairs( 102 | self, 103 | pairs: Iterable[tuple[str, str]], 104 | batch_size: int = 64, 105 | parallel: Optional[int] = None, 106 | **kwargs: Any, 107 | ) -> Iterable[float]: 108 | """ 109 | Rerank a list of query-document pairs. 110 | 111 | Args: 112 | pairs (Iterable[tuple[str, str]]): An iterable of tuples, where each tuple contains a query and a document 113 | to be scored together. 114 | batch_size (int, optional): The number of query-document pairs to process in a single batch. Defaults to 64. 115 | parallel (Optional[int], optional): The number of parallel processes to use for reranking. 116 | If None, parallelization is disabled. Defaults to None. 117 | **kwargs (Any): Additional arguments to pass to the underlying reranking model. 118 | 119 | Returns: 120 | Iterable[float]: An iterable of scores corresponding to each query-document pair in the input. 121 | Higher scores indicate a stronger match between the query and the document. 122 | 123 | Example: 124 | >>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2") 125 | >>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")] 126 | >>> scores = list(encoder.rerank_pairs(pairs)) 127 | >>> print(list(map(lambda x: round(x, 2), scores))) 128 | [-1.24, -10.6] 129 | """ 130 | yield from self.model.rerank_pairs( 131 | pairs, batch_size=batch_size, parallel=parallel, **kwargs 132 | ) 133 | 134 | @classmethod 135 | def add_custom_model( 136 | cls, 137 | model: str, 138 | sources: ModelSource, 139 | model_file: str = "onnx/model.onnx", 140 | description: str = "", 141 | license: str = "", 142 | size_in_gb: float = 0.0, 143 | additional_files: Optional[list[str]] = None, 144 | ) -> None: 145 | registered_models = cls._list_supported_models() 146 | for registered_model in registered_models: 147 | if model == registered_model.model: 148 | raise ValueError( 149 | f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, " 150 | f"please use another model name" 151 | ) 152 | 153 | CustomTextCrossEncoder.add_model( 154 | BaseModelDescription( 155 | model=model, 156 | sources=sources, 157 | model_file=model_file, 158 | description=description, 159 | license=license, 160 | size_in_GB=size_in_gb, 161 | additional_files=additional_files or [], 162 | ) 163 | ) 164 | -------------------------------------------------------------------------------- /fastembed/rerank/cross_encoder/text_cross_encoder_base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional 2 | 3 | from fastembed.common.model_description import BaseModelDescription 4 | from fastembed.common.model_management import ModelManagement 5 | 6 | 7 | class TextCrossEncoderBase(ModelManagement[BaseModelDescription]): 8 | def __init__( 9 | self, 10 | model_name: str, 11 | cache_dir: Optional[str] = None, 12 | threads: Optional[int] = None, 13 | **kwargs: Any, 14 | ): 15 | self.model_name = model_name 16 | self.cache_dir = cache_dir 17 | self.threads = threads 18 | self._local_files_only = kwargs.pop("local_files_only", False) 19 | 20 | def rerank( 21 | self, 22 | query: str, 23 | documents: Iterable[str], 24 | batch_size: int = 64, 25 | **kwargs: Any, 26 | ) -> Iterable[float]: 27 | """Rerank a list of documents given a query. 28 | 29 | Args: 30 | query (str): The query to rerank the documents. 31 | documents (Iterable[str]): The list of texts to rerank. 32 | batch_size (int): The batch size to use for reranking. 33 | **kwargs: Additional keyword argument to pass to the rerank method. 34 | 35 | Yields: 36 | Iterable[float]: The scores of the reranked the documents. 37 | """ 38 | raise NotImplementedError("This method should be overridden by subclasses") 39 | 40 | def rerank_pairs( 41 | self, 42 | pairs: Iterable[tuple[str, str]], 43 | batch_size: int = 64, 44 | parallel: Optional[int] = None, 45 | **kwargs: Any, 46 | ) -> Iterable[float]: 47 | """Rerank query-document pairs. 48 | Args: 49 | pairs (Iterable[tuple[str, str]]): Query-document pairs to rerank 50 | batch_size (int): The batch size to use for reranking. 51 | parallel: parallel: 52 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 53 | If 0, use all available cores. 54 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 55 | **kwargs: Additional keyword argument to pass to the rerank method. 56 | Yields: 57 | Iterable[float]: Scores for each individual pair 58 | """ 59 | raise NotImplementedError("This method should be overridden by subclasses") 60 | -------------------------------------------------------------------------------- /fastembed/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.sparse.sparse_embedding_base import SparseEmbedding 2 | from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding 3 | 4 | __all__ = ["SparseEmbedding", "SparseTextEmbedding"] 5 | -------------------------------------------------------------------------------- /fastembed/sparse/sparse_embedding_base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable, Optional, Union, Any 3 | 4 | import numpy as np 5 | from numpy.typing import NDArray 6 | 7 | from fastembed.common.model_description import SparseModelDescription 8 | from fastembed.common.types import NumpyArray 9 | from fastembed.common.model_management import ModelManagement 10 | 11 | 12 | @dataclass 13 | class SparseEmbedding: 14 | values: NumpyArray 15 | indices: Union[NDArray[np.int64], NDArray[np.int32]] 16 | 17 | def as_object(self) -> dict[str, NumpyArray]: 18 | return { 19 | "values": self.values, 20 | "indices": self.indices, 21 | } 22 | 23 | def as_dict(self) -> dict[int, float]: 24 | return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore 25 | 26 | @classmethod 27 | def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding": 28 | if len(data) == 0: 29 | return cls(values=np.array([]), indices=np.array([])) 30 | indices, values = zip(*data.items()) 31 | return cls(values=np.array(values), indices=np.array(indices)) 32 | 33 | 34 | class SparseTextEmbeddingBase(ModelManagement[SparseModelDescription]): 35 | def __init__( 36 | self, 37 | model_name: str, 38 | cache_dir: Optional[str] = None, 39 | threads: Optional[int] = None, 40 | **kwargs: Any, 41 | ): 42 | self.model_name = model_name 43 | self.cache_dir = cache_dir 44 | self.threads = threads 45 | self._local_files_only = kwargs.pop("local_files_only", False) 46 | 47 | def embed( 48 | self, 49 | documents: Union[str, Iterable[str]], 50 | batch_size: int = 256, 51 | parallel: Optional[int] = None, 52 | **kwargs: Any, 53 | ) -> Iterable[SparseEmbedding]: 54 | raise NotImplementedError() 55 | 56 | def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]: 57 | """ 58 | Embeds a list of text passages into a list of embeddings. 59 | 60 | Args: 61 | texts (Iterable[str]): The list of texts to embed. 62 | **kwargs: Additional keyword argument to pass to the embed method. 63 | 64 | Yields: 65 | Iterable[SparseEmbedding]: The sparse embeddings. 66 | """ 67 | 68 | # This is model-specific, so that different models can have specialized implementations 69 | yield from self.embed(texts, **kwargs) 70 | 71 | def query_embed( 72 | self, query: Union[str, Iterable[str]], **kwargs: Any 73 | ) -> Iterable[SparseEmbedding]: 74 | """ 75 | Embeds queries 76 | 77 | Args: 78 | query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. 79 | 80 | Returns: 81 | Iterable[SparseEmbedding]: The sparse embeddings. 82 | """ 83 | 84 | # This is model-specific, so that different models can have specialized implementations 85 | if isinstance(query, str): 86 | yield from self.embed([query], **kwargs) 87 | else: 88 | yield from self.embed(query, **kwargs) 89 | -------------------------------------------------------------------------------- /fastembed/sparse/sparse_text_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Optional, Sequence, Type, Union 2 | from dataclasses import asdict 3 | 4 | from fastembed.common import OnnxProvider 5 | from fastembed.sparse.bm25 import Bm25 6 | from fastembed.sparse.bm42 import Bm42 7 | from fastembed.sparse.minicoil import MiniCOIL 8 | from fastembed.sparse.sparse_embedding_base import ( 9 | SparseEmbedding, 10 | SparseTextEmbeddingBase, 11 | ) 12 | from fastembed.sparse.splade_pp import SpladePP 13 | import warnings 14 | from fastembed.common.model_description import SparseModelDescription 15 | 16 | 17 | class SparseTextEmbedding(SparseTextEmbeddingBase): 18 | EMBEDDINGS_REGISTRY: list[Type[SparseTextEmbeddingBase]] = [SpladePP, Bm42, Bm25, MiniCOIL] 19 | 20 | @classmethod 21 | def list_supported_models(cls) -> list[dict[str, Any]]: 22 | """ 23 | Lists the supported models. 24 | 25 | Returns: 26 | list[dict[str, Any]]: A list of dictionaries containing the model information. 27 | 28 | Example: 29 | ``` 30 | [ 31 | { 32 | "model": "prithvida/SPLADE_PP_en_v1", 33 | "vocab_size": 30522, 34 | "description": "Independent Implementation of SPLADE++ Model for English", 35 | "license": "apache-2.0", 36 | "size_in_GB": 0.532, 37 | "sources": { 38 | "hf": "qdrant/SPLADE_PP_en_v1", 39 | }, 40 | } 41 | ] 42 | ``` 43 | """ 44 | return [asdict(model) for model in cls._list_supported_models()] 45 | 46 | @classmethod 47 | def _list_supported_models(cls) -> list[SparseModelDescription]: 48 | result: list[SparseModelDescription] = [] 49 | for embedding in cls.EMBEDDINGS_REGISTRY: 50 | result.extend(embedding._list_supported_models()) 51 | return result 52 | 53 | def __init__( 54 | self, 55 | model_name: str, 56 | cache_dir: Optional[str] = None, 57 | threads: Optional[int] = None, 58 | providers: Optional[Sequence[OnnxProvider]] = None, 59 | cuda: bool = False, 60 | device_ids: Optional[list[int]] = None, 61 | lazy_load: bool = False, 62 | **kwargs: Any, 63 | ): 64 | super().__init__(model_name, cache_dir, threads, **kwargs) 65 | if model_name.lower() == "prithvida/Splade_PP_en_v1".lower(): 66 | warnings.warn( 67 | "The right spelling is prithivida/Splade_PP_en_v1. " 68 | "Support of this name will be removed soon, please fix the model_name", 69 | DeprecationWarning, 70 | stacklevel=2, 71 | ) 72 | model_name = "prithivida/Splade_PP_en_v1" 73 | 74 | for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: 75 | supported_models = EMBEDDING_MODEL_TYPE._list_supported_models() 76 | if any(model_name.lower() == model.model.lower() for model in supported_models): 77 | self.model = EMBEDDING_MODEL_TYPE( 78 | model_name, 79 | cache_dir, 80 | threads=threads, 81 | providers=providers, 82 | cuda=cuda, 83 | device_ids=device_ids, 84 | lazy_load=lazy_load, 85 | **kwargs, 86 | ) 87 | return 88 | 89 | raise ValueError( 90 | f"Model {model_name} is not supported in SparseTextEmbedding." 91 | "Please check the supported models using `SparseTextEmbedding.list_supported_models()`" 92 | ) 93 | 94 | def embed( 95 | self, 96 | documents: Union[str, Iterable[str]], 97 | batch_size: int = 256, 98 | parallel: Optional[int] = None, 99 | **kwargs: Any, 100 | ) -> Iterable[SparseEmbedding]: 101 | """ 102 | Encode a list of documents into list of embeddings. 103 | We use mean pooling with attention so that the model can handle variable-length inputs. 104 | 105 | Args: 106 | documents: Iterator of documents or single document to embed 107 | batch_size: Batch size for encoding -- higher values will use more memory, but be faster 108 | parallel: 109 | If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. 110 | If 0, use all available cores. 111 | If None, don't use data-parallel processing, use default onnxruntime threading instead. 112 | 113 | Returns: 114 | List of embeddings, one per document 115 | """ 116 | yield from self.model.embed(documents, batch_size, parallel, **kwargs) 117 | 118 | def query_embed( 119 | self, query: Union[str, Iterable[str]], **kwargs: Any 120 | ) -> Iterable[SparseEmbedding]: 121 | """ 122 | Embeds queries 123 | 124 | Args: 125 | query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. 126 | 127 | Returns: 128 | Iterable[SparseEmbedding]: The sparse embeddings. 129 | """ 130 | yield from self.model.query_embed(query, **kwargs) 131 | -------------------------------------------------------------------------------- /fastembed/sparse/utils/minicoil_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pure numpy implementation of encoder model for a single word. 3 | 4 | This model is not trainable, and should only be used for inference. 5 | """ 6 | 7 | import numpy as np 8 | from fastembed.common.types import NumpyArray 9 | 10 | 11 | class Encoder: 12 | """ 13 | Encoder(768, 4, 10000) 14 | 15 | Will look like this: 16 | 17 | 18 | Per-word 19 | Encoder Matrix 20 | ┌─────────────────────┐ 21 | │ Token Embedding(768)├──────┐ (10k, 768, 4) 22 | └─────────────────────┘ │ ┌─────────┐ 23 | │ │ │ 24 | ┌─────────────────────┐ │ ┌─┴───────┐ │ 25 | │ │ │ │ │ │ 26 | └─────────────────────┘ │ ┌─┴───────┐ │ │ ┌─────────┐ 27 | └────►│ │ │ ├─────►│Tanh │ 28 | ┌─────────────────────┐ │ │ │ │ └─────────┘ 29 | │ │ │ │ ├─┘ 30 | └─────────────────────┘ │ ├─┘ 31 | │ │ 32 | ┌─────────────────────┐ └─────────┘ 33 | │ │ 34 | └─────────────────────┘ 35 | 36 | Final linear transformation is accompanied by a non-linear activation function: Tanh. 37 | 38 | Tanh is used to ensure that the output is in the range [-1, 1]. 39 | It would be easier to visually interpret the output of the model, assuming that each dimension 40 | would need to encode a type of semantic cluster. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | weights: NumpyArray, 46 | ): 47 | self.weights = weights 48 | self.vocab_size, self.input_dim, self.output_dim = weights.shape 49 | 50 | self.encoder_weights: NumpyArray = weights 51 | 52 | # Activation function 53 | self.activation = np.tanh 54 | 55 | @staticmethod 56 | def convert_vocab_ids(vocab_ids: NumpyArray) -> NumpyArray: 57 | """ 58 | Convert vocab_ids of shape (batch_size, seq_len) into (batch_size, seq_len, 2) 59 | by appending batch_id alongside each vocab_id. 60 | """ 61 | batch_size, seq_len = vocab_ids.shape 62 | batch_ids = np.arange(batch_size, dtype=vocab_ids.dtype).reshape(batch_size, 1) 63 | batch_ids = np.repeat(batch_ids, seq_len, axis=1) 64 | # Stack vocab_ids and batch_ids along the last dimension 65 | combined: NumpyArray = np.stack((vocab_ids, batch_ids), axis=2).astype(np.int32) 66 | return combined 67 | 68 | @classmethod 69 | def avg_by_vocab_ids( 70 | cls, vocab_ids: NumpyArray, embeddings: NumpyArray 71 | ) -> tuple[NumpyArray, NumpyArray]: 72 | """ 73 | Takes: 74 | vocab_ids: (batch_size, seq_len) int array 75 | embeddings: (batch_size, seq_len, input_dim) float array 76 | 77 | Returns: 78 | unique_flattened_vocab_ids: (total_unique, 2) array of [vocab_id, batch_id] 79 | unique_flattened_embeddings: (total_unique, input_dim) averaged embeddings 80 | """ 81 | input_dim = embeddings.shape[2] 82 | 83 | # Flatten vocab_ids and embeddings 84 | # flattened_vocab_ids: (batch_size*seq_len, 2) 85 | flattened_vocab_ids = cls.convert_vocab_ids(vocab_ids).reshape(-1, 2) 86 | 87 | # flattened_embeddings: (batch_size*seq_len, input_dim) 88 | flattened_embeddings = embeddings.reshape(-1, input_dim) 89 | 90 | # Find unique (vocab_id, batch_id) pairs 91 | unique_flattened_vocab_ids, inverse_indices = np.unique( 92 | flattened_vocab_ids, axis=0, return_inverse=True 93 | ) 94 | 95 | # Prepare arrays to accumulate sums 96 | unique_count = unique_flattened_vocab_ids.shape[0] 97 | unique_flattened_embeddings = np.zeros((unique_count, input_dim), dtype=np.float32) 98 | unique_flattened_count = np.zeros(unique_count, dtype=np.int32) 99 | 100 | # Use np.add.at to accumulate sums based on inverse indices 101 | np.add.at(unique_flattened_embeddings, inverse_indices, flattened_embeddings) 102 | np.add.at(unique_flattened_count, inverse_indices, 1) 103 | 104 | # Compute averages 105 | unique_flattened_embeddings /= unique_flattened_count[:, None] 106 | 107 | return unique_flattened_vocab_ids.astype(np.int32), unique_flattened_embeddings.astype( 108 | np.float32 109 | ) 110 | 111 | def forward( 112 | self, vocab_ids: NumpyArray, embeddings: NumpyArray 113 | ) -> tuple[NumpyArray, NumpyArray]: 114 | """ 115 | Args: 116 | vocab_ids: (batch_size, seq_len) int array 117 | embeddings: (batch_size, seq_len, input_dim) float array 118 | 119 | Returns: 120 | unique_flattened_vocab_ids_and_batch_ids: (total_unique, 2) 121 | unique_flattened_encoded: (total_unique, output_dim) 122 | """ 123 | # Average embeddings for duplicate vocab_ids 124 | unique_flattened_vocab_ids_and_batch_ids, unique_flattened_embeddings = ( 125 | self.avg_by_vocab_ids(vocab_ids, embeddings) 126 | ) 127 | 128 | # Select the encoder weights for each unique vocab_id 129 | unique_flattened_vocab_ids = unique_flattened_vocab_ids_and_batch_ids[:, 0].astype( 130 | np.int32 131 | ) 132 | 133 | # unique_encoder_weights: (total_unique, input_dim, output_dim) 134 | unique_encoder_weights = self.encoder_weights[unique_flattened_vocab_ids] 135 | 136 | # Compute linear transform: (total_unique, output_dim) 137 | # Using Einstein summation for matrix multiplication: 138 | # 'bi,bio->bo' means: for each "b" (batch element), multiply embeddings (b,i) by weights (b,i,o) -> (b,o) 139 | unique_flattened_encoded = np.einsum( 140 | "bi,bio->bo", unique_flattened_embeddings, unique_encoder_weights 141 | ) 142 | 143 | # Apply Tanh activation and ensure float32 type 144 | unique_flattened_encoded = self.activation(unique_flattened_encoded).astype(np.float32) 145 | 146 | return unique_flattened_vocab_ids_and_batch_ids.astype(np.int32), unique_flattened_encoded 147 | -------------------------------------------------------------------------------- /fastembed/sparse/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # This code is a modified copy of the `NLTKWordTokenizer` class from `NLTK` library. 2 | 3 | import re 4 | 5 | 6 | class SimpleTokenizer: 7 | @staticmethod 8 | def tokenize(text: str) -> list[str]: 9 | text = re.sub(r"[^\w]", " ", text.lower()) 10 | text = re.sub(r"\s+", " ", text) 11 | 12 | return text.strip().split() 13 | 14 | 15 | class WordTokenizer: 16 | """The tokenizer is "destructive" such that the regexes applied will munge the 17 | input string to a state beyond re-construction. 18 | """ 19 | 20 | # Starting quotes. 21 | STARTING_QUOTES = [ 22 | (re.compile("([«“‘„]|[`]+)", re.U), r" \1 "), 23 | (re.compile(r"^\""), r"``"), 24 | (re.compile(r"(``)"), r" \1 "), 25 | (re.compile(r"([ \(\[{<])(\"|\'{2})"), r"\1 `` "), 26 | (re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"), 27 | ] 28 | 29 | # Ending quotes. 30 | ENDING_QUOTES = [ 31 | (re.compile("([»”’])", re.U), r" \1 "), 32 | (re.compile(r"''"), " '' "), 33 | (re.compile(r'"'), " '' "), 34 | (re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "), 35 | (re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "), 36 | ] 37 | 38 | # Punctuation. 39 | PUNCTUATION = [ 40 | (re.compile(r'([^\.])(\.)([\]\)}>"\'' "»”’ " r"]*)\s*$", re.U), r"\1 \2 \3 "), 41 | (re.compile(r"([:,])([^\d])"), r" \1 \2"), 42 | (re.compile(r"([:,])$"), r" \1 "), 43 | ( 44 | re.compile(r"\.{2,}", re.U), 45 | r" \g<0> ", 46 | ), 47 | (re.compile(r"[;@#$%&]"), r" \g<0> "), 48 | ( 49 | re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'), 50 | r"\1 \2\3 ", 51 | ), # Handles the final period. 52 | (re.compile(r"[?!]"), r" \g<0> "), 53 | (re.compile(r"([^'])' "), r"\1 ' "), 54 | ( 55 | re.compile(r"[*]", re.U), 56 | r" \g<0> ", 57 | ), 58 | ] 59 | 60 | # Pads parentheses 61 | PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ") 62 | DOUBLE_DASHES = (re.compile(r"--"), r" -- ") 63 | 64 | # List of contractions adapted from Robert MacIntyre's tokenizer. 65 | CONTRACTIONS2 = [ 66 | re.compile(pattern) 67 | for pattern in ( 68 | r"(?i)\b(can)(?#X)(not)\b", 69 | r"(?i)\b(d)(?#X)('ye)\b", 70 | r"(?i)\b(gim)(?#X)(me)\b", 71 | r"(?i)\b(gon)(?#X)(na)\b", 72 | r"(?i)\b(got)(?#X)(ta)\b", 73 | r"(?i)\b(lem)(?#X)(me)\b", 74 | r"(?i)\b(more)(?#X)('n)\b", 75 | r"(?i)\b(wan)(?#X)(na)(?=\s)", 76 | ) 77 | ] 78 | CONTRACTIONS3 = [ 79 | re.compile(pattern) for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b") 80 | ] 81 | 82 | @classmethod 83 | def tokenize(cls, text: str) -> list[str]: 84 | """Return a tokenized copy of `text`. 85 | 86 | >>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York.''' 87 | >>> WordTokenizer().tokenize(s) 88 | ['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36', 'euros', ')', 'in', 'New', 'York', '.'] 89 | 90 | Args: 91 | text: The text to be tokenized. 92 | 93 | Returns: 94 | A list of tokens. 95 | """ 96 | for regexp, substitution in cls.STARTING_QUOTES: 97 | text = regexp.sub(substitution, text) 98 | 99 | for regexp, substitution in cls.PUNCTUATION: 100 | text = regexp.sub(substitution, text) 101 | 102 | # Handles parentheses. 103 | regexp, substitution = cls.PARENS_BRACKETS 104 | text = regexp.sub(substitution, text) 105 | 106 | # Handles double dash. 107 | regexp, substitution = cls.DOUBLE_DASHES 108 | text = regexp.sub(substitution, text) 109 | 110 | # add extra space to make things easier 111 | text = " " + text + " " 112 | 113 | for regexp, substitution in cls.ENDING_QUOTES: 114 | text = regexp.sub(substitution, text) 115 | 116 | for regexp in cls.CONTRACTIONS2: 117 | text = regexp.sub(r" \1 \2 ", text) 118 | for regexp in cls.CONTRACTIONS3: 119 | text = regexp.sub(r" \1 \2 ", text) 120 | return text.split() 121 | -------------------------------------------------------------------------------- /fastembed/sparse/utils/vocab_resolver.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Iterable 3 | 4 | from py_rust_stemmers import SnowballStemmer 5 | import numpy as np 6 | from tokenizers import Tokenizer 7 | from numpy.typing import NDArray 8 | 9 | from fastembed.common.types import NumpyArray 10 | 11 | 12 | class VocabTokenizerBase: 13 | def tokenize(self, sentence: str) -> NumpyArray: 14 | raise NotImplementedError() 15 | 16 | def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]: 17 | raise NotImplementedError() 18 | 19 | 20 | class VocabTokenizer(VocabTokenizerBase): 21 | def __init__(self, tokenizer: Tokenizer): 22 | self.tokenizer = tokenizer 23 | 24 | def tokenize(self, sentence: str) -> NumpyArray: 25 | return np.array(self.tokenizer.encode(sentence).ids) 26 | 27 | def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]: 28 | return [self.tokenizer.id_to_token(token_id) for token_id in token_ids] 29 | 30 | 31 | class VocabResolver: 32 | def __init__(self, tokenizer: VocabTokenizerBase, stopwords: set[str], stemmer: SnowballStemmer): 33 | # Word to id mapping 34 | self.vocab: dict[str, int] = {} 35 | # Id to word mapping 36 | self.words: list[str] = [] 37 | # Lemma to word mapping 38 | self.stem_mapping: dict[str, str] = {} 39 | self.tokenizer: VocabTokenizerBase = tokenizer 40 | self.stemmer = stemmer 41 | self.stopwords: set[str] = stopwords 42 | 43 | def tokenize(self, sentence: str) -> NumpyArray: 44 | return self.tokenizer.tokenize(sentence) 45 | 46 | def lookup_word(self, word_id: int) -> str: 47 | if word_id == 0: 48 | return "UNK" 49 | return self.words[word_id - 1] 50 | 51 | def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]: 52 | return self.tokenizer.convert_ids_to_tokens(token_ids) 53 | 54 | def vocab_size(self) -> int: 55 | # We need +1 for UNK token 56 | return len(self.vocab) + 1 57 | 58 | def save_vocab(self, path: str) -> None: 59 | with open(path, "w") as f: 60 | for word in self.words: 61 | f.write(word + "\n") 62 | 63 | def save_json_vocab(self, path: str) -> None: 64 | import json 65 | 66 | with open(path, "w") as f: 67 | json.dump({"vocab": self.words, "stem_mapping": self.stem_mapping}, f, indent=2) 68 | 69 | def load_json_vocab(self, path: str) -> None: 70 | import json 71 | 72 | with open(path, "r") as f: 73 | data = json.load(f) 74 | self.words = data["vocab"] 75 | self.vocab = {word: idx + 1 for idx, word in enumerate(self.words)} 76 | self.stem_mapping = data["stem_mapping"] 77 | 78 | def add_word(self, word: str) -> None: 79 | if word not in self.vocab: 80 | self.vocab[word] = len(self.vocab) + 1 81 | self.words.append(word) 82 | stem = self.stemmer.stem_word(word) 83 | if stem not in self.stem_mapping: 84 | self.stem_mapping[stem] = word 85 | else: 86 | existing_word = self.stem_mapping[stem] 87 | if len(existing_word) > len(word): 88 | # Prefer shorter words for the same stem 89 | # Example: "swim" is preferred over "swimming" 90 | self.stem_mapping[stem] = word 91 | 92 | def load_vocab(self, path: str) -> None: 93 | with open(path, "r") as f: 94 | for line in f: 95 | self.add_word(line.strip()) 96 | 97 | @classmethod 98 | def _reconstruct_bpe( 99 | cls, bpe_tokens: Iterable[tuple[int, str]] 100 | ) -> list[tuple[str, list[int]]]: 101 | result: list[tuple[str, list[int]]] = [] 102 | acc: str = "" 103 | acc_idx: list[int] = [] 104 | 105 | continuing_subword_prefix = "##" 106 | continuing_subword_prefix_len = len(continuing_subword_prefix) 107 | 108 | for idx, token in bpe_tokens: 109 | if token.startswith(continuing_subword_prefix): 110 | acc += token[continuing_subword_prefix_len:] 111 | acc_idx.append(idx) 112 | else: 113 | if acc: 114 | result.append((acc, acc_idx)) 115 | acc_idx = [] 116 | acc = token 117 | acc_idx.append(idx) 118 | 119 | if acc: 120 | result.append((acc, acc_idx)) 121 | return result 122 | 123 | def resolve_tokens( 124 | self, token_ids: NDArray[np.int64] 125 | ) -> tuple[NDArray[np.int64], dict[int, int], dict[str, int], dict[str, list[str]]]: 126 | """ 127 | Mark known tokens (including composed tokens) with vocab ids. 128 | 129 | Args: 130 | token_ids: (seq_len) - list of ids of tokens 131 | Example: 132 | [ 133 | 101, 3897, 19332, 12718, 23348, 134 | 1010, 1996, 7151, 2296, 4845, 135 | 2359, 2005, 4234, 1010, 4332, 136 | 2871, 3191, 2062, 102 137 | ] 138 | 139 | returns: 140 | - token_ids with vocab ids 141 | [ 142 | 0, 151, 151, 0, 0, 143 | 912, 0, 0, 0, 332, 144 | 332, 332, 0, 7121, 191, 145 | 0, 0, 332, 0 146 | ] 147 | - counts of each token 148 | { 149 | 151: 1, 150 | 332: 3, 151 | 7121: 1, 152 | 191: 1, 153 | 912: 1 154 | } 155 | - oov counts of each token 156 | { 157 | "the": 1, 158 | "a": 1, 159 | "[CLS]": 1, 160 | "[SEP]": 1, 161 | ... 162 | } 163 | - forms of each token 164 | { 165 | "hello": ["hello"], 166 | "world": ["worlds", "world", "worlding"], 167 | } 168 | 169 | """ 170 | tokens = self.convert_ids_to_tokens(token_ids) 171 | tokens_mapping = self._reconstruct_bpe(enumerate(tokens)) 172 | 173 | counts: dict[int, int] = defaultdict(int) 174 | oov_count: dict[str, int] = defaultdict(int) 175 | 176 | forms: dict[str, list[str]] = defaultdict(list) 177 | 178 | for token, mapped_token_ids in tokens_mapping: 179 | vocab_id = 0 180 | if token in self.stopwords: 181 | vocab_id = 0 182 | elif token in self.vocab: 183 | vocab_id = self.vocab[token] 184 | forms[token].append(token) 185 | elif token in self.stem_mapping: 186 | vocab_id = self.vocab[self.stem_mapping[token]] 187 | forms[self.stem_mapping[token]].append(token) 188 | else: 189 | stem = self.stemmer.stem_word(token) 190 | if stem in self.stem_mapping: 191 | vocab_id = self.vocab[self.stem_mapping[stem]] 192 | forms[self.stem_mapping[stem]].append(token) 193 | 194 | for token_id in mapped_token_ids: 195 | token_ids[token_id] = vocab_id 196 | 197 | if vocab_id == 0: 198 | oov_count[token] += 1 199 | else: 200 | counts[vocab_id] += 1 201 | return token_ids, counts, oov_count, forms 202 | 203 | -------------------------------------------------------------------------------- /fastembed/text/__init__.py: -------------------------------------------------------------------------------- 1 | from fastembed.text.text_embedding import TextEmbedding 2 | 3 | __all__ = ["TextEmbedding"] 4 | -------------------------------------------------------------------------------- /fastembed/text/clip_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Type 2 | 3 | from fastembed.common.types import NumpyArray 4 | from fastembed.common.onnx_model import OnnxOutputContext 5 | from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker 6 | from fastembed.common.model_description import DenseModelDescription, ModelSource 7 | 8 | supported_clip_models: list[DenseModelDescription] = [ 9 | DenseModelDescription( 10 | model="Qdrant/clip-ViT-B-32-text", 11 | dim=512, 12 | description=( 13 | "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, " 14 | "Prefixes for queries/documents: not necessary, 2021 year" 15 | ), 16 | license="mit", 17 | size_in_GB=0.25, 18 | sources=ModelSource(hf="Qdrant/clip-ViT-B-32-text"), 19 | model_file="model.onnx", 20 | ), 21 | ] 22 | 23 | 24 | class CLIPOnnxEmbedding(OnnxTextEmbedding): 25 | @classmethod 26 | def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: 27 | return CLIPEmbeddingWorker 28 | 29 | @classmethod 30 | def _list_supported_models(cls) -> list[DenseModelDescription]: 31 | """Lists the supported models. 32 | 33 | Returns: 34 | list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. 35 | """ 36 | return supported_clip_models 37 | 38 | def _post_process_onnx_output( 39 | self, output: OnnxOutputContext, **kwargs: Any 40 | ) -> Iterable[NumpyArray]: 41 | return output.model_output 42 | 43 | 44 | class CLIPEmbeddingWorker(OnnxTextEmbeddingWorker): 45 | def init_embedding( 46 | self, 47 | model_name: str, 48 | cache_dir: str, 49 | **kwargs: Any, 50 | ) -> OnnxTextEmbedding: 51 | return CLIPOnnxEmbedding( 52 | model_name=model_name, 53 | cache_dir=cache_dir, 54 | threads=1, 55 | **kwargs, 56 | ) 57 | -------------------------------------------------------------------------------- /fastembed/text/custom_text_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Any, Iterable 2 | 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | from numpy.typing import NDArray 7 | 8 | from fastembed.common import OnnxProvider 9 | from fastembed.common.model_description import ( 10 | PoolingType, 11 | DenseModelDescription, 12 | ) 13 | from fastembed.common.onnx_model import OnnxOutputContext 14 | from fastembed.common.types import NumpyArray 15 | from fastembed.common.utils import normalize, mean_pooling 16 | from fastembed.text.onnx_embedding import OnnxTextEmbedding 17 | 18 | 19 | @dataclass(frozen=True) 20 | class PostprocessingConfig: 21 | pooling: PoolingType 22 | normalization: bool 23 | 24 | 25 | class CustomTextEmbedding(OnnxTextEmbedding): 26 | SUPPORTED_MODELS: list[DenseModelDescription] = [] 27 | POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {} 28 | 29 | def __init__( 30 | self, 31 | model_name: str, 32 | cache_dir: Optional[str] = None, 33 | threads: Optional[int] = None, 34 | providers: Optional[Sequence[OnnxProvider]] = None, 35 | cuda: bool = False, 36 | device_ids: Optional[list[int]] = None, 37 | lazy_load: bool = False, 38 | device_id: Optional[int] = None, 39 | specific_model_path: Optional[str] = None, 40 | **kwargs: Any, 41 | ): 42 | super().__init__( 43 | model_name=model_name, 44 | cache_dir=cache_dir, 45 | threads=threads, 46 | providers=providers, 47 | cuda=cuda, 48 | device_ids=device_ids, 49 | lazy_load=lazy_load, 50 | device_id=device_id, 51 | specific_model_path=specific_model_path, 52 | **kwargs, 53 | ) 54 | self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling 55 | self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization 56 | 57 | @classmethod 58 | def _list_supported_models(cls) -> list[DenseModelDescription]: 59 | return cls.SUPPORTED_MODELS 60 | 61 | def _post_process_onnx_output( 62 | self, output: OnnxOutputContext, **kwargs: Any 63 | ) -> Iterable[NumpyArray]: 64 | return self._normalize(self._pool(output.model_output, output.attention_mask)) 65 | 66 | def _pool( 67 | self, embeddings: NumpyArray, attention_mask: Optional[NDArray[np.int64]] = None 68 | ) -> NumpyArray: 69 | if self._pooling == PoolingType.CLS: 70 | return embeddings[:, 0] 71 | 72 | if self._pooling == PoolingType.MEAN: 73 | if attention_mask is None: 74 | raise ValueError("attention_mask must be provided for mean pooling") 75 | return mean_pooling(embeddings, attention_mask) 76 | 77 | if self._pooling == PoolingType.DISABLED: 78 | return embeddings 79 | 80 | raise ValueError( 81 | f"Unsupported pooling type {self._pooling}. " 82 | f"Supported types are: {PoolingType.CLS}, {PoolingType.MEAN}, {PoolingType.DISABLED}." 83 | ) 84 | 85 | def _normalize(self, embeddings: NumpyArray) -> NumpyArray: 86 | return normalize(embeddings) if self._normalization else embeddings 87 | 88 | @classmethod 89 | def add_model( 90 | cls, 91 | model_description: DenseModelDescription, 92 | pooling: PoolingType, 93 | normalization: bool, 94 | ) -> None: 95 | cls.SUPPORTED_MODELS.append(model_description) 96 | cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig( 97 | pooling=pooling, normalization=normalization 98 | ) 99 | -------------------------------------------------------------------------------- /fastembed/text/multitask_embedding.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Type, Iterable, Union, Optional 3 | 4 | import numpy as np 5 | 6 | from fastembed.common.onnx_model import OnnxOutputContext 7 | from fastembed.common.types import NumpyArray 8 | from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding 9 | from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker 10 | from fastembed.common.model_description import DenseModelDescription, ModelSource 11 | 12 | supported_multitask_models: list[DenseModelDescription] = [ 13 | DenseModelDescription( 14 | model="jinaai/jina-embeddings-v3", 15 | dim=1024, 16 | tasks={ 17 | "retrieval.query": 0, 18 | "retrieval.passage": 1, 19 | "separation": 2, 20 | "classification": 3, 21 | "text-matching": 4, 22 | }, 23 | description=( 24 | "Multi-task unimodal (text) embedding model, multi-lingual (~100), " 25 | "1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year." 26 | ), 27 | license="cc-by-nc-4.0", 28 | size_in_GB=2.29, 29 | sources=ModelSource(hf="jinaai/jina-embeddings-v3"), 30 | model_file="onnx/model.onnx", 31 | additional_files=["onnx/model.onnx_data"], 32 | ), 33 | ] 34 | 35 | 36 | class Task(int, Enum): 37 | RETRIEVAL_QUERY = 0 38 | RETRIEVAL_PASSAGE = 1 39 | SEPARATION = 2 40 | CLASSIFICATION = 3 41 | TEXT_MATCHING = 4 42 | 43 | 44 | class JinaEmbeddingV3(PooledNormalizedEmbedding): 45 | PASSAGE_TASK = Task.RETRIEVAL_PASSAGE 46 | QUERY_TASK = Task.RETRIEVAL_QUERY 47 | 48 | def __init__(self, *args: Any, task_id: Optional[int] = None, **kwargs: Any): 49 | super().__init__(*args, **kwargs) 50 | self.default_task_id: Union[Task, int] = ( 51 | task_id if task_id is not None else self.PASSAGE_TASK 52 | ) 53 | 54 | @classmethod 55 | def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: 56 | return JinaEmbeddingV3Worker 57 | 58 | @classmethod 59 | def _list_supported_models(cls) -> list[DenseModelDescription]: 60 | return supported_multitask_models 61 | 62 | def _preprocess_onnx_input( 63 | self, 64 | onnx_input: dict[str, NumpyArray], 65 | task_id: Optional[Union[int, Task]] = None, 66 | **kwargs: Any, 67 | ) -> dict[str, NumpyArray]: 68 | if task_id is None: 69 | raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>") 70 | onnx_input["task_id"] = np.array(task_id, dtype=np.int64) 71 | return onnx_input 72 | 73 | def embed( 74 | self, 75 | documents: Union[str, Iterable[str]], 76 | batch_size: int = 256, 77 | parallel: Optional[int] = None, 78 | task_id: Optional[int] = None, 79 | **kwargs: Any, 80 | ) -> Iterable[NumpyArray]: 81 | task_id = ( 82 | task_id if task_id is not None else self.default_task_id 83 | ) # required for multiprocessing 84 | yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs) 85 | 86 | def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: 87 | yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs) 88 | 89 | def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: 90 | yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs) 91 | 92 | 93 | class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker): 94 | def init_embedding( 95 | self, 96 | model_name: str, 97 | cache_dir: str, 98 | **kwargs: Any, 99 | ) -> JinaEmbeddingV3: 100 | return JinaEmbeddingV3( 101 | model_name=model_name, 102 | cache_dir=cache_dir, 103 | threads=1, 104 | **kwargs, 105 | ) 106 | 107 | def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]: 108 | self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id` 109 | for idx, batch in items: 110 | onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id) 111 | yield idx, onnx_output 112 | -------------------------------------------------------------------------------- /fastembed/text/onnx_text_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import get_all_start_methods 3 | from pathlib import Path 4 | from typing import Any, Iterable, Optional, Sequence, Type, Union 5 | 6 | import numpy as np 7 | from numpy.typing import NDArray 8 | from tokenizers import Encoding, Tokenizer 9 | 10 | from fastembed.common.types import NumpyArray, OnnxProvider 11 | from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T 12 | from fastembed.common.preprocessor_utils import load_tokenizer 13 | from fastembed.common.utils import iter_batch 14 | from fastembed.parallel_processor import ParallelWorkerPool 15 | 16 | 17 | class OnnxTextModel(OnnxModel[T]): 18 | ONNX_OUTPUT_NAMES: Optional[list[str]] = None 19 | 20 | @classmethod 21 | def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]: 22 | raise NotImplementedError("Subclasses must implement this method") 23 | 24 | def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]: 25 | """Post-process the ONNX model output to convert it into a usable format. 26 | 27 | Args: 28 | output (OnnxOutputContext): The raw output from the ONNX model. 29 | **kwargs: Additional keyword arguments that may be needed by specific implementations. 30 | 31 | Returns: 32 | Iterable[T]: Post-processed output as an iterable of type T. 33 | """ 34 | raise NotImplementedError("Subclasses must implement this method") 35 | 36 | def __init__(self) -> None: 37 | super().__init__() 38 | self.tokenizer: Optional[Tokenizer] = None 39 | self.special_token_to_id: dict[str, int] = {} 40 | 41 | def _preprocess_onnx_input( 42 | self, onnx_input: dict[str, NumpyArray], **kwargs: Any 43 | ) -> dict[str, Union[NumpyArray, NDArray[np.int64]]]: 44 | """ 45 | Preprocess the onnx input. 46 | """ 47 | return onnx_input 48 | 49 | def _load_onnx_model( 50 | self, 51 | model_dir: Path, 52 | model_file: str, 53 | threads: Optional[int], 54 | providers: Optional[Sequence[OnnxProvider]] = None, 55 | cuda: bool = False, 56 | device_id: Optional[int] = None, 57 | ) -> None: 58 | super()._load_onnx_model( 59 | model_dir=model_dir, 60 | model_file=model_file, 61 | threads=threads, 62 | providers=providers, 63 | cuda=cuda, 64 | device_id=device_id, 65 | ) 66 | self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir) 67 | 68 | def load_onnx_model(self) -> None: 69 | raise NotImplementedError("Subclasses must implement this method") 70 | 71 | def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]: 72 | return self.tokenizer.encode_batch(documents) # type: ignore[union-attr] 73 | 74 | def onnx_embed( 75 | self, 76 | documents: list[str], 77 | **kwargs: Any, 78 | ) -> OnnxOutputContext: 79 | encoded = self.tokenize(documents, **kwargs) 80 | input_ids = np.array([e.ids for e in encoded]) 81 | attention_mask = np.array([e.attention_mask for e in encoded]) 82 | input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr] 83 | onnx_input: dict[str, NumpyArray] = { 84 | "input_ids": np.array(input_ids, dtype=np.int64), 85 | } 86 | if "attention_mask" in input_names: 87 | onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64) 88 | if "token_type_ids" in input_names: 89 | onnx_input["token_type_ids"] = np.array( 90 | [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 91 | ) 92 | onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) 93 | 94 | model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr] 95 | return OnnxOutputContext( 96 | model_output=model_output[0], 97 | attention_mask=onnx_input.get("attention_mask", attention_mask), 98 | input_ids=onnx_input.get("input_ids", input_ids), 99 | ) 100 | 101 | def _embed_documents( 102 | self, 103 | model_name: str, 104 | cache_dir: str, 105 | documents: Union[str, Iterable[str]], 106 | batch_size: int = 256, 107 | parallel: Optional[int] = None, 108 | providers: Optional[Sequence[OnnxProvider]] = None, 109 | cuda: bool = False, 110 | device_ids: Optional[list[int]] = None, 111 | local_files_only: bool = False, 112 | specific_model_path: Optional[str] = None, 113 | **kwargs: Any, 114 | ) -> Iterable[T]: 115 | is_small = False 116 | 117 | if isinstance(documents, str): 118 | documents = [documents] 119 | is_small = True 120 | 121 | if isinstance(documents, list): 122 | if len(documents) < batch_size: 123 | is_small = True 124 | 125 | if parallel is None or is_small: 126 | if not hasattr(self, "model") or self.model is None: 127 | self.load_onnx_model() 128 | for batch in iter_batch(documents, batch_size): 129 | yield from self._post_process_onnx_output( 130 | self.onnx_embed(batch, **kwargs), **kwargs 131 | ) 132 | else: 133 | if parallel == 0: 134 | parallel = os.cpu_count() 135 | 136 | start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn" 137 | params = { 138 | "model_name": model_name, 139 | "cache_dir": cache_dir, 140 | "providers": providers, 141 | "local_files_only": local_files_only, 142 | "specific_model_path": specific_model_path, 143 | **kwargs, 144 | } 145 | 146 | pool = ParallelWorkerPool( 147 | num_workers=parallel or 1, 148 | worker=self._get_worker_class(), 149 | cuda=cuda, 150 | device_ids=device_ids, 151 | start_method=start_method, 152 | ) 153 | for batch in pool.ordered_map(iter_batch(documents, batch_size), **params): 154 | yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore 155 | 156 | 157 | class TextEmbeddingWorker(EmbeddingWorker[T]): 158 | def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]: 159 | for idx, batch in items: 160 | onnx_output = self.model.onnx_embed(batch) 161 | yield idx, onnx_output 162 | -------------------------------------------------------------------------------- /fastembed/text/pooled_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Type 2 | 3 | import numpy as np 4 | from numpy.typing import NDArray 5 | 6 | from fastembed.common.types import NumpyArray 7 | from fastembed.common.onnx_model import OnnxOutputContext 8 | from fastembed.common.utils import mean_pooling 9 | from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker 10 | from fastembed.common.model_description import DenseModelDescription, ModelSource 11 | 12 | supported_pooled_models: list[DenseModelDescription] = [ 13 | DenseModelDescription( 14 | model="nomic-ai/nomic-embed-text-v1.5", 15 | dim=768, 16 | description=( 17 | "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " 18 | "Prefixes for queries/documents: necessary, 2024 year." 19 | ), 20 | license="apache-2.0", 21 | size_in_GB=0.52, 22 | sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"), 23 | model_file="onnx/model.onnx", 24 | ), 25 | DenseModelDescription( 26 | model="nomic-ai/nomic-embed-text-v1.5-Q", 27 | dim=768, 28 | description=( 29 | "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " 30 | "Prefixes for queries/documents: necessary, 2024 year." 31 | ), 32 | license="apache-2.0", 33 | size_in_GB=0.13, 34 | sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"), 35 | model_file="onnx/model_quantized.onnx", 36 | ), 37 | DenseModelDescription( 38 | model="nomic-ai/nomic-embed-text-v1", 39 | dim=768, 40 | description=( 41 | "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, " 42 | "Prefixes for queries/documents: necessary, 2024 year." 43 | ), 44 | license="apache-2.0", 45 | size_in_GB=0.52, 46 | sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1"), 47 | model_file="onnx/model.onnx", 48 | ), 49 | DenseModelDescription( 50 | model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 51 | dim=384, 52 | description=( 53 | "Text embeddings, Unimodal (text), Multilingual (~50 languages), 512 input tokens truncation, " 54 | "Prefixes for queries/documents: not necessary, 2019 year." 55 | ), 56 | license="apache-2.0", 57 | size_in_GB=0.22, 58 | sources=ModelSource(hf="qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"), 59 | model_file="model_optimized.onnx", 60 | ), 61 | DenseModelDescription( 62 | model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", 63 | dim=768, 64 | description=( 65 | "Text embeddings, Unimodal (text), Multilingual (~50 languages), 384 input tokens truncation, " 66 | "Prefixes for queries/documents: not necessary, 2021 year." 67 | ), 68 | license="apache-2.0", 69 | size_in_GB=1.00, 70 | sources=ModelSource(hf="xenova/paraphrase-multilingual-mpnet-base-v2"), 71 | model_file="onnx/model.onnx", 72 | ), 73 | DenseModelDescription( 74 | model="intfloat/multilingual-e5-large", 75 | dim=1024, 76 | description=( 77 | "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, " 78 | "Prefixes for queries/documents: necessary, 2024 year." 79 | ), 80 | license="mit", 81 | size_in_GB=2.24, 82 | sources=ModelSource( 83 | hf="qdrant/multilingual-e5-large-onnx", 84 | url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", 85 | _deprecated_tar_struct=True, 86 | ), 87 | model_file="model.onnx", 88 | additional_files=["model.onnx_data"], 89 | ), 90 | ] 91 | 92 | 93 | class PooledEmbedding(OnnxTextEmbedding): 94 | @classmethod 95 | def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: 96 | return PooledEmbeddingWorker 97 | 98 | @classmethod 99 | def mean_pooling( 100 | cls, model_output: NumpyArray, attention_mask: NDArray[np.int64] 101 | ) -> NumpyArray: 102 | return mean_pooling(model_output, attention_mask) 103 | 104 | @classmethod 105 | def _list_supported_models(cls) -> list[DenseModelDescription]: 106 | """Lists the supported models. 107 | 108 | Returns: 109 | list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. 110 | """ 111 | return supported_pooled_models 112 | 113 | def _post_process_onnx_output( 114 | self, output: OnnxOutputContext, **kwargs: Any 115 | ) -> Iterable[NumpyArray]: 116 | if output.attention_mask is None: 117 | raise ValueError("attention_mask must be provided for document post-processing") 118 | 119 | embeddings = output.model_output 120 | attn_mask = output.attention_mask 121 | return self.mean_pooling(embeddings, attn_mask) 122 | 123 | 124 | class PooledEmbeddingWorker(OnnxTextEmbeddingWorker): 125 | def init_embedding( 126 | self, 127 | model_name: str, 128 | cache_dir: str, 129 | **kwargs: Any, 130 | ) -> OnnxTextEmbedding: 131 | return PooledEmbedding( 132 | model_name=model_name, 133 | cache_dir=cache_dir, 134 | threads=1, 135 | **kwargs, 136 | ) 137 | -------------------------------------------------------------------------------- /fastembed/text/pooled_normalized_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Type 2 | 3 | 4 | from fastembed.common.types import NumpyArray 5 | from fastembed.common.onnx_model import OnnxOutputContext 6 | from fastembed.common.utils import normalize 7 | from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker 8 | from fastembed.text.pooled_embedding import PooledEmbedding 9 | from fastembed.common.model_description import DenseModelDescription, ModelSource 10 | 11 | supported_pooled_normalized_models: list[DenseModelDescription] = [ 12 | DenseModelDescription( 13 | model="sentence-transformers/all-MiniLM-L6-v2", 14 | dim=384, 15 | description=( 16 | "Text embeddings, Unimodal (text), English, 256 input tokens truncation, " 17 | "Prefixes for queries/documents: not necessary, 2021 year." 18 | ), 19 | license="apache-2.0", 20 | size_in_GB=0.09, 21 | sources=ModelSource( 22 | url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", 23 | hf="qdrant/all-MiniLM-L6-v2-onnx", 24 | _deprecated_tar_struct=True, 25 | ), 26 | model_file="model.onnx", 27 | ), 28 | DenseModelDescription( 29 | model="jinaai/jina-embeddings-v2-base-en", 30 | dim=768, 31 | description=( 32 | "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, " 33 | "Prefixes for queries/documents: not necessary, 2023 year." 34 | ), 35 | license="apache-2.0", 36 | size_in_GB=0.52, 37 | sources=ModelSource(hf="xenova/jina-embeddings-v2-base-en"), 38 | model_file="onnx/model.onnx", 39 | ), 40 | DenseModelDescription( 41 | model="jinaai/jina-embeddings-v2-small-en", 42 | dim=512, 43 | description=( 44 | "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, " 45 | "Prefixes for queries/documents: not necessary, 2023 year." 46 | ), 47 | license="apache-2.0", 48 | size_in_GB=0.12, 49 | sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"), 50 | model_file="onnx/model.onnx", 51 | ), 52 | DenseModelDescription( 53 | model="jinaai/jina-embeddings-v2-base-de", 54 | dim=768, 55 | description=( 56 | "Text embeddings, Unimodal (text), Multilingual (German, English), 8192 input tokens truncation, " 57 | "Prefixes for queries/documents: not necessary, 2024 year." 58 | ), 59 | license="apache-2.0", 60 | size_in_GB=0.32, 61 | sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-de"), 62 | model_file="onnx/model_fp16.onnx", 63 | ), 64 | DenseModelDescription( 65 | model="jinaai/jina-embeddings-v2-base-code", 66 | dim=768, 67 | description=( 68 | "Text embeddings, Unimodal (text), Multilingual (English, 30 programming languages), " 69 | "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." 70 | ), 71 | license="apache-2.0", 72 | size_in_GB=0.64, 73 | sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-code"), 74 | model_file="onnx/model.onnx", 75 | ), 76 | DenseModelDescription( 77 | model="jinaai/jina-embeddings-v2-base-zh", 78 | dim=768, 79 | description=( 80 | "Text embeddings, Unimodal (text), supports mixed Chinese-English input text, " 81 | "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." 82 | ), 83 | license="apache-2.0", 84 | size_in_GB=0.64, 85 | sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-zh"), 86 | model_file="onnx/model.onnx", 87 | ), 88 | DenseModelDescription( 89 | model="jinaai/jina-embeddings-v2-base-es", 90 | dim=768, 91 | description=( 92 | "Text embeddings, Unimodal (text), supports mixed Spanish-English input text, " 93 | "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." 94 | ), 95 | license="apache-2.0", 96 | size_in_GB=0.64, 97 | sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-es"), 98 | model_file="onnx/model.onnx", 99 | ), 100 | DenseModelDescription( 101 | model="thenlper/gte-base", 102 | dim=768, 103 | description=( 104 | "General text embeddings, Unimodal (text), supports English only input text, " 105 | "512 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year." 106 | ), 107 | license="mit", 108 | size_in_GB=0.44, 109 | sources=ModelSource(hf="thenlper/gte-base"), 110 | model_file="onnx/model.onnx", 111 | ), 112 | DenseModelDescription( 113 | model="thenlper/gte-large", 114 | dim=1024, 115 | description=( 116 | "Text embeddings, Unimodal (text), English, 512 input tokens truncation, " 117 | "Prefixes for queries/documents: not necessary, 2023 year." 118 | ), 119 | license="mit", 120 | size_in_GB=1.20, 121 | sources=ModelSource(hf="qdrant/gte-large-onnx"), 122 | model_file="model.onnx", 123 | ), 124 | ] 125 | 126 | 127 | class PooledNormalizedEmbedding(PooledEmbedding): 128 | @classmethod 129 | def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]: 130 | return PooledNormalizedEmbeddingWorker 131 | 132 | @classmethod 133 | def _list_supported_models(cls) -> list[DenseModelDescription]: 134 | """Lists the supported models. 135 | 136 | Returns: 137 | list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information. 138 | """ 139 | return supported_pooled_normalized_models 140 | 141 | def _post_process_onnx_output( 142 | self, output: OnnxOutputContext, **kwargs: Any 143 | ) -> Iterable[NumpyArray]: 144 | if output.attention_mask is None: 145 | raise ValueError("attention_mask must be provided for document post-processing") 146 | 147 | embeddings = output.model_output 148 | attn_mask = output.attention_mask 149 | return normalize(self.mean_pooling(embeddings, attn_mask)) 150 | 151 | 152 | class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker): 153 | def init_embedding( 154 | self, 155 | model_name: str, 156 | cache_dir: str, 157 | **kwargs: Any, 158 | ) -> OnnxTextEmbedding: 159 | return PooledNormalizedEmbedding( 160 | model_name=model_name, 161 | cache_dir=cache_dir, 162 | threads=1, 163 | **kwargs, 164 | ) 165 | -------------------------------------------------------------------------------- /fastembed/text/text_embedding_base.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Union, Any 2 | 3 | from fastembed.common.model_description import DenseModelDescription 4 | from fastembed.common.types import NumpyArray 5 | from fastembed.common.model_management import ModelManagement 6 | 7 | 8 | class TextEmbeddingBase(ModelManagement[DenseModelDescription]): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | cache_dir: Optional[str] = None, 13 | threads: Optional[int] = None, 14 | **kwargs: Any, 15 | ): 16 | self.model_name = model_name 17 | self.cache_dir = cache_dir 18 | self.threads = threads 19 | self._local_files_only = kwargs.pop("local_files_only", False) 20 | self._embedding_size: Optional[int] = None 21 | 22 | def embed( 23 | self, 24 | documents: Union[str, Iterable[str]], 25 | batch_size: int = 256, 26 | parallel: Optional[int] = None, 27 | **kwargs: Any, 28 | ) -> Iterable[NumpyArray]: 29 | raise NotImplementedError() 30 | 31 | def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: 32 | """ 33 | Embeds a list of text passages into a list of embeddings. 34 | 35 | Args: 36 | texts (Iterable[str]): The list of texts to embed. 37 | **kwargs: Additional keyword argument to pass to the embed method. 38 | 39 | Yields: 40 | Iterable[NumpyArray]: The embeddings. 41 | """ 42 | 43 | # This is model-specific, so that different models can have specialized implementations 44 | yield from self.embed(texts, **kwargs) 45 | 46 | def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]: 47 | """ 48 | Embeds queries 49 | 50 | Args: 51 | query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. 52 | 53 | Returns: 54 | Iterable[NumpyArray]: The embeddings. 55 | """ 56 | 57 | # This is model-specific, so that different models can have specialized implementations 58 | if isinstance(query, str): 59 | yield from self.embed([query], **kwargs) 60 | else: 61 | yield from self.embed(query, **kwargs) 62 | 63 | @classmethod 64 | def get_embedding_size(cls, model_name: str) -> int: 65 | """Returns embedding size of the passed model.""" 66 | raise NotImplementedError("Subclasses must implement this method") 67 | 68 | @property 69 | def embedding_size(self) -> int: 70 | """Returns embedding size for the current model""" 71 | raise NotImplementedError("Subclasses must implement this method") 72 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: FastEmbed 2 | site_url: https://qdrant.github.io/fastembed/ 3 | site_author: Nirant Kasliwal 4 | repo_url: https://github.com/qdrant/fastembed/ 5 | repo_name: qdrant/fastembed 6 | 7 | remote_branch: gh-pages 8 | remote_name: origin 9 | 10 | copyright: | 11 | Maintained by Qdrant. Originally created by Nirant Kasliwal. 12 | 13 | theme: 14 | name: material 15 | logo: assets/favicon.png 16 | custom_dir: docs/overrides 17 | icon: 18 | repo: fontawesome/brands/github 19 | features: 20 | - search.suggest 21 | - search.highlight 22 | - navigation.instant 23 | - navigation.tracking 24 | - navigation.expand 25 | - navigation.sections 26 | - content.code.annotate 27 | - toc.follow 28 | - header.autohide 29 | - announce.dismiss 30 | accent: 31 | # Primary color 32 | color: "#3f51b5" 33 | # Text color for primary color 34 | text: "#ffffff" 35 | 36 | palette: 37 | # Palette toggle for light mode 38 | - scheme: default 39 | toggle: 40 | icon: material/brightness-7 41 | name: Switch to dark mode 42 | 43 | # Palette toggle for dark mode 44 | - scheme: slate 45 | toggle: 46 | icon: material/brightness-4 47 | name: Switch to light mode 48 | 49 | markdown_extensions: 50 | - abbr 51 | - admonition 52 | - attr_list 53 | # - highlight 54 | - def_list 55 | - toc: 56 | permalink: true 57 | toc_depth: 3 58 | 59 | plugins: 60 | - search 61 | - mkdocstrings: 62 | default_handler: python 63 | handlers: 64 | python: 65 | options: 66 | show_source: false 67 | show_bases: false 68 | show_if_no_docstring: true 69 | merge_init_into_class: true 70 | show_root_toc_entry: false 71 | show_inheritance: true 72 | show_private: false 73 | show_special_members: false 74 | - mknotebooks: 75 | execute: false 76 | timeout: 100 77 | allow_errors: false 78 | tag_remove_configs: 79 | remove_cell_tags: 80 | - Remove_cell 81 | remove_all_outputs_tags: 82 | - Remove_all_output 83 | remove_single_output_tags: 84 | - Remove_single_output 85 | remove_input_tags: 86 | - Remove_input 87 | 88 | markdown_extensions: 89 | - pymdownx.superfences: 90 | custom_fences: 91 | - name: mermaid 92 | class: mermaid 93 | format: !!python/name:pymdownx.superfences.fence_code_format 94 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fastembed" 3 | version = "0.7.0" 4 | description = "Fast, light, accurate library built for retrieval embedding generation" 5 | authors = ["Qdrant Team ", "NirantK "] 6 | license = "Apache License" 7 | readme = "README.md" 8 | packages = [{include = "fastembed"}] 9 | homepage = "https://github.com/qdrant/fastembed" 10 | repository = "https://github.com/qdrant/fastembed" 11 | keywords = ["vector", "embedding", "neural", "search", "qdrant", "sentence-transformers"] 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9.0" 15 | numpy = [ 16 | { version = ">=1.21", python = ">=3.10,<3.12" }, 17 | { version = ">=1.26", python = ">=3.12,<3.13" }, 18 | { version = ">=2.1.0", python = ">=3.13" }, 19 | { version = ">=1.21,<2.1.0", python = "<3.10" }, 20 | ] 21 | onnxruntime = [ 22 | { version = ">=1.17.0,<1.20.0", python = "<3.10" }, 23 | { version = ">1.20.0", python = ">=3.13" }, 24 | { version = ">=1.17.0,!=1.20.0", python = ">=3.10,<3.13" }, 25 | ] 26 | tqdm = "^4.66" 27 | requests = "^2.31" 28 | tokenizers = ">=0.15,<1.0" 29 | huggingface-hub = ">=0.20,<1.0" 30 | loguru = "^0.7.2" 31 | pillow = ">=10.3.0,<12.0.0" 32 | mmh3 = ">=4.1.0,<6.0.0" 33 | py-rust-stemmers = "^0.1.0" 34 | 35 | [tool.poetry.group.test.dependencies] 36 | pytest = "^7.4.2" 37 | ruff = ">=0.3.1,<1.0" 38 | 39 | [tool.poetry.group.dev.dependencies] 40 | notebook = ">=7.0.2" 41 | pre-commit = "^3.6.2" 42 | onnx = ">=1.15.0" 43 | 44 | [tool.poetry.group.docs.dependencies] 45 | mkdocs-material = "^9.5.10" 46 | mkdocstrings = "^0.24.0" 47 | pillow = ">=10.3.0,<12.0.0" 48 | cairosvg = "^2.7.1" 49 | mknotebooks = "^0.8.0" 50 | 51 | [tool.poetry.group.types.dependencies] 52 | pyright = ">=1.1.293" 53 | mypy = "^1.0.0" 54 | 55 | [build-system] 56 | requires = ["poetry-core"] 57 | build-backend = "poetry.core.masonry.api" 58 | 59 | [tool.pyright] 60 | typeCheckingMode = "strict" 61 | 62 | [tool.ruff] 63 | line-length = 99 64 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # disable DeprecationWarning https://github.com/jupyter/jupyter_core/issues/398 4 | os.environ["JUPYTER_PLATFORM_DIRS"] = "1" 5 | -------------------------------------------------------------------------------- /tests/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | TEST_DIR = Path(__file__).parent 4 | TEST_MISC_DIR = TEST_DIR / "misc" 5 | -------------------------------------------------------------------------------- /tests/misc/image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/fastembed/a260022ae6059f4e7568cbd57cd6191cdaab8f33/tests/misc/image.jpeg -------------------------------------------------------------------------------- /tests/misc/small_image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/fastembed/a260022ae6059f4e7568cbd57cd6191cdaab8f33/tests/misc/small_image.jpeg -------------------------------------------------------------------------------- /tests/profiling.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | # # 🤗 Huggingface vs ⚡ FastEmbed️ 3 | # 4 | # Comparing the performance of Huggingface's 🤗 Transformers and ⚡ FastEmbed️ on a simple task on the following machine: Apple M2 Max, 32 GB RAM 5 | # 6 | # ## 📦 Imports 7 | # 8 | # Importing the necessary libraries for this comparison. 9 | 10 | # %% 11 | import time 12 | from typing import Callable 13 | 14 | import matplotlib.pyplot as plt 15 | import torch.nn.functional as F 16 | from transformers import AutoModel, AutoTokenizer 17 | 18 | from fastembed.embedding import DefaultEmbedding 19 | 20 | # %% [markdown] 21 | # ## 📖 Data 22 | # 23 | # data is a list of strings, each string is a document. 24 | 25 | # %% 26 | documents: list[str] = [ 27 | "Chandrayaan-3 is India's third lunar mission", 28 | "It aimed to land a rover on the Moon's surface - joining the US, China and Russia", 29 | "The mission is a follow-up to Chandrayaan-2, which had partial success", 30 | "Chandrayaan-3 will be launched by the Indian Space Research Organisation (ISRO)", 31 | "The estimated cost of the mission is around $35 million", 32 | "It will carry instruments to study the lunar surface and atmosphere", 33 | "Chandrayaan-3 landed on the Moon's surface on 23rd August 2023", 34 | "It consists of a lander named Vikram and a rover named Pragyan similar to Chandrayaan-2. Its propulsion module would act like an orbiter.", 35 | "The propulsion module carries the lander and rover configuration until the spacecraft is in a 100-kilometre (62 mi) lunar orbit", 36 | "The mission used GSLV Mk III rocket for its launch", 37 | "Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota", 38 | "Chandrayaan-3 was launched earlier in the year 2023", 39 | ] 40 | len(documents) 41 | 42 | # %% [markdown] 43 | # ## Setting up 🤗 Huggingface 44 | # 45 | # We'll be using the [Huggingface Transformers](https://huggingface.co/transformers/) with PyTorch library to generate embeddings. We'll be using the same model across both libraries for a fair(er?) comparison. 46 | 47 | 48 | # %% 49 | class HF: 50 | """ 51 | HuggingFace Transformer implementation of FlagEmbedding 52 | Based on https://huggingface.co/BAAI/bge-base-en 53 | """ 54 | 55 | def __init__(self, model_id: str): 56 | self.model = AutoModel.from_pretrained(model_id) 57 | self.tokenizer = AutoTokenizer.from_pretrained(model_id) 58 | 59 | def embed(self, texts: list[str]): 60 | encoded_input = self.tokenizer( 61 | texts, max_length=512, padding=True, truncation=True, return_tensors="pt" 62 | ) 63 | model_output = self.model(**encoded_input) 64 | sentence_embeddings = model_output[0][:, 0] 65 | sentence_embeddings = F.normalize(sentence_embeddings) 66 | return sentence_embeddings 67 | 68 | 69 | hf = HF(model_id="BAAI/bge-small-en") 70 | hf.embed(documents).shape 71 | 72 | # %% [markdown] 73 | # ## Setting up ⚡️FastEmbed 74 | # 75 | # Sorry, don't have a lot to set up here. We'll be using the default model, which is Flag Embedding, same as the Huggingface model. 76 | 77 | # %% 78 | embedding_model = DefaultEmbedding() 79 | 80 | # %% [markdown] 81 | # ## 📊 Comparison 82 | # 83 | # We'll be comparing the following metrics: Minimum, Maximum, Mean, across k runs. Let's write a function to do that: 84 | # 85 | # ### 🚀 Calculating Stats 86 | 87 | 88 | # %% 89 | def calculate_time_stats( 90 | embed_func: Callable, documents: list, k: int 91 | ) -> tuple[float, float, float]: 92 | times = [] 93 | for _ in range(k): 94 | # Timing the embed_func call 95 | start_time = time.time() 96 | embed_func(documents) 97 | end_time = time.time() 98 | 99 | times.append(end_time - start_time) 100 | 101 | # Returning mean, max, and min time for the call 102 | return (sum(times) / k, max(times), min(times)) 103 | 104 | 105 | # %% 106 | hf_stats = calculate_time_stats(hf.embed, documents, k=2) 107 | print(f"Huggingface Transformers (Average, Max, Min): {hf_stats}") 108 | fst_stats = calculate_time_stats(lambda x: list(embedding_model.embed(x)), documents, k=2) 109 | print(f"FastEmbed (Average, Max, Min): {fst_stats}") 110 | 111 | 112 | # %% 113 | def plot_character_per_second_comparison( 114 | hf_stats: tuple[float, float, float], 115 | fst_stats: tuple[float, float, float], 116 | documents: list, 117 | ): 118 | # Calculating total characters in documents 119 | total_characters = sum(len(doc) for doc in documents) 120 | 121 | # Calculating characters per second for each model 122 | hf_chars_per_sec = total_characters / hf_stats[0] # Mean time is at index 0 123 | fst_chars_per_sec = total_characters / fst_stats[0] 124 | 125 | # Plotting the bar chart 126 | models = ["HF Embed (Torch)", "FastEmbed"] 127 | chars_per_sec = [hf_chars_per_sec, fst_chars_per_sec] 128 | 129 | bars = plt.bar(models, chars_per_sec, color=["#1f356c", "#dd1f4b"]) 130 | plt.ylabel("Characters per Second") 131 | plt.title("Characters Processed per Second Comparison") 132 | 133 | # Adding the number at the top of each bar 134 | for bar, chars in zip(bars, chars_per_sec): 135 | plt.text( 136 | bar.get_x() + bar.get_width() / 2, 137 | bar.get_height(), 138 | f"{chars:.1f}", 139 | ha="center", 140 | va="bottom", 141 | color="#1f356c", 142 | fontsize=12, 143 | ) 144 | 145 | plt.show() 146 | 147 | 148 | plot_character_per_second_comparison(hf_stats, fst_stats, documents) 149 | -------------------------------------------------------------------------------- /tests/test_attention_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from fastembed import SparseTextEmbedding 7 | from tests.utils import delete_model_cache 8 | 9 | 10 | @pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"]) 11 | def test_attention_embeddings(model_name: str) -> None: 12 | is_ci = os.getenv("CI") 13 | model = SparseTextEmbedding(model_name=model_name) 14 | 15 | output = list( 16 | model.query_embed( 17 | [ 18 | "I must not fear. Fear is the mind-killer.", 19 | ] 20 | ) 21 | ) 22 | 23 | assert len(output) == 1 24 | 25 | for result in output: 26 | assert len(result.indices) == len(result.values) 27 | assert np.allclose(result.values, np.ones(len(result.values))) 28 | 29 | quotes = [ 30 | "I must not fear. Fear is the mind-killer.", 31 | "All animals are equal, but some animals are more equal than others.", 32 | "It was a pleasure to burn.", 33 | "The sky above the port was the color of television, tuned to a dead channel.", 34 | "In the beginning, the universe was created." 35 | " This has made a lot of people very angry and been widely regarded as a bad move.", 36 | "It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains.", 37 | "War is peace. Freedom is slavery. Ignorance is strength.", 38 | "We're not in Infinity; we're in the suburbs.", 39 | "I was a thousand times more evil than thou!", 40 | "History is merely a list of surprises... It can only prepare us to be surprised yet again.", 41 | ".", # Empty string 42 | ] 43 | 44 | output = list(model.embed(quotes)) 45 | 46 | assert len(output) == len(quotes) 47 | 48 | for result in output[:-1]: 49 | assert len(result.indices) == len(result.values) 50 | assert len(result.indices) > 0 51 | 52 | assert len(output[-1].indices) == 0 53 | 54 | # Test support for unknown languages 55 | output = list( 56 | model.query_embed( 57 | [ 58 | "привет мир!", 59 | ] 60 | ) 61 | ) 62 | 63 | assert len(output) == 1 64 | 65 | for result in output: 66 | assert len(result.indices) == len(result.values) 67 | assert len(result.indices) == 2 68 | 69 | if is_ci: 70 | delete_model_cache(model.model._model_dir) 71 | 72 | 73 | @pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"]) 74 | def test_parallel_processing(model_name: str) -> None: 75 | is_ci = os.getenv("CI") 76 | 77 | model = SparseTextEmbedding(model_name=model_name) 78 | 79 | docs = ["hello world", "attention embedding", "Mangez-vous vraiment des grenouilles?"] * 100 80 | embeddings = list(model.embed(docs, batch_size=10, parallel=2)) 81 | 82 | embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) 83 | 84 | embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) 85 | 86 | assert len(embeddings) == len(docs) 87 | 88 | for emb_1, emb_2, emb_3 in zip(embeddings, embeddings_2, embeddings_3): 89 | assert np.allclose(emb_1.indices, emb_2.indices) 90 | assert np.allclose(emb_1.indices, emb_3.indices) 91 | assert np.allclose(emb_1.values, emb_2.values) 92 | assert np.allclose(emb_1.values, emb_3.values) 93 | 94 | if is_ci: 95 | delete_model_cache(model.model._model_dir) 96 | 97 | 98 | @pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) 99 | def test_multilanguage(model_name: str) -> None: 100 | is_ci = os.getenv("CI") 101 | 102 | docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"] 103 | 104 | model = SparseTextEmbedding(model_name=model_name, language="french") 105 | embeddings = list(model.embed(docs))[:2] 106 | assert embeddings[0].values.shape == (3,) 107 | assert embeddings[0].indices.shape == (3,) 108 | 109 | assert embeddings[1].values.shape == (1,) 110 | assert embeddings[1].indices.shape == (1,) 111 | 112 | model = SparseTextEmbedding(model_name=model_name, language="english") 113 | embeddings = list(model.embed(docs))[:2] 114 | assert embeddings[0].values.shape == (5,) 115 | assert embeddings[0].indices.shape == (5,) 116 | 117 | assert embeddings[1].values.shape == (4,) 118 | assert embeddings[1].indices.shape == (4,) 119 | 120 | if is_ci: 121 | delete_model_cache(model.model._model_dir) 122 | 123 | 124 | @pytest.mark.parametrize("model_name", ["Qdrant/bm25"]) 125 | def test_special_characters(model_name: str) -> None: 126 | is_ci = os.getenv("CI") 127 | 128 | docs = [ 129 | "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!", 130 | "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!", 131 | "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.", 132 | "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!", 133 | "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»", 134 | "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad", 135 | ] 136 | 137 | model = SparseTextEmbedding(model_name=model_name, language="english") 138 | embeddings = list(model.embed(docs)) 139 | for idx, shape in enumerate([14, 18, 15, 10, 15]): 140 | assert embeddings[idx].values.shape == (shape,) 141 | assert embeddings[idx].indices.shape == (shape,) 142 | 143 | if is_ci: 144 | delete_model_cache(model.model._model_dir) 145 | 146 | 147 | @pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"]) 148 | def test_lazy_load(model_name: str) -> None: 149 | model = SparseTextEmbedding(model_name=model_name, lazy_load=True) 150 | assert not hasattr(model.model, "model") 151 | docs = ["hello world", "flag embedding"] 152 | list(model.embed(docs)) 153 | assert hasattr(model.model, "model") 154 | 155 | model = SparseTextEmbedding(model_name=model_name, lazy_load=True) 156 | list(model.query_embed(docs)) 157 | 158 | model = SparseTextEmbedding(model_name=model_name, lazy_load=True) 159 | list(model.passage_embed(docs)) 160 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | from fastembed import ( 2 | TextEmbedding, 3 | SparseTextEmbedding, 4 | ImageEmbedding, 5 | LateInteractionMultimodalEmbedding, 6 | LateInteractionTextEmbedding, 7 | ) 8 | 9 | 10 | def test_text_list_supported_models(): 11 | for model_type in [ 12 | TextEmbedding, 13 | SparseTextEmbedding, 14 | ImageEmbedding, 15 | LateInteractionMultimodalEmbedding, 16 | LateInteractionTextEmbedding, 17 | ]: 18 | supported_models = model_type.list_supported_models() 19 | assert isinstance(supported_models, list) 20 | description = supported_models[0] 21 | assert isinstance(description, dict) 22 | 23 | assert "model" in description and description["model"] 24 | if model_type != SparseTextEmbedding: 25 | assert "dim" in description and description["dim"] 26 | assert "license" in description and description["license"] 27 | assert "size_in_GB" in description and description["size_in_GB"] 28 | assert "model_file" in description and description["model_file"] 29 | assert "sources" in description and description["sources"] 30 | assert "hf" in description["sources"] or "url" in description["sources"] 31 | -------------------------------------------------------------------------------- /tests/test_image_onnx_embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | 4 | import numpy as np 5 | import pytest 6 | import requests 7 | from PIL import Image 8 | 9 | from fastembed import ImageEmbedding 10 | from tests.config import TEST_MISC_DIR 11 | from tests.utils import delete_model_cache, should_test_model 12 | 13 | CANONICAL_VECTOR_VALUES = { 14 | "Qdrant/clip-ViT-B-32-vision": np.array([-0.0098, 0.0128, -0.0274, 0.002, -0.0059]), 15 | "Qdrant/resnet50-onnx": np.array( 16 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01046245, 0.01171397, 0.00705971, 0.0] 17 | ), 18 | "Qdrant/Unicom-ViT-B-16": np.array( 19 | [0.0170, -0.0361, 0.0125, -0.0428, -0.0232, 0.0232, -0.0602, -0.0333, 0.0155, 0.0497] 20 | ), 21 | "Qdrant/Unicom-ViT-B-32": np.array( 22 | [0.0418, 0.0550, 0.0003, 0.0253, -0.0185, 0.0016, -0.0368, -0.0402, -0.0891, -0.0186] 23 | ), 24 | "jinaai/jina-clip-v1": np.array( 25 | [-0.029, 0.0216, 0.0396, 0.0283, -0.0023, 0.0151, 0.011, -0.0235, 0.0251, -0.0343] 26 | ), 27 | } 28 | 29 | 30 | @pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"]) 31 | def test_embedding(model_name: str) -> None: 32 | is_ci = os.getenv("CI") 33 | is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" 34 | 35 | for model_desc in ImageEmbedding._list_supported_models(): 36 | if not should_test_model(model_desc, model_name, is_ci, is_manual): 37 | continue 38 | 39 | dim = model_desc.dim 40 | 41 | model = ImageEmbedding(model_name=model_desc.model) 42 | 43 | images = [ 44 | TEST_MISC_DIR / "image.jpeg", 45 | str(TEST_MISC_DIR / "small_image.jpeg"), 46 | Image.open((TEST_MISC_DIR / "small_image.jpeg")), 47 | Image.open(BytesIO(requests.get("https://qdrant.tech/img/logo.png").content)), 48 | ] 49 | embeddings = list(model.embed(images)) 50 | embeddings = np.stack(embeddings, axis=0) 51 | assert embeddings.shape == (len(images), dim) 52 | 53 | canonical_vector = CANONICAL_VECTOR_VALUES[model_desc.model] 54 | 55 | assert np.allclose( 56 | embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3 57 | ), model_desc.model 58 | 59 | assert np.allclose(embeddings[1], embeddings[2]), model_desc.model 60 | 61 | if is_ci: 62 | delete_model_cache(model.model._model_dir) 63 | 64 | 65 | @pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")]) 66 | def test_batch_embedding(n_dims: int, model_name: str) -> None: 67 | is_ci = os.getenv("CI") 68 | model = ImageEmbedding(model_name=model_name) 69 | n_images = 32 70 | test_images = [ 71 | TEST_MISC_DIR / "image.jpeg", 72 | str(TEST_MISC_DIR / "small_image.jpeg"), 73 | Image.open(TEST_MISC_DIR / "small_image.jpeg"), 74 | ] 75 | images = test_images * n_images 76 | 77 | embeddings = list(model.embed(images, batch_size=10)) 78 | embeddings = np.stack(embeddings, axis=0) 79 | assert np.allclose(embeddings[1], embeddings[2]) 80 | 81 | canonical_vector = CANONICAL_VECTOR_VALUES[model_name] 82 | 83 | assert embeddings.shape == (len(test_images) * n_images, n_dims) 84 | assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3) 85 | if is_ci: 86 | delete_model_cache(model.model._model_dir) 87 | 88 | 89 | @pytest.mark.parametrize("n_dims,model_name", [(512, "Qdrant/clip-ViT-B-32-vision")]) 90 | def test_parallel_processing(n_dims: int, model_name: str) -> None: 91 | is_ci = os.getenv("CI") 92 | model = ImageEmbedding(model_name=model_name) 93 | 94 | n_images = 32 95 | test_images = [ 96 | TEST_MISC_DIR / "image.jpeg", 97 | str(TEST_MISC_DIR / "small_image.jpeg"), 98 | Image.open(TEST_MISC_DIR / "small_image.jpeg"), 99 | ] 100 | images = test_images * n_images 101 | embeddings = list(model.embed(images, batch_size=10, parallel=2)) 102 | embeddings = np.stack(embeddings, axis=0) 103 | 104 | embeddings_2 = list(model.embed(images, batch_size=10, parallel=None)) 105 | embeddings_2 = np.stack(embeddings_2, axis=0) 106 | 107 | embeddings_3 = list(model.embed(images, batch_size=10, parallel=0)) 108 | embeddings_3 = np.stack(embeddings_3, axis=0) 109 | 110 | assert embeddings.shape == (n_images * len(test_images), n_dims) 111 | assert np.allclose(embeddings, embeddings_2, atol=1e-3) 112 | assert np.allclose(embeddings, embeddings_3, atol=1e-3) 113 | if is_ci: 114 | delete_model_cache(model.model._model_dir) 115 | 116 | 117 | @pytest.mark.parametrize("model_name", ["Qdrant/clip-ViT-B-32-vision"]) 118 | def test_lazy_load(model_name: str) -> None: 119 | is_ci = os.getenv("CI") 120 | model = ImageEmbedding(model_name=model_name, lazy_load=True) 121 | assert not hasattr(model.model, "model") 122 | images = [ 123 | TEST_MISC_DIR / "image.jpeg", 124 | str(TEST_MISC_DIR / "small_image.jpeg"), 125 | ] 126 | list(model.embed(images)) 127 | assert hasattr(model.model, "model") 128 | if is_ci: 129 | delete_model_cache(model.model._model_dir) 130 | 131 | 132 | def test_get_embedding_size() -> None: 133 | assert ImageEmbedding.get_embedding_size(model_name="Qdrant/clip-ViT-B-32-vision") == 512 134 | assert ImageEmbedding.get_embedding_size(model_name="Qdrant/clip-vit-b-32-vision") == 512 135 | 136 | 137 | def test_embedding_size() -> None: 138 | is_ci = os.getenv("CI") 139 | model_name = "Qdrant/clip-ViT-B-32-vision" 140 | model = ImageEmbedding(model_name=model_name, lazy_load=True) 141 | assert model.embedding_size == 512 142 | 143 | model_name = "Qdrant/clip-vit-b-32-vision" 144 | model = ImageEmbedding(model_name=model_name, lazy_load=True) 145 | assert model.embedding_size == 512 146 | if is_ci: 147 | delete_model_cache(model.model._model_dir) 148 | -------------------------------------------------------------------------------- /tests/test_late_interaction_multimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from PIL import Image 5 | import numpy as np 6 | 7 | from fastembed import LateInteractionMultimodalEmbedding 8 | from tests.config import TEST_MISC_DIR 9 | 10 | 11 | # vectors are abridged and rounded for brevity 12 | CANONICAL_IMAGE_VALUES = { 13 | "Qdrant/colpali-v1.3-fp16": np.array( 14 | [ 15 | [-0.0345, -0.022, 0.0567, -0.0518, -0.0782, 0.1714, -0.1738], 16 | [-0.1181, -0.099, 0.0268, 0.0774, 0.0228, 0.0563, -0.1021], 17 | [-0.117, -0.0683, 0.0371, 0.0921, 0.0107, 0.0659, -0.0666], 18 | [-0.1393, -0.0948, 0.037, 0.0951, -0.0126, 0.0678, -0.087], 19 | [-0.0957, -0.081, 0.0404, 0.052, 0.0409, 0.0335, -0.064], 20 | [-0.0626, -0.0445, 0.056, 0.0592, -0.0229, 0.0409, -0.0301], 21 | [-0.1299, -0.0691, 0.1097, 0.0728, 0.0123, 0.0519, 0.0122], 22 | ] 23 | ), 24 | } 25 | 26 | CANONICAL_QUERY_VALUES = { 27 | "Qdrant/colpali-v1.3-fp16": np.array( 28 | [ 29 | [-0.0023, 0.1477, 0.1594, 0.046, -0.0196, 0.0554, 0.1567], 30 | [-0.0139, -0.0057, 0.0932, 0.0052, -0.0678, 0.0131, 0.0537], 31 | [0.0054, 0.0364, 0.2078, -0.074, 0.0355, 0.061, 0.1593], 32 | [-0.0076, -0.0154, 0.2266, 0.0103, 0.0089, -0.024, 0.098], 33 | [-0.0274, 0.0098, 0.2106, -0.0634, 0.0616, -0.0021, 0.0708], 34 | [0.0074, 0.0025, 0.1631, -0.0802, 0.0418, -0.0219, 0.1022], 35 | [-0.0165, -0.0106, 0.1672, -0.0768, 0.0389, -0.0038, 0.1137], 36 | ] 37 | ), 38 | } 39 | 40 | queries = ["hello world", "flag embedding"] 41 | images = [ 42 | TEST_MISC_DIR / "image.jpeg", 43 | str(TEST_MISC_DIR / "image.jpeg"), 44 | Image.open((TEST_MISC_DIR / "image.jpeg")), 45 | ] 46 | 47 | 48 | def test_batch_embedding(): 49 | if os.getenv("CI"): 50 | pytest.skip("Colpali is too large to test in CI") 51 | 52 | for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): 53 | print("evaluating", model_name) 54 | model = LateInteractionMultimodalEmbedding(model_name=model_name) 55 | result = list(model.embed_image(images, batch_size=2)) 56 | 57 | for value in result: 58 | token_num, abridged_dim = expected_result.shape 59 | assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=2e-3) 60 | 61 | 62 | def test_single_embedding(): 63 | if os.getenv("CI"): 64 | pytest.skip("Colpali is too large to test in CI") 65 | 66 | for model_name, expected_result in CANONICAL_IMAGE_VALUES.items(): 67 | print("evaluating", model_name) 68 | model = LateInteractionMultimodalEmbedding(model_name=model_name) 69 | result = next(iter(model.embed_image(images, batch_size=6))) 70 | token_num, abridged_dim = expected_result.shape 71 | assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) 72 | 73 | 74 | def test_single_embedding_query(): 75 | if os.getenv("CI"): 76 | pytest.skip("Colpali is too large to test in CI") 77 | 78 | for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): 79 | print("evaluating", model_name) 80 | model = LateInteractionMultimodalEmbedding(model_name=model_name) 81 | result = next(iter(model.embed_text(queries))) 82 | token_num, abridged_dim = expected_result.shape 83 | assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) 84 | 85 | 86 | def test_get_embedding_size(): 87 | model_name = "Qdrant/colpali-v1.3-fp16" 88 | assert LateInteractionMultimodalEmbedding.get_embedding_size(model_name) == 128 89 | 90 | model_name = "Qdrant/ColPali-v1.3-fp16" 91 | assert LateInteractionMultimodalEmbedding.get_embedding_size(model_name) == 128 92 | 93 | 94 | def test_embedding_size(): 95 | if os.getenv("CI"): 96 | pytest.skip("Colpali is too large to test in CI") 97 | model_name = "Qdrant/colpali-v1.3-fp16" 98 | model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) 99 | assert model.embedding_size == 128 100 | 101 | model_name = "Qdrant/ColPali-v1.3-fp16" 102 | model = LateInteractionMultimodalEmbedding(model_name=model_name, lazy_load=True) 103 | assert model.embedding_size == 128 104 | -------------------------------------------------------------------------------- /tests/test_text_cross_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from fastembed.rerank.cross_encoder import TextCrossEncoder 7 | from tests.utils import delete_model_cache, should_test_model 8 | 9 | CANONICAL_SCORE_VALUES = { 10 | "Xenova/ms-marco-MiniLM-L-6-v2": np.array([8.500708, -2.541011]), 11 | "Xenova/ms-marco-MiniLM-L-12-v2": np.array([9.330912, -2.0380247]), 12 | "BAAI/bge-reranker-base": np.array([6.15733337, -3.65939403]), 13 | "jinaai/jina-reranker-v1-tiny-en": np.array([2.5911, 0.1122]), 14 | "jinaai/jina-reranker-v1-turbo-en": np.array([1.8295, -2.8908]), 15 | "jinaai/jina-reranker-v2-base-multilingual": np.array([1.6533, -1.6455]), 16 | } 17 | 18 | 19 | @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) 20 | def test_rerank(model_name: str) -> None: 21 | is_ci = os.getenv("CI") 22 | is_manual = os.getenv("GITHUB_EVENT_NAME") == "workflow_dispatch" 23 | 24 | for model_desc in TextCrossEncoder._list_supported_models(): 25 | if not should_test_model(model_desc, model_name, is_ci, is_manual): 26 | continue 27 | 28 | model = TextCrossEncoder(model_name=model_name) 29 | 30 | query = "What is the capital of France?" 31 | documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] 32 | scores = np.array(list(model.rerank(query, documents))) 33 | 34 | pairs = [(query, doc) for doc in documents] 35 | scores2 = np.array(list(model.rerank_pairs(pairs))) 36 | assert np.allclose( 37 | scores, scores2, atol=1e-5 38 | ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" 39 | 40 | canonical_scores = CANONICAL_SCORE_VALUES[model_name] 41 | assert np.allclose( 42 | scores, canonical_scores, atol=1e-3 43 | ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" 44 | if is_ci: 45 | delete_model_cache(model.model._model_dir) 46 | 47 | 48 | @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) 49 | def test_batch_rerank(model_name: str) -> None: 50 | is_ci = os.getenv("CI") 51 | 52 | model = TextCrossEncoder(model_name=model_name) 53 | 54 | query = "What is the capital of France?" 55 | documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 50 56 | scores = np.array(list(model.rerank(query, documents, batch_size=10))) 57 | 58 | pairs = [(query, doc) for doc in documents] 59 | scores2 = np.array(list(model.rerank_pairs(pairs))) 60 | assert np.allclose( 61 | scores, scores2, atol=1e-5 62 | ), f"Model: {model_name}, Scores: {scores}, Scores2: {scores2}" 63 | 64 | canonical_scores = np.tile(CANONICAL_SCORE_VALUES[model_name], 50) 65 | 66 | assert scores.shape == canonical_scores.shape, f"Unexpected shape for model {model_name}" 67 | assert np.allclose( 68 | scores, canonical_scores, atol=1e-3 69 | ), f"Model: {model_name}, Scores: {scores}, Expected: {canonical_scores}" 70 | if is_ci: 71 | delete_model_cache(model.model._model_dir) 72 | 73 | 74 | @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) 75 | def test_lazy_load(model_name: str) -> None: 76 | is_ci = os.getenv("CI") 77 | model = TextCrossEncoder(model_name=model_name, lazy_load=True) 78 | assert not hasattr(model.model, "model") 79 | query = "What is the capital of France?" 80 | documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] 81 | list(model.rerank(query, documents)) 82 | assert hasattr(model.model, "model") 83 | 84 | if is_ci: 85 | delete_model_cache(model.model._model_dir) 86 | 87 | 88 | @pytest.mark.parametrize("model_name", ["Xenova/ms-marco-MiniLM-L-6-v2"]) 89 | def test_rerank_pairs_parallel(model_name: str) -> None: 90 | is_ci = os.getenv("CI") 91 | 92 | model = TextCrossEncoder(model_name=model_name) 93 | query = "What is the capital of France?" 94 | documents = ["Paris is the capital of France.", "Berlin is the capital of Germany."] * 10 95 | pairs = [(query, doc) for doc in documents] 96 | scores_parallel = np.array(list(model.rerank_pairs(pairs, parallel=2, batch_size=10))) 97 | scores_sequential = np.array(list(model.rerank_pairs(pairs, batch_size=10))) 98 | assert np.allclose( 99 | scores_parallel, scores_sequential, atol=1e-5 100 | ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Scores (Sequential): {scores_sequential}" 101 | canonical_scores = CANONICAL_SCORE_VALUES[model_name] 102 | assert np.allclose( 103 | scores_parallel[: len(canonical_scores)], canonical_scores, atol=1e-3 104 | ), f"Model: {model_name}, Scores (Parallel): {scores_parallel}, Expected: {canonical_scores}" 105 | if is_ci: 106 | delete_model_cache(model.model._model_dir) 107 | -------------------------------------------------------------------------------- /tests/type_stub.py: -------------------------------------------------------------------------------- 1 | from fastembed import TextEmbedding, LateInteractionTextEmbedding, SparseTextEmbedding 2 | from fastembed.sparse.bm25 import Bm25 3 | from fastembed.rerank.cross_encoder import TextCrossEncoder 4 | 5 | 6 | text_embedder = TextEmbedding(cache_dir="models") 7 | late_interaction_embedder = LateInteractionTextEmbedding(model_name="", cache_dir="models") 8 | reranker = TextCrossEncoder(model_name="", cache_dir="models") 9 | sparse_embedder = SparseTextEmbedding(model_name="", cache_dir="models") 10 | bm25_embedder = Bm25( 11 | model_name="", 12 | k=1.0, 13 | b=1.0, 14 | avg_len=1.0, 15 | language="", 16 | token_max_length=1, 17 | disable_stemmer=False, 18 | specific_model_path="models", 19 | ) 20 | 21 | text_embedder.list_supported_models() 22 | text_embedder.embed(documents=[""], batch_size=1, parallel=1) 23 | text_embedder.embed(documents="", parallel=None, task_id=1) 24 | text_embedder.query_embed(query=[""], batch_size=1, parallel=1) 25 | text_embedder.query_embed(query="", parallel=None) 26 | text_embedder.passage_embed(texts=[""], batch_size=1, parallel=1) 27 | text_embedder.passage_embed(texts=[""], parallel=None) 28 | 29 | late_interaction_embedder.list_supported_models() 30 | late_interaction_embedder.embed(documents=[""], batch_size=1, parallel=1) 31 | late_interaction_embedder.embed(documents="", parallel=None) 32 | late_interaction_embedder.query_embed(query=[""], batch_size=1, parallel=1) 33 | late_interaction_embedder.query_embed(query="", parallel=None) 34 | late_interaction_embedder.passage_embed(texts=[""], batch_size=1, parallel=1) 35 | late_interaction_embedder.passage_embed(texts=[""], parallel=None) 36 | 37 | reranker.list_supported_models() 38 | reranker.rerank(query="", documents=[""], batch_size=1, parallel=1) 39 | reranker.rerank(query="", documents=[""], parallel=None) 40 | reranker.rerank_pairs(pairs=[("", "")], batch_size=1, parallel=1) 41 | reranker.rerank_pairs(pairs=[("", "")], parallel=None) 42 | 43 | sparse_embedder.list_supported_models() 44 | sparse_embedder.embed(documents=[""], batch_size=1, parallel=1) 45 | sparse_embedder.embed(documents="", batch_size=1, parallel=None) 46 | sparse_embedder.query_embed(query=[""], batch_size=1, parallel=1) 47 | sparse_embedder.query_embed(query="", batch_size=1, parallel=None) 48 | sparse_embedder.passage_embed(texts=[""], batch_size=1, parallel=1) 49 | sparse_embedder.passage_embed(texts=[""], batch_size=1, parallel=None) 50 | 51 | bm25_embedder.list_supported_models() 52 | bm25_embedder.embed(documents=[""], batch_size=1, parallel=1) 53 | bm25_embedder.embed(documents="", batch_size=1, parallel=None) 54 | bm25_embedder.query_embed(query=[""], batch_size=1, parallel=1) 55 | bm25_embedder.query_embed(query="", batch_size=1, parallel=None) 56 | bm25_embedder.raw_embed(documents=[""]) 57 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import traceback 3 | 4 | from pathlib import Path 5 | from types import TracebackType 6 | from typing import Union, Callable, Any, Type, Optional 7 | 8 | from fastembed.common.model_description import BaseModelDescription 9 | 10 | 11 | def delete_model_cache(model_dir: Union[str, Path]) -> None: 12 | """Delete the model cache directory. 13 | 14 | If a model was downloaded from the HuggingFace model hub, then _model_dir is the dir to snapshots, removing 15 | it won't help to release the memory, because data is in blobs directory. 16 | If a model was downloaded from GCS, then we can just remove model_dir 17 | 18 | Args: 19 | model_dir (Union[str, Path]): The path to the model cache directory. 20 | """ 21 | 22 | def on_error( 23 | func: Callable[..., Any], 24 | path: str, 25 | exc_info: tuple[Type[BaseException], BaseException, TracebackType], 26 | ) -> None: 27 | print("Failed to remove: ", path) 28 | print("Exception: ", exc_info) 29 | traceback.print_exception(*exc_info) 30 | 31 | if isinstance(model_dir, str): 32 | model_dir = Path(model_dir) 33 | 34 | if model_dir.parent.parent.name.startswith("models--"): 35 | model_dir = model_dir.parent.parent 36 | 37 | if model_dir.exists(): 38 | # todo: PermissionDenied is raised on blobs removal in Windows, with blobs > 2GB 39 | shutil.rmtree(model_dir, onerror=on_error) 40 | 41 | 42 | def should_test_model( 43 | model_desc: BaseModelDescription, 44 | autotest_model_name: str, 45 | is_ci: Optional[str], 46 | is_manual: bool, 47 | ): 48 | """Determine if a model should be tested based on environment 49 | 50 | Tests can be run either in ci or locally. 51 | Testing all models each time in ci is too long. 52 | The testing scheme in ci and on a local machine are different, therefore, there are 3 possible scenarious. 53 | 1) Run lightweight tests in ci: 54 | - test only one model that has been manually chosen as a representative for a certain class family 55 | 2) Run heavyweight (manual) tests in ci: 56 | - test all models 57 | Running tests in ci each time is too expensive, however, it's fine to run it one time with a manual dispatch 58 | 3) Run tests locally: 59 | - test all models, which are not too heavy, since network speed might be a bottleneck 60 | 61 | """ 62 | if not is_ci: 63 | if model_desc.size_in_GB > 1: 64 | return False 65 | elif not is_manual and model_desc.model != autotest_model_name: 66 | return False 67 | return True 68 | --------------------------------------------------------------------------------