├── .cargo └── config.toml ├── .dockerignore ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── feature-request.yml │ └── new-model-addition.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── build.yaml │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── liniting.yaml │ ├── matrix.json │ ├── test.yaml │ ├── trufflehog.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── Dockerfile-cuda ├── Dockerfile-cuda-all ├── Dockerfile-intel ├── LICENSE ├── Makefile ├── README.md ├── assets ├── bs1-lat.png ├── bs1-tp.png ├── bs32-lat.png └── bs32-tp.png ├── backends ├── Cargo.toml ├── candle │ ├── Cargo.toml │ ├── build.rs │ ├── src │ │ ├── alibi.rs │ │ ├── compute_cap.rs │ │ ├── flash_attn.rs │ │ ├── layers │ │ │ ├── cublaslt.rs │ │ │ ├── layer_norm.rs │ │ │ ├── linear.rs │ │ │ ├── mod.rs │ │ │ ├── rms_norm.rs │ │ │ └── rotary.rs │ │ ├── lib.rs │ │ └── models │ │ │ ├── bert.rs │ │ │ ├── distilbert.rs │ │ │ ├── flash_bert.rs │ │ │ ├── flash_distilbert.rs │ │ │ ├── flash_gte.rs │ │ │ ├── flash_jina.rs │ │ │ ├── flash_jina_code.rs │ │ │ ├── flash_mistral.rs │ │ │ ├── flash_modernbert.rs │ │ │ ├── flash_nomic.rs │ │ │ ├── flash_qwen2.rs │ │ │ ├── gte.rs │ │ │ ├── jina.rs │ │ │ ├── jina_code.rs │ │ │ ├── mistral.rs │ │ │ ├── mod.rs │ │ │ ├── modernbert.rs │ │ │ ├── mpnet.rs │ │ │ ├── nomic.rs │ │ │ └── qwen2.rs │ └── tests │ │ ├── common.rs │ │ ├── snapshots │ │ ├── test_bert__bert_batch.snap │ │ ├── test_bert__bert_batch_pooled.snap │ │ ├── test_bert__bert_batch_raw.snap │ │ ├── test_bert__bert_classification_single.snap │ │ ├── test_bert__bert_single.snap │ │ ├── test_bert__bert_single_pooled.snap │ │ ├── test_bert__bert_single_raw.snap │ │ ├── test_bert__emotions_batch.snap │ │ ├── test_bert__emotions_single.snap │ │ ├── test_flash_bert__bert_classification_single.snap │ │ ├── test_flash_bert__emotions_batch.snap │ │ ├── test_flash_bert__emotions_single.snap │ │ ├── test_flash_bert__mini_batch.snap │ │ ├── test_flash_bert__mini_batch_pooled.snap │ │ ├── test_flash_bert__mini_batch_raw.snap │ │ ├── test_flash_bert__mini_single.snap │ │ ├── test_flash_bert__mini_single_pooled.snap │ │ ├── test_flash_bert__mini_single_raw.snap │ │ ├── test_flash_gte__gte_batch.snap │ │ ├── test_flash_gte__gte_classification_single.snap │ │ ├── test_flash_gte__gte_single.snap │ │ ├── test_flash_jina__jina_batch.snap │ │ ├── test_flash_jina__jina_single.snap │ │ ├── test_flash_jina_code__jina_code_batch.snap │ │ ├── test_flash_jina_code__jina_code_single.snap │ │ ├── test_flash_mistral__mistral_batch.snap │ │ ├── test_flash_mistral__mistral_single.snap │ │ ├── test_flash_nomic__nomic_batch.snap │ │ ├── test_flash_nomic__nomic_moe_batch.snap │ │ ├── test_flash_nomic__nomic_moe_single.snap │ │ ├── test_flash_nomic__nomic_single.snap │ │ ├── test_flash_qwen2__qwen2_batch.snap │ │ ├── test_flash_qwen2__qwen2_single.snap │ │ ├── test_gte__alibaba_gte_batch.snap │ │ ├── test_gte__alibaba_gte_single.snap │ │ ├── test_gte__alibaba_new_gte_batch.snap │ │ ├── test_gte__alibaba_new_gte_single.snap │ │ ├── test_gte__gte_classification_single.snap │ │ ├── test_gte__snowflake_gte_batch.snap │ │ ├── test_gte__snowflake_gte_single.snap │ │ ├── test_jina__jina_batch.snap │ │ ├── test_jina__jina_single.snap │ │ ├── test_jina__jinabert_reranker_single.snap │ │ ├── test_jina_code__jina_code_batch.snap │ │ ├── test_jina_code__jina_code_single.snap │ │ ├── test_modernbert__modernbert_batch.snap │ │ ├── test_modernbert__modernbert_batch_flash.snap │ │ ├── test_modernbert__modernbert_batch_pooled.snap │ │ ├── test_modernbert__modernbert_batch_pooled_flash.snap │ │ ├── test_modernbert__modernbert_batch_raw.snap │ │ ├── test_modernbert__modernbert_batch_raw_flash.snap │ │ ├── test_modernbert__modernbert_classification_mean_pooling.snap │ │ ├── test_modernbert__modernbert_classification_single.snap │ │ ├── test_modernbert__modernbert_single.snap │ │ ├── test_modernbert__modernbert_single_flash.snap │ │ ├── test_modernbert__modernbert_single_pooled.snap │ │ ├── test_modernbert__modernbert_single_pooled_flash.snap │ │ ├── test_modernbert__modernbert_single_raw.snap │ │ ├── test_modernbert__modernbert_single_raw_flash.snap │ │ ├── test_mpnet__mpnet_batch.snap │ │ ├── test_mpnet__mpnet_batch_pooled.snap │ │ ├── test_mpnet__mpnet_batch_raw.snap │ │ ├── test_mpnet__mpnet_single.snap │ │ ├── test_mpnet__mpnet_single_pooled.snap │ │ ├── test_mpnet__mpnet_single_raw.snap │ │ ├── test_nomic__nomic_batch.snap │ │ ├── test_nomic__nomic_moe_batch.snap │ │ ├── test_nomic__nomic_moe_single.snap │ │ └── test_nomic__nomic_single.snap │ │ ├── test_bert.rs │ │ ├── test_flash_bert.rs │ │ ├── test_flash_gte.rs │ │ ├── test_flash_jina.rs │ │ ├── test_flash_jina_code.rs │ │ ├── test_flash_mistral.rs │ │ ├── test_flash_nomic.rs │ │ ├── test_flash_qwen2.rs │ │ ├── test_gte.rs │ │ ├── test_jina.rs │ │ ├── test_jina_code.rs │ │ ├── test_modernbert.rs │ │ ├── test_mpnet.rs │ │ └── test_nomic.rs ├── core │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── grpc-client │ ├── Cargo.toml │ ├── build.rs │ └── src │ │ ├── client.rs │ │ ├── lib.rs │ │ └── pb │ │ └── .gitignore ├── grpc-metadata │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── ort │ ├── Cargo.toml │ └── src │ │ └── lib.rs ├── proto │ └── embed.proto ├── python │ ├── Cargo.toml │ ├── server │ │ ├── .gitignore │ │ ├── Makefile │ │ ├── Makefile-flash-att │ │ ├── Makefile-flash-att-v2 │ │ ├── README.md │ │ ├── poetry.lock │ │ ├── pyproject.toml │ │ ├── requirements-hpu.txt │ │ ├── requirements-intel.txt │ │ ├── requirements.txt │ │ └── text_embeddings_server │ │ │ ├── __init__.py │ │ │ ├── cli.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── classification_model.py │ │ │ ├── default_model.py │ │ │ ├── flash_bert.py │ │ │ ├── flash_mistral.py │ │ │ ├── jinaBert_model.py │ │ │ ├── masked_model.py │ │ │ ├── model.py │ │ │ ├── pooling.py │ │ │ └── types.py │ │ │ ├── pb │ │ │ └── .gitignore │ │ │ ├── server.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── device.py │ │ │ ├── flash_attn.py │ │ │ ├── interceptor.py │ │ │ └── tracing.py │ └── src │ │ ├── lib.rs │ │ ├── logging.rs │ │ └── management.rs └── src │ ├── dtype.rs │ └── lib.rs ├── core ├── Cargo.toml └── src │ ├── download.rs │ ├── infer.rs │ ├── lib.rs │ ├── queue.rs │ └── tokenization.rs ├── cuda-all-entrypoint.sh ├── docs ├── index.html ├── openapi.json └── source │ └── en │ ├── _toctree.yml │ ├── cli_arguments.md │ ├── custom_container.md │ ├── examples.md │ ├── index.md │ ├── intel_container.md │ ├── local_cpu.md │ ├── local_gpu.md │ ├── local_metal.md │ ├── private_models.md │ ├── quick_tour.md │ ├── supported_models.md │ └── tei_cloud_run.md ├── flake.lock ├── flake.nix ├── load_tests ├── load.js ├── load_grpc.js └── load_grpc_stream.js ├── proto └── tei.proto ├── router ├── Cargo.toml ├── build.rs ├── src │ ├── grpc │ │ ├── mod.rs │ │ ├── pb │ │ │ └── .gitignore │ │ └── server.rs │ ├── http │ │ ├── mod.rs │ │ ├── server.rs │ │ └── types.rs │ ├── lib.rs │ ├── logging.rs │ ├── main.rs │ ├── prometheus.rs │ └── shutdown.rs └── tests │ ├── common.rs │ ├── snapshots │ ├── test_http_embed__embeddings_batch.snap │ ├── test_http_embed__embeddings_raw.snap │ ├── test_http_embed__embeddings_single.snap │ ├── test_http_predict__predictions_batch.snap │ ├── test_http_predict__predictions_single.snap │ └── test_http_rerank__ranks.snap │ ├── test_http_embed.rs │ ├── test_http_predict.rs │ └── test_http_rerank.rs ├── rust-toolchain.toml ├── sagemaker-entrypoint-cuda-all.sh └── sagemaker-entrypoint.sh /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-C", "target-cpu=native"] 3 | 4 | [target.wasm32-unknown-unknown] 5 | rustflags = ["-C", "target-feature=+simd128"] 6 | 7 | [target.x86_64-apple-darwin] 8 | rustflags = ["-C", "target-feature=-avx,-avx2"] 9 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .idea 2 | target 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.png filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve text-embeddings-inference 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info 8 | description: | 9 | Please share your system info with us (`text-generation-launcher --env` if installed locally). 10 | The full command line used that causes issues: 11 | OS version: 12 | Rust version (if self-compiling, `cargo version`): 13 | Model being used (`curl 127.0.0.1:8080/info | jq`): 14 | If local model please explicit the kind of model and/or equivalents. 15 | Hardware used (GPUs, how many, on which cloud) (`nvidia-smi`): 16 | Deployment specificities (Kubernetes, EKS, AKS, any particular deployments): 17 | The current version being used: 18 | 19 | placeholder: text-embeddings-inference version, platform, python version, ... 20 | validations: 21 | required: true 22 | 23 | - type: checkboxes 24 | id: information-scripts-examples 25 | attributes: 26 | label: Information 27 | description: 'The problem arises when using:' 28 | options: 29 | - label: "Docker" 30 | - label: "The CLI directly" 31 | 32 | - type: checkboxes 33 | id: information-tasks 34 | attributes: 35 | label: Tasks 36 | description: "The thing I am working on is:" 37 | options: 38 | - label: "An officially supported command" 39 | - label: "My own modifications" 40 | 41 | - type: textarea 42 | id: reproduction 43 | validations: 44 | required: true 45 | attributes: 46 | label: Reproduction 47 | description: | 48 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 49 | If you have code snippets, error messages, stack traces please provide them here as well. 50 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 51 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 52 | 53 | placeholder: | 54 | Steps to reproduce the behavior: 55 | 56 | 1. 57 | 2. 58 | 3. 59 | 60 | 61 | - type: textarea 62 | id: expected-behavior 63 | validations: 64 | required: true 65 | attributes: 66 | label: Expected behavior 67 | description: "A clear and concise description of what you would expect to happen." 68 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | version: 2.1 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new text-embeddings-inference feature 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/text-embeddings-inference/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new-model-addition.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F31F New model addition" 2 | description: Submit a proposal/request to implement a new model 3 | labels: [ "New model" ] 4 | 5 | body: 6 | - type: textarea 7 | id: description-request 8 | validations: 9 | required: true 10 | attributes: 11 | label: Model description 12 | description: | 13 | Put any and all important information relative to the model 14 | 15 | - type: checkboxes 16 | id: information-tasks 17 | attributes: 18 | label: Open source status 19 | description: | 20 | Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `transformers`. 21 | options: 22 | - label: "The model implementation is available" 23 | - label: "The model weights are available" 24 | 25 | - type: textarea 26 | id: additional-info 27 | attributes: 28 | label: Provide useful links for the implementation 29 | description: | 30 | Please provide information regarding the implementation, the weights, and the authors. 31 | Please mention the authors by @gh-username if you're aware of their usernames. 32 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), 21 | Pull Request section? 22 | - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link 23 | to it if that's the case. 24 | - [ ] Did you make sure to update the documentation with your changes? Here are the 25 | [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and 26 | [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). 27 | - [ ] Did you write any new necessary tests? 28 | 29 | 30 | ## Who can review? 31 | 32 | Anyone in the community is free to review the PR once the tests have passed. Feel free to tag 33 | members/contributors who may be interested in your PR. 34 | 35 | 41 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | paths: 6 | - "docs/source/**" 7 | branches: 8 | - main 9 | - doc-builder* 10 | - v*-release 11 | 12 | jobs: 13 | build: 14 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 15 | with: 16 | commit_sha: ${{ github.sha }} 17 | package: text-embeddings-inference 18 | additional_args: --not_python_module 19 | languages: en 20 | secrets: 21 | token: ${{ secrets.HUGGINGFACE_PUSH }} 22 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 23 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - "docs/source/**" 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | build: 14 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 15 | with: 16 | commit_sha: ${{ github.event.pull_request.head.sha }} 17 | pr_number: ${{ github.event.number }} 18 | package: text-embeddings-inference 19 | additional_args: --not_python_module 20 | languages: en 21 | -------------------------------------------------------------------------------- /.github/workflows/liniting.yaml: -------------------------------------------------------------------------------- 1 | name: Linting Tests 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | run_tests: 12 | runs-on: ubuntu-latest 13 | 14 | env: 15 | SCCACHE_GHA_ENABLED: "on" 16 | RUSTC_WRAPPER: /usr/local/bin/sccache 17 | SCCACHE: 0.10.0 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Install Rust 22 | uses: actions-rs/toolchain@v1 23 | with: 24 | # Released on: 28 December, 2023 25 | # Branched from master on: 10 November, 2023 26 | # https://releases.rs/docs/1.85.0/ 27 | toolchain: 1.85.0 28 | override: true 29 | components: rustfmt, clippy 30 | - name: Install Protoc 31 | uses: arduino/setup-protoc@v1 32 | - name: Clean unused files 33 | run: | 34 | sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android 35 | sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET 36 | - name: Install sccache 37 | run: | 38 | curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache 39 | chmod +x /usr/local/bin/sccache 40 | - name: configure sccache 41 | uses: actions/github-script@v6 42 | with: 43 | script: | 44 | core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || ''); 45 | core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); 46 | core.exportVariable('SCCACHE_GHA_CACHE_TO', 'sccache-${{runner.os}}-${{github.ref_name}}'); 47 | core.exportVariable('SCCACHE_GHA_CACHE_FROM', 'sccache-${{runner.os}}-main,sccache-${{runner.os}}-'); 48 | - name: cargo registry cache 49 | uses: actions/cache@v3 50 | with: 51 | key: cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}-${{ github.sha }} 52 | restore-keys: | 53 | cargo-${{ runner.os }}-${{ hashFiles('**/Cargo.toml') }}- 54 | cargo-${{ runner.os }}- 55 | path: | 56 | ~/.cargo/registry 57 | ~/.cargo/git 58 | - name: Build 59 | run: | 60 | cargo build 61 | cargo build -F candle -F grpc --no-default-features 62 | - name: Pre-commit checks 63 | run: | 64 | pip install pre-commit 65 | pre-commit install 66 | pre-commit run --all-files 67 | - name: sccache stats 68 | run: | 69 | /usr/local/bin/sccache --show-stats 70 | -------------------------------------------------------------------------------- /.github/workflows/matrix.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "turing", 4 | "imageNamePrefix": "turing-", 5 | "runOn": "always", 6 | "sccache": true, 7 | "cudaComputeCap": 75, 8 | "extraBuildArgs": "DEFAULT_USE_FLASH_ATTENTION=False", 9 | "grpc": true, 10 | "dockerfile": "Dockerfile-cuda" 11 | }, 12 | { 13 | "name": "ampere", 14 | "imageNamePrefix": "", 15 | "runOn": "always", 16 | "sccache": true, 17 | "cudaComputeCap": 80, 18 | "grpc": true, 19 | "dockerfile": "Dockerfile-cuda" 20 | }, 21 | { 22 | "name": "a10", 23 | "imageNamePrefix": "86-", 24 | "runOn": "always", 25 | "sccache": true, 26 | "cudaComputeCap": 86, 27 | "grpc": true, 28 | "dockerfile": "Dockerfile-cuda" 29 | }, 30 | { 31 | "name": "RTX 4000", 32 | "imageNamePrefix": "89-", 33 | "runOn": "always", 34 | "sccache": true, 35 | "cudaComputeCap": 89, 36 | "grpc": true, 37 | "dockerfile": "Dockerfile-cuda" 38 | }, 39 | { 40 | "name": "Hopper", 41 | "imageNamePrefix": "hopper-", 42 | "runOn": "always", 43 | "sccache": true, 44 | "cudaComputeCap": 90, 45 | "grpc": true, 46 | "dockerfile": "Dockerfile-cuda" 47 | }, 48 | { 49 | "name": "All", 50 | "imageNamePrefix": "cuda-", 51 | "runOn": "always", 52 | "sccache": false, 53 | "grpc": false, 54 | "dockerfile": "Dockerfile-cuda-all" 55 | }, 56 | { 57 | "name": "cpu", 58 | "imageNamePrefix": "cpu-", 59 | "runOn": "always", 60 | "sccache": true, 61 | "grpc": true, 62 | "dockerfile": "Dockerfile" 63 | }, 64 | { 65 | "name": "cpu-ipex", 66 | "imageNamePrefix": "cpu-ipex-", 67 | "runOn": "always", 68 | "sccache": true, 69 | "extraBuildArgs": "PLATFORM=cpu", 70 | "grpc": true, 71 | "dockerfile": "Dockerfile-intel" 72 | }, 73 | { 74 | "name": "xpu-ipex", 75 | "imageNamePrefix": "xpu-ipex-", 76 | "runOn": "always", 77 | "sccache": true, 78 | "extraBuildArgs": "PLATFORM=xpu", 79 | "grpc": true, 80 | "dockerfile": "Dockerfile-intel" 81 | }, 82 | { 83 | "name": "hpu", 84 | "imageNamePrefix": "hpu-", 85 | "runOn": "always", 86 | "sccache": true, 87 | "extraBuildArgs": "PLATFORM=hpu", 88 | "grpc": true, 89 | "dockerfile": "Dockerfile-intel" 90 | } 91 | ] 92 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Run basic tests 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - 'main' 8 | pull_request: 9 | paths: 10 | - ".github/workflows/build.yaml" 11 | - ".github/workflows/matrix.json" 12 | - "integration-tests/**" 13 | - "backends/**" 14 | - "core/**" 15 | - "router/**" 16 | - "Cargo.lock" 17 | - "rust-toolchain.toml" 18 | - "Dockerfile" 19 | branches: 20 | - 'main' 21 | 22 | jobs: 23 | tests: 24 | concurrency: 25 | group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} 26 | cancel-in-progress: true 27 | runs-on: 28 | group: aws-highmemory-32-plus-priv 29 | steps: 30 | - name: Checkout repository 31 | uses: actions/checkout@v4 32 | - uses: actions-rust-lang/setup-rust-toolchain@v1 33 | - name: Run sccache-cache 34 | uses: mozilla-actions/sccache-action@v0.0.9 35 | with: 36 | version: "v0.10.0" 37 | - name: Compile project 38 | env: 39 | SCCACHE_GHA_ENABLED: "true" 40 | RUSTC_WRAPPER: "sccache" 41 | run: | 42 | sudo apt-get update && sudo apt-get install protobuf-compiler -y 43 | cargo test --profile=release-debug 44 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main 16 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: text-embeddings-inference 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | target 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/text-embeddings-inference/d51a8b99d01d00d5ee7a52ea8dec53ed1c1b7d7c/.gitmodules -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | exclude: docs/source/basic_tutorials/launcher.md 9 | - repo: https://github.com/doublify/pre-commit-rust 10 | rev: v1.0 11 | hooks: 12 | - id: fmt 13 | - id: cargo-check 14 | - id: clippy 15 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "backends", 4 | "backends/candle", 5 | "backends/ort", 6 | "backends/core", 7 | "backends/python", 8 | "backends/grpc-client", 9 | "core", 10 | "router", 11 | ] 12 | default-members = [ 13 | "backends", 14 | "backends/candle", 15 | "backends/ort", 16 | "backends/core", 17 | "backends/python", 18 | "backends/grpc-client", 19 | "core", 20 | "router", 21 | ] 22 | resolver = "2" 23 | 24 | [workspace.package] 25 | version = "1.7.1" 26 | edition = "2021" 27 | authors = ["Olivier Dehaene", "Nicolas Patry", "Alvaro Bartolome"] 28 | homepage = "https://github.com/huggingface/text-embeddings-inference" 29 | 30 | [workspace.dependencies] 31 | anyhow = "1.0.75" 32 | clap = { version = "4.1", features = ["derive", "env"] } 33 | hf-hub = { version = "0.4", features = ["tokio"], default-features = false } 34 | metrics = "0.23" 35 | nohash-hasher = "0.2" 36 | num_cpus = "1.16.0" 37 | tokenizers = { version = "0.21.0", default-features = false, features = ["onig", "esaxx_fast"] } 38 | tokio = { version = "1.25", features = ["rt", "rt-multi-thread", "parking_lot", "sync", "signal"] } 39 | tracing = "0.1" 40 | serde = { version = "1.0", features = ["serde_derive"] } 41 | serde_json = "1.0" 42 | thiserror = "1.0" 43 | rand = "0.9" 44 | serial_test = "2.0.0" 45 | cudarc = { version = "0.13" , features =["cuda-12020"], default-features = false} 46 | intel-mkl-src = { version = "0.8"} 47 | candle = { version = "0.8", package = "candle-core" } 48 | candle-nn = { version = "0.8" } 49 | candle-transformers = { version = "0.8" } 50 | candle-flash-attn = { version = "0.8" } 51 | candle-cublaslt= { version = "0.0.1" } 52 | candle-layer-norm = { version = "0.0.1" } 53 | candle-rotary = { version = "0.0.1" } 54 | candle-flash-attn-v1 = { version = "0.0.1" } 55 | half = { version = "2.3.1", features = ["num-traits"] } 56 | 57 | [patch.crates-io] 58 | cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9"} 59 | candle = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-core" } 60 | candle-nn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-nn" } 61 | candle-transformers = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-transformers" } 62 | candle-flash-attn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-flash-attn" } 63 | 64 | [profile.release] 65 | debug = 0 66 | lto = "fat" 67 | opt-level = 3 68 | codegen-units = 1 69 | strip = "symbols" 70 | panic = "abort" 71 | 72 | [profile.release-debug] 73 | inherits = "release" 74 | debug = 1 75 | lto = "thin" 76 | codegen-units = 16 77 | strip = "none" 78 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | integration-tests: 2 | cargo test 3 | 4 | cuda-integration-tests: 5 | cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --profile release-debug 6 | 7 | integration-tests-review: 8 | cargo insta test --review 9 | 10 | cuda-integration-tests-review: 11 | cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --profile release-debug 12 | -------------------------------------------------------------------------------- /assets/bs1-lat.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:778b29d7d21382004fef2c528973f66bb175951ab7cd168d588cd245e36bd629 3 | size 15202 4 | -------------------------------------------------------------------------------- /assets/bs1-tp.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:478984ace4f33044bc0a53b0503a0cbfcd0a64f601922e2a13cc34d52c2b7c2b 3 | size 17169 4 | -------------------------------------------------------------------------------- /assets/bs32-lat.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:769326aad7e582a2e5271dd2d73c3bb5289684add10eb7146ddadd00d3b2077f 3 | size 17596 4 | -------------------------------------------------------------------------------- /assets/bs32-tp.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c227c5adbb8664af7aa3d59aaa408557b2865dcfbd3c6c6353caf71f2eb5b7bc 3 | size 18521 4 | -------------------------------------------------------------------------------- /backends/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-backend" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | clap = { workspace = true, optional = true } 10 | hf-hub = { workspace = true } 11 | serde_json = { workspace = true } 12 | text-embeddings-backend-core = { path = "core" } 13 | text-embeddings-backend-python = { path = "python", optional = true } 14 | text-embeddings-backend-candle = { path = "candle", optional = true } 15 | text-embeddings-backend-ort = { path = "ort", optional = true } 16 | tokio = { workspace = true } 17 | tracing = { workspace = true } 18 | rand = { workspace = true } 19 | 20 | [features] 21 | clap = ["dep:clap", "text-embeddings-backend-core/clap"] 22 | python = ["dep:text-embeddings-backend-python"] 23 | ort = ["dep:text-embeddings-backend-ort"] 24 | candle = ["dep:text-embeddings-backend-candle"] 25 | cuda = ["text-embeddings-backend-candle?/cuda"] 26 | metal = ["text-embeddings-backend-candle?/metal"] 27 | mkl = ["text-embeddings-backend-candle?/mkl"] 28 | accelerate = ["text-embeddings-backend-candle?/accelerate"] 29 | flash-attn = ["text-embeddings-backend-candle?/flash-attn"] 30 | flash-attn-v1 = ["text-embeddings-backend-candle?/flash-attn-v1"] 31 | -------------------------------------------------------------------------------- /backends/candle/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-backend-candle" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | anyhow = { workspace = true } 10 | accelerate-src = { version = "0.3.2", optional = true } 11 | intel-mkl-src = { workspace = true, optional = true } 12 | candle = { workspace = true } 13 | candle-nn = { workspace = true } 14 | candle-transformers = { workspace = true } 15 | candle-flash-attn = { workspace = true, optional = true} 16 | candle-flash-attn-v1 = { workspace = true, optional = true } 17 | candle-cublaslt = { workspace = true, optional = true } 18 | candle-layer-norm = { workspace = true, optional = true } 19 | candle-rotary = { workspace = true, optional = true } 20 | nohash-hasher = { workspace = true } 21 | text-embeddings-backend-core = { path = "../core" } 22 | tracing = { workspace = true } 23 | safetensors = "^0.4" 24 | thiserror = { workspace = true } 25 | serde = { workspace = true } 26 | serde_json = { workspace = true } 27 | memmap2 = "^0.9" 28 | 29 | [dev-dependencies] 30 | insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] } 31 | is_close = "0.1.3" 32 | hf-hub = { workspace = true, features = ["ureq"] } 33 | anyhow = { workspace = true } 34 | tokenizers = { workspace = true } 35 | serial_test = { workspace = true } 36 | 37 | [build-dependencies] 38 | anyhow = { version = "1", features = ["backtrace"] } 39 | 40 | [features] 41 | accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] 42 | metal = ["candle/metal", "candle-nn/metal"] 43 | mkl = ["dep:intel-mkl-src", "candle/_mkl"] 44 | cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"] 45 | flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"] 46 | flash-attn = ["dep:candle-flash-attn", "cuda"] 47 | -------------------------------------------------------------------------------- /backends/candle/build.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{bail, Context, Result}; 2 | 3 | fn main() { 4 | println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); 5 | if let Ok(compute_cap) = set_compute_cap() { 6 | println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap}"); 7 | } 8 | } 9 | 10 | fn set_compute_cap() -> Result { 11 | // Try to parse compute caps from env 12 | let compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { 13 | compute_cap_str 14 | .parse::() 15 | .context("Could not parse code")? 16 | } else { 17 | // Use nvidia-smi to get the current compute cap 18 | let out = std::process::Command::new("nvidia-smi") 19 | .arg("--query-gpu=compute_cap") 20 | .arg("--format=csv") 21 | .output() 22 | .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; 23 | let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; 24 | let mut lines = out.lines(); 25 | if lines.next().context("missing line in stdout")? != "compute_cap" { 26 | bail!("First line should be `compute_cap`"); 27 | } 28 | let cap = lines 29 | .next() 30 | .context("missing line in stdout")? 31 | .replace('.', ""); 32 | cap.parse::().context("cannot parse as int")? 33 | }; 34 | Ok(compute_cap) 35 | } 36 | -------------------------------------------------------------------------------- /backends/candle/src/alibi.rs: -------------------------------------------------------------------------------- 1 | // coding=utf-8 2 | // Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | // Copyright (c) 2023 Jina AI GmbH. All rights reserved. 4 | // 5 | // Licensed under the Apache License, Version 2.0 (the "License"); 6 | // you may not use this file except in compliance with the License. 7 | // You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | use candle::{DType, Device, Result, Tensor}; 17 | 18 | fn get_slopes_power_of_2(n: usize) -> Vec { 19 | let start: f64 = 2_f64.powf(-(2_f64.powf(-((n as f64).log2() - 3_f64)))); 20 | 21 | (0..n).map(|i| start * start.powi(i as i32)).collect() 22 | } 23 | 24 | pub fn alibi_head_slopes(num_attention_heads: usize) -> Vec { 25 | if (num_attention_heads as f64).log2().fract() == 0.0 { 26 | // `num_attention_heads` is a power of 2 27 | get_slopes_power_of_2(num_attention_heads) 28 | } else { 29 | let closest_power_of_2 = 30 | 2_f64.powi((num_attention_heads as f64).log2().floor() as i32) as usize; 31 | 32 | let mut slopes = get_slopes_power_of_2(closest_power_of_2); 33 | let additional_slopes: Vec = get_slopes_power_of_2(2 * closest_power_of_2) 34 | .into_iter() 35 | .enumerate() 36 | // Filter odd indices 37 | .filter(|(i, _)| i % 2 == 0) 38 | // Remove i 39 | .map(|(_, v)| v) 40 | .collect(); 41 | 42 | // Extend slopes 43 | slopes.extend_from_slice(&additional_slopes[0..(num_attention_heads - closest_power_of_2)]); 44 | 45 | slopes 46 | } 47 | } 48 | 49 | pub fn build_alibi_tensor( 50 | num_positions: usize, 51 | num_heads: usize, 52 | device: &Device, 53 | dtype: DType, 54 | ) -> Result { 55 | let context_positions = 56 | Tensor::arange(0.0, num_positions as f64, &Device::Cpu)?.unsqueeze(1)?; 57 | let memory_positions = Tensor::arange(0.0, num_positions as f64, &Device::Cpu)?.unsqueeze(0)?; 58 | 59 | let relative_positions = memory_positions.broadcast_sub(&context_positions)?.abs()?; 60 | // [num_heads, num_positions, num_positions] 61 | let relative_positions = 62 | relative_positions 63 | .unsqueeze(0)? 64 | .expand((num_heads, num_positions, num_positions))?; 65 | 66 | // [num_heads, 1, 1] 67 | let slopes = (Tensor::from_vec( 68 | alibi_head_slopes(num_heads), 69 | (num_heads, 1, 1), 70 | &Device::Cpu, 71 | )? * -1_f64)?; 72 | 73 | // [num_heads, num_positions, num_positions] 74 | let alibi = relative_positions.broadcast_mul(&slopes)?; 75 | 76 | alibi 77 | .reshape((1, num_heads, num_positions, num_positions))? 78 | .to_dtype(dtype)? 79 | .to_device(device) 80 | } 81 | -------------------------------------------------------------------------------- /backends/candle/src/compute_cap.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Context; 2 | use candle::cuda_backend::cudarc::driver; 3 | use candle::cuda_backend::cudarc::driver::sys::CUdevice_attribute::{ 4 | CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 5 | }; 6 | use candle::cuda_backend::cudarc::driver::CudaDevice; 7 | 8 | pub fn get_compile_compute_cap() -> Result { 9 | env!("CUDA_COMPUTE_CAP") 10 | .parse::() 11 | .context("Could not retrieve compile time CUDA_COMPUTE_CAP") 12 | } 13 | 14 | pub fn get_runtime_compute_cap() -> Result { 15 | driver::result::init().context("CUDA is not available")?; 16 | let device = CudaDevice::new(0).context("CUDA is not available")?; 17 | let major = device 18 | .attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR) 19 | .context("Could not retrieve device compute capability major")?; 20 | let minor = device 21 | .attribute(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR) 22 | .context("Could not retrieve device compute capability minor")?; 23 | Ok((major * 10 + minor) as usize) 24 | } 25 | 26 | fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) -> bool { 27 | match (runtime_compute_cap, compile_compute_cap) { 28 | (75, 75) => true, 29 | (80..=89, 80) => true, 30 | (86..=89, 80..=86) => true, 31 | (89, 89) => true, 32 | (90, 90) => true, 33 | (_, _) => false, 34 | } 35 | } 36 | 37 | pub fn compatible_compute_cap() -> Result { 38 | let compile_compute_cap = get_compile_compute_cap()?; 39 | let runtime_compute_cap = get_runtime_compute_cap()?; 40 | Ok(compute_cap_matching( 41 | runtime_compute_cap, 42 | compile_compute_cap, 43 | )) 44 | } 45 | 46 | #[cfg(test)] 47 | mod tests { 48 | use crate::compute_cap::compute_cap_matching; 49 | 50 | #[test] 51 | fn test_compute_cap() { 52 | assert!(compute_cap_matching(75, 75)); 53 | assert!(compute_cap_matching(80, 80)); 54 | assert!(compute_cap_matching(86, 86)); 55 | assert!(compute_cap_matching(89, 89)); 56 | assert!(compute_cap_matching(90, 90)); 57 | 58 | assert!(compute_cap_matching(86, 80)); 59 | assert!(compute_cap_matching(89, 80)); 60 | assert!(compute_cap_matching(89, 86)); 61 | 62 | assert!(!compute_cap_matching(75, 80)); 63 | assert!(!compute_cap_matching(75, 86)); 64 | assert!(!compute_cap_matching(75, 89)); 65 | assert!(!compute_cap_matching(75, 90)); 66 | 67 | assert!(!compute_cap_matching(80, 75)); 68 | assert!(!compute_cap_matching(80, 86)); 69 | assert!(!compute_cap_matching(80, 89)); 70 | assert!(!compute_cap_matching(80, 90)); 71 | 72 | assert!(!compute_cap_matching(86, 75)); 73 | assert!(!compute_cap_matching(86, 89)); 74 | assert!(!compute_cap_matching(86, 90)); 75 | 76 | assert!(!compute_cap_matching(89, 75)); 77 | assert!(!compute_cap_matching(89, 90)); 78 | 79 | assert!(!compute_cap_matching(90, 75)); 80 | assert!(!compute_cap_matching(90, 80)); 81 | assert!(!compute_cap_matching(90, 86)); 82 | assert!(!compute_cap_matching(90, 89)); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /backends/candle/src/flash_attn.rs: -------------------------------------------------------------------------------- 1 | use candle::Tensor; 2 | use std::sync::Once; 3 | 4 | static INIT: Once = Once::new(); 5 | static mut RUNTIME_COMPUTE_CAP: usize = 0; 6 | fn init_runtime_compute_cap() { 7 | unsafe { 8 | INIT.call_once(|| { 9 | use crate::compute_cap::get_runtime_compute_cap; 10 | RUNTIME_COMPUTE_CAP = get_runtime_compute_cap().unwrap(); 11 | }); 12 | } 13 | } 14 | 15 | pub fn get_runtime_compute_cap() -> usize { 16 | unsafe { 17 | init_runtime_compute_cap(); 18 | RUNTIME_COMPUTE_CAP 19 | } 20 | } 21 | 22 | #[allow(clippy::too_many_arguments, unused)] 23 | pub(crate) fn flash_attn_varlen( 24 | q: &Tensor, 25 | k: &Tensor, 26 | v: &Tensor, 27 | alibi_slopes: Option<&Tensor>, 28 | seqlens_q: &Tensor, 29 | seqlens_k: &Tensor, 30 | max_seqlen_q: usize, 31 | max_seqlen_k: usize, 32 | softmax_scale: f32, 33 | causal: bool, 34 | window_size_left: Option, 35 | window_size_right: Option, 36 | ) -> Result { 37 | let runtime_compute_cap = get_runtime_compute_cap(); 38 | 39 | if runtime_compute_cap == 75 { 40 | if alibi_slopes.is_some() { 41 | candle::bail!("Flash attention v1 does not support alibi"); 42 | } 43 | if window_size_left.is_some() | window_size_right.is_some() { 44 | candle::bail!("Flash attention v1 does not support attention windowing"); 45 | } 46 | 47 | #[cfg(feature = "flash-attn-v1")] 48 | { 49 | use candle_flash_attn_v1::flash_attn_varlen; 50 | return flash_attn_varlen( 51 | q, 52 | k, 53 | v, 54 | seqlens_q, 55 | seqlens_k, 56 | max_seqlen_q, 57 | max_seqlen_k, 58 | softmax_scale, 59 | causal, 60 | ); 61 | } 62 | #[cfg(not(feature = "flash-attn-v1"))] 63 | candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.") 64 | } else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 { 65 | #[cfg(feature = "flash-attn")] 66 | { 67 | use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; 68 | 69 | let window_size_right = if causal { 70 | Some(0) 71 | } else if window_size_right.is_some() { 72 | window_size_right 73 | } else { 74 | None 75 | }; 76 | 77 | let attention = if let Some(alibi_slopes) = alibi_slopes { 78 | flash_attn_varlen_alibi_windowed( 79 | q, 80 | k, 81 | v, 82 | alibi_slopes, 83 | seqlens_q, 84 | seqlens_k, 85 | max_seqlen_q, 86 | max_seqlen_k, 87 | softmax_scale, 88 | window_size_left, 89 | window_size_right, 90 | ) 91 | } else { 92 | flash_attn_varlen_windowed( 93 | q, 94 | k, 95 | v, 96 | seqlens_q, 97 | seqlens_k, 98 | max_seqlen_q, 99 | max_seqlen_k, 100 | softmax_scale, 101 | window_size_left, 102 | window_size_right, 103 | ) 104 | }; 105 | 106 | return attention; 107 | } 108 | #[cfg(not(feature = "flash-attn"))] 109 | candle::bail!("Flash attention is not installed. Use `flash-attn` feature.") 110 | } 111 | candle::bail!( 112 | "GPU with CUDA capability {} is not supported", 113 | runtime_compute_cap 114 | ); 115 | } 116 | -------------------------------------------------------------------------------- /backends/candle/src/layers/cublaslt.rs: -------------------------------------------------------------------------------- 1 | use crate::layers::HiddenAct; 2 | use candle::{Device, Result, Tensor}; 3 | use std::sync::Once; 4 | 5 | #[cfg(feature = "cuda")] 6 | use candle_cublaslt::{fused_batch_matmul, fused_matmul, Activation, CublasLt}; 7 | 8 | static INIT: Once = Once::new(); 9 | static mut CUBLASLT: Option = None; 10 | 11 | pub fn get_cublas_lt_wrapper() -> Option<&'static CublasLtWrapper> { 12 | unsafe { 13 | INIT.call_once(|| { 14 | #[cfg(not(feature = "cuda"))] 15 | { 16 | CUBLASLT = None; 17 | } 18 | 19 | #[cfg(feature = "cuda")] 20 | { 21 | // Check if we can call the driver 22 | // Then check if we can create a device 23 | // Then check that the device is CUDA 24 | use candle::cuda_backend::cudarc::driver; 25 | CUBLASLT = driver::result::init() 26 | .ok() 27 | .and_then(|_| Device::cuda_if_available(0).ok()) 28 | .and_then(|device| match device { 29 | Device::Cuda(_) => Some(CublasLtWrapper { 30 | cublaslt: CublasLt::new(&device).unwrap(), 31 | }), 32 | _ => None, 33 | }); 34 | } 35 | }); 36 | #[allow(static_mut_refs)] 37 | CUBLASLT.as_ref() 38 | } 39 | } 40 | 41 | #[derive(Debug, Clone)] 42 | pub struct CublasLtWrapper { 43 | #[cfg(feature = "cuda")] 44 | pub cublaslt: CublasLt, 45 | } 46 | 47 | impl CublasLtWrapper { 48 | #[allow(clippy::too_many_arguments)] 49 | pub fn matmul( 50 | &self, 51 | a: &Tensor, 52 | b: &Tensor, 53 | out: Option<&Tensor>, 54 | alpha: Option, 55 | beta: Option, 56 | bias: Option<&Tensor>, 57 | act: Option, 58 | ) -> Result { 59 | #[cfg(feature = "cuda")] 60 | { 61 | let inner_act = match act { 62 | Some(HiddenAct::Gelu) => Some(Activation::Gelu), 63 | Some(HiddenAct::Relu) => Some(Activation::Relu), 64 | _ => None, 65 | }; 66 | 67 | let mut result = fused_matmul( 68 | a, 69 | b, 70 | out, 71 | alpha, 72 | beta, 73 | bias, 74 | inner_act, 75 | self.cublaslt.clone(), 76 | )?; 77 | 78 | if Some(HiddenAct::Swiglu) == act { 79 | result = candle_nn::ops::swiglu(&result)?; 80 | } 81 | Ok(result) 82 | } 83 | #[cfg(not(feature = "cuda"))] 84 | { 85 | candle::bail!("`cuda` feature is not enabled") 86 | } 87 | } 88 | 89 | #[allow(clippy::too_many_arguments)] 90 | pub fn batch_matmul( 91 | &self, 92 | a: &Tensor, 93 | b: &Tensor, 94 | out: Option<&Tensor>, 95 | alpha: Option, 96 | beta: Option, 97 | bias: Option<&Tensor>, 98 | act: Option, 99 | ) -> Result { 100 | #[cfg(feature = "cuda")] 101 | { 102 | let inner_act = match act { 103 | Some(HiddenAct::Gelu) => Some(Activation::Gelu), 104 | Some(HiddenAct::Relu) => Some(Activation::Relu), 105 | _ => None, 106 | }; 107 | 108 | let mut result = fused_batch_matmul( 109 | a, 110 | b, 111 | out, 112 | alpha, 113 | beta, 114 | bias, 115 | inner_act, 116 | self.cublaslt.clone(), 117 | )?; 118 | 119 | if Some(HiddenAct::Swiglu) == act { 120 | result = candle_nn::ops::swiglu(&result)?; 121 | } 122 | Ok(result) 123 | } 124 | #[cfg(not(feature = "cuda"))] 125 | { 126 | candle::bail!("`cuda` feature is not enabled") 127 | } 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /backends/candle/src/layers/linear.rs: -------------------------------------------------------------------------------- 1 | use crate::layers::cublaslt::get_cublas_lt_wrapper; 2 | use candle::{Device, Result, Tensor}; 3 | use serde::Deserialize; 4 | 5 | #[derive(Debug, Deserialize, PartialEq, Clone)] 6 | #[serde(rename_all = "lowercase")] 7 | pub enum HiddenAct { 8 | Gelu, 9 | Relu, 10 | #[serde(alias = "silu")] 11 | Swiglu, 12 | } 13 | 14 | impl HiddenAct { 15 | pub fn forward(&self, x: &Tensor) -> Result { 16 | match self { 17 | Self::Gelu => x.gelu(), 18 | Self::Relu => x.relu(), 19 | Self::Swiglu => candle_nn::ops::swiglu(x), 20 | } 21 | } 22 | } 23 | 24 | #[derive(Debug)] 25 | pub struct Linear { 26 | weight: Tensor, 27 | bias: Option, 28 | act: Option, 29 | span: tracing::Span, 30 | } 31 | 32 | impl Linear { 33 | pub fn new(weight: Tensor, bias: Option, act: Option) -> Self { 34 | let span = tracing::span!(tracing::Level::TRACE, "linear"); 35 | 36 | Self { 37 | weight, 38 | bias, 39 | act, 40 | span, 41 | } 42 | } 43 | 44 | pub fn forward(&self, x: &Tensor) -> Result { 45 | let _enter = self.span.enter(); 46 | 47 | #[allow(unused)] 48 | if let (Device::Cuda(_), Some(cublaslt)) = (x.device(), get_cublas_lt_wrapper()) { 49 | match x.dims() { 50 | &[bsize, _, _] => cublaslt.batch_matmul( 51 | &self.weight.broadcast_left(bsize)?, 52 | x, 53 | None, 54 | None, 55 | None, 56 | self.bias.as_ref(), 57 | self.act.clone(), 58 | ), 59 | _ => cublaslt.matmul( 60 | &self.weight, 61 | x, 62 | None, 63 | None, 64 | None, 65 | self.bias.as_ref(), 66 | self.act.clone(), 67 | ), 68 | } 69 | } else { 70 | let w = match x.dims() { 71 | &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, 72 | _ => self.weight.t()?, 73 | }; 74 | let x = x.matmul(&w)?; 75 | let x = match &self.bias { 76 | None => Ok(x), 77 | Some(bias) => x.broadcast_add(bias), 78 | }?; 79 | if let Some(act) = &self.act { 80 | match act { 81 | HiddenAct::Gelu => x.gelu(), 82 | HiddenAct::Relu => x.relu(), 83 | HiddenAct::Swiglu => candle_nn::ops::swiglu(&x), 84 | } 85 | } else { 86 | Ok(x) 87 | } 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /backends/candle/src/layers/mod.rs: -------------------------------------------------------------------------------- 1 | #[allow(dead_code, unused)] 2 | mod cublaslt; 3 | mod layer_norm; 4 | mod linear; 5 | #[allow(dead_code, unused)] 6 | mod rms_norm; 7 | mod rotary; 8 | 9 | pub use cublaslt::get_cublas_lt_wrapper; 10 | pub use layer_norm::{LayerNorm, LayerNormNoBias}; 11 | pub use linear::{HiddenAct, Linear}; 12 | #[allow(unused_imports)] 13 | pub use rms_norm::RMSNorm; 14 | pub use rotary::{apply_rotary, get_cos_sin, get_inv_freqs, RopeScaling}; 15 | -------------------------------------------------------------------------------- /backends/candle/src/layers/rms_norm.rs: -------------------------------------------------------------------------------- 1 | use candle::{DType, Device, Result, Tensor, D}; 2 | use candle_nn::VarBuilder; 3 | 4 | #[derive(Debug)] 5 | pub struct RMSNorm { 6 | weight: Tensor, 7 | epsilon: f32, 8 | span: tracing::Span, 9 | } 10 | 11 | impl RMSNorm { 12 | pub fn load(vb: VarBuilder, hidden_size: usize, epsilon: f32) -> Result { 13 | Ok(Self { 14 | weight: vb 15 | .get(hidden_size, "weight") 16 | .or_else(|_| vb.get(hidden_size, "gamma"))?, 17 | epsilon, 18 | span: tracing::span!(tracing::Level::TRACE, "rms-norm"), 19 | }) 20 | } 21 | 22 | pub fn forward( 23 | &self, 24 | hidden_states: &Tensor, 25 | residual: Option<&Tensor>, 26 | ) -> Result<(Tensor, Tensor)> { 27 | let _enter = self.span.enter(); 28 | 29 | match hidden_states.device() { 30 | Device::Cpu | Device::Metal(_) => { 31 | let mut hidden_states = hidden_states.clone(); 32 | let residual_add = if let Some(residual) = residual { 33 | let residual_add = hidden_states.add(residual)?; 34 | hidden_states = residual_add.clone(); 35 | residual_add 36 | } else { 37 | hidden_states.clone() 38 | }; 39 | 40 | let hidden_states_dtype = hidden_states.dtype(); 41 | let internal_dtype = match hidden_states_dtype { 42 | DType::F16 | DType::BF16 => DType::F32, 43 | d => d, 44 | }; 45 | let hidden_size = hidden_states.dim(D::Minus1)?; 46 | let hidden_states = hidden_states.to_dtype(internal_dtype)?; 47 | let norm_hidden_states = 48 | (hidden_states.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; 49 | let hidden_states_normed = hidden_states 50 | .broadcast_div(&(norm_hidden_states + self.epsilon as f64)?.sqrt()?)?; 51 | Ok(( 52 | hidden_states_normed 53 | .to_dtype(hidden_states_dtype)? 54 | .broadcast_mul(&self.weight)?, 55 | residual_add, 56 | )) 57 | } 58 | Device::Cuda(_) => { 59 | #[cfg(feature = "cuda")] 60 | { 61 | use candle_layer_norm::{fused_add_rms_norm, rms_norm}; 62 | 63 | let original_shape = hidden_states.shape(); 64 | let hidden_states = hidden_states.flatten_to(D::Minus2)?; 65 | 66 | if let Some(residual) = residual { 67 | let residual = residual.flatten_to(D::Minus2)?; 68 | 69 | let (result, residual_add) = fused_add_rms_norm( 70 | &hidden_states, 71 | &residual, 72 | &self.weight, 73 | None, 74 | self.epsilon, 75 | )?; 76 | Ok(( 77 | result.reshape(original_shape)?, 78 | residual_add.reshape(original_shape)?, 79 | )) 80 | } else { 81 | let residual_add = hidden_states.clone(); 82 | 83 | let result = rms_norm(&hidden_states, &self.weight, None, self.epsilon)?; 84 | 85 | Ok(( 86 | result.reshape(original_shape)?, 87 | residual_add.reshape(original_shape)?, 88 | )) 89 | } 90 | } 91 | #[cfg(not(feature = "cuda"))] 92 | candle::bail!("`cuda` feature is not enabled") 93 | } 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /backends/candle/src/layers/rotary.rs: -------------------------------------------------------------------------------- 1 | use candle::{DType, Device, Result, Tensor, D}; 2 | use serde::Deserialize; 3 | 4 | #[derive(Debug, Clone, PartialEq, Deserialize)] 5 | pub struct NTKScaling { 6 | pub factor: f32, 7 | } 8 | 9 | #[derive(Debug, Clone, PartialEq, Deserialize)] 10 | #[serde(tag = "type", rename_all = "kebab-case")] 11 | pub enum RopeScaling { 12 | Ntk(NTKScaling), 13 | } 14 | 15 | pub fn get_inv_freqs( 16 | dim: usize, 17 | base: f32, 18 | device: &Device, 19 | rope_scaling: Option<&RopeScaling>, 20 | ) -> Result { 21 | let get_inv_freqs_inner = |dim: usize, base: f32, device: &Device| { 22 | let inv_freq: Vec<_> = (0..dim) 23 | .step_by(2) 24 | .map(|i| 1f32 / base.powf(i as f32 / dim as f32)) 25 | .collect(); 26 | let inv_freq_len = inv_freq.len(); 27 | Tensor::from_vec(inv_freq, (1, inv_freq_len), device) 28 | }; 29 | 30 | if let Some(rope_scaling) = rope_scaling { 31 | match rope_scaling { 32 | RopeScaling::Ntk(ntk_scaling) => { 33 | let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?; 34 | let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64; 35 | return inv_freqs / s; 36 | } 37 | } 38 | } 39 | get_inv_freqs_inner(dim, base, device) 40 | } 41 | 42 | pub fn get_cos_sin( 43 | length: usize, 44 | inv_freqs: &Tensor, 45 | dtype: DType, 46 | repeat_freqs: bool, 47 | ) -> Result<(Tensor, Tensor)> { 48 | let t = Tensor::arange(0u32, length as u32, inv_freqs.device())? 49 | .to_dtype(DType::F32)? 50 | .reshape((length, 1))?; 51 | let mut freqs = t.matmul(inv_freqs)?; 52 | if repeat_freqs { 53 | freqs = Tensor::cat(&[&freqs, &freqs], 1)?; 54 | } 55 | 56 | let cos = freqs.cos()?.to_dtype(dtype)?; 57 | let sin = freqs.sin()?.to_dtype(dtype)?; 58 | Ok((cos, sin)) 59 | } 60 | 61 | pub fn apply_rotary( 62 | x: &Tensor, 63 | cos: &Tensor, 64 | sin: &Tensor, 65 | attention_head_size: usize, 66 | ) -> Result { 67 | let dim = attention_head_size / 2; 68 | let x1 = x.narrow(D::Minus1, 0, dim)?; 69 | let x2 = x.narrow(D::Minus1, dim, dim)?; 70 | let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; 71 | let rope = (x.broadcast_mul(cos)? + rotate_x.broadcast_mul(sin)?)?; 72 | Ok(rope) 73 | } 74 | -------------------------------------------------------------------------------- /backends/candle/src/models/mistral.rs: -------------------------------------------------------------------------------- 1 | use crate::layers::HiddenAct; 2 | use serde::Deserialize; 3 | 4 | #[derive(Debug, Clone, PartialEq, Deserialize)] 5 | pub struct MistralConfig { 6 | pub vocab_size: usize, 7 | pub hidden_size: usize, 8 | pub intermediate_size: usize, 9 | pub num_hidden_layers: usize, 10 | pub num_attention_heads: usize, 11 | pub num_key_value_heads: usize, 12 | pub hidden_act: HiddenAct, 13 | pub max_position_embeddings: usize, 14 | pub initializer_range: f64, 15 | pub rms_norm_eps: f32, 16 | pub model_type: Option, 17 | pub rope_theta: f32, 18 | pub sliding_window: Option, 19 | } 20 | -------------------------------------------------------------------------------- /backends/candle/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "mkl")] 2 | extern crate intel_mkl_src; 3 | 4 | #[cfg(feature = "accelerate")] 5 | extern crate accelerate_src; 6 | 7 | mod bert; 8 | mod distilbert; 9 | mod jina; 10 | mod jina_code; 11 | mod mistral; 12 | mod modernbert; 13 | mod nomic; 14 | 15 | #[cfg(feature = "cuda")] 16 | mod flash_bert; 17 | 18 | #[cfg(feature = "cuda")] 19 | mod flash_jina; 20 | 21 | #[cfg(feature = "cuda")] 22 | mod flash_jina_code; 23 | 24 | #[cfg(feature = "cuda")] 25 | mod flash_nomic; 26 | 27 | #[cfg(feature = "cuda")] 28 | mod flash_distilbert; 29 | 30 | #[cfg(feature = "cuda")] 31 | mod flash_gte; 32 | 33 | #[cfg(feature = "cuda")] 34 | mod flash_mistral; 35 | 36 | #[cfg(feature = "cuda")] 37 | mod flash_qwen2; 38 | 39 | #[cfg(feature = "cuda")] 40 | mod flash_modernbert; 41 | 42 | mod gte; 43 | mod mpnet; 44 | mod qwen2; 45 | 46 | pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; 47 | use candle::{Result, Tensor}; 48 | pub use distilbert::{DistilBertConfig, DistilBertModel}; 49 | #[allow(unused_imports)] 50 | pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP}; 51 | pub use jina::JinaBertModel; 52 | pub use jina_code::JinaCodeBertModel; 53 | pub use mistral::MistralConfig; 54 | pub use modernbert::{ModernBertConfig, ModernBertModel}; 55 | pub use mpnet::{MPNetConfig, MPNetModel}; 56 | pub use nomic::{NomicBertModel, NomicConfig}; 57 | pub use qwen2::Qwen2Config; 58 | use text_embeddings_backend_core::Batch; 59 | 60 | #[cfg(feature = "cuda")] 61 | pub use flash_bert::FlashBertModel; 62 | 63 | #[cfg(feature = "cuda")] 64 | pub use flash_jina::FlashJinaBertModel; 65 | 66 | #[cfg(feature = "cuda")] 67 | pub use flash_jina_code::FlashJinaCodeBertModel; 68 | 69 | #[cfg(feature = "cuda")] 70 | pub use flash_nomic::FlashNomicBertModel; 71 | 72 | #[cfg(feature = "cuda")] 73 | pub use flash_distilbert::FlashDistilBertModel; 74 | 75 | #[cfg(feature = "cuda")] 76 | pub use flash_mistral::FlashMistralModel; 77 | 78 | #[cfg(feature = "cuda")] 79 | pub use flash_gte::FlashGTEModel; 80 | 81 | #[cfg(feature = "cuda")] 82 | pub use flash_qwen2::FlashQwen2Model; 83 | 84 | #[cfg(feature = "cuda")] 85 | pub use flash_modernbert::FlashModernBertModel; 86 | 87 | pub(crate) trait Model { 88 | fn is_padded(&self) -> bool; 89 | 90 | fn embed(&self, _batch: Batch) -> Result<(Option, Option)> { 91 | candle::bail!("`embed` is not implemented for this model"); 92 | } 93 | 94 | fn predict(&self, _batch: Batch) -> Result { 95 | candle::bail!("`predict` is not implemented for this model"); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /backends/candle/src/models/qwen2.rs: -------------------------------------------------------------------------------- 1 | use crate::layers::HiddenAct; 2 | use serde::Deserialize; 3 | 4 | #[derive(Debug, Clone, PartialEq, Deserialize)] 5 | pub struct Qwen2Config { 6 | pub vocab_size: usize, 7 | pub hidden_size: usize, 8 | pub intermediate_size: usize, 9 | pub num_hidden_layers: usize, 10 | pub num_attention_heads: usize, 11 | pub num_key_value_heads: usize, 12 | pub hidden_act: HiddenAct, 13 | pub max_position_embeddings: usize, 14 | pub rms_norm_eps: f32, 15 | pub rope_theta: f32, 16 | pub sliding_window: Option, 17 | pub use_sliding_window: bool, 18 | } 19 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_bert__bert_classification_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_bert.rs 3 | assertion_line: 211 4 | expression: predictions_single 5 | --- 6 | - - 2.8580017 7 | - -2.9722357 8 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_bert__emotions_batch.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_bert.rs 3 | expression: predictions_batch 4 | --- 5 | - - -6.548559 6 | - -6.302024 7 | - -4.8671727 8 | - -3.9600255 9 | - -4.6329865 10 | - -6.2816987 11 | - -6.069644 12 | - -5.7742686 13 | - -6.9259467 14 | - -6.1909447 15 | - -5.67395 16 | - -6.1698227 17 | - -7.513461 18 | - -6.865867 19 | - -7.186479 20 | - -7.128109 21 | - -8.210709 22 | - -7.0171394 23 | - -7.1321163 24 | - -8.533409 25 | - -6.2294865 26 | - -8.742306 27 | - -5.7792044 28 | - -8.657227 29 | - -8.258305 30 | - -6.64832 31 | - -7.4060283 32 | - 3.046496 33 | - - -5.8167515 34 | - -6.6119466 35 | - -5.2771955 36 | - -2.6306503 37 | - -4.6419163 38 | - -5.579778 39 | - -5.797174 40 | - -6.0305815 41 | - -5.8720746 42 | - 0.45377323 43 | - -3.0235887 44 | - -5.3944407 45 | - -5.186683 46 | - -6.2649117 47 | - -6.1962767 48 | - -6.97937 49 | - -5.5674877 50 | - -5.521044 51 | - -5.8899207 52 | - -4.8699703 53 | - -5.6259933 54 | - -7.6109924 55 | - -4.3881936 56 | - -6.039008 57 | - -4.934696 58 | - -0.6715916 59 | - -6.399376 60 | - -2.4499295 61 | - - -6.548559 62 | - -6.302024 63 | - -4.8671727 64 | - -3.9600255 65 | - -4.6329865 66 | - -6.2816987 67 | - -6.069644 68 | - -5.7742686 69 | - -6.9259467 70 | - -6.1909447 71 | - -5.67395 72 | - -6.1698227 73 | - -7.513461 74 | - -6.865867 75 | - -7.186479 76 | - -7.128109 77 | - -8.210709 78 | - -7.0171394 79 | - -7.1321163 80 | - -8.533409 81 | - -6.2294865 82 | - -8.742306 83 | - -5.7792044 84 | - -8.657227 85 | - -8.258305 86 | - -6.64832 87 | - -7.4060283 88 | - 3.046496 89 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_bert__emotions_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_bert.rs 3 | expression: predictions_single 4 | --- 5 | - - -6.548218 6 | - -6.3022184 7 | - -4.866901 8 | - -3.9598548 9 | - -4.6330333 10 | - -6.281371 11 | - -6.070333 12 | - -5.7753787 13 | - -6.9252844 14 | - -6.1905437 15 | - -5.674121 16 | - -6.169409 17 | - -7.5133595 18 | - -6.8658547 19 | - -7.185884 20 | - -7.1283603 21 | - -8.210392 22 | - -7.016874 23 | - -7.1315 24 | - -8.53309 25 | - -6.229343 26 | - -8.741868 27 | - -5.7791805 28 | - -8.657056 29 | - -8.258206 30 | - -6.6477957 31 | - -7.406438 32 | - 3.0466576 33 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_flash_bert__bert_classification_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_flash_bert.rs 3 | expression: predictions_single 4 | --- 5 | - - 2.8574219 6 | - -2.9726563 7 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_flash_bert__emotions_batch.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_flash_bert.rs 3 | expression: predictions_batch 4 | --- 5 | - - -6.5507813 6 | - -6.3007813 7 | - -4.8671875 8 | - -3.9589844 9 | - -4.6328125 10 | - -6.28125 11 | - -6.0703125 12 | - -5.7773438 13 | - -6.9257813 14 | - -6.1875 15 | - -5.671875 16 | - -6.1679688 17 | - -7.5117188 18 | - -6.8671875 19 | - -7.1835938 20 | - -7.1289063 21 | - -8.2109375 22 | - -7.015625 23 | - -7.1328125 24 | - -8.53125 25 | - -6.2304688 26 | - -8.7421875 27 | - -5.7773438 28 | - -8.65625 29 | - -8.2578125 30 | - -6.6484375 31 | - -7.40625 32 | - 3.046875 33 | - - -5.8164063 34 | - -6.6132813 35 | - -5.2773438 36 | - -2.6328125 37 | - -4.640625 38 | - -5.578125 39 | - -5.8007813 40 | - -6.03125 41 | - -5.8710938 42 | - 0.45166016 43 | - -3.0253906 44 | - -5.3945313 45 | - -5.1875 46 | - -6.265625 47 | - -6.1992188 48 | - -6.9804688 49 | - -5.5664063 50 | - -5.5195313 51 | - -5.890625 52 | - -4.8710938 53 | - -5.625 54 | - -7.609375 55 | - -4.3867188 56 | - -6.0390625 57 | - -4.9375 58 | - -0.6699219 59 | - -6.4023438 60 | - -2.4492188 61 | - - -6.5507813 62 | - -6.3007813 63 | - -4.8671875 64 | - -3.9589844 65 | - -4.6328125 66 | - -6.28125 67 | - -6.0703125 68 | - -5.7773438 69 | - -6.9257813 70 | - -6.1875 71 | - -5.671875 72 | - -6.1679688 73 | - -7.5117188 74 | - -6.8671875 75 | - -7.1835938 76 | - -7.1289063 77 | - -8.2109375 78 | - -7.015625 79 | - -7.1328125 80 | - -8.53125 81 | - -6.2304688 82 | - -8.7421875 83 | - -5.7773438 84 | - -8.65625 85 | - -8.2578125 86 | - -6.6484375 87 | - -7.40625 88 | - 3.046875 89 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_flash_bert__emotions_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_flash_bert.rs 3 | expression: predictions_single 4 | --- 5 | - - -6.546875 6 | - -6.3007813 7 | - -4.8671875 8 | - -3.9589844 9 | - -4.6328125 10 | - -6.28125 11 | - -6.0703125 12 | - -5.7773438 13 | - -6.9257813 14 | - -6.1914063 15 | - -5.671875 16 | - -6.1679688 17 | - -7.5117188 18 | - -6.8671875 19 | - -7.1835938 20 | - -7.1289063 21 | - -8.2109375 22 | - -7.015625 23 | - -7.1328125 24 | - -8.53125 25 | - -6.2304688 26 | - -8.7421875 27 | - -5.7773438 28 | - -8.65625 29 | - -8.2578125 30 | - -6.6484375 31 | - -7.40625 32 | - 3.046875 33 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_flash_gte__gte_classification_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_flash_gte.rs 3 | expression: predictions_single 4 | --- 5 | - - -0.7426758 6 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_gte__gte_classification_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_gte.rs 3 | expression: predictions_single 4 | --- 5 | - - -0.74173266 6 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_jina__jinabert_reranker_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_jina.rs 3 | expression: predictions 4 | --- 5 | - - -0.6045344 6 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_modernbert__modernbert_classification_mean_pooling.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_modernbert.rs 3 | expression: predictions_single 4 | --- 5 | - - -0.30617672 6 | -------------------------------------------------------------------------------- /backends/candle/tests/snapshots/test_modernbert__modernbert_classification_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: backends/candle/tests/test_modernbert.rs 3 | expression: predictions_single 4 | --- 5 | - - 2.13616 6 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_gte.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | mod common; 3 | 4 | use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; 5 | use anyhow::Result; 6 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; 7 | use text_embeddings_backend_candle::CandleBackend; 8 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 9 | 10 | #[test] 11 | #[serial_test::serial] 12 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 13 | fn test_flash_gte() -> Result<()> { 14 | let model_root = download_artifacts("Alibaba-NLP/gte-base-en-v1.5", None)?; 15 | let tokenizer = load_tokenizer(&model_root)?; 16 | 17 | let backend = CandleBackend::new( 18 | &model_root, 19 | "float16".to_string(), 20 | ModelType::Embedding(Pool::Cls), 21 | )?; 22 | 23 | let input_batch = batch( 24 | vec![ 25 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 26 | tokenizer.encode("Deep Learning is...", true).unwrap(), 27 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 28 | ], 29 | [0, 1, 2].to_vec(), 30 | vec![], 31 | ); 32 | 33 | let matcher = cosine_matcher(); 34 | 35 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 36 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 37 | insta::assert_yaml_snapshot!("gte_batch", embeddings_batch, &matcher); 38 | 39 | let input_single = batch( 40 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 41 | [0].to_vec(), 42 | vec![], 43 | ); 44 | 45 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 46 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 47 | 48 | insta::assert_yaml_snapshot!("gte_single", embeddings_single, &matcher); 49 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 50 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 51 | 52 | Ok(()) 53 | } 54 | 55 | #[test] 56 | #[serial_test::serial] 57 | #[cfg(all( 58 | feature = "cuda", 59 | any(feature = "flash-attn", feature = "flash-attn-v1") 60 | ))] 61 | fn test_flash_gte_classification() -> Result<()> { 62 | let model_root = download_artifacts("Alibaba-NLP/gte-multilingual-reranker-base", None)?; 63 | let tokenizer = load_tokenizer(&model_root)?; 64 | 65 | let backend = CandleBackend::new(&model_root, "float16".to_string(), ModelType::Classifier)?; 66 | 67 | let input_single = batch( 68 | vec![tokenizer 69 | .encode(("What is Deep Learning?", "Deep Learning is not..."), true) 70 | .unwrap()], 71 | [0].to_vec(), 72 | vec![], 73 | ); 74 | 75 | let predictions: Vec> = backend 76 | .predict(input_single)? 77 | .into_iter() 78 | .map(|(_, v)| v) 79 | .collect(); 80 | let predictions_single = SnapshotScores::from(predictions); 81 | 82 | let matcher = relative_matcher(); 83 | insta::assert_yaml_snapshot!("gte_classification_single", predictions_single, &matcher); 84 | 85 | Ok(()) 86 | } 87 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_jina.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | mod common; 3 | 4 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 5 | use anyhow::Result; 6 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 7 | use text_embeddings_backend_candle::CandleBackend; 8 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 9 | 10 | #[test] 11 | #[serial_test::serial] 12 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 13 | fn test_flash_jina_small() -> Result<()> { 14 | let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en", None)?; 15 | let tokenizer = load_tokenizer(&model_root)?; 16 | 17 | let backend = CandleBackend::new( 18 | &model_root, 19 | "float16".to_string(), 20 | ModelType::Embedding(Pool::Mean), 21 | )?; 22 | 23 | let input_batch = batch( 24 | vec![ 25 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 26 | tokenizer.encode("Deep Learning is...", true).unwrap(), 27 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 28 | ], 29 | [0, 1, 2].to_vec(), 30 | vec![], 31 | ); 32 | 33 | let matcher = cosine_matcher(); 34 | 35 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 36 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 37 | insta::assert_yaml_snapshot!("jina_batch", embeddings_batch, &matcher); 38 | 39 | let input_single = batch( 40 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 41 | [0].to_vec(), 42 | vec![], 43 | ); 44 | 45 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 46 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 47 | 48 | insta::assert_yaml_snapshot!("jina_single", embeddings_single, &matcher); 49 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 50 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 51 | 52 | Ok(()) 53 | } 54 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_jina_code.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | mod common; 3 | 4 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 5 | use anyhow::Result; 6 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 7 | use text_embeddings_backend_candle::CandleBackend; 8 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 9 | 10 | #[test] 11 | #[serial_test::serial] 12 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 13 | fn test_flash_jina_code_base() -> Result<()> { 14 | let model_root = download_artifacts("jinaai/jina-embeddings-v2-base-code", None)?; 15 | let tokenizer = load_tokenizer(&model_root)?; 16 | 17 | let backend = CandleBackend::new( 18 | &model_root, 19 | "float16".to_string(), 20 | ModelType::Embedding(Pool::Mean), 21 | )?; 22 | 23 | let input_batch = batch( 24 | vec![ 25 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 26 | tokenizer.encode("Deep Learning is...", true).unwrap(), 27 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 28 | ], 29 | [0, 1, 2].to_vec(), 30 | vec![], 31 | ); 32 | 33 | let matcher = cosine_matcher(); 34 | 35 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 36 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 37 | insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher); 38 | 39 | let input_single = batch( 40 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 41 | [0].to_vec(), 42 | vec![], 43 | ); 44 | 45 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 46 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 47 | 48 | insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher); 49 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 50 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 51 | 52 | Ok(()) 53 | } 54 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_mistral.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | mod common; 3 | 4 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 5 | use anyhow::Result; 6 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 7 | use text_embeddings_backend_candle::CandleBackend; 8 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 9 | 10 | #[test] 11 | #[serial_test::serial] 12 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 13 | fn test_flash_mistral() -> Result<()> { 14 | let model_root = download_artifacts("Salesforce/SFR-Embedding-2_R", None)?; 15 | let tokenizer = load_tokenizer(&model_root)?; 16 | 17 | let backend = CandleBackend::new( 18 | &model_root, 19 | "float16".to_string(), 20 | ModelType::Embedding(Pool::Mean), 21 | )?; 22 | 23 | let input_batch = batch( 24 | vec![ 25 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 26 | tokenizer.encode("Deep Learning is...", true).unwrap(), 27 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 28 | ], 29 | [0, 1, 2].to_vec(), 30 | vec![], 31 | ); 32 | 33 | let matcher = cosine_matcher(); 34 | 35 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 36 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 37 | insta::assert_yaml_snapshot!("mistral_batch", embeddings_batch, &matcher); 38 | 39 | let input_single = batch( 40 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 41 | [0].to_vec(), 42 | vec![], 43 | ); 44 | 45 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 46 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 47 | 48 | insta::assert_yaml_snapshot!("mistral_single", embeddings_single, &matcher); 49 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 50 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 51 | 52 | Ok(()) 53 | } 54 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_nomic.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | mod common; 3 | 4 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 5 | use anyhow::Result; 6 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 7 | use text_embeddings_backend_candle::CandleBackend; 8 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 9 | 10 | #[test] 11 | #[serial_test::serial] 12 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 13 | fn test_flash_nomic_small() -> Result<()> { 14 | let model_root = download_artifacts("nomic-ai/nomic-embed-text-v1.5", None)?; 15 | let tokenizer = load_tokenizer(&model_root)?; 16 | 17 | let backend = CandleBackend::new( 18 | &model_root, 19 | "float16".to_string(), 20 | ModelType::Embedding(Pool::Mean), 21 | )?; 22 | 23 | let input_batch = batch( 24 | vec![ 25 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 26 | tokenizer.encode("Deep Learning is...", true).unwrap(), 27 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 28 | ], 29 | [0, 1, 2].to_vec(), 30 | vec![], 31 | ); 32 | 33 | let matcher = cosine_matcher(); 34 | 35 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 36 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 37 | insta::assert_yaml_snapshot!("nomic_batch", embeddings_batch, &matcher); 38 | 39 | let input_single = batch( 40 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 41 | [0].to_vec(), 42 | vec![], 43 | ); 44 | 45 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 46 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 47 | 48 | insta::assert_yaml_snapshot!("nomic_single", embeddings_single, &matcher); 49 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 50 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 51 | 52 | Ok(()) 53 | } 54 | 55 | #[test] 56 | #[serial_test::serial] 57 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 58 | fn test_flash_nomic_moe() -> Result<()> { 59 | let model_root = download_artifacts("nomic-ai/nomic-embed-text-v2-moe", None)?; 60 | let tokenizer = load_tokenizer(&model_root)?; 61 | 62 | let backend = CandleBackend::new( 63 | &model_root, 64 | "float16".to_string(), 65 | ModelType::Embedding(Pool::Mean), 66 | )?; 67 | 68 | let input_batch = batch( 69 | vec![ 70 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 71 | tokenizer.encode("Deep Learning is...", true).unwrap(), 72 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 73 | ], 74 | [0, 1, 2].to_vec(), 75 | vec![], 76 | ); 77 | 78 | let matcher = cosine_matcher(); 79 | 80 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 81 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 82 | insta::assert_yaml_snapshot!("nomic_moe_batch", embeddings_batch, &matcher); 83 | 84 | let input_single = batch( 85 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 86 | [0].to_vec(), 87 | vec![], 88 | ); 89 | 90 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 91 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 92 | 93 | insta::assert_yaml_snapshot!("nomic_moe_single", embeddings_single, &matcher); 94 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 95 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 96 | 97 | Ok(()) 98 | } 99 | -------------------------------------------------------------------------------- /backends/candle/tests/test_flash_qwen2.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code, unused_imports)] 2 | 3 | mod common; 4 | 5 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 6 | use anyhow::Result; 7 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 8 | use text_embeddings_backend_candle::CandleBackend; 9 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 10 | use tokenizers::processors::sequence::Sequence; 11 | use tokenizers::processors::template::TemplateProcessing; 12 | use tokenizers::{PostProcessorWrapper, Tokenizer}; 13 | 14 | #[test] 15 | #[serial_test::serial] 16 | #[cfg(all(feature = "cuda", feature = "flash-attn"))] 17 | fn test_flash_qwen2() -> Result<()> { 18 | let model_root = download_artifacts("Alibaba-NLP/gte-Qwen2-1.5B-instruct", None)?; 19 | let mut tokenizer = load_tokenizer(&model_root)?; 20 | // Qwen2 updates the post processor manually instead of into the tokenizer.json... 21 | // https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct/blob/main/tokenization_qwen.py#L246 22 | let template = TemplateProcessing::builder() 23 | .try_single("$A:0 <|endoftext|>:0") 24 | .unwrap() 25 | .try_pair("$A:0 <|endoftext|>:0 $B:1 <|endoftext|>:1") 26 | .unwrap() 27 | .special_tokens(vec![("<|endoftext|>", 151643)]) 28 | .build() 29 | .unwrap(); 30 | match tokenizer.get_post_processor() { 31 | None => tokenizer.with_post_processor(Some(template)), 32 | Some(post_processor) => { 33 | let post_processor = Sequence::new(vec![ 34 | post_processor.clone(), 35 | PostProcessorWrapper::Template(template), 36 | ]); 37 | tokenizer.with_post_processor(Some(post_processor)) 38 | } 39 | }; 40 | 41 | let backend = CandleBackend::new( 42 | &model_root, 43 | "float16".to_string(), 44 | ModelType::Embedding(Pool::LastToken), 45 | )?; 46 | 47 | let input_batch = batch( 48 | vec![ 49 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 50 | tokenizer.encode("Deep Learning is...", true).unwrap(), 51 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 52 | ], 53 | [0, 1, 2].to_vec(), 54 | vec![], 55 | ); 56 | 57 | let matcher = cosine_matcher(); 58 | 59 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 60 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 61 | insta::assert_yaml_snapshot!("qwen2_batch", embeddings_batch, &matcher); 62 | 63 | let input_single = batch( 64 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 65 | [0].to_vec(), 66 | vec![], 67 | ); 68 | 69 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 70 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 71 | 72 | insta::assert_yaml_snapshot!("qwen2_single", embeddings_single, &matcher); 73 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 74 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 75 | 76 | Ok(()) 77 | } 78 | -------------------------------------------------------------------------------- /backends/candle/tests/test_jina.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | 3 | use crate::common::{sort_embeddings, SnapshotEmbeddings, SnapshotScores}; 4 | use anyhow::Result; 5 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer, relative_matcher}; 6 | use text_embeddings_backend_candle::CandleBackend; 7 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 8 | 9 | #[test] 10 | fn test_jina_small() -> Result<()> { 11 | let model_root = download_artifacts("jinaai/jina-embeddings-v2-small-en", None)?; 12 | let tokenizer = load_tokenizer(&model_root)?; 13 | 14 | let backend = CandleBackend::new( 15 | &model_root, 16 | "float32".to_string(), 17 | ModelType::Embedding(Pool::Mean), 18 | )?; 19 | 20 | let input_batch = batch( 21 | vec![ 22 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 23 | tokenizer.encode("Deep Learning is...", true).unwrap(), 24 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 25 | ], 26 | [0, 1, 2].to_vec(), 27 | vec![], 28 | ); 29 | 30 | let matcher = cosine_matcher(); 31 | 32 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 33 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 34 | insta::assert_yaml_snapshot!("jina_batch", embeddings_batch, &matcher); 35 | 36 | let input_single = batch( 37 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 38 | [0].to_vec(), 39 | vec![], 40 | ); 41 | 42 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 43 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 44 | 45 | insta::assert_yaml_snapshot!("jina_single", embeddings_single, &matcher); 46 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 47 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 48 | 49 | Ok(()) 50 | } 51 | 52 | #[test] 53 | #[serial_test::serial] 54 | fn test_jina_rerank() -> Result<()> { 55 | let model_root = download_artifacts("jinaai/jina-reranker-v1-tiny-en", Some("refs/pr/11"))?; 56 | let tokenizer = load_tokenizer(&model_root)?; 57 | 58 | let backend = CandleBackend::new(&model_root, "float32".to_string(), ModelType::Classifier)?; 59 | 60 | let input_single = batch( 61 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 62 | [0].to_vec(), 63 | vec![], 64 | ); 65 | 66 | let predictions: Vec> = backend 67 | .predict(input_single)? 68 | .into_iter() 69 | .map(|(_, v)| v) 70 | .collect(); 71 | 72 | let predictions = SnapshotScores::from(predictions); 73 | insta::assert_yaml_snapshot!("jinabert_reranker_single", predictions, &relative_matcher()); 74 | 75 | Ok(()) 76 | } 77 | -------------------------------------------------------------------------------- /backends/candle/tests/test_jina_code.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | 3 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 4 | use anyhow::Result; 5 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 6 | use text_embeddings_backend_candle::CandleBackend; 7 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 8 | 9 | #[test] 10 | fn test_jina_code_base() -> Result<()> { 11 | let model_root = download_artifacts("jinaai/jina-embeddings-v2-base-code", None)?; 12 | let tokenizer = load_tokenizer(&model_root)?; 13 | 14 | let backend = CandleBackend::new( 15 | &model_root, 16 | "float32".to_string(), 17 | ModelType::Embedding(Pool::Mean), 18 | )?; 19 | 20 | let input_batch = batch( 21 | vec![ 22 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 23 | tokenizer.encode("Deep Learning is...", true).unwrap(), 24 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 25 | ], 26 | [0, 1, 2].to_vec(), 27 | vec![], 28 | ); 29 | 30 | let matcher = cosine_matcher(); 31 | 32 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 33 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 34 | insta::assert_yaml_snapshot!("jina_code_batch", embeddings_batch, &matcher); 35 | 36 | let input_single = batch( 37 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 38 | [0].to_vec(), 39 | vec![], 40 | ); 41 | 42 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 43 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 44 | 45 | insta::assert_yaml_snapshot!("jina_code_single", embeddings_single, &matcher); 46 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 47 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 48 | 49 | Ok(()) 50 | } 51 | -------------------------------------------------------------------------------- /backends/candle/tests/test_nomic.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | 3 | use crate::common::{sort_embeddings, SnapshotEmbeddings}; 4 | use anyhow::Result; 5 | use common::{batch, cosine_matcher, download_artifacts, load_tokenizer}; 6 | use text_embeddings_backend_candle::CandleBackend; 7 | use text_embeddings_backend_core::{Backend, ModelType, Pool}; 8 | 9 | #[test] 10 | fn test_nomic_small() -> Result<()> { 11 | let model_root = download_artifacts("nomic-ai/nomic-embed-text-v1.5", None)?; 12 | let tokenizer = load_tokenizer(&model_root)?; 13 | 14 | let backend = CandleBackend::new( 15 | &model_root, 16 | "float32".to_string(), 17 | ModelType::Embedding(Pool::Mean), 18 | )?; 19 | 20 | let input_batch = batch( 21 | vec![ 22 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 23 | tokenizer.encode("Deep Learning is...", true).unwrap(), 24 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 25 | ], 26 | [0, 1, 2].to_vec(), 27 | vec![], 28 | ); 29 | 30 | let matcher = cosine_matcher(); 31 | 32 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 33 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 34 | insta::assert_yaml_snapshot!("nomic_batch", embeddings_batch, &matcher); 35 | 36 | let input_single = batch( 37 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 38 | [0].to_vec(), 39 | vec![], 40 | ); 41 | 42 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 43 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 44 | 45 | insta::assert_yaml_snapshot!("nomic_single", embeddings_single, &matcher); 46 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 47 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 48 | 49 | Ok(()) 50 | } 51 | 52 | #[test] 53 | fn test_nomic_moe() -> Result<()> { 54 | let model_root = download_artifacts("nomic-ai/nomic-embed-text-v2-moe", None)?; 55 | let tokenizer = load_tokenizer(&model_root)?; 56 | 57 | let backend = CandleBackend::new( 58 | &model_root, 59 | "float32".to_string(), 60 | ModelType::Embedding(Pool::Mean), 61 | )?; 62 | 63 | let input_batch = batch( 64 | vec![ 65 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 66 | tokenizer.encode("Deep Learning is...", true).unwrap(), 67 | tokenizer.encode("What is Deep Learning?", true).unwrap(), 68 | ], 69 | [0, 1, 2].to_vec(), 70 | vec![], 71 | ); 72 | 73 | let matcher = cosine_matcher(); 74 | 75 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_batch)?); 76 | let embeddings_batch = SnapshotEmbeddings::from(pooled_embeddings); 77 | insta::assert_yaml_snapshot!("nomic_moe_batch", embeddings_batch, &matcher); 78 | 79 | let input_single = batch( 80 | vec![tokenizer.encode("What is Deep Learning?", true).unwrap()], 81 | [0].to_vec(), 82 | vec![], 83 | ); 84 | 85 | let (pooled_embeddings, _) = sort_embeddings(backend.embed(input_single)?); 86 | let embeddings_single = SnapshotEmbeddings::from(pooled_embeddings); 87 | 88 | insta::assert_yaml_snapshot!("nomic_moe_single", embeddings_single, &matcher); 89 | assert_eq!(embeddings_batch[0], embeddings_single[0]); 90 | assert_eq!(embeddings_batch[2], embeddings_single[0]); 91 | 92 | Ok(()) 93 | } 94 | -------------------------------------------------------------------------------- /backends/core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-backend-core" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | thiserror = { workspace = true } 10 | clap = { workspace = true, optional = true } 11 | nohash-hasher = { workspace = true } 12 | serde = { workspace = true } 13 | 14 | [features] 15 | clap = ["dep:clap"] 16 | -------------------------------------------------------------------------------- /backends/core/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "clap")] 2 | use clap::ValueEnum; 3 | use nohash_hasher::IntMap; 4 | use serde::Deserialize; 5 | use std::fmt; 6 | use thiserror::Error; 7 | 8 | #[derive(Debug)] 9 | pub struct Batch { 10 | pub input_ids: Vec, 11 | pub token_type_ids: Vec, 12 | pub position_ids: Vec, 13 | pub cumulative_seq_lengths: Vec, 14 | pub max_length: u32, 15 | pub pooled_indices: Vec, 16 | pub raw_indices: Vec, 17 | } 18 | 19 | impl Batch { 20 | pub fn len(&self) -> usize { 21 | self.cumulative_seq_lengths.len() - 1 22 | } 23 | 24 | pub fn is_empty(&self) -> bool { 25 | self.len() == 0 26 | } 27 | } 28 | 29 | pub enum Embedding { 30 | Pooled(Vec), 31 | All(Vec>), 32 | } 33 | 34 | pub type Embeddings = IntMap; 35 | pub type Predictions = IntMap>; 36 | 37 | pub trait Backend { 38 | fn health(&self) -> Result<(), BackendError>; 39 | fn max_batch_size(&self) -> Option { 40 | None 41 | } 42 | 43 | fn is_padded(&self) -> bool; 44 | 45 | fn embed(&self, batch: Batch) -> Result; 46 | 47 | fn predict(&self, batch: Batch) -> Result; 48 | } 49 | 50 | #[derive(Debug, PartialEq, Clone)] 51 | pub enum ModelType { 52 | Classifier, 53 | Embedding(Pool), 54 | } 55 | 56 | #[derive(Debug, PartialEq, Clone, Deserialize)] 57 | #[cfg_attr(feature = "clap", derive(ValueEnum))] 58 | #[serde(rename_all = "snake_case")] 59 | pub enum Pool { 60 | /// Select the CLS token as embedding 61 | Cls, 62 | /// Apply Mean pooling to the model embeddings 63 | Mean, 64 | /// Apply SPLADE (Sparse Lexical and Expansion) to the model embeddings. 65 | /// This option is only available if the loaded model is a `ForMaskedLM` Transformer 66 | /// model. 67 | Splade, 68 | /// Select the last token as embedding 69 | LastToken, 70 | } 71 | 72 | impl fmt::Display for Pool { 73 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 74 | match self { 75 | Pool::Cls => write!(f, "cls"), 76 | Pool::Mean => write!(f, "mean"), 77 | Pool::Splade => write!(f, "splade"), 78 | Pool::LastToken => write!(f, "last_token"), 79 | } 80 | } 81 | } 82 | 83 | #[derive(Debug, Error, Clone)] 84 | pub enum BackendError { 85 | #[error("No backend found")] 86 | NoBackend, 87 | #[error("Could not start backend: {0}")] 88 | Start(String), 89 | #[error("{0}")] 90 | Inference(String), 91 | #[error("Backend is unhealthy")] 92 | Unhealthy, 93 | #[error("Weights not found: {0}")] 94 | WeightsNotFound(String), 95 | } 96 | -------------------------------------------------------------------------------- /backends/grpc-client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "backend-grpc-client" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | grpc-metadata = { path = "../grpc-metadata" } 10 | prost = "^0.11" 11 | thiserror = "^1.0" 12 | tokio = { version = "^1.25", features = ["sync"] } 13 | tonic = "^0.9" 14 | tower = "^0.4" 15 | tracing = "^0.1" 16 | 17 | [build-dependencies] 18 | tonic-build = "0.9.2" 19 | prost-build = "0.11.6" 20 | -------------------------------------------------------------------------------- /backends/grpc-client/build.rs: -------------------------------------------------------------------------------- 1 | use std::fs; 2 | 3 | fn main() -> Result<(), Box> { 4 | println!("cargo:rerun-if-changed=../proto/embed.proto"); 5 | fs::create_dir("src/pb").unwrap_or(()); 6 | 7 | let mut config = prost_build::Config::new(); 8 | config.protoc_arg("--experimental_allow_proto3_optional"); 9 | 10 | tonic_build::configure() 11 | .build_client(true) 12 | .build_server(false) 13 | .out_dir("src/pb") 14 | .include_file("mod.rs") 15 | .compile_with_config(config, &["../proto/embed.proto"], &["../proto"]) 16 | .unwrap_or_else(|e| panic!("protobuf compilation failed: {e}")); 17 | 18 | Ok(()) 19 | } 20 | -------------------------------------------------------------------------------- /backends/grpc-client/src/client.rs: -------------------------------------------------------------------------------- 1 | /// Single shard Client 2 | use crate::pb::embedding::v1::embedding_service_client::EmbeddingServiceClient; 3 | use crate::pb::embedding::v1::*; 4 | use crate::Result; 5 | use grpc_metadata::InjectTelemetryContext; 6 | use tonic::transport::{Channel, Uri}; 7 | use tracing::instrument; 8 | 9 | /// Text Generation Inference gRPC client 10 | #[derive(Debug, Clone)] 11 | pub struct Client { 12 | stub: EmbeddingServiceClient, 13 | } 14 | 15 | impl Client { 16 | /// Returns a client connected to the given url 17 | pub async fn connect(uri: Uri) -> Result { 18 | let channel = Channel::builder(uri).connect().await?; 19 | 20 | Ok(Self { 21 | stub: EmbeddingServiceClient::new(channel), 22 | }) 23 | } 24 | 25 | /// Returns a client connected to the given unix socket 26 | pub async fn connect_uds(path: String) -> Result { 27 | let channel = Channel::from_shared("http://[::]:50051".to_string()) 28 | .unwrap() 29 | .connect_with_connector(tower::service_fn(move |_: Uri| { 30 | tokio::net::UnixStream::connect(path.clone()) 31 | })) 32 | .await?; 33 | 34 | Ok(Self { 35 | stub: EmbeddingServiceClient::new(channel), 36 | }) 37 | } 38 | 39 | /// Get backend health 40 | #[instrument(skip(self))] 41 | pub async fn health(&mut self) -> Result { 42 | let request = tonic::Request::new(HealthRequest {}).inject_context(); 43 | let response = self.stub.health(request).await?.into_inner(); 44 | Ok(response) 45 | } 46 | 47 | #[instrument(skip_all)] 48 | pub async fn embed( 49 | &mut self, 50 | input_ids: Vec, 51 | token_type_ids: Vec, 52 | position_ids: Vec, 53 | cu_seq_lengths: Vec, 54 | max_length: u32, 55 | ) -> Result> { 56 | let request = tonic::Request::new(EmbedRequest { 57 | input_ids, 58 | token_type_ids, 59 | position_ids, 60 | max_length, 61 | cu_seq_lengths, 62 | }) 63 | .inject_context(); 64 | let response = self.stub.embed(request).await?.into_inner(); 65 | Ok(response.embeddings) 66 | } 67 | 68 | #[instrument(skip_all)] 69 | pub async fn predict( 70 | &mut self, 71 | input_ids: Vec, 72 | token_type_ids: Vec, 73 | position_ids: Vec, 74 | cu_seq_lengths: Vec, 75 | max_length: u32, 76 | ) -> Result> { 77 | let request = tonic::Request::new(EmbedRequest { 78 | input_ids, 79 | token_type_ids, 80 | position_ids, 81 | max_length, 82 | cu_seq_lengths, 83 | }) 84 | .inject_context(); 85 | let response = self.stub.predict(request).await?.into_inner(); 86 | Ok(response.scores) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /backends/grpc-client/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Text Embedding backend gRPC client library 2 | 3 | mod client; 4 | #[allow(clippy::derive_partial_eq_without_eq)] 5 | mod pb; 6 | 7 | pub use client::Client; 8 | pub use pb::embedding::v1::Embedding; 9 | pub use pb::embedding::v1::HealthResponse; 10 | use thiserror::Error; 11 | use tonic::transport; 12 | use tonic::Status; 13 | 14 | #[derive(Error, Debug, Clone)] 15 | pub enum ClientError { 16 | #[error("Could not connect to Text Embedding server: {0}")] 17 | Connection(String), 18 | #[error("Server error: {0}")] 19 | Inference(String), 20 | } 21 | 22 | impl From for ClientError { 23 | fn from(err: Status) -> Self { 24 | let err = Self::Inference(err.message().to_string()); 25 | tracing::error!("{err}"); 26 | err 27 | } 28 | } 29 | 30 | impl From for ClientError { 31 | fn from(err: transport::Error) -> Self { 32 | let err = Self::Connection(err.to_string()); 33 | tracing::error!("{err}"); 34 | err 35 | } 36 | } 37 | 38 | pub type Result = std::result::Result; 39 | -------------------------------------------------------------------------------- /backends/grpc-client/src/pb/.gitignore: -------------------------------------------------------------------------------- 1 | *.rs 2 | -------------------------------------------------------------------------------- /backends/grpc-metadata/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "grpc-metadata" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | opentelemetry = "^0.19" 8 | tonic = "^0.9" 9 | tracing = "^0.1" 10 | tracing-opentelemetry = "^0.19" 11 | -------------------------------------------------------------------------------- /backends/grpc-metadata/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A crate to extract and inject a OpenTelemetry context from and to a gRPC request. 2 | //! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples 3 | 4 | use opentelemetry::global; 5 | use opentelemetry::propagation::Injector; 6 | use tracing_opentelemetry::OpenTelemetrySpanExt; 7 | 8 | // /// Extract context metadata from a gRPC request's metadata 9 | // struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap); 10 | // 11 | // impl<'a> Extractor for MetadataExtractor<'a> { 12 | // /// Get a value for a key from the MetadataMap. If the value can't be converted to &str, returns None 13 | // fn get(&self, key: &str) -> Option<&str> { 14 | // self.0.get(key).and_then(|metadata| metadata.to_str().ok()) 15 | // } 16 | // 17 | // /// Collect all the keys from the MetadataMap. 18 | // fn keys(&self) -> Vec<&str> { 19 | // self.0 20 | // .keys() 21 | // .map(|key| match key { 22 | // tonic::metadata::KeyRef::Ascii(v) => v.as_str(), 23 | // tonic::metadata::KeyRef::Binary(v) => v.as_str(), 24 | // }) 25 | // .collect::>() 26 | // } 27 | // } 28 | 29 | /// Inject context in the metadata of a gRPC request. 30 | struct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap); 31 | 32 | impl Injector for MetadataInjector<'_> { 33 | /// Set a key and value in the MetadataMap. Does nothing if the key or value are not valid inputs 34 | fn set(&mut self, key: &str, value: String) { 35 | if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) { 36 | if let Ok(val) = value.parse() { 37 | self.0.insert(key, val); 38 | } 39 | } 40 | } 41 | } 42 | 43 | /// Get a context from the global context and inject the span into a gRPC request's metadata. 44 | fn inject(metadata: &mut tonic::metadata::MetadataMap) { 45 | global::get_text_map_propagator(|propagator| { 46 | propagator.inject_context( 47 | &tracing::Span::current().context(), 48 | &mut MetadataInjector(metadata), 49 | ) 50 | }) 51 | } 52 | 53 | pub trait InjectTelemetryContext { 54 | fn inject_context(self) -> Self; 55 | } 56 | 57 | impl InjectTelemetryContext for tonic::Request { 58 | fn inject_context(mut self) -> Self { 59 | inject(self.metadata_mut()); 60 | self 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /backends/ort/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-backend-ort" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | anyhow = { workspace = true } 10 | nohash-hasher = { workspace = true } 11 | ndarray = "0.16.1" 12 | num_cpus = { workspace = true } 13 | ort = { version = "2.0.0-rc.8", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } 14 | text-embeddings-backend-core = { path = "../core" } 15 | tracing = { workspace = true } 16 | thiserror = { workspace = true } 17 | serde = { workspace = true } 18 | serde_json = { workspace = true } 19 | -------------------------------------------------------------------------------- /backends/proto/embed.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package embedding.v1; 4 | 5 | service EmbeddingService { 6 | /// Decode token for a list of prefilled batches 7 | rpc Embed (EmbedRequest) returns (EmbedResponse); 8 | /// Health check 9 | rpc Health (HealthRequest) returns (HealthResponse); 10 | /// Predict 11 | rpc Predict (EmbedRequest) returns (PredictResponse); 12 | } 13 | 14 | message HealthRequest {} 15 | message HealthResponse {} 16 | 17 | message EmbedRequest { 18 | repeated uint32 input_ids = 1; 19 | repeated uint32 token_type_ids = 2; 20 | repeated uint32 position_ids = 3; 21 | repeated uint32 cu_seq_lengths = 4; 22 | /// Length of the longest request 23 | uint32 max_length = 5; 24 | } 25 | 26 | message Embedding { 27 | repeated float values = 1; 28 | } 29 | 30 | message EmbedResponse { 31 | repeated Embedding embeddings = 1; 32 | } 33 | 34 | message Score { 35 | repeated float values = 1; 36 | } 37 | 38 | message PredictResponse { 39 | repeated Score scores = 1; 40 | } 41 | -------------------------------------------------------------------------------- /backends/python/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-backend-python" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | backend-grpc-client = { path = "../grpc-client" } 10 | nohash-hasher = "^0.2" 11 | serde = { version = "^1.0", features = ["derive"] } 12 | serde_json = "^1.0" 13 | text-embeddings-backend-core = { path = "../core" } 14 | thiserror = "^1.0" 15 | tokio = { version = "^1.25", features = ["sync"] } 16 | tracing = "^0.1" 17 | -------------------------------------------------------------------------------- /backends/python/server/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | text_generation_server/__pycache__/ 4 | text_generation_server/pb/__pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | transformers 158 | safetensors 159 | flash-attention/ 160 | flash-attention-v2/ 161 | vllm/ 162 | -------------------------------------------------------------------------------- /backends/python/server/Makefile: -------------------------------------------------------------------------------- 1 | include Makefile-flash-att 2 | include Makefile-flash-att-v2 3 | 4 | unit-tests: 5 | pytest -s -vv -m "not private" tests 6 | 7 | gen-server: 8 | # Compile protos 9 | pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir 10 | mkdir text_embeddings_server/pb || true 11 | python -m grpc_tools.protoc -I../../proto --python_out=text_embeddings_server/pb \ 12 | --grpc_python_out=text_embeddings_server/pb --mypy_out=text_embeddings_server/pb ../../proto/embed.proto 13 | find text_embeddings_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; 14 | touch text_embeddings_server/pb/__init__.py 15 | 16 | install: gen-server 17 | pip install pip --upgrade 18 | pip install --no-deps -r requirements.txt 19 | pip install -e . 20 | 21 | run-dev: 22 | python text_embeddings_server/cli.py serve BAAI/bge-small-en 23 | 24 | export-requirements: 25 | poetry export -o requirements.txt --without-hashes 26 | -------------------------------------------------------------------------------- /backends/python/server/Makefile-flash-att: -------------------------------------------------------------------------------- 1 | flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec 2 | 3 | flash-attention: 4 | # Clone flash attention 5 | pip install packaging 6 | git clone https://github.com/HazyResearch/flash-attention.git 7 | 8 | build-flash-attention: flash-attention 9 | cd flash-attention && git fetch && git checkout $(flash_att_commit) 10 | cd flash-attention && python setup.py build 11 | cd flash-attention/csrc/layer_norm && python setup.py build 12 | 13 | install-flash-attention: build-flash-attention 14 | pip uninstall flash_attn dropout_layer_norm -y || true 15 | cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install 16 | -------------------------------------------------------------------------------- /backends/python/server/Makefile-flash-att-v2: -------------------------------------------------------------------------------- 1 | flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c 2 | 3 | flash-attention-v2: 4 | # Clone flash attention 5 | pip install packaging 6 | git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 7 | 8 | build-flash-attention-v2: flash-attention-v2 9 | cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) 10 | cd flash-attention-v2 && python setup.py build 11 | 12 | install-flash-attention-v2: build-flash-attention-v2 13 | cd flash-attention-v2 && python setup.py install 14 | -------------------------------------------------------------------------------- /backends/python/server/README.md: -------------------------------------------------------------------------------- 1 | # Text Embeddings Inference Python gRPC Server 2 | 3 | A Python gRPC server for Text Embeddings Inference 4 | 5 | ## Install 6 | 7 | ```shell 8 | make install 9 | ``` 10 | 11 | ## Run 12 | 13 | ```shell 14 | make run-dev 15 | ``` 16 | -------------------------------------------------------------------------------- /backends/python/server/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "text-embeddings-server" 3 | version = "0.1.0" 4 | description = "Text Embeddings Python gRPC Server" 5 | authors = ["Olivier Dehaene "] 6 | 7 | [tool.poetry.scripts] 8 | python-text-embeddings-server = 'text_embeddings_server.cli:app' 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.9,<3.13" 12 | protobuf = ">=4.25.3,<6" 13 | grpcio = "^1.51.1" 14 | grpcio-status = "^1.51.1" 15 | grpcio-reflection = "^1.51.1" 16 | grpc-interceptor = "^0.15.0" 17 | typer = "^0.6.1" 18 | safetensors = "^0.4" 19 | loguru = "^0.6.0" 20 | opentelemetry-api = "^1.25.0" 21 | opentelemetry-exporter-otlp = "^1.25.0" 22 | opentelemetry-instrumentation-grpc = "^0.46b0" 23 | sentence-transformers = "^3.3.1" 24 | 25 | [tool.poetry.extras] 26 | 27 | [tool.poetry.group.dev.dependencies] 28 | grpcio-tools = "^1.51.1" 29 | pytest = "^7.3.0" 30 | 31 | [[tool.poetry.source]] 32 | name = "pytorch-gpu-src" 33 | url = "https://download.pytorch.org/whl/cu118" 34 | priority = "explicit" 35 | 36 | [tool.pytest.ini_options] 37 | markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] 38 | 39 | [build-system] 40 | requires = ["poetry-core>=1.0.0"] 41 | build-backend = "poetry.core.masonry.api" 42 | -------------------------------------------------------------------------------- /backends/python/server/requirements-hpu.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" 2 | backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" 3 | certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" 4 | charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" 5 | click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" 6 | colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") 7 | deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" 8 | filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13" 9 | fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" 10 | fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" 11 | googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" 12 | grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" 13 | grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 14 | grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 15 | grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 16 | huggingface-hub==0.24.5 ; python_version >= "3.9" and python_version < "3.13" 17 | humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13" 18 | idna==3.4 ; python_version >= "3.9" and python_version < "3.13" 19 | importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" 20 | jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13" 21 | loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" 22 | markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13" 23 | mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" 24 | networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13" 25 | numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" 26 | opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 27 | opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 28 | opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 29 | opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 30 | opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 31 | opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 32 | opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 33 | opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 34 | opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 35 | packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" 36 | pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13" 37 | pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13" 38 | protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" 39 | pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" 40 | regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13" 41 | requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" 42 | safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" 43 | setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" 44 | six==1.16.0 ; python_version >= "3.9" and python_version < "3.13" 45 | sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" 46 | tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" 47 | tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" 48 | transformers==4.51.3 ; python_version >= "3.9" and python_version < "3.13" 49 | transformers[sentencepiece]==4.51.3 ; python_version >= "3.9" and python_version < "3.13" 50 | typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" 51 | typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" 52 | tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13" 53 | urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" 54 | win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" 55 | wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 56 | xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13" 57 | yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13" 58 | zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13" 59 | pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13" 60 | einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" 61 | -------------------------------------------------------------------------------- /backends/python/server/requirements-intel.txt: -------------------------------------------------------------------------------- 1 | backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" 2 | certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" 3 | charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" 4 | click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" 5 | colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") 6 | deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" 7 | filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" 8 | fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" 9 | googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" 10 | grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" 11 | grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 12 | grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 13 | grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" 14 | huggingface-hub==0.19.3 ; python_version >= "3.9" and python_version < "3.13" 15 | idna==3.4 ; python_version >= "3.9" and python_version < "3.13" 16 | jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" 17 | loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" 18 | markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" 19 | mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" 20 | networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" 21 | opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 22 | opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 23 | opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 24 | opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 25 | opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 26 | opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 27 | opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 28 | opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 29 | opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" 30 | packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" 31 | protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" 32 | pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" 33 | requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" 34 | safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" 35 | setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" 36 | sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" 37 | tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" 38 | typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" 39 | typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" 40 | urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" 41 | win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" 42 | wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13" 43 | transformers==4.51.3 ; python_version >= "3.9" and python_version < "3.13" 44 | pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13" 45 | einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13" 46 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/text-embeddings-inference/d51a8b99d01d00d5ee7a52ea8dec53ed1c1b7d7c/backends/python/server/text_embeddings_server/__init__.py -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import typer 3 | 4 | from pathlib import Path 5 | from loguru import logger 6 | from typing import Optional 7 | from enum import Enum 8 | 9 | app = typer.Typer() 10 | 11 | 12 | class Dtype(str, Enum): 13 | float32 = "float32" 14 | float16 = "float16" 15 | bloat16 = "bfloat16" 16 | 17 | 18 | @app.command() 19 | def serve( 20 | model_path: Path, 21 | dtype: Dtype = "float32", 22 | uds_path: Path = "/tmp/text-embeddings-server", 23 | logger_level: str = "INFO", 24 | json_output: bool = False, 25 | otlp_endpoint: Optional[str] = None, 26 | otlp_service_name: str = "text-embeddings-inference.server", 27 | pool: str = "cls", 28 | ): 29 | # Remove default handler 30 | logger.remove() 31 | logger.add( 32 | sys.stdout, 33 | format="{message}", 34 | filter="text_embeddings_server", 35 | level=logger_level, 36 | serialize=json_output, 37 | backtrace=True, 38 | diagnose=False, 39 | ) 40 | 41 | # Import here after the logger is added to log potential import exceptions 42 | from text_embeddings_server import server 43 | from text_embeddings_server.utils.tracing import setup_tracing 44 | 45 | # Setup OpenTelemetry distributed tracing 46 | if otlp_endpoint is not None: 47 | setup_tracing(otlp_endpoint=otlp_endpoint, otlp_service_name=otlp_service_name) 48 | 49 | # Downgrade enum into str for easier management later on 50 | dtype = None if dtype is None else dtype.value 51 | server.serve(model_path, dtype, uds_path, pool) 52 | 53 | 54 | if __name__ == "__main__": 55 | app() 56 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/classification_model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | 4 | from pathlib import Path 5 | from typing import Type, List 6 | from transformers import AutoModelForSequenceClassification 7 | from opentelemetry import trace 8 | 9 | from text_embeddings_server.models import Model 10 | from text_embeddings_server.models.types import PaddedBatch, Embedding, Score 11 | 12 | tracer = trace.get_tracer(__name__) 13 | 14 | 15 | class ClassificationModel(Model): 16 | def __init__( 17 | self, 18 | model_path: Path, 19 | device: torch.device, 20 | dtype: torch.dtype, 21 | trust_remote: bool = False, 22 | ): 23 | model = AutoModelForSequenceClassification.from_pretrained( 24 | model_path, trust_remote_code=trust_remote 25 | ) 26 | model = model.to(dtype).to(device) 27 | 28 | self.hidden_size = model.config.hidden_size 29 | position_offset = 0 30 | model_type = model.config.model_type 31 | if model_type in ["xlm-roberta", "camembert", "roberta"]: 32 | position_offset = model.config.pad_token_id + 1 33 | if hasattr(model.config, "max_seq_length"): 34 | self.max_input_length = model.config.max_seq_length 35 | else: 36 | self.max_input_length = ( 37 | model.config.max_position_embeddings - position_offset 38 | ) 39 | 40 | self.has_position_ids = ( 41 | inspect.signature(model.forward).parameters.get("position_ids", None) 42 | is not None 43 | ) 44 | self.has_token_type_ids = ( 45 | inspect.signature(model.forward).parameters.get("token_type_ids", None) 46 | is not None 47 | ) 48 | 49 | super(ClassificationModel, self).__init__( 50 | model=model, dtype=dtype, device=device 51 | ) 52 | 53 | @property 54 | def batch_type(self) -> Type[PaddedBatch]: 55 | return PaddedBatch 56 | 57 | @tracer.start_as_current_span("embed") 58 | def embed(self, batch: PaddedBatch) -> List[Embedding]: 59 | pass 60 | 61 | @tracer.start_as_current_span("predict") 62 | def predict(self, batch: PaddedBatch) -> List[Score]: 63 | kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} 64 | if self.has_token_type_ids: 65 | kwargs["token_type_ids"] = batch.token_type_ids 66 | if self.has_position_ids: 67 | kwargs["position_ids"] = batch.position_ids 68 | 69 | output = self.model(**kwargs, return_dict=True) 70 | all_scores = output.logits.tolist() 71 | return [Score(values=scores) for scores in all_scores] 72 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/default_model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | 4 | from pathlib import Path 5 | from typing import Type, List 6 | from transformers import AutoModel 7 | from opentelemetry import trace 8 | from text_embeddings_server.models.pooling import DefaultPooling 9 | 10 | from text_embeddings_server.models import Model 11 | from text_embeddings_server.models.types import PaddedBatch, Embedding, Score 12 | 13 | tracer = trace.get_tracer(__name__) 14 | 15 | 16 | class DefaultModel(Model): 17 | def __init__( 18 | self, 19 | model_path: Path, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | pool: str, 23 | trust_remote: bool = False, 24 | ): 25 | model = ( 26 | AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote) 27 | .to(dtype) 28 | .to(device) 29 | ) 30 | self.hidden_size = model.config.hidden_size 31 | self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool) 32 | 33 | position_offset = 0 34 | model_type = model.config.model_type 35 | if model_type in ["xlm-roberta", "camembert", "roberta"]: 36 | position_offset = model.config.pad_token_id + 1 37 | if hasattr(model.config, "max_seq_length"): 38 | self.max_input_length = model.config.max_seq_length 39 | else: 40 | self.max_input_length = ( 41 | model.config.max_position_embeddings - position_offset 42 | ) 43 | 44 | self.has_position_ids = ( 45 | inspect.signature(model.forward).parameters.get("position_ids", None) 46 | is not None 47 | ) 48 | self.has_token_type_ids = ( 49 | inspect.signature(model.forward).parameters.get("token_type_ids", None) 50 | is not None 51 | ) 52 | 53 | super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device) 54 | 55 | @property 56 | def batch_type(self) -> Type[PaddedBatch]: 57 | return PaddedBatch 58 | 59 | @tracer.start_as_current_span("embed") 60 | def embed(self, batch: PaddedBatch) -> List[Embedding]: 61 | kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} 62 | if self.has_token_type_ids: 63 | kwargs["token_type_ids"] = batch.token_type_ids 64 | if self.has_position_ids: 65 | kwargs["position_ids"] = batch.position_ids 66 | output = self.model(**kwargs) 67 | 68 | embedding = self.pooling.forward(output, batch.attention_mask) 69 | 70 | cpu_results = embedding.view(-1).tolist() 71 | 72 | return [ 73 | Embedding( 74 | values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] 75 | ) 76 | for i in range(len(batch)) 77 | ] 78 | 79 | @tracer.start_as_current_span("predict") 80 | def predict(self, batch: PaddedBatch) -> List[Score]: 81 | pass 82 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/masked_model.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | 4 | from pathlib import Path 5 | from typing import Type, List 6 | from transformers import AutoModelForMaskedLM 7 | from opentelemetry import trace 8 | 9 | from text_embeddings_server.models import Model 10 | from text_embeddings_server.models.types import PaddedBatch, Embedding, Score 11 | from text_embeddings_server.models.pooling import SpladePooling 12 | 13 | tracer = trace.get_tracer(__name__) 14 | 15 | 16 | class MaskedLanguageModel(Model): 17 | def __init__( 18 | self, 19 | model_path: Path, 20 | device: torch.device, 21 | dtype: torch.dtype, 22 | trust_remote: bool = False, 23 | ): 24 | model = ( 25 | AutoModelForMaskedLM.from_pretrained( 26 | model_path, trust_remote_code=trust_remote 27 | ) 28 | .to(dtype) 29 | .to(device) 30 | ) 31 | self.pooling = SpladePooling() 32 | position_offset = 0 33 | model_type = model.config.model_type 34 | if model_type in ["xlm-roberta", "camembert", "roberta"]: 35 | position_offset = model.config.pad_token_id + 1 36 | if hasattr(model.config, "max_seq_length"): 37 | self.max_input_length = model.config.max_seq_length 38 | else: 39 | self.max_input_length = ( 40 | model.config.max_position_embeddings - position_offset 41 | ) 42 | self.has_position_ids = ( 43 | inspect.signature(model.forward).parameters.get("position_ids", None) 44 | is not None 45 | ) 46 | self.has_token_type_ids = ( 47 | inspect.signature(model.forward).parameters.get("token_type_ids", None) 48 | is not None 49 | ) 50 | 51 | super(MaskedLanguageModel, self).__init__( 52 | model=model, dtype=dtype, device=device 53 | ) 54 | 55 | @property 56 | def batch_type(self) -> Type[PaddedBatch]: 57 | return PaddedBatch 58 | 59 | @tracer.start_as_current_span("embed") 60 | def embed(self, batch: PaddedBatch) -> List[Embedding]: 61 | kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} 62 | if self.has_token_type_ids: 63 | kwargs["token_type_ids"] = batch.token_type_ids 64 | if self.has_position_ids: 65 | kwargs["position_ids"] = batch.position_ids 66 | output = self.model(**kwargs) 67 | embedding = self.pooling.forward(output, batch.attention_mask) 68 | cpu_results = embedding.view(-1).tolist() 69 | 70 | step_size = embedding.shape[-1] 71 | return [ 72 | Embedding(values=cpu_results[i * step_size : (i + 1) * step_size]) 73 | for i in range(len(batch)) 74 | ] 75 | 76 | @tracer.start_as_current_span("predict") 77 | def predict(self, batch: PaddedBatch) -> List[Score]: 78 | pass 79 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List, TypeVar, Type 5 | 6 | from text_embeddings_server.models.types import Batch, Embedding 7 | 8 | B = TypeVar("B", bound=Batch) 9 | 10 | 11 | class Model(ABC): 12 | def __init__( 13 | self, 14 | model, 15 | dtype: torch.dtype, 16 | device: torch.device, 17 | ): 18 | self.model = model 19 | self.dtype = dtype 20 | self.device = device 21 | 22 | @property 23 | @abstractmethod 24 | def batch_type(self) -> Type[B]: 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | def embed(self, batch: B) -> List[Embedding]: 29 | raise NotImplementedError 30 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/pooling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from opentelemetry import trace 5 | from sentence_transformers.models import Pooling 6 | from torch import Tensor 7 | 8 | tracer = trace.get_tracer(__name__) 9 | 10 | 11 | class _Pooling(ABC): 12 | @abstractmethod 13 | def forward(self, model_output, attention_mask) -> Tensor: 14 | pass 15 | 16 | 17 | class DefaultPooling(_Pooling): 18 | def __init__(self, hidden_size, pooling_mode) -> None: 19 | assert ( 20 | pooling_mode != "splade" 21 | ), "Splade pooling is not supported for DefaultPooling" 22 | self.pooling = Pooling(hidden_size, pooling_mode=pooling_mode) 23 | 24 | @tracer.start_as_current_span("pooling") 25 | def forward(self, model_output, attention_mask) -> Tensor: 26 | pooling_features = { 27 | "token_embeddings": model_output[0], 28 | "attention_mask": attention_mask, 29 | } 30 | return self.pooling.forward(pooling_features)["sentence_embedding"] 31 | 32 | 33 | class SpladePooling(_Pooling): 34 | @tracer.start_as_current_span("pooling") 35 | def forward(self, model_output, attention_mask) -> Tensor: 36 | # Implement Splade pooling 37 | hidden_states = torch.relu(model_output[0]) 38 | hidden_states = (1 + hidden_states).log() 39 | hidden_states = torch.mul(hidden_states, attention_mask.unsqueeze(-1)) 40 | return hidden_states.max(dim=1).values 41 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/models/types.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from opentelemetry import trace 8 | 9 | from text_embeddings_server.pb import embed_pb2 10 | from text_embeddings_server.pb.embed_pb2 import Embedding, Score 11 | 12 | tracer = trace.get_tracer(__name__) 13 | PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) 14 | 15 | 16 | def round_up(number, k): 17 | return (number + k - 1) // k * k 18 | 19 | 20 | class Batch(ABC): 21 | @classmethod 22 | @abstractmethod 23 | def from_pb(cls, pb: embed_pb2.EmbedRequest, device: torch.device) -> "Batch": 24 | raise NotImplementedError 25 | 26 | @abstractmethod 27 | def __len__(self): 28 | raise NotImplementedError 29 | 30 | 31 | @dataclass 32 | class PaddedBatch(Batch): 33 | input_ids: torch.Tensor 34 | token_type_ids: torch.Tensor 35 | position_ids: torch.Tensor 36 | attention_mask: torch.Tensor 37 | 38 | @classmethod 39 | @tracer.start_as_current_span("from_pb") 40 | def from_pb( 41 | cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int 42 | ) -> "PaddedBatch": 43 | if pb.max_length > max_input_length: 44 | raise RuntimeError(f"input length exceeds model config's max_input_length") 45 | 46 | batch_size = len(pb.cu_seq_lengths) - 1 47 | if device.type == "hpu": 48 | # To better utilize HPU, we need to do batch/seq_len bucketing 49 | max_length = round_up(pb.max_length, PAD_SEQUENCE_TO_MULTIPLE_OF) 50 | max_length = min(max_length, max_input_length) 51 | new_bs = 2 ** math.ceil(math.log2(batch_size)) 52 | else: 53 | new_bs = batch_size 54 | max_length = pb.max_length 55 | # Allocate padded tensors all at once 56 | all_tensors = torch.zeros([4, new_bs, max_length], dtype=torch.int32) 57 | 58 | for i, start_index in enumerate(pb.cu_seq_lengths[:-1]): 59 | end_index = pb.cu_seq_lengths[i + 1] 60 | input_length = end_index - start_index 61 | 62 | all_tensors[0, i, :input_length] = torch.tensor( 63 | pb.input_ids[start_index:end_index], dtype=torch.int32 64 | ) 65 | all_tensors[1, i, :input_length] = torch.tensor( 66 | pb.token_type_ids[start_index:end_index], dtype=torch.int32 67 | ) 68 | all_tensors[2, i, :input_length] = torch.tensor( 69 | pb.position_ids[start_index:end_index], dtype=torch.int32 70 | ) 71 | all_tensors[3, i, :input_length] = 1 72 | 73 | # Move padded tensors all at once 74 | all_tensors = all_tensors.to(device) 75 | 76 | return PaddedBatch( 77 | input_ids=all_tensors[0], 78 | token_type_ids=all_tensors[1], 79 | position_ids=all_tensors[2], 80 | attention_mask=all_tensors[3], 81 | ) 82 | 83 | def __len__(self): 84 | return len(self.input_ids) 85 | 86 | 87 | @dataclass 88 | class FlashBatch(Batch): 89 | input_ids: torch.Tensor 90 | token_type_ids: torch.Tensor 91 | position_ids: torch.Tensor 92 | 93 | cu_seqlens: torch.Tensor 94 | max_s: int 95 | size: int 96 | 97 | @classmethod 98 | @tracer.start_as_current_span("from_pb") 99 | def from_pb( 100 | cls, pb: embed_pb2.EmbedRequest, device: torch.device, max_input_length: int 101 | ) -> "FlashBatch": 102 | batch_input_ids = torch.tensor(pb.input_ids, dtype=torch.int32, device=device) 103 | batch_token_type_ids = torch.tensor( 104 | pb.token_type_ids, dtype=torch.int32, device=device 105 | ) 106 | batch_position_ids = torch.tensor( 107 | pb.position_ids, dtype=torch.int32, device=device 108 | ) 109 | 110 | cu_seqlens = torch.tensor(pb.cu_seq_lengths, dtype=torch.int32, device=device) 111 | 112 | return FlashBatch( 113 | input_ids=batch_input_ids, 114 | token_type_ids=batch_token_type_ids, 115 | position_ids=batch_position_ids, 116 | cu_seqlens=cu_seqlens, 117 | max_s=pb.max_length, 118 | size=len(cu_seqlens) - 1, 119 | ) 120 | 121 | def __len__(self): 122 | return self.size 123 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/pb/.gitignore: -------------------------------------------------------------------------------- 1 | *.py 2 | *.pyi 3 | *.py-e 4 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/server.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import torch 3 | from grpc import aio 4 | from loguru import logger 5 | 6 | from grpc_reflection.v1alpha import reflection 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | from text_embeddings_server.models import Model, get_model 11 | from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2 12 | from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor 13 | from text_embeddings_server.utils.interceptor import ExceptionInterceptor 14 | 15 | 16 | class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer): 17 | def __init__(self, model: Model): 18 | self.model = model 19 | # Force inference mode for the lifetime of EmbeddingService 20 | self._inference_mode_raii_guard = torch._C._InferenceMode(True) 21 | 22 | async def Health(self, request, context): 23 | if self.model.device.type == "cuda": 24 | torch.zeros((2, 2), device="cuda") 25 | return embed_pb2.HealthResponse() 26 | 27 | async def Embed(self, request, context): 28 | max_input_length = self.model.max_input_length 29 | batch = self.model.batch_type.from_pb( 30 | request, self.model.device, max_input_length 31 | ) 32 | 33 | embeddings = self.model.embed(batch) 34 | 35 | return embed_pb2.EmbedResponse(embeddings=embeddings) 36 | 37 | async def Predict(self, request, context): 38 | max_input_length = self.model.max_input_length 39 | batch = self.model.batch_type.from_pb( 40 | request, self.model.device, max_input_length 41 | ) 42 | 43 | scores = self.model.predict(batch) 44 | 45 | return embed_pb2.PredictResponse(scores=scores) 46 | 47 | 48 | def serve( 49 | model_path: Path, 50 | dtype: Optional[str], 51 | uds_path: Path, 52 | pool: str, 53 | ): 54 | async def serve_inner( 55 | model_path: Path, 56 | dtype: Optional[str] = None, 57 | ): 58 | unix_socket = f"unix://{uds_path}" 59 | 60 | try: 61 | model = get_model(model_path, dtype, pool) 62 | except Exception: 63 | logger.exception("Error when initializing model") 64 | raise 65 | 66 | server = aio.server( 67 | interceptors=[ 68 | ExceptionInterceptor(), 69 | UDSOpenTelemetryAioServerInterceptor(), 70 | ] 71 | ) 72 | embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( 73 | EmbeddingService(model), server 74 | ) 75 | SERVICE_NAMES = ( 76 | embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name, 77 | reflection.SERVICE_NAME, 78 | ) 79 | reflection.enable_server_reflection(SERVICE_NAMES, server) 80 | server.add_insecure_port(unix_socket) 81 | 82 | await server.start() 83 | 84 | logger.info(f"Server started at {unix_socket}") 85 | 86 | try: 87 | await server.wait_for_termination() 88 | except KeyboardInterrupt: 89 | logger.info("Signal received. Shutting down") 90 | await server.stop(0) 91 | 92 | asyncio.run(serve_inner(model_path, dtype)) 93 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/text-embeddings-inference/d51a8b99d01d00d5ee7a52ea8dec53ed1c1b7d7c/backends/python/server/text_embeddings_server/utils/__init__.py -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/utils/device.py: -------------------------------------------------------------------------------- 1 | import os 2 | from loguru import logger 3 | import importlib.metadata 4 | import importlib.util 5 | from packaging import version 6 | import torch 7 | import subprocess 8 | 9 | ALLOW_REDUCED_PRECISION = os.getenv( 10 | "ALLOW_REDUCED_PRECISION_FP16_BF16", "true" 11 | ).lower() in [ 12 | "true", 13 | "1", 14 | ] 15 | 16 | 17 | def _is_ipex_available(): 18 | def get_major_and_minor_from_version(full_version): 19 | return ( 20 | str(version.parse(full_version).major) 21 | + "." 22 | + str(version.parse(full_version).minor) 23 | ) 24 | 25 | _torch_version = importlib.metadata.version("torch") 26 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 27 | return False 28 | _ipex_version = "N/A" 29 | try: 30 | _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") 31 | except importlib.metadata.PackageNotFoundError: 32 | return False 33 | torch_major_and_minor = get_major_and_minor_from_version(_torch_version) 34 | ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) 35 | if torch_major_and_minor != ipex_major_and_minor: 36 | logger.warning( 37 | f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," 38 | f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." 39 | ) 40 | return False 41 | return True 42 | 43 | 44 | def is_hpu() -> bool: 45 | is_hpu_available = True 46 | try: 47 | subprocess.run(["hl-smi"], capture_output=True, check=True) 48 | except: 49 | is_hpu_available = False 50 | return is_hpu_available 51 | 52 | 53 | def use_ipex() -> bool: 54 | value = os.environ.get("USE_IPEX", "True").lower() 55 | return value in ["true", "1"] and _is_ipex_available() 56 | 57 | 58 | def get_device(): 59 | device = torch.device("cpu") 60 | if torch.cuda.is_available(): 61 | device = torch.device("cuda") 62 | elif is_hpu(): 63 | import habana_frameworks.torch.core as htcore 64 | 65 | # WA for perf degradation from pytorch 2.5 66 | if ALLOW_REDUCED_PRECISION: 67 | torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) 68 | if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore 69 | device = torch.device("hpu") 70 | elif use_ipex(): 71 | import intel_extension_for_pytorch as ipex 72 | 73 | if hasattr(torch, "xpu") and torch.xpu.is_available(): 74 | device = torch.device("xpu") 75 | 76 | return device 77 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/utils/interceptor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import grpc 3 | 4 | from google.rpc import status_pb2, code_pb2 5 | from grpc_status import rpc_status 6 | from grpc_interceptor.server import AsyncServerInterceptor 7 | from loguru import logger 8 | from typing import Callable, Any 9 | 10 | 11 | class ExceptionInterceptor(AsyncServerInterceptor): 12 | async def intercept( 13 | self, 14 | method: Callable, 15 | request_or_iterator: Any, 16 | context: grpc.ServicerContext, 17 | method_name: str, 18 | ) -> Any: 19 | try: 20 | response = method(request_or_iterator, context) 21 | return await response 22 | except Exception as err: 23 | method_name = method_name.split("/")[-1] 24 | logger.exception(f"Method {method_name} encountered an error.") 25 | 26 | if torch.cuda.is_available(): 27 | torch.cuda.empty_cache() 28 | 29 | await context.abort_with_status( 30 | rpc_status.to_status( 31 | status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) 32 | ) 33 | ) 34 | -------------------------------------------------------------------------------- /backends/python/server/text_embeddings_server/utils/tracing.py: -------------------------------------------------------------------------------- 1 | import grpc 2 | 3 | from opentelemetry import trace 4 | from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter 5 | from opentelemetry.instrumentation.grpc._aio_server import ( 6 | OpenTelemetryAioServerInterceptor, 7 | ) 8 | from opentelemetry.semconv.trace import SpanAttributes 9 | from opentelemetry.sdk.resources import Resource 10 | from opentelemetry.sdk.trace import TracerProvider 11 | from opentelemetry.sdk.trace.export import ( 12 | BatchSpanProcessor, 13 | ) 14 | 15 | 16 | class UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor): 17 | def __init__(self): 18 | super().__init__(trace.get_tracer(__name__)) 19 | 20 | def _start_span(self, handler_call_details, context, set_status_on_exception=False): 21 | """ 22 | Rewrite _start_span method to support Unix Domain Socket gRPC contexts 23 | """ 24 | 25 | # standard attributes 26 | attributes = { 27 | SpanAttributes.RPC_SYSTEM: "grpc", 28 | SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], 29 | } 30 | 31 | # if we have details about the call, split into service and method 32 | if handler_call_details.method: 33 | service, method = handler_call_details.method.lstrip("/").split("/", 1) 34 | attributes.update( 35 | { 36 | SpanAttributes.RPC_METHOD: method, 37 | SpanAttributes.RPC_SERVICE: service, 38 | } 39 | ) 40 | 41 | # add some attributes from the metadata 42 | metadata = dict(context.invocation_metadata()) 43 | if "user-agent" in metadata: 44 | attributes["rpc.user_agent"] = metadata["user-agent"] 45 | 46 | # We use gRPC over a UNIX socket 47 | attributes.update({SpanAttributes.NET_TRANSPORT: "unix"}) 48 | 49 | return self._tracer.start_as_current_span( 50 | name=handler_call_details.method, 51 | kind=trace.SpanKind.SERVER, 52 | attributes=attributes, 53 | set_status_on_exception=set_status_on_exception, 54 | ) 55 | 56 | 57 | def setup_tracing(otlp_endpoint: str, otlp_service_name: str): 58 | resource = Resource.create(attributes={"service.name": otlp_service_name}) 59 | span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True) 60 | span_processor = BatchSpanProcessor(span_exporter) 61 | 62 | trace.set_tracer_provider(TracerProvider(resource=resource)) 63 | trace.get_tracer_provider().add_span_processor(span_processor) 64 | -------------------------------------------------------------------------------- /backends/python/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod logging; 2 | mod management; 3 | 4 | use backend_grpc_client::Client; 5 | use nohash_hasher::BuildNoHashHasher; 6 | use std::collections::HashMap; 7 | use text_embeddings_backend_core::{ 8 | Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, 9 | }; 10 | use tokio::runtime::Runtime; 11 | 12 | pub struct PythonBackend { 13 | _backend_process: management::BackendProcess, 14 | tokio_runtime: Runtime, 15 | backend_client: Client, 16 | } 17 | 18 | impl PythonBackend { 19 | pub fn new( 20 | model_path: String, 21 | dtype: String, 22 | model_type: ModelType, 23 | uds_path: String, 24 | otlp_endpoint: Option, 25 | otlp_service_name: String, 26 | ) -> Result { 27 | let pool = match model_type { 28 | ModelType::Classifier => Pool::Cls, 29 | ModelType::Embedding(pool) => pool, 30 | }; 31 | 32 | let backend_process = management::BackendProcess::new( 33 | model_path, 34 | dtype, 35 | &uds_path, 36 | otlp_endpoint, 37 | otlp_service_name, 38 | pool, 39 | )?; 40 | let tokio_runtime = tokio::runtime::Builder::new_current_thread() 41 | .enable_all() 42 | .build() 43 | .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; 44 | 45 | let backend_client = tokio_runtime 46 | .block_on(Client::connect_uds(uds_path)) 47 | .map_err(|err| { 48 | BackendError::Start(format!("Could not connect to backend process: {err}")) 49 | })?; 50 | 51 | Ok(Self { 52 | _backend_process: backend_process, 53 | tokio_runtime, 54 | backend_client, 55 | }) 56 | } 57 | } 58 | 59 | impl Backend for PythonBackend { 60 | fn health(&self) -> Result<(), BackendError> { 61 | if self 62 | .tokio_runtime 63 | .block_on(self.backend_client.clone().health()) 64 | .is_err() 65 | { 66 | return Err(BackendError::Unhealthy); 67 | } 68 | Ok(()) 69 | } 70 | 71 | fn is_padded(&self) -> bool { 72 | false 73 | } 74 | 75 | fn embed(&self, batch: Batch) -> Result { 76 | if !batch.raw_indices.is_empty() { 77 | return Err(BackendError::Inference( 78 | "raw embeddings are not supported for the Python backend.".to_string(), 79 | )); 80 | } 81 | let batch_size = batch.len(); 82 | 83 | let results = self 84 | .tokio_runtime 85 | .block_on(self.backend_client.clone().embed( 86 | batch.input_ids, 87 | batch.token_type_ids, 88 | batch.position_ids, 89 | batch.cumulative_seq_lengths, 90 | batch.max_length, 91 | )) 92 | .map_err(|err| BackendError::Inference(err.to_string()))?; 93 | let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); 94 | 95 | let mut embeddings = 96 | HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); 97 | for (i, e) in pooled_embeddings.into_iter().enumerate() { 98 | embeddings.insert(i, Embedding::Pooled(e)); 99 | } 100 | 101 | Ok(embeddings) 102 | } 103 | 104 | fn predict(&self, batch: Batch) -> Result { 105 | if !batch.raw_indices.is_empty() { 106 | return Err(BackendError::Inference( 107 | "raw embeddings are not supported for the Python backend.".to_string(), 108 | )); 109 | } 110 | let batch_size = batch.len(); 111 | let results = self 112 | .tokio_runtime 113 | .block_on(self.backend_client.clone().predict( 114 | batch.input_ids, 115 | batch.token_type_ids, 116 | batch.position_ids, 117 | batch.cumulative_seq_lengths, 118 | batch.max_length, 119 | )) 120 | .map_err(|err| BackendError::Inference(err.to_string()))?; 121 | let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); 122 | 123 | let mut predictions = 124 | HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); 125 | 126 | for (i, r) in raw_results.into_iter().enumerate() { 127 | predictions.insert(i, r); 128 | } 129 | 130 | Ok(predictions) 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /backends/python/src/logging.rs: -------------------------------------------------------------------------------- 1 | use serde::Deserialize; 2 | use std::io::{BufRead, Lines}; 3 | 4 | #[derive(Deserialize)] 5 | #[serde(rename_all = "UPPERCASE")] 6 | enum PythonLogLevelEnum { 7 | Trace, 8 | Debug, 9 | Info, 10 | Success, 11 | Warning, 12 | Error, 13 | Critical, 14 | } 15 | 16 | #[derive(Deserialize)] 17 | struct PythonLogLevel { 18 | name: PythonLogLevelEnum, 19 | } 20 | 21 | #[derive(Deserialize)] 22 | struct PythonLogRecord { 23 | level: PythonLogLevel, 24 | } 25 | 26 | #[derive(Deserialize)] 27 | struct PythonLogMessage { 28 | text: String, 29 | record: PythonLogRecord, 30 | } 31 | 32 | impl PythonLogMessage { 33 | fn trace(&self) { 34 | match self.record.level.name { 35 | PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), 36 | PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), 37 | PythonLogLevelEnum::Info => tracing::info!("{}", self.text), 38 | PythonLogLevelEnum::Success => tracing::info!("{}", self.text), 39 | PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), 40 | PythonLogLevelEnum::Error => tracing::error!("{}", self.text), 41 | PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), 42 | } 43 | } 44 | } 45 | 46 | impl TryFrom<&String> for PythonLogMessage { 47 | type Error = serde_json::Error; 48 | 49 | fn try_from(value: &String) -> Result { 50 | serde_json::from_str::(value) 51 | } 52 | } 53 | 54 | pub(crate) fn log_lines(lines: Lines) { 55 | for line in lines.map_while(Result::ok) { 56 | match PythonLogMessage::try_from(&line) { 57 | Ok(log) => log.trace(), 58 | Err(_) => tracing::debug!("{line}"), 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /backends/src/dtype.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | #[cfg(feature = "clap")] 4 | use clap::ValueEnum; 5 | 6 | #[derive(Debug, PartialEq)] 7 | #[cfg_attr(feature = "clap", derive(Clone, ValueEnum))] 8 | pub enum DType { 9 | // Float16 is not available on accelerate 10 | #[cfg(any( 11 | feature = "python", 12 | all(feature = "candle", not(feature = "accelerate")) 13 | ))] 14 | Float16, 15 | #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] 16 | Float32, 17 | #[cfg(feature = "python")] 18 | Bfloat16, 19 | } 20 | 21 | impl fmt::Display for DType { 22 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 23 | match self { 24 | // Float16 is not available on accelerate 25 | #[cfg(any( 26 | feature = "python", 27 | all(feature = "candle", not(feature = "accelerate")) 28 | ))] 29 | DType::Float16 => write!(f, "float16"), 30 | #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] 31 | DType::Float32 => write!(f, "float32"), 32 | #[cfg(feature = "python")] 33 | DType::Bfloat16 => write!(f, "bfloat16"), 34 | } 35 | } 36 | } 37 | 38 | #[allow(clippy::derivable_impls)] 39 | impl Default for DType { 40 | fn default() -> Self { 41 | #[cfg(any(feature = "accelerate", feature = "mkl", feature = "ort"))] 42 | { 43 | DType::Float32 44 | } 45 | #[cfg(not(any( 46 | feature = "accelerate", 47 | feature = "mkl", 48 | feature = "ort", 49 | feature = "python" 50 | )))] 51 | { 52 | DType::Float16 53 | } 54 | #[cfg(feature = "python")] 55 | { 56 | DType::Bfloat16 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-core" 3 | version.workspace = true 4 | edition.workspace = true 5 | authors.workspace = true 6 | homepage.workspace = true 7 | 8 | [dependencies] 9 | async-channel = "^2.3" 10 | hf-hub = { workspace = true } 11 | metrics = { workspace = true } 12 | serde_json = { workspace = true } 13 | text-embeddings-backend = { path = "../backends" } 14 | thiserror = { workspace = true } 15 | tokenizers = { workspace = true } 16 | tracing = { workspace = true } 17 | tokio = { workspace = true } 18 | -------------------------------------------------------------------------------- /core/src/download.rs: -------------------------------------------------------------------------------- 1 | use hf_hub::api::tokio::{ApiError, ApiRepo}; 2 | use std::path::PathBuf; 3 | use tracing::instrument; 4 | 5 | // Old classes used other config names than 'sentence_bert_config.json' 6 | pub const ST_CONFIG_NAMES: [&str; 7] = [ 7 | "sentence_bert_config.json", 8 | "sentence_roberta_config.json", 9 | "sentence_distilbert_config.json", 10 | "sentence_camembert_config.json", 11 | "sentence_albert_config.json", 12 | "sentence_xlm-roberta_config.json", 13 | "sentence_xlnet_config.json", 14 | ]; 15 | 16 | #[instrument(skip_all)] 17 | pub async fn download_artifacts(api: &ApiRepo, pool_config: bool) -> Result { 18 | let start = std::time::Instant::now(); 19 | 20 | tracing::info!("Starting download"); 21 | 22 | // Optionally download the pooling config. 23 | if pool_config { 24 | // If a pooling config exist, download it 25 | let _ = download_pool_config(api).await.map_err(|err| { 26 | tracing::warn!("Download failed: {err}"); 27 | err 28 | }); 29 | } 30 | 31 | // Download legacy sentence transformers config 32 | // We don't warn on failure as it is a legacy file 33 | let _ = download_st_config(api).await; 34 | // Download new sentence transformers config 35 | let _ = download_new_st_config(api).await.map_err(|err| { 36 | tracing::warn!("Download failed: {err}"); 37 | err 38 | }); 39 | 40 | tracing::info!("Downloading `config.json`"); 41 | api.get("config.json").await?; 42 | 43 | tracing::info!("Downloading `tokenizer.json`"); 44 | let tokenizer_path = api.get("tokenizer.json").await?; 45 | 46 | let model_root = tokenizer_path.parent().unwrap().to_path_buf(); 47 | tracing::info!("Model artifacts downloaded in {:?}", start.elapsed()); 48 | Ok(model_root) 49 | } 50 | 51 | #[instrument(skip_all)] 52 | pub async fn download_pool_config(api: &ApiRepo) -> Result { 53 | tracing::info!("Downloading `1_Pooling/config.json`"); 54 | let pool_config_path = api.get("1_Pooling/config.json").await?; 55 | Ok(pool_config_path) 56 | } 57 | 58 | #[instrument(skip_all)] 59 | pub async fn download_st_config(api: &ApiRepo) -> Result { 60 | // Try default path 61 | let err = match api.get(ST_CONFIG_NAMES[0]).await { 62 | Ok(st_config_path) => return Ok(st_config_path), 63 | Err(err) => err, 64 | }; 65 | 66 | for name in &ST_CONFIG_NAMES[1..] { 67 | if let Ok(st_config_path) = api.get(name).await { 68 | return Ok(st_config_path); 69 | } 70 | } 71 | 72 | Err(err) 73 | } 74 | 75 | #[instrument(skip_all)] 76 | pub async fn download_new_st_config(api: &ApiRepo) -> Result { 77 | tracing::info!("Downloading `config_sentence_transformers.json`"); 78 | let pool_config_path = api.get("config_sentence_transformers.json").await?; 79 | Ok(pool_config_path) 80 | } 81 | -------------------------------------------------------------------------------- /core/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod download; 2 | pub mod infer; 3 | pub mod queue; 4 | pub mod tokenization; 5 | 6 | use text_embeddings_backend::BackendError; 7 | use thiserror::Error; 8 | use tokio::sync::TryAcquireError; 9 | 10 | #[derive(Error, Debug)] 11 | pub enum TextEmbeddingsError { 12 | #[error("tokenizer error {0}")] 13 | Tokenizer(#[from] tokenizers::Error), 14 | #[error("Input validation error: {0}")] 15 | Validation(String), 16 | #[error("Model is overloaded")] 17 | Overloaded(#[from] TryAcquireError), 18 | #[error("Backend error: {0}")] 19 | Backend(#[from] BackendError), 20 | } 21 | -------------------------------------------------------------------------------- /cuda-all-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if ! command -v nvidia-smi &> /dev/null; then 4 | echo "Error: 'nvidia-smi' command not found." 5 | exit 1 6 | fi 7 | 8 | compute_cap=$(nvidia-smi --query-gpu=compute_cap --format=csv | sed -n '2p' | sed 's/\.//g') 9 | 10 | if [ ${compute_cap} -eq 75 ] 11 | then 12 | exec text-embeddings-router-75 "$@" 13 | elif [ ${compute_cap} -ge 80 -a ${compute_cap} -lt 90 ] 14 | then 15 | exec text-embeddings-router-80 "$@" 16 | elif [ ${compute_cap} -eq 90 ] 17 | then 18 | exec text-embeddings-router-90 "$@" 19 | else 20 | echo "cuda compute cap ${compute_cap} is not supported"; exit 1 21 | fi 22 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Text Embeddings Inference API 7 | 8 | 9 |
10 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /docs/source/en/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: Text Embeddings Inference 4 | - local: quick_tour 5 | title: Quick Tour 6 | - local: supported_models 7 | title: Supported models and hardware 8 | title: Getting started 9 | - sections: 10 | - local: local_cpu 11 | title: Using TEI locally with CPU 12 | - local: local_metal 13 | title: Using TEI locally with Metal 14 | - local: local_gpu 15 | title: Using TEI locally with GPU 16 | - local: private_models 17 | title: Serving private and gated models 18 | - local: custom_container 19 | title: Build custom container for TEI 20 | - local: intel_container 21 | title: Using TEI container with Intel Hardware 22 | - local: examples 23 | title: Example uses 24 | title: Tutorials 25 | - sections: 26 | - local: tei_cloud_run 27 | title: Cloud Run 28 | title: Deploying TEI on Google Cloud 29 | - sections: 30 | - local: cli_arguments 31 | title: CLI arguments 32 | title: Reference 33 | -------------------------------------------------------------------------------- /docs/source/en/custom_container.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Build a custom container for TEI 18 | 19 | You can build our own CPU or CUDA TEI container using Docker. To build a CPU container, run the following command in the 20 | directory containing your custom Dockerfile: 21 | 22 | ```shell 23 | docker build . 24 | ``` 25 | 26 | To build a CUDA container, it is essential to determine the compute capability (compute cap) of the GPU that will be 27 | used at runtime. This information is crucial for the proper configuration of the CUDA containers. The following are 28 | the examples of runtime compute capabilities for various GPU types: 29 | 30 | - Turing (T4, RTX 2000 series, ...) - `runtime_compute_cap=75` 31 | - A100 - `runtime_compute_cap=80` 32 | - A10 - `runtime_compute_cap=86` 33 | - Ada Lovelace (RTX 4000 series, ...) - `runtime_compute_cap=89` 34 | - H100 - `runtime_compute_cap=90` 35 | 36 | Once you have determined the compute capability is determined, set it as the `runtime_compute_cap` variable and build 37 | the container as shown in the example below: 38 | 39 | ```shell 40 | # Get submodule dependencies 41 | git submodule update --init 42 | 43 | runtime_compute_cap=80 44 | 45 | docker build . -f Dockerfile-cuda --build-arg CUDA_COMPUTE_CAP=$runtime_compute_cap 46 | ``` 47 | -------------------------------------------------------------------------------- /docs/source/en/examples.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Example uses 18 | 19 | - [Set up an Inference Endpoint with TEI](https://huggingface.co/learn/cookbook/automatic_embedding_tei_inference_endpoints) 20 | - [RAG containers with TEI](https://github.com/plaggy/rag-containers) 21 | -------------------------------------------------------------------------------- /docs/source/en/index.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Text Embeddings Inference 18 | 19 | Text Embeddings Inference (TEI) is a comprehensive toolkit designed for efficient deployment and serving of open source 20 | text embeddings models. It enables high-performance extraction for the most popular models, including FlagEmbedding, Ember, GTE, and E5. 21 | 22 | TEI offers multiple features tailored to optimize the deployment process and enhance overall performance. 23 | 24 | **Key Features:** 25 | 26 | * **Streamlined Deployment:** TEI eliminates the need for a model graph compilation step for an easier deployment process. 27 | * **Efficient Resource Utilization:** Benefit from small Docker images and rapid boot times, allowing for true serverless capabilities. 28 | * **Dynamic Batching:** TEI incorporates token-based dynamic batching thus optimizing resource utilization during inference. 29 | * **Optimized Inference:** TEI leverages [Flash Attention](https://github.com/HazyResearch/flash-attention), [Candle](https://github.com/huggingface/candle), and [cuBLASLt](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api) by using optimized transformers code for inference. 30 | * **Safetensors weight loading:** TEI loads [Safetensors](https://github.com/huggingface/safetensors) weights for faster boot times. 31 | * **Production-Ready:** TEI supports distributed tracing through Open Telemetry and exports Prometheus metrics. 32 | 33 | **Benchmarks** 34 | 35 | Benchmark for [BAAI/bge-base-en-v1.5](https://hf.co/BAAI/bge-large-en-v1.5) on an NVIDIA A10 with a sequence length of 512 tokens: 36 | 37 |

38 | Latency comparison for batch size of 1 39 | Throughput comparison for batch size of 1 40 |

41 |

42 | Latency comparison for batch size of 32 43 | Throughput comparison for batch size of 32 44 |

45 | 46 | **Getting Started:** 47 | 48 | To start using TEI, check the [Quick Tour](quick_tour) guide. 49 | -------------------------------------------------------------------------------- /docs/source/en/intel_container.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Using TEI Container with Intel® Hardware 18 | 19 | This guide explains how to build and deploy `text-embeddings-inference` containers optimized for Intel® hardware, including CPUs, XPUs, and HPUs. 20 | 21 | ## CPU 22 | 23 | ### Build Docker Image 24 | 25 | To build a container optimized for Intel® CPUs, run the following command: 26 | 27 | ```shell 28 | platform="cpu" 29 | 30 | docker build . -f Dockerfile-intel --build-arg PLATFORM=$platform -t tei_cpu_ipex 31 | ``` 32 | 33 | ### Deploy Docker Container 34 | 35 | To deploy your model on an Intel® CPU, use the following command: 36 | 37 | ```shell 38 | model='BAAI/bge-large-en-v1.5' 39 | volume=$PWD/data 40 | 41 | docker run -p 8080:80 -v $volume:/data tei_cpu_ipex --model-id $model 42 | ``` 43 | 44 | ## XPU 45 | 46 | ### Build Docker Image 47 | 48 | To build a container optimized for Intel® XPUs, run the following command: 49 | 50 | ```shell 51 | platform="xpu" 52 | 53 | docker build . -f Dockerfile-intel --build-arg PLATFORM=$platform -t tei_xpu_ipex 54 | ``` 55 | 56 | ### Deploy Docker Container 57 | 58 | To deploy your model on an Intel® XPU, use the following command: 59 | 60 | ```shell 61 | model='BAAI/bge-large-en-v1.5' 62 | volume=$PWD/data 63 | 64 | docker run -p 8080:80 -v $volume:/data --device=/dev/dri -v /dev/dri/by-path:/dev/dri/by-path tei_xpu_ipex --model-id $model --dtype float16 65 | ``` 66 | 67 | ## HPU 68 | 69 | ### Build Docker Image 70 | 71 | To build a container optimized for Intel® HPUs (Gaudi), run the following command: 72 | 73 | ```shell 74 | platform="hpu" 75 | 76 | docker build . -f Dockerfile-intel --build-arg PLATFORM=$platform -t tei_hpu 77 | ``` 78 | 79 | ### Deploy Docker Container 80 | 81 | To deploy your model on an Intel® HPU (Gaudi), use the following command: 82 | 83 | ```shell 84 | model='BAAI/bge-large-en-v1.5' 85 | volume=$PWD/data 86 | 87 | docker run -p 8080:80 -v $volume:/data --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e MAX_WARMUP_SEQUENCE_LENGTH=512 tei_hpu --model-id $model --dtype bfloat16 88 | ``` 89 | 90 | ## Prebuilt Docker Images 91 | 92 | For convenience, prebuilt Docker images are available on GitHub Container Registry (GHCR). You can pull these images directly without the need to build them manually: 93 | 94 | ### CPU 95 | To use the prebuilt image optimized for Intel® CPUs, run: 96 | ```shell 97 | docker pull ghcr.io/huggingface/text-embeddings-inference:cpu-ipex-latest 98 | ``` 99 | 100 | ### XPU 101 | To use the prebuilt image optimized for Intel® XPUs, run: 102 | ```shell 103 | docker pull ghcr.io/huggingface/text-embeddings-inference:xpu-ipex-latest 104 | ``` 105 | 106 | ### HPU 107 | To use the prebuilt image optimized for Intel® HPUs (Gaudi), run: 108 | ```shell 109 | docker pull ghcr.io/huggingface/text-embeddings-inference:hpu-latest 110 | ``` 111 | -------------------------------------------------------------------------------- /docs/source/en/local_cpu.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Using TEI locally with CPU 18 | 19 | You can install `text-embeddings-inference` locally to run it on your own machine. Here are the step-by-step instructions for installation: 20 | 21 | ## Step 1: Install Rust 22 | 23 | [Install Rust](https://rustup.rs/) on your machine by run the following in your terminal, then following the instructions: 24 | 25 | ```shell 26 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 27 | ``` 28 | 29 | ## Step 2: Install necessary packages 30 | 31 | Depending on your machine's architecture, run one of the following commands: 32 | 33 | ### For x86 Machines 34 | 35 | ```shell 36 | cargo install --path router -F mkl 37 | ``` 38 | 39 | ### For M1 or M2 Machines 40 | 41 | ```shell 42 | cargo install --path router -F metal 43 | ``` 44 | 45 | ## Step 3: Launch Text Embeddings Inference 46 | 47 | Once the installation is successfully complete, you can launch Text Embeddings Inference on CPU with the following command: 48 | 49 | ```shell 50 | model=BAAI/bge-large-en-v1.5 51 | revision=refs/pr/5 52 | 53 | text-embeddings-router --model-id $model --revision $revision --port 8080 54 | ``` 55 | 56 | 57 | 58 | In some cases, you might also need the OpenSSL libraries and gcc installed. On Linux machines, run the following command: 59 | 60 | ```shell 61 | sudo apt-get install libssl-dev gcc -y 62 | ``` 63 | 64 | 65 | 66 | Now you are ready to use `text-embeddings-inference` locally on your machine. 67 | If you want to run TEI locally with a GPU, check out the [Using TEI locally with GPU](local_gpu) page. 68 | -------------------------------------------------------------------------------- /docs/source/en/local_gpu.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Using TEI locally with GPU 18 | 19 | You can install `text-embeddings-inference` locally to run it on your own machine with a GPU. 20 | To make sure that your hardware is supported, check out the [Supported models and hardware](supported_models) page. 21 | 22 | ## Step 1: CUDA and NVIDIA drivers 23 | 24 | Make sure you have CUDA and the NVIDIA drivers installed - NVIDIA drivers on your device need to be compatible with CUDA version 12.2 or higher. 25 | 26 | Add the NVIDIA binaries to your path: 27 | 28 | ```shell 29 | export PATH=$PATH:/usr/local/cuda/bin 30 | ``` 31 | 32 | ## Step 2: Install Rust 33 | 34 | [Install Rust](https://rustup.rs/) on your machine by run the following in your terminal, then following the instructions: 35 | 36 | ```shell 37 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 38 | ``` 39 | 40 | ## Step 3: Install necessary packages 41 | 42 | This step can take a while as we need to compile a lot of cuda kernels. 43 | 44 | ### For Turing GPUs (T4, RTX 2000 series ... ) 45 | 46 | ```shell 47 | cargo install --path router -F candle-cuda-turing -F http --no-default-features 48 | ``` 49 | 50 | ### For Ampere and Hopper 51 | 52 | ```shell 53 | cargo install --path router -F candle-cuda -F http --no-default-features 54 | ``` 55 | 56 | ## Step 4: Launch Text Embeddings Inference 57 | 58 | You can now launch Text Embeddings Inference on GPU with: 59 | 60 | ```shell 61 | model=BAAI/bge-large-en-v1.5 62 | revision=refs/pr/5 63 | 64 | text-embeddings-router --model-id $model --revision $revision --port 8080 65 | ``` 66 | -------------------------------------------------------------------------------- /docs/source/en/local_metal.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Using TEI locally with Metal 18 | 19 | You can install `text-embeddings-inference` locally to run it on your own Mac with Metal support. 20 | Here are the step-by-step instructions for installation: 21 | 22 | ## Step 1: Install Rust 23 | 24 | [Install Rust](https://rustup.rs/) on your machine by run the following in your terminal, then following the instructions: 25 | 26 | ```shell 27 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 28 | ``` 29 | 30 | ## Step 2: Install with Metal support 31 | 32 | ```shell 33 | cargo install --path router -F metal 34 | ``` 35 | 36 | ## Step 3: Launch Text Embeddings Inference 37 | 38 | Once the installation is successfully complete, you can launch Text Embeddings Inference with Metal with the following command: 39 | 40 | ```shell 41 | model=BAAI/bge-large-en-v1.5 42 | revision=refs/pr/5 43 | 44 | text-embeddings-router --model-id $model --revision $revision --port 8080 45 | ``` 46 | 47 | Now you are ready to use `text-embeddings-inference` locally on your machine. 48 | -------------------------------------------------------------------------------- /docs/source/en/private_models.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | # Serving private and gated models 18 | 19 | If the model you wish to serve is behind gated access or resides in a private model repository on Hugging Face Hub, 20 | you will need to have access to the model to serve it. 21 | 22 | Once you have confirmed that you have access to the model: 23 | 24 | - Navigate to your account's [Profile | Settings | Access Tokens page](https://huggingface.co/settings/tokens). 25 | - Generate and copy a read token. 26 | 27 | If you're the CLI, set the `HF_TOKEN` environment variable. For example: 28 | 29 | ```shell 30 | export HF_TOKEN= 31 | ``` 32 | 33 | Alternatively, you can provide the token when deploying the model with Docker: 34 | 35 | ```shell 36 | model= 37 | volume=$PWD/data 38 | token= 39 | 40 | docker run --gpus all -e HF_TOKEN=$token -p 8080:80 -v $volume:/data --pull always ghcr.io/huggingface/text-embeddings-inference:1.7 --model-id $model 41 | ``` 42 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "crane": { 4 | "locked": { 5 | "lastModified": 1742394900, 6 | "narHash": "sha256-vVOAp9ahvnU+fQoKd4SEXB2JG2wbENkpqcwlkIXgUC0=", 7 | "owner": "ipetkov", 8 | "repo": "crane", 9 | "rev": "70947c1908108c0c551ddfd73d4f750ff2ea67cd", 10 | "type": "github" 11 | }, 12 | "original": { 13 | "owner": "ipetkov", 14 | "repo": "crane", 15 | "type": "github" 16 | } 17 | }, 18 | "flake-utils": { 19 | "inputs": { 20 | "systems": "systems" 21 | }, 22 | "locked": { 23 | "lastModified": 1731533236, 24 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 25 | "owner": "numtide", 26 | "repo": "flake-utils", 27 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "numtide", 32 | "repo": "flake-utils", 33 | "type": "github" 34 | } 35 | }, 36 | "nixpkgs": { 37 | "locked": { 38 | "lastModified": 1743076231, 39 | "narHash": "sha256-yQugdVfi316qUfqzN8JMaA2vixl+45GxNm4oUfXlbgw=", 40 | "owner": "NixOS", 41 | "repo": "nixpkgs", 42 | "rev": "6c5963357f3c1c840201eda129a99d455074db04", 43 | "type": "github" 44 | }, 45 | "original": { 46 | "owner": "NixOS", 47 | "ref": "nixpkgs-unstable", 48 | "repo": "nixpkgs", 49 | "type": "github" 50 | } 51 | }, 52 | "root": { 53 | "inputs": { 54 | "crane": "crane", 55 | "flake-utils": "flake-utils", 56 | "nixpkgs": "nixpkgs", 57 | "rust-overlay": "rust-overlay" 58 | } 59 | }, 60 | "rust-overlay": { 61 | "inputs": { 62 | "nixpkgs": [ 63 | "nixpkgs" 64 | ] 65 | }, 66 | "locked": { 67 | "lastModified": 1743129211, 68 | "narHash": "sha256-gE8t+U9miTwm2NYWS9dFY8H1/QB4ifaFDq1KdV9KEqo=", 69 | "owner": "oxalica", 70 | "repo": "rust-overlay", 71 | "rev": "f93da1d26ba9963f34f94a6872b67a7939699543", 72 | "type": "github" 73 | }, 74 | "original": { 75 | "owner": "oxalica", 76 | "repo": "rust-overlay", 77 | "type": "github" 78 | } 79 | }, 80 | "systems": { 81 | "locked": { 82 | "lastModified": 1681028828, 83 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 84 | "owner": "nix-systems", 85 | "repo": "default", 86 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 87 | "type": "github" 88 | }, 89 | "original": { 90 | "owner": "nix-systems", 91 | "repo": "default", 92 | "type": "github" 93 | } 94 | } 95 | }, 96 | "root": "root", 97 | "version": 7 98 | } 99 | -------------------------------------------------------------------------------- /load_tests/load.js: -------------------------------------------------------------------------------- 1 | import {check} from 'k6'; 2 | import http from 'k6/http'; 3 | import {Trend} from 'k6/metrics'; 4 | 5 | const host = __ENV.HOST || '127.0.0.1:3000'; 6 | 7 | const totalTime = new Trend('total_time', true); 8 | const tokenizationTIme = new Trend('tokenization_time', true); 9 | const queueTime = new Trend('queue_time', true); 10 | const inferenceTime = new Trend('inference_time', true); 11 | 12 | export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; 13 | 14 | export const options = { 15 | thresholds: { 16 | http_req_failed: ['rate==0'], 17 | }, 18 | scenarios: { 19 | // throughput: { 20 | // executor: 'shared-iterations', 21 | // vus: 5000, 22 | // iterations: 5000, 23 | // maxDuration: '2m', 24 | // gracefulStop: '1s', 25 | // }, 26 | load_test: { 27 | executor: 'constant-arrival-rate', 28 | duration: '30s', 29 | preAllocatedVUs: 5000, 30 | rate: 50, 31 | timeUnit: '1s', 32 | gracefulStop: '1s', 33 | }, 34 | }, 35 | }; 36 | 37 | export default function () { 38 | const payload = JSON.stringify({ 39 | inputs: inputs, 40 | // query: inputs, 41 | // texts: [inputs], 42 | truncate: true, 43 | }); 44 | 45 | const headers = {'Content-Type': 'application/json'}; 46 | const res = http.post(`http://${host}/`, payload, { 47 | headers, timeout: '20m' 48 | }); 49 | 50 | check(res, { 51 | 'Post status is 200': (r) => res.status === 200, 52 | }); 53 | 54 | if (res.status === 200) { 55 | totalTime.add(res.headers["X-Total-Time"]); 56 | tokenizationTIme.add(res.headers["X-Tokenization-Time"]); 57 | queueTime.add(res.headers["X-Queue-Time"]); 58 | inferenceTime.add(res.headers["X-Inference-Time"]); 59 | } else { 60 | console.log(res.error); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /load_tests/load_grpc.js: -------------------------------------------------------------------------------- 1 | import {check} from 'k6'; 2 | import grpc from 'k6/experimental/grpc'; 3 | import {Trend} from 'k6/metrics'; 4 | 5 | const host = __ENV.HOST || '127.0.0.1:3000'; 6 | 7 | const totalTime = new Trend('total_time', true); 8 | const tokenizationTIme = new Trend('tokenization_time', true); 9 | const queueTime = new Trend('queue_time', true); 10 | const inferenceTime = new Trend('inference_time', true); 11 | 12 | export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; 13 | 14 | export const options = { 15 | thresholds: { 16 | http_req_failed: ['rate==0'], 17 | }, 18 | scenarios: { 19 | // throughput: { 20 | // executor: 'shared-iterations', 21 | // vus: 10000, 22 | // iterations: 10000, 23 | // maxDuration: '2m', 24 | // gracefulStop: '1s', 25 | // }, 26 | load_test: { 27 | executor: 'constant-arrival-rate', 28 | duration: '5m', 29 | preAllocatedVUs: 5000, 30 | rate: 1000, 31 | timeUnit: '1s', 32 | gracefulStop: '1s', 33 | }, 34 | }, 35 | }; 36 | 37 | 38 | const client = new grpc.Client(); 39 | 40 | client.load([], '../proto/tei.proto'); 41 | 42 | export default function () { 43 | if (__ITER == 0) { 44 | client.connect(host, { 45 | plaintext: true 46 | }); 47 | } 48 | 49 | const payload = { 50 | inputs: inputs, 51 | truncate: true, 52 | }; 53 | 54 | const res = client.invoke('tei.v1.Embed/Embed', payload); 55 | 56 | check(res, { 57 | 'status is OK': (r) => r && r.status === grpc.StatusOK, 58 | }); 59 | 60 | if (res.status === grpc.StatusOK) { 61 | totalTime.add(res.headers["x-total-time"]); 62 | tokenizationTIme.add(res.headers["x-tokenization-time"]); 63 | queueTime.add(res.headers["x-queue-time"]); 64 | inferenceTime.add(res.headers["x-inference-time"]); 65 | } else { 66 | console.log(res.error); 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /load_tests/load_grpc_stream.js: -------------------------------------------------------------------------------- 1 | import grpc from 'k6/experimental/grpc'; 2 | import {Counter, Trend} from 'k6/metrics'; 3 | 4 | const host = __ENV.HOST || '127.0.0.1:8080'; 5 | 6 | const streamCounter = new Counter('stream_counter'); 7 | const totalTime = new Trend('total_time', true); 8 | const tokenizationTIme = new Trend('tokenization_time', true); 9 | const queueTime = new Trend('queue_time', true); 10 | const inferenceTime = new Trend('inference_time', true); 11 | 12 | export const inputs = 'A path from a point approximately 330 metres east of the most south westerleasterly corner of Unit 4 Foundry Industrial Estate, then proceeding in a generally east-north-east direction for approximately 64 metres to a point approximately 282 metres east-south-east of the most easterly corner of Unit 2 Foundry Industrial Estate, Victoria Street, Widnes and approximately 259 metres east of the most southerly corner of Unit 4 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-north-east direction for approximately 350 metres to a point approximately 3 metres west-north-west of the most north westerly corner of the boundary fence of the scrap metal yard on the south side of Cornubia Road, Widnes, and approximately 47 metres west-south-west of the stub end of Cornubia Road be diverted to a 3 metre wide path from a point approximately 183 metres east-south-east of the most easterly corner of Unit 5 Foundry Industrial Estate, Victoria Street and approximately 272 metres east of the most north-easterly corner of 26 Ann Street West, Widnes, then proceeding in a generally north easterly direction for approximately 58 metres to a point approximately 216 metres east-south-east of the most easterly corner of Unit 4 Foundry Industrial Estate, Victoria Street and approximately 221 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally easterly direction for approximately 45 metres to a point approximately 265 metres east-south-east of the most north-easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 265 metres east of the most southerly corner of Unit 5 Foundry Industrial Estate, Victoria Street, then proceeding in a generally east-south-east direction for approximately 102 metres to a point approximately 366 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 463 metres east of the most north easterly corner of 22 Ann Street West, Widnes, then proceeding in a generally north-north-easterly direction for approximately 19 metres to a point approximately 368 metres east-south-east of the most easterly corner of Unit 3 Foundry Industrial Estate, Victoria Street and approximately 512 metres east of the most south easterly corner of 17 Batherton Close, Widnes then proceeding in a generally east-south, easterly direction for approximately 16 metres to a point approximately 420 metres east-south-east of the most southerly corner of Unit 2 Foundry'; 13 | 14 | export const options = { 15 | scenarios: { 16 | throughput: { 17 | executor: 'shared-iterations', 18 | vus: 1, 19 | iterations: 1, 20 | maxDuration: '2m', 21 | gracefulStop: '1s', 22 | }, 23 | }, 24 | }; 25 | 26 | 27 | const client = new grpc.Client(); 28 | 29 | client.load([], '../proto/tei.proto'); 30 | 31 | export default function () { 32 | if (__ITER == 0) { 33 | client.connect(host, { 34 | plaintext: true 35 | }); 36 | } 37 | 38 | const stream = new grpc.Stream(client, 'tei.v1.Embed/EmbedStream'); 39 | 40 | stream.on('data', (res) => { 41 | totalTime.add(res.metadata.totalTimeNs / 1e6); 42 | tokenizationTIme.add(res.metadata.tokenizationTimeNs / 1e6); 43 | queueTime.add(res.metadata.queueTimeNs / 1e6); 44 | inferenceTime.add(res.metadata.inferenceTimeNs / 1e6); 45 | }); 46 | 47 | stream.on('error', (err) => { 48 | console.log('Stream Error: ' + JSON.stringify(err)); 49 | }); 50 | 51 | const payload = { 52 | inputs: inputs, 53 | truncate: true, 54 | }; 55 | 56 | // send 10000 requests 57 | for (let i = 0; i < 10000; i++) { 58 | stream.write(payload); 59 | } 60 | 61 | // close the client stream 62 | stream.end(); 63 | } 64 | -------------------------------------------------------------------------------- /router/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "text-embeddings-router" 3 | description = "Text Embedding Webserver" 4 | build = "build.rs" 5 | version.workspace = true 6 | edition.workspace = true 7 | authors.workspace = true 8 | homepage.workspace = true 9 | 10 | [lib] 11 | path = "src/lib.rs" 12 | 13 | [[bin]] 14 | name = "text-embeddings-router" 15 | path = "src/main.rs" 16 | 17 | [dependencies] 18 | anyhow = { workspace = true } 19 | text-embeddings-backend = { path = "../backends", features = ["clap"] } 20 | text-embeddings-core = { path = "../core" } 21 | clap = { workspace = true } 22 | futures = "^0.3" 23 | init-tracing-opentelemetry = { version = "0.18.1", features = ["opentelemetry-otlp"] } 24 | hf-hub = { workspace = true } 25 | http = "1.0.0" 26 | num_cpus = { workspace = true } 27 | metrics = { workspace = true } 28 | metrics-exporter-prometheus = { version = "0.15.1", features = [] } 29 | opentelemetry = "0.23.0" 30 | opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio"] } 31 | opentelemetry-otlp = "0.16.0" 32 | reqwest = { version = "0.12.5", features = [] } 33 | simsimd = "4.4.0" 34 | serde = { workspace = true } 35 | serde_json = { workspace = true } 36 | thiserror = { workspace = true } 37 | tokenizers = { workspace = true } 38 | tokio = { workspace = true } 39 | tracing = { workspace = true } 40 | tracing-opentelemetry = "0.24.0" 41 | tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } 42 | veil = "0.1.6" 43 | 44 | # HTTP dependencies 45 | axum = { version = "0.7.4", features = ["json"], optional = true } 46 | axum-tracing-opentelemetry = { version = "0.18.1", optional = true } 47 | base64 = { version = "0.22.1", optional = true } 48 | tower-http = { version = "0.5.1", features = ["cors"], optional = true } 49 | utoipa = { version = "4.2", features = ["axum_extras"], optional = true } 50 | utoipa-swagger-ui = { version = "7.1", features = ["axum", "vendored"], optional = true } 51 | 52 | # gRPC dependencies 53 | async-stream = { version = "0.3.5", optional = true } 54 | prost = { version = "0.12.1", optional = true } 55 | tonic = { version = "0.11.0", optional = true } 56 | tonic-health = { version = "0.11.0", optional = true } 57 | tonic-reflection = { version = "0.11.0", optional = true } 58 | tokio-stream = { version = "0.1.14", optional = true } 59 | 60 | # Optional 61 | cudarc = { workspace = true, optional = true } 62 | intel-mkl-src = { workspace = true, optional = true } 63 | 64 | # Malloc trim hack for linux 65 | [target.'cfg(target_os = "linux")'.dependencies] 66 | libc = "0.2.149" 67 | # else use mimalloc 68 | [target.'cfg(not(target_os = "linux"))'.dependencies] 69 | mimalloc = { version = "*", features = ["no_thp"] } 70 | 71 | [dev-dependencies] 72 | insta = { git = "https://github.com/OlivierDehaene/insta", rev = "f4f98c0410b91fb5a28b10df98e4422955be9c2c", features = ["yaml"] } 73 | is_close = "0.1.3" 74 | reqwest = { version = "0.12.5", features = ["json"] } 75 | serial_test = { workspace = true } 76 | 77 | [build-dependencies] 78 | vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] } 79 | tonic-build = { version = "0.11.0", optional = true } 80 | 81 | [features] 82 | default = ["candle", "http", "dynamic-linking"] 83 | http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:base64", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"] 84 | grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"] 85 | metal = ["text-embeddings-backend/metal"] 86 | mkl = ["text-embeddings-backend/mkl"] 87 | accelerate = ["text-embeddings-backend/accelerate"] 88 | python = ["text-embeddings-backend/python"] 89 | ort = ["text-embeddings-backend/ort"] 90 | candle = ["text-embeddings-backend/candle"] 91 | candle-cuda = ["candle", "text-embeddings-backend/flash-attn", "dep:cudarc"] 92 | candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1", "dep:cudarc"] 93 | candle-cuda-volta = ["candle", "text-embeddings-backend/cuda", "dep:cudarc"] 94 | static-linking = ["cudarc?/static-linking", "intel-mkl-src?/mkl-static-lp64-iomp"] 95 | dynamic-linking = ["cudarc?/dynamic-linking", "intel-mkl-src?/mkl-dynamic-lp64-iomp"] 96 | google = [] 97 | -------------------------------------------------------------------------------- /router/build.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use vergen::EmitBuilder; 3 | 4 | fn main() -> Result<(), Box> { 5 | // Try to get the git sha from the local git repository 6 | if EmitBuilder::builder() 7 | .fail_on_error() 8 | .git_sha(false) 9 | .emit() 10 | .is_err() 11 | { 12 | // Unable to get the git sha 13 | if let Ok(sha) = std::env::var("GIT_SHA") { 14 | // Set it from an env var 15 | println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); 16 | } 17 | } 18 | 19 | // Set docker label if present 20 | if let Ok(label) = std::env::var("DOCKER_LABEL") { 21 | // Set it from an env var 22 | println!("cargo:rustc-env=DOCKER_LABEL={label}"); 23 | } 24 | 25 | #[cfg(feature = "grpc")] 26 | { 27 | use std::env; 28 | use std::fs; 29 | use std::path::PathBuf; 30 | 31 | fs::create_dir("src/grpc/pb").unwrap_or(()); 32 | 33 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 34 | tonic_build::configure() 35 | .build_client(false) 36 | .build_server(true) 37 | .file_descriptor_set_path(out_dir.join("descriptor.bin")) 38 | .out_dir("src/grpc/pb") 39 | .include_file("mod.rs") 40 | .compile(&["../proto/tei.proto"], &["../proto"]) 41 | .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); 42 | } 43 | 44 | Ok(()) 45 | } 46 | -------------------------------------------------------------------------------- /router/src/grpc/mod.rs: -------------------------------------------------------------------------------- 1 | mod pb; 2 | pub(crate) mod server; 3 | 4 | use pb::tei::v1::{ 5 | embed_server::EmbedServer, info_server::InfoServer, predict_server::PredictServer, 6 | rerank_server::RerankServer, tokenize_server::TokenizeServer, *, 7 | }; 8 | -------------------------------------------------------------------------------- /router/src/grpc/pb/.gitignore: -------------------------------------------------------------------------------- 1 | *.rs 2 | -------------------------------------------------------------------------------- /router/src/http/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod server; 2 | mod types; 3 | -------------------------------------------------------------------------------- /router/src/prometheus.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use metrics_exporter_prometheus::{BuildError, Matcher, PrometheusBuilder}; 4 | 5 | pub(crate) fn prometheus_builer( 6 | addr: SocketAddr, 7 | port: u16, 8 | max_input_length: usize, 9 | ) -> Result { 10 | let mut addr = addr; 11 | addr.set_port(port); 12 | 13 | // Duration buckets 14 | let duration_matcher = Matcher::Suffix(String::from("duration")); 15 | let n_duration_buckets = 35; 16 | let mut duration_buckets = Vec::with_capacity(n_duration_buckets); 17 | // Minimum duration in seconds 18 | let mut value = 0.00001; 19 | for _ in 0..n_duration_buckets { 20 | // geometric sequence 21 | value *= 1.5; 22 | duration_buckets.push(value); 23 | } 24 | 25 | // Input Length buckets 26 | let input_length_matcher = Matcher::Full(String::from("te_request_input_length")); 27 | let input_length_buckets: Vec = (0..20) 28 | .map(|x| 2.0_f64.powi(x)) 29 | .filter(|x| (*x as usize) <= max_input_length) 30 | .collect(); 31 | 32 | // Batch size buckets 33 | let batch_size_matcher = Matcher::Full(String::from("te_batch_next_size")); 34 | let batch_size_buckets: Vec = (0..13).map(|x| 2.0_f64.powi(x)).collect(); 35 | 36 | // Batch tokens buckets 37 | let batch_tokens_matcher = Matcher::Full(String::from("te_batch_next_tokens")); 38 | let batch_tokens_buckets: Vec = (0..21).map(|x| 2.0_f64.powi(x)).collect(); 39 | 40 | // Prometheus handler 41 | PrometheusBuilder::new() 42 | .with_http_listener(addr) 43 | .set_buckets_for_metric(duration_matcher, &duration_buckets)? 44 | .set_buckets_for_metric(input_length_matcher, &input_length_buckets)? 45 | .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)? 46 | .set_buckets_for_metric(batch_tokens_matcher, &batch_tokens_buckets) 47 | } 48 | -------------------------------------------------------------------------------- /router/src/shutdown.rs: -------------------------------------------------------------------------------- 1 | use tokio::signal; 2 | 3 | /// Shutdown signal handler 4 | pub(crate) async fn shutdown_signal() { 5 | let ctrl_c = async { 6 | signal::ctrl_c() 7 | .await 8 | .expect("failed to install Ctrl+C handler"); 9 | }; 10 | 11 | #[cfg(unix)] 12 | let terminate = async { 13 | signal::unix::signal(signal::unix::SignalKind::terminate()) 14 | .expect("failed to install signal handler") 15 | .recv() 16 | .await; 17 | }; 18 | 19 | #[cfg(not(unix))] 20 | let terminate = std::future::pending::<()>(); 21 | 22 | tokio::select! { 23 | _ = ctrl_c => {}, 24 | _ = terminate => {}, 25 | } 26 | 27 | tracing::info!("signal received, starting graceful shutdown"); 28 | } 29 | -------------------------------------------------------------------------------- /router/tests/common.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use serde::{Deserialize, Serialize}; 3 | use std::time::Duration; 4 | use text_embeddings_backend::DType; 5 | use text_embeddings_router::run; 6 | use tokio::time::Instant; 7 | 8 | #[derive(Serialize, Deserialize, Debug)] 9 | pub struct Score(f32); 10 | 11 | impl Score { 12 | fn is_close(&self, other: &Self, abs_tol: f32) -> bool { 13 | is_close::default() 14 | .abs_tol(abs_tol) 15 | .is_close(self.0, other.0) 16 | } 17 | } 18 | 19 | impl PartialEq for Score { 20 | fn eq(&self, other: &Self) -> bool { 21 | // Default tolerance for equality 22 | self.is_close(other, 4e-3) 23 | } 24 | } 25 | 26 | async fn check_health(port: u16, timeout: Duration) -> Result<()> { 27 | let addr = format!("http://0.0.0.0:{port}/health"); 28 | let client = reqwest::ClientBuilder::new() 29 | .timeout(timeout) 30 | .build() 31 | .unwrap(); 32 | 33 | let start = Instant::now(); 34 | loop { 35 | if client.get(&addr).send().await.is_ok() { 36 | return Ok(()); 37 | } 38 | if start.elapsed() < timeout { 39 | tokio::time::sleep(Duration::from_secs(1)).await; 40 | } else { 41 | anyhow::bail!("Backend is not healthy"); 42 | } 43 | } 44 | } 45 | 46 | pub async fn start_server(model_id: String, revision: Option, dtype: DType) -> Result<()> { 47 | let server_task = tokio::spawn({ 48 | run( 49 | model_id, 50 | revision, 51 | Some(1), 52 | Some(dtype), 53 | None, 54 | 4, 55 | 1024, 56 | None, 57 | 32, 58 | false, 59 | None, 60 | None, 61 | None, 62 | None, 63 | 8090, 64 | None, 65 | None, 66 | 2_000_000, 67 | None, 68 | None, 69 | "text-embeddings-inference.server".to_owned(), 70 | 9000, 71 | None, 72 | ) 73 | }); 74 | 75 | tokio::select! { 76 | err = server_task => err?, 77 | _ = check_health(8090, Duration::from_secs(60)) => Ok(()) 78 | }?; 79 | Ok(()) 80 | } 81 | -------------------------------------------------------------------------------- /router/tests/snapshots/test_http_predict__predictions_single.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: router/tests/test_http_predict.rs 3 | expression: predictions_single 4 | --- 5 | - score: 0.9972132 6 | label: neutral 7 | - score: 0.0007965705 8 | label: approval 9 | - score: 0.00029795355 10 | label: annoyance 11 | - score: 0.00020120452 12 | label: admiration 13 | - score: 0.00019649488 14 | label: realization 15 | - score: 0.00013194351 16 | label: excitement 17 | - score: 0.000102370424 18 | label: disappointment 19 | - score: 0.000100379 20 | label: disgust 21 | - score: 0.000099765304 22 | label: anger 23 | - score: 0.000085338215 24 | label: disapproval 25 | - score: 0.00007990062 26 | label: joy 27 | - score: 0.00007926568 28 | label: curiosity 29 | - score: 0.00007919242 30 | label: amusement 31 | - score: 0.000066899396 32 | label: confusion 33 | - score: 0.0000655515 34 | label: optimism 35 | - score: 0.000063790605 36 | label: love 37 | - score: 0.00006372412 38 | label: sadness 39 | - score: 0.00004477469 40 | label: fear 41 | - score: 0.000041168856 42 | label: desire 43 | - score: 0.000040866962 44 | label: surprise 45 | - score: 0.000033375803 46 | label: caring 47 | - score: 0.000027021126 48 | label: gratitude 49 | - score: 0.000025252299 50 | label: embarrassment 51 | - score: 0.000016195578 52 | label: pride 53 | - score: 0.000013492137 54 | label: relief 55 | - score: 0.000013144917 56 | label: grief 57 | - score: 0.000011318838 58 | label: nervousness 59 | - score: 0.0000098800765 60 | label: remorse 61 | -------------------------------------------------------------------------------- /router/tests/snapshots/test_http_rerank__ranks.snap: -------------------------------------------------------------------------------- 1 | --- 2 | source: router/tests/test_http_rerank.rs 3 | assertion_line: 42 4 | expression: ranks 5 | --- 6 | - index: 2 7 | score: 0.9997739 8 | text: test 9 | - index: 0 10 | score: 0.9997739 11 | text: test 12 | - index: 1 13 | score: 0.3638598 14 | text: other 15 | -------------------------------------------------------------------------------- /router/tests/test_http_embed.rs: -------------------------------------------------------------------------------- 1 | // mod common; 2 | // 3 | // use crate::common::{start_server, Score}; 4 | // use anyhow::Result; 5 | // use insta::internals::YamlMatcher; 6 | // use serde_json::json; 7 | // use text_embeddings_backend::DType; 8 | // 9 | // #[tokio::test] 10 | // #[cfg(feature = "http")] 11 | // async fn test_embeddings() -> Result<()> { 12 | // start_server( 13 | // "sentence-transformers/all-MiniLM-L6-v2".to_string(), 14 | // None, 15 | // DType::Float32, 16 | // ) 17 | // .await?; 18 | // 19 | // let request = json!({ 20 | // "inputs": "test" 21 | // }); 22 | // let client = reqwest::Client::new(); 23 | // let res = client 24 | // .post("http://0.0.0.0:8090/embed") 25 | // .json(&request) 26 | // .send() 27 | // .await?; 28 | // 29 | // let embeddings_single = res.json::>>().await?; 30 | // let matcher = YamlMatcher::>>::new(); 31 | // insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher); 32 | // 33 | // let test_tokens = vec![[101, 3231, 102]]; // tokenized "test" 34 | // let request = json!({"inputs": &test_tokens}); 35 | // let res = client 36 | // .post("http://0.0.0.0:8090/embed") 37 | // .json(&request) 38 | // .send() 39 | // .await?; 40 | // 41 | // let embeddings_single = res.json::>>().await?; 42 | // let matcher = YamlMatcher::>>::new(); 43 | // insta::assert_yaml_snapshot!("embeddings_single", embeddings_single, &matcher); 44 | // 45 | // let request = json!({ 46 | // "inputs": vec!["test", "test", "test", "test", "test"], 47 | // }); 48 | // 49 | // let client = reqwest::Client::new(); 50 | // let res = client 51 | // .post("http://0.0.0.0:8090/embed") 52 | // .json(&request) 53 | // .send() 54 | // .await?; 55 | // let embeddings_batch = res.json::>>().await?; 56 | // insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher); 57 | // for embeddings in &embeddings_batch { 58 | // assert_eq!(embeddings, &embeddings_single[0]); 59 | // } 60 | // 61 | // let request = 62 | // json!({"inputs": &test_tokens.repeat(request["inputs"].as_array().unwrap().len())}); 63 | // let res = client 64 | // .post("http://0.0.0.0:8090/embed") 65 | // .json(&request) 66 | // .send() 67 | // .await?; 68 | // 69 | // let embeddings_batch = res.json::>>().await?; 70 | // insta::assert_yaml_snapshot!("embeddings_batch", embeddings_batch, &matcher); 71 | // for embeddings in &embeddings_batch { 72 | // assert_eq!(embeddings, &embeddings_single[0]); 73 | // } 74 | // 75 | // let request = json!({ 76 | // "inputs": "test" 77 | // }); 78 | // 79 | // let client = reqwest::Client::new(); 80 | // let res = client 81 | // .post("http://0.0.0.0:8090/embed_all") 82 | // .json(&request) 83 | // .send() 84 | // .await?; 85 | // 86 | // let embeddings_raw = res.json::>>>().await?; 87 | // let matcher = YamlMatcher::>>>::new(); 88 | // insta::assert_yaml_snapshot!("embeddings_raw", embeddings_raw, &matcher); 89 | // 90 | // Ok(()) 91 | // } 92 | -------------------------------------------------------------------------------- /router/tests/test_http_predict.rs: -------------------------------------------------------------------------------- 1 | // mod common; 2 | // 3 | // use crate::common::{start_server, Score}; 4 | // use anyhow::Result; 5 | // use insta::internals::YamlMatcher; 6 | // use serde::{Deserialize, Serialize}; 7 | // use serde_json::json; 8 | // use text_embeddings_backend::DType; 9 | // 10 | // #[derive(Serialize, Deserialize, Debug, PartialEq)] 11 | // pub struct SnapshotPrediction { 12 | // score: Score, 13 | // label: String, 14 | // } 15 | // 16 | // #[tokio::test] 17 | // #[serial_test::serial] 18 | // #[cfg(feature = "http")] 19 | // async fn test_predict() -> Result<()> { 20 | // let model_id = if cfg!(feature = "ort") { 21 | // "SamLowe/roberta-base-go_emotions-onnx" 22 | // } else { 23 | // "SamLowe/roberta-base-go_emotions" 24 | // }; 25 | // 26 | // start_server(model_id.to_string(), None, DType::Float32).await?; 27 | // 28 | // let request = json!({ 29 | // "inputs": "test" 30 | // }); 31 | // 32 | // let client = reqwest::Client::new(); 33 | // let res = client 34 | // .post("http://0.0.0.0:8090/predict") 35 | // .json(&request) 36 | // .send() 37 | // .await?; 38 | // 39 | // let predictions_single = res.json::>().await?; 40 | // let matcher = YamlMatcher::>::new(); 41 | // insta::assert_yaml_snapshot!("predictions_single", predictions_single, &matcher); 42 | // 43 | // let request = json!({ 44 | // "inputs": vec![ 45 | // vec!["test"], 46 | // vec!["test"], 47 | // vec!["test"], 48 | // vec!["test"], 49 | // vec!["test"], 50 | // ], 51 | // }); 52 | // 53 | // let client = reqwest::Client::new(); 54 | // let res = client 55 | // .post("http://0.0.0.0:8090/predict") 56 | // .json(&request) 57 | // .send() 58 | // .await?; 59 | // 60 | // let predictions_batch = res.json::>>().await?; 61 | // let matcher = YamlMatcher::>>::new(); 62 | // insta::assert_yaml_snapshot!("predictions_batch", predictions_batch, &matcher); 63 | // 64 | // for predictions in &predictions_batch { 65 | // assert_eq!(predictions, &predictions_single); 66 | // } 67 | // 68 | // Ok(()) 69 | // } 70 | -------------------------------------------------------------------------------- /router/tests/test_http_rerank.rs: -------------------------------------------------------------------------------- 1 | // mod common; 2 | // 3 | // use crate::common::{start_server, Score}; 4 | // use anyhow::Result; 5 | // use insta::internals::YamlMatcher; 6 | // use serde::{Deserialize, Serialize}; 7 | // use serde_json::json; 8 | // use text_embeddings_backend::DType; 9 | // 10 | // #[derive(Serialize, Deserialize, Debug, PartialEq)] 11 | // pub struct SnapshotRank { 12 | // index: usize, 13 | // score: Score, 14 | // text: String, 15 | // } 16 | // 17 | // #[tokio::test] 18 | // #[cfg(feature = "http")] 19 | // async fn test_rerank() -> Result<()> { 20 | // start_server("BAAI/bge-reranker-base".to_string(), None, DType::Float32).await?; 21 | // 22 | // let request = json!({ 23 | // "query": "test", 24 | // "texts": vec!["test", "other", "test"], 25 | // "return_text": true 26 | // }); 27 | // 28 | // let client = reqwest::Client::new(); 29 | // let res = client 30 | // .post("http://0.0.0.0:8090/rerank") 31 | // .json(&request) 32 | // .send() 33 | // .await?; 34 | // 35 | // let ranks = res.json::>().await?; 36 | // let matcher = YamlMatcher::>::new(); 37 | // insta::assert_yaml_snapshot!("ranks", ranks, &matcher); 38 | // 39 | // assert_eq!(ranks[0].index, 2); 40 | // assert_eq!(ranks[1].index, 0); 41 | // assert_eq!(ranks[0].score, ranks[1].score); 42 | // 43 | // Ok(()) 44 | // } 45 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.85.1" 3 | components = ["rustfmt", "clippy"] 4 | -------------------------------------------------------------------------------- /sagemaker-entrypoint-cuda-all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | verlte() { 4 | [ "$1" = "$2" ] && return 1 || [ "$2" = "$(echo -e "$1\n$2" | sort -V | head -n1)" ] 5 | } 6 | 7 | if [ -f /usr/local/cuda/compat/libcuda.so.1 ]; then 8 | CUDA_COMPAT_MAX_DRIVER_VERSION=$(readlink /usr/local/cuda/compat/libcuda.so.1 | cut -d"." -f 3-) 9 | echo "CUDA compat package requires Nvidia driver ≤${CUDA_COMPAT_MAX_DRIVER_VERSION}" 10 | cat /proc/driver/nvidia/version 11 | NVIDIA_DRIVER_VERSION=$(sed -n 's/^NVRM.*Kernel Module *\([0-9.]*\).*$/\1/p' /proc/driver/nvidia/version 2>/dev/null || true) 12 | echo "Current installed Nvidia driver version is ${NVIDIA_DRIVER_VERSION}" 13 | if [ $(verlte "$CUDA_COMPAT_MAX_DRIVER_VERSION" "$NVIDIA_DRIVER_VERSION") ]; then 14 | echo "Setup CUDA compatibility libs path to LD_LIBRARY_PATH" 15 | export LD_LIBRARY_PATH=/usr/local/cuda/compat:$LD_LIBRARY_PATH 16 | echo $LD_LIBRARY_PATH 17 | else 18 | echo "Skip CUDA compat libs setup as newer Nvidia driver is installed" 19 | fi 20 | else 21 | echo "Skip CUDA compat libs setup as package not found" 22 | fi 23 | 24 | if [[ -z "${HF_MODEL_ID}" ]]; then 25 | echo "HF_MODEL_ID must be set" 26 | exit 1 27 | fi 28 | export MODEL_ID="${HF_MODEL_ID}" 29 | 30 | if [[ -n "${HF_MODEL_REVISION}" ]]; then 31 | export REVISION="${HF_MODEL_REVISION}" 32 | fi 33 | 34 | if ! command -v nvidia-smi &> /dev/null; then 35 | echo "Error: 'nvidia-smi' command not found." 36 | exit 1 37 | fi 38 | 39 | # Query GPU name using nvidia-smi 40 | gpu_name=$(nvidia-smi --query-gpu=gpu_name --format=csv | awk 'NR==2') 41 | if [ $? -ne 0 ]; then 42 | echo "Error: $gpu_name" 43 | echo "Query gpu_name failed" 44 | else 45 | echo "Query gpu_name succeeded. Printing output: $gpu_name" 46 | fi 47 | 48 | # Function to get compute capability based on GPU name 49 | get_compute_cap() { 50 | gpu_name="$1" 51 | 52 | # Check if the GPU name contains "A10G" 53 | if [[ "$gpu_name" == *"A10G"* ]]; then 54 | echo "86" 55 | # Check if the GPU name contains "A100" 56 | elif [[ "$gpu_name" == *"A100"* ]]; then 57 | echo "80" 58 | # Check if the GPU name contains "H100" 59 | elif [[ "$gpu_name" == *"H100"* ]]; then 60 | echo "90" 61 | # Cover Nvidia T4 62 | elif [[ "$gpu_name" == *"T4"* ]]; then 63 | echo "75" 64 | # Cover Nvidia L4 65 | elif [[ "$gpu_name" == *"L4"* ]]; then 66 | echo "89" 67 | else 68 | echo "80" # Default compute capability 69 | fi 70 | } 71 | 72 | if [[ -z "${CUDA_COMPUTE_CAP}" ]] 73 | then 74 | compute_cap=$(get_compute_cap "$gpu_name") 75 | echo "the compute_cap is $compute_cap" 76 | else 77 | compute_cap=$CUDA_COMPUTE_CAP 78 | fi 79 | 80 | if [[ ${compute_cap} -eq 75 ]] 81 | then 82 | text-embeddings-router-75 --port 8080 --json-output 83 | elif [[ ${compute_cap} -ge 80 && ${compute_cap} -lt 90 ]] 84 | then 85 | text-embeddings-router-80 --port 8080 --json-output 86 | elif [[ ${compute_cap} -eq 90 ]] 87 | then 88 | text-embeddings-router-90 --port 8080 --json-output 89 | else 90 | echo "cuda compute cap ${compute_cap} is not supported"; exit 1 91 | fi 92 | -------------------------------------------------------------------------------- /sagemaker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -z "${HF_MODEL_ID}" ]]; then 4 | echo "HF_MODEL_ID must be set" 5 | exit 1 6 | fi 7 | export MODEL_ID="${HF_MODEL_ID}" 8 | 9 | if [[ -n "${HF_MODEL_REVISION}" ]]; then 10 | export REVISION="${HF_MODEL_REVISION}" 11 | fi 12 | 13 | text-embeddings-router --port 8080 --json-output 14 | --------------------------------------------------------------------------------