├── .devcontainer ├── Dockerfile ├── README.md └── devcontainer.json ├── .github ├── actions │ └── poetry_setup │ │ └── action.yml ├── scripts │ ├── check_diff.py │ └── get_min_versions.py └── workflows │ ├── _codespell.yml │ ├── _compile_integration_test.yml │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ ├── _test_release.yml │ ├── check_diffs.yml │ └── extract_ignored_words_list.py ├── .gitignore ├── LICENSE ├── README.md └── libs ├── azure-ai ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docs │ ├── azure_cosmos_db.ipynb │ └── azure_cosmos_db_no_sql.ipynb ├── langchain_azure_ai │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ └── tracers │ │ │ ├── __init__.py │ │ │ ├── _semantic_conventions_gen_ai.py │ │ │ └── inference_tracing.py │ ├── chat_message_histories │ │ ├── __init__.py │ │ └── cosmos_db.py │ ├── chat_models │ │ ├── __init__.py │ │ └── inference.py │ ├── embeddings │ │ ├── __init__.py │ │ └── inference.py │ ├── query_constructors │ │ ├── __init__.py │ │ └── cosmosdb_no_sql.py │ ├── utils │ │ ├── math.py │ │ └── utils.py │ └── vectorstores │ │ ├── __init__.py │ │ ├── azure_cosmos_db_mongo_vcore.py │ │ ├── azure_cosmos_db_no_sql.py │ │ ├── cache.py │ │ └── utils.py ├── poetry.lock ├── pyproject.toml ├── scripts │ ├── check_imports.py │ └── lint_imports.sh └── tests │ ├── __init__.py │ ├── integration_tests │ ├── __init__.py │ ├── cache │ │ ├── test_azure_cosmos_db_mongo_vcore_cache.py │ │ └── test_azure_cosmos_db_no_sql_cache.py │ ├── test_compile.py │ └── vectorstores │ │ ├── __init__.py │ │ ├── test_azure_cosmos_db_mongo_vcore.py │ │ └── test_azure_cosmos_db_no_sql.py │ └── unit_tests │ ├── __init__.py │ ├── test_chat_models.py │ ├── test_embeddings.py │ └── test_queryconstructor_no_sql.py ├── azure-dynamic-sessions ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── langchain_azure_dynamic_sessions │ ├── __init__.py │ ├── py.typed │ └── tools │ │ ├── __init__.py │ │ └── sessions.py ├── poetry.lock ├── pyproject.toml ├── scripts │ ├── check_imports.py │ └── lint_imports.sh └── tests │ ├── __init__.py │ ├── integration_tests │ ├── __init__.py │ ├── data │ │ └── testdata.txt │ ├── test_compile.py │ └── test_end_to_end.py │ └── unit_tests │ ├── __init__.py │ ├── test_imports.py │ └── test_sessions_python_repl_tool.py └── sqlserver ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── langchain_sqlserver ├── __init__.py ├── py.typed └── vectorstores.py ├── poetry.lock ├── pyproject.toml ├── scripts ├── check_imports.py └── lint_imports.sh └── tests ├── __init__.py ├── integration_tests ├── __init__.py ├── test_compile.py └── test_vectorstores.py ├── unit_tests ├── __init__.py ├── test_imports.py └── test_vectorstores.py └── utils ├── fake_embeddings.py └── filtering_test_cases.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mcr.microsoft.com/devcontainers/base:jammy 2 | # FROM mcr.microsoft.com/devcontainers/base:jammy 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | ARG USER=vscode 6 | 7 | RUN DEBIAN_FRONTEND=noninteractive \ 8 | && apt-get update \ 9 | && apt-get install -y build-essential --no-install-recommends make \ 10 | ca-certificates \ 11 | git \ 12 | libssl-dev \ 13 | zlib1g-dev \ 14 | libbz2-dev \ 15 | libreadline-dev \ 16 | libsqlite3-dev \ 17 | wget \ 18 | curl \ 19 | llvm \ 20 | libncurses5-dev \ 21 | xz-utils \ 22 | tk-dev \ 23 | libxml2-dev \ 24 | libxmlsec1-dev \ 25 | libffi-dev \ 26 | liblzma-dev 27 | 28 | # Python and poetry installation 29 | USER $USER 30 | ARG HOME="/home/$USER" 31 | ARG PYTHON_VERSION=3.9 32 | # ARG PYTHON_VERSION=3.10 33 | 34 | ENV PYENV_ROOT="${HOME}/.pyenv" 35 | ENV PATH="${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${HOME}/.local/bin:$PATH" 36 | 37 | RUN echo "done 0" \ 38 | && curl https://pyenv.run | bash \ 39 | && echo "done 1" \ 40 | && pyenv install ${PYTHON_VERSION} \ 41 | && echo "done 2" \ 42 | && pyenv global ${PYTHON_VERSION} \ 43 | && echo "done 3" \ 44 | && curl -sSL https://install.python-poetry.org | python3 - \ 45 | && poetry config virtualenvs.in-project true -------------------------------------------------------------------------------- /.devcontainer/README.md: -------------------------------------------------------------------------------- 1 | # Dev container 2 | 3 | This project includes a [dev container](https://containers.dev/), which lets you use a container as a full-featured dev environment. 4 | 5 | You can use the dev container configuration in this folder to build and run the app without needing to install any of its tools locally! This container comes with Python 3.9, Azure CLI and Poetry pre-installed. You can use it in [GitHub Codespaces](https://github.com/features/codespaces) or the [VS Code Dev Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers). 6 | 7 | ## GitHub Codespaces 8 | [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/langchain-ai/langchain-azure) 9 | 10 | You may use the button above, or follow these steps to open this repo in a Codespace: 11 | 1. Click the **Code** drop-down menu at the top of https://github.com/langchain-ai/langchain-azure. 12 | 1. Click on the **Codespaces** tab. 13 | 1. Click **Create codespace on master**. 14 | 15 | For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). 16 | 17 | ## VS Code Dev Containers 18 | [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langchain-ai/langchain-azure) 19 | 20 | Note: If you click the link above you will open the main repo (langchain-ai/langchain) and not your local cloned repo. This is fine if you only want to run and test the library, but if you want to contribute you can use the link below and replace with your username and cloned repo name: 21 | ``` 22 | https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com// 23 | 24 | ``` 25 | Then you will have a local cloned repo where you can contribute and then create pull requests. 26 | 27 | If you already have VS Code and Docker installed, you can use the button above to get started. This will cause VS Code to automatically install the Dev Containers extension if needed, clone the source code into a container volume, and spin up a dev container for use. 28 | 29 | Alternatively you can also follow these steps to open this repo in a container using the VS Code Dev Containers extension: 30 | 31 | 1. If this is your first time using a development container, please ensure your system meets the pre-reqs (i.e. have Docker installed) in the [getting started steps](https://aka.ms/vscode-remote/containers/getting-started). 32 | 33 | 2. Open a locally cloned copy of the code: 34 | 35 | - Fork and Clone this repository to your local filesystem. 36 | - Press F1 and select the **Dev Containers: Open Folder in Container...** command. 37 | - Select the cloned copy of this folder, wait for the container to start, and try things out! 38 | 39 | You can learn more in the [Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). 40 | 41 | ## Tips and tricks 42 | 43 | * If you are working with the same repository folder in a container and Windows, you'll want consistent line endings (otherwise you may see hundreds of changes in the SCM view). The `.gitattributes` file in the root of this repo will disable line ending conversion and should prevent this. See [tips and tricks](https://code.visualstudio.com/docs/devcontainers/tips-and-tricks#_resolving-git-line-ending-issues-in-containers-resulting-in-many-modified-files) for more info. 44 | * If you'd like to review the contents of the image used in this dev container, you can check it out in the [devcontainers/images](https://github.com/devcontainers/images/tree/main/src/python) repo. 45 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "poetry3-poetry-pyenv", 3 | "build": { 4 | "dockerfile": "Dockerfile" 5 | }, 6 | 7 | // Features to add to the Dev Container. More info: https://containers.dev/implementors/features. 8 | "features": { 9 | "ghcr.io/devcontainers/features/azure-cli:1": { 10 | } 11 | }, 12 | 13 | // 👇 Use 'forwardPorts' to make a list of ports inside the container available locally. 14 | // "forwardPorts": [], 15 | 16 | // 👇 Use 'postCreateCommand' to run commands after the container is created. 17 | // "postCreateCommand": "", 18 | 19 | // 👇 Configure tool-specific properties. 20 | "customizations": { 21 | "vscode": { 22 | "extensions":["ms-python.python", "njpwerner.autodocstring"] 23 | } 24 | } 25 | 26 | // 👇 Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 27 | // "remoteUser": "root" 28 | } 29 | -------------------------------------------------------------------------------- /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python || true 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry || true 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} 94 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import Dict 4 | 5 | LIB_DIRS = [ 6 | "libs/azure-dynamic-sessions", 7 | "libs/sqlserver", 8 | "libs/azure-ai", 9 | ] 10 | 11 | if __name__ == "__main__": 12 | files = sys.argv[1:] 13 | 14 | dirs_to_run: Dict[str, set] = { 15 | "lint": set(), 16 | "test": set(), 17 | } 18 | 19 | if len(files) == 300: 20 | # max diff length is 300 files - there are likely files missing 21 | raise ValueError("Max diff reached. Please manually run CI on changed libs.") 22 | 23 | for file in files: 24 | if any( 25 | file.startswith(dir_) 26 | for dir_ in ( 27 | ".github/workflows", 28 | ".github/tools", 29 | ".github/actions", 30 | ".github/scripts/check_diff.py", 31 | ) 32 | ): 33 | # add all LANGCHAIN_DIRS for infra changes 34 | dirs_to_run["test"].update(LIB_DIRS) 35 | 36 | if any(file.startswith(dir_) for dir_ in LIB_DIRS): 37 | for dir_ in LIB_DIRS: 38 | if file.startswith(dir_): 39 | dirs_to_run["test"].add(dir_) 40 | elif file.startswith("libs/"): 41 | raise ValueError( 42 | f"Unknown lib: {file}. check_diff.py likely needs " 43 | "an update for this new library!" 44 | ) 45 | 46 | outputs = { 47 | "dirs-to-lint": list(dirs_to_run["lint"] | dirs_to_run["test"]), 48 | "dirs-to-test": list(dirs_to_run["test"]), 49 | } 50 | for key, value in outputs.items(): 51 | json_output = json.dumps(value) 52 | print(f"{key}={json_output}") # noqa: T201 53 | -------------------------------------------------------------------------------- /.github/scripts/get_min_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import tomllib 4 | from packaging.version import parse as parse_version 5 | import re 6 | 7 | MIN_VERSION_LIBS = ["langchain-core"] 8 | 9 | 10 | def get_min_version(version: str) -> str: 11 | # case ^x.x.x 12 | _match = re.match(r"^\^(\d+(?:\.\d+){0,2})$", version) 13 | if _match: 14 | return _match.group(1) 15 | 16 | # case >=x.x.x,=(\d+(?:\.\d+){0,2}),<(\d+(?:\.\d+){0,2})$", version) 18 | if _match: 19 | _min = _match.group(1) 20 | _max = _match.group(2) 21 | assert parse_version(_min) < parse_version(_max) 22 | return _min 23 | 24 | # case x.x.x 25 | _match = re.match(r"^(\d+(?:\.\d+){0,2})$", version) 26 | if _match: 27 | return _match.group(1) 28 | 29 | raise ValueError(f"Unrecognized version format: {version}") 30 | 31 | 32 | def get_min_version_from_toml(toml_path: str): 33 | # Parse the TOML file 34 | with open(toml_path, "rb") as file: 35 | toml_data = tomllib.load(file) 36 | 37 | # Get the dependencies from tool.poetry.dependencies 38 | dependencies = toml_data["tool"]["poetry"]["dependencies"] 39 | 40 | # Initialize a dictionary to store the minimum versions 41 | min_versions = {} 42 | 43 | # Iterate over the libs in MIN_VERSION_LIBS 44 | for lib in MIN_VERSION_LIBS: 45 | # Check if the lib is present in the dependencies 46 | if lib in dependencies: 47 | # Get the version string 48 | version_string = dependencies[lib] 49 | 50 | # Use parse_version to get the minimum supported version from version_string 51 | min_version = get_min_version(version_string) 52 | 53 | # Store the minimum version in the min_versions dictionary 54 | min_versions[lib] = min_version 55 | 56 | return min_versions 57 | 58 | 59 | # Get the TOML file path from the command line argument 60 | toml_file = sys.argv[1] 61 | 62 | # Call the function to get the minimum versions 63 | min_versions = get_min_version_from_toml(toml_file) 64 | 65 | print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()])) 66 | -------------------------------------------------------------------------------- /.github/workflows/_codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: make spell_check 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | working-directory: 8 | required: true 9 | type: string 10 | description: "From which folder this pipeline executes" 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | codespell: 17 | name: (Check for spelling errors) 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Dependencies 25 | run: | 26 | pip install toml 27 | 28 | - name: Extract Ignore Words List 29 | working-directory: ${{ inputs.working-directory }} 30 | run: | 31 | # Use a Python script to extract the ignore words list from pyproject.toml 32 | python ../../.github/workflows/extract_ignored_words_list.py 33 | id: extract_ignore_words 34 | 35 | - name: Codespell 36 | uses: codespell-project/actions-codespell@v2 37 | with: 38 | skip: guide_imports.json 39 | ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} 40 | -------------------------------------------------------------------------------- /.github/workflows/_compile_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: compile-integration-test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.12" 25 | name: "poetry run pytest -m compile tests/integration_tests #${{ matrix.python-version }}" 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 30 | uses: "./.github/actions/poetry_setup" 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | poetry-version: ${{ env.POETRY_VERSION }} 34 | working-directory: ${{ inputs.working-directory }} 35 | cache-key: compile-integration 36 | 37 | - name: Install integration dependencies 38 | shell: bash 39 | run: poetry install --with=test_integration,test 40 | 41 | - name: Check integration tests compile 42 | shell: bash 43 | run: poetry run pytest -m compile tests/integration_tests 44 | 45 | - name: Ensure the tests did not create any additional files 46 | shell: bash 47 | run: | 48 | set -eu 49 | 50 | STATUS="$(git status)" 51 | echo "$STATUS" 52 | 53 | # grep will exit non-zero if the target message isn't found, 54 | # and `set -e` above will cause the step to fail. 55 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 56 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 14 | 15 | # This env var allows us to get inline annotations when ruff has complaints. 16 | RUFF_OUTPUT_FORMAT: github 17 | 18 | jobs: 19 | build: 20 | name: "make lint #${{ matrix.python-version }}" 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.9" 33 | - "3.12" 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 38 | uses: "./.github/actions/poetry_setup" 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | poetry-version: ${{ env.POETRY_VERSION }} 42 | working-directory: ${{ inputs.working-directory }} 43 | cache-key: lint-with-extras 44 | 45 | - name: Check Poetry File 46 | shell: bash 47 | working-directory: ${{ inputs.working-directory }} 48 | run: | 49 | poetry check 50 | 51 | - name: Check lock file 52 | shell: bash 53 | working-directory: ${{ inputs.working-directory }} 54 | run: | 55 | poetry lock --check 56 | 57 | - name: Install dependencies 58 | # Also installs dev/lint/test/typing dependencies, to ensure we have 59 | # type hints for as many of our libraries as possible. 60 | # This helps catch errors that require dependencies to be spotted, for example: 61 | # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 62 | # 63 | # If you change this configuration, make sure to change the `cache-key` 64 | # in the `poetry_setup` action above to stop using the old cache. 65 | # It doesn't matter how you change it, any change will cause a cache-bust. 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | poetry install --with lint,typing --all-extras 69 | 70 | - name: Get .mypy_cache to speed up mypy 71 | uses: actions/cache@v4 72 | env: 73 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 74 | with: 75 | path: | 76 | ${{ env.WORKDIR }}/.mypy_cache 77 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 78 | 79 | 80 | - name: Analysing the code with our lint 81 | working-directory: ${{ inputs.working-directory }} 82 | run: | 83 | make lint_package 84 | 85 | - name: Install unit+integration test dependencies 86 | working-directory: ${{ inputs.working-directory }} 87 | run: | 88 | poetry install --with test,test_integration --all-extras 89 | 90 | - name: Get .mypy_cache_test to speed up mypy 91 | uses: actions/cache@v4 92 | env: 93 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 94 | with: 95 | path: | 96 | ${{ env.WORKDIR }}/.mypy_cache_test 97 | key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 98 | 99 | - name: Analysing the code with our lint 100 | working-directory: ${{ inputs.working-directory }} 101 | run: | 102 | make lint_tests 103 | -------------------------------------------------------------------------------- /.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | required: true 14 | type: string 15 | default: 'libs/azure-dynamic-sessions' 16 | 17 | env: 18 | PYTHON_VERSION: "3.11" 19 | POETRY_VERSION: "1.7.1" 20 | 21 | jobs: 22 | build: 23 | if: github.ref == 'refs/heads/main' 24 | runs-on: ubuntu-latest 25 | 26 | outputs: 27 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 28 | version: ${{ steps.check-version.outputs.version }} 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | 33 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 34 | uses: "./.github/actions/poetry_setup" 35 | with: 36 | python-version: ${{ env.PYTHON_VERSION }} 37 | poetry-version: ${{ env.POETRY_VERSION }} 38 | working-directory: ${{ inputs.working-directory }} 39 | cache-key: release 40 | 41 | # We want to keep this build stage *separate* from the release stage, 42 | # so that there's no sharing of permissions between them. 43 | # The release stage has trusted publishing and GitHub repo contents write access, 44 | # and we want to keep the scope of that access limited just to the release job. 45 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 46 | # could get access to our GitHub or PyPI credentials. 47 | # 48 | # Per the trusted publishing GitHub Action: 49 | # > It is strongly advised to separate jobs for building [...] 50 | # > from the publish job. 51 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 52 | - name: Build project for distribution 53 | run: poetry build 54 | working-directory: ${{ inputs.working-directory }} 55 | 56 | - name: Upload build 57 | uses: actions/upload-artifact@v4 58 | with: 59 | name: dist 60 | path: ${{ inputs.working-directory }}/dist/ 61 | 62 | - name: Check Version 63 | id: check-version 64 | shell: bash 65 | working-directory: ${{ inputs.working-directory }} 66 | run: | 67 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 68 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 69 | 70 | test-pypi-publish: 71 | needs: 72 | - build 73 | uses: 74 | ./.github/workflows/_test_release.yml 75 | permissions: write-all 76 | with: 77 | working-directory: ${{ inputs.working-directory }} 78 | secrets: inherit 79 | 80 | pre-release-checks: 81 | needs: 82 | - build 83 | - test-pypi-publish 84 | runs-on: ubuntu-latest 85 | steps: 86 | - uses: actions/checkout@v4 87 | 88 | # We explicitly *don't* set up caching here. This ensures our tests are 89 | # maximally sensitive to catching breakage. 90 | # 91 | # For example, here's a way that caching can cause a falsely-passing test: 92 | # - Make the langchain package manifest no longer list a dependency package 93 | # as a requirement. This means it won't be installed by `pip install`, 94 | # and attempting to use it would cause a crash. 95 | # - That dependency used to be required, so it may have been cached. 96 | # When restoring the venv packages from cache, that dependency gets included. 97 | # - Tests pass, because the dependency is present even though it wasn't specified. 98 | # - The package is published, and it breaks on the missing dependency when 99 | # used in the real world. 100 | 101 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 102 | uses: "./.github/actions/poetry_setup" 103 | with: 104 | python-version: ${{ env.PYTHON_VERSION }} 105 | poetry-version: ${{ env.POETRY_VERSION }} 106 | working-directory: ${{ inputs.working-directory }} 107 | 108 | - name: Import published package 109 | shell: bash 110 | working-directory: ${{ inputs.working-directory }} 111 | env: 112 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 113 | VERSION: ${{ needs.build.outputs.version }} 114 | # Here we use: 115 | # - The default regular PyPI index as the *primary* index, meaning 116 | # that it takes priority (https://pypi.org/simple) 117 | # - The test PyPI index as an extra index, so that any dependencies that 118 | # are not found on test PyPI can be resolved and installed anyway. 119 | # (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION 120 | # package because VERSION will not have been uploaded to regular PyPI yet. 121 | # - attempt install again after 5 seconds if it fails because there is 122 | # sometimes a delay in availability on test pypi 123 | run: | 124 | poetry run pip install \ 125 | --extra-index-url https://test.pypi.org/simple/ \ 126 | "$PKG_NAME==$VERSION" || \ 127 | ( \ 128 | sleep 5 && \ 129 | poetry run pip install \ 130 | --extra-index-url https://test.pypi.org/simple/ \ 131 | "$PKG_NAME==$VERSION" \ 132 | ) 133 | 134 | # Replace all dashes in the package name with underscores, 135 | # since that's how Python imports packages with dashes in the name. 136 | IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" 137 | 138 | poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" 139 | 140 | - name: Import test dependencies 141 | run: poetry install --with test,test_integration 142 | working-directory: ${{ inputs.working-directory }} 143 | 144 | # Overwrite the local version of the package with the test PyPI version. 145 | - name: Import published package (again) 146 | working-directory: ${{ inputs.working-directory }} 147 | shell: bash 148 | env: 149 | PKG_NAME: ${{ needs.build.outputs.pkg-name }} 150 | VERSION: ${{ needs.build.outputs.version }} 151 | run: | 152 | poetry run pip install \ 153 | --extra-index-url https://test.pypi.org/simple/ \ 154 | "$PKG_NAME==$VERSION" 155 | 156 | - name: Run unit tests 157 | run: make tests 158 | working-directory: ${{ inputs.working-directory }} 159 | 160 | - name: Run integration tests 161 | run: make integration_tests 162 | working-directory: ${{ inputs.working-directory }} 163 | 164 | - name: Get minimum versions 165 | working-directory: ${{ inputs.working-directory }} 166 | id: min-version 167 | run: | 168 | poetry run pip install packaging 169 | min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml)" 170 | echo "min-versions=$min_versions" >> "$GITHUB_OUTPUT" 171 | echo "min-versions=$min_versions" 172 | 173 | - name: Run unit tests with minimum dependency versions 174 | if: ${{ steps.min-version.outputs.min-versions != '' }} 175 | env: 176 | MIN_VERSIONS: ${{ steps.min-version.outputs.min-versions }} 177 | run: | 178 | poetry run pip install $MIN_VERSIONS 179 | make tests 180 | working-directory: ${{ inputs.working-directory }} 181 | 182 | publish: 183 | needs: 184 | - build 185 | - test-pypi-publish 186 | - pre-release-checks 187 | runs-on: ubuntu-latest 188 | permissions: 189 | # This permission is used for trusted publishing: 190 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 191 | # 192 | # Trusted publishing has to also be configured on PyPI for each package: 193 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 194 | id-token: write 195 | 196 | defaults: 197 | run: 198 | working-directory: ${{ inputs.working-directory }} 199 | 200 | steps: 201 | - uses: actions/checkout@v4 202 | 203 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 204 | uses: "./.github/actions/poetry_setup" 205 | with: 206 | python-version: ${{ env.PYTHON_VERSION }} 207 | poetry-version: ${{ env.POETRY_VERSION }} 208 | working-directory: ${{ inputs.working-directory }} 209 | cache-key: release 210 | 211 | - uses: actions/download-artifact@v4 212 | with: 213 | name: dist 214 | path: ${{ inputs.working-directory }}/dist/ 215 | 216 | - name: Publish package distributions to PyPI 217 | uses: pypa/gh-action-pypi-publish@release/v1 218 | with: 219 | packages-dir: ${{ inputs.working-directory }}/dist/ 220 | verbose: true 221 | print-hash: true 222 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1\.11\.0 223 | attestations: false 224 | 225 | mark-release: 226 | needs: 227 | - build 228 | - test-pypi-publish 229 | - pre-release-checks 230 | - publish 231 | runs-on: ubuntu-latest 232 | permissions: 233 | # This permission is needed by `ncipollo/release-action` to 234 | # create the GitHub release. 235 | contents: write 236 | 237 | defaults: 238 | run: 239 | working-directory: ${{ inputs.working-directory }} 240 | 241 | steps: 242 | - uses: actions/checkout@v4 243 | 244 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 245 | uses: "./.github/actions/poetry_setup" 246 | with: 247 | python-version: ${{ env.PYTHON_VERSION }} 248 | poetry-version: ${{ env.POETRY_VERSION }} 249 | working-directory: ${{ inputs.working-directory }} 250 | cache-key: release 251 | 252 | - uses: actions/download-artifact@v4 253 | with: 254 | name: dist 255 | path: ${{ inputs.working-directory }}/dist/ 256 | 257 | - name: Create Release 258 | uses: ncipollo/release-action@v1 259 | with: 260 | artifacts: "dist/*" 261 | token: ${{ secrets.GITHUB_TOKEN }} 262 | draft: false 263 | generateReleaseNotes: true 264 | tag: ${{ inputs.working-directory }}/v${{ needs.build.outputs.version }} 265 | commit: main 266 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.12" 25 | name: "make test #${{ matrix.python-version }}" 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 30 | uses: "./.github/actions/poetry_setup" 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | poetry-version: ${{ env.POETRY_VERSION }} 34 | working-directory: ${{ inputs.working-directory }} 35 | cache-key: core 36 | 37 | - name: Install dependencies 38 | shell: bash 39 | run: poetry install --with test 40 | 41 | - name: Run core tests 42 | shell: bash 43 | run: | 44 | make test 45 | 46 | - name: Ensure the tests did not create any additional files 47 | shell: bash 48 | run: | 49 | set -eu 50 | 51 | STATUS="$(git status)" 52 | echo "$STATUS" 53 | 54 | # grep will exit non-zero if the target message isn't found, 55 | # and `set -e` above will cause the step to fail. 56 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 57 | -------------------------------------------------------------------------------- /.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | PYTHON_VERSION: "3.10" 14 | 15 | jobs: 16 | build: 17 | if: github.ref == 'refs/heads/main' 18 | runs-on: ubuntu-latest 19 | 20 | outputs: 21 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 22 | version: ${{ steps.check-version.outputs.version }} 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 28 | uses: "./.github/actions/poetry_setup" 29 | with: 30 | python-version: ${{ env.PYTHON_VERSION }} 31 | poetry-version: ${{ env.POETRY_VERSION }} 32 | working-directory: ${{ inputs.working-directory }} 33 | cache-key: release 34 | 35 | # We want to keep this build stage *separate* from the release stage, 36 | # so that there's no sharing of permissions between them. 37 | # The release stage has trusted publishing and GitHub repo contents write access, 38 | # and we want to keep the scope of that access limited just to the release job. 39 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 40 | # could get access to our GitHub or PyPI credentials. 41 | # 42 | # Per the trusted publishing GitHub Action: 43 | # > It is strongly advised to separate jobs for building [...] 44 | # > from the publish job. 45 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 46 | - name: Build project for distribution 47 | run: poetry build 48 | working-directory: ${{ inputs.working-directory }} 49 | 50 | - name: Upload build 51 | uses: actions/upload-artifact@v4 52 | with: 53 | name: test-dist 54 | path: ${{ inputs.working-directory }}/dist/ 55 | 56 | - name: Check Version 57 | id: check-version 58 | shell: bash 59 | working-directory: ${{ inputs.working-directory }} 60 | run: | 61 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 62 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 63 | 64 | publish: 65 | needs: 66 | - build 67 | runs-on: ubuntu-latest 68 | permissions: 69 | # This permission is used for trusted publishing: 70 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 71 | # 72 | # Trusted publishing has to also be configured on PyPI for each package: 73 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 74 | id-token: write 75 | 76 | steps: 77 | - uses: actions/checkout@v4 78 | 79 | - uses: actions/download-artifact@v4 80 | with: 81 | name: test-dist 82 | path: ${{ inputs.working-directory }}/dist/ 83 | 84 | - name: Publish to test PyPI 85 | uses: pypa/gh-action-pypi-publish@release/v1 86 | with: 87 | packages-dir: ${{ inputs.working-directory }}/dist/ 88 | verbose: true 89 | print-hash: true 90 | repository-url: https://test.pypi.org/legacy/ 91 | 92 | # We overwrite any existing distributions with the same name and version. 93 | # This is *only for CI use* and is *extremely dangerous* otherwise! 94 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 95 | skip-existing: true 96 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0 97 | attestations: false 98 | -------------------------------------------------------------------------------- /.github/workflows/check_diffs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | # If another push to the same PR or branch happens while this workflow is still running, 10 | # cancel the earlier run in favor of the next run. 11 | # 12 | # There's no point in testing an outdated version of the code. GitHub only allows 13 | # a limited number of job runners to be active at the same time, so it's better to cancel 14 | # pointless jobs early so that more useful jobs can run sooner. 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | env: 20 | POETRY_VERSION: "1.7.1" 21 | 22 | jobs: 23 | build: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | - id: files 31 | uses: Ana06/get-changed-files@v2.2.0 32 | - id: set-matrix 33 | run: | 34 | python .github/scripts/check_diff.py ${{ steps.files.outputs.all }} >> $GITHUB_OUTPUT 35 | outputs: 36 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }} 37 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }} 38 | lint: 39 | name: cd ${{ matrix.working-directory }} 40 | needs: [ build ] 41 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }} 42 | strategy: 43 | matrix: 44 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }} 45 | uses: ./.github/workflows/_lint.yml 46 | with: 47 | working-directory: ${{ matrix.working-directory }} 48 | secrets: inherit 49 | 50 | test: 51 | name: cd ${{ matrix.working-directory }} 52 | needs: [ build ] 53 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 54 | strategy: 55 | matrix: 56 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 57 | uses: ./.github/workflows/_test.yml 58 | with: 59 | working-directory: ${{ matrix.working-directory }} 60 | secrets: inherit 61 | 62 | compile-integration-tests: 63 | name: cd ${{ matrix.working-directory }} 64 | needs: [ build ] 65 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 66 | strategy: 67 | matrix: 68 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 69 | uses: ./.github/workflows/_compile_integration_test.yml 70 | with: 71 | working-directory: ${{ matrix.working-directory }} 72 | secrets: inherit 73 | ci_success: 74 | name: "CI Success" 75 | needs: [build, lint, test, compile-integration-tests] 76 | if: | 77 | always() 78 | runs-on: ubuntu-latest 79 | env: 80 | JOBS_JSON: ${{ toJSON(needs) }} 81 | RESULTS_JSON: ${{ toJSON(needs.*.result) }} 82 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} 83 | steps: 84 | - name: "CI Success" 85 | run: | 86 | echo $JOBS_JSON 87 | echo $RESULTS_JSON 88 | echo "Exiting with $EXIT_CODE" 89 | exit $EXIT_CODE 90 | -------------------------------------------------------------------------------- /.github/workflows/extract_ignored_words_list.py: -------------------------------------------------------------------------------- 1 | import toml 2 | 3 | pyproject_toml = toml.load("pyproject.toml") 4 | 5 | # Extract the ignore words list (adjust the key as per your TOML structure) 6 | ignore_words_list = ( 7 | pyproject_toml.get("tool", {}).get("codespell", {}).get("ignore-words-list") 8 | ) 9 | 10 | print(f"::set-output name=ignore_words_list::{ignore_words_list}") 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | .pytest_cache 4 | .ruff_cache 5 | .mypy_cache_test 6 | .env 7 | .venv* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 LangChain Azure 2 | 3 | This repository contains the following packages with Azure integrations with LangChain: 4 | 5 | - [langchain-azure-ai](https://pypi.org/project/langchain-azure-ai/) 6 | - [langchain-azure-dynamic-sessions](https://pypi.org/project/langchain-azure-dynamic-sessions/) 7 | - [langchain-sqlserver](https://pypi.org/project/langchain-sqlserver/) 8 | 9 | **Note**: This repository will replace all Azure integrations currently present in the `langchain-community` package. Users are encouraged to migrate to this repository as soon as possible. 10 | 11 | # Quick Start with langchain-azure-ai 12 | 13 | The `langchain-azure-ai` package uses the [Azure AI Foundry SDK](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/sdk-overview?tabs=sync&pivots=programming-language-python). This means you can use the package with a range of models including AzureOpenAI, Cohere, Llama, Phi-3/4, and DeepSeek-R1 to name a few. 14 | 15 | Here's a quick start example to show you how to get started with the Chat Completions model. For more details and tutorials see [Develop with LangChain and LangGraph and models from Azure AI Foundry](https://aka.ms/azureai/langchain). 16 | 17 | ### Install langchain-azure 18 | 19 | ```bash 20 | pip install -U langchain-azure-ai 21 | ``` 22 | 23 | ### Azure AI Chat Completions Model with Azure OpenAI 24 | 25 | ```python 26 | 27 | from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel 28 | from langchain_core.messages import HumanMessage, SystemMessage 29 | 30 | model = AzureAIChatCompletionsModel( 31 | endpoint="https://{your-resource-name}.services.ai.azure.com/openai/v1", 32 | credential="your-api-key", #if using Entra ID you can should use DefaultAzureCredential() instead 33 | model="gpt-4o" 34 | ) 35 | 36 | messages = [ 37 | SystemMessage( 38 | content="Translate the following from English into Italian" 39 | ), 40 | HumanMessage(content="hi!"), 41 | ] 42 | 43 | model.invoke(messages) 44 | ``` 45 | 46 | ```python 47 | AIMessage(content='Ciao!', additional_kwargs={}, response_metadata={'model': 'gpt-4o', 'token_usage': {'input_tokens': 20, 'output_tokens': 3, 'total_tokens': 23}, 'finish_reason': 'stop'}, id='run-0758e7ec-99cd-440b-bfa2-3a1078335133-0', usage_metadata={'input_tokens': 20, 'output_tokens': 3, 'total_tokens': 23}) 48 | ``` 49 | 50 | ### Azure AI Chat Completions Model with DeepSeek-R1 51 | 52 | ```python 53 | 54 | from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel 55 | from langchain_core.messages import HumanMessage, SystemMessage 56 | 57 | model = AzureAIChatCompletionsModel( 58 | endpoint="https://{your-resource-name}.services.ai.azure.com/models", 59 | credential="your-api-key", #if using Entra ID you can should use DefaultAzureCredential() instead 60 | model="DeepSeek-R1", 61 | ) 62 | 63 | messages = [ 64 | HumanMessage(content="Translate the following from English into Italian: \"hi!\"") 65 | ] 66 | 67 | message_stream = model.stream(messages) 68 | print(' '.join(chunk.content for chunk in message_stream)) 69 | ``` 70 | 71 | ```python 72 | 73 | Okay , the user just sent " hi !" and I need to translate that into Italian . Let me think . " Hi " is an informal greeting , so in Italian , the equivalent would be " C iao !" But wait , there are other options too . Sometimes people use " Sal ve ," which is a bit more neutral , but " C iao " is more common in casual settings . The user probably wants a straightforward translation , so " C iao !" is the safest bet here . Let me double -check to make sure there 's no nuance I 'm missing . N ope , " C iao " is definitely the right choice for translating " hi !" in an informal context . I 'll go with that . 74 | 75 | 76 | C iao ! 77 | ``` 78 | 79 | 80 | # Welcome Contributors 81 | 82 | Hi there! Thank you for even being interested in contributing to LangChain-Azure. 83 | As an open-source project in a rapidly developing field, we are extremely open to contributions, whether they involve new features, improved infrastructure, better documentation, or bug fixes. 84 | 85 | 86 | # Contribute Code 87 | 88 | To contribute to this project, please follow the ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow. 89 | 90 | Please follow the checked-in pull request template when opening pull requests. Note related issues and tag relevant 91 | maintainers. 92 | 93 | Pull requests cannot land without passing the formatting, linting, and testing checks first. See [Testing](#testing) and 94 | [Formatting and Linting](#formatting-and-linting) for how to run these checks locally. 95 | 96 | It's essential that we maintain great documentation and testing. If you: 97 | - Fix a bug 98 | - Add a relevant unit or integration test when possible. 99 | - Make an improvement 100 | - Update unit and integration tests when relevant. 101 | - Add a feature 102 | - Add unit and integration tests. 103 | 104 | If there's something you'd like to add or change, opening a pull request is the 105 | best way to get our attention. Please tag one of our maintainers for review. 106 | 107 | ## Dependency Management: Poetry and other env/dependency managers 108 | 109 | This project utilizes [Poetry](https://python-poetry.org/) v1.7.1+ as a dependency manager. 110 | 111 | ❗Note: *Before installing Poetry*, if you use `Conda`, create and activate a new Conda env (e.g. `conda create -n langchain python=3.9`) 112 | 113 | Install Poetry: **[documentation on how to install it](https://python-poetry.org/docs/#installation)**. 114 | 115 | ❗Note: If you use `Conda` or `Pyenv` as your environment/package manager, after installing Poetry, 116 | tell Poetry to use the virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`) 117 | 118 | ## Different packages 119 | 120 | This repository contains three packages with Azure integrations with LangChain: 121 | - [langchain-azure-ai](https://pypi.org/project/langchain-azure-ai/) 122 | - [langchain-azure-dynamic-sessions](https://pypi.org/project/langchain-azure-dynamic-sessions/) 123 | - [langchain-sqlserver](https://pypi.org/project/langchain-sqlserver/) 124 | 125 | Each of these has its own development environment. Docs are run from the top-level makefile, but development 126 | is split across separate test & release flows. 127 | 128 | ## Repository Structure 129 | 130 | If you plan on contributing to LangChain-Google code or documentation, it can be useful 131 | to understand the high level structure of the repository. 132 | 133 | LangChain-Azure is organized as a [monorepo](https://en.wikipedia.org/wiki/Monorepo) that contains multiple packages. 134 | 135 | Here's the structure visualized as a tree: 136 | 137 | ```text 138 | . 139 | ├── libs 140 | │ ├── azure-ai 141 | │ ├── azure-dynamic-sessions 142 | │ ├── langchain-sqlserver 143 | ``` 144 | 145 | ## Local Development Dependencies 146 | 147 | Install development requirements (for running langchain, running examples, linting, formatting, tests, and coverage): 148 | 149 | ```bash 150 | poetry install --with lint,typing,test,test_integration 151 | ``` 152 | 153 | Then verify dependency installation: 154 | 155 | ```bash 156 | make test 157 | ``` 158 | 159 | If during installation you receive a `WheelFileValidationError` for `debugpy`, please make sure you are running 160 | Poetry v1.6.1+. This bug was present in older versions of Poetry (e.g. 1.4.1) and has been resolved in newer releases. 161 | If you are still seeing this bug on v1.6.1+, you may also try disabling "modern installation" 162 | (`poetry config installer.modern-installation false`) and re-installing requirements. 163 | See [this `debugpy` issue](https://github.com/microsoft/debugpy/issues/1246) for more details. 164 | 165 | ## Code Formatting 166 | 167 | Formatting for this project is done via [ruff](https://docs.astral.sh/ruff/rules/). 168 | 169 | To run formatting for a library, run the same command from the relevant library directory: 170 | 171 | ```bash 172 | cd libs/{LIBRARY} 173 | make format 174 | ``` 175 | 176 | Additionally, you can run the formatter only on the files that have been modified in your current branch as compared to the master branch using the format_diff command: 177 | 178 | ```bash 179 | make format_diff 180 | ``` 181 | 182 | This is especially useful when you have made changes to a subset of the project and want to ensure your changes are properly formatted without affecting the rest of the codebase. 183 | 184 | ## Linting 185 | 186 | Linting for this project is done via a combination of [ruff](https://docs.astral.sh/ruff/rules/) and [mypy](http://mypy-lang.org/). 187 | 188 | To run linting for docs, cookbook and templates: 189 | 190 | ```bash 191 | make lint 192 | ``` 193 | 194 | To run linting for a library, run the same command from the relevant library directory: 195 | 196 | ```bash 197 | cd libs/{LIBRARY} 198 | make lint 199 | ``` 200 | 201 | In addition, you can run the linter only on the files that have been modified in your current branch as compared to the master branch using the lint_diff command: 202 | 203 | ```bash 204 | make lint_diff 205 | ``` 206 | 207 | This can be very helpful when you've made changes to only certain parts of the project and want to ensure your changes meet the linting standards without having to check the entire codebase. 208 | 209 | We recognize linting can be annoying - if you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed. 210 | 211 | ## Spellcheck 212 | 213 | Spellchecking for this project is done via [codespell](https://github.com/codespell-project/codespell). 214 | Note that `codespell` finds common typos, so it could have false-positive (correctly spelled but rarely used) and false-negatives (not finding misspelled) words. 215 | 216 | To check spelling for this project: 217 | 218 | ```bash 219 | make spell_check 220 | ``` 221 | 222 | To fix spelling in place: 223 | 224 | ```bash 225 | make spell_fix 226 | ``` 227 | 228 | If codespell is incorrectly flagging a word, you can skip spellcheck for that word by adding it to the codespell config in the `pyproject.toml` file. 229 | 230 | ```python 231 | [tool.codespell] 232 | ... 233 | # Add here: 234 | ignore-words-list =... 235 | ``` 236 | 237 | ## Testing 238 | 239 | All of our packages have unit tests and integration tests, and we favor unit tests over integration tests. 240 | 241 | Unit tests run on every pull request, so they should be fast and reliable. 242 | 243 | Integration tests run once a day, and they require more setup, so they should be reserved for confirming interface points with external services. 244 | 245 | ### Unit Tests 246 | 247 | Unit tests cover modular logic that does not require calls to outside APIs. 248 | If you add new logic, please add a unit test. 249 | In unit tests we check pre/post processing and mocking all external dependencies. 250 | 251 | To install dependencies for unit tests: 252 | 253 | ```bash 254 | poetry install --with test 255 | ``` 256 | 257 | To run unit tests: 258 | 259 | ```bash 260 | make test 261 | ``` 262 | 263 | To run unit tests in Docker: 264 | 265 | ```bash 266 | make docker_tests 267 | ``` 268 | 269 | To run a specific test: 270 | 271 | ```bash 272 | TEST_FILE=tests/unit_tests/test_imports.py make test 273 | ``` 274 | 275 | ### Integration Tests 276 | 277 | Integration tests cover logic that requires making calls to outside APIs (often integration with other services). 278 | If you add support for a new external API, please add a new integration test. 279 | 280 | **Warning:** Almost no tests should be integration tests. 281 | 282 | Tests that require making network connections make it difficult for other 283 | developers to test the code. 284 | 285 | Instead favor relying on `responses` library and/or mock.patch to mock 286 | requests using small fixtures. 287 | 288 | To install dependencies for integration tests: 289 | 290 | ```bash 291 | poetry install --with test,test_integration 292 | ``` 293 | 294 | To run integration tests: 295 | 296 | ```bash 297 | make integration_tests 298 | ``` 299 | 300 | 301 | For detailed information on how to contribute, see [LangChain contribution guide](https://python.langchain.com/docs/contributing/). 302 | 303 | -------------------------------------------------------------------------------- /libs/azure-ai/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ruff_cache 3 | .pytest_cache 4 | -------------------------------------------------------------------------------- /libs/azure-ai/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/azure-ai/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | test: 10 | poetry run pytest $(TEST_FILE) 11 | 12 | tests: 13 | poetry run pytest $(TEST_FILE) 14 | 15 | test_watch: 16 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 17 | 18 | 19 | ###################### 20 | # LINTING AND FORMATTING 21 | ###################### 22 | 23 | # Define a variable for Python and notebook files. 24 | PYTHON_FILES=. 25 | MYPY_CACHE=.mypy_cache 26 | lint format: PYTHON_FILES=. 27 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/azure-ai --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 28 | lint_package: PYTHON_FILES=langchain_azure_ai 29 | lint_tests: PYTHON_FILES=tests 30 | lint_tests: MYPY_CACHE=.mypy_cache_test 31 | 32 | lint lint_diff lint_package lint_tests: 33 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES) --fix 34 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff 35 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 36 | 37 | format format_diff: 38 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) 39 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES) 40 | 41 | spell_check: 42 | poetry run codespell --toml pyproject.toml 43 | 44 | spell_fix: 45 | poetry run codespell --toml pyproject.toml -w 46 | 47 | check_imports: $(shell find langchain_azure_ai -name '*.py') 48 | poetry run python ./scripts/check_imports.py $^ 49 | 50 | ###################### 51 | # HELP 52 | ###################### 53 | 54 | help: 55 | @echo '----' 56 | @echo 'check_imports - check imports' 57 | @echo 'format - run code formatters' 58 | @echo 'lint - run linters' 59 | @echo 'test - run unit tests' 60 | @echo 'tests - run unit tests' 61 | @echo 'test TEST_FILE= - run all tests in file' 62 | -------------------------------------------------------------------------------- /libs/azure-ai/README.md: -------------------------------------------------------------------------------- 1 | # langchain-azure-ai 2 | 3 | This package contains the LangChain integration for Azure AI Foundry. To learn more about how to use this package, see the LangChain documentation in [Azure AI Foundry](https://aka.ms/azureai/langchain). 4 | 5 | > [!NOTE] 6 | > This package is in Public Preview. For more information, see [Supplemental Terms of Use for Microsoft Azure Previews](https://azure.microsoft.com/support/legal/preview-supplemental-terms/). 7 | 8 | ## Installation 9 | 10 | ```bash 11 | pip install -U langchain-azure-ai 12 | ``` 13 | 14 | For using tracing capabilities with OpenTelemetry, you need to add the extras `opentelemetry`: 15 | 16 | ```bash 17 | pip install -U langchain-azure-ai[opentelemetry] 18 | ``` 19 | 20 | ## Changelog 21 | 22 | - **0.1.4**: 23 | 24 | - Bug fix [#91](https://github.com/langchain-ai/langchain-azure/pull/91). 25 | 26 | - **0.1.3**: 27 | 28 | - **[Breaking change]:** We renamed the parameter `model_name` in `AzureAIEmbeddingsModel` and `AzureAIChatCompletionsModel` to `model`, which is the parameter expected by the method `langchain.chat_models.init_chat_model`. 29 | - We fixed an issue with JSON mode in chat models [#81](https://github.com/langchain-ai/langchain-azure/issues/81). 30 | - We fixed the dependencies for NumpPy [#70](https://github.com/langchain-ai/langchain-azure/issues/70). 31 | - We fixed an issue when tracing Pyndantic objects in the inputs [#65](https://github.com/langchain-ai/langchain-azure/issues/65). 32 | - We made `connection_string` parameter optional as suggested at [#65](https://github.com/langchain-ai/langchain-azure/issues/65). 33 | 34 | - **0.1.2**: 35 | 36 | - Bug fix [#35](https://github.com/langchain-ai/langchain-azure/issues/35). 37 | 38 | - **0.1.1**: 39 | 40 | - Adding `AzureCosmosDBNoSqlVectorSearch` and `AzureCosmosDBNoSqlSemanticCache` for vector search and full text search. 41 | - Adding `AzureCosmosDBMongoVCoreVectorSearch` and `AzureCosmosDBMongoVCoreSemanticCache` for vector search. 42 | - You can now create `AzureAIEmbeddingsModel` and `AzureAIChatCompletionsModel` clients directly from your AI project's connection string using the parameter `project_connection_string`. Your default Azure AI Services connection is used to find the model requested. This requires to have `azure-ai-projects` package installed. 43 | - Support for native LLM structure outputs. Use `with_structured_output(method="json_schema")` to use native structured schema support. Use `with_structured_output(method="json_mode")` to use native JSON outputs capabilities. By default, LangChain uses `method="function_calling"` which uses tool calling capabilities to generate valid structure JSON payloads. This requires to have `azure-ai-inference >= 1.0.0b7`. 44 | - Bug fix [#18](https://github.com/langchain-ai/langchain-azure/issues/18) and [#31](https://github.com/langchain-ai/langchain-azure/issues/31). 45 | 46 | - **0.1.0**: 47 | 48 | - Introduce `AzureAIEmbeddingsModel` for embedding generation and `AzureAIChatCompletionsModel` for chat completions generation using the Azure AI Inference API. This client also supports GitHub Models endpoint. 49 | - Introduce `AzureAIInferenceTracer` for tracing with OpenTelemetry and Azure Application Insights. 50 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/__init__.py: -------------------------------------------------------------------------------- 1 | """LangChain integrations for Azure AI.""" 2 | 3 | from importlib import metadata 4 | 5 | try: 6 | __version__ = metadata.version(__package__) 7 | except metadata.PackageNotFoundError: 8 | # Case where package metadata is not available 9 | __version__ = "" 10 | del metadata # optional, avoids polluting the results of dir(__package__) 11 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | """Callables for Azure AI.""" 2 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/callbacks/tracers/__init__.py: -------------------------------------------------------------------------------- 1 | """Tracing capabilities for Azure AI Foundry.""" 2 | 3 | from langchain_azure_ai.callbacks.tracers.inference_tracing import ( 4 | AzureAIInferenceTracer, 5 | ) 6 | 7 | __all__ = ["AzureAIInferenceTracer"] 8 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/callbacks/tracers/_semantic_conventions_gen_ai.py: -------------------------------------------------------------------------------- 1 | GEN_AI_MESSAGE_ID = "gen_ai.message.id" 2 | GEN_AI_MESSAGE_STATUS = "gen_ai.message.status" 3 | GEN_AI_THREAD_ID = "gen_ai.thread.id" 4 | GEN_AI_THREAD_RUN_ID = "gen_ai.thread.run.id" 5 | GEN_AI_AGENT_ID = "gen_ai.agent.id" 6 | GEN_AI_AGENT_NAME = "gen_ai.agent.name" 7 | GEN_AI_AGENT_DESCRIPTION = "gen_ai.agent.description" 8 | GEN_AI_OPERATION_NAME = "gen_ai.operation.name" 9 | GEN_AI_THREAD_RUN_STATUS = "gen_ai.thread.run.status" 10 | GEN_AI_REQUEST_MODEL = "gen_ai.request.model" 11 | GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" 12 | GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" 13 | GEN_AI_REQUEST_MAX_INPUT_TOKENS = "gen_ai.request.max_input_tokens" 14 | GEN_AI_REQUEST_MAX_OUTPUT_TOKENS = "gen_ai.request.max_output_tokens" 15 | GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" 16 | GEN_AI_SYSTEM = "gen_ai.system" 17 | SERVER_ADDRESS = "server.address" 18 | AZ_AI_AGENT_SYSTEM = "az.ai.agents" 19 | GEN_AI_TOOL_NAME = "gen_ai.tool.name" 20 | GEN_AI_TOOL_CALL_ID = "gen_ai.tool.call.id" 21 | GEN_AI_REQUEST_RESPONSE_FORMAT = "gen_ai.request.response_format" 22 | GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" 23 | GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" 24 | GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" 25 | GEN_AI_SYSTEM_MESSAGE = "gen_ai.system.message" 26 | GEN_AI_EVENT_CONTENT = "gen_ai.event.content" 27 | ERROR_TYPE = "error.type" 28 | INPUTS = "inputs" 29 | OUTPUTS = "outputs" 30 | TAGS = "tags" 31 | GEN_AI_GENERATED_MESSAGE = "gen_ai.generated_message" 32 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/chat_message_histories/__init__.py: -------------------------------------------------------------------------------- 1 | """**Chat message history** stores a history of the message interactions in a chat. 2 | 3 | **Class hierarchy:** 4 | 5 | .. code-block:: 6 | 7 | BaseChatMessageHistory --> ChatMessageHistory # Examples: CosmosDBChatMessageHistory 8 | 9 | **Main helpers:** 10 | 11 | .. code-block:: 12 | 13 | AIMessage, HumanMessage, BaseMessage 14 | 15 | """ # noqa: E501 16 | 17 | import importlib 18 | from typing import TYPE_CHECKING, Any 19 | 20 | if TYPE_CHECKING: 21 | from langchain_azure_ai.chat_message_histories.cosmos_db import ( 22 | CosmosDBChatMessageHistory, 23 | ) 24 | 25 | __all__ = [ 26 | "CosmosDBChatMessageHistory", 27 | ] 28 | 29 | _module_lookup = { 30 | "CosmosDBChatMessageHistory": "langchain_azure_ai.chat_message_histories.cosmos_db", 31 | } 32 | 33 | 34 | def __getattr__(name: str) -> Any: 35 | if name in _module_lookup: 36 | module = importlib.import_module(_module_lookup[name]) 37 | return getattr(module, name) 38 | raise AttributeError(f"module {__name__} has no attribute {name}") 39 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/chat_message_histories/cosmos_db.py: -------------------------------------------------------------------------------- 1 | """Azure CosmosDB Memory History.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | from types import TracebackType 7 | from typing import TYPE_CHECKING, Any, List, Optional, Type 8 | 9 | from langchain_core.chat_history import BaseChatMessageHistory 10 | from langchain_core.messages import ( 11 | BaseMessage, 12 | messages_from_dict, 13 | messages_to_dict, 14 | ) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | if TYPE_CHECKING: 19 | from azure.cosmos import ContainerProxy 20 | 21 | USER_AGENT = ("LangChainAzure-CDBNoSql-ChatHistory-Python",) 22 | 23 | 24 | class CosmosDBChatMessageHistory(BaseChatMessageHistory): 25 | """Chat message history backed by Azure CosmosDB.""" 26 | 27 | def __init__( 28 | self, 29 | cosmos_endpoint: str, 30 | cosmos_database: str, 31 | cosmos_container: str, 32 | session_id: str, 33 | user_id: str, 34 | credential: Any = None, 35 | connection_string: Optional[str] = None, 36 | ttl: Optional[int] = None, 37 | cosmos_client_kwargs: Optional[dict] = None, 38 | ): 39 | """Initializes a new instance of the CosmosDBChatMessageHistory class. 40 | 41 | Make sure to call prepare_cosmos or use the context manager to make 42 | sure your database is ready. 43 | 44 | Either a credential or a connection string must be provided. 45 | 46 | :param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account. 47 | :param cosmos_database: The name of the database to use. 48 | :param cosmos_container: The name of the container to use. 49 | :param session_id: The session ID to use, can be overwritten while loading. 50 | :param user_id: The user ID to use, can be overwritten while loading. 51 | :param credential: The credential to use to authenticate to Azure Cosmos DB. 52 | :param connection_string: The connection string to use to authenticate. 53 | :param ttl: The time to live (in seconds) to use for documents in the container. 54 | :param cosmos_client_kwargs: Additional kwargs to pass to the CosmosClient. 55 | """ 56 | self.cosmos_endpoint = cosmos_endpoint 57 | self.cosmos_database = cosmos_database 58 | self.cosmos_container = cosmos_container 59 | self.credential = credential 60 | self.conn_string = connection_string 61 | self.session_id = session_id 62 | self.user_id = user_id 63 | self.ttl = ttl 64 | 65 | self.messages: List[BaseMessage] = [] 66 | try: 67 | from azure.cosmos import ( # pylint: disable=import-outside-toplevel 68 | CosmosClient, 69 | ) 70 | except ImportError as exc: 71 | raise ImportError( 72 | "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 73 | "Please install it with `pip install azure-cosmos`." 74 | ) from exc 75 | if self.credential: 76 | self._client = CosmosClient( 77 | url=self.cosmos_endpoint, 78 | credential=self.credential, 79 | user_agent=USER_AGENT, 80 | **cosmos_client_kwargs or {}, 81 | ) 82 | elif self.conn_string: 83 | self._client = CosmosClient.from_connection_string( 84 | conn_str=self.conn_string, 85 | user_agent=USER_AGENT, 86 | **cosmos_client_kwargs or {}, 87 | ) 88 | else: 89 | raise ValueError("Either a connection string or a credential must be set.") 90 | self._container: Optional[ContainerProxy] = None 91 | 92 | def prepare_cosmos(self) -> None: 93 | """Prepare the CosmosDB client. 94 | 95 | Use this function or the context manager to make sure your database is ready. 96 | """ 97 | try: 98 | from azure.cosmos import ( # pylint: disable=import-outside-toplevel 99 | PartitionKey, 100 | ) 101 | except ImportError as exc: 102 | raise ImportError( 103 | "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 104 | "Please install it with `pip install azure-cosmos`." 105 | ) from exc 106 | database = self._client.create_database_if_not_exists(self.cosmos_database) 107 | self._container = database.create_container_if_not_exists( 108 | self.cosmos_container, 109 | partition_key=PartitionKey("/user_id"), 110 | default_ttl=self.ttl, 111 | ) 112 | self.load_messages() 113 | 114 | def __enter__(self) -> "CosmosDBChatMessageHistory": 115 | """Context manager entry point.""" 116 | self._client.__enter__() 117 | self.prepare_cosmos() 118 | return self 119 | 120 | def __exit__( 121 | self, 122 | exc_type: Optional[Type[BaseException]], 123 | exc_val: Optional[BaseException], 124 | traceback: Optional[TracebackType], 125 | ) -> None: 126 | """Context manager exit.""" 127 | self.upsert_messages() 128 | self._client.__exit__(exc_type, exc_val, traceback) 129 | 130 | def load_messages(self) -> None: 131 | """Retrieve the messages from Cosmos.""" 132 | if not self._container: 133 | raise ValueError("Container not initialized") 134 | try: 135 | from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel 136 | CosmosHttpResponseError, 137 | ) 138 | except ImportError as exc: 139 | raise ImportError( 140 | "You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501 141 | "Please install it with `pip install azure-cosmos`." 142 | ) from exc 143 | try: 144 | item = self._container.read_item( 145 | item=self.session_id, partition_key=self.user_id 146 | ) 147 | except CosmosHttpResponseError: 148 | logger.info("no session found") 149 | return 150 | if "messages" in item and len(item["messages"]) > 0: 151 | self.messages = messages_from_dict(item["messages"]) 152 | 153 | def add_message(self, message: BaseMessage) -> None: 154 | """Add a self-created message to the store.""" 155 | self.messages.append(message) 156 | self.upsert_messages() 157 | 158 | def upsert_messages(self) -> None: 159 | """Update the cosmosdb item.""" 160 | if not self._container: 161 | raise ValueError("Container not initialized") 162 | self._container.upsert_item( 163 | body={ 164 | "id": self.session_id, 165 | "user_id": self.user_id, 166 | "messages": messages_to_dict(self.messages), 167 | } 168 | ) 169 | 170 | def clear(self) -> None: 171 | """Clear session memory from this memory and cosmos.""" 172 | self.messages = [] 173 | if self._container: 174 | self._container.delete_item( 175 | item=self.session_id, partition_key=self.user_id 176 | ) 177 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/chat_models/__init__.py: -------------------------------------------------------------------------------- 1 | """Chat completions model for Azure AI.""" 2 | 3 | from langchain_azure_ai.chat_models.inference import AzureAIChatCompletionsModel 4 | 5 | __all__ = ["AzureAIChatCompletionsModel"] 6 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | """Embedding model for Azure AI.""" 2 | 3 | from langchain_azure_ai.embeddings.inference import AzureAIEmbeddingsModel 4 | 5 | __all__ = ["AzureAIEmbeddingsModel"] 6 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/embeddings/inference.py: -------------------------------------------------------------------------------- 1 | """Azure AI embeddings model inference API.""" 2 | 3 | import logging 4 | from typing import ( 5 | Any, 6 | AsyncGenerator, 7 | Dict, 8 | Generator, 9 | Mapping, 10 | Optional, 11 | Union, 12 | ) 13 | 14 | from azure.ai.inference import EmbeddingsClient 15 | from azure.ai.inference.aio import EmbeddingsClient as EmbeddingsClientAsync 16 | from azure.ai.inference.models import EmbeddingInputType 17 | from azure.core.credentials import AzureKeyCredential, TokenCredential 18 | from azure.core.exceptions import HttpResponseError 19 | from langchain_core.embeddings import Embeddings 20 | from langchain_core.utils import get_from_dict_or_env, pre_init 21 | from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator 22 | 23 | from langchain_azure_ai.utils.utils import get_endpoint_from_project 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class AzureAIEmbeddingsModel(BaseModel, Embeddings): 29 | """Azure AI model inference for embeddings. 30 | 31 | Examples: 32 | .. code-block:: python 33 | from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel 34 | 35 | embed_model = AzureAIEmbeddingsModel( 36 | endpoint="https://[your-endpoint].inference.ai.azure.com", 37 | credential="your-api-key", 38 | ) 39 | 40 | If your endpoint supports multiple models, indicate the parameter `model_name`: 41 | 42 | .. code-block:: python 43 | from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel 44 | 45 | embed_model = AzureAIEmbeddingsModel( 46 | endpoint="https://[your-service].services.ai.azure.com/models", 47 | credential="your-api-key", 48 | model="cohere-embed-v3-multilingual" 49 | ) 50 | 51 | Troubleshooting: 52 | To diagnostic issues with the model, you can enable debug logging: 53 | 54 | .. code-block:: python 55 | import sys 56 | import logging 57 | from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel 58 | 59 | logger = logging.getLogger("azure") 60 | 61 | # Set the desired logging level. 62 | logger.setLevel(logging.DEBUG) 63 | 64 | handler = logging.StreamHandler(stream=sys.stdout) 65 | logger.addHandler(handler) 66 | 67 | model = AzureAIEmbeddingsModel( 68 | endpoint="https://[your-service].services.ai.azure.com/models", 69 | credential="your-api-key", 70 | model="cohere-embed-v3-multilingual", 71 | client_kwargs={ "logging_enable": True } 72 | ) 73 | """ 74 | 75 | model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) 76 | 77 | project_connection_string: Optional[str] = None 78 | """The connection string to use for the Azure AI project. If this is specified, 79 | then the `endpoint` parameter becomes optional and `credential` has to be of type 80 | `TokenCredential`.""" 81 | 82 | endpoint: Optional[str] = None 83 | """The endpoint URI where the model is deployed. Either this or the 84 | `project_connection_string` parameter must be specified.""" 85 | 86 | credential: Union[str, AzureKeyCredential, TokenCredential] 87 | """The API key or credential to use for the Azure AI model inference.""" 88 | 89 | api_version: Optional[str] = None 90 | """The API version to use for the Azure AI model inference API. If None, the 91 | default version is used.""" 92 | 93 | model_name: Optional[str] = Field(default=None, alias="model") 94 | """The name of the model to use for inference, if the endpoint is running more 95 | than one model. If not, this parameter is ignored.""" 96 | 97 | embed_batch_size: int = 1024 98 | """The batch size for embedding requests. The default is 1024.""" 99 | 100 | dimensions: Optional[int] = None 101 | """The number of dimensions in the embeddings to generate. If None, the model's 102 | default is used.""" 103 | 104 | model_kwargs: Dict[str, Any] = {} 105 | """Additional kwargs model parameters.""" 106 | 107 | client_kwargs: Dict[str, Any] = {} 108 | """Additional kwargs for the Azure AI client used.""" 109 | 110 | _client: EmbeddingsClient = PrivateAttr() 111 | _async_client: EmbeddingsClientAsync = PrivateAttr() 112 | _embed_input_type: Optional[EmbeddingInputType] = PrivateAttr() 113 | _model_name: Optional[str] = PrivateAttr() 114 | 115 | @pre_init 116 | def validate_environment(cls, values: Dict) -> Any: 117 | """Validate that api key exists in environment.""" 118 | values["endpoint"] = get_from_dict_or_env( 119 | values, "endpoint", "AZURE_INFERENCE_ENDPOINT" 120 | ) 121 | values["credential"] = get_from_dict_or_env( 122 | values, "credential", "AZURE_INFERENCE_CREDENTIAL" 123 | ) 124 | 125 | if values["api_version"]: 126 | values["client_kwargs"]["api_version"] = values["api_version"] 127 | 128 | return values 129 | 130 | @model_validator(mode="after") 131 | def initialize_client(self) -> "AzureAIEmbeddingsModel": 132 | """Initialize the Azure AI model inference client.""" 133 | if self.project_connection_string: 134 | if not isinstance(self.credential, TokenCredential): 135 | raise ValueError( 136 | "When using the `project_connection_string` parameter, the " 137 | "`credential` parameter must be of type `TokenCredential`." 138 | ) 139 | self.endpoint, self.credential = get_endpoint_from_project( 140 | self.project_connection_string, self.credential 141 | ) 142 | 143 | credential = ( 144 | AzureKeyCredential(self.credential) 145 | if isinstance(self.credential, str) 146 | else self.credential 147 | ) 148 | 149 | self._client = EmbeddingsClient( 150 | endpoint=self.endpoint, # type: ignore[arg-type] 151 | credential=credential, # type: ignore[arg-type] 152 | model=self.model_name, 153 | user_agent="langchain-azure-ai", 154 | **self.client_kwargs, 155 | ) 156 | 157 | self._async_client = EmbeddingsClientAsync( 158 | endpoint=self.endpoint, # type: ignore[arg-type] 159 | credential=credential, # type: ignore[arg-type] 160 | model=self.model_name, 161 | user_agent="langchain-azure-ai", 162 | **self.client_kwargs, 163 | ) 164 | 165 | if not self.model_name: 166 | try: 167 | # Get model info from the endpoint. This method may not be supported 168 | # by all endpoints. 169 | model_info = self._client.get_model_info() 170 | self._model_name = model_info.get("model_name", None) 171 | self._embed_input_type = ( 172 | None 173 | if model_info.get("model_provider_name", None).lower() == "cohere" 174 | else EmbeddingInputType.TEXT 175 | ) 176 | except HttpResponseError: 177 | logger.warning( 178 | f"Endpoint '{self.endpoint}' does not support model metadata " 179 | "retrieval. Unable to populate model attributes." 180 | ) 181 | self._model_name = "" 182 | self._embed_input_type = EmbeddingInputType.TEXT 183 | else: 184 | self._embed_input_type = ( 185 | None if "cohere" in self.model_name.lower() else EmbeddingInputType.TEXT 186 | ) 187 | 188 | return self 189 | 190 | def _get_model_params(self, **kwargs: Dict[str, Any]) -> Mapping[str, Any]: 191 | params: Dict[str, Any] = {} 192 | if self.dimensions: 193 | params["dimensions"] = self.dimensions 194 | if self.model_kwargs: 195 | params["model_extras"] = self.model_kwargs 196 | 197 | params.update(kwargs) 198 | return params 199 | 200 | def _embed( 201 | self, texts: list[str], input_type: EmbeddingInputType 202 | ) -> Generator[list[float], None, None]: 203 | for text_batch in range(0, len(texts), self.embed_batch_size): 204 | response = self._client.embed( 205 | input=texts[text_batch : text_batch + self.embed_batch_size], 206 | input_type=self._embed_input_type or input_type, 207 | **self._get_model_params(), 208 | ) 209 | 210 | for data in response.data: 211 | yield data.embedding # type: ignore 212 | 213 | async def _embed_async( 214 | self, texts: list[str], input_type: EmbeddingInputType 215 | ) -> AsyncGenerator[list[float], None]: 216 | for text_batch in range(0, len(texts), self.embed_batch_size): 217 | response = await self._async_client.embed( 218 | input=texts[text_batch : text_batch + self.embed_batch_size], 219 | input_type=self._embed_input_type or input_type, 220 | **self._get_model_params(), 221 | ) 222 | 223 | for data in response.data: 224 | yield data.embedding # type: ignore 225 | 226 | def embed_documents(self, texts: list[str]) -> list[list[float]]: 227 | """Embed search docs. 228 | 229 | Args: 230 | texts: List of text to embed. 231 | 232 | Returns: 233 | List of embeddings. 234 | """ 235 | return list(self._embed(texts, EmbeddingInputType.DOCUMENT)) 236 | 237 | def embed_query(self, text: str) -> list[float]: 238 | """Embed query text. 239 | 240 | Args: 241 | text: Text to embed. 242 | 243 | Returns: 244 | Embedding. 245 | """ 246 | return list(self._embed([text], EmbeddingInputType.QUERY))[0] 247 | 248 | async def aembed_documents(self, texts: list[str]) -> list[list[float]]: 249 | """Asynchronous Embed search docs. 250 | 251 | Args: 252 | texts: List of text to embed. 253 | 254 | Returns: 255 | List of embeddings. 256 | """ 257 | return self._embed_async(texts, EmbeddingInputType.DOCUMENT) # type: ignore[return-value] 258 | 259 | async def aembed_query(self, text: str) -> list[float]: 260 | """Asynchronous Embed query text. 261 | 262 | Args: 263 | text: Text to embed. 264 | 265 | Returns: 266 | Embedding. 267 | """ 268 | async for item in self._embed_async([text], EmbeddingInputType.QUERY): 269 | return item 270 | return [] 271 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/query_constructors/__init__.py: -------------------------------------------------------------------------------- 1 | """This module defines the query constructors for the Azure AI integrations.""" 2 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/query_constructors/cosmosdb_no_sql.py: -------------------------------------------------------------------------------- 1 | """Translator that converts a StructuredQuery into a CosmosDB NoSQL query.""" 2 | 3 | from typing import Any, Dict, Tuple 4 | 5 | from langchain_core.structured_query import ( 6 | Comparator, 7 | Comparison, 8 | Operation, 9 | Operator, 10 | StructuredQuery, 11 | Visitor, 12 | ) 13 | 14 | SQL_COMPARATOR = { 15 | Comparator.EQ: "=", 16 | Comparator.NE: "!=", 17 | Comparator.GT: ">", 18 | Comparator.GTE: ">=", 19 | Comparator.LT: "<", 20 | Comparator.LTE: "<=", 21 | Comparator.LIKE: "LIKE", 22 | Comparator.IN: "IN", 23 | Comparator.NIN: "NOT IN", 24 | } 25 | 26 | SQL_OPERATOR = { 27 | Operator.AND: "AND", 28 | Operator.OR: "OR", 29 | Operator.NOT: "NOT", 30 | } 31 | 32 | 33 | class AzureCosmosDbNoSQLTranslator(Visitor): 34 | """A visitor that converts a StructuredQuery into an CosmosDB NO SQL query.""" 35 | 36 | def __init__(self, table_name: str = "c") -> None: 37 | """Initialize the translator with the table name.""" 38 | self.table_name = table_name 39 | 40 | def visit_comparison(self, comparison: Comparison) -> str: 41 | """Visit a comparison operation and convert it into an SQL condition.""" 42 | operator = SQL_COMPARATOR.get(comparison.comparator) 43 | value = comparison.value 44 | field = f"{self.table_name}.{comparison.attribute}" 45 | 46 | if operator is None: 47 | raise ValueError(f"Unsupported operator: {comparison.comparator}") 48 | 49 | # Correct value formatting 50 | if isinstance(value, str): 51 | value = f"'{value}'" 52 | elif isinstance(value, (list, tuple)): # Handle IN clause 53 | if comparison.comparator not in [Comparator.IN, Comparator.NIN]: 54 | raise ValueError( 55 | f"Invalid comparator for list value: {comparison.comparator}" 56 | ) 57 | value = ( 58 | "(" 59 | + ", ".join(f"'{v}'" if isinstance(v, str) else str(v) for v in value) 60 | + ")" 61 | ) 62 | 63 | return f"{field} {operator} {value}" 64 | 65 | def visit_operation(self, operation: Operation) -> str: 66 | """Visit logical operations and convert them into SQL expressions. 67 | 68 | Uses parentheses to ensure correct precedence. 69 | """ 70 | operator = SQL_OPERATOR.get(operation.operator) 71 | if operator is None: 72 | raise ValueError(f"Unsupported operator: {operation.operator}") 73 | 74 | expressions = [arg.accept(self) for arg in operation.arguments] 75 | 76 | if operation.operator == Operator.NOT: 77 | return f"NOT ({expressions[0]})" 78 | 79 | return f"({f' {operator} '.join(expressions)})" 80 | 81 | def visit_structured_query( 82 | self, structured_query: StructuredQuery 83 | ) -> Tuple[str, Dict[str, Any]]: 84 | """Visit a structured query and convert it into parameter for vectorstore.""" 85 | if structured_query.filter is None: 86 | kwargs = {} 87 | else: 88 | kwargs = {"where": structured_query.filter.accept(self)} 89 | return structured_query.query, kwargs 90 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/utils/math.py: -------------------------------------------------------------------------------- 1 | """Math utils.""" 2 | 3 | import logging 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 11 | 12 | 13 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 14 | """Row-wise cosine similarity between two equal-width matrices.""" 15 | if len(X) == 0 or len(Y) == 0: 16 | return np.array([]) 17 | 18 | X = np.array(X) 19 | Y = np.array(Y) 20 | if X.shape[1] != Y.shape[1]: 21 | raise ValueError( 22 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 23 | f"and Y has shape {Y.shape}." 24 | ) 25 | try: 26 | import simsimd as simd 27 | 28 | X = np.array(X, dtype=np.float32) 29 | Y = np.array(Y, dtype=np.float32) 30 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) 31 | return Z 32 | except ImportError: 33 | logger.debug( 34 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 35 | "to use simsimd please install with `pip install simsimd`." 36 | ) 37 | X_norm = np.linalg.norm(X, axis=1) 38 | Y_norm = np.linalg.norm(Y, axis=1) 39 | # Ignore divide by zero errors run time warnings as those are handled below. 40 | with np.errstate(divide="ignore", invalid="ignore"): 41 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 42 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 43 | return similarity 44 | 45 | 46 | def cosine_similarity_top_k( 47 | X: Matrix, 48 | Y: Matrix, 49 | top_k: Optional[int] = 5, 50 | score_threshold: Optional[float] = None, 51 | ) -> Tuple[List[Tuple[int, int]], List[float]]: 52 | """Row-wise cosine similarity with optional top-k and score threshold filtering. 53 | 54 | Args: 55 | X: Matrix. 56 | Y: Matrix, same width as X. 57 | top_k: Max number of results to return. 58 | score_threshold: Minimum cosine similarity of results. 59 | 60 | Returns: 61 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), 62 | second contains corresponding cosine similarities. 63 | """ 64 | if len(X) == 0 or len(Y) == 0: 65 | return [], [] 66 | score_array = cosine_similarity(X, Y) 67 | score_threshold = score_threshold or -1.0 68 | score_array[score_array < score_threshold] = 0 69 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) 70 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] 71 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] 72 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) 73 | scores = score_array.ravel()[top_k_idxs].tolist() 74 | return list(zip(*ret_idxs)), scores # type: ignore 75 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/utils/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for LangChain Azure AI package.""" 2 | 3 | import dataclasses 4 | import json 5 | from typing import Any, Tuple, Union 6 | 7 | from azure.core.credentials import AzureKeyCredential, TokenCredential 8 | from pydantic import BaseModel 9 | 10 | 11 | class JSONObjectEncoder(json.JSONEncoder): 12 | """Custom JSON encoder for objects in LangChain.""" 13 | 14 | def default(self, o: Any) -> Any: 15 | """Serialize the object to JSON string. 16 | 17 | Args: 18 | o (Any): Object to be serialized. 19 | """ 20 | if isinstance(o, dict): 21 | if "callbacks" in o: 22 | del o["callbacks"] 23 | return o 24 | 25 | if dataclasses.is_dataclass(o): 26 | return dataclasses.asdict(o) # type: ignore 27 | 28 | if hasattr(o, "to_json"): 29 | return o.to_json() 30 | 31 | if isinstance(o, BaseModel) and hasattr(o, "model_dump_json"): 32 | return o.model_dump_json() 33 | 34 | if "__slots__" in dir(o): 35 | # Handle objects with __slots__ that are not dataclasses 36 | return { 37 | "__class__": o.__class__.__name__, 38 | **{slot: getattr(o, slot) for slot in o.__slots__}, 39 | } 40 | 41 | return super().default(o) 42 | 43 | 44 | def get_endpoint_from_project( 45 | project_connection_string: str, credential: TokenCredential 46 | ) -> Tuple[str, Union[AzureKeyCredential, TokenCredential]]: 47 | """Retrieves the default inference endpoint and credentials from a project. 48 | 49 | It uses the Azure AI project's connection string to retrieve the inference 50 | defaults. The default connection of type Azure AI Services is used to 51 | retrieve the endpoint and credentials. 52 | 53 | Args: 54 | project_connection_string (str): Connection string for the Azure AI project. 55 | credential (TokenCredential): Azure credential object. Credentials must be of 56 | type `TokenCredential` when using the `project_connection_string` 57 | parameter. 58 | 59 | Returns: 60 | Tuple[str, Union[AzureKeyCredential, TokenCredential]]: Endpoint URL and 61 | credentials. 62 | """ 63 | try: 64 | from azure.ai.projects import AIProjectClient # type: ignore[import-untyped] 65 | from azure.ai.projects.models import ( # type: ignore[import-untyped] 66 | ConnectionType, 67 | ) 68 | except ImportError: 69 | raise ImportError( 70 | "The `azure.ai.projects` package is required to use the " 71 | "`project_connection_string` parameter. Please install it with " 72 | "`pip install azure-ai-projects`." 73 | ) 74 | 75 | project = AIProjectClient.from_connection_string( 76 | conn_str=project_connection_string, 77 | credential=credential, 78 | ) 79 | 80 | connection = project.connections.get_default( 81 | connection_type=ConnectionType.AZURE_AI_SERVICES, include_credentials=True 82 | ) 83 | 84 | if not connection: 85 | raise ValueError( 86 | "No Azure AI Services connection found in the project. See " 87 | "https://aka.ms/azureai/modelinference/connection for more " 88 | "information." 89 | ) 90 | 91 | if connection.endpoint_url.endswith("/models"): 92 | endpoint = connection.endpoint_url 93 | elif connection.endpoint_url.endswith("/"): 94 | endpoint = connection.endpoint_url + "models" 95 | else: 96 | endpoint = connection.endpoint_url + "/models" 97 | 98 | return endpoint, connection.key or connection.token_credential 99 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | """**Vector store** stores embedded data and performs vector search. 2 | 3 | One of the most common ways to store and search over unstructured data is to 4 | embed it and store the resulting embedding vectors, and then query the store 5 | and retrieve the data that are 'most similar' to the embedded query. 6 | 7 | **Class hierarchy:** 8 | 9 | .. code-block:: 10 | 11 | VectorStore --> # Examples: Annoy, FAISS, Milvus 12 | 13 | BaseRetriever --> VectorStoreRetriever --> Retriever # Example: VespaRetriever 14 | 15 | **Main helpers:** 16 | 17 | .. code-block:: 18 | 19 | Embeddings, Document 20 | """ # noqa: E501 21 | 22 | from langchain_azure_ai.vectorstores.azure_cosmos_db_mongo_vcore import ( 23 | AzureCosmosDBMongoVCoreVectorSearch, 24 | ) 25 | from langchain_azure_ai.vectorstores.azure_cosmos_db_no_sql import ( 26 | AzureCosmosDBNoSqlVectorSearch, 27 | ) 28 | 29 | __all__ = [ 30 | "AzureCosmosDBNoSqlVectorSearch", 31 | "AzureCosmosDBMongoVCoreVectorSearch", 32 | ] 33 | 34 | _module_lookup = { 35 | "AzureCosmosDBMongoVCoreVectorSearch": "langchain_azure_ai.vectorstores.azure_cosmos_db_mongo_vcore", # noqa: E501 36 | "AzureCosmosDBNoSqlVectorSearch": "langchain_azure_ai.vectorstores.azure_cosmos_db_no_sql", # noqa: E501 37 | } 38 | -------------------------------------------------------------------------------- /libs/azure-ai/langchain_azure_ai/vectorstores/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with vectors and vectorstores.""" 2 | 3 | from enum import Enum 4 | from typing import List, Tuple, Type 5 | 6 | import numpy as np 7 | from langchain_core.documents import Document 8 | 9 | from langchain_azure_ai.utils.math import cosine_similarity 10 | 11 | 12 | class DistanceStrategy(str, Enum): 13 | """Enumerator of the Distance strategies for calculating distances between vectors.""" # noqa: E501 14 | 15 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" 16 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" 17 | DOT_PRODUCT = "DOT_PRODUCT" 18 | JACCARD = "JACCARD" 19 | COSINE = "COSINE" 20 | 21 | 22 | def maximal_marginal_relevance( 23 | query_embedding: np.ndarray, 24 | embedding_list: list, 25 | lambda_mult: float = 0.5, 26 | k: int = 4, 27 | ) -> List[int]: 28 | """Calculate maximal marginal relevance.""" 29 | if min(k, len(embedding_list)) <= 0: 30 | return [] 31 | if query_embedding.ndim == 1: 32 | query_embedding = np.expand_dims(query_embedding, axis=0) 33 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 34 | most_similar = int(np.argmax(similarity_to_query)) 35 | idxs = [most_similar] 36 | selected = np.array([embedding_list[most_similar]]) 37 | while len(idxs) < min(k, len(embedding_list)): 38 | best_score = -np.inf 39 | idx_to_add = -1 40 | similarity_to_selected = cosine_similarity(embedding_list, selected) 41 | for i, query_score in enumerate(similarity_to_query): 42 | if i in idxs: 43 | continue 44 | redundant_score = max(similarity_to_selected[i]) 45 | equation_score = ( 46 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 47 | ) 48 | if equation_score > best_score: 49 | best_score = equation_score 50 | idx_to_add = i 51 | idxs.append(idx_to_add) 52 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 53 | return idxs 54 | 55 | 56 | def filter_complex_metadata( 57 | documents: List[Document], 58 | *, 59 | allowed_types: Tuple[Type, ...] = (str, bool, int, float), 60 | ) -> List[Document]: 61 | """Filter out metadata types that are not supported for a vector store.""" 62 | updated_documents = [] 63 | for document in documents: 64 | filtered_metadata = {} 65 | for key, value in document.metadata.items(): 66 | if not isinstance(value, allowed_types): 67 | continue 68 | filtered_metadata[key] = value 69 | 70 | document.metadata = filtered_metadata 71 | updated_documents.append(document) 72 | 73 | return updated_documents 74 | -------------------------------------------------------------------------------- /libs/azure-ai/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-azure-ai" 3 | version = "0.1.4" 4 | description = "An integration package to support Azure AI Foundry capabilities for model inference in LangChain." 5 | authors = [] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9,<4.0" 11 | langchain-core = "^0.3.0" 12 | langchain-openai ="^0.3.0" 13 | azure-core = "^1.32.0" 14 | azure-cosmos = "^4.9.0" 15 | azure-identity = "^1.15.0" 16 | azure-ai-inference = {extras = ["opentelemetry"], version = "^1.0.0b7"} 17 | aiohttp = "^3.10.0" 18 | azure-monitor-opentelemetry = { "version" = "^1.6.4", optional = true } 19 | opentelemetry-semantic-conventions-ai = { "version" = "^0.4.2", optional = true } 20 | opentelemetry-instrumentation-threading = { "version" = "^0.49b2", optional = true } 21 | numpy = [ 22 | { version = ">=1.26.2", markers = "python_version < '3.13'" }, 23 | { version = ">=2.1.0", markers = "python_version >= '3.13'" } 24 | ] 25 | 26 | [tool.poetry.extras] 27 | opentelemetry = ["azure-monitor-opentelemetry", "opentelemetry-semantic-conventions-ai", "opentelemetry-instrumentation-threading"] 28 | 29 | 30 | [tool.poetry.group.codespell.dependencies] 31 | codespell = "^2.2.0" 32 | 33 | [tool.poetry.group.dev.dependencies] 34 | langchain = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain"} 35 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 36 | ipykernel = "^6.29.5" 37 | 38 | [tool.poetry.group.lint.dependencies] 39 | ruff = "^0.5" 40 | python-dotenv = "^1.0.1" 41 | pytest = "^7.4.3" 42 | pymongo = "^4.5.0" 43 | simsimd = "^6.0.0" 44 | 45 | 46 | [tool.poetry.group.test.dependencies] 47 | pydantic = "^2.9.2" 48 | pytest = "^7.4.3" 49 | pytest-mock = "^3.10.0" 50 | pytest-watcher = "^0.3.4" 51 | pytest-asyncio = "^0.21.1" 52 | python-dotenv = "^1.0.1" 53 | syrupy = "^4.7.2" 54 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 55 | 56 | [tool.poetry.group.test_integration.dependencies] 57 | pytest = "^7.3.0" 58 | python-dotenv = "^1.0.1" 59 | 60 | [tool.poetry.urls] 61 | "Source Code" = "https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-ai" 62 | "Release Notes" = "https://github.com/langchain-ai/langchain-azure/releases" 63 | 64 | [tool.mypy] 65 | disallow_untyped_defs = "True" 66 | 67 | [tool.poetry.group.typing.dependencies] 68 | mypy = "^1.10" 69 | 70 | [tool.ruff.lint] 71 | select = ["E", "F", "I", "D"] 72 | 73 | [tool.coverage.run] 74 | omit = ["tests/*"] 75 | 76 | [tool.pytest.ini_options] 77 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 78 | markers = [ 79 | "requires: mark tests as requiring a specific library", 80 | "compile: mark placeholder test used to compile integration tests without running them", 81 | ] 82 | asyncio_mode = "auto" 83 | 84 | [tool.poetry.group.test] 85 | optional = true 86 | 87 | [tool.poetry.group.test_integration] 88 | optional = true 89 | 90 | [tool.poetry.group.codespell] 91 | optional = true 92 | 93 | [tool.poetry.group.lint] 94 | optional = true 95 | 96 | [tool.poetry.group.dev] 97 | optional = true 98 | 99 | [tool.ruff.lint.pydocstyle] 100 | convention = "google" 101 | 102 | [tool.ruff.lint.per-file-ignores] 103 | "tests/**" = ["D"] 104 | 105 | [build-system] 106 | requires = ["poetry-core"] 107 | build-backend = "poetry.core.masonry.api" 108 | 109 | [tool.codespell] 110 | ignore-words-list = "nin" 111 | -------------------------------------------------------------------------------- /libs/azure-ai/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | """This module checks for specific import statements in the codebase.""" 2 | 3 | import sys 4 | import traceback 5 | from importlib.machinery import SourceFileLoader 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | has_failure = False 10 | for file in files: 11 | try: 12 | SourceFileLoader("x", file).load_module() 13 | except Exception: 14 | has_failure = True 15 | print(file) 16 | traceback.print_exc() 17 | print() 18 | 19 | sys.exit(1 if has_failure else 0) 20 | -------------------------------------------------------------------------------- /libs/azure-ai/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-ai/tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-ai/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-ai/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-ai/tests/integration_tests/cache/test_azure_cosmos_db_no_sql_cache.py: -------------------------------------------------------------------------------- 1 | """Test` Azure CosmosDB NoSql cache functionality.""" 2 | 3 | import os 4 | from typing import Any, Dict 5 | 6 | import pytest 7 | from langchain_core.globals import get_llm_cache, set_llm_cache 8 | from langchain_core.outputs import Generation 9 | from langchain_openai.embeddings import OpenAIEmbeddings 10 | 11 | from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel 12 | from langchain_azure_ai.vectorstores.cache import AzureCosmosDBNoSqlSemanticCache 13 | 14 | HOST = "COSMOS_DB_URI" 15 | KEY = "COSMOS_DB_KEY" 16 | model_name = os.getenv("OPENAI_EMBEDDINGS_MODEL_NAME", "text-embedding-ada-002") 17 | 18 | 19 | @pytest.fixture() 20 | def cosmos_client() -> Any: 21 | from azure.cosmos import CosmosClient 22 | 23 | return CosmosClient(HOST, KEY) 24 | 25 | 26 | @pytest.fixture() 27 | def partition_key() -> Any: 28 | from azure.cosmos import PartitionKey 29 | 30 | return PartitionKey(path="/id") 31 | 32 | 33 | @pytest.fixture() 34 | def azure_openai_embeddings() -> OpenAIEmbeddings: 35 | openai_embeddings: OpenAIEmbeddings = OpenAIEmbeddings( 36 | model=model_name, 37 | chunk_size=1, 38 | ) 39 | return openai_embeddings 40 | 41 | 42 | # cosine, euclidean, innerproduct 43 | def indexing_policy(index_type: str) -> dict: 44 | return { 45 | "indexingMode": "consistent", 46 | "includedPaths": [{"path": "/*"}], 47 | "excludedPaths": [{"path": '/"_etag"/?'}], 48 | "vectorIndexes": [{"path": "/embedding", "type": index_type}], 49 | } 50 | 51 | 52 | def vector_embedding_policy(distance_function: str) -> dict: 53 | return { 54 | "vectorEmbeddings": [ 55 | { 56 | "path": "/embedding", 57 | "dataType": "float32", 58 | "distanceFunction": distance_function, 59 | "dimensions": 1536, 60 | } 61 | ] 62 | } 63 | 64 | 65 | cosmos_container_properties_test = {"partition_key": partition_key} 66 | cosmos_database_properties_test: Dict[str, Any] = {} 67 | 68 | 69 | def test_azure_cosmos_db_nosql_semantic_cache_cosine_quantizedflat( 70 | cosmos_client: Any, 71 | azure_openai_embeddings: OpenAIEmbeddings, 72 | ) -> None: 73 | set_llm_cache( 74 | AzureCosmosDBNoSqlSemanticCache( 75 | cosmos_client=cosmos_client, 76 | embedding=azure_openai_embeddings, 77 | vector_embedding_policy=vector_embedding_policy("cosine"), 78 | indexing_policy=indexing_policy("quantizedFlat"), 79 | cosmos_container_properties=cosmos_container_properties_test, 80 | cosmos_database_properties=cosmos_database_properties_test, 81 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 82 | ) 83 | ) 84 | 85 | llm = AzureAIChatCompletionsModel() 86 | params = llm.dict() 87 | params["stop"] = None 88 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 89 | get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) 90 | 91 | # foo and bar will have the same embedding produced by FakeEmbeddings 92 | cache_output = get_llm_cache().lookup("bar", llm_string) 93 | assert cache_output == [Generation(text="fizz")] 94 | 95 | # clear the cache 96 | get_llm_cache().clear(llm_string=llm_string) 97 | 98 | 99 | def test_azure_cosmos_db_nosql_semantic_cache_cosine_flat( 100 | cosmos_client: Any, 101 | azure_openai_embeddings: OpenAIEmbeddings, 102 | ) -> None: 103 | set_llm_cache( 104 | AzureCosmosDBNoSqlSemanticCache( 105 | cosmos_client=cosmos_client, 106 | embedding=azure_openai_embeddings, 107 | vector_embedding_policy=vector_embedding_policy("cosine"), 108 | indexing_policy=indexing_policy("flat"), 109 | cosmos_container_properties=cosmos_container_properties_test, 110 | cosmos_database_properties=cosmos_database_properties_test, 111 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 112 | ) 113 | ) 114 | 115 | llm = AzureAIChatCompletionsModel() 116 | params = llm.dict() 117 | params["stop"] = None 118 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 119 | get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) 120 | 121 | # foo and bar will have the same embedding produced by FakeEmbeddings 122 | cache_output = get_llm_cache().lookup("bar", llm_string) 123 | assert cache_output == [Generation(text="fizz")] 124 | 125 | # clear the cache 126 | get_llm_cache().clear(llm_string=llm_string) 127 | 128 | 129 | def test_azure_cosmos_db_nosql_semantic_cache_dotproduct_quantizedflat( 130 | cosmos_client: Any, 131 | azure_openai_embeddings: OpenAIEmbeddings, 132 | ) -> None: 133 | set_llm_cache( 134 | AzureCosmosDBNoSqlSemanticCache( 135 | cosmos_client=cosmos_client, 136 | embedding=azure_openai_embeddings, 137 | vector_embedding_policy=vector_embedding_policy("dotProduct"), 138 | indexing_policy=indexing_policy("quantizedFlat"), 139 | cosmos_container_properties=cosmos_container_properties_test, 140 | cosmos_database_properties=cosmos_database_properties_test, 141 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 142 | ) 143 | ) 144 | 145 | llm = AzureAIChatCompletionsModel() 146 | params = llm.dict() 147 | params["stop"] = None 148 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 149 | get_llm_cache().update( 150 | "foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")] 151 | ) 152 | 153 | # foo and bar will have the same embedding produced by FakeEmbeddings 154 | cache_output = get_llm_cache().lookup("bar", llm_string) 155 | assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")] 156 | 157 | # clear the cache 158 | get_llm_cache().clear(llm_string=llm_string) 159 | 160 | 161 | def test_azure_cosmos_db_nosql_semantic_cache_dotproduct_flat( 162 | cosmos_client: Any, 163 | azure_openai_embeddings: OpenAIEmbeddings, 164 | ) -> None: 165 | set_llm_cache( 166 | AzureCosmosDBNoSqlSemanticCache( 167 | cosmos_client=cosmos_client, 168 | embedding=azure_openai_embeddings, 169 | vector_embedding_policy=vector_embedding_policy("dotProduct"), 170 | indexing_policy=indexing_policy("flat"), 171 | cosmos_container_properties=cosmos_container_properties_test, 172 | cosmos_database_properties=cosmos_database_properties_test, 173 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 174 | ) 175 | ) 176 | 177 | llm = AzureAIChatCompletionsModel() 178 | params = llm.dict() 179 | params["stop"] = None 180 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 181 | get_llm_cache().update( 182 | "foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")] 183 | ) 184 | 185 | # foo and bar will have the same embedding produced by FakeEmbeddings 186 | cache_output = get_llm_cache().lookup("bar", llm_string) 187 | assert cache_output == [Generation(text="fizz"), Generation(text="Buzz")] 188 | 189 | # clear the cache 190 | get_llm_cache().clear(llm_string=llm_string) 191 | 192 | 193 | def test_azure_cosmos_db_nosql_semantic_cache_euclidean_quantizedflat( 194 | cosmos_client: Any, 195 | azure_openai_embeddings: OpenAIEmbeddings, 196 | ) -> None: 197 | set_llm_cache( 198 | AzureCosmosDBNoSqlSemanticCache( 199 | cosmos_client=cosmos_client, 200 | embedding=azure_openai_embeddings, 201 | vector_embedding_policy=vector_embedding_policy("euclidean"), 202 | indexing_policy=indexing_policy("quantizedFlat"), 203 | cosmos_container_properties=cosmos_container_properties_test, 204 | cosmos_database_properties=cosmos_database_properties_test, 205 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 206 | ) 207 | ) 208 | 209 | llm = AzureAIChatCompletionsModel() 210 | params = llm.dict() 211 | params["stop"] = None 212 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 213 | get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) 214 | 215 | # foo and bar will have the same embedding produced by FakeEmbeddings 216 | cache_output = get_llm_cache().lookup("bar", llm_string) 217 | assert cache_output == [Generation(text="fizz")] 218 | 219 | # clear the cache 220 | get_llm_cache().clear(llm_string=llm_string) 221 | 222 | 223 | def test_azure_cosmos_db_nosql_semantic_cache_euclidean_flat( 224 | cosmos_client: Any, 225 | azure_openai_embeddings: OpenAIEmbeddings, 226 | ) -> None: 227 | set_llm_cache( 228 | AzureCosmosDBNoSqlSemanticCache( 229 | cosmos_client=cosmos_client, 230 | embedding=azure_openai_embeddings, 231 | vector_embedding_policy=vector_embedding_policy("euclidean"), 232 | indexing_policy=indexing_policy("flat"), 233 | cosmos_container_properties=cosmos_container_properties_test, 234 | cosmos_database_properties=cosmos_database_properties_test, 235 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 236 | ) 237 | ) 238 | 239 | llm = AzureAIChatCompletionsModel() 240 | params = llm.dict() 241 | params["stop"] = None 242 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 243 | get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) 244 | 245 | # foo and bar will have the same embedding produced by FakeEmbeddings 246 | cache_output = get_llm_cache().lookup("bar", llm_string) 247 | assert cache_output == [Generation(text="fizz")] 248 | 249 | # clear the cache 250 | get_llm_cache().clear(llm_string=llm_string) 251 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore[import-not-found] 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/integration_tests/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | """Test vectorstores""" 2 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/integration_tests/vectorstores/test_azure_cosmos_db_no_sql.py: -------------------------------------------------------------------------------- 1 | """Test AzureCosmosDBNoSqlVectorSearch functionality.""" 2 | 3 | import logging 4 | import os 5 | from time import sleep 6 | from typing import Any, Dict, List, Tuple 7 | 8 | import pytest 9 | from langchain_core.documents import Document 10 | from langchain_openai.embeddings import OpenAIEmbeddings 11 | 12 | from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel 13 | from langchain_azure_ai.vectorstores.azure_cosmos_db_no_sql import ( 14 | AzureCosmosDBNoSqlVectorSearch, 15 | ) 16 | 17 | logging.basicConfig(level=logging.DEBUG) 18 | 19 | model_deployment = os.getenv("OPENAI_EMBEDDINGS_DEPLOYMENT", "embeddings") 20 | model_name = os.getenv("OPENAI_EMBEDDINGS_MODEL_NAME", "text-embedding-ada-002") 21 | 22 | # Host and Key for CosmosDB No SQl 23 | HOST = os.getenv("HOST", "default_host") 24 | KEY = os.getenv("KEY", "default_key") 25 | 26 | database_name = "langchain_python_db" 27 | container_name = "langchain_python_container" 28 | 29 | 30 | @pytest.fixture() 31 | def cosmos_client() -> Any: 32 | from azure.cosmos import CosmosClient 33 | 34 | return CosmosClient(HOST, KEY) 35 | 36 | 37 | @pytest.fixture() 38 | def partition_key() -> Any: 39 | from azure.cosmos import PartitionKey 40 | 41 | return PartitionKey(path="/id") 42 | 43 | 44 | @pytest.fixture() 45 | def azure_openai_embeddings() -> Any: 46 | openai_embeddings: OpenAIEmbeddings = OpenAIEmbeddings( 47 | model=model_name, 48 | chunk_size=1, 49 | ) 50 | return openai_embeddings 51 | 52 | 53 | def safe_delete_database(cosmos_client: Any) -> None: 54 | cosmos_client.delete_database(database_name) 55 | 56 | 57 | def get_vector_indexing_policy(embedding_type: str) -> dict: 58 | return { 59 | "indexingMode": "consistent", 60 | "includedPaths": [{"path": "/*"}], 61 | "excludedPaths": [{"path": '/"_etag"/?'}], 62 | "vectorIndexes": [{"path": "/embedding", "type": embedding_type}], 63 | "fullTextIndexes": [{"path": "/text"}], 64 | } 65 | 66 | 67 | def get_vector_embedding_policy( 68 | distance_function: str, data_type: str, dimensions: int 69 | ) -> dict: 70 | return { 71 | "vectorEmbeddings": [ 72 | { 73 | "path": "/embedding", 74 | "dataType": data_type, 75 | "dimensions": dimensions, 76 | "distanceFunction": distance_function, 77 | } 78 | ] 79 | } 80 | 81 | 82 | def get_full_text_policy() -> dict: 83 | return { 84 | "defaultLanguage": "en-US", 85 | "fullTextPaths": [{"path": "/text", "language": "en-US"}], 86 | } 87 | 88 | 89 | class TestAzureCosmosDBNoSqlVectorSearch: 90 | def test_from_documents_cosine_distance( 91 | self, 92 | cosmos_client: Any, 93 | partition_key: Any, 94 | azure_openai_embeddings: AzureAIEmbeddingsModel, 95 | ) -> None: 96 | """Test end to end construction and search.""" 97 | documents = self._get_documents() 98 | 99 | store = AzureCosmosDBNoSqlVectorSearch.from_documents( 100 | documents, 101 | embedding=azure_openai_embeddings, 102 | cosmos_client=cosmos_client, 103 | database_name=database_name, 104 | container_name=container_name, 105 | vector_embedding_policy=get_vector_embedding_policy( 106 | "cosine", "float32", 400 107 | ), 108 | indexing_policy=get_vector_indexing_policy("flat"), 109 | cosmos_container_properties={"partition_key": partition_key}, 110 | cosmos_database_properties={}, 111 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 112 | full_text_policy=get_full_text_policy(), 113 | full_text_search_enabled=True, 114 | ) 115 | sleep(1) # waits for Cosmos DB to save contents to the collection 116 | 117 | output = store.similarity_search("Which dog breed is considered a herder?", k=5) 118 | 119 | assert output 120 | assert len(output) == 5 121 | assert "Border Collies" in output[0].page_content 122 | safe_delete_database(cosmos_client) 123 | 124 | def test_from_documents_cosine_distance_custom_projection( 125 | self, 126 | cosmos_client: Any, 127 | partition_key: Any, 128 | azure_openai_embeddings: AzureAIEmbeddingsModel, 129 | ) -> None: 130 | """Test end to end construction and search.""" 131 | documents = self._get_documents() 132 | 133 | store = AzureCosmosDBNoSqlVectorSearch.from_documents( 134 | documents, 135 | embedding=azure_openai_embeddings, 136 | cosmos_client=cosmos_client, 137 | database_name=database_name, 138 | container_name=container_name, 139 | vector_embedding_policy=get_vector_embedding_policy( 140 | "cosine", "float32", 400 141 | ), 142 | indexing_policy=get_vector_indexing_policy("flat"), 143 | cosmos_container_properties={"partition_key": partition_key}, 144 | cosmos_database_properties={}, 145 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 146 | full_text_policy=get_full_text_policy(), 147 | full_text_search_enabled=True, 148 | ) 149 | sleep(1) # waits for Cosmos DB to save contents to the collection 150 | 151 | projection_mapping = { 152 | "text": "text", 153 | } 154 | output = store.similarity_search( 155 | "Which dog breed is considered a herder?", 156 | k=5, 157 | projection_mapping=projection_mapping, 158 | ) 159 | 160 | assert output 161 | assert len(output) == 5 162 | assert "Border Collies" in output[0].page_content 163 | safe_delete_database(cosmos_client) 164 | 165 | def test_from_texts_cosine_distance_delete_one( 166 | self, 167 | cosmos_client: Any, 168 | partition_key: Any, 169 | azure_openai_embeddings: AzureAIEmbeddingsModel, 170 | ) -> None: 171 | texts, metadatas = self._get_texts_and_metadata() 172 | 173 | store = AzureCosmosDBNoSqlVectorSearch.from_texts( 174 | texts, 175 | azure_openai_embeddings, 176 | metadatas, 177 | cosmos_client=cosmos_client, 178 | database_name=database_name, 179 | container_name=container_name, 180 | vector_embedding_policy=get_vector_embedding_policy( 181 | "cosine", "float32", 400 182 | ), 183 | indexing_policy=get_vector_indexing_policy("flat"), 184 | cosmos_container_properties={"partition_key": partition_key}, 185 | cosmos_database_properties={}, 186 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 187 | full_text_policy=get_full_text_policy(), 188 | full_text_search_enabled=True, 189 | ) 190 | sleep(1) # waits for Cosmos DB to save contents to the collection 191 | 192 | output = store.similarity_search("Which dog breed is considered a herder?", k=1) 193 | assert output 194 | assert len(output) == 1 195 | assert "Border Collies" in output[0].page_content 196 | 197 | # delete one document 198 | store.delete_document_by_id(str(output[0].metadata["id"])) 199 | sleep(2) 200 | 201 | output2 = store.similarity_search( 202 | "Which dog breed is considered a herder?", k=1 203 | ) # noqa: E501 204 | assert output2 205 | assert len(output2) == 1 206 | assert "Border Collies" not in output2[0].page_content 207 | safe_delete_database(cosmos_client) 208 | 209 | def test_from_documents_cosine_distance_with_filtering( 210 | self, 211 | cosmos_client: Any, 212 | partition_key: Any, 213 | azure_openai_embeddings: AzureAIEmbeddingsModel, 214 | ) -> None: 215 | """Test end to end construction and search.""" 216 | documents = self._get_documents() 217 | 218 | store = AzureCosmosDBNoSqlVectorSearch.from_documents( 219 | documents, 220 | embedding=azure_openai_embeddings, 221 | cosmos_client=cosmos_client, 222 | database_name=database_name, 223 | container_name=container_name, 224 | vector_embedding_policy=get_vector_embedding_policy( 225 | "cosine", "float32", 400 226 | ), 227 | indexing_policy=get_vector_indexing_policy("diskANN"), 228 | cosmos_container_properties={"partition_key": partition_key}, 229 | cosmos_database_properties={}, 230 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 231 | full_text_policy=get_full_text_policy(), 232 | full_text_search_enabled=True, 233 | ) 234 | sleep(1) # waits for Cosmos DB to save contents to the collection 235 | 236 | output = store.similarity_search("Which dog breed is considered a herder?", k=4) 237 | assert len(output) == 4 238 | assert "Border Collies" in output[0].page_content 239 | assert output[0].metadata["a"] == 1 240 | 241 | where = "c.metadata.a = 1" 242 | output = store.similarity_search( 243 | "Which dog breed is considered a herder?", 244 | k=4, 245 | where=where, 246 | with_embedding=True, 247 | ) 248 | 249 | assert len(output) == 3 250 | assert "Border Collies" in output[0].page_content 251 | assert output[0].metadata["a"] == 1 252 | 253 | offset_limit = "OFFSET 0 LIMIT 1" 254 | 255 | output = store.similarity_search( 256 | "Which dog breed is considered a herder?", 257 | k=4, 258 | where=where, 259 | offset_limit=offset_limit, 260 | ) 261 | 262 | assert len(output) == 1 263 | assert "Border Collies" in output[0].page_content 264 | assert output[0].metadata["a"] == 1 265 | safe_delete_database(cosmos_client) 266 | 267 | def test_from_documents_full_text_and_hybrid( 268 | self, 269 | cosmos_client: Any, 270 | partition_key: Any, 271 | azure_openai_embeddings: AzureAIEmbeddingsModel, 272 | ) -> None: 273 | """Test end to end construction and search.""" 274 | documents = self._get_documents() 275 | 276 | store = AzureCosmosDBNoSqlVectorSearch.from_documents( 277 | documents, 278 | embedding=azure_openai_embeddings, 279 | cosmos_client=cosmos_client, 280 | database_name=database_name, 281 | container_name=container_name, 282 | vector_embedding_policy=get_vector_embedding_policy( 283 | "cosine", "float32", 1536 284 | ), 285 | full_text_policy=get_full_text_policy(), 286 | indexing_policy=get_vector_indexing_policy("diskANN"), 287 | cosmos_container_properties={"partition_key": partition_key}, 288 | cosmos_database_properties={}, 289 | vector_search_fields={"text_field": "text", "embedding_field": "embedding"}, 290 | full_text_search_enabled=True, 291 | ) 292 | 293 | sleep(480) # waits for Cosmos DB to save contents to the collection 294 | 295 | # Full text search contains any 296 | where = "FullTextContainsAny(c.text, 'intelligent', 'herders')" 297 | output = store.similarity_search( 298 | "Which dog breed is considered a herder?", 299 | k=5, 300 | where=where, 301 | query_type="full_text_search", 302 | ) 303 | 304 | assert output 305 | assert len(output) == 3 306 | assert "Border Collies" in output[0].page_content 307 | 308 | # Full text search contains all 309 | where = "FullTextContainsAll(c.text, 'intelligent', 'herders')" 310 | 311 | output = store.similarity_search( 312 | "Which dog breed is considered a herder?", 313 | k=5, 314 | where=where, 315 | query_type="full_text_search", 316 | ) 317 | 318 | assert output 319 | assert len(output) == 1 320 | assert "Border Collies" in output[0].page_content 321 | 322 | # Full text search BM25 ranking 323 | full_text_rank_filter = [ 324 | {"search_field": "text", "search_text": "intelligent herders"} 325 | ] 326 | output = store.similarity_search( 327 | "Which dog breed is considered a herder?", 328 | k=5, 329 | query_type="full_text_ranking", 330 | full_text_rank_filter=full_text_rank_filter, 331 | ) 332 | 333 | assert output 334 | assert len(output) == 5 335 | assert "Standard Poodles" in output[0].page_content 336 | 337 | # Full text search successfully queries for data with a single quote 338 | full_text_rank_filter = [{"search_field": "text", "search_text": "'Herders'"}] 339 | output = store.similarity_search( 340 | "Which dog breed is considered a herder?", 341 | k=5, 342 | query_type="full_text_search", 343 | full_text_rank_filter=full_text_rank_filter, 344 | ) 345 | 346 | assert output 347 | assert len(output) == 5 348 | assert "Retrievers" in output[0].page_content 349 | 350 | # Full text search BM25 ranking with filtering 351 | where = "c.metadata.a = 1" 352 | full_text_rank_filter = [ 353 | {"search_field": "text", "search_text": "intelligent herders"} 354 | ] 355 | output = store.similarity_search( 356 | "Which dog breed is considered a herder?", 357 | k=5, 358 | where=where, 359 | query_type="full_text_ranking", 360 | full_text_rank_filter=full_text_rank_filter, 361 | ) 362 | 363 | assert output 364 | assert len(output) == 3 365 | assert "Border Collies" in output[0].page_content 366 | 367 | # Hybrid search RRF ranking combination of full text search and vector search 368 | full_text_rank_filter = [ 369 | {"search_field": "text", "search_text": "intelligent herders"} 370 | ] 371 | output = store.similarity_search( 372 | "Which dog breed is considered a herder?", 373 | k=5, 374 | query_type="hybrid", 375 | full_text_rank_filter=full_text_rank_filter, 376 | ) 377 | 378 | assert output 379 | assert len(output) == 5 380 | assert "Border Collies" in output[0].page_content 381 | 382 | # Hybrid search successfully queries for data with a single quote 383 | full_text_rank_filter = [{"search_field": "text", "search_text": "'energetic'"}] 384 | output = store.similarity_search( 385 | "Which breed is energetic?", 386 | k=5, 387 | query_type="hybrid", 388 | full_text_rank_filter=full_text_rank_filter, 389 | ) 390 | 391 | assert output 392 | assert len(output) == 5 393 | assert "Border Collies" in output[0].page_content 394 | 395 | # Hybrid search RRF ranking with filtering 396 | where = "c.metadata.a = 1" 397 | full_text_rank_filter = [ 398 | {"search_field": "text", "search_text": "intelligent herders"} 399 | ] 400 | output = store.similarity_search( 401 | "Which dog breed is considered a herder?", 402 | k=5, 403 | where=where, 404 | query_type="hybrid", 405 | full_text_rank_filter=full_text_rank_filter, 406 | ) 407 | 408 | assert output 409 | assert len(output) == 3 410 | assert "Border Collies" in output[0].page_content 411 | 412 | # Full text search BM25 ranking with full text filtering 413 | where = "FullTextContains(c.text, 'energetic')" 414 | 415 | full_text_rank_filter = [ 416 | {"search_field": "text", "search_text": "intelligent herders"} 417 | ] 418 | output = store.similarity_search( 419 | "Which dog breed is considered a herder?", 420 | k=5, 421 | where=where, 422 | query_type="full_text_ranking", 423 | full_text_rank_filter=full_text_rank_filter, 424 | ) 425 | 426 | assert output 427 | assert len(output) == 3 428 | assert "Border Collies" in output[0].page_content 429 | 430 | # Full text search BM25 ranking with full text filtering 431 | where = "FullTextContains(c.text, 'energetic') AND c.metadata.a = 2" 432 | full_text_rank_filter = [ 433 | {"search_field": "text", "search_text": "intelligent herders"} 434 | ] 435 | output = store.similarity_search( 436 | "intelligent herders", 437 | k=5, 438 | where=where, 439 | query_type="full_text_ranking", 440 | full_text_rank_filter=full_text_rank_filter, 441 | ) 442 | 443 | assert output 444 | assert len(output) == 2 445 | assert "Standard Poodles" in output[0].page_content 446 | 447 | def _get_documents(self) -> List[Document]: 448 | return [ 449 | Document( 450 | page_content="Border Collies are intelligent, energetic " 451 | "herders skilled in outdoor activities.", 452 | metadata={"a": 1}, 453 | ), 454 | Document( 455 | page_content="Golden Retrievers are friendly, loyal companions " 456 | "with excellent retrieving skills.", 457 | metadata={"a": 2}, 458 | ), 459 | Document( 460 | page_content="Labrador Retrievers are playful, eager " 461 | "learners and skilled retrievers.", 462 | metadata={"a": 1}, 463 | ), 464 | Document( 465 | page_content="Australian Shepherds are agile, energetic " 466 | "herders excelling in outdoor tasks.", 467 | metadata={"a": 2, "b": 1}, 468 | ), 469 | Document( 470 | page_content="German Shepherds are brave, loyal protectors " 471 | "excelling in versatile tasks.", 472 | metadata={"a": 1, "b": 2}, 473 | ), 474 | Document( 475 | page_content="Standard Poodles are intelligent, energetic " 476 | "learners excelling in agility.", 477 | metadata={"a": 2, "b": 3}, 478 | ), 479 | ] 480 | 481 | def _get_texts_and_metadata(self) -> Tuple[List[str], List[Dict[str, Any]]]: 482 | texts = [ 483 | "Border Collies are intelligent, " 484 | "energetic herders skilled in outdoor activities.", 485 | "Golden Retrievers are friendly, " 486 | "loyal companions with excellent retrieving skills.", 487 | "Labrador Retrievers are playful, " 488 | "eager learners and skilled retrievers.", 489 | "Australian Shepherds are agile, " 490 | "energetic herders excelling in outdoor tasks.", 491 | "German Shepherds are brave, " 492 | "loyal protectors excelling in versatile tasks.", 493 | "Standard Poodles are intelligent, " 494 | "energetic learners excelling in agility.", 495 | ] 496 | metadatas = [ 497 | {"a": 1}, 498 | {"a": 2}, 499 | {"a": 1}, 500 | {"a": 2, "b": 1}, 501 | {"a": 1, "b": 2}, 502 | {"a": 2, "b": 1}, 503 | ] 504 | return texts, metadatas 505 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-ai/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-ai/tests/unit_tests/test_chat_models.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import os 5 | from typing import Any, Generator 6 | from unittest import mock 7 | 8 | # import aiohttp to force Pants to include it in the required dependencies 9 | import aiohttp # noqa 10 | import pytest 11 | from azure.ai.inference.models import ( 12 | ChatChoice, 13 | ChatCompletions, 14 | ChatCompletionsToolCall, 15 | ChatResponseMessage, 16 | CompletionsFinishReason, 17 | ModelInfo, 18 | ) 19 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall 20 | 21 | from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel 22 | from langchain_azure_ai.chat_models.inference import ( 23 | _format_tool_call_for_azure_inference, 24 | ) 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | @pytest.fixture(scope="session") 30 | def loop() -> Generator[asyncio.AbstractEventLoop, None, None]: 31 | try: 32 | loop = asyncio.get_running_loop() 33 | except RuntimeError: 34 | loop = asyncio.new_event_loop() 35 | yield loop 36 | loop.close() 37 | 38 | 39 | @pytest.fixture(scope="session") 40 | def test_params() -> dict: 41 | return { 42 | "input": [ 43 | SystemMessage( 44 | content="You are a helpful assistant. When you are asked about if this " 45 | "is a test, you always reply 'Yes, this is a test.'", 46 | ), 47 | HumanMessage(role="user", content="Is this a test?"), 48 | ], 49 | } 50 | 51 | 52 | @pytest.fixture(scope="session") 53 | def test_llm() -> AzureAIChatCompletionsModel: 54 | with mock.patch( 55 | "langchain_azure_ai.chat_models.inference.ChatCompletionsClient", autospec=True 56 | ): 57 | with mock.patch( 58 | "langchain_azure_ai.chat_models.inference.ChatCompletionsClientAsync", 59 | autospec=True, 60 | ): 61 | llm = AzureAIChatCompletionsModel( 62 | endpoint="https://my-endpoint.inference.ai.azure.com", 63 | credential="my-api-key", 64 | ) 65 | llm._client.complete.return_value = ChatCompletions( # type: ignore 66 | choices=[ 67 | ChatChoice( 68 | index=0, 69 | finish_reason=CompletionsFinishReason.STOPPED, 70 | message=ChatResponseMessage( 71 | content="Yes, this is a test.", role="assistant" 72 | ), 73 | ), 74 | ] 75 | ) 76 | llm._client.get_model_info.return_value = ModelInfo( # type: ignore 77 | model_name="my_model_name", 78 | model_provider_name="my_provider_name", 79 | model_type="chat-completions", 80 | ) 81 | llm._async_client.complete = mock.AsyncMock( # type: ignore 82 | return_value=ChatCompletions( # type: ignore 83 | choices=[ 84 | ChatChoice( 85 | index=0, 86 | finish_reason=CompletionsFinishReason.STOPPED, 87 | message=ChatResponseMessage( 88 | content="Yes, this is a test.", role="assistant" 89 | ), 90 | ), 91 | ] 92 | ) 93 | ) 94 | return llm 95 | 96 | 97 | @pytest.fixture() 98 | def test_llm_json() -> AzureAIChatCompletionsModel: 99 | with mock.patch( 100 | "langchain_azure_ai.chat_models.inference.ChatCompletionsClient", autospec=True 101 | ): 102 | llm = AzureAIChatCompletionsModel( 103 | endpoint="https://my-endpoint.inference.ai.azure.com", 104 | credential="my-api-key", 105 | ) 106 | llm._client.complete.return_value = ChatCompletions( # type: ignore 107 | choices=[ 108 | ChatChoice( 109 | index=0, 110 | finish_reason=CompletionsFinishReason.STOPPED, 111 | message=ChatResponseMessage( 112 | content='{ "message": "Yes, this is a test." }', role="assistant" 113 | ), 114 | ), 115 | ] 116 | ) 117 | return llm 118 | 119 | 120 | @pytest.fixture() 121 | def test_llm_tools() -> AzureAIChatCompletionsModel: 122 | with mock.patch( 123 | "langchain_azure_ai.chat_models.inference.ChatCompletionsClient", autospec=True 124 | ): 125 | llm = AzureAIChatCompletionsModel( 126 | endpoint="https://my-endpoint.inference.ai.azure.com", 127 | credential="my-api-key", 128 | ) 129 | llm._client.complete.return_value = ChatCompletions( # type: ignore 130 | choices=[ 131 | ChatChoice( 132 | index=0, 133 | finish_reason=CompletionsFinishReason.TOOL_CALLS, 134 | message=ChatResponseMessage( 135 | role="assistant", 136 | content="", 137 | tool_calls=[ 138 | ChatCompletionsToolCall( 139 | { 140 | "id": "abc0dF1gh", 141 | "type": "function", 142 | "function": { 143 | "name": "echo", 144 | "arguments": '{ "message": "Is this a test?" }', 145 | "call_id": None, 146 | }, 147 | } 148 | ) 149 | ], 150 | ), 151 | ) 152 | ] 153 | ) 154 | return llm 155 | 156 | 157 | def test_chat_completion( 158 | test_llm: AzureAIChatCompletionsModel, test_params: dict 159 | ) -> None: 160 | """Tests the basic chat completion functionality.""" 161 | response = test_llm.invoke(**test_params) 162 | 163 | assert isinstance(response, AIMessage) 164 | if isinstance(response.content, str): 165 | assert response.content.strip() == "Yes, this is a test." 166 | 167 | 168 | def test_achat_completion( 169 | test_llm: AzureAIChatCompletionsModel, 170 | loop: asyncio.AbstractEventLoop, 171 | test_params: dict, 172 | ) -> None: 173 | """Tests the basic chat completion functionality asynchronously.""" 174 | response = loop.run_until_complete(test_llm.ainvoke(**test_params)) 175 | 176 | assert isinstance(response, AIMessage) 177 | if isinstance(response.content, str): 178 | assert response.content.strip() == "Yes, this is a test." 179 | 180 | 181 | @pytest.mark.skipif( 182 | not { 183 | "AZURE_INFERENCE_ENDPOINT", 184 | "AZURE_INFERENCE_CREDENTIAL", 185 | }.issubset(set(os.environ)), 186 | reason="Azure AI endpoint and/or credential are not set.", 187 | ) 188 | def test_stream_chat_completion(test_params: dict) -> None: 189 | """Tests the basic chat completion functionality with streaming.""" 190 | model_name = os.environ.get("AZURE_INFERENCE_MODEL", None) 191 | 192 | llm = AzureAIChatCompletionsModel(model=model_name) 193 | 194 | response_stream = llm.stream(**test_params) 195 | 196 | buffer = "" 197 | for chunk in response_stream: 198 | buffer += chunk.content # type: ignore 199 | 200 | assert buffer.strip() == "Yes, this is a test." 201 | 202 | 203 | @pytest.mark.skipif( 204 | not { 205 | "AZURE_INFERENCE_ENDPOINT", 206 | "AZURE_INFERENCE_CREDENTIAL", 207 | }.issubset(set(os.environ)), 208 | reason="Azure AI endpoint and/or credential are not set.", 209 | ) 210 | def test_astream_chat_completion( 211 | test_params: dict, loop: asyncio.AbstractEventLoop 212 | ) -> None: 213 | """Tests the basic chat completion functionality with streaming.""" 214 | model_name = os.environ.get("AZURE_INFERENCE_MODEL", None) 215 | 216 | llm = AzureAIChatCompletionsModel(model=model_name) 217 | 218 | async def iterate() -> str: 219 | stream = llm.astream(**test_params) 220 | buffer = "" 221 | async for chunk in stream: 222 | buffer += chunk.content # type: ignore 223 | 224 | return buffer 225 | 226 | response = loop.run_until_complete(iterate()) 227 | assert response.strip() == "Yes, this is a test." 228 | 229 | 230 | def test_chat_completion_kwargs( 231 | test_llm_json: AzureAIChatCompletionsModel, 232 | ) -> None: 233 | """Tests chat completions using extra parameters.""" 234 | test_llm_json.model_kwargs.update({"response_format": {"type": "json_object"}}) 235 | response = test_llm_json.invoke( 236 | [ 237 | SystemMessage( 238 | content="You are a helpful assistant. When you are asked about if " 239 | "this is a test, you always reply 'Yes, this is a test.' in a JSON " 240 | "object with key 'message'.", 241 | ), 242 | HumanMessage(content="Is this a test?"), 243 | ], 244 | temperature=0.0, 245 | top_p=1.0, 246 | ) 247 | 248 | assert isinstance(response, AIMessage) 249 | if isinstance(response.content, str): 250 | assert ( 251 | json.loads(response.content.strip()).get("message") 252 | == "Yes, this is a test." 253 | ) 254 | 255 | 256 | def test_chat_completion_with_tools( 257 | test_llm_tools: AzureAIChatCompletionsModel, 258 | ) -> None: 259 | """Tests the chat completion functionality with the help of tools.""" 260 | 261 | def echo(message: str) -> str: 262 | """Echoes the user's message. 263 | 264 | Args: 265 | message: The message to echo 266 | """ 267 | print("Echo: " + message) 268 | return message 269 | 270 | model_with_tools = test_llm_tools.bind_tools([echo]) 271 | 272 | response = model_with_tools.invoke( 273 | [ 274 | SystemMessage( 275 | content="You are an assistant that always echoes the user's message. " 276 | "To echo a message, use the 'Echo' tool.", 277 | ), 278 | HumanMessage(content="Is this a test?"), 279 | ] 280 | ) 281 | 282 | assert isinstance(response, AIMessage) 283 | assert len(response.additional_kwargs["tool_calls"]) == 1 284 | assert response.additional_kwargs["tool_calls"][0]["name"] == "echo" 285 | 286 | 287 | def test_with_structured_output_json_mode( 288 | test_llm_json: AzureAIChatCompletionsModel, 289 | ) -> None: 290 | """Tests with_structured_output using method='json_mode'.""" 291 | # The schema is not actually used by the model in json_mode, but for 292 | # completeness, pass a dict. 293 | schema = {"type": "object", "properties": {"message": {"type": "string"}}} 294 | 295 | runnable = test_llm_json.with_structured_output(schema, method="json_mode") 296 | 297 | messages = [ 298 | SystemMessage( 299 | content="You are a helpful assistant. When you are asked if this is " 300 | "a test, reply with a JSON object with key 'message'." 301 | ), 302 | HumanMessage(content="Is this a test?"), 303 | ] 304 | 305 | response = runnable.invoke(messages) 306 | # The output should be a dict after parsing 307 | assert isinstance(response, dict) 308 | assert response.get("message") == "Yes, this is a test." 309 | 310 | 311 | @pytest.mark.skipif( 312 | not { 313 | "AZURE_INFERENCE_ENDPOINT", 314 | "AZURE_INFERENCE_CREDENTIAL", 315 | }.issubset(set(os.environ)), 316 | reason="Azure AI endpoint and/or credential are not set.", 317 | ) 318 | def test_chat_completion_gpt4o_api_version(test_params: dict) -> None: 319 | """Test chat completions endpoint with api_version indicated for a GPT model.""" 320 | # In case the endpoint being tested serves more than one model 321 | model_name = os.environ.get("AZURE_INFERENCE_MODEL", "gpt-4o") 322 | 323 | llm = AzureAIChatCompletionsModel( 324 | model=model_name, api_version="2024-05-01-preview" 325 | ) 326 | 327 | response = llm.invoke(**test_params) 328 | 329 | assert isinstance(response, AIMessage) 330 | if isinstance(response.content, str): 331 | assert response.content.strip() == "Yes, this is a test." 332 | 333 | 334 | def test_get_metadata(test_llm: AzureAIChatCompletionsModel, caplog: Any) -> None: 335 | """Tests if we can get model metadata back from the endpoint. If so, 336 | `_model_name` should not be 'unknown'. Some endpoints may not support this 337 | and in those cases a warning should be logged. 338 | """ 339 | assert ( 340 | test_llm._model_name != "unknown" 341 | or "does not support model metadata retrieval" in caplog.text 342 | ) 343 | 344 | 345 | def test_format_tool_call_has_function_type() -> None: 346 | tool_call = ToolCall( 347 | id="test-id-123", 348 | name="echo", 349 | args=json.loads('{"message": "Is this a test?"}'), 350 | ) 351 | result = _format_tool_call_for_azure_inference(tool_call) 352 | assert result.get("type") == "function" 353 | assert result.get("function", {}).get("name") == "echo" 354 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/unit_tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from unittest import mock 3 | 4 | # import aiohttp to force Pants to include it in the required dependencies 5 | import aiohttp # noqa 6 | import pytest 7 | from azure.ai.inference.models import EmbeddingItem, EmbeddingsResult 8 | from langchain_core.documents import Document 9 | from langchain_core.vectorstores import InMemoryVectorStore 10 | 11 | from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel 12 | 13 | 14 | @pytest.fixture() 15 | def test_embed_model() -> AzureAIEmbeddingsModel: 16 | with mock.patch( 17 | "langchain_azure_ai.embeddings.inference.EmbeddingsClient", autospec=True 18 | ): 19 | embed_model = AzureAIEmbeddingsModel( 20 | endpoint="https://my-endpoint.inference.ai.azure.com", 21 | credential="my-api-key", 22 | model="my_model_name", 23 | ) 24 | embed_model._client.embed.return_value = EmbeddingsResult( # type: ignore 25 | data=[EmbeddingItem(embedding=[1.0, 2.0, 3.0], index=0)] 26 | ) 27 | return embed_model 28 | 29 | 30 | def test_embed(test_embed_model: AzureAIEmbeddingsModel) -> None: 31 | """Test the basic embedding functionality.""" 32 | # In case the endpoint being tested serves more than one model 33 | documents = [ 34 | Document( 35 | id="1", 36 | page_content="Before college the two main things I worked on, " 37 | "outside of school, were writing and programming.", 38 | ) 39 | ] 40 | vector_store = InMemoryVectorStore(test_embed_model) 41 | vector_store.add_documents(documents=documents) 42 | 43 | results = vector_store.similarity_search(query="Before college", k=1) 44 | 45 | assert len(results) == len(documents) 46 | assert results[0].page_content == documents[0].page_content 47 | 48 | 49 | def test_get_metadata(test_embed_model: AzureAIEmbeddingsModel, caplog: Any) -> None: 50 | """Tests if we can get model metadata back from the endpoint. If so, 51 | model_name should not be 'unknown'. Some endpoints may not support this 52 | and in those cases a warning should be logged. 53 | """ 54 | assert ( 55 | test_embed_model.model_name != "unknown" 56 | or "does not support model metadata retrieval" in caplog.text 57 | ) 58 | -------------------------------------------------------------------------------- /libs/azure-ai/tests/unit_tests/test_queryconstructor_no_sql.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from langchain_core.structured_query import ( 3 | Comparator, 4 | Comparison, 5 | Operation, 6 | Operator, 7 | StructuredQuery, 8 | ) 9 | 10 | from langchain_azure_ai.query_constructors.cosmosdb_no_sql import ( 11 | AzureCosmosDbNoSQLTranslator, 12 | ) 13 | 14 | 15 | def test_visit_structured_query_basic() -> None: 16 | constructor = AzureCosmosDbNoSQLTranslator() 17 | structured_query = StructuredQuery(query="my search terms", limit=None, filter=None) 18 | query, filter = constructor.visit_structured_query(structured_query) 19 | assert query == "my search terms" 20 | assert filter == {} 21 | 22 | 23 | def test_visit_structured_query_with_limit() -> None: 24 | constructor = AzureCosmosDbNoSQLTranslator(table_name="t") 25 | structured_query = StructuredQuery(query="my search terms", limit=10, filter=None) 26 | query, filter = constructor.visit_structured_query(structured_query) 27 | assert query == "my search terms" 28 | assert filter == {} 29 | 30 | 31 | def test_visit_structured_query_with_filter() -> None: 32 | constructor = AzureCosmosDbNoSQLTranslator() 33 | comparison = Comparison(attribute="age", comparator=Comparator.GT, value=30) 34 | structured_query = StructuredQuery(query="my search", limit=None, filter=comparison) 35 | query, filter = constructor.visit_structured_query(structured_query) 36 | assert query == "my search" 37 | assert filter == {"where": "c.age > 30"} 38 | 39 | 40 | def test_visit_comparison_basic() -> None: 41 | constructor = AzureCosmosDbNoSQLTranslator() 42 | comparison = Comparison(attribute="age", comparator=Comparator.GT, value=30) 43 | result = constructor.visit_comparison(comparison) 44 | assert result == "c.age > 30" 45 | 46 | 47 | def test_visit_comparison_with_string() -> None: 48 | constructor = AzureCosmosDbNoSQLTranslator() 49 | comparison = Comparison(attribute="name", comparator=Comparator.EQ, value="John") 50 | result = constructor.visit_comparison(comparison) 51 | assert result == "c.name = 'John'" 52 | 53 | 54 | def test_visit_comparison_with_list() -> None: 55 | constructor = AzureCosmosDbNoSQLTranslator() 56 | comparison = Comparison( 57 | attribute="age", comparator=Comparator.IN, value=[25, 30, 35] 58 | ) 59 | result = constructor.visit_comparison(comparison) 60 | assert result == "c.age IN (25, 30, 35)" 61 | 62 | 63 | def test_visit_comparison_unsupported_operator() -> None: 64 | constructor = AzureCosmosDbNoSQLTranslator() 65 | comparison = Comparison(attribute="age", comparator=Comparator.CONTAIN, value=30) 66 | with pytest.raises(ValueError, match="Unsupported operator"): 67 | constructor.visit_comparison(comparison) 68 | 69 | 70 | def test_visit_operation_basic() -> None: 71 | constructor = AzureCosmosDbNoSQLTranslator() 72 | operation = Operation( 73 | operator=Operator.AND, 74 | arguments=[ 75 | Comparison(attribute="age", comparator=Comparator.GT, value=30), 76 | Comparison(attribute="name", comparator=Comparator.EQ, value="John"), 77 | ], 78 | ) 79 | result = constructor.visit_operation(operation) 80 | assert result == "(c.age > 30 AND c.name = 'John')" 81 | 82 | 83 | def test_visit_operation_not() -> None: 84 | constructor = AzureCosmosDbNoSQLTranslator() 85 | operation = Operation( 86 | operator=Operator.NOT, 87 | arguments=[Comparison(attribute="age", comparator=Comparator.GT, value=30)], 88 | ) 89 | result = constructor.visit_operation(operation) 90 | assert result == "NOT (c.age > 30)" 91 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | test: 10 | poetry run pytest $(TEST_FILE) 11 | 12 | tests: 13 | poetry run pytest $(TEST_FILE) 14 | 15 | test_watch: 16 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 17 | 18 | 19 | ###################### 20 | # LINTING AND FORMATTING 21 | ###################### 22 | 23 | # Define a variable for Python and notebook files. 24 | PYTHON_FILES=. 25 | MYPY_CACHE=.mypy_cache 26 | lint format: PYTHON_FILES=. 27 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 28 | lint_package: PYTHON_FILES=langchain_azure_dynamic_sessions 29 | lint_tests: PYTHON_FILES=tests 30 | lint_tests: MYPY_CACHE=.mypy_cache_test 31 | 32 | lint lint_diff lint_package lint_tests: 33 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES) 34 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff 35 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 36 | 37 | format format_diff: 38 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) 39 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES) 40 | 41 | spell_check: 42 | poetry run codespell --toml pyproject.toml 43 | 44 | spell_fix: 45 | poetry run codespell --toml pyproject.toml -w 46 | 47 | check_imports: $(shell find langchain_azure_dynamic_sessions -name '*.py') 48 | poetry run python ./scripts/check_imports.py $^ 49 | 50 | ###################### 51 | # HELP 52 | ###################### 53 | 54 | help: 55 | @echo '----' 56 | @echo 'check_imports - check imports' 57 | @echo 'format - run code formatters' 58 | @echo 'lint - run linters' 59 | @echo 'test - run unit tests' 60 | @echo 'tests - run unit tests' 61 | @echo 'test TEST_FILE= - run all tests in file' 62 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/README.md: -------------------------------------------------------------------------------- 1 | # langchain-azure-dynamic-sessions 2 | 3 | This package contains the LangChain integration for Azure Container Apps dynamic sessions. You can use it to add a secure and scalable code interpreter to your agents. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install -U langchain-azure-dynamic-sessions 9 | ``` 10 | 11 | ## Usage 12 | 13 | You first need to create an Azure Container Apps session pool and obtain its management endpoint. Then you can use the `SessionsPythonREPLTool` tool to give your agent the ability to execute Python code. 14 | 15 | ```python 16 | from langchain_azure_dynamic_sessions import SessionsPythonREPLTool 17 | 18 | 19 | # get the management endpoint from the session pool in the Azure portal 20 | tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) 21 | 22 | prompt = hub.pull("hwchase17/react") 23 | tools=[tool] 24 | react_agent = create_react_agent( 25 | llm=llm, 26 | tools=tools, 27 | prompt=prompt, 28 | ) 29 | 30 | react_agent_executor = AgentExecutor(agent=react_agent, tools=tools, verbose=True, handle_parsing_errors=True) 31 | 32 | react_agent_executor.invoke({"input": "What is the current time in Vancouver, Canada?"}) 33 | ``` 34 | 35 | By default, the tool uses `DefaultAzureCredential` to authenticate with Azure. If you're using a user-assigned managed identity, you must set the `AZURE_CLIENT_ID` environment variable to the ID of the managed identity. 36 | 37 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/langchain_azure_dynamic_sessions/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides tools for managing dynamic sessions in Azure.""" 2 | 3 | from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool 4 | 5 | __all__ = [ 6 | "SessionsPythonREPLTool", 7 | ] 8 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/langchain_azure_dynamic_sessions/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-dynamic-sessions/langchain_azure_dynamic_sessions/py.typed -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """This package provides tools for managing dynamic sessions in Azure.""" 2 | 3 | from langchain_azure_dynamic_sessions.tools.sessions import SessionsPythonREPLTool 4 | 5 | __all__ = [ 6 | "SessionsPythonREPLTool", 7 | ] 8 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/sessions.py: -------------------------------------------------------------------------------- 1 | """This is the Azure Dynamic Sessions module. 2 | 3 | This module provides the SessionsPythonREPLTool class for 4 | managing dynamic sessions in Azure. 5 | """ 6 | 7 | import importlib.metadata 8 | import json 9 | import os 10 | import re 11 | import urllib 12 | from copy import deepcopy 13 | from dataclasses import dataclass 14 | from datetime import datetime, timedelta, timezone 15 | from io import BytesIO 16 | from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple 17 | from uuid import uuid4 18 | 19 | import requests 20 | from azure.core.credentials import AccessToken 21 | from azure.identity import DefaultAzureCredential 22 | from langchain_core.tools import BaseTool 23 | 24 | try: 25 | _package_version = importlib.metadata.version("langchain-azure-dynamic-sessions") 26 | except importlib.metadata.PackageNotFoundError: 27 | _package_version = "0.0.0" 28 | USER_AGENT = f"langchain-azure-dynamic-sessions/{_package_version} (Language=Python)" 29 | 30 | 31 | def _access_token_provider_factory() -> Callable[[], Optional[str]]: 32 | """Factory function for creating an access token provider function. 33 | 34 | Returns: 35 | Callable[[], Optional[str]]: The access token provider function 36 | """ 37 | access_token: Optional[AccessToken] = None 38 | 39 | def access_token_provider() -> Optional[str]: 40 | nonlocal access_token 41 | if access_token is None or datetime.fromtimestamp( 42 | access_token.expires_on, timezone.utc 43 | ) < datetime.now(timezone.utc) + timedelta(minutes=5): 44 | credential = DefaultAzureCredential() 45 | access_token = credential.get_token("https://dynamicsessions.io/.default") 46 | return access_token.token 47 | 48 | return access_token_provider 49 | 50 | 51 | def _sanitize_input(query: str) -> str: 52 | """Sanitize input to the python REPL. 53 | 54 | Remove whitespace, backtick & python (if llm mistakes python console as terminal) 55 | 56 | Args: 57 | query: The query to sanitize 58 | 59 | Returns: 60 | str: The sanitized query 61 | """ 62 | # Removes `, whitespace & python from start 63 | query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) 64 | # Removes whitespace & ` from end 65 | query = re.sub(r"(\s|`)*$", "", query) 66 | return query 67 | 68 | 69 | @dataclass 70 | class RemoteFileMetadata: 71 | """Metadata for a file in the session.""" 72 | 73 | filename: str 74 | """The filename relative to `/mnt/data`.""" 75 | 76 | size_in_bytes: int 77 | """The size of the file in bytes.""" 78 | 79 | @property 80 | def full_path(self) -> str: 81 | """Get the full path of the file.""" 82 | return f"/mnt/data/{self.filename}" 83 | 84 | @staticmethod 85 | def from_dict(data: dict) -> "RemoteFileMetadata": 86 | """Create a RemoteFileMetadata object from a dictionary.""" 87 | properties = data.get("properties", {}) 88 | return RemoteFileMetadata( 89 | filename=properties.get("filename"), 90 | size_in_bytes=properties.get("size"), 91 | ) 92 | 93 | 94 | class SessionsPythonREPLTool(BaseTool): 95 | r"""Azure Dynamic Sessions tool. 96 | 97 | Setup: 98 | Install ``langchain-azure-dynamic-sessions`` and create a session pool, which you can do by following the instructions [here](https://learn.microsoft.com/en-us/azure/container-apps/sessions-code-interpreter?tabs=azure-cli#create-a-session-pool-with-azure-cli). 99 | 100 | .. code-block:: bash 101 | 102 | pip install -U langchain-azure-dynamic-sessions 103 | 104 | .. code-block:: python 105 | 106 | import getpass 107 | 108 | POOL_MANAGEMENT_ENDPOINT = getpass.getpass("Enter the management endpoint of the session pool: ") 109 | 110 | Instantiation: 111 | .. code-block:: python 112 | 113 | from langchain_azure_dynamic_sessions import SessionsPythonREPLTool 114 | 115 | tool = SessionsPythonREPLTool( 116 | pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT 117 | ) 118 | 119 | 120 | Invocation with args: 121 | .. code-block:: python 122 | 123 | tool.invoke("6 * 7") 124 | 125 | .. code-block:: python 126 | 127 | '{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}' 128 | 129 | Invocation with ToolCall: 130 | 131 | .. code-block:: python 132 | 133 | tool.invoke({"args": {"input":"6 * 7"}, "id": "1", "name": tool.name, "type": "tool_call"}) 134 | 135 | .. code-block:: python 136 | 137 | '{\\n "result": 42,\\n "stdout": "",\\n "stderr": ""\\n}' 138 | """ # noqa: E501 139 | 140 | name: str = "Python_REPL" 141 | description: str = ( 142 | "A Python shell. Use this to execute python commands " 143 | "when you need to perform calculations or computations. " 144 | "Input should be a valid python command. " 145 | "Returns a JSON object with the result, stdout, and stderr. " 146 | ) 147 | 148 | sanitize_input: bool = True 149 | """Whether to sanitize input to the python REPL.""" 150 | 151 | pool_management_endpoint: str 152 | """The management endpoint of the session pool. Should end with a '/'.""" 153 | 154 | access_token_provider: Callable[[], Optional[str]] = ( 155 | _access_token_provider_factory() 156 | ) 157 | """A function that returns the access token to use for the session pool.""" 158 | 159 | session_id: str = str(uuid4()) 160 | """The session ID to use for the code interpreter. Defaults to a random UUID.""" 161 | 162 | response_format: Literal["content_and_artifact"] = "content_and_artifact" 163 | 164 | def _build_url(self, path: str) -> str: 165 | pool_management_endpoint = self.pool_management_endpoint 166 | if not pool_management_endpoint: 167 | raise ValueError("pool_management_endpoint is not set") 168 | if not pool_management_endpoint.endswith("/"): 169 | pool_management_endpoint += "/" 170 | encoded_session_id = urllib.parse.quote(self.session_id) 171 | query = f"identifier={encoded_session_id}&api-version=2024-02-02-preview" 172 | query_separator = "&" if "?" in pool_management_endpoint else "?" 173 | full_url = pool_management_endpoint + path + query_separator + query 174 | return full_url 175 | 176 | def execute(self, python_code: str) -> Any: 177 | """Execute Python code in the session.""" 178 | if self.sanitize_input: 179 | python_code = _sanitize_input(python_code) 180 | 181 | access_token = self.access_token_provider() 182 | api_url = self._build_url("code/execute") 183 | headers = { 184 | "Authorization": f"Bearer {access_token}", 185 | "Content-Type": "application/json", 186 | "User-Agent": USER_AGENT, 187 | } 188 | body = { 189 | "properties": { 190 | "codeInputType": "inline", 191 | "executionType": "synchronous", 192 | "code": python_code, 193 | } 194 | } 195 | 196 | response = requests.post(api_url, headers=headers, json=body) 197 | response.raise_for_status() 198 | response_json = response.json() 199 | properties = response_json.get("properties", {}) 200 | return properties 201 | 202 | def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]: 203 | response = self.execute(python_code) 204 | 205 | # if the result is an image, remove the base64 data 206 | result = deepcopy(response.get("result")) 207 | if isinstance(result, dict): 208 | if result.get("type") == "image" and "base64_data" in result: 209 | result.pop("base64_data") 210 | 211 | content = json.dumps( 212 | { 213 | "result": result, 214 | "stdout": response.get("stdout"), 215 | "stderr": response.get("stderr"), 216 | }, 217 | indent=2, 218 | ) 219 | return content, response 220 | 221 | def upload_file( 222 | self, 223 | *, 224 | data: Optional[BinaryIO] = None, 225 | remote_file_path: Optional[str] = None, 226 | local_file_path: Optional[str] = None, 227 | ) -> RemoteFileMetadata: 228 | """Upload a file to the session. 229 | 230 | Args: 231 | data: The data to upload. 232 | remote_file_path: The path to upload the file to, relative to 233 | `/mnt/data`. If local_file_path is provided, this is defaulted 234 | to its filename. 235 | local_file_path: The path to the local file to upload. 236 | 237 | Returns: 238 | RemoteFileMetadata: The metadata for the uploaded file 239 | """ 240 | if data and local_file_path: 241 | raise ValueError("data and local_file_path cannot be provided together") 242 | 243 | if data: 244 | file_data = data 245 | elif local_file_path: 246 | if not remote_file_path: 247 | remote_file_path = os.path.basename(local_file_path) 248 | file_data = open(local_file_path, "rb") 249 | 250 | access_token = self.access_token_provider() 251 | api_url = self._build_url("files/upload") 252 | headers = { 253 | "Authorization": f"Bearer {access_token}", 254 | "User-Agent": USER_AGENT, 255 | } 256 | files = [("file", (remote_file_path, file_data, "application/octet-stream"))] 257 | 258 | response = requests.request( 259 | "POST", api_url, headers=headers, data={}, files=files 260 | ) 261 | response.raise_for_status() 262 | 263 | response_json = response.json() 264 | return RemoteFileMetadata.from_dict(response_json["value"][0]) 265 | 266 | def download_file( 267 | self, *, remote_file_path: str, local_file_path: Optional[str] = None 268 | ) -> BinaryIO: 269 | """Download a file from the session. 270 | 271 | Args: 272 | remote_file_path: The path to download the file from, 273 | relative to `/mnt/data`. 274 | local_file_path: The path to save the downloaded file to. 275 | If not provided, the file is returned as a BufferedReader. 276 | 277 | Returns: 278 | BinaryIO: The data of the downloaded file. 279 | """ 280 | access_token = self.access_token_provider() 281 | encoded_remote_file_path = urllib.parse.quote(remote_file_path) 282 | api_url = self._build_url(f"files/content/{encoded_remote_file_path}") 283 | headers = { 284 | "Authorization": f"Bearer {access_token}", 285 | "User-Agent": USER_AGENT, 286 | } 287 | 288 | response = requests.get(api_url, headers=headers) 289 | response.raise_for_status() 290 | 291 | if local_file_path: 292 | with open(local_file_path, "wb") as f: 293 | f.write(response.content) 294 | 295 | return BytesIO(response.content) 296 | 297 | def list_files(self) -> List[RemoteFileMetadata]: 298 | """List the files in the session. 299 | 300 | Returns: 301 | list[RemoteFileMetadata]: The metadata for the files in the session 302 | """ 303 | access_token = self.access_token_provider() 304 | api_url = self._build_url("files") 305 | headers = { 306 | "Authorization": f"Bearer {access_token}", 307 | "User-Agent": USER_AGENT, 308 | } 309 | 310 | response = requests.get(api_url, headers=headers) 311 | response.raise_for_status() 312 | 313 | response_json = response.json() 314 | return [RemoteFileMetadata.from_dict(entry) for entry in response_json["value"]] 315 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "langchain-azure-dynamic-sessions" 7 | version = "0.2.0" 8 | description = "An integration package connecting Azure Container Apps dynamic sessions and LangChain" 9 | authors = [] 10 | readme = "README.md" 11 | repository = "https://github.com/langchain-ai/langchain-azure" 12 | license = "MIT" 13 | 14 | [tool.mypy] 15 | disallow_untyped_defs = "True" 16 | 17 | [tool.poetry.urls] 18 | "Source Code" = "https://github.com/langchain-ai/langchain-azure/tree/main/libs/azure-dynamic-sessions" 19 | "Release Notes" = "https://github.com/langchain-ai/langchain-azure/releases" 20 | 21 | [tool.poetry.dependencies] 22 | python = ">=3.9,<4.0" 23 | langchain-core = "^0.3.0" 24 | azure-identity = "^1.16.0" 25 | requests = "^2.31.0" 26 | 27 | [tool.ruff.lint] 28 | select = ["E", "F", "I", "D"] 29 | 30 | [tool.coverage.run] 31 | omit = ["tests/*"] 32 | 33 | [tool.pytest.ini_options] 34 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 35 | markers = [ 36 | "requires: mark tests as requiring a specific library", 37 | "compile: mark placeholder test used to compile integration tests without running them", 38 | ] 39 | asyncio_mode = "auto" 40 | 41 | [tool.poetry.group.test] 42 | optional = true 43 | 44 | [tool.poetry.group.test_integration] 45 | optional = true 46 | 47 | [tool.poetry.group.codespell] 48 | optional = true 49 | 50 | [tool.poetry.group.lint] 51 | optional = true 52 | 53 | [tool.poetry.group.dev] 54 | optional = true 55 | 56 | [tool.ruff.lint.pydocstyle] 57 | convention = "google" 58 | 59 | [tool.ruff.lint.per-file-ignores] 60 | "tests/**" = ["D"] 61 | 62 | [tool.poetry.group.test.dependencies] 63 | pytest = "^7.3.0" 64 | freezegun = "^1.2.2" 65 | pytest-mock = "^3.10.0" 66 | syrupy = "^4.0.2" 67 | pytest-watcher = "^0.3.4" 68 | pytest-asyncio = "^0.21.1" 69 | python-dotenv = "^1.0.1" 70 | # TODO: hack to fix 3.9 builds 71 | cffi = [ 72 | { version = "<1.17.1", python = "<3.10" }, 73 | { version = "*", python = ">=3.10" }, 74 | ] 75 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 76 | 77 | [tool.poetry.group.test_integration.dependencies] 78 | pytest = "^7.3.0" 79 | python-dotenv = "^1.0.1" 80 | 81 | [tool.poetry.group.codespell.dependencies] 82 | codespell = "^2.2.0" 83 | 84 | [tool.poetry.group.lint.dependencies] 85 | ruff = "^0.5" 86 | python-dotenv = "^1.0.1" 87 | pytest = "^7.3.0" 88 | # TODO: hack to fix 3.9 builds 89 | cffi = [ 90 | { version = "<1.17.1", python = "<3.10" }, 91 | { version = "*", python = ">=3.10" }, 92 | ] 93 | 94 | [tool.poetry.group.dev.dependencies] 95 | ipykernel = "^6.29.4" 96 | langchainhub = "^0.1.15" 97 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 98 | langchain-openai = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/partners/openai"} 99 | 100 | [tool.poetry.group.typing.dependencies] 101 | mypy = "^1.10" 102 | types-requests = "^2.31.0.20240406" 103 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 104 | 105 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | """This module checks for specific import statements in the codebase.""" 2 | 3 | import sys 4 | import traceback 5 | from importlib.machinery import SourceFileLoader 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | has_failure = False 10 | for file in files: 11 | try: 12 | SourceFileLoader("x", file).load_module() 13 | except Exception: 14 | has_failure = True 15 | print(file) 16 | traceback.print_exc() 17 | print() 18 | 19 | sys.exit(1 if has_failure else 0) 20 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-dynamic-sessions/tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-dynamic-sessions/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/integration_tests/data/testdata.txt: -------------------------------------------------------------------------------- 1 | test file content -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore[import-not-found] 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/integration_tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from io import BytesIO 4 | 5 | import dotenv # type: ignore[import-not-found] 6 | 7 | from langchain_azure_dynamic_sessions import SessionsPythonREPLTool 8 | 9 | dotenv.load_dotenv() 10 | 11 | POOL_MANAGEMENT_ENDPOINT = os.getenv("AZURE_DYNAMIC_SESSIONS_POOL_MANAGEMENT_ENDPOINT") 12 | TEST_DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "testdata.txt") 13 | TEST_DATA_CONTENT = open(TEST_DATA_PATH, "rb").read() 14 | 15 | 16 | def test_end_to_end() -> None: 17 | tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) # type: ignore[arg-type] 18 | result = tool.run("print('hello world')\n1 + 1") 19 | assert json.loads(result) == { 20 | "result": 2, 21 | "stdout": "hello world\n", 22 | "stderr": "", 23 | } 24 | 25 | # upload file content 26 | uploaded_file1_metadata = tool.upload_file( 27 | remote_file_path="test1.txt", data=BytesIO(b"hello world!!!!!") 28 | ) 29 | assert uploaded_file1_metadata.filename == "test1.txt" 30 | assert uploaded_file1_metadata.size_in_bytes == 16 31 | assert uploaded_file1_metadata.full_path == "/mnt/data/test1.txt" 32 | downloaded_file1 = tool.download_file(remote_file_path="test1.txt") 33 | assert downloaded_file1.read() == b"hello world!!!!!" 34 | 35 | # upload file from buffer 36 | with open(TEST_DATA_PATH, "rb") as f: 37 | uploaded_file2_metadata = tool.upload_file(remote_file_path="test2.txt", data=f) 38 | assert uploaded_file2_metadata.filename == "test2.txt" 39 | downloaded_file2 = tool.download_file(remote_file_path="test2.txt") 40 | assert downloaded_file2.read() == TEST_DATA_CONTENT 41 | 42 | # upload file from disk, specifying remote file path 43 | uploaded_file3_metadata = tool.upload_file( 44 | remote_file_path="test3.txt", local_file_path=TEST_DATA_PATH 45 | ) 46 | assert uploaded_file3_metadata.filename == "test3.txt" 47 | downloaded_file3 = tool.download_file(remote_file_path="test3.txt") 48 | assert downloaded_file3.read() == TEST_DATA_CONTENT 49 | 50 | # upload file from disk, without specifying remote file path 51 | uploaded_file4_metadata = tool.upload_file(local_file_path=TEST_DATA_PATH) 52 | assert uploaded_file4_metadata.filename == os.path.basename(TEST_DATA_PATH) 53 | downloaded_file4 = tool.download_file( 54 | remote_file_path=uploaded_file4_metadata.filename 55 | ) 56 | assert downloaded_file4.read() == TEST_DATA_CONTENT 57 | 58 | # list files 59 | remote_files_metadata = tool.list_files() 60 | assert len(remote_files_metadata) == 4 61 | remote_file_paths = [metadata.filename for metadata in remote_files_metadata] 62 | expected_filenames = [ 63 | "test1.txt", 64 | "test2.txt", 65 | "test3.txt", 66 | os.path.basename(TEST_DATA_PATH), 67 | ] 68 | assert set(remote_file_paths) == set(expected_filenames) 69 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/azure-dynamic-sessions/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_azure_dynamic_sessions import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "SessionsPythonREPLTool", 5 | ] 6 | 7 | 8 | def test_all_imports() -> None: 9 | assert sorted(EXPECTED_ALL) == sorted(__all__) 10 | -------------------------------------------------------------------------------- /libs/azure-dynamic-sessions/tests/unit_tests/test_sessions_python_repl_tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import time 4 | from unittest import mock 5 | from urllib.parse import parse_qs, urlparse 6 | 7 | from azure.core.credentials import AccessToken 8 | 9 | from langchain_azure_dynamic_sessions import SessionsPythonREPLTool 10 | from langchain_azure_dynamic_sessions.tools.sessions import ( 11 | _access_token_provider_factory, 12 | ) 13 | 14 | POOL_MANAGEMENT_ENDPOINT = "https://westus2.dynamicsessions.io/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/sessions-rg/sessionPools/my-pool" 15 | 16 | 17 | def test_default_access_token_provider_returns_token() -> None: 18 | access_token_provider = _access_token_provider_factory() 19 | with mock.patch( 20 | "azure.identity.DefaultAzureCredential.get_token" 21 | ) as mock_get_token: 22 | mock_get_token.return_value = AccessToken("token_value", 0) 23 | access_token = access_token_provider() 24 | assert access_token == "token_value" 25 | 26 | 27 | def test_default_access_token_provider_returns_cached_token() -> None: 28 | access_token_provider = _access_token_provider_factory() 29 | with mock.patch( 30 | "azure.identity.DefaultAzureCredential.get_token" 31 | ) as mock_get_token: 32 | mock_get_token.return_value = AccessToken( 33 | "token_value", int(time.time() + 1000) 34 | ) 35 | access_token = access_token_provider() 36 | assert access_token == "token_value" 37 | assert mock_get_token.call_count == 1 38 | 39 | mock_get_token.return_value = AccessToken( 40 | "new_token_value", int(time.time() + 1000) 41 | ) 42 | access_token = access_token_provider() 43 | assert access_token == "token_value" 44 | assert mock_get_token.call_count == 1 45 | 46 | 47 | def test_default_access_token_provider_refreshes_expiring_token() -> None: 48 | access_token_provider = _access_token_provider_factory() 49 | with mock.patch( 50 | "azure.identity.DefaultAzureCredential.get_token" 51 | ) as mock_get_token: 52 | mock_get_token.return_value = AccessToken("token_value", int(time.time() - 1)) 53 | access_token = access_token_provider() 54 | assert access_token == "token_value" 55 | assert mock_get_token.call_count == 1 56 | 57 | mock_get_token.return_value = AccessToken( 58 | "new_token_value", int(time.time() + 1000) 59 | ) 60 | access_token = access_token_provider() 61 | assert access_token == "new_token_value" 62 | assert mock_get_token.call_count == 2 63 | 64 | 65 | @mock.patch("requests.post") 66 | @mock.patch("azure.identity.DefaultAzureCredential.get_token") 67 | def test_code_execution_calls_api( 68 | mock_get_token: mock.MagicMock, mock_post: mock.MagicMock 69 | ) -> None: 70 | tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) 71 | mock_post.return_value.json.return_value = { 72 | "$id": "1", 73 | "properties": { 74 | "$id": "2", 75 | "status": "Success", 76 | "stdout": "hello world\n", 77 | "stderr": "", 78 | "result": "", 79 | "executionTimeInMilliseconds": 33, 80 | }, 81 | } 82 | mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000)) 83 | 84 | result = tool.run("print('hello world')") 85 | 86 | assert json.loads(result) == { 87 | "result": "", 88 | "stdout": "hello world\n", 89 | "stderr": "", 90 | } 91 | 92 | api_url = f"{POOL_MANAGEMENT_ENDPOINT}/code/execute" 93 | headers = { 94 | "Authorization": "Bearer token_value", 95 | "Content-Type": "application/json", 96 | "User-Agent": mock.ANY, 97 | } 98 | body = { 99 | "properties": { 100 | "codeInputType": "inline", 101 | "executionType": "synchronous", 102 | "code": "print('hello world')", 103 | } 104 | } 105 | mock_post.assert_called_once_with(mock.ANY, headers=headers, json=body) 106 | 107 | called_headers = mock_post.call_args.kwargs["headers"] 108 | assert re.match( 109 | r"^langchain-azure-dynamic-sessions/\d+\.\d+\.\d+.* \(Language=Python\)", 110 | called_headers["User-Agent"], 111 | ) 112 | 113 | called_api_url = mock_post.call_args.args[0] 114 | assert called_api_url.startswith(api_url) 115 | 116 | 117 | @mock.patch("requests.post") 118 | @mock.patch("azure.identity.DefaultAzureCredential.get_token") 119 | def test_uses_specified_session_id( 120 | mock_get_token: mock.MagicMock, mock_post: mock.MagicMock 121 | ) -> None: 122 | tool = SessionsPythonREPLTool( 123 | pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, 124 | session_id="00000000-0000-0000-0000-000000000003", 125 | ) 126 | mock_post.return_value.json.return_value = { 127 | "$id": "1", 128 | "properties": { 129 | "$id": "2", 130 | "status": "Success", 131 | "stdout": "", 132 | "stderr": "", 133 | "result": "2", 134 | "executionTimeInMilliseconds": 33, 135 | }, 136 | } 137 | mock_get_token.return_value = AccessToken("token_value", int(time.time() + 1000)) 138 | tool.run("1 + 1") 139 | call_url = mock_post.call_args.args[0] 140 | parsed_url = urlparse(call_url) 141 | call_identifier = parse_qs(parsed_url.query)["identifier"][0] 142 | assert call_identifier == "00000000-0000-0000-0000-000000000003" 143 | 144 | 145 | def test_sanitizes_input() -> None: 146 | tool = SessionsPythonREPLTool(pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT) 147 | with mock.patch("requests.post") as mock_post: 148 | mock_post.return_value.json.return_value = { 149 | "$id": "1", 150 | "properties": { 151 | "$id": "2", 152 | "status": "Success", 153 | "stdout": "", 154 | "stderr": "", 155 | "result": "", 156 | "executionTimeInMilliseconds": 33, 157 | }, 158 | } 159 | tool.run("```python\nprint('hello world')\n```") 160 | body = mock_post.call_args.kwargs["json"] 161 | assert body["properties"]["code"] == "print('hello world')" 162 | 163 | 164 | def test_does_not_sanitize_input() -> None: 165 | tool = SessionsPythonREPLTool( 166 | pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, sanitize_input=False 167 | ) 168 | with mock.patch("requests.post") as mock_post: 169 | mock_post.return_value.json.return_value = { 170 | "$id": "1", 171 | "properties": { 172 | "$id": "2", 173 | "status": "Success", 174 | "stdout": "", 175 | "stderr": "", 176 | "result": "", 177 | "executionTimeInMilliseconds": 33, 178 | }, 179 | } 180 | tool.run("```python\nprint('hello world')\n```") 181 | body = mock_post.call_args.kwargs["json"] 182 | assert body["properties"]["code"] == "```python\nprint('hello world')\n```" 183 | 184 | 185 | def test_uses_custom_access_token_provider() -> None: 186 | def custom_access_token_provider() -> str: 187 | return "custom_token" 188 | 189 | tool = SessionsPythonREPLTool( 190 | pool_management_endpoint=POOL_MANAGEMENT_ENDPOINT, 191 | access_token_provider=custom_access_token_provider, 192 | ) 193 | 194 | with mock.patch("requests.post") as mock_post: 195 | mock_post.return_value.json.return_value = { 196 | "$id": "1", 197 | "properties": { 198 | "$id": "2", 199 | "status": "Success", 200 | "stdout": "", 201 | "stderr": "", 202 | "result": "", 203 | "executionTimeInMilliseconds": 33, 204 | }, 205 | } 206 | tool.run("print('hello world')") 207 | headers = mock_post.call_args.kwargs["headers"] 208 | assert headers["Authorization"] == "Bearer custom_token" 209 | -------------------------------------------------------------------------------- /libs/sqlserver/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ruff_cache 3 | .pytest_cache 4 | -------------------------------------------------------------------------------- /libs/sqlserver/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/sqlserver/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | test: 10 | poetry run pytest $(TEST_FILE) 11 | 12 | tests: 13 | poetry run pytest $(TEST_FILE) 14 | 15 | test_watch: 16 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 17 | 18 | 19 | ###################### 20 | # LINTING AND FORMATTING 21 | ###################### 22 | 23 | # Define a variable for Python and notebook files. 24 | PYTHON_FILES=. 25 | MYPY_CACHE=.mypy_cache 26 | lint format: PYTHON_FILES=. 27 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/azure --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 28 | lint_package: PYTHON_FILES=langchain_sqlserver 29 | lint_tests: PYTHON_FILES=tests 30 | lint_tests: MYPY_CACHE=.mypy_cache_test 31 | 32 | lint lint_diff lint_package lint_tests: 33 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES) 34 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff 35 | [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 36 | 37 | format format_diff: 38 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) 39 | [ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES) 40 | 41 | spell_check: 42 | poetry run codespell --toml pyproject.toml 43 | 44 | spell_fix: 45 | poetry run codespell --toml pyproject.toml -w 46 | 47 | check_imports: $(shell find langchain_sqlserver -name '*.py') 48 | poetry run python ./scripts/check_imports.py $^ 49 | 50 | ###################### 51 | # HELP 52 | ###################### 53 | 54 | help: 55 | @echo '----' 56 | @echo 'check_imports - check imports' 57 | @echo 'format - run code formatters' 58 | @echo 'lint - run linters' 59 | @echo 'test - run unit tests' 60 | @echo 'tests - run unit tests' 61 | @echo 'test TEST_FILE= - run all tests in file' 62 | -------------------------------------------------------------------------------- /libs/sqlserver/README.md: -------------------------------------------------------------------------------- 1 | # langchain-sqlserver 2 | 3 | This package contains the LangChain integration for Azure SQL and SQL Server. 4 | 5 | > [!NOTE] 6 | > Vector Functions are in Public Preview. Learn the details about vectors in Azure SQL here: https://aka.ms/azure-sql-vector-public-preview 7 | 8 | ## Installation 9 | 10 | ```bash 11 | pip install -U langchain-sqlserver 12 | ``` 13 | -------------------------------------------------------------------------------- /libs/sqlserver/langchain_sqlserver/__init__.py: -------------------------------------------------------------------------------- 1 | """LangChain integration for SQL Server.""" 2 | 3 | from langchain_sqlserver.vectorstores import SQLServer_VectorStore 4 | 5 | __all__ = [ 6 | "SQLServer_VectorStore", 7 | ] 8 | -------------------------------------------------------------------------------- /libs/sqlserver/langchain_sqlserver/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/sqlserver/langchain_sqlserver/py.typed -------------------------------------------------------------------------------- /libs/sqlserver/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-sqlserver" 3 | version = "0.1.2" 4 | description = "An integration package to support SQL Server in LangChain." 5 | authors = [] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9,<4.0" 11 | SQLAlchemy = ">=2.0.0,<3" 12 | azure-identity = "^1.16.0" 13 | langchain-core = "^0.3.0" 14 | pyodbc = ">=5.0.0,<6.0.0" 15 | numpy = "^1" 16 | 17 | [tool.poetry.group.codespell.dependencies] 18 | codespell = "^2.2.0" 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 22 | 23 | [tool.poetry.group.lint.dependencies] 24 | ruff = "^0.5" 25 | python-dotenv = "^1.0.1" 26 | pytest = "^7.4.3" 27 | 28 | [tool.poetry.group.test.dependencies] 29 | pydantic = "^2.9.2" 30 | pytest = "^7.4.3" 31 | pytest-mock = "^3.10.0" 32 | pytest-watcher = "^0.3.4" 33 | pytest-asyncio = "^0.21.1" 34 | python-dotenv = "^1.0.1" 35 | syrupy = "^4.7.2" 36 | langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"} 37 | langchain-text-splitters = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/text-splitters"} 38 | 39 | [tool.poetry.group.test_integration.dependencies] 40 | pytest = "^7.3.0" 41 | python-dotenv = "^1.0.1" 42 | 43 | [tool.poetry.urls] 44 | "Source Code" = "https://github.com/langchain-ai/langchain-azure/tree/main/libs/sqlserver" 45 | "Release Notes" = "https://github.com/langchain-ai/langchain-azure/releases" 46 | 47 | [tool.mypy] 48 | disallow_untyped_defs = "True" 49 | 50 | [tool.poetry.group.typing.dependencies] 51 | mypy = "^1.10" 52 | 53 | [tool.ruff.lint] 54 | select = ["E", "F", "I", "D"] 55 | 56 | [tool.coverage.run] 57 | omit = ["tests/*"] 58 | 59 | [tool.pytest.ini_options] 60 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 61 | markers = [ 62 | "requires: mark tests as requiring a specific library", 63 | "compile: mark placeholder test used to compile integration tests without running them", 64 | ] 65 | asyncio_mode = "auto" 66 | 67 | [tool.poetry.group.test] 68 | optional = true 69 | 70 | [tool.poetry.group.test_integration] 71 | optional = true 72 | 73 | [tool.poetry.group.codespell] 74 | optional = true 75 | 76 | [tool.poetry.group.lint] 77 | optional = true 78 | 79 | [tool.poetry.group.dev] 80 | optional = true 81 | 82 | [tool.ruff.lint.pydocstyle] 83 | convention = "google" 84 | 85 | [tool.ruff.lint.per-file-ignores] 86 | "tests/**" = ["D"] 87 | 88 | [build-system] 89 | requires = ["poetry-core"] 90 | build-backend = "poetry.core.masonry.api" 91 | -------------------------------------------------------------------------------- /libs/sqlserver/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | """This module checks for specific import statements in the codebase.""" 2 | 3 | import sys 4 | import traceback 5 | from importlib.machinery import SourceFileLoader 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | has_failure = False 10 | for file in files: 11 | try: 12 | SourceFileLoader("x", file).load_module() 13 | except Exception: 14 | has_failure = True 15 | print(file) 16 | traceback.print_exc() 17 | print() 18 | 19 | sys.exit(1 if has_failure else 0) 20 | -------------------------------------------------------------------------------- /libs/sqlserver/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/sqlserver/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/sqlserver/tests/__init__.py -------------------------------------------------------------------------------- /libs/sqlserver/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/sqlserver/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/sqlserver/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/sqlserver/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-azure/f0412aa74539ee8fb826fbe51e3d0da703b8dd89/libs/sqlserver/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/sqlserver/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_sqlserver import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "SQLServer_VectorStore", 5 | ] 6 | 7 | 8 | def test_all_imports() -> None: 9 | assert sorted(EXPECTED_ALL) == sorted(__all__) 10 | -------------------------------------------------------------------------------- /libs/sqlserver/tests/unit_tests/test_vectorstores.py: -------------------------------------------------------------------------------- 1 | """Test SQLServer_VectorStore functionality.""" 2 | 3 | import os 4 | from unittest import mock 5 | from unittest.mock import Mock 6 | 7 | import pytest 8 | 9 | from langchain_sqlserver.vectorstores import SQLServer_VectorStore 10 | from tests.utils.fake_embeddings import DeterministicFakeEmbedding 11 | 12 | pytest.skip( 13 | "Skipping these tests pending resource availability", allow_module_level=True 14 | ) 15 | 16 | # Connection String values should be provided in the 17 | # environment running this test suite. 18 | # 19 | _CONNECTION_STRING_WITH_UID_AND_PWD = str( 20 | os.environ.get("TEST_AZURESQLSERVER_CONNECTION_STRING_WITH_UID") 21 | ) 22 | _CONNECTION_STRING_WITH_TRUSTED_CONNECTION = str( 23 | os.environ.get("TEST_AZURESQLSERVER_TRUSTED_CONNECTION") 24 | ) 25 | _ENTRA_ID_CONNECTION_STRING_NO_PARAMS = str( 26 | os.environ.get("TEST_ENTRA_ID_CONNECTION_STRING_NO_PARAMS") 27 | ) 28 | _ENTRA_ID_CONNECTION_STRING_TRUSTED_CONNECTION_NO = str( 29 | os.environ.get("TEST_ENTRA_ID_CONNECTION_STRING_TRUSTED_CONNECTION_NO") 30 | ) 31 | _TABLE_NAME = "langchain_vector_store_tests" 32 | EMBEDDING_LENGTH = 1536 33 | 34 | 35 | # We need to mock this so that actual connection is not attempted 36 | # after mocking _provide_token. 37 | @mock.patch("sqlalchemy.dialects.mssql.dialect.initialize") 38 | @mock.patch("langchain_sqlserver.vectorstores.SQLServer_VectorStore._provide_token") 39 | @mock.patch( 40 | "langchain_sqlserver.vectorstores.SQLServer_VectorStore._prepare_json_data_type" 41 | ) 42 | def test_that_given_a_valid_entra_id_connection_string_entra_id_authentication_is_used( 43 | prep_data_type: Mock, 44 | provide_token: Mock, 45 | dialect_initialize: Mock, 46 | ) -> None: 47 | """Test that if a valid entra_id connection string is passed in 48 | to SQLServer_VectorStore object, entra id authentication is used 49 | and connection is successful.""" 50 | 51 | # Connection string is of the form below. 52 | # "mssql+pyodbc://lc-test.database.windows.net,1433/lcvectorstore 53 | # ?driver=ODBC+Driver+17+for+SQL+Server" 54 | store = connect_to_vector_store(_ENTRA_ID_CONNECTION_STRING_NO_PARAMS) 55 | # _provide_token is called only during Entra ID authentication. 56 | provide_token.assert_called() 57 | store.drop() 58 | 59 | # reset the mock so that it can be reused. 60 | provide_token.reset_mock() 61 | 62 | # "mssql+pyodbc://lc-test.database.windows.net,1433/lcvectorstore 63 | # ?driver=ODBC+Driver+17+for+SQL+Server&Trusted_Connection=no" 64 | store = connect_to_vector_store(_ENTRA_ID_CONNECTION_STRING_TRUSTED_CONNECTION_NO) 65 | provide_token.assert_called() 66 | store.drop() 67 | 68 | 69 | # We need to mock this so that actual connection is not attempted 70 | # after mocking _provide_token. 71 | @mock.patch("sqlalchemy.dialects.mssql.dialect.initialize") 72 | @mock.patch("langchain_sqlserver.vectorstores.SQLServer_VectorStore._provide_token") 73 | @mock.patch( 74 | "langchain_sqlserver.vectorstores.SQLServer_VectorStore._prepare_json_data_type" 75 | ) 76 | def test_that_given_a_connection_string_with_uid_and_pwd_entra_id_auth_is_not_used( 77 | prep_data_type: Mock, 78 | provide_token: Mock, 79 | dialect_initialize: Mock, 80 | ) -> None: 81 | """Test that if a connection string is provided to SQLServer_VectorStore object, 82 | and connection string has username and password, entra id authentication is not 83 | used and connection is successful.""" 84 | 85 | # Connection string contains username and password, 86 | # mssql+pyodbc://username:password@lc-test.database.windows.net,1433/lcvectorstore 87 | # ?driver=ODBC+Driver+17+for+SQL+Server" 88 | store = connect_to_vector_store(_CONNECTION_STRING_WITH_UID_AND_PWD) 89 | # _provide_token is called only during Entra ID authentication. 90 | provide_token.assert_not_called() 91 | store.drop() 92 | 93 | 94 | # We need to mock this so that actual connection is not attempted 95 | # after mocking _provide_token. 96 | @mock.patch("sqlalchemy.dialects.mssql.dialect.initialize") 97 | @mock.patch("langchain_sqlserver.vectorstores.SQLServer_VectorStore._provide_token") 98 | @mock.patch( 99 | "langchain_sqlserver.vectorstores.SQLServer_VectorStore._prepare_json_data_type" 100 | ) 101 | def test_that_connection_string_with_trusted_connection_yes_does_not_use_entra_id_auth( 102 | prep_data_type: Mock, 103 | provide_token: Mock, 104 | dialect_initialize: Mock, 105 | ) -> None: 106 | """Test that if a connection string is provided to SQLServer_VectorStore object, 107 | and connection string has `trusted_connection` set to `yes`, entra id 108 | authentication is not used and connection is successful.""" 109 | 110 | # Connection string is of the form below. 111 | # mssql+pyodbc://@lc-test.database.windows.net,1433/lcvectorstore 112 | # ?driver=ODBC+Driver+17+for+SQL+Server&trusted_connection=yes" 113 | store = connect_to_vector_store(_CONNECTION_STRING_WITH_TRUSTED_CONNECTION) 114 | # _provide_token is called only during Entra ID authentication. 115 | provide_token.assert_not_called() 116 | store.drop() 117 | 118 | 119 | def connect_to_vector_store(conn_string: str) -> SQLServer_VectorStore: 120 | return SQLServer_VectorStore( 121 | connection_string=conn_string, 122 | embedding_length=EMBEDDING_LENGTH, 123 | # DeterministicFakeEmbedding returns embeddings of the same 124 | # size as `embedding_length`. 125 | embedding_function=DeterministicFakeEmbedding(size=EMBEDDING_LENGTH), 126 | table_name=_TABLE_NAME, 127 | ) 128 | -------------------------------------------------------------------------------- /libs/sqlserver/tests/utils/fake_embeddings.py: -------------------------------------------------------------------------------- 1 | """Copied from LangChain Community.""" 2 | 3 | import hashlib 4 | from typing import List 5 | 6 | import numpy as np 7 | from langchain_core.embeddings import Embeddings 8 | from pydantic import BaseModel 9 | 10 | 11 | class DeterministicFakeEmbedding(Embeddings, BaseModel): 12 | """ 13 | Fake embedding model that always returns 14 | the same embedding vector for the same text. 15 | """ 16 | 17 | size: int 18 | """The size of the embedding vector.""" 19 | 20 | def _get_embedding(self, seed: int) -> List[float]: 21 | # set the seed for the random generator 22 | np.random.seed(seed) 23 | return list(abs(np.random.normal(size=self.size))) 24 | 25 | def _get_seed(self, text: str) -> int: 26 | """ 27 | Get a seed for the random generator, using the hash of the text. 28 | """ 29 | return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**8 30 | 31 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 32 | return [self._get_embedding(seed=self._get_seed(_)) for _ in texts] 33 | 34 | def embed_query(self, text: str) -> List[float]: 35 | return self._get_embedding(seed=self._get_seed(text)) 36 | -------------------------------------------------------------------------------- /libs/sqlserver/tests/utils/filtering_test_cases.py: -------------------------------------------------------------------------------- 1 | """Copied from LangChain community. 2 | 3 | Module contains test cases for testing filtering of documents in vector stores. 4 | """ 5 | 6 | from langchain_core.documents import Document 7 | 8 | metadatas = [ 9 | { 10 | "name": "adam", 11 | "date": "2021-01-01", 12 | "count": 1, 13 | "is_active": True, 14 | "tags": ["a", "b"], 15 | "location": [1.0, 2.0], 16 | "id": 1, 17 | "height": 10.0, # Float column 18 | "happiness": 0.9, # Float column 19 | "sadness": 0.1, # Float column 20 | }, 21 | { 22 | "name": "bob", 23 | "date": "2021-01-02", 24 | "count": 2, 25 | "is_active": False, 26 | "tags": ["b", "c"], 27 | "location": [2.0, 3.0], 28 | "id": 2, 29 | "height": 5.7, # Float column 30 | "happiness": 0.8, # Float column 31 | "sadness": 0.1, # Float column 32 | }, 33 | { 34 | "name": "jane", 35 | "date": "2021-01-01", 36 | "count": 3, 37 | "is_active": True, 38 | "tags": ["b", "d"], 39 | "location": [3.0, 4.0], 40 | "id": 3, 41 | "height": 2.4, # Float column 42 | "happiness": None, 43 | # Sadness missing intentionally 44 | }, 45 | ] 46 | texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas] 47 | IDS = [str(metadata["id"]) for metadata in metadatas] 48 | DOCUMENTS = [ 49 | Document(page_content=text, metadata=metadata) 50 | for text, metadata in zip(texts, metadatas) 51 | ] 52 | 53 | 54 | TYPE_1_FILTERING_TEST_CASES = [ 55 | # These tests only involve equality checks 56 | ( 57 | {"id": 1}, 58 | [1], 59 | ), 60 | # String field 61 | ( 62 | # check name 63 | {"name": "adam"}, 64 | [1], 65 | ), 66 | # Boolean fields 67 | ( 68 | {"is_active": True}, 69 | [1, 3], 70 | ), 71 | ( 72 | {"is_active": False}, 73 | [2], 74 | ), 75 | # And semantics for top level filtering 76 | ( 77 | {"id": 1, "is_active": True}, 78 | [1], 79 | ), 80 | ( 81 | {"id": 1, "is_active": False}, 82 | [], 83 | ), 84 | ] 85 | 86 | TYPE_2_FILTERING_TEST_CASES = [ 87 | # These involve equality checks and other operators 88 | # like $ne, $gt, $gte, $lt, $lte, $not 89 | ( 90 | {"id": 1}, 91 | [1], 92 | ), 93 | ( 94 | {"id": {"$ne": 1}}, 95 | [2, 3], 96 | ), 97 | ( 98 | {"id": {"$gt": 1}}, 99 | [2, 3], 100 | ), 101 | ( 102 | {"id": {"$gte": 1}}, 103 | [1, 2, 3], 104 | ), 105 | ( 106 | {"id": {"$lt": 1}}, 107 | [], 108 | ), 109 | ( 110 | {"id": {"$lte": 1}}, 111 | [1], 112 | ), 113 | # Repeat all the same tests with name (string column) 114 | ( 115 | {"name": "adam"}, 116 | [1], 117 | ), 118 | ( 119 | {"name": "bob"}, 120 | [2], 121 | ), 122 | ( 123 | {"name": {"$eq": "adam"}}, 124 | [1], 125 | ), 126 | ( 127 | {"name": {"$ne": "adam"}}, 128 | [2, 3], 129 | ), 130 | # And also gt, gte, lt, lte relying on lexicographical ordering 131 | ( 132 | {"name": {"$gt": "jane"}}, 133 | [], 134 | ), 135 | ( 136 | {"name": {"$gte": "jane"}}, 137 | [3], 138 | ), 139 | ( 140 | {"name": {"$lt": "jane"}}, 141 | [1, 2], 142 | ), 143 | ( 144 | {"name": {"$lte": "jane"}}, 145 | [1, 2, 3], 146 | ), 147 | ( 148 | {"is_active": {"$eq": True}}, 149 | [1, 3], 150 | ), 151 | ( 152 | {"is_active": {"$ne": True}}, 153 | [2], 154 | ), 155 | # Test float column. 156 | ( 157 | {"height": {"$gt": 5.0}}, 158 | [1, 2], 159 | ), 160 | ( 161 | {"height": {"$gte": 5.0}}, 162 | [1, 2], 163 | ), 164 | ( 165 | {"height": {"$lt": 5.0}}, 166 | [3], 167 | ), 168 | ( 169 | {"height": {"$lte": 5.8}}, 170 | [2, 3], 171 | ), 172 | ] 173 | 174 | TYPE_3_FILTERING_TEST_CASES = [ 175 | # These involve usage of AND and OR operators 176 | ( 177 | {"$or": [{"id": 1}, {"id": 2}]}, 178 | [1, 2], 179 | ), 180 | ( 181 | {"$or": [{"id": 1}, {"name": "bob"}]}, 182 | [1, 2], 183 | ), 184 | ( 185 | {"$and": [{"id": 1}, {"id": 2}]}, 186 | [], 187 | ), 188 | ( 189 | {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, 190 | [1, 2, 3], 191 | ), 192 | ] 193 | 194 | TYPE_4_FILTERING_TEST_CASES = [ 195 | # These involve special operators like $in, $nin, $between 196 | # Test between 197 | ( 198 | {"id": {"$between": (1, 2)}}, 199 | [1, 2], 200 | ), 201 | ( 202 | {"id": {"$between": (1, 1)}}, 203 | [1], 204 | ), 205 | ( 206 | {"name": {"$in": ["adam", "bob"]}}, 207 | [1, 2], 208 | ), 209 | ] 210 | 211 | TYPE_5_FILTERING_TEST_CASES = [ 212 | # These involve special operators like $like, $ilike that 213 | # may be specified to certain databases. 214 | ( 215 | {"name": {"$like": "a%"}}, 216 | [1], 217 | ), 218 | ( 219 | {"name": {"$like": "%a%"}}, # adam and jane 220 | [1, 3], 221 | ), 222 | ] 223 | --------------------------------------------------------------------------------