├── .envrc ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── .releaserc ├── Cargo.toml ├── LICENSE ├── README.md ├── flake.lock ├── flake.nix ├── src ├── common.rs ├── image_embedding │ ├── impl.rs │ ├── init.rs │ ├── mod.rs │ └── utils.rs ├── lib.rs ├── models │ ├── image_embedding.rs │ ├── mod.rs │ ├── model_info.rs │ ├── quantization.rs │ ├── reranking.rs │ ├── sparse.rs │ └── text_embedding.rs ├── output │ ├── embedding_output.rs │ ├── mod.rs │ └── output_precedence.rs ├── pooling.rs ├── reranking │ ├── impl.rs │ ├── init.rs │ └── mod.rs ├── sparse_text_embedding │ ├── impl.rs │ ├── init.rs │ └── mod.rs └── text_embedding │ ├── impl.rs │ ├── init.rs │ ├── mod.rs │ └── output.rs └── tests ├── assets ├── image_0.png ├── image_1.png └── sample_text.txt ├── embeddings.rs └── optimum_cli_export.rs /.envrc: -------------------------------------------------------------------------------- 1 | use flake 2 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Semantic Release 2 | on: 3 | workflow_dispatch: 4 | 5 | env: 6 | CARGO_TERM_COLOR: always 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: "☁️ checkout repository" 12 | uses: actions/checkout@v4 13 | with: 14 | fetch-depth: 0 15 | 16 | - uses: actions/setup-node@v4 17 | with: 18 | node-version: 20 19 | 20 | - name: "🔧 setup Bun" 21 | uses: oven-sh/setup-bun@v1 22 | 23 | - name: "📦 install dependencies" 24 | run: bun install -D @semantic-release/git conventional-changelog-conventionalcommits semantic-release-cargo 25 | 26 | - name: Get Author Name and Email 27 | run: | 28 | AUTHOR_NAME=$(git log -1 --pretty=format:%an ${{ github.sha }}) 29 | AUTHOR_EMAIL=$(git log -1 --pretty=format:%ae ${{ github.sha }}) 30 | echo "AUTHOR_NAME=$AUTHOR_NAME" >> $GITHUB_OUTPUT 31 | echo "AUTHOR_EMAIL=$AUTHOR_EMAIL" >> $GITHUB_OUTPUT 32 | id: author_info 33 | 34 | - name: "Semantic release🚀" 35 | id: release 36 | env: 37 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 38 | CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_TOKEN }} 39 | GIT_COMMITTER_NAME: "github-actions[bot]" 40 | GIT_COMMITTER_EMAIL: "41898282+github-actions[bot]@users.noreply.github.com" 41 | GIT_AUTHOR_NAME: ${{ steps.author_info.outputs.AUTHOR_NAME }} 42 | GIT_AUTHOR_EMAIL: ${{ steps.author_info.outputs.AUTHOR_EMAIL }} 43 | run: | 44 | bun x semantic-release 45 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: "Cargo Tests" 2 | on: 3 | pull_request: 4 | schedule: 5 | - cron: 0 0 * * * 6 | 7 | env: 8 | CARGO_TERM_COLOR: always 9 | RUSTFLAGS: "-Dwarnings" 10 | ONNX_VERSION: v1.20.1 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Restore Builds 20 | id: cache-build-restore 21 | uses: actions/cache/restore@v4 22 | with: 23 | key: '${{ runner.os }}-onnxruntime-${{ env.ONNX_VERSION }}' 24 | path: | 25 | onnxruntime/build/Linux/Release/ 26 | 27 | - name: Compile ONNX Runtime for Linux 28 | if: steps.cache-build-restore.outputs.cache-hit != 'true' 29 | run: | 30 | echo Cloning ONNX Runtime repository... 31 | git clone https://github.com/microsoft/onnxruntime --recursive --branch $ONNX_VERSION --single-branch --depth 1 32 | cd onnxruntime 33 | ./build.sh --update --build --config Release --parallel --compile_no_warning_as_error --skip_submodule_sync 34 | cd .. 35 | 36 | - name: Cargo Test With Release Build 37 | run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --release --no-default-features --features hf-hub-native-tls 38 | 39 | - name: Cargo Test Offline 40 | run: ORT_LIB_LOCATION="$(pwd)/onnxruntime/build/Linux/Release" cargo test --no-default-features 41 | 42 | - name: Cargo Clippy 43 | run: cargo clippy 44 | 45 | - name: Cargo FMT 46 | run: cargo fmt --all -- --check 47 | 48 | - name: Always Save Cache 49 | id: cache-build-save 50 | if: always() && steps.cache-build-restore.outputs.cache-hit != 'true' 51 | uses: actions/cache/save@v4 52 | with: 53 | key: '${{ steps.cache-build-restore.outputs.cache-primary-key }}' 54 | path: | 55 | onnxruntime/build/Linux/Release/ 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## File system 2 | .DS_Store 3 | desktop.ini 4 | 5 | ## Editor 6 | *.swp 7 | *.swo 8 | Session.vim 9 | .cproject 10 | .idea 11 | *.iml 12 | .vscode 13 | .project 14 | .favorites.json 15 | .settings/ 16 | .vs/ 17 | .zed/ 18 | 19 | ## Tool 20 | .valgrindrc 21 | .cargo 22 | # Included because it is part of the test case 23 | !/tests/run-make/thumb-none-qemu/example/.cargo 24 | 25 | ## Configuration 26 | /config.toml 27 | /Makefile 28 | config.mk 29 | config.stamp 30 | no_llvm_build 31 | 32 | ## Build 33 | /dl/ 34 | /doc/ 35 | /inst/ 36 | /llvm/ 37 | /mingw-build/ 38 | build/ 39 | !/compiler/rustc_mir_build/src/build/ 40 | /build-rust-analyzer/ 41 | /dist/ 42 | /unicode-downloads 43 | /target 44 | /src/bootstrap/target 45 | /src/tools/x/target 46 | # Created by default with `src/ci/docker/run.sh` 47 | /obj/ 48 | 49 | ## Temporary files 50 | *~ 51 | \#* 52 | \#*\# 53 | .#* 54 | 55 | ## Tags 56 | tags 57 | tags.* 58 | TAGS 59 | TAGS.* 60 | 61 | ## Python 62 | __pycache__/ 63 | *.py[cod] 64 | *$py.class 65 | 66 | ## Node 67 | node_modules 68 | 69 | ## Rustdoc GUI tests 70 | tests/rustdoc-gui/src/**.lock 71 | 72 | ## Model cache 73 | .fastembed_cache 74 | 75 | ## Rust files 76 | main.rs 77 | Cargo.lock 78 | 79 | ## Nix 80 | /.direnv -------------------------------------------------------------------------------- /.releaserc: -------------------------------------------------------------------------------- 1 | { 2 | "branches": [ 3 | "main", 4 | "next", 5 | { 6 | "name": "beta", 7 | "prerelease": true 8 | } 9 | ], 10 | "plugins": [ 11 | [ 12 | "@semantic-release/commit-analyzer", 13 | { 14 | "preset": "conventionalcommits", 15 | "releaseRules": [ 16 | { 17 | "breaking": true, 18 | "release": "major" 19 | }, 20 | { 21 | "type": "feat", 22 | "release": "minor" 23 | }, 24 | { 25 | "type": "fix", 26 | "release": "patch" 27 | }, 28 | { 29 | "type": "perf", 30 | "release": "patch" 31 | }, 32 | { 33 | "type": "revert", 34 | "release": "patch" 35 | }, 36 | { 37 | "type": "docs", 38 | "release": "patch" 39 | }, 40 | { 41 | "type": "style", 42 | "release": "patch" 43 | }, 44 | { 45 | "type": "refactor", 46 | "release": "patch" 47 | }, 48 | { 49 | "type": "test", 50 | "release": "patch" 51 | }, 52 | { 53 | "type": "build", 54 | "release": "patch" 55 | }, 56 | { 57 | "type": "ci", 58 | "release": "patch" 59 | }, 60 | { 61 | "type": "chore", 62 | "release": "patch" 63 | } 64 | ] 65 | } 66 | ], 67 | "@semantic-release/release-notes-generator", 68 | "@semantic-release/github", 69 | [ 70 | "semantic-release-cargo", 71 | { 72 | "allFeatures": true, 73 | "check": true, 74 | "publishArgs": [ 75 | "--no-verify" 76 | ] 77 | } 78 | ], 79 | [ 80 | "@semantic-release/git", 81 | { 82 | "assets": [ 83 | "Cargo.toml" 84 | ], 85 | "message": "chore(release): ${nextRelease.version} [skip ci]\n\n${nextRelease.notes}" 86 | } 87 | ], 88 | [ 89 | "@semantic-release/release-notes-generator", 90 | { 91 | "preset": "conventionalcommits", 92 | "parserOpts": { 93 | "noteKeywords": [ 94 | "BREAKING CHANGE", 95 | "BREAKING CHANGES", 96 | "BREAKING" 97 | ] 98 | }, 99 | "writerOpts": { 100 | "commitsSort": [ 101 | "subject", 102 | "scope" 103 | ] 104 | }, 105 | "presetConfig": { 106 | "types": [ 107 | { 108 | "type": "feat", 109 | "section": "🍕 Features" 110 | }, 111 | { 112 | "type": "feature", 113 | "section": "🍕 Features" 114 | }, 115 | { 116 | "type": "fix", 117 | "section": "🐛 Bug Fixes" 118 | }, 119 | { 120 | "type": "perf", 121 | "section": "🔥 Performance Improvements" 122 | }, 123 | { 124 | "type": "revert", 125 | "section": "⏩ Reverts" 126 | }, 127 | { 128 | "type": "docs", 129 | "section": "📝 Documentation" 130 | }, 131 | { 132 | "type": "style", 133 | "section": "🎨 Styles" 134 | }, 135 | { 136 | "type": "refactor", 137 | "section": "🧑‍💻 Code Refactoring" 138 | }, 139 | { 140 | "type": "test", 141 | "section": "✅ Tests" 142 | }, 143 | { 144 | "type": "build", 145 | "section": "🤖 Build System" 146 | }, 147 | { 148 | "type": "ci", 149 | "section": "🔁 Continuous Integration" 150 | }, 151 | { 152 | "type": "chore", 153 | "section": "🧹 Chores" 154 | } 155 | ] 156 | } 157 | } 158 | ] 159 | ] 160 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "fastembed" 3 | version = "4.9.1" 4 | edition = "2021" 5 | description = "Library for generating vector embeddings, reranking locally." 6 | license = "Apache-2.0" 7 | authors = [ 8 | "Anush008 ", 9 | "Josh Niemelä ", 10 | "GrisiaEvy ", 11 | "George MacKerron ", 12 | "Timon Vonk ", 13 | "Luya Wang ", 14 | "Tri ", 15 | "Denny Wong ", 16 | "Alex Rozgo ", 17 | ] 18 | documentation = "https://docs.rs/fastembed" 19 | repository = "https://github.com/Anush008/fastembed-rs" 20 | homepage = "https://crates.io/crates/fastembed" 21 | 22 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 23 | 24 | [dependencies] 25 | anyhow = { version = "1" } 26 | hf-hub = { version = "0.4.1", default-features = false, optional = true } 27 | image = "0.25.2" 28 | ndarray = { version = "0.16", default-features = false } 29 | ort = { version = "=2.0.0-rc.9", default-features = false, features = [ 30 | "ndarray", 31 | ] } 32 | rayon = { version = "1.10", default-features = false } 33 | serde_json = { version = "1" } 34 | tokenizers = { version = "0.21", default-features = false, features = ["onig"] } 35 | ort-sys = { version = "=2.0.0-rc.9", default-features = false } 36 | 37 | [features] 38 | default = ["ort-download-binaries", "hf-hub-native-tls"] 39 | 40 | hf-hub = ["dep:hf-hub", "hf-hub?/ureq"] 41 | hf-hub-native-tls = ["hf-hub", "hf-hub?/native-tls"] 42 | hf-hub-rustls-tls = ["hf-hub", "hf-hub?/rustls-tls"] 43 | 44 | ort-download-binaries = ["ort/download-binaries"] 45 | ort-load-dynamic = ["ort/load-dynamic"] 46 | 47 | # This feature does not change any code, but is used to limit tests if 48 | # the user does not have `optimum-cli` or even python installed. 49 | optimum-cli = [] 50 | 51 | # For compatibility recommend using hf-hub-native-tls 52 | online = ["hf-hub-native-tls"] 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

FastEmbed-rs 🦀

3 |

Rust library for generating vector embeddings, reranking locally!

4 | Crates.io 5 | MIT Licensed 6 | Semantic release 7 |
8 | 9 | ## 🍕 Features 10 | 11 | - Supports synchronous usage. No dependency on Tokio. 12 | - Uses [@pykeio/ort](https://github.com/pykeio/ort) for performant ONNX inference. 13 | - Uses [@huggingface/tokenizers](https://github.com/huggingface/tokenizers) for fast encodings. 14 | - Supports batch embeddings generation with parallelism using [@rayon-rs/rayon](https://github.com/rayon-rs/rayon). 15 | 16 | ## 🔍 Not looking for Rust? 17 | 18 | - Python 🐍: [fastembed](https://github.com/qdrant/fastembed) 19 | - Go 🐳: [fastembed-go](https://github.com/Anush008/fastembed-go) 20 | - JavaScript 🌐: [fastembed-js](https://github.com/Anush008/fastembed-js) 21 | 22 | ## 🤖 Models 23 | 24 | ### Text Embedding 25 | 26 | - [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default 27 | - [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 28 | - [**mixedbread-ai/mxbai-embed-large-v1**](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) 29 | - [**Qdrant/clip-ViT-B-32-text**](https://huggingface.co/Qdrant/clip-ViT-B-32-text) - pairs with `clip-ViT-B-32-vision` for image-to-text search 30 | - [**BAAI/bge-large-en-v1.5**](https://huggingface.co/BAAI/bge-large-en-v1.5) 31 | - [**BAAI/bge-small-zh-v1.5**](https://huggingface.co/BAAI/bge-small-zh-v1.5) 32 | - [**BAAI/bge-large-zh-v1.5**](https://huggingface.co/BAAI/bge-large-zh-v1.5) 33 | - [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5) 34 | - [**sentence-transformers/all-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) 35 | - [**sentence-transformers/paraphrase-MiniLM-L12-v2**](https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L12-v2) 36 | - [**sentence-transformers/paraphrase-multilingual-mpnet-base-v2**](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) 37 | - [**lightonai/ModernBERT-embed-large**](https://huggingface.co/lightonai/modernbert-embed-large) 38 | - [**nomic-ai/nomic-embed-text-v1**](https://huggingface.co/nomic-ai/nomic-embed-text-v1) 39 | - [**nomic-ai/nomic-embed-text-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) - pairs with `nomic-embed-vision-v1.5` for image-to-text search 40 | - [**intfloat/multilingual-e5-small**](https://huggingface.co/intfloat/multilingual-e5-small) 41 | - [**intfloat/multilingual-e5-base**](https://huggingface.co/intfloat/multilingual-e5-base) 42 | - [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large) 43 | - [**Alibaba-NLP/gte-base-en-v1.5**](https://huggingface.co/Alibaba-NLP/gte-base-en-v1.5) 44 | - [**Alibaba-NLP/gte-large-en-v1.5**](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5) 45 | 46 | ### Sparse Text Embedding 47 | 48 | - [**prithivida/Splade_PP_en_v1**](https://huggingface.co/prithivida/Splade_PP_en_v1) - Default 49 | 50 | ### Image Embedding 51 | 52 | - [**Qdrant/clip-ViT-B-32-vision**](https://huggingface.co/Qdrant/clip-ViT-B-32-vision) - Default 53 | - [**Qdrant/resnet50-onnx**](https://huggingface.co/Qdrant/resnet50-onnx) 54 | - [**Qdrant/Unicom-ViT-B-16**](https://huggingface.co/Qdrant/Unicom-ViT-B-16) 55 | - [**Qdrant/Unicom-ViT-B-32**](https://huggingface.co/Qdrant/Unicom-ViT-B-32) 56 | - [**nomic-ai/nomic-embed-vision-v1.5**](https://huggingface.co/nomic-ai/nomic-embed-vision-v1.5) 57 | 58 | ### Reranking 59 | 60 | - [**BAAI/bge-reranker-base**](https://huggingface.co/BAAI/bge-reranker-base) - Default 61 | - [**BAAI/bge-reranker-v2-m3**](https://huggingface.co/BAAI/bge-reranker-v2-m3) 62 | - [**jinaai/jina-reranker-v1-turbo-en**](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en) 63 | - [**jinaai/jina-reranker-v2-base-multiligual**](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) 64 | 65 | ## 🚀 Installation 66 | 67 | Run the following command in your project directory: 68 | 69 | ```bash 70 | cargo add fastembed 71 | ``` 72 | 73 | Or add the following line to your Cargo.toml: 74 | 75 | ```toml 76 | [dependencies] 77 | fastembed = "4" 78 | ``` 79 | 80 | ## 📖 Usage 81 | 82 | ### Text Embeddings 83 | 84 | ```rust 85 | use fastembed::{TextEmbedding, InitOptions, EmbeddingModel}; 86 | 87 | // With default InitOptions 88 | let model = TextEmbedding::try_new(Default::default())?; 89 | 90 | // With custom InitOptions 91 | let model = TextEmbedding::try_new( 92 | InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true), 93 | )?; 94 | 95 | let documents = vec![ 96 | "passage: Hello, World!", 97 | "query: Hello, World!", 98 | "passage: This is an example passage.", 99 | // You can leave out the prefix but it's recommended 100 | "fastembed-rs is licensed under Apache 2.0" 101 | ]; 102 | 103 | // Generate embeddings with the default batch size, 256 104 | let embeddings = model.embed(documents, None)?; 105 | 106 | println!("Embeddings length: {}", embeddings.len()); // -> Embeddings length: 4 107 | println!("Embedding dimension: {}", embeddings[0].len()); // -> Embedding dimension: 384 108 | 109 | ``` 110 | 111 | ### Image Embeddings 112 | 113 | ```rust 114 | use fastembed::{ImageEmbedding, ImageInitOptions, ImageEmbeddingModel}; 115 | 116 | // With default InitOptions 117 | let model = ImageEmbedding::try_new(Default::default())?; 118 | 119 | // With custom InitOptions 120 | let model = ImageEmbedding::try_new( 121 | ImageInitOptions::new(ImageEmbeddingModel::ClipVitB32).with_show_download_progress(true), 122 | )?; 123 | 124 | let images = vec!["assets/image_0.png", "assets/image_1.png"]; 125 | 126 | // Generate embeddings with the default batch size, 256 127 | let embeddings = model.embed(images, None)?; 128 | 129 | println!("Embeddings length: {}", embeddings.len()); // -> Embeddings length: 2 130 | println!("Embedding dimension: {}", embeddings[0].len()); // -> Embedding dimension: 512 131 | ``` 132 | 133 | ### Candidates Reranking 134 | 135 | ```rust 136 | use fastembed::{TextRerank, RerankInitOptions, RerankerModel}; 137 | 138 | let model = TextRerank::try_new( 139 | RerankInitOptions::new(RerankerModel::BGERerankerBase).with_show_download_progress(true), 140 | )?; 141 | 142 | let documents = vec![ 143 | "hi", 144 | "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear, is a bear species endemic to China.", 145 | "panda is animal", 146 | "i dont know", 147 | "kind of mammal", 148 | ]; 149 | 150 | // Rerank with the default batch size, 256 and return document contents 151 | let results = model.rerank("what is panda?", documents, true, None)?; 152 | println!("Rerank result: {:?}", results); 153 | ``` 154 | 155 | Alternatively, local model files can be used for inference via the `try_new_from_user_defined(...)` methods of respective structs. 156 | 157 | ## ✊ Support 158 | 159 | To support the library, please donate to our primary upstream dependency, [`ort`](https://github.com/pykeio/ort?tab=readme-ov-file#-sponsor-ort) - The Rust wrapper for the ONNX runtime. 160 | 161 | ## 📄 LICENSE 162 | 163 | Apache 2.0 © [2024](https://github.com/Anush008/fastembed-rs/blob/main/LICENSE) 164 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1731533236, 9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "ref": "main", 18 | "repo": "flake-utils", 19 | "type": "github" 20 | } 21 | }, 22 | "nixpkgs": { 23 | "locked": { 24 | "lastModified": 1740367490, 25 | "narHash": "sha256-WGaHVAjcrv+Cun7zPlI41SerRtfknGQap281+AakSAw=", 26 | "owner": "NixOS", 27 | "repo": "nixpkgs", 28 | "rev": "0196c0175e9191c474c26ab5548db27ef5d34b05", 29 | "type": "github" 30 | }, 31 | "original": { 32 | "owner": "NixOS", 33 | "ref": "nixos-unstable", 34 | "repo": "nixpkgs", 35 | "type": "github" 36 | } 37 | }, 38 | "root": { 39 | "inputs": { 40 | "flake-utils": "flake-utils", 41 | "nixpkgs": "nixpkgs" 42 | } 43 | }, 44 | "systems": { 45 | "locked": { 46 | "lastModified": 1681028828, 47 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 48 | "owner": "nix-systems", 49 | "repo": "default", 50 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 51 | "type": "github" 52 | }, 53 | "original": { 54 | "owner": "nix-systems", 55 | "repo": "default", 56 | "type": "github" 57 | } 58 | } 59 | }, 60 | "root": "root", 61 | "version": 7 62 | } 63 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | inputs = { 3 | nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 4 | flake-utils.url = "github:numtide/flake-utils?ref=main"; 5 | }; 6 | 7 | outputs = inputs: 8 | inputs.flake-utils.lib.eachDefaultSystem (system: 9 | let 10 | pkgs = inputs.nixpkgs.legacyPackages.${system}; 11 | 12 | in { 13 | devShells.default = pkgs.mkShell { 14 | packages = (with pkgs; [ 15 | openssl 16 | pkg-config 17 | ]); 18 | }; 19 | }); 20 | } 21 | -------------------------------------------------------------------------------- /src/common.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | #[cfg(feature = "hf-hub")] 3 | use hf_hub::api::sync::{ApiBuilder, ApiRepo}; 4 | use std::io::Read; 5 | use std::{fs::File, path::PathBuf}; 6 | use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams}; 7 | 8 | const DEFAULT_CACHE_DIR: &str = ".fastembed_cache"; 9 | 10 | pub fn get_cache_dir() -> String { 11 | std::env::var("FASTEMBED_CACHE_DIR").unwrap_or(DEFAULT_CACHE_DIR.into()) 12 | } 13 | 14 | pub struct SparseEmbedding { 15 | pub indices: Vec, 16 | pub values: Vec, 17 | } 18 | 19 | /// Type alias for the embedding vector 20 | pub type Embedding = Vec; 21 | 22 | /// Type alias for the error type 23 | pub type Error = anyhow::Error; 24 | 25 | // Tokenizer files for "bring your own" models 26 | #[derive(Debug, Clone, PartialEq, Eq)] 27 | pub struct TokenizerFiles { 28 | pub tokenizer_file: Vec, 29 | pub config_file: Vec, 30 | pub special_tokens_map_file: Vec, 31 | pub tokenizer_config_file: Vec, 32 | } 33 | 34 | /// The procedure for loading tokenizer files from the hugging face hub is separated 35 | /// from the main load_tokenizer function (which is expecting bytes, from any source). 36 | #[cfg(feature = "hf-hub")] 37 | pub fn load_tokenizer_hf_hub(model_repo: ApiRepo, max_length: usize) -> Result { 38 | let tokenizer_files: TokenizerFiles = TokenizerFiles { 39 | tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json")?)?, 40 | config_file: read_file_to_bytes(&model_repo.get("config.json")?)?, 41 | special_tokens_map_file: read_file_to_bytes(&model_repo.get("special_tokens_map.json")?)?, 42 | 43 | tokenizer_config_file: read_file_to_bytes(&model_repo.get("tokenizer_config.json")?)?, 44 | }; 45 | 46 | load_tokenizer(tokenizer_files, max_length) 47 | } 48 | 49 | /// Function can be called directly from the try_new_from_user_defined function (providing file bytes) 50 | /// 51 | /// Or indirectly from the try_new function via load_tokenizer_hf_hub (converting HF files to bytes) 52 | pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Result { 53 | let base_error_message = 54 | "Error building TokenizerFiles for UserDefinedEmbeddingModel. Could not read {} file."; 55 | 56 | // Deserialize each tokenizer file 57 | let config: serde_json::Value = 58 | serde_json::from_slice(&tokenizer_files.config_file).map_err(|_| { 59 | std::io::Error::new( 60 | std::io::ErrorKind::InvalidData, 61 | base_error_message.replace("{}", "config.json"), 62 | ) 63 | })?; 64 | let special_tokens_map: serde_json::Value = 65 | serde_json::from_slice(&tokenizer_files.special_tokens_map_file).map_err(|_| { 66 | std::io::Error::new( 67 | std::io::ErrorKind::InvalidData, 68 | base_error_message.replace("{}", "special_tokens_map.json"), 69 | ) 70 | })?; 71 | let tokenizer_config: serde_json::Value = 72 | serde_json::from_slice(&tokenizer_files.tokenizer_config_file).map_err(|_| { 73 | std::io::Error::new( 74 | std::io::ErrorKind::InvalidData, 75 | base_error_message.replace("{}", "tokenizer_config.json"), 76 | ) 77 | })?; 78 | let mut tokenizer: tokenizers::Tokenizer = 79 | tokenizers::Tokenizer::from_bytes(tokenizer_files.tokenizer_file).map_err(|_| { 80 | std::io::Error::new( 81 | std::io::ErrorKind::InvalidData, 82 | base_error_message.replace("{}", "tokenizer.json"), 83 | ) 84 | })?; 85 | 86 | //For BGEBaseSmall, the model_max_length value is set to 1000000000000000019884624838656. Which fits in a f64 87 | let model_max_length = tokenizer_config["model_max_length"] 88 | .as_f64() 89 | .expect("Error reading model_max_length from tokenizer_config.json") 90 | as f32; 91 | let max_length = max_length.min(model_max_length as usize); 92 | let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; 93 | let pad_token = tokenizer_config["pad_token"] 94 | .as_str() 95 | .expect("Error reading pad_token from tokenizer_config.json") 96 | .into(); 97 | 98 | let mut tokenizer = tokenizer 99 | .with_padding(Some(PaddingParams { 100 | // TODO: the user should able to choose the padding strategy 101 | strategy: PaddingStrategy::BatchLongest, 102 | pad_token, 103 | pad_id, 104 | ..Default::default() 105 | })) 106 | .with_truncation(Some(TruncationParams { 107 | max_length, 108 | ..Default::default() 109 | })) 110 | .map_err(anyhow::Error::msg)? 111 | .clone(); 112 | if let serde_json::Value::Object(root_object) = special_tokens_map { 113 | for (_, value) in root_object.iter() { 114 | if value.is_string() { 115 | tokenizer.add_special_tokens(&[AddedToken { 116 | content: value.as_str().unwrap().into(), 117 | special: true, 118 | ..Default::default() 119 | }]); 120 | } else if value.is_object() { 121 | tokenizer.add_special_tokens(&[AddedToken { 122 | content: value["content"].as_str().unwrap().into(), 123 | special: true, 124 | single_word: value["single_word"].as_bool().unwrap(), 125 | lstrip: value["lstrip"].as_bool().unwrap(), 126 | rstrip: value["rstrip"].as_bool().unwrap(), 127 | normalized: value["normalized"].as_bool().unwrap(), 128 | }]); 129 | } 130 | } 131 | } 132 | Ok(tokenizer.into()) 133 | } 134 | 135 | pub fn normalize(v: &[f32]) -> Vec { 136 | let norm = (v.iter().map(|val| val * val).sum::()).sqrt(); 137 | let epsilon = 1e-12; 138 | 139 | // We add the super-small epsilon to avoid dividing by zero 140 | v.iter().map(|&val| val / (norm + epsilon)).collect() 141 | } 142 | 143 | /// Public function to read a file to bytes. 144 | /// To be used when loading local model files. 145 | /// 146 | /// Could be used to read the onnx file from a local cache in order to constitute a UserDefinedEmbeddingModel. 147 | pub fn read_file_to_bytes(file: &PathBuf) -> Result> { 148 | let mut file = File::open(file)?; 149 | let file_size = file.metadata()?.len() as usize; 150 | let mut buffer = Vec::with_capacity(file_size); 151 | file.read_to_end(&mut buffer)?; 152 | Ok(buffer) 153 | } 154 | 155 | /// Pulls a model repo from HuggingFace.. 156 | /// HF_HOME decides the location of the cache folder 157 | /// HF_ENDPOINT modifies the URL for the HuggingFace location. 158 | #[cfg(feature = "hf-hub")] 159 | pub fn pull_from_hf( 160 | model_name: String, 161 | default_cache_dir: PathBuf, 162 | show_download_progress: bool, 163 | ) -> anyhow::Result { 164 | use std::env; 165 | 166 | let cache_dir = env::var("HF_HOME") 167 | .map(PathBuf::from) 168 | .unwrap_or(default_cache_dir); 169 | 170 | let endpoint = env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string()); 171 | 172 | let api = ApiBuilder::new() 173 | .with_cache_dir(cache_dir) 174 | .with_endpoint(endpoint) 175 | .with_progress(show_download_progress) 176 | .build()?; 177 | 178 | let repo = api.model(model_name); 179 | Ok(repo) 180 | } 181 | -------------------------------------------------------------------------------- /src/image_embedding/impl.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "hf-hub")] 2 | use hf_hub::api::sync::ApiRepo; 3 | use image::DynamicImage; 4 | use ndarray::{Array3, ArrayView3}; 5 | use ort::{ 6 | session::{builder::GraphOptimizationLevel, Session}, 7 | value::Value, 8 | }; 9 | #[cfg(feature = "hf-hub")] 10 | use std::path::PathBuf; 11 | use std::{io::Cursor, path::Path, thread::available_parallelism}; 12 | 13 | use crate::{ 14 | common::normalize, models::image_embedding::models_list, Embedding, ImageEmbeddingModel, 15 | ModelInfo, 16 | }; 17 | use anyhow::anyhow; 18 | #[cfg(feature = "hf-hub")] 19 | use anyhow::Context; 20 | 21 | #[cfg(feature = "hf-hub")] 22 | use super::ImageInitOptions; 23 | use super::{ 24 | init::{ImageInitOptionsUserDefined, UserDefinedImageEmbeddingModel}, 25 | utils::{Compose, Transform, TransformData}, 26 | ImageEmbedding, DEFAULT_BATCH_SIZE, 27 | }; 28 | use rayon::prelude::*; 29 | 30 | impl ImageEmbedding { 31 | /// Try to generate a new ImageEmbedding Instance 32 | /// 33 | /// Uses the highest level of Graph optimization 34 | /// 35 | /// Uses the total number of CPUs available as the number of intra-threads 36 | #[cfg(feature = "hf-hub")] 37 | pub fn try_new(options: ImageInitOptions) -> anyhow::Result { 38 | let ImageInitOptions { 39 | model_name, 40 | execution_providers, 41 | cache_dir, 42 | show_download_progress, 43 | } = options; 44 | 45 | let threads = available_parallelism()?.get(); 46 | 47 | let model_repo = ImageEmbedding::retrieve_model( 48 | model_name.clone(), 49 | cache_dir.clone(), 50 | show_download_progress, 51 | )?; 52 | 53 | let preprocessor_file = model_repo 54 | .get("preprocessor_config.json") 55 | .context("Failed to retrieve preprocessor_config.json")?; 56 | let preprocessor = Compose::from_file(preprocessor_file)?; 57 | 58 | let model_file_name = ImageEmbedding::get_model_info(&model_name).model_file; 59 | let model_file_reference = model_repo 60 | .get(&model_file_name) 61 | .context(format!("Failed to retrieve {}", model_file_name))?; 62 | 63 | let session = Session::builder()? 64 | .with_execution_providers(execution_providers)? 65 | .with_optimization_level(GraphOptimizationLevel::Level3)? 66 | .with_intra_threads(threads)? 67 | .commit_from_file(model_file_reference)?; 68 | 69 | Ok(Self::new(preprocessor, session)) 70 | } 71 | 72 | /// Create a ImageEmbedding instance from model files provided by the user. 73 | /// 74 | /// This can be used for 'bring your own' embedding models 75 | pub fn try_new_from_user_defined( 76 | model: UserDefinedImageEmbeddingModel, 77 | options: ImageInitOptionsUserDefined, 78 | ) -> anyhow::Result { 79 | let ImageInitOptionsUserDefined { 80 | execution_providers, 81 | } = options; 82 | 83 | let threads = available_parallelism()?.get(); 84 | 85 | let preprocessor = Compose::from_bytes(model.preprocessor_file)?; 86 | 87 | let session = Session::builder()? 88 | .with_execution_providers(execution_providers)? 89 | .with_optimization_level(GraphOptimizationLevel::Level3)? 90 | .with_intra_threads(threads)? 91 | .commit_from_memory(&model.onnx_file)?; 92 | 93 | Ok(Self::new(preprocessor, session)) 94 | } 95 | 96 | /// Private method to return an instance 97 | fn new(preprocessor: Compose, session: Session) -> Self { 98 | Self { 99 | preprocessor, 100 | session, 101 | } 102 | } 103 | 104 | /// Return the ImageEmbedding model's directory from cache or remote retrieval 105 | #[cfg(feature = "hf-hub")] 106 | fn retrieve_model( 107 | model: ImageEmbeddingModel, 108 | cache_dir: PathBuf, 109 | show_download_progress: bool, 110 | ) -> anyhow::Result { 111 | use crate::common::pull_from_hf; 112 | 113 | pull_from_hf(model.to_string(), cache_dir, show_download_progress) 114 | } 115 | 116 | /// Retrieve a list of supported models 117 | pub fn list_supported_models() -> Vec> { 118 | models_list() 119 | } 120 | 121 | /// Get ModelInfo from ImageEmbeddingModel 122 | pub fn get_model_info(model: &ImageEmbeddingModel) -> ModelInfo { 123 | ImageEmbedding::list_supported_models() 124 | .into_iter() 125 | .find(|m| &m.model == model) 126 | .expect("Model not found.") 127 | } 128 | 129 | /// Method to generate image embeddings for a Vec of image bytes 130 | pub fn embed_bytes( 131 | &self, 132 | images: &[&[u8]], 133 | batch_size: Option, 134 | ) -> anyhow::Result> { 135 | let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); 136 | 137 | let output = images 138 | .par_chunks(batch_size) 139 | .map(|batch| { 140 | // Encode the texts in the batch 141 | let inputs = batch 142 | .iter() 143 | .map(|img| { 144 | image::ImageReader::new(Cursor::new(img)) 145 | .with_guessed_format()? 146 | .decode() 147 | .map_err(|err| anyhow!("image decode: {}", err)) 148 | }) 149 | .collect::>()?; 150 | 151 | self.embed_images(inputs) 152 | }) 153 | .collect::>>()? 154 | .into_iter() 155 | .flatten() 156 | .collect(); 157 | 158 | Ok(output) 159 | } 160 | 161 | /// Method to generate image embeddings for a Vec of image path 162 | // Generic type to accept String, &str, OsString, &OsStr 163 | pub fn embed + Send + Sync>( 164 | &self, 165 | images: Vec, 166 | batch_size: Option, 167 | ) -> anyhow::Result> { 168 | // Determine the batch size, default if not specified 169 | let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); 170 | 171 | let output = images 172 | .par_chunks(batch_size) 173 | .map(|batch| { 174 | // Encode the texts in the batch 175 | let inputs = batch 176 | .iter() 177 | .map(|img| { 178 | image::ImageReader::open(img)? 179 | .decode() 180 | .map_err(|err| anyhow!("image decode: {}", err)) 181 | }) 182 | .collect::>()?; 183 | 184 | self.embed_images(inputs) 185 | }) 186 | .collect::>>()? 187 | .into_iter() 188 | .flatten() 189 | .collect(); 190 | 191 | Ok(output) 192 | } 193 | 194 | /// Embed DynamicImages 195 | pub fn embed_images(&self, imgs: Vec) -> anyhow::Result> { 196 | let inputs = imgs 197 | .into_iter() 198 | .map(|img| { 199 | let pixels = self.preprocessor.transform(TransformData::Image(img))?; 200 | match pixels { 201 | TransformData::NdArray(array) => Ok(array), 202 | _ => Err(anyhow!("Preprocessor configuration error!")), 203 | } 204 | }) 205 | .collect::>>>()?; 206 | 207 | // Extract the batch size 208 | let inputs_view: Vec> = inputs.iter().map(|img| img.view()).collect(); 209 | let pixel_values_array = ndarray::stack(ndarray::Axis(0), &inputs_view)?; 210 | 211 | let input_name = self.session.inputs[0].name.clone(); 212 | let session_inputs = ort::inputs![ 213 | input_name => Value::from_array(pixel_values_array)?, 214 | ]?; 215 | 216 | let outputs = self.session.run(session_inputs)?; 217 | 218 | // Try to get the only output key 219 | // If multiple, then default to few known keys `image_embeds` and `last_hidden_state` 220 | let last_hidden_state_key = match outputs.len() { 221 | 1 => vec![outputs.keys().next().unwrap()], 222 | _ => vec!["image_embeds", "last_hidden_state"], 223 | }; 224 | 225 | // Extract tensor and handle different dimensionalities 226 | let output_data = last_hidden_state_key 227 | .iter() 228 | .find_map(|&key| { 229 | outputs 230 | .get(key) 231 | .and_then(|v| v.try_extract_tensor::().ok()) 232 | }) 233 | .ok_or_else(|| anyhow!("Could not extract tensor from any known output key"))?; 234 | let shape = output_data.shape(); 235 | 236 | let embeddings = match shape.len() { 237 | 3 => { 238 | // For 3D output [batch_size, sequence_length, hidden_size] 239 | // Take only the first token, sequence_length[0] (CLS token), embedding 240 | // and return [batch_size, hidden_size] 241 | (0..shape[0]) 242 | .map(|batch_idx| { 243 | let cls_embedding = 244 | output_data.slice(ndarray::s![batch_idx, 0, ..]).to_vec(); 245 | normalize(&cls_embedding) 246 | }) 247 | .collect() 248 | } 249 | 2 => { 250 | // For 2D output [batch_size, hidden_size] 251 | output_data 252 | .rows() 253 | .into_iter() 254 | .map(|row| normalize(row.as_slice().unwrap())) 255 | .collect() 256 | } 257 | _ => return Err(anyhow!("Unexpected output tensor shape: {:?}", shape)), 258 | }; 259 | 260 | Ok(embeddings) 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /src/image_embedding/init.rs: -------------------------------------------------------------------------------- 1 | use std::path::{Path, PathBuf}; 2 | 3 | use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; 4 | 5 | use crate::{get_cache_dir, ImageEmbeddingModel}; 6 | 7 | use super::{utils::Compose, DEFAULT_EMBEDDING_MODEL}; 8 | 9 | /// Options for initializing the ImageEmbedding model 10 | #[derive(Debug, Clone)] 11 | #[non_exhaustive] 12 | pub struct ImageInitOptions { 13 | pub model_name: ImageEmbeddingModel, 14 | pub execution_providers: Vec, 15 | pub cache_dir: PathBuf, 16 | pub show_download_progress: bool, 17 | } 18 | 19 | impl ImageInitOptions { 20 | pub fn new(model_name: ImageEmbeddingModel) -> Self { 21 | Self { 22 | model_name, 23 | ..Default::default() 24 | } 25 | } 26 | 27 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 28 | self.cache_dir = cache_dir; 29 | self 30 | } 31 | 32 | pub fn with_execution_providers( 33 | mut self, 34 | execution_providers: Vec, 35 | ) -> Self { 36 | self.execution_providers = execution_providers; 37 | self 38 | } 39 | 40 | pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { 41 | self.show_download_progress = show_download_progress; 42 | self 43 | } 44 | } 45 | 46 | impl Default for ImageInitOptions { 47 | fn default() -> Self { 48 | Self { 49 | model_name: DEFAULT_EMBEDDING_MODEL, 50 | execution_providers: Default::default(), 51 | cache_dir: Path::new(&get_cache_dir()).to_path_buf(), 52 | show_download_progress: true, 53 | } 54 | } 55 | } 56 | 57 | /// Options for initializing UserDefinedImageEmbeddingModel 58 | /// 59 | /// Model files are held by the UserDefinedImageEmbeddingModel struct 60 | #[derive(Debug, Clone, Default)] 61 | #[non_exhaustive] 62 | pub struct ImageInitOptionsUserDefined { 63 | pub execution_providers: Vec, 64 | } 65 | 66 | impl ImageInitOptionsUserDefined { 67 | pub fn new() -> Self { 68 | Self::default() 69 | } 70 | 71 | pub fn with_execution_providers( 72 | mut self, 73 | execution_providers: Vec, 74 | ) -> Self { 75 | self.execution_providers = execution_providers; 76 | self 77 | } 78 | } 79 | 80 | /// Convert ImageInitOptions to ImageInitOptionsUserDefined 81 | /// 82 | /// This is useful for when the user wants to use the same options for both the default and user-defined models 83 | impl From for ImageInitOptionsUserDefined { 84 | fn from(options: ImageInitOptions) -> Self { 85 | ImageInitOptionsUserDefined { 86 | execution_providers: options.execution_providers, 87 | } 88 | } 89 | } 90 | 91 | /// Struct for "bring your own" embedding models 92 | /// 93 | /// The onnx_file and preprocessor_files are expecting the files' bytes 94 | #[derive(Debug, Clone, PartialEq, Eq)] 95 | #[non_exhaustive] 96 | pub struct UserDefinedImageEmbeddingModel { 97 | pub onnx_file: Vec, 98 | pub preprocessor_file: Vec, 99 | } 100 | 101 | impl UserDefinedImageEmbeddingModel { 102 | pub fn new(onnx_file: Vec, preprocessor_file: Vec) -> Self { 103 | Self { 104 | onnx_file, 105 | preprocessor_file, 106 | } 107 | } 108 | } 109 | 110 | /// Rust representation of the ImageEmbedding model 111 | pub struct ImageEmbedding { 112 | pub(crate) preprocessor: Compose, 113 | pub(crate) session: Session, 114 | } 115 | -------------------------------------------------------------------------------- /src/image_embedding/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::models::image_embedding::ImageEmbeddingModel; 2 | const DEFAULT_BATCH_SIZE: usize = 256; 3 | const DEFAULT_EMBEDDING_MODEL: ImageEmbeddingModel = ImageEmbeddingModel::ClipVitB32; 4 | 5 | mod utils; 6 | 7 | mod init; 8 | pub use init::*; 9 | 10 | mod r#impl; 11 | -------------------------------------------------------------------------------- /src/image_embedding/utils.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use image::{imageops::FilterType, DynamicImage, GenericImageView}; 3 | use ndarray::{Array, Array3}; 4 | use std::ops::{Div, Sub}; 5 | #[cfg(feature = "hf-hub")] 6 | use std::{fs::read_to_string, path::Path}; 7 | 8 | pub enum TransformData { 9 | Image(DynamicImage), 10 | NdArray(Array3), 11 | } 12 | 13 | impl TransformData { 14 | pub fn image(self) -> anyhow::Result { 15 | match self { 16 | TransformData::Image(img) => Ok(img), 17 | _ => Err(anyhow!("TransformData convert error")), 18 | } 19 | } 20 | 21 | pub fn array(self) -> anyhow::Result> { 22 | match self { 23 | TransformData::NdArray(array) => Ok(array), 24 | _ => Err(anyhow!("TransformData convert error")), 25 | } 26 | } 27 | } 28 | 29 | pub trait Transform: Send + Sync { 30 | fn transform(&self, images: TransformData) -> anyhow::Result; 31 | } 32 | 33 | struct ConvertToRGB; 34 | 35 | impl Transform for ConvertToRGB { 36 | fn transform(&self, data: TransformData) -> anyhow::Result { 37 | let image = data.image()?; 38 | let image = image.into_rgb8().into(); 39 | Ok(TransformData::Image(image)) 40 | } 41 | } 42 | 43 | pub struct Resize { 44 | pub size: (u32, u32), 45 | pub resample: FilterType, 46 | } 47 | 48 | impl Transform for Resize { 49 | fn transform(&self, data: TransformData) -> anyhow::Result { 50 | let image = data.image()?; 51 | let image = image.resize_exact(self.size.0, self.size.1, self.resample); 52 | Ok(TransformData::Image(image)) 53 | } 54 | } 55 | 56 | pub struct CenterCrop { 57 | pub size: (u32, u32), 58 | } 59 | 60 | impl Transform for CenterCrop { 61 | fn transform(&self, data: TransformData) -> anyhow::Result { 62 | let mut image = data.image()?; 63 | let (mut origin_width, mut origin_height) = image.dimensions(); 64 | let (crop_width, crop_height) = self.size; 65 | if origin_width >= crop_width && origin_height >= crop_height { 66 | // cropped area is within image boundaries 67 | let x = (origin_width - crop_width) / 2; 68 | let y = (origin_height - crop_height) / 2; 69 | let image = image.crop_imm(x, y, crop_width, crop_height); 70 | Ok(TransformData::Image(image)) 71 | } else { 72 | if origin_width > crop_width || origin_height > crop_height { 73 | let (new_width, new_height) = 74 | (origin_width.min(crop_width), origin_height.min(crop_height)); 75 | let (x, y) = if origin_width > crop_width { 76 | ((origin_width - crop_width) / 2, 0) 77 | } else { 78 | (0, (origin_height - crop_height) / 2) 79 | }; 80 | image = image.crop_imm(x, y, new_width, new_height); 81 | (origin_width, origin_height) = image.dimensions(); 82 | } 83 | let mut pixels_array = 84 | Array3::zeros((3usize, crop_width as usize, crop_height as usize)); 85 | let offset_x = (crop_width - origin_width) / 2; 86 | let offset_y = (crop_height - origin_height) / 2; 87 | // whc -> chw 88 | for (x, y, pixel) in image.to_rgb8().enumerate_pixels() { 89 | pixels_array[[0, (y + offset_y) as usize, (x + offset_x) as usize]] = 90 | pixel[0] as f32; 91 | pixels_array[[1, (y + offset_y) as usize, (x + offset_x) as usize]] = 92 | pixel[1] as f32; 93 | pixels_array[[2, (y + offset_y) as usize, (x + offset_x) as usize]] = 94 | pixel[2] as f32; 95 | } 96 | Ok(TransformData::NdArray(pixels_array)) 97 | } 98 | } 99 | } 100 | 101 | struct PILToNDarray; 102 | 103 | impl Transform for PILToNDarray { 104 | fn transform(&self, data: TransformData) -> anyhow::Result { 105 | match data { 106 | TransformData::Image(image) => { 107 | let image = image.to_rgb8(); 108 | let (width, height) = image.dimensions(); 109 | // whc -> chw 110 | let mut pixels_array = Array3::zeros((3usize, height as usize, width as usize)); 111 | for (x, y, pixel) in image.enumerate_pixels() { 112 | pixels_array[[0, y as usize, x as usize]] = pixel[0] as f32; 113 | pixels_array[[1, y as usize, x as usize]] = pixel[1] as f32; 114 | pixels_array[[2, y as usize, x as usize]] = pixel[2] as f32; 115 | } 116 | Ok(TransformData::NdArray(pixels_array)) 117 | } 118 | ndarray => Ok(ndarray), 119 | } 120 | } 121 | } 122 | 123 | pub struct Rescale { 124 | pub scale: f32, 125 | } 126 | 127 | impl Transform for Rescale { 128 | fn transform(&self, data: TransformData) -> anyhow::Result { 129 | let array = data.array()?; 130 | let array = array * self.scale; 131 | Ok(TransformData::NdArray(array)) 132 | } 133 | } 134 | 135 | pub struct Normalize { 136 | pub mean: Vec, 137 | pub std: Vec, 138 | } 139 | 140 | impl Transform for Normalize { 141 | fn transform(&self, data: TransformData) -> anyhow::Result { 142 | let array = data.array()?; 143 | let mean = Array::from_vec(self.mean.clone()) 144 | .into_shape_with_order((3, 1, 1)) 145 | .unwrap(); 146 | let std = Array::from_vec(self.std.clone()) 147 | .into_shape_with_order((3, 1, 1)) 148 | .unwrap(); 149 | 150 | let shape = array.shape().to_vec(); 151 | match shape.as_slice() { 152 | [c, h, w] => { 153 | let array_normalized = array 154 | .sub(mean.broadcast((*c, *h, *w)).unwrap()) 155 | .div(std.broadcast((*c, *h, *w)).unwrap()); 156 | Ok(TransformData::NdArray(array_normalized)) 157 | } 158 | _ => Err(anyhow!( 159 | "Transformer convert error. Normlize operator get error shape." 160 | )), 161 | } 162 | } 163 | } 164 | 165 | pub struct Compose { 166 | transforms: Vec>, 167 | } 168 | 169 | impl Compose { 170 | fn new(transforms: Vec>) -> Self { 171 | Self { transforms } 172 | } 173 | 174 | #[cfg(feature = "hf-hub")] 175 | pub fn from_file>(file: P) -> anyhow::Result { 176 | let content = read_to_string(file)?; 177 | let config = serde_json::from_str(&content)?; 178 | load_preprocessor(config) 179 | } 180 | 181 | pub fn from_bytes>(bytes: P) -> anyhow::Result { 182 | let config = serde_json::from_slice(bytes.as_ref())?; 183 | load_preprocessor(config) 184 | } 185 | } 186 | 187 | impl Transform for Compose { 188 | fn transform(&self, mut image: TransformData) -> anyhow::Result { 189 | for transform in &self.transforms { 190 | image = transform.transform(image)?; 191 | } 192 | Ok(image) 193 | } 194 | } 195 | 196 | fn load_preprocessor(config: serde_json::Value) -> anyhow::Result { 197 | let mut transformers: Vec> = vec![]; 198 | transformers.push(Box::new(ConvertToRGB)); 199 | 200 | let mode = config["image_processor_type"] 201 | .as_str() 202 | .unwrap_or("CLIPImageProcessor"); 203 | match mode { 204 | "CLIPImageProcessor" => { 205 | if config["do_resize"].as_bool().unwrap_or(false) { 206 | let size = config["size"].clone(); 207 | let shortest_edge = size["shortest_edge"].as_u64(); 208 | let (height, width) = (size["height"].as_u64(), size["width"].as_u64()); 209 | 210 | if let Some(shortest_edge) = shortest_edge { 211 | let size = (shortest_edge as u32, shortest_edge as u32); 212 | transformers.push(Box::new(Resize { 213 | size, 214 | resample: FilterType::CatmullRom, 215 | })); 216 | } else if let (Some(height), Some(width)) = (height, width) { 217 | let size = (height as u32, width as u32); 218 | transformers.push(Box::new(Resize { 219 | size, 220 | resample: FilterType::CatmullRom, 221 | })); 222 | } else { 223 | return Err(anyhow!( 224 | "Size must contain either 'shortest_edge' or 'height' and 'width'." 225 | )); 226 | } 227 | } 228 | 229 | if config["do_center_crop"].as_bool().unwrap_or(false) { 230 | let crop_size = config["crop_size"].clone(); 231 | let (height, width) = if crop_size.is_u64() { 232 | let size = crop_size.as_u64().unwrap() as u32; 233 | (size, size) 234 | } else if crop_size.is_object() { 235 | ( 236 | crop_size["height"] 237 | .as_u64() 238 | .map(|height| height as u32) 239 | .ok_or(anyhow!("crop_size height must be cotained"))?, 240 | crop_size["width"] 241 | .as_u64() 242 | .map(|width| width as u32) 243 | .ok_or(anyhow!("crop_size width must be cotained"))?, 244 | ) 245 | } else { 246 | return Err(anyhow!("Invalid crop size: {:?}", crop_size)); 247 | }; 248 | transformers.push(Box::new(CenterCrop { 249 | size: (width, height), 250 | })); 251 | } 252 | } 253 | "ConvNextFeatureExtractor" => { 254 | let shortest_edge = config["size"]["shortest_edge"].as_u64(); 255 | if shortest_edge.is_none() { 256 | return Err(anyhow!("Size dictionary must contain 'shortest_edge' key.")); 257 | } 258 | let shortest_edge = shortest_edge.unwrap() as u32; 259 | let crop_pct = config["crop_pct"].as_f64().unwrap_or(0.875); 260 | if shortest_edge < 384 { 261 | let resize_shortet_edge = shortest_edge as f64 / crop_pct; 262 | transformers.push(Box::new(Resize { 263 | size: (resize_shortet_edge as u32, resize_shortet_edge as u32), 264 | resample: FilterType::CatmullRom, 265 | })); 266 | transformers.push(Box::new(CenterCrop { 267 | size: (shortest_edge, shortest_edge), 268 | })) 269 | } else { 270 | transformers.push(Box::new(Resize { 271 | size: (shortest_edge, shortest_edge), 272 | resample: FilterType::CatmullRom, 273 | })); 274 | } 275 | } 276 | "BitImageProcessor" => { 277 | if config["do_convert_rgb"].as_bool().unwrap_or(false) { 278 | transformers.push(Box::new(ConvertToRGB)); 279 | } 280 | if config["do_resize"].as_bool().unwrap_or(false) { 281 | let size = config["size"].clone(); 282 | let shortest_edge = size["shortest_edge"].as_u64(); 283 | let (height, width) = (size["height"].as_u64(), size["width"].as_u64()); 284 | 285 | if let Some(shortest_edge) = shortest_edge { 286 | let size = (shortest_edge as u32, shortest_edge as u32); 287 | transformers.push(Box::new(Resize { 288 | size, 289 | resample: FilterType::CatmullRom, 290 | })); 291 | } else if let (Some(height), Some(width)) = (height, width) { 292 | let size = (height as u32, width as u32); 293 | transformers.push(Box::new(Resize { 294 | size, 295 | resample: FilterType::CatmullRom, 296 | })); 297 | } else { 298 | return Err(anyhow!( 299 | "Size must contain either 'shortest_edge' or 'height' and 'width'." 300 | )); 301 | } 302 | } 303 | 304 | if config["do_center_crop"].as_bool().unwrap_or(false) { 305 | let crop_size = config["crop_size"].clone(); 306 | let (height, width) = if crop_size.is_u64() { 307 | let size = crop_size.as_u64().unwrap() as u32; 308 | (size, size) 309 | } else if crop_size.is_object() { 310 | ( 311 | crop_size["height"] 312 | .as_u64() 313 | .map(|height| height as u32) 314 | .ok_or(anyhow!("crop_size height must be contained"))?, 315 | crop_size["width"] 316 | .as_u64() 317 | .map(|width| width as u32) 318 | .ok_or(anyhow!("crop_size width must be contained"))?, 319 | ) 320 | } else { 321 | return Err(anyhow!("Invalid crop size: {:?}", crop_size)); 322 | }; 323 | transformers.push(Box::new(CenterCrop { 324 | size: (width, height), 325 | })); 326 | } 327 | } 328 | mode => return Err(anyhow!("Preprocessror {} is not supported", mode)), 329 | } 330 | 331 | transformers.push(Box::new(PILToNDarray)); 332 | 333 | if config["do_rescale"].as_bool().unwrap_or(true) { 334 | let rescale_factor = config["rescale_factor"].as_f64().unwrap_or(1.0f64 / 255.0); 335 | transformers.push(Box::new(Rescale { 336 | scale: rescale_factor as f32, 337 | })); 338 | } 339 | 340 | if config["do_normalize"].as_bool().unwrap_or(false) { 341 | let mean = config["image_mean"] 342 | .as_array() 343 | .ok_or(anyhow!("image_mean must be contained"))? 344 | .iter() 345 | .map(|value| { 346 | value 347 | .as_f64() 348 | .map(|num| num as f32) 349 | .ok_or(anyhow!("image_mean must be float")) 350 | }) 351 | .collect::>>()?; 352 | let std = config["image_std"] 353 | .as_array() 354 | .ok_or(anyhow!("image_std must be contained"))? 355 | .iter() 356 | .map(|value| { 357 | value 358 | .as_f64() 359 | .map(|num| num as f32) 360 | .ok_or(anyhow!("image_std must be float")) 361 | }) 362 | .collect::>>()?; 363 | transformers.push(Box::new(Normalize { mean, std })); 364 | } 365 | 366 | Ok(Compose::new(transformers)) 367 | } 368 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! [FastEmbed](https://github.com/Anush008/fastembed-rs) - Fast, light, accurate library built for retrieval embedding generation. 2 | //! 3 | //! The library provides the TextEmbedding struct to interface with text embedding models. 4 | //! 5 | #![cfg_attr( 6 | feature = "hf-hub", 7 | doc = r#" 8 | ### Instantiating [TextEmbedding](crate::TextEmbedding) 9 | ``` 10 | use fastembed::{TextEmbedding, InitOptions, EmbeddingModel}; 11 | 12 | # fn model_demo() -> anyhow::Result<()> { 13 | // With default InitOptions 14 | let model = TextEmbedding::try_new(Default::default())?; 15 | 16 | // List all supported models 17 | dbg!(TextEmbedding::list_supported_models()); 18 | 19 | // With custom InitOptions 20 | let model = TextEmbedding::try_new( 21 | InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true), 22 | )?; 23 | # Ok(()) 24 | # } 25 | ``` 26 | "# 27 | )] 28 | //! Find more info about the available options in the [InitOptions](crate::InitOptions) documentation. 29 | //! 30 | #![cfg_attr( 31 | feature = "hf-hub", 32 | doc = r#" 33 | ### Embeddings generation 34 | ``` 35 | # use fastembed::{TextEmbedding, InitOptions, EmbeddingModel}; 36 | # fn embedding_demo() -> anyhow::Result<()> { 37 | # let model: TextEmbedding = TextEmbedding::try_new(Default::default())?; 38 | let documents = vec![ 39 | "passage: Hello, World!", 40 | "query: Hello, World!", 41 | "passage: This is an example passage.", 42 | // You can leave out the prefix but it's recommended 43 | "fastembed-rs is licensed under MIT" 44 | ]; 45 | 46 | // Generate embeddings with the default batch size, 256 47 | let embeddings = model.embed(documents, None)?; 48 | 49 | println!("Embeddings length: {}", embeddings.len()); // -> Embeddings length: 4 50 | # Ok(()) 51 | # } 52 | ``` 53 | "# 54 | )] 55 | 56 | mod common; 57 | mod image_embedding; 58 | mod models; 59 | pub mod output; 60 | mod pooling; 61 | mod reranking; 62 | mod sparse_text_embedding; 63 | mod text_embedding; 64 | 65 | pub use ort::execution_providers::ExecutionProviderDispatch; 66 | 67 | pub use crate::common::{ 68 | get_cache_dir, read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, 69 | }; 70 | pub use crate::models::{ 71 | model_info::ModelInfo, model_info::RerankerModelInfo, quantization::QuantizationMode, 72 | }; 73 | pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput}; 74 | pub use crate::pooling::Pooling; 75 | 76 | // For Text Embedding 77 | pub use crate::models::text_embedding::EmbeddingModel; 78 | pub use crate::text_embedding::{ 79 | InitOptions, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, 80 | }; 81 | 82 | // For Sparse Text Embedding 83 | pub use crate::models::sparse::SparseModel; 84 | pub use crate::sparse_text_embedding::{ 85 | SparseInitOptions, SparseTextEmbedding, UserDefinedSparseModel, 86 | }; 87 | 88 | // For Image Embedding 89 | pub use crate::image_embedding::{ 90 | ImageEmbedding, ImageInitOptions, ImageInitOptionsUserDefined, UserDefinedImageEmbeddingModel, 91 | }; 92 | pub use crate::models::image_embedding::ImageEmbeddingModel; 93 | 94 | // For Reranking 95 | pub use crate::models::reranking::RerankerModel; 96 | pub use crate::reranking::{ 97 | OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank, 98 | UserDefinedRerankingModel, 99 | }; 100 | -------------------------------------------------------------------------------- /src/models/image_embedding.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, str::FromStr}; 2 | 3 | use super::model_info::ModelInfo; 4 | 5 | #[derive(Debug, Clone, PartialEq, Eq)] 6 | pub enum ImageEmbeddingModel { 7 | /// Qdrant/clip-ViT-B-32-vision 8 | ClipVitB32, 9 | /// Qdrant/resnet50-onnx 10 | Resnet50, 11 | /// Qdrant/Unicom-ViT-B-16 12 | UnicomVitB16, 13 | /// Qdrant/Unicom-ViT-B-32 14 | UnicomVitB32, 15 | /// nomic-ai/nomic-embed-vision-v1.5 16 | NomicEmbedVisionV15, 17 | } 18 | 19 | pub fn models_list() -> Vec> { 20 | let models_list = vec![ 21 | ModelInfo { 22 | model: ImageEmbeddingModel::ClipVitB32, 23 | dim: 512, 24 | description: String::from("CLIP vision encoder based on ViT-B/32"), 25 | model_code: String::from("Qdrant/clip-ViT-B-32-vision"), 26 | model_file: String::from("model.onnx"), 27 | additional_files: Vec::new(), 28 | }, 29 | ModelInfo { 30 | model: ImageEmbeddingModel::Resnet50, 31 | dim: 2048, 32 | description: String::from("ResNet-50 from `Deep Residual Learning for Image Recognition `__."), 33 | model_code: String::from("Qdrant/resnet50-onnx"), 34 | model_file: String::from("model.onnx"), 35 | additional_files: Vec::new(), 36 | }, 37 | ModelInfo { 38 | model: ImageEmbeddingModel::UnicomVitB16, 39 | dim: 768, 40 | description: String::from("Unicom Unicom-ViT-B-16 from open-metric-learning"), 41 | model_code: String::from("Qdrant/Unicom-ViT-B-16"), 42 | model_file: String::from("model.onnx"), 43 | additional_files: Vec::new(), 44 | }, 45 | ModelInfo { 46 | model: ImageEmbeddingModel::UnicomVitB32, 47 | dim: 512, 48 | description: String::from("Unicom Unicom-ViT-B-32 from open-metric-learning"), 49 | model_code: String::from("Qdrant/Unicom-ViT-B-32"), 50 | model_file: String::from("model.onnx"), 51 | additional_files: Vec::new(), 52 | }, 53 | ModelInfo { 54 | model: ImageEmbeddingModel::NomicEmbedVisionV15, 55 | dim: 768, 56 | description: String::from("Nomic NomicEmbedVisionV15"), 57 | model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"), 58 | model_file: String::from("onnx/model.onnx"), 59 | additional_files: Vec::new(), 60 | }, 61 | ]; 62 | 63 | // TODO: Use when out in stable 64 | // assert_eq!( 65 | // std::mem::variant_count::(), 66 | // models_list.len(), 67 | // "models::models() is not exhaustive" 68 | // ); 69 | 70 | models_list 71 | } 72 | 73 | impl Display for ImageEmbeddingModel { 74 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 75 | let model_info = models_list() 76 | .into_iter() 77 | .find(|model| model.model == *self) 78 | .unwrap(); 79 | write!(f, "{}", model_info.model_code) 80 | } 81 | } 82 | 83 | impl FromStr for ImageEmbeddingModel { 84 | type Err = String; 85 | 86 | fn from_str(s: &str) -> Result { 87 | models_list() 88 | .into_iter() 89 | .find(|m| m.model_code.eq_ignore_ascii_case(s)) 90 | .map(|m| m.model) 91 | .ok_or_else(|| format!("Unknown embedding model: {s}")) 92 | } 93 | } 94 | 95 | impl TryFrom for ImageEmbeddingModel { 96 | type Error = String; 97 | 98 | fn try_from(value: String) -> Result { 99 | value.parse() 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod image_embedding; 2 | pub mod model_info; 3 | pub mod quantization; 4 | pub mod reranking; 5 | pub mod sparse; 6 | pub mod text_embedding; 7 | -------------------------------------------------------------------------------- /src/models/model_info.rs: -------------------------------------------------------------------------------- 1 | use crate::RerankerModel; 2 | 3 | /// Data struct about the available models 4 | #[derive(Debug, Clone)] 5 | pub struct ModelInfo { 6 | pub model: T, 7 | pub dim: usize, 8 | pub description: String, 9 | pub model_code: String, 10 | pub model_file: String, 11 | pub additional_files: Vec, 12 | } 13 | 14 | /// Data struct about the available reranker models 15 | #[derive(Debug, Clone)] 16 | pub struct RerankerModelInfo { 17 | pub model: RerankerModel, 18 | pub description: String, 19 | pub model_code: String, 20 | pub model_file: String, 21 | pub additional_files: Vec, 22 | } 23 | -------------------------------------------------------------------------------- /src/models/quantization.rs: -------------------------------------------------------------------------------- 1 | /// Enum for quantization mode. 2 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 3 | pub enum QuantizationMode { 4 | None, 5 | Static, 6 | Dynamic, 7 | } 8 | 9 | impl Default for QuantizationMode { 10 | fn default() -> Self { 11 | Self::None 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/models/reranking.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, str::FromStr}; 2 | 3 | use crate::RerankerModelInfo; 4 | 5 | #[derive(Debug, Clone, PartialEq, Eq)] 6 | pub enum RerankerModel { 7 | /// BAAI/bge-reranker-base 8 | BGERerankerBase, 9 | /// rozgo/bge-reranker-v2-m3 10 | BGERerankerV2M3, 11 | /// jinaai/jina-reranker-v1-turbo-en 12 | JINARerankerV1TurboEn, 13 | /// jinaai/jina-reranker-v2-base-multilingual 14 | JINARerankerV2BaseMultiligual, 15 | } 16 | 17 | pub fn reranker_model_list() -> Vec { 18 | let reranker_model_list = vec![ 19 | RerankerModelInfo { 20 | model: RerankerModel::BGERerankerBase, 21 | description: String::from("reranker model for English and Chinese"), 22 | model_code: String::from("BAAI/bge-reranker-base"), 23 | model_file: String::from("onnx/model.onnx"), 24 | additional_files: vec![], 25 | }, 26 | RerankerModelInfo { 27 | model: RerankerModel::BGERerankerV2M3, 28 | description: String::from("reranker model for multilingual"), 29 | model_code: String::from("rozgo/bge-reranker-v2-m3"), 30 | model_file: String::from("model.onnx"), 31 | additional_files: vec![String::from("model.onnx.data")], 32 | }, 33 | RerankerModelInfo { 34 | model: RerankerModel::JINARerankerV1TurboEn, 35 | description: String::from("reranker model for English"), 36 | model_code: String::from("jinaai/jina-reranker-v1-turbo-en"), 37 | model_file: String::from("onnx/model.onnx"), 38 | additional_files: vec![], 39 | }, 40 | RerankerModelInfo { 41 | model: RerankerModel::JINARerankerV2BaseMultiligual, 42 | description: String::from("reranker model for multilingual"), 43 | model_code: String::from("jinaai/jina-reranker-v2-base-multilingual"), 44 | model_file: String::from("onnx/model.onnx"), 45 | additional_files: vec![], 46 | }, 47 | ]; 48 | reranker_model_list 49 | } 50 | 51 | impl Display for RerankerModel { 52 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 53 | let model_info = reranker_model_list() 54 | .into_iter() 55 | .find(|model| model.model == *self) 56 | .expect("Model not found in supported models list."); 57 | write!(f, "{}", model_info.model_code) 58 | } 59 | } 60 | 61 | impl FromStr for RerankerModel { 62 | type Err = String; 63 | 64 | fn from_str(s: &str) -> Result { 65 | reranker_model_list() 66 | .into_iter() 67 | .find(|m| m.model_code.eq_ignore_ascii_case(s)) 68 | .map(|m| m.model) 69 | .ok_or_else(|| format!("Unknown reranker model: {s}")) 70 | } 71 | } 72 | 73 | impl TryFrom for RerankerModel { 74 | type Error = String; 75 | 76 | fn try_from(value: String) -> Result { 77 | value.parse() 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/models/sparse.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, str::FromStr}; 2 | 3 | use crate::ModelInfo; 4 | 5 | #[derive(Debug, Clone, PartialEq, Eq)] 6 | pub enum SparseModel { 7 | /// prithivida/Splade_PP_en_v1 8 | SPLADEPPV1, 9 | } 10 | 11 | pub fn models_list() -> Vec> { 12 | vec![ModelInfo { 13 | model: SparseModel::SPLADEPPV1, 14 | dim: 0, 15 | description: String::from("Splade sparse vector model for commercial use, v1"), 16 | model_code: String::from("Qdrant/Splade_PP_en_v1"), 17 | model_file: String::from("model.onnx"), 18 | additional_files: Vec::new(), 19 | }] 20 | } 21 | 22 | impl Display for SparseModel { 23 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 24 | let model_info = models_list() 25 | .into_iter() 26 | .find(|model| model.model == *self) 27 | .unwrap(); 28 | write!(f, "{}", model_info.model_code) 29 | } 30 | } 31 | 32 | impl FromStr for SparseModel { 33 | type Err = String; 34 | 35 | fn from_str(s: &str) -> Result { 36 | models_list() 37 | .into_iter() 38 | .find(|m| m.model_code.eq_ignore_ascii_case(s)) 39 | .map(|m| m.model) 40 | .ok_or_else(|| format!("Unknown sparse model: {s}")) 41 | } 42 | } 43 | 44 | impl TryFrom for SparseModel { 45 | type Error = String; 46 | 47 | fn try_from(value: String) -> Result { 48 | value.parse() 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/models/text_embedding.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, convert::TryFrom, fmt::Display, str::FromStr, sync::OnceLock}; 2 | 3 | use super::model_info::ModelInfo; 4 | 5 | /// Lazy static list of all available models. 6 | static MODEL_MAP: OnceLock>> = OnceLock::new(); 7 | 8 | #[derive(Debug, Clone, PartialEq, Eq, Hash)] 9 | pub enum EmbeddingModel { 10 | /// sentence-transformers/all-MiniLM-L6-v2 11 | AllMiniLML6V2, 12 | /// Quantized sentence-transformers/all-MiniLM-L6-v2 13 | AllMiniLML6V2Q, 14 | /// sentence-transformers/all-MiniLM-L12-v2 15 | AllMiniLML12V2, 16 | /// Quantized sentence-transformers/all-MiniLM-L12-v2 17 | AllMiniLML12V2Q, 18 | /// BAAI/bge-base-en-v1.5 19 | BGEBaseENV15, 20 | /// Quantized BAAI/bge-base-en-v1.5 21 | BGEBaseENV15Q, 22 | /// BAAI/bge-large-en-v1.5 23 | BGELargeENV15, 24 | /// Quantized BAAI/bge-large-en-v1.5 25 | BGELargeENV15Q, 26 | /// BAAI/bge-small-en-v1.5 - Default 27 | BGESmallENV15, 28 | /// Quantized BAAI/bge-small-en-v1.5 29 | BGESmallENV15Q, 30 | /// nomic-ai/nomic-embed-text-v1 31 | NomicEmbedTextV1, 32 | /// nomic-ai/nomic-embed-text-v1.5 33 | NomicEmbedTextV15, 34 | /// Quantized v1.5 nomic-ai/nomic-embed-text-v1.5 35 | NomicEmbedTextV15Q, 36 | /// sentence-transformers/paraphrase-MiniLM-L6-v2 37 | ParaphraseMLMiniLML12V2, 38 | /// Quantized sentence-transformers/paraphrase-MiniLM-L6-v2 39 | ParaphraseMLMiniLML12V2Q, 40 | /// sentence-transformers/paraphrase-mpnet-base-v2 41 | ParaphraseMLMpnetBaseV2, 42 | /// BAAI/bge-small-zh-v1.5 43 | BGESmallZHV15, 44 | /// BAAI/bge-large-zh-v1.5 45 | BGELargeZHV15, 46 | /// lightonai/modernbert-embed-large 47 | ModernBertEmbedLarge, 48 | /// intfloat/multilingual-e5-small 49 | MultilingualE5Small, 50 | /// intfloat/multilingual-e5-base 51 | MultilingualE5Base, 52 | /// intfloat/multilingual-e5-large 53 | MultilingualE5Large, 54 | /// mixedbread-ai/mxbai-embed-large-v1 55 | MxbaiEmbedLargeV1, 56 | /// Quantized mixedbread-ai/mxbai-embed-large-v1 57 | MxbaiEmbedLargeV1Q, 58 | /// Alibaba-NLP/gte-base-en-v1.5 59 | GTEBaseENV15, 60 | /// Quantized Alibaba-NLP/gte-base-en-v1.5 61 | GTEBaseENV15Q, 62 | /// Alibaba-NLP/gte-large-en-v1.5 63 | GTELargeENV15, 64 | /// Quantized Alibaba-NLP/gte-large-en-v1.5 65 | GTELargeENV15Q, 66 | /// Qdrant/clip-ViT-B-32-text 67 | ClipVitB32, 68 | /// jinaai/jina-embeddings-v2-base-code 69 | JinaEmbeddingsV2BaseCode, 70 | } 71 | 72 | /// Centralized function to initialize the models map. 73 | fn init_models_map() -> HashMap> { 74 | let models_list = vec![ 75 | ModelInfo { 76 | model: EmbeddingModel::AllMiniLML6V2, 77 | dim: 384, 78 | description: String::from("Sentence Transformer model, MiniLM-L6-v2"), 79 | model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"), 80 | model_file: String::from("model.onnx"), 81 | additional_files: Vec::new(), 82 | }, 83 | ModelInfo { 84 | model: EmbeddingModel::AllMiniLML6V2Q, 85 | dim: 384, 86 | description: String::from("Quantized Sentence Transformer model, MiniLM-L6-v2"), 87 | model_code: String::from("Xenova/all-MiniLM-L6-v2"), 88 | model_file: String::from("onnx/model_quantized.onnx"), 89 | additional_files: Vec::new(), 90 | }, 91 | ModelInfo { 92 | model: EmbeddingModel::AllMiniLML12V2, 93 | dim: 384, 94 | description: String::from("Sentence Transformer model, MiniLM-L12-v2"), 95 | model_code: String::from("Xenova/all-MiniLM-L12-v2"), 96 | model_file: String::from("onnx/model.onnx"), 97 | additional_files: Vec::new(), 98 | }, 99 | ModelInfo { 100 | model: EmbeddingModel::AllMiniLML12V2Q, 101 | dim: 384, 102 | description: String::from("Quantized Sentence Transformer model, MiniLM-L12-v2"), 103 | model_code: String::from("Xenova/all-MiniLM-L12-v2"), 104 | model_file: String::from("onnx/model_quantized.onnx"), 105 | additional_files: Vec::new(), 106 | }, 107 | ModelInfo { 108 | model: EmbeddingModel::BGEBaseENV15, 109 | dim: 768, 110 | description: String::from("v1.5 release of the base English model"), 111 | model_code: String::from("Xenova/bge-base-en-v1.5"), 112 | model_file: String::from("onnx/model.onnx"), 113 | additional_files: Vec::new(), 114 | }, 115 | ModelInfo { 116 | model: EmbeddingModel::BGEBaseENV15Q, 117 | dim: 768, 118 | description: String::from("Quantized v1.5 release of the large English model"), 119 | model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"), 120 | model_file: String::from("model_optimized.onnx"), 121 | additional_files: Vec::new(), 122 | }, 123 | ModelInfo { 124 | model: EmbeddingModel::BGELargeENV15, 125 | dim: 1024, 126 | description: String::from("v1.5 release of the large English model"), 127 | model_code: String::from("Xenova/bge-large-en-v1.5"), 128 | model_file: String::from("onnx/model.onnx"), 129 | additional_files: Vec::new(), 130 | }, 131 | ModelInfo { 132 | model: EmbeddingModel::BGELargeENV15Q, 133 | dim: 1024, 134 | description: String::from("Quantized v1.5 release of the large English model"), 135 | model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"), 136 | model_file: String::from("model_optimized.onnx"), 137 | additional_files: Vec::new(), 138 | }, 139 | ModelInfo { 140 | model: EmbeddingModel::BGESmallENV15, 141 | dim: 384, 142 | description: String::from("v1.5 release of the fast and default English model"), 143 | model_code: String::from("Xenova/bge-small-en-v1.5"), 144 | model_file: String::from("onnx/model.onnx"), 145 | additional_files: Vec::new(), 146 | }, 147 | ModelInfo { 148 | model: EmbeddingModel::BGESmallENV15Q, 149 | dim: 384, 150 | description: String::from( 151 | "Quantized v1.5 release of the fast and default English model", 152 | ), 153 | model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"), 154 | model_file: String::from("model_optimized.onnx"), 155 | additional_files: Vec::new(), 156 | }, 157 | ModelInfo { 158 | model: EmbeddingModel::NomicEmbedTextV1, 159 | dim: 768, 160 | description: String::from("8192 context length english model"), 161 | model_code: String::from("nomic-ai/nomic-embed-text-v1"), 162 | model_file: String::from("onnx/model.onnx"), 163 | additional_files: Vec::new(), 164 | }, 165 | ModelInfo { 166 | model: EmbeddingModel::NomicEmbedTextV15, 167 | dim: 768, 168 | description: String::from("v1.5 release of the 8192 context length english model"), 169 | model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), 170 | model_file: String::from("onnx/model.onnx"), 171 | additional_files: Vec::new(), 172 | }, 173 | ModelInfo { 174 | model: EmbeddingModel::NomicEmbedTextV15Q, 175 | dim: 768, 176 | description: String::from( 177 | "Quantized v1.5 release of the 8192 context length english model", 178 | ), 179 | model_code: String::from("nomic-ai/nomic-embed-text-v1.5"), 180 | model_file: String::from("onnx/model_quantized.onnx"), 181 | additional_files: Vec::new(), 182 | }, 183 | ModelInfo { 184 | model: EmbeddingModel::ParaphraseMLMiniLML12V2Q, 185 | dim: 384, 186 | description: String::from("Quantized Multi-lingual model"), 187 | model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"), 188 | model_file: String::from("model_optimized.onnx"), 189 | additional_files: Vec::new(), 190 | }, 191 | ModelInfo { 192 | model: EmbeddingModel::ParaphraseMLMiniLML12V2, 193 | dim: 384, 194 | description: String::from("Multi-lingual model"), 195 | model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"), 196 | model_file: String::from("onnx/model.onnx"), 197 | additional_files: Vec::new(), 198 | }, 199 | ModelInfo { 200 | model: EmbeddingModel::ParaphraseMLMpnetBaseV2, 201 | dim: 768, 202 | description: String::from( 203 | "Sentence-transformers model for tasks like clustering or semantic search", 204 | ), 205 | model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"), 206 | model_file: String::from("onnx/model.onnx"), 207 | additional_files: Vec::new(), 208 | }, 209 | ModelInfo { 210 | model: EmbeddingModel::BGESmallZHV15, 211 | dim: 512, 212 | description: String::from("v1.5 release of the small Chinese model"), 213 | model_code: String::from("Xenova/bge-small-zh-v1.5"), 214 | model_file: String::from("onnx/model.onnx"), 215 | additional_files: Vec::new(), 216 | }, 217 | ModelInfo { 218 | model: EmbeddingModel::BGELargeZHV15, 219 | dim: 1024, 220 | description: String::from("v1.5 release of the large Chinese model"), 221 | model_code: String::from("Xenova/bge-large-zh-v1.5"), 222 | model_file: String::from("onnx/model.onnx"), 223 | additional_files: Vec::new(), 224 | }, 225 | ModelInfo { 226 | model: EmbeddingModel::ModernBertEmbedLarge, 227 | dim: 1024, 228 | description: String::from("Large model of ModernBert Text Embeddings"), 229 | model_code: String::from("lightonai/modernbert-embed-large"), 230 | model_file: String::from("onnx/model.onnx"), 231 | additional_files: Vec::new(), 232 | }, 233 | ModelInfo { 234 | model: EmbeddingModel::MultilingualE5Small, 235 | dim: 384, 236 | description: String::from("Small model of multilingual E5 Text Embeddings"), 237 | model_code: String::from("intfloat/multilingual-e5-small"), 238 | model_file: String::from("onnx/model.onnx"), 239 | additional_files: Vec::new(), 240 | }, 241 | ModelInfo { 242 | model: EmbeddingModel::MultilingualE5Base, 243 | dim: 768, 244 | description: String::from("Base model of multilingual E5 Text Embeddings"), 245 | model_code: String::from("intfloat/multilingual-e5-base"), 246 | model_file: String::from("onnx/model.onnx"), 247 | additional_files: Vec::new(), 248 | }, 249 | ModelInfo { 250 | model: EmbeddingModel::MultilingualE5Large, 251 | dim: 1024, 252 | description: String::from("Large model of multilingual E5 Text Embeddings"), 253 | model_code: String::from("Qdrant/multilingual-e5-large-onnx"), 254 | model_file: String::from("model.onnx"), 255 | additional_files: vec!["model.onnx_data".to_string()], 256 | }, 257 | ModelInfo { 258 | model: EmbeddingModel::MxbaiEmbedLargeV1, 259 | dim: 1024, 260 | description: String::from("Large English embedding model from MixedBreed.ai"), 261 | model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), 262 | model_file: String::from("onnx/model.onnx"), 263 | additional_files: Vec::new(), 264 | }, 265 | ModelInfo { 266 | model: EmbeddingModel::MxbaiEmbedLargeV1Q, 267 | dim: 1024, 268 | description: String::from("Quantized Large English embedding model from MixedBreed.ai"), 269 | model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"), 270 | model_file: String::from("onnx/model_quantized.onnx"), 271 | additional_files: Vec::new(), 272 | }, 273 | ModelInfo { 274 | model: EmbeddingModel::GTEBaseENV15, 275 | dim: 768, 276 | description: String::from("Large multilingual embedding model from Alibaba"), 277 | model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), 278 | model_file: String::from("onnx/model.onnx"), 279 | additional_files: Vec::new(), 280 | }, 281 | ModelInfo { 282 | model: EmbeddingModel::GTEBaseENV15Q, 283 | dim: 768, 284 | description: String::from("Quantized Large multilingual embedding model from Alibaba"), 285 | model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"), 286 | model_file: String::from("onnx/model_quantized.onnx"), 287 | additional_files: Vec::new(), 288 | }, 289 | ModelInfo { 290 | model: EmbeddingModel::GTELargeENV15, 291 | dim: 1024, 292 | description: String::from("Large multilingual embedding model from Alibaba"), 293 | model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), 294 | model_file: String::from("onnx/model.onnx"), 295 | additional_files: Vec::new(), 296 | }, 297 | ModelInfo { 298 | model: EmbeddingModel::GTELargeENV15Q, 299 | dim: 1024, 300 | description: String::from("Quantized Large multilingual embedding model from Alibaba"), 301 | model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"), 302 | model_file: String::from("onnx/model_quantized.onnx"), 303 | additional_files: Vec::new(), 304 | }, 305 | ModelInfo { 306 | model: EmbeddingModel::ClipVitB32, 307 | dim: 512, 308 | description: String::from("CLIP text encoder based on ViT-B/32"), 309 | model_code: String::from("Qdrant/clip-ViT-B-32-text"), 310 | model_file: String::from("model.onnx"), 311 | additional_files: Vec::new(), 312 | }, 313 | ModelInfo { 314 | model: EmbeddingModel::JinaEmbeddingsV2BaseCode, 315 | dim: 768, 316 | description: String::from("Jina embeddings v2 base code"), 317 | model_code: String::from("jinaai/jina-embeddings-v2-base-code"), 318 | model_file: String::from("onnx/model.onnx"), 319 | additional_files: Vec::new(), 320 | }, 321 | ]; 322 | 323 | // TODO: Use when out in stable 324 | // assert_eq!( 325 | // std::mem::variant_count::(), 326 | // models_list.len(), 327 | // "models::models() is not exhaustive" 328 | // ); 329 | 330 | models_list 331 | .into_iter() 332 | .fold(HashMap::new(), |mut map, model| { 333 | // Insert the model into the map 334 | map.insert(model.model.clone(), model); 335 | map 336 | }) 337 | } 338 | 339 | /// Get a map of all available models. 340 | pub fn models_map() -> &'static HashMap> { 341 | MODEL_MAP.get_or_init(init_models_map) 342 | } 343 | 344 | /// Get model information by model code. 345 | pub fn get_model_info(model: &EmbeddingModel) -> Option<&ModelInfo> { 346 | models_map().get(model) 347 | } 348 | 349 | /// Get a list of all available models. 350 | /// 351 | /// This will assign new memory to the models list; where possible, use 352 | /// [`models_map`] instead. 353 | pub fn models_list() -> Vec> { 354 | models_map().values().cloned().collect() 355 | } 356 | 357 | impl Display for EmbeddingModel { 358 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 359 | let model_info = get_model_info(self).expect("Model not found."); 360 | write!(f, "{}", model_info.model_code) 361 | } 362 | } 363 | 364 | impl FromStr for EmbeddingModel { 365 | type Err = String; 366 | 367 | fn from_str(s: &str) -> Result { 368 | models_list() 369 | .into_iter() 370 | .find(|m| m.model_code.eq_ignore_ascii_case(s)) 371 | .map(|m| m.model) 372 | .ok_or_else(|| format!("Unknown embedding model: {s}")) 373 | } 374 | } 375 | 376 | impl TryFrom for EmbeddingModel { 377 | type Error = String; 378 | 379 | fn try_from(value: String) -> Result { 380 | value.parse() 381 | } 382 | } 383 | -------------------------------------------------------------------------------- /src/output/embedding_output.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array2, ArrayView, Dim, IxDynImpl}; 2 | use ort::session::SessionOutputs; 3 | 4 | use crate::pooling; 5 | 6 | use super::{OutputKey, OutputPrecedence}; 7 | 8 | /// [`SingleBatchOutput`] contains the output of a single batch of inference. 9 | /// 10 | /// In the future, each batch will need to deal with its own post-processing, such as 11 | /// pooling etc. This struct should contain all the necessary information for the 12 | /// post-processing to be performed. 13 | pub struct SingleBatchOutput<'r, 's> { 14 | pub session_outputs: SessionOutputs<'r, 's>, 15 | pub attention_mask_array: Array2, 16 | } 17 | 18 | impl SingleBatchOutput<'_, '_> { 19 | /// Select the output from the session outputs based on the given precedence. 20 | /// 21 | /// This returns a view into the tensor, which can be used to perform further 22 | /// operations. 23 | pub fn select_output( 24 | &self, 25 | precedence: &impl OutputPrecedence, 26 | ) -> anyhow::Result>> { 27 | let ort_output: &ort::value::Value = precedence 28 | .key_precedence() 29 | .find_map(|key| match key { 30 | OutputKey::OnlyOne => self 31 | .session_outputs 32 | .get(self.session_outputs.keys().nth(0)?), 33 | OutputKey::ByOrder(idx) => { 34 | let x = self 35 | .session_outputs 36 | .get(self.session_outputs.keys().nth(*idx)?); 37 | x 38 | } 39 | OutputKey::ByName(name) => self.session_outputs.get(name), 40 | }) 41 | .ok_or_else(|| { 42 | anyhow::Error::msg(format!( 43 | "No suitable output found in the session outputs. Available outputs: {:?}", 44 | self.session_outputs.keys().collect::>() 45 | )) 46 | })?; 47 | 48 | ort_output 49 | .try_extract_tensor::() 50 | .map_err(anyhow::Error::new) 51 | } 52 | 53 | /// Select the output from the session outputs based on the given precedence and pool it. 54 | /// 55 | /// This function will pool the output based on the given pooling option, if any. 56 | pub fn select_and_pool_output( 57 | &self, 58 | precedence: &impl OutputPrecedence, 59 | pooling_opt: Option, 60 | ) -> anyhow::Result> { 61 | let tensor = self.select_output(precedence)?; 62 | 63 | // If there is none pooling, default to cls so as not to break the existing implementations 64 | // TODO: Consider return output as is to support custom model that has built-in pooling layer: 65 | // - [] Add model with built-in pooling to the list of supported model in ``models::text_embedding::models_list`` 66 | // - [] Write unit test for new model 67 | // - [] Update ``pooling::Pooling`` to include None type 68 | // - [] Change the line below to return output as is 69 | // - [] Release major version because of breaking changes 70 | match pooling_opt.unwrap_or_default() { 71 | pooling::Pooling::Cls => pooling::cls(&tensor), 72 | pooling::Pooling::Mean => pooling::mean(&tensor, self.attention_mask_array.clone()), 73 | } 74 | } 75 | } 76 | 77 | /// Container struct with all the outputs from the embedding layer. 78 | /// 79 | /// This will contain one [`SingleBatchOutput`] object per batch/inference call. 80 | pub struct EmbeddingOutput<'r, 's> { 81 | batches: Vec>, 82 | } 83 | 84 | impl<'r, 's> EmbeddingOutput<'r, 's> { 85 | /// Create a new [`EmbeddingOutput`] from a [`ort::SessionOutputs`] object. 86 | pub fn new(batches: impl IntoIterator>) -> Self { 87 | Self { 88 | batches: batches.into_iter().collect(), 89 | } 90 | } 91 | 92 | /// Consume this [`EmbeddingOutput`] and return the raw session outputs. 93 | /// 94 | /// This allows the user to perform their custom extractions outside of this 95 | /// library. 96 | pub fn into_raw(self) -> Vec> { 97 | self.batches 98 | } 99 | 100 | /// Export the output using the given output transformer. 101 | /// 102 | /// The transformer shall be responsible for: 103 | /// - Selecting the output from the session outputs based on the precedence order, 104 | /// - Extracting the tensor from the output, then 105 | /// - Transform the tensor into the desired output. 106 | /// 107 | /// The transformer function should take a slice of [`SingleBatchOutput`], and return 108 | /// the desired output type. 109 | /// 110 | /// If any of the steps fail, this function will return an error, including 111 | /// the session output not containing the expected precedence keys. 112 | pub fn export_with_transformer( 113 | &self, 114 | // TODO: Convert this to a trait alias when it's stabilized. 115 | // https://github.com/rust-lang/rust/issues/41517 116 | transformer: impl Fn(&[SingleBatchOutput]) -> anyhow::Result, 117 | ) -> anyhow::Result { 118 | transformer(&self.batches) 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/output/mod.rs: -------------------------------------------------------------------------------- 1 | //! Utilities to help with the embeddings output. 2 | //! 3 | //! Typically, [`ort::Session::run`] will generate a [`ort::SessionOutputs`] object. 4 | //! This object contains all the keys and values of the outputs generated by the model, 5 | //! which could all be useful to the caller. 6 | //! 7 | //! This module wraps the [`ort::SessionOutputs`] objects created from batching, 8 | //! and provides a more refined and controlled way to access all the outputs. 9 | //! 10 | //! # Notable structs 11 | //! 12 | //! - [`OutputPrecedence`]: This trait defines the order of precedence for selecting the output 13 | //! from the session outputs. This is simply an iterator of [`OutputKey`]. 14 | //! - [`OutputKey`]: This enum defines a single way of selecting the output from the session outputs. 15 | //! It could be by order, by name, or the only option available. This can be further 16 | //! extended to include more ways of selecting the output. 17 | //! - [`SingleBatchOutput`]: This struct contains the output of a single batch of inference. It 18 | //! should also include all the necessary information to perform per-batch post-processing such 19 | //! as pooling. 20 | //! - [`EmbeddingOutput`]: This struct wraps the [`ort::SessionOutputs`] objects, acting as a 21 | //! staging area for the raw model outputs. Models that have multiple output types, or 22 | //! have different dimensions as expected can be handled within. 23 | //! 24 | //! It provides [`EmbeddingOutput::export_with_transformer`] which allows the user to 25 | //! provide a custom transformer to extract the output from the [`SingleBatchOutput`] objects. 26 | //! 27 | //! # Implementation 28 | //! 29 | //! Modules which generate text embeddings should each define their own default [`OutputPrecedence`] 30 | //! and public helper functions to create the array transformers. These default implementations 31 | //! allow the current [embed] methods to function as before, but also allow the user more flexibility in 32 | //! extracting the output by using the [transform] method with a custom transformer. 33 | //! 34 | //! [embed]: TextEmbedding::embed 35 | //! [transform]: TextEmbedding::transform 36 | 37 | mod embedding_output; 38 | mod output_precedence; 39 | // mod output_type; 40 | 41 | pub use embedding_output::*; 42 | pub use output_precedence::*; 43 | // pub use output_type::*; 44 | 45 | #[cfg(doc)] 46 | use crate::TextEmbedding; 47 | -------------------------------------------------------------------------------- /src/output/output_precedence.rs: -------------------------------------------------------------------------------- 1 | //! Defines the precedence of the output keys in the session outputs. 2 | //! 3 | //! # Note 4 | //! 5 | //! The purpose of this module is to replicate the existing output key selection mechanism 6 | //! in the library. This is an acceptable solution in lieu of a model-specific solution, 7 | //! e.g. reading the output keys from the model file. 8 | 9 | /// Enum for defining the key of the output. 10 | #[derive(Debug, Clone)] 11 | pub enum OutputKey { 12 | OnlyOne, 13 | ByOrder(usize), 14 | ByName(&'static str), 15 | } 16 | 17 | impl Default for OutputKey { 18 | fn default() -> Self { 19 | Self::OnlyOne 20 | } 21 | } 22 | 23 | /// Trait for defining a precedence of keys in the output. 24 | /// 25 | /// This defines the order of precedence for selecting the output from the session outputs. 26 | /// By convention, an ONNX model will have at least one output called `last_hidden_state`, 27 | /// which is however not guaranteed. This trait allows the user to define the order of 28 | /// precedence for selecting the output. 29 | /// 30 | /// Any [`OutputPrecedence`] should be usable multiple times, and should not consume itself; 31 | /// this is due to use of [`rayon`] parallelism, which means 32 | /// [`OutputPrecedence::key_precedence`] will have to be called once per batch. 33 | pub trait OutputPrecedence { 34 | /// Get the precedence of the keys in the output. 35 | fn key_precedence(&self) -> impl Iterator; 36 | } 37 | 38 | /// Any slices of [`OutputKey`] can be used as an [`OutputPrecedence`]. 39 | impl OutputPrecedence for &[OutputKey] { 40 | fn key_precedence(&self) -> impl Iterator { 41 | self.iter() 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/pooling.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{s, Array2, ArrayView, Dim, Dimension, IxDynImpl}; 2 | 3 | #[derive(Debug, Clone, PartialEq, Eq)] 4 | pub enum Pooling { 5 | Cls, 6 | Mean, 7 | } 8 | 9 | impl Default for Pooling { 10 | /// Change this to define the default pooling strategy. 11 | /// 12 | /// Currently this is set to [`Self::Cls`] for backward compatibility. 13 | fn default() -> Self { 14 | Self::Cls 15 | } 16 | } 17 | 18 | pub fn cls(tensor: &ArrayView>) -> anyhow::Result> { 19 | match tensor.dim().ndim() { 20 | 2 => Ok(tensor.slice(s![.., ..]).to_owned()), 21 | 3 => Ok(tensor.slice(s![.., 0, ..]).to_owned()), 22 | _ => Err(anyhow::Error::msg(format!( 23 | "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.", 24 | shape = tensor.dim() 25 | ))), 26 | } 27 | } 28 | 29 | /// Pool the previous layer output by taking the element-wise arithmetic mean of the token-level embeddings after applying the attention mask. 30 | /// * `token_embeddings` - token embeddings in form of a tensor output of the encoding. 31 | /// * `attention_mask_array` - is the same mask generated by Tokenizer and used for encoding. 32 | // Please refer to the original python implementation for more details: 33 | // https://github.com/UKPLab/sentence-transformers/blob/c0fc0e8238f7f48a1e92dc90f6f96c86f69f1e02/sentence_transformers/models/Pooling.py#L151 34 | pub fn mean( 35 | token_embeddings: &ArrayView>, 36 | attention_mask_array: Array2, 37 | ) -> anyhow::Result> { 38 | let attention_mask_original_dim = attention_mask_array.dim(); 39 | 40 | if token_embeddings.dim().ndim() == 2 { 41 | // There are no means to speak of if the Axis(1) is missing. 42 | // Typically we'll see a dimension of (batch_size, feature_count) here. 43 | // It can be assumed that pooling is already done within the model. 44 | return Ok(token_embeddings.slice(s![.., ..]).to_owned()); 45 | } else if token_embeddings.dim().ndim() != 3 { 46 | return Err(anyhow::Error::msg(format!( 47 | "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.", 48 | shape = token_embeddings.dim() 49 | ))); 50 | } 51 | 52 | let token_embeddings = 53 | // If the token_embeddings is 3D, return the whole thing. 54 | // Using `slice` here to assert the dimension. 55 | token_embeddings 56 | .slice(s![.., .., ..]); 57 | 58 | // Compute attention mask 59 | let attention_mask = attention_mask_array 60 | .insert_axis(ndarray::Axis(2)) 61 | .broadcast(token_embeddings.dim()) 62 | .unwrap_or_else(|| { 63 | panic!( 64 | "Could not broadcast attention mask from {:?} to {:?}", 65 | attention_mask_original_dim, 66 | token_embeddings.dim() 67 | ) 68 | }) 69 | .mapv(|x| x as f32); 70 | 71 | let masked_tensor = &attention_mask * &token_embeddings; 72 | let sum = masked_tensor.sum_axis(ndarray::Axis(1)); 73 | let mask_sum = attention_mask.sum_axis(ndarray::Axis(1)); 74 | let mask_sum = mask_sum.mapv(|x| if x == 0f32 { 1.0 } else { x }); 75 | Ok(&sum / &mask_sum) 76 | } 77 | -------------------------------------------------------------------------------- /src/reranking/impl.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "hf-hub")] 2 | use anyhow::Context; 3 | use anyhow::Result; 4 | use ort::{ 5 | session::{builder::GraphOptimizationLevel, Session}, 6 | value::Value, 7 | }; 8 | use std::thread::available_parallelism; 9 | 10 | #[cfg(feature = "hf-hub")] 11 | use crate::common::load_tokenizer_hf_hub; 12 | use crate::{ 13 | common::load_tokenizer, models::reranking::reranker_model_list, RerankerModel, 14 | RerankerModelInfo, 15 | }; 16 | #[cfg(feature = "hf-hub")] 17 | use hf_hub::{api::sync::ApiBuilder, Cache}; 18 | use ndarray::{s, Array}; 19 | use rayon::{iter::ParallelIterator, slice::ParallelSlice}; 20 | use tokenizers::Tokenizer; 21 | 22 | #[cfg(feature = "hf-hub")] 23 | use super::RerankInitOptions; 24 | use super::{ 25 | OnnxSource, RerankInitOptionsUserDefined, RerankResult, TextRerank, UserDefinedRerankingModel, 26 | DEFAULT_BATCH_SIZE, 27 | }; 28 | 29 | impl TextRerank { 30 | fn new(tokenizer: Tokenizer, session: Session) -> Self { 31 | let need_token_type_ids = session 32 | .inputs 33 | .iter() 34 | .any(|input| input.name == "token_type_ids"); 35 | Self { 36 | tokenizer, 37 | session, 38 | need_token_type_ids, 39 | } 40 | } 41 | 42 | pub fn get_model_info(model: &RerankerModel) -> RerankerModelInfo { 43 | TextRerank::list_supported_models() 44 | .into_iter() 45 | .find(|m| &m.model == model) 46 | .expect("Model not found.") 47 | } 48 | 49 | pub fn list_supported_models() -> Vec { 50 | reranker_model_list() 51 | } 52 | 53 | #[cfg(feature = "hf-hub")] 54 | pub fn try_new(options: RerankInitOptions) -> Result { 55 | use super::RerankInitOptions; 56 | 57 | let RerankInitOptions { 58 | model_name, 59 | execution_providers, 60 | max_length, 61 | cache_dir, 62 | show_download_progress, 63 | } = options; 64 | 65 | let threads = available_parallelism()?.get(); 66 | 67 | let cache = Cache::new(cache_dir); 68 | let api = ApiBuilder::from_cache(cache) 69 | .with_progress(show_download_progress) 70 | .build() 71 | .expect("Failed to build API from cache"); 72 | let model_repo = api.model(model_name.to_string()); 73 | 74 | let model_file_name = TextRerank::get_model_info(&model_name).model_file; 75 | let model_file_reference = model_repo.get(&model_file_name).context(format!( 76 | "Failed to retrieve model file: {}", 77 | model_file_name 78 | ))?; 79 | let additional_files = TextRerank::get_model_info(&model_name).additional_files; 80 | for additional_file in additional_files { 81 | let _additional_file_reference = model_repo.get(&additional_file).context(format!( 82 | "Failed to retrieve additional file: {}", 83 | additional_file 84 | ))?; 85 | } 86 | 87 | let session = Session::builder()? 88 | .with_execution_providers(execution_providers)? 89 | .with_optimization_level(GraphOptimizationLevel::Level3)? 90 | .with_intra_threads(threads)? 91 | .commit_from_file(model_file_reference)?; 92 | 93 | let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?; 94 | Ok(Self::new(tokenizer, session)) 95 | } 96 | 97 | /// Create a TextRerank instance from model files provided by the user. 98 | /// 99 | /// This can be used for 'bring your own' reranking models 100 | pub fn try_new_from_user_defined( 101 | model: UserDefinedRerankingModel, 102 | options: RerankInitOptionsUserDefined, 103 | ) -> Result { 104 | let RerankInitOptionsUserDefined { 105 | execution_providers, 106 | max_length, 107 | } = options; 108 | 109 | let threads = available_parallelism()?.get(); 110 | 111 | let session = Session::builder()? 112 | .with_execution_providers(execution_providers)? 113 | .with_optimization_level(GraphOptimizationLevel::Level3)? 114 | .with_intra_threads(threads)?; 115 | 116 | let session = match &model.onnx_source { 117 | OnnxSource::Memory(bytes) => session.commit_from_memory(bytes)?, 118 | OnnxSource::File(path) => session.commit_from_file(path)?, 119 | }; 120 | 121 | let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?; 122 | Ok(Self::new(tokenizer, session)) 123 | } 124 | 125 | /// Rerank documents using the reranker model and returns the results sorted by score in descending order. 126 | pub fn rerank + Send + Sync>( 127 | &self, 128 | query: S, 129 | documents: Vec, 130 | return_documents: bool, 131 | batch_size: Option, 132 | ) -> Result> { 133 | let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); 134 | 135 | let q = query.as_ref(); 136 | 137 | let scores: Vec = documents 138 | .par_chunks(batch_size) 139 | .map(|batch| { 140 | let inputs = batch.iter().map(|d| (q, d.as_ref())).collect(); 141 | 142 | let encodings = self 143 | .tokenizer 144 | .encode_batch(inputs, true) 145 | .expect("Failed to encode batch"); 146 | 147 | let encoding_length = encodings[0].len(); 148 | let batch_size = batch.len(); 149 | 150 | let max_size = encoding_length * batch_size; 151 | 152 | let mut ids_array = Vec::with_capacity(max_size); 153 | let mut mask_array = Vec::with_capacity(max_size); 154 | let mut type_ids_array = Vec::with_capacity(max_size); 155 | 156 | encodings.iter().for_each(|encoding| { 157 | let ids = encoding.get_ids(); 158 | let mask = encoding.get_attention_mask(); 159 | let type_ids = encoding.get_type_ids(); 160 | 161 | ids_array.extend(ids.iter().map(|x| *x as i64)); 162 | mask_array.extend(mask.iter().map(|x| *x as i64)); 163 | type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); 164 | }); 165 | 166 | let inputs_ids_array = 167 | Array::from_shape_vec((batch_size, encoding_length), ids_array)?; 168 | 169 | let attention_mask_array = 170 | Array::from_shape_vec((batch_size, encoding_length), mask_array)?; 171 | 172 | let token_type_ids_array = 173 | Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; 174 | 175 | let mut session_inputs = ort::inputs![ 176 | "input_ids" => Value::from_array(inputs_ids_array)?, 177 | "attention_mask" => Value::from_array(attention_mask_array)?, 178 | ]?; 179 | 180 | if self.need_token_type_ids { 181 | session_inputs.push(( 182 | "token_type_ids".into(), 183 | Value::from_array(token_type_ids_array)?.into(), 184 | )); 185 | } 186 | 187 | let outputs = self.session.run(session_inputs)?; 188 | 189 | let outputs = outputs["logits"] 190 | .try_extract_tensor::() 191 | .expect("Failed to extract logits tensor"); 192 | 193 | let scores: Vec = outputs 194 | .slice(s![.., 0]) 195 | .rows() 196 | .into_iter() 197 | .flat_map(|row| row.to_vec()) 198 | .collect(); 199 | 200 | Ok(scores) 201 | }) 202 | .collect::>>()? 203 | .into_iter() 204 | .flatten() 205 | .collect(); 206 | 207 | // Return top_n_result of type Vec ordered by score in descending order, don't use binary heap 208 | let mut top_n_result: Vec = scores 209 | .into_iter() 210 | .enumerate() 211 | .map(|(index, score)| RerankResult { 212 | document: return_documents.then(|| documents[index].as_ref().to_string()), 213 | score, 214 | index, 215 | }) 216 | .collect(); 217 | 218 | top_n_result.sort_by(|a, b| a.score.total_cmp(&b.score).reverse()); 219 | 220 | Ok(top_n_result.to_vec()) 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /src/reranking/init.rs: -------------------------------------------------------------------------------- 1 | use std::path::{Path, PathBuf}; 2 | 3 | use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; 4 | use tokenizers::Tokenizer; 5 | 6 | use crate::{common::get_cache_dir, RerankerModel, TokenizerFiles}; 7 | 8 | use super::{DEFAULT_MAX_LENGTH, DEFAULT_RE_RANKER_MODEL}; 9 | 10 | #[derive(Debug)] 11 | pub struct TextRerank { 12 | pub tokenizer: Tokenizer, 13 | pub(crate) session: Session, 14 | pub(crate) need_token_type_ids: bool, 15 | } 16 | 17 | /// Options for initializing the reranking model 18 | #[derive(Debug, Clone)] 19 | #[non_exhaustive] 20 | pub struct RerankInitOptions { 21 | pub model_name: RerankerModel, 22 | pub execution_providers: Vec, 23 | pub max_length: usize, 24 | pub cache_dir: PathBuf, 25 | pub show_download_progress: bool, 26 | } 27 | 28 | impl RerankInitOptions { 29 | pub fn new(model_name: RerankerModel) -> Self { 30 | Self { 31 | model_name, 32 | ..Default::default() 33 | } 34 | } 35 | 36 | pub fn with_max_length(mut self, max_length: usize) -> Self { 37 | self.max_length = max_length; 38 | self 39 | } 40 | 41 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 42 | self.cache_dir = cache_dir; 43 | self 44 | } 45 | 46 | pub fn with_execution_providers( 47 | mut self, 48 | execution_providers: Vec, 49 | ) -> Self { 50 | self.execution_providers = execution_providers; 51 | self 52 | } 53 | 54 | pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { 55 | self.show_download_progress = show_download_progress; 56 | self 57 | } 58 | } 59 | 60 | impl Default for RerankInitOptions { 61 | fn default() -> Self { 62 | Self { 63 | model_name: DEFAULT_RE_RANKER_MODEL, 64 | execution_providers: Default::default(), 65 | max_length: DEFAULT_MAX_LENGTH, 66 | cache_dir: Path::new(&get_cache_dir()).to_path_buf(), 67 | show_download_progress: true, 68 | } 69 | } 70 | } 71 | 72 | /// Options for initializing UserDefinedRerankerModel 73 | /// 74 | /// Model files are held by the UserDefinedRerankerModel struct 75 | #[derive(Debug, Clone)] 76 | #[non_exhaustive] 77 | pub struct RerankInitOptionsUserDefined { 78 | pub execution_providers: Vec, 79 | pub max_length: usize, 80 | } 81 | 82 | impl Default for RerankInitOptionsUserDefined { 83 | fn default() -> Self { 84 | Self { 85 | execution_providers: Default::default(), 86 | max_length: DEFAULT_MAX_LENGTH, 87 | } 88 | } 89 | } 90 | 91 | /// Convert RerankInitOptions to RerankInitOptionsUserDefined 92 | /// 93 | /// This is useful for when the user wants to use the same options for both the default and user-defined models 94 | impl From for RerankInitOptionsUserDefined { 95 | fn from(options: RerankInitOptions) -> Self { 96 | RerankInitOptionsUserDefined { 97 | execution_providers: options.execution_providers, 98 | max_length: options.max_length, 99 | } 100 | } 101 | } 102 | 103 | /// Enum for the source of the onnx file 104 | /// 105 | /// User-defined models can either be in memory or on disk 106 | #[derive(Debug, Clone, PartialEq, Eq)] 107 | pub enum OnnxSource { 108 | Memory(Vec), 109 | File(PathBuf), 110 | } 111 | 112 | impl From> for OnnxSource { 113 | fn from(bytes: Vec) -> Self { 114 | OnnxSource::Memory(bytes) 115 | } 116 | } 117 | 118 | impl From for OnnxSource { 119 | fn from(path: PathBuf) -> Self { 120 | OnnxSource::File(path) 121 | } 122 | } 123 | 124 | /// Struct for "bring your own" reranking models 125 | /// 126 | /// The onnx_file and tokenizer_files are expecting the files' bytes 127 | #[derive(Debug, Clone, PartialEq, Eq)] 128 | #[non_exhaustive] 129 | pub struct UserDefinedRerankingModel { 130 | pub onnx_source: OnnxSource, 131 | pub tokenizer_files: TokenizerFiles, 132 | } 133 | 134 | impl UserDefinedRerankingModel { 135 | pub fn new(onnx_source: impl Into, tokenizer_files: TokenizerFiles) -> Self { 136 | Self { 137 | onnx_source: onnx_source.into(), 138 | tokenizer_files, 139 | } 140 | } 141 | } 142 | 143 | /// Rerank result. 144 | #[derive(Debug, PartialEq, Clone)] 145 | pub struct RerankResult { 146 | pub document: Option, 147 | pub score: f32, 148 | pub index: usize, 149 | } 150 | -------------------------------------------------------------------------------- /src/reranking/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::RerankerModel; 2 | 3 | const DEFAULT_RE_RANKER_MODEL: RerankerModel = RerankerModel::BGERerankerBase; 4 | const DEFAULT_MAX_LENGTH: usize = 512; 5 | const DEFAULT_BATCH_SIZE: usize = 256; 6 | 7 | mod init; 8 | pub use init::*; 9 | 10 | mod r#impl; 11 | -------------------------------------------------------------------------------- /src/sparse_text_embedding/impl.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "hf-hub")] 2 | use crate::common::load_tokenizer_hf_hub; 3 | use crate::{ 4 | models::sparse::{models_list, SparseModel}, 5 | ModelInfo, SparseEmbedding, 6 | }; 7 | #[cfg(feature = "hf-hub")] 8 | use anyhow::Context; 9 | use anyhow::Result; 10 | #[cfg(feature = "hf-hub")] 11 | use hf_hub::api::sync::ApiRepo; 12 | use ndarray::{Array, ArrayViewD, Axis, CowArray, Dim}; 13 | use ort::{session::Session, value::Value}; 14 | #[cfg_attr(not(feature = "hf-hub"), allow(unused_imports))] 15 | use rayon::{iter::ParallelIterator, slice::ParallelSlice}; 16 | #[cfg(feature = "hf-hub")] 17 | use std::path::PathBuf; 18 | use tokenizers::Tokenizer; 19 | 20 | #[cfg_attr(not(feature = "hf-hub"), allow(unused_imports))] 21 | use std::thread::available_parallelism; 22 | 23 | #[cfg(feature = "hf-hub")] 24 | use super::SparseInitOptions; 25 | use super::{SparseTextEmbedding, DEFAULT_BATCH_SIZE}; 26 | 27 | impl SparseTextEmbedding { 28 | /// Try to generate a new SparseTextEmbedding Instance 29 | /// 30 | /// Uses the highest level of Graph optimization 31 | /// 32 | /// Uses the total number of CPUs available as the number of intra-threads 33 | #[cfg(feature = "hf-hub")] 34 | pub fn try_new(options: SparseInitOptions) -> Result { 35 | use super::SparseInitOptions; 36 | use ort::{session::builder::GraphOptimizationLevel, session::Session}; 37 | 38 | let SparseInitOptions { 39 | model_name, 40 | execution_providers, 41 | max_length, 42 | cache_dir, 43 | show_download_progress, 44 | } = options; 45 | 46 | let threads = available_parallelism()?.get(); 47 | 48 | let model_repo = SparseTextEmbedding::retrieve_model( 49 | model_name.clone(), 50 | cache_dir.clone(), 51 | show_download_progress, 52 | )?; 53 | 54 | let model_file_name = SparseTextEmbedding::get_model_info(&model_name).model_file; 55 | let model_file_reference = model_repo 56 | .get(&model_file_name) 57 | .context(format!("Failed to retrieve {} ", model_file_name))?; 58 | 59 | let session = Session::builder()? 60 | .with_execution_providers(execution_providers)? 61 | .with_optimization_level(GraphOptimizationLevel::Level3)? 62 | .with_intra_threads(threads)? 63 | .commit_from_file(model_file_reference)?; 64 | 65 | let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?; 66 | Ok(Self::new(tokenizer, session, model_name)) 67 | } 68 | 69 | /// Private method to return an instance 70 | #[cfg_attr(not(feature = "hf-hub"), allow(dead_code))] 71 | fn new(tokenizer: Tokenizer, session: Session, model: SparseModel) -> Self { 72 | let need_token_type_ids = session 73 | .inputs 74 | .iter() 75 | .any(|input| input.name == "token_type_ids"); 76 | Self { 77 | tokenizer, 78 | session, 79 | need_token_type_ids, 80 | model, 81 | } 82 | } 83 | /// Return the SparseTextEmbedding model's directory from cache or remote retrieval 84 | #[cfg(feature = "hf-hub")] 85 | fn retrieve_model( 86 | model: SparseModel, 87 | cache_dir: PathBuf, 88 | show_download_progress: bool, 89 | ) -> Result { 90 | use crate::common::pull_from_hf; 91 | 92 | pull_from_hf(model.to_string(), cache_dir, show_download_progress) 93 | } 94 | 95 | /// Retrieve a list of supported models 96 | pub fn list_supported_models() -> Vec> { 97 | models_list() 98 | } 99 | 100 | /// Get ModelInfo from SparseModel 101 | pub fn get_model_info(model: &SparseModel) -> ModelInfo { 102 | SparseTextEmbedding::list_supported_models() 103 | .into_iter() 104 | .find(|m| &m.model == model) 105 | .expect("Model not found.") 106 | } 107 | 108 | /// Method to generate sentence embeddings for a Vec of texts 109 | // Generic type to accept String, &str, OsString, &OsStr 110 | pub fn embed + Send + Sync>( 111 | &self, 112 | texts: Vec, 113 | batch_size: Option, 114 | ) -> Result> { 115 | // Determine the batch size, default if not specified 116 | let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); 117 | 118 | let output = texts 119 | .par_chunks(batch_size) 120 | .map(|batch| { 121 | // Encode the texts in the batch 122 | let inputs = batch.iter().map(|text| text.as_ref()).collect(); 123 | let encodings = self.tokenizer.encode_batch(inputs, true).unwrap(); 124 | 125 | // Extract the encoding length and batch size 126 | let encoding_length = encodings[0].len(); 127 | let batch_size = batch.len(); 128 | 129 | let max_size = encoding_length * batch_size; 130 | 131 | // Preallocate arrays with the maximum size 132 | let mut ids_array = Vec::with_capacity(max_size); 133 | let mut mask_array = Vec::with_capacity(max_size); 134 | let mut type_ids_array = Vec::with_capacity(max_size); 135 | 136 | // Not using par_iter because the closure needs to be FnMut 137 | encodings.iter().for_each(|encoding| { 138 | let ids = encoding.get_ids(); 139 | let mask = encoding.get_attention_mask(); 140 | let type_ids = encoding.get_type_ids(); 141 | 142 | // Extend the preallocated arrays with the current encoding 143 | // Requires the closure to be FnMut 144 | ids_array.extend(ids.iter().map(|x| *x as i64)); 145 | mask_array.extend(mask.iter().map(|x| *x as i64)); 146 | type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); 147 | }); 148 | 149 | // Create CowArrays from vectors 150 | let inputs_ids_array = 151 | Array::from_shape_vec((batch_size, encoding_length), ids_array)?; 152 | let owned_attention_mask = 153 | Array::from_shape_vec((batch_size, encoding_length), mask_array)?; 154 | let attention_mask_array = CowArray::from(&owned_attention_mask); 155 | 156 | let token_type_ids_array = 157 | Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; 158 | 159 | let mut session_inputs = ort::inputs![ 160 | "input_ids" => Value::from_array(inputs_ids_array)?, 161 | "attention_mask" => Value::from_array(&attention_mask_array)?, 162 | ]?; 163 | 164 | if self.need_token_type_ids { 165 | session_inputs.push(( 166 | "token_type_ids".into(), 167 | Value::from_array(token_type_ids_array)?.into(), 168 | )); 169 | } 170 | 171 | let outputs = self.session.run(session_inputs)?; 172 | 173 | // Try to get the only output key 174 | // If multiple, then default to `last_hidden_state` 175 | let last_hidden_state_key = match outputs.len() { 176 | 1 => outputs.keys().next().unwrap(), 177 | _ => "last_hidden_state", 178 | }; 179 | 180 | let output_data = outputs[last_hidden_state_key].try_extract_tensor::()?; 181 | 182 | let embeddings = SparseTextEmbedding::post_process( 183 | &self.model, 184 | &output_data, 185 | &attention_mask_array, 186 | ); 187 | 188 | Ok(embeddings) 189 | }) 190 | .collect::>>()? 191 | .into_iter() 192 | .flatten() 193 | .collect(); 194 | 195 | Ok(output) 196 | } 197 | 198 | fn post_process( 199 | model_name: &SparseModel, 200 | model_output: &ArrayViewD, 201 | attention_mask: &CowArray>, 202 | ) -> Vec { 203 | match model_name { 204 | SparseModel::SPLADEPPV1 => { 205 | // Apply ReLU and logarithm transformation 206 | let relu_log = model_output.mapv(|x| (1.0 + x.max(0.0)).ln()); 207 | 208 | // Convert to f32 and expand the dimensions 209 | let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2)); 210 | 211 | // Weight the transformed values by the attention mask 212 | let weighted_log = relu_log * attention_mask; 213 | 214 | // Get the max scores 215 | let scores = weighted_log.fold_axis(Axis(1), f32::NEG_INFINITY, |r, &v| r.max(v)); 216 | 217 | scores 218 | .rows() 219 | .into_iter() 220 | .map(|row_scores| { 221 | let mut values: Vec = Vec::with_capacity(scores.len()); 222 | let mut indices: Vec = Vec::with_capacity(scores.len()); 223 | 224 | row_scores.into_iter().enumerate().for_each(|(idx, f)| { 225 | if *f > 0.0 { 226 | values.push(*f); 227 | indices.push(idx); 228 | } 229 | }); 230 | 231 | SparseEmbedding { values, indices } 232 | }) 233 | .collect() 234 | } 235 | } 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /src/sparse_text_embedding/init.rs: -------------------------------------------------------------------------------- 1 | use std::path::{Path, PathBuf}; 2 | 3 | use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; 4 | use tokenizers::Tokenizer; 5 | 6 | use crate::{common::get_cache_dir, models::sparse::SparseModel, TokenizerFiles}; 7 | 8 | use super::{DEFAULT_EMBEDDING_MODEL, DEFAULT_MAX_LENGTH}; 9 | 10 | /// Options for initializing the SparseTextEmbedding model 11 | #[derive(Debug, Clone)] 12 | #[non_exhaustive] 13 | pub struct SparseInitOptions { 14 | pub model_name: SparseModel, 15 | pub execution_providers: Vec, 16 | pub max_length: usize, 17 | pub cache_dir: PathBuf, 18 | pub show_download_progress: bool, 19 | } 20 | 21 | impl SparseInitOptions { 22 | pub fn new(model_name: SparseModel) -> Self { 23 | Self { 24 | model_name, 25 | ..Default::default() 26 | } 27 | } 28 | 29 | pub fn with_max_length(mut self, max_length: usize) -> Self { 30 | self.max_length = max_length; 31 | self 32 | } 33 | 34 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 35 | self.cache_dir = cache_dir; 36 | self 37 | } 38 | 39 | pub fn with_execution_providers( 40 | mut self, 41 | execution_providers: Vec, 42 | ) -> Self { 43 | self.execution_providers = execution_providers; 44 | self 45 | } 46 | 47 | pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { 48 | self.show_download_progress = show_download_progress; 49 | self 50 | } 51 | } 52 | 53 | impl Default for SparseInitOptions { 54 | fn default() -> Self { 55 | Self { 56 | model_name: DEFAULT_EMBEDDING_MODEL, 57 | execution_providers: Default::default(), 58 | max_length: DEFAULT_MAX_LENGTH, 59 | cache_dir: Path::new(&get_cache_dir()).to_path_buf(), 60 | show_download_progress: true, 61 | } 62 | } 63 | } 64 | 65 | /// Struct for "bring your own" embedding models 66 | /// 67 | /// The onnx_file and tokenizer_files are expecting the files' bytes 68 | #[derive(Debug, Clone, PartialEq, Eq)] 69 | #[non_exhaustive] 70 | pub struct UserDefinedSparseModel { 71 | pub onnx_file: Vec, 72 | pub tokenizer_files: TokenizerFiles, 73 | } 74 | 75 | impl UserDefinedSparseModel { 76 | pub fn new(onnx_file: Vec, tokenizer_files: TokenizerFiles) -> Self { 77 | Self { 78 | onnx_file, 79 | tokenizer_files, 80 | } 81 | } 82 | } 83 | 84 | /// Rust representation of the SparseTextEmbedding model 85 | pub struct SparseTextEmbedding { 86 | pub tokenizer: Tokenizer, 87 | pub(crate) session: Session, 88 | pub(crate) need_token_type_ids: bool, 89 | pub(crate) model: SparseModel, 90 | } 91 | -------------------------------------------------------------------------------- /src/sparse_text_embedding/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::models::sparse::SparseModel; 2 | 3 | const DEFAULT_BATCH_SIZE: usize = 256; 4 | const DEFAULT_MAX_LENGTH: usize = 512; 5 | const DEFAULT_EMBEDDING_MODEL: SparseModel = SparseModel::SPLADEPPV1; 6 | 7 | mod init; 8 | pub use init::*; 9 | 10 | mod r#impl; 11 | -------------------------------------------------------------------------------- /src/text_embedding/impl.rs: -------------------------------------------------------------------------------- 1 | //! The definition of the main struct for text embeddings - [`TextEmbedding`]. 2 | 3 | #[cfg(feature = "hf-hub")] 4 | use crate::common::load_tokenizer_hf_hub; 5 | use crate::{ 6 | common::load_tokenizer, 7 | models::text_embedding::{get_model_info, models_list}, 8 | pooling::Pooling, 9 | Embedding, EmbeddingModel, EmbeddingOutput, ModelInfo, QuantizationMode, SingleBatchOutput, 10 | }; 11 | #[cfg(feature = "hf-hub")] 12 | use anyhow::Context; 13 | use anyhow::Result; 14 | #[cfg(feature = "hf-hub")] 15 | use hf_hub::api::sync::ApiRepo; 16 | use ndarray::Array; 17 | use ort::{ 18 | session::{builder::GraphOptimizationLevel, Session}, 19 | value::Value, 20 | }; 21 | use rayon::{ 22 | iter::{FromParallelIterator, ParallelIterator}, 23 | slice::ParallelSlice, 24 | }; 25 | #[cfg(feature = "hf-hub")] 26 | use std::path::PathBuf; 27 | use std::thread::available_parallelism; 28 | use tokenizers::Tokenizer; 29 | 30 | #[cfg(feature = "hf-hub")] 31 | use super::InitOptions; 32 | use super::{ 33 | output, InitOptionsUserDefined, TextEmbedding, UserDefinedEmbeddingModel, DEFAULT_BATCH_SIZE, 34 | }; 35 | 36 | impl TextEmbedding { 37 | /// Try to generate a new TextEmbedding Instance 38 | /// 39 | /// Uses the highest level of Graph optimization 40 | /// 41 | /// Uses the total number of CPUs available as the number of intra-threads 42 | #[cfg(feature = "hf-hub")] 43 | pub fn try_new(options: InitOptions) -> Result { 44 | let InitOptions { 45 | model_name, 46 | execution_providers, 47 | max_length, 48 | cache_dir, 49 | show_download_progress, 50 | } = options; 51 | 52 | let threads = available_parallelism()?.get(); 53 | 54 | let model_repo = TextEmbedding::retrieve_model( 55 | model_name.clone(), 56 | cache_dir.clone(), 57 | show_download_progress, 58 | )?; 59 | 60 | let model_info = TextEmbedding::get_model_info(&model_name)?; 61 | let model_file_name = &model_info.model_file; 62 | let model_file_reference = model_repo 63 | .get(model_file_name) 64 | .context(format!("Failed to retrieve {}", model_file_name))?; 65 | 66 | if !model_info.additional_files.is_empty() { 67 | for file in &model_info.additional_files { 68 | model_repo 69 | .get(file) 70 | .context(format!("Failed to retrieve {}", file))?; 71 | } 72 | } 73 | 74 | // prioritise loading pooling config if available, if not (thanks qdrant!), look for it in hardcoded 75 | let post_processing = TextEmbedding::get_default_pooling_method(&model_name); 76 | 77 | let session = Session::builder()? 78 | .with_execution_providers(execution_providers)? 79 | .with_optimization_level(GraphOptimizationLevel::Level3)? 80 | .with_intra_threads(threads)? 81 | .commit_from_file(model_file_reference)?; 82 | 83 | let tokenizer = load_tokenizer_hf_hub(model_repo, max_length)?; 84 | Ok(Self::new( 85 | tokenizer, 86 | session, 87 | post_processing, 88 | TextEmbedding::get_quantization_mode(&model_name), 89 | )) 90 | } 91 | 92 | /// Create a TextEmbedding instance from model files provided by the user. 93 | /// 94 | /// This can be used for 'bring your own' embedding models 95 | pub fn try_new_from_user_defined( 96 | model: UserDefinedEmbeddingModel, 97 | options: InitOptionsUserDefined, 98 | ) -> Result { 99 | let InitOptionsUserDefined { 100 | execution_providers, 101 | max_length, 102 | } = options; 103 | 104 | let threads = available_parallelism()?.get(); 105 | 106 | let session = Session::builder()? 107 | .with_execution_providers(execution_providers)? 108 | .with_optimization_level(GraphOptimizationLevel::Level3)? 109 | .with_intra_threads(threads)? 110 | .commit_from_memory(&model.onnx_file)?; 111 | 112 | let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?; 113 | Ok(Self::new( 114 | tokenizer, 115 | session, 116 | model.pooling, 117 | model.quantization, 118 | )) 119 | } 120 | 121 | /// Private method to return an instance 122 | fn new( 123 | tokenizer: Tokenizer, 124 | session: Session, 125 | post_process: Option, 126 | quantization: QuantizationMode, 127 | ) -> Self { 128 | let need_token_type_ids = session 129 | .inputs 130 | .iter() 131 | .any(|input| input.name == "token_type_ids"); 132 | 133 | Self { 134 | tokenizer, 135 | session, 136 | need_token_type_ids, 137 | pooling: post_process, 138 | quantization, 139 | } 140 | } 141 | /// Return the TextEmbedding model's directory from cache or remote retrieval 142 | #[cfg(feature = "hf-hub")] 143 | fn retrieve_model( 144 | model: EmbeddingModel, 145 | cache_dir: PathBuf, 146 | show_download_progress: bool, 147 | ) -> anyhow::Result { 148 | use crate::common::pull_from_hf; 149 | 150 | pull_from_hf(model.to_string(), cache_dir, show_download_progress) 151 | } 152 | 153 | pub fn get_default_pooling_method(model_name: &EmbeddingModel) -> Option { 154 | match model_name { 155 | EmbeddingModel::AllMiniLML6V2 => Some(Pooling::Mean), 156 | EmbeddingModel::AllMiniLML6V2Q => Some(Pooling::Mean), 157 | EmbeddingModel::AllMiniLML12V2 => Some(Pooling::Mean), 158 | EmbeddingModel::AllMiniLML12V2Q => Some(Pooling::Mean), 159 | 160 | EmbeddingModel::BGEBaseENV15 => Some(Pooling::Cls), 161 | EmbeddingModel::BGEBaseENV15Q => Some(Pooling::Cls), 162 | EmbeddingModel::BGELargeENV15 => Some(Pooling::Cls), 163 | EmbeddingModel::BGELargeENV15Q => Some(Pooling::Cls), 164 | EmbeddingModel::BGESmallENV15 => Some(Pooling::Cls), 165 | EmbeddingModel::BGESmallENV15Q => Some(Pooling::Cls), 166 | EmbeddingModel::BGESmallZHV15 => Some(Pooling::Cls), 167 | EmbeddingModel::BGELargeZHV15 => Some(Pooling::Cls), 168 | 169 | EmbeddingModel::NomicEmbedTextV1 => Some(Pooling::Mean), 170 | EmbeddingModel::NomicEmbedTextV15 => Some(Pooling::Mean), 171 | EmbeddingModel::NomicEmbedTextV15Q => Some(Pooling::Mean), 172 | 173 | EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean), 174 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean), 175 | EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean), 176 | 177 | EmbeddingModel::ModernBertEmbedLarge => Some(Pooling::Mean), 178 | 179 | EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean), 180 | EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean), 181 | EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean), 182 | 183 | EmbeddingModel::MxbaiEmbedLargeV1 => Some(Pooling::Cls), 184 | EmbeddingModel::MxbaiEmbedLargeV1Q => Some(Pooling::Cls), 185 | 186 | EmbeddingModel::GTEBaseENV15 => Some(Pooling::Cls), 187 | EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls), 188 | EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls), 189 | EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls), 190 | 191 | EmbeddingModel::ClipVitB32 => Some(Pooling::Mean), 192 | 193 | EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean), 194 | } 195 | } 196 | 197 | /// Get the quantization mode of the model. 198 | /// 199 | /// Any models with a `Q` suffix in their name are quantized models. 200 | /// 201 | /// Currently only 6 supported models have dynamic quantization: 202 | /// - Alibaba-NLP/gte-base-en-v1.5 203 | /// - Alibaba-NLP/gte-large-en-v1.5 204 | /// - mixedbread-ai/mxbai-embed-large-v1 205 | /// - nomic-ai/nomic-embed-text-v1.5 206 | /// - Xenova/all-MiniLM-L12-v2 207 | /// - Xenova/all-MiniLM-L6-v2 208 | /// 209 | // TODO: Update this list when more models are added 210 | pub fn get_quantization_mode(model_name: &EmbeddingModel) -> QuantizationMode { 211 | match model_name { 212 | EmbeddingModel::AllMiniLML6V2Q => QuantizationMode::Dynamic, 213 | EmbeddingModel::AllMiniLML12V2Q => QuantizationMode::Dynamic, 214 | EmbeddingModel::BGEBaseENV15Q => QuantizationMode::Static, 215 | EmbeddingModel::BGELargeENV15Q => QuantizationMode::Static, 216 | EmbeddingModel::BGESmallENV15Q => QuantizationMode::Static, 217 | EmbeddingModel::NomicEmbedTextV15Q => QuantizationMode::Dynamic, 218 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => QuantizationMode::Static, 219 | EmbeddingModel::MxbaiEmbedLargeV1Q => QuantizationMode::Dynamic, 220 | EmbeddingModel::GTEBaseENV15Q => QuantizationMode::Dynamic, 221 | EmbeddingModel::GTELargeENV15Q => QuantizationMode::Dynamic, 222 | _ => QuantizationMode::None, 223 | } 224 | } 225 | 226 | /// Retrieve a list of supported models 227 | pub fn list_supported_models() -> Vec> { 228 | models_list() 229 | } 230 | 231 | /// Get ModelInfo from EmbeddingModel 232 | pub fn get_model_info(model: &EmbeddingModel) -> Result<&ModelInfo> { 233 | get_model_info(model).ok_or_else(|| { 234 | anyhow::Error::msg(format!( 235 | "Model {model:?} not found. Please check if the model is supported \ 236 | by the current version." 237 | )) 238 | }) 239 | } 240 | 241 | /// Method to generate an [`ort::SessionOutputs`] wrapped in a [`EmbeddingOutput`] 242 | /// instance, which can be used to extract the embeddings with default or custom 243 | /// methods as well as output key precedence. 244 | /// 245 | /// Metadata that could be useful for creating the array transformer is 246 | /// returned alongside the [`EmbeddingOutput`] instance, such as pooling methods 247 | /// etc. 248 | /// 249 | /// # Note 250 | /// 251 | /// This is a lower level method than [`TextEmbedding::embed`], and is useful 252 | /// when you need to extract the session outputs in a custom way. 253 | /// 254 | /// If you want to extract the embeddings directly, use [`TextEmbedding::embed`]. 255 | /// 256 | /// If you want to use the raw session outputs, use [`EmbeddingOutput::into_raw`] 257 | /// on the output of this method. 258 | /// 259 | /// If you want to choose a different export key or customize the way the batch 260 | /// arrays are aggregated, you can define your own array transformer 261 | /// and use it on [`EmbeddingOutput::export_with_transformer`] to extract the 262 | /// embeddings with your custom output type. 263 | pub fn transform<'e, 'r, 's, S: AsRef + Send + Sync>( 264 | &'e self, 265 | texts: Vec, 266 | batch_size: Option, 267 | ) -> Result> 268 | where 269 | 'e: 'r, 270 | 'e: 's, 271 | { 272 | // Determine the batch size according to the quantization method used. 273 | // Default if not specified 274 | let batch_size = match self.quantization { 275 | QuantizationMode::Dynamic => { 276 | if let Some(batch_size) = batch_size { 277 | if batch_size < texts.len() { 278 | Err(anyhow::Error::msg( 279 | "Dynamic quantization cannot be used with batching. \ 280 | This is due to the dynamic quantization process adjusting \ 281 | the data range to fit each batch, making the embeddings \ 282 | incompatible across batches. Try specifying a batch size \ 283 | of `None`, or use a model with static or no quantization.", 284 | )) 285 | } else { 286 | Ok(texts.len()) 287 | } 288 | } else { 289 | Ok(texts.len()) 290 | } 291 | } 292 | _ => Ok(batch_size.unwrap_or(DEFAULT_BATCH_SIZE)), 293 | }?; 294 | 295 | let batches = Result::>::from_par_iter(texts.par_chunks(batch_size).map(|batch| { 296 | // Encode the texts in the batch 297 | let inputs = batch.iter().map(|text| text.as_ref()).collect(); 298 | let encodings = self.tokenizer.encode_batch(inputs, true).map_err(|e| { 299 | anyhow::Error::msg(e.to_string()).context("Failed to encode the batch.") 300 | })?; 301 | 302 | // Extract the encoding length and batch size 303 | let encoding_length = encodings[0].len(); 304 | let batch_size = batch.len(); 305 | 306 | let max_size = encoding_length * batch_size; 307 | 308 | // Preallocate arrays with the maximum size 309 | let mut ids_array = Vec::with_capacity(max_size); 310 | let mut mask_array = Vec::with_capacity(max_size); 311 | let mut type_ids_array = Vec::with_capacity(max_size); 312 | 313 | // Not using par_iter because the closure needs to be FnMut 314 | encodings.iter().for_each(|encoding| { 315 | let ids = encoding.get_ids(); 316 | let mask = encoding.get_attention_mask(); 317 | let type_ids = encoding.get_type_ids(); 318 | 319 | // Extend the preallocated arrays with the current encoding 320 | // Requires the closure to be FnMut 321 | ids_array.extend(ids.iter().map(|x| *x as i64)); 322 | mask_array.extend(mask.iter().map(|x| *x as i64)); 323 | type_ids_array.extend(type_ids.iter().map(|x| *x as i64)); 324 | }); 325 | 326 | // Create CowArrays from vectors 327 | let inputs_ids_array = Array::from_shape_vec((batch_size, encoding_length), ids_array)?; 328 | 329 | let attention_mask_array = 330 | Array::from_shape_vec((batch_size, encoding_length), mask_array)?; 331 | 332 | let token_type_ids_array = 333 | Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?; 334 | 335 | let mut session_inputs = ort::inputs![ 336 | "input_ids" => Value::from_array(inputs_ids_array)?, 337 | "attention_mask" => Value::from_array(attention_mask_array.view())?, 338 | ]?; 339 | 340 | if self.need_token_type_ids { 341 | session_inputs.push(( 342 | "token_type_ids".into(), 343 | Value::from_array(token_type_ids_array)?.into(), 344 | )); 345 | } 346 | 347 | Ok( 348 | // Package all the data required for post-processing (e.g. pooling) 349 | // into a SingleBatchOutput struct. 350 | SingleBatchOutput { 351 | session_outputs: self 352 | .session 353 | .run(session_inputs) 354 | .map_err(anyhow::Error::new)?, 355 | attention_mask_array, 356 | }, 357 | ) 358 | }))?; 359 | 360 | Ok(EmbeddingOutput::new(batches)) 361 | } 362 | 363 | /// Method to generate sentence embeddings for a Vec of texts. 364 | /// 365 | /// Accepts a [`Vec`] consisting of elements of either [`String`], &[`str`], 366 | /// [`std::ffi::OsString`], &[`std::ffi::OsStr`]. 367 | /// 368 | /// The output is a [`Vec`] of [`Embedding`]s. 369 | /// 370 | /// # Note 371 | /// 372 | /// This method is a higher level method than [`TextEmbedding::transform`] by utilizing 373 | /// the default output precedence and array transformer for the [`TextEmbedding`] model. 374 | pub fn embed + Send + Sync>( 375 | &self, 376 | texts: Vec, 377 | batch_size: Option, 378 | ) -> Result> { 379 | let batches = self.transform(texts, batch_size)?; 380 | 381 | batches.export_with_transformer(output::transformer_with_precedence( 382 | output::OUTPUT_TYPE_PRECEDENCE, 383 | self.pooling.clone(), 384 | )) 385 | } 386 | } 387 | -------------------------------------------------------------------------------- /src/text_embedding/init.rs: -------------------------------------------------------------------------------- 1 | //! Initialization options for the text embedding models. 2 | //! 3 | 4 | use crate::{ 5 | common::TokenizerFiles, get_cache_dir, pooling::Pooling, EmbeddingModel, QuantizationMode, 6 | }; 7 | use ort::{execution_providers::ExecutionProviderDispatch, session::Session}; 8 | use std::path::{Path, PathBuf}; 9 | use tokenizers::Tokenizer; 10 | 11 | use super::{DEFAULT_EMBEDDING_MODEL, DEFAULT_MAX_LENGTH}; 12 | 13 | /// Options for initializing the TextEmbedding model 14 | #[derive(Debug, Clone)] 15 | #[non_exhaustive] 16 | pub struct InitOptions { 17 | pub model_name: EmbeddingModel, 18 | pub execution_providers: Vec, 19 | pub max_length: usize, 20 | pub cache_dir: PathBuf, 21 | pub show_download_progress: bool, 22 | } 23 | 24 | impl InitOptions { 25 | /// Create a new InitOptions with the given model name 26 | pub fn new(model_name: EmbeddingModel) -> Self { 27 | Self { 28 | model_name, 29 | ..Default::default() 30 | } 31 | } 32 | 33 | /// Set the maximum length of the input text 34 | pub fn with_max_length(mut self, max_length: usize) -> Self { 35 | self.max_length = max_length; 36 | self 37 | } 38 | 39 | /// Set the cache directory for the model files 40 | pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { 41 | self.cache_dir = cache_dir; 42 | self 43 | } 44 | 45 | /// Set the execution providers for the model 46 | pub fn with_execution_providers( 47 | mut self, 48 | execution_providers: Vec, 49 | ) -> Self { 50 | self.execution_providers = execution_providers; 51 | self 52 | } 53 | 54 | /// Set whether to show download progress 55 | pub fn with_show_download_progress(mut self, show_download_progress: bool) -> Self { 56 | self.show_download_progress = show_download_progress; 57 | self 58 | } 59 | } 60 | 61 | impl Default for InitOptions { 62 | fn default() -> Self { 63 | Self { 64 | model_name: DEFAULT_EMBEDDING_MODEL, 65 | execution_providers: Default::default(), 66 | max_length: DEFAULT_MAX_LENGTH, 67 | cache_dir: Path::new(&get_cache_dir()).to_path_buf(), 68 | show_download_progress: true, 69 | } 70 | } 71 | } 72 | 73 | /// Options for initializing UserDefinedEmbeddingModel 74 | /// 75 | /// Model files are held by the UserDefinedEmbeddingModel struct 76 | #[derive(Debug, Clone)] 77 | #[non_exhaustive] 78 | pub struct InitOptionsUserDefined { 79 | pub execution_providers: Vec, 80 | pub max_length: usize, 81 | } 82 | 83 | impl InitOptionsUserDefined { 84 | pub fn new() -> Self { 85 | Self { 86 | ..Default::default() 87 | } 88 | } 89 | 90 | pub fn with_execution_providers( 91 | mut self, 92 | execution_providers: Vec, 93 | ) -> Self { 94 | self.execution_providers = execution_providers; 95 | self 96 | } 97 | 98 | pub fn with_max_length(mut self, max_length: usize) -> Self { 99 | self.max_length = max_length; 100 | self 101 | } 102 | } 103 | 104 | impl Default for InitOptionsUserDefined { 105 | fn default() -> Self { 106 | Self { 107 | execution_providers: Default::default(), 108 | max_length: DEFAULT_MAX_LENGTH, 109 | } 110 | } 111 | } 112 | 113 | /// Convert InitOptions to InitOptionsUserDefined 114 | /// 115 | /// This is useful for when the user wants to use the same options for both the default and user-defined models 116 | impl From for InitOptionsUserDefined { 117 | fn from(options: InitOptions) -> Self { 118 | InitOptionsUserDefined { 119 | execution_providers: options.execution_providers, 120 | max_length: options.max_length, 121 | } 122 | } 123 | } 124 | 125 | /// Struct for "bring your own" embedding models 126 | /// 127 | /// The onnx_file and tokenizer_files are expecting the files' bytes 128 | #[derive(Debug, Clone, PartialEq, Eq)] 129 | #[non_exhaustive] 130 | pub struct UserDefinedEmbeddingModel { 131 | pub onnx_file: Vec, 132 | pub tokenizer_files: TokenizerFiles, 133 | pub pooling: Option, 134 | pub quantization: QuantizationMode, 135 | } 136 | 137 | impl UserDefinedEmbeddingModel { 138 | pub fn new(onnx_file: Vec, tokenizer_files: TokenizerFiles) -> Self { 139 | Self { 140 | onnx_file, 141 | tokenizer_files, 142 | quantization: QuantizationMode::None, 143 | pooling: None, 144 | } 145 | } 146 | 147 | pub fn with_quantization(mut self, quantization: QuantizationMode) -> Self { 148 | self.quantization = quantization; 149 | self 150 | } 151 | 152 | pub fn with_pooling(mut self, pooling: Pooling) -> Self { 153 | self.pooling = Some(pooling); 154 | self 155 | } 156 | } 157 | 158 | /// Rust representation of the TextEmbedding model 159 | pub struct TextEmbedding { 160 | pub tokenizer: Tokenizer, 161 | pub(crate) pooling: Option, 162 | pub(crate) session: Session, 163 | pub(crate) need_token_type_ids: bool, 164 | pub(crate) quantization: QuantizationMode, 165 | } 166 | -------------------------------------------------------------------------------- /src/text_embedding/mod.rs: -------------------------------------------------------------------------------- 1 | //! Text embedding module, containing the main struct [TextEmbedding] and its 2 | //! initialization options. 3 | 4 | use crate::models::text_embedding::EmbeddingModel; 5 | 6 | // Constants. 7 | const DEFAULT_BATCH_SIZE: usize = 256; 8 | const DEFAULT_MAX_LENGTH: usize = 512; 9 | const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15; 10 | 11 | // Output precedence and transforming functions. 12 | pub mod output; 13 | 14 | // Initialization options. 15 | mod init; 16 | pub use init::*; 17 | 18 | // The implementation of the embedding models. 19 | mod r#impl; 20 | -------------------------------------------------------------------------------- /src/text_embedding/output.rs: -------------------------------------------------------------------------------- 1 | //! Output types and functions for the [`TextEmbedding`] model. 2 | //! 3 | use crate::{ 4 | common::{normalize, Embedding}, 5 | output::{OutputKey, OutputPrecedence, SingleBatchOutput}, 6 | pooling::Pooling, 7 | }; 8 | 9 | #[cfg(doc)] 10 | use super::TextEmbedding; 11 | 12 | /// The default output precedence for the TextEmbedding model. 13 | pub const OUTPUT_TYPE_PRECEDENCE: &[OutputKey] = &[ 14 | OutputKey::OnlyOne, 15 | OutputKey::ByName("last_hidden_state"), 16 | OutputKey::ByName("sentence_embedding"), 17 | // Better not to expose this unless the user explicitly asks for it. 18 | // OutputKey::ByName("token_embeddings"), 19 | ]; 20 | 21 | /// Generates thea default array transformer for the [`TextEmbedding`] model using the 22 | /// provided output precedence. 23 | /// 24 | // TODO (denwong47): now that pooling is done in SingleBatchOutput, it is possible that 25 | // all the models will use this same generic transformer. Move this into SingleBatchOutput? 26 | #[allow(unused_variables)] 27 | pub fn transformer_with_precedence( 28 | output_precedence: impl OutputPrecedence, 29 | pooling: Option, 30 | ) -> impl Fn(&[SingleBatchOutput]) -> anyhow::Result> { 31 | move |batches| { 32 | // Not using `par_iter` here: the operations here is probably not 33 | // computationally expensive enough to warrant spinning up costs of the threads. 34 | batches 35 | .iter() 36 | .map(|batch| { 37 | batch 38 | .select_and_pool_output(&output_precedence, pooling.clone()) 39 | .map(|array| { 40 | array 41 | .rows() 42 | .into_iter() 43 | .map(|row| normalize(row.as_slice().unwrap())) 44 | .collect::>() 45 | }) 46 | }) 47 | .try_fold(Vec::new(), |mut acc, res| { 48 | acc.extend(res?); 49 | Ok(acc) 50 | }) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /tests/assets/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anush008/fastembed-rs/6dd82151820847417f6f4696c7af005787c5d363/tests/assets/image_0.png -------------------------------------------------------------------------------- /tests/assets/image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anush008/fastembed-rs/6dd82151820847417f6f4696c7af005787c5d363/tests/assets/image_1.png -------------------------------------------------------------------------------- /tests/assets/sample_text.txt: -------------------------------------------------------------------------------- 1 | animals environment general health health general weight philosophy ethics Being vegetarian helps the environment Becoming a vegetarian is an environmentally friendly thing to do. Modern farming is one of the main sources of pollution in our rivers. Beef farming is one of the main causes of deforestation, and as long as people continue to buy fast food in their billions, there will be a financial incentive to continue cutting down trees to make room for cattle. Because of our desire to eat fish, our rivers and seas are being emptied of fish and many species are facing extinction. Energy resources are used up much more greedily by meat farming than my farming cereals, pulses etc. Eating meat and fish not only causes cruelty to animals, it causes serious harm to the environment and to biodiversity. For example consider Meat production related pollution and deforestation At Toronto’s 1992 Royal Agricultural Winter Fair, Agriculture Canada displayed two contrasting statistics: “it takes four football fields of land (about 1.6 hectares) to feed each Canadian” and “one apple tree produces enough fruit to make 320 pies.” Think about it — a couple of apple trees and a few rows of wheat on a mere fraction of a hectare could produce enough food for one person! [1] The 2006 U.N. Food and Agriculture Organization (FAO) report concluded that worldwide livestock farming generates 18% of the planet's greenhouse gas emissions — by comparison, all the world's cars, trains, planes and boats account for a combined 13% of greenhouse gas emissions. [2] As a result of the above point producing meat damages the environment. The demand for meat drives deforestation. Daniel Cesar Avelino of Brazil's Federal Public Prosecution Office says “We know that the single biggest driver of deforestation in the Amazon is cattle.” This clearing of tropical rainforests such as the Amazon for agriculture is estimated to produce 17% of the world's greenhouse gas emissions. [3] Not only this but the production of meat takes a lot more energy than it ultimately gives us chicken meat production consumes energy in a 4:1 ratio to protein output; beef cattle production requires an energy input to protein output ratio of 54:1. The same is true with water use due to the same phenomenon of meat being inefficient to produce in terms of the amount of grain needed to produce the same weight of meat, production requires a lot of water. Water is another scarce resource that we will soon not have enough of in various areas of the globe. Grain-fed beef production takes 100,000 liters of water for every kilogram of food. Raising broiler chickens takes 3,500 liters of water to make a kilogram of meat. In comparison, soybean production uses 2,000 liters for kilogram of food produced; rice, 1,912; wheat, 900; and potatoes, 500 liters. [4] This is while there are areas of the globe that have severe water shortages. With farming using up to 70 times more water than is used for domestic purposes: cooking and washing. A third of the population of the world is already suffering from a shortage of water. [5] Groundwater levels are falling all over the world and rivers are beginning to dry up. Already some of the biggest rivers such as China’s Yellow river do not reach the sea. [6] With a rising population becoming vegetarian is the only responsible way to eat. [1] Stephen Leckie, ‘How Meat-centred Eating Patterns Affect Food Security and the Environment’, International development research center [2] Bryan Walsh, Meat: Making Global Warming Worse, Time magazine, 10 September 2008 . [3] David Adam, Supermarket suppliers ‘helping to destroy Amazon rainforest’, The Guardian, 21st June 2009. [4] Roger Segelken, U.S. could feed 800 million people with grain that livestock eat, Cornell Science News, 7th August 1997. [5] Fiona Harvey, Water scarcity affects one in three, FT.com, 21st August 2003 [6] Rupert Wingfield-Hayes, Yellow river ‘drying up’, BBC News, 29th July 2004 -------------------------------------------------------------------------------- /tests/embeddings.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "hf-hub")] 2 | 3 | use std::fs; 4 | use std::path::Path; 5 | 6 | use hf_hub::Repo; 7 | use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; 8 | 9 | use fastembed::{ 10 | get_cache_dir, read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding, 11 | ImageEmbeddingModel, ImageInitOptions, InitOptions, InitOptionsUserDefined, ModelInfo, 12 | OnnxSource, Pooling, QuantizationMode, RerankInitOptions, RerankInitOptionsUserDefined, 13 | RerankerModel, RerankerModelInfo, SparseInitOptions, SparseTextEmbedding, TextEmbedding, 14 | TextRerank, TokenizerFiles, UserDefinedEmbeddingModel, UserDefinedRerankingModel, 15 | }; 16 | 17 | /// A small epsilon value for floating point comparisons. 18 | const EPS: f32 = 1e-2; 19 | 20 | /// Precalculated embeddings for the supported models using #99 21 | /// (4f09b6842ce1fcfaf6362678afcad9a176e05304). 22 | /// 23 | /// These are the sum of all embedding values for each document. While not 24 | /// perfect, they should be good enough to verify that the embeddings are being 25 | /// generated correctly. 26 | /// 27 | /// If you have just inserted a new `EmbeddingModel` variant, please update the 28 | /// expected embeddings. 29 | /// 30 | /// # Returns 31 | /// 32 | /// If the embeddings are correct, this function returns `Ok(())`. If there are 33 | /// any mismatches, it returns `Err(Vec)` with the indices of the 34 | /// mismatched embeddings. 35 | #[allow(unreachable_patterns)] 36 | fn verify_embeddings(model: &EmbeddingModel, embeddings: &[Embedding]) -> Result<(), Vec> { 37 | let expected = match model { 38 | EmbeddingModel::AllMiniLML12V2 => [-0.12147753, 0.30144796, -0.06882502, -0.6303331], 39 | EmbeddingModel::AllMiniLML12V2Q => [-0.07808663, 0.27919534, -0.0770612, -0.75660324], 40 | EmbeddingModel::AllMiniLML6V2 => [0.59605527, 0.36542925, -0.16450031, -0.40903988], 41 | EmbeddingModel::AllMiniLML6V2Q => [0.5677276, 0.40180072, -0.15454668, -0.4672576], 42 | EmbeddingModel::BGEBaseENV15 => [-0.51290065, -0.4844747, -0.53036124, -0.5337459], 43 | EmbeddingModel::BGEBaseENV15Q => [-0.5130697, -0.48461288, -0.53067875, -0.5337806], 44 | EmbeddingModel::BGELargeENV15 => [-0.19347441, -0.28394595, -0.1549195, -0.22201893], 45 | EmbeddingModel::BGELargeENV15Q => [-0.19366685, -0.2842059, -0.15471499, -0.22216901], 46 | EmbeddingModel::BGESmallENV15 => [0.09881669, 0.15151203, 0.12057499, 0.13641948], 47 | EmbeddingModel::BGESmallENV15Q => [0.09881936, 0.15154803, 0.12057378, 0.13639033], 48 | EmbeddingModel::BGESmallZHV15 => [-1.1194772, -1.0928253, -1.0325904, -1.0050416], 49 | EmbeddingModel::BGELargeZHV15 => [-0.62066114, -0.76666945, -0.7013123, -0.86202735], 50 | EmbeddingModel::GTEBaseENV15 => [-1.6900877, -1.7148916, -1.7333382, -1.5121834], 51 | EmbeddingModel::GTEBaseENV15Q => [-1.7032102, -1.7076654, -1.729326, -1.5317788], 52 | EmbeddingModel::GTELargeENV15 => [-1.6457459, -1.6582386, -1.6809471, -1.6070237], 53 | EmbeddingModel::GTELargeENV15Q => [-1.6044945, -1.6469251, -1.6828246, -1.6265479], 54 | EmbeddingModel::ModernBertEmbedLarge => [ 0.24799639, 0.32174295, 0.17255782, 0.32919246], 55 | EmbeddingModel::MultilingualE5Base => [-0.057211064, -0.14287914, -0.071678676, -0.17549144], 56 | EmbeddingModel::MultilingualE5Large => [-0.7473163, -0.76040405, -0.7537941, -0.72920954], 57 | EmbeddingModel::MultilingualE5Small => [-0.2640718, -0.13929011, -0.08091972, -0.12388548], 58 | EmbeddingModel::MxbaiEmbedLargeV1 => [-0.2032495, -0.29803938, -0.15803768, -0.23155808], 59 | EmbeddingModel::MxbaiEmbedLargeV1Q => [-0.1811538, -0.2884392, -0.1636593, -0.21548103], 60 | EmbeddingModel::NomicEmbedTextV1 => [0.13788113, 0.10750078, 0.050809078, 0.09284662], 61 | EmbeddingModel::NomicEmbedTextV15 => [0.1932303, 0.13795732, 0.14700879, 0.14940643], 62 | EmbeddingModel::NomicEmbedTextV15Q => [0.20999804, 0.17161125, 0.14427708, 0.19436662], 63 | EmbeddingModel::ParaphraseMLMiniLML12V2 => [-0.07795018, -0.059113946, -0.043668486, -0.1880083], 64 | EmbeddingModel::ParaphraseMLMiniLML12V2Q => [-0.07749095, -0.058981877, -0.043487836, -0.18775631], 65 | EmbeddingModel::ParaphraseMLMpnetBaseV2 => [0.39132136, 0.49490625, 0.65497226, 0.34237382], 66 | EmbeddingModel::ClipVitB32 => [0.7057363, 1.3549932, 0.46823958, 0.52351093], 67 | EmbeddingModel::JinaEmbeddingsV2BaseCode => [-0.31383067, -0.3758629, -0.24878195, -0.35373706], 68 | _ => panic!("Model {model} not found. If you have just inserted this `EmbeddingModel` variant, please update the expected embeddings."), 69 | }; 70 | 71 | let mismatched_indices = embeddings 72 | .iter() 73 | .map(|embedding| embedding.iter().sum::()) 74 | .zip(expected.iter()) 75 | .enumerate() 76 | .filter_map(|(i, (sum, &expected))| { 77 | if (sum - expected).abs() > EPS { 78 | eprintln!( 79 | "Mismatched embeddings for model {model} at index {i}: {sum} != {expected}", 80 | model = model, 81 | i = i, 82 | sum = sum, 83 | expected = expected 84 | ); 85 | Some(i) 86 | } else { 87 | None 88 | } 89 | }) 90 | .collect::>(); 91 | 92 | if mismatched_indices.is_empty() { 93 | Ok(()) 94 | } else { 95 | Err(mismatched_indices) 96 | } 97 | } 98 | 99 | macro_rules! create_embeddings_test { 100 | ( 101 | name: $name:ident, 102 | batch_size: $batch_size:expr, 103 | ) => { 104 | #[test] 105 | fn $name() { 106 | TextEmbedding::list_supported_models() 107 | .par_iter() 108 | .for_each(|supported_model| { 109 | let model: TextEmbedding = TextEmbedding::try_new(InitOptions::new(supported_model.model.clone())) 110 | .unwrap(); 111 | 112 | let documents = vec![ 113 | "Hello, World!", 114 | "This is an example passage.", 115 | "fastembed-rs is licensed under Apache-2.0", 116 | "Some other short text here blah blah blah", 117 | ]; 118 | 119 | // Generate embeddings with the default batch size, 256 120 | let batch_size = $batch_size; 121 | let embeddings = model.embed(documents.clone(), batch_size); 122 | 123 | if matches!( 124 | (batch_size, TextEmbedding::get_quantization_mode(&supported_model.model)), 125 | (Some(n), QuantizationMode::Dynamic) if n < documents.len() 126 | ) { 127 | // For Dynamic quantization, the batch size must be greater than or equal to the number of documents 128 | // Otherwise, an error is expected 129 | assert!(embeddings.is_err(), "Expected error for batch size < document count for {model} using dynamic quantization.", model=supported_model.model); 130 | } else { 131 | let embeddings = embeddings.unwrap_or_else( 132 | |exc| panic!("Expected embeddings for {model} to be generated successfully: {exc}", model=supported_model.model, exc=exc), 133 | ); 134 | assert_eq!(embeddings.len(), documents.len()); 135 | 136 | for embedding in &embeddings { 137 | assert_eq!(embedding.len(), supported_model.dim); 138 | } 139 | 140 | match verify_embeddings(&supported_model.model, &embeddings) { 141 | Ok(_) => {} 142 | Err(mismatched_indices) => { 143 | panic!( 144 | "Mismatched embeddings for model {model}: {sentences:?}", 145 | model = supported_model.model, 146 | sentences = &mismatched_indices 147 | .iter() 148 | .map(|&i| documents[i]) 149 | .collect::>() 150 | ); 151 | } 152 | } 153 | } 154 | }); 155 | } 156 | 157 | }; 158 | } 159 | 160 | create_embeddings_test!( 161 | name: test_batch_size_default, 162 | batch_size: None, 163 | ); 164 | 165 | create_embeddings_test!( 166 | name: test_with_batch_size, 167 | batch_size: Some(70), 168 | ); 169 | 170 | #[test] 171 | fn test_sparse_embeddings() { 172 | SparseTextEmbedding::list_supported_models() 173 | .par_iter() 174 | .for_each(|supported_model| { 175 | let model: SparseTextEmbedding = 176 | SparseTextEmbedding::try_new(SparseInitOptions::new(supported_model.model.clone())) 177 | .unwrap(); 178 | 179 | let documents = vec![ 180 | "Hello, World!", 181 | "This is an example passage.", 182 | "fastembed-rs is licensed under Apache-2.0", 183 | "Some other short text here blah blah blah", 184 | ]; 185 | 186 | // Generate embeddings with the default batch size, 256 187 | let embeddings = model.embed(documents.clone(), None).unwrap(); 188 | 189 | assert_eq!(embeddings.len(), documents.len()); 190 | embeddings.into_iter().for_each(|embedding| { 191 | assert!(embedding.values.iter().all(|&v| v > 0.0)); 192 | assert!(embedding.indices.len() < 100); 193 | assert_eq!(embedding.indices.len(), embedding.values.len()); 194 | }); 195 | 196 | // Clear the model cache to avoid running out of space on GitHub Actions. 197 | if std::env::var("CI").is_ok() { 198 | clean_cache(supported_model.model_code.clone()) 199 | } 200 | }); 201 | } 202 | 203 | #[test] 204 | fn test_user_defined_embedding_model() { 205 | // Constitute the model in order to ensure it's downloaded and cached 206 | let test_model_info = TextEmbedding::get_model_info(&EmbeddingModel::AllMiniLML6V2).unwrap(); 207 | 208 | TextEmbedding::try_new(InitOptions::new(test_model_info.model.clone())).unwrap(); 209 | 210 | // Get the directory of the model 211 | let model_name = test_model_info.model_code.replace('/', "--"); 212 | let model_dir = Path::new(&get_cache_dir()).join(format!("models--{}", model_name)); 213 | 214 | // Find the "snapshots" sub-directory 215 | let snapshots_dir = model_dir.join("snapshots"); 216 | 217 | // Get the first sub-directory in snapshots 218 | let model_files_dir = snapshots_dir 219 | .read_dir() 220 | .unwrap() 221 | .next() 222 | .unwrap() 223 | .unwrap() 224 | .path(); 225 | 226 | // Find the onnx file - it will be any file ending with .onnx 227 | let onnx_file = read_file_to_bytes( 228 | &model_files_dir 229 | .read_dir() 230 | .unwrap() 231 | .find(|entry| { 232 | entry 233 | .as_ref() 234 | .unwrap() 235 | .path() 236 | .extension() 237 | .unwrap() 238 | .to_str() 239 | .unwrap() 240 | == "onnx" 241 | }) 242 | .unwrap() 243 | .unwrap() 244 | .path(), 245 | ) 246 | .expect("Could not read onnx file"); 247 | 248 | // Load the tokenizer files 249 | let tokenizer_files = TokenizerFiles { 250 | tokenizer_file: read_file_to_bytes(&model_files_dir.join("tokenizer.json")) 251 | .expect("Could not read tokenizer.json"), 252 | config_file: read_file_to_bytes(&model_files_dir.join("config.json")) 253 | .expect("Could not read config.json"), 254 | special_tokens_map_file: read_file_to_bytes( 255 | &model_files_dir.join("special_tokens_map.json"), 256 | ) 257 | .expect("Could not read special_tokens_map.json"), 258 | tokenizer_config_file: read_file_to_bytes(&model_files_dir.join("tokenizer_config.json")) 259 | .expect("Could not read tokenizer_config.json"), 260 | }; 261 | // Create a UserDefinedEmbeddingModel 262 | let user_defined_model = 263 | UserDefinedEmbeddingModel::new(onnx_file, tokenizer_files).with_pooling(Pooling::Mean); 264 | 265 | // Try creating a TextEmbedding instance from the user-defined model 266 | let user_defined_text_embedding = TextEmbedding::try_new_from_user_defined( 267 | user_defined_model, 268 | InitOptionsUserDefined::default(), 269 | ) 270 | .unwrap(); 271 | 272 | let documents = vec![ 273 | "Hello, World!", 274 | "This is an example passage.", 275 | "fastembed-rs is licensed under Apache-2.0", 276 | "Some other short text here blah blah blah", 277 | ]; 278 | 279 | // Generate embeddings over documents 280 | let embeddings = user_defined_text_embedding 281 | .embed(documents.clone(), None) 282 | .unwrap(); 283 | assert_eq!(embeddings.len(), documents.len()); 284 | for embedding in embeddings { 285 | assert_eq!(embedding.len(), test_model_info.dim); 286 | } 287 | } 288 | 289 | #[test] 290 | fn test_rerank() { 291 | let test_one_model = |supported_model: &RerankerModelInfo| { 292 | println!("supported_model: {:?}", supported_model); 293 | 294 | let result = 295 | TextRerank::try_new(RerankInitOptions::new(supported_model.model.clone())).unwrap(); 296 | 297 | let documents = vec![ 298 | "hi", 299 | "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China.", 300 | "panda is an animal", 301 | "i dont know", 302 | "kind of mammal", 303 | ]; 304 | 305 | let results = result 306 | .rerank("what is panda?", documents.clone(), true, None) 307 | .unwrap(); 308 | 309 | assert_eq!( 310 | results.len(), 311 | documents.len(), 312 | "rerank model {:?} failed", 313 | supported_model 314 | ); 315 | 316 | let option_a = "panda is an animal"; 317 | let option_b = "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China."; 318 | 319 | assert!( 320 | results[0].document.as_ref().unwrap() == option_a 321 | || results[0].document.as_ref().unwrap() == option_b 322 | ); 323 | assert!( 324 | results[1].document.as_ref().unwrap() == option_a 325 | || results[1].document.as_ref().unwrap() == option_b 326 | ); 327 | assert_ne!( 328 | results[0].document, results[1].document, 329 | "The top two results should be different" 330 | ); 331 | 332 | // Clear the model cache to avoid running out of space on GitHub Actions. 333 | clean_cache(supported_model.model_code.clone()) 334 | }; 335 | TextRerank::list_supported_models() 336 | .par_iter() 337 | .for_each(test_one_model); 338 | } 339 | 340 | #[ignore] 341 | #[test] 342 | fn test_user_defined_reranking_large_model() { 343 | // Setup model to download from Hugging Face 344 | let cache = hf_hub::Cache::new(std::path::PathBuf::from(&fastembed::get_cache_dir())); 345 | let api = hf_hub::api::sync::ApiBuilder::from_cache(cache) 346 | .with_progress(true) 347 | .build() 348 | .expect("Failed to build API from cache"); 349 | let model_repo = api.model("rozgo/bge-reranker-v2-m3".to_string()); 350 | 351 | // Download the onnx model file 352 | let onnx_file = model_repo.download("model.onnx").unwrap(); 353 | // Onnx model exceeds the limit of 2GB for a file, so we need to download the data file separately 354 | let _onnx_data_file = model_repo.get("model.onnx.data").unwrap(); 355 | 356 | // OnnxSource::File is used to load the onnx file using onnx session builder commit_from_file 357 | let onnx_source = OnnxSource::File(onnx_file); 358 | 359 | // Load the tokenizer files 360 | let tokenizer_files: TokenizerFiles = TokenizerFiles { 361 | tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json").unwrap()).unwrap(), 362 | config_file: read_file_to_bytes(&model_repo.get("config.json").unwrap()).unwrap(), 363 | special_tokens_map_file: read_file_to_bytes( 364 | &model_repo.get("special_tokens_map.json").unwrap(), 365 | ) 366 | .unwrap(), 367 | 368 | tokenizer_config_file: read_file_to_bytes( 369 | &model_repo.get("tokenizer_config.json").unwrap(), 370 | ) 371 | .unwrap(), 372 | }; 373 | 374 | let model = UserDefinedRerankingModel::new(onnx_source, tokenizer_files); 375 | 376 | let user_defined_reranker = 377 | TextRerank::try_new_from_user_defined(model, Default::default()).unwrap(); 378 | 379 | let documents = vec![ 380 | "Hello, World!", 381 | "This is an example passage.", 382 | "fastembed-rs is licensed under Apache-2.0", 383 | "Some other short text here blah blah blah", 384 | ]; 385 | 386 | let results = user_defined_reranker 387 | .rerank("Ciao, Earth!", documents.clone(), false, None) 388 | .unwrap(); 389 | 390 | assert_eq!(results.len(), documents.len()); 391 | assert_eq!(results.first().unwrap().index, 0); 392 | } 393 | 394 | #[test] 395 | fn test_user_defined_reranking_model() { 396 | // Constitute the model in order to ensure it's downloaded and cached 397 | let test_model_info: fastembed::RerankerModelInfo = 398 | TextRerank::get_model_info(&RerankerModel::JINARerankerV1TurboEn); 399 | 400 | TextRerank::try_new(RerankInitOptions::new(test_model_info.model)).unwrap(); 401 | 402 | // Get the directory of the model 403 | let model_name = test_model_info.model_code.replace('/', "--"); 404 | let model_dir = Path::new(&get_cache_dir()).join(format!("models--{}", model_name)); 405 | 406 | // Find the "snapshots" sub-directory 407 | let snapshots_dir = model_dir.join("snapshots"); 408 | 409 | // Get the first sub-directory in snapshots 410 | let model_files_dir = snapshots_dir 411 | .read_dir() 412 | .unwrap() 413 | .next() 414 | .unwrap() 415 | .unwrap() 416 | .path(); 417 | 418 | // Find the onnx file - it will be any file in ./onnx ending with .onnx 419 | let onnx_file = read_file_to_bytes( 420 | &model_files_dir 421 | .join("onnx") 422 | .read_dir() 423 | .unwrap() 424 | .find(|entry| { 425 | entry 426 | .as_ref() 427 | .unwrap() 428 | .path() 429 | .extension() 430 | .unwrap() 431 | .to_str() 432 | .unwrap() 433 | == "onnx" 434 | }) 435 | .unwrap() 436 | .unwrap() 437 | .path(), 438 | ) 439 | .expect("Could not read onnx file"); 440 | 441 | // Load the tokenizer files 442 | let tokenizer_files = TokenizerFiles { 443 | tokenizer_file: read_file_to_bytes(&model_files_dir.join("tokenizer.json")) 444 | .expect("Could not read tokenizer.json"), 445 | config_file: read_file_to_bytes(&model_files_dir.join("config.json")) 446 | .expect("Could not read config.json"), 447 | special_tokens_map_file: read_file_to_bytes( 448 | &model_files_dir.join("special_tokens_map.json"), 449 | ) 450 | .expect("Could not read special_tokens_map.json"), 451 | tokenizer_config_file: read_file_to_bytes(&model_files_dir.join("tokenizer_config.json")) 452 | .expect("Could not read tokenizer_config.json"), 453 | }; 454 | // Create a UserDefinedEmbeddingModel 455 | let user_defined_model = UserDefinedRerankingModel::new(onnx_file, tokenizer_files); 456 | 457 | // Try creating a TextEmbedding instance from the user-defined model 458 | let user_defined_reranker = TextRerank::try_new_from_user_defined( 459 | user_defined_model, 460 | RerankInitOptionsUserDefined::default(), 461 | ) 462 | .unwrap(); 463 | 464 | let documents = vec![ 465 | "Hello, World!", 466 | "This is an example passage.", 467 | "fastembed-rs is licensed under Apache-2.0", 468 | "Some other short text here blah blah blah", 469 | ]; 470 | 471 | // Generate embeddings over documents 472 | let results = user_defined_reranker 473 | .rerank("Ciao, Earth!", documents.clone(), false, None) 474 | .unwrap(); 475 | 476 | assert_eq!(results.len(), documents.len()); 477 | assert_eq!(results.first().unwrap().index, 0); 478 | } 479 | 480 | #[test] 481 | fn test_image_embedding_model() { 482 | let test_one_model = |supported_model: &ModelInfo| { 483 | let model: ImageEmbedding = 484 | ImageEmbedding::try_new(ImageInitOptions::new(supported_model.model.clone())).unwrap(); 485 | 486 | let images = vec!["tests/assets/image_0.png", "tests/assets/image_1.png"]; 487 | 488 | // Generate embeddings with the default batch size, 256 489 | let embeddings = model.embed(images.clone(), None).unwrap(); 490 | 491 | assert_eq!(embeddings.len(), images.len()); 492 | }; 493 | ImageEmbedding::list_supported_models() 494 | .par_iter() 495 | .for_each(test_one_model); 496 | } 497 | 498 | #[test] 499 | #[ignore] 500 | fn test_nomic_embed_vision_v1_5() { 501 | fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { 502 | let dot_product = a.iter().zip(b).map(|(x, y)| x * y).sum::(); 503 | let norm_a = a.iter().map(|x| x * x).sum::().sqrt(); 504 | let norm_b = b.iter().map(|x| x * x).sum::().sqrt(); 505 | dot_product / (norm_a * norm_b) 506 | } 507 | 508 | fn cosine_similarity_matrix( 509 | embeddings_a: &[Vec], 510 | embeddings_b: &[Vec], 511 | ) -> Vec> { 512 | embeddings_a 513 | .iter() 514 | .map(|a| { 515 | embeddings_b 516 | .iter() 517 | .map(|b| cosine_similarity(a, b)) 518 | .collect() 519 | }) 520 | .collect() 521 | } 522 | 523 | // Test the NomicEmbedVisionV15 model specifically because it outputs a 3D tensor with a different 524 | // output key ('last_hidden_state') compared to other models. This test ensures our tensor extraction 525 | // logic can handle both standard output keys and this model's specific naming convention. 526 | let image_model = ImageEmbedding::try_new(ImageInitOptions::new( 527 | fastembed::ImageEmbeddingModel::NomicEmbedVisionV15, 528 | )) 529 | .unwrap(); 530 | 531 | // tests/assets/image_0.png is a blue cat 532 | // tests/assets/image_1.png is a red cat 533 | let images = vec!["tests/assets/image_0.png", "tests/assets/image_1.png"]; 534 | let image_embeddings = image_model.embed(images.clone(), None).unwrap(); 535 | assert_eq!(image_embeddings.len(), images.len()); 536 | 537 | let text_model = TextEmbedding::try_new(InitOptions::new( 538 | fastembed::EmbeddingModel::NomicEmbedTextV15, 539 | )) 540 | .unwrap(); 541 | let texts = vec!["green cat", "blue cat", "red cat", "yellow cat", "dog"]; 542 | let text_embeddings = text_model.embed(texts.clone(), None).unwrap(); 543 | 544 | // Generate similarity matrix 545 | let similarity_matrix = cosine_similarity_matrix(&text_embeddings, &image_embeddings); 546 | // Print the similarity matrix with text labels 547 | for (i, row) in similarity_matrix.iter().enumerate() { 548 | println!("{}: {:?}", texts[i], row); 549 | } 550 | 551 | assert_eq!(text_embeddings.len(), texts.len()); 552 | assert_eq!(text_embeddings[0].len(), 768); 553 | } 554 | 555 | fn clean_cache(model_code: String) { 556 | let repo = Repo::model(model_code); 557 | let cache_dir = format!("{}/{}", &get_cache_dir(), repo.folder_name()); 558 | fs::remove_dir_all(cache_dir).ok(); 559 | } 560 | 561 | // This is item "test-environment-aeghhgwpe-pro02a" of the [Aguana corpus](http://argumentation.bplaced.net/arguana/data) 562 | fn get_sample_text() -> String { 563 | let t = include_str!("assets/sample_text.txt"); 564 | t.to_string() 565 | } 566 | 567 | #[test] 568 | fn test_batch_size_does_not_change_output() { 569 | let model = TextEmbedding::try_new( 570 | InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_max_length(384), 571 | ) 572 | .expect("Create model successfully"); 573 | 574 | let sentences = vec![ 575 | "Books are no more threatened by Kindle than stairs by elevators.", 576 | "You are who you are when nobody's watching.", 577 | "An original idea. That can't be too hard. The library must be full of them.", 578 | "Gaia visited her daughter Mnemosyne, who was busy being unpronounceable.", 579 | "You can never be overdressed or overeducated.", 580 | "I don't want to go to heaven. None of my friends are there.", 581 | "I never travel without my diary. One should always have something sensational to read in the train.", 582 | "I can resist anything except temptation.", 583 | "It is absurd to divide people into good and bad. People are either charming or tedious." 584 | ]; 585 | 586 | let single_batch = model 587 | .embed(sentences.clone(), None) 588 | .expect("create successfully"); 589 | let small_batch = model 590 | .embed(sentences, Some(3)) 591 | .expect("create successfully"); 592 | 593 | assert_eq!(single_batch.len(), small_batch.len()); 594 | for (a, b) in single_batch.into_iter().zip(small_batch.into_iter()) { 595 | assert!(a == b, "Expect each sentence embedding are equal."); 596 | } 597 | } 598 | 599 | #[test] 600 | fn test_bgesmallen1point5_match_python_counterpart() { 601 | let model = TextEmbedding::try_new( 602 | InitOptions::new(EmbeddingModel::BGESmallENV15).with_max_length(384), 603 | ) 604 | .expect("Create model successfully"); 605 | 606 | let text = get_sample_text(); 607 | 608 | // baseline is generated in python using Xenova/bge-small-en-v1.5.onnx 609 | // Tokenize with python SentenceTransformer("BAAI/bge-small-en-v1.5") default tokenizer 610 | // with (text, padding="max_length",max_length=384,truncation=True, return_tensors="np"). 611 | // Normalized and pooled with SentenceTransformer("BAAI/bge-small-en-v1.5") default pooling settings. 612 | // we only take a 10 items to keep the test file polite 613 | let baseline: Vec = vec![ 614 | 4.208_193_7e-2, 615 | -2.748_133_2e-2, 616 | 6.742_810_5e-2, 617 | 2.282_790_5e-2, 618 | 4.257_192e-2, 619 | -4.163_983_5e-2, 620 | 6.814_807_4e-6, 621 | -9.643_933e-3, 622 | -3.475_583e-3, 623 | 6.606_272e-2, 624 | ]; 625 | 626 | let embeddings = model.embed(vec![text], None).expect("create successfully"); 627 | let tolerance: f32 = 1e-3; 628 | for (expected, actual) in embeddings[0] 629 | .clone() 630 | .into_iter() 631 | .take(baseline.len()) 632 | .zip(baseline.into_iter()) 633 | { 634 | assert!((expected - actual).abs() < tolerance); 635 | } 636 | } 637 | 638 | #[test] 639 | fn test_allminilml6v2_match_python_counterpart() { 640 | let model = TextEmbedding::try_new( 641 | InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_max_length(384), 642 | ) 643 | .expect("Create model successfully"); 644 | 645 | let text = get_sample_text(); 646 | 647 | // baseline is generated in python using qdrant/all-mini-lm-l6-v2.onnx 648 | // Tokenizer with python SentenceTransformer("all-mini-lm-l6-v2") default tokenizer 649 | // with (text, padding="max_length",max_length=384,truncation=True, return_tensors="np"). 650 | // Normalized and pooled with SentenceTransformer("all-mini-lm-l6-v2") default pooling settings. 651 | // we only take a 10 items to keep the test file polite 652 | let baseline: Vec = vec![ 653 | 3.510_517_6e-2, 654 | 1.046_043e-2, 655 | 3.767_998_5e-2, 656 | 7.073_633_4e-2, 657 | 9.097_775e-2, 658 | -2.507_714_7e-2, 659 | -2.214_382e-2, 660 | -1.016_435_9e-2, 661 | 4.660_127_3e-2, 662 | 7.431_366e-2, 663 | ]; 664 | 665 | let embeddings = model.embed(vec![text], None).expect("create successfully"); 666 | let tolerance: f32 = 1e-6; 667 | for (expected, actual) in embeddings[0] 668 | .clone() 669 | .into_iter() 670 | .take(baseline.len()) 671 | .zip(baseline.into_iter()) 672 | { 673 | assert!((expected - actual).abs() < tolerance); 674 | } 675 | } 676 | -------------------------------------------------------------------------------- /tests/optimum_cli_export.rs: -------------------------------------------------------------------------------- 1 | #![cfg(feature = "hf-hub")] 2 | #![cfg(feature = "optimum-cli")] 3 | //! Test the use of the ``optimum-cli`` to pull models from the Hugging Face Hub, 4 | //! and generate embeddings successfully with the pulled model. 5 | //! 6 | //! Generated models from optimum can have different output types - `last_hidden_state` 7 | //! may not be the default output. This test is to ensure that the correct output key 8 | //! is used when generating embeddings. 9 | 10 | use std::{path::PathBuf, process}; 11 | 12 | use fastembed::{ 13 | get_cache_dir, Pooling, QuantizationMode, TextEmbedding, TokenizerFiles, 14 | UserDefinedEmbeddingModel, 15 | }; 16 | 17 | const EPS: f32 = 1e-4; 18 | 19 | /// Check if the ``optimum-cli`` is available. 20 | fn has_optimum_cli() -> bool { 21 | process::Command::new("optimum-cli") 22 | .arg("--help") 23 | .output() 24 | .is_ok() 25 | } 26 | 27 | /// Pull a model from the Hugging Face Hub using ``optimum-cli``. 28 | /// 29 | /// This function assumes you have already checked if the ``optimum-cli`` is available. 30 | /// The return error will not distinguish between a missing ``optimum-cli`` and a failed download. 31 | fn pull_model( 32 | model_name: &str, 33 | output: &PathBuf, 34 | pooling: Option, 35 | ) -> anyhow::Result { 36 | eprintln!("Pulling {model_name} from the Hugging Face Hub..."); 37 | process::Command::new("optimum-cli") 38 | .args(&[ 39 | "export", 40 | "onnx", 41 | "--model", 42 | model_name, 43 | output 44 | .as_os_str() 45 | .to_str() 46 | .expect("Failed to convert path to string"), 47 | ]) 48 | .output() 49 | .map_err(|e| anyhow::anyhow!("Failed to pull model: {}", e))?; 50 | 51 | load_model(output, pooling) 52 | } 53 | 54 | /// Load bytes from a file, with a nicer error message. 55 | fn load_bytes_from_file(path: &PathBuf) -> anyhow::Result> { 56 | std::fs::read(path).map_err(|e| anyhow::anyhow!("Failed to read file at {:?}: {}", path, e)) 57 | } 58 | 59 | /// Load a model from a local directory. 60 | fn load_model(output: &PathBuf, pooling: Option) -> anyhow::Result { 61 | let model = UserDefinedEmbeddingModel { 62 | onnx_file: load_bytes_from_file(&output.join("model.onnx"))?, 63 | tokenizer_files: TokenizerFiles { 64 | tokenizer_file: load_bytes_from_file(&output.join("tokenizer.json"))?, 65 | config_file: load_bytes_from_file(&output.join("config.json"))?, 66 | special_tokens_map_file: load_bytes_from_file(&output.join("special_tokens_map.json"))?, 67 | tokenizer_config_file: load_bytes_from_file(&output.join("tokenizer_config.json"))?, 68 | }, 69 | pooling, 70 | quantization: QuantizationMode::None, 71 | }; 72 | 73 | TextEmbedding::try_new_from_user_defined(model, Default::default()) 74 | } 75 | 76 | macro_rules! create_test { 77 | ( 78 | repo_name: $repo_name:literal, 79 | repo_owner: $repo_owner:literal, 80 | name: $name:ident, 81 | pooling: $pooling:expr, 82 | expected_embedding_dim: $expected_embedding_dim:literal, 83 | expected: $expected:expr 84 | ) => { 85 | #[test] 86 | fn $name() { 87 | let repo_name = $repo_name; 88 | let repo_owner = $repo_owner; 89 | let model_name = format!("{}/{}", repo_owner, repo_name); 90 | let cache_dir = get_cache_dir(); 91 | let output_path = format!("{cache_dir}/exported--{repo_owner}--{repo_name}-onnx"); 92 | let output = PathBuf::from(output_path); 93 | 94 | assert!( 95 | has_optimum_cli(), 96 | "optimum-cli is not available. Please install it with `pip install optimum-cli`" 97 | ); 98 | 99 | let model = load_model(&output, $pooling).unwrap_or_else(|_| { 100 | pull_model(&model_name, &output, $pooling).expect("Failed to pull model") 101 | }); 102 | 103 | let documents = vec![ 104 | "Hello, World!", 105 | "This is an example passage.", 106 | "fastembed-rs is licensed under Apache-2.0", 107 | "Some other short text here blah blah blah", 108 | ]; 109 | let expected_length = documents.len(); 110 | 111 | // Generate embeddings with the default batch size, 256 112 | let embeddings = model 113 | .embed(documents.clone(), Some(3)) 114 | .expect("Failed to generate embeddings"); 115 | 116 | assert_eq!(embeddings.len(), expected_length); 117 | assert_eq!(embeddings[0].len(), $expected_embedding_dim); 118 | 119 | embeddings 120 | .into_iter() 121 | .map(|embedding| embedding.iter().sum::()) 122 | .zip($expected.iter()) 123 | .enumerate() 124 | .for_each(|(index, (embedding, expected))| { 125 | assert!( 126 | (embedding - expected).abs() < EPS, 127 | "Mismatched embeddings sum for '{}': Expected: {}, Got: {}", 128 | documents[index], 129 | expected, 130 | embedding 131 | ); 132 | }); 133 | } 134 | }; 135 | } 136 | 137 | create_test! { 138 | repo_name: "all-MiniLM-L6-v2", 139 | repo_owner: "sentence-transformers", 140 | name: optimum_cli_export_all_minilm_l6_v2_mean, 141 | pooling: Some(Pooling::Mean), // Mean does not matter here because the output is 2D 142 | expected_embedding_dim: 384, 143 | // These are generated by Python; there could be accumulated variations 144 | // when summed. 145 | expected: [ 0.5960538 , 0.36542776, -0.16450086, -0.40904027] 146 | } 147 | create_test! { 148 | repo_name: "all-MiniLM-L6-v2", 149 | repo_owner: "sentence-transformers", 150 | name: optimum_cli_export_all_minilm_l6_v2_cls, 151 | pooling: Some(Pooling::Cls), 152 | expected_embedding_dim: 384, 153 | // These are generated by Python; there could be accumulated variations 154 | // when summed. 155 | expected: [ 0.5960538 , 0.36542776, -0.16450086, -0.40904027] 156 | } 157 | create_test! { 158 | repo_name: "all-mpnet-base-v2", 159 | repo_owner: "sentence-transformers", 160 | name: optimum_cli_export_all_mpnet_base_v2_mean, 161 | pooling: Some(Pooling::Mean), 162 | expected_embedding_dim: 768, 163 | // These are generated by Python; there could be accumulated variations 164 | // when summed. 165 | expected: [-0.21253565, -0.05080119, 0.14072478, -0.29081905] 166 | } 167 | create_test! { 168 | repo_name: "all-mpnet-base-v2", 169 | repo_owner: "sentence-transformers", 170 | name: optimum_cli_export_all_mpnet_base_v2_cls, 171 | pooling: Some(Pooling::Cls), 172 | expected_embedding_dim: 768, 173 | // These are generated by Python; there could be accumulated variations 174 | // when summed. 175 | expected: [-0.21253565, -0.05080119, 0.14072478, -0.29081905] 176 | } 177 | --------------------------------------------------------------------------------