├── .github ├── actions │ └── poetry_setup │ │ └── action.yml ├── scripts │ ├── check_diff.py │ └── get_min_versions.py └── workflows │ ├── _all_ci.yml │ ├── _codespell.yml │ ├── _compile_integration_test.yml │ ├── _integration_test.yml │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ ├── _test_release.yml │ ├── check_diffs.yml │ └── extract_ignored_words_list.py ├── .gitignore ├── LICENSE ├── PULL_REQUEST_TEMPLATE.md ├── README.md ├── libs ├── community │ ├── .gitignore │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── langchain_google_community │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── bigquery.py │ │ ├── bigquery_vector_search.py │ │ ├── bq_storage_vectorstores │ │ │ ├── _base.py │ │ │ ├── bigquery.py │ │ │ ├── featurestore.py │ │ │ └── utils.py │ │ ├── calendar │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── create_event.py │ │ │ ├── current_datetime.py │ │ │ ├── delete_event.py │ │ │ ├── get_calendars_info.py │ │ │ ├── move_event.py │ │ │ ├── search_events.py │ │ │ ├── toolkit.py │ │ │ ├── update_event.py │ │ │ └── utils.py │ │ ├── docai.py │ │ ├── documentai_warehouse.py │ │ ├── drive.py │ │ ├── gcs_directory.py │ │ ├── gcs_file.py │ │ ├── geocoding.py │ │ ├── gmail │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── create_draft.py │ │ │ ├── get_message.py │ │ │ ├── get_thread.py │ │ │ ├── loader.py │ │ │ ├── search.py │ │ │ ├── send_message.py │ │ │ ├── toolkit.py │ │ │ └── utils.py │ │ ├── google_speech_to_text.py │ │ ├── places_api.py │ │ ├── search.py │ │ ├── texttospeech.py │ │ ├── translate.py │ │ ├── vertex_ai_search.py │ │ ├── vertex_check_grounding.py │ │ ├── vertex_rank.py │ │ └── vision.py │ ├── poetry.lock │ ├── pyproject.toml │ ├── scripts │ │ ├── check_imports.py │ │ └── lint_imports.sh │ └── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── integration_tests │ │ ├── .env.example │ │ ├── __init__.py │ │ ├── fake.py │ │ ├── feature_store │ │ │ ├── test_feature_store_bq_vectorstore.py │ │ │ └── test_feature_store_fs_vectorstore.py │ │ ├── terraform │ │ │ └── main.tf │ │ ├── test_bigquery.py │ │ ├── test_bigquery_vector_search.py │ │ ├── test_check_grounding.py │ │ ├── test_docai.py │ │ ├── test_docai_warehoure_retriever.py │ │ ├── test_geocoding_integration.py │ │ ├── test_googlesearch_api.py │ │ ├── test_placeholder.py │ │ ├── test_rank.py │ │ ├── test_vertex_ai_search.py │ │ └── test_vision.py │ │ └── unit_tests │ │ ├── __init__.py │ │ ├── test_check_grounding.py │ │ ├── test_docai.py │ │ ├── test_drive.py │ │ ├── test_google_calendar.py │ │ ├── test_googlesearch_api.py │ │ ├── test_placeholder.py │ │ ├── test_rank.py │ │ ├── test_utils.py │ │ └── test_vertex_ai_search.py ├── genai │ ├── .gitignore │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── langchain_google_genai │ │ ├── __init__.py │ │ ├── _common.py │ │ ├── _enums.py │ │ ├── _function_utils.py │ │ ├── _genai_extension.py │ │ ├── _image_utils.py │ │ ├── chat_models.py │ │ ├── embeddings.py │ │ ├── genai_aqa.py │ │ ├── google_vector_store.py │ │ ├── llms.py │ │ └── py.typed │ ├── poetry.lock │ ├── pyproject.toml │ ├── scripts │ │ ├── check_imports.py │ │ └── lint_imports.sh │ └── tests │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── integration_tests │ │ ├── .env.example │ │ ├── __init__.py │ │ ├── terraform │ │ │ └── main.tf │ │ ├── test_callbacks.py │ │ ├── test_chat_models.py │ │ ├── test_compile.py │ │ ├── test_embeddings.py │ │ ├── test_function_call.py │ │ ├── test_llms.py │ │ ├── test_standard.py │ │ └── test_tools.py │ │ └── unit_tests │ │ ├── __init__.py │ │ ├── __snapshots__ │ │ └── test_standard.ambr │ │ ├── test_chat_models.py │ │ ├── test_common.py │ │ ├── test_embeddings.py │ │ ├── test_function_utils.py │ │ ├── test_genai_aqa.py │ │ ├── test_google_vector_store.py │ │ ├── test_imports.py │ │ ├── test_llms.py │ │ └── test_standard.py └── vertexai │ ├── .gitignore │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── langchain_google_vertexai │ ├── __init__.py │ ├── _anthropic_parsers.py │ ├── _anthropic_utils.py │ ├── _base.py │ ├── _client_utils.py │ ├── _enums.py │ ├── _image_utils.py │ ├── _retry.py │ ├── _utils.py │ ├── callbacks.py │ ├── chains.py │ ├── chat_models.py │ ├── embeddings.py │ ├── evaluators │ │ ├── __init__.py │ │ ├── _core.py │ │ └── evaluation.py │ ├── functions_utils.py │ ├── gemma.py │ ├── llms.py │ ├── model_garden.py │ ├── model_garden_maas │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── llama.py │ │ └── mistral.py │ ├── py.typed │ ├── utils.py │ ├── vectorstores │ │ ├── __init__.py │ │ ├── _sdk_manager.py │ │ ├── _searcher.py │ │ ├── _utils.py │ │ ├── document_storage.py │ │ └── vectorstores.py │ └── vision_models.py │ ├── poetry.lock │ ├── pyproject.toml │ ├── scripts │ ├── check_imports.py │ └── lint_imports.sh │ └── tests │ ├── __init__.py │ ├── conftest.py │ ├── integration_tests │ ├── .env.example │ ├── TODO.md │ ├── __init__.py │ ├── conftest.py │ ├── terraform │ │ └── main.tf │ ├── test_anthropic_cache.py │ ├── test_anthropic_files.py │ ├── test_callbacks.py │ ├── test_chains.py │ ├── test_chat_models.py │ ├── test_compile.py │ ├── test_embeddings.py │ ├── test_evaluation.py │ ├── test_gemma.py │ ├── test_image_utils.py │ ├── test_llms.py │ ├── test_llms_safety.py │ ├── test_maas.py │ ├── test_medlm.py │ ├── test_model_garden.py │ ├── test_standard.py │ ├── test_vectorstores.py │ └── test_vision_models.py │ └── unit_tests │ ├── __init__.py │ ├── __snapshots__ │ └── test_standard.ambr │ ├── test_anthropic_utils.py │ ├── test_chains.py │ ├── test_chat_models.py │ ├── test_embeddings.py │ ├── test_evaluation.py │ ├── test_function_utils.py │ ├── test_image_utils.py │ ├── test_imports.py │ ├── test_llm.py │ ├── test_maas.py │ ├── test_model_garden_retry.py │ ├── test_model_garden_timeout.py │ ├── test_standard.py │ ├── test_utils.py │ ├── test_vectorstores.py │ └── test_vision_models.py └── terraform ├── cloudbuild ├── main.tf └── variables.tf ├── github-connection ├── main.tf └── variables.tf └── secrets ├── main.tf └── variables.tf /.github/actions/poetry_setup/action.yml: -------------------------------------------------------------------------------- 1 | # An action for setting up poetry install with caching. 2 | # Using a custom action since the default action does not 3 | # take poetry install groups into account. 4 | # Action code from: 5 | # https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 6 | name: poetry-install-with-caching 7 | description: Poetry install with support for caching of dependency groups. 8 | 9 | inputs: 10 | python-version: 11 | description: Python version, supporting MAJOR.MINOR only 12 | required: true 13 | 14 | poetry-version: 15 | description: Poetry version 16 | required: true 17 | 18 | cache-key: 19 | description: Cache key to use for manual handling of caching 20 | required: true 21 | 22 | working-directory: 23 | description: Directory whose poetry.lock file should be cached 24 | required: true 25 | 26 | runs: 27 | using: composite 28 | steps: 29 | - uses: actions/setup-python@v5 30 | name: Setup python ${{ inputs.python-version }} 31 | id: setup-python 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - uses: actions/cache@v4 36 | id: cache-bin-poetry 37 | name: Cache Poetry binary - Python ${{ inputs.python-version }} 38 | env: 39 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" 40 | with: 41 | path: | 42 | /opt/pipx/venvs/poetry 43 | # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. 44 | key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} 45 | 46 | - name: Refresh shell hashtable and fixup softlinks 47 | if: steps.cache-bin-poetry.outputs.cache-hit == 'true' 48 | shell: bash 49 | env: 50 | POETRY_VERSION: ${{ inputs.poetry-version }} 51 | PYTHON_VERSION: ${{ inputs.python-version }} 52 | run: | 53 | set -eux 54 | 55 | # Refresh the shell hashtable, to ensure correct `which` output. 56 | hash -r 57 | 58 | # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. 59 | # Delete and recreate the softlinks pipx expects to have. 60 | rm /opt/pipx/venvs/poetry/bin/python 61 | cd /opt/pipx/venvs/poetry/bin 62 | ln -s "$(which "python$PYTHON_VERSION")" python 63 | chmod +x python 64 | cd /opt/pipx_bin/ 65 | ln -s /opt/pipx/venvs/poetry/bin/poetry poetry 66 | chmod +x poetry 67 | 68 | # Ensure everything got set up correctly. 69 | /opt/pipx/venvs/poetry/bin/python --version 70 | /opt/pipx_bin/poetry --version 71 | 72 | - name: Install poetry 73 | if: steps.cache-bin-poetry.outputs.cache-hit != 'true' 74 | shell: bash 75 | env: 76 | POETRY_VERSION: ${{ inputs.poetry-version }} 77 | PYTHON_VERSION: ${{ inputs.python-version }} 78 | # Install poetry using the python version installed by setup-python step. 79 | run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose 80 | 81 | - name: Restore pip and poetry cached dependencies 82 | uses: actions/cache@v4 83 | env: 84 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" 85 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 86 | with: 87 | path: | 88 | ~/.cache/pip 89 | ~/.cache/pypoetry/virtualenvs 90 | ~/.cache/pypoetry/cache 91 | ~/.cache/pypoetry/artifacts 92 | ${{ env.WORKDIR }}/.venv 93 | key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles(format('{0}/**/poetry.lock', env.WORKDIR)) }} 94 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | LANGCHAIN_DIRS = { 5 | "libs/genai", 6 | "libs/vertexai", 7 | "libs/community" 8 | } 9 | 10 | if __name__ == "__main__": 11 | files = sys.argv[1:] 12 | dirs_to_run = set() 13 | 14 | if len(files) == 300: 15 | # max diff length is 300 files - there are likely files missing 16 | raise ValueError("Max diff reached. Please manually run CI on changed libs.") 17 | 18 | for file in files: 19 | if any( 20 | file.startswith(dir_) 21 | for dir_ in ( 22 | ".github/workflows", 23 | ".github/tools", 24 | ".github/actions", 25 | ".github/scripts/check_diff.py", 26 | ) 27 | ): 28 | dirs_to_run.update(LANGCHAIN_DIRS) 29 | elif "libs/genai" in file: 30 | dirs_to_run.update({"libs/genai"}) 31 | elif "libs/vertexai" in file: 32 | dirs_to_run.update({"libs/vertexai"}) 33 | elif "libs/community" in file: 34 | dirs_to_run.update({"libs/community"}) 35 | else: 36 | pass 37 | json_output = json.dumps(list(dirs_to_run)) 38 | print(f"dirs-to-run={json_output}") 39 | -------------------------------------------------------------------------------- /.github/scripts/get_min_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 11): 4 | import tomllib 5 | else: 6 | # for python 3.10 and below, which doesnt have stdlib tomllib 7 | import tomli as tomllib 8 | 9 | from packaging.version import parse as parse_version 10 | import re 11 | 12 | MIN_VERSION_LIBS = ["langchain-core"] 13 | 14 | SKIP_IF_PULL_REQUEST = ["langchain-core"] 15 | 16 | 17 | def get_min_version(version: str) -> str: 18 | # base regex for x.x.x with cases for rc/post/etc 19 | # valid strings: https://peps.python.org/pep-0440/#public-version-identifiers 20 | vstring = r"\d+(?:\.\d+){0,2}(?:(?:a|b|rc|\.post|\.dev)\d+)?" 21 | # case ^x.x.x 22 | _match = re.match(f"^\\^({vstring})$", version) 23 | if _match: 24 | return _match.group(1) 25 | 26 | # case >=x.x.x,=({vstring}),<({vstring})$", version) 28 | if _match: 29 | _min = _match.group(1) 30 | _max = _match.group(2) 31 | assert parse_version(_min) < parse_version(_max) 32 | return _min 33 | 34 | # case x.x.x 35 | _match = re.match(f"^({vstring})$", version) 36 | if _match: 37 | return _match.group(1) 38 | 39 | raise ValueError(f"Unrecognized version format: {version}") 40 | 41 | 42 | def get_min_version_from_toml(toml_path: str, versions_for: str): 43 | # Parse the TOML file 44 | with open(toml_path, "rb") as file: 45 | toml_data = tomllib.load(file) 46 | 47 | # Get the dependencies from tool.poetry.dependencies 48 | dependencies = toml_data["tool"]["poetry"]["dependencies"] 49 | 50 | # Initialize a dictionary to store the minimum versions 51 | min_versions = {} 52 | 53 | # Iterate over the libs in MIN_VERSION_LIBS 54 | for lib in MIN_VERSION_LIBS: 55 | if versions_for == "pull_request" and lib in SKIP_IF_PULL_REQUEST: 56 | # some libs only get checked on release because of simultaneous 57 | # changes 58 | continue 59 | # Check if the lib is present in the dependencies 60 | if lib in dependencies: 61 | # Get the version string 62 | version_string = dependencies[lib] 63 | 64 | if isinstance(version_string, dict): 65 | version_string = version_string["version"] 66 | 67 | # Use parse_version to get the minimum supported version from version_string 68 | min_version = get_min_version(version_string) 69 | 70 | # Store the minimum version in the min_versions dictionary 71 | min_versions[lib] = min_version 72 | 73 | return min_versions 74 | 75 | 76 | if __name__ == "__main__": 77 | # Get the TOML file path from the command line argument 78 | toml_file = sys.argv[1] 79 | versions_for = sys.argv[2] 80 | assert versions_for in ["release", "pull_request"] 81 | 82 | # Call the function to get the minimum versions 83 | min_versions = get_min_version_from_toml(toml_file, versions_for) 84 | 85 | print(" ".join([f"{lib}=={version}" for lib, version in min_versions.items()])) 86 | -------------------------------------------------------------------------------- /.github/workflows/_all_ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: langchain CI 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | working-directory: 8 | required: true 9 | type: string 10 | description: "From which folder this pipeline executes" 11 | workflow_dispatch: 12 | inputs: 13 | working-directory: 14 | required: true 15 | type: choice 16 | default: 'libs/vertexai' 17 | options: 18 | - libs/genai 19 | - libs/vertexai 20 | - libs/community 21 | 22 | 23 | # If another push to the same PR or branch happens while this workflow is still running, 24 | # cancel the earlier run in favor of the next run. 25 | # 26 | # There's no point in testing an outdated version of the code. GitHub only allows 27 | # a limited number of job runners to be active at the same time, so it's better to cancel 28 | # pointless jobs early so that more useful jobs can run sooner. 29 | concurrency: 30 | group: ${{ github.workflow }}-${{ github.ref }}-${{ inputs.working-directory }} 31 | cancel-in-progress: true 32 | 33 | env: 34 | POETRY_VERSION: "1.7.1" 35 | 36 | jobs: 37 | lint: 38 | name: "-" 39 | uses: ./.github/workflows/_lint.yml 40 | with: 41 | working-directory: ${{ inputs.working-directory }} 42 | secrets: inherit 43 | 44 | test: 45 | name: "-" 46 | uses: ./.github/workflows/_test.yml 47 | with: 48 | working-directory: ${{ inputs.working-directory }} 49 | secrets: inherit 50 | 51 | compile-integration-tests: 52 | name: "-" 53 | uses: ./.github/workflows/_compile_integration_test.yml 54 | with: 55 | working-directory: ${{ inputs.working-directory }} 56 | secrets: inherit 57 | -------------------------------------------------------------------------------- /.github/workflows/_codespell.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: make spell_check 3 | 4 | on: 5 | workflow_call: 6 | inputs: 7 | working-directory: 8 | required: true 9 | type: string 10 | description: "From which folder this pipeline executes" 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | codespell: 17 | name: (Check for spelling errors) 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v4 23 | 24 | - name: Install Dependencies 25 | run: | 26 | pip install toml 27 | 28 | - name: Extract Ignore Words List 29 | working-directory: ${{ inputs.working-directory }} 30 | run: | 31 | # Use a Python script to extract the ignore words list from pyproject.toml 32 | python ../../.github/workflows/extract_ignored_words_list.py 33 | id: extract_ignore_words 34 | 35 | - name: Codespell 36 | uses: codespell-project/actions-codespell@v2 37 | with: 38 | skip: guide_imports.json 39 | ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} 40 | -------------------------------------------------------------------------------- /.github/workflows/_compile_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: compile-integration-test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | POETRY_VERSION: "1.7.1" 13 | 14 | jobs: 15 | build: 16 | defaults: 17 | run: 18 | working-directory: ${{ inputs.working-directory }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.9" 24 | - "3.10" 25 | - "3.11" 26 | - "3.12" 27 | name: "poetry run pytest -m compile tests/integration_tests #${{ matrix.python-version }}" 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 32 | uses: "./.github/actions/poetry_setup" 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | poetry-version: ${{ env.POETRY_VERSION }} 36 | working-directory: ${{ inputs.working-directory }} 37 | cache-key: compile-integration 38 | 39 | - name: Install integration dependencies 40 | shell: bash 41 | run: poetry install --with=test_integration,test 42 | 43 | - name: Check integration tests compile 44 | shell: bash 45 | run: poetry run pytest -m compile tests/integration_tests 46 | 47 | - name: Ensure the tests did not create any additional files 48 | shell: bash 49 | run: | 50 | set -eu 51 | 52 | STATUS="$(git status)" 53 | echo "$STATUS" 54 | 55 | # grep will exit non-zero if the target message isn't found, 56 | # and `set -e` above will cause the step to fail. 57 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 58 | -------------------------------------------------------------------------------- /.github/workflows/_integration_test.yml: -------------------------------------------------------------------------------- 1 | name: Integration tests 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | python-version: 10 | required: true 11 | type: string 12 | description: "Python version to use" 13 | 14 | env: 15 | POETRY_VERSION: "1.7.1" 16 | 17 | jobs: 18 | build: 19 | defaults: 20 | run: 21 | working-directory: ${{ inputs.working-directory }} 22 | runs-on: ubuntu-latest 23 | name: Python ${{ inputs.python-version }} 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python ${{ inputs.python-version }} + Poetry ${{ env.POETRY_VERSION }} 28 | uses: "./.github/actions/poetry_setup" 29 | with: 30 | python-version: ${{ inputs.python-version }} 31 | poetry-version: ${{ env.POETRY_VERSION }} 32 | working-directory: ${{ inputs.working-directory }} 33 | cache-key: core 34 | 35 | - name: Install dependencies 36 | shell: bash 37 | run: poetry install --with test,test_integration 38 | 39 | - name: 'Authenticate to Google Cloud' 40 | id: 'auth' 41 | uses: google-github-actions/auth@v2 42 | with: 43 | credentials_json: '${{ secrets.GOOGLE_CREDENTIALS }}' 44 | 45 | - name: Run integration tests 46 | shell: bash 47 | env: 48 | GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} 49 | GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }} 50 | GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }} 51 | GOOGLE_VERTEX_AI_WEB_CREDENTIALS: ${{ secrets.GOOGLE_VERTEX_AI_WEB_CREDENTIALS }} 52 | run: | 53 | make integration_tests 54 | 55 | - name: Ensure the tests did not create any additional files 56 | shell: bash 57 | run: | 58 | set -eu 59 | 60 | STATUS="$(git status)" 61 | echo "$STATUS" 62 | 63 | # grep will exit non-zero if the target message isn't found, 64 | # and `set -e` above will cause the step to fail. 65 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 66 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | langchain-location: 11 | required: false 12 | type: string 13 | description: "Relative path to the langchain library folder" 14 | 15 | env: 16 | POETRY_VERSION: "1.7.1" 17 | 18 | jobs: 19 | build: 20 | defaults: 21 | run: 22 | working-directory: ${{ inputs.working-directory }} 23 | runs-on: ubuntu-latest 24 | strategy: 25 | matrix: 26 | python-version: 27 | - "3.9" 28 | - "3.10" 29 | - "3.11" 30 | - "3.12" 31 | name: "make test #${{ matrix.python-version }}" 32 | steps: 33 | - uses: actions/checkout@v4 34 | 35 | - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} 36 | uses: "./.github/actions/poetry_setup" 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | poetry-version: ${{ env.POETRY_VERSION }} 40 | working-directory: ${{ inputs.working-directory }} 41 | cache-key: core 42 | 43 | - name: Install dependencies 44 | shell: bash 45 | run: poetry install --with test 46 | 47 | - name: Install langchain editable 48 | working-directory: ${{ inputs.working-directory }} 49 | if: ${{ inputs.langchain-location }} 50 | env: 51 | LANGCHAIN_LOCATION: ${{ inputs.langchain-location }} 52 | run: | 53 | poetry run pip install -e "$LANGCHAIN_LOCATION" 54 | 55 | - name: Run core tests 56 | shell: bash 57 | run: | 58 | make test 59 | 60 | - name: Ensure the tests did not create any additional files 61 | shell: bash 62 | run: | 63 | set -eu 64 | 65 | STATUS="$(git status)" 66 | echo "$STATUS" 67 | 68 | # grep will exit non-zero if the target message isn't found, 69 | # and `set -e` above will cause the step to fail. 70 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 71 | -------------------------------------------------------------------------------- /.github/workflows/check_diffs.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | # If another push to the same PR or branch happens while this workflow is still running, 10 | # cancel the earlier run in favor of the next run. 11 | # 12 | # There's no point in testing an outdated version of the code. GitHub only allows 13 | # a limited number of job runners to be active at the same time, so it's better to cancel 14 | # pointless jobs early so that more useful jobs can run sooner. 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | build: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.10' 27 | - id: files 28 | uses: Ana06/get-changed-files@v2.2.0 29 | - id: set-matrix 30 | run: | 31 | python .github/scripts/check_diff.py ${{ steps.files.outputs.all }} >> $GITHUB_OUTPUT 32 | outputs: 33 | dirs-to-run: ${{ steps.set-matrix.outputs.dirs-to-run }} 34 | codespell: 35 | name: cd ${{ matrix.working-directory }} 36 | needs: [build] 37 | strategy: 38 | matrix: 39 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-run) }} 40 | uses: ./.github/workflows/_codespell.yml 41 | with: 42 | working-directory: ${{ matrix.working-directory }} 43 | 44 | ci: 45 | name: cd ${{ matrix.working-directory }} 46 | needs: [ build ] 47 | strategy: 48 | matrix: 49 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-run) }} 50 | uses: ./.github/workflows/_all_ci.yml 51 | with: 52 | working-directory: ${{ matrix.working-directory }} 53 | -------------------------------------------------------------------------------- /.github/workflows/extract_ignored_words_list.py: -------------------------------------------------------------------------------- 1 | import toml 2 | 3 | pyproject_toml = toml.load("pyproject.toml") 4 | 5 | # Extract the ignore words list (adjust the key as per your TOML structure) 6 | ignore_words_list = ( 7 | pyproject_toml.get("tool", {}).get("codespell", {}).get("ignore-words-list") 8 | ) 9 | 10 | print(f"::set-output name=ignore_words_list::{ignore_words_list}") 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | .vscode/ 3 | .idea/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # Google GitHub Actions credentials files created by: 34 | # https://github.com/google-github-actions/auth 35 | # 36 | # That action recommends adding this gitignore to prevent accidentally committing keys. 37 | gha-creds-*.json 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | docs/docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | notebooks/ 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .envrc 118 | .venv 119 | .venvs 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # macOS display setting files 145 | .DS_Store 146 | 147 | # Wandb directory 148 | wandb/ 149 | 150 | # asdf tool versions 151 | .tool-versions 152 | /.ruff_cache/ 153 | 154 | *.pkl 155 | *.bin 156 | 157 | # integration test artifacts 158 | data_map* 159 | \[('_type', 'fake'), ('stop', None)] 160 | 161 | # Replit files 162 | *replit* 163 | 164 | node_modules 165 | docs/.yarn/ 166 | docs/node_modules/ 167 | docs/.docusaurus/ 168 | docs/.cache-loader/ 169 | docs/_dist 170 | docs/api_reference/*api_reference.rst 171 | docs/api_reference/_build 172 | docs/api_reference/*/ 173 | !docs/api_reference/_static/ 174 | !docs/api_reference/templates/ 175 | !docs/api_reference/themes/ 176 | docs/docs/build 177 | docs/docs/node_modules 178 | docs/docs/yarn.lock 179 | _dist 180 | docs/docs/templates 181 | 182 | #terraform 183 | **/.terraform 184 | **/terraform.tfstate* 185 | **/terraform.tfvars 186 | **/*.auto.tfvars 187 | **/*.tfvars 188 | **/.terraform.tfstate.lock.info 189 | **/.terraform.lock.hcl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 4 | 5 | 30 | 31 | 43 | 44 | ## PR Description 45 | 46 | 47 | 48 | ## Relevant issues 49 | 50 | 51 | 52 | ## Type 53 | 54 | 55 | 56 | 57 | 🆕 New Feature 58 | 🐛 Bug Fix 59 | 🧹 Refactoring 60 | 📖 Documentation 61 | 🚄 Infrastructure 62 | ✅ Test 63 | 64 | ## Changes(optional) 65 | 66 | 67 | 68 | ## Testing(optional) 69 | 70 | 71 | 72 | 73 | ## Note(optional) 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /libs/community/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /libs/community/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/community/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 10 | 11 | test tests integration_test integration_tests: 12 | poetry run pytest $(TEST_FILE) 13 | 14 | # Run unit tests and generate a coverage report. 15 | coverage: 16 | poetry run pytest --cov \ 17 | --cov-config=.coveragerc \ 18 | --cov-report xml \ 19 | --cov-report term-missing:skip-covered \ 20 | $(TEST_FILE) 21 | 22 | ###################### 23 | # LINTING AND FORMATTING 24 | ###################### 25 | 26 | # Define a variable for Python and notebook files. 27 | PYTHON_FILES=. 28 | MYPY_CACHE=.mypy_cache 29 | lint format: PYTHON_FILES=. 30 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/community --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 31 | lint_package: PYTHON_FILES=langchain_google_community 32 | lint_tests: PYTHON_FILES=tests 33 | lint_tests: MYPY_CACHE=.mypy_cache_test 34 | 35 | lint lint_diff lint_package lint_tests: 36 | ./scripts/lint_imports.sh 37 | poetry run ruff check . 38 | poetry run ruff format $(PYTHON_FILES) --diff 39 | poetry run ruff check --select I $(PYTHON_FILES) 40 | mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 41 | 42 | format format_diff: 43 | poetry run ruff format $(PYTHON_FILES) 44 | poetry run ruff check --select I --fix $(PYTHON_FILES) 45 | 46 | spell_check: 47 | poetry run codespell --toml pyproject.toml 48 | 49 | spell_fix: 50 | poetry run codespell --toml pyproject.toml -w 51 | 52 | check_imports: $(shell find langchain_google_community -name '*.py') 53 | poetry run python ./scripts/check_imports.py $^ 54 | 55 | ###################### 56 | # HELP 57 | ###################### 58 | 59 | help: 60 | @echo '----' 61 | @echo 'check_imports - check imports' 62 | @echo 'format - run code formatters' 63 | @echo 'lint - run linters' 64 | @echo 'test - run unit tests' 65 | @echo 'tests - run unit tests' 66 | @echo 'test TEST_FILE= - run all tests in file' 67 | -------------------------------------------------------------------------------- /libs/community/README.md: -------------------------------------------------------------------------------- 1 | # langchain-google-community 2 | 3 | This package contains the LangChain integrations for Google products that are not part of `langchain-google-vertexai` or `langchain-google-genai` packages. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install -U langchain-google-community 9 | ``` 10 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_google_community.bigquery import BigQueryLoader 2 | from langchain_google_community.bigquery_vector_search import BigQueryVectorSearch 3 | from langchain_google_community.bq_storage_vectorstores.bigquery import ( 4 | BigQueryVectorStore, 5 | ) 6 | from langchain_google_community.bq_storage_vectorstores.featurestore import ( 7 | VertexFSVectorStore, 8 | ) 9 | from langchain_google_community.calendar.toolkit import ( 10 | CalendarCreateEvent, 11 | CalendarDeleteEvent, 12 | CalendarMoveEvent, 13 | CalendarSearchEvents, 14 | CalendarToolkit, 15 | CalendarUpdateEvent, 16 | GetCalendarsInfo, 17 | GetCurrentDatetime, 18 | ) 19 | from langchain_google_community.docai import DocAIParser, DocAIParsingResults 20 | from langchain_google_community.documentai_warehouse import DocumentAIWarehouseRetriever 21 | from langchain_google_community.drive import GoogleDriveLoader 22 | from langchain_google_community.gcs_directory import GCSDirectoryLoader 23 | from langchain_google_community.gcs_file import GCSFileLoader 24 | from langchain_google_community.geocoding import ( 25 | GoogleGeocodingAPIWrapper, 26 | GoogleGeocodingTool, 27 | ) 28 | from langchain_google_community.gmail.loader import GMailLoader 29 | from langchain_google_community.gmail.toolkit import GmailToolkit 30 | from langchain_google_community.google_speech_to_text import SpeechToTextLoader 31 | from langchain_google_community.places_api import ( 32 | GooglePlacesAPIWrapper, 33 | GooglePlacesTool, 34 | ) 35 | from langchain_google_community.search import ( 36 | GoogleSearchAPIWrapper, 37 | GoogleSearchResults, 38 | GoogleSearchRun, 39 | ) 40 | from langchain_google_community.texttospeech import TextToSpeechTool 41 | from langchain_google_community.translate import GoogleTranslateTransformer 42 | from langchain_google_community.vertex_ai_search import ( 43 | VertexAIMultiTurnSearchRetriever, 44 | VertexAISearchRetriever, 45 | VertexAISearchSummaryTool, 46 | ) 47 | from langchain_google_community.vertex_check_grounding import ( 48 | VertexAICheckGroundingWrapper, 49 | ) 50 | from langchain_google_community.vertex_rank import VertexAIRank 51 | from langchain_google_community.vision import CloudVisionLoader, CloudVisionParser 52 | 53 | __all__ = [ 54 | "BigQueryLoader", 55 | "BigQueryVectorStore", 56 | "BigQueryVectorSearch", 57 | "CalendarCreateEvent", 58 | "CalendarDeleteEvent", 59 | "CalendarMoveEvent", 60 | "CalendarSearchEvents", 61 | "CalendarUpdateEvent", 62 | "GetCalendarsInfo", 63 | "GetCurrentDatetime", 64 | "CalendarToolkit", 65 | "CloudVisionLoader", 66 | "CloudVisionParser", 67 | "DocAIParser", 68 | "DocAIParsingResults", 69 | "DocumentAIWarehouseRetriever", 70 | "GCSDirectoryLoader", 71 | "GCSFileLoader", 72 | "GMailLoader", 73 | "GmailToolkit", 74 | "GoogleDriveLoader", 75 | "GoogleGeocodingAPIWrapper", 76 | "GoogleGeocodingTool", 77 | "GooglePlacesAPIWrapper", 78 | "GooglePlacesTool", 79 | "GoogleSearchAPIWrapper", 80 | "GoogleSearchResults", 81 | "GoogleSearchRun", 82 | "GoogleTranslateTransformer", 83 | "SpeechToTextLoader", 84 | "TextToSpeechTool", 85 | "VertexAIMultiTurnSearchRetriever", 86 | "VertexAISearchRetriever", 87 | "VertexAISearchSummaryTool", 88 | "VertexAIRank", 89 | "VertexAICheckGroundingWrapper", 90 | "VertexFSVectorStore", 91 | ] 92 | 93 | 94 | from importlib import metadata 95 | 96 | try: 97 | __version__ = metadata.version(__package__) 98 | except metadata.PackageNotFoundError: 99 | # Case where package metadata is not available. 100 | __version__ = "" 101 | del metadata # optional, avoids polluting the results of dir(__package__) 102 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/bq_storage_vectorstores/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from google.cloud.exceptions import NotFound 4 | 5 | 6 | def validate_column_in_bq_schema( 7 | columns: dict, column_name: str, expected_types: list, expected_modes: list 8 | ) -> None: 9 | """Validates a column within a BigQuery schema. 10 | 11 | Args: 12 | columns: A dictionary of BigQuery SchemaField objects representing 13 | the table schema. 14 | column_name: The name of the column to validate. 15 | expected_types: A list of acceptable data types for the column. 16 | expected_modes: A list of acceptable modes for the column. 17 | 18 | Raises: 19 | ValueError: If the column doesn't exist, has an unacceptable type, 20 | or has an unacceptable mode. 21 | """ 22 | 23 | if column_name not in columns: 24 | raise ValueError(f"Column {column_name} is missing from the schema.") 25 | 26 | column = columns[column_name] 27 | 28 | if column.field_type not in expected_types: 29 | raise ValueError( 30 | f"Column {column_name} must be one of the following types: {expected_types}" 31 | ) 32 | 33 | if column.mode not in expected_modes: 34 | raise ValueError( 35 | f"Column {column_name} must be one of the following modes: {expected_modes}" 36 | ) 37 | 38 | 39 | def doc_match_filter(document: Dict[str, Any], filter: Dict[str, Any]) -> bool: 40 | for column, value in filter.items(): 41 | # ignore fields that are not part of the document 42 | if document.get(column, value) != value: 43 | return False 44 | return True 45 | 46 | 47 | def cast_proto_type(column: str, value: Any) -> Any: 48 | if column.startswith("int"): 49 | return int(value) 50 | elif column.startswith("double"): 51 | return float(value) 52 | elif column.startswith("bool"): 53 | return bool(value) 54 | return value 55 | 56 | 57 | def check_bq_dataset_exists(client: Any, dataset_id: str) -> bool: 58 | from google.cloud import bigquery # type: ignore[attr-defined] 59 | 60 | if not isinstance(client, bigquery.Client): 61 | raise TypeError("client must be an instance of bigquery.Client") 62 | 63 | try: 64 | client.get_dataset(dataset_id) # Make an API request. 65 | return True 66 | except NotFound: 67 | return False 68 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/__init__.py: -------------------------------------------------------------------------------- 1 | """Google Calendar toolkit.""" 2 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/base.py: -------------------------------------------------------------------------------- 1 | """Base class for Google Calendar tools.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from langchain_core.tools import BaseTool 8 | from pydantic import Field 9 | 10 | from langchain_google_community.calendar.utils import build_calendar_service 11 | 12 | if TYPE_CHECKING: 13 | # This is for linting and IDE typehints 14 | from googleapiclient.discovery import Resource # type: ignore[import] 15 | else: 16 | try: 17 | # We do this so pydantic can resolve the types when instantiating 18 | from googleapiclient.discovery import Resource 19 | except ImportError: 20 | pass 21 | 22 | 23 | class CalendarBaseTool(BaseTool): # type: ignore[override] 24 | """Base class for Google Calendar tools.""" 25 | 26 | api_resource: Resource = Field(default_factory=build_calendar_service) 27 | 28 | @classmethod 29 | def from_api_resource(cls, api_resource: Resource) -> "CalendarBaseTool": 30 | """Create a tool from an api resource. 31 | 32 | Args: 33 | api_resource: The api resource to use. 34 | 35 | Returns: 36 | A tool. 37 | """ 38 | return cls(service=api_resource) # type: ignore[call-arg] 39 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/current_datetime.py: -------------------------------------------------------------------------------- 1 | """Get the current datetime according to the calendar timezone.""" 2 | 3 | from datetime import datetime 4 | from typing import Optional, Type 5 | 6 | from langchain_core.callbacks import CallbackManagerForToolRun 7 | from pydantic import BaseModel, Field 8 | from zoneinfo import ZoneInfo 9 | 10 | from langchain_google_community.calendar.base import CalendarBaseTool 11 | 12 | 13 | class CurrentDatetimeSchema(BaseModel): 14 | """Input for GetCurrentDatetime.""" 15 | 16 | calendar_id: Optional[str] = Field( 17 | default="primary", description="The calendar ID. Defaults to 'primary'." 18 | ) 19 | 20 | 21 | class GetCurrentDatetime(CalendarBaseTool): # type: ignore[override, override] 22 | """Tool that gets the current datetime according to the calendar timezone.""" 23 | 24 | name: str = "get_current_datetime" 25 | description: str = ( 26 | "Use this tool to get the current datetime according to the calendar timezone." 27 | "The output datetime format is 'YYYY-MM-DD HH:MM:SS'" 28 | ) 29 | args_schema: Type[CurrentDatetimeSchema] = CurrentDatetimeSchema 30 | 31 | def get_timezone(self, calendar_id: Optional[str]) -> str: 32 | """Get the timezone of the specified calendar.""" 33 | calendars = self.api_resource.calendarList().list().execute().get("items", []) 34 | if not calendars: 35 | raise ValueError("No calendars found.") 36 | if calendar_id == "primary": 37 | return calendars[0]["timeZone"] 38 | else: 39 | for item in calendars: 40 | if item["id"] == calendar_id and item["accessRole"] != "reader": 41 | return item["timeZone"] 42 | raise ValueError(f"Timezone not found for calendar ID: {calendar_id}") 43 | 44 | def _run( 45 | self, 46 | calendar_id: Optional[str] = "primary", 47 | run_manager: Optional[CallbackManagerForToolRun] = None, 48 | ) -> str: 49 | """Run the tool to create an event in Google Calendar.""" 50 | try: 51 | timezone = self.get_timezone(calendar_id) 52 | date_time = datetime.now(ZoneInfo(timezone)).strftime("%Y-%m-%d %H:%M:%S") 53 | return f"Time zone: {timezone}, Date and time: {date_time}" 54 | except Exception as error: 55 | raise Exception(f"An error occurred: {error}") from error 56 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/delete_event.py: -------------------------------------------------------------------------------- 1 | """Delete an event in Google Calendar.""" 2 | 3 | from typing import Optional, Type 4 | 5 | from langchain_core.callbacks import CallbackManagerForToolRun 6 | from pydantic import BaseModel, Field 7 | 8 | from langchain_google_community.calendar.base import CalendarBaseTool 9 | 10 | 11 | class DeleteEventSchema(BaseModel): 12 | """Input for CalendarDeleteEvent.""" 13 | 14 | event_id: str = Field(..., description="The event ID to delete.") 15 | calendar_id: Optional[str] = Field( 16 | default="primary", description="The origin calendar ID." 17 | ) 18 | send_updates: Optional[str] = Field( 19 | default=None, 20 | description=( 21 | "Whether to send updates to attendees." 22 | "Allowed values are 'all', 'externalOnly', or 'none'." 23 | ), 24 | ) 25 | 26 | 27 | class CalendarDeleteEvent(CalendarBaseTool): # type: ignore[override, override] 28 | """Tool that delete an event in Google Calendar.""" 29 | 30 | name: str = "delete_calendar_event" 31 | description: str = "Use this tool to delete an event." 32 | args_schema: Type[DeleteEventSchema] = DeleteEventSchema 33 | 34 | def _run( 35 | self, 36 | event_id: str, 37 | calendar_id: Optional[str] = "primary", 38 | send_updates: Optional[str] = None, 39 | run_manager: Optional[CallbackManagerForToolRun] = None, 40 | ) -> str: 41 | """Run the tool to delete an event in Google Calendar.""" 42 | try: 43 | self.api_resource.events().delete( 44 | eventId=event_id, calendarId=calendar_id, sendUpdates=send_updates 45 | ).execute() 46 | return "Event deleted" 47 | except Exception as error: 48 | raise Exception(f"An error occurred: {error}") from error 49 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/get_calendars_info.py: -------------------------------------------------------------------------------- 1 | """Get information about the calendars in Google Calendar.""" 2 | 3 | import json 4 | from typing import Optional 5 | 6 | from langchain_core.callbacks import CallbackManagerForToolRun 7 | 8 | from langchain_google_community.calendar.base import CalendarBaseTool 9 | 10 | 11 | class GetCalendarsInfo(CalendarBaseTool): # type: ignore[override, override] 12 | """Tool that get information about the calendars in Google Calendar.""" 13 | 14 | name: str = "get_calendars_info" 15 | description: str = ( 16 | "Use this tool to get information about the calendars in Google Calendar." 17 | ) 18 | 19 | def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: 20 | """Run the tool to get information about the calendars in Google Calendar.""" 21 | try: 22 | calendars = self.api_resource.calendarList().list().execute() 23 | data = [] 24 | for item in calendars.get("items", []): 25 | data.append( 26 | { 27 | "id": item["id"], 28 | "summary": item["summary"], 29 | "timeZone": item["timeZone"], 30 | } 31 | ) 32 | return json.dumps(data) 33 | except Exception as error: 34 | raise Exception(f"An error occurred: {error}") from error 35 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/move_event.py: -------------------------------------------------------------------------------- 1 | """Move an event between calendars in Google Calendar.""" 2 | 3 | from typing import Optional, Type 4 | 5 | from langchain_core.callbacks import CallbackManagerForToolRun 6 | from pydantic import BaseModel, Field 7 | 8 | from langchain_google_community.calendar.base import CalendarBaseTool 9 | 10 | 11 | class MoveEventSchema(BaseModel): 12 | """Input for CalendarMoveEvent.""" 13 | 14 | event_id: str = Field(..., description="The event ID to move.") 15 | origin_calenddar_id: str = Field(..., description="The origin calendar ID.") 16 | destination_calendar_id: str = Field( 17 | ..., description="The destination calendar ID." 18 | ) 19 | send_updates: Optional[str] = Field( 20 | default=None, 21 | description=( 22 | "Whether to send updates to attendees." 23 | "Allowed values are 'all', 'externalOnly', or 'none'." 24 | ), 25 | ) 26 | 27 | 28 | class CalendarMoveEvent(CalendarBaseTool): # type: ignore[override, override] 29 | """Tool that move an event between calendars in Google Calendar.""" 30 | 31 | name: str = "move_calendar_event" 32 | description: str = "Use this tool to move an event between calendars." 33 | args_schema: Type[MoveEventSchema] = MoveEventSchema 34 | 35 | def _run( 36 | self, 37 | event_id: str, 38 | origin_calendar_id: str, 39 | destination_calendar_id: str, 40 | send_updates: Optional[str] = None, 41 | run_manager: Optional[CallbackManagerForToolRun] = None, 42 | ) -> str: 43 | """Run the tool to update an event in Google Calendar.""" 44 | try: 45 | result = ( 46 | self.api_resource.events() 47 | .move( 48 | eventId=event_id, 49 | calendarId=origin_calendar_id, 50 | destination=destination_calendar_id, 51 | sendUpdates=send_updates, 52 | ) 53 | .execute() 54 | ) 55 | return f"Event moved: {result.get('htmlLink')}" 56 | except Exception as error: 57 | raise Exception(f"An error occurred: {error}") from error 58 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/toolkit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, List 4 | 5 | from langchain_core.tools import BaseTool 6 | from langchain_core.tools.base import BaseToolkit 7 | from pydantic import ConfigDict, Field 8 | 9 | from langchain_google_community.calendar.create_event import CalendarCreateEvent 10 | from langchain_google_community.calendar.current_datetime import GetCurrentDatetime 11 | from langchain_google_community.calendar.delete_event import CalendarDeleteEvent 12 | from langchain_google_community.calendar.get_calendars_info import GetCalendarsInfo 13 | from langchain_google_community.calendar.move_event import CalendarMoveEvent 14 | from langchain_google_community.calendar.search_events import CalendarSearchEvents 15 | from langchain_google_community.calendar.update_event import CalendarUpdateEvent 16 | from langchain_google_community.calendar.utils import build_calendar_service 17 | 18 | if TYPE_CHECKING: 19 | # This is for linting and IDE typehints 20 | from googleapiclient.discovery import Resource # type: ignore[import] 21 | else: 22 | try: 23 | # We do this so pydantic can resolve the types when instantiating 24 | from googleapiclient.discovery import Resource 25 | except ImportError: 26 | pass 27 | 28 | 29 | SCOPES = ["https://www.googleapis.com/auth/calendar"] 30 | 31 | 32 | class CalendarToolkit(BaseToolkit): 33 | """Toolkit for interacting with Google Calendar. 34 | 35 | *Security Note*: This toolkit contains tools that can read and modify 36 | the state of a service; e.g., by reading, creating, updating, deleting 37 | data associated with this service. 38 | 39 | For example, this toolkit can be used to create events on behalf of the 40 | associated account. 41 | 42 | See https://python.langchain.com/docs/security for more information. 43 | """ 44 | 45 | api_resource: Resource = Field(default_factory=build_calendar_service) 46 | 47 | model_config = ConfigDict( 48 | arbitrary_types_allowed=True, 49 | ) 50 | 51 | def get_tools(self) -> List[BaseTool]: 52 | """Get the tools in the toolkit.""" 53 | return [ 54 | CalendarCreateEvent(api_resource=self.api_resource), 55 | CalendarSearchEvents(api_resource=self.api_resource), 56 | CalendarUpdateEvent(api_resource=self.api_resource), 57 | GetCalendarsInfo(api_resource=self.api_resource), 58 | CalendarMoveEvent(api_resource=self.api_resource), 59 | CalendarDeleteEvent(api_resource=self.api_resource), 60 | GetCurrentDatetime(api_resource=self.api_resource), 61 | ] 62 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/calendar/utils.py: -------------------------------------------------------------------------------- 1 | """Google Calendar tool utils.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import warnings 7 | from datetime import datetime 8 | from typing import TYPE_CHECKING, Optional 9 | 10 | from langchain_google_community._utils import ( 11 | get_google_credentials, 12 | import_googleapiclient_resource_builder, 13 | ) 14 | 15 | if TYPE_CHECKING: 16 | from google.oauth2.credentials import Credentials 17 | from googleapiclient.discovery import Resource 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | DEFAULT_SCOPES = ["https://www.googleapis.com/auth/calendar"] 23 | 24 | 25 | def build_calendar_service( 26 | credentials: Optional[Credentials] = None, 27 | service_name: str = "calendar", 28 | service_version: str = "v3", 29 | ) -> Resource: 30 | """Build a Google Calendar service.""" 31 | credentials = credentials or get_google_credentials(scopes=DEFAULT_SCOPES) 32 | builder = import_googleapiclient_resource_builder() 33 | return builder(service_name, service_version, credentials=credentials) 34 | 35 | 36 | def build_resouce_service( 37 | credentials: Optional[Credentials] = None, 38 | service_name: str = "calendar", 39 | service_version: str = "v3", 40 | ) -> Resource: 41 | warnings.warn( 42 | "build_resource_service is deprecated and will be removed in a future version." 43 | "Use build_calendar_service instead.", 44 | DeprecationWarning, 45 | stacklevel=2, 46 | ) 47 | return build_calendar_service(credentials, service_name, service_version) 48 | 49 | 50 | def is_all_day_event(start_datetime: str, end_datetime: str) -> bool: 51 | """Check if the event is all day.""" 52 | try: 53 | datetime.strptime(start_datetime, "%Y-%m-%d") 54 | datetime.strptime(end_datetime, "%Y-%m-%d") 55 | return True 56 | except ValueError: 57 | return False 58 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gcs_directory.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | 3 | from langchain_core.document_loaders import BaseLoader 4 | from langchain_core.documents import Document 5 | 6 | from langchain_google_community._utils import get_client_info 7 | from langchain_google_community.gcs_file import GCSFileLoader 8 | 9 | 10 | class GCSDirectoryLoader(BaseLoader): 11 | """Load from GCS directory.""" 12 | 13 | def __init__( 14 | self, 15 | project_name: str, 16 | bucket: str, 17 | prefix: str = "", 18 | loader_func: Optional[Callable[[str], BaseLoader]] = None, 19 | ): 20 | """Initialize with bucket and key name. 21 | 22 | Args: 23 | project_name: The ID of the project for the GCS bucket. 24 | bucket: The name of the GCS bucket. 25 | prefix: The prefix of the GCS bucket. 26 | loader_func: A loader function that instantiates a loader based on a 27 | file_path argument. If nothing is provided, the GCSFileLoader 28 | would use its default loader. 29 | """ 30 | self.project_name = project_name 31 | self.bucket = bucket 32 | self.prefix = prefix 33 | self._loader_func = loader_func 34 | 35 | def load(self) -> List[Document]: 36 | """Load documents.""" 37 | try: 38 | from google.cloud import storage # type: ignore[attr-defined] 39 | except ImportError: 40 | raise ImportError( 41 | "Could not import google-cloud-storage python package. " 42 | "Please, install gcs dependency group: " 43 | "`pip install langchain-google-community[gcs]`" 44 | ) 45 | client = storage.Client( 46 | project=self.project_name, 47 | client_info=get_client_info(module="google-cloud-storage"), 48 | ) 49 | docs = [] 50 | for blob in client.list_blobs(self.bucket, prefix=self.prefix): 51 | # we shall just skip directories since GCSFileLoader creates 52 | # intermediate directories on the fly 53 | if blob.name.endswith("/"): 54 | continue 55 | loader = GCSFileLoader( 56 | self.project_name, self.bucket, blob.name, loader_func=self._loader_func 57 | ) 58 | docs.extend(loader.load()) 59 | return docs 60 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gcs_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Callable, List, Optional 4 | 5 | from langchain_core.document_loaders import BaseLoader 6 | from langchain_core.documents import Document 7 | 8 | from langchain_google_community._utils import get_client_info 9 | 10 | 11 | class GCSFileLoader(BaseLoader): 12 | """Load from GCS file.""" 13 | 14 | def __init__( 15 | self, 16 | project_name: str, 17 | bucket: str, 18 | blob: str, 19 | loader_func: Optional[Callable[[str], BaseLoader]] = None, 20 | ): 21 | """Initialize with bucket and key name. 22 | 23 | Args: 24 | project_name: The name of the project to load 25 | bucket: The name of the GCS bucket. 26 | blob: The name of the GCS blob to load. 27 | loader_func: A loader function that instantiates a loader based on a 28 | file_path argument. If nothing is provided, the 29 | UnstructuredFileLoader is used. 30 | 31 | Examples: 32 | To use an alternative PDF loader: 33 | >> from from langchain_community.document_loaders import PyPDFLoader 34 | >> loader = GCSFileLoader(..., loader_func=PyPDFLoader) 35 | 36 | To use UnstructuredFileLoader with additional arguments: 37 | >> loader = GCSFileLoader(..., 38 | >> loader_func=lambda x: UnstructuredFileLoader(x, mode="elements")) 39 | 40 | """ 41 | self.bucket = bucket 42 | self.blob = blob 43 | self.project_name = project_name 44 | 45 | def default_loader_func(file_path: str) -> BaseLoader: 46 | try: 47 | from langchain_community.document_loaders.unstructured import ( 48 | UnstructuredFileLoader, 49 | ) 50 | except ImportError: 51 | message = ( 52 | "UnstructuredFileLoader loader not found! Either provide a " 53 | "custom loader with loader_func argument, or install " 54 | "`pip install langchain-google-community`" 55 | ) 56 | print(message) 57 | return UnstructuredFileLoader(file_path) 58 | 59 | self._loader_func = loader_func if loader_func else default_loader_func 60 | 61 | def load(self) -> List[Document]: 62 | """Load documents.""" 63 | try: 64 | from google.cloud import storage # type: ignore[attr-defined] 65 | except ImportError: 66 | raise ImportError( 67 | "Could not import google-cloud-storage python package. " 68 | "Please, install gcs dependency group: " 69 | "`pip install langchain-google-community[gcs]`" 70 | ) 71 | 72 | # initialize a client 73 | storage_client = storage.Client( 74 | self.project_name, client_info=get_client_info("google-cloud-storage") 75 | ) 76 | # Create a bucket object for our bucket 77 | bucket = storage_client.get_bucket(self.bucket) 78 | # Create a blob object from the filepath 79 | blob = bucket.blob(self.blob) 80 | # retrieve custom metadata associated with the blob 81 | metadata = bucket.get_blob(self.blob).metadata 82 | with tempfile.TemporaryDirectory() as temp_dir: 83 | file_path = f"{temp_dir}/{self.blob}" 84 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 85 | # Download the file to a destination 86 | blob.download_to_filename(file_path) 87 | loader = self._loader_func(file_path) 88 | docs = loader.load() 89 | for doc in docs: 90 | if "source" in doc.metadata: 91 | doc.metadata["source"] = f"gs://{self.bucket}/{self.blob}" 92 | if metadata: 93 | doc.metadata.update(metadata) 94 | return docs 95 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/community/langchain_google_community/gmail/__init__.py -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/base.py: -------------------------------------------------------------------------------- 1 | """Base class for Gmail tools.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from langchain_core.tools import BaseTool 8 | from pydantic import Field 9 | 10 | from langchain_google_community.gmail.utils import build_gmail_service 11 | 12 | if TYPE_CHECKING: 13 | # This is for linting and IDE typehints 14 | from googleapiclient.discovery import Resource # type: ignore[import] 15 | else: 16 | try: 17 | # We do this so pydantic can resolve the types when instantiating 18 | from googleapiclient.discovery import Resource 19 | except ImportError: 20 | pass 21 | 22 | 23 | class GmailBaseTool(BaseTool): 24 | """Base class for Gmail tools.""" 25 | 26 | api_resource: Resource = Field(default_factory=build_gmail_service) 27 | 28 | @classmethod 29 | def from_api_resource(cls, api_resource: Resource) -> "GmailBaseTool": 30 | """Create a tool from an api resource. 31 | 32 | Args: 33 | api_resource: The api resource to use. 34 | 35 | Returns: 36 | A tool. 37 | """ 38 | return cls(service=api_resource) # type: ignore[call-arg] 39 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/create_draft.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from email.message import EmailMessage 3 | from typing import List, Optional, Type 4 | 5 | from langchain_core.callbacks import CallbackManagerForToolRun 6 | from pydantic import BaseModel, Field 7 | 8 | from langchain_google_community.gmail.base import GmailBaseTool 9 | 10 | 11 | class CreateDraftSchema(BaseModel): 12 | """Input for CreateDraftTool.""" 13 | 14 | message: str = Field( 15 | ..., 16 | description="The message to include in the draft.", 17 | ) 18 | to: List[str] = Field( 19 | ..., 20 | description="The list of recipients.", 21 | ) 22 | subject: str = Field( 23 | ..., 24 | description="The subject of the message.", 25 | ) 26 | cc: Optional[List[str]] = Field( 27 | default=None, 28 | description="The list of CC recipients.", 29 | ) 30 | bcc: Optional[List[str]] = Field( 31 | default=None, 32 | description="The list of BCC recipients.", 33 | ) 34 | 35 | 36 | class GmailCreateDraft(GmailBaseTool): 37 | """Tool that creates a draft email for Gmail.""" 38 | 39 | name: str = "create_gmail_draft" 40 | description: str = ( 41 | "Use this tool to create a draft email with the provided message fields." 42 | ) 43 | args_schema: Type[CreateDraftSchema] = CreateDraftSchema 44 | 45 | def _prepare_draft_message( 46 | self, 47 | message: str, 48 | to: List[str], 49 | subject: str, 50 | cc: Optional[List[str]] = None, 51 | bcc: Optional[List[str]] = None, 52 | ) -> dict: 53 | draft_message = EmailMessage() 54 | draft_message.set_content(message) 55 | 56 | draft_message["To"] = ", ".join(to) 57 | draft_message["Subject"] = subject 58 | if cc is not None: 59 | draft_message["Cc"] = ", ".join(cc) 60 | 61 | if bcc is not None: 62 | draft_message["Bcc"] = ", ".join(bcc) 63 | 64 | encoded_message = base64.urlsafe_b64encode(draft_message.as_bytes()).decode() 65 | return {"message": {"raw": encoded_message}} 66 | 67 | def _run( 68 | self, 69 | message: str, 70 | to: List[str], 71 | subject: str, 72 | cc: Optional[List[str]] = None, 73 | bcc: Optional[List[str]] = None, 74 | run_manager: Optional[CallbackManagerForToolRun] = None, 75 | ) -> str: 76 | try: 77 | create_message = self._prepare_draft_message(message, to, subject, cc, bcc) 78 | draft = ( 79 | self.api_resource.users() 80 | .drafts() 81 | .create(userId="me", body=create_message) 82 | .execute() 83 | ) 84 | output = f'Draft created. Draft Id: {draft["id"]}' 85 | return output 86 | except Exception as e: 87 | raise Exception(f"An error occurred: {e}") 88 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/get_message.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import email 3 | from typing import Dict, Optional, Type 4 | 5 | from langchain_core.callbacks import CallbackManagerForToolRun 6 | from pydantic import BaseModel, Field 7 | 8 | from langchain_google_community.gmail.base import GmailBaseTool 9 | from langchain_google_community.gmail.utils import clean_email_body 10 | 11 | 12 | class SearchArgsSchema(BaseModel): 13 | """Input for GetMessageTool.""" 14 | 15 | message_id: str = Field( 16 | ..., 17 | description="The unique ID of the email message, retrieved from a search.", 18 | ) 19 | 20 | 21 | class GmailGetMessage(GmailBaseTool): 22 | """Tool that gets a message by ID from Gmail.""" 23 | 24 | name: str = "get_gmail_message" 25 | description: str = ( 26 | "Use this tool to fetch an email by message ID." 27 | " Returns the thread ID, snippet, body, subject, and sender." 28 | ) 29 | args_schema: Type[SearchArgsSchema] = SearchArgsSchema 30 | 31 | def _run( 32 | self, 33 | message_id: str, 34 | run_manager: Optional[CallbackManagerForToolRun] = None, 35 | ) -> Dict: 36 | """Run the tool.""" 37 | query = ( 38 | self.api_resource.users() 39 | .messages() 40 | .get(userId="me", format="raw", id=message_id) 41 | ) 42 | message_data = query.execute() 43 | raw_message = base64.urlsafe_b64decode(message_data["raw"]) 44 | 45 | email_msg = email.message_from_bytes(raw_message) 46 | 47 | subject = email_msg["Subject"] 48 | sender = email_msg["From"] 49 | 50 | message_body = "" 51 | if email_msg.is_multipart(): 52 | for part in email_msg.walk(): 53 | ctype = part.get_content_type() 54 | cdispo = str(part.get("Content-Disposition")) 55 | if ctype == "text/plain" and "attachment" not in cdispo: 56 | message_body = part.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr] 57 | break 58 | else: 59 | message_body = email_msg.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr] 60 | 61 | body = clean_email_body(message_body) 62 | 63 | return { 64 | "id": message_id, 65 | "threadId": message_data["threadId"], 66 | "snippet": message_data["snippet"], 67 | "body": body, 68 | "subject": subject, 69 | "sender": sender, 70 | } 71 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/get_thread.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Type 2 | 3 | from langchain_core.callbacks import CallbackManagerForToolRun 4 | from pydantic import BaseModel, Field 5 | 6 | from langchain_google_community.gmail.base import GmailBaseTool 7 | 8 | 9 | class GetThreadSchema(BaseModel): 10 | """Input for GetMessageTool.""" 11 | 12 | # From https://support.google.com/mail/answer/7190?hl=en 13 | thread_id: str = Field( 14 | ..., 15 | description="The thread ID.", 16 | ) 17 | 18 | 19 | class GmailGetThread(GmailBaseTool): 20 | """Tool that gets a thread by ID from Gmail.""" 21 | 22 | name: str = "get_gmail_thread" 23 | description: str = ( 24 | "Use this tool to search for email messages." 25 | " The input must be a valid Gmail query." 26 | " The output is a JSON list of messages." 27 | ) 28 | args_schema: Type[GetThreadSchema] = GetThreadSchema 29 | 30 | def _run( 31 | self, 32 | thread_id: str, 33 | run_manager: Optional[CallbackManagerForToolRun] = None, 34 | ) -> Dict: 35 | """Run the tool.""" 36 | query = self.api_resource.users().threads().get(userId="me", id=thread_id) 37 | thread_data = query.execute() 38 | if not isinstance(thread_data, dict): 39 | raise ValueError("The output of the query must be a list.") 40 | messages = thread_data["messages"] 41 | thread_data["messages"] = [] 42 | keys_to_keep = ["id", "snippet", "snippet"] 43 | # TODO: Parse body. 44 | for message in messages: 45 | thread_data["messages"].append( 46 | {k: message[k] for k in keys_to_keep if k in message} 47 | ) 48 | return thread_data 49 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/send_message.py: -------------------------------------------------------------------------------- 1 | """Send Gmail messages.""" 2 | 3 | import base64 4 | from email.mime.multipart import MIMEMultipart 5 | from email.mime.text import MIMEText 6 | from typing import Any, Dict, List, Optional, Type, Union 7 | 8 | from langchain_core.callbacks import CallbackManagerForToolRun 9 | from pydantic import BaseModel, Field 10 | 11 | from langchain_google_community.gmail.base import GmailBaseTool 12 | 13 | 14 | class SendMessageSchema(BaseModel): 15 | """Input for SendMessageTool.""" 16 | 17 | message: str = Field( 18 | ..., 19 | description="The message to send.", 20 | ) 21 | to: Union[str, List[str]] = Field( 22 | ..., 23 | description="The list of recipients.", 24 | ) 25 | subject: str = Field( 26 | ..., 27 | description="The subject of the message.", 28 | ) 29 | cc: Optional[Union[str, List[str]]] = Field( 30 | default=None, 31 | description="The list of CC recipients.", 32 | ) 33 | bcc: Optional[Union[str, List[str]]] = Field( 34 | default=None, 35 | description="The list of BCC recipients.", 36 | ) 37 | 38 | 39 | class GmailSendMessage(GmailBaseTool): 40 | """Tool that sends a message to Gmail.""" 41 | 42 | name: str = "send_gmail_message" 43 | description: str = ( 44 | "Use this tool to send email messages." " The input is the message, recipients" 45 | ) 46 | args_schema: Type[SendMessageSchema] = SendMessageSchema 47 | 48 | def _prepare_message( 49 | self, 50 | message: str, 51 | to: Union[str, List[str]], 52 | subject: str, 53 | cc: Optional[Union[str, List[str]]] = None, 54 | bcc: Optional[Union[str, List[str]]] = None, 55 | ) -> Dict[str, Any]: 56 | """Create a message for an email.""" 57 | mime_message = MIMEMultipart() 58 | mime_message.attach(MIMEText(message, "html")) 59 | 60 | mime_message["To"] = ", ".join(to if isinstance(to, list) else [to]) 61 | mime_message["Subject"] = subject 62 | if cc is not None: 63 | mime_message["Cc"] = ", ".join(cc if isinstance(cc, list) else [cc]) 64 | 65 | if bcc is not None: 66 | mime_message["Bcc"] = ", ".join(bcc if isinstance(bcc, list) else [bcc]) 67 | 68 | encoded_message = base64.urlsafe_b64encode(mime_message.as_bytes()).decode() 69 | return {"raw": encoded_message} 70 | 71 | def _run( 72 | self, 73 | message: str, 74 | to: Union[str, List[str]], 75 | subject: str, 76 | cc: Optional[Union[str, List[str]]] = None, 77 | bcc: Optional[Union[str, List[str]]] = None, 78 | run_manager: Optional[CallbackManagerForToolRun] = None, 79 | ) -> str: 80 | """Run the tool.""" 81 | try: 82 | create_message = self._prepare_message(message, to, subject, cc=cc, bcc=bcc) 83 | send_message = ( 84 | self.api_resource.users() 85 | .messages() 86 | .send(userId="me", body=create_message) 87 | ) 88 | sent_message = send_message.execute() 89 | return f'Message sent. Message Id: {sent_message["id"]}' 90 | except Exception as error: 91 | raise Exception(f"An error occurred: {error}") 92 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/gmail/toolkit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, List 4 | 5 | from langchain_community.agent_toolkits.base import BaseToolkit 6 | from langchain_core.tools import BaseTool 7 | from pydantic import ConfigDict, Field 8 | 9 | from langchain_google_community.gmail.create_draft import GmailCreateDraft 10 | from langchain_google_community.gmail.get_message import GmailGetMessage 11 | from langchain_google_community.gmail.get_thread import GmailGetThread 12 | from langchain_google_community.gmail.search import GmailSearch 13 | from langchain_google_community.gmail.send_message import GmailSendMessage 14 | from langchain_google_community.gmail.utils import build_gmail_service 15 | 16 | if TYPE_CHECKING: 17 | # This is for linting and IDE typehints 18 | from googleapiclient.discovery import Resource # type: ignore[import] 19 | else: 20 | try: 21 | # We do this so pydantic can resolve the types when instantiating 22 | from googleapiclient.discovery import Resource 23 | except ImportError: 24 | pass 25 | 26 | 27 | SCOPES = ["https://mail.google.com/"] 28 | 29 | 30 | class GmailToolkit(BaseToolkit): 31 | """Toolkit for interacting with Gmail. 32 | 33 | *Security Note*: This toolkit contains tools that can read and modify 34 | the state of a service; e.g., by reading, creating, updating, deleting 35 | data associated with this service. 36 | 37 | For example, this toolkit can be used to send emails on behalf of the 38 | associated account. 39 | 40 | See https://python.langchain.com/docs/security for more information. 41 | """ 42 | 43 | api_resource: Resource = Field(default_factory=build_gmail_service) 44 | 45 | model_config = ConfigDict( 46 | arbitrary_types_allowed=True, 47 | ) 48 | 49 | def get_tools(self) -> List[BaseTool]: 50 | """Get the tools in the toolkit.""" 51 | return [ 52 | GmailCreateDraft(api_resource=self.api_resource), 53 | GmailSendMessage(api_resource=self.api_resource), 54 | GmailSearch(api_resource=self.api_resource), 55 | GmailGetMessage(api_resource=self.api_resource), 56 | GmailGetThread(api_resource=self.api_resource), 57 | ] 58 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/texttospeech.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import tempfile 4 | from typing import TYPE_CHECKING, Any, Optional 5 | 6 | from langchain_core.callbacks import CallbackManagerForToolRun 7 | from langchain_core.tools import BaseTool 8 | 9 | from langchain_google_community._utils import get_client_info 10 | 11 | if TYPE_CHECKING: 12 | from google.cloud import texttospeech # type: ignore[attr-defined] 13 | 14 | 15 | def _import_google_cloud_texttospeech() -> Any: 16 | try: 17 | from google.cloud import texttospeech # type: ignore[attr-defined] 18 | except ImportError as e: 19 | raise ImportError( 20 | "Could not import google-cloud-texttospeech python package. " 21 | "Please, install texttospeech dependency group: " 22 | "`pip install langchain-google-community[texttospeech]`" 23 | ) from e 24 | return texttospeech 25 | 26 | 27 | def _encoding_file_extension_map(encoding: texttospeech.AudioEncoding) -> Optional[str]: 28 | texttospeech = _import_google_cloud_texttospeech() 29 | 30 | ENCODING_FILE_EXTENSION_MAP = { 31 | texttospeech.AudioEncoding.LINEAR16: ".wav", 32 | texttospeech.AudioEncoding.MP3: ".mp3", 33 | texttospeech.AudioEncoding.OGG_OPUS: ".ogg", 34 | texttospeech.AudioEncoding.MULAW: ".wav", 35 | texttospeech.AudioEncoding.ALAW: ".wav", 36 | } 37 | return ENCODING_FILE_EXTENSION_MAP.get(encoding) 38 | 39 | 40 | class TextToSpeechTool(BaseTool): 41 | """Tool that queries the Google Cloud Text to Speech API. 42 | 43 | In order to set this up, follow instructions at: 44 | https://cloud.google.com/text-to-speech/docs/before-you-begin 45 | """ 46 | 47 | name: str = "google_cloud_texttospeech" 48 | description: str = ( 49 | "A wrapper around Google Cloud Text-to-Speech. " 50 | "Useful for when you need to synthesize audio from text. " 51 | "It supports multiple languages, including English, German, Polish, " 52 | "Spanish, Italian, French, Portuguese, and Hindi. " 53 | ) 54 | 55 | _client: Any 56 | 57 | def __init__(self, **kwargs: Any) -> None: 58 | """Initializes private fields.""" 59 | texttospeech = _import_google_cloud_texttospeech() 60 | 61 | super().__init__(**kwargs) 62 | 63 | self._client = texttospeech.TextToSpeechClient( 64 | client_info=get_client_info(module="text-to-speech") 65 | ) 66 | 67 | def _run( 68 | self, 69 | input_text: str, 70 | language_code: str = "en-US", 71 | ssml_gender: Optional[texttospeech.SsmlVoiceGender] = None, 72 | audio_encoding: Optional[texttospeech.AudioEncoding] = None, 73 | run_manager: Optional[CallbackManagerForToolRun] = None, 74 | ) -> str: 75 | """Use the tool.""" 76 | texttospeech = _import_google_cloud_texttospeech() 77 | ssml_gender = ssml_gender or texttospeech.SsmlVoiceGender.NEUTRAL 78 | audio_encoding = audio_encoding or texttospeech.AudioEncoding.MP3 79 | 80 | response = self._client.synthesize_speech( 81 | input=texttospeech.SynthesisInput(text=input_text), 82 | voice=texttospeech.VoiceSelectionParams( 83 | language_code=language_code, ssml_gender=ssml_gender 84 | ), 85 | audio_config=texttospeech.AudioConfig(audio_encoding=audio_encoding), 86 | ) 87 | 88 | suffix = _encoding_file_extension_map(audio_encoding) 89 | 90 | with tempfile.NamedTemporaryFile(mode="bx", suffix=suffix, delete=False) as f: 91 | f.write(response.audio_content) 92 | return f.name 93 | -------------------------------------------------------------------------------- /libs/community/langchain_google_community/vision.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional 2 | 3 | from langchain_core.document_loaders import BaseBlobParser, BaseLoader 4 | from langchain_core.document_loaders.blob_loaders import Blob 5 | from langchain_core.documents import Document 6 | 7 | from langchain_google_community._utils import get_client_info 8 | 9 | 10 | class CloudVisionParser(BaseBlobParser): 11 | def __init__(self, project: Optional[str] = None): 12 | try: 13 | from google.cloud import vision # type: ignore[attr-defined] 14 | except ImportError as e: 15 | raise ImportError( 16 | "Could not import google-cloud-vision python package. " 17 | "Please, install vision dependency group: " 18 | "poetry install --with vision" 19 | ) from e 20 | client_options = None 21 | if project: 22 | client_options = {"quota_project_id": project} 23 | self._client = vision.ImageAnnotatorClient( 24 | client_options=client_options, 25 | client_info=get_client_info(module="cloud-vision"), 26 | ) 27 | 28 | def load(self, gcs_uri: str) -> Document: 29 | """Loads an image from GCS path to a Document, only the text.""" 30 | from google.cloud import vision # type: ignore[attr-defined] 31 | 32 | image = vision.Image(source=vision.ImageSource(image_uri=gcs_uri)) 33 | text_detection_response = self._client.text_detection(image=image) 34 | annotations = text_detection_response.text_annotations 35 | 36 | if annotations: 37 | text = annotations[0].description 38 | else: 39 | text = "" 40 | return Document(page_content=text, metadata={"source": gcs_uri}) 41 | 42 | def lazy_parse(self, blob: Blob) -> Iterator[Document]: 43 | yield self.load(blob.path) # type: ignore[arg-type] 44 | 45 | 46 | class CloudVisionLoader(BaseLoader): 47 | def __init__(self, file_path: str, project: Optional[str] = None): 48 | try: 49 | from google.cloud import vision # type: ignore[attr-defined] 50 | except ImportError as e: 51 | raise ImportError( 52 | "Could not import google-cloud-vision python package. " 53 | "Please, install vision dependency group: " 54 | "`pip install langchain-google-community[vision]`" 55 | ) from e 56 | client_options = None 57 | if project: 58 | client_options = {"quota_project_id": project} 59 | self._client = vision.ImageAnnotatorClient( 60 | client_options=client_options, 61 | client_info=get_client_info(module="cloud-vision"), 62 | ) 63 | self._file_path = file_path 64 | 65 | def load(self) -> List[Document]: 66 | """Loads an image from GCS path to a Document, only the text.""" 67 | from google.cloud import vision # type: ignore[attr-defined] 68 | 69 | image = vision.Image(source=vision.ImageSource(image_uri=self._file_path)) 70 | text_detection_response = self._client.text_detection(image=image) 71 | annotations = text_detection_response.text_annotations 72 | 73 | if annotations: 74 | text = annotations[0].description 75 | else: 76 | text = "" 77 | return [Document(page_content=text, metadata={"source": self._file_path})] 78 | -------------------------------------------------------------------------------- /libs/community/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | for file in files: 9 | try: 10 | SourceFileLoader("x", file).load_module() 11 | except Exception: 12 | has_faillure = True 13 | print(file) 14 | traceback.print_exc() 15 | print() 16 | 17 | sys.exit(1 if has_failure else 0) 18 | -------------------------------------------------------------------------------- /libs/community/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/community/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/community/tests/__init__.py -------------------------------------------------------------------------------- /libs/community/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests configuration to be executed before tests execution. 3 | """ 4 | 5 | from typing import List 6 | 7 | import pytest 8 | 9 | _RELEASE_FLAG = "release" 10 | _GPU_FLAG = "gpu" 11 | _LONG_FLAG = "long" 12 | _EXTENDED_FLAG = "extended" 13 | 14 | _PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG] 15 | 16 | 17 | def pytest_addoption(parser: pytest.Parser) -> None: 18 | """ 19 | Add flags accepted by pytest CLI. 20 | 21 | Args: 22 | parser: The pytest parser object. 23 | 24 | Returns: 25 | 26 | """ 27 | for flag in _PYTEST_FLAGS: 28 | parser.addoption( 29 | f"--{flag}", action="store_true", default=False, help=f"run {flag} tests" 30 | ) 31 | 32 | 33 | def pytest_configure(config: pytest.Config) -> None: 34 | """ 35 | Add pytest custom configuration. 36 | 37 | Args: 38 | config: The pytest config object. 39 | 40 | Returns: 41 | """ 42 | for flag in _PYTEST_FLAGS: 43 | config.addinivalue_line( 44 | "markers", f"{flag}: mark test to run as {flag} only test" 45 | ) 46 | 47 | 48 | def pytest_collection_modifyitems( 49 | config: pytest.Config, items: List[pytest.Item] 50 | ) -> None: 51 | """ 52 | Skip tests with a marker from our list that were not explicitly invoked. 53 | 54 | Args: 55 | config: The pytest config object. 56 | items: The list of tests to be executed. 57 | 58 | Returns: 59 | """ 60 | for item in items: 61 | keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS)) 62 | if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)): 63 | skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run") 64 | item.add_marker(skip) 65 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/.env.example: -------------------------------------------------------------------------------- 1 | PROJECT_ID=project_id 2 | DATA_STORE_ID=data_store_id 3 | IMAGE_GCS_PATH=image_gcs_path 4 | GOOGLE_API_KEY=google_api_key 5 | GOOGLE_CSE_ID=google_cse_id 6 | GOOGLE_MAPS_API_KEY=google_maps_api_key -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/community/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/fake.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from langchain_core.embeddings import Embeddings 5 | from pydantic import BaseModel 6 | 7 | 8 | class FakeEmbeddings(Embeddings, BaseModel): 9 | """Fake embedding model.""" 10 | 11 | size: int 12 | """The size of the embedding vector.""" 13 | 14 | def _get_embedding(self) -> List[float]: 15 | return list(np.random.normal(size=self.size)) 16 | 17 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 18 | return [self._get_embedding() for _ in texts] 19 | 20 | def embed_query(self, text: str) -> List[float]: 21 | return self._get_embedding() 22 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/terraform/main.tf: -------------------------------------------------------------------------------- 1 | module "cloudbuild" { 2 | source = "./../../../../../terraform/cloudbuild" 3 | 4 | library = "community" 5 | project_id = "" 6 | cloudbuildv2_repository_id = "" 7 | cloudbuild_env_vars = { 8 | DATA_STORE_ID = "", 9 | IMAGE_GCS_PATH = "gs://cloud-samples-data/vision/label/wakeupcat.jpg", 10 | PROCESSOR_NAME = "" 11 | } 12 | cloudbuild_secret_vars = { 13 | GOOGLE_API_KEY = "" 14 | GOOGLE_CSE_ID = "" 15 | } 16 | } -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_bigquery.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from langchain_google_community.bigquery import BigQueryLoader 4 | 5 | 6 | @pytest.mark.extended 7 | def test_bigquery_loader_no_options() -> None: 8 | loader = BigQueryLoader("SELECT 1 AS a, 2 AS b") 9 | docs = loader.load() 10 | 11 | assert len(docs) == 1 12 | assert docs[0].page_content == "a: 1\nb: 2" 13 | assert docs[0].metadata == {} 14 | 15 | 16 | @pytest.mark.extended 17 | def test_bigquery_loader_page_content_columns() -> None: 18 | loader = BigQueryLoader( 19 | "SELECT 1 AS a, 2 AS b UNION ALL SELECT 3 AS a, 4 AS b", 20 | page_content_columns=["a"], 21 | ) 22 | docs = loader.load() 23 | 24 | assert len(docs) == 2 25 | assert docs[0].page_content == "a: 1" 26 | assert docs[0].metadata == {} 27 | 28 | assert docs[1].page_content == "a: 3" 29 | assert docs[1].metadata == {} 30 | 31 | 32 | @pytest.mark.extended 33 | def test_bigquery_loader_metadata_columns() -> None: 34 | loader = BigQueryLoader( 35 | "SELECT 1 AS a, 2 AS b", 36 | page_content_columns=["a"], 37 | metadata_columns=["b"], 38 | ) 39 | docs = loader.load() 40 | 41 | assert len(docs) == 1 42 | assert docs[0].page_content == "a: 1" 43 | assert docs[0].metadata == {"b": 2} 44 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_bigquery_vector_search.py: -------------------------------------------------------------------------------- 1 | """Test BigQuery Vector Search. 2 | In order to run this test, you need to install Google Cloud BigQuery SDK 3 | pip install google-cloud-bigquery 4 | Your end-user credentials would be used to make the calls (make sure you've run 5 | `gcloud auth login` first). 6 | """ 7 | 8 | import os 9 | import uuid 10 | 11 | import pytest 12 | 13 | from langchain_google_community import BigQueryVectorSearch 14 | from tests.integration_tests.fake import FakeEmbeddings 15 | 16 | TEST_TABLE_NAME = "langchain_test_table" 17 | 18 | 19 | @pytest.fixture(scope="class") 20 | def store(request: pytest.FixtureRequest) -> BigQueryVectorSearch: 21 | """BigQueryVectorStore tests context. 22 | 23 | In order to run this test, you define PROJECT_ID environment variable 24 | with GCP project id. 25 | 26 | Example: 27 | export PROJECT_ID=... 28 | """ 29 | from google.cloud import bigquery # type: ignore[attr-defined] 30 | 31 | bigquery.Client(location="US").create_dataset( 32 | TestBigQueryVectorStore.dataset_name, exists_ok=True 33 | ) 34 | TestBigQueryVectorStore.store = BigQueryVectorSearch( 35 | project_id=os.environ.get("PROJECT_ID", None), # type: ignore[arg-type] 36 | embedding=FakeEmbeddings(), # type: ignore[call-arg] 37 | dataset_name=TestBigQueryVectorStore.dataset_name, 38 | table_name=TEST_TABLE_NAME, 39 | ) 40 | TestBigQueryVectorStore.store.add_texts( 41 | TestBigQueryVectorStore.texts, TestBigQueryVectorStore.metadatas 42 | ) 43 | 44 | def teardown() -> None: 45 | bigquery.Client(location="US").delete_dataset( 46 | TestBigQueryVectorStore.dataset_name, 47 | delete_contents=True, 48 | not_found_ok=True, 49 | ) 50 | 51 | request.addfinalizer(teardown) 52 | return TestBigQueryVectorStore.store 53 | 54 | 55 | class TestBigQueryVectorStore: 56 | """BigQueryVectorStore tests class.""" 57 | 58 | dataset_name = uuid.uuid4().hex 59 | store: BigQueryVectorSearch 60 | texts = ["apple", "ice cream", "Saturn", "candy", "banana"] 61 | metadatas = [ 62 | { 63 | "kind": "fruit", 64 | }, 65 | { 66 | "kind": "treat", 67 | }, 68 | { 69 | "kind": "planet", 70 | }, 71 | { 72 | "kind": "treat", 73 | }, 74 | { 75 | "kind": "fruit", 76 | }, 77 | ] 78 | 79 | @pytest.mark.skip(reason="investigating") 80 | @pytest.mark.extended 81 | def test_semantic_search(self, store: BigQueryVectorSearch) -> None: 82 | """Test on semantic similarity.""" 83 | docs = store.similarity_search("food", k=4) 84 | print(docs) # noqa: T201 85 | kinds = [d.metadata["kind"] for d in docs] 86 | assert "fruit" in kinds 87 | assert "treat" in kinds 88 | assert "planet" not in kinds 89 | 90 | @pytest.mark.skip(reason="investigating") 91 | @pytest.mark.extended 92 | def test_semantic_search_filter_fruits(self, store: BigQueryVectorSearch) -> None: 93 | """Test on semantic similarity with metadata filter.""" 94 | docs = store.similarity_search("food", filter={"kind": "fruit"}) 95 | kinds = [d.metadata["kind"] for d in docs] 96 | assert "fruit" in kinds 97 | assert "treat" not in kinds 98 | assert "planet" not in kinds 99 | 100 | @pytest.mark.skip(reason="investigating") 101 | @pytest.mark.extended 102 | def test_get_doc_by_filter(self, store: BigQueryVectorSearch) -> None: 103 | """Test on document retrieval with metadata filter.""" 104 | docs = store.get_documents(filter={"kind": "fruit"}) 105 | kinds = [d.metadata["kind"] for d in docs] 106 | assert "fruit" in kinds 107 | assert "treat" not in kinds 108 | assert "planet" not in kinds 109 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_docai.py: -------------------------------------------------------------------------------- 1 | """Integration tests for the Google Cloud DocAI parser.""" 2 | 3 | import os 4 | 5 | import pytest 6 | from langchain_core.document_loaders.blob_loaders import Blob 7 | 8 | from langchain_google_community.docai import DocAIParser 9 | 10 | 11 | @pytest.mark.extended 12 | def test_docai_layout_parser() -> None: 13 | processor_name = os.environ["PROCESSOR_NAME"] 14 | parser = DocAIParser(processor_name=processor_name, location="us") 15 | assert parser._use_layout_parser is True 16 | blob = Blob( 17 | data=None, 18 | path="gs://cloud-samples-data/gen-app-builder/search/alphabet-investor-pdfs/2022Q1_alphabet_earnings_release.pdf", 19 | ) 20 | docs = list(parser.online_process(blob=blob)) 21 | assert len(docs) == 11 22 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_docai_warehoure_retriever.py: -------------------------------------------------------------------------------- 1 | """Test Google Cloud Document AI Warehouse retriever.""" 2 | 3 | import os 4 | 5 | import pytest 6 | from langchain_core.documents import Document 7 | 8 | from langchain_google_community import DocumentAIWarehouseRetriever 9 | 10 | 11 | @pytest.mark.extended 12 | @pytest.mark.skip(reason="CI/CD not ready.") 13 | def test_google_documentai_warehoure_retriever() -> None: 14 | """In order to run this test, you should provide a project_id and user_ldap. 15 | 16 | Example: 17 | export USER_LDAP=... 18 | export PROJECT_NUMBER=... 19 | """ 20 | project_number = os.environ["PROJECT_NUMBER"] 21 | user_ldap = os.environ["USER_LDAP"] 22 | docai_wh_retriever = DocumentAIWarehouseRetriever(project_number=project_number) 23 | documents = docai_wh_retriever.get_relevant_documents( 24 | "What are Alphabet's Other Bets?", user_ldap=user_ldap 25 | ) 26 | assert len(documents) > 0 27 | for doc in documents: 28 | assert isinstance(doc, Document) 29 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_geocoding_integration.py: -------------------------------------------------------------------------------- 1 | """Integration tests for Google Geocoding API.""" 2 | 3 | import os 4 | 5 | import pytest 6 | from pydantic import SecretStr 7 | 8 | from langchain_google_community import GoogleGeocodingAPIWrapper, GoogleGeocodingTool 9 | 10 | 11 | @pytest.fixture 12 | def api_wrapper() -> GoogleGeocodingAPIWrapper: 13 | """Create API wrapper with credentials from environment.""" 14 | api_key = os.getenv("GOOGLE_MAPS_API_KEY") 15 | if not api_key: 16 | pytest.skip("GOOGLE_MAPS_API_KEY environment variable not set") 17 | return GoogleGeocodingAPIWrapper(google_api_key=SecretStr(api_key)) 18 | 19 | 20 | @pytest.fixture 21 | def geocoding_tool(api_wrapper: GoogleGeocodingAPIWrapper) -> GoogleGeocodingTool: 22 | """Create geocoding tool with the API wrapper.""" 23 | return GoogleGeocodingTool( 24 | api_wrapper=api_wrapper, 25 | max_results=5, 26 | include_bounds=True, 27 | include_metadata=True, 28 | ) 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_geocode_async(api_wrapper: GoogleGeocodingAPIWrapper) -> None: 33 | """Test async geocoding functionality.""" 34 | result = await api_wrapper.geocode_async("Statue of Liberty, New York") 35 | 36 | assert result["status"] == "OK" 37 | assert len(result["results"]) > 0 38 | 39 | location = result["results"][0] 40 | assert "address" in location 41 | assert "New York" in location["address"]["full"] 42 | 43 | # Check coordinates are roughly correct for Statue of Liberty 44 | lat = location["geometry"]["location"]["lat"] 45 | lng = location["geometry"]["location"]["lng"] 46 | assert 40.68 < lat < 40.69 # Statue of Liberty latitude 47 | assert -74.05 < lng < -74.04 # Statue of Liberty longitude 48 | 49 | 50 | def test_geocode_batch(api_wrapper: GoogleGeocodingAPIWrapper) -> None: 51 | """Test batch geocoding with multiple locations.""" 52 | locations = ["Times Square, New York", "Big Ben, London", "Sydney Opera House"] 53 | 54 | result = api_wrapper.batch_geocode(locations) 55 | assert result["status"] == "OK" 56 | assert len(result["results"]) == len(locations) 57 | 58 | # Verify each location has valid data 59 | for location in result["results"]: 60 | assert "address" in location 61 | assert "full" in location["address"] 62 | assert "geometry" in location 63 | assert "location" in location["geometry"] 64 | assert "lat" in location["geometry"]["location"] 65 | assert "lng" in location["geometry"]["location"] 66 | 67 | 68 | def test_geocode_error_handling(api_wrapper: GoogleGeocodingAPIWrapper) -> None: 69 | """Test error handling with invalid queries.""" 70 | result = api_wrapper.raw_results("NonexistentPlace12345!@#$%") 71 | 72 | assert result["status"] == "ZERO_RESULTS" 73 | assert len(result["results"]) == 0 74 | 75 | 76 | def test_tool_batch_query(geocoding_tool: GoogleGeocodingTool) -> None: 77 | """Test geocoding tool with batch query.""" 78 | query = "Times Square, Central Park, Empire State Building" 79 | result, metadata = geocoding_tool._run(query) 80 | 81 | assert len(result) == 3 82 | assert metadata["status"] == "OK" 83 | assert all("geometry" in loc for loc in result) 84 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_googlesearch_api.py: -------------------------------------------------------------------------------- 1 | """Integration test for Google Search API Wrapper.""" 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | from langchain_google_community.search import GoogleSearchAPIWrapper 8 | 9 | 10 | @pytest.mark.extended 11 | def test_call() -> None: 12 | """Test that call gives the correct answer.""" 13 | google_api_key = os.environ["GOOGLE_API_KEY"] 14 | google_cse_id = os.environ["GOOGLE_CSE_ID"] 15 | search = GoogleSearchAPIWrapper( # type: ignore[call-arg] 16 | google_api_key=google_api_key, google_cse_id=google_cse_id 17 | ) 18 | output = search.run("What was Obama's first name?") 19 | assert "Barack Hussein Obama II" in output 20 | 21 | 22 | @pytest.mark.extended 23 | def test_result_with_params_call() -> None: 24 | """Test that call gives the correct answer with extra params.""" 25 | google_api_key = os.environ["GOOGLE_API_KEY"] 26 | google_cse_id = os.environ["GOOGLE_CSE_ID"] 27 | search = GoogleSearchAPIWrapper( # type: ignore[call-arg] 28 | google_api_key=google_api_key, google_cse_id=google_cse_id 29 | ) 30 | output = search.results( 31 | query="What was Obama's first name?", 32 | num_results=5, 33 | search_params={"cr": "us", "safe": "active"}, 34 | ) 35 | assert len(output) 36 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_placeholder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | pass 7 | -------------------------------------------------------------------------------- /libs/community/tests/integration_tests/test_vision.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from langchain_core.document_loaders.blob_loaders import Blob 5 | from langchain_core.documents import Document 6 | 7 | from langchain_google_community import CloudVisionLoader, CloudVisionParser 8 | 9 | 10 | @pytest.mark.extended 11 | def test_parse_image() -> None: 12 | gcs_path = os.environ["IMAGE_GCS_PATH"] 13 | project = os.environ["PROJECT_ID"] 14 | blob = Blob(path=gcs_path, data="") # type: ignore 15 | loader = CloudVisionParser(project=project) 16 | documents = loader.parse(blob) 17 | assert len(documents) == 1 18 | assert isinstance(documents[0], Document) 19 | assert len(documents[0].page_content) > 1 20 | 21 | 22 | @pytest.mark.extended 23 | def test_load_image() -> None: 24 | gcs_path = os.environ["IMAGE_GCS_PATH"] 25 | project = os.environ["PROJECT_ID"] 26 | loader = CloudVisionLoader(project=project, file_path=gcs_path) 27 | documents = loader.load() 28 | assert len(documents) == 1 29 | assert isinstance(documents[0], Document) 30 | assert len(documents[0].page_content) > 1 31 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/community/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_docai.py: -------------------------------------------------------------------------------- 1 | """Tests for the Google Cloud DocAI parser.""" 2 | 3 | from unittest.mock import ANY, patch 4 | 5 | import pytest 6 | 7 | from langchain_google_community.docai import DocAIParser 8 | 9 | 10 | def test_docai_parser_valid_processor_name() -> None: 11 | processor_name = "projects/123456/locations/us-central1/processors/ab123dfg" 12 | with patch("google.cloud.documentai.DocumentProcessorServiceClient") as test_client: 13 | parser = DocAIParser(processor_name=processor_name, location="us") 14 | test_client.assert_called_once_with(client_options=ANY, client_info=ANY) 15 | assert parser._processor_name == processor_name 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "processor_name", 20 | ["projects/123456/locations/us-central1/processors/ab123dfg:publish", "ab123dfg"], 21 | ) 22 | def test_docai_parser_invalid_processor_name(processor_name: str) -> None: 23 | with patch("google.cloud.documentai.DocumentProcessorServiceClient"): 24 | with pytest.raises(ValueError): 25 | _ = DocAIParser(processor_name=processor_name, location="us") 26 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_drive.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from langchain_google_community.drive import GoogleDriveLoader 4 | 5 | 6 | def test_drive_default_scope() -> None: 7 | """Test that default scope is set correctly.""" 8 | loader = GoogleDriveLoader(folder_id="dummy_folder") 9 | assert loader.scopes == ["https://www.googleapis.com/auth/drive.file"] 10 | 11 | 12 | def test_drive_custom_scope() -> None: 13 | """Test setting custom scope.""" 14 | custom_scopes = ["https://www.googleapis.com/auth/drive.readonly"] 15 | loader = GoogleDriveLoader(folder_id="dummy_folder", scopes=custom_scopes) 16 | assert loader.scopes == custom_scopes 17 | 18 | 19 | def test_drive_multiple_scopes() -> None: 20 | """Test setting multiple valid scopes.""" 21 | custom_scopes = [ 22 | "https://www.googleapis.com/auth/drive.readonly", 23 | "https://www.googleapis.com/auth/drive.metadata.readonly", 24 | ] 25 | loader = GoogleDriveLoader(folder_id="dummy_folder", scopes=custom_scopes) 26 | assert loader.scopes == custom_scopes 27 | 28 | 29 | def test_drive_empty_scope_list() -> None: 30 | """Test that empty scope list raises error.""" 31 | with pytest.raises(ValueError, match="At least one scope must be provided"): 32 | GoogleDriveLoader(folder_id="dummy_folder", scopes=[]) 33 | 34 | 35 | def test_drive_invalid_scope() -> None: 36 | """Test that invalid scope raises error.""" 37 | invalid_scopes = ["https://www.googleapis.com/auth/drive.invalid"] 38 | with pytest.raises(ValueError, match="Invalid Google Drive API scope"): 39 | GoogleDriveLoader(folder_id="dummy_folder", scopes=invalid_scopes) 40 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_googlesearch_api.py: -------------------------------------------------------------------------------- 1 | """Integration test for Google Search API Wrapper.""" 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | from langchain_google_community.search import GoogleSearchAPIWrapper 7 | 8 | 9 | @pytest.mark.extended 10 | def test_no_result_call() -> None: 11 | """Test that call gives no result.""" 12 | with patch("googleapiclient.discovery.build") as search_engine: 13 | search_engine.return_value.cse.return_value.list.return_value.execute.return_value = {} # noqa: E501 14 | search = GoogleSearchAPIWrapper( # type: ignore[call-arg] 15 | google_api_key="key", google_cse_id="cse" 16 | ) 17 | output = search.run("test") 18 | search_engine.assert_called_once_with("customsearch", "v1", developerKey="key") 19 | assert "No good Google Search Result was found" == output 20 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_placeholder.py: -------------------------------------------------------------------------------- 1 | def test_placeholder() -> None: 2 | pass 3 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_rank.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | from google.cloud import discoveryengine_v1alpha # type: ignore 5 | from langchain_core.documents import Document 6 | from pytest import approx 7 | 8 | from langchain_google_community.vertex_rank import VertexAIRank 9 | 10 | 11 | # Fixtures for common setup 12 | @pytest.fixture 13 | def mock_rank_service_client() -> Mock: 14 | mock_client = Mock(spec=discoveryengine_v1alpha.RankServiceClient) 15 | mock_client.rank.return_value = discoveryengine_v1alpha.RankResponse( 16 | records=[ 17 | discoveryengine_v1alpha.RankingRecord( 18 | id="1", content="Document 1", title="Title 1", score=0.9 19 | ), 20 | discoveryengine_v1alpha.RankingRecord( 21 | id="2", content="Document 2", title="Title 2", score=0.8 22 | ), 23 | ] 24 | ) 25 | return mock_client 26 | 27 | 28 | @pytest.fixture 29 | def ranker(mock_rank_service_client: Mock) -> VertexAIRank: 30 | return VertexAIRank( 31 | project_id="test-project", 32 | location_id="test-location", 33 | ranking_config="test-config", 34 | title_field="source", 35 | client=mock_rank_service_client, 36 | ) 37 | 38 | 39 | # Unit tests 40 | def test_vertex_ai_ranker_initialization(mock_rank_service_client: Mock) -> None: 41 | ranker = VertexAIRank( 42 | project_id="test-project", 43 | location_id="test-location", 44 | ranking_config="test-config", 45 | title_field="source", 46 | client=mock_rank_service_client, 47 | ) 48 | assert ranker.project_id == "test-project" 49 | assert ranker.location_id == "test-location" 50 | assert ranker.ranking_config == "test-config" 51 | assert ranker.title_field == "source" 52 | 53 | 54 | @patch("google.cloud.discoveryengine_v1alpha.RankServiceClient") 55 | def test_rerank_documents(mock_rank_service_client: Mock, ranker: VertexAIRank) -> None: 56 | documents = [ 57 | Document(page_content="Document 1", metadata={"source": "Title 1"}), 58 | Document(page_content="Document 2", metadata={"source": "Title 2"}), 59 | ] 60 | reranked_documents = ranker._rerank_documents( 61 | query="test query", documents=documents 62 | ) 63 | print(reranked_documents) 64 | assert len(reranked_documents) == 2 65 | assert reranked_documents[0].page_content == "Document 1" 66 | assert reranked_documents[0].metadata["relevance_score"] == approx(0.9) 67 | assert reranked_documents[0].metadata["source"] == "Title 1" 68 | assert reranked_documents[1].page_content == "Document 2" 69 | assert reranked_documents[1].metadata["relevance_score"] == approx(0.8) 70 | assert reranked_documents[1].metadata["source"] == "Title 2" 71 | -------------------------------------------------------------------------------- /libs/community/tests/unit_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | from langchain_google_community._utils import ( 4 | get_user_agent, 5 | ) 6 | 7 | 8 | @patch("langchain_google_community._utils.os.environ.get") 9 | @patch("langchain_google_community._utils.metadata.version") 10 | def test_get_user_agent_with_telemetry_env_variable( 11 | mock_version: MagicMock, mock_environ_get: MagicMock 12 | ) -> None: 13 | mock_version.return_value = "1.2.3" 14 | mock_environ_get.return_value = True 15 | client_lib_version, user_agent_str = get_user_agent(module="test-module") 16 | assert client_lib_version == "1.2.3-test-module+remote_reasoning_engine" 17 | assert user_agent_str == ( 18 | "langchain-google-community/1.2.3-test-module+remote_reasoning_engine" 19 | ) 20 | 21 | 22 | @patch("langchain_google_community._utils.os.environ.get") 23 | @patch("langchain_google_community._utils.metadata.version") 24 | def test_get_user_agent_without_telemetry_env_variable( 25 | mock_version: MagicMock, mock_environ_get: MagicMock 26 | ) -> None: 27 | mock_version.return_value = "1.2.3" 28 | mock_environ_get.return_value = False 29 | client_lib_version, user_agent_str = get_user_agent(module="test-module") 30 | assert client_lib_version == "1.2.3-test-module" 31 | assert user_agent_str == "langchain-google-community/1.2.3-test-module" 32 | -------------------------------------------------------------------------------- /libs/genai/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /libs/genai/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/genai/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests help 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 10 | 11 | test tests integration_test integration_tests: 12 | poetry run pytest $(TEST_FILE) 13 | 14 | check_imports: $(shell find langchain_google_genai -name '*.py') 15 | poetry run python ./scripts/check_imports.py $^ 16 | 17 | test_watch: 18 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 19 | 20 | # Run unit tests and generate a coverage report. 21 | coverage: 22 | poetry run pytest --cov \ 23 | --cov-config=.coveragerc \ 24 | --cov-report xml \ 25 | --cov-report term-missing:skip-covered \ 26 | $(TEST_FILE) 27 | 28 | ###################### 29 | # LINTING AND FORMATTING 30 | ###################### 31 | 32 | # Define a variable for Python and notebook files. 33 | PYTHON_FILES=. 34 | MYPY_CACHE=.mypy_cache 35 | lint format: PYTHON_FILES=. 36 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 37 | lint_package: PYTHON_FILES=langchain_google_genai 38 | lint_tests: PYTHON_FILES=tests 39 | lint_tests: MYPY_CACHE=.mypy_cache_test 40 | 41 | lint lint_diff lint_package lint_tests: 42 | ./scripts/lint_imports.sh 43 | poetry run ruff check . 44 | poetry run ruff format $(PYTHON_FILES) --diff 45 | poetry run ruff check --select I $(PYTHON_FILES) 46 | mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 47 | 48 | format format_diff: 49 | poetry run ruff format $(PYTHON_FILES) 50 | poetry run ruff check --select I --fix $(PYTHON_FILES) 51 | 52 | spell_check: 53 | poetry run codespell --toml pyproject.toml 54 | 55 | spell_fix: 56 | poetry run codespell --toml pyproject.toml -w 57 | 58 | check_imports: $(shell find langchain_google_genai -name '*.py') 59 | poetry run python ./scripts/check_imports.py $^ 60 | 61 | ###################### 62 | # HELP 63 | ###################### 64 | 65 | help: 66 | @echo '----' 67 | @echo 'check_imports - check imports' 68 | @echo 'format - run code formatters' 69 | @echo 'lint - run linters' 70 | @echo 'test - run unit tests' 71 | @echo 'tests - run unit tests' 72 | @echo 'integration_test - run integration tests(NOTE: "export GOOGLE_API_KEY=..." is needed.)' 73 | @echo 'test TEST_FILE= - run all tests in file' 74 | -------------------------------------------------------------------------------- /libs/genai/langchain_google_genai/__init__.py: -------------------------------------------------------------------------------- 1 | """**LangChain Google Generative AI Integration** 2 | 3 | This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities. 4 | 5 | **Chat Models** 6 | 7 | The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications. 8 | 9 | **LLMs** 10 | 11 | The `GoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini LLMs. It allows users to generate text using a specified Gemini model. 12 | 13 | **Embeddings** 14 | 15 | The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models. 16 | These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more. 17 | **Installation** 18 | 19 | To install the package, use pip: 20 | 21 | ```python 22 | pip install -U langchain-google-genai 23 | ``` 24 | ## Using Chat Models 25 | 26 | After setting up your environment with the required API key, you can interact with the Google Gemini models. 27 | 28 | ```python 29 | from langchain_google_genai import ChatGoogleGenerativeAI 30 | 31 | llm = ChatGoogleGenerativeAI(model="gemini-pro") 32 | llm.invoke("Sing a ballad of LangChain.") 33 | ``` 34 | 35 | ## Using LLMs 36 | 37 | The package also supports generating text with Google's models. 38 | 39 | ```python 40 | from langchain_google_genai import GoogleGenerativeAI 41 | 42 | llm = GoogleGenerativeAI(model="gemini-pro") 43 | llm.invoke("Once upon a time, a library called LangChain") 44 | ``` 45 | 46 | ## Embedding Generation 47 | 48 | The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications. 49 | 50 | ```python 51 | from langchain_google_genai import GoogleGenerativeAIEmbeddings 52 | 53 | embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") 54 | embeddings.embed_query("hello, world!") 55 | ``` 56 | """ # noqa: E501 57 | 58 | from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory, Modality 59 | from langchain_google_genai.chat_models import ChatGoogleGenerativeAI 60 | from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings 61 | from langchain_google_genai.genai_aqa import ( 62 | AqaInput, 63 | AqaOutput, 64 | GenAIAqa, 65 | ) 66 | from langchain_google_genai.google_vector_store import ( 67 | DoesNotExistsException, 68 | GoogleVectorStore, 69 | ) 70 | from langchain_google_genai.llms import GoogleGenerativeAI 71 | 72 | __all__ = [ 73 | "AqaInput", 74 | "AqaOutput", 75 | "ChatGoogleGenerativeAI", 76 | "DoesNotExistsException", 77 | "GenAIAqa", 78 | "GoogleGenerativeAIEmbeddings", 79 | "GoogleGenerativeAI", 80 | "GoogleVectorStore", 81 | "HarmBlockThreshold", 82 | "HarmCategory", 83 | "Modality", 84 | "DoesNotExistsException", 85 | ] 86 | -------------------------------------------------------------------------------- /libs/genai/langchain_google_genai/_enums.py: -------------------------------------------------------------------------------- 1 | import google.ai.generativelanguage_v1beta as genai 2 | 3 | HarmBlockThreshold = genai.SafetySetting.HarmBlockThreshold 4 | HarmCategory = genai.HarmCategory 5 | Modality = genai.GenerationConfig.Modality 6 | 7 | __all__ = ["HarmBlockThreshold", "HarmCategory", "Modality"] 8 | -------------------------------------------------------------------------------- /libs/genai/langchain_google_genai/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/genai/langchain_google_genai/py.typed -------------------------------------------------------------------------------- /libs/genai/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "langchain-google-genai" 3 | version = "2.1.5" 4 | description = "An integration package connecting Google's genai package and LangChain" 5 | authors = [] 6 | readme = "README.md" 7 | repository = "https://github.com/langchain-ai/langchain-google" 8 | license = "MIT" 9 | 10 | [tool.poetry.urls] 11 | "Source Code" = "https://github.com/langchain-ai/langchain-google/tree/main/libs/genai" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9,<4.0" 15 | langchain-core = "^0.3.62" 16 | google-ai-generativelanguage = "^0.6.18" 17 | pydantic = ">=2,<3" 18 | filetype = "^1.2.0" 19 | 20 | [tool.poetry.group.test] 21 | optional = true 22 | 23 | [tool.poetry.group.test.dependencies] 24 | pytest = "^7.3.0" 25 | freezegun = "^1.2.2" 26 | pytest-mock = "^3.10.0" 27 | syrupy = "^4.0.2" 28 | pytest-watcher = "^0.3.4" 29 | pytest-asyncio = "^0.21.1" 30 | pytest-retry = "^1.7.0" 31 | numpy = ">=1.26.2" 32 | langchain-tests = "0.3.19" 33 | 34 | [tool.codespell] 35 | ignore-words-list = "rouge" 36 | 37 | 38 | [tool.poetry.group.codespell] 39 | optional = true 40 | 41 | [tool.poetry.group.codespell.dependencies] 42 | codespell = "^2.2.0" 43 | 44 | 45 | [tool.poetry.group.test_integration] 46 | optional = true 47 | 48 | [tool.poetry.group.test_integration.dependencies] 49 | pytest = "^7.3.0" 50 | 51 | 52 | [tool.poetry.group.lint] 53 | optional = true 54 | 55 | [tool.poetry.group.lint.dependencies] 56 | ruff = "^0.1.5" 57 | 58 | 59 | [tool.poetry.group.typing.dependencies] 60 | mypy = "^1.10" 61 | types-requests = "^2.28.11.5" 62 | types-google-cloud-ndb = "^2.2.0.1" 63 | types-protobuf = "^4.24.0.20240302" 64 | numpy = ">=1.26.2" 65 | 66 | 67 | [tool.poetry.group.dev] 68 | optional = true 69 | 70 | [tool.poetry.group.dev.dependencies] 71 | types-requests = "^2.31.0.10" 72 | types-google-cloud-ndb = "^2.2.0.1" 73 | 74 | [tool.ruff.lint] 75 | select = [ 76 | "E", # pycodestyle 77 | "F", # pyflakes 78 | "I", # isort 79 | ] 80 | 81 | [tool.mypy] 82 | disallow_untyped_defs = "True" 83 | 84 | [tool.coverage.run] 85 | omit = ["tests/*"] 86 | 87 | [build-system] 88 | requires = ["poetry-core>=1.0.0"] 89 | build-backend = "poetry.core.masonry.api" 90 | 91 | [tool.pytest.ini_options] 92 | # --strict-markers will raise errors on unknown marks. 93 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 94 | # 95 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 96 | # --strict-config any warnings encountered while parsing the `pytest` 97 | # section of the configuration file raise errors. 98 | # 99 | # https://github.com/tophat/syrupy 100 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 101 | #addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 102 | # Registering custom markers. 103 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 104 | markers = [ 105 | "requires: mark tests as requiring a specific library", 106 | "asyncio: mark tests as requiring asyncio", 107 | "compile: mark placeholder test used to compile integration tests without running them", 108 | ] 109 | asyncio_mode = "auto" 110 | -------------------------------------------------------------------------------- /libs/genai/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | for file in files: 9 | try: 10 | SourceFileLoader("x", file).load_module() 11 | except Exception: 12 | has_faillure = True 13 | print(file) 14 | traceback.print_exc() 15 | print() 16 | 17 | sys.exit(1 if has_failure else 0) 18 | -------------------------------------------------------------------------------- /libs/genai/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/genai/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/genai/tests/__init__.py -------------------------------------------------------------------------------- /libs/genai/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests configuration to be executed before tests execution. 3 | """ 4 | 5 | from typing import List 6 | 7 | import pytest 8 | 9 | _RELEASE_FLAG = "release" 10 | _GPU_FLAG = "gpu" 11 | _LONG_FLAG = "long" 12 | _EXTENDED_FLAG = "extended" 13 | 14 | _PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG] 15 | 16 | 17 | def pytest_addoption(parser: pytest.Parser) -> None: 18 | """ 19 | Add flags accepted by pytest CLI. 20 | 21 | Args: 22 | parser: The pytest parser object. 23 | 24 | Returns: 25 | 26 | """ 27 | for flag in _PYTEST_FLAGS: 28 | parser.addoption( 29 | f"--{flag}", action="store_true", default=False, help=f"run {flag} tests" 30 | ) 31 | 32 | 33 | def pytest_configure(config: pytest.Config) -> None: 34 | """ 35 | Add pytest custom configuration. 36 | 37 | Args: 38 | config: The pytest config object. 39 | 40 | Returns: 41 | """ 42 | for flag in _PYTEST_FLAGS: 43 | config.addinivalue_line( 44 | "markers", f"{flag}: mark test to run as {flag} only test" 45 | ) 46 | 47 | 48 | def pytest_collection_modifyitems( 49 | config: pytest.Config, items: List[pytest.Item] 50 | ) -> None: 51 | """ 52 | Skip tests with a marker from our list that were not explicitly invoked. 53 | 54 | Args: 55 | config: The pytest config object. 56 | items: The list of tests to be executed. 57 | 58 | Returns: 59 | """ 60 | for item in items: 61 | keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS)) 62 | if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)): 63 | skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run") 64 | item.add_marker(skip) 65 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/.env.example: -------------------------------------------------------------------------------- 1 | PROJECT_ID=project_id -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/genai/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/terraform/main.tf: -------------------------------------------------------------------------------- 1 | module "cloudbuild" { 2 | source = "./../../../../../terraform/cloudbuild" 3 | 4 | library = "genai" 5 | project_id = "" 6 | cloudbuildv2_repository_id = "" 7 | cloudbuild_env_vars = { 8 | } 9 | cloudbuild_secret_vars = { 10 | GOOGLE_API_KEY = "" 11 | } 12 | } -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from langchain_core.callbacks import BaseCallbackHandler 4 | from langchain_core.outputs import LLMResult 5 | from langchain_core.prompts import PromptTemplate 6 | 7 | from langchain_google_genai import ChatGoogleGenerativeAI 8 | 9 | 10 | class StreamingLLMCallbackHandler(BaseCallbackHandler): 11 | def __init__(self, **kwargs: Any): 12 | super().__init__(**kwargs) 13 | self.tokens: List[Any] = [] 14 | self.generations: List[Any] = [] 15 | 16 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 17 | self.tokens.append(token) 18 | 19 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: 20 | self.generations.append(response.generations[0][0].text) 21 | 22 | 23 | def test_streaming_callback() -> None: 24 | prompt_template = "Tell me details about the Company {name} with 2 bullet point?" 25 | cb = StreamingLLMCallbackHandler() 26 | llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001", callbacks=[cb]) 27 | llm_chain = PromptTemplate.from_template(prompt_template) | llm 28 | for t in llm_chain.stream({"name": "Google"}): 29 | pass 30 | assert len(cb.tokens) > 1 31 | assert len(cb.generations) == 1 32 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_function_call.py: -------------------------------------------------------------------------------- 1 | """Test ChatGoogleGenerativeAI function call.""" 2 | 3 | import json 4 | 5 | from langchain_core.messages import AIMessage 6 | from langchain_core.tools import tool 7 | from pydantic import BaseModel 8 | 9 | from langchain_google_genai.chat_models import ( 10 | ChatGoogleGenerativeAI, 11 | ) 12 | 13 | 14 | def test_function_call() -> None: 15 | functions = [ 16 | { 17 | "name": "get_weather", 18 | "description": "Determine weather in my location", 19 | "parameters": { 20 | "type": "object", 21 | "properties": { 22 | "location": { 23 | "type": "string", 24 | "description": "The city and state e.g. San Francisco, CA", 25 | }, 26 | "unit": {"type": "string", "enum": ["c", "f"]}, 27 | }, 28 | "required": ["location"], 29 | }, 30 | } 31 | ] 32 | llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind( 33 | functions=functions 34 | ) 35 | res = llm.invoke("what weather is today in san francisco?") 36 | assert res 37 | assert res.additional_kwargs 38 | assert "function_call" in res.additional_kwargs 39 | assert "get_weather" == res.additional_kwargs["function_call"]["name"] 40 | arguments_str = res.additional_kwargs["function_call"]["arguments"] 41 | assert isinstance(arguments_str, str) 42 | arguments = json.loads(arguments_str) 43 | assert "location" in arguments 44 | 45 | 46 | def test_tool_call() -> None: 47 | @tool 48 | def search_tool(query: str) -> str: 49 | """Searches the web for `query` and returns the result.""" 50 | raise NotImplementedError 51 | 52 | llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind( 53 | functions=[search_tool] 54 | ) 55 | response = llm.invoke("weather in san francisco") 56 | assert isinstance(response, AIMessage) 57 | assert isinstance(response.content, str) 58 | assert response.content == "" 59 | function_call = response.additional_kwargs.get("function_call") 60 | assert function_call 61 | assert function_call["name"] == "search_tool" 62 | arguments_str = function_call.get("arguments") 63 | assert arguments_str 64 | arguments = json.loads(arguments_str) 65 | assert "query" in arguments 66 | 67 | 68 | class MyModel(BaseModel): 69 | name: str 70 | age: int 71 | 72 | 73 | def test_pydantic_call() -> None: 74 | llm = ChatGoogleGenerativeAI(model="models/gemini-2.0-flash-001").bind( 75 | functions=[MyModel] 76 | ) 77 | response = llm.invoke("my name is Erick and I am 27 years old") 78 | assert isinstance(response, AIMessage) 79 | assert isinstance(response.content, str) 80 | assert response.content == "" 81 | function_call = response.additional_kwargs.get("function_call") 82 | assert function_call 83 | assert function_call["name"] == "MyModel" 84 | arguments_str = function_call.get("arguments") 85 | assert arguments_str 86 | arguments = json.loads(arguments_str) 87 | assert arguments == { 88 | "name": "Erick", 89 | "age": 27.0, 90 | } 91 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_llms.py: -------------------------------------------------------------------------------- 1 | """Test Google GenerativeAI API wrapper. 2 | 3 | Note: This test must be run with the GOOGLE_API_KEY environment variable set to a 4 | valid API key. 5 | """ 6 | 7 | from typing import Dict, Generator 8 | 9 | import pytest 10 | from langchain_core.outputs import LLMResult 11 | 12 | from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory 13 | 14 | model_names = ["gemini-1.5-flash-latest"] 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "model_name", 19 | model_names, 20 | ) 21 | def test_google_generativeai_call(model_name: str) -> None: 22 | """Test valid call to Google GenerativeAI text API.""" 23 | if model_name: 24 | llm = GoogleGenerativeAI(max_tokens=10, model=model_name) 25 | else: 26 | llm = GoogleGenerativeAI(max_tokens=10) # type: ignore[call-arg] 27 | output = llm("Say foo:") 28 | assert isinstance(output, str) 29 | assert llm._llm_type == "google_gemini" 30 | assert llm.client.model == f"models/{model_name}" 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "model_name", 35 | model_names, 36 | ) 37 | def test_google_generativeai_generate(model_name: str) -> None: 38 | llm = GoogleGenerativeAI(temperature=0.3, model=model_name) 39 | output = llm.generate(["Say foo:"]) 40 | assert isinstance(output, LLMResult) 41 | assert len(output.generations) == 1 42 | assert len(output.generations[0]) == 1 43 | # check the usage data 44 | generation_info = output.generations[0][0].generation_info 45 | assert generation_info is not None 46 | assert len(generation_info.get("usage_metadata", {})) > 0 47 | 48 | 49 | async def test_google_generativeai_agenerate() -> None: 50 | llm = GoogleGenerativeAI(temperature=0, model="models/gemini-2.0-flash-001") 51 | output = await llm.agenerate(["Please say foo:"]) 52 | assert isinstance(output, LLMResult) 53 | 54 | 55 | def test_generativeai_stream() -> None: 56 | llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest") 57 | outputs = list(llm.stream("Please say foo:")) 58 | assert isinstance(outputs[0], str) 59 | 60 | 61 | def test_generativeai_get_num_tokens_gemini() -> None: 62 | llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest") 63 | output = llm.get_num_tokens("How are you?") 64 | assert output == 4 65 | 66 | 67 | def test_safety_settings_gemini() -> None: 68 | # test with blocked prompt 69 | llm = GoogleGenerativeAI(temperature=0, model="gemini-1.5-flash-latest") 70 | output = llm.generate(prompts=["how to make a bomb?"]) 71 | assert isinstance(output, LLMResult) 72 | assert len(output.generations[0]) > 0 73 | 74 | # safety filters 75 | safety_settings: Dict[HarmCategory, HarmBlockThreshold] = { 76 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, # type: ignore[dict-item] 77 | } 78 | 79 | # test with safety filters directly to generate 80 | output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings) 81 | assert isinstance(output, LLMResult) 82 | assert len(output.generations[0]) > 0 83 | 84 | # test with safety filters directly to stream 85 | streamed_messages = [] 86 | output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings) 87 | assert isinstance(output_stream, Generator) 88 | for message in output_stream: 89 | streamed_messages.append(message) 90 | assert len(streamed_messages) > 0 91 | 92 | # test with safety filters on instantiation 93 | llm = GoogleGenerativeAI( 94 | model="gemini-1.5-flash-latest", 95 | safety_settings=safety_settings, 96 | temperature=0, 97 | ) 98 | output = llm.generate(prompts=["how to make a bomb?"]) 99 | assert isinstance(output, LLMResult) 100 | assert len(output.generations[0]) > 0 101 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_standard.py: -------------------------------------------------------------------------------- 1 | """Standard LangChain interface tests""" 2 | 3 | from typing import Dict, List, Literal, Type 4 | 5 | import pytest 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_core.rate_limiters import InMemoryRateLimiter 8 | from langchain_core.tools import BaseTool 9 | from langchain_tests.integration_tests import ChatModelIntegrationTests 10 | 11 | from langchain_google_genai import ChatGoogleGenerativeAI 12 | 13 | rate_limiter = InMemoryRateLimiter(requests_per_second=0.25) 14 | 15 | 16 | class TestGeminiAI2Standard(ChatModelIntegrationTests): 17 | @property 18 | def chat_model_class(self) -> Type[BaseChatModel]: 19 | return ChatGoogleGenerativeAI 20 | 21 | @property 22 | def chat_model_params(self) -> dict: 23 | return { 24 | "model": "models/gemini-2.0-flash-001", 25 | "rate_limiter": rate_limiter, 26 | } 27 | 28 | @property 29 | def supports_image_inputs(self) -> bool: 30 | return True 31 | 32 | @property 33 | def supports_image_urls(self) -> bool: 34 | return True 35 | 36 | @property 37 | def supports_image_tool_message(self) -> bool: 38 | return True 39 | 40 | @property 41 | def supports_pdf_inputs(self) -> bool: 42 | return True 43 | 44 | @property 45 | def supports_audio_inputs(self) -> bool: 46 | return True 47 | 48 | @pytest.mark.xfail( 49 | reason="Likely a bug in genai: prompt_token_count inconsistent in final chunk." 50 | ) 51 | def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: 52 | super().test_usage_metadata_streaming(model) 53 | 54 | @pytest.mark.xfail(reason="investigate") 55 | def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None: 56 | super().test_bind_runnables_as_tools(model) 57 | 58 | @pytest.mark.xfail(reason=("investigate")) 59 | def test_tool_calling_with_no_arguments(self, model: BaseChatModel) -> None: 60 | super().test_tool_calling_with_no_arguments(model) 61 | 62 | 63 | class TestGeminiAIStandard(ChatModelIntegrationTests): 64 | @property 65 | def chat_model_class(self) -> Type[BaseChatModel]: 66 | return ChatGoogleGenerativeAI 67 | 68 | @property 69 | def chat_model_params(self) -> dict: 70 | return { 71 | "model": "models/gemini-1.5-pro-latest", 72 | "rate_limiter": rate_limiter, 73 | } 74 | 75 | @pytest.mark.xfail(reason="Not yet supported") 76 | def test_tool_message_histories_list_content( 77 | self, model: BaseChatModel, my_adder_tool: BaseTool 78 | ) -> None: 79 | super().test_tool_message_histories_list_content(model, my_adder_tool) 80 | 81 | @pytest.mark.xfail( 82 | reason="Investigate: prompt_token_count inconsistent in final chunk." 83 | ) 84 | def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: 85 | super().test_usage_metadata_streaming(model) 86 | 87 | @property 88 | def supported_usage_metadata_details( 89 | self, 90 | ) -> Dict[ 91 | Literal["invoke", "stream"], 92 | List[ 93 | Literal[ 94 | "audio_input", 95 | "audio_output", 96 | "reasoning_output", 97 | "cache_read_input", 98 | "cache_creation_input", 99 | ] 100 | ], 101 | ]: 102 | return {"invoke": [], "stream": []} 103 | -------------------------------------------------------------------------------- /libs/genai/tests/integration_tests/test_tools.py: -------------------------------------------------------------------------------- 1 | from langchain_core.tools import tool 2 | 3 | from langchain_google_genai import ChatGoogleGenerativeAI 4 | 5 | 6 | @tool 7 | def check_weather(location: str) -> str: 8 | """Return the weather forecast for the specified location.""" 9 | return f"It's always sunny in {location}" 10 | 11 | 12 | @tool 13 | def check_live_traffic(location: str) -> str: 14 | """Return the live traffic for the specified location.""" 15 | return f"The traffic is standstill in {location}" 16 | 17 | 18 | @tool 19 | def check_tennis_score(player: str) -> str: 20 | """Return the latest player's tennis score.""" 21 | return f"{player} is currently winning 6-0" 22 | 23 | 24 | def test_multiple_tools() -> None: 25 | tools = [check_weather, check_live_traffic, check_tennis_score] 26 | 27 | model = ChatGoogleGenerativeAI( 28 | model="gemini-1.5-flash-001", 29 | ) 30 | 31 | model_with_tools = model.bind_tools(tools) 32 | 33 | input = "What is the latest tennis score for Leonid?" 34 | 35 | result = model_with_tools.invoke(input) 36 | assert len(result.tool_calls) == 1 # type: ignore 37 | assert result.tool_calls[0]["name"] == "check_tennis_score" # type: ignore 38 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/genai/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: TestGeminiAIStandard.test_serdes[serialized] 3 | dict({ 4 | 'id': list([ 5 | 'langchain_google_genai', 6 | 'chat_models', 7 | 'ChatGoogleGenerativeAI', 8 | ]), 9 | 'kwargs': dict({ 10 | 'default_metadata': list([ 11 | ]), 12 | 'google_api_key': dict({ 13 | 'id': list([ 14 | 'GOOGLE_API_KEY', 15 | ]), 16 | 'lc': 1, 17 | 'type': 'secret', 18 | }), 19 | 'max_output_tokens': 100, 20 | 'max_retries': 2, 21 | 'model': 'models/gemini-1.0-pro-001', 22 | 'model_kwargs': dict({ 23 | 'stop': list([ 24 | ]), 25 | }), 26 | 'n': 1, 27 | 'temperature': 0.0, 28 | 'timeout': 60.0, 29 | }), 30 | 'lc': 1, 31 | 'name': 'ChatGoogleGenerativeAI', 32 | 'type': 'constructor', 33 | }) 34 | # --- 35 | # name: TestGemini_15_AIStandard.test_serdes[serialized] 36 | dict({ 37 | 'id': list([ 38 | 'langchain_google_genai', 39 | 'chat_models', 40 | 'ChatGoogleGenerativeAI', 41 | ]), 42 | 'kwargs': dict({ 43 | 'default_metadata': list([ 44 | ]), 45 | 'google_api_key': dict({ 46 | 'id': list([ 47 | 'GOOGLE_API_KEY', 48 | ]), 49 | 'lc': 1, 50 | 'type': 'secret', 51 | }), 52 | 'max_output_tokens': 100, 53 | 'max_retries': 2, 54 | 'model': 'models/gemini-1.5-pro-001', 55 | 'model_kwargs': dict({ 56 | 'stop': list([ 57 | ]), 58 | }), 59 | 'n': 1, 60 | 'temperature': 0.0, 61 | 'timeout': 60.0, 62 | }), 63 | 'lc': 1, 64 | 'name': 'ChatGoogleGenerativeAI', 65 | 'type': 'constructor', 66 | }) 67 | # --- 68 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/test_common.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | from langchain_google_genai._common import ( 4 | get_user_agent, 5 | ) 6 | 7 | 8 | @patch("langchain_google_genai._common.os.environ.get") 9 | @patch("langchain_google_genai._common.metadata.version") 10 | def test_get_user_agent_with_telemetry_env_variable( 11 | mock_version: MagicMock, mock_environ_get: MagicMock 12 | ) -> None: 13 | mock_version.return_value = "1.2.3" 14 | mock_environ_get.return_value = True 15 | client_lib_version, user_agent_str = get_user_agent(module="test-module") 16 | assert client_lib_version == "1.2.3-test-module+remote_reasoning_engine" 17 | assert user_agent_str == ( 18 | "langchain-google-genai/1.2.3-test-module+remote_reasoning_engine" 19 | ) 20 | 21 | 22 | @patch("langchain_google_genai._common.os.environ.get") 23 | @patch("langchain_google_genai._common.metadata.version") 24 | def test_get_user_agent_without_telemetry_env_variable( 25 | mock_version: MagicMock, mock_environ_get: MagicMock 26 | ) -> None: 27 | mock_version.return_value = "1.2.3" 28 | mock_environ_get.return_value = False 29 | client_lib_version, user_agent_str = get_user_agent(module="test-module") 30 | assert client_lib_version == "1.2.3-test-module" 31 | assert user_agent_str == "langchain-google-genai/1.2.3-test-module" 32 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/test_genai_aqa.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import google.ai.generativelanguage as genai 4 | import pytest 5 | 6 | from langchain_google_genai import ( 7 | AqaInput, 8 | GenAIAqa, 9 | ) 10 | from langchain_google_genai import _genai_extension as genaix 11 | 12 | # Make sure the tests do not hit actual production servers. 13 | genaix.set_config( 14 | genaix.Config( 15 | api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend", 16 | testing=True, 17 | ) 18 | ) 19 | 20 | 21 | @pytest.mark.requires("google.ai.generativelanguage") 22 | def test_it_can_be_constructed() -> None: 23 | GenAIAqa() 24 | 25 | 26 | @pytest.mark.requires("google.ai.generativelanguage") 27 | @patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") 28 | def test_invoke(mock_generate_answer: MagicMock) -> None: 29 | # Arrange 30 | mock_generate_answer.return_value = genai.GenerateAnswerResponse( 31 | answer=genai.Candidate( 32 | content=genai.Content(parts=[genai.Part(text="42")]), 33 | grounding_attributions=[ 34 | genai.GroundingAttribution( 35 | content=genai.Content( 36 | parts=[genai.Part(text="Meaning of life is 42.")] 37 | ), 38 | source_id=genai.AttributionSourceId( 39 | grounding_passage=genai.AttributionSourceId.GroundingPassageId( 40 | passage_id="corpora/123/documents/456/chunks/789", 41 | part_index=0, 42 | ) 43 | ), 44 | ), 45 | ], 46 | finish_reason=genai.Candidate.FinishReason.STOP, 47 | ), 48 | answerable_probability=0.7, 49 | ) 50 | 51 | # Act 52 | aqa = GenAIAqa( 53 | temperature=0.5, 54 | answer_style=genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE, 55 | safety_settings=[ 56 | genai.SafetySetting( 57 | category=genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, 58 | threshold=genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, 59 | ) 60 | ], 61 | ) 62 | output = aqa.invoke( 63 | input=AqaInput( 64 | prompt="What is the meaning of life?", 65 | source_passages=["It's 42."], 66 | ) 67 | ) 68 | 69 | # Assert 70 | assert output.answer == "42" 71 | assert output.attributed_passages == ["Meaning of life is 42."] 72 | assert output.answerable_probability == pytest.approx(0.7) 73 | 74 | assert mock_generate_answer.call_count == 1 75 | request = mock_generate_answer.call_args.args[0] 76 | assert request.contents[0].parts[0].text == "What is the meaning of life?" 77 | 78 | assert request.answer_style == genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE 79 | 80 | assert len(request.safety_settings) == 1 81 | assert ( 82 | request.safety_settings[0].category 83 | == genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT 84 | ) 85 | assert ( 86 | request.safety_settings[0].threshold 87 | == genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE 88 | ) 89 | 90 | assert request.temperature == 0.5 91 | 92 | passages = request.inline_passages.passages 93 | assert len(passages) == 1 94 | passage = passages[0] 95 | assert passage.content.parts[0].text == "It's 42." 96 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_google_genai import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "AqaInput", 5 | "AqaOutput", 6 | "ChatGoogleGenerativeAI", 7 | "DoesNotExistsException", 8 | "GenAIAqa", 9 | "GoogleGenerativeAIEmbeddings", 10 | "GoogleGenerativeAI", 11 | "GoogleVectorStore", 12 | "HarmBlockThreshold", 13 | "HarmCategory", 14 | "Modality", 15 | "DoesNotExistsException", 16 | ] 17 | 18 | 19 | def test_all_imports() -> None: 20 | assert sorted(EXPECTED_ALL) == sorted(__all__) 21 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/test_llms.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from langchain_google_genai.llms import GoogleGenerativeAI 4 | 5 | 6 | def test_tracing_params() -> None: 7 | # Test standard tracing params 8 | llm = GoogleGenerativeAI(model="gemini-pro", google_api_key="foo") # type: ignore[call-arg] 9 | ls_params = llm._get_ls_params() 10 | assert ls_params == { 11 | "ls_provider": "google_genai", 12 | "ls_model_type": "llm", 13 | "ls_model_name": "gemini-pro", 14 | "ls_temperature": 0.7, 15 | } 16 | 17 | llm = GoogleGenerativeAI( 18 | model="gemini-pro", 19 | temperature=0.1, 20 | max_output_tokens=10, 21 | google_api_key="foo", # type: ignore[call-arg] 22 | ) 23 | ls_params = llm._get_ls_params() 24 | assert ls_params == { 25 | "ls_provider": "google_genai", 26 | "ls_model_type": "llm", 27 | "ls_model_name": "gemini-pro", 28 | "ls_temperature": 0.1, 29 | "ls_max_tokens": 10, 30 | } 31 | 32 | # Test initialization with an invalid argument to check warning 33 | with patch("langchain_google_genai.llms.logger.warning") as mock_warning: 34 | llm = GoogleGenerativeAI( 35 | model="gemini-pro", 36 | google_api_key="foo", # type: ignore[call-arg] 37 | safety_setting={ 38 | "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_LOW_AND_ABOVE" 39 | }, # Invalid arg 40 | ) 41 | assert llm.model == "gemini-pro" 42 | ls_params = llm._get_ls_params() 43 | assert ls_params["ls_model_name"] == "gemini-pro" 44 | mock_warning.assert_called_once() 45 | call_args = mock_warning.call_args[0][0] 46 | assert "Unexpected argument 'safety_setting'" in call_args 47 | assert "Did you mean: 'safety_settings'?" in call_args 48 | -------------------------------------------------------------------------------- /libs/genai/tests/unit_tests/test_standard.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type 2 | 3 | from langchain_core.language_models import BaseChatModel 4 | from langchain_tests.unit_tests import ChatModelUnitTests 5 | 6 | from langchain_google_genai import ChatGoogleGenerativeAI 7 | 8 | 9 | class TestGeminiAIStandard(ChatModelUnitTests): 10 | @property 11 | def chat_model_class(self) -> Type[BaseChatModel]: 12 | return ChatGoogleGenerativeAI 13 | 14 | @property 15 | def chat_model_params(self) -> dict: 16 | return {"model": "models/gemini-1.0-pro-001"} 17 | 18 | @property 19 | def init_from_env_params(self) -> Tuple[dict, dict, dict]: 20 | return ( 21 | {"GOOGLE_API_KEY": "api_key"}, 22 | self.chat_model_params, 23 | {"google_api_key": "api_key"}, 24 | ) 25 | 26 | 27 | class TestGemini_15_AIStandard(ChatModelUnitTests): 28 | @property 29 | def chat_model_class(self) -> Type[BaseChatModel]: 30 | return ChatGoogleGenerativeAI 31 | 32 | @property 33 | def chat_model_params(self) -> dict: 34 | return {"model": "models/gemini-1.5-pro-001"} 35 | 36 | @property 37 | def init_from_env_params(self) -> Tuple[dict, dict, dict]: 38 | return ( 39 | {"GOOGLE_API_KEY": "api_key"}, 40 | self.chat_model_params, 41 | {"google_api_key": "api_key"}, 42 | ) 43 | -------------------------------------------------------------------------------- /libs/vertexai/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache_test 3 | -------------------------------------------------------------------------------- /libs/vertexai/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/vertexai/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all format lint test tests integration_tests docker_tests help extended_tests 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | # Define a variable for the test file path. 7 | TEST_FILE ?= tests/unit_tests/ 8 | 9 | integration_test integration_tests: TEST_FILE = tests/integration_tests/ 10 | 11 | test tests integration_test integration_tests: 12 | poetry run pytest --release $(TEST_FILE) 13 | 14 | test_watch: 15 | poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) 16 | 17 | # Run unit tests and generate a coverage report. 18 | coverage: 19 | poetry run pytest --cov \ 20 | --cov-config=.coveragerc \ 21 | --cov-report xml \ 22 | --cov-report term-missing:skip-covered \ 23 | $(TEST_FILE) 24 | 25 | ###################### 26 | # LINTING AND FORMATTING 27 | ###################### 28 | 29 | # Define a variable for Python and notebook files. 30 | PYTHON_FILES=. 31 | MYPY_CACHE=.mypy_cache 32 | lint format: PYTHON_FILES=. 33 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/google-vertexai --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') 34 | lint_package: PYTHON_FILES=langchain_google_vertexai 35 | lint_tests: PYTHON_FILES=tests 36 | lint_tests: MYPY_CACHE=.mypy_cache_test 37 | 38 | lint lint_diff lint_package lint_tests: 39 | ./scripts/lint_imports.sh 40 | poetry run ruff check . 41 | poetry run ruff format $(PYTHON_FILES) --diff 42 | poetry run ruff check --select I $(PYTHON_FILES) 43 | mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) 44 | 45 | format format_diff: 46 | poetry run ruff format $(PYTHON_FILES) 47 | poetry run ruff --select I --fix $(PYTHON_FILES) 48 | 49 | spell_check: 50 | poetry run codespell --toml pyproject.toml 51 | 52 | spell_fix: 53 | poetry run codespell --toml pyproject.toml -w 54 | 55 | check_imports: $(shell find langchain_google_vertexai -name '*.py') 56 | poetry run python ./scripts/check_imports.py $^ 57 | 58 | ###################### 59 | # HELP 60 | ###################### 61 | 62 | help: 63 | @echo '----' 64 | @echo 'check_imports - check imports' 65 | @echo 'format - run code formatters' 66 | @echo 'lint - run linters' 67 | @echo 'test - run unit tests' 68 | @echo 'tests - run unit tests' 69 | @echo 'test TEST_FILE= - run all tests in file' 70 | -------------------------------------------------------------------------------- /libs/vertexai/README.md: -------------------------------------------------------------------------------- 1 | # langchain-google-vertexai 2 | 3 | This package contains the LangChain integrations for Google Cloud generative models. 4 | 5 | ## Contents 6 | 7 | 1. [Installation](#installation) 8 | 2. [Chat Models](#chat-models) 9 | * [Multimodal inputs](#multimodal-inputs) 10 | 3. [Embeddings](#embeddings) 11 | 4. [LLMs](#llms) 12 | 5. [Code Generation](#code-generation) 13 | * [Example: Generate a Python function](#example-generate-a-python-function) 14 | * [Example: Generate JavaScript code](#example-generate-javascript-code) 15 | * [Notes](#notes) 16 | 17 | ## Installation 18 | 19 | ```bash 20 | pip install -U langchain-google-vertexai 21 | ``` 22 | 23 | ## Chat Models 24 | 25 | `ChatVertexAI` class exposes models such as `gemini-pro` and other Gemini variants. 26 | 27 | To use, you should have a Google Cloud project with APIs enabled, and configured credentials. Initialize the model as: 28 | 29 | ```python 30 | from langchain_google_vertexai import ChatVertexAI 31 | 32 | llm = ChatVertexAI(model_name="gemini-pro") 33 | llm.invoke("Sing a ballad of LangChain.") 34 | ``` 35 | 36 | ### Multimodal inputs 37 | 38 | Gemini supports image inputs when providing a single chat message. Example: 39 | 40 | ```python 41 | from langchain_core.messages import HumanMessage 42 | from langchain_google_vertexai import ChatVertexAI 43 | 44 | llm = ChatVertexAI(model_name="gemini-2.0-flash-001") 45 | message = HumanMessage( 46 | content=[ 47 | { 48 | "type": "text", 49 | "text": "What's in this image?", 50 | }, 51 | {"type": "image_url", "image_url": {"url": "https://picsum.photos/seed/picsum/200/300"}}, 52 | ] 53 | ) 54 | llm.invoke([message]) 55 | ``` 56 | 57 | The value of `image_url` can be: 58 | 59 | * A public image URL 60 | * An accessible Google Cloud Storage (GCS) file (e.g., `"gcs://path/to/file.png"`) 61 | * A base64 encoded image (e.g., `"data:image/png;base64,abcd124"`) 62 | 63 | ## Embeddings 64 | 65 | Google Cloud embeddings models can be used as: 66 | 67 | ```python 68 | from langchain_google_vertexai import VertexAIEmbeddings 69 | 70 | embeddings = VertexAIEmbeddings() 71 | embeddings.embed_query("hello, world!") 72 | ``` 73 | 74 | ## LLMs 75 | 76 | Use Google Cloud's generative AI models as LangChain LLMs: 77 | 78 | ```python 79 | from langchain_core.prompts import PromptTemplate 80 | from langchain_google_vertexai import ChatVertexAI 81 | 82 | template = """Question: {question} 83 | 84 | Answer: Let's think step by step.""" 85 | prompt = PromptTemplate.from_template(template) 86 | 87 | llm = ChatVertexAI(model_name="gemini-pro") 88 | chain = prompt | llm 89 | 90 | question = "Who was the president of the USA in 1994?" 91 | print(chain.invoke({"question": question})) 92 | ``` 93 | 94 | ## Code Generation 95 | 96 | You can use Gemini models for code generation tasks to generate code snippets, functions, or scripts in various programming languages. 97 | 98 | ### Example: Generate a Python function 99 | 100 | ```python 101 | from langchain_google_vertexai import ChatVertexAI 102 | 103 | llm = ChatVertexAI(model_name="gemini-pro", temperature=0.3, max_output_tokens=1000) 104 | 105 | prompt = "Write a Python function that checks if a string is a valid email address." 106 | 107 | generated_code = llm.invoke(prompt) 108 | print(generated_code) 109 | ``` 110 | 111 | ### Example: Generate JavaScript code 112 | 113 | ```python 114 | from langchain_google_vertexai import ChatVertexAI 115 | 116 | llm = ChatVertexAI(model_name="gemini-pro", temperature=0.3, max_output_tokens=1000) 117 | prompt_js = "Write a JavaScript function that returns the factorial of a number." 118 | 119 | print(llm.invoke(prompt_js)) 120 | ``` 121 | 122 | ### Notes 123 | 124 | * Adjust `temperature` to control creativity (higher values increase randomness). 125 | * Use `max_output_tokens` to limit the length of the generated code. 126 | * Gemini models are well-suited for code generation tasks with advanced understanding of programming concepts. -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Type, Union 2 | 3 | from langchain_core.messages import AIMessage, ToolCall 4 | from langchain_core.messages.tool import tool_call 5 | from langchain_core.output_parsers import BaseGenerationOutputParser 6 | from langchain_core.outputs import ChatGeneration, Generation 7 | from pydantic import BaseModel, ConfigDict 8 | 9 | 10 | class ToolsOutputParser(BaseGenerationOutputParser): 11 | first_tool_only: bool = False 12 | args_only: bool = False 13 | pydantic_schemas: Optional[List[Type[BaseModel]]] = None 14 | 15 | model_config = ConfigDict( 16 | extra="forbid", 17 | ) 18 | 19 | def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: 20 | """Parse a list of candidate model Generations into a specific format. 21 | 22 | Args: 23 | result: A list of Generations to be parsed. The Generations are assumed 24 | to be different candidate outputs for a single model input. 25 | 26 | Returns: 27 | Structured output. 28 | """ 29 | if not result or not isinstance(result[0], ChatGeneration): 30 | return None if self.first_tool_only else [] 31 | 32 | message = result[0].message 33 | tool_calls: List[Any] = [] 34 | 35 | if isinstance(message, AIMessage) and message.tool_calls: 36 | tool_calls = message.tool_calls 37 | elif isinstance(message.content, list): 38 | content: Any = message.content 39 | tool_calls = _extract_tool_calls(content) 40 | 41 | if self.pydantic_schemas: 42 | tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] 43 | elif self.args_only: 44 | tool_calls = [tc["args"] for tc in tool_calls] 45 | 46 | if self.first_tool_only: 47 | return tool_calls[0] if tool_calls else None 48 | else: 49 | return [tool_call for tool_call in tool_calls] 50 | 51 | def _pydantic_parse(self, tool_call: dict) -> BaseModel: 52 | cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ 53 | tool_call["name"] 54 | ] 55 | return cls_(**tool_call["args"]) 56 | 57 | 58 | def _extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]: 59 | """Extract tool calls from a list of content blocks.""" 60 | if isinstance(content, list): 61 | tool_calls = [] 62 | for block in content: 63 | if isinstance(block, str): 64 | continue 65 | if block["type"] != "tool_use": 66 | continue 67 | tool_calls.append( 68 | tool_call(name=block["name"], args=block["input"], id=block["id"]) 69 | ) 70 | return tool_calls 71 | else: 72 | return [] 73 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/_enums.py: -------------------------------------------------------------------------------- 1 | from vertexai.generative_models import ( # type: ignore 2 | HarmBlockThreshold, 3 | HarmCategory, 4 | SafetySetting, 5 | ) 6 | 7 | __all__ = ["HarmBlockThreshold", "HarmCategory", "SafetySetting"] 8 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/callbacks.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import Any, Dict, List 3 | 4 | from langchain_core.callbacks import BaseCallbackHandler 5 | from langchain_core.outputs import LLMResult 6 | 7 | 8 | class VertexAICallbackHandler(BaseCallbackHandler): 9 | """Callback Handler that tracks VertexAI info.""" 10 | 11 | prompt_tokens: int = 0 12 | prompt_characters: int = 0 13 | completion_tokens: int = 0 14 | completion_characters: int = 0 15 | successful_requests: int = 0 16 | total_tokens: int = 0 17 | cached_tokens: int = 0 18 | 19 | def __init__(self) -> None: 20 | super().__init__() 21 | self._lock = threading.Lock() 22 | 23 | def __repr__(self) -> str: 24 | return ( 25 | f"\tPrompt tokens: {self.prompt_tokens}\n" 26 | f"\tPrompt characters: {self.prompt_characters}\n" 27 | f"\tCompletion tokens: {self.completion_tokens}\n" 28 | f"\tCompletion characters: {self.completion_characters}\n" 29 | f"\tCached tokens: {self.cached_tokens}\n" 30 | f"\tTotal tokens: {self.total_tokens}\n" 31 | f"Successful requests: {self.successful_requests}\n" 32 | ) 33 | 34 | @property 35 | def always_verbose(self) -> bool: 36 | """Whether to call verbose callbacks even if verbose is False.""" 37 | return True 38 | 39 | def on_llm_start( 40 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 41 | ) -> None: 42 | """Runs when LLM starts running.""" 43 | pass 44 | 45 | def on_llm_new_token(self, token: str, **kwargs: Any) -> None: 46 | """Runs on new LLM token. Only available when streaming is enabled.""" 47 | pass 48 | 49 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 50 | """Collects token usage.""" 51 | completion_tokens, prompt_tokens, total_tokens, cached_tokens = 0, 0, 0, 0 52 | completion_characters, prompt_characters = 0, 0 53 | for generations in response.generations: 54 | if len(generations) > 0 and generations[0].generation_info: 55 | usage_metadata = generations[0].generation_info.get( 56 | "usage_metadata", {} 57 | ) 58 | completion_tokens += usage_metadata.get("candidates_token_count", 0) 59 | prompt_tokens += usage_metadata.get("prompt_token_count", 0) 60 | total_tokens += usage_metadata.get("total_token_count", 0) 61 | cached_tokens += usage_metadata.get("cached_content_token_count", 0) 62 | completion_characters += usage_metadata.get( 63 | "candidates_billable_characters", 0 64 | ) 65 | prompt_characters += usage_metadata.get("prompt_billable_characters", 0) 66 | 67 | with self._lock: 68 | self.prompt_characters += prompt_characters 69 | self.prompt_tokens += prompt_tokens 70 | self.completion_characters += completion_characters 71 | self.completion_tokens += completion_tokens 72 | self.successful_requests += 1 73 | self.total_tokens += total_tokens 74 | self.cached_tokens += cached_tokens 75 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/langchain_google_vertexai/evaluators/__init__.py -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/model_garden_maas/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_google_vertexai.model_garden_maas._base import ( 2 | _LLAMA_MODELS, 3 | _MISTRAL_MODELS, 4 | ) 5 | from langchain_google_vertexai.model_garden_maas.llama import VertexModelGardenLlama 6 | 7 | _MAAS_MODELS = _MISTRAL_MODELS + _LLAMA_MODELS 8 | 9 | 10 | def get_vertex_maas_model(model_name, **kwargs): 11 | """Return a corresponding Vertex MaaS instance. 12 | 13 | A factory method based on model's name. 14 | """ 15 | if model_name not in _MAAS_MODELS: 16 | raise ValueError(f"model name {model_name} is not supported!") 17 | if model_name in _MISTRAL_MODELS: 18 | from langchain_google_vertexai.model_garden_maas.mistral import ( # noqa: F401 19 | VertexModelGardenMistral, 20 | ) 21 | 22 | return VertexModelGardenMistral(model=model_name, **kwargs) 23 | return VertexModelGardenLlama(model=model_name, **kwargs) 24 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/model_garden_maas/mistral.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from langchain_core.callbacks import ( 4 | CallbackManagerForLLMRun, 5 | ) 6 | from langchain_mistralai import ( # type: ignore[unused-ignore, import-not-found] 7 | chat_models, 8 | ) 9 | 10 | from langchain_google_vertexai.model_garden_maas._base import ( 11 | _BaseVertexMaasModelGarden, 12 | acompletion_with_retry, 13 | completion_with_retry, 14 | ) 15 | 16 | chat_models.acompletion_with_retry = acompletion_with_retry # type: ignore[unused-ignore, assignment] 17 | 18 | 19 | class VertexModelGardenMistral(_BaseVertexMaasModelGarden, chat_models.ChatMistralAI): # type: ignore[unused-ignore, misc] 20 | def completion_with_retry( 21 | self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any 22 | ) -> Any: 23 | return completion_with_retry(self, **kwargs) 24 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/langchain_google_vertexai/py.typed -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import List, Optional 3 | 4 | from langchain_core.messages import BaseMessage 5 | from vertexai.preview import caching # type: ignore 6 | 7 | from langchain_google_vertexai._image_utils import ImageBytesLoader 8 | from langchain_google_vertexai.chat_models import ( 9 | ChatVertexAI, 10 | _parse_chat_history_gemini, 11 | ) 12 | from langchain_google_vertexai.functions_utils import ( 13 | _format_to_gapic_tool, 14 | _format_tool_config, 15 | _ToolConfigDict, 16 | _ToolsType, 17 | ) 18 | 19 | 20 | def create_context_cache( 21 | model: ChatVertexAI, 22 | messages: List[BaseMessage], 23 | expire_time: Optional[datetime] = None, 24 | time_to_live: Optional[timedelta] = None, 25 | tools: Optional[_ToolsType] = None, 26 | tool_config: Optional[_ToolConfigDict] = None, 27 | ) -> str: 28 | """Creates a cache for content in some model. 29 | 30 | Args: 31 | model: ChatVertexAI model. Must be at least gemini-1.5 pro or flash. 32 | messages: List of messages to cache. 33 | expire_time: Timestamp of when this resource is considered expired. 34 | At most one of expire_time and ttl can be set. If neither is set, default TTL 35 | on the API side will be used (currently 1 hour). 36 | time_to_live: The TTL for this resource. If provided, the expiration time is 37 | computed: created_time + TTL. 38 | At most one of expire_time and ttl can be set. If neither is set, default TTL 39 | on the API side will be used (currently 1 hour). 40 | tools: A list of tool definitions to bind to this chat model. 41 | Can be a pydantic model, callable, or BaseTool. Pydantic 42 | models, callables, and BaseTools will be automatically converted to 43 | their schema dictionary representation. 44 | tool_config: Optional. Immutable. Tool config. This config is shared for all 45 | tools. 46 | 47 | Raises: 48 | ValueError: If model doesn't support context catching. 49 | 50 | Returns: 51 | String with the identificator of the created cache. 52 | """ 53 | 54 | if not model._is_gemini_advanced: 55 | error_msg = f"Model {model.full_model_name} doesn't support context catching" 56 | raise ValueError(error_msg) 57 | 58 | system_instruction, contents = _parse_chat_history_gemini( 59 | messages, ImageBytesLoader(project=model.project) 60 | ) 61 | 62 | if tool_config: 63 | tool_config = _format_tool_config(tool_config) 64 | 65 | if tools is not None: 66 | tools = [_format_to_gapic_tool(tools)] 67 | 68 | cached_content = caching.CachedContent.create( 69 | model_name=model.full_model_name, 70 | system_instruction=system_instruction, 71 | contents=contents, 72 | ttl=time_to_live, 73 | expire_time=expire_time, 74 | tool_config=tool_config, 75 | tools=tools, 76 | ) 77 | 78 | return cached_content.name 79 | -------------------------------------------------------------------------------- /libs/vertexai/langchain_google_vertexai/vectorstores/__init__.py: -------------------------------------------------------------------------------- 1 | from langchain_google_vertexai.vectorstores.document_storage import ( 2 | DataStoreDocumentStorage, 3 | GCSDocumentStorage, 4 | ) 5 | from langchain_google_vertexai.vectorstores.vectorstores import ( 6 | VectorSearchVectorStore, 7 | VectorSearchVectorStoreDatastore, 8 | VectorSearchVectorStoreGCS, 9 | ) 10 | 11 | __all__ = [ 12 | "VectorSearchVectorStore", 13 | "VectorSearchVectorStoreDatastore", 14 | "VectorSearchVectorStoreGCS", 15 | "DataStoreDocumentStorage", 16 | "GCSDocumentStorage", 17 | ] 18 | -------------------------------------------------------------------------------- /libs/vertexai/scripts/check_imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | from importlib.machinery import SourceFileLoader 4 | 5 | if __name__ == "__main__": 6 | files = sys.argv[1:] 7 | has_failure = False 8 | for file in files: 9 | try: 10 | SourceFileLoader("x", file).load_module() 11 | except Exception: 12 | has_faillure = True 13 | print(file) 14 | traceback.print_exc() 15 | print() 16 | 17 | sys.exit(1 if has_failure else 0) 18 | -------------------------------------------------------------------------------- /libs/vertexai/scripts/lint_imports.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | # Initialize a variable to keep track of errors 6 | errors=0 7 | 8 | # make sure not importing from langchain or langchain_experimental 9 | git --no-pager grep '^from langchain\.' . && errors=$((errors+1)) 10 | git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1)) 11 | 12 | # Decide on an exit status based on the errors 13 | if [ "$errors" -gt 0 ]; then 14 | exit 1 15 | else 16 | exit 0 17 | fi 18 | -------------------------------------------------------------------------------- /libs/vertexai/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/tests/__init__.py -------------------------------------------------------------------------------- /libs/vertexai/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests configuration to be executed before tests execution. 3 | """ 4 | 5 | from typing import List 6 | 7 | import pytest 8 | 9 | _RELEASE_FLAG = "release" 10 | _GPU_FLAG = "gpu" 11 | _LONG_FLAG = "long" 12 | _EXTENDED_FLAG = "extended" 13 | 14 | _PYTEST_FLAGS = [_RELEASE_FLAG, _GPU_FLAG, _LONG_FLAG, _EXTENDED_FLAG, "first"] 15 | 16 | 17 | def pytest_addoption(parser: pytest.Parser) -> None: 18 | """ 19 | Add flags accepted by pytest CLI. 20 | 21 | Args: 22 | parser: The pytest parser object. 23 | 24 | Returns: 25 | 26 | """ 27 | for flag in _PYTEST_FLAGS: 28 | parser.addoption( 29 | f"--{flag}", action="store_true", default=False, help=f"run {flag} tests" 30 | ) 31 | 32 | 33 | def pytest_configure(config: pytest.Config) -> None: 34 | """ 35 | Add pytest custom configuration. 36 | 37 | Args: 38 | config: The pytest config object. 39 | 40 | Returns: 41 | """ 42 | for flag in _PYTEST_FLAGS: 43 | config.addinivalue_line( 44 | "markers", f"{flag}: mark test to run as {flag} only test" 45 | ) 46 | 47 | 48 | def pytest_collection_modifyitems( 49 | config: pytest.Config, items: List[pytest.Item] 50 | ) -> None: 51 | """ 52 | Skip tests with a marker from our list that were not explicitly invoked. 53 | 54 | Args: 55 | config: The pytest config object. 56 | items: The list of tests to be executed. 57 | 58 | Returns: 59 | """ 60 | for item in items: 61 | keywords = list(set(item.keywords).intersection(_PYTEST_FLAGS)) 62 | if keywords and not any((config.getoption(f"--{kw}") for kw in keywords)): 63 | skip = pytest.mark.skip(reason=f"need --{keywords[0]} option to run") 64 | item.add_marker(skip) 65 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/.env.example: -------------------------------------------------------------------------------- 1 | PROJECT_ID=projecy_id 2 | FALCON_ENDPOINT_ID=falcon_endpoint_id 3 | GEMMA_ENDPOINT_ID=gemma_endpoint_id 4 | LLAMA_ENDPOINT_ID=llama_endpoint_id 5 | IMAGE_GCS_PATH=image_gcs_path 6 | VECTOR_SEARCH_STAGING_BUCKET=VECTOR_SEARCH_STAGING_BUCKET 7 | VECTOR_SEARCH_STREAM_INDEX_ID=VECTOR_SEARCH_STREAM_INDEX_ID 8 | VECTOR_SEARCH_STREAM_ENDPOINT_ID=VECTOR_SEARCH_STREAM_ENDPOINT_ID 9 | VECTOR_SEARCH_BATCH_INDEX_ID=VECTOR_SEARCH_BATCH_INDEX_ID 10 | VECTOR_SEARCH_BATCH_ENDPOINT_ID=VECTOR_SEARCH_BATCH_ENDPOINT_ID -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/TODO.md: -------------------------------------------------------------------------------- 1 | # Integration tests TODO 2 | 3 | ## Required 4 | - [ ] TBD 5 | 6 | ## Optional 7 | - [ ] Find a solution not to rely on the "REGION" env var's default value in test_vectorstores.py -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | _DEFAULT_MODEL_NAME = "gemini-2.0-flash-001" 4 | _DEFAULT_THINKING_MODEL_NAME = "gemini-2.5-flash-preview-04-17" 5 | 6 | 7 | @pytest.fixture 8 | def base64_image() -> str: 9 | return ( 10 | "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" 11 | "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" 12 | "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" 13 | "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" 14 | "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" 15 | "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" 16 | "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" 17 | "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" 18 | "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" 19 | "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" 20 | "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" 21 | "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" 22 | "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" 23 | "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" 24 | "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" 25 | "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" 26 | "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" 27 | "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" 28 | "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" 29 | ) 30 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/terraform/main.tf: -------------------------------------------------------------------------------- 1 | module "cloudbuild" { 2 | source = "./../../../../../terraform/cloudbuild" 3 | 4 | library = "vertexai" 5 | project_id = "" 6 | cloudbuildv2_repository_id = "" 7 | cloudbuild_env_vars = { 8 | FALCON_ENDPOINT_ID = "", 9 | GEMMA_ENDPOINT_ID = "", 10 | LLAMA_ENDPOINT_ID = "", 11 | IMAGE_GCS_PATH = "", 12 | VECTOR_SEARCH_STAGING_BUCKET="", 13 | VECTOR_SEARCH_STREAM_INDEX_ID="", 14 | VECTOR_SEARCH_STREAM_ENDPOINT_ID="", 15 | VECTOR_SEARCH_BATCH_INDEX_ID="", 16 | VECTOR_SEARCH_BATCH_ENDPOINT_ID="", 17 | } 18 | cloudbuild_secret_vars = { 19 | GOOGLE_API_KEY = "" 20 | } 21 | } -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_anthropic_files.py: -------------------------------------------------------------------------------- 1 | """Test ChatGoogleVertexAI chat model.""" 2 | import os 3 | 4 | import pytest 5 | 6 | from langchain_google_vertexai._image_utils import image_bytes_to_b64_string 7 | from langchain_google_vertexai._utils import load_image_from_gcs 8 | from langchain_google_vertexai.model_garden import ChatAnthropicVertex 9 | 10 | 11 | @pytest.mark.extended 12 | def test_pdf_gcs_uri(): 13 | gcs_uri = "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf" 14 | llm = ChatAnthropicVertex( 15 | model="claude-3-5-sonnet-v2@20241022", 16 | location="us-east5", 17 | temperature=0.8, 18 | project=os.environ["PROJECT_ID"], 19 | ) 20 | 21 | res = llm.invoke( 22 | [ 23 | { 24 | "role": "user", 25 | "content": [ 26 | "Parse this pdf.", 27 | {"type": "image_url", "image_url": {"url": gcs_uri}}, 28 | ], 29 | } 30 | ] 31 | ) 32 | assert len(res.content) > 100 33 | 34 | 35 | @pytest.mark.extended 36 | def test_pdf_byts(): 37 | gcs_uri = "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf" 38 | llm = ChatAnthropicVertex( 39 | model="claude-3-5-sonnet-v2@20241022", 40 | location="us-east5", 41 | temperature=0.8, 42 | project=os.environ["PROJECT_ID"], 43 | ) 44 | image = load_image_from_gcs(gcs_uri, "kuligin-sandbox1") 45 | image_data = image_bytes_to_b64_string(image.data, "ascii", "pdf") 46 | 47 | res = llm.invoke( 48 | [ 49 | { 50 | "role": "user", 51 | "content": [ 52 | "Parse this pdf.", 53 | {"type": "image_url", "image_url": {"url": image_data}}, 54 | ], 55 | } 56 | ] 57 | ) 58 | assert len(res.content) > 100 59 | 60 | 61 | @pytest.mark.extended 62 | def test_https_image(): 63 | uri = "https://picsum.photos/seed/picsum/200/300.jpg" 64 | 65 | llm = ChatAnthropicVertex( 66 | model="claude-3-5-sonnet-v2@20241022", 67 | location="us-east5", 68 | temperature=0.8, 69 | project=os.environ["PROJECT_ID"], 70 | ) 71 | 72 | res = llm.invoke( 73 | [ 74 | { 75 | "role": "user", 76 | "content": [ 77 | "Parse this pdf.", 78 | {"type": "image_url", "image_url": {"url": uri}}, 79 | ], 80 | } 81 | ] 82 | ) 83 | assert len(res.content) > 10 84 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from langchain_core.messages import HumanMessage 3 | 4 | from langchain_google_vertexai.callbacks import VertexAICallbackHandler 5 | from langchain_google_vertexai.chat_models import ChatVertexAI 6 | from langchain_google_vertexai.llms import VertexAI 7 | from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME 8 | 9 | 10 | @pytest.mark.release 11 | @pytest.mark.parametrize( 12 | "model_name", 13 | [_DEFAULT_MODEL_NAME], 14 | ) 15 | def test_llm_invoke(model_name: str) -> None: 16 | vb = VertexAICallbackHandler() 17 | llm = VertexAI(model_name=model_name, temperature=0.0, callbacks=[vb]) 18 | _ = llm.invoke("2+2") 19 | assert vb.successful_requests == 1 20 | assert vb.prompt_tokens > 0 21 | assert vb.completion_tokens > 0 22 | prompt_tokens = vb.prompt_tokens 23 | completion_tokens = vb.completion_tokens 24 | _ = llm.invoke("2+2") 25 | assert vb.successful_requests == 2 26 | assert vb.prompt_tokens > prompt_tokens 27 | assert vb.completion_tokens > completion_tokens 28 | 29 | 30 | @pytest.mark.release 31 | @pytest.mark.parametrize( 32 | "model_name", 33 | [_DEFAULT_MODEL_NAME], 34 | ) 35 | def test_chat_call(model_name: str) -> None: 36 | vb = VertexAICallbackHandler() 37 | llm = ChatVertexAI(model_name=model_name, temperature=0.0, callbacks=[vb]) 38 | message = HumanMessage(content="Hello") 39 | _ = llm([message]) 40 | assert vb.successful_requests == 1 41 | assert vb.prompt_tokens > 0 42 | assert vb.completion_tokens > 0 43 | prompt_tokens = vb.prompt_tokens 44 | completion_tokens = vb.completion_tokens 45 | _ = llm([message]) 46 | assert vb.successful_requests == 2 47 | assert vb.prompt_tokens > prompt_tokens 48 | assert vb.completion_tokens > completion_tokens 49 | 50 | 51 | @pytest.mark.release 52 | @pytest.mark.parametrize( 53 | "model_name", 54 | [_DEFAULT_MODEL_NAME], 55 | ) 56 | def test_invoke_config(model_name: str) -> None: 57 | vb = VertexAICallbackHandler() 58 | llm = VertexAI(model_name=model_name, temperature=0.0) 59 | llm.invoke("2+2", config={"callbacks": [vb]}) 60 | assert vb.successful_requests == 1 61 | assert vb.prompt_tokens > 0 62 | assert vb.completion_tokens > 0 63 | prompt_tokens = vb.prompt_tokens 64 | completion_tokens = vb.completion_tokens 65 | llm.invoke("2+2", config={"callbacks": [vb]}) 66 | assert vb.successful_requests == 2 67 | assert vb.prompt_tokens > prompt_tokens 68 | assert vb.completion_tokens > completion_tokens 69 | 70 | 71 | @pytest.mark.release 72 | def test_llm_stream() -> None: 73 | vb = VertexAICallbackHandler() 74 | llm = VertexAI(model_name=_DEFAULT_MODEL_NAME, temperature=0.0, callbacks=[vb]) 75 | for _ in llm.stream("2+2"): 76 | pass 77 | assert vb.successful_requests == 1 78 | assert vb.prompt_tokens > 0 79 | assert vb.completion_tokens > 0 80 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_chains.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | from langchain_core.messages import ( 5 | AIMessage, 6 | ) 7 | from langchain_core.prompts import ChatPromptTemplate 8 | from pydantic import BaseModel, Field 9 | 10 | from langchain_google_vertexai import ChatVertexAI, create_structured_runnable 11 | from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME 12 | 13 | 14 | class RecordPerson(BaseModel): 15 | """Record some identifying information about a person.""" 16 | 17 | name: str = Field(..., description="The person's name") 18 | age: int = Field(..., description="The person's age") 19 | fav_food: Optional[str] = Field( 20 | default=None, description="The person's favorite food" 21 | ) 22 | 23 | 24 | class RecordDog(BaseModel): 25 | """Record some identifying information about a dog.""" 26 | 27 | name: str = Field(..., description="The dog's name") 28 | color: str = Field(..., description="The dog's color") 29 | fav_food: Optional[str] = Field(default=None, description="The dog's favorite food") 30 | 31 | 32 | @pytest.mark.release 33 | def test_create_structured_runnable() -> None: 34 | llm = ChatVertexAI(model_name=_DEFAULT_MODEL_NAME) 35 | prompt = ChatPromptTemplate.from_template( 36 | "You are a world class algorithm for recording entities.\nMake calls to the " 37 | "relevant function to record the entities in the following input:\n {input}\n" 38 | "Tip: Make sure to answer in the correct format" 39 | ) 40 | chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt) 41 | res = chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) 42 | assert isinstance(res, RecordDog) 43 | 44 | 45 | @pytest.mark.release 46 | def test_create_structured_runnable_with_prompt() -> None: 47 | llm = ChatVertexAI(model_name=_DEFAULT_MODEL_NAME, temperature=0) 48 | prompt = ChatPromptTemplate.from_template( 49 | "Describe a random {class} and mention their name, {attr} and favorite food" 50 | ) 51 | chain = create_structured_runnable( 52 | [RecordPerson, RecordDog], llm, prompt=prompt, use_extra_step=True 53 | ) 54 | res = chain.invoke({"class": "person", "attr": "age"}) 55 | assert isinstance(res, RecordPerson) 56 | 57 | 58 | @pytest.mark.release 59 | def test_reflection() -> None: 60 | class Reflection(BaseModel): 61 | reflections: str = Field( 62 | description="The critique and reflections on the sufficiency, superfluency," 63 | " and general quality of the response" 64 | ) 65 | score: int = Field( 66 | description="Score from 0-10 on the quality of the candidate response.", 67 | # gte=0, 68 | # lte=10, 69 | ) 70 | found_solution: bool = Field( 71 | description="Whether the response has fully solved the question or task." 72 | ) 73 | 74 | def as_message(self): 75 | return AIMessage( 76 | content=f"Reasoning: {self.reflections}\nScore: {self.score}" 77 | ) 78 | 79 | @property 80 | def normalized_score(self) -> float: 81 | return self.score / 10.0 82 | 83 | llm = ChatVertexAI(model_name=_DEFAULT_MODEL_NAME) 84 | 85 | prompt = ChatPromptTemplate.from_messages( 86 | [ 87 | ( 88 | "system", 89 | "Reflect and grade the assistant response to the user question below.", 90 | ), 91 | ( 92 | "user", 93 | "Which planet is the closest to the Earth?", 94 | ), 95 | ("ai", "{input}"), 96 | ] 97 | ) 98 | 99 | reflection_llm_chain = prompt | llm.with_structured_output(Reflection) 100 | res = reflection_llm_chain.invoke({"input": "Mars"}) 101 | assert isinstance(res, Reflection) 102 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from langchain_google_vertexai import ( 6 | VertexPairWiseStringEvaluator, 7 | VertexStringEvaluator, 8 | ) 9 | 10 | 11 | @pytest.mark.extended 12 | def test_evaluate() -> None: 13 | evaluator = VertexStringEvaluator( 14 | metric="bleu", project_id=os.environ["PROJECT_ID"] 15 | ) 16 | result = evaluator.evaluate( 17 | examples=[ 18 | {"reference": "This is a test."}, 19 | {"reference": "This is another test."}, 20 | ], 21 | predictions=[ 22 | {"prediction": "This is a test."}, 23 | {"prediction": "This is another one."}, 24 | ], 25 | ) 26 | assert len(result) == 2 27 | assert result[0]["score"] == 1.0 28 | assert result[1]["score"] < 1.0 29 | 30 | 31 | @pytest.mark.extended 32 | @pytest.mark.flaky(retries=3) 33 | def test_evaluate_strings() -> None: 34 | evaluator = VertexStringEvaluator( 35 | metric="safety", project_id=os.environ["PROJECT_ID"] 36 | ) 37 | result = evaluator._evaluate_strings(prediction="This is a test") 38 | assert isinstance(result, dict) 39 | assert "score" in result 40 | assert "explanation" in result 41 | 42 | 43 | @pytest.mark.extended 44 | @pytest.mark.flaky(retries=3) 45 | async def test_aevaluate_strings() -> None: 46 | evaluator = VertexStringEvaluator( 47 | metric="question_answering_quality", project_id=os.environ["PROJECT_ID"] 48 | ) 49 | result = await evaluator._aevaluate_strings( 50 | prediction="London", 51 | input="What is the capital of Great Britain?", 52 | instruction="Be concise", 53 | ) 54 | assert isinstance(result, dict) 55 | assert "score" in result 56 | assert "explanation" in result 57 | 58 | 59 | @pytest.mark.extended 60 | @pytest.mark.xfail(reason="TODO: investigate (started failing 2025-03-25).") 61 | async def test_evaluate_pairwise() -> None: 62 | evaluator = VertexPairWiseStringEvaluator( 63 | metric="pairwise_question_answering_quality", 64 | project_id=os.environ["PROJECT_ID"], 65 | ) 66 | result = evaluator.evaluate_string_pairs( 67 | prediction="London", 68 | prediction_b="Berlin", 69 | input="What is the capital of Great Britain?", 70 | instruction="Be concise", 71 | ) 72 | assert isinstance(result, dict) 73 | assert "confidence" in result 74 | assert "explanation" in result 75 | assert result["pairwise_choice"] == "BASELINE" 76 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_gemma.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from langchain_core.messages import ( 5 | AIMessage, 6 | HumanMessage, 7 | ) 8 | 9 | from langchain_google_vertexai import ( 10 | GemmaChatLocalKaggle, 11 | GemmaChatVertexAIModelGarden, 12 | GemmaLocalKaggle, 13 | GemmaVertexAIModelGarden, 14 | ) 15 | 16 | 17 | @pytest.mark.extended 18 | def test_gemma_model_garden() -> None: 19 | """In order to run this test, you should provide endpoint names. 20 | 21 | Example: 22 | export GEMMA_ENDPOINT_ID=... 23 | export PROJECT_ID=... 24 | """ 25 | endpoint_id = os.environ["GEMMA_ENDPOINT_ID"] 26 | project = os.environ["PROJECT_ID"] 27 | location = "us-central1" 28 | llm = GemmaVertexAIModelGarden( 29 | endpoint_id=endpoint_id, 30 | project=project, 31 | location=location, 32 | ) 33 | output = llm.invoke("What is the meaning of life?") 34 | assert isinstance(output, str) 35 | assert len(output) > 2 36 | assert llm._llm_type == "gemma_vertexai_model_garden" 37 | 38 | 39 | @pytest.mark.extended 40 | def test_gemma_chat_model_garden() -> None: 41 | """In order to run this test, you should provide endpoint names. 42 | 43 | Example: 44 | export GEMMA_ENDPOINT_ID=... 45 | export PROJECT_ID=... 46 | """ 47 | endpoint_id = os.environ["GEMMA_ENDPOINT_ID"] 48 | project = os.environ["PROJECT_ID"] 49 | location = "us-central1" 50 | llm = GemmaChatVertexAIModelGarden( 51 | endpoint_id=endpoint_id, 52 | project=project, 53 | location=location, 54 | ) 55 | assert llm._llm_type == "gemma_vertexai_model_garden" 56 | 57 | text_question1, text_answer1 = "How much is 2+2?", "4" 58 | text_question2 = "How much is 3+3?" 59 | message1 = HumanMessage(content=text_question1) 60 | message2 = AIMessage(content=text_answer1) 61 | message3 = HumanMessage(content=text_question2) 62 | output = llm.invoke([message1]) 63 | assert isinstance(output, AIMessage) 64 | assert len(output.content) > 2 65 | output = llm.invoke([message1, message2, message3]) 66 | assert isinstance(output, AIMessage) 67 | assert len(output.content) > 2 68 | 69 | 70 | @pytest.mark.gpu 71 | def test_gemma_kaggle() -> None: 72 | llm = GemmaLocalKaggle(model_name="gemma_2b_en") 73 | output = llm.invoke("What is the meaning of life?") 74 | assert isinstance(output, str) 75 | print(output) 76 | assert len(output) > 2 77 | 78 | 79 | @pytest.mark.gpu 80 | def test_gemma_chat_kaggle() -> None: 81 | llm = GemmaChatLocalKaggle(model_name="gemma_2b_en") 82 | text_question1, text_answer1 = "How much is 2+2?", "4" 83 | text_question2 = "How much is 3+3?" 84 | message1 = HumanMessage(content=text_question1) 85 | message2 = AIMessage(content=text_answer1) 86 | message3 = HumanMessage(content=text_question2) 87 | output = llm.invoke([message1]) 88 | assert isinstance(output, AIMessage) 89 | assert len(output.content) > 2 90 | output = llm.invoke([message1, message2, message3]) 91 | assert isinstance(output, AIMessage) 92 | assert len(output.content) > 2 93 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_image_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from google.cloud import storage # type: ignore[attr-defined, unused-ignore] 3 | from google.cloud.exceptions import NotFound 4 | 5 | from langchain_google_vertexai._image_utils import ImageBytesLoader 6 | 7 | 8 | @pytest.mark.skip("CI testing not set up") 9 | def test_image_utils(): 10 | base64_image = ( 11 | "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" 12 | "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" 13 | "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" 14 | "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" 15 | "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" 16 | "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" 17 | "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" 18 | "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" 19 | "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" 20 | "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" 21 | "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" 22 | "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" 23 | "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" 24 | "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" 25 | "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" 26 | "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" 27 | "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" 28 | "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" 29 | "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" 30 | ) 31 | 32 | loader = ImageBytesLoader() 33 | 34 | image_bytes = loader.load_bytes(base64_image) 35 | 36 | assert isinstance(image_bytes, bytes) 37 | 38 | # Check loads image from blob 39 | 40 | bucket_name = "test_image_utils" 41 | blob_name = "my_image.png" 42 | 43 | client = storage.Client() 44 | bucket = client.bucket(bucket_name=bucket_name) 45 | blob = bucket.blob(blob_name) 46 | 47 | try: 48 | blob.upload_from_string(data=image_bytes) 49 | except NotFound: 50 | client.create_bucket(bucket) 51 | blob.upload_from_string(data=image_bytes) 52 | 53 | gcs_uri = f"gs://{bucket.name}/{blob.name}" 54 | 55 | gcs_image_bytes = loader.load_bytes(gcs_uri) 56 | 57 | assert image_bytes == gcs_image_bytes 58 | 59 | # Checks loads image from url 60 | url = ( 61 | "https://www.google.co.jp/images/branding/" 62 | "googlelogo/1x/googlelogo_color_272x92dp.png" 63 | ) 64 | 65 | image_bytes = loader.load_bytes(url) 66 | assert isinstance(image_bytes, bytes) 67 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_medlm.py: -------------------------------------------------------------------------------- 1 | """Test MedlLM models. 2 | - medlm-medium@latest is part of GEMINI family, 3 | - should return str for VertexAI/Text Completion, 4 | - should returnAIMessage for ChatVertexAI/Chat Completion""" 5 | 6 | import pytest 7 | 8 | from langchain_google_vertexai import VertexAI 9 | 10 | 11 | @pytest.mark.extended 12 | def test_invoke_medlm_large_palm_error() -> None: 13 | with pytest.raises(ValueError): 14 | model = VertexAI(model_name="medlm-large") 15 | model.invoke("How you can help me?") 16 | -------------------------------------------------------------------------------- /libs/vertexai/tests/integration_tests/test_standard.py: -------------------------------------------------------------------------------- 1 | """Standard LangChain interface tests""" 2 | 3 | from typing import Type 4 | 5 | import pytest 6 | from langchain_core.language_models import BaseChatModel 7 | from langchain_core.rate_limiters import InMemoryRateLimiter 8 | from langchain_tests.integration_tests import ChatModelIntegrationTests 9 | 10 | from langchain_google_vertexai import ChatVertexAI 11 | 12 | rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) 13 | 14 | 15 | @pytest.mark.first 16 | class TestGemini2AIStandard(ChatModelIntegrationTests): 17 | @property 18 | def chat_model_class(self) -> Type[BaseChatModel]: 19 | return ChatVertexAI 20 | 21 | @property 22 | def chat_model_params(self) -> dict: 23 | return { 24 | "model_name": "gemini-2.0-flash-001", 25 | "rate_limiter": rate_limiter, 26 | "temperature": 0, 27 | "api_transport": None, 28 | } 29 | 30 | @property 31 | def supports_image_inputs(self) -> bool: 32 | return True 33 | 34 | @property 35 | def supports_image_urls(self) -> bool: 36 | return True 37 | 38 | @property 39 | def supports_pdf_inputs(self) -> bool: 40 | return True 41 | 42 | @property 43 | def supports_video_inputs(self) -> bool: 44 | return True 45 | 46 | @property 47 | def supports_audio_inputs(self) -> bool: 48 | return True 49 | 50 | @property 51 | def supports_json_mode(self) -> bool: 52 | return True 53 | 54 | 55 | class TestGemini_15_AIStandard(ChatModelIntegrationTests): 56 | @property 57 | def chat_model_class(self) -> Type[BaseChatModel]: 58 | return ChatVertexAI 59 | 60 | @property 61 | def chat_model_params(self) -> dict: 62 | return { 63 | "model_name": "gemini-1.5-pro-002", 64 | "rate_limiter": rate_limiter, 65 | "temperature": 0, 66 | "api_transport": None, 67 | } 68 | 69 | @property 70 | def supports_image_inputs(self) -> bool: 71 | return True 72 | 73 | @property 74 | def supports_image_urls(self) -> bool: 75 | return True 76 | 77 | @property 78 | def supports_pdf_inputs(self) -> bool: 79 | return True 80 | 81 | @property 82 | def supports_video_inputs(self) -> bool: 83 | return True 84 | 85 | @property 86 | def supports_audio_inputs(self) -> bool: 87 | return True 88 | 89 | @property 90 | def supports_json_mode(self) -> bool: 91 | return True 92 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/__snapshots__/test_standard.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: TestGemini_15_AIStandard.test_serdes[serialized] 3 | dict({ 4 | 'id': list([ 5 | 'langchain', 6 | 'chat_models', 7 | 'vertexai', 8 | 'ChatVertexAI', 9 | ]), 10 | 'kwargs': dict({ 11 | 'default_metadata': list([ 12 | ]), 13 | 'endpoint_version': 'v1beta1', 14 | 'location': 'us-central1', 15 | 'max_output_tokens': 100, 16 | 'max_retries': 2, 17 | 'model_family': '2', 18 | 'model_kwargs': dict({ 19 | 'api_key': 'test', 20 | 'timeout': 60, 21 | }), 22 | 'model_name': 'gemini-1.5-pro-001', 23 | 'n': 1, 24 | 'perform_literal_eval_on_string_raw_content': True, 25 | 'project': 'test-proj', 26 | 'request_parallelism': 5, 27 | 'stop': list([ 28 | ]), 29 | 'temperature': 0.0, 30 | }), 31 | 'lc': 1, 32 | 'name': 'ChatVertexAI', 33 | 'type': 'constructor', 34 | }) 35 | # --- 36 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_chains.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-google/3b02c56fea3136c06f4253ad9058d9123fda3949/libs/vertexai/tests/unit_tests/test_chains.py -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from unittest.mock import MagicMock, patch 3 | 4 | import pytest 5 | from pydantic import model_validator 6 | from typing_extensions import Self 7 | 8 | from langchain_google_vertexai import VertexAIEmbeddings 9 | from langchain_google_vertexai.embeddings import ( 10 | EmbeddingTaskTypes, 11 | GoogleEmbeddingModelType, 12 | ) 13 | 14 | 15 | def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: 16 | mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001") 17 | assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT 18 | with pytest.raises(NotImplementedError) as e: 19 | mock_embeddings.embed_images(["test"])[0] 20 | assert e.value == "Only supported for multimodal models" 21 | 22 | 23 | def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None: 24 | mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001") 25 | default_batch_size = mock_embeddings.instance["batch_size"] 26 | texts = ["text {i}" for i in range(default_batch_size * 2)] 27 | # It should only return one batch (out of two) still to process 28 | _, batches = mock_embeddings._prepare_and_validate_batches(texts=texts) 29 | assert len(batches) == 1 30 | # The second time it should return the batches unchanged 31 | _, batches = mock_embeddings._prepare_and_validate_batches(texts=texts) 32 | assert len(batches) == 2 33 | 34 | 35 | @patch.object(VertexAIEmbeddings, "embed") 36 | def test_embed_documents_with_question_answering_task(mock_embed) -> None: 37 | mock_embeddings = MockVertexAIEmbeddings("text-embedding-005") 38 | texts = [f"text {i}" for i in range(5)] 39 | 40 | embedding_dimension = 768 41 | embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING" 42 | 43 | mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts] 44 | 45 | embeddings = mock_embeddings.embed_documents( 46 | texts=texts, embeddings_task_type=embeddings_task_type 47 | ) 48 | 49 | assert isinstance(embeddings, list) 50 | assert len(embeddings) == len(texts) 51 | assert len(embeddings[0]) == embedding_dimension 52 | 53 | # Verify embed() was called correctly 54 | mock_embed.assert_called_once_with(texts, 0, embeddings_task_type) 55 | 56 | 57 | @patch.object(VertexAIEmbeddings, "embed") 58 | def test_embed_query_with_question_answering_task(mock_embed) -> None: 59 | mock_embeddings = MockVertexAIEmbeddings("text-embedding-005") 60 | text = "text 0" 61 | 62 | embedding_dimension = 768 63 | embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING" 64 | 65 | mock_embed.return_value = [[0.001] * embedding_dimension] 66 | 67 | embedding = mock_embeddings.embed_query( 68 | text=text, embeddings_task_type=embeddings_task_type 69 | ) 70 | 71 | assert isinstance(embedding, list) 72 | assert len(embedding) == embedding_dimension 73 | 74 | # Verify embed() was called correctly 75 | mock_embed.assert_called_once_with([text], 1, embeddings_task_type) 76 | 77 | 78 | class MockVertexAIEmbeddings(VertexAIEmbeddings): 79 | """ 80 | A mock class for avoiding instantiating VertexAI and the EmbeddingModel client 81 | instance during init 82 | """ 83 | 84 | def __init__(self, model_name, **kwargs: Any) -> None: 85 | super().__init__(model_name, **kwargs) 86 | 87 | @classmethod 88 | def _init_vertexai(cls, values: Dict) -> None: 89 | pass 90 | 91 | @model_validator(mode="after") 92 | def validate_environment(self) -> Self: 93 | self.client = MagicMock() 94 | return self 95 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_google_vertexai import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "ChatVertexAI", 5 | "create_structured_runnable", 6 | "DataStoreDocumentStorage", 7 | "FunctionCallingConfig", 8 | "FunctionDeclaration", 9 | "GCSDocumentStorage", 10 | "GemmaChatLocalHF", 11 | "GemmaChatLocalKaggle", 12 | "GemmaChatVertexAIModelGarden", 13 | "GemmaLocalHF", 14 | "GemmaLocalKaggle", 15 | "GemmaVertexAIModelGarden", 16 | "HarmBlockThreshold", 17 | "HarmCategory", 18 | "PydanticFunctionsOutputParser", 19 | "SafetySetting", 20 | "Schema", 21 | "ToolConfig", 22 | "Type", 23 | "VectorSearchVectorStore", 24 | "VectorSearchVectorStoreDatastore", 25 | "VectorSearchVectorStoreGCS", 26 | "VertexAI", 27 | "VertexAIEmbeddings", 28 | "VertexAIImageCaptioning", 29 | "VertexAIImageCaptioningChat", 30 | "VertexAIImageEditorChat", 31 | "VertexAIImageGeneratorChat", 32 | "VertexAIModelGarden", 33 | "VertexAIVisualQnAChat", 34 | "VertexPairWiseStringEvaluator", 35 | "VertexStringEvaluator", 36 | "create_context_cache", 37 | "get_vertex_maas_model", 38 | ] 39 | 40 | 41 | def test_all_imports() -> None: 42 | assert sorted(EXPECTED_ALL) == sorted(__all__) 43 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_model_garden_retry.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | from anthropic import APIError 5 | 6 | from langchain_google_vertexai.model_garden import ( 7 | _create_retry_decorator, 8 | ) 9 | 10 | 11 | def create_api_error(): 12 | """Helper function to create an APIError with required arguments.""" 13 | mock_request = MagicMock() 14 | mock_request.method = "POST" 15 | mock_request.url = "test-url" 16 | mock_request.headers = {} 17 | mock_request.body = None 18 | return APIError( 19 | message="Test error", 20 | request=mock_request, 21 | body={"error": {"message": "Test error"}}, 22 | ) 23 | 24 | 25 | def test_retry_on_errors(): 26 | """Test that the retry decorator works with sync functions.""" 27 | max_retries = 2 28 | wait_exponential_kwargs = {"multiplier": 1.0, "min": 1.0, "max": 10.0} 29 | mock_function = MagicMock(side_effect=[create_api_error(), "success"]) 30 | 31 | decorator = _create_retry_decorator( 32 | max_retries=max_retries, wait_exponential_kwargs=wait_exponential_kwargs 33 | ) 34 | wrapped_func = decorator(mock_function) 35 | 36 | result = wrapped_func() 37 | assert result == "success" 38 | assert mock_function.call_count == 2 39 | 40 | 41 | def test_max_retries_exceeded(): 42 | """Test that the retry decorator fails after max retries.""" 43 | max_retries = 2 44 | wait_exponential_kwargs = {"multiplier": 1.0, "min": 1.0, "max": 10.0} 45 | mock_function = MagicMock(side_effect=[create_api_error(), create_api_error()]) 46 | 47 | decorator = _create_retry_decorator( 48 | max_retries=max_retries, wait_exponential_kwargs=wait_exponential_kwargs 49 | ) 50 | wrapped_func = decorator(mock_function) 51 | 52 | with pytest.raises(APIError): 53 | wrapped_func() 54 | assert mock_function.call_count == 2 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_async_retry_on_errors(): 59 | """Test that the retry decorator works with async functions.""" 60 | max_retries = 2 61 | wait_exponential_kwargs = {"multiplier": 1.0, "min": 1.0, "max": 10.0} 62 | 63 | class AsyncMock: 64 | def __init__(self): 65 | self.call_count = 0 66 | 67 | async def __call__(self): 68 | self.call_count += 1 69 | if self.call_count == 1: 70 | raise create_api_error() 71 | return "success" 72 | 73 | mock_async = AsyncMock() 74 | 75 | decorator = _create_retry_decorator( 76 | max_retries=max_retries, wait_exponential_kwargs=wait_exponential_kwargs 77 | ) 78 | wrapped_func = decorator(mock_async) 79 | 80 | result = await wrapped_func() 81 | assert result == "success" 82 | assert mock_async.call_count == 2 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_async_max_retries_exceeded(): 87 | """Test that the async retry decorator fails after max retries.""" 88 | max_retries = 2 89 | wait_exponential_kwargs = {"multiplier": 1.0, "min": 1.0, "max": 10.0} 90 | 91 | class AsyncMock: 92 | def __init__(self): 93 | self.call_count = 0 94 | 95 | async def __call__(self): 96 | self.call_count += 1 97 | raise create_api_error() 98 | 99 | mock_async = AsyncMock() 100 | 101 | decorator = _create_retry_decorator( 102 | max_retries=max_retries, wait_exponential_kwargs=wait_exponential_kwargs 103 | ) 104 | wrapped_func = decorator(mock_async) 105 | 106 | with pytest.raises(APIError): 107 | await wrapped_func() 108 | assert mock_async.call_count == 2 109 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_model_garden_timeout.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import httpx 4 | import pytest 5 | 6 | from langchain_google_vertexai.model_garden import ChatAnthropicVertex 7 | 8 | TIMEOUT_TEST_CASES = [ 9 | pytest.param(30.0, id="float_timeout"), 10 | pytest.param(httpx.Timeout(timeout=30.0), id="httpx_timeout"), 11 | pytest.param(None, id="none_timeout"), 12 | pytest.param( 13 | ..., # No timeout specified in constructor 14 | id="default_timeout", 15 | ), 16 | ] 17 | 18 | 19 | @pytest.mark.parametrize("timeout_value", TIMEOUT_TEST_CASES) 20 | def test_timeout_configuration(timeout_value): 21 | """Test that different timeout values are correctly handled.""" 22 | with patch("anthropic.AnthropicVertex") as mock_sync_client, patch( 23 | "anthropic.AsyncAnthropicVertex" 24 | ) as mock_async_client: 25 | mock_sync_instance = Mock() 26 | mock_sync_instance.timeout = None if timeout_value is ... else timeout_value 27 | mock_sync_client.return_value = mock_sync_instance 28 | 29 | mock_async_instance = Mock() 30 | mock_async_instance.timeout = None if timeout_value is ... else timeout_value 31 | mock_async_client.return_value = mock_async_instance 32 | 33 | # Create chat instance with or without timeout parameter 34 | chat_kwargs = {"project": "test-project"} 35 | if timeout_value is not ...: 36 | chat_kwargs["timeout"] = timeout_value 37 | 38 | chat = ChatAnthropicVertex(**chat_kwargs) 39 | 40 | # Verify initialization 41 | mock_sync_client.assert_called_once() 42 | expected_timeout = None if timeout_value is ... else timeout_value 43 | assert ( 44 | mock_sync_client.call_args.kwargs["timeout"] == expected_timeout 45 | ), "Synchronous Anthropic instance not initialized with correct timeout" 46 | 47 | mock_async_client.assert_called_once() 48 | assert ( 49 | mock_async_client.call_args.kwargs["timeout"] == expected_timeout 50 | ), "Asynchronous Anthropic instance not initialized with correct timeout" 51 | 52 | # Verify the clients have the correct timeout after initialization 53 | assert ( 54 | chat.client.timeout == expected_timeout 55 | ), "Sync client timeout not set correctly after initialization" 56 | assert ( 57 | chat.async_client.timeout == expected_timeout 58 | ), "Async client timeout not set correctly after initialization" 59 | 60 | 61 | def test_timeout_invalid(): 62 | """Test that invalid timeout values raise appropriate errors.""" 63 | with pytest.raises(ValueError) as exc_info: 64 | ChatAnthropicVertex( 65 | project="test-project", 66 | timeout="invalid", 67 | ) 68 | assert "Input should be a valid number" in str(exc_info.value) 69 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_standard.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from langchain_core.language_models import BaseChatModel 4 | from langchain_tests.unit_tests import ChatModelUnitTests 5 | 6 | from langchain_google_vertexai import ChatVertexAI 7 | 8 | 9 | class TestGemini_15_AIStandard(ChatModelUnitTests): 10 | @property 11 | def chat_model_class(self) -> Type[BaseChatModel]: 12 | return ChatVertexAI 13 | 14 | @property 15 | def chat_model_params(self) -> dict: 16 | return {"model_name": "gemini-1.5-pro-001"} 17 | -------------------------------------------------------------------------------- /libs/vertexai/tests/unit_tests/test_vision_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vertexai.vision_models import Image # type: ignore[import-untyped] 3 | 4 | from langchain_google_vertexai.vision_models import _BaseImageTextModel 5 | 6 | 7 | def test_get_image_from_message_part(base64_image: str): 8 | model = _BaseImageTextModel() 9 | 10 | # Should work with a well formatted dictionary: 11 | message = {"type": "image_url", "image_url": {"url": base64_image}} 12 | image = model._get_image_from_message_part(message) 13 | assert isinstance(image, Image) 14 | 15 | # Should not work with a simple string 16 | simple_string = base64_image 17 | image = model._get_image_from_message_part(simple_string) 18 | assert image is None 19 | 20 | # Should not work with a string message 21 | message = {"type": "text", "text": "I'm a text message"} 22 | image = model._get_image_from_message_part(message) 23 | assert image is None 24 | 25 | 26 | def test_get_text_from_message_part(): 27 | DUMMY_MESSAGE = "Some message" 28 | model = _BaseImageTextModel() 29 | 30 | # Should not work with an image 31 | message = {"type": "image_url", "image_url": {"url": base64_image}} 32 | text = model._get_text_from_message_part(message) 33 | assert text is None 34 | 35 | # Should work with a simple string 36 | simple_message = DUMMY_MESSAGE 37 | text = model._get_text_from_message_part(simple_message) 38 | assert text == DUMMY_MESSAGE 39 | 40 | # Should work with a text message 41 | message = {"type": "text", "text": DUMMY_MESSAGE} 42 | text = model._get_text_from_message_part(message) 43 | assert text == DUMMY_MESSAGE 44 | 45 | 46 | @pytest.fixture 47 | def base64_image() -> str: 48 | return ( 49 | "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA" 50 | "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" 51 | "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" 52 | "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" 53 | "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" 54 | "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" 55 | "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" 56 | "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" 57 | "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" 58 | "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" 59 | "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" 60 | "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" 61 | "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" 62 | "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" 63 | "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" 64 | "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" 65 | "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" 66 | "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" 67 | "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" 68 | ) 69 | -------------------------------------------------------------------------------- /terraform/cloudbuild/main.tf: -------------------------------------------------------------------------------- 1 | provider "google" { 2 | project = var.project_id 3 | } 4 | 5 | 6 | locals { 7 | cloudbuild_service_account_roles = [ 8 | "roles/bigquery.user", #BigQuery User 9 | "roles/discoveryengine.editor", #Discovery Engine Editor 10 | "roles/logging.logWriter", #Logs Writer 11 | "roles/secretmanager.secretAccessor", #Secret Manager Secret Accessor 12 | "roles/serviceusage.serviceUsageConsumer", #Service Usage Consumer 13 | "roles/aiplatform.user", #Vertex AI User 14 | ] 15 | cloudbuild_env_vars = merge( 16 | { 17 | for key, value in var.cloudbuild_env_vars : 18 | "_${upper(key)}" => value 19 | }, 20 | { _LIB = var.library }, 21 | { _POETRY_VERSION = var.poetry_version }, 22 | { _PYTHON_VERSION = var.python_version }, 23 | ) 24 | #TODO: multiline 25 | cloudbuild_config = "python -m pip install -q poetry==$${_POETRY_VERSION} --verbose && cd libs/$${_LIB} && poetry install -q --with test,test_integration --all-extras && poetry run pytest --extended --release tests/integration_tests/" 26 | } 27 | 28 | resource "google_cloudbuild_trigger" "cloudbuild_trigger" { 29 | name = "${var.prefix}-${var.library}" 30 | location = var.region 31 | service_account = google_service_account.cloudbuild_service_account.id 32 | 33 | included_files = ["libs/${var.library}/**"] 34 | 35 | repository_event_config { 36 | repository = var.cloudbuildv2_repository_id 37 | pull_request { 38 | branch = "^main$" 39 | #comment_control = "COMMENTS_ENABLED_FOR_EXTERNAL_CONTRIBUTORS_ONLY" #Configure builds to run whether a repository owner or collaborator need to comment /gcbrun 40 | } 41 | } 42 | 43 | substitutions = local.cloudbuild_env_vars 44 | 45 | include_build_logs = "INCLUDE_BUILD_LOGS_WITH_STATUS" 46 | 47 | build { 48 | step { 49 | id = "integration tests" 50 | name = "python:$${_PYTHON_VERSION}" 51 | args = ["-c", local.cloudbuild_config] 52 | entrypoint = "bash" 53 | env = concat(["PROJECT_ID=$PROJECT_ID"], [for env_name, env_value in var.cloudbuild_env_vars : "${env_name}=$_${env_name}"]) 54 | secret_env = keys(var.cloudbuild_secret_vars) 55 | } 56 | 57 | options { 58 | logging = "CLOUD_LOGGING_ONLY" 59 | } 60 | 61 | available_secrets { 62 | dynamic "secret_manager" { 63 | for_each = var.cloudbuild_secret_vars 64 | content { 65 | env = secret_manager.key 66 | version_name = "projects/$${PROJECT_ID}/secrets/${secret_manager.value}/versions/latest" 67 | } 68 | } 69 | } 70 | } 71 | } 72 | 73 | resource "google_service_account" "cloudbuild_service_account" { 74 | account_id = "${var.library}-cb-sa" 75 | } 76 | 77 | resource "google_project_iam_member" "cloudbuild_service_account_iam" { 78 | project = var.project_id 79 | for_each = toset(local.cloudbuild_service_account_roles) 80 | role = each.value 81 | member = "serviceAccount:${google_service_account.cloudbuild_service_account.email}" 82 | } 83 | -------------------------------------------------------------------------------- /terraform/cloudbuild/variables.tf: -------------------------------------------------------------------------------- 1 | variable "project_id" { 2 | type = string 3 | description = "" 4 | } 5 | 6 | variable "region" { 7 | type = string 8 | default = "us-central1" 9 | description = "" 10 | } 11 | 12 | variable "prefix" { 13 | type = string 14 | default = "langchain-google" 15 | description = "" 16 | } 17 | 18 | variable "cloudbuildv2_repository_id" { 19 | type = string 20 | description = "" 21 | } 22 | 23 | variable "poetry_version" { 24 | type = string 25 | default = "1.7.1" 26 | description = "" 27 | } 28 | 29 | variable "python_version" { 30 | type = string 31 | default = "3.11" 32 | description = "" 33 | } 34 | 35 | variable "library" { 36 | type = string 37 | description = "" 38 | } 39 | 40 | variable "cloudbuild_env_vars" { 41 | type = map(string) 42 | } 43 | 44 | variable "cloudbuild_secret_vars" { 45 | type = map(string) 46 | } -------------------------------------------------------------------------------- /terraform/github-connection/main.tf: -------------------------------------------------------------------------------- 1 | provider "google" { 2 | project = var.project_id 3 | } 4 | 5 | resource "google_cloudbuildv2_connection" "langchain_google_github_connection" { 6 | location = var.region 7 | name = "${var.prefix}-connection" 8 | 9 | github_config { 10 | app_installation_id = var.github_app_installation_id 11 | authorizer_credential { 12 | oauth_token_secret_version = "${var.github_oauth_token_secret_id}/versions/latest" 13 | } 14 | } 15 | } 16 | 17 | resource "google_cloudbuildv2_repository" "langchain_google_repository" { 18 | name = "${var.prefix}-repository" 19 | parent_connection = google_cloudbuildv2_connection.langchain_google_github_connection.id 20 | remote_uri = var.langchain_github_repo 21 | } 22 | 23 | output "langchain_google_repository_id" { 24 | value = google_cloudbuildv2_repository.langchain_google_repository.id 25 | } -------------------------------------------------------------------------------- /terraform/github-connection/variables.tf: -------------------------------------------------------------------------------- 1 | variable "project_id" { 2 | type = string 3 | description = "" 4 | } 5 | 6 | variable "region" { 7 | type = string 8 | default = "us-central1" 9 | description = "" 10 | } 11 | 12 | variable "prefix" { 13 | type = string 14 | default = "langchain-google" 15 | description = "" 16 | } 17 | 18 | variable "github_oauth_token_secret_id" { 19 | type = string 20 | description = "" 21 | } 22 | 23 | variable "langchain_github_repo" { 24 | type = string 25 | description = "" 26 | default = "https://github.com/langchain-ai/langchain-google.git" 27 | } 28 | 29 | variable "github_app_installation_id" { 30 | type = string 31 | description = "Your installation ID can be found in the URL of your Cloud Build GitHub App. In the following URL, https://github.com/settings/installations/1234567, the installation ID is the numerical value 1234567." 32 | } -------------------------------------------------------------------------------- /terraform/secrets/main.tf: -------------------------------------------------------------------------------- 1 | provider "google" { 2 | project = var.project_id 3 | } 4 | 5 | resource "google_secret_manager_secret" "google_api_key_secret" { 6 | secret_id = "${var.prefix}-google-api-key" 7 | replication { 8 | auto {} 9 | } 10 | } 11 | 12 | resource "google_secret_manager_secret_version" "google_api_key_secret_version" { 13 | secret = google_secret_manager_secret.google_api_key_secret.id 14 | secret_data = var.google_api_key 15 | } 16 | 17 | output "google_api_key_secret_id" { 18 | value = google_secret_manager_secret.google_api_key_secret.id 19 | } 20 | 21 | resource "google_secret_manager_secret" "google_cse_id_secret" { 22 | secret_id = "${var.prefix}-google-cse-id" 23 | replication { 24 | auto {} 25 | } 26 | } 27 | 28 | resource "google_secret_manager_secret_version" "google_cse_id_secret_version" { 29 | secret = google_secret_manager_secret.google_cse_id_secret.id 30 | secret_data = var.google_cse_id 31 | } 32 | 33 | output "google_cse_id_secret_id" { 34 | value = google_secret_manager_secret.google_cse_id_secret.id 35 | } 36 | 37 | resource "google_secret_manager_secret" "github_oauth_token_secret" { 38 | secret_id = "${var.prefix}-github-oauth-token" 39 | replication { 40 | auto {} 41 | } 42 | } 43 | 44 | resource "google_secret_manager_secret_version" "github_oauth_token_secret_version" { 45 | secret = google_secret_manager_secret.github_oauth_token_secret.id 46 | secret_data = var.github_oauth_token 47 | } 48 | 49 | output "github_oauth_token_secret_id" { 50 | value = google_secret_manager_secret.github_oauth_token_secret.id 51 | } 52 | 53 | #autocreate api key with scope restrictions 54 | #add link to google_cse_id creation -------------------------------------------------------------------------------- /terraform/secrets/variables.tf: -------------------------------------------------------------------------------- 1 | variable "project_id" { 2 | type = string 3 | description = "" 4 | } 5 | 6 | variable "prefix" { 7 | type = string 8 | default = "langchain-google" 9 | description = "" 10 | } 11 | 12 | variable "github_oauth_token" { 13 | type = string 14 | description = "" 15 | } 16 | 17 | variable "google_api_key" { 18 | type = string 19 | description = "" 20 | } 21 | 22 | variable "google_cse_id" { 23 | type = string 24 | description = "" 25 | } --------------------------------------------------------------------------------