├── .github ├── actions │ └── poetry_setup │ │ └── action.yml ├── scripts │ ├── check_diff.py │ └── get_min_versions.py └── workflows │ ├── _codespell.yml │ ├── _integration_test.yml │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ ├── _test_release.yml │ ├── check_diffs.yml │ └── extract_ignored_words_list.py ├── LICENSE ├── README.md └── libs └── elasticsearch ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── langchain_elasticsearch ├── __init__.py ├── _async │ ├── cache.py │ ├── chat_history.py │ ├── embeddings.py │ ├── retrievers.py │ └── vectorstores.py ├── _sync │ ├── cache.py │ ├── chat_history.py │ ├── embeddings.py │ ├── retrievers.py │ └── vectorstores.py ├── _utilities.py ├── cache.py ├── chat_history.py ├── client.py ├── embeddings.py ├── py.typed ├── retrievers.py └── vectorstores.py ├── poetry.lock ├── pyproject.toml ├── scripts ├── check_imports.py ├── lint_imports.sh └── run_unasync.py └── tests ├── __init__.py ├── _async └── fake_embeddings.py ├── _sync └── fake_embeddings.py ├── conftest.py ├── fake_embeddings.py ├── integration_tests ├── __init__.py ├── _async │ ├── __init__.py │ ├── _test_utilities.py │ ├── test_cache.py │ ├── test_chat_history.py │ ├── test_embeddings.py │ ├── test_retrievers.py │ └── test_vectorstores.py ├── _sync │ ├── __init__.py │ ├── _test_utilities.py │ ├── test_cache.py │ ├── test_chat_history.py │ ├── test_embeddings.py │ ├── test_retrievers.py │ └── test_vectorstores.py └── docker-compose.yml └── unit_tests ├── __init__.py ├── _async ├── __init__.py ├── test_cache.py └── test_vectorstores.py ├── _sync ├── __init__.py ├── test_cache.py └── test_vectorstores.py └── test_imports.py /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} 94 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import Dict 4 | 5 | LIB_DIRS = ["libs/elasticsearch"] 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | 10 | dirs_to_run: Dict[str, set] = { 11 | "lint": set(), 12 | "test": set(), 13 | } 14 | 15 | if len(files) == 300: 16 | # max diff length is 300 files - there are likely files missing 17 | raise ValueError("Max diff reached. Please manually run CI on changed libs.") 18 | 19 | for file in files: 20 | if any( 21 | file.startswith(dir_) 22 | for dir_ in ( 23 | ".github/workflows", 24 | ".github/tools", 25 | ".github/actions", 26 | ".github/scripts/check_diff.py", 27 | ) 28 | ): 29 | # add all LANGCHAIN_DIRS for infra changes 30 | dirs_to_run["test"].update(LIB_DIRS) 31 | 32 | if any(file.startswith(dir_) for dir_ in LIB_DIRS): 33 | for dir_ in LIB_DIRS: 34 | if file.startswith(dir_): 35 | dirs_to_run["test"].add(dir_) 36 | elif file.startswith("libs/"): 37 | raise ValueError( 38 | f"Unknown lib: {file}. check_diff.py likely needs " 39 | "an update for this new library!" 40 | ) 41 | 42 | outputs = { 43 | "dirs-to-lint": list(dirs_to_run["lint"] | dirs_to_run["test"]), 44 | "dirs-to-test": list(dirs_to_run["test"]), 45 | } 46 | for key, value in outputs.items(): 47 | json_output = json.dumps(value) 48 | print(f"{key}={json_output}") # noqa: T201 49 | -------------------------------------------------------------------------------- /.github/scripts/get_min_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 11): 4 | import tomllib 5 | else: 6 | # for python 3.10 and below, which doesnt have stdlib tomllib 7 | import tomli as tomllib 8 | 9 | from packaging.version import parse as parse_version 10 | import re 11 | 12 | MIN_VERSION_LIBS = ["langchain-core"] 13 | 14 | SKIP_IF_PULL_REQUEST = ["langchain-core"] 15 | 16 | 17 | def get_min_version(version: str) -> str: 18 | # base regex for x.x.x with cases for rc/post/etc 19 | # valid strings: https://peps.python.org/pep-0440/#public-version-identifiers 20 | vstring = r"\d+(?:\.\d+){0,2}(?:(?:a|b|rc|\.post|\.dev)\d+)?" 21 | # case ^x.x.x 22 | _match = re.match(f"^\\^({vstring})$", version) 23 | if _match: 24 | return _match.group(1) 25 | 26 | # case >=x.x.x,=({vstring}),<({vstring})$", version) 28 | if _match: 29 | _min = _match.group(1) 30 | _max = _match.group(2) 31 | assert parse_version(_min) < parse_version(_max) 32 | return _min 33 | 34 | # case x.x.x 35 | _match = re.match(f"^({vstring})$", version) 36 | if _match: 37 | return _match.group(1) 38 | 39 | raise ValueError(f"Unrecognized version format: {version}") 40 | 41 | 42 | def get_min_version_from_toml(toml_path: str, versions_for: str): 43 | # Parse the TOML file 44 | with open(toml_path, "rb") as file: 45 | toml_data = tomllib.load(file) 46 | 47 | # Get the dependencies from tool.poetry.dependencies 48 | dependencies = toml_data["tool"]["poetry"]["dependencies"] 49 | 50 | # Initialize a dictionary to store the minimum versions 51 | min_versions = {} 52 | 53 | # Iterate over the libs in MIN_VERSION_LIBS 54 | for lib in MIN_VERSION_LIBS: 55 | if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST: 56 | # some libs only get checked on release because of simultaneous 57 | # changes 58 | continue 59 | # Check if the lib is present in the dependencies 60 | if lib in dependencies: 61 | # Get the version string 62 | version_string = dependencies[lib] 63 | 64 | if isinstance(version_string, dict): 65 | version_string = version_string["version"] 66 | 67 | # Use parse_version to get the minimum supported version from version_string 68 | min_version = get_min_version(version_string) 69 | 70 | # Store the minimum version in the min_versions dictionary 71 | min_versions[lib] = min_version 72 | 73 | return min_versions 74 | 75 | 76 | if __name__ == "__main__": 77 | # Get the TOML file path from the command line argument 78 | toml_file = sys.argv[1] 79 | versions_for = sys.argv[2] 80 | assert versions_for in ["release", "pull_request"] 81 | 82 | # Call the function to get the minimum versions 83 | min_versions = get_min_version_from_toml(toml_file, versions_for) 84 | 85 | print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()])) 86 | -------------------------------------------------------------------------------- /.github/workflows/_codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: make spell_check 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | working-directory: 8 | required: true 9 | type: string 10 | description: "From which folder this pipeline executes" 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | codespell: 17 | name: (Check for spelling errors) 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Dependencies 25 | run: | 26 | pip install toml 27 | 28 | - name: Extract Ignore Words List 29 | working-directory: ${{ inputs.working-directory }} 30 | run: | 31 | # Use a Python script to extract the ignore words list from pyproject.toml 32 | python ../../.github/workflows/extract_ignored_words_list.py 33 | id: extract_ignore_words 34 | 35 | - name: Codespell 36 | uses: codespell-project/actions-codespell@v2 37 | with: 38 | skip: guide_imports.json 39 | ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} 40 | -------------------------------------------------------------------------------- /.github/workflows/_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: integration-test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | required: true 14 | type: string 15 | description: "From which folder this pipeline executes" 16 | 17 | env: 18 | POETRY_VERSION: "1.7.1" 19 | 20 | jobs: 21 | build: 22 | name: "make integration_tests" 23 | defaults: 24 | run: 25 | working-directory: ${{ inputs.working-directory }} 26 | runs-on: ubuntu-latest 27 | strategy: 28 | matrix: 29 | python-version: 30 | - "3.9" 31 | - "3.10" 32 | - "3.11" 33 | services: 34 | elasticsearch: 35 | image: elasticsearch:8.13.0 36 | env: 37 | discovery.type: single-node 38 | xpack.license.self_generated.type: trial 39 | xpack.security.enabled: false # disable password and TLS; never do this in production! 40 | ports: 41 | - 9200:9200 42 | options: >- 43 | --health-cmd "curl --fail http://localhost:9200/_cluster/health" 44 | --health-start-period 10s 45 | --health-timeout 3s 46 | --health-interval 3s 47 | --health-retries 10 48 | steps: 49 | - uses: actions/checkout@v4 50 | 51 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 52 | uses: "./.github/actions/poetry_setup" 53 | with: 54 | python-version: ${{ matrix.python-version }} 55 | poetry-version: ${{ env.POETRY_VERSION }} 56 | working-directory: ${{ inputs.working-directory }} 57 | cache-key: integration-tests 58 | 59 | - name: Install dependencies 60 | shell: bash 61 | run: poetry install --with=test_integration,test 62 | 63 | - name: Run integration tests 64 | shell: bash 65 | run: make integration_test 66 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 14 | 15 | # This env var allows us to get inline annotations when ruff has complaints. 16 | RUFF_OUTPUT_FORMAT: github 17 | 18 | jobs: 19 | build: 20 | name: "make lint #${{ matrix.python-version }}" 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.8" 33 | - "3.11" 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 38 | uses: "./.github/actions/poetry_setup" 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | poetry-version: ${{ env.POETRY_VERSION }} 42 | working-directory: ${{ inputs.working-directory }} 43 | cache-key: lint-with-extras 44 | 45 | - name: Check Poetry File 46 | shell: bash 47 | working-directory: ${{ inputs.working-directory }} 48 | run: | 49 | poetry check 50 | 51 | - name: Check lock file 52 | shell: bash 53 | working-directory: ${{ inputs.working-directory }} 54 | run: | 55 | poetry lock --check 56 | 57 | - name: Install dependencies 58 | # Also installs dev/lint/test/typing dependencies, to ensure we have 59 | # type hints for as many of our libraries as possible. 60 | # This helps catch errors that require dependencies to be spotted, for example: 61 | # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 62 | # 63 | # If you change this configuration, make sure to change the `cache-key` 64 | # in the `poetry_setup` action above to stop using the old cache. 65 | # It doesn't matter how you change it, any change will cause a cache-bust. 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | poetry install --with lint,typing 69 | 70 | - name: Get .mypy_cache to speed up mypy 71 | uses: actions/cache@v4 72 | env: 73 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 74 | with: 75 | path: | 76 | ${{ env.WORKDIR }}/.mypy_cache 77 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 78 | 79 | 80 | - name: Analysing the code with our lint 81 | working-directory: ${{ inputs.working-directory }} 82 | run: | 83 | make lint_package 84 | 85 | - name: Install unit+integration test dependencies 86 | working-directory: ${{ inputs.working-directory }} 87 | run: | 88 | poetry install --with test,test_integration 89 | 90 | - name: Get .mypy_cache_test to speed up mypy 91 | uses: actions/cache@v4 92 | env: 93 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 94 | with: 95 | path: | 96 | ${{ env.WORKDIR }}/.mypy_cache_test 97 | key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 98 | 99 | - name: Analysing the code with our lint 100 | working-directory: ${{ inputs.working-directory }} 101 | run: | 102 | make lint_tests 103 | -------------------------------------------------------------------------------- /.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | required: true 14 | type: string 15 | default: 'libs/elasticsearch' 16 | dangerous-nonmaster-release: 17 | required: false 18 | type: boolean 19 | default: false 20 | description: "Release from a non-master branch (danger!)" 21 | 22 | env: 23 | PYTHON_VERSION: "3.11" 24 | POETRY_VERSION: "1.7.1" 25 | 26 | jobs: 27 | build: 28 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmaster-release 29 | runs-on: ubuntu-latest 30 | 31 | outputs: 32 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 33 | version: ${{ steps.check-version.outputs.version }} 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | 38 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 39 | uses: "./.github/actions/poetry_setup" 40 | with: 41 | python-version: ${{ env.PYTHON_VERSION }} 42 | poetry-version: ${{ env.POETRY_VERSION }} 43 | working-directory: ${{ inputs.working-directory }} 44 | cache-key: release 45 | 46 | # We want to keep this build stage *separate* from the release stage, 47 | # so that there's no sharing of permissions between them. 48 | # The release stage has trusted publishing and GitHub repo contents write access, 49 | # and we want to keep the scope of that access limited just to the release job. 50 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 51 | # could get access to our GitHub or PyPI credentials. 52 | # 53 | # Per the trusted publishing GitHub Action: 54 | # > It is strongly advised to separate jobs for building [...] 55 | # > from the publish job. 56 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 57 | - name: Build project for distribution 58 | run: poetry build 59 | working-directory: ${{ inputs.working-directory }} 60 | 61 | - name: Upload build 62 | uses: actions/upload-artifact@v4 63 | with: 64 | name: dist 65 | path: ${{ inputs.working-directory }}/dist/ 66 | 67 | - name: Check Version 68 | id: check-version 69 | shell: bash 70 | working-directory: ${{ inputs.working-directory }} 71 | run: | 72 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 73 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 74 | 75 | test-pypi-publish: 76 | needs: 77 | - build 78 | uses: 79 | ./.github/workflows/_test_release.yml 80 | permissions: write-all 81 | with: 82 | working-directory: ${{ inputs.working-directory }} 83 | dangerous-nonmaster-release: ${{ inputs.dangerous-nonmaster-release }} 84 | secrets: inherit 85 | 86 | pre-release-checks: 87 | needs: 88 | - build 89 | - test-pypi-publish 90 | runs-on: ubuntu-latest 91 | services: 92 | elasticsearch: 93 | image: elasticsearch:8.13.0 94 | env: 95 | discovery.type: single-node 96 | xpack.license.self_generated.type: trial 97 | xpack.security.enabled: false # disable password and TLS; never do this in production! 98 | ports: 99 | - 9200:9200 100 | options: >- 101 | --health-cmd "curl --fail http://localhost:9200/_cluster/health" 102 | --health-start-period 10s 103 | --health-timeout 3s 104 | --health-interval 3s 105 | --health-retries 10 106 | steps: 107 | - uses: actions/checkout@v4 108 | 109 | # We explicitly *don't* set up caching here. This ensures our tests are 110 | # maximally sensitive to catching breakage. 111 | # 112 | # For example, here's a way that caching can cause a falsely-passing test: 113 | # - Make the langchain package manifest no longer list a dependency package 114 | # as a requirement. This means it won't be installed by `pip install`, 115 | # and attempting to use it would cause a crash. 116 | # - That dependency used to be required, so it may have been cached. 117 | # When restoring the venv packages from cache, that dependency gets included. 118 | # - Tests pass, because the dependency is present even though it wasn't specified. 119 | # - The package is published, and it breaks on the missing dependency when 120 | # used in the real world. 121 | 122 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 123 | uses: "./.github/actions/poetry_setup" 124 | with: 125 | python-version: ${{ env.PYTHON_VERSION }} 126 | poetry-version: ${{ env.POETRY_VERSION }} 127 | working-directory: ${{ inputs.working-directory }} 128 | 129 | - name: Import published package 130 | shell: bash 131 | working-directory: ${{ inputs.working-directory }} 132 | env: 133 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 134 | VERSION: ${{ needs.build.outputs.version }} 135 | # Here we use: 136 | # - The default regular PyPI index as the *primary* index, meaning 137 | # that it takes priority (https://pypi.org/simple) 138 | # - The test PyPI index as an extra index, so that any dependencies that 139 | # are not found on test PyPI can be resolved and installed anyway. 140 | # (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION 141 | # package because VERSION will not have been uploaded to regular PyPI yet. 142 | # - attempt install again after 5 seconds if it fails because there is 143 | # sometimes a delay in availability on test pypi 144 | run: | 145 | poetry run pip install \ 146 | --extra-index-url https://test.pypi.org/simple/ \ 147 | "$PKG_NAME==$VERSION" || \ 148 | ( \ 149 | sleep 5 && \ 150 | poetry run pip install \ 151 | --extra-index-url https://test.pypi.org/simple/ \ 152 | "$PKG_NAME==$VERSION" \ 153 | ) 154 | 155 | # Replace all dashes in the package name with underscores, 156 | # since that's how Python imports packages with dashes in the name. 157 | IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" 158 | 159 | poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" 160 | 161 | - name: Import test dependencies 162 | run: poetry install --with test,test_integration 163 | working-directory: ${{ inputs.working-directory }} 164 | 165 | # Overwrite the local version of the package with the test PyPI version. 166 | - name: Import published package (again) 167 | working-directory: ${{ inputs.working-directory }} 168 | shell: bash 169 | env: 170 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 171 | VERSION: ${{ needs.build.outputs.version }} 172 | run: | 173 | poetry run pip install \ 174 | --extra-index-url https://test.pypi.org/simple/ \ 175 | "$PKG_NAME==$VERSION" 176 | 177 | - name: Run unit tests 178 | run: make tests 179 | working-directory: ${{ inputs.working-directory }} 180 | 181 | - name: Run integration tests 182 | run: make integration_test 183 | working-directory: ${{ inputs.working-directory }} 184 | 185 | - name: Get minimum versions 186 | working-directory: ${{ inputs.working-directory }} 187 | id: min-version 188 | run: | 189 | poetry run pip install packaging 190 | min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml release)" 191 | echo "min-versions=$min_versions" >> "$GITHUB_OUTPUT" 192 | echo "min-versions=$min_versions" 193 | 194 | - name: Run unit tests with minimum dependency versions 195 | if: ${{ steps.min-version.outputs.min-versions != '' }} 196 | env: 197 | MIN_VERSIONS: ${{ steps.min-version.outputs.min-versions }} 198 | run: | 199 | poetry run pip install $MIN_VERSIONS 200 | make tests 201 | working-directory: ${{ inputs.working-directory }} 202 | 203 | publish: 204 | needs: 205 | - build 206 | - test-pypi-publish 207 | - pre-release-checks 208 | runs-on: ubuntu-latest 209 | permissions: 210 | # This permission is used for trusted publishing: 211 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 212 | # 213 | # Trusted publishing has to also be configured on PyPI for each package: 214 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 215 | id-token: write 216 | 217 | defaults: 218 | run: 219 | working-directory: ${{ inputs.working-directory }} 220 | 221 | steps: 222 | - uses: actions/checkout@v4 223 | 224 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 225 | uses: "./.github/actions/poetry_setup" 226 | with: 227 | python-version: ${{ env.PYTHON_VERSION }} 228 | poetry-version: ${{ env.POETRY_VERSION }} 229 | working-directory: ${{ inputs.working-directory }} 230 | cache-key: release 231 | 232 | - uses: actions/download-artifact@v4 233 | with: 234 | name: dist 235 | path: ${{ inputs.working-directory }}/dist/ 236 | 237 | - name: Publish package distributions to PyPI 238 | uses: pypa/gh-action-pypi-publish@release/v1 239 | with: 240 | packages-dir: ${{ inputs.working-directory }}/dist/ 241 | verbose: true 242 | print-hash: true 243 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1\.11\.0 244 | attestations: false 245 | 246 | mark-release: 247 | needs: 248 | - build 249 | - test-pypi-publish 250 | - pre-release-checks 251 | - publish 252 | runs-on: ubuntu-latest 253 | permissions: 254 | # This permission is needed by `ncipollo/release-action` to 255 | # create the GitHub release. 256 | contents: write 257 | 258 | defaults: 259 | run: 260 | working-directory: ${{ inputs.working-directory }} 261 | 262 | steps: 263 | - uses: actions/checkout@v4 264 | 265 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 266 | uses: "./.github/actions/poetry_setup" 267 | with: 268 | python-version: ${{ env.PYTHON_VERSION }} 269 | poetry-version: ${{ env.POETRY_VERSION }} 270 | working-directory: ${{ inputs.working-directory }} 271 | cache-key: releaseartifact@v4 272 | 273 | - uses: actions/download-artifact@v4 274 | with: 275 | name: dist 276 | path: ${{ inputs.working-directory }}/dist/ 277 | 278 | - name: Create Release 279 | uses: ncipollo/release-action@v1 280 | with: 281 | artifacts: "dist/*" 282 | token: ${{ secrets.GITHUB_TOKEN }} 283 | draft: false 284 | generateReleaseNotes: true 285 | tag: ${{ inputs.working-directory }}/v${{ needs.build.outputs.version }} 286 | commit: ${{ github.sha }} 287 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.8" 24 | - "3.9" 25 | - "3.10" 26 | - "3.11" 27 | name: "make test #${{ matrix.python-version }}" 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 32 | uses: "./.github/actions/poetry_setup" 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | poetry-version: ${{ env.POETRY_VERSION }} 36 | working-directory: ${{ inputs.working-directory }} 37 | cache-key: core 38 | 39 | - name: Install dependencies 40 | shell: bash 41 | run: poetry install --with test 42 | 43 | - name: Run core tests 44 | shell: bash 45 | run: | 46 | make test 47 | 48 | - name: Ensure the tests did not create any additional files 49 | shell: bash 50 | run: | 51 | set -eu 52 | 53 | STATUS="$(git status)" 54 | echo "$STATUS" 55 | 56 | # grep will exit non-zero if the target message isn't found, 57 | # and `set -e` above will cause the step to fail. 58 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 59 | -------------------------------------------------------------------------------- /.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | dangerous-nonmaster-release: 11 | required: false 12 | type: boolean 13 | default: false 14 | description: "Release from a non-master branch (danger!)" 15 | 16 | env: 17 | POETRY_VERSION: "1.7.1" 18 | PYTHON_VERSION: "3.10" 19 | 20 | jobs: 21 | build: 22 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmaster-release 23 | runs-on: ubuntu-latest 24 | 25 | outputs: 26 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 27 | version: ${{ steps.check-version.outputs.version }} 28 | 29 | steps: 30 | - uses: actions/checkout@v4 31 | 32 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 33 | uses: "./.github/actions/poetry_setup" 34 | with: 35 | python-version: ${{ env.PYTHON_VERSION }} 36 | poetry-version: ${{ env.POETRY_VERSION }} 37 | working-directory: ${{ inputs.working-directory }} 38 | cache-key: release 39 | 40 | # We want to keep this build stage *separate* from the release stage, 41 | # so that there's no sharing of permissions between them. 42 | # The release stage has trusted publishing and GitHub repo contents write access, 43 | # and we want to keep the scope of that access limited just to the release job. 44 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 45 | # could get access to our GitHub or PyPI credentials. 46 | # 47 | # Per the trusted publishing GitHub Action: 48 | # > It is strongly advised to separate jobs for building [...] 49 | # > from the publish job. 50 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 51 | - name: Build project for distribution 52 | run: poetry build 53 | working-directory: ${{ inputs.working-directory }} 54 | 55 | - name: Upload build 56 | uses: actions/upload-artifact@v4 57 | with: 58 | name: test-dist 59 | path: ${{ inputs.working-directory }}/dist/ 60 | 61 | - name: Check Version 62 | id: check-version 63 | shell: bash 64 | working-directory: ${{ inputs.working-directory }} 65 | run: | 66 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 67 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 68 | 69 | publish: 70 | needs: 71 | - build 72 | runs-on: ubuntu-latest 73 | permissions: 74 | # This permission is used for trusted publishing: 75 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 76 | # 77 | # Trusted publishing has to also be configured on PyPI for each package: 78 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 79 | id-token: write 80 | 81 | steps: 82 | - uses: actions/checkout@v4 83 | 84 | - uses: actions/download-artifact@v4 85 | with: 86 | name: test-dist 87 | path: ${{ inputs.working-directory }}/dist/ 88 | 89 | - name: Publish to test PyPI 90 | uses: pypa/gh-action-pypi-publish@release/v1 91 | with: 92 | packages-dir: ${{ inputs.working-directory }}/dist/ 93 | verbose: true 94 | print-hash: true 95 | repository-url: https://test.pypi.org/legacy/ 96 | 97 | # We overwrite any existing distributions with the same name and version. 98 | # This is *only for CI use* and is *extremely dangerous* otherwise! 99 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 100 | skip-existing: true 101 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0 102 | attestations: false 103 | -------------------------------------------------------------------------------- /.github/workflows/check_diffs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | # If another push to the same PR or branch happens while this workflow is still running, 10 | # cancel the earlier run in favor of the next run. 11 | # 12 | # There's no point in testing an outdated version of the code. GitHub only allows 13 | # a limited number of job runners to be active at the same time, so it's better to cancel 14 | # pointless jobs early so that more useful jobs can run sooner. 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | env: 20 | POETRY_VERSION: "1.7.1" 21 | 22 | jobs: 23 | build: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | - id: files 31 | uses: Ana06/get-changed-files@v2.2.0 32 | - id: set-matrix 33 | run: | 34 | python .github/scripts/check_diff.py ${{ steps.files.outputs.all }} >> $GITHUB_OUTPUT 35 | outputs: 36 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }} 37 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }} 38 | lint: 39 | name: cd ${{ matrix.working-directory }} 40 | needs: [ build ] 41 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }} 42 | strategy: 43 | matrix: 44 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }} 45 | uses: ./.github/workflows/_lint.yml 46 | with: 47 | working-directory: ${{ matrix.working-directory }} 48 | secrets: inherit 49 | 50 | test: 51 | name: cd ${{ matrix.working-directory }} 52 | needs: [ build ] 53 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 54 | strategy: 55 | matrix: 56 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 57 | uses: ./.github/workflows/_test.yml 58 | with: 59 | working-directory: ${{ matrix.working-directory }} 60 | secrets: inherit 61 | 62 | integration-test: 63 | name: cd ${{ matrix.working-directory }} 64 | needs: [ build ] 65 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 66 | strategy: 67 | matrix: 68 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 69 | uses: ./.github/workflows/_integration_test.yml 70 | with: 71 | working-directory: ${{ matrix.working-directory }} 72 | secrets: inherit 73 | 74 | ci_success: 75 | name: "CI Success" 76 | needs: [build, lint, test, integration-test] 77 | if: | 78 | always() 79 | runs-on: ubuntu-latest 80 | env: 81 | JOBS_JSON: ${{ toJSON(needs) }} 82 | RESULTS_JSON: ${{ toJSON(needs.*.result) }} 83 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} 84 | steps: 85 | - name: "CI Success" 86 | run: | 87 | echo $JOBS_JSON 88 | echo $RESULTS_JSON 89 | echo "Exiting with $EXIT_CODE" 90 | exit $EXIT_CODE 91 | -------------------------------------------------------------------------------- /.github/workflows/extract_ignored_words_list.py: -------------------------------------------------------------------------------- 1 | import toml 2 | 3 | pyproject_toml = toml.load("pyproject.toml") 4 | 5 | # Extract the ignore words list (adjust the key as per your TOML structure) 6 | ignore_words_list = ( 7 | pyproject_toml.get("tool", {}).get("codespell", {}).get("ignore-words-list") 8 | ) 9 | 10 | print(f"::set-output name=ignore_words_list::{ignore_words_list}") 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 LangChain Elastic 2 | 3 | This repository contains 1 package with Elasticsearch integrations with LangChain: 4 | 5 | - [langchain-elasticsearch](https://pypi.org/project/langchain-elasticsearch/) integrates [Elasticsearch](https://www.elastic.co/elasticsearch). 6 | - [ElasticsearchStore](https://python.langchain.com/docs/integrations/vectorstores/elasticsearch/) 7 | - [ElasticsearchRetriever](https://python.langchain.com/docs/integrations/retrievers/elasticsearch_retriever/) 8 | - [ElasticsearchEmbeddings](https://python.langchain.com/docs/integrations/text_embedding/elasticsearch/) 9 | - [ElasticsearchChatMessageHistory](https://python.langchain.com/docs/integrations/memory/elasticsearch_chat_message_history/) 10 | - [ElasticsearchCache](https://python.langchain.com/docs/integrations/llm_caching/#elasticsearch-cache) 11 | - [ElasticsearchEmbeddingsCache](https://python.langchain.com/docs/integrations/stores/elasticsearch/) 12 | -------------------------------------------------------------------------------- /libs/elasticsearch/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /libs/elasticsearch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/elasticsearch/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | install: 7 | poetry install 8 | 9 | # Define a variable for the test file path. 10 | TEST_FILE ?= tests/unit_tests/ 11 | integration_test integration_tests: TEST_FILE=tests/integration_tests/ 12 | 13 | test tests integration_test integration_tests: 14 | poetry run pytest $(TEST_FILE) 15 | 16 | 17 | ###################### 18 | # LINTING AND FORMATTING 19 | ###################### 20 | 21 | # Define a variable for Python and notebook files. 22 | PYTHON_FILES=. 23 | MYPY_CACHE=.mypy_cache 24 | lint format: PYTHON_FILES=. 25 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/elasticsearch --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 26 | lint_package: PYTHON_FILES=langchain_elasticsearch 27 | lint_tests: PYTHON_FILES=tests 28 | lint_tests: MYPY_CACHE=.mypy_cache_test 29 | 30 | lint lint_diff lint_package lint_tests: 31 | poetry run ruff . 32 | poetry run ruff format $(PYTHON_FILES) --diff 33 | poetry run ruff --select I $(PYTHON_FILES) 34 | mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 35 | 36 | format format_diff: 37 | poetry run ruff format $(PYTHON_FILES) 38 | poetry run ruff --select I --fix $(PYTHON_FILES) 39 | 40 | spell_check: 41 | poetry run codespell --toml pyproject.toml 42 | 43 | spell_fix: 44 | poetry run codespell --toml pyproject.toml -w 45 | 46 | check_imports: $(shell find langchain_elasticsearch -name '*.py') 47 | poetry run python ./scripts/check_imports.py $^ 48 | 49 | run_unasync: 50 | poetry run python ./scripts/run_unasync.py 51 | 52 | run_unasync_check: 53 | poetry run python ./scripts/run_unasync.py --check 54 | 55 | ###################### 56 | # HELP 57 | ###################### 58 | 59 | help: 60 | @echo '----' 61 | @echo 'check_imports - check imports' 62 | @echo 'format - run code formatters' 63 | @echo 'lint - run linters' 64 | @echo 'test - run unit tests' 65 | @echo 'tests - run unit tests' 66 | @echo 'test TEST_FILE= - run all tests in file' 67 | -------------------------------------------------------------------------------- /libs/elasticsearch/README.md: -------------------------------------------------------------------------------- 1 | # langchain-elasticsearch 2 | 3 | This package contains the LangChain integration with Elasticsearch. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install -U langchain-elasticsearch 9 | ``` 10 | 11 | ## Elasticsearch setup 12 | 13 | ### Elastic Cloud 14 | 15 | You need a running Elasticsearch deployment. The easiest way to start one is through [Elastic Cloud](https://cloud.elastic.co/). 16 | You can sign up for a [free trial](https://www.elastic.co/cloud/cloud-trial-overview). 17 | 18 | 1. [Create a deployment](https://www.elastic.co/guide/en/cloud/current/ec-create-deployment.html) 19 | 2. Get your Cloud ID: 20 | 1. In the [Elastic Cloud console](https://cloud.elastic.co), click "Manage" next to your deployment 21 | 2. Copy the Cloud ID and paste it into the `es_cloud_id` parameter below 22 | 3. Create an API key: 23 | 1. In the [Elastic Cloud console](https://cloud.elastic.co), click "Open" next to your deployment 24 | 2. In the left-hand side menu, go to "Stack Management", then to "API Keys" 25 | 3. Click "Create API key" 26 | 4. Enter a name for the API key and click "Create" 27 | 5. Copy the API key and paste it into the `es_api_key` parameter below 28 | 29 | ### Elastic Cloud 30 | 31 | Alternatively, you can run Elasticsearch via Docker as described in the [docs](https://python.langchain.com/docs/integrations/vectorstores/elasticsearch). 32 | 33 | ## Usage 34 | 35 | ### ElasticsearchStore 36 | 37 | The `ElasticsearchStore` class exposes Elasticsearch as a vector store. 38 | 39 | ```python 40 | from langchain_elasticsearch import ElasticsearchStore 41 | 42 | embeddings = ... # use a LangChain Embeddings class or ElasticsearchEmbeddings 43 | 44 | vectorstore = ElasticsearchStore( 45 | es_cloud_id="your-cloud-id", 46 | es_api_key="your-api-key", 47 | index_name="your-index-name", 48 | embeddings=embeddings, 49 | ) 50 | ``` 51 | 52 | ### ElasticsearchRetriever 53 | 54 | The `ElasticsearchRetriever` class can be user to implement more complex queries. 55 | This can be useful for power users and necessary if data was ingested outside of LangChain 56 | (for example using a web crawler). 57 | 58 | ```python 59 | def fuzzy_query(search_query: str) -> Dict: 60 | return { 61 | "query": { 62 | "match": { 63 | text_field: { 64 | "query": search_query, 65 | "fuzziness": "AUTO", 66 | } 67 | }, 68 | }, 69 | } 70 | 71 | 72 | fuzzy_retriever = ElasticsearchRetriever.from_es_params( 73 | es_cloud_id="your-cloud-id", 74 | es_api_key="your-api-key", 75 | index_name="your-index-name", 76 | body_func=fuzzy_query, 77 | content_field=text_field, 78 | ) 79 | 80 | fuzzy_retriever.get_relevant_documents("fooo") 81 | ``` 82 | 83 | ### ElasticsearchEmbeddings 84 | 85 | The `ElasticsearchEmbeddings` class provides an interface to generate embeddings using a model 86 | deployed in an Elasticsearch cluster. 87 | 88 | ```python 89 | from langchain_elasticsearch import ElasticsearchEmbeddings 90 | 91 | embeddings = ElasticsearchEmbeddings.from_credentials( 92 | model_id="your-model-id", 93 | input_field="your-input-field", 94 | es_cloud_id="your-cloud-id", 95 | es_api_key="your-api-key", 96 | ) 97 | ``` 98 | 99 | ### ElasticsearchChatMessageHistory 100 | 101 | The `ElasticsearchChatMessageHistory` class stores chat histories in Elasticsearch. 102 | 103 | ```python 104 | from langchain_elasticsearch import ElasticsearchChatMessageHistory 105 | 106 | chat_history = ElasticsearchChatMessageHistory( 107 | index="your-index-name", 108 | session_id="your-session-id", 109 | es_cloud_id="your-cloud-id", 110 | es_api_key="your-api-key", 111 | ) 112 | ``` 113 | 114 | 115 | ### ElasticsearchCache 116 | 117 | A caching layer for LLMs that uses Elasticsearch. 118 | 119 | Simple example: 120 | 121 | ```python 122 | from langchain.globals import set_llm_cache 123 | 124 | from langchain_elasticsearch import ElasticsearchCache 125 | 126 | set_llm_cache( 127 | ElasticsearchCache( 128 | es_url="http://localhost:9200", 129 | index_name="llm-chat-cache", 130 | metadata={"project": "my_chatgpt_project"}, 131 | ) 132 | ) 133 | ``` 134 | 135 | The `index_name` parameter can also accept aliases. This allows to use the 136 | [ILM: Manage the index lifecycle](https://www.elastic.co/guide/en/elasticsearch/reference/current/index-lifecycle-management.html) 137 | that we suggest to consider for managing retention and controlling cache growth. 138 | 139 | Look at the class docstring for all parameters. 140 | 141 | #### Index the generated text 142 | 143 | The cached data won't be searchable by default. 144 | The developer can customize the building of the Elasticsearch document in order to add indexed text fields, 145 | where to put, for example, the text generated by the LLM. 146 | 147 | This can be done by subclassing end overriding methods. 148 | The new cache class can be applied also to a pre-existing cache index: 149 | 150 | ```python 151 | import json 152 | from typing import Any, Dict, List 153 | 154 | from langchain.globals import set_llm_cache 155 | from langchain_core.caches import RETURN_VAL_TYPE 156 | 157 | from langchain_elasticsearch import ElasticsearchCache 158 | 159 | 160 | class SearchableElasticsearchCache(ElasticsearchCache): 161 | @property 162 | def mapping(self) -> Dict[str, Any]: 163 | mapping = super().mapping 164 | mapping["mappings"]["properties"]["parsed_llm_output"] = { 165 | "type": "text", 166 | "analyzer": "english", 167 | } 168 | return mapping 169 | 170 | def build_document( 171 | self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE 172 | ) -> Dict[str, Any]: 173 | body = super().build_document(prompt, llm_string, return_val) 174 | body["parsed_llm_output"] = self._parse_output(body["llm_output"]) 175 | return body 176 | 177 | @staticmethod 178 | def _parse_output(data: List[str]) -> List[str]: 179 | return [ 180 | json.loads(output)["kwargs"]["message"]["kwargs"]["content"] 181 | for output in data 182 | ] 183 | 184 | 185 | set_llm_cache( 186 | SearchableElasticsearchCache( 187 | es_url="http://localhost:9200", 188 | index_name="llm-chat-cache" 189 | ) 190 | ) 191 | ``` 192 | 193 | When overriding the mapping and the document building, 194 | please only make additive modifications, keeping the base mapping intact. 195 | 196 | ### ElasticsearchEmbeddingsCache 197 | 198 | Store and temporarily cache embeddings. 199 | 200 | Caching embeddings is obtained by using the [CacheBackedEmbeddings](https://python.langchain.com/docs/modules/data_connection/text_embedding/caching_embeddings), it can be instantiated using `CacheBackedEmbeddings.from_bytes_store` method. 201 | 202 | ```python 203 | from langchain.embeddings import CacheBackedEmbeddings 204 | from langchain_openai import OpenAIEmbeddings 205 | 206 | from langchain_elasticsearch import ElasticsearchEmbeddingsCache 207 | 208 | underlying_embeddings = OpenAIEmbeddings(model="text-embedding-3-small") 209 | 210 | store = ElasticsearchEmbeddingsCache( 211 | es_url="http://localhost:9200", 212 | index_name="llm-chat-cache", 213 | metadata={"project": "my_chatgpt_project"}, 214 | namespace="my_chatgpt_project", 215 | ) 216 | 217 | embeddings = CacheBackedEmbeddings.from_bytes_store( 218 | underlying_embeddings=OpenAIEmbeddings(), 219 | document_embedding_cache=store, 220 | query_embedding_cache=store, 221 | ) 222 | ``` 223 | 224 | Similarly to the chat cache, one can subclass `ElasticsearchEmbeddingsCache` in order to index vectors for search. 225 | 226 | ```python 227 | from typing import Any, Dict, List 228 | from langchain_elasticsearch import ElasticsearchEmbeddingsCache 229 | 230 | class SearchableElasticsearchStore(ElasticsearchEmbeddingsCache): 231 | @property 232 | def mapping(self) -> Dict[str, Any]: 233 | mapping = super().mapping 234 | mapping["mappings"]["properties"]["vector"] = { 235 | "type": "dense_vector", 236 | "dims": 1536, 237 | "index": True, 238 | "similarity": "dot_product", 239 | } 240 | return mapping 241 | 242 | def build_document(self, llm_input: str, vector: List[float]) -> Dict[str, Any]: 243 | body = super().build_document(llm_input, vector) 244 | body["vector"] = vector 245 | return body 246 | ``` 247 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/__init__.py: -------------------------------------------------------------------------------- 1 | from elasticsearch.helpers.vectorstore import ( 2 | AsyncBM25Strategy, 3 | AsyncDenseVectorScriptScoreStrategy, 4 | AsyncDenseVectorStrategy, 5 | AsyncRetrievalStrategy, 6 | AsyncSparseVectorStrategy, 7 | BM25Strategy, 8 | DenseVectorScriptScoreStrategy, 9 | DenseVectorStrategy, 10 | DistanceMetric, 11 | RetrievalStrategy, 12 | SparseVectorStrategy, 13 | ) 14 | 15 | from langchain_elasticsearch.cache import ( 16 | AsyncElasticsearchCache, 17 | AsyncElasticsearchEmbeddingsCache, 18 | ElasticsearchCache, 19 | ElasticsearchEmbeddingsCache, 20 | ) 21 | from langchain_elasticsearch.chat_history import ( 22 | AsyncElasticsearchChatMessageHistory, 23 | ElasticsearchChatMessageHistory, 24 | ) 25 | from langchain_elasticsearch.embeddings import ( 26 | AsyncElasticsearchEmbeddings, 27 | ElasticsearchEmbeddings, 28 | ) 29 | from langchain_elasticsearch.retrievers import ( 30 | AsyncElasticsearchRetriever, 31 | ElasticsearchRetriever, 32 | ) 33 | from langchain_elasticsearch.vectorstores import ( 34 | ApproxRetrievalStrategy, 35 | AsyncElasticsearchStore, 36 | BM25RetrievalStrategy, 37 | ElasticsearchStore, 38 | ExactRetrievalStrategy, 39 | SparseRetrievalStrategy, 40 | ) 41 | 42 | __all__ = [ 43 | "AsyncElasticsearchCache", 44 | "AsyncElasticsearchChatMessageHistory", 45 | "AsyncElasticsearchEmbeddings", 46 | "AsyncElasticsearchEmbeddingsCache", 47 | "AsyncElasticsearchRetriever", 48 | "AsyncElasticsearchStore", 49 | "ElasticsearchCache", 50 | "ElasticsearchChatMessageHistory", 51 | "ElasticsearchEmbeddings", 52 | "ElasticsearchEmbeddingsCache", 53 | "ElasticsearchRetriever", 54 | "ElasticsearchStore", 55 | # retrieval strategies 56 | "AsyncBM25Strategy", 57 | "AsyncDenseVectorScriptScoreStrategy", 58 | "AsyncDenseVectorStrategy", 59 | "AsyncRetrievalStrategy", 60 | "AsyncSparseVectorStrategy", 61 | "BM25Strategy", 62 | "DenseVectorScriptScoreStrategy", 63 | "DenseVectorStrategy", 64 | "DistanceMetric", 65 | "RetrievalStrategy", 66 | "SparseVectorStrategy", 67 | # deprecated retrieval strategies 68 | "ApproxRetrievalStrategy", 69 | "BM25RetrievalStrategy", 70 | "ExactRetrievalStrategy", 71 | "SparseRetrievalStrategy", 72 | ] 73 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_async/chat_history.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from time import time 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence 5 | 6 | from langchain_core.chat_history import BaseChatMessageHistory 7 | from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict 8 | 9 | from langchain_elasticsearch._utilities import async_with_user_agent_header 10 | from langchain_elasticsearch.client import create_async_elasticsearch_client 11 | 12 | if TYPE_CHECKING: 13 | from elasticsearch import AsyncElasticsearch 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class AsyncElasticsearchChatMessageHistory(BaseChatMessageHistory): 19 | """Chat message history that stores history in Elasticsearch. 20 | 21 | Args: 22 | es_url: URL of the Elasticsearch instance to connect to. 23 | es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. 24 | es_user: Username to use when connecting to Elasticsearch. 25 | es_password: Password to use when connecting to Elasticsearch. 26 | es_api_key: API key to use when connecting to Elasticsearch. 27 | es_connection: Optional pre-existing Elasticsearch connection. 28 | esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. 29 | index: Name of the index to use. 30 | session_id: Arbitrary key that is used to store the messages 31 | of a single chat session. 32 | 33 | For synchronous applications, use the `ElasticsearchChatMessageHistory` class. 34 | For asyhchronous applications, use the `AsyncElasticsearchChatMessageHistory` class. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | index: str, 40 | session_id: str, 41 | *, 42 | es_connection: Optional["AsyncElasticsearch"] = None, 43 | es_url: Optional[str] = None, 44 | es_cloud_id: Optional[str] = None, 45 | es_user: Optional[str] = None, 46 | es_api_key: Optional[str] = None, 47 | es_password: Optional[str] = None, 48 | esnsure_ascii: Optional[bool] = True, 49 | ): 50 | self.index: str = index 51 | self.session_id: str = session_id 52 | self.ensure_ascii = esnsure_ascii 53 | 54 | # Initialize Elasticsearch client from passed client arg or connection info 55 | if es_connection is not None: 56 | self.client = es_connection 57 | elif es_url is not None or es_cloud_id is not None: 58 | try: 59 | self.client = create_async_elasticsearch_client( 60 | url=es_url, 61 | username=es_user, 62 | password=es_password, 63 | cloud_id=es_cloud_id, 64 | api_key=es_api_key, 65 | ) 66 | except Exception as err: 67 | logger.error(f"Error connecting to Elasticsearch: {err}") 68 | raise err 69 | else: 70 | raise ValueError( 71 | """Either provide a pre-existing Elasticsearch connection, \ 72 | or valid credentials for creating a new connection.""" 73 | ) 74 | 75 | self.client = async_with_user_agent_header(self.client, "langchain-py-ms") 76 | self.created = False 77 | 78 | async def create_if_missing(self) -> None: 79 | if not self.created: 80 | if await self.client.indices.exists(index=self.index): 81 | logger.debug( 82 | ( 83 | f"Chat history index {self.index} already exists, " 84 | "skipping creation." 85 | ) 86 | ) 87 | else: 88 | logger.debug(f"Creating index {self.index} for storing chat history.") 89 | 90 | await self.client.indices.create( 91 | index=self.index, 92 | mappings={ 93 | "properties": { 94 | "session_id": {"type": "keyword"}, 95 | "created_at": {"type": "date"}, 96 | "history": {"type": "text"}, 97 | } 98 | }, 99 | ) 100 | self.created = True 101 | 102 | async def aget_messages(self) -> List[BaseMessage]: # type: ignore[override] 103 | """Retrieve the messages from Elasticsearch""" 104 | from elasticsearch import ApiError 105 | 106 | await self.create_if_missing() 107 | 108 | search_after: Dict[str, Any] = {} 109 | items = [] 110 | while True: 111 | try: 112 | result = await self.client.search( 113 | index=self.index, 114 | query={"term": {"session_id": self.session_id}}, 115 | sort="created_at:asc", 116 | size=100, 117 | **search_after, 118 | ) 119 | except ApiError as err: 120 | logger.error(f"Could not retrieve messages from Elasticsearch: {err}") 121 | raise err 122 | 123 | if result and len(result["hits"]["hits"]) > 0: 124 | items += [ 125 | json.loads(document["_source"]["history"]) 126 | for document in result["hits"]["hits"] 127 | ] 128 | search_after = {"search_after": result["hits"]["hits"][-1]["sort"]} 129 | else: 130 | break 131 | 132 | return messages_from_dict(items) 133 | 134 | async def aadd_message(self, message: BaseMessage) -> None: 135 | """Add messages to the chat session in Elasticsearch""" 136 | try: 137 | from elasticsearch import ApiError 138 | 139 | await self.create_if_missing() 140 | await self.client.index( 141 | index=self.index, 142 | document={ 143 | "session_id": self.session_id, 144 | "created_at": round(time() * 1000), 145 | "history": json.dumps( 146 | message_to_dict(message), 147 | ensure_ascii=bool(self.ensure_ascii), 148 | ), 149 | }, 150 | refresh=True, 151 | ) 152 | except ApiError as err: 153 | logger.error(f"Could not add message to Elasticsearch: {err}") 154 | raise err 155 | 156 | async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: 157 | for message in messages: 158 | await self.aadd_message(message) 159 | 160 | async def aclear(self) -> None: 161 | """Clear session memory in Elasticsearch""" 162 | try: 163 | from elasticsearch import ApiError 164 | 165 | await self.create_if_missing() 166 | await self.client.delete_by_query( 167 | index=self.index, 168 | query={"term": {"session_id": self.session_id}}, 169 | refresh=True, 170 | ) 171 | except ApiError as err: 172 | logger.error(f"Could not clear session memory in Elasticsearch: {err}") 173 | raise err 174 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, List, Optional 4 | 5 | from elasticsearch import AsyncElasticsearch 6 | from elasticsearch.helpers.vectorstore import AsyncEmbeddingService 7 | from langchain_core.embeddings import Embeddings 8 | from langchain_core.utils import get_from_env 9 | 10 | if TYPE_CHECKING: 11 | from elasticsearch._async.client.ml import MlClient 12 | 13 | 14 | class AsyncElasticsearchEmbeddings(Embeddings): 15 | """Elasticsearch embedding models. 16 | 17 | This class provides an interface to generate embeddings using a model deployed 18 | in an Elasticsearch cluster. It requires an Elasticsearch connection object 19 | and the model_id of the model deployed in the cluster. 20 | 21 | In Elasticsearch you need to have an embedding model loaded and deployed. 22 | - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html 23 | - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html 24 | 25 | For synchronous applications, use the `ElasticsearchEmbeddings` class. 26 | For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class. 27 | """ # noqa: E501 28 | 29 | def __init__( 30 | self, 31 | client: MlClient, 32 | model_id: str, 33 | *, 34 | input_field: str = "text_field", 35 | ): 36 | """ 37 | Initialize the ElasticsearchEmbeddings instance. 38 | 39 | Args: 40 | client (MlClient): An Elasticsearch ML client object. 41 | model_id (str): The model_id of the model deployed in the Elasticsearch 42 | cluster. 43 | input_field (str): The name of the key for the input text field in the 44 | document. Defaults to 'text_field'. 45 | """ 46 | self.client = client 47 | self.model_id = model_id 48 | self.input_field = input_field 49 | 50 | @classmethod 51 | def from_credentials( 52 | cls, 53 | model_id: str, 54 | *, 55 | es_cloud_id: Optional[str] = None, 56 | es_api_key: Optional[str] = None, 57 | input_field: str = "text_field", 58 | ) -> AsyncElasticsearchEmbeddings: 59 | """Instantiate embeddings from Elasticsearch credentials. 60 | 61 | Args: 62 | model_id (str): The model_id of the model deployed in the Elasticsearch 63 | cluster. 64 | input_field (str): The name of the key for the input text field in the 65 | document. Defaults to 'text_field'. 66 | es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. 67 | es_user: (str, optional): Elasticsearch username. 68 | es_password: (str, optional): Elasticsearch password. 69 | 70 | Example: 71 | .. code-block:: python 72 | 73 | from langchain_elasticserach.embeddings import ElasticsearchEmbeddings 74 | 75 | # Define the model ID and input field name (if different from default) 76 | model_id = "your_model_id" 77 | # Optional, only if different from 'text_field' 78 | input_field = "your_input_field" 79 | 80 | # Credentials can be passed in two ways. Either set the env vars 81 | # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically 82 | # pulled in, or pass them in directly as kwargs. 83 | embeddings = ElasticsearchEmbeddings.from_credentials( 84 | model_id, 85 | input_field=input_field, 86 | # es_cloud_id="foo", 87 | # es_user="bar", 88 | # es_password="baz", 89 | ) 90 | 91 | documents = [ 92 | "This is an example document.", 93 | "Another example document to generate embeddings for.", 94 | ] 95 | embeddings_generator.embed_documents(documents) 96 | """ 97 | from elasticsearch._async.client.ml import MlClient 98 | 99 | es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") 100 | es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY") 101 | 102 | # Connect to Elasticsearch 103 | es_connection = AsyncElasticsearch(cloud_id=es_cloud_id, api_key=es_api_key) 104 | client = MlClient(es_connection) 105 | return cls(client, model_id, input_field=input_field) 106 | 107 | @classmethod 108 | def from_es_connection( 109 | cls, 110 | model_id: str, 111 | es_connection: AsyncElasticsearch, 112 | input_field: str = "text_field", 113 | ) -> AsyncElasticsearchEmbeddings: 114 | """ 115 | Instantiate embeddings from an existing Elasticsearch connection. 116 | 117 | This method provides a way to create an instance of the ElasticsearchEmbeddings 118 | class using an existing Elasticsearch connection. The connection object is used 119 | to create an MlClient, which is then used to initialize the 120 | ElasticsearchEmbeddings instance. 121 | 122 | Args: 123 | model_id (str): The model_id of the model deployed in the Elasticsearch cluster. 124 | es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch 125 | connection object. input_field (str, optional): The name of the key for the 126 | input text field in the document. Defaults to 'text_field'. 127 | 128 | Returns: 129 | ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. 130 | 131 | Example: 132 | .. code-block:: python 133 | 134 | from elasticsearch import Elasticsearch 135 | 136 | from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings 137 | 138 | # Define the model ID and input field name (if different from default) 139 | model_id = "your_model_id" 140 | # Optional, only if different from 'text_field' 141 | input_field = "your_input_field" 142 | 143 | # Create Elasticsearch connection 144 | es_connection = Elasticsearch( 145 | hosts=["localhost:9200"], http_auth=("user", "password") 146 | ) 147 | 148 | # Instantiate ElasticsearchEmbeddings using the existing connection 149 | embeddings = ElasticsearchEmbeddings.from_es_connection( 150 | model_id, 151 | es_connection, 152 | input_field=input_field, 153 | ) 154 | 155 | documents = [ 156 | "This is an example document.", 157 | "Another example document to generate embeddings for.", 158 | ] 159 | embeddings_generator.embed_documents(documents) 160 | """ 161 | from elasticsearch._async.client.ml import MlClient 162 | 163 | # Create an MlClient from the given Elasticsearch connection 164 | client = MlClient(es_connection) 165 | 166 | # Return a new instance of the ElasticsearchEmbeddings class with 167 | # the MlClient, model_id, and input_field 168 | return cls(client, model_id, input_field=input_field) 169 | 170 | async def _embedding_func(self, texts: List[str]) -> List[List[float]]: 171 | """ 172 | Generate embeddings for the given texts using the Elasticsearch model. 173 | 174 | Args: 175 | texts (List[str]): A list of text strings to generate embeddings for. 176 | 177 | Returns: 178 | List[List[float]]: A list of embeddings, one for each text in the input 179 | list. 180 | """ 181 | response = await self.client.infer_trained_model( 182 | model_id=self.model_id, docs=[{self.input_field: text} for text in texts] 183 | ) 184 | 185 | embeddings = [doc["predicted_value"] for doc in response["inference_results"]] 186 | return embeddings 187 | 188 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 189 | """ 190 | Generate embeddings for a list of documents. 191 | 192 | Args: 193 | texts (List[str]): A list of document text strings to generate embeddings 194 | for. 195 | 196 | Returns: 197 | List[List[float]]: A list of embeddings, one for each document in the input 198 | list. 199 | """ 200 | return await self._embedding_func(texts) 201 | 202 | async def aembed_query(self, text: str) -> List[float]: 203 | """ 204 | Generate an embedding for a single query text. 205 | 206 | Args: 207 | text (str): The query text to generate an embedding for. 208 | 209 | Returns: 210 | List[float]: The embedding for the input query text. 211 | """ 212 | return (await self._embedding_func([text]))[0] 213 | 214 | 215 | class AsyncEmbeddingServiceAdapter(AsyncEmbeddingService): 216 | """ 217 | Adapter for LangChain Embeddings to support the EmbeddingService interface from 218 | elasticsearch.helpers.vectorstore. 219 | """ 220 | 221 | def __init__(self, langchain_embeddings: Embeddings): 222 | self._langchain_embeddings = langchain_embeddings 223 | 224 | def __eq__(self, other): # type: ignore[no-untyped-def] 225 | if isinstance(other, self.__class__): 226 | return self.__dict__ == other.__dict__ 227 | else: 228 | return False 229 | 230 | async def embed_documents(self, texts: List[str]) -> List[List[float]]: 231 | """ 232 | Generate embeddings for a list of documents. 233 | 234 | Args: 235 | texts (List[str]): A list of document text strings to generate embeddings 236 | for. 237 | 238 | Returns: 239 | List[List[float]]: A list of embeddings, one for each document in the input 240 | list. 241 | """ 242 | return await self._langchain_embeddings.aembed_documents(texts) 243 | 244 | async def embed_query(self, text: str) -> List[float]: 245 | """ 246 | Generate an embedding for a single query text. 247 | 248 | Args: 249 | text (str): The query text to generate an embedding for. 250 | 251 | Returns: 252 | List[float]: The embedding for the input query text. 253 | """ 254 | return await self._langchain_embeddings.aembed_query(text) 255 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast 3 | 4 | from elasticsearch import AsyncElasticsearch 5 | from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun 6 | from langchain_core.documents import Document 7 | from langchain_core.retrievers import BaseRetriever 8 | 9 | from langchain_elasticsearch._utilities import async_with_user_agent_header 10 | from langchain_elasticsearch.client import create_async_elasticsearch_client 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class AsyncElasticsearchRetriever(BaseRetriever): 16 | """ 17 | Elasticsearch retriever 18 | 19 | Args: 20 | es_client: Elasticsearch client connection. Alternatively you can use the 21 | `from_es_params` method with parameters to initialize the client. 22 | index_name: The name of the index to query. Can also be a list of names. 23 | body_func: Function to create an Elasticsearch DSL query body from a search 24 | string. The returned query body must fit what you would normally send in a 25 | POST request the the _search endpoint. If applicable, it also includes 26 | parameters the `size` parameter etc. 27 | content_field: The document field name that contains the page content. If 28 | multiple indices are queried, specify a dict {index_name: field_name} here. 29 | document_mapper: Function to map Elasticsearch hits to LangChain Documents. 30 | 31 | For synchronous applications, use the ``ElasticsearchRetriever`` class. 32 | For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. 33 | """ 34 | 35 | es_client: AsyncElasticsearch 36 | index_name: Union[str, Sequence[str]] 37 | body_func: Callable[[str], Dict] 38 | content_field: Optional[Union[str, Mapping[str, str]]] = None 39 | document_mapper: Optional[Callable[[Mapping], Document]] = None 40 | 41 | def __init__(self, **kwargs: Any) -> None: 42 | super().__init__(**kwargs) 43 | 44 | if self.content_field is None and self.document_mapper is None: 45 | raise ValueError("One of content_field or document_mapper must be defined.") 46 | if self.content_field is not None and self.document_mapper is not None: 47 | raise ValueError( 48 | "Both content_field and document_mapper are defined. " 49 | "Please provide only one." 50 | ) 51 | 52 | if not self.document_mapper: 53 | if isinstance(self.content_field, str): 54 | self.document_mapper = self._single_field_mapper 55 | elif isinstance(self.content_field, Mapping): 56 | self.document_mapper = self._multi_field_mapper 57 | else: 58 | raise ValueError( 59 | "unknown type for content_field, expected string or dict." 60 | ) 61 | 62 | self.es_client = async_with_user_agent_header(self.es_client, "langchain-py-r") 63 | 64 | @classmethod 65 | def from_es_params( 66 | cls, 67 | index_name: Union[str, Sequence[str]], 68 | body_func: Callable[[str], Dict], 69 | content_field: Optional[Union[str, Mapping[str, str]]] = None, 70 | document_mapper: Optional[Callable[[Mapping], Document]] = None, 71 | url: Optional[str] = None, 72 | cloud_id: Optional[str] = None, 73 | api_key: Optional[str] = None, 74 | username: Optional[str] = None, 75 | password: Optional[str] = None, 76 | params: Optional[Dict[str, Any]] = None, 77 | ) -> "AsyncElasticsearchRetriever": 78 | client = None 79 | try: 80 | client = create_async_elasticsearch_client( 81 | url=url, 82 | cloud_id=cloud_id, 83 | api_key=api_key, 84 | username=username, 85 | password=password, 86 | params=params, 87 | ) 88 | except Exception as err: 89 | logger.error(f"Error connecting to Elasticsearch: {err}") 90 | raise err 91 | 92 | return cls( 93 | es_client=client, 94 | index_name=index_name, 95 | body_func=body_func, 96 | content_field=content_field, 97 | document_mapper=document_mapper, 98 | ) 99 | 100 | async def _aget_relevant_documents( 101 | self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun 102 | ) -> List[Document]: 103 | if not self.es_client or not self.document_mapper: 104 | raise ValueError("faulty configuration") # should not happen 105 | 106 | body = self.body_func(query) 107 | results = await self.es_client.search(index=self.index_name, body=body) 108 | return [self.document_mapper(hit) for hit in results["hits"]["hits"]] 109 | 110 | def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: 111 | content = hit["_source"].pop(self.content_field) 112 | return Document(page_content=content, metadata=hit) 113 | 114 | def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: 115 | self.content_field = cast(Mapping, self.content_field) 116 | field = self.content_field[hit["_index"]] 117 | content = hit["_source"].pop(field) 118 | return Document(page_content=content, metadata=hit) 119 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_sync/chat_history.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from time import time 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence 5 | 6 | from langchain_core.chat_history import BaseChatMessageHistory 7 | from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict 8 | 9 | from langchain_elasticsearch._utilities import with_user_agent_header 10 | from langchain_elasticsearch.client import create_elasticsearch_client 11 | 12 | if TYPE_CHECKING: 13 | from elasticsearch import Elasticsearch 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class ElasticsearchChatMessageHistory(BaseChatMessageHistory): 19 | """Chat message history that stores history in Elasticsearch. 20 | 21 | Args: 22 | es_url: URL of the Elasticsearch instance to connect to. 23 | es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. 24 | es_user: Username to use when connecting to Elasticsearch. 25 | es_password: Password to use when connecting to Elasticsearch. 26 | es_api_key: API key to use when connecting to Elasticsearch. 27 | es_connection: Optional pre-existing Elasticsearch connection. 28 | esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True. 29 | index: Name of the index to use. 30 | session_id: Arbitrary key that is used to store the messages 31 | of a single chat session. 32 | 33 | For synchronous applications, use the `ElasticsearchChatMessageHistory` class. 34 | For asyhchronous applications, use the `AsyncElasticsearchChatMessageHistory` class. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | index: str, 40 | session_id: str, 41 | *, 42 | es_connection: Optional["Elasticsearch"] = None, 43 | es_url: Optional[str] = None, 44 | es_cloud_id: Optional[str] = None, 45 | es_user: Optional[str] = None, 46 | es_api_key: Optional[str] = None, 47 | es_password: Optional[str] = None, 48 | esnsure_ascii: Optional[bool] = True, 49 | ): 50 | self.index: str = index 51 | self.session_id: str = session_id 52 | self.ensure_ascii = esnsure_ascii 53 | 54 | # Initialize Elasticsearch client from passed client arg or connection info 55 | if es_connection is not None: 56 | self.client = es_connection 57 | elif es_url is not None or es_cloud_id is not None: 58 | try: 59 | self.client = create_elasticsearch_client( 60 | url=es_url, 61 | username=es_user, 62 | password=es_password, 63 | cloud_id=es_cloud_id, 64 | api_key=es_api_key, 65 | ) 66 | except Exception as err: 67 | logger.error(f"Error connecting to Elasticsearch: {err}") 68 | raise err 69 | else: 70 | raise ValueError( 71 | """Either provide a pre-existing Elasticsearch connection, \ 72 | or valid credentials for creating a new connection.""" 73 | ) 74 | 75 | self.client = with_user_agent_header(self.client, "langchain-py-ms") 76 | self.created = False 77 | 78 | def create_if_missing(self) -> None: 79 | if not self.created: 80 | if self.client.indices.exists(index=self.index): 81 | logger.debug( 82 | ( 83 | f"Chat history index {self.index} already exists, " 84 | "skipping creation." 85 | ) 86 | ) 87 | else: 88 | logger.debug(f"Creating index {self.index} for storing chat history.") 89 | 90 | self.client.indices.create( 91 | index=self.index, 92 | mappings={ 93 | "properties": { 94 | "session_id": {"type": "keyword"}, 95 | "created_at": {"type": "date"}, 96 | "history": {"type": "text"}, 97 | } 98 | }, 99 | ) 100 | self.created = True 101 | 102 | def get_messages(self) -> List[BaseMessage]: # type: ignore[override] 103 | """Retrieve the messages from Elasticsearch""" 104 | from elasticsearch import ApiError 105 | 106 | self.create_if_missing() 107 | 108 | search_after: Dict[str, Any] = {} 109 | items = [] 110 | while True: 111 | try: 112 | result = self.client.search( 113 | index=self.index, 114 | query={"term": {"session_id": self.session_id}}, 115 | sort="created_at:asc", 116 | size=100, 117 | **search_after, 118 | ) 119 | except ApiError as err: 120 | logger.error(f"Could not retrieve messages from Elasticsearch: {err}") 121 | raise err 122 | 123 | if result and len(result["hits"]["hits"]) > 0: 124 | items += [ 125 | json.loads(document["_source"]["history"]) 126 | for document in result["hits"]["hits"] 127 | ] 128 | search_after = {"search_after": result["hits"]["hits"][-1]["sort"]} 129 | else: 130 | break 131 | 132 | return messages_from_dict(items) 133 | 134 | def add_message(self, message: BaseMessage) -> None: 135 | """Add messages to the chat session in Elasticsearch""" 136 | try: 137 | from elasticsearch import ApiError 138 | 139 | self.create_if_missing() 140 | self.client.index( 141 | index=self.index, 142 | document={ 143 | "session_id": self.session_id, 144 | "created_at": round(time() * 1000), 145 | "history": json.dumps( 146 | message_to_dict(message), 147 | ensure_ascii=bool(self.ensure_ascii), 148 | ), 149 | }, 150 | refresh=True, 151 | ) 152 | except ApiError as err: 153 | logger.error(f"Could not add message to Elasticsearch: {err}") 154 | raise err 155 | 156 | def add_messages(self, messages: Sequence[BaseMessage]) -> None: 157 | for message in messages: 158 | self.add_message(message) 159 | 160 | def clear(self) -> None: 161 | """Clear session memory in Elasticsearch""" 162 | try: 163 | from elasticsearch import ApiError 164 | 165 | self.create_if_missing() 166 | self.client.delete_by_query( 167 | index=self.index, 168 | query={"term": {"session_id": self.session_id}}, 169 | refresh=True, 170 | ) 171 | except ApiError as err: 172 | logger.error(f"Could not clear session memory in Elasticsearch: {err}") 173 | raise err 174 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_sync/embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, List, Optional 4 | 5 | from elasticsearch import Elasticsearch 6 | from elasticsearch.helpers.vectorstore import EmbeddingService 7 | from langchain_core.embeddings import Embeddings 8 | from langchain_core.utils import get_from_env 9 | 10 | if TYPE_CHECKING: 11 | from elasticsearch._sync.client.ml import MlClient 12 | 13 | 14 | class ElasticsearchEmbeddings(Embeddings): 15 | """Elasticsearch embedding models. 16 | 17 | This class provides an interface to generate embeddings using a model deployed 18 | in an Elasticsearch cluster. It requires an Elasticsearch connection object 19 | and the model_id of the model deployed in the cluster. 20 | 21 | In Elasticsearch you need to have an embedding model loaded and deployed. 22 | - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html 23 | - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html 24 | 25 | For synchronous applications, use the `ElasticsearchEmbeddings` class. 26 | For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class. 27 | """ # noqa: E501 28 | 29 | def __init__( 30 | self, 31 | client: MlClient, 32 | model_id: str, 33 | *, 34 | input_field: str = "text_field", 35 | ): 36 | """ 37 | Initialize the ElasticsearchEmbeddings instance. 38 | 39 | Args: 40 | client (MlClient): An Elasticsearch ML client object. 41 | model_id (str): The model_id of the model deployed in the Elasticsearch 42 | cluster. 43 | input_field (str): The name of the key for the input text field in the 44 | document. Defaults to 'text_field'. 45 | """ 46 | self.client = client 47 | self.model_id = model_id 48 | self.input_field = input_field 49 | 50 | @classmethod 51 | def from_credentials( 52 | cls, 53 | model_id: str, 54 | *, 55 | es_cloud_id: Optional[str] = None, 56 | es_api_key: Optional[str] = None, 57 | input_field: str = "text_field", 58 | ) -> ElasticsearchEmbeddings: 59 | """Instantiate embeddings from Elasticsearch credentials. 60 | 61 | Args: 62 | model_id (str): The model_id of the model deployed in the Elasticsearch 63 | cluster. 64 | input_field (str): The name of the key for the input text field in the 65 | document. Defaults to 'text_field'. 66 | es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. 67 | es_user: (str, optional): Elasticsearch username. 68 | es_password: (str, optional): Elasticsearch password. 69 | 70 | Example: 71 | .. code-block:: python 72 | 73 | from langchain_elasticserach.embeddings import ElasticsearchEmbeddings 74 | 75 | # Define the model ID and input field name (if different from default) 76 | model_id = "your_model_id" 77 | # Optional, only if different from 'text_field' 78 | input_field = "your_input_field" 79 | 80 | # Credentials can be passed in two ways. Either set the env vars 81 | # ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically 82 | # pulled in, or pass them in directly as kwargs. 83 | embeddings = ElasticsearchEmbeddings.from_credentials( 84 | model_id, 85 | input_field=input_field, 86 | # es_cloud_id="foo", 87 | # es_user="bar", 88 | # es_password="baz", 89 | ) 90 | 91 | documents = [ 92 | "This is an example document.", 93 | "Another example document to generate embeddings for.", 94 | ] 95 | embeddings_generator.embed_documents(documents) 96 | """ 97 | from elasticsearch._sync.client.ml import MlClient 98 | 99 | es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") 100 | es_api_key = es_api_key or get_from_env("es_api_key", "ES_API_KEY") 101 | 102 | # Connect to Elasticsearch 103 | es_connection = Elasticsearch(cloud_id=es_cloud_id, api_key=es_api_key) 104 | client = MlClient(es_connection) 105 | return cls(client, model_id, input_field=input_field) 106 | 107 | @classmethod 108 | def from_es_connection( 109 | cls, 110 | model_id: str, 111 | es_connection: Elasticsearch, 112 | input_field: str = "text_field", 113 | ) -> ElasticsearchEmbeddings: 114 | """ 115 | Instantiate embeddings from an existing Elasticsearch connection. 116 | 117 | This method provides a way to create an instance of the ElasticsearchEmbeddings 118 | class using an existing Elasticsearch connection. The connection object is used 119 | to create an MlClient, which is then used to initialize the 120 | ElasticsearchEmbeddings instance. 121 | 122 | Args: 123 | model_id (str): The model_id of the model deployed in the Elasticsearch cluster. 124 | es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch 125 | connection object. input_field (str, optional): The name of the key for the 126 | input text field in the document. Defaults to 'text_field'. 127 | 128 | Returns: 129 | ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. 130 | 131 | Example: 132 | .. code-block:: python 133 | 134 | from elasticsearch import Elasticsearch 135 | 136 | from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings 137 | 138 | # Define the model ID and input field name (if different from default) 139 | model_id = "your_model_id" 140 | # Optional, only if different from 'text_field' 141 | input_field = "your_input_field" 142 | 143 | # Create Elasticsearch connection 144 | es_connection = Elasticsearch( 145 | hosts=["localhost:9200"], http_auth=("user", "password") 146 | ) 147 | 148 | # Instantiate ElasticsearchEmbeddings using the existing connection 149 | embeddings = ElasticsearchEmbeddings.from_es_connection( 150 | model_id, 151 | es_connection, 152 | input_field=input_field, 153 | ) 154 | 155 | documents = [ 156 | "This is an example document.", 157 | "Another example document to generate embeddings for.", 158 | ] 159 | embeddings_generator.embed_documents(documents) 160 | """ 161 | from elasticsearch._sync.client.ml import MlClient 162 | 163 | # Create an MlClient from the given Elasticsearch connection 164 | client = MlClient(es_connection) 165 | 166 | # Return a new instance of the ElasticsearchEmbeddings class with 167 | # the MlClient, model_id, and input_field 168 | return cls(client, model_id, input_field=input_field) 169 | 170 | def _embedding_func(self, texts: List[str]) -> List[List[float]]: 171 | """ 172 | Generate embeddings for the given texts using the Elasticsearch model. 173 | 174 | Args: 175 | texts (List[str]): A list of text strings to generate embeddings for. 176 | 177 | Returns: 178 | List[List[float]]: A list of embeddings, one for each text in the input 179 | list. 180 | """ 181 | response = self.client.infer_trained_model( 182 | model_id=self.model_id, docs=[{self.input_field: text} for text in texts] 183 | ) 184 | 185 | embeddings = [doc["predicted_value"] for doc in response["inference_results"]] 186 | return embeddings 187 | 188 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 189 | """ 190 | Generate embeddings for a list of documents. 191 | 192 | Args: 193 | texts (List[str]): A list of document text strings to generate embeddings 194 | for. 195 | 196 | Returns: 197 | List[List[float]]: A list of embeddings, one for each document in the input 198 | list. 199 | """ 200 | return self._embedding_func(texts) 201 | 202 | def embed_query(self, text: str) -> List[float]: 203 | """ 204 | Generate an embedding for a single query text. 205 | 206 | Args: 207 | text (str): The query text to generate an embedding for. 208 | 209 | Returns: 210 | List[float]: The embedding for the input query text. 211 | """ 212 | return (self._embedding_func([text]))[0] 213 | 214 | 215 | class EmbeddingServiceAdapter(EmbeddingService): 216 | """ 217 | Adapter for LangChain Embeddings to support the EmbeddingService interface from 218 | elasticsearch.helpers.vectorstore. 219 | """ 220 | 221 | def __init__(self, langchain_embeddings: Embeddings): 222 | self._langchain_embeddings = langchain_embeddings 223 | 224 | def __eq__(self, other): # type: ignore[no-untyped-def] 225 | if isinstance(other, self.__class__): 226 | return self.__dict__ == other.__dict__ 227 | else: 228 | return False 229 | 230 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 231 | """ 232 | Generate embeddings for a list of documents. 233 | 234 | Args: 235 | texts (List[str]): A list of document text strings to generate embeddings 236 | for. 237 | 238 | Returns: 239 | List[List[float]]: A list of embeddings, one for each document in the input 240 | list. 241 | """ 242 | return self._langchain_embeddings.embed_documents(texts) 243 | 244 | def embed_query(self, text: str) -> List[float]: 245 | """ 246 | Generate an embedding for a single query text. 247 | 248 | Args: 249 | text (str): The query text to generate an embedding for. 250 | 251 | Returns: 252 | List[float]: The embedding for the input query text. 253 | """ 254 | return self._langchain_embeddings.embed_query(text) 255 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast 3 | 4 | from elasticsearch import Elasticsearch 5 | from langchain_core.callbacks import CallbackManagerForRetrieverRun 6 | from langchain_core.documents import Document 7 | from langchain_core.retrievers import BaseRetriever 8 | 9 | from langchain_elasticsearch._utilities import with_user_agent_header 10 | from langchain_elasticsearch.client import create_elasticsearch_client 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ElasticsearchRetriever(BaseRetriever): 16 | """ 17 | Elasticsearch retriever 18 | 19 | Args: 20 | es_client: Elasticsearch client connection. Alternatively you can use the 21 | `from_es_params` method with parameters to initialize the client. 22 | index_name: The name of the index to query. Can also be a list of names. 23 | body_func: Function to create an Elasticsearch DSL query body from a search 24 | string. The returned query body must fit what you would normally send in a 25 | POST request the the _search endpoint. If applicable, it also includes 26 | parameters the `size` parameter etc. 27 | content_field: The document field name that contains the page content. If 28 | multiple indices are queried, specify a dict {index_name: field_name} here. 29 | document_mapper: Function to map Elasticsearch hits to LangChain Documents. 30 | 31 | For synchronous applications, use the ``ElasticsearchRetriever`` class. 32 | For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. 33 | """ 34 | 35 | es_client: Elasticsearch 36 | index_name: Union[str, Sequence[str]] 37 | body_func: Callable[[str], Dict] 38 | content_field: Optional[Union[str, Mapping[str, str]]] = None 39 | document_mapper: Optional[Callable[[Mapping], Document]] = None 40 | 41 | def __init__(self, **kwargs: Any) -> None: 42 | super().__init__(**kwargs) 43 | 44 | if self.content_field is None and self.document_mapper is None: 45 | raise ValueError("One of content_field or document_mapper must be defined.") 46 | if self.content_field is not None and self.document_mapper is not None: 47 | raise ValueError( 48 | "Both content_field and document_mapper are defined. " 49 | "Please provide only one." 50 | ) 51 | 52 | if not self.document_mapper: 53 | if isinstance(self.content_field, str): 54 | self.document_mapper = self._single_field_mapper 55 | elif isinstance(self.content_field, Mapping): 56 | self.document_mapper = self._multi_field_mapper 57 | else: 58 | raise ValueError( 59 | "unknown type for content_field, expected string or dict." 60 | ) 61 | 62 | self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") 63 | 64 | @classmethod 65 | def from_es_params( 66 | cls, 67 | index_name: Union[str, Sequence[str]], 68 | body_func: Callable[[str], Dict], 69 | content_field: Optional[Union[str, Mapping[str, str]]] = None, 70 | document_mapper: Optional[Callable[[Mapping], Document]] = None, 71 | url: Optional[str] = None, 72 | cloud_id: Optional[str] = None, 73 | api_key: Optional[str] = None, 74 | username: Optional[str] = None, 75 | password: Optional[str] = None, 76 | params: Optional[Dict[str, Any]] = None, 77 | ) -> "ElasticsearchRetriever": 78 | client = None 79 | try: 80 | client = create_elasticsearch_client( 81 | url=url, 82 | cloud_id=cloud_id, 83 | api_key=api_key, 84 | username=username, 85 | password=password, 86 | params=params, 87 | ) 88 | except Exception as err: 89 | logger.error(f"Error connecting to Elasticsearch: {err}") 90 | raise err 91 | 92 | return cls( 93 | es_client=client, 94 | index_name=index_name, 95 | body_func=body_func, 96 | content_field=content_field, 97 | document_mapper=document_mapper, 98 | ) 99 | 100 | def _get_relevant_documents( 101 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun 102 | ) -> List[Document]: 103 | if not self.es_client or not self.document_mapper: 104 | raise ValueError("faulty configuration") # should not happen 105 | 106 | body = self.body_func(query) 107 | results = self.es_client.search(index=self.index_name, body=body) 108 | return [self.document_mapper(hit) for hit in results["hits"]["hits"]] 109 | 110 | def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: 111 | content = hit["_source"].pop(self.content_field) 112 | return Document(page_content=content, metadata=hit) 113 | 114 | def _multi_field_mapper(self, hit: Mapping[str, Any]) -> Document: 115 | self.content_field = cast(Mapping, self.content_field) 116 | field = self.content_field[hit["_index"]] 117 | content = hit["_source"].pop(field) 118 | return Document(page_content=content, metadata=hit) 119 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/cache.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterator, List, Optional, Sequence, Tuple 2 | 3 | from elasticsearch import Elasticsearch, exceptions, helpers # noqa: F401 4 | from elasticsearch.helpers import BulkIndexError # noqa: F401 5 | from langchain_core.caches import RETURN_VAL_TYPE, BaseCache # noqa: F401 6 | from langchain_core.load import dumps, loads # noqa: F401 7 | from langchain_core.stores import ByteStore # noqa: F401 8 | 9 | from langchain_elasticsearch._async.cache import ( 10 | AsyncElasticsearchCache as _AsyncElasticsearchCache, 11 | ) 12 | from langchain_elasticsearch._async.cache import ( 13 | AsyncElasticsearchEmbeddingsCache as _AsyncElasticsearchEmbeddingsCache, 14 | ) 15 | from langchain_elasticsearch._sync.cache import ( 16 | ElasticsearchCache as _ElasticsearchCache, 17 | ) 18 | from langchain_elasticsearch._sync.cache import ( 19 | ElasticsearchEmbeddingsCache as _ElasticsearchEmbeddingsCache, 20 | ) 21 | from langchain_elasticsearch.client import ( # noqa: F401 22 | create_async_elasticsearch_client, 23 | create_elasticsearch_client, 24 | ) 25 | 26 | 27 | # langchain defines some sync methods as abstract in its base class 28 | # so we have to add dummy methods for them, even though we only use the async versions 29 | class AsyncElasticsearchCache(_AsyncElasticsearchCache): 30 | def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: 31 | raise NotImplementedError("This class is asynchronous, use alookup()") 32 | 33 | def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: 34 | raise NotImplementedError("This class is asynchronous, use aupdate()") 35 | 36 | def clear(self, **kwargs: Any) -> None: 37 | raise NotImplementedError("This class is asynchronous, use aclear()") 38 | 39 | 40 | # langchain defines some sync methods as abstract in its base class 41 | # so we have to add dummy methods for them, even though we only use the async versions 42 | class AsyncElasticsearchEmbeddingsCache(_AsyncElasticsearchEmbeddingsCache): 43 | def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: 44 | raise NotImplementedError("This class is asynchronous, use amget()") 45 | 46 | def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: 47 | raise NotImplementedError("This class is asynchronous, use amset()") 48 | 49 | def mdelete(self, keys: Sequence[str]) -> None: 50 | raise NotImplementedError("This class is asynchronous, use amdelete()") 51 | 52 | def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: 53 | raise NotImplementedError("This class is asynchronous, use ayield_keys()") 54 | 55 | 56 | # these are only defined here so that they are picked up by Langchain's docs generator 57 | class ElasticsearchCache(_ElasticsearchCache): 58 | pass 59 | 60 | 61 | class ElasticsearchEmbeddingsCache(_ElasticsearchEmbeddingsCache): 62 | pass 63 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/chat_history.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.messages import BaseMessage 4 | 5 | from langchain_elasticsearch._async.chat_history import ( 6 | AsyncElasticsearchChatMessageHistory as _AsyncElasticsearchChatMessageHistory, 7 | ) 8 | from langchain_elasticsearch._sync.chat_history import ( 9 | ElasticsearchChatMessageHistory as _ElasticsearchChatMessageHistory, 10 | ) 11 | 12 | 13 | # add the messages property which is only in the sync version 14 | class ElasticsearchChatMessageHistory(_ElasticsearchChatMessageHistory): 15 | @property 16 | def messages(self) -> List[BaseMessage]: # type: ignore[override] 17 | return self.get_messages() 18 | 19 | 20 | # langchain defines some sync methods as abstract in its base class 21 | # so we have to add dummy methods for them, even though we only use the async versions 22 | class AsyncElasticsearchChatMessageHistory(_AsyncElasticsearchChatMessageHistory): 23 | def clear(self) -> None: 24 | raise NotImplementedError("This class is asynchronous, use aclear()") 25 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/client.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from elasticsearch import AsyncElasticsearch, Elasticsearch 4 | 5 | 6 | def create_elasticsearch_client( 7 | url: Optional[str] = None, 8 | cloud_id: Optional[str] = None, 9 | api_key: Optional[str] = None, 10 | username: Optional[str] = None, 11 | password: Optional[str] = None, 12 | params: Optional[Dict[str, Any]] = None, 13 | ) -> Elasticsearch: 14 | if url and cloud_id: 15 | raise ValueError( 16 | "Both es_url and cloud_id are defined. Please provide only one." 17 | ) 18 | 19 | connection_params: Dict[str, Any] = {} 20 | 21 | if url: 22 | connection_params["hosts"] = [url] 23 | elif cloud_id: 24 | connection_params["cloud_id"] = cloud_id 25 | else: 26 | raise ValueError("Please provide either elasticsearch_url or cloud_id.") 27 | 28 | if api_key: 29 | connection_params["api_key"] = api_key 30 | elif username and password: 31 | connection_params["basic_auth"] = (username, password) 32 | 33 | if params is not None: 34 | connection_params.update(params) 35 | 36 | es_client = Elasticsearch(**connection_params) 37 | 38 | es_client.info() # test connection 39 | 40 | return es_client 41 | 42 | 43 | def create_async_elasticsearch_client( 44 | url: Optional[str] = None, 45 | cloud_id: Optional[str] = None, 46 | api_key: Optional[str] = None, 47 | username: Optional[str] = None, 48 | password: Optional[str] = None, 49 | params: Optional[Dict[str, Any]] = None, 50 | ) -> AsyncElasticsearch: 51 | if url and cloud_id: 52 | raise ValueError( 53 | "Both es_url and cloud_id are defined. Please provide only one." 54 | ) 55 | 56 | connection_params: Dict[str, Any] = {} 57 | 58 | if url: 59 | connection_params["hosts"] = [url] 60 | elif cloud_id: 61 | connection_params["cloud_id"] = cloud_id 62 | else: 63 | raise ValueError("Please provide either elasticsearch_url or cloud_id.") 64 | 65 | if api_key: 66 | connection_params["api_key"] = api_key 67 | elif username and password: 68 | connection_params["basic_auth"] = (username, password) 69 | 70 | if params is not None: 71 | connection_params.update(params) 72 | 73 | es_client = AsyncElasticsearch(**connection_params) 74 | return es_client 75 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.embeddings import Embeddings # noqa: F401 4 | 5 | from langchain_elasticsearch._async.embeddings import ( 6 | AsyncElasticsearchEmbeddings as _AsyncElasticsearchEmbeddings, 7 | ) 8 | from langchain_elasticsearch._async.embeddings import ( 9 | AsyncEmbeddingService as _AsyncEmbeddingService, 10 | ) 11 | from langchain_elasticsearch._async.embeddings import ( 12 | AsyncEmbeddingServiceAdapter as _AsyncEmbeddingServiceAdapter, 13 | ) 14 | from langchain_elasticsearch._sync.embeddings import ( 15 | ElasticsearchEmbeddings as _ElasticsearchEmbeddings, 16 | ) 17 | from langchain_elasticsearch._sync.embeddings import ( 18 | EmbeddingService as _EmbeddingService, 19 | ) 20 | from langchain_elasticsearch._sync.embeddings import ( 21 | EmbeddingServiceAdapter as _EmbeddingServiceAdapter, 22 | ) 23 | 24 | 25 | # langchain defines some sync methods as abstract in its base class 26 | # so we have to add dummy methods for them, even though we only use the async versions 27 | class AsyncElasticsearchEmbeddings(_AsyncElasticsearchEmbeddings): 28 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 29 | raise NotImplementedError("This class is asynchronous, use aembed_documents()") 30 | 31 | def embed_query(self, text: str) -> List[float]: 32 | raise NotImplementedError("This class is asynchronous, use aembed_query()") 33 | 34 | 35 | # these are only defined here so that they are picked up by Langchain's docs generator 36 | class ElasticsearchEmbeddings(_ElasticsearchEmbeddings): 37 | pass 38 | 39 | 40 | class EmbeddingService(_EmbeddingService): 41 | pass 42 | 43 | 44 | class EmbeddingServiceAdapter(_EmbeddingServiceAdapter): 45 | pass 46 | 47 | 48 | class AsyncEmbeddingService(_AsyncEmbeddingService): 49 | pass 50 | 51 | 52 | class AsyncEmbeddingServiceAdapter(_AsyncEmbeddingServiceAdapter): 53 | pass 54 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/langchain_elasticsearch/py.typed -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/retrievers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from langchain_core.callbacks import CallbackManagerForRetrieverRun 4 | from langchain_core.documents import Document 5 | 6 | from langchain_elasticsearch._async.retrievers import ( 7 | AsyncElasticsearchRetriever as _AsyncElasticsearchRetriever, 8 | ) 9 | from langchain_elasticsearch._sync.retrievers import ( 10 | ElasticsearchRetriever as _ElasticsearchRetriever, 11 | ) 12 | 13 | 14 | # langchain defines some sync methods as abstract in its base class 15 | # so we have to add dummy methods for them, even though we only use the async versions 16 | class AsyncElasticsearchRetriever(_AsyncElasticsearchRetriever): 17 | def _get_relevant_documents( 18 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any 19 | ) -> List[Document]: 20 | raise NotImplementedError( 21 | "This class is asynchronous, use _aget_relevant_documents()" 22 | ) 23 | 24 | 25 | # this is only defined here so that it is picked up by Langchain's docs generator 26 | class ElasticsearchRetriever(_ElasticsearchRetriever): 27 | pass 28 | -------------------------------------------------------------------------------- /libs/elasticsearch/langchain_elasticsearch/vectorstores.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from langchain_elasticsearch._async.vectorstores import ( 4 | AsyncBM25Strategy as _AsyncBM25Strategy, 5 | ) 6 | from langchain_elasticsearch._async.vectorstores import ( 7 | AsyncDenseVectorScriptScoreStrategy as _AsyncDenseVectorScriptScoreStrategy, 8 | ) 9 | from langchain_elasticsearch._async.vectorstores import ( 10 | AsyncDenseVectorStrategy as _AsyncDenseVectorStrategy, 11 | ) 12 | from langchain_elasticsearch._async.vectorstores import ( 13 | AsyncElasticsearchStore as _AsyncElasticsearchStore, 14 | ) 15 | from langchain_elasticsearch._async.vectorstores import ( 16 | AsyncRetrievalStrategy as _AsyncRetrievalStrategy, 17 | ) 18 | from langchain_elasticsearch._async.vectorstores import ( 19 | AsyncSparseVectorStrategy as _AsyncSparseVectorStrategy, 20 | ) 21 | from langchain_elasticsearch._async.vectorstores import ( 22 | DistanceMetric, # noqa: F401 23 | Document, 24 | Embeddings, 25 | ) 26 | from langchain_elasticsearch._sync.vectorstores import ( 27 | BM25Strategy as _BM25Strategy, 28 | ) 29 | from langchain_elasticsearch._sync.vectorstores import ( 30 | DenseVectorScriptScoreStrategy as _DenseVectorScriptScoreStrategy, 31 | ) 32 | from langchain_elasticsearch._sync.vectorstores import ( 33 | DenseVectorStrategy as _DenseVectorStrategy, 34 | ) 35 | from langchain_elasticsearch._sync.vectorstores import ( 36 | ElasticsearchStore as _ElasticsearchStore, 37 | ) 38 | from langchain_elasticsearch._sync.vectorstores import ( 39 | RetrievalStrategy as _RetrievalStrategy, 40 | ) 41 | from langchain_elasticsearch._sync.vectorstores import ( 42 | SparseVectorStrategy as _SparseVectorStrategy, 43 | ) 44 | 45 | # deprecated strategy classes 46 | from langchain_elasticsearch._utilities import ( # noqa: F401 47 | ApproxRetrievalStrategy, 48 | BaseRetrievalStrategy, 49 | BM25RetrievalStrategy, 50 | DistanceStrategy, 51 | ExactRetrievalStrategy, 52 | SparseRetrievalStrategy, 53 | ) 54 | 55 | 56 | # langchain defines some sync methods as abstract in its base class 57 | # so we have to add dummy methods for them, even though we only use the async versions 58 | class AsyncElasticsearchStore(_AsyncElasticsearchStore): 59 | @classmethod 60 | def from_texts( 61 | cls, 62 | texts: list[str], 63 | embedding: Embeddings, 64 | metadatas: Optional[list[dict]] = None, 65 | *, 66 | ids: Optional[list[str]] = None, 67 | **kwargs: Any, 68 | ) -> "AsyncElasticsearchStore": 69 | raise NotImplementedError("This class is asynchronous, use afrom_texts()") 70 | 71 | def similarity_search( 72 | self, query: str, k: int = 4, **kwargs: Any 73 | ) -> list[Document]: 74 | raise NotImplementedError( 75 | "This class is asynchronous, use asimilarity_search()" 76 | ) 77 | 78 | 79 | # these are only defined here so that they are picked up by Langchain's docs generator 80 | class ElasticsearchStore(_ElasticsearchStore): 81 | pass 82 | 83 | 84 | class BM25Strategy(_BM25Strategy): 85 | pass 86 | 87 | 88 | class DenseVectorScriptScoreStrategy(_DenseVectorScriptScoreStrategy): 89 | pass 90 | 91 | 92 | class DenseVectorStrategy(_DenseVectorStrategy): 93 | pass 94 | 95 | 96 | class RetrievalStrategy(_RetrievalStrategy): 97 | pass 98 | 99 | 100 | class SparseVectorStrategy(_SparseVectorStrategy): 101 | pass 102 | 103 | 104 | class AsyncBM25Strategy(_AsyncBM25Strategy): 105 | pass 106 | 107 | 108 | class AsyncDenseVectorScriptScoreStrategy(_AsyncDenseVectorScriptScoreStrategy): 109 | pass 110 | 111 | 112 | class AsyncDenseVectorStrategy(_AsyncDenseVectorStrategy): 113 | pass 114 | 115 | 116 | class AsyncRetrievalStrategy(_AsyncRetrievalStrategy): 117 | pass 118 | 119 | 120 | class AsyncSparseVectorStrategy(_AsyncSparseVectorStrategy): 121 | pass 122 | -------------------------------------------------------------------------------- /libs/elasticsearch/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-elasticsearch" 3 | version = "0.3.2" 4 | description = "An integration package connecting Elasticsearch and LangChain" 5 | authors = [] 6 | readme = "README.md" 7 | repository = "https://github.com/langchain-ai/langchain-elastic" 8 | license = "MIT" 9 | 10 | [tool.poetry.urls] 11 | "Source Code" = "https://github.com/langchain-ai/langchain-elastic/tree/main/libs/elasticsearch" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9,<4.0" 15 | langchain-core = "^0.3.0" 16 | elasticsearch = {version = "^8.15.1", extras = ["vectorstore_mmr"]} 17 | 18 | [tool.poetry.group.test] 19 | optional = true 20 | 21 | [tool.poetry.group.test.dependencies] 22 | pytest = "^7.3.0" 23 | freezegun = "^1.2.2" 24 | pytest-mock = "^3.10.0" 25 | syrupy = "^4.0.2" 26 | pytest-watcher = "^0.3.4" 27 | pytest-asyncio = "^0.21.1" 28 | langchain = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain" } 29 | langchain-community = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/community" } 30 | langchain-text-splitters = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/text-splitters" } 31 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } 32 | 33 | [tool.poetry.group.codespell] 34 | optional = true 35 | 36 | [tool.poetry.group.codespell.dependencies] 37 | codespell = "^2.2.0" 38 | 39 | [tool.poetry.group.lint] 40 | optional = true 41 | 42 | [tool.poetry.group.lint.dependencies] 43 | ruff = "^0.1.5" 44 | 45 | [tool.poetry.group.typing.dependencies] 46 | mypy = "^0.991" 47 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } 48 | 49 | [tool.poetry.group.dev] 50 | optional = true 51 | 52 | [tool.poetry.group.dev.dependencies] 53 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } 54 | 55 | [tool.poetry.group.test_integration] 56 | optional = true 57 | 58 | [tool.poetry.group.test_integration.dependencies] 59 | 60 | [tool.poetry.group.codegen.dependencies] 61 | unasync = "^0.6.0" 62 | 63 | [tool.ruff] 64 | select = [ 65 | "E", # pycodestyle 66 | "F", # pyflakes 67 | "I", # isort 68 | ] 69 | 70 | [tool.mypy] 71 | disallow_untyped_defs = "True" 72 | 73 | [tool.coverage.run] 74 | omit = ["tests/*"] 75 | 76 | [build-system] 77 | requires = ["poetry-core>=1.0.0"] 78 | build-backend = "poetry.core.masonry.api" 79 | 80 | [tool.pytest.ini_options] 81 | # --strict-markers will raise errors on unknown marks. 82 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 83 | # 84 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 85 | # --strict-config any warnings encountered while parsing the `pytest` 86 | # section of the configuration file raise errors. 87 | # 88 | # https://github.com/tophat/syrupy 89 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 90 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 91 | # Registering custom markers. 92 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 93 | markers = [ 94 | "requires: mark tests as requiring a specific library", 95 | "asyncio: mark tests as requiring asyncio", 96 | "sync: mark tests as performing I/O without asyncio", 97 | ] 98 | asyncio_mode = "auto" 99 | -------------------------------------------------------------------------------- /libs/elasticsearch/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | for file in files: 9 | try: 10 | SourceFileLoader("x", file).load_module() 11 | except Exception: 12 | has_faillure = True 13 | print(file) 14 | traceback.print_exc() 15 | print() 16 | 17 | sys.exit(1 if has_failure else 0) 18 | -------------------------------------------------------------------------------- /libs/elasticsearch/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/elasticsearch/scripts/run_unasync.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | from glob import glob 5 | from pathlib import Path 6 | 7 | import unasync 8 | 9 | 10 | def main(check=False): 11 | # the list of directories that need to be processed with unasync 12 | # each entry has two paths: 13 | # - the source path with the async sources 14 | # - the destination path where the sync sources should be written 15 | source_dirs = [ 16 | ( 17 | "langchain_elasticsearch/_async/", 18 | "langchain_elasticsearch/_sync/", 19 | ), 20 | ("tests/_async/", "tests/_sync/"), 21 | ("tests/integration_tests/_async/", "tests/integration_tests/_sync/"), 22 | ("tests/unit_tests/_async/", "tests/unit_tests/_sync/"), 23 | ] 24 | 25 | # Unasync all the generated async code 26 | additional_replacements = { 27 | "_async": "_sync", 28 | "AsyncElasticsearch": "Elasticsearch", 29 | "AsyncTransport": "Transport", 30 | "AsyncBM25Strategy": "BM25Strategy", 31 | "AsyncDenseVectorScriptScoreStrategy": "DenseVectorScriptScoreStrategy", 32 | "AsyncDenseVectorStrategy": "DenseVectorStrategy", 33 | "AsyncRetrievalStrategy": "RetrievalStrategy", 34 | "AsyncSparseVectorStrategy": "SparseVectorStrategy", 35 | "AsyncVectorStore": "VectorStore", 36 | "AsyncElasticsearchStore": "ElasticsearchStore", 37 | "AsyncElasticsearchEmbeddings": "ElasticsearchEmbeddings", 38 | "AsyncElasticsearchEmbeddingsCache": "ElasticsearchEmbeddingsCache", 39 | "AsyncEmbeddingServiceAdapter": "EmbeddingServiceAdapter", 40 | "AsyncEmbeddingService": "EmbeddingService", 41 | "AsyncElasticsearchRetriever": "ElasticsearchRetriever", 42 | "AsyncElasticsearchCache": "ElasticsearchCache", 43 | "AsyncElasticsearchChatMessageHistory": "ElasticsearchChatMessageHistory", 44 | "AsyncCallbackManagerForRetrieverRun": "CallbackManagerForRetrieverRun", 45 | "AsyncFakeEmbeddings": "FakeEmbeddings", 46 | "AsyncConsistentFakeEmbeddings": "ConsistentFakeEmbeddings", 47 | "AsyncRequestSavingTransport": "RequestSavingTransport", 48 | "AsyncMock": "Mock", 49 | "Embeddings": "Embeddings", 50 | "AsyncGenerator": "Generator", 51 | "AsyncIterator": "Iterator", 52 | "create_async_elasticsearch_client": "create_elasticsearch_client", 53 | "aadd_texts": "add_texts", 54 | "aadd_embeddings": "add_embeddings", 55 | "aadd_documents": "add_documents", 56 | "afrom_texts": "from_texts", 57 | "afrom_documents": "from_documents", 58 | "amax_marginal_relevance_search": "max_marginal_relevance_search", 59 | "asimilarity_search": "similarity_search", 60 | "asimilarity_search_by_vector_with_relevance_scores": "similarity_search_by_vector_with_relevance_scores", # noqa: E501 61 | "asimilarity_search_with_score": "similarity_search_with_score", 62 | "asimilarity_search_with_relevance_scores": "similarity_search_with_relevance_scores", # noqa: E501 63 | "adelete": "delete", 64 | "aclose": "close", 65 | "ainvoke": "invoke", 66 | "aembed_documents": "embed_documents", 67 | "aembed_query": "embed_query", 68 | "_aget_relevant_documents": "_get_relevant_documents", 69 | "aget_relevant_documents": "get_relevant_documents", 70 | "alookup": "lookup", 71 | "aupdate": "update", 72 | "aclear": "clear", 73 | "amget": "mget", 74 | "amset": "mset", 75 | "amdelete": "mdelete", 76 | "ayield_keys": "yield_keys", 77 | "asearch": "search", 78 | "aget_messages": "get_messages", 79 | "aadd_messages": "add_messages", 80 | "aadd_message": "add_message", 81 | "aencode_vector": "encode_vector", 82 | "assert_awaited_with": "assert_called_with", 83 | "async_es_client_fx": "es_client_fx", 84 | "async_es_embeddings_cache_fx": "es_embeddings_cache_fx", 85 | "async_es_cache_fx": "es_cache_fx", 86 | "async_bulk": "bulk", 87 | "async_with_user_agent_header": "with_user_agent_header", 88 | "asyncio": "sync", 89 | } 90 | rules = [ 91 | unasync.Rule( 92 | fromdir=dir[0], 93 | todir=f"{dir[0]}_sync_check/" if check else dir[1], 94 | additional_replacements=additional_replacements, 95 | ) 96 | for dir in source_dirs 97 | ] 98 | 99 | filepaths = [] 100 | for root, _, filenames in os.walk(Path(__file__).absolute().parent.parent): 101 | if "/site-packages" in root or "/." in root or "__pycache__" in root: 102 | continue 103 | for filename in filenames: 104 | if filename.rpartition(".")[-1] in ( 105 | "py", 106 | "pyi", 107 | ) and not filename.startswith("utils.py"): 108 | filepaths.append(os.path.join(root, filename)) 109 | 110 | unasync.unasync_files(filepaths, rules) 111 | for dir in source_dirs: 112 | output_dir = f"{dir[0]}_sync_check/" if check else dir[1] 113 | subprocess.check_call(["ruff", "format", "--target-version=py38", output_dir]) 114 | subprocess.check_call(["ruff", "check", "--fix", "--select", "I", output_dir]) 115 | for file in glob("*.py", root_dir=dir[0]): 116 | subprocess.check_call( 117 | [ 118 | "sed", 119 | "-i.bak", 120 | "s/pytest.mark.asyncio/pytest.mark.sync/", 121 | f"{output_dir}{file}", 122 | ] 123 | ) 124 | subprocess.check_call( 125 | [ 126 | "sed", 127 | "-i.bak", 128 | "s/get_messages()/messages/", 129 | f"{output_dir}{file}", 130 | ] 131 | ) 132 | subprocess.check_call(["rm", f"{output_dir}{file}.bak"]) 133 | 134 | if check: 135 | # make sure there are no differences between _sync and _sync_check 136 | subprocess.check_call( 137 | [ 138 | "diff", 139 | f"{dir[1]}{file}", 140 | f"{output_dir}{file}", 141 | ] 142 | ) 143 | 144 | if check: 145 | subprocess.check_call(["rm", "-rf", output_dir]) 146 | 147 | 148 | if __name__ == "__main__": 149 | main(check="--check" in sys.argv) 150 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/_async/fake_embeddings.py: -------------------------------------------------------------------------------- 1 | """Fake Embedding class for testing purposes.""" 2 | 3 | from typing import List 4 | 5 | from langchain_core.embeddings import Embeddings 6 | 7 | fake_texts = ["foo", "bar", "baz"] 8 | 9 | 10 | class AsyncFakeEmbeddings(Embeddings): 11 | """Fake embeddings functionality for testing.""" 12 | 13 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 14 | """Return simple embeddings. 15 | Embeddings encode each text as its index.""" 16 | return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] 17 | 18 | async def aembed_query(self, text: str) -> List[float]: 19 | """Return constant query embeddings. 20 | Embeddings are identical to embed_documents(texts)[0]. 21 | Distance to each text will be that text's index, 22 | as it was passed to embed_documents.""" 23 | return [float(1.0)] * 9 + [float(0.0)] 24 | 25 | 26 | class AsyncConsistentFakeEmbeddings(AsyncFakeEmbeddings): 27 | """Fake embeddings which remember all the texts seen so far to return consistent 28 | vectors for the same texts.""" 29 | 30 | def __init__(self, dimensionality: int = 10) -> None: 31 | self.known_texts: List[str] = [] 32 | self.dimensionality = dimensionality 33 | 34 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 35 | """Return consistent embeddings for each text seen so far.""" 36 | out_vectors = [] 37 | for text in texts: 38 | if text not in self.known_texts: 39 | self.known_texts.append(text) 40 | vector = [float(1.0)] * (self.dimensionality - 1) + [ 41 | float(self.known_texts.index(text)) 42 | ] 43 | out_vectors.append(vector) 44 | return out_vectors 45 | 46 | async def aembed_query(self, text: str) -> List[float]: 47 | """Return consistent embeddings for the text, if seen before, or a constant 48 | one if the text is unknown.""" 49 | return (await self.aembed_documents([text]))[0] 50 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/_sync/fake_embeddings.py: -------------------------------------------------------------------------------- 1 | """Fake Embedding class for testing purposes.""" 2 | 3 | from typing import List 4 | 5 | from langchain_core.embeddings import Embeddings 6 | 7 | fake_texts = ["foo", "bar", "baz"] 8 | 9 | 10 | class FakeEmbeddings(Embeddings): 11 | """Fake embeddings functionality for testing.""" 12 | 13 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 14 | """Return simple embeddings. 15 | Embeddings encode each text as its index.""" 16 | return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] 17 | 18 | def embed_query(self, text: str) -> List[float]: 19 | """Return constant query embeddings. 20 | Embeddings are identical to embed_documents(texts)[0]. 21 | Distance to each text will be that text's index, 22 | as it was passed to embed_documents.""" 23 | return [float(1.0)] * 9 + [float(0.0)] 24 | 25 | 26 | class ConsistentFakeEmbeddings(FakeEmbeddings): 27 | """Fake embeddings which remember all the texts seen so far to return consistent 28 | vectors for the same texts.""" 29 | 30 | def __init__(self, dimensionality: int = 10) -> None: 31 | self.known_texts: List[str] = [] 32 | self.dimensionality = dimensionality 33 | 34 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 35 | """Return consistent embeddings for each text seen so far.""" 36 | out_vectors = [] 37 | for text in texts: 38 | if text not in self.known_texts: 39 | self.known_texts.append(text) 40 | vector = [float(1.0)] * (self.dimensionality - 1) + [ 41 | float(self.known_texts.index(text)) 42 | ] 43 | out_vectors.append(vector) 44 | return out_vectors 45 | 46 | def embed_query(self, text: str) -> List[float]: 47 | """Return consistent embeddings for the text, if seen before, or a constant 48 | one if the text is unknown.""" 49 | return (self.embed_documents([text]))[0] 50 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | from unittest import mock 3 | from unittest.mock import AsyncMock, MagicMock 4 | 5 | import pytest 6 | from elasticsearch import AsyncElasticsearch, Elasticsearch 7 | from elasticsearch._async.client import IndicesClient as AsyncIndicesClient 8 | from elasticsearch._sync.client import IndicesClient 9 | from langchain_community.chat_models.fake import FakeMessagesListChatModel 10 | from langchain_core.language_models import BaseChatModel 11 | from langchain_core.messages import AIMessage 12 | 13 | from langchain_elasticsearch import ( 14 | AsyncElasticsearchCache, 15 | AsyncElasticsearchEmbeddingsCache, 16 | ElasticsearchCache, 17 | ElasticsearchEmbeddingsCache, 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def es_client_fx() -> Generator[MagicMock, None, None]: 23 | client_mock = MagicMock(spec=Elasticsearch) 24 | client_mock.return_value.indices = MagicMock(spec=IndicesClient) 25 | yield client_mock() 26 | 27 | 28 | @pytest.fixture 29 | def async_es_client_fx() -> Generator[MagicMock, None, None]: 30 | client_mock = MagicMock(spec=AsyncElasticsearch) 31 | client_mock.return_value.indices = MagicMock(spec=AsyncIndicesClient) 32 | # coroutines need to be mocked explicitly 33 | client_mock.return_value.indices.exists_alias = AsyncMock() 34 | client_mock.return_value.indices.put_mapping = AsyncMock() 35 | client_mock.return_value.indices.exists = AsyncMock() 36 | client_mock.return_value.indices.create = AsyncMock() 37 | yield client_mock() 38 | 39 | 40 | @pytest.fixture 41 | def es_embeddings_cache_fx( 42 | es_client_fx: MagicMock, 43 | ) -> Generator[ElasticsearchEmbeddingsCache, None, None]: 44 | with mock.patch( 45 | "langchain_elasticsearch._sync.cache.create_elasticsearch_client", 46 | return_value=es_client_fx, 47 | ): 48 | yield ElasticsearchEmbeddingsCache( 49 | es_url="http://localhost:9200", 50 | index_name="test_index", 51 | store_input=True, 52 | namespace="test", 53 | metadata={"project": "test_project"}, 54 | ) 55 | 56 | 57 | @pytest.fixture 58 | def async_es_embeddings_cache_fx( 59 | async_es_client_fx: MagicMock, 60 | ) -> Generator[AsyncElasticsearchEmbeddingsCache, None, None]: 61 | with mock.patch( 62 | "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", 63 | return_value=async_es_client_fx, 64 | ): 65 | yield AsyncElasticsearchEmbeddingsCache( 66 | es_url="http://localhost:9200", 67 | index_name="test_index", 68 | store_input=True, 69 | namespace="test", 70 | metadata={"project": "test_project"}, 71 | ) 72 | 73 | 74 | @pytest.fixture 75 | def es_cache_fx( 76 | es_client_fx: MagicMock, 77 | ) -> Generator[ElasticsearchCache, None, None]: 78 | with mock.patch( 79 | "langchain_elasticsearch._sync.cache.create_elasticsearch_client", 80 | return_value=es_client_fx, 81 | ): 82 | yield ElasticsearchCache( 83 | es_url="http://localhost:30096", 84 | index_name="test_index", 85 | store_input=True, 86 | store_input_params=True, 87 | metadata={"project": "test_project"}, 88 | ) 89 | 90 | 91 | @pytest.fixture 92 | def async_es_cache_fx( 93 | async_es_client_fx: MagicMock, 94 | ) -> Generator[AsyncElasticsearchCache, None, None]: 95 | with mock.patch( 96 | "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", 97 | return_value=async_es_client_fx, 98 | ): 99 | yield AsyncElasticsearchCache( 100 | es_url="http://localhost:30096", 101 | index_name="test_index", 102 | store_input=True, 103 | store_input_params=True, 104 | metadata={"project": "test_project"}, 105 | ) 106 | 107 | 108 | @pytest.fixture 109 | def fake_chat_fx() -> Generator[BaseChatModel, None, None]: 110 | yield FakeMessagesListChatModel( 111 | cache=True, responses=[AIMessage(content="test output")] 112 | ) 113 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/fake_embeddings.py: -------------------------------------------------------------------------------- 1 | """Fake Embedding class for testing purposes.""" 2 | 3 | from typing import List 4 | 5 | from ._async.fake_embeddings import ( 6 | AsyncConsistentFakeEmbeddings as _AsyncConsistentFakeEmbeddings, 7 | ) 8 | from ._async.fake_embeddings import AsyncFakeEmbeddings as _AsyncFakeEmbeddings 9 | from ._sync.fake_embeddings import ( # noqa: F401 10 | ConsistentFakeEmbeddings, 11 | FakeEmbeddings, 12 | ) 13 | 14 | 15 | # langchain defines embed_documents and embed_query as abstract in its base class 16 | # so we have to add dummy methods for them, even though we only use the async versions 17 | class AsyncFakeEmbeddings(_AsyncFakeEmbeddings): 18 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 19 | raise NotImplementedError("This class is asynchronous, use aembed_documents()") 20 | 21 | def embed_query(self, text: str) -> List[float]: 22 | raise NotImplementedError("This class is asynchronous, use aembed_query()") 23 | 24 | 25 | class AsyncConsistentFakeEmbeddings(_AsyncConsistentFakeEmbeddings): 26 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 27 | raise NotImplementedError("This class is asynchronous, use aembed_documents()") 28 | 29 | def embed_query(self, text: str) -> List[float]: 30 | raise NotImplementedError("This class is asynchronous, use aembed_query()") 31 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/integration_tests/_async/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/_test_utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional 3 | 4 | from elastic_transport import AsyncTransport 5 | from elasticsearch import ( 6 | AsyncElasticsearch, 7 | BadRequestError, 8 | ConflictError, 9 | NotFoundError, 10 | ) 11 | 12 | 13 | def read_env() -> Dict: 14 | url = os.environ.get("ES_URL", "http://localhost:9200") 15 | cloud_id = os.environ.get("ES_CLOUD_ID") 16 | api_key = os.environ.get("ES_API_KEY") 17 | 18 | if cloud_id: 19 | return {"es_cloud_id": cloud_id, "es_api_key": api_key} 20 | return {"es_url": url} 21 | 22 | 23 | class AsyncRequestSavingTransport(AsyncTransport): 24 | def __init__(self, *args: Any, **kwargs: Any) -> None: 25 | super().__init__(*args, **kwargs) 26 | self.requests: List[Dict] = [] 27 | 28 | async def perform_request(self, *args, **kwargs): # type: ignore 29 | self.requests.append(kwargs) 30 | return await super().perform_request(*args, **kwargs) 31 | 32 | 33 | def create_es_client( 34 | es_params: Optional[Dict[str, str]] = None, 35 | es_kwargs: Dict = {}, 36 | ) -> AsyncElasticsearch: 37 | if es_params is None: 38 | es_params = read_env() 39 | if not es_kwargs: 40 | es_kwargs = {} 41 | 42 | if "es_cloud_id" in es_params: 43 | return AsyncElasticsearch( 44 | cloud_id=es_params["es_cloud_id"], 45 | api_key=es_params["es_api_key"], 46 | **es_kwargs, 47 | ) 48 | 49 | return AsyncElasticsearch(hosts=[es_params["es_url"]], **es_kwargs) 50 | 51 | 52 | def requests_saving_es_client() -> AsyncElasticsearch: 53 | return create_es_client(es_kwargs={"transport_class": AsyncRequestSavingTransport}) 54 | 55 | 56 | async def clear_test_indices(es: AsyncElasticsearch) -> None: 57 | index_names_response = await es.indices.get(index="_all") 58 | index_names = index_names_response.keys() 59 | for index_name in index_names: 60 | if index_name.startswith("test_"): 61 | await es.indices.delete(index=index_name) 62 | await es.indices.refresh(index="_all") 63 | 64 | 65 | async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool: 66 | try: 67 | dummy = {"x": "y"} 68 | await client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) 69 | return True 70 | except NotFoundError: 71 | return False 72 | except ConflictError: 73 | return False 74 | except BadRequestError: 75 | # This error is expected because we do not know the expected document 76 | # shape and just use a dummy doc above. 77 | return True 78 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/test_cache.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Dict, Union 2 | 3 | import pytest 4 | from elasticsearch.helpers import BulkIndexError 5 | from langchain.embeddings.cache import _value_serializer 6 | from langchain.globals import set_llm_cache 7 | from langchain_core.language_models import BaseChatModel 8 | 9 | from langchain_elasticsearch import ( 10 | AsyncElasticsearchCache, 11 | AsyncElasticsearchEmbeddingsCache, 12 | ) 13 | 14 | from ._test_utilities import clear_test_indices, create_es_client, read_env 15 | 16 | 17 | @pytest.fixture 18 | async def es_env_fx() -> Union[dict, AsyncGenerator]: 19 | params = read_env() 20 | es = create_es_client(params) 21 | await es.options(ignore_status=404).indices.delete(index="test_index1") 22 | await es.options(ignore_status=404).indices.delete(index="test_index2") 23 | await es.indices.create(index="test_index1") 24 | await es.indices.create(index="test_index2") 25 | await es.indices.put_alias(index="test_index1", name="test_alias") 26 | await es.indices.put_alias( 27 | index="test_index2", name="test_alias", is_write_index=True 28 | ) 29 | yield params 30 | await es.options(ignore_status=404).indices.delete_alias( 31 | index="test_index1,test_index2", name="test_alias" 32 | ) 33 | await clear_test_indices(es) 34 | await es.close() 35 | 36 | 37 | @pytest.mark.asyncio 38 | async def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 39 | cache = AsyncElasticsearchCache( 40 | **es_env_fx, index_name="test_index1", metadata={"project": "test"} 41 | ) 42 | es_client = cache._es_client 43 | set_llm_cache(cache) 44 | await fake_chat_fx.ainvoke("test") 45 | assert (await es_client.count(index="test_index1"))["count"] == 1 46 | await fake_chat_fx.ainvoke("test") 47 | assert (await es_client.count(index="test_index1"))["count"] == 1 48 | record = (await es_client.search(index="test_index1"))["hits"]["hits"][0]["_source"] 49 | assert "test output" in record.get("llm_output", [""])[0] 50 | assert record.get("llm_input") 51 | assert record.get("timestamp") 52 | assert record.get("llm_params") 53 | assert record.get("metadata") == {"project": "test"} 54 | cache2 = AsyncElasticsearchCache( 55 | **es_env_fx, 56 | index_name="test_index1", 57 | metadata={"project": "test"}, 58 | store_input=False, 59 | store_input_params=False, 60 | ) 61 | set_llm_cache(cache2) 62 | await fake_chat_fx.ainvoke("test") 63 | assert (await es_client.count(index="test_index1"))["count"] == 1 64 | await fake_chat_fx.ainvoke("test2") 65 | assert (await es_client.count(index="test_index1"))["count"] == 2 66 | await fake_chat_fx.ainvoke("test2") 67 | records = [ 68 | record["_source"] 69 | for record in (await es_client.search(index="test_index1"))["hits"]["hits"] 70 | ] 71 | assert all("test output" in record.get("llm_output", [""])[0] for record in records) 72 | assert not all(record.get("llm_input", "") for record in records) 73 | assert all(record.get("timestamp", "") for record in records) 74 | assert not all(record.get("llm_params", "") for record in records) 75 | assert all(record.get("metadata") == {"project": "test"} for record in records) 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 80 | cache = AsyncElasticsearchCache( 81 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 82 | ) 83 | es_client = cache._es_client 84 | set_llm_cache(cache) 85 | await fake_chat_fx.ainvoke("test") 86 | assert (await es_client.count(index="test_index2"))["count"] == 1 87 | await fake_chat_fx.ainvoke("test2") 88 | assert (await es_client.count(index="test_index2"))["count"] == 2 89 | await es_client.indices.put_alias( 90 | index="test_index2", name="test_alias", is_write_index=False 91 | ) 92 | await es_client.indices.put_alias( 93 | index="test_index1", name="test_alias", is_write_index=True 94 | ) 95 | await fake_chat_fx.ainvoke("test3") 96 | assert (await es_client.count(index="test_index1"))["count"] == 1 97 | await fake_chat_fx.ainvoke("test2") 98 | assert (await es_client.count(index="test_index1"))["count"] == 1 99 | await es_client.indices.delete_alias(index="test_index2", name="test_alias") 100 | # we cache the response for prompt "test2" on both test_index1 and test_index2 101 | await fake_chat_fx.ainvoke("test2") 102 | assert (await es_client.count(index="test_index1"))["count"] == 2 103 | await es_client.indices.put_alias(index="test_index2", name="test_alias") 104 | # we just test the latter scenario is working 105 | assert await fake_chat_fx.ainvoke("test2") 106 | 107 | 108 | @pytest.mark.asyncio 109 | async def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 110 | cache = AsyncElasticsearchCache( 111 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 112 | ) 113 | es_client = cache._es_client 114 | set_llm_cache(cache) 115 | await fake_chat_fx.ainvoke("test") 116 | await fake_chat_fx.ainvoke("test2") 117 | await es_client.indices.put_alias( 118 | index="test_index2", name="test_alias", is_write_index=False 119 | ) 120 | await es_client.indices.put_alias( 121 | index="test_index1", name="test_alias", is_write_index=True 122 | ) 123 | await fake_chat_fx.ainvoke("test3") 124 | assert (await es_client.count(index="test_alias"))["count"] == 3 125 | await cache.aclear() 126 | assert (await es_client.count(index="test_alias"))["count"] == 0 127 | 128 | 129 | @pytest.mark.asyncio 130 | async def test_mdelete_cache_store(es_env_fx: Dict) -> None: 131 | store = AsyncElasticsearchEmbeddingsCache( 132 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 133 | ) 134 | 135 | recors = ["my little tests", "my little tests2", "my little tests3"] 136 | await store.amset( 137 | [ 138 | (recors[0], _value_serializer([1, 2, 3])), 139 | (recors[1], _value_serializer([1, 2, 3])), 140 | (recors[2], _value_serializer([1, 2, 3])), 141 | ] 142 | ) 143 | 144 | assert (await store._es_client.count(index="test_alias"))["count"] == 3 145 | 146 | await store.amdelete(recors[:2]) 147 | assert (await store._es_client.count(index="test_alias"))["count"] == 1 148 | 149 | await store.amdelete(recors[2:]) 150 | assert (await store._es_client.count(index="test_alias"))["count"] == 0 151 | 152 | with pytest.raises(BulkIndexError): 153 | await store.amdelete(recors) 154 | 155 | 156 | @pytest.mark.asyncio 157 | async def test_mset_cache_store(es_env_fx: Dict) -> None: 158 | store = AsyncElasticsearchEmbeddingsCache( 159 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 160 | ) 161 | 162 | records = ["my little tests", "my little tests2", "my little tests3"] 163 | 164 | await store.amset([(records[0], _value_serializer([1, 2, 3]))]) 165 | assert (await store._es_client.count(index="test_alias"))["count"] == 1 166 | await store.amset([(records[0], _value_serializer([1, 2, 3]))]) 167 | assert (await store._es_client.count(index="test_alias"))["count"] == 1 168 | await store.amset( 169 | [ 170 | (records[1], _value_serializer([1, 2, 3])), 171 | (records[2], _value_serializer([1, 2, 3])), 172 | ] 173 | ) 174 | assert (await store._es_client.count(index="test_alias"))["count"] == 3 175 | 176 | 177 | @pytest.mark.asyncio 178 | async def test_mget_cache_store(es_env_fx: Dict) -> None: 179 | store_no_alias = AsyncElasticsearchEmbeddingsCache( 180 | **es_env_fx, 181 | index_name="test_index3", 182 | metadata={"project": "test"}, 183 | namespace="test", 184 | ) 185 | 186 | records = ["my little tests", "my little tests2", "my little tests3"] 187 | docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] 188 | 189 | await store_no_alias.amset(docs) 190 | assert (await store_no_alias._es_client.count(index="test_index3"))["count"] == 3 191 | 192 | cached_records = await store_no_alias.amget([d[0] for d in docs]) 193 | assert all(cached_records) 194 | assert all([r == d[1] for r, d in zip(cached_records, docs)]) 195 | 196 | store_alias = AsyncElasticsearchEmbeddingsCache( 197 | **es_env_fx, 198 | index_name="test_alias", 199 | metadata={"project": "test"}, 200 | namespace="test", 201 | maximum_duplicates_allowed=1, 202 | ) 203 | 204 | await store_alias.amset(docs) 205 | assert (await store_alias._es_client.count(index="test_alias"))["count"] == 3 206 | 207 | cached_records = await store_alias.amget([d[0] for d in docs]) 208 | assert all(cached_records) 209 | assert all([r == d[1] for r, d in zip(cached_records, docs)]) 210 | 211 | 212 | @pytest.mark.asyncio 213 | async def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: 214 | """verify the logic of deduplication of keys in the cache store""" 215 | 216 | store_alias = AsyncElasticsearchEmbeddingsCache( 217 | **es_env_fx, 218 | index_name="test_alias", 219 | metadata={"project": "test"}, 220 | namespace="test", 221 | maximum_duplicates_allowed=2, 222 | ) 223 | 224 | es_client = store_alias._es_client 225 | 226 | records = ["my little tests", "my little tests2", "my little tests3"] 227 | docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] 228 | 229 | await store_alias.amset(docs) 230 | assert (await es_client.count(index="test_alias"))["count"] == 3 231 | 232 | store_no_alias = AsyncElasticsearchEmbeddingsCache( 233 | **es_env_fx, 234 | index_name="test_index3", 235 | metadata={"project": "test"}, 236 | namespace="test", 237 | maximum_duplicates_allowed=1, 238 | ) 239 | 240 | new_records = records + ["my little tests4", "my little tests5"] 241 | new_docs = [ 242 | (r, _value_serializer([0.1, 2, i + 100])) for i, r in enumerate(new_records) 243 | ] 244 | 245 | # store the same 3 previous records and 2 more in a fresh index 246 | await store_no_alias.amset(new_docs) 247 | assert (await es_client.count(index="test_index3"))["count"] == 5 248 | 249 | # update the alias to point to the new index and verify the cache 250 | await es_client.indices.update_aliases( 251 | actions=[ 252 | { 253 | "add": { 254 | "index": "test_index3", 255 | "alias": "test_alias", 256 | } 257 | } 258 | ] 259 | ) 260 | 261 | # the alias now point to two indices that contains multiple records 262 | # of the same keys, the cache store should return the latest records. 263 | cached_records = await store_alias.amget([d[0] for d in new_docs]) 264 | assert all(cached_records) 265 | assert len(cached_records) == 5 266 | assert (await es_client.count(index="test_alias"))["count"] == 8 267 | assert cached_records[:3] != [ 268 | d[1] for d in docs 269 | ], "the first 3 records should be updated" 270 | assert cached_records == [ 271 | d[1] for d in new_docs 272 | ], "new records should be returned and the updated ones" 273 | assert all([r == d[1] for r, d in zip(cached_records, new_docs)]) 274 | await es_client.options(ignore_status=404).indices.delete_alias( 275 | index="test_index3", name="test_alias" 276 | ) 277 | 278 | 279 | @pytest.mark.asyncio 280 | async def test_build_document_cache_store(es_env_fx: Dict) -> None: 281 | store = AsyncElasticsearchEmbeddingsCache( 282 | **es_env_fx, 283 | index_name="test_alias", 284 | metadata={"project": "test"}, 285 | namespace="test", 286 | ) 287 | 288 | await store.amset([("my little tests", _value_serializer([0.1, 2, 3]))]) 289 | record = (await store._es_client.search(index="test_alias"))["hits"]["hits"][0][ 290 | "_source" 291 | ] 292 | 293 | assert record.get("metadata") == {"project": "test"} 294 | assert record.get("namespace") == "test" 295 | assert record.get("timestamp") 296 | assert record.get("text_input") == "my little tests" 297 | assert record.get("vector_dump") == AsyncElasticsearchEmbeddingsCache.encode_vector( 298 | _value_serializer([0.1, 2, 3]) 299 | ) 300 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/test_chat_history.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import AsyncIterator 3 | 4 | import pytest 5 | from langchain.memory import ConversationBufferMemory 6 | from langchain_core.messages import AIMessage, HumanMessage, message_to_dict 7 | 8 | from langchain_elasticsearch.chat_history import AsyncElasticsearchChatMessageHistory 9 | 10 | from ._test_utilities import clear_test_indices, create_es_client, read_env 11 | 12 | """ 13 | cd tests/integration_tests 14 | docker-compose up elasticsearch 15 | 16 | By default runs against local docker instance of Elasticsearch. 17 | To run against Elastic Cloud, set the following environment variables: 18 | - ES_CLOUD_ID 19 | - ES_USERNAME 20 | - ES_PASSWORD 21 | """ 22 | 23 | 24 | class TestElasticsearch: 25 | @pytest.fixture 26 | async def elasticsearch_connection(self) -> AsyncIterator[dict]: 27 | params = read_env() 28 | es = create_es_client(params) 29 | 30 | yield params 31 | 32 | await clear_test_indices(es) 33 | await es.close() 34 | 35 | @pytest.fixture(scope="function") 36 | def index_name(self) -> str: 37 | """Return the index name.""" 38 | return f"test_{uuid.uuid4().hex}" 39 | 40 | async def test_memory_with_message_store( 41 | self, elasticsearch_connection: dict, index_name: str 42 | ) -> None: 43 | """Test the memory with a message store.""" 44 | # setup Elasticsearch as a message store 45 | message_history = AsyncElasticsearchChatMessageHistory( 46 | **elasticsearch_connection, index=index_name, session_id="test-session" 47 | ) 48 | 49 | memory = ConversationBufferMemory( 50 | memory_key="baz", chat_memory=message_history, return_messages=True 51 | ) 52 | 53 | # add some messages 54 | await memory.chat_memory.aadd_messages( 55 | [ 56 | AIMessage("This is me, the AI (1)"), 57 | HumanMessage("This is me, the human (1)"), 58 | AIMessage("This is me, the AI (2)"), 59 | HumanMessage("This is me, the human (2)"), 60 | AIMessage("This is me, the AI (3)"), 61 | HumanMessage("This is me, the human (3)"), 62 | AIMessage("This is me, the AI (4)"), 63 | HumanMessage("This is me, the human (4)"), 64 | AIMessage("This is me, the AI (5)"), 65 | HumanMessage("This is me, the human (5)"), 66 | AIMessage("This is me, the AI (6)"), 67 | HumanMessage("This is me, the human (6)"), 68 | AIMessage("This is me, the AI (7)"), 69 | HumanMessage("This is me, the human (7)"), 70 | ] 71 | ) 72 | 73 | # get the message history from the memory store and turn it into a json 74 | messages = [ 75 | message_to_dict(msg) for msg in (await memory.chat_memory.aget_messages()) 76 | ] 77 | 78 | assert len(messages) == 14 79 | for i in range(7): 80 | assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})" 81 | assert ( 82 | messages[i * 2 + 1]["data"]["content"] 83 | == f"This is me, the human ({i+1})" 84 | ) 85 | 86 | # remove the record from Elasticsearch, so the next test run won't pick it up 87 | await memory.chat_memory.aclear() 88 | 89 | assert await memory.chat_memory.aget_messages() == [] 90 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/test_embeddings.py: -------------------------------------------------------------------------------- 1 | """Test elasticsearch_embeddings embeddings.""" 2 | 3 | import os 4 | 5 | import pytest 6 | from elasticsearch import AsyncElasticsearch 7 | 8 | from langchain_elasticsearch.embeddings import AsyncElasticsearchEmbeddings 9 | 10 | from ._test_utilities import model_is_deployed 11 | 12 | # deployed with 13 | # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html 14 | MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") 15 | NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) 16 | 17 | ES_URL = os.environ.get("ES_URL", "http://localhost:9200") 18 | 19 | 20 | @pytest.mark.asyncio 21 | async def test_elasticsearch_embedding_documents() -> None: 22 | """Test Elasticsearch embedding documents.""" 23 | client = AsyncElasticsearch(hosts=[ES_URL]) 24 | if not (await model_is_deployed(client, MODEL_ID)): 25 | await client.close() 26 | pytest.skip( 27 | reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" 28 | ) 29 | 30 | documents = ["foo bar", "bar foo", "foo"] 31 | embedding = AsyncElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) 32 | output = await embedding.aembed_documents(documents) 33 | await client.close() 34 | assert len(output) == 3 35 | assert len(output[0]) == NUM_DIMENSIONS 36 | assert len(output[1]) == NUM_DIMENSIONS 37 | assert len(output[2]) == NUM_DIMENSIONS 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_elasticsearch_embedding_query() -> None: 42 | """Test Elasticsearch embedding query.""" 43 | client = AsyncElasticsearch(hosts=[ES_URL]) 44 | if not (await model_is_deployed(client, MODEL_ID)): 45 | await client.close() 46 | pytest.skip( 47 | reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" 48 | ) 49 | 50 | document = "foo bar" 51 | embedding = AsyncElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) 52 | output = await embedding.aembed_query(document) 53 | await client.close() 54 | assert len(output) == NUM_DIMENSIONS 55 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py: -------------------------------------------------------------------------------- 1 | """Test ElasticsearchRetriever functionality.""" 2 | 3 | import os 4 | import re 5 | import uuid 6 | from typing import Any, Dict 7 | 8 | import pytest 9 | from elasticsearch import AsyncElasticsearch 10 | from langchain_core.documents import Document 11 | 12 | from langchain_elasticsearch.retrievers import AsyncElasticsearchRetriever 13 | 14 | from ._test_utilities import requests_saving_es_client 15 | 16 | """ 17 | cd tests/integration_tests 18 | docker-compose up elasticsearch 19 | 20 | By default runs against local docker instance of Elasticsearch. 21 | To run against Elastic Cloud, set the following environment variables: 22 | - ES_CLOUD_ID 23 | - ES_API_KEY 24 | """ 25 | 26 | 27 | async def index_test_data( 28 | es_client: AsyncElasticsearch, index_name: str, field_name: str 29 | ) -> None: 30 | docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")] 31 | for identifier, text in docs: 32 | await es_client.index( 33 | index=index_name, 34 | document={field_name: text, "another_field": 1}, 35 | id=str(identifier), 36 | refresh=True, 37 | ) 38 | 39 | 40 | class TestElasticsearchRetriever: 41 | @pytest.fixture(scope="function") 42 | async def es_client(self) -> Any: 43 | client = requests_saving_es_client() 44 | yield client 45 | await client.close() 46 | 47 | @pytest.fixture(scope="function") 48 | def index_name(self) -> str: 49 | """Return the index name.""" 50 | return f"test_{uuid.uuid4().hex}" 51 | 52 | @pytest.mark.asyncio 53 | async def test_user_agent_header( 54 | self, es_client: AsyncElasticsearch, index_name: str 55 | ) -> None: 56 | """Test that the user agent header is set correctly.""" 57 | 58 | retriever = AsyncElasticsearchRetriever( 59 | index_name=index_name, 60 | body_func=lambda _: {"query": {"match_all": {}}}, 61 | content_field="text", 62 | es_client=es_client, 63 | ) 64 | 65 | assert retriever.es_client 66 | user_agent = retriever.es_client._headers["User-Agent"] 67 | assert ( 68 | re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) 69 | is not None 70 | ), f"The string '{user_agent}' does not match the expected pattern." 71 | 72 | await index_test_data(es_client, index_name, "text") 73 | await retriever.aget_relevant_documents("foo") 74 | 75 | search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] 76 | user_agent = search_request["headers"]["User-Agent"] 77 | assert ( 78 | re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) 79 | is not None 80 | ), f"The string '{user_agent}' does not match the expected pattern." 81 | 82 | @pytest.mark.asyncio 83 | async def test_init_url(self, index_name: str) -> None: 84 | """Test end-to-end indexing and search.""" 85 | 86 | text_field = "text" 87 | 88 | def body_func(query: str) -> Dict: 89 | return {"query": {"match": {text_field: {"query": query}}}} 90 | 91 | es_url = os.environ.get("ES_URL", "http://localhost:9200") 92 | cloud_id = os.environ.get("ES_CLOUD_ID") 93 | api_key = os.environ.get("ES_API_KEY") 94 | 95 | config = ( 96 | {"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url} 97 | ) 98 | 99 | retriever = AsyncElasticsearchRetriever.from_es_params( 100 | index_name=index_name, 101 | body_func=body_func, 102 | content_field=text_field, 103 | **config, # type: ignore[arg-type] 104 | ) 105 | 106 | await index_test_data(retriever.es_client, index_name, text_field) 107 | result = await retriever.aget_relevant_documents("foo") 108 | 109 | assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} 110 | assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} 111 | for r in result: 112 | assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} 113 | assert text_field not in r.metadata["_source"] 114 | assert "another_field" in r.metadata["_source"] 115 | 116 | @pytest.mark.asyncio 117 | async def test_init_client( 118 | self, es_client: AsyncElasticsearch, index_name: str 119 | ) -> None: 120 | """Test end-to-end indexing and search.""" 121 | 122 | text_field = "text" 123 | 124 | def body_func(query: str) -> Dict: 125 | return {"query": {"match": {text_field: {"query": query}}}} 126 | 127 | retriever = AsyncElasticsearchRetriever( 128 | index_name=index_name, 129 | body_func=body_func, 130 | content_field=text_field, 131 | es_client=es_client, 132 | ) 133 | 134 | await index_test_data(es_client, index_name, text_field) 135 | result = await retriever.aget_relevant_documents("foo") 136 | 137 | assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} 138 | assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} 139 | for r in result: 140 | assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} 141 | assert text_field not in r.metadata["_source"] 142 | assert "another_field" in r.metadata["_source"] 143 | 144 | @pytest.mark.asyncio 145 | async def test_multiple_index_and_content_fields( 146 | self, es_client: AsyncElasticsearch, index_name: str 147 | ) -> None: 148 | """Test multiple content fields""" 149 | index_name_1 = f"{index_name}_1" 150 | index_name_2 = f"{index_name}_2" 151 | text_field_1 = "text_1" 152 | text_field_2 = "text_2" 153 | 154 | def body_func(query: str) -> Dict: 155 | return { 156 | "query": { 157 | "multi_match": { 158 | "query": query, 159 | "fields": [text_field_1, text_field_2], 160 | } 161 | } 162 | } 163 | 164 | retriever = AsyncElasticsearchRetriever( 165 | index_name=[index_name_1, index_name_2], 166 | content_field={index_name_1: text_field_1, index_name_2: text_field_2}, 167 | body_func=body_func, 168 | es_client=es_client, 169 | ) 170 | 171 | await index_test_data(es_client, index_name_1, text_field_1) 172 | await index_test_data(es_client, index_name_2, text_field_2) 173 | result = await retriever.aget_relevant_documents("foo") 174 | 175 | # matches from both indices 176 | assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ 177 | ("foo", index_name_1), 178 | ("foo", index_name_2), 179 | ("foo bar", index_name_1), 180 | ("foo bar", index_name_2), 181 | ("foo baz", index_name_1), 182 | ("foo baz", index_name_2), 183 | ] 184 | 185 | @pytest.mark.asyncio 186 | async def test_custom_mapper( 187 | self, es_client: AsyncElasticsearch, index_name: str 188 | ) -> None: 189 | """Test custom document maper""" 190 | 191 | text_field = "text" 192 | meta = {"some_field": 12} 193 | 194 | def body_func(query: str) -> Dict: 195 | return {"query": {"match": {text_field: {"query": query}}}} 196 | 197 | def id_as_content(hit: Dict) -> Document: 198 | return Document(page_content=hit["_id"], metadata=meta) 199 | 200 | retriever = AsyncElasticsearchRetriever( 201 | index_name=index_name, 202 | body_func=body_func, 203 | document_mapper=id_as_content, 204 | es_client=es_client, 205 | ) 206 | 207 | await index_test_data(es_client, index_name, text_field) 208 | result = await retriever.aget_relevant_documents("foo") 209 | 210 | assert [r.page_content for r in result] == ["3", "1", "5"] 211 | assert [r.metadata for r in result] == [meta, meta, meta] 212 | 213 | @pytest.mark.asyncio 214 | async def test_fail_content_field_and_mapper( 215 | self, es_client: AsyncElasticsearch 216 | ) -> None: 217 | """Raise exception if both content_field and document_mapper are specified.""" 218 | 219 | with pytest.raises(ValueError): 220 | AsyncElasticsearchRetriever( 221 | content_field="text", 222 | document_mapper=lambda x: x, 223 | index_name="foo", 224 | body_func=lambda x: x, 225 | es_client=es_client, 226 | ) 227 | 228 | @pytest.mark.asyncio 229 | async def test_fail_neither_content_field_nor_mapper( 230 | self, es_client: AsyncElasticsearch 231 | ) -> None: 232 | """Raise exception if neither content_field nor document_mapper are specified""" 233 | 234 | with pytest.raises(ValueError): 235 | AsyncElasticsearchRetriever( 236 | index_name="foo", 237 | body_func=lambda x: x, 238 | es_client=es_client, 239 | ) 240 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/integration_tests/_sync/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/_test_utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional 3 | 4 | from elastic_transport import Transport 5 | from elasticsearch import ( 6 | BadRequestError, 7 | ConflictError, 8 | Elasticsearch, 9 | NotFoundError, 10 | ) 11 | 12 | 13 | def read_env() -> Dict: 14 | url = os.environ.get("ES_URL", "http://localhost:9200") 15 | cloud_id = os.environ.get("ES_CLOUD_ID") 16 | api_key = os.environ.get("ES_API_KEY") 17 | 18 | if cloud_id: 19 | return {"es_cloud_id": cloud_id, "es_api_key": api_key} 20 | return {"es_url": url} 21 | 22 | 23 | class RequestSavingTransport(Transport): 24 | def __init__(self, *args: Any, **kwargs: Any) -> None: 25 | super().__init__(*args, **kwargs) 26 | self.requests: List[Dict] = [] 27 | 28 | def perform_request(self, *args, **kwargs): # type: ignore 29 | self.requests.append(kwargs) 30 | return super().perform_request(*args, **kwargs) 31 | 32 | 33 | def create_es_client( 34 | es_params: Optional[Dict[str, str]] = None, 35 | es_kwargs: Dict = {}, 36 | ) -> Elasticsearch: 37 | if es_params is None: 38 | es_params = read_env() 39 | if not es_kwargs: 40 | es_kwargs = {} 41 | 42 | if "es_cloud_id" in es_params: 43 | return Elasticsearch( 44 | cloud_id=es_params["es_cloud_id"], 45 | api_key=es_params["es_api_key"], 46 | **es_kwargs, 47 | ) 48 | 49 | return Elasticsearch(hosts=[es_params["es_url"]], **es_kwargs) 50 | 51 | 52 | def requests_saving_es_client() -> Elasticsearch: 53 | return create_es_client(es_kwargs={"transport_class": RequestSavingTransport}) 54 | 55 | 56 | def clear_test_indices(es: Elasticsearch) -> None: 57 | index_names_response = es.indices.get(index="_all") 58 | index_names = index_names_response.keys() 59 | for index_name in index_names: 60 | if index_name.startswith("test_"): 61 | es.indices.delete(index=index_name) 62 | es.indices.refresh(index="_all") 63 | 64 | 65 | def model_is_deployed(client: Elasticsearch, model_id: str) -> bool: 66 | try: 67 | dummy = {"x": "y"} 68 | client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) 69 | return True 70 | except NotFoundError: 71 | return False 72 | except ConflictError: 73 | return False 74 | except BadRequestError: 75 | # This error is expected because we do not know the expected document 76 | # shape and just use a dummy doc above. 77 | return True 78 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/test_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Generator, Union 2 | 3 | import pytest 4 | from elasticsearch.helpers import BulkIndexError 5 | from langchain.embeddings.cache import _value_serializer 6 | from langchain.globals import set_llm_cache 7 | from langchain_core.language_models import BaseChatModel 8 | 9 | from langchain_elasticsearch import ( 10 | ElasticsearchCache, 11 | ElasticsearchEmbeddingsCache, 12 | ) 13 | 14 | from ._test_utilities import clear_test_indices, create_es_client, read_env 15 | 16 | 17 | @pytest.fixture 18 | def es_env_fx() -> Union[dict, Generator]: 19 | params = read_env() 20 | es = create_es_client(params) 21 | es.options(ignore_status=404).indices.delete(index="test_index1") 22 | es.options(ignore_status=404).indices.delete(index="test_index2") 23 | es.indices.create(index="test_index1") 24 | es.indices.create(index="test_index2") 25 | es.indices.put_alias(index="test_index1", name="test_alias") 26 | es.indices.put_alias(index="test_index2", name="test_alias", is_write_index=True) 27 | yield params 28 | es.options(ignore_status=404).indices.delete_alias( 29 | index="test_index1,test_index2", name="test_alias" 30 | ) 31 | clear_test_indices(es) 32 | es.close() 33 | 34 | 35 | @pytest.mark.sync 36 | def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 37 | cache = ElasticsearchCache( 38 | **es_env_fx, index_name="test_index1", metadata={"project": "test"} 39 | ) 40 | es_client = cache._es_client 41 | set_llm_cache(cache) 42 | fake_chat_fx.invoke("test") 43 | assert (es_client.count(index="test_index1"))["count"] == 1 44 | fake_chat_fx.invoke("test") 45 | assert (es_client.count(index="test_index1"))["count"] == 1 46 | record = (es_client.search(index="test_index1"))["hits"]["hits"][0]["_source"] 47 | assert "test output" in record.get("llm_output", [""])[0] 48 | assert record.get("llm_input") 49 | assert record.get("timestamp") 50 | assert record.get("llm_params") 51 | assert record.get("metadata") == {"project": "test"} 52 | cache2 = ElasticsearchCache( 53 | **es_env_fx, 54 | index_name="test_index1", 55 | metadata={"project": "test"}, 56 | store_input=False, 57 | store_input_params=False, 58 | ) 59 | set_llm_cache(cache2) 60 | fake_chat_fx.invoke("test") 61 | assert (es_client.count(index="test_index1"))["count"] == 1 62 | fake_chat_fx.invoke("test2") 63 | assert (es_client.count(index="test_index1"))["count"] == 2 64 | fake_chat_fx.invoke("test2") 65 | records = [ 66 | record["_source"] 67 | for record in (es_client.search(index="test_index1"))["hits"]["hits"] 68 | ] 69 | assert all("test output" in record.get("llm_output", [""])[0] for record in records) 70 | assert not all(record.get("llm_input", "") for record in records) 71 | assert all(record.get("timestamp", "") for record in records) 72 | assert not all(record.get("llm_params", "") for record in records) 73 | assert all(record.get("metadata") == {"project": "test"} for record in records) 74 | 75 | 76 | @pytest.mark.sync 77 | def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 78 | cache = ElasticsearchCache( 79 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 80 | ) 81 | es_client = cache._es_client 82 | set_llm_cache(cache) 83 | fake_chat_fx.invoke("test") 84 | assert (es_client.count(index="test_index2"))["count"] == 1 85 | fake_chat_fx.invoke("test2") 86 | assert (es_client.count(index="test_index2"))["count"] == 2 87 | es_client.indices.put_alias( 88 | index="test_index2", name="test_alias", is_write_index=False 89 | ) 90 | es_client.indices.put_alias( 91 | index="test_index1", name="test_alias", is_write_index=True 92 | ) 93 | fake_chat_fx.invoke("test3") 94 | assert (es_client.count(index="test_index1"))["count"] == 1 95 | fake_chat_fx.invoke("test2") 96 | assert (es_client.count(index="test_index1"))["count"] == 1 97 | es_client.indices.delete_alias(index="test_index2", name="test_alias") 98 | # we cache the response for prompt "test2" on both test_index1 and test_index2 99 | fake_chat_fx.invoke("test2") 100 | assert (es_client.count(index="test_index1"))["count"] == 2 101 | es_client.indices.put_alias(index="test_index2", name="test_alias") 102 | # we just test the latter scenario is working 103 | assert fake_chat_fx.invoke("test2") 104 | 105 | 106 | @pytest.mark.sync 107 | def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: 108 | cache = ElasticsearchCache( 109 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 110 | ) 111 | es_client = cache._es_client 112 | set_llm_cache(cache) 113 | fake_chat_fx.invoke("test") 114 | fake_chat_fx.invoke("test2") 115 | es_client.indices.put_alias( 116 | index="test_index2", name="test_alias", is_write_index=False 117 | ) 118 | es_client.indices.put_alias( 119 | index="test_index1", name="test_alias", is_write_index=True 120 | ) 121 | fake_chat_fx.invoke("test3") 122 | assert (es_client.count(index="test_alias"))["count"] == 3 123 | cache.clear() 124 | assert (es_client.count(index="test_alias"))["count"] == 0 125 | 126 | 127 | @pytest.mark.sync 128 | def test_mdelete_cache_store(es_env_fx: Dict) -> None: 129 | store = ElasticsearchEmbeddingsCache( 130 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 131 | ) 132 | 133 | recors = ["my little tests", "my little tests2", "my little tests3"] 134 | store.mset( 135 | [ 136 | (recors[0], _value_serializer([1, 2, 3])), 137 | (recors[1], _value_serializer([1, 2, 3])), 138 | (recors[2], _value_serializer([1, 2, 3])), 139 | ] 140 | ) 141 | 142 | assert (store._es_client.count(index="test_alias"))["count"] == 3 143 | 144 | store.mdelete(recors[:2]) 145 | assert (store._es_client.count(index="test_alias"))["count"] == 1 146 | 147 | store.mdelete(recors[2:]) 148 | assert (store._es_client.count(index="test_alias"))["count"] == 0 149 | 150 | with pytest.raises(BulkIndexError): 151 | store.mdelete(recors) 152 | 153 | 154 | @pytest.mark.sync 155 | def test_mset_cache_store(es_env_fx: Dict) -> None: 156 | store = ElasticsearchEmbeddingsCache( 157 | **es_env_fx, index_name="test_alias", metadata={"project": "test"} 158 | ) 159 | 160 | records = ["my little tests", "my little tests2", "my little tests3"] 161 | 162 | store.mset([(records[0], _value_serializer([1, 2, 3]))]) 163 | assert (store._es_client.count(index="test_alias"))["count"] == 1 164 | store.mset([(records[0], _value_serializer([1, 2, 3]))]) 165 | assert (store._es_client.count(index="test_alias"))["count"] == 1 166 | store.mset( 167 | [ 168 | (records[1], _value_serializer([1, 2, 3])), 169 | (records[2], _value_serializer([1, 2, 3])), 170 | ] 171 | ) 172 | assert (store._es_client.count(index="test_alias"))["count"] == 3 173 | 174 | 175 | @pytest.mark.sync 176 | def test_mget_cache_store(es_env_fx: Dict) -> None: 177 | store_no_alias = ElasticsearchEmbeddingsCache( 178 | **es_env_fx, 179 | index_name="test_index3", 180 | metadata={"project": "test"}, 181 | namespace="test", 182 | ) 183 | 184 | records = ["my little tests", "my little tests2", "my little tests3"] 185 | docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] 186 | 187 | store_no_alias.mset(docs) 188 | assert (store_no_alias._es_client.count(index="test_index3"))["count"] == 3 189 | 190 | cached_records = store_no_alias.mget([d[0] for d in docs]) 191 | assert all(cached_records) 192 | assert all([r == d[1] for r, d in zip(cached_records, docs)]) 193 | 194 | store_alias = ElasticsearchEmbeddingsCache( 195 | **es_env_fx, 196 | index_name="test_alias", 197 | metadata={"project": "test"}, 198 | namespace="test", 199 | maximum_duplicates_allowed=1, 200 | ) 201 | 202 | store_alias.mset(docs) 203 | assert (store_alias._es_client.count(index="test_alias"))["count"] == 3 204 | 205 | cached_records = store_alias.mget([d[0] for d in docs]) 206 | assert all(cached_records) 207 | assert all([r == d[1] for r, d in zip(cached_records, docs)]) 208 | 209 | 210 | @pytest.mark.sync 211 | def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: 212 | """verify the logic of deduplication of keys in the cache store""" 213 | 214 | store_alias = ElasticsearchEmbeddingsCache( 215 | **es_env_fx, 216 | index_name="test_alias", 217 | metadata={"project": "test"}, 218 | namespace="test", 219 | maximum_duplicates_allowed=2, 220 | ) 221 | 222 | es_client = store_alias._es_client 223 | 224 | records = ["my little tests", "my little tests2", "my little tests3"] 225 | docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] 226 | 227 | store_alias.mset(docs) 228 | assert (es_client.count(index="test_alias"))["count"] == 3 229 | 230 | store_no_alias = ElasticsearchEmbeddingsCache( 231 | **es_env_fx, 232 | index_name="test_index3", 233 | metadata={"project": "test"}, 234 | namespace="test", 235 | maximum_duplicates_allowed=1, 236 | ) 237 | 238 | new_records = records + ["my little tests4", "my little tests5"] 239 | new_docs = [ 240 | (r, _value_serializer([0.1, 2, i + 100])) for i, r in enumerate(new_records) 241 | ] 242 | 243 | # store the same 3 previous records and 2 more in a fresh index 244 | store_no_alias.mset(new_docs) 245 | assert (es_client.count(index="test_index3"))["count"] == 5 246 | 247 | # update the alias to point to the new index and verify the cache 248 | es_client.indices.update_aliases( 249 | actions=[ 250 | { 251 | "add": { 252 | "index": "test_index3", 253 | "alias": "test_alias", 254 | } 255 | } 256 | ] 257 | ) 258 | 259 | # the alias now point to two indices that contains multiple records 260 | # of the same keys, the cache store should return the latest records. 261 | cached_records = store_alias.mget([d[0] for d in new_docs]) 262 | assert all(cached_records) 263 | assert len(cached_records) == 5 264 | assert (es_client.count(index="test_alias"))["count"] == 8 265 | assert cached_records[:3] != [ 266 | d[1] for d in docs 267 | ], "the first 3 records should be updated" 268 | assert cached_records == [ 269 | d[1] for d in new_docs 270 | ], "new records should be returned and the updated ones" 271 | assert all([r == d[1] for r, d in zip(cached_records, new_docs)]) 272 | es_client.options(ignore_status=404).indices.delete_alias( 273 | index="test_index3", name="test_alias" 274 | ) 275 | 276 | 277 | @pytest.mark.sync 278 | def test_build_document_cache_store(es_env_fx: Dict) -> None: 279 | store = ElasticsearchEmbeddingsCache( 280 | **es_env_fx, 281 | index_name="test_alias", 282 | metadata={"project": "test"}, 283 | namespace="test", 284 | ) 285 | 286 | store.mset([("my little tests", _value_serializer([0.1, 2, 3]))]) 287 | record = (store._es_client.search(index="test_alias"))["hits"]["hits"][0]["_source"] 288 | 289 | assert record.get("metadata") == {"project": "test"} 290 | assert record.get("namespace") == "test" 291 | assert record.get("timestamp") 292 | assert record.get("text_input") == "my little tests" 293 | assert record.get("vector_dump") == ElasticsearchEmbeddingsCache.encode_vector( 294 | _value_serializer([0.1, 2, 3]) 295 | ) 296 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/test_chat_history.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Iterator 3 | 4 | import pytest 5 | from langchain.memory import ConversationBufferMemory 6 | from langchain_core.messages import AIMessage, HumanMessage, message_to_dict 7 | 8 | from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory 9 | 10 | from ._test_utilities import clear_test_indices, create_es_client, read_env 11 | 12 | """ 13 | cd tests/integration_tests 14 | docker-compose up elasticsearch 15 | 16 | By default runs against local docker instance of Elasticsearch. 17 | To run against Elastic Cloud, set the following environment variables: 18 | - ES_CLOUD_ID 19 | - ES_USERNAME 20 | - ES_PASSWORD 21 | """ 22 | 23 | 24 | class TestElasticsearch: 25 | @pytest.fixture 26 | def elasticsearch_connection(self) -> Iterator[dict]: 27 | params = read_env() 28 | es = create_es_client(params) 29 | 30 | yield params 31 | 32 | clear_test_indices(es) 33 | es.close() 34 | 35 | @pytest.fixture(scope="function") 36 | def index_name(self) -> str: 37 | """Return the index name.""" 38 | return f"test_{uuid.uuid4().hex}" 39 | 40 | def test_memory_with_message_store( 41 | self, elasticsearch_connection: dict, index_name: str 42 | ) -> None: 43 | """Test the memory with a message store.""" 44 | # setup Elasticsearch as a message store 45 | message_history = ElasticsearchChatMessageHistory( 46 | **elasticsearch_connection, index=index_name, session_id="test-session" 47 | ) 48 | 49 | memory = ConversationBufferMemory( 50 | memory_key="baz", chat_memory=message_history, return_messages=True 51 | ) 52 | 53 | # add some messages 54 | memory.chat_memory.add_messages( 55 | [ 56 | AIMessage("This is me, the AI (1)"), 57 | HumanMessage("This is me, the human (1)"), 58 | AIMessage("This is me, the AI (2)"), 59 | HumanMessage("This is me, the human (2)"), 60 | AIMessage("This is me, the AI (3)"), 61 | HumanMessage("This is me, the human (3)"), 62 | AIMessage("This is me, the AI (4)"), 63 | HumanMessage("This is me, the human (4)"), 64 | AIMessage("This is me, the AI (5)"), 65 | HumanMessage("This is me, the human (5)"), 66 | AIMessage("This is me, the AI (6)"), 67 | HumanMessage("This is me, the human (6)"), 68 | AIMessage("This is me, the AI (7)"), 69 | HumanMessage("This is me, the human (7)"), 70 | ] 71 | ) 72 | 73 | # get the message history from the memory store and turn it into a json 74 | messages = [message_to_dict(msg) for msg in (memory.chat_memory.messages)] 75 | 76 | assert len(messages) == 14 77 | for i in range(7): 78 | assert messages[i * 2]["data"]["content"] == f"This is me, the AI ({i+1})" 79 | assert ( 80 | messages[i * 2 + 1]["data"]["content"] 81 | == f"This is me, the human ({i+1})" 82 | ) 83 | 84 | # remove the record from Elasticsearch, so the next test run won't pick it up 85 | memory.chat_memory.clear() 86 | 87 | assert memory.chat_memory.messages == [] 88 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/test_embeddings.py: -------------------------------------------------------------------------------- 1 | """Test elasticsearch_embeddings embeddings.""" 2 | 3 | import os 4 | 5 | import pytest 6 | from elasticsearch import Elasticsearch 7 | 8 | from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings 9 | 10 | from ._test_utilities import model_is_deployed 11 | 12 | # deployed with 13 | # https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html 14 | MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3") 15 | NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384")) 16 | 17 | ES_URL = os.environ.get("ES_URL", "http://localhost:9200") 18 | 19 | 20 | @pytest.mark.sync 21 | def test_elasticsearch_embedding_documents() -> None: 22 | """Test Elasticsearch embedding documents.""" 23 | client = Elasticsearch(hosts=[ES_URL]) 24 | if not (model_is_deployed(client, MODEL_ID)): 25 | client.close() 26 | pytest.skip( 27 | reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" 28 | ) 29 | 30 | documents = ["foo bar", "bar foo", "foo"] 31 | embedding = ElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) 32 | output = embedding.embed_documents(documents) 33 | client.close() 34 | assert len(output) == 3 35 | assert len(output[0]) == NUM_DIMENSIONS 36 | assert len(output[1]) == NUM_DIMENSIONS 37 | assert len(output[2]) == NUM_DIMENSIONS 38 | 39 | 40 | @pytest.mark.sync 41 | def test_elasticsearch_embedding_query() -> None: 42 | """Test Elasticsearch embedding query.""" 43 | client = Elasticsearch(hosts=[ES_URL]) 44 | if not (model_is_deployed(client, MODEL_ID)): 45 | client.close() 46 | pytest.skip( 47 | reason=f"{MODEL_ID} model is not deployed in ML Node, skipping test" 48 | ) 49 | 50 | document = "foo bar" 51 | embedding = ElasticsearchEmbeddings.from_es_connection(MODEL_ID, client) 52 | output = embedding.embed_query(document) 53 | client.close() 54 | assert len(output) == NUM_DIMENSIONS 55 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py: -------------------------------------------------------------------------------- 1 | """Test ElasticsearchRetriever functionality.""" 2 | 3 | import os 4 | import re 5 | import uuid 6 | from typing import Any, Dict 7 | 8 | import pytest 9 | from elasticsearch import Elasticsearch 10 | from langchain_core.documents import Document 11 | 12 | from langchain_elasticsearch.retrievers import ElasticsearchRetriever 13 | 14 | from ._test_utilities import requests_saving_es_client 15 | 16 | """ 17 | cd tests/integration_tests 18 | docker-compose up elasticsearch 19 | 20 | By default runs against local docker instance of Elasticsearch. 21 | To run against Elastic Cloud, set the following environment variables: 22 | - ES_CLOUD_ID 23 | - ES_API_KEY 24 | """ 25 | 26 | 27 | def index_test_data(es_client: Elasticsearch, index_name: str, field_name: str) -> None: 28 | docs = [(1, "foo bar"), (2, "bar"), (3, "foo"), (4, "baz"), (5, "foo baz")] 29 | for identifier, text in docs: 30 | es_client.index( 31 | index=index_name, 32 | document={field_name: text, "another_field": 1}, 33 | id=str(identifier), 34 | refresh=True, 35 | ) 36 | 37 | 38 | class TestElasticsearchRetriever: 39 | @pytest.fixture(scope="function") 40 | def es_client(self) -> Any: 41 | client = requests_saving_es_client() 42 | yield client 43 | client.close() 44 | 45 | @pytest.fixture(scope="function") 46 | def index_name(self) -> str: 47 | """Return the index name.""" 48 | return f"test_{uuid.uuid4().hex}" 49 | 50 | @pytest.mark.sync 51 | def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> None: 52 | """Test that the user agent header is set correctly.""" 53 | 54 | retriever = ElasticsearchRetriever( 55 | index_name=index_name, 56 | body_func=lambda _: {"query": {"match_all": {}}}, 57 | content_field="text", 58 | es_client=es_client, 59 | ) 60 | 61 | assert retriever.es_client 62 | user_agent = retriever.es_client._headers["User-Agent"] 63 | assert ( 64 | re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) 65 | is not None 66 | ), f"The string '{user_agent}' does not match the expected pattern." 67 | 68 | index_test_data(es_client, index_name, "text") 69 | retriever.get_relevant_documents("foo") 70 | 71 | search_request = es_client.transport.requests[-1] # type: ignore[attr-defined] 72 | user_agent = search_request["headers"]["User-Agent"] 73 | assert ( 74 | re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) 75 | is not None 76 | ), f"The string '{user_agent}' does not match the expected pattern." 77 | 78 | @pytest.mark.sync 79 | def test_init_url(self, index_name: str) -> None: 80 | """Test end-to-end indexing and search.""" 81 | 82 | text_field = "text" 83 | 84 | def body_func(query: str) -> Dict: 85 | return {"query": {"match": {text_field: {"query": query}}}} 86 | 87 | es_url = os.environ.get("ES_URL", "http://localhost:9200") 88 | cloud_id = os.environ.get("ES_CLOUD_ID") 89 | api_key = os.environ.get("ES_API_KEY") 90 | 91 | config = ( 92 | {"cloud_id": cloud_id, "api_key": api_key} if cloud_id else {"url": es_url} 93 | ) 94 | 95 | retriever = ElasticsearchRetriever.from_es_params( 96 | index_name=index_name, 97 | body_func=body_func, 98 | content_field=text_field, 99 | **config, # type: ignore[arg-type] 100 | ) 101 | 102 | index_test_data(retriever.es_client, index_name, text_field) 103 | result = retriever.get_relevant_documents("foo") 104 | 105 | assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} 106 | assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} 107 | for r in result: 108 | assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} 109 | assert text_field not in r.metadata["_source"] 110 | assert "another_field" in r.metadata["_source"] 111 | 112 | @pytest.mark.sync 113 | def test_init_client(self, es_client: Elasticsearch, index_name: str) -> None: 114 | """Test end-to-end indexing and search.""" 115 | 116 | text_field = "text" 117 | 118 | def body_func(query: str) -> Dict: 119 | return {"query": {"match": {text_field: {"query": query}}}} 120 | 121 | retriever = ElasticsearchRetriever( 122 | index_name=index_name, 123 | body_func=body_func, 124 | content_field=text_field, 125 | es_client=es_client, 126 | ) 127 | 128 | index_test_data(es_client, index_name, text_field) 129 | result = retriever.get_relevant_documents("foo") 130 | 131 | assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} 132 | assert {r.metadata["_id"] for r in result} == {"3", "1", "5"} 133 | for r in result: 134 | assert set(r.metadata.keys()) == {"_index", "_id", "_score", "_source"} 135 | assert text_field not in r.metadata["_source"] 136 | assert "another_field" in r.metadata["_source"] 137 | 138 | @pytest.mark.sync 139 | def test_multiple_index_and_content_fields( 140 | self, es_client: Elasticsearch, index_name: str 141 | ) -> None: 142 | """Test multiple content fields""" 143 | index_name_1 = f"{index_name}_1" 144 | index_name_2 = f"{index_name}_2" 145 | text_field_1 = "text_1" 146 | text_field_2 = "text_2" 147 | 148 | def body_func(query: str) -> Dict: 149 | return { 150 | "query": { 151 | "multi_match": { 152 | "query": query, 153 | "fields": [text_field_1, text_field_2], 154 | } 155 | } 156 | } 157 | 158 | retriever = ElasticsearchRetriever( 159 | index_name=[index_name_1, index_name_2], 160 | content_field={index_name_1: text_field_1, index_name_2: text_field_2}, 161 | body_func=body_func, 162 | es_client=es_client, 163 | ) 164 | 165 | index_test_data(es_client, index_name_1, text_field_1) 166 | index_test_data(es_client, index_name_2, text_field_2) 167 | result = retriever.get_relevant_documents("foo") 168 | 169 | # matches from both indices 170 | assert sorted([(r.page_content, r.metadata["_index"]) for r in result]) == [ 171 | ("foo", index_name_1), 172 | ("foo", index_name_2), 173 | ("foo bar", index_name_1), 174 | ("foo bar", index_name_2), 175 | ("foo baz", index_name_1), 176 | ("foo baz", index_name_2), 177 | ] 178 | 179 | @pytest.mark.sync 180 | def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None: 181 | """Test custom document maper""" 182 | 183 | text_field = "text" 184 | meta = {"some_field": 12} 185 | 186 | def body_func(query: str) -> Dict: 187 | return {"query": {"match": {text_field: {"query": query}}}} 188 | 189 | def id_as_content(hit: Dict) -> Document: 190 | return Document(page_content=hit["_id"], metadata=meta) 191 | 192 | retriever = ElasticsearchRetriever( 193 | index_name=index_name, 194 | body_func=body_func, 195 | document_mapper=id_as_content, 196 | es_client=es_client, 197 | ) 198 | 199 | index_test_data(es_client, index_name, text_field) 200 | result = retriever.get_relevant_documents("foo") 201 | 202 | assert [r.page_content for r in result] == ["3", "1", "5"] 203 | assert [r.metadata for r in result] == [meta, meta, meta] 204 | 205 | @pytest.mark.sync 206 | def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None: 207 | """Raise exception if both content_field and document_mapper are specified.""" 208 | 209 | with pytest.raises(ValueError): 210 | ElasticsearchRetriever( 211 | content_field="text", 212 | document_mapper=lambda x: x, 213 | index_name="foo", 214 | body_func=lambda x: x, 215 | es_client=es_client, 216 | ) 217 | 218 | @pytest.mark.sync 219 | def test_fail_neither_content_field_nor_mapper( 220 | self, es_client: Elasticsearch 221 | ) -> None: 222 | """Raise exception if neither content_field nor document_mapper are specified""" 223 | 224 | with pytest.raises(ValueError): 225 | ElasticsearchRetriever( 226 | index_name="foo", 227 | body_func=lambda x: x, 228 | es_client=es_client, 229 | ) 230 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/integration_tests/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | elasticsearch: 5 | image: elasticsearch:8.13.0 6 | environment: 7 | - discovery.type=single-node 8 | - xpack.license.self_generated.type=trial 9 | - xpack.security.enabled=false # disable password and TLS; never do this in production! 10 | ports: 11 | - "9200:9200" 12 | healthcheck: 13 | test: 14 | [ 15 | "CMD-SHELL", 16 | "curl --silent --fail http://localhost:9200/_cluster/health || exit 1" 17 | ] 18 | interval: 10s 19 | retries: 60 20 | 21 | kibana: 22 | image: kibana:8.13.0 23 | environment: 24 | - ELASTICSEARCH_URL=http://elasticsearch:9200 25 | ports: 26 | - "5601:5601" 27 | healthcheck: 28 | test: 29 | [ 30 | "CMD-SHELL", 31 | "curl --silent --fail http://localhost:5601/login || exit 1" 32 | ] 33 | interval: 10s 34 | retries: 60 35 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/_async/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/unit_tests/_async/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/_async/test_cache.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Dict 3 | from unittest import mock 4 | from unittest.mock import ANY, MagicMock, patch 5 | 6 | import pytest 7 | from _pytest.fixtures import FixtureRequest 8 | from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig 9 | from elasticsearch import NotFoundError 10 | from langchain.embeddings.cache import _value_serializer 11 | from langchain_core.load import dumps 12 | from langchain_core.outputs import Generation 13 | 14 | from langchain_elasticsearch import ( 15 | AsyncElasticsearchCache, 16 | AsyncElasticsearchEmbeddingsCache, 17 | ) 18 | 19 | 20 | def serialize_encode_vector(vector: Any) -> str: 21 | return AsyncElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) 22 | 23 | 24 | @pytest.mark.asyncio 25 | async def test_initialization_llm_cache(async_es_client_fx: MagicMock) -> None: 26 | async_es_client_fx.ping.return_value = True 27 | async_es_client_fx.indices.exists_alias.return_value = True 28 | with mock.patch( 29 | "langchain_elasticsearch._sync.cache.create_elasticsearch_client", 30 | return_value=async_es_client_fx, 31 | ): 32 | with mock.patch( 33 | "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", 34 | return_value=async_es_client_fx, 35 | ): 36 | cache = AsyncElasticsearchCache( 37 | es_url="http://localhost:9200", index_name="test_index" 38 | ) 39 | assert await cache.is_alias() 40 | async_es_client_fx.indices.exists_alias.assert_awaited_with( 41 | name="test_index" 42 | ) 43 | async_es_client_fx.indices.put_mapping.assert_awaited_with( 44 | index="test_index", body=cache.mapping["mappings"] 45 | ) 46 | async_es_client_fx.indices.exists_alias.return_value = False 47 | async_es_client_fx.indices.exists.return_value = False 48 | cache = AsyncElasticsearchCache( 49 | es_url="http://localhost:9200", index_name="test_index" 50 | ) 51 | assert not (await cache.is_alias()) 52 | async_es_client_fx.indices.create.assert_awaited_with( 53 | index="test_index", body=cache.mapping 54 | ) 55 | 56 | 57 | def test_mapping_llm_cache( 58 | async_es_cache_fx: AsyncElasticsearchCache, request: FixtureRequest 59 | ) -> None: 60 | mapping = request.getfixturevalue("es_cache_fx").mapping 61 | assert mapping.get("mappings") 62 | assert mapping["mappings"].get("properties") 63 | 64 | 65 | def test_key_generation_llm_cache(es_cache_fx: AsyncElasticsearchCache) -> None: 66 | key1 = es_cache_fx._key("test_prompt", "test_llm_string") 67 | assert key1 and isinstance(key1, str) 68 | key2 = es_cache_fx._key("test_prompt", "test_llm_string1") 69 | assert key2 and key1 != key2 70 | key3 = es_cache_fx._key("test_prompt1", "test_llm_string") 71 | assert key3 and key1 != key3 72 | 73 | 74 | def test_clear_llm_cache( 75 | es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache 76 | ) -> None: 77 | es_cache_fx.clear() 78 | es_client_fx.delete_by_query.assert_called_once_with( 79 | index="test_index", 80 | body={"query": {"match_all": {}}}, 81 | refresh=True, 82 | wait_for_completion=True, 83 | ) 84 | 85 | 86 | def test_build_document_llm_cache(es_cache_fx: AsyncElasticsearchCache) -> None: 87 | doc = es_cache_fx.build_document( 88 | "test_prompt", "test_llm_string", [Generation(text="test_prompt")] 89 | ) 90 | assert doc["llm_input"] == "test_prompt" 91 | assert doc["llm_params"] == "test_llm_string" 92 | assert isinstance(doc["llm_output"], list) 93 | assert all(isinstance(gen, str) for gen in doc["llm_output"]) 94 | assert datetime.fromisoformat(str(doc["timestamp"])) 95 | assert doc["metadata"] == es_cache_fx._metadata 96 | 97 | 98 | def test_update_llm_cache( 99 | es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache 100 | ) -> None: 101 | es_cache_fx.update("test_prompt", "test_llm_string", [Generation(text="test")]) 102 | timestamp = es_client_fx.index.call_args.kwargs["body"]["timestamp"] 103 | doc = es_cache_fx.build_document( 104 | "test_prompt", "test_llm_string", [Generation(text="test")] 105 | ) 106 | doc["timestamp"] = timestamp 107 | es_client_fx.index.assert_called_once_with( 108 | index=es_cache_fx._index_name, 109 | id=es_cache_fx._key("test_prompt", "test_llm_string"), 110 | body=doc, 111 | require_alias=es_cache_fx._is_alias, 112 | refresh=True, 113 | ) 114 | 115 | 116 | def test_lookup_llm_cache( 117 | es_client_fx: MagicMock, es_cache_fx: AsyncElasticsearchCache 118 | ) -> None: 119 | cache_key = es_cache_fx._key("test_prompt", "test_llm_string") 120 | doc: Dict[str, Any] = { 121 | "_source": { 122 | "llm_output": [dumps(Generation(text="test"))], 123 | "timestamp": "2024-03-07T13:25:36.410756", 124 | } 125 | } 126 | es_cache_fx._is_alias = False 127 | es_client_fx.get.side_effect = NotFoundError( 128 | "not found", 129 | ApiResponseMeta(404, "0", HttpHeaders(), 0, NodeConfig("http", "xxx", 80)), 130 | "", 131 | ) 132 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None 133 | es_client_fx.get.assert_called_once_with( 134 | index="test_index", id=cache_key, source=["llm_output"] 135 | ) 136 | es_client_fx.get.side_effect = None 137 | es_client_fx.get.return_value = doc 138 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ 139 | Generation(text="test") 140 | ] 141 | es_cache_fx._is_alias = True 142 | es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} 143 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None 144 | es_client_fx.search.assert_called_once_with( 145 | index="test_index", 146 | body={ 147 | "query": {"term": {"_id": cache_key}}, 148 | "sort": {"timestamp": {"order": "asc"}}, 149 | }, 150 | source_includes=["llm_output"], 151 | ) 152 | doc2 = { 153 | "_source": { 154 | "llm_output": [dumps(Generation(text="test2"))], 155 | "timestamp": "2024-03-08T13:25:36.410756", 156 | }, 157 | } 158 | es_client_fx.search.return_value = { 159 | "hits": {"total": {"value": 2}, "hits": [doc2, doc]} 160 | } 161 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ 162 | Generation(text="test2") 163 | ] 164 | 165 | 166 | def test_key_generation_cache_store( 167 | es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, 168 | ) -> None: 169 | key1 = es_embeddings_cache_fx._key("test_text") 170 | assert key1 and isinstance(key1, str) 171 | key2 = es_embeddings_cache_fx._key("test_text2") 172 | assert key2 and key1 != key2 173 | es_embeddings_cache_fx._namespace = "other" 174 | key3 = es_embeddings_cache_fx._key("test_text") 175 | assert key3 and key1 != key3 176 | es_embeddings_cache_fx._namespace = None 177 | key4 = es_embeddings_cache_fx._key("test_text") 178 | assert key4 and key1 != key4 and key3 != key4 179 | 180 | 181 | def test_build_document_cache_store( 182 | es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, 183 | ) -> None: 184 | doc = es_embeddings_cache_fx.build_document( 185 | "test_text", _value_serializer([1.5, 2, 3.6]) 186 | ) 187 | assert doc["text_input"] == "test_text" 188 | assert doc["vector_dump"] == serialize_encode_vector([1.5, 2, 3.6]) 189 | assert datetime.fromisoformat(str(doc["timestamp"])) 190 | assert doc["metadata"] == es_embeddings_cache_fx._metadata 191 | 192 | 193 | def test_mget_cache_store( 194 | es_client_fx: MagicMock, es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache 195 | ) -> None: 196 | cache_keys = [ 197 | es_embeddings_cache_fx._key("test_text1"), 198 | es_embeddings_cache_fx._key("test_text2"), 199 | es_embeddings_cache_fx._key("test_text3"), 200 | ] 201 | docs = { 202 | "docs": [ 203 | {"_index": "test_index", "_id": cache_keys[0], "found": False}, 204 | { 205 | "_index": "test_index", 206 | "_id": cache_keys[1], 207 | "found": True, 208 | "_source": {"vector_dump": serialize_encode_vector([1.5, 2, 3.6])}, 209 | }, 210 | { 211 | "_index": "test_index", 212 | "_id": cache_keys[2], 213 | "found": True, 214 | "_source": {"vector_dump": serialize_encode_vector([5, 6, 7.1])}, 215 | }, 216 | ] 217 | } 218 | es_embeddings_cache_fx._is_alias = False 219 | es_client_fx.mget.return_value = docs 220 | assert es_embeddings_cache_fx.mget([]) == [] 221 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 222 | None, 223 | _value_serializer([1.5, 2, 3.6]), 224 | _value_serializer([5, 6, 7.1]), 225 | ] 226 | es_client_fx.mget.assert_called_with( 227 | index="test_index", ids=cache_keys, source_includes=["vector_dump"] 228 | ) 229 | es_embeddings_cache_fx._is_alias = True 230 | es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} 231 | assert es_embeddings_cache_fx.mget([]) == [] 232 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 233 | None, 234 | None, 235 | None, 236 | ] 237 | es_client_fx.search.assert_called_with( 238 | index="test_index", 239 | body={ 240 | "query": {"ids": {"values": cache_keys}}, 241 | "size": 3, 242 | }, 243 | source_includes=["vector_dump", "timestamp"], 244 | ) 245 | resp = { 246 | "hits": {"total": {"value": 3}, "hits": [d for d in docs["docs"] if d["found"]]} 247 | } 248 | es_client_fx.search.return_value = resp 249 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 250 | None, 251 | _value_serializer([1.5, 2, 3.6]), 252 | _value_serializer([5, 6, 7.1]), 253 | ] 254 | 255 | 256 | def test_deduplicate_hits( 257 | es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, 258 | ) -> None: 259 | hits = [ 260 | { 261 | "_id": "1", 262 | "_source": { 263 | "timestamp": "2022-01-01T00:00:00", 264 | "vector_dump": serialize_encode_vector([1, 2, 3]), 265 | }, 266 | }, 267 | { 268 | "_id": "1", 269 | "_source": { 270 | "timestamp": "2022-01-02T00:00:00", 271 | "vector_dump": serialize_encode_vector([4, 5, 6]), 272 | }, 273 | }, 274 | { 275 | "_id": "2", 276 | "_source": { 277 | "timestamp": "2022-01-01T00:00:00", 278 | "vector_dump": serialize_encode_vector([7, 8, 9]), 279 | }, 280 | }, 281 | ] 282 | 283 | result = es_embeddings_cache_fx._deduplicate_hits(hits) 284 | 285 | assert len(result) == 2 286 | assert result["1"] == _value_serializer([4, 5, 6]) 287 | assert result["2"] == _value_serializer([7, 8, 9]) 288 | 289 | 290 | def test_mget_duplicate_keys_cache_store( 291 | es_client_fx: MagicMock, es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache 292 | ) -> None: 293 | cache_keys = [ 294 | es_embeddings_cache_fx._key("test_text1"), 295 | es_embeddings_cache_fx._key("test_text2"), 296 | ] 297 | 298 | resp = { 299 | "hits": { 300 | "total": {"value": 3}, 301 | "hits": [ 302 | { 303 | "_index": "test_index", 304 | "_id": cache_keys[1], 305 | "found": True, 306 | "_source": { 307 | "vector_dump": serialize_encode_vector([1.5, 2, 3.6]), 308 | "timestamp": "2024-03-07T13:25:36.410756", 309 | }, 310 | }, 311 | { 312 | "_index": "test_index", 313 | "_id": cache_keys[0], 314 | "found": True, 315 | "_source": { 316 | "vector_dump": serialize_encode_vector([1, 6, 7.1]), 317 | "timestamp": "2024-03-07T13:25:46.410756", 318 | }, 319 | }, 320 | { 321 | "_index": "test_index", 322 | "_id": cache_keys[0], 323 | "found": True, 324 | "_source": { 325 | "vector_dump": serialize_encode_vector([2, 6, 7.1]), 326 | "timestamp": "2024-03-07T13:27:46.410756", 327 | }, 328 | }, 329 | ], 330 | } 331 | } 332 | 333 | es_embeddings_cache_fx._is_alias = True 334 | es_client_fx.search.return_value = resp 335 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2"]) == [ 336 | _value_serializer([2, 6, 7.1]), 337 | _value_serializer([1.5, 2, 3.6]), 338 | ] 339 | es_client_fx.search.assert_called_with( 340 | index="test_index", 341 | body={ 342 | "query": {"ids": {"values": cache_keys}}, 343 | "size": len(cache_keys), 344 | }, 345 | source_includes=["vector_dump", "timestamp"], 346 | ) 347 | 348 | 349 | def _del_timestamp(doc: Dict[str, Any]) -> Dict[str, Any]: 350 | del doc["_source"]["timestamp"] 351 | return doc 352 | 353 | 354 | def test_mset_cache_store( 355 | es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, 356 | ) -> None: 357 | input = [ 358 | ("test_text1", _value_serializer([1.5, 2, 3.6])), 359 | ("test_text2", _value_serializer([5, 6, 7.1])), 360 | ] 361 | actions = [ 362 | { 363 | "_op_type": "index", 364 | "_id": es_embeddings_cache_fx._key(k), 365 | "_source": es_embeddings_cache_fx.build_document(k, v), 366 | } 367 | for k, v in input 368 | ] 369 | es_embeddings_cache_fx._is_alias = False 370 | with patch("elasticsearch.helpers.bulk") as bulk_mock: 371 | es_embeddings_cache_fx.mset([]) 372 | bulk_mock.assert_called_once() 373 | es_embeddings_cache_fx.mset(input) 374 | bulk_mock.assert_called_with( 375 | client=es_embeddings_cache_fx._es_client, 376 | actions=ANY, 377 | index="test_index", 378 | require_alias=False, 379 | refresh=True, 380 | ) 381 | assert [_del_timestamp(d) for d in bulk_mock.call_args.kwargs["actions"]] == [ 382 | _del_timestamp(d) for d in actions 383 | ] 384 | 385 | 386 | def test_mdelete_cache_store( 387 | es_embeddings_cache_fx: AsyncElasticsearchEmbeddingsCache, 388 | ) -> None: 389 | input = ["test_text1", "test_text2"] 390 | actions = [ 391 | {"_op_type": "delete", "_id": es_embeddings_cache_fx._key(k)} for k in input 392 | ] 393 | es_embeddings_cache_fx._is_alias = False 394 | with patch("elasticsearch.helpers.bulk") as bulk_mock: 395 | es_embeddings_cache_fx.mdelete([]) 396 | bulk_mock.assert_called_once() 397 | es_embeddings_cache_fx.mdelete(input) 398 | bulk_mock.assert_called_with( 399 | client=es_embeddings_cache_fx._es_client, 400 | actions=ANY, 401 | index="test_index", 402 | require_alias=False, 403 | refresh=True, 404 | ) 405 | assert list(bulk_mock.call_args.kwargs["actions"]) == actions 406 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/_sync/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-elastic/83ef83351f81c07396567ecc896f70b73759f5f7/libs/elasticsearch/tests/unit_tests/_sync/__init__.py -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/_sync/test_cache.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Any, Dict 3 | from unittest import mock 4 | from unittest.mock import ANY, MagicMock, patch 5 | 6 | import pytest 7 | from _pytest.fixtures import FixtureRequest 8 | from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig 9 | from elasticsearch import NotFoundError 10 | from langchain.embeddings.cache import _value_serializer 11 | from langchain_core.load import dumps 12 | from langchain_core.outputs import Generation 13 | 14 | from langchain_elasticsearch import ( 15 | ElasticsearchCache, 16 | ElasticsearchEmbeddingsCache, 17 | ) 18 | 19 | 20 | def serialize_encode_vector(vector: Any) -> str: 21 | return ElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) 22 | 23 | 24 | @pytest.mark.sync 25 | def test_initialization_llm_cache(es_client_fx: MagicMock) -> None: 26 | es_client_fx.ping.return_value = True 27 | es_client_fx.indices.exists_alias.return_value = True 28 | with mock.patch( 29 | "langchain_elasticsearch._sync.cache.create_elasticsearch_client", 30 | return_value=es_client_fx, 31 | ): 32 | with mock.patch( 33 | "langchain_elasticsearch._async.cache.create_async_elasticsearch_client", 34 | return_value=es_client_fx, 35 | ): 36 | cache = ElasticsearchCache( 37 | es_url="http://localhost:9200", index_name="test_index" 38 | ) 39 | assert cache.is_alias() 40 | es_client_fx.indices.exists_alias.assert_called_with(name="test_index") 41 | es_client_fx.indices.put_mapping.assert_called_with( 42 | index="test_index", body=cache.mapping["mappings"] 43 | ) 44 | es_client_fx.indices.exists_alias.return_value = False 45 | es_client_fx.indices.exists.return_value = False 46 | cache = ElasticsearchCache( 47 | es_url="http://localhost:9200", index_name="test_index" 48 | ) 49 | assert not (cache.is_alias()) 50 | es_client_fx.indices.create.assert_called_with( 51 | index="test_index", body=cache.mapping 52 | ) 53 | 54 | 55 | def test_mapping_llm_cache( 56 | es_cache_fx: ElasticsearchCache, request: FixtureRequest 57 | ) -> None: 58 | mapping = request.getfixturevalue("es_cache_fx").mapping 59 | assert mapping.get("mappings") 60 | assert mapping["mappings"].get("properties") 61 | 62 | 63 | def test_key_generation_llm_cache(es_cache_fx: ElasticsearchCache) -> None: 64 | key1 = es_cache_fx._key("test_prompt", "test_llm_string") 65 | assert key1 and isinstance(key1, str) 66 | key2 = es_cache_fx._key("test_prompt", "test_llm_string1") 67 | assert key2 and key1 != key2 68 | key3 = es_cache_fx._key("test_prompt1", "test_llm_string") 69 | assert key3 and key1 != key3 70 | 71 | 72 | def test_clear_llm_cache( 73 | es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache 74 | ) -> None: 75 | es_cache_fx.clear() 76 | es_client_fx.delete_by_query.assert_called_once_with( 77 | index="test_index", 78 | body={"query": {"match_all": {}}}, 79 | refresh=True, 80 | wait_for_completion=True, 81 | ) 82 | 83 | 84 | def test_build_document_llm_cache(es_cache_fx: ElasticsearchCache) -> None: 85 | doc = es_cache_fx.build_document( 86 | "test_prompt", "test_llm_string", [Generation(text="test_prompt")] 87 | ) 88 | assert doc["llm_input"] == "test_prompt" 89 | assert doc["llm_params"] == "test_llm_string" 90 | assert isinstance(doc["llm_output"], list) 91 | assert all(isinstance(gen, str) for gen in doc["llm_output"]) 92 | assert datetime.fromisoformat(str(doc["timestamp"])) 93 | assert doc["metadata"] == es_cache_fx._metadata 94 | 95 | 96 | def test_update_llm_cache( 97 | es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache 98 | ) -> None: 99 | es_cache_fx.update("test_prompt", "test_llm_string", [Generation(text="test")]) 100 | timestamp = es_client_fx.index.call_args.kwargs["body"]["timestamp"] 101 | doc = es_cache_fx.build_document( 102 | "test_prompt", "test_llm_string", [Generation(text="test")] 103 | ) 104 | doc["timestamp"] = timestamp 105 | es_client_fx.index.assert_called_once_with( 106 | index=es_cache_fx._index_name, 107 | id=es_cache_fx._key("test_prompt", "test_llm_string"), 108 | body=doc, 109 | require_alias=es_cache_fx._is_alias, 110 | refresh=True, 111 | ) 112 | 113 | 114 | def test_lookup_llm_cache( 115 | es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache 116 | ) -> None: 117 | cache_key = es_cache_fx._key("test_prompt", "test_llm_string") 118 | doc: Dict[str, Any] = { 119 | "_source": { 120 | "llm_output": [dumps(Generation(text="test"))], 121 | "timestamp": "2024-03-07T13:25:36.410756", 122 | } 123 | } 124 | es_cache_fx._is_alias = False 125 | es_client_fx.get.side_effect = NotFoundError( 126 | "not found", 127 | ApiResponseMeta(404, "0", HttpHeaders(), 0, NodeConfig("http", "xxx", 80)), 128 | "", 129 | ) 130 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None 131 | es_client_fx.get.assert_called_once_with( 132 | index="test_index", id=cache_key, source=["llm_output"] 133 | ) 134 | es_client_fx.get.side_effect = None 135 | es_client_fx.get.return_value = doc 136 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ 137 | Generation(text="test") 138 | ] 139 | es_cache_fx._is_alias = True 140 | es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} 141 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") is None 142 | es_client_fx.search.assert_called_once_with( 143 | index="test_index", 144 | body={ 145 | "query": {"term": {"_id": cache_key}}, 146 | "sort": {"timestamp": {"order": "asc"}}, 147 | }, 148 | source_includes=["llm_output"], 149 | ) 150 | doc2 = { 151 | "_source": { 152 | "llm_output": [dumps(Generation(text="test2"))], 153 | "timestamp": "2024-03-08T13:25:36.410756", 154 | }, 155 | } 156 | es_client_fx.search.return_value = { 157 | "hits": {"total": {"value": 2}, "hits": [doc2, doc]} 158 | } 159 | assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ 160 | Generation(text="test2") 161 | ] 162 | 163 | 164 | def test_key_generation_cache_store( 165 | es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, 166 | ) -> None: 167 | key1 = es_embeddings_cache_fx._key("test_text") 168 | assert key1 and isinstance(key1, str) 169 | key2 = es_embeddings_cache_fx._key("test_text2") 170 | assert key2 and key1 != key2 171 | es_embeddings_cache_fx._namespace = "other" 172 | key3 = es_embeddings_cache_fx._key("test_text") 173 | assert key3 and key1 != key3 174 | es_embeddings_cache_fx._namespace = None 175 | key4 = es_embeddings_cache_fx._key("test_text") 176 | assert key4 and key1 != key4 and key3 != key4 177 | 178 | 179 | def test_build_document_cache_store( 180 | es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, 181 | ) -> None: 182 | doc = es_embeddings_cache_fx.build_document( 183 | "test_text", _value_serializer([1.5, 2, 3.6]) 184 | ) 185 | assert doc["text_input"] == "test_text" 186 | assert doc["vector_dump"] == serialize_encode_vector([1.5, 2, 3.6]) 187 | assert datetime.fromisoformat(str(doc["timestamp"])) 188 | assert doc["metadata"] == es_embeddings_cache_fx._metadata 189 | 190 | 191 | def test_mget_cache_store( 192 | es_client_fx: MagicMock, es_embeddings_cache_fx: ElasticsearchEmbeddingsCache 193 | ) -> None: 194 | cache_keys = [ 195 | es_embeddings_cache_fx._key("test_text1"), 196 | es_embeddings_cache_fx._key("test_text2"), 197 | es_embeddings_cache_fx._key("test_text3"), 198 | ] 199 | docs = { 200 | "docs": [ 201 | {"_index": "test_index", "_id": cache_keys[0], "found": False}, 202 | { 203 | "_index": "test_index", 204 | "_id": cache_keys[1], 205 | "found": True, 206 | "_source": {"vector_dump": serialize_encode_vector([1.5, 2, 3.6])}, 207 | }, 208 | { 209 | "_index": "test_index", 210 | "_id": cache_keys[2], 211 | "found": True, 212 | "_source": {"vector_dump": serialize_encode_vector([5, 6, 7.1])}, 213 | }, 214 | ] 215 | } 216 | es_embeddings_cache_fx._is_alias = False 217 | es_client_fx.mget.return_value = docs 218 | assert es_embeddings_cache_fx.mget([]) == [] 219 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 220 | None, 221 | _value_serializer([1.5, 2, 3.6]), 222 | _value_serializer([5, 6, 7.1]), 223 | ] 224 | es_client_fx.mget.assert_called_with( 225 | index="test_index", ids=cache_keys, source_includes=["vector_dump"] 226 | ) 227 | es_embeddings_cache_fx._is_alias = True 228 | es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} 229 | assert es_embeddings_cache_fx.mget([]) == [] 230 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 231 | None, 232 | None, 233 | None, 234 | ] 235 | es_client_fx.search.assert_called_with( 236 | index="test_index", 237 | body={ 238 | "query": {"ids": {"values": cache_keys}}, 239 | "size": 3, 240 | }, 241 | source_includes=["vector_dump", "timestamp"], 242 | ) 243 | resp = { 244 | "hits": {"total": {"value": 3}, "hits": [d for d in docs["docs"] if d["found"]]} 245 | } 246 | es_client_fx.search.return_value = resp 247 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ 248 | None, 249 | _value_serializer([1.5, 2, 3.6]), 250 | _value_serializer([5, 6, 7.1]), 251 | ] 252 | 253 | 254 | def test_deduplicate_hits( 255 | es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, 256 | ) -> None: 257 | hits = [ 258 | { 259 | "_id": "1", 260 | "_source": { 261 | "timestamp": "2022-01-01T00:00:00", 262 | "vector_dump": serialize_encode_vector([1, 2, 3]), 263 | }, 264 | }, 265 | { 266 | "_id": "1", 267 | "_source": { 268 | "timestamp": "2022-01-02T00:00:00", 269 | "vector_dump": serialize_encode_vector([4, 5, 6]), 270 | }, 271 | }, 272 | { 273 | "_id": "2", 274 | "_source": { 275 | "timestamp": "2022-01-01T00:00:00", 276 | "vector_dump": serialize_encode_vector([7, 8, 9]), 277 | }, 278 | }, 279 | ] 280 | 281 | result = es_embeddings_cache_fx._deduplicate_hits(hits) 282 | 283 | assert len(result) == 2 284 | assert result["1"] == _value_serializer([4, 5, 6]) 285 | assert result["2"] == _value_serializer([7, 8, 9]) 286 | 287 | 288 | def test_mget_duplicate_keys_cache_store( 289 | es_client_fx: MagicMock, es_embeddings_cache_fx: ElasticsearchEmbeddingsCache 290 | ) -> None: 291 | cache_keys = [ 292 | es_embeddings_cache_fx._key("test_text1"), 293 | es_embeddings_cache_fx._key("test_text2"), 294 | ] 295 | 296 | resp = { 297 | "hits": { 298 | "total": {"value": 3}, 299 | "hits": [ 300 | { 301 | "_index": "test_index", 302 | "_id": cache_keys[1], 303 | "found": True, 304 | "_source": { 305 | "vector_dump": serialize_encode_vector([1.5, 2, 3.6]), 306 | "timestamp": "2024-03-07T13:25:36.410756", 307 | }, 308 | }, 309 | { 310 | "_index": "test_index", 311 | "_id": cache_keys[0], 312 | "found": True, 313 | "_source": { 314 | "vector_dump": serialize_encode_vector([1, 6, 7.1]), 315 | "timestamp": "2024-03-07T13:25:46.410756", 316 | }, 317 | }, 318 | { 319 | "_index": "test_index", 320 | "_id": cache_keys[0], 321 | "found": True, 322 | "_source": { 323 | "vector_dump": serialize_encode_vector([2, 6, 7.1]), 324 | "timestamp": "2024-03-07T13:27:46.410756", 325 | }, 326 | }, 327 | ], 328 | } 329 | } 330 | 331 | es_embeddings_cache_fx._is_alias = True 332 | es_client_fx.search.return_value = resp 333 | assert es_embeddings_cache_fx.mget(["test_text1", "test_text2"]) == [ 334 | _value_serializer([2, 6, 7.1]), 335 | _value_serializer([1.5, 2, 3.6]), 336 | ] 337 | es_client_fx.search.assert_called_with( 338 | index="test_index", 339 | body={ 340 | "query": {"ids": {"values": cache_keys}}, 341 | "size": len(cache_keys), 342 | }, 343 | source_includes=["vector_dump", "timestamp"], 344 | ) 345 | 346 | 347 | def _del_timestamp(doc: Dict[str, Any]) -> Dict[str, Any]: 348 | del doc["_source"]["timestamp"] 349 | return doc 350 | 351 | 352 | def test_mset_cache_store( 353 | es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, 354 | ) -> None: 355 | input = [ 356 | ("test_text1", _value_serializer([1.5, 2, 3.6])), 357 | ("test_text2", _value_serializer([5, 6, 7.1])), 358 | ] 359 | actions = [ 360 | { 361 | "_op_type": "index", 362 | "_id": es_embeddings_cache_fx._key(k), 363 | "_source": es_embeddings_cache_fx.build_document(k, v), 364 | } 365 | for k, v in input 366 | ] 367 | es_embeddings_cache_fx._is_alias = False 368 | with patch("elasticsearch.helpers.bulk") as bulk_mock: 369 | es_embeddings_cache_fx.mset([]) 370 | bulk_mock.assert_called_once() 371 | es_embeddings_cache_fx.mset(input) 372 | bulk_mock.assert_called_with( 373 | client=es_embeddings_cache_fx._es_client, 374 | actions=ANY, 375 | index="test_index", 376 | require_alias=False, 377 | refresh=True, 378 | ) 379 | assert [_del_timestamp(d) for d in bulk_mock.call_args.kwargs["actions"]] == [ 380 | _del_timestamp(d) for d in actions 381 | ] 382 | 383 | 384 | def test_mdelete_cache_store( 385 | es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, 386 | ) -> None: 387 | input = ["test_text1", "test_text2"] 388 | actions = [ 389 | {"_op_type": "delete", "_id": es_embeddings_cache_fx._key(k)} for k in input 390 | ] 391 | es_embeddings_cache_fx._is_alias = False 392 | with patch("elasticsearch.helpers.bulk") as bulk_mock: 393 | es_embeddings_cache_fx.mdelete([]) 394 | bulk_mock.assert_called_once() 395 | es_embeddings_cache_fx.mdelete(input) 396 | bulk_mock.assert_called_with( 397 | client=es_embeddings_cache_fx._es_client, 398 | actions=ANY, 399 | index="test_index", 400 | require_alias=False, 401 | refresh=True, 402 | ) 403 | assert list(bulk_mock.call_args.kwargs["actions"]) == actions 404 | -------------------------------------------------------------------------------- /libs/elasticsearch/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_elasticsearch import __all__ 2 | 3 | EXPECTED_ALL = sorted( 4 | [ 5 | "ElasticsearchCache", 6 | "ElasticsearchChatMessageHistory", 7 | "ElasticsearchEmbeddings", 8 | "ElasticsearchEmbeddingsCache", 9 | "ElasticsearchRetriever", 10 | "ElasticsearchStore", 11 | "AsyncElasticsearchCache", 12 | "AsyncElasticsearchChatMessageHistory", 13 | "AsyncElasticsearchEmbeddings", 14 | "AsyncElasticsearchEmbeddingsCache", 15 | "AsyncElasticsearchRetriever", 16 | "AsyncElasticsearchStore", 17 | # retrieval strategies 18 | "BM25Strategy", 19 | "DenseVectorScriptScoreStrategy", 20 | "DenseVectorStrategy", 21 | "DistanceMetric", 22 | "RetrievalStrategy", 23 | "SparseVectorStrategy", 24 | "AsyncBM25Strategy", 25 | "AsyncDenseVectorScriptScoreStrategy", 26 | "AsyncDenseVectorStrategy", 27 | "AsyncRetrievalStrategy", 28 | "AsyncSparseVectorStrategy", 29 | # deprecated retrieval strategies 30 | "ApproxRetrievalStrategy", 31 | "BM25RetrievalStrategy", 32 | "ExactRetrievalStrategy", 33 | "SparseRetrievalStrategy", 34 | ] 35 | ) 36 | 37 | 38 | def test_all_imports() -> None: 39 | assert sorted(EXPECTED_ALL) == sorted(__all__) 40 | --------------------------------------------------------------------------------