├── .github ├── dependabot.yml ├── scripts │ └── check_diff.py └── workflows │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ ├── _test_release.yml │ ├── ci.yml │ ├── codeql.yml │ └── zizmor.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── RELEASE.md ├── docs ├── _extensions │ └── gallery_directive.py ├── _static │ ├── favicon.png │ ├── wordmark-api-dark.svg │ └── wordmark-api.svg ├── conf.py ├── create_api_rst.py └── templates │ ├── class.rst │ ├── enum.rst │ ├── function.rst │ ├── langchain_docs.html │ ├── pydantic.rst │ ├── runnable_non_pydantic.rst │ ├── runnable_pydantic.rst │ └── typeddict.rst ├── justfile ├── libs ├── langchain-mongodb │ ├── .gitignore │ ├── CHANGELOG.md │ ├── LICENSE │ ├── README.md │ ├── justfile │ ├── langchain_mongodb │ │ ├── __init__.py │ │ ├── agent_toolkit │ │ │ ├── __init__.py │ │ │ ├── database.py │ │ │ ├── prompt.py │ │ │ ├── tool.py │ │ │ └── toolkit.py │ │ ├── cache.py │ │ ├── chat_message_histories.py │ │ ├── docstores.py │ │ ├── graphrag │ │ │ ├── __init__.py │ │ │ ├── example_templates.py │ │ │ ├── graph.py │ │ │ ├── prompts.py │ │ │ └── schema.py │ │ ├── index.py │ │ ├── indexes.py │ │ ├── loaders.py │ │ ├── pipelines.py │ │ ├── py.typed │ │ ├── retrievers │ │ │ ├── __init__.py │ │ │ ├── full_text_search.py │ │ │ ├── graphrag.py │ │ │ ├── hybrid_search.py │ │ │ ├── parent_document.py │ │ │ └── self_querying.py │ │ ├── utils.py │ │ └── vectorstores.py │ ├── pyproject.toml │ ├── tests │ │ ├── __init__.py │ │ ├── integration_tests │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ ├── test_agent_toolkit.py │ │ │ ├── test_cache.py │ │ │ ├── test_chain_example.py │ │ │ ├── test_chat_message_histories.py │ │ │ ├── test_compile.py │ │ │ ├── test_docstore.py │ │ │ ├── test_graphrag.py │ │ │ ├── test_index.py │ │ │ ├── test_indexes.py │ │ │ ├── test_loaders.py │ │ │ ├── test_mmr.py │ │ │ ├── test_mongodb_database.py │ │ │ ├── test_parent_document.py │ │ │ ├── test_retriever_selfquerying.py │ │ │ ├── test_retrievers.py │ │ │ ├── test_retrievers_standard.py │ │ │ ├── test_tools.py │ │ │ ├── test_vectorstore_add_delete.py │ │ │ ├── test_vectorstore_from_documents.py │ │ │ ├── test_vectorstore_from_texts.py │ │ │ └── test_vectorstore_standard.py │ │ ├── unit_tests │ │ │ ├── __init__.py │ │ │ ├── test_cache.py │ │ │ ├── test_chat_message_histories.py │ │ │ ├── test_imports.py │ │ │ ├── test_index.py │ │ │ ├── test_retrievers.py │ │ │ ├── test_tools.py │ │ │ └── test_vectorstores.py │ │ └── utils.py │ └── uv.lock ├── langgraph-checkpoint-mongodb │ ├── CHANGELOG.md │ ├── README.md │ ├── justfile │ ├── langgraph │ │ └── checkpoint │ │ │ └── mongodb │ │ │ ├── __init__.py │ │ │ ├── aio.py │ │ │ ├── saver.py │ │ │ └── utils.py │ ├── pyproject.toml │ ├── tests │ │ ├── __snapshots__ │ │ │ ├── test_pregel.ambr │ │ │ └── test_pregel_async.ambr │ │ ├── integration_tests │ │ │ ├── README │ │ │ ├── conftest.py │ │ │ ├── test_interrupts.py │ │ │ └── test_sanity.py │ │ └── unit_tests │ │ │ ├── conftest.py │ │ │ ├── test_async.py │ │ │ ├── test_delete_thread.py │ │ │ └── test_sync.py │ └── uv.lock └── langgraph-store-mongodb │ ├── CHANGELOG.md │ ├── README.md │ ├── justfile │ ├── langgraph │ └── store │ │ └── mongodb │ │ ├── __init__.py │ │ └── base.py │ ├── pyproject.toml │ ├── tests │ ├── integration_tests │ │ └── test_store_semantic.py │ └── unit_tests │ │ ├── test_store.py │ │ └── test_store_async.py │ └── uv.lock ├── pyproject.toml ├── scripts ├── setup_ollama.sh ├── start_local_atlas.sh └── update-locks.sh └── uv.lock /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | groups: 9 | actions: 10 | patterns: 11 | - "*" 12 | # Python 13 | - package-ecosystem: "pip" 14 | directory: "libs/langchain-mongodb" 15 | schedule: 16 | interval: "weekly" 17 | - package-ecosystem: "pip" 18 | directory: "libs/langgraph-checkpoint-mongodb" 19 | schedule: 20 | interval: "weekly" 21 | - package-ecosystem: "pip" 22 | directory: "libs/langgraph-store-mongodb" 23 | schedule: 24 | interval: "weekly" 25 | -------------------------------------------------------------------------------- /.github/scripts/check_diff.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import Dict 4 | 5 | LIB_DIRS = [ 6 | "libs/langchain-mongodb", 7 | "libs/langgraph-checkpoint-mongodb", 8 | "libs/langgraph-store-mongodb", 9 | ] 10 | 11 | if __name__ == "__main__": 12 | files = sys.argv[1:] # changed files 13 | 14 | dirs_to_run: Dict[str, set] = { 15 | "lint": set(), 16 | "test": set(), 17 | } 18 | 19 | if len(files) == 300: 20 | # max diff length is 300 files - there are likely files missing 21 | raise ValueError("Max diff reached. Please manually run CI on changed libs.") 22 | 23 | for file in files: 24 | if any(file.startswith(dir_) for dir_ in (".github", "scripts")): 25 | # add all LIB_DIRS for infra changes 26 | dirs_to_run["test"].update(LIB_DIRS) 27 | dirs_to_run["lint"].update(LIB_DIRS) 28 | 29 | if any(file.startswith(dir_) for dir_ in LIB_DIRS): 30 | for dir_ in LIB_DIRS: 31 | if file.startswith(dir_): 32 | dirs_to_run["test"].add(dir_) 33 | dirs_to_run["lint"].add(dir_) 34 | elif file.startswith("libs/"): 35 | raise ValueError( 36 | f"Unknown lib: {file}. check_diff.py likely needs " 37 | "an update for this new library!" 38 | ) 39 | 40 | outputs = { 41 | "dirs-to-lint": list(dirs_to_run["lint"]), 42 | "dirs-to-test": list(dirs_to_run["test"]), 43 | } 44 | for key, value in outputs.items(): 45 | json_output = json.dumps(value) 46 | print(f"{key}={json_output}") # noqa: T201 47 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 13 | 14 | # This env var allows us to get inline annotations when ruff has complaints. 15 | RUFF_OUTPUT_FORMAT: github 16 | 17 | jobs: 18 | build: 19 | name: "run lint #${{ matrix.python-version }}" 20 | runs-on: ubuntu-latest 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | # Only lint on the min and max supported Python versions. 25 | # It's extremely unlikely that there's a lint issue on any version in between 26 | # that doesn't show up on the min or max versions. 27 | # 28 | # GitHub rate-limits how many jobs can be running at any one time. 29 | # Starting new jobs is also relatively slow, 30 | # so linting on fewer versions makes CI faster. 31 | python-version: 32 | - "3.9" 33 | - "3.12" 34 | steps: 35 | - uses: actions/checkout@v4 36 | with: 37 | persist-credentials: false 38 | - name: Install uv 39 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5 40 | with: 41 | enable-cache: true 42 | python-version: ${{ matrix.python-version }} 43 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock" 44 | 45 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3 46 | 47 | - name: Install dependencies 48 | working-directory: ${{ inputs.working-directory }} 49 | run: just install 50 | 51 | - name: Get .mypy_cache to speed up mypy 52 | uses: actions/cache@v4 53 | env: 54 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" 55 | with: 56 | path: | 57 | ${{ env.WORKDIR }}/.mypy_cache 58 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/uv.lock', inputs.working-directory)) }} 59 | 60 | 61 | - name: Analysing the code with our lint 62 | working-directory: ${{ inputs.working-directory }} 63 | run: just lint 64 | 65 | - name: Checking the types 66 | working-directory: ${{ inputs.working-directory }} 67 | run: just typing 68 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | jobs: 12 | build: 13 | defaults: 14 | run: 15 | working-directory: ${{ inputs.working-directory }} 16 | runs-on: ubuntu-latest 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | python-version: 21 | - "3.9" 22 | - "3.12" 23 | name: "run test #${{ matrix.python-version }}" 24 | steps: 25 | - uses: actions/checkout@v4 26 | with: 27 | persist-credentials: false 28 | - name: Install uv 29 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5 30 | with: 31 | enable-cache: true 32 | python-version: ${{ matrix.python-version }} 33 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock" 34 | 35 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3 36 | 37 | - name: Install dependencies 38 | shell: bash 39 | run: just install 40 | 41 | - name: Start MongoDB 42 | uses: supercharge/mongodb-github-action@90004df786821b6308fb02299e5835d0dae05d0d # 1.12.0 43 | 44 | - name: Run unit tests 45 | shell: bash 46 | run: just unit_tests 47 | 48 | - name: Start local Atlas 49 | working-directory: . 50 | run: bash scripts/start_local_atlas.sh 51 | 52 | - name: Install Ollama 53 | run: curl -fsSL https://ollama.com/install.sh | sh 54 | 55 | - name: Run Ollama 56 | working-directory: . 57 | run: | 58 | ollama serve & 59 | sleep 5 # wait for the Ollama server to be ready 60 | bash scripts/setup_ollama.sh 61 | 62 | - name: Run integration tests 63 | run: just integration_tests 64 | 65 | - name: Ensure the tests did not create any additional files 66 | shell: bash 67 | run: | 68 | set -eu 69 | 70 | STATUS="$(git status)" 71 | echo "$STATUS" 72 | 73 | # grep will exit non-zero if the target message isn't found, 74 | # and `set -e` above will cause the step to fail. 75 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 76 | 77 | - name: Run unit tests with minimum dependency versions 78 | run: | 79 | uv sync --python=${{ matrix.python-version }} --resolution=lowest-direct 80 | just unit_tests 81 | -------------------------------------------------------------------------------- /.github/workflows/_test_release.yml: -------------------------------------------------------------------------------- 1 | name: test-release 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | PYTHON_VERSION: "3.10" 13 | 14 | jobs: 15 | build: 16 | if: github.ref == 'refs/heads/main' 17 | runs-on: ubuntu-latest 18 | defaults: 19 | run: 20 | working-directory: ${{ inputs.working-directory }} 21 | outputs: 22 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 23 | version: ${{ steps.check-version.outputs.version }} 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | with: 28 | persist-credentials: false 29 | - name: Install uv 30 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5 31 | with: 32 | enable-cache: true 33 | python-version: ${{ env.PYTHON_VERSION }} 34 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock" 35 | 36 | # We want to keep this build stage *separate* from the release stage, 37 | # so that there's no sharing of permissions between them. 38 | # The release stage has trusted publishing and GitHub repo contents write access, 39 | # and we want to keep the scope of that access limited just to the release job. 40 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 41 | # could get access to our GitHub or PyPI credentials. 42 | # 43 | # Per the trusted publishing GitHub Action: 44 | # > It is strongly advised to separate jobs for building [...] 45 | # > from the publish job. 46 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 47 | - name: Build project for distribution 48 | run: uv build 49 | 50 | - name: Upload build 51 | uses: actions/upload-artifact@v4 52 | with: 53 | name: test-dist 54 | path: ${{ inputs.working-directory }}/dist/ 55 | 56 | - name: Check Version 57 | id: check-version 58 | shell: bash 59 | run: | 60 | set -eu 61 | echo pkg-name="$(uvx hatch project metadata | jq -r .name)" >> $GITHUB_OUTPUT 62 | echo version="$(uvx hatch version)" >> $GITHUB_OUTPUT 63 | 64 | publish: 65 | needs: 66 | - build 67 | runs-on: ubuntu-latest 68 | permissions: 69 | # This permission is used for trusted publishing: 70 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 71 | # 72 | # Trusted publishing has to also be configured on PyPI for each package: 73 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 74 | id-token: write 75 | 76 | steps: 77 | - uses: actions/checkout@v4 78 | with: 79 | persist-credentials: false 80 | - uses: actions/download-artifact@v4 81 | with: 82 | name: test-dist 83 | path: ${{ inputs.working-directory }}/dist/ 84 | 85 | - name: Publish to test PyPI 86 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # release/v1 87 | with: 88 | packages-dir: ${{ inputs.working-directory }}/dist/ 89 | verbose: true 90 | print-hash: true 91 | repository-url: https://test.pypi.org/legacy/ 92 | 93 | # We overwrite any existing distributions with the same name and version. 94 | # This is *only for CI use* and is *extremely dangerous* otherwise! 95 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates 96 | skip-existing: true 97 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0 98 | attestations: false 99 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CI 3 | 4 | on: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | 9 | # If another push to the same PR or branch happens while this workflow is still running, 10 | # cancel the earlier run in favor of the next run. 11 | # 12 | # There's no point in testing an outdated version of the code. GitHub only allows 13 | # a limited number of job runners to be active at the same time, so it's better to cancel 14 | # pointless jobs early so that more useful jobs can run sooner. 15 | concurrency: 16 | group: ${{ github.workflow }}-${{ github.ref }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | build: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | persist-credentials: false 26 | - uses: actions/setup-python@v5 27 | with: 28 | python-version: '3.10' 29 | - id: files 30 | uses: Ana06/get-changed-files@25f79e676e7ea1868813e21465014798211fad8c # v2.3.0 31 | - id: set-matrix 32 | env: 33 | FILES: ${{ steps.files.outputs.all }} 34 | run: | 35 | python .github/scripts/check_diff.py ${FILES} >> $GITHUB_OUTPUT 36 | outputs: 37 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }} 38 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }} 39 | pre-commit: 40 | name: pre-commit 41 | runs-on: ubuntu-latest 42 | 43 | steps: 44 | - uses: actions/checkout@v4 45 | with: 46 | persist-credentials: false 47 | - uses: actions/setup-python@v5 48 | - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 49 | with: 50 | extra_args: --all-files --hook-stage=manual 51 | lint: 52 | name: cd ${{ matrix.working-directory }} 53 | needs: [ build ] 54 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }} 55 | strategy: 56 | fail-fast: false 57 | matrix: 58 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }} 59 | uses: ./.github/workflows/_lint.yml 60 | with: 61 | working-directory: ${{ matrix.working-directory }} 62 | secrets: inherit 63 | 64 | test: 65 | name: cd ${{ matrix.working-directory }} 66 | needs: [ build ] 67 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }} 68 | strategy: 69 | fail-fast: false 70 | matrix: 71 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }} 72 | uses: ./.github/workflows/_test.yml 73 | with: 74 | working-directory: ${{ matrix.working-directory }} 75 | secrets: inherit 76 | 77 | api-docs: 78 | name: API docs 79 | runs-on: ubuntu-latest 80 | steps: 81 | - uses: actions/checkout@v4 82 | with: 83 | persist-credentials: false 84 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3 85 | - name: Install uv 86 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5 87 | with: 88 | enable-cache: true 89 | python-version: 3.9 90 | - run: just docs 91 | 92 | ci_success: 93 | name: "CI Success" 94 | needs: [build, lint, test, pre-commit, api-docs] 95 | if: | 96 | always() 97 | runs-on: ubuntu-latest 98 | env: 99 | JOBS_JSON: ${{ toJSON(needs) }} 100 | RESULTS_JSON: ${{ toJSON(needs.*.result) }} 101 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} 102 | steps: 103 | - name: "CI Success" 104 | run: | 105 | echo $JOBS_JSON 106 | echo $RESULTS_JSON 107 | echo "Exiting with $EXIT_CODE" 108 | exit $EXIT_CODE 109 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | branches: [ "main"] 6 | tags: ['*'] 7 | pull_request: 8 | workflow_call: 9 | inputs: 10 | ref: 11 | required: true 12 | type: string 13 | schedule: 14 | - cron: '17 10 * * 2' 15 | 16 | jobs: 17 | analyze: 18 | name: Analyze (${{ matrix.language }}) 19 | runs-on: "ubuntu-latest" 20 | timeout-minutes: 360 21 | permissions: 22 | # required for all workflows 23 | security-events: write 24 | 25 | # required to fetch internal or private CodeQL packs 26 | packages: read 27 | 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | include: 32 | - language: python 33 | build-mode: none 34 | - language: actions 35 | build-mode: none 36 | steps: 37 | - name: Checkout repository 38 | uses: actions/checkout@v4 39 | with: 40 | ref: ${{ inputs.ref }} 41 | persist-credentials: false 42 | - uses: actions/setup-python@v3 43 | 44 | # Initializes the CodeQL tools for scanning. 45 | - name: Initialize CodeQL 46 | uses: github/codeql-action/init@28deaeda66b76a05916b6923827895f2b14ab387 # v3 47 | with: 48 | languages: ${{ matrix.language }} 49 | build-mode: ${{ matrix.build-mode }} 50 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 51 | queries: security-extended 52 | config: | 53 | paths-ignore: 54 | - 'docs/**' 55 | - 'scripts/**' 56 | - '**/tests/**' 57 | 58 | - if: matrix.build-mode == 'manual' 59 | run: | 60 | pip install -e . 61 | 62 | - name: Perform CodeQL Analysis 63 | uses: github/codeql-action/analyze@28deaeda66b76a05916b6923827895f2b14ab387 # v3 64 | with: 65 | category: "/language:${{matrix.language}}" 66 | -------------------------------------------------------------------------------- /.github/workflows/zizmor.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Actions Security Analysis with zizmor 🌈 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["**"] 8 | 9 | jobs: 10 | zizmor: 11 | name: zizmor latest via PyPI 12 | runs-on: ubuntu-latest 13 | permissions: 14 | security-events: write 15 | # required for workflows in private repositories 16 | contents: read 17 | actions: read 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | with: 22 | persist-credentials: false 23 | 24 | - name: Install the latest version of uv 25 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5 26 | 27 | - name: Run zizmor 🌈 28 | run: uvx zizmor --format sarif . > results.sarif 29 | env: 30 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | 32 | - name: Upload SARIF file 33 | uses: github/codeql-action/upload-sarif@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3 34 | with: 35 | sarif_file: results.sarif 36 | category: zizmor 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | .pytest_cache 4 | .ruff_cache 5 | .mypy_cache_test 6 | .env 7 | .venv* 8 | .local_atlas_uri 9 | docs/langchain_mongodb 10 | docs/langgraph_checkpoint_mongodb 11 | docs/index.md 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.5.0 5 | hooks: 6 | - id: check-added-large-files 7 | - id: check-case-conflict 8 | - id: check-toml 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: forbid-new-submodules 13 | - id: trailing-whitespace 14 | exclude_types: [json] 15 | exclude: | 16 | (?x)^(.*.ambr)$ 17 | 18 | # We use the Python version instead of the original version which seems to require Docker 19 | # https://github.com/koalaman/shellcheck-precommit 20 | - repo: https://github.com/shellcheck-py/shellcheck-py 21 | rev: v0.9.0.6 22 | hooks: 23 | - id: shellcheck 24 | name: shellcheck 25 | args: ["--severity=warning"] 26 | stages: [manual] 27 | 28 | - repo: https://github.com/sirosen/check-jsonschema 29 | rev: 0.29.4 30 | hooks: 31 | - id: check-github-workflows 32 | args: ["--verbose"] 33 | 34 | - repo: https://github.com/codespell-project/codespell 35 | rev: "v2.2.6" 36 | hooks: 37 | - id: codespell 38 | args: ["-L", "nin"] 39 | stages: [manual] 40 | 41 | - repo: https://github.com/astral-sh/ruff-pre-commit 42 | # Ruff version. 43 | rev: v0.8.5 44 | hooks: 45 | # Run the linter. 46 | - id: ruff 47 | args: [ --fix ] 48 | # Run the formatter. 49 | - id: ruff-format 50 | 51 | - repo: https://github.com/tcort/markdown-link-check 52 | rev: v3.12.2 53 | hooks: 54 | - id: markdown-link-check 55 | args: [-q] 56 | 57 | - repo: local 58 | hooks: 59 | - id: update-locks 60 | name: update-locks 61 | entry: bash ./scripts/update-locks.sh 62 | language: python 63 | require_serial: true 64 | fail_fast: true 65 | additional_dependencies: 66 | - uv 67 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-22.04 4 | tools: 5 | python: "3.11" 6 | jobs: 7 | create_environment: 8 | - asdf plugin add uv 9 | - asdf install uv latest 10 | - asdf global uv latest 11 | install: 12 | - uv sync 13 | build: 14 | html: 15 | - uv run sphinx-build -T -b html docs $READTHEDOCS_OUTPUT/html 16 | sphinx: 17 | configuration: docs/conf.py 18 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | We welcome contributions to this project! Please follow the following guidance to setup the project for development and start contributing. 4 | 5 | ### Fork and clone the repository 6 | 7 | To contribute to this project, please follow the ["fork and pull request"](https://docs.github.com/en/get-started/exploring-projects-on-github/contributing-to-a-project) workflow. Please do not try to push directly to this repo unless you are a maintainer. 8 | 9 | 10 | ### Dependency Management: uv and other env/dependency managers 11 | 12 | This project utilizes [uv](https://docs.astral.sh/uv/) v0.5.15+ as a dependency manager. 13 | 14 | Install uv: **[documentation on how to install it](https://docs.astral.sh/uv/getting-started/installation/)**. 15 | 16 | ### Local Development Dependencies 17 | 18 | The project configuration and the `justfile` for running dev commands are located under the `libs/langchain-mongodb` or `libs/langgraph-checkpoint-mongodb` directories. 19 | 20 | `just` can be [installed](https://just.systems/man/en/packages.html) from many package managers, including `brew`. 21 | 22 | ```bash 23 | cd libs/langchain-mongodb 24 | ``` 25 | 26 | Install langchain-mongodb development requirements (for running langchain, running examples, linting, formatting, tests, and coverage): 27 | 28 | ```bash 29 | just install 30 | ``` 31 | 32 | Then verify the installation. 33 | 34 | ```bash 35 | just unit_tests 36 | ``` 37 | 38 | In order to run the integration tests, you'll also need a `MONGODB_URI` for MongoDB Atlas, as well 39 | as either an `OPENAI_API_KEY` or a configured local version of [ollama](https://ollama.com/download). 40 | 41 | We have a convenience script to start a local Atlas instance, which requires `podman`: 42 | 43 | ```bash 44 | scripts/start_local_atlas.sh 45 | ``` 46 | 47 | This will create a `.local_atlas_uri` file that has the `MONGODB_URI` set. The `justfiles` are configured 48 | to read the environment variable from this file. 49 | 50 | If using `ollama`, we have a convenience script to download the library used in our tests: 51 | 52 | ```bash 53 | scripts/setup_ollama.sh 54 | ``` 55 | 56 | ### Testing 57 | 58 | Unit tests cover modular logic that does not require calls to outside APIs. 59 | If you add new logic, please add a unit test. 60 | 61 | To run unit tests: 62 | 63 | ```bash 64 | just unit_tests 65 | ``` 66 | 67 | Integration tests cover the end-to-end service calls as much as possible. 68 | However, in certain cases this might not be practical, so you can mock the 69 | service response for these tests. There are examples of this in the repo, 70 | that can help you write your own tests. If you have suggestions to improve 71 | this, please raise an issue. 72 | 73 | To run the integration tests: 74 | 75 | ```bash 76 | just integration_tests 77 | ``` 78 | 79 | ### Formatting and Linting 80 | 81 | Formatting ensures that the code in this repo has consistent style so that the 82 | code looks more presentable and readable. It corrects these errors when you run 83 | the formatting command. Linting finds and highlights the code errors and helps 84 | avoid coding practicies that can lead to errors. 85 | 86 | Run both of these locally before submitting a PR. The CI scripts will run these 87 | when you submit a PR, and you won't be able to merge changes without fixing 88 | issues identified by the CI. 89 | 90 | We use pre-commit for [pre-commit](https://pypi.org/project/pre-commit/) for 91 | automatic formatting of the codebase. 92 | 93 | To set up `pre-commit` locally, run: 94 | 95 | ```bash 96 | brew install pre-commit 97 | pre-commit install 98 | ``` 99 | 100 | #### Manual Linting and Formatting 101 | 102 | Linting and formatting for this project is done via a combination of [ruff](https://docs.astral.sh/ruff/rules/) and [mypy](http://mypy-lang.org/). 103 | 104 | To run lint: 105 | 106 | ```bash 107 | just lint 108 | ``` 109 | 110 | To run the type checker: 111 | 112 | ```bash 113 | just typing 114 | ``` 115 | 116 | We recognize linting can be annoying - if you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed. 117 | 118 | #### Building the docs 119 | 120 | The docs are build from the top level of the repository, by running `just docs`. 121 | 122 | #### Spellcheck 123 | 124 | Spellchecking for this project is done via [codespell](https://github.com/codespell-project/codespell). 125 | Note that `codespell` finds common typos, so it could have false-positive (correctly spelled but rarely used) and false-negatives (not finding misspelled) words. 126 | 127 | To check spelling for this project: 128 | 129 | ```bash 130 | just codespell 131 | ``` 132 | 133 | If codespell is incorrectly flagging a word, you can skip spellcheck for that word by adding it to the codespell config in the `.pre-commit-config.yaml` file. 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦜️🔗 LangChain MongoDB 2 | 3 | This is a Monorepo containing partner packages of MongoDB and LangChainAI. 4 | It includes integrations between MongoDB, Atlas, LangChain, and LangGraph. 5 | 6 | It contains the following packages. 7 | 8 | - `langchain-mongodb` ([PyPI](https://pypi.org/project/langchain-mongodb/)) 9 | - `langgraph-checkpoint-mongodb` ([PyPI](https://pypi.org/project/langgraph-checkpoint-mongodb/)) 10 | 11 | **Note**: This repository replaces all MongoDB integrations currently present in the `langchain-community` package. Users are encouraged to migrate to this repository as soon as possible. 12 | 13 | ## Features 14 | 15 | ### LangChain 16 | 17 | #### Components 18 | 19 | - [MongoDBAtlasFullTextSearchRetriever](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#full-text-search-retriever) 20 | - [MongoDBAtlasHybridSearchRetriever](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#hybrid-search-retriever) 21 | - [MongoDBAtlasSemanticCache](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#mongodbatlassemanticcache) 22 | - [MongoDBAtlasVectorSearch](https://python.langchain.com/docs/integrations/vectorstores/mongodb_atlas/) 23 | - [MongoDBCache](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#mongodbcache) 24 | - [MongoDBChatMessageHistory](https://python.langchain.com/docs/integrations/memory/mongodb_chat_message_history/) 25 | 26 | #### API Reference 27 | 28 | - [MongoDBAtlasParentDocumentRetriever](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/retrievers/langchain_mongodb.retrievers.parent_document.MongoDBAtlasParentDocumentRetriever.html#langchain_mongodb.retrievers.parent_document.MongoDBAtlasParentDocumentRetriever) 29 | - [MongoDBAtlasSelfQueryRetriever](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/retrievers/langchain_mongodb.retrievers.self_querying.MongoDBAtlasSelfQueryRetriever.html). 30 | - [MongoDBDatabaseToolkit](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/agent_toolkit/langchain_mongodb.agent_toolkit.toolkit.MongoDBDatabaseToolkit.html) 31 | - [MongoDBDatabase](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/agent_toolkit/langchain_mongodb.agent_toolkit.database.MongoDBDatabase.html#langchain_mongodb.agent_toolkit.database.MongoDBDatabase) 32 | - [MongoDBDocStore](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/docstores/langchain_mongodb.docstores.MongoDBDocStore.html#langchain_mongodb.docstores.MongoDBDocStore) 33 | - [MongoDBGraphStore](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/graphrag/langchain_mongodb.graphrag.graph.MongoDBGraphStore.html) 34 | - [MongoDBLoader](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/loaders/langchain_mongodb.loaders.MongoDBLoader.html#langchain_mongodb.loaders.MongoDBLoader) 35 | - [MongoDBRecordManager](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/indexes/langchain_mongodb.indexes.MongoDBRecordManager.html#langchain_mongodb.indexes.MongoDBRecordManager) 36 | 37 | ### LangGraph 38 | 39 | - Checkpointing (BaseCheckpointSaver) 40 | - [MongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/aio/langgraph.checkpoint.mongodb.aio.AsyncMongoDBSaver.html#asyncmongodbsaver) 41 | - [AsyncMongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/saver/langgraph.checkpoint.mongodb.saver.MongoDBSaver.html#mongodbsaver) 42 | 43 | ## Installation 44 | 45 | You can install the `langchain-mongodb` package from PyPI. 46 | 47 | ```bash 48 | pip install langchain-mongodb 49 | ``` 50 | 51 | You can install the `langgraph-checkpoint-mongodb` package from PyPI as well: 52 | 53 | ```bash 54 | pip install langgraph-checkpoint-mongodb 55 | ``` 56 | 57 | ## Usage 58 | 59 | See [langchain-mongodb usage](libs/langchain-mongodb/README.md#usage) and [langgraph-checkpoint-mongodb usage](libs/langgraph-checkpoint-mongodb/README.md#usage). 60 | 61 | For more detailed usage examples and documentation, please refer to the [LangChain documentation](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/). 62 | 63 | API docs can be found on [ReadTheDocs](https://langchain-mongodb.readthedocs.io/en/latest/index.html). 64 | 65 | ## Contributing 66 | 67 | See the [Contributing Guide](CONTRIBUTING.md). 68 | 69 | ## License 70 | 71 | This project is licensed under the [MIT License](LICENSE). 72 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Langchain MongoDB Releases 2 | 3 | ## Prep the JIRA Release 4 | 5 | - Go to the release in [JIRA](https://jira.mongodb.org/projects/INTPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin%3Arelease-page&status=unreleased). 6 | 7 | - Make sure there are no unfinished tickets. Move them to another version if need be. 8 | 9 | - Click on the triple dot icon to the right and select "Edit". 10 | 11 | - Update the description for a quick summary. 12 | 13 | ## Prep the Release 14 | 15 | - Create a PR to bump the version and update the changelog, including today's date. 16 | Bump the minor version for new features, patch for a bug fix. 17 | 18 | - Merge the PR. 19 | 20 | ## Run the Release Workflow 21 | 22 | - Got to the release [workflow](https://github.com/langchain-ai/langchain-mongodb/actions/workflows/_release.yml). 23 | 24 | - Click "Run Workflow". 25 | 26 | - Choose the appropriate library from the dropdown. 27 | 28 | - Click "Run Workflow". 29 | 30 | - The workflow will create the tag, release to PyPI, and create the GitHub Release. 31 | 32 | ## JIRA Release 33 | 34 | - Return to the JIRA release [list](https://jira.mongodb.org/projects/INTPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin%3Arelease-page&status=unreleased). 35 | 36 | - Click "Save". 37 | 38 | - Click on the triple dot again and select "Release". 39 | 40 | - Enter today's date, and click "Confirm". 41 | 42 | - Click "Release". 43 | 44 | 45 | ## Finish the Release 46 | 47 | - Return to the release action and wait for it to complete successfully. 48 | 49 | - Announce the release on Slack. e.g "ANN: langchain-mongodb 0.5 with support for GraphRAG. 50 | -------------------------------------------------------------------------------- /docs/_extensions/gallery_directive.py: -------------------------------------------------------------------------------- 1 | """A directive to generate a gallery of images from structured data. 2 | 3 | Generating a gallery of images that are all the same size is a common 4 | pattern in documentation, and this can be cumbersome if the gallery is 5 | generated programmatically. This directive wraps this particular use-case 6 | in a helper-directive to generate it with a single YAML configuration file. 7 | 8 | It currently exists for maintainers of the pydata-sphinx-theme, 9 | but might be abstracted into a standalone package if it proves useful. 10 | """ 11 | 12 | from pathlib import Path 13 | from typing import Any, ClassVar, Dict, List 14 | 15 | from docutils import nodes 16 | from docutils.parsers.rst import directives 17 | from sphinx.application import Sphinx 18 | from sphinx.util import logging 19 | from sphinx.util.docutils import SphinxDirective 20 | from yaml import safe_load 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | TEMPLATE_GRID = """ 26 | `````{{grid}} {columns} 27 | {options} 28 | 29 | {content} 30 | 31 | ````` 32 | """ 33 | 34 | GRID_CARD = """ 35 | ````{{grid-item-card}} {title} 36 | {options} 37 | 38 | {content} 39 | ```` 40 | """ 41 | 42 | 43 | class GalleryGridDirective(SphinxDirective): 44 | """A directive to show a gallery of images and links in a Bootstrap grid. 45 | 46 | The grid can be generated from a YAML file that contains a list of items, or 47 | from the content of the directive (also formatted in YAML). Use the parameter 48 | "class-card" to add an additional CSS class to all cards. When specifying the grid 49 | items, you can use all parameters from "grid-item-card" directive to customize 50 | individual cards + ["image", "header", "content", "title"]. 51 | 52 | Danger: 53 | This directive can only be used in the context of a Myst documentation page as 54 | the templates use Markdown flavored formatting. 55 | """ 56 | 57 | name = "gallery-grid" 58 | has_content = True 59 | required_arguments = 0 60 | optional_arguments = 1 61 | final_argument_whitespace = True 62 | option_spec: ClassVar[dict[str, Any]] = { 63 | # A class to be added to the resulting container 64 | "grid-columns": directives.unchanged, 65 | "class-container": directives.unchanged, 66 | "class-card": directives.unchanged, 67 | } 68 | 69 | def run(self) -> List[nodes.Node]: 70 | """Create the gallery grid.""" 71 | if self.arguments: 72 | # If an argument is given, assume it's a path to a YAML file 73 | # Parse it and load it into the directive content 74 | path_data_rel = Path(self.arguments[0]) 75 | path_doc, _ = self.get_source_info() 76 | path_doc = Path(path_doc).parent 77 | path_data = (path_doc / path_data_rel).resolve() 78 | if not path_data.exists(): 79 | logger.info(f"Could not find grid data at {path_data}.") 80 | nodes.text("No grid data found at {path_data}.") 81 | return 82 | yaml_string = path_data.read_text() 83 | else: 84 | yaml_string = "\n".join(self.content) 85 | 86 | # Use all the element with an img-bottom key as sites to show 87 | # and generate a card item for each of them 88 | grid_items = [] 89 | for item in safe_load(yaml_string): 90 | # remove parameters that are not needed for the card options 91 | title = item.pop("title", "") 92 | 93 | # build the content of the card using some extra parameters 94 | header = f"{item.pop('header')} \n^^^ \n" if "header" in item else "" 95 | image = f"![image]({item.pop('image')}) \n" if "image" in item else "" 96 | content = f"{item.pop('content')} \n" if "content" in item else "" 97 | 98 | # optional parameter that influence all cards 99 | if "class-card" in self.options: 100 | item["class-card"] = self.options["class-card"] 101 | 102 | loc_options_str = "\n".join(f":{k}: {v}" for k, v in item.items()) + " \n" 103 | 104 | card = GRID_CARD.format( 105 | options=loc_options_str, content=header + image + content, title=title 106 | ) 107 | grid_items.append(card) 108 | 109 | # Parse the template with Sphinx Design to create an output container 110 | # Prep the options for the template grid 111 | class_ = "gallery-directive" + f' {self.options.get("class-container", "")}' 112 | options = {"gutter": 2, "class-container": class_} 113 | options_str = "\n".join(f":{k}: {v}" for k, v in options.items()) 114 | 115 | # Create the directive string for the grid 116 | grid_directive = TEMPLATE_GRID.format( 117 | columns=self.options.get("grid-columns", "1 2 3 4"), 118 | options=options_str, 119 | content="\n".join(grid_items), 120 | ) 121 | 122 | # Parse content as a directive so Sphinx Design processes it 123 | container = nodes.container() 124 | self.state.nested_parse([grid_directive], 0, container) 125 | 126 | # Sphinx Design outputs a container too, so just use that 127 | return [container.children[0]] 128 | 129 | 130 | def setup(app: Sphinx) -> Dict[str, Any]: 131 | """Add custom configuration to sphinx app. 132 | 133 | Args: 134 | app: the Sphinx application 135 | 136 | Returns: 137 | the 2 parallel parameters set to ``True``. 138 | """ 139 | app.add_directive("gallery-grid", GalleryGridDirective) 140 | 141 | return { 142 | "parallel_read_safe": True, 143 | "parallel_write_safe": True, 144 | } 145 | -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/_static/wordmark-api-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /docs/_static/wordmark-api.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /docs/templates/class.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | {% if methods %} 10 | .. rubric:: {{ _('Methods') }} 11 | 12 | .. autosummary:: 13 | {% for item in methods %} 14 | ~{{ item }} 15 | {%- endfor %} 16 | 17 | {% for item in methods %} 18 | .. automethod:: {{ item }} 19 | {%- endfor %} 20 | 21 | {% endif %} 22 | {% endblock %} 23 | -------------------------------------------------------------------------------- /docs/templates/enum.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block attributes %} 9 | {% for item in attributes %} 10 | .. autoattribute:: {{ item }} 11 | {% endfor %} 12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /docs/templates/function.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | -------------------------------------------------------------------------------- /docs/templates/langchain_docs.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 9 | 10 | 11 | Docs 12 | 13 | -------------------------------------------------------------------------------- /docs/templates/pydantic.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autopydantic_model:: {{ objname }} 7 | :model-show-json: False 8 | :model-show-config-summary: False 9 | :model-show-validator-members: False 10 | :model-show-field-summary: False 11 | :field-signature-prefix: param 12 | :members: 13 | :undoc-members: 14 | :inherited-members: 15 | :member-order: groupwise 16 | :show-inheritance: True 17 | :special-members: __call__ 18 | :exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, model_extra, model_fields_set, model_json_schema 19 | 20 | 21 | {% block attributes %} 22 | {% endblock %} 23 | -------------------------------------------------------------------------------- /docs/templates/runnable_non_pydantic.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | .. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface `. 🏃 9 | 10 | The :py:class:`Runnable Interface ` has additional methods that are available on runnables, such as :py:meth:`with_types `, :py:meth:`with_retry `, :py:meth:`assign `, :py:meth:`bind `, :py:meth:`get_graph `, and more. 11 | 12 | {% block attributes %} 13 | {% if attributes %} 14 | .. rubric:: {{ _('Attributes') }} 15 | 16 | .. autosummary:: 17 | {% for item in attributes %} 18 | ~{{ item }} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | 23 | {% block methods %} 24 | {% if methods %} 25 | .. rubric:: {{ _('Methods') }} 26 | 27 | .. autosummary:: 28 | {% for item in methods %} 29 | ~{{ item }} 30 | {%- endfor %} 31 | 32 | {% for item in methods %} 33 | .. automethod:: {{ item }} 34 | {%- endfor %} 35 | 36 | {% endif %} 37 | {% endblock %} 38 | -------------------------------------------------------------------------------- /docs/templates/runnable_pydantic.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autopydantic_model:: {{ objname }} 7 | :model-show-json: False 8 | :model-show-config-summary: False 9 | :model-show-validator-members: False 10 | :model-show-field-summary: False 11 | :field-signature-prefix: param 12 | :members: 13 | :undoc-members: 14 | :inherited-members: 15 | :member-order: groupwise 16 | :show-inheritance: True 17 | :special-members: __call__ 18 | :exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema, predict, apredict, predict_messages, apredict_messages, generate, generate_prompt, agenerate, agenerate_prompt, call_as_llm 19 | 20 | .. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface `. 🏃 21 | 22 | The :py:class:`Runnable Interface ` has additional methods that are available on runnables, such as :py:meth:`with_types `, :py:meth:`with_retry `, :py:meth:`assign `, :py:meth:`bind `, :py:meth:`get_graph `, and more. 23 | -------------------------------------------------------------------------------- /docs/templates/typeddict.rst: -------------------------------------------------------------------------------- 1 | {{ objname }} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block attributes %} 9 | {% for item in attributes %} 10 | .. autoattribute:: {{ item }} 11 | {% endfor %} 12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | set shell := ["bash", "-c"] 2 | 3 | # Default target executed when no arguments are given. 4 | [private] 5 | default: 6 | @just --list 7 | 8 | docs: 9 | uv run sphinx-build -T -b html docs docs/_build/html 10 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | --- 4 | 5 | ## Changes in version 0.6.1 (2025/05/12) 6 | 7 | - Improve robustness of `MongoDBDatabase.run`. 8 | 9 | ## Changes in version 0.6.1 (2025/04/16) 10 | 11 | - Fixed a syntax error in a docstring. 12 | - Fixed some incorrect typings. 13 | - Added detail to README 14 | - Added MongoDBDocStore to README. 15 | 16 | ## Changes in version 0.6 (2025/03/26) 17 | 18 | - Added Natural language to MQL Database tool. 19 | - Added `MongoDBAtlasSelfQueryRetriever`. 20 | - Added logic for vector stores to optionally create vector search indexes. 21 | - Added `close()` methods to classes to ensure proper cleanup of resources. 22 | - Changed the default `batch_size` to 100 to align with resource constraints on 23 | AI model APIs. 24 | 25 | ## Changes in version 0.5 (2025/02/25) 26 | 27 | - Added GraphRAG support via `MongoDBGraphStore`. 28 | 29 | ## Changes in version 0.4 (2025/01/09) 30 | 31 | - Added support for `MongoDBRecordManager`. 32 | - Added support for `MongoDBLoader`. 33 | - Added support for `numpy 2.0`. 34 | - Added zizmor GitHub Actions security scanning. 35 | - Added local LLM support for testing. 36 | 37 | ## Changes in version 0.3 (2024/12/13) 38 | 39 | - Added support for `MongoDBAtlasParentDocumentRetriever`. 40 | - Migrated to https://github.com/langchain-ai/langchain-mongodb. 41 | 42 | ## Changes in version 0.2 (2024/09/13) 43 | 44 | - Added support for `MongoDBAtlasFullTextSearchRetriever` and `MongoDBAtlasHybridSearchRetriever`. 45 | 46 | ## Changes in version 0.1 (2024/02/29) 47 | 48 | - Initial release, added support for `MongoDBAtlasVectorSearch`. 49 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/README.md: -------------------------------------------------------------------------------- 1 | from libs.community.tests.unit_tests.chains.test_pebblo_retrieval import retriever 2 | 3 | # langchain-mongodb 4 | 5 | # Installation 6 | ``` 7 | pip install -U langchain-mongodb 8 | ``` 9 | 10 | # Usage 11 | - [Integrate Atlas Vector Search with LangChain](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas. 12 | 13 | ## Using MongoDBAtlasVectorSearch 14 | ```python 15 | import os 16 | from langchain_mongodb import MongoDBAtlasVectorSearch 17 | from langchain_openai import OpenAIEmbeddings 18 | 19 | # Pull MongoDB Atlas URI from environment variables 20 | MONGODB_ATLAS_CONNECTION_STRING = os.environ["MONGODB_CONNECTION_STRING"] 21 | DB_NAME = "langchain_db" 22 | COLLECTION_NAME = "test" 23 | VECTOR_SEARCH_INDEX_NAME = "index_name" 24 | 25 | MODEL_NAME = "text-embedding-3-large" 26 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] 27 | 28 | 29 | vectorstore = MongoDBAtlasVectorSearch.from_connection_string( 30 | connection_string=MONGODB_ATLAS_CONNECTION_STRING, 31 | namespace=DB_NAME + "." + COLLECTION_NAME, 32 | embedding=OpenAIEmbeddings(model=MODEL_NAME), 33 | index_name=VECTOR_SEARCH_INDEX_NAME, 34 | ) 35 | 36 | retrieved_docs = vectorstore.similarity_search( 37 | "How do I deploy MongoDBAtlasVectorSearch in our production environment?") 38 | ``` 39 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/justfile: -------------------------------------------------------------------------------- 1 | set shell := ["bash", "-c"] 2 | set dotenv-load 3 | set dotenv-filename := "../../.local_atlas_uri" 4 | 5 | # Default target executed when no arguments are given. 6 | [private] 7 | default: 8 | @just --list 9 | 10 | install: 11 | uv sync --frozen 12 | 13 | [group('test')] 14 | integration_tests *args="": 15 | uv run pytest tests/integration_tests/ {{args}} 16 | 17 | [group('test')] 18 | unit_tests *args="": 19 | uv run pytest tests/unit_tests {{args}} 20 | 21 | [group('test')] 22 | tests *args="": 23 | uv run pytest {{args}} 24 | 25 | [group('test')] 26 | test_watch filename: 27 | uv run ptw --snapshot-update --now . -- -vv {{filename}} 28 | 29 | [group('lint')] 30 | lint: 31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files 32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files 33 | 34 | [group('lint')] 35 | typing: 36 | uv run mypy --install-types --non-interactive . 37 | 38 | [group('lint')] 39 | codespell: 40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files 41 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate your operational database and vector search in a single, unified, 3 | fully managed platform with full vector database capabilities on MongoDB Atlas. 4 | 5 | 6 | Store your operational data, metadata, and vector embeddings in oue VectorStore, 7 | MongoDBAtlasVectorSearch. 8 | Insert into a Chain via a Vector, FullText, or Hybrid Retriever. 9 | """ 10 | 11 | from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache 12 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory 13 | from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch 14 | 15 | __all__ = [ 16 | "MongoDBAtlasVectorSearch", 17 | "MongoDBChatMessageHistory", 18 | "MongoDBCache", 19 | "MongoDBAtlasSemanticCache", 20 | ] 21 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/agent_toolkit/__init__.py: -------------------------------------------------------------------------------- 1 | from .database import MongoDBDatabase 2 | from .prompt import MONGODB_AGENT_SYSTEM_PROMPT 3 | from .toolkit import MongoDBDatabaseToolkit 4 | 5 | __all__ = ["MongoDBDatabaseToolkit", "MongoDBDatabase", "MONGODB_AGENT_SYSTEM_PROMPT"] 6 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/agent_toolkit/prompt.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | 4 | MONGODB_AGENT_SYSTEM_PROMPT = """You are an agent designed to interact with a MongoDB database. 5 | Given an input question, create a syntactically correct MongoDB query to run, then look at the results of the query and return the answer. 6 | Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. 7 | You can order the results by a relevant field to return the most interesting examples in the database. 8 | Never query for all the fields from a specific collection, only ask for the relevant fields given the question. 9 | 10 | You have access to tools for interacting with the database. 11 | Only use the below tools. Only use the information returned by the below tools to construct your final answer. 12 | You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 13 | 14 | DO NOT make any update, insert, or delete operations. 15 | 16 | The query MUST include the collection name and the contents of the aggregation pipeline. 17 | 18 | An example query looks like: 19 | 20 | ```python 21 | db.Invoice.aggregate([ {{ "$group": {{ _id: "$BillingCountry", "totalSpent": {{ "$sum": "$Total" }} }} }}, {{ "$sort": {{ "totalSpent": -1 }} }}, {{ "$limit": 5 }} ]) 22 | ``` 23 | 24 | To start you should ALWAYS look at the collections in the database to see what you can query. 25 | Do NOT skip this step. 26 | Then you should query the schema of the most relevant collections.""" 27 | 28 | MONGODB_SUFFIX = """Begin! 29 | 30 | Question: {input} 31 | Thought: I should look at the collections in the database to see what I can query. Then I should query the schema of the most relevant collections. 32 | {agent_scratchpad}""" 33 | 34 | MONGODB_FUNCTIONS_SUFFIX = """I should look at the collections in the database to see what I can query. Then I should query the schema of the most relevant collections.""" 35 | 36 | 37 | MONGODB_QUERY_CHECKER = """ 38 | {query} 39 | 40 | Double check the MongoDB query above for common mistakes, including: 41 | - Missing content in the aggegregation pipeline 42 | - Improperly quoting identifiers 43 | - Improperly quoting operators 44 | - The content in the aggregation pipeline is not valid JSON 45 | 46 | If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. 47 | 48 | Output the final MongoDB query only. 49 | 50 | MongoDB Query: """ 51 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/agent_toolkit/tool.py: -------------------------------------------------------------------------------- 1 | """Tools for interacting with a MongoDB database.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, Dict, Optional, Type 6 | 7 | from langchain_core.callbacks import ( 8 | CallbackManagerForToolRun, 9 | ) 10 | from langchain_core.language_models import BaseLanguageModel 11 | from langchain_core.prompts import PromptTemplate 12 | from langchain_core.tools import BaseTool 13 | from pydantic import BaseModel, ConfigDict, Field, model_validator 14 | from pymongo.cursor import Cursor 15 | 16 | from .database import MongoDBDatabase 17 | from .prompt import MONGODB_QUERY_CHECKER 18 | 19 | 20 | class BaseMongoDBDatabaseTool(BaseModel): 21 | """Base tool for interacting with a MongoDB database.""" 22 | 23 | db: MongoDBDatabase = Field(exclude=True) 24 | 25 | model_config = ConfigDict( 26 | arbitrary_types_allowed=True, 27 | ) 28 | 29 | 30 | class _QueryMongoDBDatabaseToolInput(BaseModel): 31 | query: str = Field(..., description="A detailed and correct MongoDB query.") 32 | 33 | 34 | class QueryMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override] 35 | """Tool for querying a MongoDB database.""" 36 | 37 | name: str = "mongodb_query" 38 | description: str = """ 39 | Execute a MongoDB query against the database and get back the result. 40 | If the query is not correct, an error message will be returned. 41 | If an error is returned, rewrite the query, check the query, and try again. 42 | """ 43 | args_schema: Type[BaseModel] = _QueryMongoDBDatabaseToolInput 44 | 45 | def _run(self, query: str, **kwargs: Any) -> str | Cursor: 46 | """Execute the query, return the results or an error message.""" 47 | return self.db.run_no_throw(query) 48 | 49 | 50 | class _InfoMongoDBDatabaseToolInput(BaseModel): 51 | collection_names: str = Field( 52 | ..., 53 | description=( 54 | "A comma-separated list of the collection names for which to return the schema. " 55 | "Example input: 'collection1, collection2, collection3'" 56 | ), 57 | ) 58 | 59 | 60 | class InfoMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override] 61 | """Tool for getting metadata about a MongoDB database.""" 62 | 63 | name: str = "mongodb_schema" 64 | description: str = ( 65 | "Get the schema and sample documents for the specified MongoDB collections." 66 | ) 67 | args_schema: Type[BaseModel] = _InfoMongoDBDatabaseToolInput 68 | 69 | def _run( 70 | self, 71 | collection_names: str, 72 | run_manager: Optional[CallbackManagerForToolRun] = None, 73 | ) -> str: 74 | """Get the schema for collections in a comma-separated list.""" 75 | return self.db.get_collection_info_no_throw( 76 | [t.strip() for t in collection_names.split(",")] 77 | ) 78 | 79 | 80 | class _ListMongoDBDatabaseToolInput(BaseModel): 81 | tool_input: str = Field("", description="An empty string") 82 | 83 | 84 | class ListMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override] 85 | """Tool for getting collection names.""" 86 | 87 | name: str = "mongodb_list_collections" 88 | description: str = "Input is an empty string, output is a comma-separated list of collections in the database." 89 | args_schema: Type[BaseModel] = _ListMongoDBDatabaseToolInput 90 | 91 | def _run( 92 | self, 93 | tool_input: str = "", 94 | run_manager: Optional[CallbackManagerForToolRun] = None, 95 | ) -> str: 96 | """Get a comma-separated list of collection names.""" 97 | return ", ".join(self.db.get_usable_collection_names()) 98 | 99 | 100 | class _QueryMongoDBCheckerToolInput(BaseModel): 101 | query: str = Field(..., description="A detailed and MongoDB query to be checked.") 102 | 103 | 104 | class QueryMongoDBCheckerTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override] 105 | """Use an LLM to check if a query is correct. 106 | Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" 107 | 108 | template: str = MONGODB_QUERY_CHECKER 109 | llm: BaseLanguageModel 110 | prompt: PromptTemplate = Field(init=False) 111 | name: str = "mongodb_query_checker" 112 | description: str = """ 113 | Use this tool to double check if your query is correct before executing it. 114 | Always use this tool before executing a query with mongodb_query! 115 | """ 116 | args_schema: Type[BaseModel] = _QueryMongoDBCheckerToolInput 117 | 118 | @model_validator(mode="before") 119 | @classmethod 120 | def initialize_prompt(cls, values: Dict[str, Any]) -> Any: 121 | if "prompt" not in values: 122 | values["prompt"] = PromptTemplate( 123 | template=MONGODB_QUERY_CHECKER, input_variables=["query"] 124 | ) 125 | 126 | if values["prompt"].input_variables != ["query"]: 127 | raise ValueError( 128 | "Prompt for QueryCheckerTool must have input variables ['query']" 129 | ) 130 | 131 | return values 132 | 133 | def _run( 134 | self, 135 | query: str, 136 | run_manager: Optional[CallbackManagerForToolRun] = None, 137 | ) -> str: 138 | """Use the LLM to check the query.""" 139 | # TODO: check the query using pymongo first. 140 | chain = self.prompt | self.llm 141 | return chain.invoke(query) # type:ignore[arg-type] 142 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/agent_toolkit/toolkit.py: -------------------------------------------------------------------------------- 1 | """Toolkit for interacting with an MongoDB database.""" 2 | 3 | from typing import List 4 | 5 | from langchain_core.caches import BaseCache as BaseCache 6 | from langchain_core.callbacks import Callbacks as Callbacks 7 | from langchain_core.language_models import BaseLanguageModel 8 | from langchain_core.tools import BaseTool 9 | from langchain_core.tools.base import BaseToolkit 10 | from pydantic import ConfigDict, Field 11 | 12 | from .database import MongoDBDatabase 13 | from .tool import ( 14 | InfoMongoDBDatabaseTool, 15 | ListMongoDBDatabaseTool, 16 | QueryMongoDBCheckerTool, 17 | QueryMongoDBDatabaseTool, 18 | ) 19 | 20 | 21 | class MongoDBDatabaseToolkit(BaseToolkit): 22 | """MongoDBDatabaseToolkit for interacting with MongoDB databases. 23 | 24 | Setup: 25 | Install ``langchain-mongodb``. 26 | 27 | .. code-block:: bash 28 | 29 | pip install -U langchain-mongodb 30 | 31 | Key init args: 32 | db: MongoDBDatabase 33 | The MongoDB database. 34 | llm: BaseLanguageModel 35 | The language model (for use with QueryMongoDBCheckerTool) 36 | 37 | Instantiate: 38 | .. code-block:: python 39 | 40 | from langchain_mongodb.agent_toolkit.toolkit import MongoDBDatabaseToolkit 41 | from langchain_mongodb.agent_toolkit.database import MongoDBDatabase 42 | from langchain_openai import ChatOpenAI 43 | 44 | db = MongoDBDatabase.from_connection_string("mongodb://localhost:27017/chinook") 45 | llm = ChatOpenAI(temperature=0) 46 | 47 | toolkit = MongoDBDatabaseToolkit(db=db, llm=llm) 48 | 49 | Tools: 50 | .. code-block:: python 51 | 52 | toolkit.get_tools() 53 | 54 | Use within an agent: 55 | .. code-block:: python 56 | 57 | from langchain import hub 58 | from langgraph.prebuilt import create_react_agent 59 | from langchain_mongodb.agent_toolkit import MONGODB_AGENT_SYSTEM_PROMPT 60 | 61 | # Pull prompt (or define your own) 62 | system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5) 63 | 64 | # Create agent 65 | agent_executor = create_react_agent( 66 | llm, toolkit.get_tools(), state_modifier=system_message 67 | ) 68 | 69 | # Query agent 70 | example_query = "Which country's customers spent the most?" 71 | 72 | events = agent_executor.stream( 73 | {"messages": [("user", example_query)]}, 74 | stream_mode="values", 75 | ) 76 | for event in events: 77 | event["messages"][-1].pretty_print() 78 | """ # noqa: E501 79 | 80 | db: MongoDBDatabase = Field(exclude=True) 81 | llm: BaseLanguageModel = Field(exclude=True) 82 | 83 | model_config = ConfigDict( 84 | arbitrary_types_allowed=True, 85 | ) 86 | 87 | def get_tools(self) -> List[BaseTool]: 88 | """Get the tools in the toolkit.""" 89 | list_mongodb_database_tool = ListMongoDBDatabaseTool(db=self.db) 90 | info_mongodb_database_tool_description = ( 91 | "Input to this tool is a comma-separated list of collections, output is the " 92 | "schema and sample rows for those collections. " 93 | "Be sure that the collectionss actually exist by calling " 94 | f"{list_mongodb_database_tool.name} first! " 95 | "Example Input: collection1, collection2, collection3" 96 | ) 97 | info_mongodb_database_tool = InfoMongoDBDatabaseTool( 98 | db=self.db, description=info_mongodb_database_tool_description 99 | ) 100 | query_mongodb_database_tool_description = ( 101 | "Input to this tool is a detailed and correct MongoDB query, output is a " 102 | "result from the database. If the query is not correct, an error message " 103 | "will be returned. If an error is returned, rewrite the query, check the " 104 | "query, and try again. If you encounter an issue with Unknown column " 105 | f"'xxxx' in 'field list', use {info_mongodb_database_tool.name} " 106 | "to query the correct collections fields." 107 | ) 108 | query_mongodb_database_tool = QueryMongoDBDatabaseTool( 109 | db=self.db, description=query_mongodb_database_tool_description 110 | ) 111 | query_mongodb_checker_tool_description = ( 112 | "Use this tool to double check if your query is correct before executing " 113 | "it. Always use this tool before executing a query with " 114 | f"{query_mongodb_database_tool.name}!" 115 | ) 116 | query_mongodb_checker_tool = QueryMongoDBCheckerTool( 117 | db=self.db, llm=self.llm, description=query_mongodb_checker_tool_description 118 | ) 119 | return [ 120 | query_mongodb_database_tool, 121 | info_mongodb_database_tool, 122 | list_mongodb_database_tool, 123 | query_mongodb_checker_tool, 124 | ] 125 | 126 | def get_context(self) -> dict: 127 | """Return db context that you may want in agent prompt.""" 128 | return self.db.get_context() 129 | 130 | 131 | MongoDBDatabaseToolkit.model_rebuild() 132 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from importlib.metadata import version 4 | from typing import Dict, List, Optional 5 | 6 | from langchain_core.chat_history import BaseChatMessageHistory 7 | from langchain_core.messages import ( 8 | BaseMessage, 9 | message_to_dict, 10 | messages_from_dict, 11 | ) 12 | from pymongo import MongoClient, errors 13 | from pymongo.driver_info import DriverInfo 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | DEFAULT_DBNAME = "chat_history" 18 | DEFAULT_COLLECTION_NAME = "message_store" 19 | DEFAULT_SESSION_ID_KEY = "SessionId" 20 | DEFAULT_HISTORY_KEY = "History" 21 | 22 | 23 | class MongoDBChatMessageHistory(BaseChatMessageHistory): 24 | """Chat message history that stores history in MongoDB. 25 | 26 | Setup: 27 | Install ``langchain-mongodb`` python package. 28 | 29 | .. code-block:: bash 30 | 31 | pip install langchain-mongodb 32 | 33 | Instantiate: 34 | .. code-block:: python 35 | 36 | from langchain_mongodb import MongoDBChatMessageHistory 37 | 38 | 39 | history = MongoDBChatMessageHistory( 40 | connection_string="mongodb://your-host:your-port/", # mongodb://localhost:27017/ 41 | session_id = "your-session-id", 42 | ) 43 | 44 | Add and retrieve messages: 45 | .. code-block:: python 46 | 47 | # Add single message 48 | history.add_message(message) 49 | 50 | # Add batch messages 51 | history.add_messages([message1, message2, message3, ...]) 52 | 53 | # Add human message 54 | history.add_user_message(human_message) 55 | 56 | # Add ai message 57 | history.add_ai_message(ai_message) 58 | 59 | # Retrieve messages 60 | messages = history.messages 61 | """ # noqa: E501 62 | 63 | def __init__( 64 | self, 65 | connection_string: Optional[str], 66 | session_id: str, 67 | database_name: str = DEFAULT_DBNAME, 68 | collection_name: str = DEFAULT_COLLECTION_NAME, 69 | *, 70 | session_id_key: str = DEFAULT_SESSION_ID_KEY, 71 | history_key: str = DEFAULT_HISTORY_KEY, 72 | create_index: bool = True, 73 | history_size: Optional[int] = None, 74 | index_kwargs: Optional[Dict] = None, 75 | client: Optional[MongoClient] = None, 76 | ): 77 | """Initialize with a MongoDBChatMessageHistory instance. 78 | 79 | Args: 80 | connection_string: Optional[str] 81 | connection string to connect to MongoDB. Can be None if mongo_client is 82 | provided. 83 | session_id: str 84 | arbitrary key that is used to store the messages of 85 | a single chat session. 86 | database_name: Optional[str] 87 | name of the database to use. 88 | collection_name: Optional[str] 89 | name of the collection to use. 90 | session_id_key: Optional[str] 91 | name of the field that stores the session id. 92 | history_key: Optional[str] 93 | name of the field that stores the chat history. 94 | create_index: Optional[bool] 95 | whether to create an index on the session id field. 96 | history_size: Optional[int] 97 | count of (most recent) messages to fetch from MongoDB. 98 | index_kwargs: Optional[Dict] 99 | additional keyword arguments to pass to the index creation. 100 | client: Optional[MongoClient] 101 | an existing MongoClient instance. 102 | If provided, connection_string is ignored. 103 | """ 104 | self.session_id = session_id 105 | self.database_name = database_name 106 | self.collection_name = collection_name 107 | self.session_id_key = session_id_key 108 | self.history_key = history_key 109 | self.history_size = history_size 110 | 111 | if client: 112 | if connection_string: 113 | raise ValueError("Must provide connection_string or client, not both") 114 | self.client = client 115 | elif connection_string: 116 | try: 117 | self.client = MongoClient( 118 | connection_string, 119 | driver=DriverInfo( 120 | name="Langchain", version=version("langchain-mongodb") 121 | ), 122 | ) 123 | except errors.ConnectionFailure as error: 124 | logger.error(error) 125 | else: 126 | raise ValueError("Either connection_string or client must be provided") 127 | 128 | self.db = self.client[database_name] 129 | self.collection = self.db[collection_name] 130 | 131 | if create_index: 132 | index_kwargs = index_kwargs or {} 133 | self.collection.create_index(self.session_id_key, **index_kwargs) 134 | 135 | @property 136 | def messages(self) -> List[BaseMessage]: # type: ignore 137 | """Retrieve the messages from MongoDB""" 138 | try: 139 | if self.history_size is None: 140 | cursor = self.collection.find({self.session_id_key: self.session_id}) 141 | else: 142 | skip_count = max( 143 | 0, 144 | self.collection.count_documents( 145 | {self.session_id_key: self.session_id} 146 | ) 147 | - self.history_size, 148 | ) 149 | cursor = self.collection.find( 150 | {self.session_id_key: self.session_id}, skip=skip_count 151 | ) 152 | except errors.OperationFailure as error: 153 | logger.error(error) 154 | 155 | if cursor: 156 | items = [json.loads(document[self.history_key]) for document in cursor] 157 | else: 158 | items = [] 159 | 160 | messages = messages_from_dict(items) 161 | return messages 162 | 163 | def close(self) -> None: 164 | """Close the resources used by the MongoDBChatMessageHistory.""" 165 | self.client.close() 166 | 167 | def add_message(self, message: BaseMessage) -> None: 168 | """Append the message to the record in MongoDB""" 169 | try: 170 | self.collection.insert_one( 171 | { 172 | self.session_id_key: self.session_id, 173 | self.history_key: json.dumps(message_to_dict(message)), 174 | } 175 | ) 176 | except errors.WriteError as err: 177 | logger.error(err) 178 | 179 | def clear(self) -> None: 180 | """Clear session memory from MongoDB""" 181 | try: 182 | self.collection.delete_many({self.session_id_key: self.session_id}) 183 | except errors.WriteError as err: 184 | logger.error(err) 185 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/docstores.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from importlib.metadata import version 4 | from typing import Any, Generator, Iterable, Iterator, List, Optional, Sequence, Union 5 | 6 | from langchain_core.documents import Document 7 | from langchain_core.stores import BaseStore 8 | from pymongo import MongoClient 9 | from pymongo.collection import Collection 10 | from pymongo.driver_info import DriverInfo 11 | 12 | from langchain_mongodb.utils import ( 13 | make_serializable, 14 | ) 15 | 16 | DEFAULT_INSERT_BATCH_SIZE = 100 17 | 18 | 19 | class MongoDBDocStore(BaseStore): 20 | """MongoDB Collection providing BaseStore interface. 21 | 22 | This is meant to be treated as a key-value store: [str, Document] 23 | 24 | 25 | In a MongoDB Collection, the field name _id is reserved for use as a primary key. 26 | Its value must be unique in the collection, is immutable, 27 | and may be of any type other than an array or regex. 28 | As this field is always indexed, it is the natural choice to hold keys. 29 | 30 | The value will be held simply in a field called "value". 31 | It can contain any valid BSON type. 32 | 33 | Example key value pair: {"_id": "foo", "value": "bar"}. 34 | """ 35 | 36 | def __init__(self, collection: Collection, text_key: str = "page_content") -> None: 37 | self.collection = collection 38 | self._text_key = text_key 39 | 40 | @classmethod 41 | def from_connection_string( 42 | cls, 43 | connection_string: str, 44 | namespace: str, 45 | **kwargs: Any, 46 | ) -> MongoDBDocStore: 47 | """Construct a Key-Value Store from a MongoDB connection URI. 48 | 49 | Args: 50 | connection_string: A valid MongoDB connection URI. 51 | namespace: A valid MongoDB namespace (in form f"{database}.{collection}") 52 | 53 | Returns: 54 | A new MongoDBDocStore instance. 55 | """ 56 | client: MongoClient = MongoClient( 57 | connection_string, 58 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), 59 | ) 60 | db_name, collection_name = namespace.split(".") 61 | collection = client[db_name][collection_name] 62 | return cls(collection=collection) 63 | 64 | def close(self) -> None: 65 | """Close the resources used by the MongoDBDocStore.""" 66 | self.collection.database.client.close() 67 | 68 | def mget(self, keys: Sequence[str]) -> list[Optional[Document]]: 69 | """Get the values associated with the given keys. 70 | 71 | If a key is not found in the store, the corresponding value will be None. 72 | As returning None is not the default find behavior, we form a dictionary 73 | and loop over the keys. 74 | 75 | Args: 76 | keys (Sequence[str]): A sequence of keys. 77 | 78 | Returns: List of values associated with the given keys. 79 | """ 80 | found_docs = {} 81 | for res in self.collection.find({"_id": {"$in": keys}}): 82 | text = res.pop(self._text_key) 83 | key = res.pop("_id") 84 | make_serializable(res) 85 | found_docs[key] = Document(page_content=text, metadata=res) 86 | return [found_docs.get(key, None) for key in keys] 87 | 88 | def mset( 89 | self, 90 | key_value_pairs: Sequence[tuple[str, Document]], 91 | batch_size: int = DEFAULT_INSERT_BATCH_SIZE, 92 | ) -> None: 93 | """Set the values for the given keys. 94 | 95 | Args: 96 | key_value_pairs: A sequence of key-value pairs. 97 | batch_size: Number of documents to insert at a time. 98 | Tuning this may help with performance and sidestep MongoDB limits. 99 | """ 100 | keys, docs = zip(*key_value_pairs) 101 | n_docs = len(docs) 102 | start = 0 103 | for end in range(batch_size, n_docs + batch_size, batch_size): 104 | texts, metadatas = zip( 105 | *[(doc.page_content, doc.metadata) for doc in docs[start:end]] 106 | ) 107 | self.insert_many(texts=texts, metadatas=metadatas, ids=keys[start:end]) # type: ignore 108 | start = end 109 | 110 | def mdelete(self, keys: Sequence[str]) -> None: 111 | """Delete the given keys and their associated values. 112 | 113 | Args: 114 | keys (Sequence[str]): A sequence of keys to delete. 115 | """ 116 | self.collection.delete_many({"_id": {"$in": keys}}) 117 | 118 | def yield_keys( 119 | self, *, prefix: Optional[str] = None 120 | ) -> Union[Iterator[str], Iterator[str]]: 121 | """Get an iterator over keys that match the given prefix. 122 | 123 | Args: 124 | prefix (str): The prefix to match. 125 | 126 | Yields: 127 | Iterator[str | str]: An iterator over keys that match the given prefix. 128 | This method is allowed to return an iterator over either str 129 | depending on what makes more sense for the given store. 130 | """ 131 | query = {"_id": {"$regex": f"^{prefix}"}} if prefix else {} 132 | for document in self.collection.find(query, {"_id": 1}): 133 | yield document["_id"] 134 | 135 | def insert_many( 136 | self, 137 | texts: Union[List[str], Iterable[str]], 138 | metadatas: Union[List[dict], Generator[dict, Any, Any]], 139 | ids: List[str], 140 | ) -> None: 141 | """Bulk insert single batch of texts, embeddings, and optionally ids. 142 | 143 | insert_many in PyMongo does not overwrite existing documents. 144 | Instead, it attempts to insert each document as a new document. 145 | If a document with the same _id already exists in the collection, 146 | an error will be raised for that specific document. However, other documents 147 | in the batch that do not have conflicting _ids will still be inserted. 148 | """ 149 | to_insert = [ 150 | {"_id": i, self._text_key: t, **m} for i, t, m in zip(ids, texts, metadatas) 151 | ] 152 | self.collection.insert_many(to_insert) # type: ignore 153 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/graphrag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/langchain_mongodb/graphrag/__init__.py -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/graphrag/example_templates.py: -------------------------------------------------------------------------------- 1 | """These prompts serve as defaults, and templates when you wish to override them. 2 | 3 | LLM performance can be greatly increased by adding domain-specific examples. 4 | """ 5 | 6 | entity_extraction = """ 7 | ## Examples 8 | Use the following examples to guide your work. 9 | 10 | ### Example 1: Constrained entity and relationship types: Person, Friend 11 | #### Input 12 | Alice Palace has been the CEO of MongoDB since January 1, 2018. 13 | She maintains close friendships with Jarnail Singh, whom she has known since May 1, 2019, 14 | and Jasbinder Kaur, who she has been seeing weekly since May 1, 2015. 15 | 16 | #### Output 17 | (If `allowed_entity_types` is ["Person"] and `allowed_relationship_types` is ["Friend"]) 18 | {{ 19 | "entities": [ 20 | {{ 21 | "_id": "Alice Palace", 22 | "type": "Person", 23 | "attributes": {{ 24 | "job": ["CEO of MongoDB"], 25 | "startDate": ["2018-01-01"] 26 | }}, 27 | "relationships": {{ 28 | "targets": ["Jasbinder Kaur", "Jarnail Singh"], 29 | "types": ["Friend", "Friend"], 30 | "attributes": [ 31 | {{ "since": ["2019-05-01"] }}, 32 | {{ "since": ["2015-05-01"], "frequency": ["weekly"] }} 33 | ] 34 | }} 35 | }}, 36 | {{ 37 | "_id": "Jarnail Singh", 38 | "type": "Person", 39 | "relationships": {{ 40 | "targets": ["Alice Palace"], 41 | "types": ["Friend"], 42 | "attributes": [{{ "since": ["2019-05-01"] }}] 43 | }} 44 | }}, 45 | {{ 46 | "_id": "Jasbinder Kaur", 47 | "type": "Person", 48 | "relationships": {{ 49 | "targets": ["Alice Palace"], 50 | "types": ["Friend"], 51 | "attributes": [{{ "since": ["2015-05-01"], "frequency": ["weekly"] }}] 52 | }} 53 | }} 54 | ] 55 | }} 56 | 57 | ### Example 2: Event Extraction 58 | #### Input 59 | The 2022 OpenAI Developer Conference took place on October 10, 2022, in San Francisco. 60 | Elon Musk and Sam Altman were keynote speakers at the event. 61 | 62 | #### Output 63 | (If `allowed_entity_types` is ["Event", "Person"] and `allowed_relationship_types` is ["Speaker"]) 64 | {{ 65 | "entities": [ 66 | {{ 67 | "_id": "2022 OpenAI Developer Conference", 68 | "type": "Event", 69 | "attributes": {{ 70 | "date": ["2022-10-10"], 71 | "location": ["San Francisco"] 72 | }}, 73 | "relationships": {{ 74 | "targets": ["Elon Musk", "Sam Altman"], 75 | "types": ["Speaker", "Speaker"] 76 | }} 77 | }}, 78 | {{ "_id": "Elon Musk", "type": "Person" }}, 79 | {{ "_id": "Sam Altman", "type": "Person" }} 80 | ] 81 | }} 82 | 83 | ### Example 3: Concept Relationship 84 | #### Input 85 | Quantum computing is a field of study that focuses on developing computers based on the principles of quantum mechanics. 86 | 87 | #### Output 88 | (If `allowed_entity_types` is ["Concept"] and `allowed_relationship_types` is ["Based On"]) 89 | {{ 90 | "entities": [ 91 | {{ 92 | "_id": "Quantum Computing", 93 | "type": "Concept", 94 | "relationships": {{ 95 | "targets": ["Quantum Mechanics"], 96 | "types": ["Based On"] 97 | }} 98 | }}, 99 | {{ "_id": "Quantum Mechanics", "type": "Concept" }} 100 | ] 101 | }} 102 | 103 | ### Example 4: News Article 104 | #### Input 105 | On March 1, 2023, NASA successfully launched the Artemis II mission, sending astronauts to orbit the Moon. 106 | NASA Administrator Bill Nelson praised the historic achievement. 107 | 108 | #### Output 109 | (If `allowed_entity_types` is ["Organization", "Event", "Person"] and `allowed_relationship_types` is ["Managed By", "Praised By"]) 110 | {{ 111 | "entities": [ 112 | {{ 113 | "_id": "Artemis II Mission", 114 | "type": "Event", 115 | "attributes": {{ "date": ["2023-03-01"] }}, 116 | "relationships": {{ 117 | "targets": ["NASA"], 118 | "types": ["Managed By"] 119 | }} 120 | }}, 121 | {{ 122 | "_id": "NASA", 123 | "type": "Organization" 124 | }}, 125 | {{ 126 | "_id": "Bill Nelson", 127 | "type": "Person", 128 | "relationships": {{ 129 | "targets": ["Artemis II Mission"], 130 | "types": ["Praised By"] 131 | }} 132 | }} 133 | ] 134 | }} 135 | 136 | ### Example 5: Technical Article 137 | #### Input 138 | Rust is a programming language that guarantees memory safety without requiring garbage collection. 139 | It is known for its strong ownership model, which prevents data races. 140 | 141 | #### Output 142 | (If `allowed_entity_types` is ["Programming Language", "Concept"] and `allowed_relationship_types` is ["Ensures", "Uses"]) 143 | {{ 144 | "entities": [ 145 | {{ 146 | "_id": "Rust", 147 | "type": "Programming Language", 148 | "relationships": {{ 149 | "targets": ["Memory Safety"], 150 | "types": ["Ensures"] 151 | }} 152 | }}, 153 | {{ 154 | "_id": "Memory Safety", 155 | "type": "Concept", 156 | "relationships": {{ 157 | "targets": ["Ownership Model"], 158 | "types": ["Uses"] 159 | }} 160 | }}, 161 | {{ "_id": "Ownership Model", "type": "Concept" }} 162 | ] 163 | }} 164 | """ 165 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following contains the JSON Schema for the Entities in the Collection representing the Knowledge Graph. 3 | If validate is set to True, the schema is enforced upon insert and update. 4 | See `$jsonSchema `_ 5 | 6 | The following defines the default entity_schema. 7 | It allows all possible values of "type" and "relationship". 8 | 9 | If allowed_entity_types: List[str] is given to MongoDBGraphStore's constructor, 10 | then `self._schema["properties"]["type"]["enum"] = allowed_entity_types` is added. 11 | 12 | If allowed_relationship_types: List[str] is given to MongoDBGraphStore's constructor, 13 | additionalProperties is set to False, and relationship schema is provided for each key. 14 | """ 15 | 16 | entity_schema = { 17 | "bsonType": "object", 18 | "required": ["_id", "type"], 19 | "properties": { 20 | "_id": { 21 | "bsonType": "string", 22 | "description": "Unique identifier for the entity", 23 | }, 24 | "type": { 25 | "bsonType": "string", 26 | "description": "Type of the entity (e.g., 'Person', 'Organization')", 27 | # Note: When constrained, predefined types are added. For example: 28 | # "enum": ["Person", "Organization", "Location", "Event"], 29 | }, 30 | "attributes": { 31 | "bsonType": "object", 32 | "description": "Key-value pairs describing the entity", 33 | "additionalProperties": { 34 | "bsonType": "array", 35 | "items": {"bsonType": "string"}, # Enforce array of strings 36 | }, 37 | }, 38 | "relationships": { 39 | "bsonType": "object", 40 | "description": "Key-value pairs of relationships", 41 | "required": ["target_ids"], 42 | "properties": { 43 | "target_ids": { 44 | "bsonType": "array", 45 | "description": "name/_id values of the target entities", 46 | "items": {"bsonType": "string"}, 47 | }, 48 | "types": { 49 | "bsonType": "array", 50 | "description": "An array of relationships to corresponding target_ids (in same array position).", 51 | "items": {"bsonType": "string"}, 52 | # Note: When constrained, predefined types are added. For example: 53 | # "enum": ["used_in", "owns", "written_by", "located_in"], # Predefined types 54 | }, 55 | "attributes": { 56 | "bsonType": "array", 57 | "description": "An array of attributes describing the relationships to corresponding target_ids (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.", 58 | "items": { 59 | "bsonType": "object", 60 | "additionalProperties": { 61 | "bsonType": "array", 62 | "items": {"bsonType": "string"}, 63 | }, 64 | }, 65 | }, 66 | }, 67 | "additionalProperties": False, 68 | }, 69 | }, 70 | } 71 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/pipelines.py: -------------------------------------------------------------------------------- 1 | """Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search 2 | 3 | See the following for more: 4 | - `Full-Text Search `_ 5 | - `MongoDB Operators `_ 6 | - `Vector Search `_ 7 | - `Filter Example `_ 8 | """ 9 | 10 | from typing import Any, Dict, List, Optional 11 | 12 | 13 | def text_search_stage( 14 | query: str, 15 | search_field: str, 16 | index_name: str, 17 | limit: Optional[int] = None, 18 | filter: Optional[Dict[str, Any]] = None, 19 | include_scores: Optional[bool] = True, 20 | **kwargs: Any, 21 | ) -> List[Dict[str, Any]]: # noqa: E501 22 | """Full-Text search using Lucene's standard (BM25) analyzer 23 | 24 | Args: 25 | query: Input text to search for 26 | search_field: Field in Collection that will be searched 27 | index_name: Atlas Search Index name 28 | limit: Maximum number of documents to return. Default of no limit 29 | filter: Any MQL match expression comparing an indexed field 30 | include_scores: Scores provide measure of relative relevance 31 | 32 | Returns: 33 | Dictionary defining the $search stage 34 | """ 35 | pipeline = [ 36 | { 37 | "$search": { 38 | "index": index_name, 39 | "text": {"query": query, "path": search_field}, 40 | } 41 | } 42 | ] 43 | if filter: 44 | pipeline.append({"$match": filter}) # type: ignore 45 | if include_scores: 46 | pipeline.append({"$set": {"score": {"$meta": "searchScore"}}}) 47 | if limit: 48 | pipeline.append({"$limit": limit}) # type: ignore 49 | 50 | return pipeline # type: ignore 51 | 52 | 53 | def vector_search_stage( 54 | query_vector: List[float], 55 | search_field: str, 56 | index_name: str, 57 | top_k: int = 4, 58 | filter: Optional[Dict[str, Any]] = None, 59 | oversampling_factor: int = 10, 60 | **kwargs: Any, 61 | ) -> Dict[str, Any]: # noqa: E501 62 | """Vector Search Stage without Scores. 63 | 64 | Scoring is applied later depending on strategy. 65 | vector search includes a vectorSearchScore that is typically used. 66 | hybrid uses Reciprocal Rank Fusion. 67 | 68 | Args: 69 | query_vector: List of embedding vector 70 | search_field: Field in Collection containing embedding vectors 71 | index_name: Name of Atlas Vector Search Index tied to Collection 72 | top_k: Number of documents to return 73 | oversampling_factor: this times limit is the number of candidates 74 | filter: MQL match expression comparing an indexed field. 75 | Some operators are not supported. 76 | See `vectorSearch filter docs `_ 77 | 78 | 79 | Returns: 80 | Dictionary defining the $vectorSearch 81 | """ 82 | stage = { 83 | "index": index_name, 84 | "path": search_field, 85 | "queryVector": query_vector, 86 | "numCandidates": top_k * oversampling_factor, 87 | "limit": top_k, 88 | } 89 | if filter: 90 | stage["filter"] = filter 91 | return {"$vectorSearch": stage} 92 | 93 | 94 | def combine_pipelines( 95 | pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str 96 | ) -> None: 97 | """Combines two aggregations into a single result set in-place.""" 98 | if pipeline: 99 | pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}}) 100 | else: 101 | pipeline.extend(stage) 102 | 103 | 104 | def reciprocal_rank_stage( 105 | score_field: str, penalty: float = 0, **kwargs: Any 106 | ) -> List[Dict[str, Any]]: 107 | """Stage adds Reciprocal Rank Fusion weighting. 108 | 109 | First, it pushes documents retrieved from previous stage 110 | into a temporary sub-document. It then unwinds to establish 111 | the rank to each and applies the penalty. 112 | 113 | Args: 114 | score_field: A unique string to identify the search being ranked 115 | penalty: A non-negative float. 116 | extra_fields: Any fields other than text_field that one wishes to keep. 117 | 118 | Returns: 119 | RRF score 120 | """ 121 | 122 | rrf_pipeline = [ 123 | {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, 124 | {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, 125 | { 126 | "$addFields": { 127 | f"docs.{score_field}": { 128 | "$divide": [1.0, {"$add": ["$rank", penalty, 1]}] 129 | }, 130 | "docs.rank": "$rank", 131 | "_id": "$docs._id", 132 | } 133 | }, 134 | {"$replaceRoot": {"newRoot": "$docs"}}, 135 | ] 136 | 137 | return rrf_pipeline # type: ignore 138 | 139 | 140 | def final_hybrid_stage( 141 | scores_fields: List[str], limit: int, **kwargs: Any 142 | ) -> List[Dict[str, Any]]: 143 | """Sum weighted scores, sort, and apply limit. 144 | 145 | Args: 146 | scores_fields: List of fields given to scores of vector and text searches 147 | limit: Number of documents to return 148 | 149 | Returns: 150 | Final aggregation stages 151 | """ 152 | 153 | return [ 154 | {"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}}, 155 | {"$replaceRoot": {"newRoot": "$docs"}}, 156 | {"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}}, 157 | {"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}}, 158 | {"$sort": {"score": -1}}, 159 | {"$limit": limit}, 160 | ] 161 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/langchain_mongodb/py.typed -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | """Search Retrievers of various types. 2 | 3 | Use ``MongoDBAtlasVectorSearch.as_retriever(**)`` 4 | to create MongoDB's core Vector Search Retriever. 5 | """ 6 | 7 | from langchain_mongodb.retrievers.full_text_search import ( 8 | MongoDBAtlasFullTextSearchRetriever, 9 | ) 10 | from langchain_mongodb.retrievers.graphrag import MongoDBGraphRAGRetriever 11 | from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever 12 | from langchain_mongodb.retrievers.parent_document import ( 13 | MongoDBAtlasParentDocumentRetriever, 14 | ) 15 | from langchain_mongodb.retrievers.self_querying import MongoDBAtlasSelfQueryRetriever 16 | 17 | __all__ = [ 18 | "MongoDBAtlasHybridSearchRetriever", 19 | "MongoDBAtlasFullTextSearchRetriever", 20 | "MongoDBAtlasParentDocumentRetriever", 21 | "MongoDBGraphRAGRetriever", 22 | "MongoDBAtlasSelfQueryRetriever", 23 | ] 24 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Any, Dict, List, Optional 2 | 3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun 4 | from langchain_core.documents import Document 5 | from langchain_core.retrievers import BaseRetriever 6 | from pydantic import Field 7 | from pymongo.collection import Collection 8 | 9 | from langchain_mongodb.pipelines import text_search_stage 10 | from langchain_mongodb.utils import make_serializable 11 | 12 | 13 | class MongoDBAtlasFullTextSearchRetriever(BaseRetriever): 14 | """Retriever performs full-text searches using Lucene's standard (BM25) analyzer.""" 15 | 16 | collection: Collection 17 | """MongoDB Collection on an Atlas cluster""" 18 | search_index_name: str 19 | """Atlas Search Index name""" 20 | search_field: str 21 | """Collection field that contains the text to be searched. It must be indexed""" 22 | k: Optional[int] = None 23 | """Number of documents to return. Default is no limit""" 24 | filter: Optional[Dict[str, Any]] = None 25 | """(Optional) List of MQL match expression comparing an indexed field""" 26 | include_scores: bool = True 27 | """If True, include scores that provide measure of relative relevance""" 28 | top_k: Annotated[ 29 | Optional[int], Field(deprecated='top_k is deprecated, use "k" instead') 30 | ] = None 31 | """Number of documents to return. Default is no limit""" 32 | 33 | def close(self) -> None: 34 | """Close the resources used by the MongoDBAtlasFullTextSearchRetriever.""" 35 | self.collection.database.client.close() 36 | 37 | def _get_relevant_documents( 38 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any 39 | ) -> List[Document]: 40 | """Retrieve documents that are highest scoring / most similar to query. 41 | 42 | Args: 43 | query: String to find relevant documents for 44 | run_manager: The callback handler to use 45 | Returns: 46 | List of relevant documents 47 | """ 48 | default_k = self.k if self.k is not None else self.top_k 49 | pipeline = text_search_stage( # type: ignore 50 | query=query, 51 | search_field=self.search_field, 52 | index_name=self.search_index_name, 53 | limit=kwargs.get("k", default_k), 54 | filter=self.filter, 55 | include_scores=self.include_scores, 56 | ) 57 | 58 | # Execution 59 | cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type] 60 | 61 | # Formatting 62 | docs = [] 63 | for res in cursor: 64 | text = res.pop(self.search_field) 65 | make_serializable(res) 66 | docs.append(Document(page_content=text, metadata=res)) 67 | return docs 68 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/retrievers/graphrag.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun 4 | from langchain_core.documents import Document 5 | from langchain_core.retrievers import BaseRetriever 6 | 7 | from langchain_mongodb.graphrag.graph import MongoDBGraphStore 8 | 9 | 10 | class MongoDBGraphRAGRetriever(BaseRetriever): 11 | """RunnableSerializable API of MongoDB GraphRAG.""" 12 | 13 | graph_store: MongoDBGraphStore 14 | """Underlying Knowledge Graph storing entities and their relationships.""" 15 | 16 | def _get_relevant_documents( 17 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun 18 | ) -> List[Document]: 19 | """Retrieve list of Entities found via traversal of KnowledgeGraph. 20 | 21 | Each Document's page_content is a string representation of the Entity dict. 22 | 23 | Description and details are provided in the underlying Entity Graph: 24 | :class:`~langchain_mongodb.graphrag.graph.MongoDBGraphStore` 25 | 26 | Args: 27 | query: String to find relevant documents for 28 | run_manager: The callback handler to use if desired 29 | Returns: 30 | List of relevant documents. 31 | """ 32 | return [Document(str(e)) for e in self.graph_store.similarity_search(query)] 33 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/retrievers/hybrid_search.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Any, Dict, List, Optional 2 | 3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun 4 | from langchain_core.documents import Document 5 | from langchain_core.retrievers import BaseRetriever 6 | from pydantic import Field 7 | from pymongo.collection import Collection 8 | 9 | from langchain_mongodb import MongoDBAtlasVectorSearch 10 | from langchain_mongodb.pipelines import ( 11 | combine_pipelines, 12 | final_hybrid_stage, 13 | reciprocal_rank_stage, 14 | text_search_stage, 15 | vector_search_stage, 16 | ) 17 | from langchain_mongodb.utils import make_serializable 18 | 19 | 20 | class MongoDBAtlasHybridSearchRetriever(BaseRetriever): 21 | """Hybrid Search Retriever combines vector and full-text searches 22 | weighting them the via Reciprocal Rank Fusion (RRF) algorithm. 23 | 24 | Increasing the vector_penalty will reduce the importance on the vector search. 25 | Increasing the fulltext_penalty will correspondingly reduce the fulltext score. 26 | For more on the algorithm,see 27 | https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking 28 | """ 29 | 30 | vectorstore: MongoDBAtlasVectorSearch 31 | """MongoDBAtlas VectorStore""" 32 | search_index_name: str 33 | """Atlas Search Index (full-text) name""" 34 | k: int = 4 35 | """Number of documents to return.""" 36 | oversampling_factor: int = 10 37 | """This times k is the number of candidates chosen at each step""" 38 | pre_filter: Optional[Dict[str, Any]] = None 39 | """(Optional) Any MQL match expression comparing an indexed field""" 40 | post_filter: Optional[List[Dict[str, Any]]] = None 41 | """(Optional) Pipeline of MongoDB aggregation stages for postprocessing.""" 42 | vector_penalty: float = 60.0 43 | """Penalty applied to vector search results in RRF: scores=1/(rank + penalty)""" 44 | fulltext_penalty: float = 60.0 45 | """Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)""" 46 | show_embeddings: float = False 47 | """If true, returned Document metadata will include vectors.""" 48 | top_k: Annotated[ 49 | Optional[int], Field(deprecated='top_k is deprecated, use "k" instead') 50 | ] = None 51 | """Number of documents to return.""" 52 | 53 | @property 54 | def collection(self) -> Collection: 55 | return self.vectorstore._collection 56 | 57 | def close(self) -> None: 58 | """Close the resources used by the MongoDBAtlasHybridSearchRetriever.""" 59 | self.vectorstore.close() 60 | 61 | def _get_relevant_documents( 62 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any 63 | ) -> List[Document]: 64 | """Retrieve documents that are highest scoring / most similar to query. 65 | 66 | Note that the same query is used in both searches, 67 | embedded for vector search, and as-is for full-text search. 68 | 69 | Args: 70 | query: String to find relevant documents for 71 | run_manager: The callback handler to use 72 | Returns: 73 | List of relevant documents 74 | """ 75 | 76 | query_vector = self.vectorstore._embedding.embed_query(query) 77 | 78 | scores_fields = ["vector_score", "fulltext_score"] 79 | pipeline: List[Any] = [] 80 | 81 | # Get the appropriate value for k. 82 | default_k = self.top_k if self.top_k is not None else self.k 83 | k = kwargs.get("k", default_k) 84 | 85 | # First we build up the aggregation pipeline, 86 | # then it is passed to the server to execute 87 | # Vector Search stage 88 | vector_pipeline = [ 89 | vector_search_stage( 90 | query_vector=query_vector, 91 | search_field=self.vectorstore._embedding_key, 92 | index_name=self.vectorstore._index_name, 93 | top_k=k, 94 | filter=self.pre_filter, 95 | oversampling_factor=self.oversampling_factor, 96 | ) 97 | ] 98 | vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty) 99 | 100 | combine_pipelines(pipeline, vector_pipeline, self.collection.name) 101 | 102 | # Full-Text Search stage 103 | text_pipeline = text_search_stage( 104 | query=query, 105 | search_field=self.vectorstore._text_key, 106 | index_name=self.search_index_name, 107 | limit=k, 108 | filter=self.pre_filter, 109 | ) 110 | 111 | text_pipeline.extend( 112 | reciprocal_rank_stage("fulltext_score", self.fulltext_penalty) 113 | ) 114 | 115 | combine_pipelines(pipeline, text_pipeline, self.collection.name) 116 | 117 | # Sum and sort stage 118 | pipeline.extend(final_hybrid_stage(scores_fields=scores_fields, limit=k)) 119 | 120 | # Removal of embeddings unless requested. 121 | if not self.show_embeddings: 122 | pipeline.append({"$project": {self.vectorstore._embedding_key: 0}}) 123 | # Post filtering 124 | if self.post_filter is not None: 125 | pipeline.extend(self.post_filter) 126 | 127 | # Execution 128 | cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type] 129 | 130 | # Formatting 131 | docs = [] 132 | for res in cursor: 133 | text = res.pop(self.vectorstore._text_key) 134 | # score = res.pop("score") # The score remains buried! 135 | make_serializable(res) 136 | docs.append(Document(page_content=text, metadata=res)) 137 | return docs 138 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/langchain_mongodb/utils.py: -------------------------------------------------------------------------------- 1 | """Various Utility Functions 2 | 3 | - Tools for handling bson.ObjectId 4 | 5 | The help IDs live as ObjectId in MongoDB and str in Langchain and JSON. 6 | 7 | 8 | - Tools for the Maximal Marginal Relevance (MMR) reranking 9 | 10 | These are duplicated from langchain_community to avoid cross-dependencies. 11 | 12 | Functions "maximal_marginal_relevance" and "cosine_similarity" 13 | are duplicated in this utility respectively from modules: 14 | 15 | - "libs/community/langchain_community/vectorstores/utils.py" 16 | - "libs/community/langchain_community/utils/math.py" 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | import logging 22 | from datetime import date, datetime 23 | from typing import Any, Dict, List, Union 24 | 25 | import numpy as np 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 30 | 31 | 32 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 33 | """Row-wise cosine similarity between two equal-width matrices.""" 34 | if len(X) == 0 or len(Y) == 0: 35 | return np.array([]) 36 | 37 | X = np.array(X) 38 | Y = np.array(Y) 39 | if X.shape[1] != Y.shape[1]: 40 | raise ValueError( 41 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 42 | f"and Y has shape {Y.shape}." 43 | ) 44 | try: 45 | import simsimd as simd 46 | 47 | X = np.array(X, dtype=np.float32) 48 | Y = np.array(Y, dtype=np.float32) 49 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) 50 | return Z 51 | except ImportError: 52 | logger.debug( 53 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 54 | "to use simsimd please install with `pip install simsimd`." 55 | ) 56 | X_norm = np.linalg.norm(X, axis=1) 57 | Y_norm = np.linalg.norm(Y, axis=1) 58 | # Ignore divide by zero errors run time warnings as those are handled below. 59 | with np.errstate(divide="ignore", invalid="ignore"): 60 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 61 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 62 | return similarity 63 | 64 | 65 | def maximal_marginal_relevance( 66 | query_embedding: np.ndarray, 67 | embedding_list: list, 68 | lambda_mult: float = 0.5, 69 | k: int = 4, 70 | ) -> List[int]: 71 | r"""Compute Maximal Marginal Relevance (MMR). 72 | 73 | MMR is a technique used to select documents that are both relevant to the query 74 | and diverse among themselves. This function returns the indices 75 | of the top-k embeddings that maximize the marginal relevance. 76 | 77 | Args: 78 | query_embedding (np.ndarray): The embedding vector of the query. 79 | embedding_list (list of np.ndarray): A list containing the embedding vectors 80 | of the candidate documents. 81 | lambda_mult (float, optional): The trade-off parameter between 82 | relevance and diversity. Defaults to 0.5. 83 | k (int, optional): The number of embeddings to select. Defaults to 4. 84 | 85 | Returns: 86 | list of int: The indices of the embeddings that maximize the marginal relevance. 87 | 88 | Notes: 89 | The Maximal Marginal Relevance (MMR) is computed using the following formula: 90 | 91 | MMR = argmax_{D_i ∈ R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j ∈ S} Sim(D_i, D_j)] 92 | 93 | where: 94 | - R is the set of candidate documents, 95 | - S is the set of selected documents, 96 | - Q is the query embedding, 97 | - Sim(D_i, Q) is the similarity between document D_i and the query, 98 | - Sim(D_i, D_j) is the similarity between documents D_i and D_j, 99 | - λ is the trade-off parameter. 100 | """ 101 | 102 | if min(k, len(embedding_list)) <= 0: 103 | return [] 104 | if query_embedding.ndim == 1: 105 | query_embedding = np.expand_dims(query_embedding, axis=0) 106 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 107 | most_similar = int(np.argmax(similarity_to_query)) 108 | idxs = [most_similar] 109 | selected = np.array([embedding_list[most_similar]]) 110 | while len(idxs) < min(k, len(embedding_list)): 111 | best_score = -np.inf 112 | idx_to_add = -1 113 | similarity_to_selected = cosine_similarity(embedding_list, selected) 114 | for i, query_score in enumerate(similarity_to_query): 115 | if i in idxs: 116 | continue 117 | redundant_score = max(similarity_to_selected[i]) 118 | equation_score = ( 119 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 120 | ) 121 | if equation_score > best_score: 122 | best_score = equation_score 123 | idx_to_add = i 124 | idxs.append(idx_to_add) 125 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 126 | return idxs 127 | 128 | 129 | def str_to_oid(str_repr: str) -> Any | str: 130 | """Attempt to cast string representation of id to MongoDB's internal BSON ObjectId. 131 | 132 | To be consistent with ObjectId, input must be a 24 character hex string. 133 | If it is not, MongoDB will happily use the string in the main _id index. 134 | Importantly, the str representation that comes out of MongoDB will have this form. 135 | 136 | Args: 137 | str_repr: id as string. 138 | 139 | Returns: 140 | ObjectID 141 | """ 142 | from bson import ObjectId 143 | from bson.errors import InvalidId 144 | 145 | try: 146 | return ObjectId(str_repr) 147 | except InvalidId: 148 | logger.debug( 149 | "ObjectIds must be 12-character byte or 24-character hex strings. " 150 | "Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'" 151 | ) 152 | return str_repr 153 | 154 | 155 | def oid_to_str(oid: Any) -> str: 156 | """Convert MongoDB's internal BSON ObjectId into a simple str for compatibility. 157 | 158 | Instructive helper to show where data is coming out of MongoDB. 159 | 160 | Args: 161 | oid: bson.ObjectId 162 | 163 | Returns: 164 | 24 character hex string. 165 | """ 166 | return str(oid) 167 | 168 | 169 | def make_serializable( 170 | obj: Dict[str, Any], 171 | ) -> None: 172 | """Recursively cast values in a dict to a form able to json.dump""" 173 | 174 | from bson import ObjectId 175 | 176 | for k, v in obj.items(): 177 | if isinstance(v, dict): 178 | make_serializable(v) 179 | elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)): 180 | obj[k] = [oid_to_str(item) for item in v] 181 | elif isinstance(v, ObjectId): 182 | obj[k] = oid_to_str(v) 183 | elif isinstance(v, (datetime, date)): 184 | obj[k] = v.isoformat() 185 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [build-system] 3 | requires = ["hatchling>1.24"] 4 | build-backend = "hatchling.build" 5 | 6 | [project] 7 | name = "langchain-mongodb" 8 | version = "0.6.2" 9 | description = "An integration package connecting MongoDB and LangChain" 10 | readme = "README.md" 11 | requires-python = ">=3.9" 12 | dependencies = [ 13 | "langchain-core>=0.3", 14 | "langchain>=0.3", 15 | "pymongo>=4.6.1", 16 | "langchain-text-splitters>=0.3", 17 | "numpy>=1.26", 18 | "lark<2.0.0,>=1.1.9", 19 | ] 20 | 21 | [dependency-groups] 22 | dev = [ 23 | "freezegun>=1.2.2", 24 | "langchain>=0.3.14", 25 | "langchain-core>=0.3.29", 26 | "langchain-text-splitters>=0.3.5", 27 | "pytest-mock>=3.10.0", 28 | "pytest>=7.3.0", 29 | "syrupy>=4.0.2", 30 | "pytest-watcher>=0.3.4", 31 | "pytest-asyncio>=0.21.1", 32 | "mongomock>=4.2.0.post1", 33 | "pre-commit>=4.0", 34 | "mypy>=1.10", 35 | "simsimd>=5.0.0", 36 | "langchain-ollama>=0.2.2", 37 | "langchain-openai>=0.2.14", 38 | "langchain-community>=0.3.14", 39 | "pypdf>=5.0.1", 40 | "langgraph>=0.2.72", 41 | "flaky>=3.8.1", 42 | "langchain-tests==0.3.14", 43 | "pip>=25.0.1", 44 | "typing-extensions>=4.12.2", 45 | ] 46 | 47 | [tool.pytest.ini_options] 48 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5" 49 | markers = [ 50 | "requires: mark tests as requiring a specific library", 51 | "compile: mark placeholder test used to compile integration tests without running them", 52 | ] 53 | asyncio_mode = "auto" 54 | asyncio_default_fixture_loop_scope = "function" 55 | 56 | [tool.mypy] 57 | disallow_untyped_defs = true 58 | 59 | [[tool.mypy.overrides]] 60 | module = ["tests.*"] 61 | disable_error_code = ["no-untyped-def", "no-untyped-call", "var-annotated"] 62 | 63 | [tool.ruff] 64 | lint.select = [ 65 | "E", # pycodestyle 66 | "F", # Pyflakes 67 | "UP", # pyupgrade 68 | "B", # flake8-bugbear 69 | "I", # isort 70 | ] 71 | lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035"] 72 | 73 | [tool.coverage.run] 74 | omit = ["tests/*"] 75 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/__init__.py -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/integration_tests/__init__.py -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Generator, List 3 | 4 | import pytest 5 | from langchain_community.document_loaders import PyPDFLoader 6 | from langchain_core.documents import Document 7 | from langchain_core.embeddings import Embeddings 8 | from langchain_ollama.embeddings import OllamaEmbeddings 9 | from langchain_openai import OpenAIEmbeddings 10 | from pymongo import MongoClient 11 | 12 | from ..utils import CONNECTION_STRING 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def technical_report_pages() -> List[Document]: 17 | """Returns a Document for each of the 100 pages of a GPT-4 Technical Report""" 18 | loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf") 19 | pages = loader.load() 20 | return pages 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def client() -> Generator[MongoClient, None, None]: 25 | client = MongoClient(CONNECTION_STRING) 26 | yield client 27 | client.close() 28 | 29 | 30 | @pytest.fixture(scope="session") 31 | def embedding() -> Embeddings: 32 | if os.environ.get("OPENAI_API_KEY"): 33 | return OpenAIEmbeddings( 34 | openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa 35 | model="text-embedding-3-small", 36 | ) 37 | 38 | return OllamaEmbeddings(model="all-minilm:l6-v2") 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def dimensions() -> int: 43 | if os.environ.get("OPENAI_API_KEY"): 44 | return 1536 45 | return 384 46 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_agent_toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | 4 | import pytest 5 | import requests 6 | from flaky import flaky # type:ignore[import-untyped] 7 | from langchain_openai import ChatOpenAI 8 | from langgraph.prebuilt import create_react_agent 9 | from pymongo import MongoClient 10 | 11 | from langchain_mongodb.agent_toolkit import ( 12 | MONGODB_AGENT_SYSTEM_PROMPT, 13 | MongoDBDatabase, 14 | MongoDBDatabaseToolkit, 15 | ) 16 | from tests.utils import CONNECTION_STRING 17 | 18 | DB_NAME = "langchain_test_db_chinook" 19 | 20 | 21 | @pytest.fixture 22 | def db(client: MongoClient) -> MongoDBDatabase: 23 | # Load the raw data into sqlite. 24 | url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql" 25 | response = requests.get(url) 26 | sql_script = response.text 27 | con = sqlite3.connect(":memory:", check_same_thread=False) 28 | con.executescript(sql_script) 29 | 30 | # Convert the sqlite data to MongoDB data. 31 | con.row_factory = sqlite3.Row 32 | cursor = con.cursor() 33 | sql_query = """SELECT name FROM sqlite_master WHERE type='table';""" 34 | cursor.execute(sql_query) 35 | tables = [i[0] for i in cursor.fetchall()] 36 | cursor.close() 37 | for t in tables: 38 | coll = client[DB_NAME][t] 39 | coll.delete_many({}) 40 | cursor = con.cursor() 41 | cursor.execute(f"select * from {t}") 42 | docs = [dict(i) for i in cursor.fetchall()] 43 | cursor.close() 44 | coll.insert_many(docs) 45 | return MongoDBDatabase(client, DB_NAME) 46 | 47 | 48 | @flaky(max_runs=5, min_passes=4) 49 | @pytest.mark.skipif( 50 | "OPENAI_API_KEY" not in os.environ, reason="test requires OpenAI for chat responses" 51 | ) 52 | def test_toolkit_response(db): 53 | db_wrapper = MongoDBDatabase.from_connection_string( 54 | CONNECTION_STRING, database=DB_NAME 55 | ) 56 | llm = ChatOpenAI(model="gpt-4o-mini", timeout=60) 57 | 58 | toolkit = MongoDBDatabaseToolkit(db=db_wrapper, llm=llm) 59 | 60 | system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5) 61 | 62 | test_query = "Which country's customers spent the most?" 63 | agent = create_react_agent(llm, toolkit.get_tools(), state_modifier=system_message) 64 | agent.step_timeout = 60 65 | events = agent.stream( 66 | {"messages": [("user", test_query)]}, 67 | stream_mode="values", 68 | ) 69 | messages = [] 70 | for event in events: 71 | messages.extend(event["messages"]) 72 | assert "USA" in messages[-1].content, messages[-1].content 73 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import Any, List, Union 4 | 5 | import pytest # type: ignore[import-not-found] 6 | from langchain_core.caches import BaseCache 7 | from langchain_core.globals import ( 8 | get_llm_cache, 9 | set_llm_cache, 10 | ) 11 | from langchain_core.load.dump import dumps 12 | from langchain_core.messages import AIMessage, BaseMessage, HumanMessage 13 | from langchain_core.outputs import ChatGeneration, Generation, LLMResult 14 | from pymongo import MongoClient 15 | from pymongo.collection import Collection 16 | 17 | from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache 18 | from langchain_mongodb.index import ( 19 | create_vector_search_index, 20 | ) 21 | 22 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, FakeChatModel, FakeLLM 23 | 24 | CONN_STRING = os.environ.get("MONGODB_URI") 25 | INDEX_NAME = "langchain-test-index-semantic-cache" 26 | COLLECTION = "langchain_test_cache" 27 | 28 | DIMENSIONS = 1536 # Meets OpenAI model 29 | TIMEOUT = 60.0 30 | 31 | 32 | def random_string() -> str: 33 | return str(uuid.uuid4()) 34 | 35 | 36 | @pytest.fixture(scope="module") 37 | def collection(client: MongoClient) -> Collection: 38 | """A Collection with both a Vector and a Full-text Search Index""" 39 | if COLLECTION not in client[DB_NAME].list_collection_names(): 40 | clxn = client[DB_NAME].create_collection(COLLECTION) 41 | else: 42 | clxn = client[DB_NAME][COLLECTION] 43 | 44 | clxn.delete_many({}) 45 | 46 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): 47 | create_vector_search_index( 48 | collection=clxn, 49 | index_name=INDEX_NAME, 50 | dimensions=DIMENSIONS, 51 | path="embedding", 52 | filters=["llm_string"], 53 | similarity="cosine", 54 | wait_until_complete=TIMEOUT, 55 | ) 56 | 57 | return clxn 58 | 59 | 60 | def llm_cache(cls: Any) -> BaseCache: 61 | set_llm_cache( 62 | cls( 63 | embedding=ConsistentFakeEmbeddings(dimensionality=DIMENSIONS), 64 | connection_string=CONN_STRING, 65 | collection_name=COLLECTION, 66 | database_name=DB_NAME, 67 | index_name=INDEX_NAME, 68 | score_threshold=0.5, 69 | wait_until_ready=TIMEOUT, 70 | ) 71 | ) 72 | assert get_llm_cache() 73 | return get_llm_cache() 74 | 75 | 76 | @pytest.fixture(scope="module", autouse=True) 77 | def reset_cache(): 78 | """Prevents global cache being affected in other module's tests.""" 79 | yield 80 | print("\nAll cache tests have finished. Setting global cache to None.") 81 | set_llm_cache(None) 82 | 83 | 84 | def _execute_test( 85 | prompt: Union[str, List[BaseMessage]], 86 | llm: Union[str, FakeLLM, FakeChatModel], 87 | response: List[Generation], 88 | ) -> None: 89 | # Fabricate an LLM String 90 | 91 | if not isinstance(llm, str): 92 | params = llm.dict() 93 | params["stop"] = None 94 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 95 | else: 96 | llm_string = llm 97 | 98 | # If the prompt is a str then we should pass just the string 99 | dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt) 100 | 101 | # Update the cache 102 | get_llm_cache().update(dumped_prompt, llm_string, response) 103 | 104 | # Retrieve the cached result through 'generate' call 105 | output: Union[List[Generation], LLMResult, None] 106 | expected_output: Union[List[Generation], LLMResult] 107 | 108 | if isinstance(llm, str): 109 | output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore 110 | expected_output = response 111 | else: 112 | output = llm.generate([prompt]) # type: ignore 113 | expected_output = LLMResult( 114 | generations=[response], 115 | llm_output={}, 116 | ) 117 | 118 | assert output == expected_output # type: ignore 119 | 120 | 121 | @pytest.mark.parametrize( 122 | "prompt, llm, response", 123 | [ 124 | ("foo", "bar", [Generation(text="fizz")]), 125 | ("foo", FakeLLM(), [Generation(text="fizz")]), 126 | ( 127 | [HumanMessage(content="foo")], 128 | FakeChatModel(), 129 | [ChatGeneration(message=AIMessage(content="foo"))], 130 | ), 131 | ], 132 | ids=[ 133 | "plain_cache", 134 | "cache_with_llm", 135 | "cache_with_chat", 136 | ], 137 | ) 138 | @pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache]) 139 | @pytest.mark.parametrize("remove_score", [True, False]) 140 | def test_mongodb_cache( 141 | remove_score: bool, 142 | cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache], 143 | prompt: Union[str, List[BaseMessage]], 144 | llm: Union[str, FakeLLM, FakeChatModel], 145 | response: List[Generation], 146 | collection: Collection, 147 | ) -> None: 148 | llm_cache(cacher) 149 | if remove_score: 150 | get_llm_cache().score_threshold = None # type: ignore 151 | try: 152 | _execute_test(prompt, llm, response) 153 | finally: 154 | get_llm_cache().clear() 155 | 156 | 157 | @pytest.mark.parametrize( 158 | "prompts, generations", 159 | [ 160 | # Single prompt, single generation 161 | ([random_string()], [[random_string()]]), 162 | # Single prompt, multiple generations 163 | ([random_string()], [[random_string(), random_string()]]), 164 | # Single prompt, multiple generations 165 | ([random_string()], [[random_string(), random_string(), random_string()]]), 166 | # Multiple prompts, multiple generations 167 | ( 168 | [random_string(), random_string()], 169 | [[random_string()], [random_string(), random_string()]], 170 | ), 171 | ], 172 | ids=[ 173 | "single_prompt_single_generation", 174 | "single_prompt_two_generations", 175 | "single_prompt_three_generations", 176 | "multiple_prompts_multiple_generations", 177 | ], 178 | ) 179 | def test_mongodb_atlas_cache_matrix( 180 | prompts: List[str], 181 | generations: List[List[str]], 182 | collection: Collection, 183 | ) -> None: 184 | llm_cache(MongoDBAtlasSemanticCache) 185 | llm = FakeLLM() 186 | 187 | # Fabricate an LLM String 188 | params = llm.dict() 189 | params["stop"] = None 190 | llm_string = str(sorted([(k, v) for k, v in params.items()])) 191 | 192 | llm_generations = [ 193 | [ 194 | Generation(text=generation, generation_info=params) 195 | for generation in prompt_i_generations 196 | ] 197 | for prompt_i_generations in generations 198 | ] 199 | 200 | for prompt_i, llm_generations_i in zip(prompts, llm_generations): 201 | _execute_test(prompt_i, llm_string, llm_generations_i) 202 | assert llm.generate(prompts) == LLMResult( 203 | generations=llm_generations, llm_output={} 204 | ) 205 | get_llm_cache().clear() 206 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_chain_example.py: -------------------------------------------------------------------------------- 1 | "Demonstrates MongoDBAtlasVectorSearch.as_retriever() invoked in a chain" "" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | 7 | import pytest # type: ignore[import-not-found] 8 | from langchain_core.documents import Document 9 | from langchain_core.embeddings import Embeddings 10 | from langchain_core.output_parsers.string import StrOutputParser 11 | from langchain_core.prompts.chat import ChatPromptTemplate 12 | from langchain_core.runnables import RunnablePassthrough 13 | from langchain_openai import ChatOpenAI 14 | from pymongo import MongoClient 15 | from pymongo.collection import Collection 16 | 17 | from langchain_mongodb import index 18 | 19 | from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch 20 | 21 | COLLECTION_NAME = "langchain_test_chain_example" 22 | INDEX_NAME = "langchain-test-chain-example-vector-index" 23 | DIMENSIONS = 1536 24 | TIMEOUT = 60.0 25 | INTERVAL = 0.5 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def collection(client: MongoClient) -> Collection: 30 | """A Collection with both a Vector and a Full-text Search Index""" 31 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 32 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 33 | else: 34 | clxn = client[DB_NAME][COLLECTION_NAME] 35 | 36 | clxn.delete_many({}) 37 | 38 | if all([INDEX_NAME != ix["name"] for ix in clxn.list_search_indexes()]): 39 | index.create_vector_search_index( 40 | collection=clxn, 41 | index_name=INDEX_NAME, 42 | dimensions=DIMENSIONS, 43 | path="embedding", 44 | similarity="cosine", 45 | filters=None, 46 | wait_until_complete=TIMEOUT, 47 | ) 48 | 49 | return clxn 50 | 51 | 52 | @pytest.mark.skipif( 53 | not os.environ.get("OPENAI_API_KEY"), 54 | reason="Requires OpenAI for chat responses.", 55 | ) 56 | def test_chain( 57 | collection: Collection, 58 | embedding: Embeddings, 59 | ) -> None: 60 | """Demonstrate usage of MongoDBAtlasVectorSearch in a realistic chain 61 | 62 | Follows example in the docs: https://python.langchain.com/docs/how_to/hybrid/ 63 | 64 | Requires OpenAI_API_KEY for embedding and chat model. 65 | Requires INDEX_NAME to have been set up on MONGODB_URI 66 | """ 67 | 68 | vectorstore = PatchedMongoDBAtlasVectorSearch( 69 | collection=collection, 70 | embedding=embedding, 71 | index_name=INDEX_NAME, 72 | text_key="page_content", 73 | ) 74 | 75 | texts = [ 76 | "In 2023, I visited Paris", 77 | "In 2022, I visited New York", 78 | "In 2021, I visited New Orleans", 79 | "In 2019, I visited San Francisco", 80 | "In 2020, I visited Vancouver", 81 | ] 82 | vectorstore.add_texts(texts) 83 | 84 | query = "In the United States, what city did I visit last?" 85 | # One can do vector search on the vector store, using its various search types. 86 | k = len(texts) 87 | 88 | store_output = list(vectorstore.similarity_search(query=query, k=k)) 89 | assert len(store_output) == k 90 | assert isinstance(store_output[0], Document) 91 | 92 | # Unfortunately, the VectorStore output cannot be given to a Chat Model 93 | # If we wish Chat Model to answer based on our own data, 94 | # we have to give it the right things to work with. 95 | # The way that Langchain does this is by piping results along in 96 | # a Chain: https://python.langchain.com/v0.1/docs/modules/chains/ 97 | 98 | # Now, we can turn our VectorStore into something Runnable in a Chain 99 | # by turning it into a Retriever. 100 | # For the simple VectorSearch Retriever, we can do this like so. 101 | 102 | retriever = vectorstore.as_retriever(search_kwargs=dict(k=k)) 103 | 104 | # This does not do much other than expose our search function 105 | # as an invoke() method with a a certain API, a Runnable. 106 | retriever_output = retriever.invoke(query) 107 | assert len(retriever_output) == len(texts) 108 | assert retriever_output[0].page_content == store_output[0].page_content 109 | 110 | # To get a natural language response to our question, 111 | # we need ChatOpenAI, a template to better frame the question as a prompt, 112 | # and a parser to send the output to a string. 113 | # Together, these become our Chain! 114 | # Here goes: 115 | 116 | template = """Answer the question based only on the following context. 117 | Answer in as few words as possible. 118 | {context} 119 | Question: {question} 120 | """ 121 | prompt = ChatPromptTemplate.from_template(template) 122 | 123 | model = ChatOpenAI() 124 | 125 | chain = ( 126 | {"context": retriever, "question": RunnablePassthrough()} # type: ignore 127 | | prompt 128 | | model 129 | | StrOutputParser() 130 | ) 131 | 132 | answer = chain.invoke("What city did I visit last?") 133 | 134 | assert "Paris" in answer 135 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found] 4 | from langchain_core.messages import message_to_dict 5 | 6 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory 7 | 8 | from ..utils import CONNECTION_STRING, DB_NAME 9 | 10 | COLLECTION = "langchain_test_chat" 11 | 12 | 13 | def test_memory_with_message_store() -> None: 14 | """Test the memory with a message store.""" 15 | # setup MongoDB as a message store 16 | message_history = MongoDBChatMessageHistory( 17 | connection_string=CONNECTION_STRING, 18 | session_id="test-session", 19 | database_name=DB_NAME, 20 | collection_name=COLLECTION, 21 | ) 22 | memory = ConversationBufferMemory( 23 | memory_key="baz", chat_memory=message_history, return_messages=True 24 | ) 25 | 26 | # add some messages 27 | memory.chat_memory.add_ai_message("This is me, the AI") 28 | memory.chat_memory.add_user_message("This is me, the human") 29 | 30 | # get the message history from the memory store and turn it into a json 31 | messages = memory.chat_memory.messages 32 | messages_json = json.dumps([message_to_dict(msg) for msg in messages]) 33 | 34 | assert "This is me, the AI" in messages_json 35 | assert "This is me, the human" in messages_json 36 | 37 | # remove the record from MongoDB, so the next test run won't pick it up 38 | memory.chat_memory.clear() 39 | 40 | assert memory.chat_memory.messages == [] 41 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore[import-not-found] 2 | 3 | 4 | @pytest.mark.compile 5 | def test_placeholder() -> None: 6 | """Used for compiling integration tests without running any real tests.""" 7 | pass 8 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_docstore.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.documents import Document 4 | from pymongo import MongoClient 5 | 6 | from langchain_mongodb.docstores import MongoDBDocStore 7 | 8 | from ..utils import DB_NAME 9 | 10 | COLLECTION_NAME = "langchain_test_docstore" 11 | 12 | 13 | def test_docstore(client: MongoClient, technical_report_pages: List[Document]) -> None: 14 | db = client[DB_NAME] 15 | db.drop_collection(COLLECTION_NAME) 16 | clxn = db[COLLECTION_NAME] 17 | 18 | n_docs = len(technical_report_pages) 19 | assert clxn.count_documents({}) == 0 20 | docstore = MongoDBDocStore(collection=clxn) 21 | 22 | docstore.mset([(str(i), technical_report_pages[i]) for i in range(n_docs)]) 23 | assert clxn.count_documents({}) == n_docs 24 | 25 | twenties = list(docstore.yield_keys(prefix="2")) 26 | assert len(twenties) == 11 # includes 2, 20, 21, ..., 29 27 | 28 | docstore.mdelete([str(i) for i in range(20, 30)] + ["2"]) 29 | assert clxn.count_documents({}) == n_docs - 11 30 | assert set(docstore.mget(twenties)) == {None} 31 | 32 | sample = docstore.mget(["8", "16", "24", "36"]) 33 | assert sample[2] is None 34 | assert all(isinstance(sample[i], Document) for i in [0, 1, 3]) 35 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_index.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, List, Optional 2 | 3 | import pytest 4 | from pymongo import MongoClient 5 | from pymongo.collection import Collection 6 | 7 | from langchain_mongodb import MongoDBAtlasVectorSearch, index 8 | 9 | from ..utils import ConsistentFakeEmbeddings 10 | 11 | DB_NAME = "langchain_test_index_db" 12 | COLLECTION_NAME = "test_index" 13 | VECTOR_INDEX_NAME = "vector_index" 14 | 15 | TIMEOUT = 120 16 | DIMENSIONS = 10 17 | 18 | 19 | @pytest.fixture 20 | def collection(client: MongoClient) -> Generator: 21 | """Depending on uri, this could point to any type of cluster.""" 22 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 23 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 24 | else: 25 | clxn = client[DB_NAME][COLLECTION_NAME] 26 | clxn = client[DB_NAME][COLLECTION_NAME] 27 | clxn.delete_many({}) 28 | yield clxn 29 | clxn.delete_many({}) 30 | 31 | 32 | def test_search_index_drop_add_delete_commands(collection: Collection) -> None: 33 | index_name = VECTOR_INDEX_NAME 34 | dimensions = DIMENSIONS 35 | path = "embedding" 36 | similarity = "cosine" 37 | filters: Optional[List[str]] = None 38 | wait_until_complete = TIMEOUT 39 | 40 | for index_info in collection.list_search_indexes(): 41 | index.drop_vector_search_index( 42 | collection, index_info["name"], wait_until_complete=wait_until_complete 43 | ) 44 | 45 | assert len(list(collection.list_search_indexes())) == 0 46 | 47 | index.create_vector_search_index( 48 | collection=collection, 49 | index_name=index_name, 50 | dimensions=dimensions, 51 | path=path, 52 | similarity=similarity, 53 | filters=filters, 54 | wait_until_complete=wait_until_complete, 55 | ) 56 | 57 | assert index._is_index_ready(collection, index_name) 58 | indexes = list(collection.list_search_indexes()) 59 | assert len(indexes) == 1 60 | assert indexes[0]["name"] == index_name 61 | 62 | index.drop_vector_search_index( 63 | collection, index_name, wait_until_complete=wait_until_complete 64 | ) 65 | 66 | indexes = list(collection.list_search_indexes()) 67 | assert len(indexes) == 0 68 | 69 | 70 | @pytest.mark.skip("collection.update_vector_search_index requires [CLOUDP-275518]") 71 | def test_search_index_update_vector_search_index(collection: Collection) -> None: 72 | index_name = "INDEX_TO_UPDATE" 73 | similarity_orig = "cosine" 74 | similarity_new = "euclidean" 75 | 76 | # Create another index 77 | index.create_vector_search_index( 78 | collection=collection, 79 | index_name=index_name, 80 | dimensions=DIMENSIONS, 81 | path="embedding", 82 | similarity=similarity_orig, 83 | wait_until_complete=TIMEOUT, 84 | ) 85 | 86 | assert index._is_index_ready(collection, index_name) 87 | indexes = list(collection.list_search_indexes()) 88 | assert len(indexes) == 1 89 | assert indexes[0]["name"] == index_name 90 | assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_orig 91 | 92 | # Update the index and test new similarity 93 | index.update_vector_search_index( 94 | collection=collection, 95 | index_name=index_name, 96 | dimensions=DIMENSIONS, 97 | path="embedding", 98 | similarity=similarity_new, 99 | wait_until_complete=TIMEOUT, 100 | ) 101 | 102 | assert index._is_index_ready(collection, index_name) 103 | indexes = list(collection.list_search_indexes()) 104 | assert len(indexes) == 1 105 | assert indexes[0]["name"] == index_name 106 | assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_new 107 | 108 | 109 | def test_vectorstore_create_vector_search_index(collection: Collection) -> None: 110 | """Tests vectorstore wrapper around index command.""" 111 | 112 | # Set up using the index module's api 113 | if len(list(collection.list_search_indexes())) != 0: 114 | index.drop_vector_search_index( 115 | collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT 116 | ) 117 | 118 | # Test MongoDBAtlasVectorSearch's API 119 | _ = MongoDBAtlasVectorSearch( 120 | collection=collection, 121 | embedding=ConsistentFakeEmbeddings(), 122 | index_name=VECTOR_INDEX_NAME, 123 | dimensions=DIMENSIONS, 124 | auto_index_timeout=TIMEOUT, 125 | ) 126 | 127 | assert index._is_index_ready(collection, VECTOR_INDEX_NAME) 128 | indexes = list(collection.list_search_indexes()) 129 | assert len(indexes) == 1 130 | assert indexes[0]["name"] == VECTOR_INDEX_NAME 131 | 132 | # Tear down using the index module's api 133 | index.drop_vector_search_index( 134 | collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT 135 | ) 136 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_loaders.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from langchain_core.documents import Document 4 | from pymongo import MongoClient 5 | 6 | from langchain_mongodb.loaders import MongoDBLoader 7 | 8 | from ..utils import DB_NAME 9 | 10 | COLLECTION_NAME = "langchain_test_loader" 11 | 12 | 13 | def raw_docs() -> List[Dict]: 14 | return [ 15 | {"_id": "1", "address": {"building": "1", "room": "1"}}, 16 | {"_id": "2", "address": {"building": "2", "room": "2"}}, 17 | {"_id": "3", "address": {"building": "3", "room": "2"}}, 18 | ] 19 | 20 | 21 | def expected_documents() -> List[Document]: 22 | return [ 23 | Document( 24 | page_content="2 2", 25 | metadata={"_id": "2", "database": DB_NAME, "collection": COLLECTION_NAME}, 26 | ), 27 | Document( 28 | page_content="3 2", 29 | metadata={"_id": "3", "database": DB_NAME, "collection": COLLECTION_NAME}, 30 | ), 31 | ] 32 | 33 | 34 | async def test_load_with_filters(client: MongoClient) -> None: 35 | filter_criteria = {"address.room": {"$eq": "2"}} 36 | field_names = ["address.building", "address.room"] 37 | metadata_names = ["_id"] 38 | include_db_collection_in_metadata = True 39 | 40 | collection = client[DB_NAME][COLLECTION_NAME] 41 | collection.delete_many({}) 42 | collection.insert_many(raw_docs()) 43 | 44 | loader = MongoDBLoader( 45 | collection, 46 | filter_criteria=filter_criteria, 47 | field_names=field_names, 48 | metadata_names=metadata_names, 49 | include_db_collection_in_metadata=include_db_collection_in_metadata, 50 | ) 51 | documents = await loader.aload() 52 | 53 | assert documents == expected_documents() 54 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_mmr.py: -------------------------------------------------------------------------------- 1 | """Test max_marginal_relevance_search.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest # type: ignore[import-not-found] 6 | from langchain_core.embeddings import Embeddings 7 | from pymongo import MongoClient 8 | from pymongo.collection import Collection 9 | 10 | from langchain_mongodb.index import ( 11 | create_vector_search_index, 12 | ) 13 | 14 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch 15 | 16 | COLLECTION_NAME = "langchain_test_vectorstores" 17 | INDEX_NAME = "langchain-test-index-vectorstores" 18 | DIMENSIONS = 5 19 | 20 | 21 | @pytest.fixture() 22 | def collection(client: MongoClient) -> Collection: 23 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 24 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 25 | else: 26 | clxn = client[DB_NAME][COLLECTION_NAME] 27 | 28 | clxn.delete_many({}) 29 | 30 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): 31 | create_vector_search_index( 32 | collection=clxn, 33 | index_name=INDEX_NAME, 34 | dimensions=5, 35 | path="embedding", 36 | filters=["c"], 37 | similarity="cosine", 38 | wait_until_complete=60, 39 | ) 40 | 41 | return clxn 42 | 43 | 44 | @pytest.fixture 45 | def embeddings() -> Embeddings: 46 | return ConsistentFakeEmbeddings(DIMENSIONS) 47 | 48 | 49 | def test_mmr(embeddings: Embeddings, collection: Collection) -> None: 50 | texts = ["foo", "foo", "fou", "foy"] 51 | collection.delete_many({}) 52 | vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts( 53 | texts, 54 | embedding=embeddings, 55 | collection=collection, 56 | index_name=INDEX_NAME, 57 | ) 58 | query = "foo" 59 | output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1) 60 | assert len(output) == len(texts) 61 | assert output[0].page_content == "foo" 62 | assert output[1].page_content != "foo" 63 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_mongodb_database.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: E501 2 | """Test MongoDB database wrapper.""" 3 | 4 | import json 5 | 6 | import pytest 7 | from pymongo import MongoClient 8 | 9 | from langchain_mongodb.agent_toolkit import MongoDBDatabase 10 | 11 | DB_NAME = "langchain_test_db_user" 12 | 13 | 14 | @pytest.fixture 15 | def db(client: MongoClient) -> MongoDBDatabase: 16 | client[DB_NAME].user.delete_many({}) 17 | user = dict(name="Alice", bio="Engineer from Ohio") 18 | client[DB_NAME]["user"].insert_one(user) 19 | company = dict(name="Acme", location="Montana") 20 | client[DB_NAME]["company"].insert_one(company) 21 | return MongoDBDatabase(client, DB_NAME) 22 | 23 | 24 | def test_collection_info(db: MongoDBDatabase) -> None: 25 | """Test that collection info is constructed properly.""" 26 | output = db.collection_info 27 | expected_header = f""" 28 | Database name: {DB_NAME} 29 | Collection name: company 30 | Schema from a sample of documents from the collection: 31 | _id: ObjectId 32 | name: String 33 | location: String""".strip() 34 | 35 | for line in expected_header.splitlines(): 36 | assert line.strip() in output 37 | 38 | 39 | def test_collection_info_w_sample_docs(db: MongoDBDatabase) -> None: 40 | """Test that collection info is constructed properly.""" 41 | 42 | # Provision. 43 | values = [ 44 | {"name": "Harrison", "bio": "bio"}, 45 | {"name": "Chase", "bio": "bio"}, 46 | ] 47 | db._client[DB_NAME]["user"].delete_many({}) 48 | db._client[DB_NAME]["user"].insert_many(values) 49 | 50 | # Query and verify. 51 | db = MongoDBDatabase(db._client, DB_NAME, sample_docs_in_collection_info=2) 52 | output = db.collection_info 53 | 54 | expected_header = f""" 55 | Database name: {DB_NAME} 56 | Collection name: company 57 | Schema from a sample of documents from the collection: 58 | _id: ObjectId 59 | name: String 60 | location: String 61 | """.strip() 62 | 63 | for line in expected_header.splitlines(): 64 | assert line.strip() in output 65 | 66 | 67 | def test_database_run(db: MongoDBDatabase) -> None: 68 | """Verify running MQL expressions returning results as strings.""" 69 | 70 | # Provision. 71 | db._client[DB_NAME]["user"].delete_many({}) 72 | user = dict(name="Harrison", bio="That is my Bio " * 24) 73 | db._client[DB_NAME]["user"].insert_one(user) 74 | 75 | # Query and verify. 76 | command = """db.user.aggregate([ { "$match": { "name": "Harrison" } } ])""" 77 | output = db.run(command) 78 | assert isinstance(output, str) 79 | docs = json.loads(output.strip()) 80 | del docs[0]["_id"] 81 | del user["_id"] 82 | assert docs[0] == user 83 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_parent_document.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | from typing import List 3 | 4 | from langchain_core.documents import Document 5 | from langchain_core.embeddings import Embeddings 6 | from langchain_text_splitters import RecursiveCharacterTextSplitter 7 | from pymongo import MongoClient 8 | from pymongo.driver_info import DriverInfo 9 | 10 | from langchain_mongodb.docstores import MongoDBDocStore 11 | from langchain_mongodb.index import create_vector_search_index 12 | from langchain_mongodb.retrievers import ( 13 | MongoDBAtlasParentDocumentRetriever, 14 | ) 15 | 16 | from ..utils import CONNECTION_STRING, DB_NAME, PatchedMongoDBAtlasVectorSearch 17 | 18 | COLLECTION_NAME = "langchain_test_parent_document_combined" 19 | VECTOR_INDEX_NAME = "langchain-test-parent-document-vector-index" 20 | EMBEDDING_FIELD = "embedding" 21 | TEXT_FIELD = "page_content" 22 | SIMILARITY = "cosine" 23 | TIMEOUT = 60.0 24 | 25 | 26 | def test_1clxn_retriever( 27 | technical_report_pages: List[Document], 28 | embedding: Embeddings, 29 | dimensions: int, 30 | ) -> None: 31 | # Setup 32 | client: MongoClient = MongoClient( 33 | CONNECTION_STRING, 34 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), 35 | ) 36 | db = client[DB_NAME] 37 | combined_clxn = db[COLLECTION_NAME] 38 | if COLLECTION_NAME not in db.list_collection_names(): 39 | db.create_collection(COLLECTION_NAME) 40 | # Clean up 41 | combined_clxn.delete_many({}) 42 | # Create Search Index if it doesn't exist 43 | sixs = list(combined_clxn.list_search_indexes()) 44 | if len(sixs) == 0: 45 | create_vector_search_index( 46 | collection=combined_clxn, 47 | index_name=VECTOR_INDEX_NAME, 48 | dimensions=dimensions, 49 | path=EMBEDDING_FIELD, 50 | similarity=SIMILARITY, 51 | wait_until_complete=TIMEOUT, 52 | ) 53 | # Create Vector and Doc Stores 54 | vectorstore = PatchedMongoDBAtlasVectorSearch( 55 | collection=combined_clxn, 56 | embedding=embedding, 57 | index_name=VECTOR_INDEX_NAME, 58 | text_key=TEXT_FIELD, 59 | embedding_key=EMBEDDING_FIELD, 60 | relevance_score_fn=SIMILARITY, 61 | ) 62 | docstore = MongoDBDocStore(collection=combined_clxn, text_key=TEXT_FIELD) 63 | # Combine into a ParentDocumentRetriever 64 | retriever = MongoDBAtlasParentDocumentRetriever( 65 | vectorstore=vectorstore, 66 | docstore=docstore, 67 | child_splitter=RecursiveCharacterTextSplitter(chunk_size=400), 68 | ) 69 | # Add documents (splitting, creating embedding, adding to vectorstore and docstore) 70 | retriever.add_documents(technical_report_pages) 71 | # invoke the retriever with a query 72 | question = "What percentage of the Uniform Bar Examination can GPT4 pass?" 73 | responses = retriever.invoke(question) 74 | 75 | assert len(responses) == 3 76 | assert all("GPT-4" in doc.page_content for doc in responses) 77 | assert {4, 5, 29} == set(doc.metadata["page"] for doc in responses) 78 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_retrievers_standard.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from langchain_core.documents import Document 4 | from langchain_tests.integration_tests import ( 5 | RetrieversIntegrationTests, 6 | ) 7 | from pymongo import MongoClient 8 | from pymongo.collection import Collection 9 | 10 | from langchain_mongodb import MongoDBAtlasVectorSearch 11 | from langchain_mongodb.index import ( 12 | create_fulltext_search_index, 13 | ) 14 | from langchain_mongodb.retrievers import ( 15 | MongoDBAtlasFullTextSearchRetriever, 16 | MongoDBAtlasHybridSearchRetriever, 17 | ) 18 | 19 | from ..utils import ( 20 | CONNECTION_STRING, 21 | DB_NAME, 22 | TIMEOUT, 23 | ConsistentFakeEmbeddings, 24 | PatchedMongoDBAtlasVectorSearch, 25 | ) 26 | 27 | DIMENSIONS = 5 28 | COLLECTION_NAME = "langchain_test_retrievers_standard" 29 | VECTOR_INDEX_NAME = "vector_index" 30 | PAGE_CONTENT_FIELD = "text" 31 | SEARCH_INDEX_NAME = "text_index" 32 | 33 | 34 | def setup_test() -> tuple[Collection, MongoDBAtlasVectorSearch]: 35 | client = MongoClient(CONNECTION_STRING) 36 | coll = client[DB_NAME][COLLECTION_NAME] 37 | 38 | # Set up the vector search index and add the documents if needed. 39 | vs = PatchedMongoDBAtlasVectorSearch( 40 | coll, 41 | embedding=ConsistentFakeEmbeddings(DIMENSIONS), 42 | dimensions=DIMENSIONS, 43 | index_name=VECTOR_INDEX_NAME, 44 | text_key=PAGE_CONTENT_FIELD, 45 | auto_index_timeout=TIMEOUT, 46 | ) 47 | 48 | if coll.count_documents({}) == 0: 49 | vs.add_documents( 50 | [ 51 | Document(page_content="In 2023, I visited Paris"), 52 | Document(page_content="In 2022, I visited New York"), 53 | Document(page_content="In 2021, I visited New Orleans"), 54 | Document(page_content="Sandwiches are beautiful. Sandwiches are fine."), 55 | ] 56 | ) 57 | 58 | # Set up the search index if needed. 59 | if not any([ix["name"] == SEARCH_INDEX_NAME for ix in coll.list_search_indexes()]): 60 | create_fulltext_search_index( 61 | collection=coll, 62 | index_name=SEARCH_INDEX_NAME, 63 | field=PAGE_CONTENT_FIELD, 64 | wait_until_complete=TIMEOUT, 65 | ) 66 | 67 | return coll, vs 68 | 69 | 70 | class TestMongoDBAtlasFullTextSearchRetriever(RetrieversIntegrationTests): 71 | @property 72 | def retriever_constructor(self) -> Type[MongoDBAtlasFullTextSearchRetriever]: 73 | """Get a retriever for integration tests.""" 74 | return MongoDBAtlasFullTextSearchRetriever 75 | 76 | @property 77 | def retriever_constructor_params(self) -> dict: 78 | coll, _ = setup_test() 79 | return { 80 | "collection": coll, 81 | "search_index_name": SEARCH_INDEX_NAME, 82 | "search_field": PAGE_CONTENT_FIELD, 83 | } 84 | 85 | @property 86 | def retriever_query_example(self) -> str: 87 | """ 88 | Returns a str representing the "query" of an example retriever call. 89 | """ 90 | return "When was the last time I visited new orleans?" 91 | 92 | 93 | class TestMongoDBAtlasHybridSearchRetriever(RetrieversIntegrationTests): 94 | @property 95 | def retriever_constructor(self) -> Type[MongoDBAtlasHybridSearchRetriever]: 96 | """Get a retriever for integration tests.""" 97 | return MongoDBAtlasHybridSearchRetriever 98 | 99 | @property 100 | def retriever_constructor_params(self) -> dict: 101 | coll, vs = setup_test() 102 | return { 103 | "vectorstore": vs, 104 | "collection": coll, 105 | "search_index_name": SEARCH_INDEX_NAME, 106 | "search_field": PAGE_CONTENT_FIELD, 107 | } 108 | 109 | @property 110 | def retriever_query_example(self) -> str: 111 | """ 112 | Returns a str representing the "query" of an example retriever call. 113 | """ 114 | return "When was the last time I visited new orleans?" 115 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_tools.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from langchain_tests.integration_tests import ToolsIntegrationTests 4 | 5 | from langchain_mongodb.agent_toolkit.tool import ( 6 | InfoMongoDBDatabaseTool, 7 | ListMongoDBDatabaseTool, 8 | QueryMongoDBCheckerTool, 9 | QueryMongoDBDatabaseTool, 10 | ) 11 | from tests.utils import create_database, create_llm 12 | 13 | 14 | class TestQueryMongoDBDatabaseToolIntegration(ToolsIntegrationTests): 15 | @property 16 | def tool_constructor(self) -> Type[QueryMongoDBDatabaseTool]: 17 | return QueryMongoDBDatabaseTool 18 | 19 | @property 20 | def tool_constructor_params(self) -> dict: 21 | return dict(db=create_database()) 22 | 23 | @property 24 | def tool_invoke_params_example(self) -> dict: 25 | return dict(query='db.test.aggregate([{"$match": {}}])') 26 | 27 | 28 | class TestInfoMongoDBDatabaseToolIntegration(ToolsIntegrationTests): 29 | @property 30 | def tool_constructor(self) -> Type[InfoMongoDBDatabaseTool]: 31 | return InfoMongoDBDatabaseTool 32 | 33 | @property 34 | def tool_constructor_params(self) -> dict: 35 | return dict(db=create_database()) 36 | 37 | @property 38 | def tool_invoke_params_example(self) -> dict: 39 | return dict(collection_names="test") 40 | 41 | 42 | class TestListMongoDBDatabaseToolIntegration(ToolsIntegrationTests): 43 | @property 44 | def tool_constructor(self) -> Type[ListMongoDBDatabaseTool]: 45 | return ListMongoDBDatabaseTool 46 | 47 | @property 48 | def tool_constructor_params(self) -> dict: 49 | return dict(db=create_database()) 50 | 51 | @property 52 | def tool_invoke_params_example(self) -> dict: 53 | return dict() 54 | 55 | 56 | class TestQueryMongoDBCheckerToolIntegration(ToolsIntegrationTests): 57 | @property 58 | def tool_constructor(self) -> Type[QueryMongoDBCheckerTool]: 59 | return QueryMongoDBCheckerTool 60 | 61 | @property 62 | def tool_constructor_params(self) -> dict: 63 | return dict(db=create_database(), llm=create_llm()) 64 | 65 | @property 66 | def tool_invoke_params_example(self) -> dict: 67 | return dict(query='db.test.aggregate([{"$match": {}}])') 68 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py: -------------------------------------------------------------------------------- 1 | """Test MongoDBAtlasVectorSearch.from_documents.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List 6 | 7 | import pytest # type: ignore[import-not-found] 8 | from langchain_core.documents import Document 9 | from langchain_core.embeddings import Embeddings 10 | from pymongo import MongoClient 11 | from pymongo.collection import Collection 12 | 13 | from langchain_mongodb.index import ( 14 | create_vector_search_index, 15 | ) 16 | 17 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch 18 | 19 | COLLECTION_NAME = "langchain_test_from_documents" 20 | INDEX_NAME = "langchain-test-index-from-documents" 21 | DIMENSIONS = 5 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def collection(client: MongoClient) -> Collection: 26 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 27 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 28 | else: 29 | clxn = client[DB_NAME][COLLECTION_NAME] 30 | 31 | clxn.delete_many({}) 32 | 33 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): 34 | create_vector_search_index( 35 | collection=clxn, 36 | index_name=INDEX_NAME, 37 | dimensions=DIMENSIONS, 38 | path="embedding", 39 | similarity="cosine", 40 | wait_until_complete=60, 41 | ) 42 | 43 | return clxn 44 | 45 | 46 | @pytest.fixture(scope="module") 47 | def example_documents() -> List[Document]: 48 | return [ 49 | Document(page_content="Dogs are tough.", metadata={"a": 1}), 50 | Document(page_content="Cats have fluff.", metadata={"b": 1}), 51 | Document(page_content="What is a sandwich?", metadata={"c": 1}), 52 | Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}), 53 | ] 54 | 55 | 56 | @pytest.fixture(scope="module") 57 | def embeddings() -> Embeddings: 58 | return ConsistentFakeEmbeddings(DIMENSIONS) 59 | 60 | 61 | @pytest.fixture(scope="module") 62 | def vectorstore( 63 | collection: Collection, example_documents: List[Document], embeddings: Embeddings 64 | ) -> PatchedMongoDBAtlasVectorSearch: 65 | """VectorStore created with a few documents and a trivial embedding model. 66 | 67 | Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all 68 | but one important feature. It waits until all documents are fully indexed 69 | before returning control to the caller. 70 | """ 71 | return PatchedMongoDBAtlasVectorSearch.from_documents( 72 | example_documents, 73 | embedding=embeddings, 74 | collection=collection, 75 | index_name=INDEX_NAME, 76 | ) 77 | 78 | 79 | def test_default_search( 80 | vectorstore: PatchedMongoDBAtlasVectorSearch, example_documents: List[Document] 81 | ) -> None: 82 | """Test end to end construction and search.""" 83 | output = vectorstore.similarity_search("Sandwich", k=1) 84 | assert len(output) == 1 85 | # Check for the presence of the metadata key 86 | assert any( 87 | [key.page_content == output[0].page_content for key in example_documents] 88 | ) 89 | # Assert no presence of embeddings in results 90 | assert all(["embedding" not in key.metadata for key in output]) 91 | 92 | 93 | def test_search_with_embeddings(vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: 94 | output = vectorstore.similarity_search("Sandwich", k=2, include_embeddings=True) 95 | assert len(output) == 2 96 | 97 | # Assert embeddings in results 98 | assert all([key.metadata.get("embedding") for key in output]) 99 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py: -------------------------------------------------------------------------------- 1 | """Test MongoDBAtlasVectorSearch.from_documents.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Dict, Generator, List 6 | 7 | import pytest # type: ignore[import-not-found] 8 | from langchain_core.embeddings import Embeddings 9 | from pymongo import MongoClient 10 | from pymongo.collection import Collection 11 | 12 | from langchain_mongodb import MongoDBAtlasVectorSearch 13 | from langchain_mongodb.index import ( 14 | create_vector_search_index, 15 | ) 16 | 17 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch 18 | 19 | COLLECTION_NAME = "langchain_test_from_texts" 20 | INDEX_NAME = "langchain-test-index-from-texts" 21 | DIMENSIONS = 5 22 | 23 | 24 | @pytest.fixture(scope="module") 25 | def collection(client: MongoClient) -> Collection: 26 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 27 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 28 | else: 29 | clxn = client[DB_NAME][COLLECTION_NAME] 30 | 31 | clxn.delete_many({}) 32 | 33 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): 34 | create_vector_search_index( 35 | collection=clxn, 36 | index_name=INDEX_NAME, 37 | dimensions=DIMENSIONS, 38 | path="embedding", 39 | filters=["c"], 40 | similarity="cosine", 41 | wait_until_complete=60, 42 | ) 43 | 44 | return clxn 45 | 46 | 47 | @pytest.fixture(scope="module") 48 | def texts() -> List[str]: 49 | return [ 50 | "Dogs are tough.", 51 | "Cats have fluff.", 52 | "What is a sandwich?", 53 | "That fence is purple.", 54 | ] 55 | 56 | 57 | @pytest.fixture(scope="module") 58 | def metadatas() -> List[Dict]: 59 | return [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}] 60 | 61 | 62 | @pytest.fixture(scope="module") 63 | def embeddings() -> Embeddings: 64 | return ConsistentFakeEmbeddings(DIMENSIONS) 65 | 66 | 67 | @pytest.fixture(scope="module") 68 | def vectorstore( 69 | collection: Collection, 70 | texts: List[str], 71 | embeddings: Embeddings, 72 | metadatas: List[dict], 73 | ) -> Generator[MongoDBAtlasVectorSearch]: 74 | """VectorStore created with a few documents and a trivial embedding model. 75 | 76 | Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all 77 | but one important feature. It waits until all documents are fully indexed 78 | before returning control to the caller. 79 | """ 80 | vectorstore_from_texts = PatchedMongoDBAtlasVectorSearch.from_texts( 81 | texts=texts, 82 | embedding=embeddings, 83 | metadatas=metadatas, 84 | collection=collection, 85 | index_name=INDEX_NAME, 86 | ) 87 | yield vectorstore_from_texts 88 | 89 | vectorstore_from_texts.collection.delete_many({}) 90 | 91 | 92 | def test_search_with_metadatas_and_pre_filter( 93 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] 94 | ) -> None: 95 | # Confirm the presence of metadata in output 96 | output = vectorstore.similarity_search("Sandwich", k=1) 97 | assert len(output) == 1 98 | metakeys = [list(d.keys())[0] for d in metadatas] 99 | assert any([key in output[0].metadata for key in metakeys]) 100 | 101 | 102 | def test_search_filters_all( 103 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] 104 | ) -> None: 105 | # Test filtering out 106 | does_not_match_filter = vectorstore.similarity_search( 107 | "Sandwich", k=1, pre_filter={"c": {"$lte": 0}} 108 | ) 109 | assert does_not_match_filter == [] 110 | 111 | 112 | def test_search_pre_filter( 113 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict] 114 | ) -> None: 115 | # Test filtering with expected output 116 | matches_filter = vectorstore.similarity_search( 117 | "Sandwich", k=3, pre_filter={"c": {"$gt": 0}} 118 | ) 119 | assert len(matches_filter) == 1 120 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/integration_tests/test_vectorstore_standard.py: -------------------------------------------------------------------------------- 1 | """Test MongoDBAtlasVectorSearch.from_documents.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest # type: ignore[import-not-found] 6 | from langchain_core.vectorstores import VectorStore 7 | from langchain_tests.integration_tests import VectorStoreIntegrationTests 8 | from pymongo import MongoClient 9 | from pymongo.collection import Collection 10 | 11 | from langchain_mongodb.index import ( 12 | create_vector_search_index, 13 | ) 14 | 15 | from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch 16 | 17 | COLLECTION_NAME = "langchain_test_standard" 18 | INDEX_NAME = "langchain-test-index-standard" 19 | DIMENSIONS = 6 20 | 21 | 22 | @pytest.fixture 23 | def collection(client: MongoClient) -> Collection: 24 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names(): 25 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME) 26 | else: 27 | clxn = client[DB_NAME][COLLECTION_NAME] 28 | 29 | clxn.delete_many({}) 30 | 31 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): 32 | create_vector_search_index( 33 | collection=clxn, 34 | index_name=INDEX_NAME, 35 | dimensions=DIMENSIONS, 36 | path="embedding", 37 | filters=["c"], 38 | similarity="cosine", 39 | wait_until_complete=60, 40 | ) 41 | 42 | return clxn 43 | 44 | 45 | class TestMongoDBAtlasVectorSearch(VectorStoreIntegrationTests): 46 | @pytest.fixture() 47 | def vectorstore(self, collection) -> VectorStore: # type: ignore 48 | """Get an empty vectorstore for unit tests.""" 49 | store = PatchedMongoDBAtlasVectorSearch( 50 | collection, self.get_embeddings(), index_name=INDEX_NAME 51 | ) 52 | # note: store should be EMPTY at this point 53 | # if you need to delete data, you may do so here 54 | return store 55 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py: -------------------------------------------------------------------------------- 1 | import json 2 | from importlib.metadata import version 3 | 4 | import mongomock 5 | import pytest 6 | from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found] 7 | from langchain_core.messages import message_to_dict 8 | from pymongo.driver_info import DriverInfo 9 | from pytest_mock import MockerFixture 10 | 11 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory 12 | 13 | from ..utils import MockCollection 14 | 15 | 16 | class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory): 17 | def __init__(self) -> None: 18 | self.session_id = "test-session" 19 | self.database_name = "test-database" 20 | self.collection_name = "test-collection" 21 | self.collection = MockCollection() 22 | self.session_id_key = "SessionId" 23 | self.history_key = "History" 24 | self.history_size = None 25 | self.db = self.collection.database 26 | self.client = self.db.client 27 | 28 | 29 | def test_memory_with_message_store() -> None: 30 | """Test the memory with a message store.""" 31 | # setup MongoDB as a message store 32 | message_history = PatchedMongoDBChatMessageHistory() 33 | memory = ConversationBufferMemory( 34 | memory_key="baz", chat_memory=message_history, return_messages=True 35 | ) 36 | 37 | # add some messages 38 | memory.chat_memory.add_ai_message("This is me, the AI") 39 | memory.chat_memory.add_user_message("This is me, the human") 40 | 41 | # get the message history from the memory store and turn it into a json 42 | messages = memory.chat_memory.messages 43 | messages_json = json.dumps([message_to_dict(msg) for msg in messages]) 44 | 45 | assert "This is me, the AI" in messages_json 46 | assert "This is me, the human" in messages_json 47 | 48 | # remove the record from MongoDB, so the next test run won't pick it up 49 | memory.chat_memory.clear() 50 | 51 | assert memory.chat_memory.messages == [] 52 | 53 | message_history.close() 54 | assert message_history.client.is_closed 55 | 56 | 57 | def test_init_with_connection_string(mocker: MockerFixture) -> None: 58 | mock_mongo_client = mocker.patch( 59 | "langchain_mongodb.chat_message_histories.MongoClient" 60 | ) 61 | 62 | history = MongoDBChatMessageHistory( 63 | connection_string="mongodb://localhost:27017/", 64 | session_id="test-session", 65 | database_name="test-database", 66 | collection_name="test-collection", 67 | ) 68 | 69 | mock_mongo_client.assert_called_once_with( 70 | "mongodb://localhost:27017/", 71 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), 72 | ) 73 | assert history.session_id == "test-session" 74 | assert history.database_name == "test-database" 75 | assert history.collection_name == "test-collection" 76 | history.close() 77 | 78 | 79 | def test_init_with_existing_client() -> None: 80 | client = mongomock.MongoClient() # type: ignore[var-annotated] 81 | 82 | # Initialize MongoDBChatMessageHistory with the mock client 83 | history = MongoDBChatMessageHistory( 84 | connection_string=None, 85 | session_id="test-session", 86 | database_name="test-database", 87 | collection_name="test-collection", 88 | client=client, 89 | ) 90 | 91 | assert history.session_id == "test-session" 92 | 93 | # Verify that the collection is correctly created within the specified database 94 | assert "test-database" in client.list_database_names() 95 | assert "test-collection" in client["test-database"].list_collection_names() 96 | 97 | history.close() 98 | 99 | 100 | def test_init_raises_error_without_connection_or_client() -> None: 101 | with pytest.raises( 102 | ValueError, match="Either connection_string or client must be provided" 103 | ): 104 | MongoDBChatMessageHistory( 105 | session_id="test_session", 106 | connection_string=None, 107 | client=None, 108 | ) 109 | 110 | 111 | def test_init_raises_error_with_both_connection_and_client() -> None: 112 | client_mock = mongomock.MongoClient() # type: ignore[var-annotated] 113 | 114 | with pytest.raises( 115 | ValueError, match="Must provide connection_string or client, not both" 116 | ): 117 | MongoDBChatMessageHistory( 118 | connection_string="mongodb://localhost:27017/", 119 | session_id="test_session", 120 | client=client_mock, 121 | ) 122 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_mongodb import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "MongoDBAtlasVectorSearch", 5 | "MongoDBChatMessageHistory", 6 | "MongoDBCache", 7 | "MongoDBAtlasSemanticCache", 8 | ] 9 | 10 | 11 | def test_all_imports() -> None: 12 | assert sorted(EXPECTED_ALL) == sorted(__all__) 13 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/test_index.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | 3 | import pytest 4 | from pymongo import MongoClient 5 | from pymongo.collection import Collection 6 | from pymongo.errors import OperationFailure, ServerSelectionTimeoutError 7 | 8 | from langchain_mongodb import index 9 | 10 | DIMENSION = 5 11 | TIMEOUT = 120 12 | 13 | 14 | @pytest.fixture 15 | def collection() -> Collection: 16 | """Collection on MongoDB Cluster, not an Atlas one.""" 17 | client: MongoClient = MongoClient() 18 | return client["db"]["collection"] 19 | 20 | 21 | def test_create_vector_search_index(collection: Collection) -> None: 22 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): 23 | index.create_vector_search_index( 24 | collection, 25 | "index_name", 26 | DIMENSION, 27 | "embedding", 28 | "cosine", 29 | [], 30 | wait_until_complete=TIMEOUT, 31 | ) 32 | 33 | 34 | def test_drop_vector_search_index(collection: Collection) -> None: 35 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): 36 | index.drop_vector_search_index( 37 | collection, "index_name", wait_until_complete=TIMEOUT 38 | ) 39 | 40 | 41 | def test_update_vector_search_index(collection: Collection) -> None: 42 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): 43 | index.update_vector_search_index( 44 | collection, 45 | "index_name", 46 | DIMENSION, 47 | "embedding", 48 | "cosine", 49 | [], 50 | wait_until_complete=TIMEOUT, 51 | ) 52 | 53 | 54 | def test___is_index_ready(collection: Collection) -> None: 55 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)): 56 | index._is_index_ready(collection, "index_name") 57 | 58 | 59 | def test__wait_for_predicate() -> None: 60 | err = "error string" 61 | with pytest.raises(TimeoutError) as e: 62 | index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1) 63 | assert err in str(e) 64 | 65 | index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5) 66 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/test_retrievers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from langchain_text_splitters import RecursiveCharacterTextSplitter 3 | 4 | from langchain_mongodb.docstores import MongoDBDocStore 5 | from langchain_mongodb.retrievers import ( 6 | MongoDBAtlasFullTextSearchRetriever, 7 | MongoDBAtlasHybridSearchRetriever, 8 | MongoDBAtlasParentDocumentRetriever, 9 | ) 10 | from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch 11 | 12 | from ..utils import ConsistentFakeEmbeddings, MockCollection 13 | 14 | 15 | @pytest.fixture() 16 | def collection() -> MockCollection: 17 | return MockCollection() 18 | 19 | 20 | @pytest.fixture() 21 | def embeddings() -> ConsistentFakeEmbeddings: 22 | return ConsistentFakeEmbeddings() 23 | 24 | 25 | def test_full_text_search(collection): 26 | search = MongoDBAtlasFullTextSearchRetriever( 27 | collection=collection, search_index_name="foo", search_field="bar" 28 | ) 29 | search.close() 30 | assert collection.database.client.is_closed 31 | 32 | 33 | def test_hybrid_search(collection, embeddings): 34 | vs = MongoDBAtlasVectorSearch(collection, embeddings) 35 | search = MongoDBAtlasHybridSearchRetriever(vectorstore=vs, search_index_name="foo") 36 | search.close() 37 | assert collection.database.client.is_closed 38 | 39 | 40 | def test_parent_retriever(collection, embeddings): 41 | vs = MongoDBAtlasVectorSearch(collection, embeddings) 42 | ds = MongoDBDocStore(collection) 43 | cs = RecursiveCharacterTextSplitter(chunk_size=400) 44 | retriever = MongoDBAtlasParentDocumentRetriever( 45 | vectorstore=vs, docstore=ds, child_splitter=cs 46 | ) 47 | retriever.close() 48 | assert collection.database.client.is_closed 49 | -------------------------------------------------------------------------------- /libs/langchain-mongodb/tests/unit_tests/test_tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Type 4 | 5 | from langchain_tests.unit_tests import ToolsUnitTests 6 | 7 | from langchain_mongodb.agent_toolkit import MongoDBDatabase 8 | from langchain_mongodb.agent_toolkit.tool import ( 9 | InfoMongoDBDatabaseTool, 10 | ListMongoDBDatabaseTool, 11 | QueryMongoDBCheckerTool, 12 | QueryMongoDBDatabaseTool, 13 | ) 14 | 15 | from ..utils import FakeLLM, MockClient 16 | 17 | 18 | class TestQueryMongoDBDatabaseToolUnit(ToolsUnitTests): 19 | @property 20 | def tool_constructor(self) -> Type[QueryMongoDBDatabaseTool]: 21 | return QueryMongoDBDatabaseTool 22 | 23 | @property 24 | def tool_constructor_params(self) -> dict: 25 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type] 26 | 27 | @property 28 | def tool_invoke_params_example(self) -> dict: 29 | return dict(query="db.foo.aggregate()") 30 | 31 | 32 | class TestInfoMongoDBDatabaseToolUnit(ToolsUnitTests): 33 | @property 34 | def tool_constructor(self) -> Type[InfoMongoDBDatabaseTool]: 35 | return InfoMongoDBDatabaseTool 36 | 37 | @property 38 | def tool_constructor_params(self) -> dict: 39 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type] 40 | 41 | @property 42 | def tool_invoke_params_example(self) -> dict: 43 | return dict(collection_names="test") 44 | 45 | 46 | class TestListMongoDBDatabaseToolUnit(ToolsUnitTests): 47 | @property 48 | def tool_constructor(self) -> Type[ListMongoDBDatabaseTool]: 49 | return ListMongoDBDatabaseTool 50 | 51 | @property 52 | def tool_constructor_params(self) -> dict: 53 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type] 54 | 55 | @property 56 | def tool_invoke_params_example(self) -> dict: 57 | return dict() 58 | 59 | 60 | class TestQueryMongoDBCheckerToolUnit(ToolsUnitTests): 61 | @property 62 | def tool_constructor(self) -> Type[QueryMongoDBCheckerTool]: 63 | return QueryMongoDBCheckerTool 64 | 65 | @property 66 | def tool_constructor_params(self) -> dict: 67 | return dict(db=MongoDBDatabase(MockClient(), "test"), llm=FakeLLM()) # type:ignore[arg-type] 68 | 69 | @property 70 | def tool_invoke_params_example(self) -> dict: 71 | return dict(query="db.foo.aggregate()") 72 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | --- 4 | 5 | ## Changes in version 0.1.3 (2025/04/01) 6 | 7 | - Add compatibility with `pymongo.AsyncMongoClient`. 8 | 9 | ## Changes in version 0.1.2 (2025/03/26) 10 | 11 | - Add compatibility with `langgraph-checkpoint` 2.0.23. 12 | 13 | ## Changes in version 0.1.1 (2025/02/26) 14 | 15 | - Remove dependency on `langgraph`. 16 | 17 | ## Changes in version 0.1 (2024/12/13) 18 | 19 | - Initial release, added support for `MongoDBSaver`. 20 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/README.md: -------------------------------------------------------------------------------- 1 | # LangGraph Checkpoint MongoDB 2 | 3 | Implementation of LangGraph CheckpointSaver that uses MongoDB. 4 | 5 | ## Usage 6 | 7 | ```python 8 | from langgraph.checkpoint.mongodb import MongoDBSaver 9 | 10 | write_config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} 11 | read_config = {"configurable": {"thread_id": "1"}} 12 | 13 | MONGODB_URI = "mongodb://localhost:27017" 14 | DB_NAME = "checkpoint_example" 15 | 16 | with MongoDBSaver.from_conn_string(MONGODB_URI, DB_NAME) as checkpointer: 17 | checkpoint = { 18 | "v": 1, 19 | "ts": "2024-07-31T20:14:19.804150+00:00", 20 | "id": "1ef4f797-8335-6428-8001-8a1503f9b875", 21 | "channel_values": { 22 | "my_key": "meow", 23 | "node": "node" 24 | }, 25 | "channel_versions": { 26 | "__start__": 2, 27 | "my_key": 3, 28 | "start:node": 3, 29 | "node": 3 30 | }, 31 | "versions_seen": { 32 | "__input__": {}, 33 | "__start__": { 34 | "__start__": 1 35 | }, 36 | "node": { 37 | "start:node": 2 38 | } 39 | }, 40 | "pending_sends": [], 41 | } 42 | 43 | # store checkpoint 44 | checkpointer.put(write_config, checkpoint, {}, {}) 45 | 46 | # load checkpoint 47 | checkpointer.get(read_config) 48 | 49 | # list checkpoints 50 | list(checkpointer.list(read_config)) 51 | ``` 52 | 53 | ### Async 54 | 55 | ```python 56 | from langgraph.checkpoint.pymongo import AsyncMongoDBSaver 57 | 58 | async with AsyncMongoDBSaver.from_conn_string(MONGODB_URI) as checkpointer: 59 | checkpoint = { 60 | "v": 1, 61 | "ts": "2024-07-31T20:14:19.804150+00:00", 62 | "id": "1ef4f797-8335-6428-8001-8a1503f9b875", 63 | "channel_values": { 64 | "my_key": "meow", 65 | "node": "node" 66 | }, 67 | "channel_versions": { 68 | "__start__": 2, 69 | "my_key": 3, 70 | "start:node": 3, 71 | "node": 3 72 | }, 73 | "versions_seen": { 74 | "__input__": {}, 75 | "__start__": { 76 | "__start__": 1 77 | }, 78 | "node": { 79 | "start:node": 2 80 | } 81 | }, 82 | "pending_sends": [], 83 | } 84 | 85 | # store checkpoint 86 | await checkpointer.aput(write_config, checkpoint, {}, {}) 87 | 88 | # load checkpoint 89 | await checkpointer.aget(read_config) 90 | 91 | # list checkpoints 92 | [c async for c in checkpointer.alist(read_config)] 93 | ``` 94 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/justfile: -------------------------------------------------------------------------------- 1 | set shell := ["bash", "-c"] 2 | set dotenv-load 3 | set dotenv-filename := "../../.local_atlas_uri" 4 | 5 | # Default target executed when no arguments are given. 6 | [private] 7 | default: 8 | @just --list 9 | 10 | install: 11 | uv sync --frozen 12 | 13 | [group('test')] 14 | integration_tests *args="": 15 | uv run pytest tests/integration_tests/ {{args}} 16 | 17 | [group('test')] 18 | unit_tests *args="": 19 | uv run pytest tests/unit_tests {{args}} 20 | 21 | [group('test')] 22 | tests *args="": 23 | uv run pytest {{args}} 24 | 25 | [group('test')] 26 | test_watch filename: 27 | uv run ptw --snapshot-update --now . -- -vv {{filename}} 28 | 29 | [group('lint')] 30 | lint: 31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files 32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files 33 | 34 | [group('lint')] 35 | typing: 36 | uv run mypy . 37 | 38 | [group('lint')] 39 | codespell: 40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files 41 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | from .aio import AsyncMongoDBSaver 2 | from .saver import MongoDBSaver 3 | 4 | __all__ = ["MongoDBSaver", "AsyncMongoDBSaver"] 5 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | :private: 3 | Utilities for langchain-checkpoint-mongod. 4 | """ 5 | 6 | from typing import Any, Union 7 | 8 | from langgraph.checkpoint.base import CheckpointMetadata 9 | from langgraph.checkpoint.serde.base import SerializerProtocol 10 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer 11 | 12 | serde: SerializerProtocol = JsonPlusSerializer() 13 | 14 | 15 | def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata: 16 | """Deserialize metadata document 17 | 18 | The CheckpointMetadata class itself cannot be stored directly in MongoDB, 19 | but as a dictionary it can. For efficient filtering in MongoDB, 20 | we keep dict keys as strings. 21 | 22 | metadata is stored in MongoDB collection with string keys and 23 | serde serialized keys. 24 | """ 25 | if isinstance(metadata, dict): 26 | output = dict() 27 | for key, value in metadata.items(): 28 | output[key] = loads_metadata(value) 29 | return output 30 | else: 31 | return serde.loads(metadata) 32 | 33 | 34 | def dumps_metadata( 35 | metadata: Union[CheckpointMetadata, Any], 36 | ) -> Union[bytes, dict[str, Any]]: 37 | """Serialize all values in metadata dictionary. 38 | 39 | Keep dict keys as strings for efficient filtering in MongoDB 40 | """ 41 | if isinstance(metadata, dict): 42 | output = dict() 43 | for key, value in metadata.items(): 44 | output[key] = dumps_metadata(value) 45 | return output 46 | else: 47 | return serde.dumps(metadata) 48 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>1.24"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "langgraph-checkpoint-mongodb" 7 | version = "0.1.3" 8 | description = "Library with a MongoDB implementation of LangGraph checkpoint saver." 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | dependencies = [ 12 | "langgraph-checkpoint>=2.0.23,<3.0.0", 13 | "langchain-mongodb>=0.6.1", 14 | "pymongo>=4.10,<4.13", 15 | "motor>3.6.0", 16 | ] 17 | 18 | [dependency-groups] 19 | dev = [ 20 | "anyio>=4.4.0", 21 | "langchain-core>=0.3.55", 22 | "langchain-ollama>=0.2.2", 23 | "langchain-openai>=0.2.14", 24 | "langgraph>=0.3.23", 25 | "langgraph-checkpoint>=2.0.9", 26 | "pytest-asyncio>=0.21.1", 27 | "pytest>=7.2.1", 28 | "pytest-mock>=3.11.1", 29 | "pytest-watch>=4.2.0", 30 | "pytest-repeat>=0.9.3", 31 | "syrupy>=4.0.2", 32 | "pre-commit>=4.0", 33 | "mypy>=1.10.0", 34 | ] 35 | 36 | [tool.hatch.build.targets.wheel] 37 | packages = ["langgraph"] 38 | 39 | [tool.pytest.ini_options] 40 | addopts = "--strict-markers --strict-config --durations=5 -vv" 41 | markers = [ 42 | "requires: mark tests as requiring a specific library", 43 | "compile: mark placeholder test used to compile integration tests without running them", 44 | ] 45 | asyncio_mode = "auto" 46 | 47 | [tool.ruff] 48 | lint.select = [ 49 | "E", # pycodestyle 50 | "F", # Pyflakes 51 | "UP", # pyupgrade 52 | "B", # flake8-bugbear 53 | "I", # isort 54 | ] 55 | lint.ignore = ["E501", "B008", "UP007", "UP006"] 56 | 57 | [tool.mypy] 58 | # https://mypy.readthedocs.io/en/stable/config_file.html 59 | disallow_untyped_defs = true 60 | explicit_package_bases = true 61 | warn_no_return = false 62 | warn_unused_ignores = true 63 | warn_redundant_casts = true 64 | allow_redefinition = true 65 | disable_error_code = "typeddict-item, return-value" 66 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/integration_tests/README: -------------------------------------------------------------------------------- 1 | tldr; Placeholder until after https://jira.mongodb.org/browse/INTPYTHON-492 is resolved. 2 | 3 | This directory had previously held copies of test_pregel.py and test_pregel_async.py, 4 | as well as associated fixtures and utilities. 5 | [See here](https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/tests) 6 | What was here was only temporary. 7 | The files were copied from the langgraph library, not langgraph-checkpoint-X. 8 | While not in the checkpoint lib itself, they provided extensive testing of the checkpointers 9 | that were crucial to testing edge cases. 10 | 11 | As LangGraph development continued our copies have began to fail. 12 | We will determine how to proceed with maintenance in conversation with the 13 | Langchain-AI / LangGraph team. 14 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/integration_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from langchain_core.embeddings import Embeddings 5 | from langchain_ollama.embeddings import OllamaEmbeddings 6 | from langchain_openai import OpenAIEmbeddings 7 | 8 | 9 | @pytest.fixture(scope="session") 10 | def embedding() -> Embeddings: 11 | if os.environ.get("OPENAI_API_KEY"): 12 | return OpenAIEmbeddings( 13 | openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa 14 | model="text-embedding-3-small", 15 | ) 16 | 17 | return OllamaEmbeddings(model="all-minilm:l6-v2") 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def dimensions() -> int: 22 | if os.environ.get("OPENAI_API_KEY"): 23 | return 1536 24 | return 384 25 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_interrupts.py: -------------------------------------------------------------------------------- 1 | """Follows https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/time-travel""" 2 | 3 | import os 4 | from collections.abc import Generator 5 | from typing import TypedDict 6 | 7 | import pytest 8 | from langchain_core.runnables import RunnableConfig 9 | 10 | from langgraph.checkpoint.base import BaseCheckpointSaver 11 | from langgraph.checkpoint.memory import InMemorySaver 12 | from langgraph.checkpoint.mongodb import MongoDBSaver 13 | from langgraph.graph import END, StateGraph 14 | from langgraph.graph.graph import CompiledGraph 15 | 16 | # --- Configuration --- 17 | MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017") 18 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test") 19 | CHECKPOINT_CLXN_NAME = "interrupts_checkpoints" 20 | WRITES_CLXN_NAME = "interrupts_writes" 21 | 22 | 23 | @pytest.fixture(scope="function") 24 | def checkpointer_memory() -> Generator[InMemorySaver, None, None]: 25 | yield InMemorySaver() 26 | 27 | 28 | @pytest.fixture(scope="function") 29 | def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]: 30 | with MongoDBSaver.from_conn_string( 31 | MONGODB_URI, 32 | db_name=DB_NAME, 33 | checkpoint_collection_name=CHECKPOINT_CLXN_NAME, 34 | writes_collection_name=WRITES_CLXN_NAME, 35 | ) as checkpointer: 36 | checkpointer.checkpoint_collection.delete_many({}) 37 | checkpointer.writes_collection.delete_many({}) 38 | yield checkpointer 39 | checkpointer.checkpoint_collection.drop() 40 | checkpointer.writes_collection.drop() 41 | 42 | 43 | ALL_CHECKPOINTERS_SYNC = [ 44 | "checkpointer_memory", 45 | "checkpointer_mongodb", 46 | ] 47 | 48 | 49 | @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) 50 | def test(request: pytest.FixtureRequest, checkpointer_name: str) -> None: 51 | checkpointer: BaseCheckpointSaver = request.getfixturevalue(checkpointer_name) 52 | assert isinstance(checkpointer, BaseCheckpointSaver) 53 | 54 | # --- State Definition --- 55 | class State(TypedDict): 56 | value: int 57 | step: int 58 | 59 | # --- Node Definitions --- 60 | def node_inc(state: State) -> State: 61 | """Increments value and step by 1""" 62 | current_step = state.get("step", 0) 63 | return {"value": state["value"] + 1, "step": current_step + 1} 64 | 65 | def node_double(state: State) -> State: 66 | """Doubles value and increments step by 1""" 67 | current_step = state.get("step", 0) 68 | return {"value": state["value"] * 2, "step": current_step + 1} 69 | 70 | # --- Graph Construction --- 71 | builder = StateGraph(State) 72 | builder.add_node("increment", node_inc) 73 | builder.add_node("double", node_double) 74 | builder.set_entry_point("increment") 75 | builder.add_edge("increment", "double") 76 | builder.add_edge("double", END) 77 | 78 | # --- Compile Graph (with Interruption) --- 79 | # Using sync for simplicity in this demo 80 | graph: CompiledGraph = builder.compile( 81 | checkpointer=checkpointer, interrupt_after=["increment"] 82 | ) 83 | 84 | # --- Configure --- 85 | config: RunnableConfig = {"configurable": {"thread_id": "thread_#1"}} 86 | initial_input = {"value": 10, "step": 0} 87 | 88 | # --- 1st invoke, with Interruption 89 | interrupted_state = graph.invoke(initial_input, config=config) 90 | assert interrupted_state == {"value": 10 + 1, "step": 1} 91 | state_history = list(graph.get_state_history(config)) 92 | assert len(state_history) == 3 93 | # The states are returned in reverse chronological order. 94 | assert state_history[0].next == ("double",) 95 | 96 | # --- 2nd invoke, with input=None, and original config ==> continues from point of interruption 97 | final_state = graph.invoke(None, config=config) 98 | assert final_state == {"value": (10 + 1) * 2, "step": 2} 99 | state_history = list(graph.get_state_history(config)) 100 | assert len(state_history) == 4 101 | assert state_history[0].next == () 102 | assert state_history[-1].next == ("__start__",) 103 | 104 | # --- 3rd invoke, but with an input ===> the CompiledGraph is restarted. 105 | new_input = {"value": 100, "step": -100} 106 | third_state = graph.invoke(new_input, config=config) 107 | assert third_state == {"value": 101, "step": -99} 108 | 109 | # The entire state history is preserved however 110 | state_history = list(graph.get_state_history(config)) 111 | assert len(state_history) == 7 112 | assert state_history[0].next == ("double",) 113 | assert state_history[2].next == ("__start__",) 114 | 115 | # --- Upstate state and continue from interrupt 116 | updated_state = {"value": 1000, "step": 1000} 117 | updated_config = graph.update_state(config, updated_state) 118 | final_state = graph.invoke(input=None, config=updated_config) 119 | assert final_state == {"value": 2000, "step": 1001} 120 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_sanity.py: -------------------------------------------------------------------------------- 1 | # Placeholder so that Monorepo CI passes until after https://jira.mongodb.org/browse/INTPYTHON-492 is resolved. 2 | def test_sanity() -> None: 3 | assert True 4 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/unit_tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from langchain_core.runnables import RunnableConfig 5 | 6 | from langgraph.checkpoint.base import ( 7 | CheckpointMetadata, 8 | create_checkpoint, 9 | empty_checkpoint, 10 | ) 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def input_data() -> dict: 15 | """Setup and store conveniently in a single dictionary.""" 16 | inputs: dict[str, Any] = {} 17 | 18 | inputs["config_1"] = RunnableConfig( 19 | configurable=dict(thread_id="thread-1", thread_ts="1", checkpoint_ns="") 20 | ) # config_1 tests deprecated thread_ts 21 | 22 | inputs["config_2"] = RunnableConfig( 23 | configurable=dict(thread_id="thread-2", checkpoint_id="2", checkpoint_ns="") 24 | ) 25 | 26 | inputs["config_3"] = RunnableConfig( 27 | configurable=dict( 28 | thread_id="thread-2", checkpoint_id="2-inner", checkpoint_ns="inner" 29 | ) 30 | ) 31 | 32 | inputs["chkpnt_1"] = empty_checkpoint() 33 | inputs["chkpnt_2"] = create_checkpoint(inputs["chkpnt_1"], {}, 1) 34 | inputs["chkpnt_3"] = empty_checkpoint() 35 | 36 | inputs["metadata_1"] = CheckpointMetadata( 37 | source="input", step=2, writes={}, score=1 38 | ) 39 | inputs["metadata_2"] = CheckpointMetadata( 40 | source="loop", step=1, writes={"foo": "bar"}, score=None 41 | ) 42 | inputs["metadata_3"] = CheckpointMetadata() 43 | 44 | return inputs 45 | -------------------------------------------------------------------------------- /libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | import pytest 5 | from bson.errors import InvalidDocument 6 | from motor.motor_asyncio import AsyncIOMotorClient 7 | 8 | from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver 9 | 10 | MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017") 11 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test") 12 | COLLECTION_NAME = "sync_checkpoints_aio" 13 | 14 | 15 | async def test_asearch(input_data: dict[str, Any]) -> None: 16 | # Clear collections if they exist 17 | client: AsyncIOMotorClient = AsyncIOMotorClient(MONGODB_URI) 18 | db = client[DB_NAME] 19 | 20 | for clxn in await db.list_collection_names(): 21 | await db.drop_collection(clxn) 22 | 23 | async with AsyncMongoDBSaver.from_conn_string( 24 | MONGODB_URI, DB_NAME, COLLECTION_NAME 25 | ) as saver: 26 | # save checkpoints 27 | await saver.aput( 28 | input_data["config_1"], 29 | input_data["chkpnt_1"], 30 | input_data["metadata_1"], 31 | {}, 32 | ) 33 | await saver.aput( 34 | input_data["config_2"], 35 | input_data["chkpnt_2"], 36 | input_data["metadata_2"], 37 | {}, 38 | ) 39 | await saver.aput( 40 | input_data["config_3"], 41 | input_data["chkpnt_3"], 42 | input_data["metadata_3"], 43 | {}, 44 | ) 45 | 46 | # call method / assertions 47 | query_1 = {"source": "input"} # search by 1 key 48 | query_2 = { 49 | "step": 1, 50 | "writes": {"foo": "bar"}, 51 | } # search by multiple keys 52 | query_3: dict[str, Any] = {} # search by no keys, return all checkpoints 53 | query_4 = {"source": "update", "step": 1} # no match 54 | 55 | search_results_1 = [c async for c in saver.alist(None, filter=query_1)] 56 | assert len(search_results_1) == 1 57 | assert search_results_1[0].metadata == input_data["metadata_1"] 58 | 59 | search_results_2 = [c async for c in saver.alist(None, filter=query_2)] 60 | assert len(search_results_2) == 1 61 | assert search_results_2[0].metadata == input_data["metadata_2"] 62 | 63 | search_results_3 = [c async for c in saver.alist(None, filter=query_3)] 64 | assert len(search_results_3) == 3 65 | 66 | search_results_4 = [c async for c in saver.alist(None, filter=query_4)] 67 | assert len(search_results_4) == 0 68 | 69 | # search by config (defaults to checkpoints across all namespaces) 70 | search_results_5 = [ 71 | c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}}) 72 | ] 73 | assert len(search_results_5) == 2 74 | assert { 75 | search_results_5[0].config["configurable"]["checkpoint_ns"], 76 | search_results_5[1].config["configurable"]["checkpoint_ns"], 77 | } == {"", "inner"} 78 | 79 | 80 | async def test_null_chars(input_data: dict[str, Any]) -> None: 81 | """In MongoDB string *values* can be any valid UTF-8 including nulls. 82 | *Field names*, however, cannot contain nulls characters.""" 83 | async with AsyncMongoDBSaver.from_conn_string( 84 | MONGODB_URI, DB_NAME, COLLECTION_NAME 85 | ) as saver: 86 | null_str = "\x00abc" # string containing null character 87 | 88 | # 1. null string in field *value* 89 | null_value_cfg = await saver.aput( 90 | input_data["config_1"], 91 | input_data["chkpnt_1"], 92 | {"my_key": null_str}, 93 | {}, 94 | ) 95 | null_tuple = await saver.aget_tuple(null_value_cfg) 96 | assert null_tuple.metadata["my_key"] == null_str # type: ignore 97 | cps = [c async for c in saver.alist(None, filter={"my_key": null_str})] 98 | assert cps[0].metadata["my_key"] == null_str 99 | 100 | # 2. null string in field *name* 101 | with pytest.raises(InvalidDocument): 102 | await saver.aput( 103 | input_data["config_1"], 104 | input_data["chkpnt_1"], 105 | {null_str: "my_value"}, # type: ignore 106 | {}, 107 | ) 108 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | --- 4 | 5 | ## Changes in version 0.0.1 (2025/05/09) 6 | 7 | - Initial release, added support for `MongoDBStore`. 8 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/README.md: -------------------------------------------------------------------------------- 1 | # langraph-store-mongodb 2 | 3 | LangGraph long-term memory using MongoDB. 4 | 5 | The implementation is feature complete (v0.0.1), but we may revise the API based on feedback and further usage patterns. 6 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/justfile: -------------------------------------------------------------------------------- 1 | set shell := ["bash", "-c"] 2 | set dotenv-load 3 | set dotenv-filename := "../../.local_atlas_uri" 4 | 5 | # Default target executed when no arguments are given. 6 | [private] 7 | default: 8 | @just --list 9 | 10 | install: 11 | uv sync --frozen 12 | 13 | [group('test')] 14 | integration_tests *args="": 15 | uv run pytest tests/integration_tests/ {{args}} 16 | 17 | [group('test')] 18 | unit_tests *args="": 19 | uv run pytest tests/unit_tests {{args}} 20 | 21 | [group('test')] 22 | tests *args="": 23 | uv run pytest {{args}} 24 | 25 | [group('test')] 26 | test_watch filename: 27 | uv run ptw --snapshot-update --now . -- -vv {{filename}} 28 | 29 | [group('lint')] 30 | lint: 31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files 32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files 33 | 34 | [group('lint')] 35 | typing: 36 | uv run mypy . 37 | 38 | [group('lint')] 39 | codespell: 40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files 41 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/langgraph/store/mongodb/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | MongoDBStore, 3 | VectorIndexConfig, 4 | create_vector_index_config, 5 | ) 6 | 7 | __all__ = ["MongoDBStore", "VectorIndexConfig", "create_vector_index_config"] 8 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>1.24"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "langgraph-store-mongodb" 7 | version = "0.0.1" 8 | description = "MongoDB implementation of the LangGraph long-term memory store." 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | dependencies = [ 12 | "langgraph-checkpoint>=2.0.23,<3.0.0", 13 | "langchain-mongodb>=0.6.1", 14 | ] 15 | 16 | [dependency-groups] 17 | dev = [ 18 | "pytest-asyncio>=0.21.1", 19 | "pytest>=7.2.1", 20 | "pre-commit>=4.0", 21 | "mypy>=1.10.0", 22 | ] 23 | 24 | [tool.hatch.build.targets.wheel] 25 | packages = ["langgraph"] 26 | 27 | [tool.pytest.ini_options] 28 | addopts = "--strict-markers --strict-config --durations=5 -vv" 29 | markers = [ 30 | "requires: mark tests as requiring a specific library", 31 | "compile: mark placeholder test used to compile integration tests without running them", 32 | ] 33 | asyncio_mode = "auto" 34 | 35 | [tool.ruff] 36 | lint.select = [ 37 | "E", # pycodestyle 38 | "F", # Pyflakes 39 | "UP", # pyupgrade 40 | "B", # flake8-bugbear 41 | "I", # isort 42 | ] 43 | lint.ignore = ["E501", "B008", "UP007", "UP006"] 44 | 45 | [tool.mypy] 46 | # https://mypy.readthedocs.io/en/stable/config_file.html 47 | disallow_untyped_defs = true 48 | explicit_package_bases = true 49 | warn_no_return = false 50 | warn_unused_ignores = true 51 | warn_redundant_casts = true 52 | allow_redefinition = true 53 | disable_error_code = "typeddict-item, return-value" 54 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/tests/integration_tests/test_store_semantic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Generator 3 | from time import monotonic, sleep 4 | from typing import Callable 5 | 6 | import pytest 7 | from langchain_core.embeddings import Embeddings 8 | from pymongo import MongoClient 9 | from pymongo.collection import Collection 10 | from pymongo.errors import OperationFailure 11 | 12 | from langgraph.store.base import PutOp 13 | from langgraph.store.memory import InMemoryStore 14 | from langgraph.store.mongodb import ( 15 | MongoDBStore, 16 | create_vector_index_config, 17 | ) 18 | 19 | MONGODB_URI = os.environ.get( 20 | "MONGODB_URI", "mongodb://localhost:27017?directConnection=true" 21 | ) 22 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test") 23 | COLLECTION_NAME = "semantic_search" 24 | INDEX_NAME = "vector_index" 25 | TIMEOUT, INTERVAL = 30, 1 # timeout to index new data 26 | 27 | DIMENSIONS = 5 # Dimensions of embedding model 28 | 29 | 30 | def wait(cond: Callable, timeout: int = 15, interval: int = 1) -> None: 31 | start = monotonic() 32 | while monotonic() - start < timeout: 33 | if cond(): 34 | return 35 | else: 36 | sleep(interval) 37 | raise TimeoutError("timeout waiting for: ", cond) 38 | 39 | 40 | class StaticEmbeddings(Embeddings): 41 | """ANN Search is not tested here. That is done in langchain-mongodb.""" 42 | 43 | def embed_documents(self, texts: list[str]) -> list[list[float]]: 44 | vectors = [] 45 | for txt in texts: 46 | vectors.append(self.embed_query(txt)) 47 | return vectors 48 | 49 | def embed_query(self, text: str) -> list[float]: 50 | if "pears" in text: 51 | return [1.0] + [0.5] * (DIMENSIONS - 1) 52 | else: 53 | return [0.5] * DIMENSIONS 54 | 55 | 56 | @pytest.fixture 57 | def collection() -> Generator[Collection, None, None]: 58 | client: MongoClient = MongoClient(MONGODB_URI) 59 | db = client[DB_NAME] 60 | db.drop_collection(COLLECTION_NAME) 61 | collection = db.create_collection(COLLECTION_NAME) 62 | wait(lambda: collection.count_documents({}) == 0, TIMEOUT, INTERVAL) 63 | try: 64 | collection.drop_search_index(INDEX_NAME) 65 | except OperationFailure: 66 | pass 67 | wait( 68 | lambda: len(collection.list_search_indexes().to_list()) == 0, TIMEOUT, INTERVAL 69 | ) 70 | 71 | yield collection 72 | 73 | client.close() 74 | 75 | 76 | def test_filters(collection: Collection) -> None: 77 | """Test permutations of namespace_prefix in filter.""" 78 | 79 | index_config = create_vector_index_config( 80 | name=INDEX_NAME, 81 | dims=DIMENSIONS, 82 | fields=["product"], 83 | embed=StaticEmbeddings(), # embedding 84 | filters=["metadata.available"], 85 | ) 86 | store_mdb = MongoDBStore( 87 | collection, index_config=index_config, auto_index_timeout=TIMEOUT 88 | ) 89 | store_in_mem = InMemoryStore(index=index_config) 90 | 91 | namespaces = [ 92 | ("a",), 93 | ("a", "b", "c"), 94 | ("a", "b", "c", "d"), 95 | ] 96 | 97 | products = ["apples", "oranges", "pears"] 98 | 99 | # Add some indexed data 100 | put_ops = [] 101 | for i, ns in enumerate(namespaces): 102 | put_ops.append( 103 | PutOp( 104 | namespace=ns, 105 | key=f"id_{i}", 106 | value={ 107 | "product": products[i], 108 | "metadata": {"available": bool(i % 2), "grade": "A" * (i + 1)}, 109 | }, 110 | ) 111 | ) 112 | 113 | store_mdb.batch(put_ops) 114 | store_in_mem.batch(put_ops) 115 | 116 | query = "What is the grade of our pears?" 117 | # Case 1: fields is a string: 118 | namespace_prefix = ("a",) # filter ("a",) catches all docs 119 | wait( 120 | lambda: len(store_mdb.search(namespace_prefix, query=query)) == len(products), 121 | TIMEOUT, 122 | INTERVAL, 123 | ) 124 | 125 | result_mdb = store_mdb.search(namespace_prefix, query=query) 126 | assert result_mdb[0].value["product"] == "pears" # test sorted by score 127 | 128 | result_mem = store_in_mem.search(namespace_prefix, query=query) 129 | assert len(result_mem) == len(products) 130 | 131 | # Case 2: filter on 2nd namespace in hierarchy 132 | namespace_prefix = ("a", "b") 133 | result_mem = store_in_mem.search(namespace_prefix, query=query) 134 | result_mdb = store_mdb.search(namespace_prefix, query=query) 135 | # filter ("a",) catches all docs 136 | assert len(result_mem) == len(result_mdb) == len(products) - 1 137 | assert result_mdb[0].value["product"] == "pears" 138 | 139 | # Case 3: Empty namespace_prefix 140 | namespace_prefix = ("",) 141 | result_mem = store_in_mem.search(namespace_prefix, query=query) 142 | result_mdb = store_mdb.search(namespace_prefix, query=query) 143 | assert len(result_mem) == len(result_mdb) == 0 144 | 145 | # Case 4: With filter on value (nested) 146 | namespace_prefix = ("a",) 147 | available = {"metadata.available": True} 148 | result_mdb = store_mdb.search(namespace_prefix, query=query, filter=available) 149 | assert result_mdb[0].value["product"] == "oranges" 150 | assert len(result_mdb) == 1 151 | -------------------------------------------------------------------------------- /libs/langgraph-store-mongodb/tests/unit_tests/test_store_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Generator 3 | from typing import Union 4 | 5 | import pytest 6 | from pymongo import MongoClient 7 | 8 | from langgraph.store.base import ( 9 | GetOp, 10 | ListNamespacesOp, 11 | PutOp, 12 | SearchOp, 13 | TTLConfig, 14 | ) 15 | from langgraph.store.mongodb import ( 16 | MongoDBStore, 17 | ) 18 | 19 | MONGODB_URI = os.environ.get( 20 | "MONGODB_URI", "mongodb://localhost:27017?directConnection=true" 21 | ) 22 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test") 23 | COLLECTION_NAME = "async_store" 24 | 25 | 26 | @pytest.fixture 27 | def store() -> Generator: 28 | """Create a simple store following that in base's test_list_namespaces_basic""" 29 | client: MongoClient = MongoClient(MONGODB_URI) 30 | collection = client[DB_NAME][COLLECTION_NAME] 31 | collection.delete_many({}) 32 | collection.drop_indexes() 33 | 34 | yield MongoDBStore( 35 | collection, 36 | ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True), 37 | ) 38 | 39 | if client: 40 | client.close() 41 | 42 | 43 | async def test_batch_async(store: MongoDBStore) -> None: 44 | N = 100 45 | M = 5 46 | ops: list[Union[PutOp, GetOp, ListNamespacesOp, SearchOp]] = [] 47 | for m in range(M): 48 | for i in range(N): 49 | ops.append( 50 | PutOp( 51 | ("test", "foo", "bar", "baz", str(m % 2)), 52 | f"key{i}", 53 | value={"foo": "bar" + str(i)}, 54 | ) 55 | ) 56 | ops.append( 57 | GetOp( 58 | ("test", "foo", "bar", "baz", str(m % 2)), 59 | f"key{i}", 60 | ) 61 | ) 62 | ops.append( 63 | ListNamespacesOp( 64 | match_conditions=None, 65 | max_depth=m + 1, 66 | ) 67 | ) 68 | ops.append( 69 | SearchOp( 70 | ("test",), 71 | ) 72 | ) 73 | ops.append( 74 | PutOp( 75 | ("test", "foo", "bar", "baz", str(m % 2)), 76 | f"key{i}", 77 | value={"foo": "bar" + str(i)}, 78 | ) 79 | ) 80 | ops.append( 81 | PutOp(("test", "foo", "bar", "baz", str(m % 2)), f"key{i}", None) 82 | ) 83 | 84 | results = await store.abatch(ops) 85 | assert len(results) == M * N * 6 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "langchain-mongodb-monorepo" 3 | version = "0.1.0" 4 | description = "LangChain MongoDB mono-repo" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | license = { file = "LICENSE" } 8 | dependencies = [] 9 | 10 | [dependency-groups] 11 | dev = [ 12 | "autodoc-pydantic>=2.2.0", 13 | "pydata-sphinx-theme>=0.16.1", 14 | "sphinx-design>=0.6.1", 15 | "sphinx>=7.4.7", 16 | "sphinx-copybutton>=0.5.2", 17 | "toml>=0.10.2", 18 | "langchain-core>=0.3.30", 19 | "sphinxcontrib-googleanalytics>=0.4", 20 | "langchain-mongodb", 21 | "langgraph-checkpoint-mongodb", 22 | "langgraph-store-mongodb", 23 | "langchain-community>=0.3.14", 24 | "myst-parser>=3.0.1", 25 | "ipython>=8.18.1", 26 | ] 27 | 28 | [tool.uv.sources] 29 | langchain-mongodb = { path = "libs/langchain-mongodb", editable = true } 30 | langgraph-checkpoint-mongodb = { path = "libs/langgraph-checkpoint-mongodb", editable = true } 31 | langgraph-store-mongodb = { path = "libs/langgraph-store-mongodb", editable = true } 32 | -------------------------------------------------------------------------------- /scripts/setup_ollama.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | ollama pull all-minilm:l6-v2 5 | ollama pull llama3:8b 6 | -------------------------------------------------------------------------------- /scripts/start_local_atlas.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | echo "Starting the container" 5 | 6 | IMAGE=mongodb/mongodb-atlas-local:latest 7 | DOCKER=$(which docker || which podman) 8 | 9 | $DOCKER pull $IMAGE 10 | 11 | $DOCKER kill mongodb_atlas_local || true 12 | 13 | CONTAINER_ID=$($DOCKER run --rm -d --name mongodb_atlas_local -P $IMAGE) 14 | 15 | function wait() { 16 | CONTAINER_ID=$1 17 | echo "waiting for container to become healthy..." 18 | $DOCKER logs mongodb_atlas_local 19 | } 20 | 21 | wait "$CONTAINER_ID" 22 | 23 | EXPOSED_PORT=$($DOCKER inspect --format='{{ (index (index .NetworkSettings.Ports "27017/tcp") 0).HostPort }}' "$CONTAINER_ID") 24 | export MONGODB_URI="mongodb://127.0.0.1:$EXPOSED_PORT/?directConnection=true" 25 | SCRIPT_DIR=$(realpath "$(dirname ${BASH_SOURCE[0]})") 26 | ROOT_DIR=$(dirname $SCRIPT_DIR) 27 | echo "MONGODB_URI=$MONGODB_URI" > $ROOT_DIR/.local_atlas_uri 28 | 29 | # Sleep for a bit to let all services start. 30 | sleep 5 31 | -------------------------------------------------------------------------------- /scripts/update-locks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | python -m uv lock 5 | 6 | pushd libs/langgraph-checkpoint-mongodb 7 | python -m uv lock 8 | popd 9 | 10 | pushd libs/langchain-mongodb 11 | python -m uv lock 12 | popd 13 | --------------------------------------------------------------------------------