├── .github └── workflows │ ├── publish_dev_package.yml │ └── publish_on_release.yml ├── .gitignore ├── LICENSE ├── README.md ├── _version_helper.py ├── pyproject.toml └── src └── pymilvus └── model ├── __init__.py ├── base.py ├── dense ├── __init__.py ├── cohere.py ├── gemini.py ├── instructor.py ├── instructor_embedding │ └── instructor_impl.py ├── jinaai.py ├── mistralai.py ├── model2vec_embed.py ├── nomic.py ├── onnx.py ├── openai.py ├── sentence_transformer.py ├── tei.py └── voyageai.py ├── hybrid ├── __init__.py ├── bge_m3.py ├── mgte.py └── mgte_embedding │ └── gte_impl.py ├── reranker ├── __init__.py ├── bgereranker.py ├── cohere.py ├── cross_encoder.py ├── jinaai.py ├── tei.py └── voyageai.py ├── sparse ├── __init__.py ├── bm25 │ ├── __init__.py │ ├── bm25.py │ ├── lang.yaml │ └── tokenizers.py ├── splade.py ├── splade_embedding │ └── splade_impl.py └── utils.py └── utils ├── __init__.py └── dependency_control.py /.github/workflows/publish_dev_package.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to TestPyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to TestPyPI 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Check out from Git 15 | uses: actions/checkout@v4 16 | - name: Get history and tags for SCM versioning 17 | run: | 18 | git fetch --prune --unshallow 19 | git fetch --depth=1 origin +refs/tags/*:refs/tags/* 20 | - name: Set up Python 3.12 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: 3.12 24 | - name: Install pypa/build 25 | run: >- 26 | python -m 27 | pip install 28 | build 29 | --user 30 | - name: Build a binary wheel and a source tarball 31 | run: >- 32 | PYTHONPATH=. 33 | python -m 34 | build 35 | --sdist 36 | --wheel 37 | --outdir dist/ 38 | . 39 | - name: Publish distribution 📦 to Test PyPI 40 | uses: pypa/gh-action-pypi-publish@release/v1 41 | with: 42 | password: ${{ secrets.TOKEN_TEST_PYPI }} 43 | repository-url: https://test.pypi.org/legacy/ 44 | 45 | -------------------------------------------------------------------------------- /.github/workflows/publish_on_release.yml: -------------------------------------------------------------------------------- 1 | name: Publish 🐍 On Release to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-n-publish: 9 | name: Build and publish Python 🐍 distributions 📦 to PyPI 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Check out from Git 14 | uses: actions/checkout@v4 15 | - name: Get history and tags for SCM versioning 16 | run: | 17 | git fetch --prune --unshallow 18 | git fetch --depth=1 origin +refs/tags/*:refs/tags/* 19 | - name: Set up Python 3.12 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.12 23 | - name: Install pypa/build 24 | run: >- 25 | python -m 26 | pip install 27 | build 28 | --user 29 | - name: Build a binary wheel and a source tarball 30 | run: >- 31 | PYTHONPATH=. 32 | python -m 33 | build 34 | --sdist 35 | --wheel 36 | --outdir dist/ 37 | . 38 | - name: Publish distribution 📦 to Test PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | with: 41 | password: ${{ secrets.TOKEN_TEST_PYPI }} 42 | repository-url: https://test.pypi.org/legacy/ 43 | - name: Publish distribution 📦 to PyPI 44 | if: startsWith(github.ref, 'refs/tags') 45 | uses: pypa/gh-action-pypi-publish@release/v1 46 | with: 47 | password: ${{ secrets.PYPI_API_TOKEN }} 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | ## Pycharm IDE 3 | /.idea 4 | 5 | .vscode 6 | 7 | # Python files 8 | *.pyc 9 | dist 10 | docs/build/ 11 | *.egg 12 | *.egg-info 13 | *.eggs/ 14 | **/__pycache__/ 15 | .pytest*/ 16 | /build/ 17 | .remember 18 | venv/ 19 | 20 | # Env 21 | .env 22 | 23 | # Local Temp 24 | temp/ 25 | *.swp 26 | assets/ 27 | TODO 28 | 29 | # GitHub 30 | .coverage 31 | htmlcov/ 32 | debug/ 33 | .codecov.yml 34 | coverage.xml 35 | 36 | # Example data 37 | /examples/bulk_writer 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Milvus Model Lib 2 | 3 | The `milvus-model` library provides the integration with common embedding and reranker models for Milvus, a high performance open-source vector database built for AI applications. `milvus-model` lib is included as a dependency in `pymilvus`, the Python SDK of Milvus. 4 | 5 | `milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers or Hugging Face [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference) . 6 | 7 | `milvus-model` supports Python 3.8 and above. 8 | 9 | ## Installation 10 | 11 | If you use `pymilvus`, you can install `milvus-model` through its alias `pymilvus[model]`: 12 | ```bash 13 | pip install pymilvus[model] 14 | # or pip install "pymilvus[model]" for zsh. 15 | ``` 16 | 17 | You can also install it directly: 18 | ```bash 19 | pip install pymilvus.model 20 | ``` 21 | 22 | To upgrade milvus-model to the latest version, use: 23 | ``` 24 | pip install pymilvus.model --upgrade 25 | ``` 26 | If milvus-model was initially installed as part of the PyMilvus optional components, you should also upgrade PyMilvus to ensure compatibility. This can be done with: 27 | ``` 28 | pip install pymilvus[model] --upgrade 29 | ``` 30 | If you need to install a specific version of milvus-model, specify the version number: 31 | ```bash 32 | pip install pymilvus.model==0.3.0 33 | ``` 34 | This command installs version 0.3.0 of milvus-model. 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /_version_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | this module is a hack only in place to allow for setuptools 3 | to use the attribute for the versions 4 | 5 | it works only if the backend-path of the build-system section 6 | from pyproject.toml is respected 7 | """ 8 | 9 | from __future__ import annotations 10 | 11 | import logging 12 | from typing import Callable 13 | 14 | from setuptools import build_meta as build_meta # noqa 15 | 16 | from setuptools_scm import _types as _t 17 | from setuptools_scm import Configuration 18 | from setuptools_scm import get_version 19 | from setuptools_scm import git 20 | from setuptools_scm import hg 21 | from setuptools_scm.fallbacks import parse_pkginfo 22 | from setuptools_scm.version import ( 23 | get_no_local_node, 24 | _parse_version_tag, 25 | guess_next_simple_semver, 26 | SEMVER_MINOR, 27 | guess_next_version, 28 | ScmVersion, 29 | ) 30 | 31 | log = logging.getLogger("setuptools_scm") 32 | # todo: take fake entrypoints from pyproject.toml 33 | try_parse: list[Callable[[_t.PathT, Configuration], ScmVersion | None]] = [ 34 | parse_pkginfo, 35 | git.parse, 36 | hg.parse, 37 | git.parse_archival, 38 | hg.parse_archival, 39 | ] 40 | 41 | 42 | def parse(root: str, config: Configuration) -> ScmVersion | None: 43 | for maybe_parse in try_parse: 44 | try: 45 | parsed = maybe_parse(root, config) 46 | except OSError as e: 47 | log.warning("parse with %s failed with: %s", maybe_parse, e) 48 | else: 49 | if parsed is not None: 50 | return parsed 51 | 52 | 53 | fmt = "{guessed}.rc{distance}" 54 | 55 | 56 | def custom_version(version: ScmVersion) -> str: 57 | if version.exact: 58 | return version.format_with("{tag}") 59 | if version.branch is not None: 60 | # Does the branch name (stripped of namespace) parse as a version? 61 | branch_ver_data = _parse_version_tag(version.branch.split("/")[-1], version.config) 62 | if branch_ver_data is not None: 63 | branch_ver = branch_ver_data["version"] 64 | if branch_ver[0] == "v": 65 | # Allow branches that start with 'v', similar to Version. 66 | branch_ver = branch_ver[1:] 67 | # Does the branch version up to the minor part match the tag? If not it 68 | # might be like, an issue number or something and not a version number, so 69 | # we only want to use it if it matches. 70 | tag_ver_up_to_minor = str(version.tag).split(".")[:SEMVER_MINOR] 71 | branch_ver_up_to_minor = branch_ver.split(".")[:SEMVER_MINOR] 72 | if branch_ver_up_to_minor == tag_ver_up_to_minor: 73 | # We're in a release/maintenance branch, next is a patch/rc/beta bump: 74 | return version.format_next_version(guess_next_version, fmt=fmt) 75 | # We're in a development branch, next is a minor bump: 76 | # return version.format_next_version(guess_next_simple_semver, retain=SEMVER_MINOR, fmt=fmt) 77 | return version.format_next_version(guess_next_version, fmt=fmt) 78 | 79 | 80 | def scm_version() -> str: 81 | return get_version( 82 | relative_to=__file__, 83 | parse=parse, 84 | version_scheme=custom_version, 85 | local_scheme=get_no_local_node, 86 | ) 87 | 88 | 89 | version: str 90 | 91 | 92 | def __getattr__(name: str) -> str: 93 | if name == "version": 94 | global version 95 | version = scm_version() 96 | return version 97 | raise AttributeError(name) 98 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 75", 4 | "wheel", 5 | "gitpython", 6 | "setuptools_scm >= 8.0", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | [project] 10 | name = "pymilvus.model" 11 | authors = [ 12 | {name = "Milvus Team", email = "milvus-team@zilliz.com"}, 13 | ] 14 | requires-python = ">=3.8" 15 | description = "Model components for PyMilvus, the Python SDK for Milvus" 16 | readme = "README.md" 17 | dependencies = [ 18 | "transformers >= 4.36.0", 19 | "onnxruntime", 20 | "scipy >= 1.10.0", 21 | "protobuf", 22 | "numpy" 23 | ] 24 | 25 | classifiers = [ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: Apache Software License", 28 | ] 29 | 30 | dynamic = ["version"] 31 | 32 | [project.urls] 33 | repository = "https://github.com/milvus-io/milvus-model" # Update the repository URL 34 | 35 | [tool.setuptools] 36 | package-dir = { "pymilvus" = "pymilvus" } 37 | include-package-data = true 38 | 39 | [tool.setuptools.packages.find] 40 | where = ["src", "_version_helper.py"] 41 | 42 | [tool.setuptools.package-data] 43 | "pymilvus.model.sparse.bm25" = ["lang.yaml"] 44 | 45 | [tool.setuptools.dynamic] 46 | version = { attr = "_version_helper.version" } 47 | 48 | [tool.setuptools_scm] 49 | 50 | [tool.black] 51 | line-length = 100 52 | target-version = ['py37'] 53 | include = '\.pyi?$' 54 | extend-ignore = ["E203", "E501"] 55 | 56 | [tool.ruff] 57 | src = ["src"] 58 | lint.select = [ 59 | "E", 60 | "F", 61 | "C90", 62 | "I", 63 | "N", 64 | "B", "C", "G", 65 | "A", 66 | "ANN001", 67 | "S", "T", "W", "ARG", "BLE", "COM", "DJ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT" 68 | ] 69 | lint.ignore = [ 70 | "N818", 71 | "DTZ", # datatime related 72 | "BLE", # blind-except (BLE001) 73 | "SLF", # SLF001 Private member accessed: `_fetch_handler` [E] 74 | "PD003", 75 | "TRY003", # [ruff] TRY003 Avoid specifying long messages outside the exception class [E] TODO 76 | "PLR2004", # Magic value used in comparison, consider replacing 65535 with a constant variable [E] TODO 77 | "TRY301", #[ruff] TRY301 Abstract `raise` to an inner function [E] 78 | "FBT001", #[ruff] FBT001 Boolean positional arg in function definition [E] TODO 79 | "FBT002", # [ruff] FBT002 Boolean default value in function definition [E] TODO 80 | "PLR0911", # Too many return statements (15 > 6) [E] 81 | "G004", # [ruff] G004 Logging statement uses f-string [E] 82 | "S603", # [ruff] S603 `subprocess` call: check for execution of untrusted input [E] 83 | "N802", #[ruff] N802 Function name `OK` should be lowercase [E] TODO 84 | "PD011", # [ruff] PD011 Use `.to_numpy()` instead of `.values` [E] 85 | "COM812", 86 | "FBT003", # [ruff] FBT003 Boolean positional value in function call [E] TODO 87 | "ARG002", 88 | "E501", # black takes care of it 89 | "ARG005", # [ruff] ARG005 Unused lambda argument: `disable` [E] 90 | "TRY400", 91 | "PLR0912", # TODO 92 | "C901", # TODO 93 | "PYI041", # TODO 94 | ] 95 | 96 | # Allow autofix for all enabled rules (when `--fix`) is provided. 97 | lint.fixable = [ 98 | "A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", 99 | "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", 100 | "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", 101 | "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", 102 | "YTT", 103 | ] 104 | lint.unfixable = [] 105 | 106 | show-fixes = true 107 | 108 | # Exclude a variety of commonly ignored directories. 109 | exclude = [ 110 | ".bzr", 111 | ".direnv", 112 | ".eggs", 113 | ".git", 114 | ".git-rewrite", 115 | ".hg", 116 | ".mypy_cache", 117 | ".nox", 118 | ".pants.d", 119 | ".pytype", 120 | ".ruff_cache", 121 | ".svn", 122 | ".tox", 123 | ".venv", 124 | "__pypackages__", 125 | "_build", 126 | "buck-out", 127 | "build", 128 | "dist", 129 | "node_modules", 130 | "venv", 131 | "grpc_gen", 132 | "__pycache__", 133 | "pymilvus/client/stub.py", 134 | "tests", 135 | ] 136 | 137 | # Same as Black. 138 | line-length = 100 139 | 140 | # Allow unused variables when underscore-prefixed. 141 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 142 | 143 | # Assume Python 3.7 144 | target-version = "py37" 145 | 146 | [tool.ruff.lint.mccabe] 147 | # Unlike Flake8, default to a complexity level of 10. 148 | max-complexity = 18 149 | 150 | [tool.ruff.lint.pycodestyle] 151 | max-doc-length = 100 152 | 153 | [tool.ruff.lint.pylint] 154 | max-args = 20 155 | max-branches = 15 156 | 157 | [tool.ruff.lint.flake8-builtins] 158 | builtins-ignorelist = [ 159 | "format", 160 | "next", 161 | "object", # TODO 162 | "id", 163 | "dict", # TODO 164 | "filter", 165 | ] 166 | -------------------------------------------------------------------------------- /src/pymilvus/model/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid", "reranker", "utils"] 2 | 3 | from . import dense, hybrid, sparse, reranker, utils 4 | 5 | DefaultEmbeddingFunction = dense.onnx.OnnxEmbeddingFunction 6 | -------------------------------------------------------------------------------- /src/pymilvus/model/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List 3 | 4 | 5 | class BaseEmbeddingFunction: 6 | @abstractmethod 7 | def __call__(self, texts: List[str]): 8 | """ """ 9 | 10 | @abstractmethod 11 | def encode_queries(self, queries: List[str]): 12 | """ """ 13 | 14 | 15 | class BaseRerankFunction: 16 | @abstractmethod 17 | def __call__(self, query: str, documents: List[str], top_k: int): 18 | """ """ 19 | 20 | 21 | class RerankResult: 22 | def __init__(self, text: str, score: float, index: int): 23 | self.text = text 24 | self.score = score 25 | self.index = index 26 | 27 | def to_dict(self): 28 | return {"text": self.text, "score": self.score, "index": self.index} 29 | 30 | def __str__(self): 31 | return f"RerankResult(text={self.text!r}, score={self.score}, index={self.index})" 32 | 33 | def __repr__(self): 34 | return f"RerankResult(text={self.text!r}, score={self.score}, index={self.index})" 35 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/__init__.py: -------------------------------------------------------------------------------- 1 | from pymilvus.model.dense.openai import OpenAIEmbeddingFunction 2 | from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction 3 | from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction 4 | from pymilvus.model.dense.jinaai import JinaEmbeddingFunction 5 | from pymilvus.model.dense.tei import TEIEmbeddingFunction 6 | from pymilvus.model.dense.onnx import OnnxEmbeddingFunction 7 | from pymilvus.model.dense.cohere import CohereEmbeddingFunction 8 | from pymilvus.model.dense.mistralai import MistralAIEmbeddingFunction 9 | from pymilvus.model.dense.nomic import NomicEmbeddingFunction 10 | from pymilvus.model.dense.instructor import InstructorEmbeddingFunction 11 | from pymilvus.model.dense.model2vec_embed import Model2VecEmbeddingFunction 12 | from pymilvus.model.dense.gemini import GeminiEmbeddingFunction 13 | 14 | __all__ = [ 15 | "OpenAIEmbeddingFunction", 16 | "SentenceTransformerEmbeddingFunction", 17 | "VoyageEmbeddingFunction", 18 | "JinaEmbeddingFunction", 19 | "TEIEmbeddingFunction", 20 | "OnnxEmbeddingFunction", 21 | "CohereEmbeddingFunction", 22 | "MistralAIEmbeddingFunction", 23 | "NomicEmbeddingFunction", 24 | "InstructorEmbeddingFunction", 25 | "Model2VecEmbeddingFunction", 26 | "GeminiEmbeddingFunction", 27 | ] 28 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/cohere.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import struct 3 | from collections import defaultdict 4 | import numpy as np 5 | 6 | from pymilvus.model.base import BaseEmbeddingFunction 7 | from pymilvus.model.utils import import_cohere 8 | 9 | 10 | 11 | class CohereEmbeddingFunction(BaseEmbeddingFunction): 12 | def __init__(self, 13 | model_name: str = "embed-english-light-v3.0", 14 | api_key: Optional[str] = None, 15 | input_type: str = "search_document", 16 | embedding_types: Optional[List[str]] = None, 17 | truncate: Optional[str] = None, 18 | **kwargs): 19 | self.model_name = model_name 20 | self.input_type = input_type 21 | self.embedding_types = embedding_types 22 | self.truncate = truncate 23 | 24 | import_cohere() 25 | import cohere 26 | 27 | if isinstance(embedding_types, list): 28 | if len(embedding_types) > 1: 29 | raise ValueError("Only one embedding type can be specified using current PyMilvus model library.") 30 | elif embedding_types[0] == "int8" or embedding_types[0] == "uint8": 31 | raise ValueError("Currently int8 or uint8 is not supported with PyMilvus model library.") 32 | else: 33 | pass 34 | 35 | self.client = cohere.Client(api_key, **kwargs) 36 | self._cohereai_model_meta_info = defaultdict(dict) 37 | self._cohereai_model_meta_info["embed-english-v3.0"]["dim"] = 1024 38 | self._cohereai_model_meta_info["embed-english-light-v3.0"]["dim"] = 384 39 | self._cohereai_model_meta_info["embed-english-v2.0"]["dim"] = 4096 40 | self._cohereai_model_meta_info["embed-english-light-v2.0"]["dim"] = 1024 41 | self._cohereai_model_meta_info["embed-multilingual-v3.0"]["dim"] = 1024 42 | self._cohereai_model_meta_info["embed-multilingual-light-v3.0"]["dim"] = 384 43 | self._cohereai_model_meta_info["embed-multilingual-v2.0"]["dim"] = 768 44 | 45 | def _call_cohere_api(self, texts: List[str], input_type: str) -> List[np.array]: 46 | embeddings = self.client.embed( 47 | texts=texts, 48 | model=self.model_name, 49 | input_type=input_type, 50 | embedding_types=self.embedding_types, 51 | truncate=self.truncate 52 | ).embeddings 53 | if self.embedding_types is None: 54 | results = [np.array(data, dtype=np.float32) for data in embeddings] 55 | else: 56 | results = getattr(embeddings, self.embedding_types[0], None) 57 | if self.embedding_types[0] == "binary": 58 | results = [struct.pack('b' * len(int8_vector), *int8_vector) for int8_vector in results] 59 | elif self.embedding_types[0] == "ubinary": 60 | results = [struct.pack('B' * len(uint8_vector), *uint8_vector) for uint8_vector in results] 61 | elif self.embedding_types[0] == "float": 62 | results = [np.array(result, dtype=np.float32) for result in results] 63 | else: 64 | pass 65 | return results 66 | 67 | def encode_documents(self, documents: List[str]) -> List[np.array]: 68 | return self._call_cohere_api(documents, input_type="search_document") 69 | 70 | def encode_queries(self, queries: List[str]) -> List[np.array]: 71 | return self._call_cohere_api(queries, input_type="search_query") 72 | 73 | def __call__(self, texts: List[str]) -> List[np.array]: 74 | return self._call_cohere_api(texts, self.input_type) 75 | 76 | @property 77 | def dim(self): 78 | return self._cohereai_model_meta_info[self.model_name]["dim"] 79 | 80 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/gemini.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Optional 3 | 4 | 5 | import numpy as np 6 | 7 | from pymilvus.model.base import BaseEmbeddingFunction 8 | from pymilvus.model.utils import import_google 9 | 10 | 11 | class GeminiEmbeddingFunction(BaseEmbeddingFunction): 12 | def __init__( 13 | self, 14 | model_name: str = "gemini-embedding-exp-03-07", 15 | api_key: Optional[str] = None, 16 | config: Optional['types.EmbedContentConfig']=None, 17 | **kwargs, 18 | ): 19 | import_google() 20 | from google import genai 21 | from google.genai import types 22 | self.model_name = model_name 23 | self.client = genai.Client(api_key=api_key, **kwargs) 24 | self.config: Optional[types.EmbedContentConfig] = config 25 | 26 | self._gemini_model_meta_info = defaultdict(dict) 27 | self._gemini_model_meta_info["gemini-embedding-exp-03-07"]["dim"] = 3072 28 | self._gemini_model_meta_info["models/embedding-001"]["dim"] = 768 29 | self._gemini_model_meta_info["models/text-embedding-004"]["dim"] = 768 30 | 31 | def encode_queries(self, queries: List[str]) -> List[np.array]: 32 | return self._encode(queries) 33 | 34 | def encode_documents(self, documents: List[str]) -> List[np.array]: 35 | return self._encode(documents) 36 | 37 | @property 38 | def dim(self): 39 | if self.config is None or self.config.output_dimensionality is None: 40 | return self._gemini_model_meta_info[self.model_name]["dim"] 41 | else: 42 | return self.config.output_dimensionality 43 | 44 | def __call__(self, texts: List[str]) -> List[np.array]: 45 | return self._encode(texts) 46 | 47 | def _encode_query(self, query: str) -> np.array: 48 | return self._encode(query)[0] 49 | 50 | def _encode_document(self, document: str) -> np.array: 51 | return self._encode(document)[0] 52 | 53 | def _encode(self, texts: List[str]): 54 | result = self.client.models.embed_content( 55 | model=self.model_name, 56 | contents=texts, 57 | config=self.config 58 | ) 59 | return [np.array(data.values) for data in result.embeddings] 60 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/instructor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import numpy as np 3 | 4 | from pymilvus.model.base import BaseEmbeddingFunction 5 | 6 | 7 | class InstructorEmbeddingFunction(BaseEmbeddingFunction): 8 | def __init__( 9 | self, 10 | model_name: str = "hkunlp/instructor-xl", 11 | batch_size: int = 32, 12 | query_instruction: str = "Represent the question for retrieval:", 13 | doc_instruction: str = "Represent the document for retrieval:", 14 | device: str = "cpu", 15 | normalize_embeddings: bool = True, 16 | **kwargs, 17 | ): 18 | from .instructor_embedding.instructor_impl import Instructor 19 | 20 | self.model_name = model_name 21 | self.query_instruction = query_instruction 22 | self.doc_instruction = doc_instruction 23 | self.batch_size = batch_size 24 | self.normalize_embeddings = normalize_embeddings 25 | 26 | _model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs) 27 | self.model = Instructor(**_model_config) 28 | 29 | def __call__(self, texts: List[str]) -> List[np.array]: 30 | return self._encode([[self.doc_instruction, text] for text in texts]) 31 | 32 | def _encode(self, texts: List[str]) -> List[np.array]: 33 | embs = self.model.encode( 34 | texts, 35 | batch_size=self.batch_size, 36 | show_progress_bar=False, 37 | convert_to_numpy=True, 38 | normalize_embeddings=self.normalize_embeddings 39 | ) 40 | return list(embs) 41 | 42 | @property 43 | def dim(self): 44 | return self.model.get_sentence_embedding_dimension() 45 | 46 | def encode_queries(self, queries: List[str]) -> List[np.array]: 47 | instructed_queries = [[self.query_instruction, query] for query in queries] 48 | return self._encode(instructed_queries) 49 | 50 | def encode_documents(self, documents: List[str]) -> List[np.array]: 51 | instructed_documents = [[self.doc_instruction, document] for document in documents] 52 | return self._encode(instructed_documents) 53 | 54 | def _encode_query(self, query: str) -> np.array: 55 | instructed_query = self.query_instruction + query 56 | embs = self.model.encode( 57 | sentences=[instructed_query], 58 | batch_size=1, 59 | show_progress_bar=False, 60 | convert_to_numpy=True, 61 | normalize_embeddings=self.normalize_embeddings, 62 | ) 63 | return embs[0] 64 | 65 | def _encode_document(self, document: str) -> np.array: 66 | instructed_document = self.doc_instruction + document 67 | embs = self.model.encode( 68 | sentences=[instructed_document], 69 | batch_size=1, 70 | show_progress_bar=False, 71 | convert_to_numpy=True, 72 | normalize_embeddings=self.normalize_embeddings, 73 | ) 74 | return embs[0] 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/instructor_embedding/instructor_impl.py: -------------------------------------------------------------------------------- 1 | #The following code is adapted from/inspired by the 'instructor-embedding' project: 2 | #https://github.com/xlang-ai/instructor-embedding 3 | #Specifically, instructor-embedding/InstructorEmbedding/instructor.py 4 | # 5 | # License: Apache License 2.0 (January 2004) 6 | # 7 | # For more information on the original project, visit the GitHub repository: 8 | # https://github.com/xlang-ai/instructor-embedding 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at: 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | 22 | # This script is based on the modifications from https://github.com/UKPLab/sentence-transformers 23 | import importlib 24 | import json 25 | import os 26 | from collections import OrderedDict 27 | from typing import Union 28 | 29 | from pymilvus.model.utils import import_sentence_transformers, import_huggingface_hub, import_torch 30 | 31 | import_sentence_transformers() 32 | import_huggingface_hub() 33 | import_torch() 34 | 35 | import numpy as np 36 | import torch 37 | from sentence_transformers import SentenceTransformer 38 | from sentence_transformers.models import Transformer 39 | from torch import Tensor, nn 40 | from transformers import AutoConfig, AutoTokenizer 41 | from sentence_transformers.util import disabled_tqdm 42 | from huggingface_hub import snapshot_download 43 | 44 | 45 | def batch_to_device(batch, target_device: str): 46 | for key in batch: 47 | if isinstance(batch[key], Tensor): 48 | batch[key] = batch[key].to(target_device) 49 | return batch 50 | 51 | 52 | class InstructorPooling(nn.Module): 53 | """Performs pooling (max or mean) on the token embeddings. 54 | 55 | Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. 56 | This layer also allows to use the CLS token if it is returned by the underlying word embedding model. 57 | You can concatenate multiple poolings together. 58 | 59 | :param word_embedding_dimension: Dimensions for the word embeddings 60 | :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings 61 | :param pooling_mode_cls_token: Use the first token (CLS token) as text representations 62 | :param pooling_mode_max_tokens: Use max in each dimension over all tokens. 63 | :param pooling_mode_mean_tokens: Perform mean-pooling 64 | :param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but divide by sqrt(input_length). 65 | :param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling, 66 | see https://arxiv.org/abs/2202.08904 67 | :param pooling_mode_lasttoken: Perform last token pooling, 68 | see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005 69 | """ 70 | 71 | def __init__( 72 | self, 73 | word_embedding_dimension: int, 74 | pooling_mode: Union[str, None] = None, 75 | pooling_mode_cls_token: bool = False, 76 | pooling_mode_max_tokens: bool = False, 77 | pooling_mode_mean_tokens: bool = True, 78 | pooling_mode_mean_sqrt_len_tokens: bool = False, 79 | pooling_mode_weightedmean_tokens: bool = False, 80 | pooling_mode_lasttoken: bool = False, 81 | ): 82 | super().__init__() 83 | 84 | self.config_keys = [ 85 | "word_embedding_dimension", 86 | "pooling_mode_cls_token", 87 | "pooling_mode_mean_tokens", 88 | "pooling_mode_max_tokens", 89 | "pooling_mode_mean_sqrt_len_tokens", 90 | "pooling_mode_weightedmean_tokens", 91 | "pooling_mode_lasttoken", 92 | ] 93 | 94 | if pooling_mode is not None: # Set pooling mode by string 95 | pooling_mode = pooling_mode.lower() 96 | assert pooling_mode in ["mean", "max", "cls", "weightedmean", "lasttoken"] 97 | pooling_mode_cls_token = pooling_mode == "cls" 98 | pooling_mode_max_tokens = pooling_mode == "max" 99 | pooling_mode_mean_tokens = pooling_mode == "mean" 100 | pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean" 101 | pooling_mode_lasttoken = pooling_mode == "lasttoken" 102 | 103 | self.word_embedding_dimension = word_embedding_dimension 104 | self.pooling_mode_cls_token = pooling_mode_cls_token 105 | self.pooling_mode_mean_tokens = pooling_mode_mean_tokens 106 | self.pooling_mode_max_tokens = pooling_mode_max_tokens 107 | self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens 108 | self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens 109 | self.pooling_mode_lasttoken = pooling_mode_lasttoken 110 | 111 | pooling_mode_multiplier = sum( 112 | [ 113 | pooling_mode_cls_token, 114 | pooling_mode_max_tokens, 115 | pooling_mode_mean_tokens, 116 | pooling_mode_mean_sqrt_len_tokens, 117 | pooling_mode_weightedmean_tokens, 118 | pooling_mode_lasttoken, 119 | ] 120 | ) 121 | self.pooling_output_dimension = ( 122 | pooling_mode_multiplier * word_embedding_dimension 123 | ) 124 | 125 | def __repr__(self): 126 | return f"Pooling({self.get_config_dict()})" 127 | 128 | def get_pooling_mode_str(self) -> str: 129 | """ 130 | Returns the pooling mode as string 131 | """ 132 | modes = [] 133 | if self.pooling_mode_cls_token: 134 | modes.append("cls") 135 | if self.pooling_mode_mean_tokens: 136 | modes.append("mean") 137 | if self.pooling_mode_max_tokens: 138 | modes.append("max") 139 | if self.pooling_mode_mean_sqrt_len_tokens: 140 | modes.append("mean_sqrt_len_tokens") 141 | if self.pooling_mode_weightedmean_tokens: 142 | modes.append("weightedmean") 143 | if self.pooling_mode_lasttoken: 144 | modes.append("lasttoken") 145 | 146 | return "+".join(modes) 147 | 148 | def forward(self, features): 149 | # print(features.keys()) 150 | token_embeddings = features["token_embeddings"] 151 | attention_mask = features["attention_mask"] 152 | 153 | ## Pooling strategy 154 | output_vectors = [] 155 | if self.pooling_mode_cls_token: 156 | cls_token = features.get( 157 | "cls_token_embeddings", token_embeddings[:, 0] 158 | ) # Take first token by default 159 | output_vectors.append(cls_token) 160 | if self.pooling_mode_max_tokens: 161 | input_mask_expanded = ( 162 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 163 | ) 164 | token_embeddings[ 165 | input_mask_expanded == 0 166 | ] = -1e9 # Set padding tokens to large negative value 167 | max_over_time = torch.max(token_embeddings, 1)[0] 168 | output_vectors.append(max_over_time) 169 | if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: 170 | input_mask_expanded = ( 171 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 172 | ) 173 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 174 | 175 | # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present 176 | if "token_weights_sum" in features: 177 | sum_mask = ( 178 | features["token_weights_sum"] 179 | .unsqueeze(-1) 180 | .expand(sum_embeddings.size()) 181 | ) 182 | else: 183 | sum_mask = input_mask_expanded.sum(1) 184 | 185 | sum_mask = torch.clamp(sum_mask, min=1e-9) 186 | 187 | if self.pooling_mode_mean_tokens: 188 | output_vectors.append(sum_embeddings / sum_mask) 189 | if self.pooling_mode_mean_sqrt_len_tokens: 190 | output_vectors.append(sum_embeddings / torch.sqrt(sum_mask)) 191 | if self.pooling_mode_weightedmean_tokens: 192 | input_mask_expanded = ( 193 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 194 | ) 195 | # token_embeddings shape: bs, seq, hidden_dim 196 | weights = ( 197 | torch.arange(start=1, end=token_embeddings.shape[1] + 1) 198 | .unsqueeze(0) 199 | .unsqueeze(-1) 200 | .expand(token_embeddings.size()) 201 | .float() 202 | .to(token_embeddings.device) 203 | ) 204 | assert weights.shape == token_embeddings.shape == input_mask_expanded.shape 205 | input_mask_expanded = input_mask_expanded * weights 206 | 207 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 208 | 209 | # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present 210 | if "token_weights_sum" in features: 211 | sum_mask = ( 212 | features["token_weights_sum"] 213 | .unsqueeze(-1) 214 | .expand(sum_embeddings.size()) 215 | ) 216 | else: 217 | sum_mask = input_mask_expanded.sum(1) 218 | 219 | sum_mask = torch.clamp(sum_mask, min=1e-9) 220 | output_vectors.append(sum_embeddings / sum_mask) 221 | if self.pooling_mode_lasttoken: 222 | batch_size, _, hidden_dim = token_embeddings.shape 223 | # attention_mask shape: (bs, seq_len) 224 | # Get shape [bs] indices of the last token (i.e. the last token for each batch item) 225 | # argmin gives us the index of the first 0 in the attention mask; 226 | # We get the last 1 index by subtracting 1 227 | gather_indices = ( 228 | torch.argmin(attention_mask, 1, keepdim=False) - 1 229 | ) # Shape [bs] 230 | 231 | # There are empty sequences, where the index would become -1 which will crash 232 | gather_indices = torch.clamp(gather_indices, min=0) 233 | 234 | # Turn indices from shape [bs] --> [bs, 1, hidden_dim] 235 | gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim) 236 | gather_indices = gather_indices.unsqueeze(1) 237 | assert gather_indices.shape == (batch_size, 1, hidden_dim) 238 | 239 | # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim) 240 | # Actually no need for the attention mask as we gather the last token where attn_mask = 1 241 | # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we 242 | # use the attention mask to ignore them again 243 | input_mask_expanded = ( 244 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 245 | ) 246 | embedding = torch.gather( 247 | token_embeddings * input_mask_expanded, 1, gather_indices 248 | ).squeeze(dim=1) 249 | output_vectors.append(embedding) 250 | 251 | output_vector = torch.cat(output_vectors, 1) 252 | features.update({"sentence_embedding": output_vector}) 253 | return features 254 | 255 | def get_sentence_embedding_dimension(self): 256 | return self.pooling_output_dimension 257 | 258 | def get_config_dict(self): 259 | return {key: self.__dict__[key] for key in self.config_keys} 260 | 261 | def save(self, output_path): 262 | with open( 263 | os.path.join(output_path, "config.json"), "w", encoding="UTF-8" 264 | ) as config_file: 265 | json.dump(self.get_config_dict(), config_file, indent=2) 266 | 267 | @staticmethod 268 | def load(input_path): 269 | with open( 270 | os.path.join(input_path, "config.json"), encoding="UTF-8" 271 | ) as config_file: 272 | config = json.load(config_file) 273 | 274 | return InstructorPooling(**config) 275 | 276 | 277 | def import_from_string(dotted_path): 278 | """ 279 | Import a dotted module path and return the attribute/class designated by the 280 | last name in the path. Raise ImportError if the import failed. 281 | """ 282 | try: 283 | module_path, class_name = dotted_path.rsplit(".", 1) 284 | except ValueError: 285 | msg = f"{dotted_path} doesn't look like a module path" 286 | raise ImportError(msg) 287 | 288 | try: 289 | module = importlib.import_module(dotted_path) 290 | except: 291 | module = importlib.import_module(module_path) 292 | 293 | try: 294 | return getattr(module, class_name) 295 | except AttributeError: 296 | msg = f"Module {module_path} does not define a {class_name} attribute/class" 297 | raise ImportError(msg) 298 | 299 | 300 | class InstructorTransformer(Transformer): 301 | def __init__( 302 | self, 303 | model_name_or_path: str, 304 | max_seq_length=None, 305 | model_args=None, 306 | cache_dir=None, 307 | tokenizer_args=None, 308 | do_lower_case: bool = False, 309 | tokenizer_name_or_path: Union[str, None] = None, 310 | load_model: bool = True, 311 | ): 312 | super().__init__(model_name_or_path) 313 | if model_args is None: 314 | model_args = {} 315 | if tokenizer_args is None: 316 | tokenizer_args = {} 317 | self.config_keys = ["max_seq_length", "do_lower_case"] 318 | self.do_lower_case = do_lower_case 319 | self.model_name_or_path = model_name_or_path 320 | if model_name_or_path == "bi-contriever": 321 | model_name_or_path = "facebook/contriever" 322 | if model_name_or_path.startswith("bigtr"): 323 | model_name_or_path = model_name_or_path.split("#")[1] 324 | if "bigtr" in model_name_or_path and os.path.isdir(model_name_or_path): 325 | config = AutoConfig.from_pretrained( 326 | os.path.join(model_name_or_path, "with_prompt"), 327 | **model_args, 328 | cache_dir=cache_dir, 329 | ) 330 | else: 331 | config = AutoConfig.from_pretrained( 332 | model_name_or_path, **model_args, cache_dir=cache_dir 333 | ) 334 | 335 | if load_model: 336 | self._load_model(self.model_name_or_path, config, cache_dir, **model_args) 337 | self.tokenizer = AutoTokenizer.from_pretrained( 338 | tokenizer_name_or_path 339 | if tokenizer_name_or_path is not None 340 | else model_name_or_path, 341 | cache_dir=cache_dir, 342 | **tokenizer_args, 343 | ) 344 | 345 | if max_seq_length is None: 346 | if ( 347 | hasattr(self.auto_model, "config") 348 | and hasattr(self.auto_model.config, "max_position_embeddings") 349 | and hasattr(self.tokenizer, "model_max_length") 350 | ): 351 | max_seq_length = min( 352 | self.auto_model.config.max_position_embeddings, 353 | self.tokenizer.model_max_length, 354 | ) 355 | 356 | self.max_seq_length = max_seq_length 357 | if tokenizer_name_or_path is not None: 358 | self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__ 359 | 360 | def forward(self, features): 361 | input_features = { 362 | "input_ids": features["input_ids"], 363 | "attention_mask": features["attention_mask"], 364 | } 365 | if "token_type_ids" in features: 366 | input_features["token_type_ids"] = features["token_type_ids"] 367 | 368 | instruction_mask = features["instruction_mask"] 369 | output_states = self.auto_model(**input_features, return_dict=False) 370 | output_tokens = output_states[0] 371 | attention_mask = features["attention_mask"] 372 | instruction_mask = features["instruction_mask"] 373 | attention_mask = attention_mask * instruction_mask 374 | features.update( 375 | {"token_embeddings": output_tokens, "attention_mask": attention_mask} 376 | ) 377 | 378 | if self.auto_model.config.output_hidden_states: 379 | all_layer_idx = 2 380 | if ( 381 | len(output_states) < 3 382 | ): # Some models only output last_hidden_states and all_hidden_states 383 | all_layer_idx = 1 384 | hidden_states = output_states[all_layer_idx] 385 | features.update({"all_layer_embeddings": hidden_states}) 386 | 387 | return features 388 | 389 | @staticmethod 390 | def load(input_path: str): 391 | # Old classes used other config names than 'sentence_bert_config.json' 392 | for config_name in [ 393 | "sentence_bert_config.json", 394 | "sentence_roberta_config.json", 395 | "sentence_distilbert_config.json", 396 | "sentence_camembert_config.json", 397 | "sentence_albert_config.json", 398 | "sentence_xlm-roberta_config.json", 399 | "sentence_xlnet_config.json", 400 | ]: 401 | sbert_config_path = os.path.join(input_path, config_name) 402 | if os.path.exists(sbert_config_path): 403 | break 404 | 405 | with open(sbert_config_path, encoding="UTF-8") as config_file: 406 | config = json.load(config_file) 407 | return InstructorTransformer(model_name_or_path=input_path, **config) 408 | 409 | def tokenize(self, texts): 410 | """ 411 | Tokenizes a text and maps tokens to token-ids 412 | """ 413 | output = {} 414 | if isinstance(texts[0], str): 415 | to_tokenize = [texts] 416 | to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] 417 | 418 | # Lowercase 419 | if self.do_lower_case: 420 | to_tokenize = [[s.lower() for s in col] for col in to_tokenize] 421 | 422 | input_features = self.tokenizer( 423 | *to_tokenize, 424 | padding="max_length", 425 | truncation="longest_first", 426 | return_tensors="pt", 427 | max_length=self.max_seq_length, 428 | ) 429 | 430 | elif isinstance(texts[0], list): 431 | assert isinstance(texts[0][1], str) 432 | assert ( 433 | len(texts[0]) == 2 434 | ), "The input should have both instruction and input text" 435 | 436 | instructions = [] 437 | instruction_prepended_input_texts = [] 438 | for pair in texts: 439 | instruction = pair[0].strip() 440 | text = pair[1].strip() 441 | if self.do_lower_case: 442 | instruction = instruction.lower() 443 | text = text.lower() 444 | instructions.append(instruction) 445 | instruction_prepended_input_texts.append("".join([instruction, text])) 446 | 447 | input_features = self.tokenize(instruction_prepended_input_texts) 448 | instruction_features = self.tokenize(instructions) 449 | input_features = Instructor.prepare_input_features( 450 | input_features, instruction_features 451 | ) 452 | else: 453 | raise ValueError("not support other modes") 454 | 455 | output.update(input_features) 456 | return output 457 | 458 | 459 | class Instructor(SentenceTransformer): 460 | @staticmethod 461 | def prepare_input_features( 462 | input_features, instruction_features, return_data_type: str = "pt" 463 | ): 464 | if return_data_type == "np": 465 | input_features["attention_mask"] = torch.from_numpy( 466 | input_features["attention_mask"] 467 | ) 468 | instruction_features["attention_mask"] = torch.from_numpy( 469 | instruction_features["attention_mask"] 470 | ) 471 | 472 | input_attention_mask_shape = input_features["attention_mask"].shape 473 | instruction_attention_mask = instruction_features["attention_mask"] 474 | 475 | # reducing the attention length by 1 in order to omit the attention corresponding to the end_token 476 | instruction_attention_mask = instruction_attention_mask[:, 1:] 477 | 478 | # creating instruction attention matrix equivalent to the size of the input attention matrix 479 | expanded_instruction_attention_mask = torch.zeros( 480 | input_attention_mask_shape, dtype=torch.int64 481 | ) 482 | # assigning the the actual instruction attention matrix to the expanded_instruction_attention_mask 483 | # eg: 484 | # instruction_attention_mask: 3x3 485 | # [[1,1,1], 486 | # [1,1,0], 487 | # [1,0,0]] 488 | # expanded_instruction_attention_mask: 3x4 489 | # [[1,1,1,0], 490 | # [1,1,0,0], 491 | # [1,0,0,0]] 492 | expanded_instruction_attention_mask[ 493 | : instruction_attention_mask.size(0), : instruction_attention_mask.size(1) 494 | ] = instruction_attention_mask 495 | 496 | # In the pooling layer we want to consider only the tokens corresponding to the input text 497 | # and not the instruction. This is achieved by inverting the 498 | # attention_mask corresponding to the instruction. 499 | expanded_instruction_attention_mask = 1 - expanded_instruction_attention_mask 500 | input_features["instruction_mask"] = expanded_instruction_attention_mask 501 | if return_data_type == "np": 502 | input_features["attention_mask"] = input_features["attention_mask"].numpy() 503 | instruction_features["attention_mask"] = instruction_features[ 504 | "attention_mask" 505 | ].numpy() 506 | return input_features 507 | 508 | def smart_batching_collate(self, batch): 509 | num_texts = len(batch[0].texts) 510 | texts = [[] for _ in range(num_texts)] 511 | labels = [] 512 | 513 | for example in batch: 514 | for idx, text in enumerate(example.texts): 515 | texts[idx].append(text) 516 | labels.append(example.label) 517 | 518 | labels = torch.tensor(labels) 519 | batched_input_features = [] 520 | 521 | for idx in range(num_texts): 522 | assert isinstance(texts[idx][0], list) 523 | assert ( 524 | len(texts[idx][0]) == 2 525 | ), "The input should have both instruction and input text" 526 | 527 | num = len(texts[idx]) 528 | instructions = [] 529 | instruction_prepended_input_texts = [] 530 | for local_idx in range(num): 531 | assert len(texts[idx][local_idx]) == 2 532 | instructions.append(texts[idx][local_idx][0]) 533 | instruction_prepended_input_texts.append("".join(texts[idx][local_idx])) 534 | assert isinstance(instructions[-1], str) 535 | assert isinstance(instruction_prepended_input_texts[-1], str) 536 | 537 | input_features = self.tokenize(instruction_prepended_input_texts) 538 | instruction_features = self.tokenize(instructions) 539 | input_features = Instructor.prepare_input_features( 540 | input_features, instruction_features 541 | ) 542 | batched_input_features.append(input_features) 543 | 544 | return batched_input_features, labels 545 | 546 | def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False, **kwargs): 547 | """ 548 | Loads a full sentence-transformers model 549 | """ 550 | # Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544 551 | download_kwargs = { 552 | "repo_id": model_path, 553 | "revision": revision, 554 | "library_name": "sentence-transformers", 555 | "token": token, 556 | "cache_dir": cache_folder, 557 | "tqdm_class": disabled_tqdm, 558 | } 559 | model_path = snapshot_download(**download_kwargs) 560 | 561 | # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) 562 | config_sentence_transformers_json_path = os.path.join( 563 | model_path, "config_sentence_transformers.json" 564 | ) 565 | if os.path.exists(config_sentence_transformers_json_path): 566 | with open( 567 | config_sentence_transformers_json_path, encoding="UTF-8" 568 | ) as config_file: 569 | self._model_config = json.load(config_file) 570 | 571 | # Check if a readme exists 572 | model_card_path = os.path.join(model_path, "README.md") 573 | if os.path.exists(model_card_path): 574 | try: 575 | with open(model_card_path, encoding="utf8") as config_file: 576 | self._model_card_text = config_file.read() 577 | except: 578 | pass 579 | 580 | # Load the modules of sentence transformer 581 | modules_json_path = os.path.join(model_path, "modules.json") 582 | with open(modules_json_path, encoding="UTF-8") as config_file: 583 | modules_config = json.load(config_file) 584 | 585 | modules = OrderedDict() 586 | for module_config in modules_config: 587 | if module_config["idx"] == 0: 588 | module_class = InstructorTransformer 589 | elif module_config["idx"] == 1: 590 | module_class = InstructorPooling 591 | else: 592 | module_class = import_from_string(module_config["type"]) 593 | module = module_class.load(os.path.join(model_path, module_config["path"])) 594 | modules[module_config["name"]] = module 595 | 596 | return modules 597 | 598 | def encode( 599 | self, 600 | sentences, 601 | batch_size: int = 32, 602 | show_progress_bar: Union[bool, None] = None, 603 | output_value: str = "sentence_embedding", 604 | convert_to_numpy: bool = True, 605 | convert_to_tensor: bool = False, 606 | device: Union[str, None] = None, 607 | normalize_embeddings: bool = False, 608 | ): 609 | """ 610 | Computes sentence embeddings 611 | 612 | :param sentences: the sentences to embed 613 | :param batch_size: the batch size used for the computation 614 | :param show_progress_bar: Output a progress bar when encode sentences 615 | :param output_value: Default sentence_embedding, to get sentence embeddings. 616 | Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values 617 | :param convert_to_numpy: If true, the output is a list of numpy vectors. 618 | Else, it is a list of pytorch tensors. 619 | :param convert_to_tensor: If true, you get one large tensor as return. 620 | Overwrites any setting from convert_to_numpy 621 | :param device: Which torch.device to use for the computation 622 | :param normalize_embeddings: If set to true, returned vectors will have length 1. 623 | In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. 624 | 625 | :return: 626 | By default, a list of tensors is returned. If convert_to_tensor, 627 | a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. 628 | """ 629 | self.eval() 630 | if show_progress_bar is None: 631 | show_progress_bar = False 632 | 633 | if convert_to_tensor: 634 | convert_to_numpy = False 635 | 636 | if output_value != "sentence_embedding": 637 | convert_to_tensor = False 638 | convert_to_numpy = False 639 | 640 | input_was_string = False 641 | if isinstance(sentences, str) or not hasattr( 642 | sentences, "__len__" 643 | ): # Cast an individual sentence to a list with length 1 644 | sentences = [sentences] 645 | input_was_string = True 646 | 647 | if device is None: 648 | device = self.device 649 | 650 | self.to(device) 651 | 652 | all_embeddings = [] 653 | if isinstance(sentences[0], list): 654 | lengths = [] 655 | for sen in sentences: 656 | lengths.append(-self._text_length(sen[1])) 657 | length_sorted_idx = np.argsort(lengths) 658 | else: 659 | length_sorted_idx = np.argsort( 660 | [-self._text_length(sen) for sen in sentences] 661 | ) 662 | sentences_sorted = [sentences[idx] for idx in length_sorted_idx] 663 | 664 | for start_index in range(0, len(sentences), batch_size): 665 | sentences_batch = sentences_sorted[start_index : start_index + batch_size] 666 | features = self.tokenize(sentences_batch) 667 | features = batch_to_device(features, device) 668 | 669 | with torch.no_grad(): 670 | out_features = self.forward(features) 671 | 672 | if output_value == "token_embeddings": 673 | embeddings = [] 674 | for token_emb, attention in zip( 675 | out_features[output_value], out_features["attention_mask"] 676 | ): 677 | last_mask_id = len(attention) - 1 678 | while last_mask_id > 0 and attention[last_mask_id].item() == 0: 679 | last_mask_id -= 1 680 | 681 | embeddings.append(token_emb[0 : last_mask_id + 1]) 682 | elif output_value is None: # Return all outputs 683 | embeddings = [] 684 | for sent_idx in range(len(out_features["sentence_embedding"])): 685 | row = { 686 | name: out_features[name][sent_idx] for name in out_features 687 | } 688 | embeddings.append(row) 689 | else: # Sentence embeddings 690 | embeddings = out_features[output_value] 691 | embeddings = embeddings.detach() 692 | if normalize_embeddings: 693 | embeddings = torch.nn.functional.normalize( 694 | embeddings, p=2, dim=1 695 | ) 696 | 697 | # fixes for #522 and #487 to avoid oom problems on gpu with large datasets 698 | if convert_to_numpy: 699 | embeddings = embeddings.cpu() 700 | 701 | all_embeddings.extend(embeddings) 702 | 703 | all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] 704 | 705 | if convert_to_tensor: 706 | all_embeddings = torch.stack(all_embeddings) 707 | elif convert_to_numpy: 708 | all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) 709 | 710 | if input_was_string: 711 | all_embeddings = all_embeddings[0] 712 | 713 | return all_embeddings 714 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/jinaai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import requests 6 | 7 | from pymilvus.model.base import BaseEmbeddingFunction 8 | 9 | API_URL = "https://api.jina.ai/v1/embeddings" 10 | 11 | 12 | class JinaEmbeddingFunction(BaseEmbeddingFunction): 13 | def __init__( 14 | self, 15 | model_name: str = "jina-embeddings-v3", 16 | api_key: Optional[str] = None, 17 | task: str = 'retrieval.passage', 18 | dimensions: Optional[int] = None, 19 | late_chunking: Optional[bool] = False, 20 | **kwargs, 21 | ): 22 | if api_key is None: 23 | if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]: 24 | self.api_key = os.environ["JINAAI_API_KEY"] 25 | else: 26 | error_message = ( 27 | "Did not find api_key, please add an environment variable" 28 | " `JINAAI_API_KEY` which contains it, or pass" 29 | " `api_key` as a named parameter." 30 | ) 31 | raise ValueError(error_message) 32 | else: 33 | self.api_key = api_key 34 | self.model_name = model_name 35 | self._session = requests.Session() 36 | self._session.headers.update( 37 | {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} 38 | ) 39 | self.task = task 40 | self._dim = dimensions 41 | self.late_chunking = late_chunking 42 | 43 | @property 44 | def dim(self): 45 | if self._dim is None: 46 | self._dim = self._call_jina_api([""])[0].shape[0] 47 | return self._dim 48 | 49 | def encode_queries(self, queries: List[str]) -> List[np.array]: 50 | return self._call_jina_api(queries, task='retrieval.query') 51 | 52 | def encode_documents(self, documents: List[str]) -> List[np.array]: 53 | return self._call_jina_api(documents, task='retrieval.passage') 54 | 55 | def __call__(self, texts: List[str]) -> List[np.array]: 56 | return self._call_jina_api(texts, task=self.task) 57 | 58 | def _call_jina_api(self, texts: List[str], task: Optional[str] = None): 59 | data = { 60 | "input": texts, 61 | "model": self.model_name, 62 | "task": task, 63 | "late_chunking": self.late_chunking, 64 | } 65 | if self._dim is not None: 66 | data["dimensions"] = self._dim 67 | resp = self._session.post( # type: ignore[assignment] 68 | API_URL, 69 | json=data, 70 | ).json() 71 | if "data" not in resp: 72 | raise RuntimeError(resp["detail"]) 73 | 74 | embeddings = resp["data"] 75 | 76 | # Sort resulting embeddings by index 77 | sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] 78 | return [np.array(result["embedding"]) for result in sorted_embeddings] 79 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/mistralai.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import os 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | from pymilvus.model.base import BaseEmbeddingFunction 7 | from pymilvus.model.utils import import_mistralai 8 | 9 | 10 | 11 | class MistralAIEmbeddingFunction(BaseEmbeddingFunction): 12 | def __init__( 13 | self, 14 | api_key: str, 15 | model_name: str = "mistral-embed", 16 | **kwargs, 17 | ): 18 | import_mistralai() 19 | from mistralai import Mistral 20 | 21 | self._mistral_model_meta_info = defaultdict(dict) 22 | self._mistral_model_meta_info[model_name]["dim"] = 1024 # fixed dimension 23 | 24 | if api_key is None: 25 | if "MISTRALAI_API_KEY" in os.environ and os.environ["MISTRALAI_API_KEY"]: 26 | self.api_key = os.environ["MISTRALAI_API_KEY"] 27 | else: 28 | error_message = ( 29 | "Did not find api_key, please add an environment variable" 30 | " `MISTRALAI_API_KEY` which contains it, or pass" 31 | " `api_key` as a named parameter." 32 | ) 33 | raise ValueError(error_message) 34 | else: 35 | self.api_key = api_key 36 | self.model_name = model_name 37 | self.client = Mistral(api_key=api_key) 38 | self._encode_config = {"model": model_name, **kwargs} 39 | 40 | def encode_queries(self, queries: List[str]) -> List[np.array]: 41 | return self._encode(queries) 42 | 43 | def encode_documents(self, documents: List[str]) -> List[np.array]: 44 | return self._encode(documents) 45 | 46 | @property 47 | def dim(self): 48 | return self._mistral_model_meta_info[self.model_name]["dim"] 49 | 50 | def __call__(self, texts: List[str]) -> List[np.array]: 51 | return self._encode(texts) 52 | 53 | def _encode_query(self, query: str) -> np.array: 54 | return self._encode([query])[0] 55 | 56 | def _encode_document(self, document: str) -> np.array: 57 | return self._encode([document])[0] 58 | 59 | def _call_mistral_api(self, texts: List[str]): 60 | embeddings_batch_response = self.client.embeddings.create( 61 | inputs=texts, 62 | **self._encode_config 63 | ) 64 | return [np.array(data.embedding) for data in embeddings_batch_response.data] 65 | 66 | def _encode(self, texts: List[str]): 67 | return self._call_mistral_api(texts) 68 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/model2vec_embed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List, Union 3 | from pathlib import Path 4 | from pymilvus.model.base import BaseEmbeddingFunction 5 | from pymilvus.model.utils import import_model2vec 6 | 7 | 8 | class Model2VecEmbeddingFunction(BaseEmbeddingFunction): 9 | def __init__(self, model_source: Union[str, Path] = "minishlab/potion-base-8M", **kwargs): 10 | """ 11 | Initialize the Model2VecEmbeddingFunction, which loads a Model2Vec model either from the Hugging Face Hub or from a local directory. 12 | Defaults to use the "minishlab/potion-base-8M" model and load from Hugging Face. 13 | 14 | Parameters: 15 | model_source (Union[str, Path]): 16 | - If a string is provided and it does not correspond to an existing local directory, 17 | it is interpreted as a Hugging Face model identifier (e.g., "minishlab/potion-base-8M"). 18 | - If the provided string (or Path) corresponds to an existing directory, the model is loaded locally. 19 | **kwargs: 20 | - Additional keyword arguments that will be passed to the StaticModel.from_pretrained method 21 | when loading a remote model from the Hugging Face Hub, including parameters such as 22 | huggingface authentication tokens. 23 | """ 24 | import_model2vec() 25 | from model2vec import StaticModel 26 | 27 | self.model_source = model_source 28 | model_path = Path(model_source) 29 | 30 | if model_path.exists() and model_path.is_dir(): 31 | self.model = StaticModel.load_local(model_path) 32 | else: 33 | self.model = StaticModel.from_pretrained(model_source, **kwargs) 34 | 35 | dummy_embedding = self.model.encode(["dummy"]) 36 | self._dim = dummy_embedding[0].shape[0] 37 | 38 | @property 39 | def dim(self) -> int: 40 | return self._dim 41 | 42 | def encode_queries(self, queries: List[str]) -> List[np.array]: 43 | return self._encode(queries) 44 | 45 | def encode_documents(self, documents: List[str]) -> List[np.array]: 46 | return self._encode(documents) 47 | 48 | def _encode_query(self, query: str) -> np.array: 49 | return self._encode([query])[0] 50 | 51 | def _encode_document(self, document: str) -> np.array: 52 | return self._encode([document])[0] 53 | 54 | def __call__(self, texts: List[str]) -> List[np.array]: 55 | return self._encode(texts) 56 | 57 | def _encode(self, texts: List[str]) -> List[np.array]: 58 | embeddings = self.model.encode(texts) 59 | return [embeddings[i] for i in range(embeddings.shape[0])] -------------------------------------------------------------------------------- /src/pymilvus/model/dense/nomic.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import os 4 | from collections import defaultdict 5 | 6 | from pymilvus.model.base import BaseEmbeddingFunction 7 | from pymilvus.model.utils import import_nomic 8 | 9 | 10 | class NomicEmbeddingFunction(BaseEmbeddingFunction): 11 | def __init__( 12 | self, 13 | model_name: str = "nomic-embed-text-v1.5", 14 | task_type: str = "search_document", 15 | dimensions: int = 768, 16 | **kwargs, 17 | ): 18 | self._nomic_model_meta_info = defaultdict(dict) 19 | self._nomic_model_meta_info[model_name]["dim"] = dimensions # set the dimension 20 | 21 | self.model_name = model_name 22 | self.task_type = task_type 23 | self.dimensionality = dimensions 24 | if "dimensionality" in kwargs: 25 | self.dimensionality = kwargs["dimensionality"] 26 | kwargs.pop("dimensionality") 27 | 28 | self._encode_config = { 29 | "model": model_name, 30 | "task_type": task_type, 31 | "dimensionality": self.dimensionality, 32 | **kwargs, 33 | } 34 | 35 | def encode_queries(self, queries: List[str]) -> List[np.array]: 36 | return self._encode(queries, task_type="search_query") 37 | 38 | def encode_documents(self, documents: List[str]) -> List[np.array]: 39 | return self._encode(documents, task_type="search_document") 40 | 41 | @property 42 | def dim(self): 43 | return self._nomic_model_meta_info[self.model_name]["dim"] 44 | 45 | def __call__(self, texts: List[str]) -> List[np.array]: 46 | return self._encode(texts, task_type=self.task_type) 47 | 48 | def _encode_query(self, query: str) -> np.array: 49 | return self._encode([query], task_type="search_query")[0] 50 | 51 | def _encode_document(self, document: str) -> np.array: 52 | return self._encode([document], task_type="search_document")[0] 53 | 54 | def _call_nomic_api(self, texts: List[str], task_type: str): 55 | import_nomic() 56 | from nomic import embed 57 | 58 | embeddings_batch_response = embed.text( 59 | texts=texts, 60 | **self._encode_config 61 | ) 62 | return [np.array(embedding) for embedding in embeddings_batch_response["embeddings"]] 63 | 64 | def _encode(self, texts: List[str], task_type: str): 65 | return self._call_nomic_api(texts, task_type) 66 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/onnx.py: -------------------------------------------------------------------------------- 1 | 2 | import onnxruntime 3 | 4 | from transformers import AutoTokenizer, AutoConfig 5 | import numpy as np 6 | from typing import List 7 | 8 | from pymilvus.model.base import BaseEmbeddingFunction 9 | from pymilvus.model.utils import import_huggingface_hub 10 | 11 | 12 | class OnnxEmbeddingFunction(BaseEmbeddingFunction): 13 | def __init__(self, model_name: str = "GPTCache/paraphrase-albert-onnx", tokenizer_name: str = "GPTCache/paraphrase-albert-small-v2"): 14 | import_huggingface_hub() 15 | from huggingface_hub import hf_hub_download 16 | 17 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 18 | self.model_name = model_name 19 | onnx_model_path = hf_hub_download(repo_id=model_name, filename="model.onnx") 20 | self.ort_session = onnxruntime.InferenceSession(onnx_model_path) 21 | config = AutoConfig.from_pretrained( 22 | tokenizer_name 23 | ) 24 | self.__dimension = config.hidden_size 25 | 26 | def __call__(self, texts: List[str]) -> List[np.array]: 27 | return self._encode(texts) 28 | 29 | def encode_queries(self, queries: List[str]) -> List[np.array]: 30 | return self._encode(queries) 31 | 32 | def encode_documents(self, documents: List[str]) -> List[np.array]: 33 | return self._encode(documents) 34 | 35 | def _encode(self, texts: List[str]) -> List[np.array]: 36 | return [self._to_embedding(text) for text in texts] 37 | 38 | def _to_embedding(self, data: str, **_): 39 | encoded_text = self.tokenizer.encode_plus(data, padding="max_length", truncation=True) 40 | 41 | ort_inputs = { 42 | "input_ids": np.array(encoded_text["input_ids"]).astype("int64").reshape(1, -1), 43 | "attention_mask": np.array(encoded_text["attention_mask"]).astype("int64").reshape(1, -1), 44 | "token_type_ids": np.array(encoded_text["token_type_ids"]).astype("int64").reshape(1, -1), 45 | } 46 | 47 | ort_outputs = self.ort_session.run(None, ort_inputs) 48 | ort_feat = ort_outputs[0] 49 | emb = self._post_proc(ort_feat, ort_inputs["attention_mask"]) 50 | emb = emb.flatten() 51 | return emb / np.linalg.norm(emb) 52 | 53 | def _post_proc(self, token_embeddings, attention_mask): 54 | input_mask_expanded = ( 55 | np.expand_dims(attention_mask, -1) 56 | .repeat(token_embeddings.shape[-1], -1) 57 | .astype(float) 58 | ) 59 | sentence_embs = np.sum(token_embeddings * input_mask_expanded, 1) / np.maximum( 60 | input_mask_expanded.sum(1), 1e-9 61 | ) 62 | return sentence_embs 63 | 64 | @property 65 | def dim(self): 66 | return self.__dimension 67 | 68 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/openai.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | 6 | from pymilvus.model.base import BaseEmbeddingFunction 7 | from pymilvus.model.utils import import_openai 8 | 9 | 10 | class OpenAIEmbeddingFunction(BaseEmbeddingFunction): 11 | def __init__( 12 | self, 13 | model_name: str = "text-embedding-ada-002", 14 | api_key: Optional[str] = None, 15 | base_url: Optional[str] = None, 16 | dimensions: Optional[int] = None, 17 | **kwargs, 18 | ): 19 | import_openai() 20 | from openai import OpenAI 21 | 22 | self._openai_model_meta_info = defaultdict(dict) 23 | self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536 24 | self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072 25 | self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536 26 | 27 | self._model_config = dict({"api_key": api_key, "base_url": base_url}, **kwargs) 28 | additional_encode_config = {} 29 | if dimensions is not None: 30 | additional_encode_config = {"dimensions": dimensions} 31 | self._openai_model_meta_info[model_name]["dim"] = dimensions 32 | 33 | self._encode_config = {"model": model_name, **additional_encode_config} 34 | self.model_name = model_name 35 | self.client = OpenAI(**self._model_config) 36 | 37 | def encode_queries(self, queries: List[str]) -> List[np.array]: 38 | return self._encode(queries) 39 | 40 | def encode_documents(self, documents: List[str]) -> List[np.array]: 41 | return self._encode(documents) 42 | 43 | @property 44 | def dim(self): 45 | return self._openai_model_meta_info[self.model_name]["dim"] 46 | 47 | def __call__(self, texts: List[str]) -> List[np.array]: 48 | return self._encode(texts) 49 | 50 | def _encode_query(self, query: str) -> np.array: 51 | return self._encode(query)[0] 52 | 53 | def _encode_document(self, document: str) -> np.array: 54 | return self._encode(document)[0] 55 | 56 | def _call_openai_api(self, texts: List[str]): 57 | results = self.client.embeddings.create(input=texts, **self._encode_config).data 58 | return [np.array(data.embedding) for data in results] 59 | 60 | def _encode(self, texts: List[str]): 61 | return self._call_openai_api(texts) 62 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from pymilvus.model.base import BaseEmbeddingFunction 6 | from pymilvus.model.utils import import_sentence_transformers 7 | 8 | 9 | class SentenceTransformerEmbeddingFunction(BaseEmbeddingFunction): 10 | def __init__( 11 | self, 12 | model_name: str = "all-MiniLM-L6-v2", 13 | batch_size: int = 32, 14 | query_instruction: str = "", 15 | doc_instruction: str = "", 16 | device: str = "cpu", 17 | normalize_embeddings: bool = True, 18 | **kwargs, 19 | ): 20 | import_sentence_transformers() 21 | from sentence_transformers import SentenceTransformer 22 | self.model_name = model_name 23 | self.query_instruction = query_instruction 24 | self.doc_instruction = doc_instruction 25 | self.batch_size = batch_size 26 | self.normalize_embeddings = normalize_embeddings 27 | 28 | _model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs) 29 | self.model = SentenceTransformer(**_model_config) 30 | 31 | def __call__(self, texts: List[str]) -> List[np.array]: 32 | return self._encode(texts) 33 | 34 | def _encode(self, texts: List[str]) -> List[np.array]: 35 | embs = self.model.encode( 36 | texts, 37 | batch_size=self.batch_size, 38 | show_progress_bar=False, 39 | convert_to_numpy=True, 40 | normalize_embeddings=self.normalize_embeddings 41 | ) 42 | return list(embs) 43 | 44 | @property 45 | def dim(self): 46 | return self.model.get_sentence_embedding_dimension() 47 | 48 | def encode_queries(self, queries: List[str]) -> List[np.array]: 49 | instructed_queries = [self.query_instruction + query for query in queries] 50 | return self._encode(instructed_queries) 51 | 52 | def encode_documents(self, documents: List[str]) -> List[np.array]: 53 | instructed_documents = [self.doc_instruction + document for document in documents] 54 | return self._encode(instructed_documents) 55 | 56 | def _encode_query(self, query: str) -> np.array: 57 | instructed_query = self.query_instruction + query 58 | embs = self.model.encode( 59 | sentences=[instructed_query], 60 | batch_size=1, 61 | show_progress_bar=False, 62 | convert_to_numpy=True, 63 | normalize_embeddings=self.normalize_embeddings, 64 | ) 65 | return embs[0] 66 | 67 | def _encode_document(self, document: str) -> np.array: 68 | instructed_document = self.doc_instruction + document 69 | embs = self.model.encode( 70 | sentences=[instructed_document], 71 | batch_size=1, 72 | show_progress_bar=False, 73 | convert_to_numpy=True, 74 | normalize_embeddings=self.normalize_embeddings, 75 | ) 76 | return embs[0] 77 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/tei.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import requests 5 | 6 | from pymilvus.model.base import BaseEmbeddingFunction 7 | 8 | 9 | class TEIEmbeddingFunction(BaseEmbeddingFunction): 10 | def __init__( 11 | self, 12 | api_url: str, 13 | dimensions: Optional[int] = None, 14 | ): 15 | self.api_url = api_url + "/v1/embeddings" 16 | self._session = requests.Session() 17 | self._dim = dimensions 18 | 19 | @property 20 | def dim(self): 21 | if self._dim is None: 22 | # This works by sending a dummy message to the API to retrieve the vector dimension, 23 | # as the original API does not directly provide this information 24 | self._dim = self._call_api(["get dim"])[0].shape[0] 25 | return self._dim 26 | 27 | def encode_queries(self, queries: List[str]) -> List[np.array]: 28 | return self._call_api(queries) 29 | 30 | def encode_documents(self, documents: List[str]) -> List[np.array]: 31 | return self._call_api(documents) 32 | 33 | def __call__(self, texts: List[str]) -> List[np.array]: 34 | return self._call_api(texts) 35 | 36 | def _call_api(self, texts: List[str]): 37 | data = {"input": texts} 38 | resp = self._session.post( # type: ignore[assignment] 39 | self.api_url, 40 | json=data, 41 | ).json() 42 | if "data" not in resp: 43 | raise RuntimeError(resp["message"]) 44 | 45 | embeddings = resp["data"] 46 | 47 | # Sort resulting embeddings by index 48 | sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] 49 | return [np.array(result["embedding"]) for result in sorted_embeddings] 50 | -------------------------------------------------------------------------------- /src/pymilvus/model/dense/voyageai.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import struct 6 | 7 | from pymilvus.model.base import BaseEmbeddingFunction 8 | from pymilvus.model.utils import import_voyageai 9 | 10 | 11 | class VoyageEmbeddingFunction(BaseEmbeddingFunction): 12 | def __init__(self, 13 | model_name: str = "voyage-3-large", 14 | api_key: Optional[str] = None, 15 | embedding_type: Optional[str] = None, 16 | truncate: Optional[bool] = None, 17 | dimension: Optional[int] = None, 18 | **kwargs): 19 | import_voyageai() 20 | import voyageai 21 | 22 | self.model_name = model_name 23 | self.truncate = truncate 24 | self._voyageai_model_meta_info = defaultdict(dict) 25 | self._voyageai_model_meta_info["voyage-3-large"]["dim"] = [1024, 256, 512, 2048] 26 | self._voyageai_model_meta_info["voyage-code-3"]["dim"] = [1024, 256, 512, 2048] 27 | self._voyageai_model_meta_info["voyage-3"]["dim"] = [1024] 28 | self._voyageai_model_meta_info["voyage-3-lite"]["dim"] = [512] 29 | self._voyageai_model_meta_info["voyage-finance-2"]["dim"] = [1024] 30 | self._voyageai_model_meta_info["voyage-multilingual-2"]["dim"] = [1024] 31 | self._voyageai_model_meta_info["voyage-law-2"]["dim"] = [1024] 32 | #old model 33 | self._voyageai_model_meta_info["voyage-large-2"]["dim"] = [1536] 34 | self._voyageai_model_meta_info["voyage-code-2"]["dim"] = [1536] 35 | self._voyageai_model_meta_info["voyage-2"]["dim"] = [1024] 36 | self._voyageai_model_meta_info["voyage-lite-02-instruct"]["dim"] = [1024] 37 | 38 | if dimension is not None and dimension not in self._voyageai_model_meta_info[self.model_name]["dim"]: 39 | raise ValueError(f"The provided dimension ({dimension}) is not supported by the selected model ({self.model_name}). " 40 | "Leave this parameter empty to use the default dimension for the model. " 41 | "Please check the supported dimensions here: https://docs.voyageai.com/docs/embeddings" 42 | ) 43 | 44 | if embedding_type == "int8" or embedding_type == "uint8": 45 | raise ValueError("Currently int8 or uint8 is not supported with PyMilvus model library.") 46 | 47 | if self.model_name in ['voyage-3-large', 'voyage-code-3']: 48 | if embedding_type is not None and embedding_type not in ['float', 'binary', 'ubinary']: 49 | raise ValueError(f"The provided embedding_type ({embedding_type}) is not supported by the selected model " 50 | f"({self.model_name}). Leave this parameter empty for the default embedding_type (float). " 51 | f"Please check the supported embedding_type values here: https://docs.voyageai.com/docs/embeddings") 52 | else: 53 | if embedding_type is not None and embedding_type != 'float': 54 | raise ValueError(f"The provided embedding_type ({embedding_type}) is not supported by the selected model " 55 | f"({self.model_name}). Leave this parameter empty for the default embedding_type (float). " 56 | f"Please check the supported embedding_type values here: https://docs.voyageai.com/docs/embeddings") 57 | 58 | self.embedding_type = embedding_type 59 | self.dimension = dimension 60 | self.client = voyageai.Client(api_key, **kwargs) 61 | 62 | @property 63 | def dim(self): 64 | if self.dimension is None: 65 | return self._voyageai_model_meta_info[self.model_name]["dim"][0] 66 | return self.dimension 67 | 68 | def encode_queries(self, queries: List[str]) -> List[np.array]: 69 | return self._call_voyage_api(queries, input_type="query") 70 | 71 | def encode_documents(self, documents: List[str]) -> List[np.array]: 72 | return self._call_voyage_api(documents, input_type="document") 73 | 74 | def __call__(self, texts: List[str]) -> List[np.array]: 75 | return self._call_voyage_api(texts) 76 | 77 | def _call_voyage_api(self, texts: List[str], input_type: Optional[str] = None): 78 | embeddings = self.client.embed( 79 | texts=texts, 80 | model=self.model_name, 81 | input_type=input_type, 82 | truncation=self.truncate, 83 | output_dtype=self.embedding_type, 84 | output_dimension=self.dim, 85 | ).embeddings 86 | 87 | if self.embedding_type is None or self.embedding_type == "float": 88 | results = [np.array(data, dtype=np.float32) for data in embeddings] 89 | elif self.embedding_type == "binary": 90 | results = [ 91 | np.unpackbits((np.array(result, dtype=np.int16) + 128).astype(np.uint8)).astype(bool) 92 | for result in embeddings 93 | ] 94 | elif self.embedding_type == "ubinary": 95 | results = [np.unpackbits(np.array(result, dtype=np.uint8)).astype(bool) for result in embeddings] 96 | else: 97 | raise ValueError(f"The provided embedding_type ({self.embedding_type}) is not supported by the selected model " 98 | f"({self.model_name}). Leave this parameter empty for the default embedding_type (float). " 99 | f"Please check the supported embedding_type values here: https://docs.voyageai.com/docs/embeddings") 100 | return results 101 | -------------------------------------------------------------------------------- /src/pymilvus/model/hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from pymilvus.model.hybrid.bge_m3 import BGEM3EmbeddingFunction 2 | from pymilvus.model.hybrid.mgte import MGTEEmbeddingFunction 3 | 4 | __all__ = ["BGEM3EmbeddingFunction", "MGTEEmbeddingFunction"] 5 | -------------------------------------------------------------------------------- /src/pymilvus/model/hybrid/bge_m3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | from scipy.sparse import csr_array 5 | import numpy as np 6 | 7 | from pymilvus.model.base import BaseEmbeddingFunction 8 | from pymilvus.model.utils import import_FlagEmbedding, import_datasets 9 | from pymilvus.model.sparse.utils import stack_sparse_embeddings 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.DEBUG) 14 | 15 | 16 | class BGEM3EmbeddingFunction(BaseEmbeddingFunction): 17 | def __init__( 18 | self, 19 | model_name: str = "BAAI/bge-m3", 20 | batch_size: int = 16, 21 | device: str = None, 22 | normalize_embeddings: bool = True, 23 | use_fp16: bool = False, 24 | return_dense: bool = True, 25 | return_sparse: bool = True, 26 | return_colbert_vecs: bool = False, 27 | **kwargs, 28 | ): 29 | import_datasets() 30 | import_FlagEmbedding() 31 | 32 | try: 33 | from FlagEmbedding import BGEM3FlagModel 34 | except AttributeError as e: 35 | import sys 36 | if "google.colab" in sys.modules and "ListView" in str(e): 37 | print( 38 | "\033[91mIt looks like you're running on Google Colab. Please restart the session to resolve this issue.\033[0m") 39 | print( 40 | "\033[91mFor further details, visit: https://github.com/milvus-io/milvus-model/issues/32.\033[0m") 41 | raise 42 | 43 | self.model_name = model_name 44 | self.batch_size = batch_size 45 | self.normalize_embeddings = normalize_embeddings 46 | self.device = device 47 | self.use_fp16 = use_fp16 48 | 49 | if device == "cpu" and use_fp16 is True: 50 | logger.warning( 51 | "Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. " 52 | ) 53 | 54 | if "devices" in kwargs: 55 | device = kwargs["devices"] 56 | kwargs.pop("devices") 57 | 58 | _model_config = dict( 59 | { 60 | "model_name_or_path": model_name, 61 | "devices": device, 62 | "normalize_embeddings": normalize_embeddings, 63 | "use_fp16": use_fp16, 64 | }, 65 | **kwargs, 66 | ) 67 | _encode_config = { 68 | "batch_size": batch_size, 69 | "return_dense": return_dense, 70 | "return_sparse": return_sparse, 71 | "return_colbert_vecs": return_colbert_vecs, 72 | } 73 | self._model_config = _model_config 74 | self._encode_config = _encode_config 75 | 76 | self.model = BGEM3FlagModel(**self._model_config) 77 | 78 | def __call__(self, texts: List[str]) -> Dict: 79 | return self._encode(texts) 80 | 81 | @property 82 | def dim(self) -> Dict: 83 | return { 84 | "dense": self.model.model.model.config.hidden_size, 85 | "colbert_vecs": self.model.model.colbert_linear.out_features, 86 | "sparse": len(self.model.tokenizer), 87 | } 88 | 89 | def _encode(self, texts: List[str]) -> Dict: 90 | output = self.model.encode(sentences=texts, **self._encode_config) 91 | results = {} 92 | if self._encode_config["return_dense"] is True: 93 | results["dense"] = list(output["dense_vecs"]) 94 | if self._encode_config["return_sparse"] is True: 95 | sparse_dim = self.dim["sparse"] 96 | results["sparse"] = [] 97 | for sparse_vec in output["lexical_weights"]: 98 | indices = [int(k) for k in sparse_vec] 99 | values = np.array(list(sparse_vec.values()), dtype=np.float64) 100 | row_indices = [0] * len(indices) 101 | csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) 102 | results["sparse"].append(csr) 103 | results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr() 104 | if self._encode_config["return_colbert_vecs"] is True: 105 | results["colbert_vecs"] = output["colbert_vecs"] 106 | return results 107 | 108 | 109 | def encode_queries(self, queries: List[str]) -> Dict: 110 | return self._encode(queries) 111 | 112 | def encode_documents(self, documents: List[str]) -> Dict: 113 | return self._encode(documents) 114 | -------------------------------------------------------------------------------- /src/pymilvus/model/hybrid/mgte.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code is adapted from/inspired from : 3 | https://huggingface.co/Alibaba-NLP/gte-multilingual-base/blob/main/scripts/gte_embedding.py 4 | 5 | # Copyright 2024 The GTE Team Authors and Alibaba Group. 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | """ 8 | import logging 9 | 10 | from typing import Dict, List, Optional 11 | 12 | from pymilvus.model.base import BaseEmbeddingFunction 13 | from pymilvus.model.sparse.utils import stack_sparse_embeddings 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.DEBUG) 17 | 18 | 19 | 20 | class MGTEEmbeddingFunction(BaseEmbeddingFunction): 21 | def __init__( 22 | self, 23 | model_name: str = "Alibaba-NLP/gte-multilingual-base", 24 | batch_size: int = 16, 25 | device: str = "", 26 | normalize_embeddings: bool = True, 27 | dimensions: Optional[int] = None, 28 | use_fp16: bool = False, 29 | return_dense: bool = True, 30 | return_sparse: bool = True, 31 | **kwargs, 32 | ): 33 | from .mgte_embedding.gte_impl import _GTEEmbeddidng 34 | self.model_name = model_name 35 | self.batch_size = batch_size 36 | self.normalize_embeddings = normalize_embeddings 37 | self.device = device 38 | self.use_fp16 = use_fp16 39 | self.dimensions = dimensions 40 | 41 | if "dimension" in kwargs: 42 | self.dimensions = kwargs["dimension"] 43 | kwargs.pop("dimension") 44 | 45 | if device == "cpu" and use_fp16 is True: 46 | logger.warning( 47 | "Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. " 48 | ) 49 | 50 | _model_config = dict( 51 | { 52 | "model_name": model_name, 53 | "device": device, 54 | "normalized": normalize_embeddings, 55 | "use_fp16": use_fp16, 56 | }, 57 | **kwargs, 58 | ) 59 | _encode_config = { 60 | "batch_size": batch_size, 61 | "return_dense": return_dense, 62 | "return_sparse": return_sparse, 63 | } 64 | self._model_config = _model_config 65 | self._encode_config = _encode_config 66 | 67 | self.model = _GTEEmbeddidng(**self._model_config) 68 | 69 | _encode_config["dimension"] = self.dimensions 70 | 71 | if self.dimensions is None: 72 | self.dimensions = self.model.model.config.hidden_size 73 | 74 | def __call__(self, texts: List[str]) -> Dict: 75 | return self._encode(texts) 76 | 77 | @property 78 | def dim(self) -> Dict: 79 | return { 80 | "dense": self.dimensions, 81 | "sparse": len(self.model.tokenizer), 82 | } 83 | 84 | def _encode(self, texts: List[str]) -> Dict: 85 | from scipy.sparse import csr_array 86 | 87 | output = self.model.encode(texts=texts, **self._encode_config) 88 | results = {} 89 | if self._encode_config["return_dense"] is True: 90 | results["dense"] = list(output["dense_embeddings"]) 91 | if self._encode_config["return_sparse"] is True: 92 | sparse_dim = self.dim["sparse"] 93 | results["sparse"] = [] 94 | for sparse_vec in output["token_weights"]: 95 | indices = [int(k) for k in sparse_vec] 96 | values = list(sparse_vec.values()) 97 | row_indices = [0] * len(indices) 98 | csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) 99 | results["sparse"].append(csr) 100 | results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr() 101 | return results 102 | 103 | def encode_queries(self, queries: List[str]) -> Dict: 104 | return self._encode(queries) 105 | 106 | def encode_documents(self, documents: List[str]) -> Dict: 107 | return self._encode(documents) 108 | -------------------------------------------------------------------------------- /src/pymilvus/model/hybrid/mgte_embedding/gte_impl.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from collections import defaultdict 3 | 4 | from pymilvus.model.utils import import_torch, import_scipy, import_transformers 5 | 6 | import_torch() 7 | import_scipy() 8 | import_transformers() 9 | 10 | import numpy as np 11 | import torch 12 | from transformers import AutoModelForTokenClassification, AutoTokenizer 13 | from transformers.utils import is_torch_npu_available 14 | 15 | 16 | class _GTEEmbeddidng(torch.nn.Module): 17 | def __init__(self, 18 | model_name: str = None, 19 | normalized: bool = True, 20 | use_fp16: bool = True, 21 | device: str = None 22 | ): 23 | super().__init__() 24 | self.normalized = normalized 25 | if device: 26 | self.device = torch.device(device) 27 | else: 28 | if torch.cuda.is_available(): 29 | self.device = torch.device("cuda") 30 | elif torch.backends.mps.is_available(): 31 | self.device = torch.device("mps") 32 | elif is_torch_npu_available(): 33 | self.device = torch.device("npu") 34 | else: 35 | self.device = torch.device("cpu") 36 | use_fp16 = False 37 | self.use_fp16 = use_fp16 38 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 39 | self.model = AutoModelForTokenClassification.from_pretrained( 40 | model_name, trust_remote_code=True, torch_dtype=torch.float16 if self.use_fp16 else None 41 | ) 42 | self.vocab_size = self.model.config.vocab_size 43 | self.model.to(self.device) 44 | 45 | def _process_token_weights(self, token_weights: np.ndarray, input_ids: list): 46 | # conver to dict 47 | result = defaultdict(int) 48 | unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, 49 | self.tokenizer.unk_token_id]) 50 | for w, idx in zip(token_weights, input_ids): 51 | if idx not in unused_tokens and w > 0: 52 | token = int(idx) 53 | if w > result[token]: 54 | result[token] = w 55 | return result 56 | 57 | @torch.no_grad() 58 | def encode(self, 59 | texts: None, 60 | dimension: int = None, 61 | max_length: int = 8192, 62 | batch_size: int = 16, 63 | return_dense: bool = True, 64 | return_sparse: bool = False): 65 | if isinstance(texts, str): 66 | texts = [texts] 67 | num_texts = len(texts) 68 | all_dense_vecs = [] 69 | all_token_weights = [] 70 | for n, i in enumerate(range(0, num_texts, batch_size)): 71 | batch = texts[i: i + batch_size] 72 | resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse) 73 | if return_dense: 74 | all_dense_vecs.append(resulst['dense_embeddings']) 75 | if return_sparse: 76 | all_token_weights.extend(resulst['token_weights']) 77 | all_dense_vecs = torch.cat(all_dense_vecs, dim=0) 78 | return { 79 | "dense_embeddings": all_dense_vecs, 80 | "token_weights": all_token_weights 81 | } 82 | 83 | @torch.no_grad() 84 | def _encode(self, 85 | texts: Dict[str, torch.Tensor] = None, 86 | dimension: int = None, 87 | max_length: int = 1024, 88 | batch_size: int = 16, 89 | return_dense: bool = True, 90 | return_sparse: bool = False): 91 | 92 | text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length) 93 | text_input = {k: v.to(self.model.device) for k,v in text_input.items()} 94 | model_out = self.model(**text_input, return_dict=True) 95 | 96 | output = {} 97 | if return_dense: 98 | dense_vecs = model_out.last_hidden_state[:, 0, :dimension] 99 | if self.normalized: 100 | dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1) 101 | output['dense_embeddings'] = dense_vecs 102 | if return_sparse: 103 | token_weights = torch.relu(model_out.logits).squeeze(-1) 104 | token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(), 105 | text_input['input_ids'].cpu().numpy().tolist())) 106 | output['token_weights'] = token_weights 107 | 108 | return output -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/__init__.py: -------------------------------------------------------------------------------- 1 | from pymilvus.model.reranker.cohere import CohereRerankFunction 2 | from pymilvus.model.reranker.bgereranker import BGERerankFunction 3 | from pymilvus.model.reranker.voyageai import VoyageRerankFunction 4 | from pymilvus.model.reranker.cross_encoder import CrossEncoderRerankFunction 5 | from pymilvus.model.reranker.jinaai import JinaRerankFunction 6 | from pymilvus.model.reranker.tei import TEIRerankFunction 7 | 8 | __all__ = [ 9 | "CohereRerankFunction", 10 | "BGERerankFunction", 11 | "VoyageRerankFunction", 12 | "CrossEncoderRerankFunction", 13 | "JinaRerankFunction", 14 | "TEIRerankFunction", 15 | ] 16 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/bgereranker.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union 2 | 3 | from pymilvus.model.base import BaseRerankFunction, RerankResult 4 | from pymilvus.model.utils import import_FlagEmbedding, import_transformers 5 | 6 | 7 | class BGERerankFunction(BaseRerankFunction): 8 | def __init__( 9 | self, 10 | model_name: str = "BAAI/bge-reranker-v2-m3", 11 | use_fp16: bool = True, 12 | batch_size: int = 32, 13 | normalize: bool = True, 14 | device: Optional[Union[str, List]] = None, 15 | query_max_length: int = 256, 16 | max_length: int = 512, 17 | **kwargs: Any, 18 | ): 19 | import_FlagEmbedding() 20 | import_transformers() 21 | from FlagEmbedding import FlagAutoReranker 22 | 23 | self.model_name = model_name 24 | self.batch_size = batch_size 25 | self.normalize = normalize 26 | self.device = device 27 | 28 | if "devices" in kwargs: 29 | device = kwargs["devices"] 30 | kwargs.pop("devices") 31 | 32 | _model_config = dict( 33 | { 34 | "model_name_or_path": model_name, 35 | "batch_size": batch_size, 36 | "use_fp16": use_fp16, 37 | "devices": device, 38 | "max_length": max_length, 39 | "query_max_length": query_max_length, 40 | "normalize": normalize, 41 | }, 42 | **kwargs, 43 | ) 44 | self.reranker = FlagAutoReranker.from_finetuned(**_model_config) 45 | 46 | 47 | def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]: 48 | return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] 49 | 50 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 51 | query_document_pairs = [[query, doc] for doc in documents] 52 | batched_texts = self._batchify(documents, self.batch_size) 53 | scores = [] 54 | for batched_text in batched_texts: 55 | query_document_pairs = [[query, text] for text in batched_text] 56 | batch_score = self.reranker.compute_score( 57 | query_document_pairs, normalize=self.normalize 58 | ) 59 | scores.extend(batch_score) 60 | ranked_order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) 61 | 62 | if top_k: 63 | ranked_order = ranked_order[:top_k] 64 | 65 | results = [] 66 | for index in ranked_order: 67 | results.append(RerankResult(text=documents[index], score=scores[index], index=index)) 68 | return results 69 | 70 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/cohere.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pymilvus.model.base import BaseRerankFunction, RerankResult 4 | from pymilvus.model.utils import import_cohere 5 | 6 | 7 | class CohereRerankFunction(BaseRerankFunction): 8 | def __init__(self, model_name: str = "rerank-english-v3.0", api_key: Optional[str] = None, return_documents=True, **kwargs): 9 | import_cohere() 10 | import cohere 11 | 12 | self.model_name = model_name 13 | self.client = cohere.ClientV2(api_key) 14 | self.rerank_config = {"return_documents": return_documents, **kwargs} 15 | 16 | 17 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 18 | co_results = self.client.rerank( 19 | query=query, documents=documents, top_n=top_k, model=self.model_name, **self.rerank_config) 20 | results = [] 21 | for co_result in co_results.results: 22 | document_text = "" 23 | if self.rerank_config["return_documents"] is True: 24 | document_text = co_result.document.text 25 | results.append( 26 | RerankResult( 27 | text=document_text, 28 | score=co_result.relevance_score, 29 | index=co_result.index, 30 | ) 31 | ) 32 | return results 33 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/cross_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from pymilvus.model.base import BaseRerankFunction, RerankResult 4 | from pymilvus.model.utils import import_sentence_transformers 5 | 6 | 7 | class CrossEncoderRerankFunction(BaseRerankFunction): 8 | def __init__( 9 | self, 10 | model_name: str = "", 11 | device: str = "", 12 | batch_size: int = 32, 13 | activation_fct: Any = None, 14 | **kwargs, 15 | ): 16 | import_sentence_transformers() 17 | import sentence_transformers 18 | 19 | if sentence_transformers is None: 20 | error_message = "sentence_transformer is not installed." 21 | raise ImportError(error_message) 22 | self.model_name = model_name 23 | self.device = device 24 | self.batch_size = batch_size 25 | self.activation_fct = activation_fct 26 | self.model = sentence_transformers.cross_encoder.CrossEncoder( 27 | model_name=model_name, device=self.device, default_activation_function=activation_fct 28 | ) 29 | 30 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 31 | query_document_pairs = [[query, doc] for doc in documents] 32 | scores = self.model.predict(query_document_pairs, batch_size=self.batch_size) 33 | 34 | ranked_order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) 35 | if top_k: 36 | ranked_order = ranked_order[:top_k] 37 | 38 | results = [] 39 | for index in ranked_order: 40 | results.append(RerankResult(text=documents[index], score=scores[index], index=index)) 41 | return results 42 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/jinaai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import requests 5 | 6 | from pymilvus.model.base import BaseRerankFunction, RerankResult 7 | 8 | API_URL = "https://api.jina.ai/v1/rerank" 9 | 10 | 11 | class JinaRerankFunction(BaseRerankFunction): 12 | def __init__(self, model_name: str = "jina-reranker-v2-base-multilingual", api_key: Optional[str] = None): 13 | if api_key is None: 14 | if "JINAAI_API_KEY" in os.environ and os.environ["JINAAI_API_KEY"]: 15 | self.api_key = os.environ["JINAAI_API_KEY"] 16 | else: 17 | error_message = ( 18 | "Did not find api_key, please add an environment variable" 19 | " `JINAAI_API_KEY` which contains it, or pass" 20 | " `api_key` as a named parameter." 21 | ) 22 | raise ValueError(error_message) 23 | else: 24 | self.api_key = api_key 25 | self.model_name = model_name 26 | self._session = requests.Session() 27 | self._session.headers.update( 28 | {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} 29 | ) 30 | self.model_name = model_name 31 | 32 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 33 | resp = self._session.post( # type: ignore[assignment] 34 | API_URL, 35 | json={ 36 | "query": query, 37 | "documents": documents, 38 | "model": self.model_name, 39 | "top_n": top_k, 40 | }, 41 | ).json() 42 | if "results" not in resp: 43 | raise RuntimeError(resp["detail"]) 44 | 45 | results = [] 46 | for res in resp["results"]: 47 | results.append( 48 | RerankResult( 49 | text=res["document"]["text"], score=res["relevance_score"], index=res["index"] 50 | ) 51 | ) 52 | return results 53 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/tei.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import requests 4 | 5 | from pymilvus.model.base import BaseRerankFunction, RerankResult 6 | 7 | 8 | class TEIRerankFunction(BaseRerankFunction): 9 | def __init__(self, api_url: str): 10 | self.api_url = api_url + "/rerank" 11 | self._session = requests.Session() 12 | 13 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 14 | resp = self._session.post( # type: ignore[assignment] 15 | self.api_url, 16 | json={ 17 | "query": query, 18 | "return_text": True, 19 | "texts": documents, 20 | }, 21 | ).json() 22 | if "error" in resp: 23 | raise RuntimeError(resp["error"]) 24 | 25 | results = [] 26 | for res in resp[:5]: 27 | results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"])) 28 | return results 29 | -------------------------------------------------------------------------------- /src/pymilvus/model/reranker/voyageai.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pymilvus.model.base import BaseRerankFunction, RerankResult 4 | from pymilvus.model.utils import import_voyageai 5 | 6 | 7 | class VoyageRerankFunction(BaseRerankFunction): 8 | def __init__(self, model_name: str = "rerank-2", api_key: Optional[str] = None): 9 | import_voyageai() 10 | import voyageai 11 | 12 | self.model_name = model_name 13 | self.client = voyageai.Client(api_key=api_key) 14 | 15 | def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 16 | vo_results = self.client.rerank(query, documents, model=self.model_name, top_k=top_k) 17 | results = [] 18 | for vo_result in vo_results.results: 19 | results.append( 20 | RerankResult( 21 | text=vo_result.document, score=vo_result.relevance_score, index=vo_result.index 22 | ) 23 | ) 24 | return results 25 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from pymilvus.model.sparse.bm25 import BM25EmbeddingFunction 2 | from pymilvus.model.sparse.splade import SpladeEmbeddingFunction 3 | 4 | __all__ = ["SpladeEmbeddingFunction", "BM25EmbeddingFunction"] 5 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/bm25/__init__.py: -------------------------------------------------------------------------------- 1 | from pymilvus.model.sparse.bm25.bm25 import BM25EmbeddingFunction 2 | from pymilvus.model.sparse.bm25.tokenizers import Analyzer, build_analyzer_from_yaml, build_default_analyzer 3 | 4 | __all__ = [ 5 | "BM25EmbeddingFunction", 6 | "Analyzer", 7 | "build_analyzer_from_yaml", 8 | "build_default_analyzer", 9 | ] 10 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/bm25/bm25.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file incorporates components from the 'rank_bm25' project by Dorian Brown: 3 | https://github.com/dorianbrown/rank_bm25 4 | Specifically, the rank_bm25.py file. 5 | 6 | The incorporated components are licensed under the Apache License, Version 2.0 (the "License"); 7 | you may not use these components except in compliance with the License. 8 | You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import json 18 | import logging 19 | import math 20 | from collections import defaultdict 21 | from multiprocessing import Pool 22 | from pathlib import Path 23 | from typing import Dict, List, Optional 24 | 25 | import requests 26 | from scipy.sparse import csr_array, vstack 27 | import numpy as np 28 | 29 | from pymilvus.model.base import BaseEmbeddingFunction 30 | from pymilvus.model.sparse.bm25.tokenizers import Analyzer, build_default_analyzer 31 | 32 | logger = logging.getLogger(__name__) 33 | logger.setLevel(logging.DEBUG) 34 | console_handler = logging.StreamHandler() 35 | console_handler.setLevel(logging.INFO) 36 | logger.addHandler(console_handler) 37 | 38 | 39 | class BM25EmbeddingFunction(BaseEmbeddingFunction): 40 | def __init__( 41 | self, 42 | analyzer: Analyzer = None, 43 | corpus: Optional[List] = None, 44 | k1: float = 1.5, 45 | b: float = 0.75, 46 | epsilon: float = 0.25, 47 | num_workers: int = 1, 48 | ): 49 | if analyzer is None: 50 | analyzer = build_default_analyzer(language="en") 51 | self.analyzer = analyzer 52 | self.corpus_size = 0 53 | self.avgdl = 0 54 | self.idf = {} 55 | self.k1 = k1 56 | self.b = b 57 | self.epsilon = epsilon 58 | self.num_workers = num_workers 59 | 60 | if analyzer and corpus is not None: 61 | self.fit(corpus) 62 | 63 | def _calc_term_indices(self): 64 | for index, word in enumerate(self.idf): 65 | self.idf[word][1] = index 66 | 67 | def _compute_statistics(self, corpus: List[str]): 68 | term_document_frequencies = defaultdict(int) 69 | total_word_count = 0 70 | for document in corpus: 71 | total_word_count += len(document) 72 | 73 | frequencies = defaultdict(int) 74 | for word in document: 75 | frequencies[word] += 1 76 | 77 | for word, _ in frequencies.items(): 78 | term_document_frequencies[word] += 1 79 | self.corpus_size += 1 80 | self.avgdl = total_word_count / self.corpus_size 81 | return term_document_frequencies 82 | 83 | def _tokenize_corpus(self, corpus: List[str]): 84 | if self.num_workers == 1: 85 | return [self.analyzer(text) for text in corpus] 86 | pool = Pool(self.num_workers) 87 | return pool.map(self.analyzer, corpus) 88 | 89 | def _calc_idf(self, term_document_frequencies: Dict): 90 | # collect idf sum to calculate an average idf for epsilon value 91 | idf_sum = 0 92 | # collect words with negative idf to set them a special epsilon value. 93 | # idf can be negative if word is contained in more than half of documents 94 | negative_idfs = [] 95 | for word, freq in term_document_frequencies.items(): 96 | idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) 97 | if word not in self.idf: 98 | self.idf[word] = [0.0, 0] 99 | self.idf[word][0] = idf 100 | idf_sum += idf 101 | if idf < 0: 102 | negative_idfs.append(word) 103 | self.average_idf = idf_sum / len(self.idf) 104 | 105 | eps = self.epsilon * self.average_idf 106 | for word in negative_idfs: 107 | self.idf[word][0] = eps 108 | 109 | def _rebuild(self, corpus: List[str]): 110 | self._clear() 111 | corpus = self._tokenize_corpus(corpus) 112 | term_document_frequencies = self._compute_statistics(corpus) 113 | self._calc_idf(term_document_frequencies) 114 | self._calc_term_indices() 115 | 116 | def _clear(self): 117 | self.corpus_size = 0 118 | # idf records the (value, index) 119 | self.idf = defaultdict(list) 120 | 121 | @property 122 | def dim(self): 123 | return len(self.idf) 124 | 125 | def fit(self, corpus: List[str]): 126 | self._rebuild(corpus) 127 | 128 | def _encode_query(self, query: str) -> csr_array: 129 | terms = self.analyzer(query) 130 | values, rows, cols = [], [], [] 131 | for term in terms: 132 | if term in self.idf: 133 | values.append(self.idf[term][0]) 134 | rows.append(0) 135 | cols.append(self.idf[term][1]) 136 | return csr_array((values, (rows, cols)), shape=(1, len(self.idf))).astype(np.float32) 137 | 138 | def _encode_document(self, doc: str) -> csr_array: 139 | terms = self.analyzer(doc) 140 | frequencies = defaultdict(int) 141 | doc_len = len(terms) 142 | term_set = set() 143 | for term in terms: 144 | frequencies[term] += 1 145 | term_set.add(term) 146 | values, rows, cols = [], [], [] 147 | for term in term_set: 148 | if term in self.idf: 149 | term_freq = frequencies[term] 150 | value = ( 151 | term_freq 152 | * (self.k1 + 1) 153 | / (term_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)) 154 | ) 155 | rows.append(0) 156 | cols.append(self.idf[term][1]) 157 | values.append(value) 158 | return csr_array((values, (rows, cols)), shape=(1, len(self.idf))).astype(np.float32) 159 | 160 | def encode_queries(self, queries: List[str]) -> csr_array: 161 | sparse_embs = [self._encode_query(query) for query in queries] 162 | return vstack(sparse_embs).tocsr() 163 | 164 | def __call__(self, texts: List[str]) -> csr_array: 165 | error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'." 166 | raise ValueError(error_message) 167 | 168 | def encode_documents(self, documents: List[str]) -> csr_array: 169 | sparse_embs = [self._encode_document(document) for document in documents] 170 | return vstack(sparse_embs).tocsr() 171 | 172 | def save(self, path: str): 173 | bm25_params = {} 174 | bm25_params["version"] = "v1" 175 | bm25_params["corpus_size"] = self.corpus_size 176 | bm25_params["avgdl"] = self.avgdl 177 | bm25_params["idf_word"] = [None for _ in range(len(self.idf))] 178 | bm25_params["idf_value"] = [None for _ in range(len(self.idf))] 179 | for word, values in self.idf.items(): 180 | bm25_params["idf_word"][values[1]] = word 181 | bm25_params["idf_value"][values[1]] = values[0] 182 | 183 | bm25_params["k1"] = self.k1 184 | bm25_params["b"] = self.b 185 | bm25_params["epsilon"] = self.epsilon 186 | 187 | with Path(path).open("w") as json_file: 188 | json.dump(bm25_params, json_file) 189 | 190 | def load(self, path: Optional[str] = None): 191 | default_meta_filename = "bm25_msmarco_v1.json" 192 | default_meta_url = "https://github.com/milvus-io/pymilvus-assets/releases/download/v0.1-bm25v1/bm25_msmarco_v1.json" 193 | if path is None: 194 | logger.info(f"path is None, using default {default_meta_filename}.") 195 | if not Path(default_meta_filename).exists(): 196 | try: 197 | logger.info( 198 | f"{default_meta_filename} not found, start downloading from {default_meta_url} to ./{default_meta_filename}." 199 | ) 200 | response = requests.get(default_meta_url, timeout=30) 201 | response.raise_for_status() 202 | with Path(default_meta_filename).open("wb") as f: 203 | f.write(response.content) 204 | logger.info(f"{default_meta_filename} has been downloaded successfully.") 205 | except requests.exceptions.RequestException as e: 206 | error_message = f"Failed to download the file: {e}" 207 | raise RuntimeError(error_message) from e 208 | path = default_meta_filename 209 | try: 210 | with Path(path).open() as json_file: 211 | bm25_params = json.load(json_file) 212 | except OSError as e: 213 | error_message = f"Error opening file {path}: {e}" 214 | raise RuntimeError(error_message) from e 215 | self.corpus_size = bm25_params["corpus_size"] 216 | self.avgdl = bm25_params["avgdl"] 217 | self.idf = {} 218 | for i in range(len(bm25_params["idf_word"])): 219 | self.idf[bm25_params["idf_word"][i]] = [bm25_params["idf_value"][i], i] 220 | self.k1 = bm25_params["k1"] 221 | self.b = bm25_params["b"] 222 | self.epsilon = bm25_params["epsilon"] 223 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/bm25/lang.yaml: -------------------------------------------------------------------------------- 1 | en: 2 | tokenizer: 3 | class: StandardTokenizer 4 | params: {} 5 | filters: 6 | - class: LowercaseFilter 7 | params: {} 8 | - class: PunctuationFilter 9 | params: {} 10 | - class: StopwordFilter 11 | params: 12 | language: 'english' 13 | - class: StemmingFilter 14 | params: 15 | language: 'english' 16 | de: 17 | tokenizer: 18 | class: StandardTokenizer 19 | params: {} 20 | filters: 21 | - class: LowercaseFilter 22 | params: {} 23 | - class: StopwordFilter 24 | params: 25 | language: 'german' 26 | - class: PunctuationFilter 27 | params: {} 28 | - class: StemmingFilter 29 | params: 30 | language: 'german' 31 | fr: 32 | tokenizer: 33 | class: StandardTokenizer 34 | params: {} 35 | filters: 36 | - class: LowercaseFilter 37 | params: {} 38 | - class: PunctuationFilter 39 | params: {} 40 | - class: StopwordFilter 41 | params: 42 | language: 'french' 43 | - class: StemmingFilter 44 | params: 45 | language: 'french' 46 | ru: 47 | tokenizer: 48 | class: StandardTokenizer 49 | params: {} 50 | filters: 51 | - class: LowercaseFilter 52 | params: {} 53 | - class: PunctuationFilter 54 | params: {} 55 | - class: StopwordFilter 56 | params: 57 | language: 'russian' 58 | - class: StemmingFilter 59 | params: 60 | language: 'russian' 61 | sp: 62 | tokenizer: 63 | class: StandardTokenizer 64 | params: {} 65 | filters: 66 | - class: LowercaseFilter 67 | params: {} 68 | - class: PunctuationFilter 69 | params: 70 | extras: '¡¿' 71 | - class: StopwordFilter 72 | params: 73 | language: 'spanish' 74 | - class: StemmingFilter 75 | params: 76 | language: 'spanish' 77 | it: 78 | tokenizer: 79 | class: StandardTokenizer 80 | params: {} 81 | filters: 82 | - class: LowercaseFilter 83 | params: {} 84 | - class: PunctuationFilter 85 | params: {} 86 | - class: StopwordFilter 87 | params: 88 | language: 'italian' 89 | - class: StemmingFilter 90 | params: 91 | language: 'italian' 92 | pt: 93 | tokenizer: 94 | class: StandardTokenizer 95 | params: {} 96 | filters: 97 | - class: LowercaseFilter 98 | params: {} 99 | - class: PunctuationFilter 100 | params: {} 101 | - class: StopwordFilter 102 | params: 103 | language: 'portuguese' 104 | - class: StemmingFilter 105 | params: 106 | language: 'portuguese' 107 | zh: 108 | tokenizer: 109 | class: JiebaTokenizer 110 | params: {} 111 | filters: 112 | - class: StopwordFilter 113 | params: 114 | language: 'chinese' 115 | - class: PunctuationFilter 116 | params: 117 | extras: ' 、"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰–—‘’‛“”„‟…‧﹏﹑﹔·.!?。。' 118 | jp: 119 | tokenizer: 120 | class: MecabTokenizer 121 | params: {} 122 | preprocessors: 123 | - class: CharacterfilterPreprocessor 124 | params: 125 | chars_to_replace: ['、', '。', '「', '」', '『', '』', '【', '】', '(', ')', '{', '}', '・', ':', ';', '!', '?', 'ー', '〜', '…', '‥', '[', ']'] 126 | filters: 127 | - class: StopwordFilter 128 | params: {} 129 | - class: PunctuationFilter 130 | params: {} 131 | kr: 132 | tokenizer: 133 | class: KonlpyTokenizer 134 | params: {} 135 | filters: 136 | - class: StopwordFilter 137 | params: {} 138 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/bm25/tokenizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import string 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from typing import Any, Dict, List, Match, Optional, Type 7 | from pymilvus.model.utils import import_nltk, import_jieba, import_mecab, import_konlpy, import_unidic_lite, import_kiwi 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | _class_registry = {} 14 | 15 | 16 | def register_class(register_as: str): 17 | def decorator(cls: Type[Any]): 18 | _class_registry[register_as] = cls 19 | return cls 20 | 21 | return decorator 22 | 23 | 24 | class Preprocessor: 25 | def apply(self, text: str): 26 | error_message = "Each preprocessor must implement its 'apply' method." 27 | raise NotImplementedError(error_message) 28 | 29 | 30 | @register_class("CharacterfilterPreprocessor") 31 | class CharacterfilterPreprocessor: 32 | def __init__(self, chars_to_replace: str): 33 | self.replacement_table = str.maketrans({char: " " for char in chars_to_replace}) 34 | 35 | def apply(self, text: str): 36 | return text.translate(self.replacement_table) 37 | 38 | 39 | @register_class("ReplacePreprocessor") 40 | class ReplacePreprocessor: 41 | def __init__(self, replacement_mapping: Dict[str, str]): 42 | self.replacement_mapping = replacement_mapping 43 | self.pattern = re.compile("|".join(map(re.escape, replacement_mapping.keys()))) 44 | 45 | def _replacement_function(self, match: Match): 46 | return self.replacement_mapping[match.group(0)] 47 | 48 | def apply(self, text: str): 49 | return self.pattern.sub(self._replacement_function, text) 50 | 51 | 52 | @register_class("StandardTokenizer") 53 | class StandardTokenizer: 54 | def __init__(self): 55 | import_nltk() 56 | import nltk 57 | from nltk import word_tokenize 58 | try: 59 | word_tokenize("this is a simple test.") 60 | except LookupError: 61 | nltk.download("punkt_tab") 62 | def tokenize(self, text: str): 63 | from nltk import word_tokenize 64 | return word_tokenize(text) 65 | 66 | 67 | class TextFilter: 68 | def apply(self, tokens: List[str]): 69 | error_message = "Each filter must implement the 'apply' method." 70 | raise NotImplementedError(error_message) 71 | 72 | 73 | @register_class("LowercaseFilter") 74 | class LowercaseFilter(TextFilter): 75 | def apply(self, tokens: List[str]): 76 | return [token.lower() for token in tokens] 77 | 78 | 79 | @register_class("StopwordFilter") 80 | class StopwordFilter(TextFilter): 81 | def __init__(self, language: str = "english", stopword_list: Optional[List[str]] = None): 82 | import_nltk() 83 | import nltk 84 | from nltk.corpus import stopwords 85 | try: 86 | nltk.corpus.stopwords.words(language) 87 | except LookupError: 88 | nltk.download("stopwords") 89 | 90 | if stopword_list is None: 91 | stopword_list = [] 92 | self.stopwords = set(stopwords.words(language) + stopword_list) 93 | 94 | def apply(self, tokens: List[str]): 95 | return [token for token in tokens if token not in self.stopwords] 96 | 97 | 98 | @register_class("PunctuationFilter") 99 | class PunctuationFilter(TextFilter): 100 | def __init__(self, extras: str = ""): 101 | self.punctuation = set(string.punctuation + extras) 102 | 103 | def apply(self, tokens: List[str]): 104 | return [token for token in tokens if token not in self.punctuation] 105 | 106 | 107 | @register_class("StemmingFilter") 108 | class StemmingFilter(TextFilter): 109 | def __init__(self, language: str = "english"): 110 | import_nltk() 111 | from nltk.stem.snowball import SnowballStemmer 112 | self.stemmer = SnowballStemmer(language) 113 | 114 | def apply(self, tokens: List[str]): 115 | return [self.stemmer.stem(token) for token in tokens] 116 | 117 | 118 | class Tokenizer: 119 | def tokenize(self, text: str): 120 | error_message = "Each tokenizer must implement its 'tokenize' method." 121 | raise NotImplementedError(error_message) 122 | 123 | 124 | @register_class("JiebaTokenizer") 125 | class JiebaTokenizer(Tokenizer): 126 | def __init__(self): 127 | import_jieba() 128 | if find_spec("jieba") is None: 129 | error_message = "jieba is required for JiebaTokenizer but is not installed. Please install it using 'pip install jieba'." 130 | logger.error(error_message) 131 | raise ImportError(error_message) 132 | 133 | def tokenize(self, text: str): 134 | import jieba 135 | 136 | return jieba.lcut(text) 137 | 138 | 139 | @register_class("MecabTokenizer") 140 | class MecabTokenizer(Tokenizer): 141 | def __init__(self): 142 | import_unidic_lite() 143 | import_mecab() 144 | if find_spec("MeCab") is None: 145 | error_message = "MeCab is required for MecabTokenizer but is not installed. Please install it using 'pip install mecab-python3'." 146 | logger.error(error_message) 147 | raise ImportError(error_message) 148 | 149 | def tokenize(self, text: str): 150 | import MeCab 151 | 152 | wakati = MeCab.Tagger("-Owakati") 153 | return wakati.parse(text).split() 154 | 155 | 156 | @register_class("KonlpyTokenizer") 157 | class KonlpyTokenizer(Tokenizer): 158 | def __init__(self): 159 | import_konlpy() 160 | if find_spec("konlpy") is None: 161 | error_message = "konlpy is required for KonlpyTokenizer but is not installed. Please install it using 'pip install konlpy'." 162 | logger.error(error_message) 163 | raise ImportError(error_message) 164 | 165 | def tokenize(self, text: str): 166 | from konlpy.tag import Kkma 167 | 168 | return Kkma().nouns(text) 169 | 170 | 171 | @register_class("KiwiTokenizer") 172 | class KiwiTokenizer(Tokenizer): 173 | def __init__(self): 174 | import_kiwi() 175 | if find_spec("kiwipiepy") is None: 176 | error_message = "kiwipiepy is required for KiwiTokenizer but is not installed. Please install it using 'pip install kiwipiepy'." 177 | logger.error(error_message) 178 | raise ImportError(error_message) 179 | 180 | def tokenize(self, text: str): 181 | from kiwipiepy import Kiwi 182 | 183 | return [t.form for t in Kiwi().tokenize(text, normalize_coda=True)] 184 | 185 | 186 | class Analyzer: 187 | def __init__( 188 | self, 189 | name: str, 190 | tokenizer: Tokenizer, 191 | preprocessors: Optional[List[Preprocessor]] = None, 192 | filters: Optional[List[TextFilter]] = None, 193 | ): 194 | self.name = name 195 | self.tokenizer = tokenizer 196 | self.preprocessors = preprocessors 197 | self.filters = filters 198 | 199 | def __call__(self, text: str): 200 | for preprocessor in self.preprocessors: 201 | text = preprocessor.apply(text) 202 | tokens = self.tokenizer.tokenize(text) 203 | for _filter in self.filters: 204 | tokens = _filter.apply(tokens) 205 | return tokens 206 | 207 | 208 | def build_default_analyzer(language: str = "en"): 209 | default_config_path = Path(__file__).parent / "lang.yaml" 210 | return build_analyzer_from_yaml(default_config_path, language) 211 | 212 | 213 | def build_analyzer_from_yaml(filepath: str, name: str): 214 | import yaml 215 | 216 | with Path(filepath).open(encoding="utf-8") as file: 217 | config = yaml.safe_load(file) 218 | 219 | lang_config = config.get(name) 220 | if not lang_config: 221 | error_message = f"No configuration found {name}" 222 | raise ValueError(error_message) 223 | 224 | tokenizer_class_type = _class_registry[lang_config["tokenizer"]["class"]] 225 | tokenizer_params = lang_config["tokenizer"]["params"] 226 | 227 | tokenizer = tokenizer_class_type(**tokenizer_params) 228 | preprocessors = [] 229 | filters = [] 230 | if "preprocessors" in lang_config: 231 | preprocessors = [ 232 | _class_registry[filter_config["class"]](**filter_config["params"]) 233 | for filter_config in lang_config["preprocessors"] 234 | ] 235 | if "filters" in lang_config: 236 | filters = [ 237 | _class_registry[filter_config["class"]](**filter_config["params"]) 238 | for filter_config in lang_config["filters"] 239 | ] 240 | 241 | return Analyzer(name=name, tokenizer=tokenizer, preprocessors=preprocessors, filters=filters) 242 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/splade.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code is adapted from/inspired by the 'neural-cherche' project: 3 | https://github.com/raphaelsty/neural-cherche 4 | Specifically, neural-cherche/neural_cherche/models/splade.py 5 | 6 | MIT License 7 | 8 | Copyright (c) 2023 Raphael Sourty 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | import logging 30 | from typing import List, Optional 31 | 32 | from scipy.sparse import csr_array 33 | 34 | from pymilvus.model.base import BaseEmbeddingFunction 35 | 36 | 37 | logger = logging.getLogger(__name__) 38 | logger.setLevel(logging.DEBUG) 39 | 40 | 41 | class SpladeEmbeddingFunction(BaseEmbeddingFunction): 42 | model_name: str 43 | 44 | def __init__( 45 | self, 46 | model_name: str = "naver/splade-cocondenser-ensembledistil", 47 | batch_size: int = 32, 48 | query_instruction: str = "", 49 | doc_instruction: str = "", 50 | device: Optional[str] = "cpu", 51 | k_tokens_query: Optional[int] = None, 52 | k_tokens_document: Optional[int] = None, 53 | **kwargs, 54 | ): 55 | from .splade_embedding.splade_impl import _SpladeImplementation 56 | self.model_name = model_name 57 | 58 | _model_config = dict( 59 | {"model_name_or_path": model_name, "batch_size": batch_size, "device": device}, 60 | **kwargs, 61 | ) 62 | self._model_config = _model_config 63 | self.model = _SpladeImplementation(**self._model_config) 64 | self.device = device 65 | self.k_tokens_query = k_tokens_query 66 | self.k_tokens_document = k_tokens_document 67 | self.query_instruction = query_instruction 68 | self.doc_instruction = doc_instruction 69 | 70 | def __call__(self, texts: List[str]) -> csr_array: 71 | return self._encode(texts, None) 72 | 73 | def encode_documents(self, documents: List[str]) -> csr_array: 74 | return self._encode( 75 | [self.doc_instruction + document for document in documents], self.k_tokens_document, 76 | ) 77 | 78 | def _encode(self, texts: List[str], k_tokens: int) -> csr_array: 79 | return self.model.forward(texts, k_tokens=k_tokens) 80 | 81 | def encode_queries(self, queries: List[str]) -> csr_array: 82 | return self._encode( 83 | [self.query_instruction + query for query in queries], self.k_tokens_query, 84 | ) 85 | 86 | @property 87 | def dim(self) -> int: 88 | return len(self.model.tokenizer) 89 | 90 | def _encode_query(self, query: str) -> csr_array: 91 | return self.model.forward([self.query_instruction + query], k_tokens=self.k_tokens_query)[0] 92 | 93 | def _encode_document(self, document: str) -> csr_array: 94 | return self.model.forward( 95 | [self.doc_instruction + document], k_tokens=self.k_tokens_document 96 | )[0] 97 | 98 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/splade_embedding/splade_impl.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code is adapted from/inspired by the 'neural-cherche' project: 3 | https://github.com/raphaelsty/neural-cherche 4 | Specifically, neural-cherche/neural_cherche/models/splade.py 5 | 6 | MIT License 7 | 8 | Copyright (c) 2023 Raphael Sourty 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | """ 28 | 29 | from typing import Dict, List, Optional 30 | 31 | from pymilvus.model.utils import import_transformers, import_scipy, import_torch 32 | from pymilvus.model.sparse.utils import stack_sparse_embeddings 33 | 34 | import_torch() 35 | import_scipy() 36 | import_transformers() 37 | 38 | import torch 39 | from scipy.sparse import csr_array 40 | from transformers import AutoModelForMaskedLM, AutoTokenizer 41 | 42 | 43 | class _SpladeImplementation: 44 | def __init__( 45 | self, 46 | model_name_or_path: Optional[str] = None, 47 | device: Optional[str] = None, 48 | batch_size: int = 32, 49 | **kwargs, 50 | ): 51 | self.device = device 52 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 53 | self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **kwargs) 54 | self.model.to(self.device) 55 | self.batch_size = batch_size 56 | 57 | self.relu = torch.nn.ReLU() 58 | self.relu.to(self.device) 59 | self.model.config.output_hidden_states = True 60 | 61 | def _encode(self, texts: List[str]): 62 | encoded_input = self.tokenizer.batch_encode_plus( 63 | texts, 64 | truncation=True, 65 | max_length=self.tokenizer.model_max_length, 66 | return_tensors="pt", 67 | add_special_tokens=True, 68 | padding=True, 69 | ) 70 | encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()} 71 | output = self.model(**encoded_input) 72 | return output.logits 73 | 74 | def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]: 75 | return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] 76 | 77 | def forward(self, texts: List[str], k_tokens: int) -> csr_array: 78 | with torch.no_grad(): 79 | batched_texts = self._batchify(texts, self.batch_size) 80 | sparse_embs = [] 81 | for batch_texts in batched_texts: 82 | logits = self._encode(texts=batch_texts) 83 | activations = self._get_activation(logits=logits) 84 | if k_tokens is None: 85 | nonzero_indices = torch.nonzero(activations["sparse_activations"]) 86 | activations["activations"] = nonzero_indices 87 | else: 88 | activations = self._update_activations(**activations, k_tokens=k_tokens) 89 | batch_csr = self._convert_to_csr_array(activations) 90 | sparse_embs.extend(batch_csr) 91 | return stack_sparse_embeddings(sparse_embs).tocsr() 92 | 93 | def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]: 94 | return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)} 95 | 96 | def _update_activations(self, sparse_activations: torch.Tensor, k_tokens: int) -> torch.Tensor: 97 | activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices 98 | 99 | # Set value of max sparse_activations which are not in top k to 0. 100 | sparse_activations = sparse_activations * torch.zeros( 101 | (sparse_activations.shape[0], sparse_activations.shape[1]), 102 | dtype=int, 103 | device=self.device, 104 | ).scatter_(dim=1, index=activations.long(), value=1) 105 | 106 | activations = torch.cat( 107 | ( 108 | torch.arange(activations.shape[0], device=activations.device) 109 | .repeat_interleave(activations.shape[1]) 110 | .reshape(-1, 1), 111 | activations.reshape((-1, 1)), 112 | ), 113 | dim=1, 114 | ) 115 | 116 | return { 117 | "activations": activations, 118 | "sparse_activations": sparse_activations, 119 | } 120 | 121 | def _filter_activations( 122 | self, activations: torch.Tensor, k_tokens: int, **kwargs 123 | ) -> torch.Tensor: 124 | _, activations = torch.topk(input=activations, k=k_tokens, dim=1, **kwargs) 125 | return activations 126 | 127 | def _convert_to_csr_array(self, activations: Dict): 128 | 129 | values = ( 130 | activations["sparse_activations"][ 131 | activations["activations"][:, 0], activations["activations"][:, 1] 132 | ] 133 | .cpu() 134 | .detach() 135 | .numpy() 136 | ) 137 | 138 | row_indices = activations["activations"][:, 0].cpu().detach().numpy() 139 | col_indices = activations["activations"][:, 1].cpu().detach().numpy() 140 | return csr_array( 141 | (values.flatten(), (row_indices, col_indices)), 142 | shape=activations["sparse_activations"].shape, 143 | ) 144 | -------------------------------------------------------------------------------- /src/pymilvus/model/sparse/utils.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse import csr_array, vstack 2 | 3 | def stack_sparse_embeddings(sparse_embs): 4 | return vstack([sparse_emb.reshape((1,-1)) for sparse_emb in sparse_embs]) 5 | 6 | -------------------------------------------------------------------------------- /src/pymilvus/model/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "import_openai", 3 | "import_sentence_transformers", 4 | "import_FlagEmbedding", 5 | "import_nltk", 6 | "import_transformers", 7 | "import_jieba", 8 | "import_konlpy", 9 | "import_kiwi", 10 | "import_mecab", 11 | "import_scipy", 12 | "import_protobuf", 13 | "import_unidic_lite", 14 | "import_cohere", 15 | "import_voyageai", 16 | "import_torch", 17 | "import_huggingface_hub", 18 | "import_mistralai", 19 | "import_nomic", 20 | "import_instructor", 21 | "import_datasets", 22 | "import_model2vec", 23 | "import_google", 24 | ] 25 | 26 | import importlib.util 27 | from typing import Optional 28 | 29 | from pymilvus.model.utils.dependency_control import prompt_install 30 | 31 | def import_google(): 32 | _check_library("google-genai", package="google-genai>=1.7.0") 33 | 34 | def import_openai(): 35 | _check_library("openai", package="openai>=1.12.0") 36 | 37 | def import_sentence_transformers(): 38 | _check_library("sentence_transformers", package="sentence-transformers") 39 | 40 | def import_FlagEmbedding(): 41 | _check_library("peft", package="peft") 42 | _check_library("FlagEmbedding", package="FlagEmbedding>=1.3.3") 43 | 44 | def import_nltk(): 45 | _check_library("nltk", package="nltk>=3.9.1") 46 | 47 | def import_transformers(): 48 | _check_library("transformers", package="transformers>=4.36.0") 49 | 50 | def import_jieba(): 51 | _check_library("jieba", package="jieba") 52 | 53 | def import_konlpy(): 54 | _check_library("konlpy", package="konlpy") 55 | 56 | def import_kiwi(): 57 | _check_library("kiwipiepy", package="kiwipiepy") 58 | 59 | def import_mecab(): 60 | _check_library("konlpy", package="mecab-python3") 61 | 62 | def import_scipy(): 63 | _check_library("scipy", package="scipy>=1.10.0") 64 | 65 | def import_protobuf(): 66 | _check_library("protobuf", package="protobuf==3.20.2") 67 | 68 | def import_unidic_lite(): 69 | _check_library("unidic-lite", package="unidic-lite") 70 | 71 | def import_cohere(): 72 | _check_library("cohere", "cohere>=5.10.0") 73 | 74 | def import_voyageai(): 75 | _check_library("voyageai", "voyageai>=0.2.0") 76 | 77 | def import_torch(): 78 | _check_library("torch", "torch") 79 | 80 | def import_huggingface_hub(): 81 | _check_library("huggingface_hub", package="huggingface-hub") 82 | 83 | def import_mistralai(): 84 | _check_library("mistralai", package="mistralai") 85 | 86 | def import_nomic(): 87 | _check_library("nomic", package="nomic") 88 | 89 | def import_instructor(): 90 | _check_library("InstructorEmbedding", package="InstructorEmbedding") 91 | 92 | def import_datasets(): 93 | _check_library("datasets", package="datasets") 94 | 95 | def import_model2vec(): 96 | _check_library("model2vec", package="model2vec") 97 | 98 | def _check_library(libname: str, prompt: bool = True, package: Optional[str] = None): 99 | is_avail = False 100 | if importlib.util.find_spec(libname): 101 | is_avail = True 102 | if not is_avail and prompt: 103 | prompt_install(package if package else libname) 104 | return is_avail 105 | -------------------------------------------------------------------------------- /src/pymilvus/model/utils/dependency_control.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | def prompt_install(package: str, warn: bool = False): # pragma: no cover 4 | cmd = f"pip install -q {package}" 5 | try: 6 | if warn and input(f"Install {package}? Y/n: ") != "Y": 7 | raise ModuleNotFoundError(f"No module named {package}") 8 | print(f"start to install package: {package}") 9 | subprocess.check_call(cmd, shell=True) 10 | print(f"successfully installed package: {package}") 11 | except subprocess.CalledProcessError as e: 12 | raise ValueError(f"install error {e}") 13 | --------------------------------------------------------------------------------