├── .github ├── actions │ └── poetry_setup │ │ └── action.yml ├── dependabot.yml ├── scripts │ ├── check_diff.py │ └── get_min_versions.py └── workflows │ ├── _codespell.yml │ ├── _compile_integration_test.yml │ ├── _integration_test.yml │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ ├── _test_release.yml │ ├── check_diffs.yml │ └── extract_ignored_words_list.py ├── .gitignore ├── LICENSE ├── README.md ├── libs ├── aws │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── langchain_aws │ │ ├── __init__.py │ │ ├── agents │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── types.py │ │ │ └── utils.py │ │ ├── chains │ │ │ ├── __init__.py │ │ │ └── graph_qa │ │ │ │ ├── __init__.py │ │ │ │ ├── neptune_cypher.py │ │ │ │ ├── neptune_sparql.py │ │ │ │ └── prompts.py │ │ ├── chat_models │ │ │ ├── __init__.py │ │ │ ├── bedrock.py │ │ │ ├── bedrock_converse.py │ │ │ └── sagemaker_endpoint.py │ │ ├── document_compressors │ │ │ ├── __init__.py │ │ │ └── rerank.py │ │ ├── embeddings │ │ │ ├── __init__.py │ │ │ └── bedrock.py │ │ ├── function_calling.py │ │ ├── graphs │ │ │ ├── __init__.py │ │ │ ├── neptune_graph.py │ │ │ └── neptune_rdf_graph.py │ │ ├── llms │ │ │ ├── __init__.py │ │ │ ├── bedrock.py │ │ │ └── sagemaker_endpoint.py │ │ ├── py.typed │ │ ├── retrievers │ │ │ ├── __init__.py │ │ │ ├── bedrock.py │ │ │ └── kendra.py │ │ ├── runnables │ │ │ ├── __init__.py │ │ │ └── q_business.py │ │ ├── utilities │ │ │ ├── math.py │ │ │ ├── redis.py │ │ │ └── utils.py │ │ ├── utils.py │ │ └── vectorstores │ │ │ ├── __init__.py │ │ │ └── inmemorydb │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── cache.py │ │ │ ├── constants.py │ │ │ ├── filters.py │ │ │ └── schema.py │ ├── poetry.lock │ ├── pyproject.toml │ ├── scripts │ │ ├── check_imports.py │ │ └── lint_imports.sh │ └── tests │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── integration_tests │ │ ├── __init__.py │ │ ├── agents │ │ │ └── test_bedrock_agents.py │ │ ├── chat_models │ │ │ ├── __init__.py │ │ │ ├── test_bedrock.py │ │ │ ├── test_bedrock_converse.py │ │ │ ├── test_sagemaker_endpoint.py │ │ │ └── test_standard.py │ │ ├── embeddings │ │ │ ├── __init__.py │ │ │ └── test_bedrock_embeddings.py │ │ ├── graphs │ │ │ └── __init__.py │ │ ├── llms │ │ │ ├── __init__.py │ │ │ ├── test_bedrock.py │ │ │ └── test_sagemaker_endpoint.py │ │ ├── retrievers │ │ │ ├── __init__.py │ │ │ ├── test_amazon_kendra_retriever.py │ │ │ └── test_amazon_knowledgebases_retriever.py │ │ └── test_compile.py │ │ └── unit_tests │ │ ├── __init__.py │ │ ├── __snapshots__ │ │ └── test_standard.ambr │ │ ├── agents │ │ ├── test_bedrock_agents.py │ │ └── test_utils.py │ │ ├── chat_models │ │ ├── __init__.py │ │ ├── __snapshots__ │ │ │ └── test_bedrock_converse.ambr │ │ ├── test_bedrock.py │ │ ├── test_bedrock_converse.py │ │ └── test_sagemaker_endpoint.py │ │ ├── document_compressors │ │ ├── __init__.py │ │ └── test_rerank.py │ │ ├── llms │ │ ├── __init__.py │ │ └── test_bedrock.py │ │ ├── retrievers │ │ └── test_bedrock.py │ │ ├── test_imports.py │ │ ├── test_standard.py │ │ └── test_utils.py └── langgraph-checkpoint-aws │ ├── .gitignore │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── langgraph_checkpoint_aws │ ├── __init__.py │ ├── constants.py │ ├── models.py │ ├── saver.py │ ├── session.py │ └── utils.py │ ├── poetry.lock │ ├── pyproject.toml │ └── tests │ ├── __init__.py │ ├── integration_tests │ ├── __init__.py │ ├── saver │ │ ├── __init__.py │ │ └── test_saver.py │ └── test_compile.py │ └── unit_tests │ ├── __init__.py │ ├── conftest.py │ ├── test_saver.py │ ├── test_session.py │ └── test_utils.py └── samples ├── agents ├── agents_with_nova.ipynb ├── bedrock_agent_langgraph.ipynb ├── bedrock_agents_code_interpreter.ipynb ├── bedrock_agents_roc.ipynb ├── inline_agent_langraph.ipynb └── inline_agent_runnable_roc.ipynb ├── document_compressors └── rerank.ipynb ├── inmemory ├── memorydb-guide.pdf ├── retriever.ipynb ├── semantic_cache.ipynb └── vectorestore.ipynb └── models └── getting_started_with_nova.ipynb /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} 94 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | groups: 8 | actions-dependencies: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import Dict 4 | 5 | LIB_DIRS = ["libs/aws", "libs/langgraph-checkpoint-aws"] 6 | 7 | if __name__ == "__main__": 8 | files = sys.argv[1:] 9 | 10 | dirs_to_run: Dict[str, set] = { 11 | "lint": set(), 12 | "test": set(), 13 | } 14 | 15 | if len(files) == 300: 16 | # max diff length is 300 files - there are likely files missing 17 | raise ValueError("Max diff reached. Please manually run CI on changed libs.") 18 | 19 | for file in files: 20 | if any( 21 | file.startswith(dir_) 22 | for dir_ in ( 23 | ".github/workflows", 24 | ".github/tools", 25 | ".github/actions", 26 | ".github/scripts/check_diff.py", 27 | ) 28 | ): 29 | # add all LIB_DIRS for infra changes 30 | dirs_to_run["test"].update(LIB_DIRS) 31 | 32 | if any(file.startswith(dir_) for dir_ in LIB_DIRS): 33 | for dir_ in LIB_DIRS: 34 | if file.startswith(dir_): 35 | dirs_to_run["test"].add(dir_) 36 | elif file.startswith("libs/"): 37 | raise ValueError( 38 | f"Unknown lib: {file}. check_diff.py likely needs " 39 | "an update for this new library!" 40 | ) 41 | 42 | outputs = { 43 | "dirs-to-lint": list(dirs_to_run["lint"] | dirs_to_run["test"]), 44 | "dirs-to-test": list(dirs_to_run["test"]), 45 | } 46 | for key, value in outputs.items(): 47 | json_output = json.dumps(value) 48 | print(f"{key}={json_output}") # noqa: T201 49 | -------------------------------------------------------------------------------- /.github/scripts/get_min_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 11): 4 | import tomllib 5 | else: 6 | # for python 3.10 and below, which doesnt have stdlib tomllib 7 | import tomli as tomllib 8 | 9 | from packaging.version import parse as parse_version 10 | import re 11 | 12 | MIN_VERSION_LIBS = ["langchain-core"] 13 | 14 | SKIP_IF_PULL_REQUEST = ["langchain-core"] 15 | 16 | 17 | def get_min_version(version: str) -> str: 18 | # base regex for x.x.x with cases for rc/post/etc 19 | # valid strings: https://peps.python.org/pep-0440/#public-version-identifiers 20 | vstring = r"\d+(?:\.\d+){0,2}(?:(?:a|b|rc|\.post|\.dev)\d+)?" 21 | # case ^x.x.x 22 | _match = re.match(f"^\\^({vstring})$", version) 23 | if _match: 24 | return _match.group(1) 25 | 26 | # case >=x.x.x,=({vstring}),<({vstring})$", version) 28 | if _match: 29 | _min = _match.group(1) 30 | _max = _match.group(2) 31 | assert parse_version(_min) < parse_version(_max) 32 | return _min 33 | 34 | # case x.x.x 35 | _match = re.match(f"^({vstring})$", version) 36 | if _match: 37 | return _match.group(1) 38 | 39 | raise ValueError(f"Unrecognized version format: {version}") 40 | 41 | 42 | def get_min_version_from_toml(toml_path: str, versions_for: str): 43 | # Parse the TOML file 44 | with open(toml_path, "rb") as file: 45 | toml_data = tomllib.load(file) 46 | 47 | # Get the dependencies from tool.poetry.dependencies 48 | dependencies = toml_data["tool"]["poetry"]["dependencies"] 49 | 50 | # Initialize a dictionary to store the minimum versions 51 | min_versions = {} 52 | 53 | # Iterate over the libs in MIN_VERSION_LIBS 54 | for lib in MIN_VERSION_LIBS: 55 | if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST: 56 | # some libs only get checked on release because of simultaneous 57 | # changes 58 | continue 59 | # Check if the lib is present in the dependencies 60 | if lib in dependencies: 61 | # Get the version string 62 | version_string = dependencies[lib] 63 | 64 | if isinstance(version_string, dict): 65 | version_string = version_string["version"] 66 | 67 | # Use parse_version to get the minimum supported version from version_string 68 | min_version = get_min_version(version_string) 69 | 70 | # Store the minimum version in the min_versions dictionary 71 | min_versions[lib] = min_version 72 | 73 | return min_versions 74 | 75 | 76 | if __name__ == "__main__": 77 | # Get the TOML file path from the command line argument 78 | toml_file = sys.argv[1] 79 | versions_for = sys.argv[2] 80 | assert versions_for in ["release", "pull_request"] 81 | 82 | # Call the function to get the minimum versions 83 | min_versions = get_min_version_from_toml(toml_file, versions_for) 84 | 85 | print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()])) 86 | -------------------------------------------------------------------------------- /.github/workflows/_codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: make spell_check 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | working-directory: 8 | required: true 9 | type: string 10 | description: "From which folder this pipeline executes" 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | codespell: 17 | name: (Check for spelling errors) 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Dependencies 25 | run: | 26 | pip install toml 27 | 28 | - name: Extract Ignore Words List 29 | working-directory: ${{ inputs.working-directory }} 30 | run: | 31 | # Use a Python script to extract the ignore words list from pyproject.toml 32 | python ../../.github/workflows/extract_ignored_words_list.py 33 | id: extract_ignore_words 34 | 35 | - name: Codespell 36 | uses: codespell-project/actions-codespell@v2 37 | with: 38 | skip: guide_imports.json 39 | ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} 40 | -------------------------------------------------------------------------------- /.github/workflows/_compile_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: compile-integration-test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.10" 25 | - "3.11" 26 | - "3.12" 27 | name: "poetry run pytest -m compile tests/integration_tests #${{ matrix.python-version }}" 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 32 | uses: "./.github/actions/poetry_setup" 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | poetry-version: ${{ env.POETRY_VERSION }} 36 | working-directory: ${{ inputs.working-directory }} 37 | cache-key: compile-integration 38 | 39 | - name: Install integration dependencies 40 | shell: bash 41 | run: poetry install --with=test_integration,test 42 | 43 | - name: Check integration tests compile 44 | shell: bash 45 | run: poetry run pytest -m compile tests/integration_tests 46 | 47 | - name: Ensure the tests did not create any additional files 48 | shell: bash 49 | run: | 50 | set -eu 51 | 52 | STATUS="$(git status)" 53 | echo "$STATUS" 54 | 55 | # grep will exit non-zero if the target message isn't found, 56 | # and `set -e` above will cause the step to fail. 57 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 58 | -------------------------------------------------------------------------------- /.github/workflows/_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: integration-test 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: choice 9 | options: 10 | - libs/aws 11 | - libs/langgraph-checkpoint-aws 12 | fork: 13 | required: true 14 | type: string 15 | default: 'langchain-ai' 16 | description: "Which fork to run this test against" 17 | branch: 18 | required: true 19 | type: string 20 | default: 'main' 21 | description: "Which branch to run this test against" 22 | test-file: 23 | required: true 24 | type: string 25 | default: "tests/integration_tests/**/test*.py" 26 | description: "Which test file to run" 27 | 28 | env: 29 | PYTHON_VERSION: "3.11" 30 | POETRY_VERSION: "1.7.1" 31 | 32 | jobs: 33 | build: 34 | defaults: 35 | run: 36 | working-directory: ${{ inputs.working-directory }} 37 | runs-on: ubuntu-latest 38 | name: "make integration_test" 39 | steps: 40 | - uses: actions/checkout@v4 41 | with: 42 | repository: "${{ inputs.fork }}/langchain-aws" 43 | ref: "${{ inputs.branch }}" 44 | 45 | - name: Set up Python ${{ env.PYTHON_VERSION }} + Poetry ${{ env.POETRY_VERSION }} 46 | uses: "./.github/actions/poetry_setup" 47 | with: 48 | python-version: ${{ env.PYTHON_VERSION }} 49 | poetry-version: ${{ env.POETRY_VERSION }} 50 | working-directory: ${{ inputs.working-directory }} 51 | cache-key: core 52 | 53 | - name: Install dependencies 54 | shell: bash 55 | run: poetry install --with test 56 | 57 | - name: Configure AWS Credentials 58 | uses: aws-actions/configure-aws-credentials@v4 59 | with: 60 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 61 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 62 | aws-region: ${{ secrets.AWS_REGION }} 63 | 64 | - name: Run integration tests 65 | shell: bash 66 | run: | 67 | poetry run pytest -vv -s ${{ inputs.test-file }} 68 | 69 | - name: Ensure the tests did not create any additional files 70 | shell: bash 71 | run: | 72 | set -eu 73 | 74 | STATUS="$(git status)" 75 | echo "$STATUS" 76 | 77 | # grep will exit non-zero if the target message isn't found, 78 | # and `set -e` above will cause the step to fail. 79 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 80 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 14 | 15 | # This env var allows us to get inline annotations when ruff has complaints. 16 | RUFF_OUTPUT_FORMAT: github 17 | 18 | jobs: 19 | build: 20 | name: "make lint #${{ matrix.python-version }}" 21 | runs-on: ubuntu-latest 22 | strategy: 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.9" 33 | - "3.12" 34 | steps: 35 | - uses: actions/checkout@v4 36 | 37 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 38 | uses: "./.github/actions/poetry_setup" 39 | with: 40 | python-version: ${{ matrix.python-version }} 41 | poetry-version: ${{ env.POETRY_VERSION }} 42 | working-directory: ${{ inputs.working-directory }} 43 | cache-key: lint-with-extras 44 | 45 | - name: Check Poetry File 46 | shell: bash 47 | working-directory: ${{ inputs.working-directory }} 48 | run: | 49 | poetry check 50 | 51 | - name: Check lock file 52 | shell: bash 53 | working-directory: ${{ inputs.working-directory }} 54 | run: | 55 | poetry lock --check 56 | 57 | - name: Install dependencies 58 | # Also installs dev/lint/test/typing dependencies, to ensure we have 59 | # type hints for as many of our libraries as possible. 60 | # This helps catch errors that require dependencies to be spotted, for example: 61 | # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 62 | # 63 | # If you change this configuration, make sure to change the `cache-key` 64 | # in the `poetry_setup` action above to stop using the old cache. 65 | # It doesn't matter how you change it, any change will cause a cache-bust. 66 | working-directory: ${{ inputs.working-directory }} 67 | run: | 68 | poetry install --with lint,typing 69 | 70 | - name: Get .mypy_cache to speed up mypy 71 | uses: actions/cache@v4 72 | env: 73 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 74 | with: 75 | path: | 76 | ${{ env.WORKDIR }}/.mypy_cache 77 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 78 | 79 | 80 | - name: Analysing the code with our lint 81 | working-directory: ${{ inputs.working-directory }} 82 | run: | 83 | make lint_package 84 | 85 | - name: Install unit+integration test dependencies 86 | working-directory: ${{ inputs.working-directory }} 87 | run: | 88 | poetry install --with test,test_integration 89 | 90 | - name: Get .mypy_cache_test to speed up mypy 91 | uses: actions/cache@v4 92 | env: 93 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 94 | with: 95 | path: | 96 | ${{ env.WORKDIR }}/.mypy_cache_test 97 | key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }} 98 | 99 | - name: Analysing the code with our lint 100 | working-directory: ${{ inputs.working-directory }} 101 | run: | 102 | make lint_tests 103 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.10" 25 | - "3.11" 26 | - "3.12" 27 | name: "make test #${{ matrix.python-version }}" 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 32 | uses: "./.github/actions/poetry_setup" 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | poetry-version: ${{ env.POETRY_VERSION }} 36 | working-directory: ${{ inputs.working-directory }} 37 | cache-key: core 38 | 39 | - name: Install dependencies 40 | shell: bash 41 | run: poetry install --with test --with test_integration 42 | 43 | - name: Run core tests 44 | shell: bash 45 | run: | 46 | make test 47 | 48 | - name: Ensure the tests did not create any additional files 49 | shell: bash 50 | run: | 51 | set -eu 52 | 53 | STATUS="$(git status)" 54 | echo "$STATUS" 55 | 56 | # grep will exit non-zero if the target message isn't found, 57 | # and `set -e` above will cause the step to fail. 58 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 59 | -------------------------------------------------------------------------------- /.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | dangerous-nonmaster-release: 11 | required: false 12 | type: boolean 13 | default: false 14 | description: "Release from a non-master branch (danger!)" 15 | 16 | env: 17 | POETRY_VERSION: "1.7.1" 18 | PYTHON_VERSION: "3.10" 19 | 20 | jobs: 21 | build: 22 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmaster-release 23 | runs-on: ubuntu-latest 24 | 25 | outputs: 26 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 27 | version: ${{ steps.check-version.outputs.version }} 28 | 29 | steps: 30 | - uses: actions/checkout@v4 31 | 32 | - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} 33 | uses: "./.github/actions/poetry_setup" 34 | with: 35 | python-version: ${{ env.PYTHON_VERSION }} 36 | poetry-version: ${{ env.POETRY_VERSION }} 37 | working-directory: ${{ inputs.working-directory }} 38 | cache-key: release 39 | 40 | # We want to keep this build stage *separate* from the release stage, 41 | # so that there's no sharing of permissions between them. 42 | # The release stage has trusted publishing and GitHub repo contents write access, 43 | # and we want to keep the scope of that access limited just to the release job. 44 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 45 | # could get access to our GitHub or PyPI credentials. 46 | # 47 | # Per the trusted publishing GitHub Action: 48 | # > It is strongly advised to separate jobs for building [...] 49 | # > from the publish job. 50 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 51 | - name: Build project for distribution 52 | run: poetry build 53 | working-directory: ${{ inputs.working-directory }} 54 | 55 | - name: Upload build 56 | uses: actions/upload-artifact@v4 57 | with: 58 | name: test-dist 59 | path: ${{ inputs.working-directory }}/dist/ 60 | 61 | - name: Check Version 62 | id: check-version 63 | shell: bash 64 | working-directory: ${{ inputs.working-directory }} 65 | run: | 66 | echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT 67 | echo version="$(poetry version --short)" >> $GITHUB_OUTPUT 68 | 69 | publish: 70 | needs: 71 | - build 72 | runs-on: ubuntu-latest 73 | permissions: 74 | # This permission is used for trusted publishing: 75 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 76 | # 77 | # Trusted publishing has to also be configured on PyPI for each package: 78 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 79 | id-token: write 80 | 81 | steps: 82 | - uses: actions/checkout@v4 83 | 84 | - uses: actions/download-artifact@v4 85 | with: 86 | name: test-dist 87 | path: ${{ inputs.working-directory }}/dist/ 88 | 89 | - name: Publish to test PyPI 90 | uses: pypa/gh-action-pypi-publish@release/v1 91 | with: 92 | packages-dir: ${{ inputs.working-directory }}/dist/ 93 | verbose: true 94 | print-hash: true 95 | repository-url: https://test.pypi.org/legacy/ 96 | 97 | # We overwrite any existing distributions with the same name and version. 98 | # This is *only for CI use* and is *extremely dangerous* otherwise! 99 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 100 | skip-existing: true 101 | attestations: false 102 | -------------------------------------------------------------------------------- /.github/workflows/check_diffs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | # If another push to the same PR or branch happens while this workflow is still running, 10 | # cancel the earlier run in favor of the next run. 11 | # 12 | # There's no point in testing an outdated version of the code. GitHub only allows 13 | # a limited number of job runners to be active at the same time, so it's better to cancel 14 | # pointless jobs early so that more useful jobs can run sooner. 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | env: 20 | POETRY_VERSION: "1.7.1" 21 | 22 | jobs: 23 | build: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.10' 30 | - id: files 31 | uses: Ana06/get-changed-files@v2.3.0 32 | - id: set-matrix 33 | run: | 34 | python .github/scripts/check_diff.py ${{ steps.files.outputs.all }} >> $GITHUB_OUTPUT 35 | outputs: 36 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }} 37 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }} 38 | lint: 39 | name: cd ${{ matrix.working-directory }} 40 | needs: [ build ] 41 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }} 42 | strategy: 43 | matrix: 44 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }} 45 | uses: ./.github/workflows/_lint.yml 46 | with: 47 | working-directory: ${{ matrix.working-directory }} 48 | secrets: inherit 49 | 50 | test: 51 | name: cd ${{ matrix.working-directory }} 52 | needs: [ build ] 53 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 54 | strategy: 55 | matrix: 56 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 57 | uses: ./.github/workflows/_test.yml 58 | with: 59 | working-directory: ${{ matrix.working-directory }} 60 | secrets: inherit 61 | 62 | compile-integration-tests: 63 | name: cd ${{ matrix.working-directory }} 64 | needs: [ build ] 65 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 66 | strategy: 67 | matrix: 68 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 69 | uses: ./.github/workflows/_compile_integration_test.yml 70 | with: 71 | working-directory: ${{ matrix.working-directory }} 72 | secrets: inherit 73 | ci_success: 74 | name: "CI Success" 75 | needs: [build, lint, test, compile-integration-tests] 76 | if: | 77 | always() 78 | runs-on: ubuntu-latest 79 | env: 80 | JOBS_JSON: ${{ toJSON(needs) }} 81 | RESULTS_JSON: ${{ toJSON(needs.*.result) }} 82 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} 83 | steps: 84 | - name: "CI Success" 85 | run: | 86 | echo $JOBS_JSON 87 | echo $RESULTS_JSON 88 | echo "Exiting with $EXIT_CODE" 89 | exit $EXIT_CODE 90 | -------------------------------------------------------------------------------- /.github/workflows/extract_ignored_words_list.py: -------------------------------------------------------------------------------- 1 | import toml 2 | 3 | pyproject_toml = toml.load("pyproject.toml") 4 | 5 | # Extract the ignore words list (adjust the key as per your TOML structure) 6 | ignore_words_list = ( 7 | pyproject_toml.get("tool", {}).get("codespell", {}).get("ignore-words-list") 8 | ) 9 | 10 | print(f"::set-output name=ignore_words_list::{ignore_words_list}") 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | .vscode/ 3 | .idea/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # Google GitHub Actions credentials files created by: 35 | # https://github.com/google-github-actions/auth 36 | # 37 | # That action recommends adding this gitignore to prevent accidentally committing keys. 38 | gha-creds-*.json 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | .mypy_cache_test/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | docs/docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | notebooks/ 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .envrc 120 | .venv* 121 | venv* 122 | env/ 123 | ENV/ 124 | env.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # macOS display setting files 145 | .DS_Store 146 | 147 | # Wandb directory 148 | wandb/ 149 | 150 | # asdf tool versions 151 | .tool-versions 152 | /.ruff_cache/ 153 | 154 | *.pkl 155 | *.bin 156 | 157 | # integration test artifacts 158 | data_map* 159 | \[('_type', 'fake'), ('stop', None)] 160 | 161 | # Replit files 162 | *replit* 163 | 164 | node_modules 165 | docs/.yarn/ 166 | docs/node_modules/ 167 | docs/.docusaurus/ 168 | docs/.cache-loader/ 169 | docs/_dist 170 | docs/api_reference/*api_reference.rst 171 | docs/api_reference/_build 172 | docs/api_reference/*/ 173 | !docs/api_reference/_static/ 174 | !docs/api_reference/templates/ 175 | !docs/api_reference/themes/ 176 | docs/docs/build 177 | docs/docs/node_modules 178 | docs/docs/yarn.lock 179 | _dist 180 | docs/docs/templates 181 | 182 | prof 183 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 LangChain 🤝 Amazon Web Services (AWS) 2 | 3 | This monorepo provides LangChain and LangGraph components for various AWS services. It aims to replace and expand upon the existing LangChain AWS components found in the `langchain-community` package in the LangChain repository. 4 | 5 | The following packages are hosted in this repository: 6 | - `langchain-aws` ([PyPi](https://pypi.org/project/langchain-aws/)) 7 | - `langgraph-checkpoint-aws` ([PyPi](https://pypi.org/project/langgraph-checkpoint-aws/)) 8 | 9 | ## Features 10 | 11 | ### LangChain 12 | - **LLMs**: Includes LLM classes for AWS services like [Bedrock](https://aws.amazon.com/bedrock) and [SageMaker Endpoints](https://aws.amazon.com/sagemaker/deploy/), allowing you to leverage their language models within LangChain. 13 | - **Retrievers**: Supports retrievers for services like [Amazon Kendra](https://aws.amazon.com/kendra/) and [KnowledgeBases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/), enabling efficient retrieval of relevant information in your RAG applications. 14 | - **Graphs**: Provides components for working with [AWS Neptune](https://aws.amazon.com/neptune/) graphs within LangChain. 15 | - **Agents**: Includes Runnables to support [Amazon Bedrock Agents](https://aws.amazon.com/bedrock/agents/), allowing you to leverage Bedrock Agents within LangChain and LangGraph. 16 | 17 | ### LangGraph 18 | - **Checkpointers**: Provides a custom checkpointing solution for LangGraph agents using the [AWS Bedrock Session Management Service](https://docs.aws.amazon.com/bedrock/latest/userguide/sessions.html). 19 | 20 | ...and more to come. This repository will continue to expand and offer additional components for various AWS services as development progresses. 21 | 22 | **Note**: This repository will replace all AWS integrations currently present in the `langchain-community` package. Users are encouraged to migrate to this repository as soon as possible. 23 | 24 | ## Installation 25 | 26 | You can install the `langchain-aws` package from PyPI. 27 | 28 | ```bash 29 | pip install langchain-aws 30 | ``` 31 | 32 | The `langgraph-checkpoint-aws` package can also be installed from PyPI. 33 | 34 | ```bash 35 | pip install langgraph-checkpoint-aws 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### LangChain 41 | 42 | Here's a simple example of how to use the `langchain-aws` package. 43 | 44 | ```python 45 | from langchain_aws import ChatBedrock 46 | 47 | # Initialize the Bedrock chat model 48 | llm = ChatBedrock( 49 | model="anthropic.claude-3-sonnet-20240229-v1:0", 50 | beta_use_converse_api=True 51 | ) 52 | 53 | # Invoke the llm 54 | response = llm.invoke("Hello! How are you today?") 55 | print(response) 56 | ``` 57 | 58 | For more detailed usage examples and documentation, please refer to the [LangChain docs](https://python.langchain.com/docs/integrations/platforms/aws/). 59 | 60 | ### LangGraph 61 | 62 | You can find usage examples for `langgraph-checkpoint-aws` [here](https://github.com/michaelnchin/langchain-aws/blob/main/libs/langgraph-checkpoint-aws/README.md#usage). 63 | 64 | ## Contributing 65 | 66 | We welcome contributions to this repository! To get started, please follow the contribution guide for your specific project of interest: 67 | 68 | - For `langchain-aws`, see [here](https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/CONTRIBUTING.md). 69 | - For `langgraph-checkpointer-aws`, see [here](https://github.com/langchain-ai/langchain-aws/blob/main/libs/langgraph-checkpoint-aws/CONTRIBUTING.md). 70 | 71 | Each guide provides detailed instructions on how to set up the project for development and guidance on how to contribute effectively. 72 | 73 | ## License 74 | 75 | This project is licensed under the [MIT License](LICENSE). 76 | -------------------------------------------------------------------------------- /libs/aws/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | conduct@langchain.dev. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations -------------------------------------------------------------------------------- /libs/aws/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribute Code 2 | 3 | To contribute to this project, please follow the ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow. Please do not try to push directly to this repo. 4 | 5 | Note related issues and tag relevant maintainers in pull requests. 6 | 7 | Pull requests cannot land without passing the formatting, linting, and testing checks first. See [Testing](#testing) and 8 | [Formatting and Linting](#formatting-and-linting) for how to run these checks locally. 9 | 10 | It's essential that we maintain great documentation and testing. Add or update relevant unit or integration test when possible. 11 | These live in `tests/unit_tests` and `tests/integration_tests`. Example notebooks and documentation lives in `/docs` inside the 12 | LangChain repo [here](https://github.com/langchain-ai/langchain/tree/master/docs). 13 | 14 | We are a small, progress-oriented team. If there's something you'd like to add or change, opening a pull request is the 15 | best way to get our attention. 16 | 17 | ## 🚀 Quick Start 18 | 19 | This quick start guide explains how to setup the repository locally for development. 20 | 21 | ### Dependency Management: Poetry and other env/dependency managers 22 | 23 | This project utilizes [Poetry](https://python-poetry.org/) v1.7.1+ as a dependency manager. 24 | 25 | ❗Note: *Before installing Poetry*, if you use `Conda`, create and activate a new Conda env (e.g. `conda create -n langchain python=3.9`) 26 | 27 | Install Poetry: **[documentation on how to install it](https://python-poetry.org/docs/#installation)**. 28 | 29 | ❗Note: If you use `Conda` or `Pyenv` as your environment/package manager, after installing Poetry, 30 | tell Poetry to use the virtualenv python environment (`poetry config virtualenvs.prefer-active-python true`) 31 | 32 | The instructions here assume that you run all commands from the `libs/aws` directory. 33 | 34 | ```bash 35 | cd libs/aws 36 | ``` 37 | 38 | ### Install for development 39 | 40 | ```bash 41 | poetry install --with lint,typing,test,test_integration,dev 42 | ``` 43 | 44 | Then verify the installation. 45 | 46 | ```bash 47 | make test 48 | ``` 49 | 50 | If during installation you receive a `WheelFileValidationError` for `debugpy`, please make sure you are running 51 | Poetry v1.6.1+. This bug was present in older versions of Poetry (e.g. 1.4.1) and has been resolved in newer releases. 52 | If you are still seeing this bug on v1.6.1+, you may also try disabling "modern installation" 53 | (`poetry config installer.modern-installation false`) and re-installing requirements. 54 | See [this `debugpy` issue](https://github.com/microsoft/debugpy/issues/1246) for more details. 55 | 56 | ### Testing 57 | 58 | Unit tests cover modular logic that does not require calls to outside APIs. 59 | If you add new logic, please add a unit test. 60 | 61 | To run unit tests: 62 | 63 | ```bash 64 | make test 65 | ``` 66 | 67 | Integration tests cover the end-to-end service calls as much as possible. 68 | However, in certain cases this might not be practical, so you can mock the 69 | service response for these tests. There are examples of this in the repo, 70 | that can help you write your own tests. If you have suggestions to improve 71 | this, please get in touch with us. 72 | 73 | To run the integration tests: 74 | 75 | ```bash 76 | make integration_test 77 | ``` 78 | 79 | ### Formatting and Linting 80 | 81 | Formatting ensures that the code in this repo has consistent style so that the 82 | code looks more presentable and readable. It corrects these errors when you run 83 | the formatting command. Linting finds and highlights the code errors and helps 84 | avoid coding practices that can lead to errors. 85 | 86 | Run both of these locally before submitting a PR. The CI scripts will run these 87 | when you submit a PR, and you won't be able to merge changes without fixing 88 | issues identified by the CI. 89 | 90 | #### Code Formatting 91 | 92 | Formatting for this project is done via [ruff](https://docs.astral.sh/ruff/rules/). 93 | 94 | To run format: 95 | 96 | ```bash 97 | make format 98 | ``` 99 | 100 | Additionally, you can run the formatter only on the files that have been modified in your current branch 101 | as compared to the master branch using the `format_diff` command. This is especially useful when you have 102 | made changes to a subset of the project and want to ensure your changes are properly formatted without 103 | affecting the rest of the codebase. 104 | 105 | ```bash 106 | make format_diff 107 | ``` 108 | 109 | #### Linting 110 | 111 | Linting for this project is done via a combination of [ruff](https://docs.astral.sh/ruff/rules/) and [mypy](http://mypy-lang.org/). 112 | 113 | To run lint: 114 | 115 | ```bash 116 | make lint 117 | ``` 118 | 119 | In addition, you can run the linter only on the files that have been modified in your current branch as compared to the master branch using the `lint_diff` command. This can be very helpful when you've made changes to only certain parts of the project and want to ensure your changes meet the linting standards without having to check the entire codebase. 120 | 121 | ```bash 122 | make lint_diff 123 | ``` 124 | 125 | We recognize linting can be annoying - if you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed. 126 | 127 | #### Spellcheck 128 | 129 | Spellchecking for this project is done via [codespell](https://github.com/codespell-project/codespell). 130 | Note that `codespell` finds common typos, so it could have false-positive (correctly spelled but rarely used) and false-negatives (not finding misspelled) words. 131 | 132 | To check spelling for this project: 133 | 134 | ```bash 135 | make spell_check 136 | ``` 137 | 138 | To fix spelling in place: 139 | 140 | ```bash 141 | make spell_fix 142 | ``` 143 | 144 | If codespell is incorrectly flagging a word, you can skip spellcheck for that word by adding it to the codespell config in the `pyproject.toml` file. 145 | 146 | ```python 147 | [tool.codespell] 148 | ... 149 | # Add here: 150 | ignore-words-list = 'momento,collison,ned,foor,reworked,path,whats,apply,misogyny,unsecure' 151 | ``` 152 | -------------------------------------------------------------------------------- /libs/aws/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/aws/Makefile: -------------------------------------------------------------------------------- 1 | ###################### 2 | # NON FILE TARGETS 3 | ###################### 4 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 5 | 6 | ###################### 7 | # ALL TARGETS 8 | ###################### 9 | 10 | all: help ## Default target = help 11 | 12 | ###################### 13 | # TEST CASES 14 | ###################### 15 | 16 | # Define a variable for the test file path. 17 | test test_watch tests: TEST_FILE ?= tests/unit_tests/ 18 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 19 | 20 | # Define a variable for Python and notebook files. 21 | PYTHON_FILES=. 22 | 23 | tests: ## Run all unit tests 24 | poetry run pytest $(TEST_FILE) 25 | 26 | test: ## Run individual unit test: make test TEST_FILE=tests/unit_test/test.py 27 | poetry run pytest $(TEST_FILE) 28 | 29 | integration_tests: ## Run all integration tests 30 | poetry run pytest $(TEST_FILE) 31 | 32 | integration_test: ## Run individual integration test: make integration_test TEST_FILE=tests/integration_tests/integ_test.py 33 | poetry run pytest $(TEST_FILE) 34 | 35 | test_watch: ## Run and interactively watch unit tests 36 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 37 | 38 | ###################### 39 | # LINTING AND FORMATTING 40 | ###################### 41 | 42 | # Define a variable for Python and notebook files. 43 | PYTHON_FILES=. 44 | MYPY_CACHE=.mypy_cache 45 | lint format: PYTHON_FILES=. 46 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/aws --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 47 | lint_package: PYTHON_FILES=langchain_aws 48 | lint_tests: PYTHON_FILES=tests 49 | lint_tests: MYPY_CACHE=.mypy_cache_test 50 | 51 | lint: ## Run linter 52 | poetry run ruff check 53 | 54 | lint_diff: ## Run linter 55 | poetry run ruff format $(PYTHON_FILES) --diff 56 | 57 | lint_package: ## Run linter on package 58 | poetry run ruff check --select I $(PYTHON_FILES) 59 | 60 | lint_tests: ## Run linter tests 61 | mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 62 | 63 | format: ## Run code formatter 64 | poetry run ruff format $(PYTHON_FILES) 65 | 66 | format_diff: ## Run code formatter and show differences 67 | poetry run ruff check --select I --fix $(PYTHON_FILES) 68 | 69 | spell_check: ## Run code spell check 70 | poetry run codespell --toml pyproject.toml 71 | 72 | spell_fix: ## Run code spell fix 73 | poetry run codespell --toml pyproject.toml -w 74 | 75 | ###################### 76 | # DEPENDENCIES 77 | ###################### 78 | 79 | install_dev: ## Install development environment 80 | @pip install --no-cache -U poetry 81 | @poetry install --with dev,test,codespell,lint,typing 82 | 83 | check_imports: $(shell find langchain_aws -name '*.py') ## Check missing imports 84 | @poetry run python ./scripts/check_imports.py $^ 85 | 86 | ###################### 87 | # HELP 88 | ###################### 89 | 90 | help: ## Print this help 91 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 92 | -------------------------------------------------------------------------------- /libs/aws/README.md: -------------------------------------------------------------------------------- 1 | # langchain-aws 2 | 3 | This package contains the LangChain integrations with AWS. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install -U langchain-aws 9 | ``` 10 | All integrations in this package assume that you have the credentials setup to connect with AWS services. 11 | 12 | ## Chat Models 13 | 14 | `ChatBedrock` class exposes chat models from Bedrock. 15 | 16 | ```python 17 | from langchain_aws import ChatBedrock 18 | 19 | llm = ChatBedrock() 20 | llm.invoke("Sing a ballad of LangChain.") 21 | ``` 22 | 23 | ## Embeddings 24 | 25 | `BedrockEmbeddings` class exposes embeddings from Bedrock. 26 | 27 | ```python 28 | from langchain_aws import BedrockEmbeddings 29 | 30 | embeddings = BedrockEmbeddings() 31 | embeddings.embed_query("What is the meaning of life?") 32 | ``` 33 | 34 | ## LLMs 35 | `BedrockLLM` class exposes LLMs from Bedrock. 36 | 37 | ```python 38 | from langchain_aws import BedrockLLM 39 | 40 | llm = BedrockLLM() 41 | llm.invoke("The meaning of life is") 42 | ``` 43 | 44 | ## Retrievers 45 | `AmazonKendraRetriever` class provides a retriever to connect with Amazon Kendra. 46 | 47 | ```python 48 | from langchain_aws import AmazonKendraRetriever 49 | 50 | retriever = AmazonKendraRetriever( 51 | index_id="561be2b6d-9804c7e7-f6a0fbb8-5ccd350" 52 | ) 53 | 54 | retriever.get_relevant_documents(query="What is the meaning of life?") 55 | ``` 56 | 57 | `AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. 58 | 59 | ```python 60 | from langchain_aws import AmazonKnowledgeBasesRetriever 61 | 62 | retriever = AmazonKnowledgeBasesRetriever( 63 | knowledge_base_id="IAPJ4QPUEU", 64 | retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, 65 | ) 66 | 67 | retriever.get_relevant_documents(query="What is the meaning of life?") 68 | ``` 69 | ## VectorStores 70 | `InMemoryVectorStore` class provides a vectorstore to connect with Amazon MemoryDB. 71 | 72 | ```python 73 | from langchain_aws.vectorstores.inmemorydb import InMemoryVectorStore 74 | 75 | vds = InMemoryVectorStore.from_documents( 76 | chunks, 77 | embeddings, 78 | redis_url="rediss://cluster_endpoint:6379/ssl=True ssl_cert_reqs=none", 79 | vector_schema=vector_schema, 80 | index_name=INDEX_NAME, 81 | ) 82 | ``` 83 | 84 | ## MemoryDB as Retriever 85 | 86 | Here we go over different options for using the vector store as a retriever. 87 | 88 | There are three different search methods we can use to do retrieval. By default, it will use semantic similarity. 89 | 90 | ```python 91 | retriever=vds.as_retriever() 92 | ``` 93 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.chains import ( 2 | create_neptune_opencypher_qa_chain, 3 | create_neptune_sparql_qa_chain, 4 | ) 5 | from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse 6 | from langchain_aws.document_compressors.rerank import BedrockRerank 7 | from langchain_aws.embeddings import BedrockEmbeddings 8 | from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph 9 | from langchain_aws.llms import BedrockLLM, SagemakerEndpoint 10 | from langchain_aws.retrievers import ( 11 | AmazonKendraRetriever, 12 | AmazonKnowledgeBasesRetriever, 13 | ) 14 | from langchain_aws.vectorstores.inmemorydb import ( 15 | InMemorySemanticCache, 16 | InMemoryVectorStore, 17 | ) 18 | 19 | 20 | def setup_logging(): 21 | import logging 22 | import os 23 | 24 | if os.environ.get("LANGCHAIN_AWS_DEBUG", "FALSE").lower() in ["true", "1"]: 25 | DEFAULT_LOG_FORMAT = ( 26 | "%(asctime)s %(levelname)s | [%(filename)s:%(lineno)s]" 27 | "| %(name)s - %(message)s" 28 | ) 29 | log_formatter = logging.Formatter(DEFAULT_LOG_FORMAT) 30 | log_handler = logging.StreamHandler() 31 | log_handler.setFormatter(log_formatter) 32 | logging.getLogger("langchain_aws").addHandler(log_handler) 33 | logging.getLogger("langchain_aws").setLevel(logging.DEBUG) 34 | 35 | 36 | setup_logging() 37 | 38 | __all__ = [ 39 | "BedrockEmbeddings", 40 | "BedrockLLM", 41 | "ChatBedrock", 42 | "ChatBedrockConverse", 43 | "SagemakerEndpoint", 44 | "AmazonKendraRetriever", 45 | "AmazonKnowledgeBasesRetriever", 46 | "create_neptune_opencypher_qa_chain", 47 | "create_neptune_sparql_qa_chain", 48 | "NeptuneAnalyticsGraph", 49 | "NeptuneGraph", 50 | "InMemoryVectorStore", 51 | "InMemorySemanticCache", 52 | "BedrockRerank", 53 | ] 54 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.agents.base import ( 2 | BedrockAgentsRunnable, 3 | BedrockInlineAgentsRunnable, 4 | ) 5 | from langchain_aws.agents.types import ( 6 | BedrockAgentAction, 7 | BedrockAgentFinish, 8 | ) 9 | 10 | __all__ = [ 11 | "BedrockAgentAction", 12 | "BedrockAgentFinish", 13 | "BedrockAgentsRunnable", 14 | "BedrockInlineAgentsRunnable", 15 | ] 16 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/agents/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, List, Optional, Union 4 | 5 | from langchain_core.agents import AgentAction, AgentFinish 6 | from langchain_core.tools import BaseTool 7 | from typing_extensions import TypedDict 8 | 9 | _DEFAULT_ACTION_GROUP_NAME = "DEFAULT_AG_" 10 | _TEST_AGENT_ALIAS_ID = "TSTALIASID" 11 | 12 | 13 | class BedrockAgentFinish(AgentFinish): 14 | """AgentFinish with session id information. 15 | 16 | Parameters: 17 | session_id: Session id 18 | trace_log: trace log as string when enable_trace flag is set 19 | """ 20 | 21 | session_id: str 22 | trace_log: Optional[str] 23 | 24 | @classmethod 25 | def is_lc_serializable(cls) -> bool: 26 | """Check if the class is serializable by LangChain. 27 | 28 | Returns: 29 | False 30 | """ 31 | return False 32 | 33 | 34 | class BedrockAgentAction(AgentAction): 35 | """AgentAction with session id information. 36 | 37 | Parameters: 38 | session_id: session id 39 | trace_log: trace log as string when enable_trace flag is set 40 | """ 41 | 42 | session_id: str 43 | trace_log: Optional[str] 44 | invocation_id: Optional[str] 45 | 46 | @classmethod 47 | def is_lc_serializable(cls) -> bool: 48 | """Check if the class is serializable by LangChain. 49 | 50 | Returns: 51 | False 52 | """ 53 | return False 54 | 55 | 56 | OutputType = Union[List[BedrockAgentAction], BedrockAgentFinish] 57 | 58 | 59 | class GuardrailConfiguration(TypedDict): 60 | guardrail_identifier: str 61 | guardrail_version: str 62 | 63 | 64 | class KnowledgebaseConfiguration(TypedDict, total=False): 65 | description: str 66 | knowledgeBaseId: str 67 | retrievalConfiguration: Dict 68 | 69 | 70 | class InlineAgentConfiguration(TypedDict, total=False): 71 | """Configurations for an Inline Agent.""" 72 | 73 | foundation_model: str 74 | instruction: str 75 | enable_trace: Optional[bool] 76 | tools: List[BaseTool] 77 | enable_human_input: Optional[bool] 78 | enable_code_interpreter: Optional[bool] 79 | customer_encryption_key_arn: Optional[str] 80 | idle_session_ttl_in_seconds: Optional[int] 81 | guardrail_configuration: Optional[GuardrailConfiguration] 82 | knowledge_bases: Optional[KnowledgebaseConfiguration] 83 | prompt_override_configuration: Optional[Dict] 84 | inline_session_state: Optional[Dict] 85 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chains/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.chains.graph_qa import ( 2 | create_neptune_opencypher_qa_chain, 3 | create_neptune_sparql_qa_chain, 4 | ) 5 | 6 | __all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"] 7 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chains/graph_qa/__init__.py: -------------------------------------------------------------------------------- 1 | from .neptune_cypher import create_neptune_opencypher_qa_chain 2 | from .neptune_sparql import create_neptune_sparql_qa_chain 3 | 4 | __all__ = ["create_neptune_opencypher_qa_chain", "create_neptune_sparql_qa_chain"] 5 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chains/graph_qa/neptune_cypher.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from typing import Optional, Union 5 | 6 | from langchain_core.language_models import BaseLanguageModel 7 | from langchain_core.prompts.base import BasePromptTemplate 8 | from langchain_core.runnables import Runnable, RunnablePassthrough 9 | 10 | from langchain_aws.graphs import BaseNeptuneGraph 11 | 12 | from .prompts import ( 13 | CYPHER_QA_PROMPT, 14 | NEPTUNE_OPENCYPHER_GENERATION_PROMPT, 15 | NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, 16 | ) 17 | 18 | INTERMEDIATE_STEPS_KEY = "intermediate_steps" 19 | 20 | 21 | def trim_query(query: str) -> str: 22 | """Trim the query to only include Cypher keywords.""" 23 | keywords = ( 24 | "CALL", 25 | "CREATE", 26 | "DELETE", 27 | "DETACH", 28 | "LIMIT", 29 | "MATCH", 30 | "MERGE", 31 | "OPTIONAL", 32 | "ORDER", 33 | "REMOVE", 34 | "RETURN", 35 | "SET", 36 | "SKIP", 37 | "UNWIND", 38 | "WITH", 39 | "WHERE", 40 | "//", 41 | ) 42 | 43 | lines = query.split("\n") 44 | new_query = "" 45 | 46 | for line in lines: 47 | if line.strip().upper().startswith(keywords): 48 | new_query += line + "\n" 49 | 50 | return new_query 51 | 52 | 53 | def extract_cypher(text: str) -> str: 54 | """Extract Cypher code from text using Regex.""" 55 | # The pattern to find Cypher code enclosed in triple backticks 56 | pattern = r"```(.*?)```" 57 | 58 | # Find all matches in the input text 59 | matches = re.findall(pattern, text, re.DOTALL) 60 | 61 | return matches[0] if matches else text 62 | 63 | 64 | def use_simple_prompt(llm: BaseLanguageModel) -> bool: 65 | """Decides whether to use the simple prompt""" 66 | if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore 67 | return True 68 | 69 | # Bedrock anthropic 70 | if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore 71 | return True 72 | 73 | return False 74 | 75 | 76 | def get_prompt(llm: BaseLanguageModel) -> BasePromptTemplate: 77 | """Selects the final prompt""" 78 | if use_simple_prompt(llm): 79 | return NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT 80 | else: 81 | return NEPTUNE_OPENCYPHER_GENERATION_PROMPT 82 | 83 | 84 | def create_neptune_opencypher_qa_chain( 85 | llm: BaseLanguageModel, 86 | graph: BaseNeptuneGraph, 87 | qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, 88 | cypher_prompt: Optional[BasePromptTemplate] = None, 89 | return_intermediate_steps: bool = False, 90 | return_direct: bool = False, 91 | extra_instructions: Optional[str] = None, 92 | allow_dangerous_requests: bool = False, 93 | ) -> Runnable: 94 | """Chain for question-answering against a Neptune graph 95 | by generating openCypher statements. 96 | 97 | *Security note*: Make sure that the database connection uses credentials 98 | that are narrowly-scoped to only include necessary permissions. 99 | Failure to do so may result in data corruption or loss, since the calling 100 | code may attempt commands that would result in deletion, mutation 101 | of data if appropriately prompted or reading sensitive data if such 102 | data is present in the database. 103 | The best way to guard against such negative outcomes is to (as appropriate) 104 | limit the permissions granted to the credentials used with this tool. 105 | 106 | See https://python.langchain.com/docs/security for more information. 107 | 108 | Example: 109 | .. code-block:: python 110 | 111 | chain = create_neptune_opencypher_qa_chain( 112 | llm=llm, 113 | graph=graph 114 | ) 115 | response = chain.invoke({"query": "your_query_here"}) 116 | """ 117 | 118 | if allow_dangerous_requests is not True: 119 | raise ValueError( 120 | "In order to use this chain, you must acknowledge that it can make " 121 | "dangerous requests by setting `allow_dangerous_requests` to `True`. " 122 | "You must narrowly scope the permissions of the database connection " 123 | "to only include necessary permissions. Failure to do so may result " 124 | "in data corruption or loss or reading sensitive data if such data is " 125 | "present in the database. " 126 | "Only use this chain if you understand the risks and have taken the " 127 | "necessary precautions. " 128 | "See https://python.langchain.com/docs/security for more information." 129 | ) 130 | 131 | qa_chain = qa_prompt | llm 132 | 133 | _cypher_prompt = cypher_prompt or get_prompt(llm) 134 | cypher_generation_chain = _cypher_prompt | llm 135 | 136 | def normalize_input(raw_input: Union[str, dict]) -> dict: 137 | if isinstance(raw_input, str): 138 | return {"query": raw_input} 139 | return raw_input 140 | 141 | def execute_graph_query(cypher_query: str) -> dict: 142 | return graph.query(cypher_query) 143 | 144 | def get_cypher_inputs(inputs: dict) -> dict: 145 | return { 146 | "question": inputs["query"], 147 | "schema": graph.get_schema, 148 | "extra_instructions": extra_instructions or "", 149 | } 150 | 151 | def get_qa_inputs(inputs: dict) -> dict: 152 | return { 153 | "question": inputs["query"], 154 | "context": inputs["context"], 155 | } 156 | 157 | def format_response(inputs: dict) -> dict: 158 | intermediate_steps = [{"query": inputs["cypher"]}] 159 | 160 | if return_direct: 161 | final_response = {"result": inputs["context"]} 162 | else: 163 | final_response = {"result": inputs["qa_result"]} 164 | intermediate_steps.append({"context": inputs["context"]}) 165 | 166 | if return_intermediate_steps: 167 | final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps 168 | 169 | return final_response 170 | 171 | chain_result = ( 172 | normalize_input 173 | | RunnablePassthrough.assign(cypher_generation_inputs=get_cypher_inputs) 174 | | { 175 | "query": lambda x: x["query"], 176 | "cypher": (lambda x: x["cypher_generation_inputs"]) 177 | | cypher_generation_chain 178 | | (lambda x: extract_cypher(x.content)) 179 | | trim_query, 180 | } 181 | | RunnablePassthrough.assign(context=lambda x: execute_graph_query(x["cypher"])) 182 | | RunnablePassthrough.assign(qa_result=(lambda x: get_qa_inputs(x)) | qa_chain) 183 | | format_response 184 | ) 185 | 186 | return chain_result 187 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chains/graph_qa/neptune_sparql.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question answering over an RDF or OWL graph using SPARQL. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Any, Optional, Union 8 | 9 | from langchain_core.language_models import BaseLanguageModel 10 | from langchain_core.prompts.base import BasePromptTemplate 11 | from langchain_core.prompts.prompt import PromptTemplate 12 | from langchain_core.runnables import Runnable, RunnablePassthrough 13 | 14 | from langchain_aws.graphs import NeptuneRdfGraph 15 | 16 | from .prompts import ( 17 | NEPTUNE_SPARQL_GENERATION_PROMPT, 18 | NEPTUNE_SPARQL_GENERATION_TEMPLATE, 19 | SPARQL_QA_PROMPT, 20 | ) 21 | 22 | INTERMEDIATE_STEPS_KEY = "intermediate_steps" 23 | 24 | 25 | def extract_sparql(query: str) -> str: 26 | """Extract SPARQL code from a text. 27 | 28 | Args: 29 | query: Text to extract SPARQL code from. 30 | 31 | Returns: 32 | SPARQL code extracted from the text. 33 | """ 34 | query = query.strip() 35 | querytoks = query.split("```") 36 | if len(querytoks) == 3: 37 | query = querytoks[1] 38 | 39 | if query.startswith("sparql"): 40 | query = query[6:] 41 | elif query.startswith("") and query.endswith(""): 42 | query = query[8:-9] 43 | return query 44 | 45 | 46 | def get_prompt(examples: str) -> BasePromptTemplate: 47 | """Selects the final prompt.""" 48 | template_to_use = NEPTUNE_SPARQL_GENERATION_TEMPLATE 49 | if examples: 50 | template_to_use = template_to_use.replace("Examples:", "Examples: " + examples) 51 | return PromptTemplate( 52 | input_variables=["schema", "prompt"], template=template_to_use 53 | ) 54 | return NEPTUNE_SPARQL_GENERATION_PROMPT 55 | 56 | 57 | def create_neptune_sparql_qa_chain( 58 | llm: BaseLanguageModel, 59 | graph: NeptuneRdfGraph, 60 | qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, 61 | sparql_prompt: Optional[BasePromptTemplate] = None, 62 | return_intermediate_steps: bool = False, 63 | return_direct: bool = False, 64 | extra_instructions: Optional[str] = None, 65 | allow_dangerous_requests: bool = False, 66 | examples: Optional[str] = None, 67 | ) -> Runnable[Any, dict]: 68 | """Chain for question-answering against a Neptune graph 69 | by generating SPARQL statements. 70 | 71 | *Security note*: Make sure that the database connection uses credentials 72 | that are narrowly-scoped to only include necessary permissions. 73 | Failure to do so may result in data corruption or loss, since the calling 74 | code may attempt commands that would result in deletion, mutation 75 | of data if appropriately prompted or reading sensitive data if such 76 | data is present in the database. 77 | The best way to guard against such negative outcomes is to (as appropriate) 78 | limit the permissions granted to the credentials used with this tool. 79 | 80 | See https://python.langchain.com/docs/security for more information. 81 | 82 | Example: 83 | .. code-block:: python 84 | 85 | chain = create_neptune_sparql_qa_chain( 86 | llm=llm, 87 | graph=graph 88 | ) 89 | response = chain.invoke({"query": "your_query_here"}) 90 | """ 91 | if allow_dangerous_requests is not True: 92 | raise ValueError( 93 | "In order to use this chain, you must acknowledge that it can make " 94 | "dangerous requests by setting `allow_dangerous_requests` to `True`. " 95 | "You must narrowly scope the permissions of the database connection " 96 | "to only include necessary permissions. Failure to do so may result " 97 | "in data corruption or loss or reading sensitive data if such data is " 98 | "present in the database. " 99 | "Only use this chain if you understand the risks and have taken the " 100 | "necessary precautions. " 101 | "See https://python.langchain.com/docs/security for more information." 102 | ) 103 | 104 | qa_chain = qa_prompt | llm 105 | 106 | _sparql_prompt = sparql_prompt or get_prompt(examples) 107 | sparql_generation_chain = _sparql_prompt | llm 108 | 109 | def normalize_input(raw_input: Union[str, dict]) -> dict: 110 | if isinstance(raw_input, str): 111 | return {"query": raw_input} 112 | return raw_input 113 | 114 | def execute_graph_query(sparql_query: str) -> dict: 115 | return graph.query(sparql_query) 116 | 117 | def get_sparql_inputs(inputs: dict) -> dict: 118 | return { 119 | "prompt": inputs["query"], 120 | "schema": graph.get_schema, 121 | "extra_instructions": extra_instructions or "", 122 | } 123 | 124 | def get_qa_inputs(inputs: dict) -> dict: 125 | return { 126 | "prompt": inputs["query"], 127 | "context": inputs["context"], 128 | } 129 | 130 | def format_response(inputs: dict) -> dict: 131 | intermediate_steps = [{"query": inputs["sparql"]}] 132 | 133 | if return_direct: 134 | final_response = {"result": inputs["context"]} 135 | else: 136 | final_response = {"result": inputs["qa_result"]} 137 | intermediate_steps.append({"context": inputs["context"]}) 138 | 139 | if return_intermediate_steps: 140 | final_response[INTERMEDIATE_STEPS_KEY] = intermediate_steps 141 | 142 | return final_response 143 | 144 | chain_result = ( 145 | normalize_input 146 | | RunnablePassthrough.assign(sparql_generation_inputs=get_sparql_inputs) 147 | | { 148 | "query": lambda x: x["query"], 149 | "sparql": (lambda x: x["sparql_generation_inputs"]) 150 | | sparql_generation_chain 151 | | (lambda x: extract_sparql(x.content)), 152 | } 153 | | RunnablePassthrough.assign(context=lambda x: execute_graph_query(x["sparql"])) 154 | | RunnablePassthrough.assign(qa_result=(lambda x: get_qa_inputs(x)) | qa_chain) 155 | | format_response 156 | ) 157 | 158 | return chain_result 159 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chains/graph_qa/prompts.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from langchain_core.prompts.prompt import PromptTemplate 3 | 4 | CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. 5 | Instructions: 6 | Use only the provided relationship types and properties in the schema. 7 | Do not use any other relationship types or properties that are not provided. 8 | Schema: 9 | {schema} 10 | Note: Do not include any explanations or apologies in your responses. 11 | Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. 12 | Do not include any text except the generated Cypher statement. 13 | 14 | The question is: 15 | {question}""" 16 | CYPHER_GENERATION_PROMPT = PromptTemplate( 17 | input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE 18 | ) 19 | 20 | CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. 21 | The information part contains the provided information that you must use to construct an answer. 22 | The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. 23 | Make the answer sound as a response to the question. Do not mention that you based the result on the given information. 24 | Here is an example: 25 | 26 | Question: Which managers own Neo4j stocks? 27 | Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] 28 | Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. 29 | 30 | Follow this example when generating answers. 31 | If the provided information is empty, say that you don't know the answer. 32 | Information: 33 | {context} 34 | 35 | Question: {question} 36 | Helpful Answer:""" 37 | CYPHER_QA_PROMPT = PromptTemplate( 38 | input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE 39 | ) 40 | 41 | SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query. 42 | You are an assistant that creates well-written and human understandable answers. 43 | The information part contains the information provided, which you can use to construct an answer. 44 | The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it. 45 | Make your response sound like the information is coming from an AI assistant, but don't add any information. 46 | Information: 47 | {context} 48 | 49 | Question: {prompt} 50 | Helpful Answer:""" 51 | SPARQL_QA_PROMPT = PromptTemplate( 52 | input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE 53 | ) 54 | 55 | NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ 56 | Instructions: 57 | Generate the query in openCypher format and follow these rules: 58 | Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. 59 | Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. 60 | Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions} 61 | \n""" 62 | 63 | NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( 64 | "Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS 65 | ) 66 | 67 | NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( 68 | input_variables=["schema", "question", "extra_instructions"], 69 | template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, 70 | ) 71 | 72 | NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ 73 | Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions} 74 | Question: "{question}". 75 | Here is the property graph schema: 76 | {schema} 77 | \n""" 78 | 79 | NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( 80 | input_variables=["schema", "question", "extra_instructions"], 81 | template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, 82 | ) 83 | 84 | NEPTUNE_SPARQL_GENERATION_TEMPLATE = """ 85 | Task: Generate a SPARQL SELECT statement for querying a graph database. 86 | For instance, to find all email addresses of John Doe, the following 87 | query in backticks would be suitable: 88 | ``` 89 | PREFIX foaf: 90 | SELECT ?email 91 | WHERE {{ 92 | ?person foaf:name "John Doe" . 93 | ?person foaf:mbox ?email . 94 | }} 95 | ``` 96 | Instructions: 97 | Use only the node types and properties provided in the schema. 98 | Do not use any node types and properties that are not explicitly provided. 99 | Include all necessary prefixes. 100 | 101 | Examples: 102 | 103 | Schema: 104 | {schema} 105 | Note: Be as concise as possible. 106 | Do not include any explanations or apologies in your responses. 107 | Do not respond to any questions that ask for anything else than 108 | for you to construct a SPARQL query. 109 | Do not include any text except the SPARQL query generated. 110 | 111 | The question is: 112 | {prompt}""" 113 | 114 | NEPTUNE_SPARQL_GENERATION_PROMPT = PromptTemplate( 115 | input_variables=["schema", "prompt"], template=NEPTUNE_SPARQL_GENERATION_TEMPLATE 116 | ) 117 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/chat_models/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.chat_models.bedrock import ChatBedrock 2 | from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse 3 | 4 | __all__ = ["ChatBedrock", "ChatBedrockConverse"] 5 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/document_compressors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/langchain_aws/document_compressors/__init__.py -------------------------------------------------------------------------------- /libs/aws/langchain_aws/document_compressors/rerank.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict, List, Optional, Sequence, Union 3 | 4 | from langchain_core.callbacks.manager import Callbacks 5 | from langchain_core.documents import BaseDocumentCompressor, Document 6 | from langchain_core.utils import from_env, secret_from_env 7 | from pydantic import ConfigDict, Field, SecretStr, model_validator 8 | 9 | from langchain_aws.utils import create_aws_client 10 | 11 | 12 | class BedrockRerank(BaseDocumentCompressor): 13 | """Document compressor that uses AWS Bedrock Rerank API.""" 14 | 15 | model_arn: str 16 | """The ARN of the reranker model.""" 17 | 18 | client: Any = Field(default=None, exclude=True) #: :meta private: 19 | """Bedrock client to use for compressing documents.""" 20 | 21 | top_n: Optional[int] = 3 22 | """Number of documents to return.""" 23 | 24 | region_name: Optional[str] = None 25 | """The aws region, e.g., `us-west-2`. 26 | 27 | Falls back to AWS_REGION or AWS_DEFAULT_REGION env variable or region specified in 28 | ~/.aws/config in case it is not provided here. 29 | """ 30 | 31 | credentials_profile_name: Optional[str] = Field( 32 | default_factory=from_env("AWS_PROFILE", default=None) 33 | ) 34 | """AWS profile for authentication, optional.""" 35 | 36 | aws_access_key_id: Optional[SecretStr] = Field( 37 | default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None) 38 | ) 39 | """AWS access key id. 40 | 41 | If provided, aws_secret_access_key must also be provided. 42 | If not specified, the default credential profile or, if on an EC2 instance, 43 | credentials from IMDS will be used. 44 | See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 45 | 46 | If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable. 47 | """ 48 | 49 | aws_secret_access_key: Optional[SecretStr] = Field( 50 | default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None) 51 | ) 52 | """AWS secret_access_key. 53 | 54 | If provided, aws_access_key_id must also be provided. 55 | If not specified, the default credential profile or, if on an EC2 instance, 56 | credentials from IMDS will be used. 57 | See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 58 | 59 | If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable. 60 | """ 61 | 62 | aws_session_token: Optional[SecretStr] = Field( 63 | default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None) 64 | ) 65 | """AWS session token. 66 | 67 | If provided, aws_access_key_id and aws_secret_access_key must 68 | also be provided. Not required unless using temporary credentials. 69 | See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 70 | 71 | If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable. 72 | """ 73 | 74 | endpoint_url: Optional[str] = Field(default=None, alias="base_url") 75 | """Needed if you don't want to default to us-east-1 endpoint""" 76 | 77 | config: Any = None 78 | """An optional botocore.config.Config instance to pass to the client.""" 79 | 80 | model_config = ConfigDict( 81 | extra="forbid", 82 | arbitrary_types_allowed=True, 83 | ) 84 | 85 | @model_validator(mode="before") 86 | @classmethod 87 | def initialize_client(cls, values: Dict[str, Any]) -> Any: 88 | """Initialize the AWS Bedrock client.""" 89 | if not values.get("client"): 90 | values["client"] = create_aws_client( 91 | region_name=values.get("region_name"), 92 | credentials_profile_name=values.get("credentials_profile_name"), 93 | aws_access_key_id=values.get("aws_access_key_id"), 94 | aws_secret_access_key=values.get("aws_secret_access_key"), 95 | aws_session_token=values.get("aws_session_token"), 96 | endpoint_url=values.get("endpoint_url"), 97 | config=values.get("config"), 98 | service_name="bedrock-agent-runtime", 99 | ) 100 | return values 101 | 102 | def rerank( 103 | self, 104 | documents: Sequence[Union[str, Document, dict]], 105 | query: str, 106 | top_n: Optional[int] = None, 107 | additional_model_request_fields: Optional[Dict[str, Any]] = None, 108 | ) -> List[Dict[str, Any]]: 109 | """Returns an ordered list of documents based on their relevance to the query. 110 | 111 | Args: 112 | query: The query to use for reranking. 113 | documents: A sequence of documents to rerank. 114 | top_n: The number of top-ranked results to return. Defaults to self.top_n. 115 | additional_model_request_fields: Additional fields to pass to the model. 116 | 117 | Returns: 118 | List[Dict[str, Any]]: A list of ranked documents with relevance scores. 119 | """ 120 | if len(documents) == 0: 121 | return [] 122 | 123 | # Serialize documents for the Bedrock API 124 | serialized_documents = [ 125 | {"textDocument": {"text": doc.page_content}, "type": "TEXT"} 126 | if isinstance(doc, Document) 127 | else {"textDocument": {"text": doc}, "type": "TEXT"} 128 | if isinstance(doc, str) 129 | else {"jsonDocument": doc, "type": "JSON"} 130 | for doc in documents 131 | ] 132 | 133 | request_body = { 134 | "queries": [{"textQuery": {"text": query}, "type": "TEXT"}], 135 | "rerankingConfiguration": { 136 | "bedrockRerankingConfiguration": { 137 | "modelConfiguration": { 138 | "modelArn": self.model_arn, 139 | "additionalModelRequestFields": additional_model_request_fields 140 | or {}, 141 | }, 142 | "numberOfResults": top_n or self.top_n, 143 | }, 144 | "type": "BEDROCK_RERANKING_MODEL", 145 | }, 146 | "sources": [ 147 | {"inlineDocumentSource": doc, "type": "INLINE"} 148 | for doc in serialized_documents 149 | ], 150 | } 151 | 152 | response = self.client.rerank(**request_body) 153 | response_body = response.get("results", []) 154 | 155 | results = [ 156 | {"index": result["index"], "relevance_score": result["relevanceScore"]} 157 | for result in response_body 158 | ] 159 | 160 | return results 161 | 162 | def compress_documents( 163 | self, 164 | documents: Sequence[Document], 165 | query: str, 166 | callbacks: Optional[Callbacks] = None, 167 | ) -> Sequence[Document]: 168 | """ 169 | Compress documents using Bedrock's rerank API. 170 | 171 | Args: 172 | documents: A sequence of documents to compress. 173 | query: The query to use for compressing the documents. 174 | callbacks: Callbacks to run during the compression process. 175 | 176 | Returns: 177 | A sequence of compressed documents. 178 | """ 179 | compressed = [] 180 | for res in self.rerank(documents, query): 181 | doc = documents[res["index"]] 182 | doc_copy = Document(**doc.model_dump()) 183 | doc_copy.metadata["relevance_score"] = res["relevance_score"] 184 | compressed.append(doc_copy) 185 | return compressed 186 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.embeddings.bedrock import BedrockEmbeddings 2 | 3 | __all__ = ["BedrockEmbeddings"] 4 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/function_calling.py: -------------------------------------------------------------------------------- 1 | """Methods for creating function specs in the style of Bedrock Functions 2 | for supported model providers""" 3 | 4 | import json 5 | from typing import ( 6 | Annotated, 7 | Any, 8 | Callable, 9 | Dict, 10 | List, 11 | Literal, 12 | Optional, 13 | Union, 14 | cast, 15 | ) 16 | 17 | from langchain_core.messages import ToolCall 18 | from langchain_core.output_parsers import BaseGenerationOutputParser 19 | from langchain_core.outputs import ChatGeneration, Generation 20 | from langchain_core.prompts.chat import AIMessage 21 | from langchain_core.tools import BaseTool 22 | from langchain_core.utils.function_calling import convert_to_openai_tool 23 | from langchain_core.utils.pydantic import TypeBaseModel 24 | from pydantic import BaseModel, ConfigDict, SkipValidation 25 | from typing_extensions import TypedDict 26 | 27 | PYTHON_TO_JSON_TYPES = { 28 | "str": "string", 29 | "int": "integer", 30 | "float": "number", 31 | "bool": "boolean", 32 | } 33 | 34 | SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question. 35 | 36 | You may call them like this: 37 | 38 | 39 | $TOOL_NAME 40 | 41 | <$PARAMETER_NAME>$PARAMETER_VALUE 42 | ... 43 | 44 | 45 | 46 | 47 | Here are the tools available: 48 | 49 | {formatted_tools} 50 | """ # noqa: E501 51 | 52 | TOOL_FORMAT = """ 53 | {tool_name} 54 | {tool_description} 55 | 56 | {formatted_parameters} 57 | 58 | """ 59 | 60 | TOOL_PARAMETER_FORMAT = """ 61 | {parameter_name} 62 | {parameter_type} 63 | {parameter_description} 64 | """ 65 | 66 | 67 | class AnthropicTool(TypedDict): 68 | name: str 69 | description: str 70 | input_schema: Dict[str, Any] 71 | 72 | 73 | def _tools_in_params(params: dict) -> bool: 74 | return "tools" in params or ( 75 | "extra_body" in params and params["extra_body"].get("tools") 76 | ) 77 | 78 | 79 | class _AnthropicToolUse(TypedDict): 80 | type: Literal["tool_use"] 81 | name: str 82 | input: dict 83 | id: str 84 | 85 | 86 | def _lc_tool_calls_to_anthropic_tool_use_blocks( 87 | tool_calls: List[ToolCall], 88 | ) -> List[_AnthropicToolUse]: 89 | blocks = [] 90 | for tool_call in tool_calls: 91 | blocks.append( 92 | _AnthropicToolUse( 93 | type="tool_use", 94 | name=tool_call["name"], 95 | input=tool_call["args"], 96 | id=cast(str, tool_call["id"]), 97 | ) 98 | ) 99 | return blocks 100 | 101 | 102 | def _get_type(parameter: Dict[str, Any]) -> str: 103 | if "type" in parameter: 104 | return parameter["type"] 105 | if "anyOf" in parameter: 106 | return json.dumps({"anyOf": parameter["anyOf"]}) 107 | if "allOf" in parameter: 108 | return json.dumps({"allOf": parameter["allOf"]}) 109 | return json.dumps(parameter) 110 | 111 | 112 | def get_system_message(tools: List[AnthropicTool]) -> str: 113 | tools_data: List[Dict] = [ 114 | { 115 | "tool_name": tool["name"], 116 | "tool_description": tool["description"], 117 | "formatted_parameters": "\n".join( 118 | [ 119 | TOOL_PARAMETER_FORMAT.format( 120 | parameter_name=name, 121 | parameter_type=_get_type(parameter), 122 | parameter_description=parameter.get("description"), 123 | ) 124 | for name, parameter in tool["input_schema"]["properties"].items() 125 | ] 126 | ), 127 | } 128 | for tool in tools 129 | ] 130 | tools_formatted = "\n".join( 131 | [ 132 | TOOL_FORMAT.format( 133 | tool_name=tool["tool_name"], 134 | tool_description=tool["tool_description"], 135 | formatted_parameters=tool["formatted_parameters"], 136 | ) 137 | for tool in tools_data 138 | ] 139 | ) 140 | return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted) 141 | 142 | 143 | class FunctionDescription(TypedDict): 144 | """Representation of a callable function to send to an LLM.""" 145 | 146 | name: str 147 | """The name of the function.""" 148 | description: str 149 | """A description of the function.""" 150 | parameters: dict 151 | """The parameters of the function.""" 152 | 153 | 154 | class ToolDescription(TypedDict): 155 | """Representation of a callable function to the OpenAI API.""" 156 | 157 | type: Literal["function"] 158 | function: FunctionDescription 159 | 160 | 161 | class ToolsOutputParser(BaseGenerationOutputParser): 162 | first_tool_only: bool = False 163 | args_only: bool = False 164 | pydantic_schemas: Optional[List[Annotated[TypeBaseModel, SkipValidation()]]] = None 165 | 166 | model_config = ConfigDict( 167 | extra="forbid", 168 | ) 169 | 170 | def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: 171 | """Parse a list of candidate model Generations into a specific format. 172 | 173 | Args: 174 | result: A list of Generations to be parsed. The Generations are assumed 175 | to be different candidate outputs for a single model input. 176 | 177 | Returns: 178 | Structured output. 179 | """ 180 | if ( 181 | not result 182 | or not isinstance(result[0], ChatGeneration) 183 | or not isinstance(result[0].message, AIMessage) 184 | or not result[0].message.tool_calls 185 | ): 186 | return None if self.first_tool_only else [] 187 | tool_calls: Any = result[0].message.tool_calls 188 | if self.pydantic_schemas: 189 | tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] 190 | elif self.args_only: 191 | tool_calls = [tc["args"] for tc in tool_calls] 192 | else: 193 | pass 194 | 195 | if self.first_tool_only: 196 | return tool_calls[0] 197 | else: 198 | return tool_calls 199 | 200 | def _pydantic_parse(self, tool_call: ToolCall) -> BaseModel: 201 | cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ 202 | tool_call["name"] 203 | ] 204 | return cls_(**tool_call["args"]) 205 | 206 | 207 | def convert_to_anthropic_tool( 208 | tool: Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool], 209 | ) -> Union[AnthropicTool, Dict]: 210 | # already in Anthropic tool format 211 | if isinstance(tool, dict) and all( 212 | k in tool for k in ("name", "description", "input_schema") 213 | ): 214 | tool["description"] = tool["description"] or tool["name"] 215 | return AnthropicTool(tool) # type: ignore 216 | else: 217 | formatted = convert_to_openai_tool(tool)["function"] 218 | formatted["description"] = formatted.get("description") or formatted["name"] 219 | return AnthropicTool( 220 | name=formatted["name"], 221 | description=formatted["description"], 222 | input_schema=formatted["parameters"], 223 | ) 224 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.graphs.neptune_graph import ( 2 | BaseNeptuneGraph, 3 | NeptuneAnalyticsGraph, 4 | NeptuneGraph, 5 | ) 6 | from langchain_aws.graphs.neptune_rdf_graph import NeptuneRdfGraph 7 | 8 | __all__ = [ 9 | "BaseNeptuneGraph", 10 | "NeptuneAnalyticsGraph", 11 | "NeptuneGraph", 12 | "NeptuneRdfGraph", 13 | ] 14 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/llms/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.llms.bedrock import ( 2 | ALTERNATION_ERROR, 3 | BedrockBase, 4 | BedrockLLM, 5 | LLMInputOutputAdapter, 6 | ) 7 | from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint 8 | 9 | __all__ = [ 10 | "ALTERNATION_ERROR", 11 | "BedrockBase", 12 | "BedrockLLM", 13 | "LLMInputOutputAdapter", 14 | "SagemakerEndpoint", 15 | ] 16 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/langchain_aws/py.typed -------------------------------------------------------------------------------- /libs/aws/langchain_aws/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.retrievers.bedrock import AmazonKnowledgeBasesRetriever 2 | from langchain_aws.retrievers.kendra import AmazonKendraRetriever 3 | 4 | __all__ = [ 5 | "AmazonKendraRetriever", 6 | "AmazonKendraRetriever", 7 | "AmazonKnowledgeBasesRetriever", 8 | ] 9 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/runnables/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_aws.runnables.q_business import AmazonQ 2 | 3 | __all__ = ["AmazonQ"] -------------------------------------------------------------------------------- /libs/aws/langchain_aws/runnables/q_business.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List, Optional, Union 3 | 4 | from langchain_core._api.beta_decorator import beta 5 | from langchain_core.messages.ai import AIMessage 6 | from langchain_core.prompt_values import ChatPromptValue 7 | from langchain_core.runnables import Runnable 8 | from langchain_core.runnables.config import RunnableConfig 9 | from pydantic import ConfigDict 10 | from typing_extensions import Self 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @beta(message="This API is in beta and can change in future.") 16 | class AmazonQ(Runnable[Union[str,ChatPromptValue, List[ChatPromptValue]], ChatPromptValue]): 17 | """Amazon Q Runnable wrapper. 18 | 19 | To authenticate, the AWS client uses the following methods to 20 | automatically load credentials: 21 | https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html 22 | 23 | Make sure the credentials / roles used have the required policies to 24 | access the Amazon Q service. 25 | """ 26 | 27 | region_name: Optional[str] = None 28 | """AWS region name. If not provided, will be extracted from environment.""" 29 | 30 | credentials: Optional[Any] = None 31 | """Amazon Q credentials used to instantiate the client if the client is not provided.""" 32 | 33 | client: Optional[Any] = None 34 | """Amazon Q client.""" 35 | 36 | application_id: str = None 37 | 38 | """Store the full response from Amazon Q.""" 39 | 40 | parent_message_id: Optional[str] = None 41 | 42 | conversation_id: Optional[str] = None 43 | 44 | chat_mode: str = "RETRIEVAL_MODE" 45 | 46 | model_config = ConfigDict( 47 | extra="forbid", 48 | ) 49 | 50 | def __init__( 51 | self, 52 | region_name: Optional[str] = None, 53 | credentials: Optional[Any] = None, 54 | client: Optional[Any] = None, 55 | application_id: str = None, 56 | parent_message_id: Optional[str] = None, 57 | conversation_id: Optional[str] = None, 58 | chat_mode: str = "RETRIEVAL_MODE", 59 | ): 60 | self.region_name = region_name 61 | self.credentials = credentials 62 | self.client = client or self.validate_environment() 63 | self.application_id = application_id 64 | self.parent_message_id = parent_message_id 65 | self.conversation_id = conversation_id 66 | self.chat_mode = chat_mode 67 | 68 | def invoke( 69 | self, 70 | input: Union[str,ChatPromptValue], 71 | config: Optional[RunnableConfig] = None, 72 | **kwargs: Any 73 | ) -> ChatPromptValue: 74 | """Call out to Amazon Q service. 75 | 76 | Args: 77 | input: The prompt to pass into the model. 78 | 79 | Returns: 80 | The string generated by the model. 81 | 82 | Example: 83 | .. code-block:: python 84 | 85 | model = AmazonQ( 86 | credentials=your_credentials, 87 | application_id=your_app_id 88 | ) 89 | response = model.invoke("Tell me a joke") 90 | """ 91 | try: 92 | # Prepare the request 93 | request = { 94 | 'applicationId': self.application_id, 95 | 'userMessage': self.convert_langchain_messages_to_q_input(input), # Langchain's input comes in the form of an array of "messages". We must convert to a single string for Amazon Q's use 96 | 'chatMode': self.chat_mode, 97 | } 98 | if self.conversation_id: 99 | request.update({ 100 | 'conversationId': self.conversation_id, 101 | 'parentMessageId': self.parent_message_id, 102 | }) 103 | 104 | # Call Amazon Q 105 | response = self.client.chat_sync(**request) 106 | 107 | # Extract the response text 108 | if 'systemMessage' in response: 109 | return AIMessage(content=response["systemMessage"], response_metadata=response) 110 | else: 111 | raise ValueError("Unexpected response format from Amazon Q") 112 | 113 | except Exception as e: 114 | if "Prompt Length" in str(e): 115 | logger.info(f"Prompt Length: {len(input)}") 116 | logger.info(f"""Prompt: 117 | {input}""") 118 | raise ValueError(f"Error raised by Amazon Q service: {e}") 119 | 120 | def validate_environment(self) -> Self: 121 | """Don't do anything if client provided externally""" 122 | #If the client is not provided, and the user_id is not provided in the class constructor, throw an error saying one or the other needs to be provided 123 | if self.credentials is None: 124 | raise ValueError( 125 | "Either the credentials or the client needs to be provided." 126 | ) 127 | 128 | """Validate that AWS credentials to and python package exists in environment.""" 129 | try: 130 | import boto3 131 | 132 | try: 133 | if self.region_name is not None: 134 | client = boto3.client('qbusiness', self.region_name, **self.credentials) 135 | else: 136 | # use default region 137 | client = boto3.client('qbusiness', **self.credentials) 138 | 139 | except Exception as e: 140 | raise ValueError( 141 | "Could not load credentials to authenticate with AWS client. " 142 | "Please check that credentials in the specified " 143 | "profile name are valid." 144 | ) from e 145 | 146 | except ImportError: 147 | raise ImportError( 148 | "Could not import boto3 python package. " 149 | "Please install it with `pip install boto3`." 150 | ) 151 | return client 152 | def convert_langchain_messages_to_q_input(self, input: Union[str,ChatPromptValue,List[ChatPromptValue]]) -> str: 153 | #If it is just a string and not a ChatPromptTemplate collection just return string 154 | if type(input) is str: 155 | return input 156 | return input.to_string() 157 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/utilities/math.py: -------------------------------------------------------------------------------- 1 | """Math utils.""" 2 | 3 | import logging 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 11 | 12 | 13 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 14 | """Row-wise cosine similarity between two equal-width matrices.""" 15 | if len(X) == 0 or len(Y) == 0: 16 | return np.array([]) 17 | 18 | X = np.array(X) 19 | Y = np.array(Y) 20 | if X.shape[1] != Y.shape[1]: 21 | raise ValueError( 22 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 23 | f"and Y has shape {Y.shape}." 24 | ) 25 | try: 26 | import simsimd as simd 27 | 28 | X = np.array(X, dtype=np.float32) 29 | Y = np.array(Y, dtype=np.float32) 30 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) 31 | return Z 32 | except ImportError: 33 | logger.debug( 34 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 35 | "to use simsimd please install with `pip install simsimd`." 36 | ) 37 | X_norm = np.linalg.norm(X, axis=1) 38 | Y_norm = np.linalg.norm(Y, axis=1) 39 | # Ignore divide by zero errors run time warnings as those are handled below. 40 | with np.errstate(divide="ignore", invalid="ignore"): 41 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 42 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 43 | return similarity 44 | 45 | 46 | def cosine_similarity_top_k( 47 | X: Matrix, 48 | Y: Matrix, 49 | top_k: Optional[int] = 5, 50 | score_threshold: Optional[float] = None, 51 | ) -> Tuple[List[Tuple[int, int]], List[float]]: 52 | """Row-wise cosine similarity with optional top-k and score threshold filtering. 53 | 54 | Args: 55 | X: Matrix. 56 | Y: Matrix, same width as X. 57 | top_k: Max number of results to return. 58 | score_threshold: Minimum cosine similarity of results. 59 | 60 | Returns: 61 | Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx), 62 | second contains corresponding cosine similarities. 63 | """ 64 | if len(X) == 0 or len(Y) == 0: 65 | return [], [] 66 | score_array = cosine_similarity(X, Y) 67 | score_threshold = score_threshold or -1.0 68 | score_array[score_array < score_threshold] = 0 69 | top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) 70 | top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] # type: ignore 71 | top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] 72 | ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) 73 | scores = score_array.ravel()[top_k_idxs].tolist() 74 | return list(zip(*ret_idxs)), scores # type: ignore 75 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/utilities/redis.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import re 5 | from typing import TYPE_CHECKING, Any, List, Optional, Pattern 6 | 7 | import numpy as np 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | if TYPE_CHECKING: 12 | from redis.client import Redis as RedisType # type: ignore[import-untyped] 13 | 14 | 15 | def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: 16 | return np.array(array).astype(dtype).tobytes() 17 | 18 | 19 | def _buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: 20 | return np.frombuffer(buffer, dtype=dtype).tolist() 21 | 22 | 23 | class TokenEscaper: 24 | """ 25 | Escape punctuation within an input string. 26 | """ 27 | 28 | # Characters that RediSearch requires us to escape during queries. 29 | # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization 30 | DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" 31 | 32 | def __init__(self, escape_chars_re: Optional[Pattern] = None): 33 | if escape_chars_re: 34 | self.escaped_chars_re = escape_chars_re 35 | else: 36 | self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS) 37 | 38 | def escape(self, value: str) -> str: 39 | if not isinstance(value, str): 40 | raise TypeError( 41 | "Value must be a string object for token escaping." 42 | f"Got type {type(value)}" 43 | ) 44 | 45 | def escape_symbol(match: re.Match) -> str: 46 | value = match.group(0) 47 | return f"\\{value}" 48 | 49 | return self.escaped_chars_re.sub(escape_symbol, value) 50 | 51 | 52 | def get_client(redis_url: str, **kwargs: Any) -> RedisType: 53 | """Get a redis client from the connection url given. This helper accepts 54 | urls for Redis server (TCP with/without TLS or UnixSocket) as well as 55 | Redis Sentinel connections. 56 | 57 | Before creating a connection the existence of the database driver is checked 58 | and ValueError raised otherwise. 59 | 60 | To use, you should have the ``redis`` python package installed. 61 | 62 | Example: 63 | .. code-block:: python 64 | 65 | from langchain_community.utilities.redis import get_client 66 | redis_client = get_client( 67 | redis_url="redis://username:password@localhost:6379" 68 | index_name="my-index", 69 | embedding_function=embeddings.embed_query, 70 | ) 71 | 72 | """ 73 | 74 | # Initialize with necessary components. 75 | try: 76 | import redis # type: ignore[import-untyped] 77 | except ImportError: 78 | raise ImportError( 79 | "Could not import redis python package. " 80 | "Please install it with `pip install redis>=4.1.0`." 81 | ) 82 | 83 | # Connect to redis server from url, reconnect with cluster client if needed 84 | redis_client = redis.from_url(redis_url, **kwargs) 85 | if _check_for_cluster(redis_client): 86 | redis_client.close() 87 | redis_client = _redis_cluster_client(redis_url, **kwargs) 88 | 89 | return redis_client 90 | 91 | 92 | def _check_for_cluster(redis_client: RedisType) -> bool: 93 | import redis 94 | 95 | try: 96 | cluster_info = redis_client.info("cluster") 97 | return cluster_info["cluster_enabled"] == 1 98 | except redis.exceptions.RedisError: 99 | return False 100 | 101 | 102 | def _redis_cluster_client(redis_url: str, **kwargs: Any) -> RedisType: 103 | from redis.cluster import RedisCluster # type: ignore[import-untyped] 104 | 105 | return RedisCluster.from_url(redis_url, **kwargs) # type: ignore[return-value] 106 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/utilities/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with vectors and vectorstores.""" 2 | 3 | from enum import Enum 4 | from typing import List, Tuple, Type 5 | 6 | import numpy as np 7 | from langchain_core.documents import Document 8 | 9 | from langchain_aws.utilities.math import cosine_similarity 10 | 11 | 12 | class DistanceStrategy(str, Enum): 13 | """Enumerator of the Distance strategies for calculating distances 14 | between vectors.""" 15 | 16 | EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" 17 | MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" 18 | DOT_PRODUCT = "DOT_PRODUCT" 19 | JACCARD = "JACCARD" 20 | COSINE = "COSINE" 21 | 22 | 23 | def maximal_marginal_relevance( 24 | query_embedding: np.ndarray, 25 | embedding_list: list, 26 | lambda_mult: float = 0.5, 27 | k: int = 4, 28 | ) -> List[int]: 29 | """Calculate maximal marginal relevance.""" 30 | if min(k, len(embedding_list)) <= 0: 31 | return [] 32 | if query_embedding.ndim == 1: 33 | query_embedding = np.expand_dims(query_embedding, axis=0) 34 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 35 | most_similar = int(np.argmax(similarity_to_query)) 36 | idxs = [most_similar] 37 | selected = np.array([embedding_list[most_similar]]) 38 | while len(idxs) < min(k, len(embedding_list)): 39 | best_score = -np.inf 40 | idx_to_add = -1 41 | similarity_to_selected = cosine_similarity(embedding_list, selected) 42 | for i, query_score in enumerate(similarity_to_query): 43 | if i in idxs: 44 | continue 45 | redundant_score = max(similarity_to_selected[i]) 46 | equation_score = ( 47 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 48 | ) 49 | if equation_score > best_score: 50 | best_score = equation_score 51 | idx_to_add = i 52 | idxs.append(idx_to_add) 53 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 54 | return idxs 55 | 56 | 57 | def filter_complex_metadata( 58 | documents: List[Document], 59 | *, 60 | allowed_types: Tuple[Type, ...] = (str, bool, int, float), 61 | ) -> List[Document]: 62 | """Filter out metadata types that are not supported for a vector store.""" 63 | updated_documents = [] 64 | for document in documents: 65 | filtered_metadata = {} 66 | for key, value in document.metadata.items(): 67 | if not isinstance(value, allowed_types): 68 | continue 69 | filtered_metadata[key] = value 70 | 71 | document.metadata = filtered_metadata 72 | updated_documents.append(document) 73 | 74 | return updated_documents 75 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from abc import abstractmethod 4 | from typing import Any, Dict, Generic, Iterator, List, Literal, Optional, TypeVar, Union 5 | 6 | from botocore.exceptions import BotoCoreError, UnknownServiceError 7 | from packaging import version 8 | from pydantic import SecretStr 9 | 10 | MESSAGE_ROLES = Literal["system", "user", "assistant"] 11 | MESSAGE_FORMAT = Dict[Literal["role", "content"], Union[MESSAGE_ROLES, str]] 12 | 13 | INPUT_TYPE = TypeVar( 14 | "INPUT_TYPE", bound=Union[str, List[str], MESSAGE_FORMAT, List[MESSAGE_FORMAT]] 15 | ) 16 | OUTPUT_TYPE = TypeVar( 17 | "OUTPUT_TYPE", 18 | bound=Union[str, List[List[float]], MESSAGE_FORMAT, List[MESSAGE_FORMAT], Iterator], 19 | ) 20 | 21 | 22 | class ContentHandlerBase(Generic[INPUT_TYPE, OUTPUT_TYPE]): 23 | """A handler class to transform input from LLM and BaseChatModel to a 24 | 25 | format that SageMaker endpoint expects. 26 | 27 | Similarly, the class handles transforming output from the 28 | SageMaker endpoint to a format that LLM & BaseChatModel class expects. 29 | """ 30 | 31 | """ 32 | Example: 33 | .. code-block:: python 34 | 35 | class ContentHandler(ContentHandlerBase): 36 | content_type = "application/json" 37 | accepts = "application/json" 38 | 39 | def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: 40 | input_str = json.dumps({prompt: prompt, **model_kwargs}) 41 | return input_str.encode('utf-8') 42 | 43 | def transform_output(self, output: bytes) -> str: 44 | response_json = json.loads(output.read().decode("utf-8")) 45 | return response_json[0]["generated_text"] 46 | """ 47 | 48 | content_type: Optional[str] = "text/plain" 49 | """The MIME type of the input data passed to endpoint""" 50 | 51 | accepts: Optional[str] = "text/plain" 52 | """The MIME type of the response data returned from endpoint""" 53 | 54 | @abstractmethod 55 | def transform_input(self, prompt: INPUT_TYPE, model_kwargs: Dict) -> bytes: 56 | """Transforms the input to a format that model can accept 57 | as the request Body. Should return bytes or seekable file 58 | like object in the format specified in the content_type 59 | request header. 60 | """ 61 | 62 | @abstractmethod 63 | def transform_output(self, output: bytes) -> OUTPUT_TYPE: 64 | """Transforms the output from the model to string that 65 | the LLM class expects. 66 | """ 67 | 68 | 69 | def enforce_stop_tokens(text: str, stop: List[str]) -> str: 70 | """Cut off the text as soon as any stop words occur.""" 71 | return re.split("|".join(stop), text, maxsplit=1)[0] 72 | 73 | 74 | def anthropic_tokens_supported() -> bool: 75 | """Check if all requirements for Anthropic count_tokens() are met.""" 76 | try: 77 | import anthropic 78 | except ImportError: 79 | return False 80 | 81 | if version.parse(anthropic.__version__) > version.parse("0.38.0"): 82 | return False 83 | 84 | try: 85 | import httpx 86 | 87 | if version.parse(httpx.__version__) > version.parse("0.27.2"): 88 | raise ImportError() 89 | except ImportError: 90 | raise ImportError("httpx<=0.27.2 is required.") 91 | 92 | return True 93 | 94 | 95 | def _get_anthropic_client() -> Any: 96 | import anthropic 97 | 98 | return anthropic.Anthropic() 99 | 100 | 101 | def get_num_tokens_anthropic(text: str) -> int: 102 | """Get the number of tokens in a string of text.""" 103 | client = _get_anthropic_client() 104 | return client.count_tokens(text=text) 105 | 106 | 107 | def get_token_ids_anthropic(text: str) -> List[int]: 108 | """Get the token ids for a string of text.""" 109 | client = _get_anthropic_client() 110 | tokenizer = client.get_tokenizer() 111 | encoded_text = tokenizer.encode(text) 112 | return encoded_text.ids 113 | 114 | 115 | def create_aws_client( 116 | service_name: str, 117 | region_name: Optional[str] = None, 118 | credentials_profile_name: Optional[str] = None, 119 | aws_access_key_id: Optional[SecretStr] = None, 120 | aws_secret_access_key: Optional[SecretStr] = None, 121 | aws_session_token: Optional[SecretStr] = None, 122 | endpoint_url: Optional[str] = None, 123 | config: Any = None, 124 | ): 125 | """Helper function to validate AWS credentials and create an AWS client. 126 | 127 | Args: 128 | service_name: The name of the AWS service to create a client for. 129 | region_name: AWS region name. If not provided, will try to get from environment variables. 130 | credentials_profile_name: The name of the AWS credentials profile to use. 131 | aws_access_key_id: AWS access key ID. 132 | aws_secret_access_key: AWS secret access key. 133 | aws_session_token: AWS session token. 134 | endpoint_url: The complete URL to use for the constructed client. 135 | config: Advanced client configuration options. 136 | Returns: 137 | boto3.client: An AWS service client instance. 138 | 139 | """ 140 | 141 | try: 142 | import boto3 143 | 144 | region_name = ( 145 | region_name 146 | or os.getenv("AWS_REGION") 147 | or os.getenv("AWS_DEFAULT_REGION") 148 | ) 149 | 150 | client_params = { 151 | "service_name": service_name, 152 | "region_name": region_name, 153 | "endpoint_url": endpoint_url, 154 | "config": config, 155 | } 156 | client_params = { 157 | k: v for k, v in client_params.items() if v 158 | } 159 | 160 | needs_session = bool( 161 | credentials_profile_name or 162 | aws_access_key_id or 163 | aws_secret_access_key or 164 | aws_session_token 165 | ) 166 | 167 | if not needs_session: 168 | return boto3.client(**client_params) 169 | 170 | if credentials_profile_name: 171 | session = boto3.Session(profile_name=credentials_profile_name) 172 | elif aws_access_key_id and aws_secret_access_key: 173 | session_params = { 174 | "aws_access_key_id": aws_access_key_id.get_secret_value(), 175 | "aws_secret_access_key": aws_secret_access_key.get_secret_value(), 176 | } 177 | if aws_session_token: 178 | session_params["aws_session_token"] = aws_session_token.get_secret_value() 179 | session = boto3.Session(**session_params) 180 | else: 181 | raise ValueError( 182 | "If providing credentials, both aws_access_key_id and " 183 | "aws_secret_access_key must be specified." 184 | ) 185 | 186 | if not client_params.get("region_name") and session.region_name: 187 | client_params["region_name"] = session.region_name 188 | 189 | return session.client(**client_params) 190 | 191 | except UnknownServiceError as e: 192 | raise ModuleNotFoundError( 193 | f"Ensure that you have installed the latest boto3 package " 194 | f"that contains the API for `{service_name}`." 195 | ) from e 196 | except BotoCoreError as e: 197 | raise ValueError( 198 | "Could not load credentials to authenticate with AWS client. " 199 | "Please check that the specified profile name and/or its credentials are valid. " 200 | f"Service error: {e}" 201 | ) from e 202 | except Exception as e: 203 | raise ValueError(f"Error raised by service:\n\n{e}") from e 204 | 205 | 206 | def thinking_in_params(params: dict) -> bool: 207 | """Check if the thinking parameter is enabled in the request.""" 208 | return params.get("thinking", {}).get("type") == "enabled" 209 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import TYPE_CHECKING, Any 3 | 4 | if TYPE_CHECKING: 5 | from langchain_aws.vectorstores.inmemorydb import ( 6 | InMemorySemanticCache, 7 | InMemoryVectorStore, 8 | ) 9 | __all__ = [ 10 | "InMemoryVectorStore", 11 | "InMemorySemanticCache", 12 | ] 13 | 14 | _module_lookup = { 15 | "InMemoryVectorStore": "langchain_aws.vectorstores.inmemorydb", 16 | "InMemorySemanticCache": "langchain_aws.vectorstores.inmemorydb", 17 | } 18 | 19 | 20 | def __getattr__(name: str) -> Any: 21 | if name in _module_lookup: 22 | module = importlib.import_module(_module_lookup[name]) 23 | return getattr(module, name) 24 | raise AttributeError(f"module {__name__} has no attribute {name}") 25 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/vectorstores/inmemorydb/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import InMemoryVectorStore, InMemoryVectorStoreRetriever 2 | from .cache import InMemorySemanticCache 3 | from .filters import ( 4 | InMemoryDBFilter, 5 | InMemoryDBNum, 6 | InMemoryDBTag, 7 | InMemoryDBText, 8 | ) 9 | 10 | __all__ = [ 11 | "InMemoryVectorStore", 12 | "InMemoryDBFilter", 13 | "InMemoryDBTag", 14 | "InMemoryDBText", 15 | "InMemoryDBNum", 16 | "InMemoryVectorStoreRetriever", 17 | "InMemorySemanticCache", 18 | ] 19 | -------------------------------------------------------------------------------- /libs/aws/langchain_aws/vectorstores/inmemorydb/constants.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import numpy as np 4 | 5 | # required modules 6 | 7 | # distance metrics 8 | INMEMORYDB_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"] 9 | 10 | # supported vector datatypes 11 | INMEMORYDB_VECTOR_DTYPE_MAP: Dict[str, Any] = { 12 | "FLOAT32": np.float32, 13 | "FLOAT64": np.float64, 14 | } 15 | 16 | INMEMORYDB_TAG_SEPARATOR = "," 17 | -------------------------------------------------------------------------------- /libs/aws/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-aws" 3 | version = "0.2.25" 4 | description = "An integration package connecting AWS and LangChain" 5 | authors = [] 6 | readme = "README.md" 7 | repository = "https://github.com/langchain-ai/langchain-aws" 8 | license = "MIT" 9 | 10 | [tool.poetry.urls] 11 | "Source Code" = "https://github.com/langchain-ai/langchain-aws/tree/main/libs/aws" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9" 15 | langchain-core = "^0.3.64" 16 | boto3 = ">=1.37.24" 17 | pydantic = ">=2.10.0,<3" 18 | numpy = [ 19 | { version = "^1", python = "<3.12" }, 20 | { version = ">=1.26.0,<3", python = ">=3.12" }, 21 | ] 22 | 23 | [tool.poetry.group.test] 24 | optional = true 25 | 26 | [tool.poetry.group.test.dependencies] 27 | pytest = "^7.4.3" 28 | pytest-cov = "^4.1.0" 29 | syrupy = { version = "^4.0.2", python = "<4.0" } 30 | pytest-asyncio = "^0.23.2" 31 | pytest-watcher = { version = "^0.3.4", python = "<4.0" } 32 | langchain-tests = { version = "^0.3.18", python = "<4.0" } 33 | langchain = { version = "^0.3.7", python = "<4.0" } 34 | 35 | [tool.poetry.group.codespell] 36 | optional = true 37 | 38 | [tool.codespell] 39 | ignore-words-list = "HAS, notIn" 40 | 41 | [tool.poetry.group.codespell.dependencies] 42 | codespell = "^2.2.6" 43 | 44 | 45 | 46 | [tool.poetry.group.test_integration] 47 | optional = true 48 | 49 | [tool.poetry.group.test_integration.dependencies] 50 | 51 | 52 | 53 | [tool.poetry.group.lint] 54 | optional = true 55 | 56 | [tool.poetry.group.lint.dependencies] 57 | ruff = "^0.1.8" 58 | 59 | 60 | 61 | [tool.poetry.group.typing.dependencies] 62 | mypy = "^1.7" 63 | types-requests = "^2.28.11.5" 64 | 65 | 66 | 67 | [tool.poetry.group.dev] 68 | optional = true 69 | 70 | [tool.poetry.group.dev.dependencies] 71 | 72 | [tool.ruff.lint] 73 | select = [ 74 | "E", # pycodestyle 75 | "F", # pyflakes 76 | "I", # isort 77 | "T201", # print 78 | ] 79 | 80 | [tool.mypy] 81 | ignore_missing_imports = "True" 82 | disallow_untyped_defs = "True" 83 | exclude = ["notebooks", "samples"] 84 | 85 | [tool.coverage.run] 86 | omit = ["tests/*"] 87 | 88 | [build-system] 89 | requires = ["poetry-core>=1.0.0"] 90 | build-backend = "poetry.core.masonry.api" 91 | 92 | [tool.pytest.ini_options] 93 | # --strict-markers will raise errors on unknown marks. 94 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 95 | # 96 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 97 | # --strict-config any warnings encountered while parsing the `pytest` 98 | # section of the configuration file raise errors. 99 | # 100 | # https://github.com/tophat/syrupy 101 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 102 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 --cov=langchain_aws" 103 | # Registering custom markers. 104 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 105 | markers = [ 106 | "requires: mark tests as requiring a specific library", 107 | "asyncio: mark tests as requiring asyncio", 108 | "compile: mark placeholder test used to compile integration tests without running them", 109 | "scheduled: mark tests to run in scheduled testing", 110 | ] 111 | asyncio_mode = "auto" 112 | -------------------------------------------------------------------------------- /libs/aws/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | 9 | for file in files: 10 | try: 11 | SourceFileLoader("x", file).load_module() 12 | except Exception: 13 | has_failure = True 14 | print(file) # noqa: T201 15 | traceback.print_exc() 16 | print() # noqa: T201 17 | 18 | sys.exit(1 if has_failure else 0) 19 | -------------------------------------------------------------------------------- /libs/aws/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain, langchain_experimental, or langchain_community 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1)) 12 | 13 | # Decide on an exit status based on the errors 14 | if [ "$errors" -gt 0 ]; then 15 | exit 1 16 | else 17 | exit 0 18 | fi 19 | -------------------------------------------------------------------------------- /libs/aws/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/chat_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/chat_models/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/chat_models/test_standard.py: -------------------------------------------------------------------------------- 1 | """Standard LangChain interface tests""" 2 | 3 | from typing import Type 4 | 5 | import pytest 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_tests.integration_tests import ChatModelIntegrationTests 8 | 9 | from langchain_aws.chat_models.bedrock import ChatBedrock 10 | 11 | 12 | class TestBedrockStandard(ChatModelIntegrationTests): 13 | @property 14 | def chat_model_class(self) -> Type[BaseChatModel]: 15 | return ChatBedrock 16 | 17 | @property 18 | def chat_model_params(self) -> dict: 19 | return {"model_id": "us.anthropic.claude-sonnet-4-20250514-v1:0"} 20 | 21 | @property 22 | def standard_chat_model_params(self) -> dict: 23 | return {"temperature": 0, "max_tokens": 100} 24 | 25 | @property 26 | def supports_image_inputs(self) -> bool: 27 | return True 28 | 29 | @pytest.mark.xfail(reason="Not implemented.") 30 | def test_double_messages_conversation(self, model: BaseChatModel) -> None: 31 | super().test_double_messages_conversation(model) 32 | 33 | 34 | class TestBedrockUseConverseStandard(ChatModelIntegrationTests): 35 | @property 36 | def chat_model_class(self) -> Type[BaseChatModel]: 37 | return ChatBedrock 38 | 39 | @property 40 | def chat_model_params(self) -> dict: 41 | return { 42 | "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", 43 | "beta_use_converse_api": True, 44 | } 45 | 46 | @property 47 | def standard_chat_model_params(self) -> dict: 48 | return { 49 | "temperature": 0, 50 | "max_tokens": 100, 51 | "stop_sequences": [], 52 | "model_kwargs": { 53 | "stop": [], 54 | }, 55 | } 56 | 57 | @property 58 | def supports_image_inputs(self) -> bool: 59 | return True 60 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/embeddings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/embeddings/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import numpy as np 3 | import pytest 4 | 5 | from langchain_aws import BedrockEmbeddings 6 | from langchain_aws.embeddings import bedrock 7 | 8 | 9 | @pytest.fixture 10 | def bedrock_embeddings() -> BedrockEmbeddings: 11 | return BedrockEmbeddings(model_id="amazon.titan-embed-text-v1") 12 | 13 | 14 | @pytest.fixture 15 | def bedrock_embeddings_v2() -> BedrockEmbeddings: 16 | return BedrockEmbeddings( 17 | model_id="amazon.titan-embed-text-v2:0", 18 | model_kwargs={"dimensions": 256, "normalize": True}, 19 | ) 20 | 21 | 22 | @pytest.fixture 23 | def cohere_embeddings_v3() -> BedrockEmbeddings: 24 | return BedrockEmbeddings( 25 | model_id="cohere.embed-english-v3", 26 | ) 27 | 28 | 29 | @pytest.mark.scheduled 30 | def test_bedrock_embedding_documents(bedrock_embeddings) -> None: 31 | documents = ["foo bar"] 32 | output = bedrock_embeddings.embed_documents(documents) 33 | assert len(output) == 1 34 | assert len(output[0]) == 1536 35 | 36 | 37 | @pytest.mark.scheduled 38 | def test_bedrock_embedding_documents_with_v2(bedrock_embeddings_v2) -> None: 39 | documents = ["foo bar"] 40 | output = bedrock_embeddings_v2.embed_documents(documents) 41 | assert len(output) == 1 42 | assert len(output[0]) == 256 43 | 44 | 45 | @pytest.mark.scheduled 46 | def test_bedrock_embedding_documents_multiple(bedrock_embeddings) -> None: 47 | documents = ["foo bar", "bar foo", "foo"] 48 | output = bedrock_embeddings.embed_documents(documents) 49 | assert len(output) == 3 50 | assert len(output[0]) == 1536 51 | assert len(output[1]) == 1536 52 | assert len(output[2]) == 1536 53 | 54 | 55 | @pytest.mark.scheduled 56 | async def test_bedrock_embedding_documents_async_multiple(bedrock_embeddings) -> None: 57 | documents = ["foo bar", "bar foo", "foo"] 58 | output = await bedrock_embeddings.aembed_documents(documents) 59 | assert len(output) == 3 60 | assert len(output[0]) == 1536 61 | assert len(output[1]) == 1536 62 | assert len(output[2]) == 1536 63 | 64 | 65 | @pytest.mark.scheduled 66 | def test_bedrock_embedding_query(bedrock_embeddings) -> None: 67 | document = "foo bar" 68 | output = bedrock_embeddings.embed_query(document) 69 | assert len(output) == 1536 70 | 71 | 72 | @pytest.mark.scheduled 73 | async def test_bedrock_embedding_async_query(bedrock_embeddings) -> None: 74 | document = "foo bar" 75 | output = await bedrock_embeddings.aembed_query(document) 76 | assert len(output) == 1536 77 | 78 | 79 | @pytest.mark.skip(reason="Unblock scheduled testing. TODO: fix.") 80 | @pytest.mark.scheduled 81 | def test_bedrock_embedding_with_empty_string(bedrock_embeddings) -> None: 82 | document = ["", "abc"] 83 | output = bedrock_embeddings.embed_documents(document) 84 | assert len(output) == 2 85 | assert len(output[0]) == 1536 86 | 87 | 88 | @pytest.mark.scheduled 89 | def test_embed_documents_normalized(bedrock_embeddings) -> None: 90 | bedrock_embeddings.normalize = True 91 | output = bedrock_embeddings.embed_documents(["foo walked to the market"]) 92 | assert np.isclose(np.linalg.norm(output[0]), 1.0) 93 | 94 | 95 | @pytest.mark.scheduled 96 | def test_embed_query_normalized(bedrock_embeddings) -> None: 97 | bedrock_embeddings.normalize = True 98 | output = bedrock_embeddings.embed_query("foo walked to the market") 99 | assert np.isclose(np.linalg.norm(output), 1.0) 100 | 101 | 102 | @pytest.mark.scheduled 103 | def test_embed_query_with_size(bedrock_embeddings_v2) -> None: 104 | prompt_data = """Priority should be funding retirement through ROTH/IRA/401K 105 | over HAS extra. You need to fund your HAS for reasonable and expected medical 106 | expenses. 107 | """ 108 | response = bedrock_embeddings_v2.embed_documents([prompt_data]) 109 | output = bedrock_embeddings_v2.embed_query(prompt_data) 110 | assert len(response[0]) == 256 111 | assert len(output) == 256 112 | 113 | 114 | @pytest.mark.scheduled 115 | def test_bedrock_cohere_embedding_documents(cohere_embeddings_v3) -> None: 116 | documents = ["foo bar"] 117 | output = cohere_embeddings_v3.embed_documents(documents) 118 | assert len(output) == 1 119 | assert len(output[0]) == 1024 120 | 121 | 122 | @pytest.mark.scheduled 123 | def test_bedrock_cohere_embedding_documents_multiple(cohere_embeddings_v3) -> None: 124 | documents = ["foo bar", "bar foo", "foo"] 125 | output = cohere_embeddings_v3.embed_documents(documents) 126 | assert len(output) == 3 127 | assert len(output[0]) == 1024 128 | assert len(output[1]) == 1024 129 | assert len(output[2]) == 1024 130 | 131 | 132 | @pytest.mark.scheduled 133 | def test_bedrock_cohere_batching() -> None: 134 | # Test maximum text batch 135 | documents = [f"{val}" for val in range(200)] 136 | assert len(list(bedrock._batch_cohere_embedding_texts(documents))) == 3 137 | 138 | # Test large character batch 139 | large_char_batch = ["foo", "bar", "a" * 2045, "baz"] 140 | assert list(bedrock._batch_cohere_embedding_texts(large_char_batch)) == [ 141 | ["foo", "bar"], 142 | ["a" * 2045, "baz"], 143 | ] 144 | 145 | # Should be fine with exactly 2048 characters 146 | assert list(bedrock._batch_cohere_embedding_texts(["a" * 2048])) == [["a" * 2048]] 147 | 148 | # But raise an error if it's more than that 149 | with pytest.raises(ValueError): 150 | list(bedrock._batch_cohere_embedding_texts(["a" * 2049])) 151 | 152 | 153 | @pytest.mark.scheduled 154 | def test_bedrock_cohere_embedding_large_document_set(cohere_embeddings_v3) -> None: 155 | lots_of_documents = 200 156 | documents = [f"text_{val}" for val in range(lots_of_documents)] 157 | output = cohere_embeddings_v3.embed_documents(documents) 158 | assert len(output) == 200 159 | assert len(output[0]) == 1024 160 | assert len(output[1]) == 1024 161 | assert len(output[2]) == 1024 162 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/graphs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/graphs/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/llms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/llms/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/llms/test_bedrock.py: -------------------------------------------------------------------------------- 1 | from langchain_aws import BedrockLLM 2 | 3 | 4 | def test_bedrock_llm() -> None: 5 | llm = BedrockLLM(model_id="anthropic.claude-v2:1") # type: ignore[call-arg] 6 | response = llm.invoke("Hello") 7 | assert isinstance(response, str) 8 | assert len(response) > 0 9 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/llms/test_sagemaker_endpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict 3 | from unittest.mock import Mock 4 | 5 | from langchain_aws.llms import SagemakerEndpoint 6 | from langchain_aws.llms.sagemaker_endpoint import LLMContentHandler 7 | 8 | 9 | class DefaultHandler(LLMContentHandler): 10 | accepts = "application/json" 11 | content_type = "application/json" 12 | 13 | def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: 14 | return prompt.encode() 15 | 16 | def transform_output(self, output: bytes) -> str: 17 | body = json.loads(output.decode()) 18 | return body[0]["generated_text"] 19 | 20 | 21 | def test_sagemaker_endpoint_invoke() -> None: 22 | client = Mock() 23 | response = { 24 | "ContentType": "application/json", 25 | "Body": b'[{"generated_text": "SageMaker Endpoint"}]', 26 | } 27 | client.invoke_endpoint.return_value = response 28 | 29 | llm = SagemakerEndpoint( 30 | endpoint_name="my-endpoint", 31 | region_name="us-west-2", 32 | content_handler=DefaultHandler(), 33 | model_kwargs={ 34 | "parameters": { 35 | "max_new_tokens": 50, 36 | } 37 | }, 38 | client=client, 39 | ) 40 | 41 | service_response = llm.invoke("What is Sagemaker endpoints?") 42 | 43 | assert service_response == "SageMaker Endpoint" 44 | client.invoke_endpoint.assert_called_once_with( 45 | EndpointName="my-endpoint", 46 | Body=b"What is Sagemaker endpoints?", 47 | ContentType="application/json", 48 | Accept="application/json", 49 | ) 50 | 51 | 52 | def test_sagemaker_endpoint_inference_component_invoke() -> None: 53 | client = Mock() 54 | response = { 55 | "ContentType": "application/json", 56 | "Body": b'[{"generated_text": "SageMaker Endpoint"}]', 57 | } 58 | client.invoke_endpoint.return_value = response 59 | 60 | llm = SagemakerEndpoint( 61 | endpoint_name="my-endpoint", 62 | inference_component_name="my-inference-component", 63 | region_name="us-west-2", 64 | content_handler=DefaultHandler(), 65 | model_kwargs={ 66 | "parameters": { 67 | "max_new_tokens": 50, 68 | } 69 | }, 70 | client=client, 71 | ) 72 | 73 | service_response = llm.invoke("What is Sagemaker endpoints?") 74 | 75 | assert service_response == "SageMaker Endpoint" 76 | client.invoke_endpoint.assert_called_once_with( 77 | EndpointName="my-endpoint", 78 | Body=b"What is Sagemaker endpoints?", 79 | ContentType="application/json", 80 | Accept="application/json", 81 | InferenceComponentName="my-inference-component", 82 | ) 83 | 84 | 85 | def test_sagemaker_endpoint_stream() -> None: 86 | class ContentHandler(LLMContentHandler): 87 | accepts = "application/json" 88 | content_type = "application/json" 89 | 90 | def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: 91 | body = json.dumps({"inputs": prompt, **model_kwargs}) 92 | return body.encode() 93 | 94 | def transform_output(self, output: bytes) -> str: 95 | body = json.loads(output) 96 | return body.get("outputs")[0] 97 | 98 | body = ( 99 | {"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}}, 100 | {"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}}, 101 | {"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}}, 102 | ) 103 | 104 | response = {"ContentType": "application/json", "Body": body} 105 | 106 | client = Mock() 107 | client.invoke_endpoint_with_response_stream.return_value = response 108 | 109 | llm = SagemakerEndpoint( 110 | endpoint_name="my-endpoint", 111 | region_name="us-west-2", 112 | content_handler=ContentHandler(), 113 | client=client, 114 | model_kwargs={"parameters": {"max_new_tokens": 50}}, 115 | ) 116 | 117 | expected_body = json.dumps( 118 | {"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}} 119 | ).encode() 120 | 121 | chunks = ["S", "age", "Maker"] 122 | service_chunks = [] 123 | 124 | for chunk in llm.stream("What is Sagemaker endpoints?"): 125 | service_chunks.append(chunk) 126 | 127 | assert service_chunks == chunks 128 | client.invoke_endpoint_with_response_stream.assert_called_once_with( 129 | EndpointName="my-endpoint", 130 | Body=expected_body, 131 | ContentType="application/json", 132 | ) 133 | 134 | 135 | def test_sagemaker_endpoint_inference_component_stream() -> None: 136 | class ContentHandler(LLMContentHandler): 137 | accepts = "application/json" 138 | content_type = "application/json" 139 | 140 | def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: 141 | body = json.dumps({"inputs": prompt, **model_kwargs}) 142 | return body.encode() 143 | 144 | def transform_output(self, output: bytes) -> str: 145 | body = json.loads(output) 146 | return body.get("outputs")[0] 147 | 148 | body = ( 149 | {"PayloadPart": {"Bytes": b'{"outputs": ["S"]}\n'}}, 150 | {"PayloadPart": {"Bytes": b'{"outputs": ["age"]}\n'}}, 151 | {"PayloadPart": {"Bytes": b'{"outputs": ["Maker"]}\n'}}, 152 | ) 153 | 154 | response = {"ContentType": "application/json", "Body": body} 155 | 156 | client = Mock() 157 | client.invoke_endpoint_with_response_stream.return_value = response 158 | 159 | llm = SagemakerEndpoint( 160 | endpoint_name="my-endpoint", 161 | inference_component_name="my_inference_component", 162 | region_name="us-west-2", 163 | content_handler=ContentHandler(), 164 | client=client, 165 | model_kwargs={"parameters": {"max_new_tokens": 50}}, 166 | ) 167 | 168 | expected_body = json.dumps( 169 | {"inputs": "What is Sagemaker endpoints?", "parameters": {"max_new_tokens": 50}} 170 | ).encode() 171 | 172 | chunks = ["S", "age", "Maker"] 173 | service_chunks = [] 174 | 175 | for chunk in llm.stream("What is Sagemaker endpoints?"): 176 | service_chunks.append(chunk) 177 | 178 | assert service_chunks == chunks 179 | client.invoke_endpoint_with_response_stream.assert_called_once_with( 180 | EndpointName="my-endpoint", 181 | Body=expected_body, 182 | ContentType="application/json", 183 | InferenceComponentName="my_inference_component", 184 | ) 185 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/retrievers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/integration_tests/retrievers/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/retrievers/test_amazon_kendra_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | 6 | from langchain_aws import AmazonKendraRetriever 7 | from langchain_aws.retrievers.kendra import RetrieveResultItem 8 | 9 | 10 | @pytest.fixture 11 | def mock_client() -> Mock: 12 | mock_client = Mock() 13 | return mock_client 14 | 15 | 16 | @pytest.fixture 17 | def retriever(mock_client: Any) -> AmazonKendraRetriever: 18 | return AmazonKendraRetriever( 19 | index_id="test_index_id", client=mock_client, top_k=3, min_score_confidence=0.6 20 | ) 21 | 22 | 23 | def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore[no-untyped-def] 24 | # Mock data for Kendra response 25 | mock_retrieve_result = { 26 | "QueryId": "test_query_id", 27 | "ResultItems": [ 28 | RetrieveResultItem( 29 | Id="doc1", 30 | DocumentId="doc1", 31 | DocumentURI="https://example.com/doc1", 32 | DocumentTitle="Document 1", 33 | Content="This is the content of Document 1.", 34 | ScoreAttributes={"ScoreConfidence": "HIGH"}, 35 | ), 36 | RetrieveResultItem( 37 | Id="doc2", 38 | DocumentId="doc2", 39 | DocumentURI="https://example.com/doc2", 40 | DocumentTitle="Document 2", 41 | Content="This is the content of Document 2.", 42 | ScoreAttributes={"ScoreConfidence": "MEDIUM"}, 43 | ), 44 | RetrieveResultItem( 45 | Id="doc3", 46 | DocumentId="doc3", 47 | DocumentURI="https://example.com/doc3", 48 | DocumentTitle="Document 3", 49 | Content="This is the content of Document 3.", 50 | ScoreAttributes={"ScoreConfidence": "HIGH"}, 51 | ), 52 | ], 53 | } 54 | 55 | mock_client.retrieve.return_value = mock_retrieve_result 56 | 57 | query = "test query" 58 | 59 | docs = retriever.invoke(query) 60 | 61 | # Only documents with confidence score of HIGH are returned 62 | assert len(docs) == 2 63 | assert docs[0].page_content == ( 64 | "Document Title: Document 1\nDocument Excerpt: \n" 65 | "This is the content of Document 1.\n" 66 | ) 67 | assert docs[1].page_content == ( 68 | "Document Title: Document 3\nDocument Excerpt: \n" 69 | "This is the content of Document 3.\n" 70 | ) 71 | 72 | # Assert that the mock methods were called with the expected arguments 73 | mock_client.retrieve.assert_called_with( 74 | IndexId="test_index_id", QueryText="test query", PageSize=3 75 | ) 76 | mock_client.query.assert_not_called() 77 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import pytest 4 | from langchain_core.documents import Document 5 | 6 | from langchain_aws import AmazonKnowledgeBasesRetriever 7 | 8 | 9 | @pytest.fixture 10 | def mock_client() -> Mock: 11 | return Mock() 12 | 13 | 14 | @pytest.fixture 15 | def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever: 16 | return AmazonKnowledgeBasesRetriever( 17 | knowledge_base_id="test-knowledge-base", 18 | client=mock_client, 19 | retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type] 20 | min_score_confidence=0.0, 21 | ) 22 | 23 | 24 | def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore[no-untyped-def] 25 | response = { 26 | "retrievalResults": [ 27 | { 28 | "content": {"text": "This is the first result."}, 29 | "location": "location1", 30 | "score": 0.9, 31 | }, 32 | { 33 | "content": {"text": "This is the second result."}, 34 | "location": "location2", 35 | "score": 0.8, 36 | }, 37 | {"content": {"text": "This is the third result."}, "location": "location3"}, 38 | { 39 | "content": {"text": "This is the fourth result."}, 40 | "metadata": {"key1": "value1", "key2": "value2"}, 41 | }, 42 | ] 43 | } 44 | mock_client.retrieve.return_value = response 45 | 46 | query = "test query" 47 | 48 | expected_documents = [ 49 | Document( 50 | page_content="This is the first result.", 51 | metadata={"location": "location1", "score": 0.9, "type": "TEXT"}, 52 | ), 53 | Document( 54 | page_content="This is the second result.", 55 | metadata={"location": "location2", "score": 0.8, "type": "TEXT"}, 56 | ), 57 | Document( 58 | page_content="This is the third result.", 59 | metadata={"location": "location3", "score": 0.0, "type": "TEXT"}, 60 | ), 61 | Document( 62 | page_content="This is the fourth result.", 63 | metadata={ 64 | "type": "TEXT", 65 | "score": 0.0, 66 | "source_metadata": { 67 | "key1": "value1", 68 | "key2": "value2", 69 | }, 70 | }, 71 | ), 72 | ] 73 | 74 | documents = retriever.invoke(query) 75 | 76 | assert documents == expected_documents 77 | 78 | mock_client.retrieve.assert_called_once_with( 79 | retrievalQuery={"text": "test query"}, 80 | knowledgeBaseId="test-knowledge-base", 81 | retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}}, 82 | ) 83 | 84 | 85 | def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # type: ignore[no-untyped-def] 86 | response = { 87 | "retrievalResults": [ 88 | { 89 | "content": {"text": "This is the first result."}, 90 | "location": "location1", 91 | "score": 0.9, 92 | }, 93 | { 94 | "content": {"text": "This is the second result."}, 95 | "location": "location2", 96 | "score": 0.8, 97 | }, 98 | {"content": {"text": "This is the third result."}, "location": "location3"}, 99 | { 100 | "content": {"text": "This is the fourth result."}, 101 | "metadata": {"key1": "value1", "key2": "value2"}, 102 | }, 103 | ] 104 | } 105 | mock_client.retrieve.return_value = response 106 | 107 | query = "test query" 108 | 109 | expected_documents = [ 110 | Document( 111 | page_content="This is the first result.", 112 | metadata={"location": "location1", "score": 0.9, "type": "TEXT"}, 113 | ), 114 | Document( 115 | page_content="This is the second result.", 116 | metadata={"location": "location2", "score": 0.8, "type": "TEXT"}, 117 | ), 118 | ] 119 | 120 | retriever.min_score_confidence = 0.80 121 | documents = retriever.invoke(query) 122 | 123 | assert documents == expected_documents 124 | -------------------------------------------------------------------------------- /libs/aws/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | """All unit tests (lightweight tests).""" 2 | 3 | from typing import Any 4 | 5 | 6 | def assert_all_importable(module: Any) -> None: 7 | for attr in module.__all__: 8 | getattr(module, attr) 9 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/__snapshots__/test_standard.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: TestBedrockAsConverseStandard.test_serdes[serialized] 3 | dict({ 4 | 'id': list([ 5 | 'langchain', 6 | 'chat_models', 7 | 'bedrock', 8 | 'ChatBedrock', 9 | ]), 10 | 'kwargs': dict({ 11 | 'beta_use_converse_api': True, 12 | 'guardrails': dict({ 13 | 'guardrailIdentifier': None, 14 | 'guardrailVersion': None, 15 | 'trace': None, 16 | }), 17 | 'max_tokens': 100, 18 | 'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', 19 | 'model_kwargs': dict({ 20 | 'stop': list([ 21 | ]), 22 | }), 23 | 'provider_stop_reason_key_map': dict({ 24 | 'ai21': 'finishReason', 25 | 'amazon': 'completionReason', 26 | 'anthropic': 'stop_reason', 27 | 'cohere': 'finish_reason', 28 | 'mistral': 'stop_reason', 29 | }), 30 | 'provider_stop_sequence_key_name_map': dict({ 31 | 'ai21': 'stop_sequences', 32 | 'amazon': 'stopSequences', 33 | 'anthropic': 'stop_sequences', 34 | 'cohere': 'stop_sequences', 35 | 'mistral': 'stop_sequences', 36 | }), 37 | 'region_name': 'us-east-1', 38 | 'temperature': 0, 39 | }), 40 | 'lc': 1, 41 | 'name': 'ChatBedrock', 42 | 'type': 'constructor', 43 | }) 44 | # --- 45 | # name: TestBedrockStandard.test_serdes[serialized] 46 | dict({ 47 | 'id': list([ 48 | 'langchain', 49 | 'chat_models', 50 | 'bedrock', 51 | 'ChatBedrock', 52 | ]), 53 | 'kwargs': dict({ 54 | 'guardrails': dict({ 55 | 'guardrailIdentifier': None, 56 | 'guardrailVersion': None, 57 | 'trace': None, 58 | }), 59 | 'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', 60 | 'model_kwargs': dict({ 61 | }), 62 | 'provider_stop_reason_key_map': dict({ 63 | 'ai21': 'finishReason', 64 | 'amazon': 'completionReason', 65 | 'anthropic': 'stop_reason', 66 | 'cohere': 'finish_reason', 67 | 'mistral': 'stop_reason', 68 | }), 69 | 'provider_stop_sequence_key_name_map': dict({ 70 | 'ai21': 'stop_sequences', 71 | 'amazon': 'stopSequences', 72 | 'anthropic': 'stop_sequences', 73 | 'cohere': 'stop_sequences', 74 | 'mistral': 'stop_sequences', 75 | }), 76 | 'region_name': 'us-east-1', 77 | }), 78 | 'lc': 1, 79 | 'name': 'ChatBedrock', 80 | 'type': 'constructor', 81 | }) 82 | # --- 83 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/agents/test_utils.py: -------------------------------------------------------------------------------- 1 | from botocore.config import Config 2 | 3 | from langchain_aws.agents.utils import SDK_USER_AGENT, get_boto_session 4 | 5 | 6 | def test_get_boto3_session() -> None: 7 | client_params, session = get_boto_session() 8 | assert "config" in client_params 9 | config = client_params["config"] 10 | assert SDK_USER_AGENT in config.user_agent_extra 11 | 12 | 13 | def test_get_boto_session_with_config() -> None: 14 | # Set default client parameters 15 | fake_config = Config( 16 | connect_timeout=240, 17 | read_timeout=240, 18 | retries={"max_attempts": 1}, 19 | ) 20 | client_params, session = get_boto_session(config=fake_config) 21 | assert "config" in client_params 22 | config = client_params["config"] 23 | assert SDK_USER_AGENT in config.user_agent_extra 24 | assert config.connect_timeout == fake_config.connect_timeout 25 | assert config.read_timeout == fake_config.read_timeout 26 | assert config.retries["max_attempts"] == fake_config.retries["max_attempts"] 27 | 28 | 29 | def test_get_boto_session_with_user_agent() -> None: 30 | # Set default client parameters 31 | fake_config = Config( 32 | connect_timeout=240, 33 | read_timeout=240, 34 | retries={"max_attempts": 1}, 35 | user_agent_extra="MY_USER_AGENT_EXTRA", 36 | ) 37 | client_params, session = get_boto_session(config=fake_config) 38 | assert "config" in client_params 39 | config = client_params["config"] 40 | assert SDK_USER_AGENT in config.user_agent_extra 41 | assert fake_config.user_agent_extra in config.user_agent_extra 42 | assert config.connect_timeout == fake_config.connect_timeout 43 | assert config.read_timeout == fake_config.read_timeout 44 | assert config.retries["max_attempts"] == fake_config.retries["max_attempts"] 45 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/chat_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/unit_tests/chat_models/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/chat_models/__snapshots__/test_bedrock_converse.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: TestBedrockStandard.test_serdes[serialized] 3 | dict({ 4 | 'id': list([ 5 | 'langchain_aws', 6 | 'chat_models', 7 | 'ChatBedrockConverse', 8 | ]), 9 | 'kwargs': dict({ 10 | 'max_tokens': 100, 11 | 'model_id': 'anthropic.claude-3-sonnet-20240229-v1:0', 12 | 'provider': 'anthropic', 13 | 'region_name': 'us-west-1', 14 | 'stop_sequences': list([ 15 | ]), 16 | 'supports_tool_choice_values': list([ 17 | 'auto', 18 | 'any', 19 | 'tool', 20 | ]), 21 | 'temperature': 0.0, 22 | }), 23 | 'lc': 1, 24 | 'name': 'ChatBedrockConverse', 25 | 'type': 'constructor', 26 | }) 27 | # --- -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/chat_models/test_sagemaker_endpoint.py: -------------------------------------------------------------------------------- 1 | # type:ignore 2 | """Test chat model integration.""" 3 | import json 4 | from typing import Dict 5 | from unittest.mock import Mock 6 | 7 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage 8 | 9 | from langchain_aws.chat_models.sagemaker_endpoint import ( 10 | ChatModelContentHandler, 11 | ChatSagemakerEndpoint, 12 | _messages_to_sagemaker, 13 | ) 14 | 15 | 16 | class DefaultHandler(ChatModelContentHandler): 17 | content_type = "application/json" 18 | accepts = "application/json" 19 | 20 | def transform_input(self, prompt, model_kwargs: Dict) -> bytes: 21 | return json.dumps(prompt).encode("utf-8") 22 | 23 | def transform_output(self, output: bytes) -> str: 24 | response_json = json.loads(output.decode()) 25 | return AIMessage(content=response_json[0]["generated_text"]) 26 | 27 | 28 | def test_format_messages_request() -> None: 29 | client = Mock() 30 | messages = [ 31 | SystemMessage("Output everything you have."), # type: ignore[misc] 32 | HumanMessage("What is an llm?"), # type: ignore[misc] 33 | ] 34 | kwargs = {} 35 | 36 | llm = ChatSagemakerEndpoint( 37 | endpoint_name="my-endpoint", 38 | region_name="us-west-2", 39 | content_handler=DefaultHandler(), 40 | model_kwargs={ 41 | "parameters": { 42 | "max_new_tokens": 50, 43 | } 44 | }, 45 | client=client, 46 | ) 47 | invocation_params = llm._format_messages_request(messages=messages, **kwargs) 48 | 49 | expected_invocation_params = { 50 | "EndpointName": "my-endpoint", 51 | "Body": b"""[{"role": "system", "content": "Output everything you have."}, {"role": "user", "content": "What is an llm?"}]""", 52 | "ContentType": "application/json", 53 | "Accept": "application/json", 54 | } 55 | assert invocation_params == expected_invocation_params 56 | 57 | 58 | def test__messages_to_sagemaker() -> None: 59 | messages = [ 60 | SystemMessage("foo"), # type: ignore[misc] 61 | HumanMessage("bar"), # type: ignore[misc] 62 | AIMessage("some answer"), 63 | HumanMessage("follow-up question"), # type: ignore[misc] 64 | ] 65 | expected = [ 66 | {"role": "system", "content": "foo"}, 67 | {"role": "user", "content": "bar"}, 68 | {"role": "assistant", "content": "some answer"}, 69 | {"role": "user", "content": "follow-up question"}, 70 | ] 71 | actual = _messages_to_sagemaker(messages) 72 | assert expected == actual 73 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/document_compressors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/unit_tests/document_compressors/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/document_compressors/test_rerank.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | from langchain_core.documents import Document 5 | 6 | from langchain_aws.document_compressors.rerank import BedrockRerank 7 | 8 | 9 | @pytest.fixture 10 | def reranker() -> BedrockRerank: 11 | reranker = BedrockRerank( 12 | model_arn="arn:aws:bedrock:us-west-2::foundation-model/amazon.rerank-v1:0", 13 | region_name="us-east-1", 14 | ) 15 | reranker.client = MagicMock() 16 | return reranker 17 | 18 | 19 | @patch("boto3.Session") 20 | def test_initialize_client( 21 | mock_boto_session: MagicMock, reranker: BedrockRerank 22 | ) -> None: 23 | session_instance = MagicMock() 24 | mock_boto_session.return_value = session_instance 25 | session_instance.client.return_value = MagicMock() 26 | assert reranker.client is not None 27 | 28 | 29 | @patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank") 30 | def test_rerank(mock_rerank: MagicMock, reranker: BedrockRerank) -> None: 31 | mock_rerank.return_value = [ 32 | {"index": 0, "relevance_score": 0.9}, 33 | {"index": 1, "relevance_score": 0.8}, 34 | ] 35 | 36 | documents = [Document(page_content="Doc 1"), Document(page_content="Doc 2")] 37 | query = "Example Query" 38 | results = reranker.rerank(documents, query) 39 | 40 | assert len(results) == 2 41 | assert results[0]["index"] == 0 42 | assert results[0]["relevance_score"] == 0.9 43 | assert results[1]["index"] == 1 44 | assert results[1]["relevance_score"] == 0.8 45 | 46 | 47 | @patch("langchain_aws.document_compressors.rerank.BedrockRerank.rerank") 48 | def test_compress_documents(mock_rerank: MagicMock, reranker: BedrockRerank) -> None: 49 | mock_rerank.return_value = [ 50 | {"index": 0, "relevance_score": 0.95}, 51 | {"index": 1, "relevance_score": 0.85}, 52 | ] 53 | 54 | documents = [ 55 | Document(page_content="Content 1", id="doc1"), 56 | Document(page_content="Content 2", id="doc2"), 57 | ] 58 | query = "Relevant query" 59 | compressed_docs = reranker.compress_documents(documents, query) 60 | 61 | assert compressed_docs[0].id == "doc1" 62 | assert compressed_docs[0].page_content == "Content 1" 63 | assert compressed_docs[1].id == "doc2" 64 | assert compressed_docs[1].page_content == "Content 2" 65 | 66 | assert len(compressed_docs) == 2 67 | assert compressed_docs[0].metadata["relevance_score"] == 0.95 68 | assert compressed_docs[1].metadata["relevance_score"] == 0.85 69 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/llms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/aws/tests/unit_tests/llms/__init__.py -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | from pathlib import Path 4 | 5 | 6 | def test_importable_all() -> None: 7 | for path in glob.glob("../langchain_aws/*"): 8 | relative_path = Path(path).parts[-1] 9 | if relative_path.endswith(".typed"): 10 | continue 11 | module_name = relative_path.split(".")[0] 12 | module = importlib.import_module("langchain_aws." + module_name) 13 | all_ = getattr(module, "__all__", []) 14 | for cls_ in all_: 15 | getattr(module, cls_) 16 | -------------------------------------------------------------------------------- /libs/aws/tests/unit_tests/test_standard.py: -------------------------------------------------------------------------------- 1 | """Standard LangChain interface tests""" 2 | 3 | from typing import Type 4 | 5 | import pytest 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_tests.unit_tests import ChatModelUnitTests 8 | 9 | from langchain_aws.chat_models.bedrock import ChatBedrock 10 | 11 | 12 | class TestBedrockStandard(ChatModelUnitTests): 13 | @property 14 | def chat_model_class(self) -> Type[BaseChatModel]: 15 | return ChatBedrock 16 | 17 | @property 18 | def chat_model_params(self) -> dict: 19 | return { 20 | "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", 21 | "region_name": "us-east-1", 22 | } 23 | 24 | @property 25 | def standard_chat_model_params(self) -> dict: 26 | return {} 27 | 28 | @pytest.mark.xfail(reason="Not implemented.") 29 | def test_standard_params(self, model: BaseChatModel) -> None: 30 | super().test_standard_params(model) 31 | 32 | 33 | class TestBedrockAsConverseStandard(ChatModelUnitTests): 34 | @property 35 | def chat_model_class(self) -> Type[BaseChatModel]: 36 | return ChatBedrock 37 | 38 | @property 39 | def chat_model_params(self) -> dict: 40 | return { 41 | "model_id": "anthropic.claude-3-sonnet-20240229-v1:0", 42 | "region_name": "us-east-1", 43 | "beta_use_converse_api": True, 44 | } 45 | 46 | @property 47 | def standard_chat_model_params(self) -> dict: 48 | return { 49 | "model_kwargs": { 50 | "temperature": 0, 51 | "max_tokens": 100, 52 | "stop": [], 53 | } 54 | } 55 | 56 | @pytest.mark.xfail(reason="Not implemented.") 57 | def test_standard_params(self, model: BaseChatModel) -> None: 58 | super().test_standard_params(model) 59 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache_test/* 3 | .coverage -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | 31 | ## Finding contributions to work on 32 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 33 | 34 | 35 | ## Code of Conduct 36 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 37 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 38 | opensource-codeofconduct@amazon.com with any additional questions or comments. 39 | 40 | 41 | ## Security issue notifications 42 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 43 | 44 | 45 | ## Licensing 46 | 47 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/Makefile: -------------------------------------------------------------------------------- 1 | ###################### 2 | # VARIABLES 3 | ###################### 4 | 5 | PYTHON_FILES=. 6 | MYPY_CACHE=.mypy_cache 7 | PACKAGE_NAME=langgraph_checkpoint_aws 8 | 9 | .PHONY: help lint format install_dev install_test install_lint install_typing install_codespell check_imports spell_check spell_fix 10 | 11 | ###################### 12 | # LINTING AND FORMATTING 13 | ###################### 14 | 15 | # Define different lint targets 16 | lint format: PYTHON_FILES=. 17 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 18 | lint_package: PYTHON_FILES=$(PACKAGE_NAME) 19 | lint_tests: PYTHON_FILES=tests 20 | lint_tests: MYPY_CACHE=.mypy_cache_test 21 | 22 | lint: ## Run linter 23 | poetry run ruff check $(PYTHON_FILES) 24 | 25 | lint_fix: ## Run linter and fix issues 26 | poetry run ruff check --fix $(PYTHON_FILES) 27 | 28 | lint_diff: ## Run linter on changed files 29 | poetry run ruff check $(PYTHON_FILES) 30 | 31 | lint_package: ## Run linter on package 32 | poetry run ruff check --select I $(PYTHON_FILES) 33 | 34 | lint_tests: ## Run type checking on tests 35 | mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 36 | 37 | format: ## Run code formatter 38 | poetry run ruff format $(PYTHON_FILES) 39 | 40 | format_diff: ## Run code formatter and show differences 41 | poetry run ruff format $(PYTHON_FILES) --diff 42 | 43 | spell_check: ## Run code spell check 44 | poetry run codespell --toml pyproject.toml 45 | 46 | spell_fix: ## Run code spell fix 47 | poetry run codespell --toml pyproject.toml -w 48 | 49 | ###################### 50 | # TESTING 51 | ###################### 52 | 53 | # Define a variable for the test file path. 54 | test test_watch tests: TEST_FILE ?= tests/unit_tests/ 55 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 56 | 57 | # Define a variable for Python and notebook files. 58 | PYTHON_FILES=. 59 | 60 | tests: ## Run all unit tests 61 | poetry run pytest $(TEST_FILE) 62 | 63 | test: ## Run individual unit test: make test TEST_FILE=tests/unit_test/test.py 64 | poetry run pytest $(TEST_FILE) 65 | 66 | integration_tests: ## Run all integration tests 67 | poetry run pytest $(TEST_FILE) 68 | 69 | integration_test: ## Run individual integration test: make integration_test TEST_FILE=tests/integration_tests/integ_test.py 70 | poetry run pytest $(TEST_FILE) 71 | 72 | test_watch: ## Run and interactively watch unit tests 73 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 74 | 75 | ###################### 76 | # DEPENDENCIES 77 | ###################### 78 | 79 | install: ## Install package 80 | @pip install --no-cache -U poetry 81 | @poetry install 82 | 83 | install_dev: ## Install development environment 84 | @pip install --no-cache -U poetry 85 | @poetry install --with dev 86 | 87 | install_test: ## Install test dependencies 88 | @pip install --no-cache -U poetry 89 | @poetry install --with test 90 | 91 | install_lint: ## Install lint dependencies 92 | @pip install --no-cache -U poetry 93 | @poetry install --with lint 94 | 95 | install_typing: ## Install typing dependencies 96 | @pip install --no-cache -U poetry 97 | @poetry install --with typing 98 | 99 | install_codespell: ## Install codespell dependencies 100 | @pip install --no-cache -U poetry 101 | @poetry install --with codespell 102 | 103 | install_all: ## Install all dependencies including optional groups 104 | @pip install --no-cache -U poetry 105 | @poetry install --with dev,test,lint,typing,codespell 106 | 107 | check_imports: $(shell find $(PACKAGE_NAME) -name '*.py') ## Check missing imports 108 | @poetry run python ./scripts/check_imports.py $^ 109 | 110 | ###################### 111 | # CLEANING 112 | ###################### 113 | 114 | clean: ## Clean all generated files 115 | find . -type d -name "__pycache__" -exec rm -rf {} + 116 | find . -type f -name "*.pyc" -delete 117 | find . -type f -name "*.pyo" -delete 118 | find . -type f -name "*.pyd" -delete 119 | find . -type f -name ".coverage" -delete 120 | find . -type d -name "*.egg-info" -exec rm -rf {} + 121 | find . -type d -name "*.egg" -exec rm -rf {} + 122 | find . -type d -name ".pytest_cache" -exec rm -rf {} + 123 | find . -type d -name ".mypy_cache" -exec rm -rf {} + 124 | find . -type d -name ".ruff_cache" -exec rm -rf {} + 125 | find . -type d -name "dist" -exec rm -rf {} + 126 | find . -type d -name "build" -exec rm -rf {} + 127 | 128 | ###################### 129 | # HELP 130 | ###################### 131 | 132 | help: ## Print this help 133 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 134 | 135 | .DEFAULT_GOAL := help 136 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/README.md: -------------------------------------------------------------------------------- 1 | # LangGraph Checkpoint AWS 2 | A custom LangChain checkpointer implementation that uses Bedrock Session Management Service to enable stateful and resumable LangGraph agents through efficient state persistence and retrieval. 3 | 4 | ## Overview 5 | This package provides a custom checkpointing solution for LangGraph agents using AWS Bedrock Session Management Service. It enables: 6 | 1. Stateful conversations and interactions 7 | 2. Resumable agent sessions 8 | 3. Efficient state persistence and retrieval 9 | 4. Seamless integration with AWS Bedrock 10 | 11 | ## Installation 12 | You can install the package using pip: 13 | 14 | ```bash 15 | pip install langgraph-checkpoint-aws 16 | ``` 17 | Or with Poetry: 18 | ```bash 19 | poetry add langgraph-checkpoint-aws 20 | ``` 21 | 22 | ## Requirements 23 | ```text 24 | Python >=3.9 25 | langgraph-checkpoint >=2.0.0 26 | langgraph >=0.2.55 27 | boto3 >=1.37.3 28 | ``` 29 | 30 | ## Usage 31 | 32 | ```python 33 | from langgraph.graph import StateGraph 34 | from langgraph_checkpoint_aws.saver import BedrockSessionSaver 35 | 36 | # Initialize the saver 37 | session_saver = BedrockSessionSaver( 38 | region_name="us-west-2", # Your AWS region 39 | credentials_profile_name="default", # Optional: AWS credentials profile 40 | ) 41 | 42 | # Create a session 43 | session_id = session_saver.session_client.create_session().session_id 44 | 45 | # Use with LangGraph 46 | builder = StateGraph(int) 47 | builder.add_node("add_one", lambda x: x + 1) 48 | builder.set_entry_point("add_one") 49 | builder.set_finish_point("add_one") 50 | 51 | graph = builder.compile(checkpointer=session_saver) 52 | config = {"configurable": {"thread_id": session_id}} 53 | graph.invoke(1, config) 54 | ``` 55 | 56 | ## Configuration Options 57 | 58 | `BedrockSessionSaver` accepts the following parameters: 59 | 60 | ```python 61 | def __init__( 62 | region_name: Optional[str] = None, 63 | credentials_profile_name: Optional[str] = None, 64 | aws_access_key_id: Optional[SecretStr] = None, 65 | aws_secret_access_key: Optional[SecretStr] = None, 66 | aws_session_token: Optional[SecretStr] = None, 67 | endpoint_url: Optional[str] = None, 68 | config: Optional[Config] = None, 69 | ) 70 | ``` 71 | 72 | - `region_name`: AWS region where Bedrock is available 73 | - `credentials_profile_name`: Name of AWS credentials profile to use 74 | - `aws_access_key_id`: AWS access key ID for authentication 75 | - `aws_secret_access_key`: AWS secret access key for authentication 76 | - `aws_session_token`: AWS session token for temporary credentials 77 | - `endpoint_url`: Custom endpoint URL for the Bedrock service 78 | - `config`: Botocore configuration object 79 | ## Development 80 | Setting Up Development Environment 81 | 82 | * Clone the repository: 83 | ```bash 84 | git clone 85 | cd libs/aws/langgraph-checkpoint-aws 86 | ``` 87 | * Install development dependencies: 88 | ```bash 89 | make install_all 90 | ``` 91 | * Or install specific components: 92 | ```bash 93 | make install_dev # Basic development tools 94 | make install_test # Testing tools 95 | make install_lint # Linting tools 96 | make install_typing # Type checking tools 97 | make install_codespell # Spell checking tools 98 | ``` 99 | 100 | ## Running Tests 101 | ```bash 102 | make tests # Run all tests 103 | make test_watch # Run tests in watch mode 104 | 105 | ``` 106 | 107 | ## Code Quality 108 | ```bash 109 | make lint # Run linter 110 | make format # Format code 111 | make spell_check # Check spelling 112 | ``` 113 | 114 | ## Clean Up 115 | ```bash 116 | make clean # Remove all generated files 117 | ``` 118 | 119 | ## AWS Configuration 120 | 121 | Ensure you have AWS credentials configured using one of these methods: 122 | 1. Environment variables 123 | 2. AWS credentials file (~/.aws/credentials) 124 | 3. IAM roles 125 | 4. Direct credential injection via constructor parameters 126 | 127 | ## Required AWS permissions: 128 | 129 | ```json 130 | { 131 | "Version": "2012-10-17", 132 | "Statement": [ 133 | { 134 | "Sid": "Statement1", 135 | "Effect": "Allow", 136 | "Action": [ 137 | "bedrock:CreateSession", 138 | "bedrock:GetSession", 139 | "bedrock:UpdateSession", 140 | "bedrock:DeleteSession", 141 | "bedrock:EndSession", 142 | "bedrock:ListSessions", 143 | "bedrock:CreateInvocation", 144 | "bedrock:ListInvocations", 145 | "bedrock:PutInvocationStep", 146 | "bedrock:GetInvocationStep", 147 | "bedrock:ListInvocationSteps" 148 | ], 149 | "Resource": [ 150 | "*" 151 | ] 152 | }, 153 | { 154 | "Effect": "Allow", 155 | "Action": [ 156 | "kms:Decrypt", 157 | "kms:Encrypt", 158 | "kms:GenerateDataKey", 159 | "kms:DescribeKey" 160 | ], 161 | "Resource": "arn:aws:kms:{region}:{account}:key/{kms-key-id}" 162 | }, 163 | { 164 | "Effect": "Allow", 165 | "Action": [ 166 | "bedrock:TagResource", 167 | "bedrock:UntagResource", 168 | "bedrock:ListTagsForResource" 169 | ], 170 | "Resource": "arn:aws:bedrock:{region}:{account}:session/*" 171 | } 172 | ] 173 | } 174 | ``` 175 | 176 | ## Security Considerations 177 | * Never commit AWS credentials 178 | * Use environment variables or AWS IAM roles for authentication 179 | * Follow AWS security best practices 180 | * Use IAM roles and temporary credentials when possible 181 | * Implement proper access controls for session management 182 | 183 | ## Contributing 184 | * Fork the repository 185 | * Create a feature branch 186 | * Make your changes 187 | * Run tests and linting 188 | * Submit a pull request 189 | 190 | ## License 191 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 192 | 193 | ## Acknowledgments 194 | * LangChain team for the base LangGraph framework 195 | * AWS Bedrock team for the session management service -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LangGraph Checkpoint AWS - A LangChain checkpointer implementation using Bedrock Session Management Service. 3 | """ 4 | 5 | __version__ = "0.1.0" 6 | SDK_USER_AGENT = f"LangGraphCheckpointAWS#{__version__}" 7 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/constants.py: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PREFIX = "CHECKPOINTS" 2 | 3 | WRITES_PREFIX = "WRITES" 4 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/session.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import boto3 4 | from botocore.config import Config 5 | from pydantic import SecretStr 6 | 7 | from langgraph_checkpoint_aws.models import ( 8 | CreateInvocationRequest, 9 | CreateInvocationResponse, 10 | CreateSessionRequest, 11 | CreateSessionResponse, 12 | DeleteSessionRequest, 13 | EndSessionRequest, 14 | EndSessionResponse, 15 | GetInvocationStepRequest, 16 | GetInvocationStepResponse, 17 | GetSessionRequest, 18 | GetSessionResponse, 19 | ListInvocationsRequest, 20 | ListInvocationsResponse, 21 | ListInvocationStepsRequest, 22 | ListInvocationStepsResponse, 23 | PutInvocationStepRequest, 24 | PutInvocationStepResponse, 25 | ) 26 | from langgraph_checkpoint_aws.utils import process_aws_client_args, to_boto_params 27 | 28 | 29 | class BedrockAgentRuntimeSessionClient: 30 | """ 31 | Client for AWS Bedrock Agent Runtime API 32 | 33 | This class provides an interface to interact with AWS Bedrock Agent Runtime service. 34 | It handles session management, invocations and invocation steps through the Bedrock Agent Runtime API. 35 | 36 | The client supports operations like: 37 | - Session management (create, get, end, delete) 38 | - Invocation management (create, list) 39 | - Invocation step management (put, get, list) 40 | 41 | Attributes: 42 | client: Boto3 client for bedrock-agent-runtime service 43 | """ 44 | 45 | def __init__( 46 | self, 47 | region_name: Optional[str] = None, 48 | credentials_profile_name: Optional[str] = None, 49 | aws_access_key_id: Optional[SecretStr] = None, 50 | aws_secret_access_key: Optional[SecretStr] = None, 51 | aws_session_token: Optional[SecretStr] = None, 52 | endpoint_url: Optional[str] = None, 53 | config: Optional[Config] = None, 54 | ): 55 | """ 56 | Initialize BedrockAgentRuntime with AWS configuration 57 | 58 | Args: 59 | region_name: AWS region (e.g., us-west-2) 60 | credentials_profile_name: AWS credentials profile name 61 | aws_access_key_id: AWS access key ID 62 | aws_secret_access_key: AWS secret access key 63 | aws_session_token: AWS session token 64 | endpoint_url: Custom endpoint URL 65 | config: Boto3 config object 66 | """ 67 | _session_kwargs, _client_kwargs = process_aws_client_args( 68 | region_name, 69 | credentials_profile_name, 70 | aws_access_key_id, 71 | aws_secret_access_key, 72 | aws_session_token, 73 | endpoint_url, 74 | config, 75 | ) 76 | session = boto3.Session(**_session_kwargs) 77 | self.client = session.client("bedrock-agent-runtime", **_client_kwargs) 78 | 79 | def create_session( 80 | self, request: Optional[CreateSessionRequest] = None 81 | ) -> CreateSessionResponse: 82 | """ 83 | Create a new session 84 | 85 | Args: 86 | request (CreateSessionRequest): Optional object containing session creation details 87 | 88 | Returns: 89 | CreateSessionResponse: Response object containing session identifier and metadata 90 | """ 91 | 92 | response = self.client.create_session( 93 | **to_boto_params(request) if request else {} 94 | ) 95 | return CreateSessionResponse(**response) 96 | 97 | def get_session(self, request: GetSessionRequest) -> GetSessionResponse: 98 | """ 99 | Get details of an existing session 100 | 101 | Args: 102 | request (GetSessionRequest): Object containing session identifier 103 | 104 | Returns: 105 | GetSessionResponse: Response object containing session details and metadata 106 | """ 107 | response = self.client.get_session(**to_boto_params(request)) 108 | return GetSessionResponse(**response) 109 | 110 | def end_session(self, request: EndSessionRequest) -> EndSessionResponse: 111 | """ 112 | End an existing session 113 | 114 | Args: 115 | request (EndSessionRequest): Object containing session identifier 116 | 117 | Returns: 118 | EndSessionResponse: Response object containing the ended session details 119 | """ 120 | response = self.client.end_session(**to_boto_params(request)) 121 | return EndSessionResponse(**response) 122 | 123 | def delete_session(self, request: DeleteSessionRequest) -> None: 124 | """ 125 | Delete an existing session 126 | 127 | Args: 128 | request (DeleteSessionRequest): Object containing session identifier 129 | """ 130 | self.client.delete_session(**to_boto_params(request)) 131 | 132 | def create_invocation( 133 | self, request: CreateInvocationRequest 134 | ) -> CreateInvocationResponse: 135 | """ 136 | Create a new invocation 137 | 138 | Args: 139 | request (CreateInvocationRequest): Object containing invocation details 140 | 141 | Returns: 142 | CreateInvocationResponse: Response object containing invocation identifier and metadata 143 | """ 144 | response = self.client.create_invocation(**to_boto_params(request)) 145 | return CreateInvocationResponse(**response) 146 | 147 | def list_invocations( 148 | self, request: ListInvocationsRequest 149 | ) -> ListInvocationsResponse: 150 | """ 151 | List invocations for a session 152 | 153 | Args: 154 | request (ListInvocationsRequest): Object containing session identifier 155 | 156 | Returns: 157 | ListInvocationsResponse: Response object containing list of invocations and pagination token 158 | """ 159 | response = self.client.list_invocations(**to_boto_params(request)) 160 | return ListInvocationsResponse(**response) 161 | 162 | def put_invocation_step( 163 | self, request: PutInvocationStepRequest 164 | ) -> PutInvocationStepResponse: 165 | """ 166 | Put a step in an invocation 167 | 168 | Args: 169 | request (PutInvocationStepRequest): Object containing invocation identifier and step payload 170 | 171 | Returns: 172 | PutInvocationStepResponse: Response object containing invocation step identifier 173 | """ 174 | response = self.client.put_invocation_step(**to_boto_params(request)) 175 | return PutInvocationStepResponse(**response) 176 | 177 | def get_invocation_step( 178 | self, request: GetInvocationStepRequest 179 | ) -> GetInvocationStepResponse: 180 | """ 181 | Get a step in an invocation 182 | 183 | Args: 184 | request (GetInvocationStepRequest): Object containing invocation and step identifiers 185 | 186 | Returns: 187 | GetInvocationStepResponse: Response object containing invocation step identifier and payload 188 | """ 189 | response = self.client.get_invocation_step(**to_boto_params(request)) 190 | return GetInvocationStepResponse(**response) 191 | 192 | def list_invocation_steps( 193 | self, request: ListInvocationStepsRequest 194 | ) -> ListInvocationStepsResponse: 195 | """ 196 | List steps in an invocation 197 | 198 | Args: 199 | request (ListInvocationStepsRequest): Object containing invocation step id and pagination token 200 | 201 | Returns: 202 | ListInvocationStepsResponse: Response object containing list of invocation steps and pagination token 203 | """ 204 | response = self.client.list_invocation_steps(**to_boto_params(request)) 205 | return ListInvocationStepsResponse(**response) 206 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "langgraph-checkpoint-aws" 7 | version = "0.1.0" 8 | description = "A LangChain checkpointer implementation that uses Bedrock Session Management Service to enable stateful and resumable LangGraph agents." 9 | authors = [] 10 | license = "MIT" 11 | readme = "README.md" 12 | repository = "https://github.com/langchain-ai/langchain-aws" 13 | keywords = ["aws", "bedrock", "langchain", "langgraph", "checkpointer"] 14 | 15 | [tool.poetry.urls] 16 | "Source Code" = "https://github.com/langchain-ai/langchain-aws/tree/main/libs/langgraph-checkpoint-aws" 17 | 18 | [tool.poetry.dependencies] 19 | python = ">=3.9,<4.0" 20 | langgraph-checkpoint = ">=2.0.0" 21 | langgraph = ">=0.2.55" 22 | boto3 = ">=1.37.3" 23 | 24 | [tool.poetry.group.dev] 25 | optional = true 26 | 27 | [tool.poetry.group.dev.dependencies] 28 | ruff = ">=0.1.9" 29 | mypy = ">=1.7.1" 30 | codespell = ">=2.2.6" 31 | 32 | [tool.poetry.group.test] 33 | optional = true 34 | 35 | [tool.poetry.group.test.dependencies] 36 | pytest = ">=7.4.3" 37 | pytest-cov = ">=4.1.0" 38 | 39 | [tool.poetry.group.test_integration] 40 | optional = true 41 | 42 | [tool.poetry.group.test_integration.dependencies] 43 | langchain-aws = ">=0.2.14" 44 | 45 | [tool.poetry.group.lint] 46 | optional = true 47 | 48 | [tool.poetry.group.lint.dependencies] 49 | ruff = ">=0.1.9" 50 | 51 | [tool.poetry.group.typing] 52 | optional = true 53 | 54 | [tool.poetry.group.typing.dependencies] 55 | mypy = ">=1.7.1" 56 | 57 | [tool.poetry.group.codespell] 58 | optional = true 59 | 60 | [tool.poetry.group.codespell.dependencies] 61 | codespell = ">=2.2.6" 62 | 63 | [tool.ruff] 64 | lint.select = [ 65 | "E", # pycodestyle 66 | "F", # Pyflakes 67 | "UP", # pyupgrade 68 | "B", # flake8-bugbear 69 | "I", # isort 70 | "T201", # print 71 | ] 72 | 73 | [tool.codespell] 74 | skip = '.git,*.pdf,*.svg' 75 | ignore-words-list = '' 76 | 77 | [tool.mypy] 78 | ignore_missing_imports = "True" 79 | 80 | [tool.coverage.run] 81 | omit = ["tests/*"] 82 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/langgraph-checkpoint-aws/tests/__init__.py -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/langgraph-checkpoint-aws/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/integration_tests/saver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/__init__.py -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_saver.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Literal 3 | 4 | import pytest 5 | from langchain_aws import ChatBedrock 6 | from langchain_core.tools import tool 7 | from langgraph.checkpoint.base import Checkpoint, uuid6 8 | from langgraph.prebuilt import create_react_agent 9 | 10 | from langgraph_checkpoint_aws.saver import BedrockSessionSaver 11 | 12 | 13 | @tool 14 | def get_weather(city: Literal["nyc", "sf"]): 15 | """Use this to get weather information.""" 16 | if city == "nyc": 17 | return "It might be cloudy in nyc" 18 | elif city == "sf": 19 | return "It's always sunny in sf" 20 | else: 21 | raise AssertionError("Unknown city") 22 | 23 | 24 | class TestBedrockMemorySaver: 25 | @pytest.fixture 26 | def tools(self): 27 | # Setup tools 28 | return [get_weather] 29 | 30 | @pytest.fixture 31 | def model(self): 32 | # Setup model 33 | return ChatBedrock( 34 | model="anthropic.claude-3-sonnet-20240229-v1:0", region="us-west-2" 35 | ) 36 | 37 | @pytest.fixture 38 | def session_saver(self): 39 | # Setup session saver 40 | return BedrockSessionSaver(region_name="us-west-2") 41 | 42 | @pytest.fixture 43 | def boto_session_client(self, session_saver): 44 | return session_saver.session_client.client 45 | 46 | def test_weather_tool_responses(self): 47 | # Test weather tool directly 48 | assert get_weather.invoke("sf") == "It's always sunny in sf" 49 | assert get_weather.invoke("nyc") == "It might be cloudy in nyc" 50 | 51 | def test_checkpoint_save_and_retrieve(self, boto_session_client, session_saver): 52 | # Create session 53 | session_id = boto_session_client.create_session()["sessionId"] 54 | assert session_id, "Session ID should not be empty" 55 | 56 | config = {"configurable": {"thread_id": session_id, "checkpoint_ns": ""}} 57 | checkpoint = Checkpoint( 58 | v=1, 59 | id=str(uuid6(clock_seq=-2)), 60 | ts=datetime.datetime.now(datetime.timezone.utc).isoformat(), 61 | channel_values={"key": "value"}, 62 | channel_versions={}, 63 | versions_seen={}, 64 | pending_sends=[], 65 | ) 66 | checkpoint_metadata = {"source": "input", "step": 1, "writes": {"key": "value"}} 67 | 68 | try: 69 | saved_config = session_saver.put( 70 | config, 71 | checkpoint, 72 | checkpoint_metadata, 73 | {}, 74 | ) 75 | assert saved_config == { 76 | "configurable": { 77 | "checkpoint_id": checkpoint["id"], 78 | "checkpoint_ns": "", 79 | "thread_id": session_id, 80 | } 81 | } 82 | 83 | checkpoint_tuple = session_saver.get_tuple(saved_config) 84 | assert checkpoint_tuple.checkpoint == checkpoint 85 | assert checkpoint_tuple.metadata == checkpoint_metadata 86 | assert checkpoint_tuple.config == saved_config 87 | 88 | finally: 89 | boto_session_client.end_session(sessionIdentifier=session_id) 90 | boto_session_client.delete_session(sessionIdentifier=session_id) 91 | 92 | def test_weather_query_and_checkpointing( 93 | self, boto_session_client, tools, model, session_saver 94 | ): 95 | # Create session 96 | session_id = boto_session_client.create_session()["sessionId"] 97 | assert session_id, "Session ID should not be empty" 98 | try: 99 | # Create graph and config 100 | graph = create_react_agent(model, tools=tools, checkpointer=session_saver) 101 | config = {"configurable": {"thread_id": session_id}} 102 | 103 | # Test weather query 104 | response = graph.invoke( 105 | {"messages": [("human", "what's the weather in sf")]}, config 106 | ) 107 | assert response, "Response should not be empty" 108 | 109 | # Test checkpoint retrieval 110 | checkpoint = session_saver.get(config) 111 | assert checkpoint, "Checkpoint should not be empty" 112 | 113 | # Test checkpoint listing 114 | checkpoint_tuples = list(session_saver.list(config)) 115 | assert checkpoint_tuples, "Checkpoint tuples should not be empty" 116 | assert isinstance( 117 | checkpoint_tuples, list 118 | ), "Checkpoint tuples should be a list" 119 | finally: 120 | boto_session_client.end_session(sessionIdentifier=session_id) 121 | boto_session_client.delete_session(sessionIdentifier=session_id) 122 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/libs/langgraph-checkpoint-aws/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-aws/tests/unit_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from unittest.mock import Mock, patch 3 | 4 | import boto3 5 | import pytest 6 | from botocore.config import Config 7 | from langgraph.checkpoint.base import CheckpointTuple 8 | from langgraph.checkpoint.serde.base import SerializerProtocol 9 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer 10 | 11 | from langgraph_checkpoint_aws import SDK_USER_AGENT 12 | from langgraph_checkpoint_aws.models import ( 13 | CreateSessionRequest, 14 | ) 15 | from langgraph_checkpoint_aws.utils import ( 16 | construct_checkpoint_tuple, 17 | deserialize_data, 18 | deserialize_from_base64, 19 | generate_checkpoint_id, 20 | generate_deterministic_uuid, 21 | generate_write_id, 22 | process_aws_client_args, 23 | serialize_data, 24 | serialize_to_base64, 25 | to_boto_params, 26 | ) 27 | 28 | 29 | class TestUtils: 30 | @pytest.fixture 31 | def json_serializer(self): 32 | return JsonPlusSerializer() 33 | 34 | def test_session_models_conversion(self, sample_metadata, sample_kms_key_arn): 35 | request = CreateSessionRequest( 36 | session_metadata=sample_metadata, 37 | encryption_key_arn=sample_kms_key_arn, 38 | ) 39 | 40 | result = to_boto_params(request) 41 | 42 | assert result == { 43 | "encryptionKeyArn": sample_kms_key_arn, 44 | "sessionMetadata": {"key1": "value1", "key2": "value2"}, 45 | } 46 | 47 | # Test without optional fields 48 | request = CreateSessionRequest() 49 | result = to_boto_params(request) 50 | assert result == {} 51 | 52 | @pytest.mark.parametrize( 53 | "test_case", 54 | [ 55 | ("", "d41d8cd9-8f00-b204-e980-0998ecf8427e"), 56 | ("test-string", "661f8009-fa8e-56a9-d0e9-4a0a644397d7"), 57 | ( 58 | "checkpoint$abc|def$11111111-1111-1111-1111-111111111111", 59 | "321a564a-a10d-4ffe-ae32-b32c1131af27", 60 | ), 61 | ], 62 | ) 63 | def test_generate_deterministic_uuid(self, test_case): 64 | input_string, expected_uuid = test_case 65 | input_string_bytes = input_string.encode("utf-8") 66 | result_as_str = generate_deterministic_uuid(input_string) 67 | result_as_bytes = generate_deterministic_uuid(input_string_bytes) 68 | 69 | assert isinstance(result_as_str, uuid.UUID) 70 | assert isinstance(result_as_bytes, uuid.UUID) 71 | # Test deterministic behavior 72 | assert str(result_as_str) == expected_uuid 73 | assert str(result_as_bytes) == expected_uuid 74 | 75 | def test__generate_checkpoint_id_success(self): 76 | input_str = "test_namespace" 77 | result = generate_checkpoint_id(input_str) 78 | assert result == "72f4457f-e6bb-e1db-49ee-06cd9901904f" 79 | 80 | def test__generate_write_id_success(self): 81 | checkpoint_ns = "test_namespace" 82 | checkpoint_id = "test_checkpoint" 83 | result = generate_write_id(checkpoint_ns, checkpoint_id) 84 | assert result == "f75c463a-a608-0629-401e-f4d270073c0c" 85 | 86 | def test_serialize_deserialize_success(self, json_serializer): 87 | sample_dict = {"key": "value"} 88 | serialized = serialize_data(json_serializer, sample_dict) 89 | deserialized = deserialize_data(json_serializer, serialized) 90 | assert deserialized == sample_dict 91 | 92 | def test_serialize_deserialize_base64_success(self, json_serializer): 93 | sample_dict = {"key": "value"} 94 | serialized = serialize_to_base64(json_serializer, sample_dict) 95 | deserialized = deserialize_from_base64(json_serializer, *serialized) 96 | assert deserialized == sample_dict 97 | 98 | @patch("langgraph_checkpoint_aws.utils.deserialize_from_base64") 99 | @patch("langgraph_checkpoint_aws.utils.deserialize_data") 100 | def test__construct_checkpoint_tuple( 101 | self, 102 | mock_deserialize_data, 103 | mock_deserialize_from_base64, 104 | sample_session_checkpoint, 105 | sample_session_pending_write, 106 | ): 107 | # Arrange 108 | thread_id = "test_thread_id" 109 | checkpoint_ns = "test_namespace" 110 | 111 | serde = Mock(spec=SerializerProtocol) 112 | mock_deserialize_data.return_value = {} 113 | mock_deserialize_from_base64.return_value = {} 114 | 115 | # Act 116 | result = construct_checkpoint_tuple( 117 | thread_id, 118 | checkpoint_ns, 119 | sample_session_checkpoint, 120 | [sample_session_pending_write], 121 | [], 122 | serde, 123 | ) 124 | 125 | # Assert 126 | assert isinstance(result, CheckpointTuple) 127 | assert result.config["configurable"]["thread_id"] == thread_id 128 | assert result.config["configurable"]["checkpoint_ns"] == checkpoint_ns 129 | 130 | 131 | @patch("botocore.client.BaseClient._make_request") 132 | def test_process_aws_client_args_user_agent(mock_make_request): 133 | # Setup 134 | config = Config(user_agent_extra="existing_agent") 135 | 136 | # Process args 137 | session_kwargs, client_kwargs = process_aws_client_args( 138 | region_name="us-west-2", config=config 139 | ) 140 | 141 | # Create session and client 142 | session = boto3.Session(**session_kwargs) 143 | client = session.client("bedrock-agent-runtime", **client_kwargs) 144 | 145 | # Trigger a request to capture the user agent 146 | try: 147 | client.create_session() 148 | except Exception: 149 | pass 150 | 151 | # Verify user agent in the request headers 152 | actual_user_agent = mock_make_request.call_args[0][1]["headers"]["User-Agent"] 153 | assert "existing_agent" in actual_user_agent 154 | assert f"md/sdk_user_agent/{SDK_USER_AGENT}" in actual_user_agent 155 | -------------------------------------------------------------------------------- /samples/document_compressors/rerank.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Rerank Document Compressor\n", 8 | "\n", 9 | "In this notebook we will go through how you can use a rerank document compressor with Bedrock.\n", 10 | "\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import boto3\n", 20 | "\n", 21 | "session = boto3.Session()\n", 22 | "client = session.client('bedrock')\n", 23 | "foundation_model = client.get_foundation_model(modelIdentifier=\"amazon.rerank-v1:0\")\n", 24 | "\n", 25 | "model_arn = foundation_model[\"modelDetails\"][\"modelArn\"]" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "The code below processes a list of documents to determine their relevance to a given query using AWS Bedrock's reranking capabilities. It initializes a BedrockRerank instance, providing a list of documents and a query. The `compress_documents` method then evaluates and ranks the documents based on relevance, ensuring that the most relevant ones are prioritized." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "Content: AWS Bedrock enables access to AI models.\n", 45 | "Score: 0.07081620395183563\n", 46 | "Content: Artificial intelligence is transforming the world.\n", 47 | "Score: 2.8350802949717036e-06\n", 48 | "Content: LangChain is a powerful library for LLMs.\n", 49 | "Score: 1.5903378880466335e-06\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "from langchain_core.documents import Document\n", 55 | "from langchain_aws import BedrockRerank\n", 56 | "\n", 57 | "# Initialize the class\n", 58 | "reranker = BedrockRerank(model_arn=model_arn)\n", 59 | "\n", 60 | "# List of documents to rerank\n", 61 | "documents = [\n", 62 | " Document(page_content=\"LangChain is a powerful library for LLMs.\"),\n", 63 | " Document(page_content=\"AWS Bedrock enables access to AI models.\"),\n", 64 | " Document(page_content=\"Artificial intelligence is transforming the world.\"),\n", 65 | "]\n", 66 | "\n", 67 | "# Query for reranking\n", 68 | "query = \"What is AWS Bedrock?\"\n", 69 | "\n", 70 | "# Call the rerank method\n", 71 | "results = reranker.compress_documents(documents, query)\n", 72 | "\n", 73 | "# Display the most relevant documents\n", 74 | "for doc in results:\n", 75 | " print(f\"Content: {doc.page_content}\")\n", 76 | " print(f\"Score: {doc.metadata['relevance_score']}\")" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Now let's enhance our base retriever by wrapping it with a `ContextualCompressionRetriever`. Here, we integrate `BedrockRerank`, which leverages AWS Bedrock's reranking capabilities to refine the retrieved results.\n", 84 | "\n", 85 | "When a query is executed, the retriever first retrieves relevant documents using FAISS and then reranks them based on relevance, providing more accurate and meaningful responses." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Content: AWS Bedrock provides cloud-based AI models.\n", 98 | "Score: 0.07585818320512772\n", 99 | "Content: Machine learning can be used for predictions.\n", 100 | "Score: 2.8573158488143235e-06\n", 101 | "Content: LangChain integrates LLM models.\n", 102 | "Score: 1.640820528336917e-06\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "from langchain_aws import BedrockEmbeddings\n", 108 | "from langchain.retrievers.contextual_compression import ContextualCompressionRetriever\n", 109 | "from langchain.vectorstores import FAISS\n", 110 | "from langchain_core.documents import Document\n", 111 | "from langchain_aws import BedrockRerank\n", 112 | "\n", 113 | "# Create a vector store using FAISS with Bedrock embeddings\n", 114 | "documents = [\n", 115 | " Document(page_content=\"LangChain integrates LLM models.\"),\n", 116 | " Document(page_content=\"AWS Bedrock provides cloud-based AI models.\"),\n", 117 | " Document(page_content=\"Machine learning can be used for predictions.\"),\n", 118 | "]\n", 119 | "embeddings = BedrockEmbeddings()\n", 120 | "vectorstore = FAISS.from_documents(documents, embeddings)\n", 121 | "\n", 122 | "# Create the document compressor using BedrockRerank\n", 123 | "reranker = BedrockRerank(model_arn=model_arn)\n", 124 | "\n", 125 | "# Create the retriever with contextual compression\n", 126 | "retriever = ContextualCompressionRetriever(\n", 127 | " base_compressor=reranker,\n", 128 | " base_retriever=vectorstore.as_retriever(),\n", 129 | ")\n", 130 | "\n", 131 | "# Execute a query\n", 132 | "query = \"How does AWS Bedrock work?\"\n", 133 | "retrieved_docs = retriever.invoke(query)\n", 134 | "\n", 135 | "# Display the most relevant documents\n", 136 | "for doc in retrieved_docs:\n", 137 | " print(f\"Content: {doc.page_content}\")\n", 138 | " print(f\"Score: {doc.metadata.get('relevance_score', 'N/A')}\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "Unlike `compress_documents`, which works with structured Document objects, the rerank method allows passing plain text strings. This simplifies the process of evaluating and ranking text data." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 6, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "Index: 1, Score: 0.07159119844436646\n", 158 | "Document: AWS Bedrock provides access to cloud-based models.\n", 159 | "Index: 2, Score: 9.666109690442681e-06\n", 160 | "Document: Machine learning is revolutionizing the world.\n", 161 | "Index: 0, Score: 8.25057043130073e-07\n", 162 | "Document: LangChain is used to integrate LLM models.\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "from langchain_aws import BedrockRerank\n", 168 | "\n", 169 | "# Initialize BedrockRerank\n", 170 | "reranker = BedrockRerank(model_arn=model_arn)\n", 171 | "\n", 172 | "# Unstructured documents\n", 173 | "documents = [\n", 174 | " \"LangChain is used to integrate LLM models.\",\n", 175 | " \"AWS Bedrock provides access to cloud-based models.\",\n", 176 | " \"Machine learning is revolutionizing the world.\",\n", 177 | "]\n", 178 | "\n", 179 | "# Query\n", 180 | "query = \"What is the role of AWS Bedrock?\"\n", 181 | "\n", 182 | "# Rerank the documents\n", 183 | "results = reranker.rerank(query=query, documents=documents)\n", 184 | "\n", 185 | "# Display the results\n", 186 | "for res in results:\n", 187 | " print(f\"Index: {res['index']}, Score: {res['relevance_score']}\")\n", 188 | " print(f\"Document: {documents[res['index']]}\")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": ".venv", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.10.15" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /samples/inmemory/memorydb-guide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-aws/9b2c03ef7e9114fae9bd4cc44a28dfd621d195e0/samples/inmemory/memorydb-guide.pdf -------------------------------------------------------------------------------- /samples/inmemory/semantic_cache.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "04ad9c80-2b3a-4df1-bead-fb347d51359f", 6 | "metadata": {}, 7 | "source": [ 8 | "## 1. Set environment vairable for MemoryDB cluster " 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "99ba6c98-8486-45c6-a15c-6025777df3cb", 14 | "metadata": {}, 15 | "source": [ 16 | "## 2. Install packages" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "e6ecf91d-454a-44c4-ab16-31faf6c541ae", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Install a pip package in the current Jupyter kernel\n", 27 | "import sys\n", 28 | "!{sys.executable} -m pip install langchain_core\n", 29 | "!{sys.executable} -m pip install langchain_aws\n", 30 | "!{sys.executable} -m pip install redis" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "d10cf385-93ac-4b03-8532-30efebfb6061", 37 | "metadata": { 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "import os\n", 43 | "from langchain_core.globals import set_llm_cache\n", 44 | "from langchain_aws import InMemorySemanticCache\n", 45 | "from langchain_aws import ChatBedrock\n", 46 | "from langchain_aws.embeddings import BedrockEmbeddings\n", 47 | "import redis\n", 48 | "from redis.cluster import RedisCluster as MemoryDB" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "15a34675-5c66-4ea1-9a1c-b5a205302216", 54 | "metadata": {}, 55 | "source": [ 56 | "## Initialize the ChatBedrock and embeddings " 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "ccfa25cf-26aa-45f9-9a2a-df1a1606fa68", 63 | "metadata": { 64 | "tags": [] 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "# create the Anthropic Model\n", 69 | "model_kwargs = {\n", 70 | " \"temperature\": 0, \n", 71 | " \"top_k\": 250, \n", 72 | " \"top_p\": 1,\n", 73 | " \"stop_sequences\": [\"\\\\n\\\\nHuman:\"]\n", 74 | "} " 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "2481182f-3432-4e59-a14e-b00f31d0b84f", 81 | "metadata": { 82 | "tags": [] 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# use the Anthropic Claude model\n", 87 | "llm = ChatBedrock(\n", 88 | " model_id=\"anthropic.claude-3-sonnet-20240229-v1:0\",\n", 89 | " model_kwargs=model_kwargs\n", 90 | ")\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "6ba15759-b194-4fd8-a9a6-1be417b47887", 97 | "metadata": { 98 | "tags": [] 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "# create a Titan Embeddings client\n", 103 | "embeddings = BedrockEmbeddings()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "74c7aec2-8efa-4f2d-bca3-09a148d38efc", 109 | "metadata": {}, 110 | "source": [ 111 | "## Connect to MemoryDB" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "27319029-fba1-401e-84b1-f63cfcfb1f17", 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "%%time\n", 124 | "memorydb_host = os.environ.get(\"MEMORYDB_HOST\", \"localhost\")\n", 125 | "memorydb_port = os.environ.get(\"MEMORYDB_PORT\", 6379)\n", 126 | "# print(f\"MemoryDB Url = {memorydb_host}:{memorydb_port}\")\n", 127 | "rc = MemoryDB(host=memorydb_host, port=memorydb_port, ssl=False, decode_responses=False, ssl_cert_reqs=\"none\")\n", 128 | "rc.ping()\n", 129 | "#rc.flushall()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "1f7fe013-f713-4797-ae38-fe7a0f3e32b3", 135 | "metadata": {}, 136 | "source": [ 137 | "## Submit a query without setting up cache" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "68fe670e-9c9c-4d2f-9c39-29c2279a9512", 144 | "metadata": { 145 | "tags": [] 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "%%time\n", 150 | "response=llm.invoke(\"Tell me about mission to moon\")\n", 151 | "print(response.content)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "id": "2d883259-1b8d-43e9-bd08-c8e6438b0b19", 157 | "metadata": {}, 158 | "source": [ 159 | "## Enable MemoryDB for durable semantic caching " 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "4f00da6f-a370-418f-a305-fc2a59f0fc78", 166 | "metadata": { 167 | "tags": [] 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "set_llm_cache(\n", 172 | " InMemorySemanticCache(redis_url=f\"redis://{memorydb_host}:{memorydb_port}/ssl=True&ssl_cert_reqs=none\",\n", 173 | " embedding=embeddings)\n", 174 | ")" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "id": "0ac943fd-6ffe-4c99-ac5c-e9fbf93b7e44", 180 | "metadata": {}, 181 | "source": [ 182 | "### Submit a query to the LLM and Re-run the same block to see the improvemnt in response time. " 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "0ad4cc96-69eb-4a47-ab6f-b21d5592d9fc", 189 | "metadata": { 190 | "tags": [] 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "%%time\n", 195 | "response=llm.invoke(\"Tell me about mission to moon\")\n", 196 | "print(response.content)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "2c668188-5658-492d-8fce-bea9063d2cfc", 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "%%time\n", 207 | "response=llm.invoke(\"Who first invented a telescope\")\n", 208 | "print(response.content)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "6665d884-08fe-46a8-a6b4-fcf42df44d43", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "%%time\n", 219 | "response=llm.invoke(\"Who first invented a car\")\n", 220 | "print(response.content)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "7d645f3c-d721-490c-a734-8c22158b227f", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "%%time\n", 231 | "respone3=llm.invoke(\"Who first a Telescope\")\n", 232 | "print(respone3.content)" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "conda_python3", 239 | "language": "python", 240 | "name": "conda_python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.10.14" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 5 257 | } 258 | --------------------------------------------------------------------------------