├── .github
├── dependabot.yml
├── scripts
│ └── check_diff.py
└── workflows
│ ├── _lint.yml
│ ├── _release.yml
│ ├── _test.yml
│ ├── _test_release.yml
│ ├── ci.yml
│ ├── codeql.yml
│ └── zizmor.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── RELEASE.md
├── docs
├── _extensions
│ └── gallery_directive.py
├── _static
│ ├── favicon.png
│ ├── wordmark-api-dark.svg
│ └── wordmark-api.svg
├── conf.py
├── create_api_rst.py
└── templates
│ ├── class.rst
│ ├── enum.rst
│ ├── function.rst
│ ├── langchain_docs.html
│ ├── pydantic.rst
│ ├── runnable_non_pydantic.rst
│ ├── runnable_pydantic.rst
│ └── typeddict.rst
├── justfile
├── libs
├── langchain-mongodb
│ ├── .gitignore
│ ├── CHANGELOG.md
│ ├── LICENSE
│ ├── README.md
│ ├── justfile
│ ├── langchain_mongodb
│ │ ├── __init__.py
│ │ ├── agent_toolkit
│ │ │ ├── __init__.py
│ │ │ ├── database.py
│ │ │ ├── prompt.py
│ │ │ ├── tool.py
│ │ │ └── toolkit.py
│ │ ├── cache.py
│ │ ├── chat_message_histories.py
│ │ ├── docstores.py
│ │ ├── graphrag
│ │ │ ├── __init__.py
│ │ │ ├── example_templates.py
│ │ │ ├── graph.py
│ │ │ ├── prompts.py
│ │ │ └── schema.py
│ │ ├── index.py
│ │ ├── indexes.py
│ │ ├── loaders.py
│ │ ├── pipelines.py
│ │ ├── py.typed
│ │ ├── retrievers
│ │ │ ├── __init__.py
│ │ │ ├── full_text_search.py
│ │ │ ├── graphrag.py
│ │ │ ├── hybrid_search.py
│ │ │ ├── parent_document.py
│ │ │ └── self_querying.py
│ │ ├── utils.py
│ │ └── vectorstores.py
│ ├── pyproject.toml
│ ├── tests
│ │ ├── __init__.py
│ │ ├── integration_tests
│ │ │ ├── __init__.py
│ │ │ ├── conftest.py
│ │ │ ├── test_agent_toolkit.py
│ │ │ ├── test_cache.py
│ │ │ ├── test_chain_example.py
│ │ │ ├── test_chat_message_histories.py
│ │ │ ├── test_compile.py
│ │ │ ├── test_docstore.py
│ │ │ ├── test_graphrag.py
│ │ │ ├── test_index.py
│ │ │ ├── test_indexes.py
│ │ │ ├── test_loaders.py
│ │ │ ├── test_mmr.py
│ │ │ ├── test_mongodb_database.py
│ │ │ ├── test_parent_document.py
│ │ │ ├── test_retriever_selfquerying.py
│ │ │ ├── test_retrievers.py
│ │ │ ├── test_retrievers_standard.py
│ │ │ ├── test_tools.py
│ │ │ ├── test_vectorstore_add_delete.py
│ │ │ ├── test_vectorstore_from_documents.py
│ │ │ ├── test_vectorstore_from_texts.py
│ │ │ └── test_vectorstore_standard.py
│ │ ├── unit_tests
│ │ │ ├── __init__.py
│ │ │ ├── test_cache.py
│ │ │ ├── test_chat_message_histories.py
│ │ │ ├── test_imports.py
│ │ │ ├── test_index.py
│ │ │ ├── test_retrievers.py
│ │ │ ├── test_tools.py
│ │ │ └── test_vectorstores.py
│ │ └── utils.py
│ └── uv.lock
├── langgraph-checkpoint-mongodb
│ ├── CHANGELOG.md
│ ├── README.md
│ ├── justfile
│ ├── langgraph
│ │ └── checkpoint
│ │ │ └── mongodb
│ │ │ ├── __init__.py
│ │ │ ├── aio.py
│ │ │ ├── saver.py
│ │ │ └── utils.py
│ ├── pyproject.toml
│ ├── tests
│ │ ├── __snapshots__
│ │ │ ├── test_pregel.ambr
│ │ │ └── test_pregel_async.ambr
│ │ ├── integration_tests
│ │ │ ├── README
│ │ │ ├── conftest.py
│ │ │ ├── test_interrupts.py
│ │ │ └── test_sanity.py
│ │ └── unit_tests
│ │ │ ├── conftest.py
│ │ │ ├── test_async.py
│ │ │ ├── test_delete_thread.py
│ │ │ └── test_sync.py
│ └── uv.lock
└── langgraph-store-mongodb
│ ├── CHANGELOG.md
│ ├── README.md
│ ├── justfile
│ ├── langgraph
│ └── store
│ │ └── mongodb
│ │ ├── __init__.py
│ │ └── base.py
│ ├── pyproject.toml
│ ├── tests
│ ├── integration_tests
│ │ └── test_store_semantic.py
│ └── unit_tests
│ │ ├── test_store.py
│ │ └── test_store_async.py
│ └── uv.lock
├── pyproject.toml
├── scripts
├── setup_ollama.sh
├── start_local_atlas.sh
└── update-locks.sh
└── uv.lock
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | # GitHub Actions
4 | - package-ecosystem: "github-actions"
5 | directory: "/"
6 | schedule:
7 | interval: "weekly"
8 | groups:
9 | actions:
10 | patterns:
11 | - "*"
12 | # Python
13 | - package-ecosystem: "pip"
14 | directory: "libs/langchain-mongodb"
15 | schedule:
16 | interval: "weekly"
17 | - package-ecosystem: "pip"
18 | directory: "libs/langgraph-checkpoint-mongodb"
19 | schedule:
20 | interval: "weekly"
21 | - package-ecosystem: "pip"
22 | directory: "libs/langgraph-store-mongodb"
23 | schedule:
24 | interval: "weekly"
25 |
--------------------------------------------------------------------------------
/.github/scripts/check_diff.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from typing import Dict
4 |
5 | LIB_DIRS = [
6 | "libs/langchain-mongodb",
7 | "libs/langgraph-checkpoint-mongodb",
8 | "libs/langgraph-store-mongodb",
9 | ]
10 |
11 | if __name__ == "__main__":
12 | files = sys.argv[1:] # changed files
13 |
14 | dirs_to_run: Dict[str, set] = {
15 | "lint": set(),
16 | "test": set(),
17 | }
18 |
19 | if len(files) == 300:
20 | # max diff length is 300 files - there are likely files missing
21 | raise ValueError("Max diff reached. Please manually run CI on changed libs.")
22 |
23 | for file in files:
24 | if any(file.startswith(dir_) for dir_ in (".github", "scripts")):
25 | # add all LIB_DIRS for infra changes
26 | dirs_to_run["test"].update(LIB_DIRS)
27 | dirs_to_run["lint"].update(LIB_DIRS)
28 |
29 | if any(file.startswith(dir_) for dir_ in LIB_DIRS):
30 | for dir_ in LIB_DIRS:
31 | if file.startswith(dir_):
32 | dirs_to_run["test"].add(dir_)
33 | dirs_to_run["lint"].add(dir_)
34 | elif file.startswith("libs/"):
35 | raise ValueError(
36 | f"Unknown lib: {file}. check_diff.py likely needs "
37 | "an update for this new library!"
38 | )
39 |
40 | outputs = {
41 | "dirs-to-lint": list(dirs_to_run["lint"]),
42 | "dirs-to-test": list(dirs_to_run["test"]),
43 | }
44 | for key, value in outputs.items():
45 | json_output = json.dumps(value)
46 | print(f"{key}={json_output}") # noqa: T201
47 |
--------------------------------------------------------------------------------
/.github/workflows/_lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 |
11 | env:
12 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
13 |
14 | # This env var allows us to get inline annotations when ruff has complaints.
15 | RUFF_OUTPUT_FORMAT: github
16 |
17 | jobs:
18 | build:
19 | name: "run lint #${{ matrix.python-version }}"
20 | runs-on: ubuntu-latest
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | # Only lint on the min and max supported Python versions.
25 | # It's extremely unlikely that there's a lint issue on any version in between
26 | # that doesn't show up on the min or max versions.
27 | #
28 | # GitHub rate-limits how many jobs can be running at any one time.
29 | # Starting new jobs is also relatively slow,
30 | # so linting on fewer versions makes CI faster.
31 | python-version:
32 | - "3.9"
33 | - "3.12"
34 | steps:
35 | - uses: actions/checkout@v4
36 | with:
37 | persist-credentials: false
38 | - name: Install uv
39 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5
40 | with:
41 | enable-cache: true
42 | python-version: ${{ matrix.python-version }}
43 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock"
44 |
45 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
46 |
47 | - name: Install dependencies
48 | working-directory: ${{ inputs.working-directory }}
49 | run: just install
50 |
51 | - name: Get .mypy_cache to speed up mypy
52 | uses: actions/cache@v4
53 | env:
54 | SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2"
55 | with:
56 | path: |
57 | ${{ env.WORKDIR }}/.mypy_cache
58 | key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/uv.lock', inputs.working-directory)) }}
59 |
60 |
61 | - name: Analysing the code with our lint
62 | working-directory: ${{ inputs.working-directory }}
63 | run: just lint
64 |
65 | - name: Checking the types
66 | working-directory: ${{ inputs.working-directory }}
67 | run: just typing
68 |
--------------------------------------------------------------------------------
/.github/workflows/_test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 |
11 | jobs:
12 | build:
13 | defaults:
14 | run:
15 | working-directory: ${{ inputs.working-directory }}
16 | runs-on: ubuntu-latest
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | python-version:
21 | - "3.9"
22 | - "3.12"
23 | name: "run test #${{ matrix.python-version }}"
24 | steps:
25 | - uses: actions/checkout@v4
26 | with:
27 | persist-credentials: false
28 | - name: Install uv
29 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5
30 | with:
31 | enable-cache: true
32 | python-version: ${{ matrix.python-version }}
33 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock"
34 |
35 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
36 |
37 | - name: Install dependencies
38 | shell: bash
39 | run: just install
40 |
41 | - name: Start MongoDB
42 | uses: supercharge/mongodb-github-action@90004df786821b6308fb02299e5835d0dae05d0d # 1.12.0
43 |
44 | - name: Run unit tests
45 | shell: bash
46 | run: just unit_tests
47 |
48 | - name: Start local Atlas
49 | working-directory: .
50 | run: bash scripts/start_local_atlas.sh
51 |
52 | - name: Install Ollama
53 | run: curl -fsSL https://ollama.com/install.sh | sh
54 |
55 | - name: Run Ollama
56 | working-directory: .
57 | run: |
58 | ollama serve &
59 | sleep 5 # wait for the Ollama server to be ready
60 | bash scripts/setup_ollama.sh
61 |
62 | - name: Run integration tests
63 | run: just integration_tests
64 |
65 | - name: Ensure the tests did not create any additional files
66 | shell: bash
67 | run: |
68 | set -eu
69 |
70 | STATUS="$(git status)"
71 | echo "$STATUS"
72 |
73 | # grep will exit non-zero if the target message isn't found,
74 | # and `set -e` above will cause the step to fail.
75 | echo "$STATUS" | grep 'nothing to commit, working tree clean'
76 |
77 | - name: Run unit tests with minimum dependency versions
78 | run: |
79 | uv sync --python=${{ matrix.python-version }} --resolution=lowest-direct
80 | just unit_tests
81 |
--------------------------------------------------------------------------------
/.github/workflows/_test_release.yml:
--------------------------------------------------------------------------------
1 | name: test-release
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 |
11 | env:
12 | PYTHON_VERSION: "3.10"
13 |
14 | jobs:
15 | build:
16 | if: github.ref == 'refs/heads/main'
17 | runs-on: ubuntu-latest
18 | defaults:
19 | run:
20 | working-directory: ${{ inputs.working-directory }}
21 | outputs:
22 | pkg-name: ${{ steps.check-version.outputs.pkg-name }}
23 | version: ${{ steps.check-version.outputs.version }}
24 |
25 | steps:
26 | - uses: actions/checkout@v4
27 | with:
28 | persist-credentials: false
29 | - name: Install uv
30 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5
31 | with:
32 | enable-cache: true
33 | python-version: ${{ env.PYTHON_VERSION }}
34 | cache-dependency-glob: "${{ inputs.working-directory }}/uv.lock"
35 |
36 | # We want to keep this build stage *separate* from the release stage,
37 | # so that there's no sharing of permissions between them.
38 | # The release stage has trusted publishing and GitHub repo contents write access,
39 | # and we want to keep the scope of that access limited just to the release job.
40 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency)
41 | # could get access to our GitHub or PyPI credentials.
42 | #
43 | # Per the trusted publishing GitHub Action:
44 | # > It is strongly advised to separate jobs for building [...]
45 | # > from the publish job.
46 | # https://github.com/pypa/gh-action-pypi-publish#non-goals
47 | - name: Build project for distribution
48 | run: uv build
49 |
50 | - name: Upload build
51 | uses: actions/upload-artifact@v4
52 | with:
53 | name: test-dist
54 | path: ${{ inputs.working-directory }}/dist/
55 |
56 | - name: Check Version
57 | id: check-version
58 | shell: bash
59 | run: |
60 | set -eu
61 | echo pkg-name="$(uvx hatch project metadata | jq -r .name)" >> $GITHUB_OUTPUT
62 | echo version="$(uvx hatch version)" >> $GITHUB_OUTPUT
63 |
64 | publish:
65 | needs:
66 | - build
67 | runs-on: ubuntu-latest
68 | permissions:
69 | # This permission is used for trusted publishing:
70 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/
71 | #
72 | # Trusted publishing has to also be configured on PyPI for each package:
73 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/
74 | id-token: write
75 |
76 | steps:
77 | - uses: actions/checkout@v4
78 | with:
79 | persist-credentials: false
80 | - uses: actions/download-artifact@v4
81 | with:
82 | name: test-dist
83 | path: ${{ inputs.working-directory }}/dist/
84 |
85 | - name: Publish to test PyPI
86 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # release/v1
87 | with:
88 | packages-dir: ${{ inputs.working-directory }}/dist/
89 | verbose: true
90 | print-hash: true
91 | repository-url: https://test.pypi.org/legacy/
92 |
93 | # We overwrite any existing distributions with the same name and version.
94 | # This is *only for CI use* and is *extremely dangerous* otherwise!
95 | # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates
96 | skip-existing: true
97 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0
98 | attestations: false
99 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: CI
3 |
4 | on:
5 | push:
6 | branches: [main]
7 | pull_request:
8 |
9 | # If another push to the same PR or branch happens while this workflow is still running,
10 | # cancel the earlier run in favor of the next run.
11 | #
12 | # There's no point in testing an outdated version of the code. GitHub only allows
13 | # a limited number of job runners to be active at the same time, so it's better to cancel
14 | # pointless jobs early so that more useful jobs can run sooner.
15 | concurrency:
16 | group: ${{ github.workflow }}-${{ github.ref }}
17 | cancel-in-progress: true
18 |
19 | jobs:
20 | build:
21 | runs-on: ubuntu-latest
22 | steps:
23 | - uses: actions/checkout@v4
24 | with:
25 | persist-credentials: false
26 | - uses: actions/setup-python@v5
27 | with:
28 | python-version: '3.10'
29 | - id: files
30 | uses: Ana06/get-changed-files@25f79e676e7ea1868813e21465014798211fad8c # v2.3.0
31 | - id: set-matrix
32 | env:
33 | FILES: ${{ steps.files.outputs.all }}
34 | run: |
35 | python .github/scripts/check_diff.py ${FILES} >> $GITHUB_OUTPUT
36 | outputs:
37 | dirs-to-lint: ${{ steps.set-matrix.outputs.dirs-to-lint }}
38 | dirs-to-test: ${{ steps.set-matrix.outputs.dirs-to-test }}
39 | pre-commit:
40 | name: pre-commit
41 | runs-on: ubuntu-latest
42 |
43 | steps:
44 | - uses: actions/checkout@v4
45 | with:
46 | persist-credentials: false
47 | - uses: actions/setup-python@v5
48 | - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
49 | with:
50 | extra_args: --all-files --hook-stage=manual
51 | lint:
52 | name: cd ${{ matrix.working-directory }}
53 | needs: [ build ]
54 | if: ${{ needs.build.outputs.dirs-to-lint != '[]' }}
55 | strategy:
56 | fail-fast: false
57 | matrix:
58 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-lint) }}
59 | uses: ./.github/workflows/_lint.yml
60 | with:
61 | working-directory: ${{ matrix.working-directory }}
62 | secrets: inherit
63 |
64 | test:
65 | name: cd ${{ matrix.working-directory }}
66 | needs: [ build ]
67 | if: ${{ needs.build.outputs.dirs-to-test != '[]' }}
68 | strategy:
69 | fail-fast: false
70 | matrix:
71 | working-directory: ${{ fromJson(needs.build.outputs.dirs-to-test) }}
72 | uses: ./.github/workflows/_test.yml
73 | with:
74 | working-directory: ${{ matrix.working-directory }}
75 | secrets: inherit
76 |
77 | api-docs:
78 | name: API docs
79 | runs-on: ubuntu-latest
80 | steps:
81 | - uses: actions/checkout@v4
82 | with:
83 | persist-credentials: false
84 | - uses: extractions/setup-just@e33e0265a09d6d736e2ee1e0eb685ef1de4669ff # v3
85 | - name: Install uv
86 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5
87 | with:
88 | enable-cache: true
89 | python-version: 3.9
90 | - run: just docs
91 |
92 | ci_success:
93 | name: "CI Success"
94 | needs: [build, lint, test, pre-commit, api-docs]
95 | if: |
96 | always()
97 | runs-on: ubuntu-latest
98 | env:
99 | JOBS_JSON: ${{ toJSON(needs) }}
100 | RESULTS_JSON: ${{ toJSON(needs.*.result) }}
101 | EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}}
102 | steps:
103 | - name: "CI Success"
104 | run: |
105 | echo $JOBS_JSON
106 | echo $RESULTS_JSON
107 | echo "Exiting with $EXIT_CODE"
108 | exit $EXIT_CODE
109 |
--------------------------------------------------------------------------------
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | name: "CodeQL"
2 |
3 | on:
4 | push:
5 | branches: [ "main"]
6 | tags: ['*']
7 | pull_request:
8 | workflow_call:
9 | inputs:
10 | ref:
11 | required: true
12 | type: string
13 | schedule:
14 | - cron: '17 10 * * 2'
15 |
16 | jobs:
17 | analyze:
18 | name: Analyze (${{ matrix.language }})
19 | runs-on: "ubuntu-latest"
20 | timeout-minutes: 360
21 | permissions:
22 | # required for all workflows
23 | security-events: write
24 |
25 | # required to fetch internal or private CodeQL packs
26 | packages: read
27 |
28 | strategy:
29 | fail-fast: false
30 | matrix:
31 | include:
32 | - language: python
33 | build-mode: none
34 | - language: actions
35 | build-mode: none
36 | steps:
37 | - name: Checkout repository
38 | uses: actions/checkout@v4
39 | with:
40 | ref: ${{ inputs.ref }}
41 | persist-credentials: false
42 | - uses: actions/setup-python@v3
43 |
44 | # Initializes the CodeQL tools for scanning.
45 | - name: Initialize CodeQL
46 | uses: github/codeql-action/init@28deaeda66b76a05916b6923827895f2b14ab387 # v3
47 | with:
48 | languages: ${{ matrix.language }}
49 | build-mode: ${{ matrix.build-mode }}
50 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
51 | queries: security-extended
52 | config: |
53 | paths-ignore:
54 | - 'docs/**'
55 | - 'scripts/**'
56 | - '**/tests/**'
57 |
58 | - if: matrix.build-mode == 'manual'
59 | run: |
60 | pip install -e .
61 |
62 | - name: Perform CodeQL Analysis
63 | uses: github/codeql-action/analyze@28deaeda66b76a05916b6923827895f2b14ab387 # v3
64 | with:
65 | category: "/language:${{matrix.language}}"
66 |
--------------------------------------------------------------------------------
/.github/workflows/zizmor.yml:
--------------------------------------------------------------------------------
1 | name: GitHub Actions Security Analysis with zizmor 🌈
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["**"]
8 |
9 | jobs:
10 | zizmor:
11 | name: zizmor latest via PyPI
12 | runs-on: ubuntu-latest
13 | permissions:
14 | security-events: write
15 | # required for workflows in private repositories
16 | contents: read
17 | actions: read
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v4
21 | with:
22 | persist-credentials: false
23 |
24 | - name: Install the latest version of uv
25 | uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v5
26 |
27 | - name: Run zizmor 🌈
28 | run: uvx zizmor --format sarif . > results.sarif
29 | env:
30 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31 |
32 | - name: Upload SARIF file
33 | uses: github/codeql-action/upload-sarif@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3
34 | with:
35 | sarif_file: results.sarif
36 | category: zizmor
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .mypy_cache
3 | .pytest_cache
4 | .ruff_cache
5 | .mypy_cache_test
6 | .env
7 | .venv*
8 | .local_atlas_uri
9 | docs/langchain_mongodb
10 | docs/langgraph_checkpoint_mongodb
11 | docs/index.md
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 |
2 | repos:
3 | - repo: https://github.com/pre-commit/pre-commit-hooks
4 | rev: v4.5.0
5 | hooks:
6 | - id: check-added-large-files
7 | - id: check-case-conflict
8 | - id: check-toml
9 | - id: check-yaml
10 | - id: debug-statements
11 | - id: end-of-file-fixer
12 | - id: forbid-new-submodules
13 | - id: trailing-whitespace
14 | exclude_types: [json]
15 | exclude: |
16 | (?x)^(.*.ambr)$
17 |
18 | # We use the Python version instead of the original version which seems to require Docker
19 | # https://github.com/koalaman/shellcheck-precommit
20 | - repo: https://github.com/shellcheck-py/shellcheck-py
21 | rev: v0.9.0.6
22 | hooks:
23 | - id: shellcheck
24 | name: shellcheck
25 | args: ["--severity=warning"]
26 | stages: [manual]
27 |
28 | - repo: https://github.com/sirosen/check-jsonschema
29 | rev: 0.29.4
30 | hooks:
31 | - id: check-github-workflows
32 | args: ["--verbose"]
33 |
34 | - repo: https://github.com/codespell-project/codespell
35 | rev: "v2.2.6"
36 | hooks:
37 | - id: codespell
38 | args: ["-L", "nin"]
39 | stages: [manual]
40 |
41 | - repo: https://github.com/astral-sh/ruff-pre-commit
42 | # Ruff version.
43 | rev: v0.8.5
44 | hooks:
45 | # Run the linter.
46 | - id: ruff
47 | args: [ --fix ]
48 | # Run the formatter.
49 | - id: ruff-format
50 |
51 | - repo: https://github.com/tcort/markdown-link-check
52 | rev: v3.12.2
53 | hooks:
54 | - id: markdown-link-check
55 | args: [-q]
56 |
57 | - repo: local
58 | hooks:
59 | - id: update-locks
60 | name: update-locks
61 | entry: bash ./scripts/update-locks.sh
62 | language: python
63 | require_serial: true
64 | fail_fast: true
65 | additional_dependencies:
66 | - uv
67 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-22.04
4 | tools:
5 | python: "3.11"
6 | jobs:
7 | create_environment:
8 | - asdf plugin add uv
9 | - asdf install uv latest
10 | - asdf global uv latest
11 | install:
12 | - uv sync
13 | build:
14 | html:
15 | - uv run sphinx-build -T -b html docs $READTHEDOCS_OUTPUT/html
16 | sphinx:
17 | configuration: docs/conf.py
18 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guide
2 |
3 | We welcome contributions to this project! Please follow the following guidance to setup the project for development and start contributing.
4 |
5 | ### Fork and clone the repository
6 |
7 | To contribute to this project, please follow the ["fork and pull request"](https://docs.github.com/en/get-started/exploring-projects-on-github/contributing-to-a-project) workflow. Please do not try to push directly to this repo unless you are a maintainer.
8 |
9 |
10 | ### Dependency Management: uv and other env/dependency managers
11 |
12 | This project utilizes [uv](https://docs.astral.sh/uv/) v0.5.15+ as a dependency manager.
13 |
14 | Install uv: **[documentation on how to install it](https://docs.astral.sh/uv/getting-started/installation/)**.
15 |
16 | ### Local Development Dependencies
17 |
18 | The project configuration and the `justfile` for running dev commands are located under the `libs/langchain-mongodb` or `libs/langgraph-checkpoint-mongodb` directories.
19 |
20 | `just` can be [installed](https://just.systems/man/en/packages.html) from many package managers, including `brew`.
21 |
22 | ```bash
23 | cd libs/langchain-mongodb
24 | ```
25 |
26 | Install langchain-mongodb development requirements (for running langchain, running examples, linting, formatting, tests, and coverage):
27 |
28 | ```bash
29 | just install
30 | ```
31 |
32 | Then verify the installation.
33 |
34 | ```bash
35 | just unit_tests
36 | ```
37 |
38 | In order to run the integration tests, you'll also need a `MONGODB_URI` for MongoDB Atlas, as well
39 | as either an `OPENAI_API_KEY` or a configured local version of [ollama](https://ollama.com/download).
40 |
41 | We have a convenience script to start a local Atlas instance, which requires `podman`:
42 |
43 | ```bash
44 | scripts/start_local_atlas.sh
45 | ```
46 |
47 | This will create a `.local_atlas_uri` file that has the `MONGODB_URI` set. The `justfiles` are configured
48 | to read the environment variable from this file.
49 |
50 | If using `ollama`, we have a convenience script to download the library used in our tests:
51 |
52 | ```bash
53 | scripts/setup_ollama.sh
54 | ```
55 |
56 | ### Testing
57 |
58 | Unit tests cover modular logic that does not require calls to outside APIs.
59 | If you add new logic, please add a unit test.
60 |
61 | To run unit tests:
62 |
63 | ```bash
64 | just unit_tests
65 | ```
66 |
67 | Integration tests cover the end-to-end service calls as much as possible.
68 | However, in certain cases this might not be practical, so you can mock the
69 | service response for these tests. There are examples of this in the repo,
70 | that can help you write your own tests. If you have suggestions to improve
71 | this, please raise an issue.
72 |
73 | To run the integration tests:
74 |
75 | ```bash
76 | just integration_tests
77 | ```
78 |
79 | ### Formatting and Linting
80 |
81 | Formatting ensures that the code in this repo has consistent style so that the
82 | code looks more presentable and readable. It corrects these errors when you run
83 | the formatting command. Linting finds and highlights the code errors and helps
84 | avoid coding practicies that can lead to errors.
85 |
86 | Run both of these locally before submitting a PR. The CI scripts will run these
87 | when you submit a PR, and you won't be able to merge changes without fixing
88 | issues identified by the CI.
89 |
90 | We use pre-commit for [pre-commit](https://pypi.org/project/pre-commit/) for
91 | automatic formatting of the codebase.
92 |
93 | To set up `pre-commit` locally, run:
94 |
95 | ```bash
96 | brew install pre-commit
97 | pre-commit install
98 | ```
99 |
100 | #### Manual Linting and Formatting
101 |
102 | Linting and formatting for this project is done via a combination of [ruff](https://docs.astral.sh/ruff/rules/) and [mypy](http://mypy-lang.org/).
103 |
104 | To run lint:
105 |
106 | ```bash
107 | just lint
108 | ```
109 |
110 | To run the type checker:
111 |
112 | ```bash
113 | just typing
114 | ```
115 |
116 | We recognize linting can be annoying - if you do not want to do it, please contact a project maintainer, and they can help you with it. We do not want this to be a blocker for good code getting contributed.
117 |
118 | #### Building the docs
119 |
120 | The docs are build from the top level of the repository, by running `just docs`.
121 |
122 | #### Spellcheck
123 |
124 | Spellchecking for this project is done via [codespell](https://github.com/codespell-project/codespell).
125 | Note that `codespell` finds common typos, so it could have false-positive (correctly spelled but rarely used) and false-negatives (not finding misspelled) words.
126 |
127 | To check spelling for this project:
128 |
129 | ```bash
130 | just codespell
131 | ```
132 |
133 | If codespell is incorrectly flagging a word, you can skip spellcheck for that word by adding it to the codespell config in the `.pre-commit-config.yaml` file.
134 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 LangChain
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🦜️🔗 LangChain MongoDB
2 |
3 | This is a Monorepo containing partner packages of MongoDB and LangChainAI.
4 | It includes integrations between MongoDB, Atlas, LangChain, and LangGraph.
5 |
6 | It contains the following packages.
7 |
8 | - `langchain-mongodb` ([PyPI](https://pypi.org/project/langchain-mongodb/))
9 | - `langgraph-checkpoint-mongodb` ([PyPI](https://pypi.org/project/langgraph-checkpoint-mongodb/))
10 |
11 | **Note**: This repository replaces all MongoDB integrations currently present in the `langchain-community` package. Users are encouraged to migrate to this repository as soon as possible.
12 |
13 | ## Features
14 |
15 | ### LangChain
16 |
17 | #### Components
18 |
19 | - [MongoDBAtlasFullTextSearchRetriever](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#full-text-search-retriever)
20 | - [MongoDBAtlasHybridSearchRetriever](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#hybrid-search-retriever)
21 | - [MongoDBAtlasSemanticCache](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#mongodbatlassemanticcache)
22 | - [MongoDBAtlasVectorSearch](https://python.langchain.com/docs/integrations/vectorstores/mongodb_atlas/)
23 | - [MongoDBCache](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/#mongodbcache)
24 | - [MongoDBChatMessageHistory](https://python.langchain.com/docs/integrations/memory/mongodb_chat_message_history/)
25 |
26 | #### API Reference
27 |
28 | - [MongoDBAtlasParentDocumentRetriever](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/retrievers/langchain_mongodb.retrievers.parent_document.MongoDBAtlasParentDocumentRetriever.html#langchain_mongodb.retrievers.parent_document.MongoDBAtlasParentDocumentRetriever)
29 | - [MongoDBAtlasSelfQueryRetriever](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/retrievers/langchain_mongodb.retrievers.self_querying.MongoDBAtlasSelfQueryRetriever.html).
30 | - [MongoDBDatabaseToolkit](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/agent_toolkit/langchain_mongodb.agent_toolkit.toolkit.MongoDBDatabaseToolkit.html)
31 | - [MongoDBDatabase](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/agent_toolkit/langchain_mongodb.agent_toolkit.database.MongoDBDatabase.html#langchain_mongodb.agent_toolkit.database.MongoDBDatabase)
32 | - [MongoDBDocStore](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/docstores/langchain_mongodb.docstores.MongoDBDocStore.html#langchain_mongodb.docstores.MongoDBDocStore)
33 | - [MongoDBGraphStore](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/graphrag/langchain_mongodb.graphrag.graph.MongoDBGraphStore.html)
34 | - [MongoDBLoader](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/loaders/langchain_mongodb.loaders.MongoDBLoader.html#langchain_mongodb.loaders.MongoDBLoader)
35 | - [MongoDBRecordManager](https://langchain-mongodb.readthedocs.io/en/latest/langchain_mongodb/indexes/langchain_mongodb.indexes.MongoDBRecordManager.html#langchain_mongodb.indexes.MongoDBRecordManager)
36 |
37 | ### LangGraph
38 |
39 | - Checkpointing (BaseCheckpointSaver)
40 | - [MongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/aio/langgraph.checkpoint.mongodb.aio.AsyncMongoDBSaver.html#asyncmongodbsaver)
41 | - [AsyncMongoDBSaver](https://langchain-mongodb.readthedocs.io/en/latest/langgraph_checkpoint_mongodb/saver/langgraph.checkpoint.mongodb.saver.MongoDBSaver.html#mongodbsaver)
42 |
43 | ## Installation
44 |
45 | You can install the `langchain-mongodb` package from PyPI.
46 |
47 | ```bash
48 | pip install langchain-mongodb
49 | ```
50 |
51 | You can install the `langgraph-checkpoint-mongodb` package from PyPI as well:
52 |
53 | ```bash
54 | pip install langgraph-checkpoint-mongodb
55 | ```
56 |
57 | ## Usage
58 |
59 | See [langchain-mongodb usage](libs/langchain-mongodb/README.md#usage) and [langgraph-checkpoint-mongodb usage](libs/langgraph-checkpoint-mongodb/README.md#usage).
60 |
61 | For more detailed usage examples and documentation, please refer to the [LangChain documentation](https://python.langchain.com/docs/integrations/providers/mongodb_atlas/).
62 |
63 | API docs can be found on [ReadTheDocs](https://langchain-mongodb.readthedocs.io/en/latest/index.html).
64 |
65 | ## Contributing
66 |
67 | See the [Contributing Guide](CONTRIBUTING.md).
68 |
69 | ## License
70 |
71 | This project is licensed under the [MIT License](LICENSE).
72 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | # Langchain MongoDB Releases
2 |
3 | ## Prep the JIRA Release
4 |
5 | - Go to the release in [JIRA](https://jira.mongodb.org/projects/INTPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin%3Arelease-page&status=unreleased).
6 |
7 | - Make sure there are no unfinished tickets. Move them to another version if need be.
8 |
9 | - Click on the triple dot icon to the right and select "Edit".
10 |
11 | - Update the description for a quick summary.
12 |
13 | ## Prep the Release
14 |
15 | - Create a PR to bump the version and update the changelog, including today's date.
16 | Bump the minor version for new features, patch for a bug fix.
17 |
18 | - Merge the PR.
19 |
20 | ## Run the Release Workflow
21 |
22 | - Got to the release [workflow](https://github.com/langchain-ai/langchain-mongodb/actions/workflows/_release.yml).
23 |
24 | - Click "Run Workflow".
25 |
26 | - Choose the appropriate library from the dropdown.
27 |
28 | - Click "Run Workflow".
29 |
30 | - The workflow will create the tag, release to PyPI, and create the GitHub Release.
31 |
32 | ## JIRA Release
33 |
34 | - Return to the JIRA release [list](https://jira.mongodb.org/projects/INTPYTHON?selectedItem=com.atlassian.jira.jira-projects-plugin%3Arelease-page&status=unreleased).
35 |
36 | - Click "Save".
37 |
38 | - Click on the triple dot again and select "Release".
39 |
40 | - Enter today's date, and click "Confirm".
41 |
42 | - Click "Release".
43 |
44 |
45 | ## Finish the Release
46 |
47 | - Return to the release action and wait for it to complete successfully.
48 |
49 | - Announce the release on Slack. e.g "ANN: langchain-mongodb 0.5 with support for GraphRAG.
50 |
--------------------------------------------------------------------------------
/docs/_extensions/gallery_directive.py:
--------------------------------------------------------------------------------
1 | """A directive to generate a gallery of images from structured data.
2 |
3 | Generating a gallery of images that are all the same size is a common
4 | pattern in documentation, and this can be cumbersome if the gallery is
5 | generated programmatically. This directive wraps this particular use-case
6 | in a helper-directive to generate it with a single YAML configuration file.
7 |
8 | It currently exists for maintainers of the pydata-sphinx-theme,
9 | but might be abstracted into a standalone package if it proves useful.
10 | """
11 |
12 | from pathlib import Path
13 | from typing import Any, ClassVar, Dict, List
14 |
15 | from docutils import nodes
16 | from docutils.parsers.rst import directives
17 | from sphinx.application import Sphinx
18 | from sphinx.util import logging
19 | from sphinx.util.docutils import SphinxDirective
20 | from yaml import safe_load
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | TEMPLATE_GRID = """
26 | `````{{grid}} {columns}
27 | {options}
28 |
29 | {content}
30 |
31 | `````
32 | """
33 |
34 | GRID_CARD = """
35 | ````{{grid-item-card}} {title}
36 | {options}
37 |
38 | {content}
39 | ````
40 | """
41 |
42 |
43 | class GalleryGridDirective(SphinxDirective):
44 | """A directive to show a gallery of images and links in a Bootstrap grid.
45 |
46 | The grid can be generated from a YAML file that contains a list of items, or
47 | from the content of the directive (also formatted in YAML). Use the parameter
48 | "class-card" to add an additional CSS class to all cards. When specifying the grid
49 | items, you can use all parameters from "grid-item-card" directive to customize
50 | individual cards + ["image", "header", "content", "title"].
51 |
52 | Danger:
53 | This directive can only be used in the context of a Myst documentation page as
54 | the templates use Markdown flavored formatting.
55 | """
56 |
57 | name = "gallery-grid"
58 | has_content = True
59 | required_arguments = 0
60 | optional_arguments = 1
61 | final_argument_whitespace = True
62 | option_spec: ClassVar[dict[str, Any]] = {
63 | # A class to be added to the resulting container
64 | "grid-columns": directives.unchanged,
65 | "class-container": directives.unchanged,
66 | "class-card": directives.unchanged,
67 | }
68 |
69 | def run(self) -> List[nodes.Node]:
70 | """Create the gallery grid."""
71 | if self.arguments:
72 | # If an argument is given, assume it's a path to a YAML file
73 | # Parse it and load it into the directive content
74 | path_data_rel = Path(self.arguments[0])
75 | path_doc, _ = self.get_source_info()
76 | path_doc = Path(path_doc).parent
77 | path_data = (path_doc / path_data_rel).resolve()
78 | if not path_data.exists():
79 | logger.info(f"Could not find grid data at {path_data}.")
80 | nodes.text("No grid data found at {path_data}.")
81 | return
82 | yaml_string = path_data.read_text()
83 | else:
84 | yaml_string = "\n".join(self.content)
85 |
86 | # Use all the element with an img-bottom key as sites to show
87 | # and generate a card item for each of them
88 | grid_items = []
89 | for item in safe_load(yaml_string):
90 | # remove parameters that are not needed for the card options
91 | title = item.pop("title", "")
92 |
93 | # build the content of the card using some extra parameters
94 | header = f"{item.pop('header')} \n^^^ \n" if "header" in item else ""
95 | image = f"}) \n" if "image" in item else ""
96 | content = f"{item.pop('content')} \n" if "content" in item else ""
97 |
98 | # optional parameter that influence all cards
99 | if "class-card" in self.options:
100 | item["class-card"] = self.options["class-card"]
101 |
102 | loc_options_str = "\n".join(f":{k}: {v}" for k, v in item.items()) + " \n"
103 |
104 | card = GRID_CARD.format(
105 | options=loc_options_str, content=header + image + content, title=title
106 | )
107 | grid_items.append(card)
108 |
109 | # Parse the template with Sphinx Design to create an output container
110 | # Prep the options for the template grid
111 | class_ = "gallery-directive" + f' {self.options.get("class-container", "")}'
112 | options = {"gutter": 2, "class-container": class_}
113 | options_str = "\n".join(f":{k}: {v}" for k, v in options.items())
114 |
115 | # Create the directive string for the grid
116 | grid_directive = TEMPLATE_GRID.format(
117 | columns=self.options.get("grid-columns", "1 2 3 4"),
118 | options=options_str,
119 | content="\n".join(grid_items),
120 | )
121 |
122 | # Parse content as a directive so Sphinx Design processes it
123 | container = nodes.container()
124 | self.state.nested_parse([grid_directive], 0, container)
125 |
126 | # Sphinx Design outputs a container too, so just use that
127 | return [container.children[0]]
128 |
129 |
130 | def setup(app: Sphinx) -> Dict[str, Any]:
131 | """Add custom configuration to sphinx app.
132 |
133 | Args:
134 | app: the Sphinx application
135 |
136 | Returns:
137 | the 2 parallel parameters set to ``True``.
138 | """
139 | app.add_directive("gallery-grid", GalleryGridDirective)
140 |
141 | return {
142 | "parallel_read_safe": True,
143 | "parallel_write_safe": True,
144 | }
145 |
--------------------------------------------------------------------------------
/docs/_static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/docs/_static/favicon.png
--------------------------------------------------------------------------------
/docs/_static/wordmark-api-dark.svg:
--------------------------------------------------------------------------------
1 |
12 |
--------------------------------------------------------------------------------
/docs/_static/wordmark-api.svg:
--------------------------------------------------------------------------------
1 |
12 |
--------------------------------------------------------------------------------
/docs/templates/class.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block methods %}
9 | {% if methods %}
10 | .. rubric:: {{ _('Methods') }}
11 |
12 | .. autosummary::
13 | {% for item in methods %}
14 | ~{{ item }}
15 | {%- endfor %}
16 |
17 | {% for item in methods %}
18 | .. automethod:: {{ item }}
19 | {%- endfor %}
20 |
21 | {% endif %}
22 | {% endblock %}
23 |
--------------------------------------------------------------------------------
/docs/templates/enum.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block attributes %}
9 | {% for item in attributes %}
10 | .. autoattribute:: {{ item }}
11 | {% endfor %}
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/docs/templates/function.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autofunction:: {{ objname }}
7 |
--------------------------------------------------------------------------------
/docs/templates/langchain_docs.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
9 |
10 |
11 | Docs
12 |
13 |
--------------------------------------------------------------------------------
/docs/templates/pydantic.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autopydantic_model:: {{ objname }}
7 | :model-show-json: False
8 | :model-show-config-summary: False
9 | :model-show-validator-members: False
10 | :model-show-field-summary: False
11 | :field-signature-prefix: param
12 | :members:
13 | :undoc-members:
14 | :inherited-members:
15 | :member-order: groupwise
16 | :show-inheritance: True
17 | :special-members: __call__
18 | :exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, model_extra, model_fields_set, model_json_schema
19 |
20 |
21 | {% block attributes %}
22 | {% endblock %}
23 |
--------------------------------------------------------------------------------
/docs/templates/runnable_non_pydantic.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | .. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface `. 🏃
9 |
10 | The :py:class:`Runnable Interface ` has additional methods that are available on runnables, such as :py:meth:`with_types `, :py:meth:`with_retry `, :py:meth:`assign `, :py:meth:`bind `, :py:meth:`get_graph `, and more.
11 |
12 | {% block attributes %}
13 | {% if attributes %}
14 | .. rubric:: {{ _('Attributes') }}
15 |
16 | .. autosummary::
17 | {% for item in attributes %}
18 | ~{{ item }}
19 | {%- endfor %}
20 | {% endif %}
21 | {% endblock %}
22 |
23 | {% block methods %}
24 | {% if methods %}
25 | .. rubric:: {{ _('Methods') }}
26 |
27 | .. autosummary::
28 | {% for item in methods %}
29 | ~{{ item }}
30 | {%- endfor %}
31 |
32 | {% for item in methods %}
33 | .. automethod:: {{ item }}
34 | {%- endfor %}
35 |
36 | {% endif %}
37 | {% endblock %}
38 |
--------------------------------------------------------------------------------
/docs/templates/runnable_pydantic.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autopydantic_model:: {{ objname }}
7 | :model-show-json: False
8 | :model-show-config-summary: False
9 | :model-show-validator-members: False
10 | :model-show-field-summary: False
11 | :field-signature-prefix: param
12 | :members:
13 | :undoc-members:
14 | :inherited-members:
15 | :member-order: groupwise
16 | :show-inheritance: True
17 | :special-members: __call__
18 | :exclude-members: construct, copy, dict, from_orm, parse_file, parse_obj, parse_raw, schema, schema_json, update_forward_refs, validate, json, is_lc_serializable, to_json_not_implemented, lc_secrets, lc_attributes, lc_id, get_lc_namespace, astream_log, transform, atransform, get_output_schema, get_prompts, config_schema, map, pick, pipe, InputType, OutputType, config_specs, output_schema, get_input_schema, get_graph, get_name, input_schema, name, assign, as_tool, get_config_jsonschema, get_input_jsonschema, get_output_jsonschema, model_construct, model_copy, model_dump, model_dump_json, model_parametrized_name, model_post_init, model_rebuild, model_validate, model_validate_json, model_validate_strings, to_json, model_extra, model_fields_set, model_json_schema, predict, apredict, predict_messages, apredict_messages, generate, generate_prompt, agenerate, agenerate_prompt, call_as_llm
19 |
20 | .. NOTE:: {{objname}} implements the standard :py:class:`Runnable Interface `. 🏃
21 |
22 | The :py:class:`Runnable Interface ` has additional methods that are available on runnables, such as :py:meth:`with_types `, :py:meth:`with_retry `, :py:meth:`assign `, :py:meth:`bind `, :py:meth:`get_graph `, and more.
23 |
--------------------------------------------------------------------------------
/docs/templates/typeddict.rst:
--------------------------------------------------------------------------------
1 | {{ objname }}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block attributes %}
9 | {% for item in attributes %}
10 | .. autoattribute:: {{ item }}
11 | {% endfor %}
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/justfile:
--------------------------------------------------------------------------------
1 | set shell := ["bash", "-c"]
2 |
3 | # Default target executed when no arguments are given.
4 | [private]
5 | default:
6 | @just --list
7 |
8 | docs:
9 | uv run sphinx-build -T -b html docs docs/_build/html
10 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ---
4 |
5 | ## Changes in version 0.6.1 (2025/05/12)
6 |
7 | - Improve robustness of `MongoDBDatabase.run`.
8 |
9 | ## Changes in version 0.6.1 (2025/04/16)
10 |
11 | - Fixed a syntax error in a docstring.
12 | - Fixed some incorrect typings.
13 | - Added detail to README
14 | - Added MongoDBDocStore to README.
15 |
16 | ## Changes in version 0.6 (2025/03/26)
17 |
18 | - Added Natural language to MQL Database tool.
19 | - Added `MongoDBAtlasSelfQueryRetriever`.
20 | - Added logic for vector stores to optionally create vector search indexes.
21 | - Added `close()` methods to classes to ensure proper cleanup of resources.
22 | - Changed the default `batch_size` to 100 to align with resource constraints on
23 | AI model APIs.
24 |
25 | ## Changes in version 0.5 (2025/02/25)
26 |
27 | - Added GraphRAG support via `MongoDBGraphStore`.
28 |
29 | ## Changes in version 0.4 (2025/01/09)
30 |
31 | - Added support for `MongoDBRecordManager`.
32 | - Added support for `MongoDBLoader`.
33 | - Added support for `numpy 2.0`.
34 | - Added zizmor GitHub Actions security scanning.
35 | - Added local LLM support for testing.
36 |
37 | ## Changes in version 0.3 (2024/12/13)
38 |
39 | - Added support for `MongoDBAtlasParentDocumentRetriever`.
40 | - Migrated to https://github.com/langchain-ai/langchain-mongodb.
41 |
42 | ## Changes in version 0.2 (2024/09/13)
43 |
44 | - Added support for `MongoDBAtlasFullTextSearchRetriever` and `MongoDBAtlasHybridSearchRetriever`.
45 |
46 | ## Changes in version 0.1 (2024/02/29)
47 |
48 | - Initial release, added support for `MongoDBAtlasVectorSearch`.
49 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 LangChain, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/README.md:
--------------------------------------------------------------------------------
1 | from libs.community.tests.unit_tests.chains.test_pebblo_retrieval import retriever
2 |
3 | # langchain-mongodb
4 |
5 | # Installation
6 | ```
7 | pip install -U langchain-mongodb
8 | ```
9 |
10 | # Usage
11 | - [Integrate Atlas Vector Search with LangChain](https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/langchain/#get-started-with-the-langchain-integration) for a walkthrough on using your first LangChain implementation with MongoDB Atlas.
12 |
13 | ## Using MongoDBAtlasVectorSearch
14 | ```python
15 | import os
16 | from langchain_mongodb import MongoDBAtlasVectorSearch
17 | from langchain_openai import OpenAIEmbeddings
18 |
19 | # Pull MongoDB Atlas URI from environment variables
20 | MONGODB_ATLAS_CONNECTION_STRING = os.environ["MONGODB_CONNECTION_STRING"]
21 | DB_NAME = "langchain_db"
22 | COLLECTION_NAME = "test"
23 | VECTOR_SEARCH_INDEX_NAME = "index_name"
24 |
25 | MODEL_NAME = "text-embedding-3-large"
26 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
27 |
28 |
29 | vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
30 | connection_string=MONGODB_ATLAS_CONNECTION_STRING,
31 | namespace=DB_NAME + "." + COLLECTION_NAME,
32 | embedding=OpenAIEmbeddings(model=MODEL_NAME),
33 | index_name=VECTOR_SEARCH_INDEX_NAME,
34 | )
35 |
36 | retrieved_docs = vectorstore.similarity_search(
37 | "How do I deploy MongoDBAtlasVectorSearch in our production environment?")
38 | ```
39 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/justfile:
--------------------------------------------------------------------------------
1 | set shell := ["bash", "-c"]
2 | set dotenv-load
3 | set dotenv-filename := "../../.local_atlas_uri"
4 |
5 | # Default target executed when no arguments are given.
6 | [private]
7 | default:
8 | @just --list
9 |
10 | install:
11 | uv sync --frozen
12 |
13 | [group('test')]
14 | integration_tests *args="":
15 | uv run pytest tests/integration_tests/ {{args}}
16 |
17 | [group('test')]
18 | unit_tests *args="":
19 | uv run pytest tests/unit_tests {{args}}
20 |
21 | [group('test')]
22 | tests *args="":
23 | uv run pytest {{args}}
24 |
25 | [group('test')]
26 | test_watch filename:
27 | uv run ptw --snapshot-update --now . -- -vv {{filename}}
28 |
29 | [group('lint')]
30 | lint:
31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files
32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files
33 |
34 | [group('lint')]
35 | typing:
36 | uv run mypy --install-types --non-interactive .
37 |
38 | [group('lint')]
39 | codespell:
40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files
41 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Integrate your operational database and vector search in a single, unified,
3 | fully managed platform with full vector database capabilities on MongoDB Atlas.
4 |
5 |
6 | Store your operational data, metadata, and vector embeddings in oue VectorStore,
7 | MongoDBAtlasVectorSearch.
8 | Insert into a Chain via a Vector, FullText, or Hybrid Retriever.
9 | """
10 |
11 | from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
12 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
13 | from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
14 |
15 | __all__ = [
16 | "MongoDBAtlasVectorSearch",
17 | "MongoDBChatMessageHistory",
18 | "MongoDBCache",
19 | "MongoDBAtlasSemanticCache",
20 | ]
21 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/__init__.py:
--------------------------------------------------------------------------------
1 | from .database import MongoDBDatabase
2 | from .prompt import MONGODB_AGENT_SYSTEM_PROMPT
3 | from .toolkit import MongoDBDatabaseToolkit
4 |
5 | __all__ = ["MongoDBDatabaseToolkit", "MongoDBDatabase", "MONGODB_AGENT_SYSTEM_PROMPT"]
6 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/prompt.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 |
4 | MONGODB_AGENT_SYSTEM_PROMPT = """You are an agent designed to interact with a MongoDB database.
5 | Given an input question, create a syntactically correct MongoDB query to run, then look at the results of the query and return the answer.
6 | Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
7 | You can order the results by a relevant field to return the most interesting examples in the database.
8 | Never query for all the fields from a specific collection, only ask for the relevant fields given the question.
9 |
10 | You have access to tools for interacting with the database.
11 | Only use the below tools. Only use the information returned by the below tools to construct your final answer.
12 | You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
13 |
14 | DO NOT make any update, insert, or delete operations.
15 |
16 | The query MUST include the collection name and the contents of the aggregation pipeline.
17 |
18 | An example query looks like:
19 |
20 | ```python
21 | db.Invoice.aggregate([ {{ "$group": {{ _id: "$BillingCountry", "totalSpent": {{ "$sum": "$Total" }} }} }}, {{ "$sort": {{ "totalSpent": -1 }} }}, {{ "$limit": 5 }} ])
22 | ```
23 |
24 | To start you should ALWAYS look at the collections in the database to see what you can query.
25 | Do NOT skip this step.
26 | Then you should query the schema of the most relevant collections."""
27 |
28 | MONGODB_SUFFIX = """Begin!
29 |
30 | Question: {input}
31 | Thought: I should look at the collections in the database to see what I can query. Then I should query the schema of the most relevant collections.
32 | {agent_scratchpad}"""
33 |
34 | MONGODB_FUNCTIONS_SUFFIX = """I should look at the collections in the database to see what I can query. Then I should query the schema of the most relevant collections."""
35 |
36 |
37 | MONGODB_QUERY_CHECKER = """
38 | {query}
39 |
40 | Double check the MongoDB query above for common mistakes, including:
41 | - Missing content in the aggegregation pipeline
42 | - Improperly quoting identifiers
43 | - Improperly quoting operators
44 | - The content in the aggregation pipeline is not valid JSON
45 |
46 | If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
47 |
48 | Output the final MongoDB query only.
49 |
50 | MongoDB Query: """
51 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/tool.py:
--------------------------------------------------------------------------------
1 | """Tools for interacting with a MongoDB database."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any, Dict, Optional, Type
6 |
7 | from langchain_core.callbacks import (
8 | CallbackManagerForToolRun,
9 | )
10 | from langchain_core.language_models import BaseLanguageModel
11 | from langchain_core.prompts import PromptTemplate
12 | from langchain_core.tools import BaseTool
13 | from pydantic import BaseModel, ConfigDict, Field, model_validator
14 | from pymongo.cursor import Cursor
15 |
16 | from .database import MongoDBDatabase
17 | from .prompt import MONGODB_QUERY_CHECKER
18 |
19 |
20 | class BaseMongoDBDatabaseTool(BaseModel):
21 | """Base tool for interacting with a MongoDB database."""
22 |
23 | db: MongoDBDatabase = Field(exclude=True)
24 |
25 | model_config = ConfigDict(
26 | arbitrary_types_allowed=True,
27 | )
28 |
29 |
30 | class _QueryMongoDBDatabaseToolInput(BaseModel):
31 | query: str = Field(..., description="A detailed and correct MongoDB query.")
32 |
33 |
34 | class QueryMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override]
35 | """Tool for querying a MongoDB database."""
36 |
37 | name: str = "mongodb_query"
38 | description: str = """
39 | Execute a MongoDB query against the database and get back the result.
40 | If the query is not correct, an error message will be returned.
41 | If an error is returned, rewrite the query, check the query, and try again.
42 | """
43 | args_schema: Type[BaseModel] = _QueryMongoDBDatabaseToolInput
44 |
45 | def _run(self, query: str, **kwargs: Any) -> str | Cursor:
46 | """Execute the query, return the results or an error message."""
47 | return self.db.run_no_throw(query)
48 |
49 |
50 | class _InfoMongoDBDatabaseToolInput(BaseModel):
51 | collection_names: str = Field(
52 | ...,
53 | description=(
54 | "A comma-separated list of the collection names for which to return the schema. "
55 | "Example input: 'collection1, collection2, collection3'"
56 | ),
57 | )
58 |
59 |
60 | class InfoMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override]
61 | """Tool for getting metadata about a MongoDB database."""
62 |
63 | name: str = "mongodb_schema"
64 | description: str = (
65 | "Get the schema and sample documents for the specified MongoDB collections."
66 | )
67 | args_schema: Type[BaseModel] = _InfoMongoDBDatabaseToolInput
68 |
69 | def _run(
70 | self,
71 | collection_names: str,
72 | run_manager: Optional[CallbackManagerForToolRun] = None,
73 | ) -> str:
74 | """Get the schema for collections in a comma-separated list."""
75 | return self.db.get_collection_info_no_throw(
76 | [t.strip() for t in collection_names.split(",")]
77 | )
78 |
79 |
80 | class _ListMongoDBDatabaseToolInput(BaseModel):
81 | tool_input: str = Field("", description="An empty string")
82 |
83 |
84 | class ListMongoDBDatabaseTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override]
85 | """Tool for getting collection names."""
86 |
87 | name: str = "mongodb_list_collections"
88 | description: str = "Input is an empty string, output is a comma-separated list of collections in the database."
89 | args_schema: Type[BaseModel] = _ListMongoDBDatabaseToolInput
90 |
91 | def _run(
92 | self,
93 | tool_input: str = "",
94 | run_manager: Optional[CallbackManagerForToolRun] = None,
95 | ) -> str:
96 | """Get a comma-separated list of collection names."""
97 | return ", ".join(self.db.get_usable_collection_names())
98 |
99 |
100 | class _QueryMongoDBCheckerToolInput(BaseModel):
101 | query: str = Field(..., description="A detailed and MongoDB query to be checked.")
102 |
103 |
104 | class QueryMongoDBCheckerTool(BaseMongoDBDatabaseTool, BaseTool): # type: ignore[override, override]
105 | """Use an LLM to check if a query is correct.
106 | Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
107 |
108 | template: str = MONGODB_QUERY_CHECKER
109 | llm: BaseLanguageModel
110 | prompt: PromptTemplate = Field(init=False)
111 | name: str = "mongodb_query_checker"
112 | description: str = """
113 | Use this tool to double check if your query is correct before executing it.
114 | Always use this tool before executing a query with mongodb_query!
115 | """
116 | args_schema: Type[BaseModel] = _QueryMongoDBCheckerToolInput
117 |
118 | @model_validator(mode="before")
119 | @classmethod
120 | def initialize_prompt(cls, values: Dict[str, Any]) -> Any:
121 | if "prompt" not in values:
122 | values["prompt"] = PromptTemplate(
123 | template=MONGODB_QUERY_CHECKER, input_variables=["query"]
124 | )
125 |
126 | if values["prompt"].input_variables != ["query"]:
127 | raise ValueError(
128 | "Prompt for QueryCheckerTool must have input variables ['query']"
129 | )
130 |
131 | return values
132 |
133 | def _run(
134 | self,
135 | query: str,
136 | run_manager: Optional[CallbackManagerForToolRun] = None,
137 | ) -> str:
138 | """Use the LLM to check the query."""
139 | # TODO: check the query using pymongo first.
140 | chain = self.prompt | self.llm
141 | return chain.invoke(query) # type:ignore[arg-type]
142 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/toolkit.py:
--------------------------------------------------------------------------------
1 | """Toolkit for interacting with an MongoDB database."""
2 |
3 | from typing import List
4 |
5 | from langchain_core.caches import BaseCache as BaseCache
6 | from langchain_core.callbacks import Callbacks as Callbacks
7 | from langchain_core.language_models import BaseLanguageModel
8 | from langchain_core.tools import BaseTool
9 | from langchain_core.tools.base import BaseToolkit
10 | from pydantic import ConfigDict, Field
11 |
12 | from .database import MongoDBDatabase
13 | from .tool import (
14 | InfoMongoDBDatabaseTool,
15 | ListMongoDBDatabaseTool,
16 | QueryMongoDBCheckerTool,
17 | QueryMongoDBDatabaseTool,
18 | )
19 |
20 |
21 | class MongoDBDatabaseToolkit(BaseToolkit):
22 | """MongoDBDatabaseToolkit for interacting with MongoDB databases.
23 |
24 | Setup:
25 | Install ``langchain-mongodb``.
26 |
27 | .. code-block:: bash
28 |
29 | pip install -U langchain-mongodb
30 |
31 | Key init args:
32 | db: MongoDBDatabase
33 | The MongoDB database.
34 | llm: BaseLanguageModel
35 | The language model (for use with QueryMongoDBCheckerTool)
36 |
37 | Instantiate:
38 | .. code-block:: python
39 |
40 | from langchain_mongodb.agent_toolkit.toolkit import MongoDBDatabaseToolkit
41 | from langchain_mongodb.agent_toolkit.database import MongoDBDatabase
42 | from langchain_openai import ChatOpenAI
43 |
44 | db = MongoDBDatabase.from_connection_string("mongodb://localhost:27017/chinook")
45 | llm = ChatOpenAI(temperature=0)
46 |
47 | toolkit = MongoDBDatabaseToolkit(db=db, llm=llm)
48 |
49 | Tools:
50 | .. code-block:: python
51 |
52 | toolkit.get_tools()
53 |
54 | Use within an agent:
55 | .. code-block:: python
56 |
57 | from langchain import hub
58 | from langgraph.prebuilt import create_react_agent
59 | from langchain_mongodb.agent_toolkit import MONGODB_AGENT_SYSTEM_PROMPT
60 |
61 | # Pull prompt (or define your own)
62 | system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)
63 |
64 | # Create agent
65 | agent_executor = create_react_agent(
66 | llm, toolkit.get_tools(), state_modifier=system_message
67 | )
68 |
69 | # Query agent
70 | example_query = "Which country's customers spent the most?"
71 |
72 | events = agent_executor.stream(
73 | {"messages": [("user", example_query)]},
74 | stream_mode="values",
75 | )
76 | for event in events:
77 | event["messages"][-1].pretty_print()
78 | """ # noqa: E501
79 |
80 | db: MongoDBDatabase = Field(exclude=True)
81 | llm: BaseLanguageModel = Field(exclude=True)
82 |
83 | model_config = ConfigDict(
84 | arbitrary_types_allowed=True,
85 | )
86 |
87 | def get_tools(self) -> List[BaseTool]:
88 | """Get the tools in the toolkit."""
89 | list_mongodb_database_tool = ListMongoDBDatabaseTool(db=self.db)
90 | info_mongodb_database_tool_description = (
91 | "Input to this tool is a comma-separated list of collections, output is the "
92 | "schema and sample rows for those collections. "
93 | "Be sure that the collectionss actually exist by calling "
94 | f"{list_mongodb_database_tool.name} first! "
95 | "Example Input: collection1, collection2, collection3"
96 | )
97 | info_mongodb_database_tool = InfoMongoDBDatabaseTool(
98 | db=self.db, description=info_mongodb_database_tool_description
99 | )
100 | query_mongodb_database_tool_description = (
101 | "Input to this tool is a detailed and correct MongoDB query, output is a "
102 | "result from the database. If the query is not correct, an error message "
103 | "will be returned. If an error is returned, rewrite the query, check the "
104 | "query, and try again. If you encounter an issue with Unknown column "
105 | f"'xxxx' in 'field list', use {info_mongodb_database_tool.name} "
106 | "to query the correct collections fields."
107 | )
108 | query_mongodb_database_tool = QueryMongoDBDatabaseTool(
109 | db=self.db, description=query_mongodb_database_tool_description
110 | )
111 | query_mongodb_checker_tool_description = (
112 | "Use this tool to double check if your query is correct before executing "
113 | "it. Always use this tool before executing a query with "
114 | f"{query_mongodb_database_tool.name}!"
115 | )
116 | query_mongodb_checker_tool = QueryMongoDBCheckerTool(
117 | db=self.db, llm=self.llm, description=query_mongodb_checker_tool_description
118 | )
119 | return [
120 | query_mongodb_database_tool,
121 | info_mongodb_database_tool,
122 | list_mongodb_database_tool,
123 | query_mongodb_checker_tool,
124 | ]
125 |
126 | def get_context(self) -> dict:
127 | """Return db context that you may want in agent prompt."""
128 | return self.db.get_context()
129 |
130 |
131 | MongoDBDatabaseToolkit.model_rebuild()
132 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from importlib.metadata import version
4 | from typing import Dict, List, Optional
5 |
6 | from langchain_core.chat_history import BaseChatMessageHistory
7 | from langchain_core.messages import (
8 | BaseMessage,
9 | message_to_dict,
10 | messages_from_dict,
11 | )
12 | from pymongo import MongoClient, errors
13 | from pymongo.driver_info import DriverInfo
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | DEFAULT_DBNAME = "chat_history"
18 | DEFAULT_COLLECTION_NAME = "message_store"
19 | DEFAULT_SESSION_ID_KEY = "SessionId"
20 | DEFAULT_HISTORY_KEY = "History"
21 |
22 |
23 | class MongoDBChatMessageHistory(BaseChatMessageHistory):
24 | """Chat message history that stores history in MongoDB.
25 |
26 | Setup:
27 | Install ``langchain-mongodb`` python package.
28 |
29 | .. code-block:: bash
30 |
31 | pip install langchain-mongodb
32 |
33 | Instantiate:
34 | .. code-block:: python
35 |
36 | from langchain_mongodb import MongoDBChatMessageHistory
37 |
38 |
39 | history = MongoDBChatMessageHistory(
40 | connection_string="mongodb://your-host:your-port/", # mongodb://localhost:27017/
41 | session_id = "your-session-id",
42 | )
43 |
44 | Add and retrieve messages:
45 | .. code-block:: python
46 |
47 | # Add single message
48 | history.add_message(message)
49 |
50 | # Add batch messages
51 | history.add_messages([message1, message2, message3, ...])
52 |
53 | # Add human message
54 | history.add_user_message(human_message)
55 |
56 | # Add ai message
57 | history.add_ai_message(ai_message)
58 |
59 | # Retrieve messages
60 | messages = history.messages
61 | """ # noqa: E501
62 |
63 | def __init__(
64 | self,
65 | connection_string: Optional[str],
66 | session_id: str,
67 | database_name: str = DEFAULT_DBNAME,
68 | collection_name: str = DEFAULT_COLLECTION_NAME,
69 | *,
70 | session_id_key: str = DEFAULT_SESSION_ID_KEY,
71 | history_key: str = DEFAULT_HISTORY_KEY,
72 | create_index: bool = True,
73 | history_size: Optional[int] = None,
74 | index_kwargs: Optional[Dict] = None,
75 | client: Optional[MongoClient] = None,
76 | ):
77 | """Initialize with a MongoDBChatMessageHistory instance.
78 |
79 | Args:
80 | connection_string: Optional[str]
81 | connection string to connect to MongoDB. Can be None if mongo_client is
82 | provided.
83 | session_id: str
84 | arbitrary key that is used to store the messages of
85 | a single chat session.
86 | database_name: Optional[str]
87 | name of the database to use.
88 | collection_name: Optional[str]
89 | name of the collection to use.
90 | session_id_key: Optional[str]
91 | name of the field that stores the session id.
92 | history_key: Optional[str]
93 | name of the field that stores the chat history.
94 | create_index: Optional[bool]
95 | whether to create an index on the session id field.
96 | history_size: Optional[int]
97 | count of (most recent) messages to fetch from MongoDB.
98 | index_kwargs: Optional[Dict]
99 | additional keyword arguments to pass to the index creation.
100 | client: Optional[MongoClient]
101 | an existing MongoClient instance.
102 | If provided, connection_string is ignored.
103 | """
104 | self.session_id = session_id
105 | self.database_name = database_name
106 | self.collection_name = collection_name
107 | self.session_id_key = session_id_key
108 | self.history_key = history_key
109 | self.history_size = history_size
110 |
111 | if client:
112 | if connection_string:
113 | raise ValueError("Must provide connection_string or client, not both")
114 | self.client = client
115 | elif connection_string:
116 | try:
117 | self.client = MongoClient(
118 | connection_string,
119 | driver=DriverInfo(
120 | name="Langchain", version=version("langchain-mongodb")
121 | ),
122 | )
123 | except errors.ConnectionFailure as error:
124 | logger.error(error)
125 | else:
126 | raise ValueError("Either connection_string or client must be provided")
127 |
128 | self.db = self.client[database_name]
129 | self.collection = self.db[collection_name]
130 |
131 | if create_index:
132 | index_kwargs = index_kwargs or {}
133 | self.collection.create_index(self.session_id_key, **index_kwargs)
134 |
135 | @property
136 | def messages(self) -> List[BaseMessage]: # type: ignore
137 | """Retrieve the messages from MongoDB"""
138 | try:
139 | if self.history_size is None:
140 | cursor = self.collection.find({self.session_id_key: self.session_id})
141 | else:
142 | skip_count = max(
143 | 0,
144 | self.collection.count_documents(
145 | {self.session_id_key: self.session_id}
146 | )
147 | - self.history_size,
148 | )
149 | cursor = self.collection.find(
150 | {self.session_id_key: self.session_id}, skip=skip_count
151 | )
152 | except errors.OperationFailure as error:
153 | logger.error(error)
154 |
155 | if cursor:
156 | items = [json.loads(document[self.history_key]) for document in cursor]
157 | else:
158 | items = []
159 |
160 | messages = messages_from_dict(items)
161 | return messages
162 |
163 | def close(self) -> None:
164 | """Close the resources used by the MongoDBChatMessageHistory."""
165 | self.client.close()
166 |
167 | def add_message(self, message: BaseMessage) -> None:
168 | """Append the message to the record in MongoDB"""
169 | try:
170 | self.collection.insert_one(
171 | {
172 | self.session_id_key: self.session_id,
173 | self.history_key: json.dumps(message_to_dict(message)),
174 | }
175 | )
176 | except errors.WriteError as err:
177 | logger.error(err)
178 |
179 | def clear(self) -> None:
180 | """Clear session memory from MongoDB"""
181 | try:
182 | self.collection.delete_many({self.session_id_key: self.session_id})
183 | except errors.WriteError as err:
184 | logger.error(err)
185 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/docstores.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from importlib.metadata import version
4 | from typing import Any, Generator, Iterable, Iterator, List, Optional, Sequence, Union
5 |
6 | from langchain_core.documents import Document
7 | from langchain_core.stores import BaseStore
8 | from pymongo import MongoClient
9 | from pymongo.collection import Collection
10 | from pymongo.driver_info import DriverInfo
11 |
12 | from langchain_mongodb.utils import (
13 | make_serializable,
14 | )
15 |
16 | DEFAULT_INSERT_BATCH_SIZE = 100
17 |
18 |
19 | class MongoDBDocStore(BaseStore):
20 | """MongoDB Collection providing BaseStore interface.
21 |
22 | This is meant to be treated as a key-value store: [str, Document]
23 |
24 |
25 | In a MongoDB Collection, the field name _id is reserved for use as a primary key.
26 | Its value must be unique in the collection, is immutable,
27 | and may be of any type other than an array or regex.
28 | As this field is always indexed, it is the natural choice to hold keys.
29 |
30 | The value will be held simply in a field called "value".
31 | It can contain any valid BSON type.
32 |
33 | Example key value pair: {"_id": "foo", "value": "bar"}.
34 | """
35 |
36 | def __init__(self, collection: Collection, text_key: str = "page_content") -> None:
37 | self.collection = collection
38 | self._text_key = text_key
39 |
40 | @classmethod
41 | def from_connection_string(
42 | cls,
43 | connection_string: str,
44 | namespace: str,
45 | **kwargs: Any,
46 | ) -> MongoDBDocStore:
47 | """Construct a Key-Value Store from a MongoDB connection URI.
48 |
49 | Args:
50 | connection_string: A valid MongoDB connection URI.
51 | namespace: A valid MongoDB namespace (in form f"{database}.{collection}")
52 |
53 | Returns:
54 | A new MongoDBDocStore instance.
55 | """
56 | client: MongoClient = MongoClient(
57 | connection_string,
58 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
59 | )
60 | db_name, collection_name = namespace.split(".")
61 | collection = client[db_name][collection_name]
62 | return cls(collection=collection)
63 |
64 | def close(self) -> None:
65 | """Close the resources used by the MongoDBDocStore."""
66 | self.collection.database.client.close()
67 |
68 | def mget(self, keys: Sequence[str]) -> list[Optional[Document]]:
69 | """Get the values associated with the given keys.
70 |
71 | If a key is not found in the store, the corresponding value will be None.
72 | As returning None is not the default find behavior, we form a dictionary
73 | and loop over the keys.
74 |
75 | Args:
76 | keys (Sequence[str]): A sequence of keys.
77 |
78 | Returns: List of values associated with the given keys.
79 | """
80 | found_docs = {}
81 | for res in self.collection.find({"_id": {"$in": keys}}):
82 | text = res.pop(self._text_key)
83 | key = res.pop("_id")
84 | make_serializable(res)
85 | found_docs[key] = Document(page_content=text, metadata=res)
86 | return [found_docs.get(key, None) for key in keys]
87 |
88 | def mset(
89 | self,
90 | key_value_pairs: Sequence[tuple[str, Document]],
91 | batch_size: int = DEFAULT_INSERT_BATCH_SIZE,
92 | ) -> None:
93 | """Set the values for the given keys.
94 |
95 | Args:
96 | key_value_pairs: A sequence of key-value pairs.
97 | batch_size: Number of documents to insert at a time.
98 | Tuning this may help with performance and sidestep MongoDB limits.
99 | """
100 | keys, docs = zip(*key_value_pairs)
101 | n_docs = len(docs)
102 | start = 0
103 | for end in range(batch_size, n_docs + batch_size, batch_size):
104 | texts, metadatas = zip(
105 | *[(doc.page_content, doc.metadata) for doc in docs[start:end]]
106 | )
107 | self.insert_many(texts=texts, metadatas=metadatas, ids=keys[start:end]) # type: ignore
108 | start = end
109 |
110 | def mdelete(self, keys: Sequence[str]) -> None:
111 | """Delete the given keys and their associated values.
112 |
113 | Args:
114 | keys (Sequence[str]): A sequence of keys to delete.
115 | """
116 | self.collection.delete_many({"_id": {"$in": keys}})
117 |
118 | def yield_keys(
119 | self, *, prefix: Optional[str] = None
120 | ) -> Union[Iterator[str], Iterator[str]]:
121 | """Get an iterator over keys that match the given prefix.
122 |
123 | Args:
124 | prefix (str): The prefix to match.
125 |
126 | Yields:
127 | Iterator[str | str]: An iterator over keys that match the given prefix.
128 | This method is allowed to return an iterator over either str
129 | depending on what makes more sense for the given store.
130 | """
131 | query = {"_id": {"$regex": f"^{prefix}"}} if prefix else {}
132 | for document in self.collection.find(query, {"_id": 1}):
133 | yield document["_id"]
134 |
135 | def insert_many(
136 | self,
137 | texts: Union[List[str], Iterable[str]],
138 | metadatas: Union[List[dict], Generator[dict, Any, Any]],
139 | ids: List[str],
140 | ) -> None:
141 | """Bulk insert single batch of texts, embeddings, and optionally ids.
142 |
143 | insert_many in PyMongo does not overwrite existing documents.
144 | Instead, it attempts to insert each document as a new document.
145 | If a document with the same _id already exists in the collection,
146 | an error will be raised for that specific document. However, other documents
147 | in the batch that do not have conflicting _ids will still be inserted.
148 | """
149 | to_insert = [
150 | {"_id": i, self._text_key: t, **m} for i, t, m in zip(ids, texts, metadatas)
151 | ]
152 | self.collection.insert_many(to_insert) # type: ignore
153 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/graphrag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/langchain_mongodb/graphrag/__init__.py
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/graphrag/example_templates.py:
--------------------------------------------------------------------------------
1 | """These prompts serve as defaults, and templates when you wish to override them.
2 |
3 | LLM performance can be greatly increased by adding domain-specific examples.
4 | """
5 |
6 | entity_extraction = """
7 | ## Examples
8 | Use the following examples to guide your work.
9 |
10 | ### Example 1: Constrained entity and relationship types: Person, Friend
11 | #### Input
12 | Alice Palace has been the CEO of MongoDB since January 1, 2018.
13 | She maintains close friendships with Jarnail Singh, whom she has known since May 1, 2019,
14 | and Jasbinder Kaur, who she has been seeing weekly since May 1, 2015.
15 |
16 | #### Output
17 | (If `allowed_entity_types` is ["Person"] and `allowed_relationship_types` is ["Friend"])
18 | {{
19 | "entities": [
20 | {{
21 | "_id": "Alice Palace",
22 | "type": "Person",
23 | "attributes": {{
24 | "job": ["CEO of MongoDB"],
25 | "startDate": ["2018-01-01"]
26 | }},
27 | "relationships": {{
28 | "targets": ["Jasbinder Kaur", "Jarnail Singh"],
29 | "types": ["Friend", "Friend"],
30 | "attributes": [
31 | {{ "since": ["2019-05-01"] }},
32 | {{ "since": ["2015-05-01"], "frequency": ["weekly"] }}
33 | ]
34 | }}
35 | }},
36 | {{
37 | "_id": "Jarnail Singh",
38 | "type": "Person",
39 | "relationships": {{
40 | "targets": ["Alice Palace"],
41 | "types": ["Friend"],
42 | "attributes": [{{ "since": ["2019-05-01"] }}]
43 | }}
44 | }},
45 | {{
46 | "_id": "Jasbinder Kaur",
47 | "type": "Person",
48 | "relationships": {{
49 | "targets": ["Alice Palace"],
50 | "types": ["Friend"],
51 | "attributes": [{{ "since": ["2015-05-01"], "frequency": ["weekly"] }}]
52 | }}
53 | }}
54 | ]
55 | }}
56 |
57 | ### Example 2: Event Extraction
58 | #### Input
59 | The 2022 OpenAI Developer Conference took place on October 10, 2022, in San Francisco.
60 | Elon Musk and Sam Altman were keynote speakers at the event.
61 |
62 | #### Output
63 | (If `allowed_entity_types` is ["Event", "Person"] and `allowed_relationship_types` is ["Speaker"])
64 | {{
65 | "entities": [
66 | {{
67 | "_id": "2022 OpenAI Developer Conference",
68 | "type": "Event",
69 | "attributes": {{
70 | "date": ["2022-10-10"],
71 | "location": ["San Francisco"]
72 | }},
73 | "relationships": {{
74 | "targets": ["Elon Musk", "Sam Altman"],
75 | "types": ["Speaker", "Speaker"]
76 | }}
77 | }},
78 | {{ "_id": "Elon Musk", "type": "Person" }},
79 | {{ "_id": "Sam Altman", "type": "Person" }}
80 | ]
81 | }}
82 |
83 | ### Example 3: Concept Relationship
84 | #### Input
85 | Quantum computing is a field of study that focuses on developing computers based on the principles of quantum mechanics.
86 |
87 | #### Output
88 | (If `allowed_entity_types` is ["Concept"] and `allowed_relationship_types` is ["Based On"])
89 | {{
90 | "entities": [
91 | {{
92 | "_id": "Quantum Computing",
93 | "type": "Concept",
94 | "relationships": {{
95 | "targets": ["Quantum Mechanics"],
96 | "types": ["Based On"]
97 | }}
98 | }},
99 | {{ "_id": "Quantum Mechanics", "type": "Concept" }}
100 | ]
101 | }}
102 |
103 | ### Example 4: News Article
104 | #### Input
105 | On March 1, 2023, NASA successfully launched the Artemis II mission, sending astronauts to orbit the Moon.
106 | NASA Administrator Bill Nelson praised the historic achievement.
107 |
108 | #### Output
109 | (If `allowed_entity_types` is ["Organization", "Event", "Person"] and `allowed_relationship_types` is ["Managed By", "Praised By"])
110 | {{
111 | "entities": [
112 | {{
113 | "_id": "Artemis II Mission",
114 | "type": "Event",
115 | "attributes": {{ "date": ["2023-03-01"] }},
116 | "relationships": {{
117 | "targets": ["NASA"],
118 | "types": ["Managed By"]
119 | }}
120 | }},
121 | {{
122 | "_id": "NASA",
123 | "type": "Organization"
124 | }},
125 | {{
126 | "_id": "Bill Nelson",
127 | "type": "Person",
128 | "relationships": {{
129 | "targets": ["Artemis II Mission"],
130 | "types": ["Praised By"]
131 | }}
132 | }}
133 | ]
134 | }}
135 |
136 | ### Example 5: Technical Article
137 | #### Input
138 | Rust is a programming language that guarantees memory safety without requiring garbage collection.
139 | It is known for its strong ownership model, which prevents data races.
140 |
141 | #### Output
142 | (If `allowed_entity_types` is ["Programming Language", "Concept"] and `allowed_relationship_types` is ["Ensures", "Uses"])
143 | {{
144 | "entities": [
145 | {{
146 | "_id": "Rust",
147 | "type": "Programming Language",
148 | "relationships": {{
149 | "targets": ["Memory Safety"],
150 | "types": ["Ensures"]
151 | }}
152 | }},
153 | {{
154 | "_id": "Memory Safety",
155 | "type": "Concept",
156 | "relationships": {{
157 | "targets": ["Ownership Model"],
158 | "types": ["Uses"]
159 | }}
160 | }},
161 | {{ "_id": "Ownership Model", "type": "Concept" }}
162 | ]
163 | }}
164 | """
165 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/graphrag/schema.py:
--------------------------------------------------------------------------------
1 | """
2 | The following contains the JSON Schema for the Entities in the Collection representing the Knowledge Graph.
3 | If validate is set to True, the schema is enforced upon insert and update.
4 | See `$jsonSchema `_
5 |
6 | The following defines the default entity_schema.
7 | It allows all possible values of "type" and "relationship".
8 |
9 | If allowed_entity_types: List[str] is given to MongoDBGraphStore's constructor,
10 | then `self._schema["properties"]["type"]["enum"] = allowed_entity_types` is added.
11 |
12 | If allowed_relationship_types: List[str] is given to MongoDBGraphStore's constructor,
13 | additionalProperties is set to False, and relationship schema is provided for each key.
14 | """
15 |
16 | entity_schema = {
17 | "bsonType": "object",
18 | "required": ["_id", "type"],
19 | "properties": {
20 | "_id": {
21 | "bsonType": "string",
22 | "description": "Unique identifier for the entity",
23 | },
24 | "type": {
25 | "bsonType": "string",
26 | "description": "Type of the entity (e.g., 'Person', 'Organization')",
27 | # Note: When constrained, predefined types are added. For example:
28 | # "enum": ["Person", "Organization", "Location", "Event"],
29 | },
30 | "attributes": {
31 | "bsonType": "object",
32 | "description": "Key-value pairs describing the entity",
33 | "additionalProperties": {
34 | "bsonType": "array",
35 | "items": {"bsonType": "string"}, # Enforce array of strings
36 | },
37 | },
38 | "relationships": {
39 | "bsonType": "object",
40 | "description": "Key-value pairs of relationships",
41 | "required": ["target_ids"],
42 | "properties": {
43 | "target_ids": {
44 | "bsonType": "array",
45 | "description": "name/_id values of the target entities",
46 | "items": {"bsonType": "string"},
47 | },
48 | "types": {
49 | "bsonType": "array",
50 | "description": "An array of relationships to corresponding target_ids (in same array position).",
51 | "items": {"bsonType": "string"},
52 | # Note: When constrained, predefined types are added. For example:
53 | # "enum": ["used_in", "owns", "written_by", "located_in"], # Predefined types
54 | },
55 | "attributes": {
56 | "bsonType": "array",
57 | "description": "An array of attributes describing the relationships to corresponding target_ids (in same array position). Each element is an object containing key-value pairs, where values are arrays of strings.",
58 | "items": {
59 | "bsonType": "object",
60 | "additionalProperties": {
61 | "bsonType": "array",
62 | "items": {"bsonType": "string"},
63 | },
64 | },
65 | },
66 | },
67 | "additionalProperties": False,
68 | },
69 | },
70 | }
71 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/pipelines.py:
--------------------------------------------------------------------------------
1 | """Aggregation pipeline components used in Atlas Full-Text, Vector, and Hybrid Search
2 |
3 | See the following for more:
4 | - `Full-Text Search `_
5 | - `MongoDB Operators `_
6 | - `Vector Search `_
7 | - `Filter Example `_
8 | """
9 |
10 | from typing import Any, Dict, List, Optional
11 |
12 |
13 | def text_search_stage(
14 | query: str,
15 | search_field: str,
16 | index_name: str,
17 | limit: Optional[int] = None,
18 | filter: Optional[Dict[str, Any]] = None,
19 | include_scores: Optional[bool] = True,
20 | **kwargs: Any,
21 | ) -> List[Dict[str, Any]]: # noqa: E501
22 | """Full-Text search using Lucene's standard (BM25) analyzer
23 |
24 | Args:
25 | query: Input text to search for
26 | search_field: Field in Collection that will be searched
27 | index_name: Atlas Search Index name
28 | limit: Maximum number of documents to return. Default of no limit
29 | filter: Any MQL match expression comparing an indexed field
30 | include_scores: Scores provide measure of relative relevance
31 |
32 | Returns:
33 | Dictionary defining the $search stage
34 | """
35 | pipeline = [
36 | {
37 | "$search": {
38 | "index": index_name,
39 | "text": {"query": query, "path": search_field},
40 | }
41 | }
42 | ]
43 | if filter:
44 | pipeline.append({"$match": filter}) # type: ignore
45 | if include_scores:
46 | pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})
47 | if limit:
48 | pipeline.append({"$limit": limit}) # type: ignore
49 |
50 | return pipeline # type: ignore
51 |
52 |
53 | def vector_search_stage(
54 | query_vector: List[float],
55 | search_field: str,
56 | index_name: str,
57 | top_k: int = 4,
58 | filter: Optional[Dict[str, Any]] = None,
59 | oversampling_factor: int = 10,
60 | **kwargs: Any,
61 | ) -> Dict[str, Any]: # noqa: E501
62 | """Vector Search Stage without Scores.
63 |
64 | Scoring is applied later depending on strategy.
65 | vector search includes a vectorSearchScore that is typically used.
66 | hybrid uses Reciprocal Rank Fusion.
67 |
68 | Args:
69 | query_vector: List of embedding vector
70 | search_field: Field in Collection containing embedding vectors
71 | index_name: Name of Atlas Vector Search Index tied to Collection
72 | top_k: Number of documents to return
73 | oversampling_factor: this times limit is the number of candidates
74 | filter: MQL match expression comparing an indexed field.
75 | Some operators are not supported.
76 | See `vectorSearch filter docs `_
77 |
78 |
79 | Returns:
80 | Dictionary defining the $vectorSearch
81 | """
82 | stage = {
83 | "index": index_name,
84 | "path": search_field,
85 | "queryVector": query_vector,
86 | "numCandidates": top_k * oversampling_factor,
87 | "limit": top_k,
88 | }
89 | if filter:
90 | stage["filter"] = filter
91 | return {"$vectorSearch": stage}
92 |
93 |
94 | def combine_pipelines(
95 | pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
96 | ) -> None:
97 | """Combines two aggregations into a single result set in-place."""
98 | if pipeline:
99 | pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
100 | else:
101 | pipeline.extend(stage)
102 |
103 |
104 | def reciprocal_rank_stage(
105 | score_field: str, penalty: float = 0, **kwargs: Any
106 | ) -> List[Dict[str, Any]]:
107 | """Stage adds Reciprocal Rank Fusion weighting.
108 |
109 | First, it pushes documents retrieved from previous stage
110 | into a temporary sub-document. It then unwinds to establish
111 | the rank to each and applies the penalty.
112 |
113 | Args:
114 | score_field: A unique string to identify the search being ranked
115 | penalty: A non-negative float.
116 | extra_fields: Any fields other than text_field that one wishes to keep.
117 |
118 | Returns:
119 | RRF score
120 | """
121 |
122 | rrf_pipeline = [
123 | {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
124 | {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
125 | {
126 | "$addFields": {
127 | f"docs.{score_field}": {
128 | "$divide": [1.0, {"$add": ["$rank", penalty, 1]}]
129 | },
130 | "docs.rank": "$rank",
131 | "_id": "$docs._id",
132 | }
133 | },
134 | {"$replaceRoot": {"newRoot": "$docs"}},
135 | ]
136 |
137 | return rrf_pipeline # type: ignore
138 |
139 |
140 | def final_hybrid_stage(
141 | scores_fields: List[str], limit: int, **kwargs: Any
142 | ) -> List[Dict[str, Any]]:
143 | """Sum weighted scores, sort, and apply limit.
144 |
145 | Args:
146 | scores_fields: List of fields given to scores of vector and text searches
147 | limit: Number of documents to return
148 |
149 | Returns:
150 | Final aggregation stages
151 | """
152 |
153 | return [
154 | {"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
155 | {"$replaceRoot": {"newRoot": "$docs"}},
156 | {"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
157 | {"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
158 | {"$sort": {"score": -1}},
159 | {"$limit": limit},
160 | ]
161 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/langchain_mongodb/py.typed
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/retrievers/__init__.py:
--------------------------------------------------------------------------------
1 | """Search Retrievers of various types.
2 |
3 | Use ``MongoDBAtlasVectorSearch.as_retriever(**)``
4 | to create MongoDB's core Vector Search Retriever.
5 | """
6 |
7 | from langchain_mongodb.retrievers.full_text_search import (
8 | MongoDBAtlasFullTextSearchRetriever,
9 | )
10 | from langchain_mongodb.retrievers.graphrag import MongoDBGraphRAGRetriever
11 | from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
12 | from langchain_mongodb.retrievers.parent_document import (
13 | MongoDBAtlasParentDocumentRetriever,
14 | )
15 | from langchain_mongodb.retrievers.self_querying import MongoDBAtlasSelfQueryRetriever
16 |
17 | __all__ = [
18 | "MongoDBAtlasHybridSearchRetriever",
19 | "MongoDBAtlasFullTextSearchRetriever",
20 | "MongoDBAtlasParentDocumentRetriever",
21 | "MongoDBGraphRAGRetriever",
22 | "MongoDBAtlasSelfQueryRetriever",
23 | ]
24 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Any, Dict, List, Optional
2 |
3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
4 | from langchain_core.documents import Document
5 | from langchain_core.retrievers import BaseRetriever
6 | from pydantic import Field
7 | from pymongo.collection import Collection
8 |
9 | from langchain_mongodb.pipelines import text_search_stage
10 | from langchain_mongodb.utils import make_serializable
11 |
12 |
13 | class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
14 | """Retriever performs full-text searches using Lucene's standard (BM25) analyzer."""
15 |
16 | collection: Collection
17 | """MongoDB Collection on an Atlas cluster"""
18 | search_index_name: str
19 | """Atlas Search Index name"""
20 | search_field: str
21 | """Collection field that contains the text to be searched. It must be indexed"""
22 | k: Optional[int] = None
23 | """Number of documents to return. Default is no limit"""
24 | filter: Optional[Dict[str, Any]] = None
25 | """(Optional) List of MQL match expression comparing an indexed field"""
26 | include_scores: bool = True
27 | """If True, include scores that provide measure of relative relevance"""
28 | top_k: Annotated[
29 | Optional[int], Field(deprecated='top_k is deprecated, use "k" instead')
30 | ] = None
31 | """Number of documents to return. Default is no limit"""
32 |
33 | def close(self) -> None:
34 | """Close the resources used by the MongoDBAtlasFullTextSearchRetriever."""
35 | self.collection.database.client.close()
36 |
37 | def _get_relevant_documents(
38 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
39 | ) -> List[Document]:
40 | """Retrieve documents that are highest scoring / most similar to query.
41 |
42 | Args:
43 | query: String to find relevant documents for
44 | run_manager: The callback handler to use
45 | Returns:
46 | List of relevant documents
47 | """
48 | default_k = self.k if self.k is not None else self.top_k
49 | pipeline = text_search_stage( # type: ignore
50 | query=query,
51 | search_field=self.search_field,
52 | index_name=self.search_index_name,
53 | limit=kwargs.get("k", default_k),
54 | filter=self.filter,
55 | include_scores=self.include_scores,
56 | )
57 |
58 | # Execution
59 | cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
60 |
61 | # Formatting
62 | docs = []
63 | for res in cursor:
64 | text = res.pop(self.search_field)
65 | make_serializable(res)
66 | docs.append(Document(page_content=text, metadata=res))
67 | return docs
68 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/retrievers/graphrag.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
4 | from langchain_core.documents import Document
5 | from langchain_core.retrievers import BaseRetriever
6 |
7 | from langchain_mongodb.graphrag.graph import MongoDBGraphStore
8 |
9 |
10 | class MongoDBGraphRAGRetriever(BaseRetriever):
11 | """RunnableSerializable API of MongoDB GraphRAG."""
12 |
13 | graph_store: MongoDBGraphStore
14 | """Underlying Knowledge Graph storing entities and their relationships."""
15 |
16 | def _get_relevant_documents(
17 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun
18 | ) -> List[Document]:
19 | """Retrieve list of Entities found via traversal of KnowledgeGraph.
20 |
21 | Each Document's page_content is a string representation of the Entity dict.
22 |
23 | Description and details are provided in the underlying Entity Graph:
24 | :class:`~langchain_mongodb.graphrag.graph.MongoDBGraphStore`
25 |
26 | Args:
27 | query: String to find relevant documents for
28 | run_manager: The callback handler to use if desired
29 | Returns:
30 | List of relevant documents.
31 | """
32 | return [Document(str(e)) for e in self.graph_store.similarity_search(query)]
33 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/retrievers/hybrid_search.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Any, Dict, List, Optional
2 |
3 | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
4 | from langchain_core.documents import Document
5 | from langchain_core.retrievers import BaseRetriever
6 | from pydantic import Field
7 | from pymongo.collection import Collection
8 |
9 | from langchain_mongodb import MongoDBAtlasVectorSearch
10 | from langchain_mongodb.pipelines import (
11 | combine_pipelines,
12 | final_hybrid_stage,
13 | reciprocal_rank_stage,
14 | text_search_stage,
15 | vector_search_stage,
16 | )
17 | from langchain_mongodb.utils import make_serializable
18 |
19 |
20 | class MongoDBAtlasHybridSearchRetriever(BaseRetriever):
21 | """Hybrid Search Retriever combines vector and full-text searches
22 | weighting them the via Reciprocal Rank Fusion (RRF) algorithm.
23 |
24 | Increasing the vector_penalty will reduce the importance on the vector search.
25 | Increasing the fulltext_penalty will correspondingly reduce the fulltext score.
26 | For more on the algorithm,see
27 | https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking
28 | """
29 |
30 | vectorstore: MongoDBAtlasVectorSearch
31 | """MongoDBAtlas VectorStore"""
32 | search_index_name: str
33 | """Atlas Search Index (full-text) name"""
34 | k: int = 4
35 | """Number of documents to return."""
36 | oversampling_factor: int = 10
37 | """This times k is the number of candidates chosen at each step"""
38 | pre_filter: Optional[Dict[str, Any]] = None
39 | """(Optional) Any MQL match expression comparing an indexed field"""
40 | post_filter: Optional[List[Dict[str, Any]]] = None
41 | """(Optional) Pipeline of MongoDB aggregation stages for postprocessing."""
42 | vector_penalty: float = 60.0
43 | """Penalty applied to vector search results in RRF: scores=1/(rank + penalty)"""
44 | fulltext_penalty: float = 60.0
45 | """Penalty applied to full-text search results in RRF: scores=1/(rank + penalty)"""
46 | show_embeddings: float = False
47 | """If true, returned Document metadata will include vectors."""
48 | top_k: Annotated[
49 | Optional[int], Field(deprecated='top_k is deprecated, use "k" instead')
50 | ] = None
51 | """Number of documents to return."""
52 |
53 | @property
54 | def collection(self) -> Collection:
55 | return self.vectorstore._collection
56 |
57 | def close(self) -> None:
58 | """Close the resources used by the MongoDBAtlasHybridSearchRetriever."""
59 | self.vectorstore.close()
60 |
61 | def _get_relevant_documents(
62 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
63 | ) -> List[Document]:
64 | """Retrieve documents that are highest scoring / most similar to query.
65 |
66 | Note that the same query is used in both searches,
67 | embedded for vector search, and as-is for full-text search.
68 |
69 | Args:
70 | query: String to find relevant documents for
71 | run_manager: The callback handler to use
72 | Returns:
73 | List of relevant documents
74 | """
75 |
76 | query_vector = self.vectorstore._embedding.embed_query(query)
77 |
78 | scores_fields = ["vector_score", "fulltext_score"]
79 | pipeline: List[Any] = []
80 |
81 | # Get the appropriate value for k.
82 | default_k = self.top_k if self.top_k is not None else self.k
83 | k = kwargs.get("k", default_k)
84 |
85 | # First we build up the aggregation pipeline,
86 | # then it is passed to the server to execute
87 | # Vector Search stage
88 | vector_pipeline = [
89 | vector_search_stage(
90 | query_vector=query_vector,
91 | search_field=self.vectorstore._embedding_key,
92 | index_name=self.vectorstore._index_name,
93 | top_k=k,
94 | filter=self.pre_filter,
95 | oversampling_factor=self.oversampling_factor,
96 | )
97 | ]
98 | vector_pipeline += reciprocal_rank_stage("vector_score", self.vector_penalty)
99 |
100 | combine_pipelines(pipeline, vector_pipeline, self.collection.name)
101 |
102 | # Full-Text Search stage
103 | text_pipeline = text_search_stage(
104 | query=query,
105 | search_field=self.vectorstore._text_key,
106 | index_name=self.search_index_name,
107 | limit=k,
108 | filter=self.pre_filter,
109 | )
110 |
111 | text_pipeline.extend(
112 | reciprocal_rank_stage("fulltext_score", self.fulltext_penalty)
113 | )
114 |
115 | combine_pipelines(pipeline, text_pipeline, self.collection.name)
116 |
117 | # Sum and sort stage
118 | pipeline.extend(final_hybrid_stage(scores_fields=scores_fields, limit=k))
119 |
120 | # Removal of embeddings unless requested.
121 | if not self.show_embeddings:
122 | pipeline.append({"$project": {self.vectorstore._embedding_key: 0}})
123 | # Post filtering
124 | if self.post_filter is not None:
125 | pipeline.extend(self.post_filter)
126 |
127 | # Execution
128 | cursor = self.collection.aggregate(pipeline) # type: ignore[arg-type]
129 |
130 | # Formatting
131 | docs = []
132 | for res in cursor:
133 | text = res.pop(self.vectorstore._text_key)
134 | # score = res.pop("score") # The score remains buried!
135 | make_serializable(res)
136 | docs.append(Document(page_content=text, metadata=res))
137 | return docs
138 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/langchain_mongodb/utils.py:
--------------------------------------------------------------------------------
1 | """Various Utility Functions
2 |
3 | - Tools for handling bson.ObjectId
4 |
5 | The help IDs live as ObjectId in MongoDB and str in Langchain and JSON.
6 |
7 |
8 | - Tools for the Maximal Marginal Relevance (MMR) reranking
9 |
10 | These are duplicated from langchain_community to avoid cross-dependencies.
11 |
12 | Functions "maximal_marginal_relevance" and "cosine_similarity"
13 | are duplicated in this utility respectively from modules:
14 |
15 | - "libs/community/langchain_community/vectorstores/utils.py"
16 | - "libs/community/langchain_community/utils/math.py"
17 | """
18 |
19 | from __future__ import annotations
20 |
21 | import logging
22 | from datetime import date, datetime
23 | from typing import Any, Dict, List, Union
24 |
25 | import numpy as np
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
30 |
31 |
32 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
33 | """Row-wise cosine similarity between two equal-width matrices."""
34 | if len(X) == 0 or len(Y) == 0:
35 | return np.array([])
36 |
37 | X = np.array(X)
38 | Y = np.array(Y)
39 | if X.shape[1] != Y.shape[1]:
40 | raise ValueError(
41 | f"Number of columns in X and Y must be the same. X has shape {X.shape} "
42 | f"and Y has shape {Y.shape}."
43 | )
44 | try:
45 | import simsimd as simd
46 |
47 | X = np.array(X, dtype=np.float32)
48 | Y = np.array(Y, dtype=np.float32)
49 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
50 | return Z
51 | except ImportError:
52 | logger.debug(
53 | "Unable to import simsimd, defaulting to NumPy implementation. If you want "
54 | "to use simsimd please install with `pip install simsimd`."
55 | )
56 | X_norm = np.linalg.norm(X, axis=1)
57 | Y_norm = np.linalg.norm(Y, axis=1)
58 | # Ignore divide by zero errors run time warnings as those are handled below.
59 | with np.errstate(divide="ignore", invalid="ignore"):
60 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
61 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
62 | return similarity
63 |
64 |
65 | def maximal_marginal_relevance(
66 | query_embedding: np.ndarray,
67 | embedding_list: list,
68 | lambda_mult: float = 0.5,
69 | k: int = 4,
70 | ) -> List[int]:
71 | r"""Compute Maximal Marginal Relevance (MMR).
72 |
73 | MMR is a technique used to select documents that are both relevant to the query
74 | and diverse among themselves. This function returns the indices
75 | of the top-k embeddings that maximize the marginal relevance.
76 |
77 | Args:
78 | query_embedding (np.ndarray): The embedding vector of the query.
79 | embedding_list (list of np.ndarray): A list containing the embedding vectors
80 | of the candidate documents.
81 | lambda_mult (float, optional): The trade-off parameter between
82 | relevance and diversity. Defaults to 0.5.
83 | k (int, optional): The number of embeddings to select. Defaults to 4.
84 |
85 | Returns:
86 | list of int: The indices of the embeddings that maximize the marginal relevance.
87 |
88 | Notes:
89 | The Maximal Marginal Relevance (MMR) is computed using the following formula:
90 |
91 | MMR = argmax_{D_i ∈ R \ S} [λ * Sim(D_i, Q) - (1 - λ) * max_{D_j ∈ S} Sim(D_i, D_j)]
92 |
93 | where:
94 | - R is the set of candidate documents,
95 | - S is the set of selected documents,
96 | - Q is the query embedding,
97 | - Sim(D_i, Q) is the similarity between document D_i and the query,
98 | - Sim(D_i, D_j) is the similarity between documents D_i and D_j,
99 | - λ is the trade-off parameter.
100 | """
101 |
102 | if min(k, len(embedding_list)) <= 0:
103 | return []
104 | if query_embedding.ndim == 1:
105 | query_embedding = np.expand_dims(query_embedding, axis=0)
106 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
107 | most_similar = int(np.argmax(similarity_to_query))
108 | idxs = [most_similar]
109 | selected = np.array([embedding_list[most_similar]])
110 | while len(idxs) < min(k, len(embedding_list)):
111 | best_score = -np.inf
112 | idx_to_add = -1
113 | similarity_to_selected = cosine_similarity(embedding_list, selected)
114 | for i, query_score in enumerate(similarity_to_query):
115 | if i in idxs:
116 | continue
117 | redundant_score = max(similarity_to_selected[i])
118 | equation_score = (
119 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score
120 | )
121 | if equation_score > best_score:
122 | best_score = equation_score
123 | idx_to_add = i
124 | idxs.append(idx_to_add)
125 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
126 | return idxs
127 |
128 |
129 | def str_to_oid(str_repr: str) -> Any | str:
130 | """Attempt to cast string representation of id to MongoDB's internal BSON ObjectId.
131 |
132 | To be consistent with ObjectId, input must be a 24 character hex string.
133 | If it is not, MongoDB will happily use the string in the main _id index.
134 | Importantly, the str representation that comes out of MongoDB will have this form.
135 |
136 | Args:
137 | str_repr: id as string.
138 |
139 | Returns:
140 | ObjectID
141 | """
142 | from bson import ObjectId
143 | from bson.errors import InvalidId
144 |
145 | try:
146 | return ObjectId(str_repr)
147 | except InvalidId:
148 | logger.debug(
149 | "ObjectIds must be 12-character byte or 24-character hex strings. "
150 | "Examples: b'heres12bytes', '6f6e6568656c6c6f68656768'"
151 | )
152 | return str_repr
153 |
154 |
155 | def oid_to_str(oid: Any) -> str:
156 | """Convert MongoDB's internal BSON ObjectId into a simple str for compatibility.
157 |
158 | Instructive helper to show where data is coming out of MongoDB.
159 |
160 | Args:
161 | oid: bson.ObjectId
162 |
163 | Returns:
164 | 24 character hex string.
165 | """
166 | return str(oid)
167 |
168 |
169 | def make_serializable(
170 | obj: Dict[str, Any],
171 | ) -> None:
172 | """Recursively cast values in a dict to a form able to json.dump"""
173 |
174 | from bson import ObjectId
175 |
176 | for k, v in obj.items():
177 | if isinstance(v, dict):
178 | make_serializable(v)
179 | elif isinstance(v, list) and v and isinstance(v[0], (ObjectId, date, datetime)):
180 | obj[k] = [oid_to_str(item) for item in v]
181 | elif isinstance(v, ObjectId):
182 | obj[k] = oid_to_str(v)
183 | elif isinstance(v, (datetime, date)):
184 | obj[k] = v.isoformat()
185 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [build-system]
3 | requires = ["hatchling>1.24"]
4 | build-backend = "hatchling.build"
5 |
6 | [project]
7 | name = "langchain-mongodb"
8 | version = "0.6.2"
9 | description = "An integration package connecting MongoDB and LangChain"
10 | readme = "README.md"
11 | requires-python = ">=3.9"
12 | dependencies = [
13 | "langchain-core>=0.3",
14 | "langchain>=0.3",
15 | "pymongo>=4.6.1",
16 | "langchain-text-splitters>=0.3",
17 | "numpy>=1.26",
18 | "lark<2.0.0,>=1.1.9",
19 | ]
20 |
21 | [dependency-groups]
22 | dev = [
23 | "freezegun>=1.2.2",
24 | "langchain>=0.3.14",
25 | "langchain-core>=0.3.29",
26 | "langchain-text-splitters>=0.3.5",
27 | "pytest-mock>=3.10.0",
28 | "pytest>=7.3.0",
29 | "syrupy>=4.0.2",
30 | "pytest-watcher>=0.3.4",
31 | "pytest-asyncio>=0.21.1",
32 | "mongomock>=4.2.0.post1",
33 | "pre-commit>=4.0",
34 | "mypy>=1.10",
35 | "simsimd>=5.0.0",
36 | "langchain-ollama>=0.2.2",
37 | "langchain-openai>=0.2.14",
38 | "langchain-community>=0.3.14",
39 | "pypdf>=5.0.1",
40 | "langgraph>=0.2.72",
41 | "flaky>=3.8.1",
42 | "langchain-tests==0.3.14",
43 | "pip>=25.0.1",
44 | "typing-extensions>=4.12.2",
45 | ]
46 |
47 | [tool.pytest.ini_options]
48 | addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
49 | markers = [
50 | "requires: mark tests as requiring a specific library",
51 | "compile: mark placeholder test used to compile integration tests without running them",
52 | ]
53 | asyncio_mode = "auto"
54 | asyncio_default_fixture_loop_scope = "function"
55 |
56 | [tool.mypy]
57 | disallow_untyped_defs = true
58 |
59 | [[tool.mypy.overrides]]
60 | module = ["tests.*"]
61 | disable_error_code = ["no-untyped-def", "no-untyped-call", "var-annotated"]
62 |
63 | [tool.ruff]
64 | lint.select = [
65 | "E", # pycodestyle
66 | "F", # Pyflakes
67 | "UP", # pyupgrade
68 | "B", # flake8-bugbear
69 | "I", # isort
70 | ]
71 | lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035"]
72 |
73 | [tool.coverage.run]
74 | omit = ["tests/*"]
75 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/__init__.py
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/integration_tests/__init__.py
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Generator, List
3 |
4 | import pytest
5 | from langchain_community.document_loaders import PyPDFLoader
6 | from langchain_core.documents import Document
7 | from langchain_core.embeddings import Embeddings
8 | from langchain_ollama.embeddings import OllamaEmbeddings
9 | from langchain_openai import OpenAIEmbeddings
10 | from pymongo import MongoClient
11 |
12 | from ..utils import CONNECTION_STRING
13 |
14 |
15 | @pytest.fixture(scope="session")
16 | def technical_report_pages() -> List[Document]:
17 | """Returns a Document for each of the 100 pages of a GPT-4 Technical Report"""
18 | loader = PyPDFLoader("https://arxiv.org/pdf/2303.08774.pdf")
19 | pages = loader.load()
20 | return pages
21 |
22 |
23 | @pytest.fixture(scope="session")
24 | def client() -> Generator[MongoClient, None, None]:
25 | client = MongoClient(CONNECTION_STRING)
26 | yield client
27 | client.close()
28 |
29 |
30 | @pytest.fixture(scope="session")
31 | def embedding() -> Embeddings:
32 | if os.environ.get("OPENAI_API_KEY"):
33 | return OpenAIEmbeddings(
34 | openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
35 | model="text-embedding-3-small",
36 | )
37 |
38 | return OllamaEmbeddings(model="all-minilm:l6-v2")
39 |
40 |
41 | @pytest.fixture(scope="session")
42 | def dimensions() -> int:
43 | if os.environ.get("OPENAI_API_KEY"):
44 | return 1536
45 | return 384
46 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_agent_toolkit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sqlite3
3 |
4 | import pytest
5 | import requests
6 | from flaky import flaky # type:ignore[import-untyped]
7 | from langchain_openai import ChatOpenAI
8 | from langgraph.prebuilt import create_react_agent
9 | from pymongo import MongoClient
10 |
11 | from langchain_mongodb.agent_toolkit import (
12 | MONGODB_AGENT_SYSTEM_PROMPT,
13 | MongoDBDatabase,
14 | MongoDBDatabaseToolkit,
15 | )
16 | from tests.utils import CONNECTION_STRING
17 |
18 | DB_NAME = "langchain_test_db_chinook"
19 |
20 |
21 | @pytest.fixture
22 | def db(client: MongoClient) -> MongoDBDatabase:
23 | # Load the raw data into sqlite.
24 | url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
25 | response = requests.get(url)
26 | sql_script = response.text
27 | con = sqlite3.connect(":memory:", check_same_thread=False)
28 | con.executescript(sql_script)
29 |
30 | # Convert the sqlite data to MongoDB data.
31 | con.row_factory = sqlite3.Row
32 | cursor = con.cursor()
33 | sql_query = """SELECT name FROM sqlite_master WHERE type='table';"""
34 | cursor.execute(sql_query)
35 | tables = [i[0] for i in cursor.fetchall()]
36 | cursor.close()
37 | for t in tables:
38 | coll = client[DB_NAME][t]
39 | coll.delete_many({})
40 | cursor = con.cursor()
41 | cursor.execute(f"select * from {t}")
42 | docs = [dict(i) for i in cursor.fetchall()]
43 | cursor.close()
44 | coll.insert_many(docs)
45 | return MongoDBDatabase(client, DB_NAME)
46 |
47 |
48 | @flaky(max_runs=5, min_passes=4)
49 | @pytest.mark.skipif(
50 | "OPENAI_API_KEY" not in os.environ, reason="test requires OpenAI for chat responses"
51 | )
52 | def test_toolkit_response(db):
53 | db_wrapper = MongoDBDatabase.from_connection_string(
54 | CONNECTION_STRING, database=DB_NAME
55 | )
56 | llm = ChatOpenAI(model="gpt-4o-mini", timeout=60)
57 |
58 | toolkit = MongoDBDatabaseToolkit(db=db_wrapper, llm=llm)
59 |
60 | system_message = MONGODB_AGENT_SYSTEM_PROMPT.format(top_k=5)
61 |
62 | test_query = "Which country's customers spent the most?"
63 | agent = create_react_agent(llm, toolkit.get_tools(), state_modifier=system_message)
64 | agent.step_timeout = 60
65 | events = agent.stream(
66 | {"messages": [("user", test_query)]},
67 | stream_mode="values",
68 | )
69 | messages = []
70 | for event in events:
71 | messages.extend(event["messages"])
72 | assert "USA" in messages[-1].content, messages[-1].content
73 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_cache.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import Any, List, Union
4 |
5 | import pytest # type: ignore[import-not-found]
6 | from langchain_core.caches import BaseCache
7 | from langchain_core.globals import (
8 | get_llm_cache,
9 | set_llm_cache,
10 | )
11 | from langchain_core.load.dump import dumps
12 | from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
13 | from langchain_core.outputs import ChatGeneration, Generation, LLMResult
14 | from pymongo import MongoClient
15 | from pymongo.collection import Collection
16 |
17 | from langchain_mongodb.cache import MongoDBAtlasSemanticCache, MongoDBCache
18 | from langchain_mongodb.index import (
19 | create_vector_search_index,
20 | )
21 |
22 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, FakeChatModel, FakeLLM
23 |
24 | CONN_STRING = os.environ.get("MONGODB_URI")
25 | INDEX_NAME = "langchain-test-index-semantic-cache"
26 | COLLECTION = "langchain_test_cache"
27 |
28 | DIMENSIONS = 1536 # Meets OpenAI model
29 | TIMEOUT = 60.0
30 |
31 |
32 | def random_string() -> str:
33 | return str(uuid.uuid4())
34 |
35 |
36 | @pytest.fixture(scope="module")
37 | def collection(client: MongoClient) -> Collection:
38 | """A Collection with both a Vector and a Full-text Search Index"""
39 | if COLLECTION not in client[DB_NAME].list_collection_names():
40 | clxn = client[DB_NAME].create_collection(COLLECTION)
41 | else:
42 | clxn = client[DB_NAME][COLLECTION]
43 |
44 | clxn.delete_many({})
45 |
46 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
47 | create_vector_search_index(
48 | collection=clxn,
49 | index_name=INDEX_NAME,
50 | dimensions=DIMENSIONS,
51 | path="embedding",
52 | filters=["llm_string"],
53 | similarity="cosine",
54 | wait_until_complete=TIMEOUT,
55 | )
56 |
57 | return clxn
58 |
59 |
60 | def llm_cache(cls: Any) -> BaseCache:
61 | set_llm_cache(
62 | cls(
63 | embedding=ConsistentFakeEmbeddings(dimensionality=DIMENSIONS),
64 | connection_string=CONN_STRING,
65 | collection_name=COLLECTION,
66 | database_name=DB_NAME,
67 | index_name=INDEX_NAME,
68 | score_threshold=0.5,
69 | wait_until_ready=TIMEOUT,
70 | )
71 | )
72 | assert get_llm_cache()
73 | return get_llm_cache()
74 |
75 |
76 | @pytest.fixture(scope="module", autouse=True)
77 | def reset_cache():
78 | """Prevents global cache being affected in other module's tests."""
79 | yield
80 | print("\nAll cache tests have finished. Setting global cache to None.")
81 | set_llm_cache(None)
82 |
83 |
84 | def _execute_test(
85 | prompt: Union[str, List[BaseMessage]],
86 | llm: Union[str, FakeLLM, FakeChatModel],
87 | response: List[Generation],
88 | ) -> None:
89 | # Fabricate an LLM String
90 |
91 | if not isinstance(llm, str):
92 | params = llm.dict()
93 | params["stop"] = None
94 | llm_string = str(sorted([(k, v) for k, v in params.items()]))
95 | else:
96 | llm_string = llm
97 |
98 | # If the prompt is a str then we should pass just the string
99 | dumped_prompt: str = prompt if isinstance(prompt, str) else dumps(prompt)
100 |
101 | # Update the cache
102 | get_llm_cache().update(dumped_prompt, llm_string, response)
103 |
104 | # Retrieve the cached result through 'generate' call
105 | output: Union[List[Generation], LLMResult, None]
106 | expected_output: Union[List[Generation], LLMResult]
107 |
108 | if isinstance(llm, str):
109 | output = get_llm_cache().lookup(dumped_prompt, llm) # type: ignore
110 | expected_output = response
111 | else:
112 | output = llm.generate([prompt]) # type: ignore
113 | expected_output = LLMResult(
114 | generations=[response],
115 | llm_output={},
116 | )
117 |
118 | assert output == expected_output # type: ignore
119 |
120 |
121 | @pytest.mark.parametrize(
122 | "prompt, llm, response",
123 | [
124 | ("foo", "bar", [Generation(text="fizz")]),
125 | ("foo", FakeLLM(), [Generation(text="fizz")]),
126 | (
127 | [HumanMessage(content="foo")],
128 | FakeChatModel(),
129 | [ChatGeneration(message=AIMessage(content="foo"))],
130 | ),
131 | ],
132 | ids=[
133 | "plain_cache",
134 | "cache_with_llm",
135 | "cache_with_chat",
136 | ],
137 | )
138 | @pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
139 | @pytest.mark.parametrize("remove_score", [True, False])
140 | def test_mongodb_cache(
141 | remove_score: bool,
142 | cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
143 | prompt: Union[str, List[BaseMessage]],
144 | llm: Union[str, FakeLLM, FakeChatModel],
145 | response: List[Generation],
146 | collection: Collection,
147 | ) -> None:
148 | llm_cache(cacher)
149 | if remove_score:
150 | get_llm_cache().score_threshold = None # type: ignore
151 | try:
152 | _execute_test(prompt, llm, response)
153 | finally:
154 | get_llm_cache().clear()
155 |
156 |
157 | @pytest.mark.parametrize(
158 | "prompts, generations",
159 | [
160 | # Single prompt, single generation
161 | ([random_string()], [[random_string()]]),
162 | # Single prompt, multiple generations
163 | ([random_string()], [[random_string(), random_string()]]),
164 | # Single prompt, multiple generations
165 | ([random_string()], [[random_string(), random_string(), random_string()]]),
166 | # Multiple prompts, multiple generations
167 | (
168 | [random_string(), random_string()],
169 | [[random_string()], [random_string(), random_string()]],
170 | ),
171 | ],
172 | ids=[
173 | "single_prompt_single_generation",
174 | "single_prompt_two_generations",
175 | "single_prompt_three_generations",
176 | "multiple_prompts_multiple_generations",
177 | ],
178 | )
179 | def test_mongodb_atlas_cache_matrix(
180 | prompts: List[str],
181 | generations: List[List[str]],
182 | collection: Collection,
183 | ) -> None:
184 | llm_cache(MongoDBAtlasSemanticCache)
185 | llm = FakeLLM()
186 |
187 | # Fabricate an LLM String
188 | params = llm.dict()
189 | params["stop"] = None
190 | llm_string = str(sorted([(k, v) for k, v in params.items()]))
191 |
192 | llm_generations = [
193 | [
194 | Generation(text=generation, generation_info=params)
195 | for generation in prompt_i_generations
196 | ]
197 | for prompt_i_generations in generations
198 | ]
199 |
200 | for prompt_i, llm_generations_i in zip(prompts, llm_generations):
201 | _execute_test(prompt_i, llm_string, llm_generations_i)
202 | assert llm.generate(prompts) == LLMResult(
203 | generations=llm_generations, llm_output={}
204 | )
205 | get_llm_cache().clear()
206 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_chain_example.py:
--------------------------------------------------------------------------------
1 | "Demonstrates MongoDBAtlasVectorSearch.as_retriever() invoked in a chain" ""
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 |
7 | import pytest # type: ignore[import-not-found]
8 | from langchain_core.documents import Document
9 | from langchain_core.embeddings import Embeddings
10 | from langchain_core.output_parsers.string import StrOutputParser
11 | from langchain_core.prompts.chat import ChatPromptTemplate
12 | from langchain_core.runnables import RunnablePassthrough
13 | from langchain_openai import ChatOpenAI
14 | from pymongo import MongoClient
15 | from pymongo.collection import Collection
16 |
17 | from langchain_mongodb import index
18 |
19 | from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch
20 |
21 | COLLECTION_NAME = "langchain_test_chain_example"
22 | INDEX_NAME = "langchain-test-chain-example-vector-index"
23 | DIMENSIONS = 1536
24 | TIMEOUT = 60.0
25 | INTERVAL = 0.5
26 |
27 |
28 | @pytest.fixture(scope="module")
29 | def collection(client: MongoClient) -> Collection:
30 | """A Collection with both a Vector and a Full-text Search Index"""
31 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
32 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
33 | else:
34 | clxn = client[DB_NAME][COLLECTION_NAME]
35 |
36 | clxn.delete_many({})
37 |
38 | if all([INDEX_NAME != ix["name"] for ix in clxn.list_search_indexes()]):
39 | index.create_vector_search_index(
40 | collection=clxn,
41 | index_name=INDEX_NAME,
42 | dimensions=DIMENSIONS,
43 | path="embedding",
44 | similarity="cosine",
45 | filters=None,
46 | wait_until_complete=TIMEOUT,
47 | )
48 |
49 | return clxn
50 |
51 |
52 | @pytest.mark.skipif(
53 | not os.environ.get("OPENAI_API_KEY"),
54 | reason="Requires OpenAI for chat responses.",
55 | )
56 | def test_chain(
57 | collection: Collection,
58 | embedding: Embeddings,
59 | ) -> None:
60 | """Demonstrate usage of MongoDBAtlasVectorSearch in a realistic chain
61 |
62 | Follows example in the docs: https://python.langchain.com/docs/how_to/hybrid/
63 |
64 | Requires OpenAI_API_KEY for embedding and chat model.
65 | Requires INDEX_NAME to have been set up on MONGODB_URI
66 | """
67 |
68 | vectorstore = PatchedMongoDBAtlasVectorSearch(
69 | collection=collection,
70 | embedding=embedding,
71 | index_name=INDEX_NAME,
72 | text_key="page_content",
73 | )
74 |
75 | texts = [
76 | "In 2023, I visited Paris",
77 | "In 2022, I visited New York",
78 | "In 2021, I visited New Orleans",
79 | "In 2019, I visited San Francisco",
80 | "In 2020, I visited Vancouver",
81 | ]
82 | vectorstore.add_texts(texts)
83 |
84 | query = "In the United States, what city did I visit last?"
85 | # One can do vector search on the vector store, using its various search types.
86 | k = len(texts)
87 |
88 | store_output = list(vectorstore.similarity_search(query=query, k=k))
89 | assert len(store_output) == k
90 | assert isinstance(store_output[0], Document)
91 |
92 | # Unfortunately, the VectorStore output cannot be given to a Chat Model
93 | # If we wish Chat Model to answer based on our own data,
94 | # we have to give it the right things to work with.
95 | # The way that Langchain does this is by piping results along in
96 | # a Chain: https://python.langchain.com/v0.1/docs/modules/chains/
97 |
98 | # Now, we can turn our VectorStore into something Runnable in a Chain
99 | # by turning it into a Retriever.
100 | # For the simple VectorSearch Retriever, we can do this like so.
101 |
102 | retriever = vectorstore.as_retriever(search_kwargs=dict(k=k))
103 |
104 | # This does not do much other than expose our search function
105 | # as an invoke() method with a a certain API, a Runnable.
106 | retriever_output = retriever.invoke(query)
107 | assert len(retriever_output) == len(texts)
108 | assert retriever_output[0].page_content == store_output[0].page_content
109 |
110 | # To get a natural language response to our question,
111 | # we need ChatOpenAI, a template to better frame the question as a prompt,
112 | # and a parser to send the output to a string.
113 | # Together, these become our Chain!
114 | # Here goes:
115 |
116 | template = """Answer the question based only on the following context.
117 | Answer in as few words as possible.
118 | {context}
119 | Question: {question}
120 | """
121 | prompt = ChatPromptTemplate.from_template(template)
122 |
123 | model = ChatOpenAI()
124 |
125 | chain = (
126 | {"context": retriever, "question": RunnablePassthrough()} # type: ignore
127 | | prompt
128 | | model
129 | | StrOutputParser()
130 | )
131 |
132 | answer = chain.invoke("What city did I visit last?")
133 |
134 | assert "Paris" in answer
135 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_chat_message_histories.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
4 | from langchain_core.messages import message_to_dict
5 |
6 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
7 |
8 | from ..utils import CONNECTION_STRING, DB_NAME
9 |
10 | COLLECTION = "langchain_test_chat"
11 |
12 |
13 | def test_memory_with_message_store() -> None:
14 | """Test the memory with a message store."""
15 | # setup MongoDB as a message store
16 | message_history = MongoDBChatMessageHistory(
17 | connection_string=CONNECTION_STRING,
18 | session_id="test-session",
19 | database_name=DB_NAME,
20 | collection_name=COLLECTION,
21 | )
22 | memory = ConversationBufferMemory(
23 | memory_key="baz", chat_memory=message_history, return_messages=True
24 | )
25 |
26 | # add some messages
27 | memory.chat_memory.add_ai_message("This is me, the AI")
28 | memory.chat_memory.add_user_message("This is me, the human")
29 |
30 | # get the message history from the memory store and turn it into a json
31 | messages = memory.chat_memory.messages
32 | messages_json = json.dumps([message_to_dict(msg) for msg in messages])
33 |
34 | assert "This is me, the AI" in messages_json
35 | assert "This is me, the human" in messages_json
36 |
37 | # remove the record from MongoDB, so the next test run won't pick it up
38 | memory.chat_memory.clear()
39 |
40 | assert memory.chat_memory.messages == []
41 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_compile.py:
--------------------------------------------------------------------------------
1 | import pytest # type: ignore[import-not-found]
2 |
3 |
4 | @pytest.mark.compile
5 | def test_placeholder() -> None:
6 | """Used for compiling integration tests without running any real tests."""
7 | pass
8 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_docstore.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from langchain_core.documents import Document
4 | from pymongo import MongoClient
5 |
6 | from langchain_mongodb.docstores import MongoDBDocStore
7 |
8 | from ..utils import DB_NAME
9 |
10 | COLLECTION_NAME = "langchain_test_docstore"
11 |
12 |
13 | def test_docstore(client: MongoClient, technical_report_pages: List[Document]) -> None:
14 | db = client[DB_NAME]
15 | db.drop_collection(COLLECTION_NAME)
16 | clxn = db[COLLECTION_NAME]
17 |
18 | n_docs = len(technical_report_pages)
19 | assert clxn.count_documents({}) == 0
20 | docstore = MongoDBDocStore(collection=clxn)
21 |
22 | docstore.mset([(str(i), technical_report_pages[i]) for i in range(n_docs)])
23 | assert clxn.count_documents({}) == n_docs
24 |
25 | twenties = list(docstore.yield_keys(prefix="2"))
26 | assert len(twenties) == 11 # includes 2, 20, 21, ..., 29
27 |
28 | docstore.mdelete([str(i) for i in range(20, 30)] + ["2"])
29 | assert clxn.count_documents({}) == n_docs - 11
30 | assert set(docstore.mget(twenties)) == {None}
31 |
32 | sample = docstore.mget(["8", "16", "24", "36"])
33 | assert sample[2] is None
34 | assert all(isinstance(sample[i], Document) for i in [0, 1, 3])
35 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_index.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, List, Optional
2 |
3 | import pytest
4 | from pymongo import MongoClient
5 | from pymongo.collection import Collection
6 |
7 | from langchain_mongodb import MongoDBAtlasVectorSearch, index
8 |
9 | from ..utils import ConsistentFakeEmbeddings
10 |
11 | DB_NAME = "langchain_test_index_db"
12 | COLLECTION_NAME = "test_index"
13 | VECTOR_INDEX_NAME = "vector_index"
14 |
15 | TIMEOUT = 120
16 | DIMENSIONS = 10
17 |
18 |
19 | @pytest.fixture
20 | def collection(client: MongoClient) -> Generator:
21 | """Depending on uri, this could point to any type of cluster."""
22 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
23 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
24 | else:
25 | clxn = client[DB_NAME][COLLECTION_NAME]
26 | clxn = client[DB_NAME][COLLECTION_NAME]
27 | clxn.delete_many({})
28 | yield clxn
29 | clxn.delete_many({})
30 |
31 |
32 | def test_search_index_drop_add_delete_commands(collection: Collection) -> None:
33 | index_name = VECTOR_INDEX_NAME
34 | dimensions = DIMENSIONS
35 | path = "embedding"
36 | similarity = "cosine"
37 | filters: Optional[List[str]] = None
38 | wait_until_complete = TIMEOUT
39 |
40 | for index_info in collection.list_search_indexes():
41 | index.drop_vector_search_index(
42 | collection, index_info["name"], wait_until_complete=wait_until_complete
43 | )
44 |
45 | assert len(list(collection.list_search_indexes())) == 0
46 |
47 | index.create_vector_search_index(
48 | collection=collection,
49 | index_name=index_name,
50 | dimensions=dimensions,
51 | path=path,
52 | similarity=similarity,
53 | filters=filters,
54 | wait_until_complete=wait_until_complete,
55 | )
56 |
57 | assert index._is_index_ready(collection, index_name)
58 | indexes = list(collection.list_search_indexes())
59 | assert len(indexes) == 1
60 | assert indexes[0]["name"] == index_name
61 |
62 | index.drop_vector_search_index(
63 | collection, index_name, wait_until_complete=wait_until_complete
64 | )
65 |
66 | indexes = list(collection.list_search_indexes())
67 | assert len(indexes) == 0
68 |
69 |
70 | @pytest.mark.skip("collection.update_vector_search_index requires [CLOUDP-275518]")
71 | def test_search_index_update_vector_search_index(collection: Collection) -> None:
72 | index_name = "INDEX_TO_UPDATE"
73 | similarity_orig = "cosine"
74 | similarity_new = "euclidean"
75 |
76 | # Create another index
77 | index.create_vector_search_index(
78 | collection=collection,
79 | index_name=index_name,
80 | dimensions=DIMENSIONS,
81 | path="embedding",
82 | similarity=similarity_orig,
83 | wait_until_complete=TIMEOUT,
84 | )
85 |
86 | assert index._is_index_ready(collection, index_name)
87 | indexes = list(collection.list_search_indexes())
88 | assert len(indexes) == 1
89 | assert indexes[0]["name"] == index_name
90 | assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_orig
91 |
92 | # Update the index and test new similarity
93 | index.update_vector_search_index(
94 | collection=collection,
95 | index_name=index_name,
96 | dimensions=DIMENSIONS,
97 | path="embedding",
98 | similarity=similarity_new,
99 | wait_until_complete=TIMEOUT,
100 | )
101 |
102 | assert index._is_index_ready(collection, index_name)
103 | indexes = list(collection.list_search_indexes())
104 | assert len(indexes) == 1
105 | assert indexes[0]["name"] == index_name
106 | assert indexes[0]["latestDefinition"]["fields"][0]["similarity"] == similarity_new
107 |
108 |
109 | def test_vectorstore_create_vector_search_index(collection: Collection) -> None:
110 | """Tests vectorstore wrapper around index command."""
111 |
112 | # Set up using the index module's api
113 | if len(list(collection.list_search_indexes())) != 0:
114 | index.drop_vector_search_index(
115 | collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT
116 | )
117 |
118 | # Test MongoDBAtlasVectorSearch's API
119 | _ = MongoDBAtlasVectorSearch(
120 | collection=collection,
121 | embedding=ConsistentFakeEmbeddings(),
122 | index_name=VECTOR_INDEX_NAME,
123 | dimensions=DIMENSIONS,
124 | auto_index_timeout=TIMEOUT,
125 | )
126 |
127 | assert index._is_index_ready(collection, VECTOR_INDEX_NAME)
128 | indexes = list(collection.list_search_indexes())
129 | assert len(indexes) == 1
130 | assert indexes[0]["name"] == VECTOR_INDEX_NAME
131 |
132 | # Tear down using the index module's api
133 | index.drop_vector_search_index(
134 | collection, VECTOR_INDEX_NAME, wait_until_complete=TIMEOUT
135 | )
136 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_loaders.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | from langchain_core.documents import Document
4 | from pymongo import MongoClient
5 |
6 | from langchain_mongodb.loaders import MongoDBLoader
7 |
8 | from ..utils import DB_NAME
9 |
10 | COLLECTION_NAME = "langchain_test_loader"
11 |
12 |
13 | def raw_docs() -> List[Dict]:
14 | return [
15 | {"_id": "1", "address": {"building": "1", "room": "1"}},
16 | {"_id": "2", "address": {"building": "2", "room": "2"}},
17 | {"_id": "3", "address": {"building": "3", "room": "2"}},
18 | ]
19 |
20 |
21 | def expected_documents() -> List[Document]:
22 | return [
23 | Document(
24 | page_content="2 2",
25 | metadata={"_id": "2", "database": DB_NAME, "collection": COLLECTION_NAME},
26 | ),
27 | Document(
28 | page_content="3 2",
29 | metadata={"_id": "3", "database": DB_NAME, "collection": COLLECTION_NAME},
30 | ),
31 | ]
32 |
33 |
34 | async def test_load_with_filters(client: MongoClient) -> None:
35 | filter_criteria = {"address.room": {"$eq": "2"}}
36 | field_names = ["address.building", "address.room"]
37 | metadata_names = ["_id"]
38 | include_db_collection_in_metadata = True
39 |
40 | collection = client[DB_NAME][COLLECTION_NAME]
41 | collection.delete_many({})
42 | collection.insert_many(raw_docs())
43 |
44 | loader = MongoDBLoader(
45 | collection,
46 | filter_criteria=filter_criteria,
47 | field_names=field_names,
48 | metadata_names=metadata_names,
49 | include_db_collection_in_metadata=include_db_collection_in_metadata,
50 | )
51 | documents = await loader.aload()
52 |
53 | assert documents == expected_documents()
54 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_mmr.py:
--------------------------------------------------------------------------------
1 | """Test max_marginal_relevance_search."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest # type: ignore[import-not-found]
6 | from langchain_core.embeddings import Embeddings
7 | from pymongo import MongoClient
8 | from pymongo.collection import Collection
9 |
10 | from langchain_mongodb.index import (
11 | create_vector_search_index,
12 | )
13 |
14 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
15 |
16 | COLLECTION_NAME = "langchain_test_vectorstores"
17 | INDEX_NAME = "langchain-test-index-vectorstores"
18 | DIMENSIONS = 5
19 |
20 |
21 | @pytest.fixture()
22 | def collection(client: MongoClient) -> Collection:
23 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
24 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
25 | else:
26 | clxn = client[DB_NAME][COLLECTION_NAME]
27 |
28 | clxn.delete_many({})
29 |
30 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
31 | create_vector_search_index(
32 | collection=clxn,
33 | index_name=INDEX_NAME,
34 | dimensions=5,
35 | path="embedding",
36 | filters=["c"],
37 | similarity="cosine",
38 | wait_until_complete=60,
39 | )
40 |
41 | return clxn
42 |
43 |
44 | @pytest.fixture
45 | def embeddings() -> Embeddings:
46 | return ConsistentFakeEmbeddings(DIMENSIONS)
47 |
48 |
49 | def test_mmr(embeddings: Embeddings, collection: Collection) -> None:
50 | texts = ["foo", "foo", "fou", "foy"]
51 | collection.delete_many({})
52 | vectorstore = PatchedMongoDBAtlasVectorSearch.from_texts(
53 | texts,
54 | embedding=embeddings,
55 | collection=collection,
56 | index_name=INDEX_NAME,
57 | )
58 | query = "foo"
59 | output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
60 | assert len(output) == len(texts)
61 | assert output[0].page_content == "foo"
62 | assert output[1].page_content != "foo"
63 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_mongodb_database.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: E501
2 | """Test MongoDB database wrapper."""
3 |
4 | import json
5 |
6 | import pytest
7 | from pymongo import MongoClient
8 |
9 | from langchain_mongodb.agent_toolkit import MongoDBDatabase
10 |
11 | DB_NAME = "langchain_test_db_user"
12 |
13 |
14 | @pytest.fixture
15 | def db(client: MongoClient) -> MongoDBDatabase:
16 | client[DB_NAME].user.delete_many({})
17 | user = dict(name="Alice", bio="Engineer from Ohio")
18 | client[DB_NAME]["user"].insert_one(user)
19 | company = dict(name="Acme", location="Montana")
20 | client[DB_NAME]["company"].insert_one(company)
21 | return MongoDBDatabase(client, DB_NAME)
22 |
23 |
24 | def test_collection_info(db: MongoDBDatabase) -> None:
25 | """Test that collection info is constructed properly."""
26 | output = db.collection_info
27 | expected_header = f"""
28 | Database name: {DB_NAME}
29 | Collection name: company
30 | Schema from a sample of documents from the collection:
31 | _id: ObjectId
32 | name: String
33 | location: String""".strip()
34 |
35 | for line in expected_header.splitlines():
36 | assert line.strip() in output
37 |
38 |
39 | def test_collection_info_w_sample_docs(db: MongoDBDatabase) -> None:
40 | """Test that collection info is constructed properly."""
41 |
42 | # Provision.
43 | values = [
44 | {"name": "Harrison", "bio": "bio"},
45 | {"name": "Chase", "bio": "bio"},
46 | ]
47 | db._client[DB_NAME]["user"].delete_many({})
48 | db._client[DB_NAME]["user"].insert_many(values)
49 |
50 | # Query and verify.
51 | db = MongoDBDatabase(db._client, DB_NAME, sample_docs_in_collection_info=2)
52 | output = db.collection_info
53 |
54 | expected_header = f"""
55 | Database name: {DB_NAME}
56 | Collection name: company
57 | Schema from a sample of documents from the collection:
58 | _id: ObjectId
59 | name: String
60 | location: String
61 | """.strip()
62 |
63 | for line in expected_header.splitlines():
64 | assert line.strip() in output
65 |
66 |
67 | def test_database_run(db: MongoDBDatabase) -> None:
68 | """Verify running MQL expressions returning results as strings."""
69 |
70 | # Provision.
71 | db._client[DB_NAME]["user"].delete_many({})
72 | user = dict(name="Harrison", bio="That is my Bio " * 24)
73 | db._client[DB_NAME]["user"].insert_one(user)
74 |
75 | # Query and verify.
76 | command = """db.user.aggregate([ { "$match": { "name": "Harrison" } } ])"""
77 | output = db.run(command)
78 | assert isinstance(output, str)
79 | docs = json.loads(output.strip())
80 | del docs[0]["_id"]
81 | del user["_id"]
82 | assert docs[0] == user
83 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 | from typing import List
3 |
4 | from langchain_core.documents import Document
5 | from langchain_core.embeddings import Embeddings
6 | from langchain_text_splitters import RecursiveCharacterTextSplitter
7 | from pymongo import MongoClient
8 | from pymongo.driver_info import DriverInfo
9 |
10 | from langchain_mongodb.docstores import MongoDBDocStore
11 | from langchain_mongodb.index import create_vector_search_index
12 | from langchain_mongodb.retrievers import (
13 | MongoDBAtlasParentDocumentRetriever,
14 | )
15 |
16 | from ..utils import CONNECTION_STRING, DB_NAME, PatchedMongoDBAtlasVectorSearch
17 |
18 | COLLECTION_NAME = "langchain_test_parent_document_combined"
19 | VECTOR_INDEX_NAME = "langchain-test-parent-document-vector-index"
20 | EMBEDDING_FIELD = "embedding"
21 | TEXT_FIELD = "page_content"
22 | SIMILARITY = "cosine"
23 | TIMEOUT = 60.0
24 |
25 |
26 | def test_1clxn_retriever(
27 | technical_report_pages: List[Document],
28 | embedding: Embeddings,
29 | dimensions: int,
30 | ) -> None:
31 | # Setup
32 | client: MongoClient = MongoClient(
33 | CONNECTION_STRING,
34 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
35 | )
36 | db = client[DB_NAME]
37 | combined_clxn = db[COLLECTION_NAME]
38 | if COLLECTION_NAME not in db.list_collection_names():
39 | db.create_collection(COLLECTION_NAME)
40 | # Clean up
41 | combined_clxn.delete_many({})
42 | # Create Search Index if it doesn't exist
43 | sixs = list(combined_clxn.list_search_indexes())
44 | if len(sixs) == 0:
45 | create_vector_search_index(
46 | collection=combined_clxn,
47 | index_name=VECTOR_INDEX_NAME,
48 | dimensions=dimensions,
49 | path=EMBEDDING_FIELD,
50 | similarity=SIMILARITY,
51 | wait_until_complete=TIMEOUT,
52 | )
53 | # Create Vector and Doc Stores
54 | vectorstore = PatchedMongoDBAtlasVectorSearch(
55 | collection=combined_clxn,
56 | embedding=embedding,
57 | index_name=VECTOR_INDEX_NAME,
58 | text_key=TEXT_FIELD,
59 | embedding_key=EMBEDDING_FIELD,
60 | relevance_score_fn=SIMILARITY,
61 | )
62 | docstore = MongoDBDocStore(collection=combined_clxn, text_key=TEXT_FIELD)
63 | # Combine into a ParentDocumentRetriever
64 | retriever = MongoDBAtlasParentDocumentRetriever(
65 | vectorstore=vectorstore,
66 | docstore=docstore,
67 | child_splitter=RecursiveCharacterTextSplitter(chunk_size=400),
68 | )
69 | # Add documents (splitting, creating embedding, adding to vectorstore and docstore)
70 | retriever.add_documents(technical_report_pages)
71 | # invoke the retriever with a query
72 | question = "What percentage of the Uniform Bar Examination can GPT4 pass?"
73 | responses = retriever.invoke(question)
74 |
75 | assert len(responses) == 3
76 | assert all("GPT-4" in doc.page_content for doc in responses)
77 | assert {4, 5, 29} == set(doc.metadata["page"] for doc in responses)
78 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_retrievers_standard.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 |
3 | from langchain_core.documents import Document
4 | from langchain_tests.integration_tests import (
5 | RetrieversIntegrationTests,
6 | )
7 | from pymongo import MongoClient
8 | from pymongo.collection import Collection
9 |
10 | from langchain_mongodb import MongoDBAtlasVectorSearch
11 | from langchain_mongodb.index import (
12 | create_fulltext_search_index,
13 | )
14 | from langchain_mongodb.retrievers import (
15 | MongoDBAtlasFullTextSearchRetriever,
16 | MongoDBAtlasHybridSearchRetriever,
17 | )
18 |
19 | from ..utils import (
20 | CONNECTION_STRING,
21 | DB_NAME,
22 | TIMEOUT,
23 | ConsistentFakeEmbeddings,
24 | PatchedMongoDBAtlasVectorSearch,
25 | )
26 |
27 | DIMENSIONS = 5
28 | COLLECTION_NAME = "langchain_test_retrievers_standard"
29 | VECTOR_INDEX_NAME = "vector_index"
30 | PAGE_CONTENT_FIELD = "text"
31 | SEARCH_INDEX_NAME = "text_index"
32 |
33 |
34 | def setup_test() -> tuple[Collection, MongoDBAtlasVectorSearch]:
35 | client = MongoClient(CONNECTION_STRING)
36 | coll = client[DB_NAME][COLLECTION_NAME]
37 |
38 | # Set up the vector search index and add the documents if needed.
39 | vs = PatchedMongoDBAtlasVectorSearch(
40 | coll,
41 | embedding=ConsistentFakeEmbeddings(DIMENSIONS),
42 | dimensions=DIMENSIONS,
43 | index_name=VECTOR_INDEX_NAME,
44 | text_key=PAGE_CONTENT_FIELD,
45 | auto_index_timeout=TIMEOUT,
46 | )
47 |
48 | if coll.count_documents({}) == 0:
49 | vs.add_documents(
50 | [
51 | Document(page_content="In 2023, I visited Paris"),
52 | Document(page_content="In 2022, I visited New York"),
53 | Document(page_content="In 2021, I visited New Orleans"),
54 | Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
55 | ]
56 | )
57 |
58 | # Set up the search index if needed.
59 | if not any([ix["name"] == SEARCH_INDEX_NAME for ix in coll.list_search_indexes()]):
60 | create_fulltext_search_index(
61 | collection=coll,
62 | index_name=SEARCH_INDEX_NAME,
63 | field=PAGE_CONTENT_FIELD,
64 | wait_until_complete=TIMEOUT,
65 | )
66 |
67 | return coll, vs
68 |
69 |
70 | class TestMongoDBAtlasFullTextSearchRetriever(RetrieversIntegrationTests):
71 | @property
72 | def retriever_constructor(self) -> Type[MongoDBAtlasFullTextSearchRetriever]:
73 | """Get a retriever for integration tests."""
74 | return MongoDBAtlasFullTextSearchRetriever
75 |
76 | @property
77 | def retriever_constructor_params(self) -> dict:
78 | coll, _ = setup_test()
79 | return {
80 | "collection": coll,
81 | "search_index_name": SEARCH_INDEX_NAME,
82 | "search_field": PAGE_CONTENT_FIELD,
83 | }
84 |
85 | @property
86 | def retriever_query_example(self) -> str:
87 | """
88 | Returns a str representing the "query" of an example retriever call.
89 | """
90 | return "When was the last time I visited new orleans?"
91 |
92 |
93 | class TestMongoDBAtlasHybridSearchRetriever(RetrieversIntegrationTests):
94 | @property
95 | def retriever_constructor(self) -> Type[MongoDBAtlasHybridSearchRetriever]:
96 | """Get a retriever for integration tests."""
97 | return MongoDBAtlasHybridSearchRetriever
98 |
99 | @property
100 | def retriever_constructor_params(self) -> dict:
101 | coll, vs = setup_test()
102 | return {
103 | "vectorstore": vs,
104 | "collection": coll,
105 | "search_index_name": SEARCH_INDEX_NAME,
106 | "search_field": PAGE_CONTENT_FIELD,
107 | }
108 |
109 | @property
110 | def retriever_query_example(self) -> str:
111 | """
112 | Returns a str representing the "query" of an example retriever call.
113 | """
114 | return "When was the last time I visited new orleans?"
115 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_tools.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 |
3 | from langchain_tests.integration_tests import ToolsIntegrationTests
4 |
5 | from langchain_mongodb.agent_toolkit.tool import (
6 | InfoMongoDBDatabaseTool,
7 | ListMongoDBDatabaseTool,
8 | QueryMongoDBCheckerTool,
9 | QueryMongoDBDatabaseTool,
10 | )
11 | from tests.utils import create_database, create_llm
12 |
13 |
14 | class TestQueryMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
15 | @property
16 | def tool_constructor(self) -> Type[QueryMongoDBDatabaseTool]:
17 | return QueryMongoDBDatabaseTool
18 |
19 | @property
20 | def tool_constructor_params(self) -> dict:
21 | return dict(db=create_database())
22 |
23 | @property
24 | def tool_invoke_params_example(self) -> dict:
25 | return dict(query='db.test.aggregate([{"$match": {}}])')
26 |
27 |
28 | class TestInfoMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
29 | @property
30 | def tool_constructor(self) -> Type[InfoMongoDBDatabaseTool]:
31 | return InfoMongoDBDatabaseTool
32 |
33 | @property
34 | def tool_constructor_params(self) -> dict:
35 | return dict(db=create_database())
36 |
37 | @property
38 | def tool_invoke_params_example(self) -> dict:
39 | return dict(collection_names="test")
40 |
41 |
42 | class TestListMongoDBDatabaseToolIntegration(ToolsIntegrationTests):
43 | @property
44 | def tool_constructor(self) -> Type[ListMongoDBDatabaseTool]:
45 | return ListMongoDBDatabaseTool
46 |
47 | @property
48 | def tool_constructor_params(self) -> dict:
49 | return dict(db=create_database())
50 |
51 | @property
52 | def tool_invoke_params_example(self) -> dict:
53 | return dict()
54 |
55 |
56 | class TestQueryMongoDBCheckerToolIntegration(ToolsIntegrationTests):
57 | @property
58 | def tool_constructor(self) -> Type[QueryMongoDBCheckerTool]:
59 | return QueryMongoDBCheckerTool
60 |
61 | @property
62 | def tool_constructor_params(self) -> dict:
63 | return dict(db=create_database(), llm=create_llm())
64 |
65 | @property
66 | def tool_invoke_params_example(self) -> dict:
67 | return dict(query='db.test.aggregate([{"$match": {}}])')
68 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_documents.py:
--------------------------------------------------------------------------------
1 | """Test MongoDBAtlasVectorSearch.from_documents."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import List
6 |
7 | import pytest # type: ignore[import-not-found]
8 | from langchain_core.documents import Document
9 | from langchain_core.embeddings import Embeddings
10 | from pymongo import MongoClient
11 | from pymongo.collection import Collection
12 |
13 | from langchain_mongodb.index import (
14 | create_vector_search_index,
15 | )
16 |
17 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
18 |
19 | COLLECTION_NAME = "langchain_test_from_documents"
20 | INDEX_NAME = "langchain-test-index-from-documents"
21 | DIMENSIONS = 5
22 |
23 |
24 | @pytest.fixture(scope="module")
25 | def collection(client: MongoClient) -> Collection:
26 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
27 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
28 | else:
29 | clxn = client[DB_NAME][COLLECTION_NAME]
30 |
31 | clxn.delete_many({})
32 |
33 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
34 | create_vector_search_index(
35 | collection=clxn,
36 | index_name=INDEX_NAME,
37 | dimensions=DIMENSIONS,
38 | path="embedding",
39 | similarity="cosine",
40 | wait_until_complete=60,
41 | )
42 |
43 | return clxn
44 |
45 |
46 | @pytest.fixture(scope="module")
47 | def example_documents() -> List[Document]:
48 | return [
49 | Document(page_content="Dogs are tough.", metadata={"a": 1}),
50 | Document(page_content="Cats have fluff.", metadata={"b": 1}),
51 | Document(page_content="What is a sandwich?", metadata={"c": 1}),
52 | Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
53 | ]
54 |
55 |
56 | @pytest.fixture(scope="module")
57 | def embeddings() -> Embeddings:
58 | return ConsistentFakeEmbeddings(DIMENSIONS)
59 |
60 |
61 | @pytest.fixture(scope="module")
62 | def vectorstore(
63 | collection: Collection, example_documents: List[Document], embeddings: Embeddings
64 | ) -> PatchedMongoDBAtlasVectorSearch:
65 | """VectorStore created with a few documents and a trivial embedding model.
66 |
67 | Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all
68 | but one important feature. It waits until all documents are fully indexed
69 | before returning control to the caller.
70 | """
71 | return PatchedMongoDBAtlasVectorSearch.from_documents(
72 | example_documents,
73 | embedding=embeddings,
74 | collection=collection,
75 | index_name=INDEX_NAME,
76 | )
77 |
78 |
79 | def test_default_search(
80 | vectorstore: PatchedMongoDBAtlasVectorSearch, example_documents: List[Document]
81 | ) -> None:
82 | """Test end to end construction and search."""
83 | output = vectorstore.similarity_search("Sandwich", k=1)
84 | assert len(output) == 1
85 | # Check for the presence of the metadata key
86 | assert any(
87 | [key.page_content == output[0].page_content for key in example_documents]
88 | )
89 | # Assert no presence of embeddings in results
90 | assert all(["embedding" not in key.metadata for key in output])
91 |
92 |
93 | def test_search_with_embeddings(vectorstore: PatchedMongoDBAtlasVectorSearch) -> None:
94 | output = vectorstore.similarity_search("Sandwich", k=2, include_embeddings=True)
95 | assert len(output) == 2
96 |
97 | # Assert embeddings in results
98 | assert all([key.metadata.get("embedding") for key in output])
99 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_from_texts.py:
--------------------------------------------------------------------------------
1 | """Test MongoDBAtlasVectorSearch.from_documents."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Dict, Generator, List
6 |
7 | import pytest # type: ignore[import-not-found]
8 | from langchain_core.embeddings import Embeddings
9 | from pymongo import MongoClient
10 | from pymongo.collection import Collection
11 |
12 | from langchain_mongodb import MongoDBAtlasVectorSearch
13 | from langchain_mongodb.index import (
14 | create_vector_search_index,
15 | )
16 |
17 | from ..utils import DB_NAME, ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
18 |
19 | COLLECTION_NAME = "langchain_test_from_texts"
20 | INDEX_NAME = "langchain-test-index-from-texts"
21 | DIMENSIONS = 5
22 |
23 |
24 | @pytest.fixture(scope="module")
25 | def collection(client: MongoClient) -> Collection:
26 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
27 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
28 | else:
29 | clxn = client[DB_NAME][COLLECTION_NAME]
30 |
31 | clxn.delete_many({})
32 |
33 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
34 | create_vector_search_index(
35 | collection=clxn,
36 | index_name=INDEX_NAME,
37 | dimensions=DIMENSIONS,
38 | path="embedding",
39 | filters=["c"],
40 | similarity="cosine",
41 | wait_until_complete=60,
42 | )
43 |
44 | return clxn
45 |
46 |
47 | @pytest.fixture(scope="module")
48 | def texts() -> List[str]:
49 | return [
50 | "Dogs are tough.",
51 | "Cats have fluff.",
52 | "What is a sandwich?",
53 | "That fence is purple.",
54 | ]
55 |
56 |
57 | @pytest.fixture(scope="module")
58 | def metadatas() -> List[Dict]:
59 | return [{"a": 1}, {"b": 1}, {"c": 1}, {"d": 1, "e": 2}]
60 |
61 |
62 | @pytest.fixture(scope="module")
63 | def embeddings() -> Embeddings:
64 | return ConsistentFakeEmbeddings(DIMENSIONS)
65 |
66 |
67 | @pytest.fixture(scope="module")
68 | def vectorstore(
69 | collection: Collection,
70 | texts: List[str],
71 | embeddings: Embeddings,
72 | metadatas: List[dict],
73 | ) -> Generator[MongoDBAtlasVectorSearch]:
74 | """VectorStore created with a few documents and a trivial embedding model.
75 |
76 | Note: PatchedMongoDBAtlasVectorSearch is MongoDBAtlasVectorSearch in all
77 | but one important feature. It waits until all documents are fully indexed
78 | before returning control to the caller.
79 | """
80 | vectorstore_from_texts = PatchedMongoDBAtlasVectorSearch.from_texts(
81 | texts=texts,
82 | embedding=embeddings,
83 | metadatas=metadatas,
84 | collection=collection,
85 | index_name=INDEX_NAME,
86 | )
87 | yield vectorstore_from_texts
88 |
89 | vectorstore_from_texts.collection.delete_many({})
90 |
91 |
92 | def test_search_with_metadatas_and_pre_filter(
93 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict]
94 | ) -> None:
95 | # Confirm the presence of metadata in output
96 | output = vectorstore.similarity_search("Sandwich", k=1)
97 | assert len(output) == 1
98 | metakeys = [list(d.keys())[0] for d in metadatas]
99 | assert any([key in output[0].metadata for key in metakeys])
100 |
101 |
102 | def test_search_filters_all(
103 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict]
104 | ) -> None:
105 | # Test filtering out
106 | does_not_match_filter = vectorstore.similarity_search(
107 | "Sandwich", k=1, pre_filter={"c": {"$lte": 0}}
108 | )
109 | assert does_not_match_filter == []
110 |
111 |
112 | def test_search_pre_filter(
113 | vectorstore: PatchedMongoDBAtlasVectorSearch, metadatas: List[Dict]
114 | ) -> None:
115 | # Test filtering with expected output
116 | matches_filter = vectorstore.similarity_search(
117 | "Sandwich", k=3, pre_filter={"c": {"$gt": 0}}
118 | )
119 | assert len(matches_filter) == 1
120 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/integration_tests/test_vectorstore_standard.py:
--------------------------------------------------------------------------------
1 | """Test MongoDBAtlasVectorSearch.from_documents."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest # type: ignore[import-not-found]
6 | from langchain_core.vectorstores import VectorStore
7 | from langchain_tests.integration_tests import VectorStoreIntegrationTests
8 | from pymongo import MongoClient
9 | from pymongo.collection import Collection
10 |
11 | from langchain_mongodb.index import (
12 | create_vector_search_index,
13 | )
14 |
15 | from ..utils import DB_NAME, PatchedMongoDBAtlasVectorSearch
16 |
17 | COLLECTION_NAME = "langchain_test_standard"
18 | INDEX_NAME = "langchain-test-index-standard"
19 | DIMENSIONS = 6
20 |
21 |
22 | @pytest.fixture
23 | def collection(client: MongoClient) -> Collection:
24 | if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
25 | clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
26 | else:
27 | clxn = client[DB_NAME][COLLECTION_NAME]
28 |
29 | clxn.delete_many({})
30 |
31 | if not any([INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
32 | create_vector_search_index(
33 | collection=clxn,
34 | index_name=INDEX_NAME,
35 | dimensions=DIMENSIONS,
36 | path="embedding",
37 | filters=["c"],
38 | similarity="cosine",
39 | wait_until_complete=60,
40 | )
41 |
42 | return clxn
43 |
44 |
45 | class TestMongoDBAtlasVectorSearch(VectorStoreIntegrationTests):
46 | @pytest.fixture()
47 | def vectorstore(self, collection) -> VectorStore: # type: ignore
48 | """Get an empty vectorstore for unit tests."""
49 | store = PatchedMongoDBAtlasVectorSearch(
50 | collection, self.get_embeddings(), index_name=INDEX_NAME
51 | )
52 | # note: store should be EMPTY at this point
53 | # if you need to delete data, you may do so here
54 | return store
55 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-mongodb/c83235a1caa90207f4c5441927d47aaa259afa0e/libs/langchain-mongodb/tests/unit_tests/__init__.py
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py:
--------------------------------------------------------------------------------
1 | import json
2 | from importlib.metadata import version
3 |
4 | import mongomock
5 | import pytest
6 | from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
7 | from langchain_core.messages import message_to_dict
8 | from pymongo.driver_info import DriverInfo
9 | from pytest_mock import MockerFixture
10 |
11 | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
12 |
13 | from ..utils import MockCollection
14 |
15 |
16 | class PatchedMongoDBChatMessageHistory(MongoDBChatMessageHistory):
17 | def __init__(self) -> None:
18 | self.session_id = "test-session"
19 | self.database_name = "test-database"
20 | self.collection_name = "test-collection"
21 | self.collection = MockCollection()
22 | self.session_id_key = "SessionId"
23 | self.history_key = "History"
24 | self.history_size = None
25 | self.db = self.collection.database
26 | self.client = self.db.client
27 |
28 |
29 | def test_memory_with_message_store() -> None:
30 | """Test the memory with a message store."""
31 | # setup MongoDB as a message store
32 | message_history = PatchedMongoDBChatMessageHistory()
33 | memory = ConversationBufferMemory(
34 | memory_key="baz", chat_memory=message_history, return_messages=True
35 | )
36 |
37 | # add some messages
38 | memory.chat_memory.add_ai_message("This is me, the AI")
39 | memory.chat_memory.add_user_message("This is me, the human")
40 |
41 | # get the message history from the memory store and turn it into a json
42 | messages = memory.chat_memory.messages
43 | messages_json = json.dumps([message_to_dict(msg) for msg in messages])
44 |
45 | assert "This is me, the AI" in messages_json
46 | assert "This is me, the human" in messages_json
47 |
48 | # remove the record from MongoDB, so the next test run won't pick it up
49 | memory.chat_memory.clear()
50 |
51 | assert memory.chat_memory.messages == []
52 |
53 | message_history.close()
54 | assert message_history.client.is_closed
55 |
56 |
57 | def test_init_with_connection_string(mocker: MockerFixture) -> None:
58 | mock_mongo_client = mocker.patch(
59 | "langchain_mongodb.chat_message_histories.MongoClient"
60 | )
61 |
62 | history = MongoDBChatMessageHistory(
63 | connection_string="mongodb://localhost:27017/",
64 | session_id="test-session",
65 | database_name="test-database",
66 | collection_name="test-collection",
67 | )
68 |
69 | mock_mongo_client.assert_called_once_with(
70 | "mongodb://localhost:27017/",
71 | driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
72 | )
73 | assert history.session_id == "test-session"
74 | assert history.database_name == "test-database"
75 | assert history.collection_name == "test-collection"
76 | history.close()
77 |
78 |
79 | def test_init_with_existing_client() -> None:
80 | client = mongomock.MongoClient() # type: ignore[var-annotated]
81 |
82 | # Initialize MongoDBChatMessageHistory with the mock client
83 | history = MongoDBChatMessageHistory(
84 | connection_string=None,
85 | session_id="test-session",
86 | database_name="test-database",
87 | collection_name="test-collection",
88 | client=client,
89 | )
90 |
91 | assert history.session_id == "test-session"
92 |
93 | # Verify that the collection is correctly created within the specified database
94 | assert "test-database" in client.list_database_names()
95 | assert "test-collection" in client["test-database"].list_collection_names()
96 |
97 | history.close()
98 |
99 |
100 | def test_init_raises_error_without_connection_or_client() -> None:
101 | with pytest.raises(
102 | ValueError, match="Either connection_string or client must be provided"
103 | ):
104 | MongoDBChatMessageHistory(
105 | session_id="test_session",
106 | connection_string=None,
107 | client=None,
108 | )
109 |
110 |
111 | def test_init_raises_error_with_both_connection_and_client() -> None:
112 | client_mock = mongomock.MongoClient() # type: ignore[var-annotated]
113 |
114 | with pytest.raises(
115 | ValueError, match="Must provide connection_string or client, not both"
116 | ):
117 | MongoDBChatMessageHistory(
118 | connection_string="mongodb://localhost:27017/",
119 | session_id="test_session",
120 | client=client_mock,
121 | )
122 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/test_imports.py:
--------------------------------------------------------------------------------
1 | from langchain_mongodb import __all__
2 |
3 | EXPECTED_ALL = [
4 | "MongoDBAtlasVectorSearch",
5 | "MongoDBChatMessageHistory",
6 | "MongoDBCache",
7 | "MongoDBAtlasSemanticCache",
8 | ]
9 |
10 |
11 | def test_all_imports() -> None:
12 | assert sorted(EXPECTED_ALL) == sorted(__all__)
13 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/test_index.py:
--------------------------------------------------------------------------------
1 | from time import sleep
2 |
3 | import pytest
4 | from pymongo import MongoClient
5 | from pymongo.collection import Collection
6 | from pymongo.errors import OperationFailure, ServerSelectionTimeoutError
7 |
8 | from langchain_mongodb import index
9 |
10 | DIMENSION = 5
11 | TIMEOUT = 120
12 |
13 |
14 | @pytest.fixture
15 | def collection() -> Collection:
16 | """Collection on MongoDB Cluster, not an Atlas one."""
17 | client: MongoClient = MongoClient()
18 | return client["db"]["collection"]
19 |
20 |
21 | def test_create_vector_search_index(collection: Collection) -> None:
22 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
23 | index.create_vector_search_index(
24 | collection,
25 | "index_name",
26 | DIMENSION,
27 | "embedding",
28 | "cosine",
29 | [],
30 | wait_until_complete=TIMEOUT,
31 | )
32 |
33 |
34 | def test_drop_vector_search_index(collection: Collection) -> None:
35 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
36 | index.drop_vector_search_index(
37 | collection, "index_name", wait_until_complete=TIMEOUT
38 | )
39 |
40 |
41 | def test_update_vector_search_index(collection: Collection) -> None:
42 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
43 | index.update_vector_search_index(
44 | collection,
45 | "index_name",
46 | DIMENSION,
47 | "embedding",
48 | "cosine",
49 | [],
50 | wait_until_complete=TIMEOUT,
51 | )
52 |
53 |
54 | def test___is_index_ready(collection: Collection) -> None:
55 | with pytest.raises((OperationFailure, ServerSelectionTimeoutError)):
56 | index._is_index_ready(collection, "index_name")
57 |
58 |
59 | def test__wait_for_predicate() -> None:
60 | err = "error string"
61 | with pytest.raises(TimeoutError) as e:
62 | index._wait_for_predicate(lambda: sleep(5), err=err, timeout=0.5, interval=0.1)
63 | assert err in str(e)
64 |
65 | index._wait_for_predicate(lambda: True, err=err, timeout=1.0, interval=0.5)
66 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/test_retrievers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from langchain_text_splitters import RecursiveCharacterTextSplitter
3 |
4 | from langchain_mongodb.docstores import MongoDBDocStore
5 | from langchain_mongodb.retrievers import (
6 | MongoDBAtlasFullTextSearchRetriever,
7 | MongoDBAtlasHybridSearchRetriever,
8 | MongoDBAtlasParentDocumentRetriever,
9 | )
10 | from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
11 |
12 | from ..utils import ConsistentFakeEmbeddings, MockCollection
13 |
14 |
15 | @pytest.fixture()
16 | def collection() -> MockCollection:
17 | return MockCollection()
18 |
19 |
20 | @pytest.fixture()
21 | def embeddings() -> ConsistentFakeEmbeddings:
22 | return ConsistentFakeEmbeddings()
23 |
24 |
25 | def test_full_text_search(collection):
26 | search = MongoDBAtlasFullTextSearchRetriever(
27 | collection=collection, search_index_name="foo", search_field="bar"
28 | )
29 | search.close()
30 | assert collection.database.client.is_closed
31 |
32 |
33 | def test_hybrid_search(collection, embeddings):
34 | vs = MongoDBAtlasVectorSearch(collection, embeddings)
35 | search = MongoDBAtlasHybridSearchRetriever(vectorstore=vs, search_index_name="foo")
36 | search.close()
37 | assert collection.database.client.is_closed
38 |
39 |
40 | def test_parent_retriever(collection, embeddings):
41 | vs = MongoDBAtlasVectorSearch(collection, embeddings)
42 | ds = MongoDBDocStore(collection)
43 | cs = RecursiveCharacterTextSplitter(chunk_size=400)
44 | retriever = MongoDBAtlasParentDocumentRetriever(
45 | vectorstore=vs, docstore=ds, child_splitter=cs
46 | )
47 | retriever.close()
48 | assert collection.database.client.is_closed
49 |
--------------------------------------------------------------------------------
/libs/langchain-mongodb/tests/unit_tests/test_tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Type
4 |
5 | from langchain_tests.unit_tests import ToolsUnitTests
6 |
7 | from langchain_mongodb.agent_toolkit import MongoDBDatabase
8 | from langchain_mongodb.agent_toolkit.tool import (
9 | InfoMongoDBDatabaseTool,
10 | ListMongoDBDatabaseTool,
11 | QueryMongoDBCheckerTool,
12 | QueryMongoDBDatabaseTool,
13 | )
14 |
15 | from ..utils import FakeLLM, MockClient
16 |
17 |
18 | class TestQueryMongoDBDatabaseToolUnit(ToolsUnitTests):
19 | @property
20 | def tool_constructor(self) -> Type[QueryMongoDBDatabaseTool]:
21 | return QueryMongoDBDatabaseTool
22 |
23 | @property
24 | def tool_constructor_params(self) -> dict:
25 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type]
26 |
27 | @property
28 | def tool_invoke_params_example(self) -> dict:
29 | return dict(query="db.foo.aggregate()")
30 |
31 |
32 | class TestInfoMongoDBDatabaseToolUnit(ToolsUnitTests):
33 | @property
34 | def tool_constructor(self) -> Type[InfoMongoDBDatabaseTool]:
35 | return InfoMongoDBDatabaseTool
36 |
37 | @property
38 | def tool_constructor_params(self) -> dict:
39 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type]
40 |
41 | @property
42 | def tool_invoke_params_example(self) -> dict:
43 | return dict(collection_names="test")
44 |
45 |
46 | class TestListMongoDBDatabaseToolUnit(ToolsUnitTests):
47 | @property
48 | def tool_constructor(self) -> Type[ListMongoDBDatabaseTool]:
49 | return ListMongoDBDatabaseTool
50 |
51 | @property
52 | def tool_constructor_params(self) -> dict:
53 | return dict(db=MongoDBDatabase(MockClient(), "test")) # type:ignore[arg-type]
54 |
55 | @property
56 | def tool_invoke_params_example(self) -> dict:
57 | return dict()
58 |
59 |
60 | class TestQueryMongoDBCheckerToolUnit(ToolsUnitTests):
61 | @property
62 | def tool_constructor(self) -> Type[QueryMongoDBCheckerTool]:
63 | return QueryMongoDBCheckerTool
64 |
65 | @property
66 | def tool_constructor_params(self) -> dict:
67 | return dict(db=MongoDBDatabase(MockClient(), "test"), llm=FakeLLM()) # type:ignore[arg-type]
68 |
69 | @property
70 | def tool_invoke_params_example(self) -> dict:
71 | return dict(query="db.foo.aggregate()")
72 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ---
4 |
5 | ## Changes in version 0.1.3 (2025/04/01)
6 |
7 | - Add compatibility with `pymongo.AsyncMongoClient`.
8 |
9 | ## Changes in version 0.1.2 (2025/03/26)
10 |
11 | - Add compatibility with `langgraph-checkpoint` 2.0.23.
12 |
13 | ## Changes in version 0.1.1 (2025/02/26)
14 |
15 | - Remove dependency on `langgraph`.
16 |
17 | ## Changes in version 0.1 (2024/12/13)
18 |
19 | - Initial release, added support for `MongoDBSaver`.
20 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/README.md:
--------------------------------------------------------------------------------
1 | # LangGraph Checkpoint MongoDB
2 |
3 | Implementation of LangGraph CheckpointSaver that uses MongoDB.
4 |
5 | ## Usage
6 |
7 | ```python
8 | from langgraph.checkpoint.mongodb import MongoDBSaver
9 |
10 | write_config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
11 | read_config = {"configurable": {"thread_id": "1"}}
12 |
13 | MONGODB_URI = "mongodb://localhost:27017"
14 | DB_NAME = "checkpoint_example"
15 |
16 | with MongoDBSaver.from_conn_string(MONGODB_URI, DB_NAME) as checkpointer:
17 | checkpoint = {
18 | "v": 1,
19 | "ts": "2024-07-31T20:14:19.804150+00:00",
20 | "id": "1ef4f797-8335-6428-8001-8a1503f9b875",
21 | "channel_values": {
22 | "my_key": "meow",
23 | "node": "node"
24 | },
25 | "channel_versions": {
26 | "__start__": 2,
27 | "my_key": 3,
28 | "start:node": 3,
29 | "node": 3
30 | },
31 | "versions_seen": {
32 | "__input__": {},
33 | "__start__": {
34 | "__start__": 1
35 | },
36 | "node": {
37 | "start:node": 2
38 | }
39 | },
40 | "pending_sends": [],
41 | }
42 |
43 | # store checkpoint
44 | checkpointer.put(write_config, checkpoint, {}, {})
45 |
46 | # load checkpoint
47 | checkpointer.get(read_config)
48 |
49 | # list checkpoints
50 | list(checkpointer.list(read_config))
51 | ```
52 |
53 | ### Async
54 |
55 | ```python
56 | from langgraph.checkpoint.pymongo import AsyncMongoDBSaver
57 |
58 | async with AsyncMongoDBSaver.from_conn_string(MONGODB_URI) as checkpointer:
59 | checkpoint = {
60 | "v": 1,
61 | "ts": "2024-07-31T20:14:19.804150+00:00",
62 | "id": "1ef4f797-8335-6428-8001-8a1503f9b875",
63 | "channel_values": {
64 | "my_key": "meow",
65 | "node": "node"
66 | },
67 | "channel_versions": {
68 | "__start__": 2,
69 | "my_key": 3,
70 | "start:node": 3,
71 | "node": 3
72 | },
73 | "versions_seen": {
74 | "__input__": {},
75 | "__start__": {
76 | "__start__": 1
77 | },
78 | "node": {
79 | "start:node": 2
80 | }
81 | },
82 | "pending_sends": [],
83 | }
84 |
85 | # store checkpoint
86 | await checkpointer.aput(write_config, checkpoint, {}, {})
87 |
88 | # load checkpoint
89 | await checkpointer.aget(read_config)
90 |
91 | # list checkpoints
92 | [c async for c in checkpointer.alist(read_config)]
93 | ```
94 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/justfile:
--------------------------------------------------------------------------------
1 | set shell := ["bash", "-c"]
2 | set dotenv-load
3 | set dotenv-filename := "../../.local_atlas_uri"
4 |
5 | # Default target executed when no arguments are given.
6 | [private]
7 | default:
8 | @just --list
9 |
10 | install:
11 | uv sync --frozen
12 |
13 | [group('test')]
14 | integration_tests *args="":
15 | uv run pytest tests/integration_tests/ {{args}}
16 |
17 | [group('test')]
18 | unit_tests *args="":
19 | uv run pytest tests/unit_tests {{args}}
20 |
21 | [group('test')]
22 | tests *args="":
23 | uv run pytest {{args}}
24 |
25 | [group('test')]
26 | test_watch filename:
27 | uv run ptw --snapshot-update --now . -- -vv {{filename}}
28 |
29 | [group('lint')]
30 | lint:
31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files
32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files
33 |
34 | [group('lint')]
35 | typing:
36 | uv run mypy .
37 |
38 | [group('lint')]
39 | codespell:
40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files
41 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/__init__.py:
--------------------------------------------------------------------------------
1 | from .aio import AsyncMongoDBSaver
2 | from .saver import MongoDBSaver
3 |
4 | __all__ = ["MongoDBSaver", "AsyncMongoDBSaver"]
5 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | :private:
3 | Utilities for langchain-checkpoint-mongod.
4 | """
5 |
6 | from typing import Any, Union
7 |
8 | from langgraph.checkpoint.base import CheckpointMetadata
9 | from langgraph.checkpoint.serde.base import SerializerProtocol
10 | from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
11 |
12 | serde: SerializerProtocol = JsonPlusSerializer()
13 |
14 |
15 | def loads_metadata(metadata: dict[str, Any]) -> CheckpointMetadata:
16 | """Deserialize metadata document
17 |
18 | The CheckpointMetadata class itself cannot be stored directly in MongoDB,
19 | but as a dictionary it can. For efficient filtering in MongoDB,
20 | we keep dict keys as strings.
21 |
22 | metadata is stored in MongoDB collection with string keys and
23 | serde serialized keys.
24 | """
25 | if isinstance(metadata, dict):
26 | output = dict()
27 | for key, value in metadata.items():
28 | output[key] = loads_metadata(value)
29 | return output
30 | else:
31 | return serde.loads(metadata)
32 |
33 |
34 | def dumps_metadata(
35 | metadata: Union[CheckpointMetadata, Any],
36 | ) -> Union[bytes, dict[str, Any]]:
37 | """Serialize all values in metadata dictionary.
38 |
39 | Keep dict keys as strings for efficient filtering in MongoDB
40 | """
41 | if isinstance(metadata, dict):
42 | output = dict()
43 | for key, value in metadata.items():
44 | output[key] = dumps_metadata(value)
45 | return output
46 | else:
47 | return serde.dumps(metadata)
48 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling>1.24"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "langgraph-checkpoint-mongodb"
7 | version = "0.1.3"
8 | description = "Library with a MongoDB implementation of LangGraph checkpoint saver."
9 | readme = "README.md"
10 | requires-python = ">=3.9"
11 | dependencies = [
12 | "langgraph-checkpoint>=2.0.23,<3.0.0",
13 | "langchain-mongodb>=0.6.1",
14 | "pymongo>=4.10,<4.13",
15 | "motor>3.6.0",
16 | ]
17 |
18 | [dependency-groups]
19 | dev = [
20 | "anyio>=4.4.0",
21 | "langchain-core>=0.3.55",
22 | "langchain-ollama>=0.2.2",
23 | "langchain-openai>=0.2.14",
24 | "langgraph>=0.3.23",
25 | "langgraph-checkpoint>=2.0.9",
26 | "pytest-asyncio>=0.21.1",
27 | "pytest>=7.2.1",
28 | "pytest-mock>=3.11.1",
29 | "pytest-watch>=4.2.0",
30 | "pytest-repeat>=0.9.3",
31 | "syrupy>=4.0.2",
32 | "pre-commit>=4.0",
33 | "mypy>=1.10.0",
34 | ]
35 |
36 | [tool.hatch.build.targets.wheel]
37 | packages = ["langgraph"]
38 |
39 | [tool.pytest.ini_options]
40 | addopts = "--strict-markers --strict-config --durations=5 -vv"
41 | markers = [
42 | "requires: mark tests as requiring a specific library",
43 | "compile: mark placeholder test used to compile integration tests without running them",
44 | ]
45 | asyncio_mode = "auto"
46 |
47 | [tool.ruff]
48 | lint.select = [
49 | "E", # pycodestyle
50 | "F", # Pyflakes
51 | "UP", # pyupgrade
52 | "B", # flake8-bugbear
53 | "I", # isort
54 | ]
55 | lint.ignore = ["E501", "B008", "UP007", "UP006"]
56 |
57 | [tool.mypy]
58 | # https://mypy.readthedocs.io/en/stable/config_file.html
59 | disallow_untyped_defs = true
60 | explicit_package_bases = true
61 | warn_no_return = false
62 | warn_unused_ignores = true
63 | warn_redundant_casts = true
64 | allow_redefinition = true
65 | disable_error_code = "typeddict-item, return-value"
66 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/integration_tests/README:
--------------------------------------------------------------------------------
1 | tldr; Placeholder until after https://jira.mongodb.org/browse/INTPYTHON-492 is resolved.
2 |
3 | This directory had previously held copies of test_pregel.py and test_pregel_async.py,
4 | as well as associated fixtures and utilities.
5 | [See here](https://github.com/langchain-ai/langgraph/blob/main/libs/langgraph/tests)
6 | What was here was only temporary.
7 | The files were copied from the langgraph library, not langgraph-checkpoint-X.
8 | While not in the checkpoint lib itself, they provided extensive testing of the checkpointers
9 | that were crucial to testing edge cases.
10 |
11 | As LangGraph development continued our copies have began to fail.
12 | We will determine how to proceed with maintenance in conversation with the
13 | Langchain-AI / LangGraph team.
14 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/integration_tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from langchain_core.embeddings import Embeddings
5 | from langchain_ollama.embeddings import OllamaEmbeddings
6 | from langchain_openai import OpenAIEmbeddings
7 |
8 |
9 | @pytest.fixture(scope="session")
10 | def embedding() -> Embeddings:
11 | if os.environ.get("OPENAI_API_KEY"):
12 | return OpenAIEmbeddings(
13 | openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
14 | model="text-embedding-3-small",
15 | )
16 |
17 | return OllamaEmbeddings(model="all-minilm:l6-v2")
18 |
19 |
20 | @pytest.fixture(scope="session")
21 | def dimensions() -> int:
22 | if os.environ.get("OPENAI_API_KEY"):
23 | return 1536
24 | return 384
25 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_interrupts.py:
--------------------------------------------------------------------------------
1 | """Follows https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/time-travel"""
2 |
3 | import os
4 | from collections.abc import Generator
5 | from typing import TypedDict
6 |
7 | import pytest
8 | from langchain_core.runnables import RunnableConfig
9 |
10 | from langgraph.checkpoint.base import BaseCheckpointSaver
11 | from langgraph.checkpoint.memory import InMemorySaver
12 | from langgraph.checkpoint.mongodb import MongoDBSaver
13 | from langgraph.graph import END, StateGraph
14 | from langgraph.graph.graph import CompiledGraph
15 |
16 | # --- Configuration ---
17 | MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
18 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
19 | CHECKPOINT_CLXN_NAME = "interrupts_checkpoints"
20 | WRITES_CLXN_NAME = "interrupts_writes"
21 |
22 |
23 | @pytest.fixture(scope="function")
24 | def checkpointer_memory() -> Generator[InMemorySaver, None, None]:
25 | yield InMemorySaver()
26 |
27 |
28 | @pytest.fixture(scope="function")
29 | def checkpointer_mongodb() -> Generator[MongoDBSaver, None, None]:
30 | with MongoDBSaver.from_conn_string(
31 | MONGODB_URI,
32 | db_name=DB_NAME,
33 | checkpoint_collection_name=CHECKPOINT_CLXN_NAME,
34 | writes_collection_name=WRITES_CLXN_NAME,
35 | ) as checkpointer:
36 | checkpointer.checkpoint_collection.delete_many({})
37 | checkpointer.writes_collection.delete_many({})
38 | yield checkpointer
39 | checkpointer.checkpoint_collection.drop()
40 | checkpointer.writes_collection.drop()
41 |
42 |
43 | ALL_CHECKPOINTERS_SYNC = [
44 | "checkpointer_memory",
45 | "checkpointer_mongodb",
46 | ]
47 |
48 |
49 | @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
50 | def test(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
51 | checkpointer: BaseCheckpointSaver = request.getfixturevalue(checkpointer_name)
52 | assert isinstance(checkpointer, BaseCheckpointSaver)
53 |
54 | # --- State Definition ---
55 | class State(TypedDict):
56 | value: int
57 | step: int
58 |
59 | # --- Node Definitions ---
60 | def node_inc(state: State) -> State:
61 | """Increments value and step by 1"""
62 | current_step = state.get("step", 0)
63 | return {"value": state["value"] + 1, "step": current_step + 1}
64 |
65 | def node_double(state: State) -> State:
66 | """Doubles value and increments step by 1"""
67 | current_step = state.get("step", 0)
68 | return {"value": state["value"] * 2, "step": current_step + 1}
69 |
70 | # --- Graph Construction ---
71 | builder = StateGraph(State)
72 | builder.add_node("increment", node_inc)
73 | builder.add_node("double", node_double)
74 | builder.set_entry_point("increment")
75 | builder.add_edge("increment", "double")
76 | builder.add_edge("double", END)
77 |
78 | # --- Compile Graph (with Interruption) ---
79 | # Using sync for simplicity in this demo
80 | graph: CompiledGraph = builder.compile(
81 | checkpointer=checkpointer, interrupt_after=["increment"]
82 | )
83 |
84 | # --- Configure ---
85 | config: RunnableConfig = {"configurable": {"thread_id": "thread_#1"}}
86 | initial_input = {"value": 10, "step": 0}
87 |
88 | # --- 1st invoke, with Interruption
89 | interrupted_state = graph.invoke(initial_input, config=config)
90 | assert interrupted_state == {"value": 10 + 1, "step": 1}
91 | state_history = list(graph.get_state_history(config))
92 | assert len(state_history) == 3
93 | # The states are returned in reverse chronological order.
94 | assert state_history[0].next == ("double",)
95 |
96 | # --- 2nd invoke, with input=None, and original config ==> continues from point of interruption
97 | final_state = graph.invoke(None, config=config)
98 | assert final_state == {"value": (10 + 1) * 2, "step": 2}
99 | state_history = list(graph.get_state_history(config))
100 | assert len(state_history) == 4
101 | assert state_history[0].next == ()
102 | assert state_history[-1].next == ("__start__",)
103 |
104 | # --- 3rd invoke, but with an input ===> the CompiledGraph is restarted.
105 | new_input = {"value": 100, "step": -100}
106 | third_state = graph.invoke(new_input, config=config)
107 | assert third_state == {"value": 101, "step": -99}
108 |
109 | # The entire state history is preserved however
110 | state_history = list(graph.get_state_history(config))
111 | assert len(state_history) == 7
112 | assert state_history[0].next == ("double",)
113 | assert state_history[2].next == ("__start__",)
114 |
115 | # --- Upstate state and continue from interrupt
116 | updated_state = {"value": 1000, "step": 1000}
117 | updated_config = graph.update_state(config, updated_state)
118 | final_state = graph.invoke(input=None, config=updated_config)
119 | assert final_state == {"value": 2000, "step": 1001}
120 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_sanity.py:
--------------------------------------------------------------------------------
1 | # Placeholder so that Monorepo CI passes until after https://jira.mongodb.org/browse/INTPYTHON-492 is resolved.
2 | def test_sanity() -> None:
3 | assert True
4 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/unit_tests/conftest.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 | from langchain_core.runnables import RunnableConfig
5 |
6 | from langgraph.checkpoint.base import (
7 | CheckpointMetadata,
8 | create_checkpoint,
9 | empty_checkpoint,
10 | )
11 |
12 |
13 | @pytest.fixture(scope="session")
14 | def input_data() -> dict:
15 | """Setup and store conveniently in a single dictionary."""
16 | inputs: dict[str, Any] = {}
17 |
18 | inputs["config_1"] = RunnableConfig(
19 | configurable=dict(thread_id="thread-1", thread_ts="1", checkpoint_ns="")
20 | ) # config_1 tests deprecated thread_ts
21 |
22 | inputs["config_2"] = RunnableConfig(
23 | configurable=dict(thread_id="thread-2", checkpoint_id="2", checkpoint_ns="")
24 | )
25 |
26 | inputs["config_3"] = RunnableConfig(
27 | configurable=dict(
28 | thread_id="thread-2", checkpoint_id="2-inner", checkpoint_ns="inner"
29 | )
30 | )
31 |
32 | inputs["chkpnt_1"] = empty_checkpoint()
33 | inputs["chkpnt_2"] = create_checkpoint(inputs["chkpnt_1"], {}, 1)
34 | inputs["chkpnt_3"] = empty_checkpoint()
35 |
36 | inputs["metadata_1"] = CheckpointMetadata(
37 | source="input", step=2, writes={}, score=1
38 | )
39 | inputs["metadata_2"] = CheckpointMetadata(
40 | source="loop", step=1, writes={"foo": "bar"}, score=None
41 | )
42 | inputs["metadata_3"] = CheckpointMetadata()
43 |
44 | return inputs
45 |
--------------------------------------------------------------------------------
/libs/langgraph-checkpoint-mongodb/tests/unit_tests/test_async.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any
3 |
4 | import pytest
5 | from bson.errors import InvalidDocument
6 | from motor.motor_asyncio import AsyncIOMotorClient
7 |
8 | from langgraph.checkpoint.mongodb.aio import AsyncMongoDBSaver
9 |
10 | MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")
11 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
12 | COLLECTION_NAME = "sync_checkpoints_aio"
13 |
14 |
15 | async def test_asearch(input_data: dict[str, Any]) -> None:
16 | # Clear collections if they exist
17 | client: AsyncIOMotorClient = AsyncIOMotorClient(MONGODB_URI)
18 | db = client[DB_NAME]
19 |
20 | for clxn in await db.list_collection_names():
21 | await db.drop_collection(clxn)
22 |
23 | async with AsyncMongoDBSaver.from_conn_string(
24 | MONGODB_URI, DB_NAME, COLLECTION_NAME
25 | ) as saver:
26 | # save checkpoints
27 | await saver.aput(
28 | input_data["config_1"],
29 | input_data["chkpnt_1"],
30 | input_data["metadata_1"],
31 | {},
32 | )
33 | await saver.aput(
34 | input_data["config_2"],
35 | input_data["chkpnt_2"],
36 | input_data["metadata_2"],
37 | {},
38 | )
39 | await saver.aput(
40 | input_data["config_3"],
41 | input_data["chkpnt_3"],
42 | input_data["metadata_3"],
43 | {},
44 | )
45 |
46 | # call method / assertions
47 | query_1 = {"source": "input"} # search by 1 key
48 | query_2 = {
49 | "step": 1,
50 | "writes": {"foo": "bar"},
51 | } # search by multiple keys
52 | query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
53 | query_4 = {"source": "update", "step": 1} # no match
54 |
55 | search_results_1 = [c async for c in saver.alist(None, filter=query_1)]
56 | assert len(search_results_1) == 1
57 | assert search_results_1[0].metadata == input_data["metadata_1"]
58 |
59 | search_results_2 = [c async for c in saver.alist(None, filter=query_2)]
60 | assert len(search_results_2) == 1
61 | assert search_results_2[0].metadata == input_data["metadata_2"]
62 |
63 | search_results_3 = [c async for c in saver.alist(None, filter=query_3)]
64 | assert len(search_results_3) == 3
65 |
66 | search_results_4 = [c async for c in saver.alist(None, filter=query_4)]
67 | assert len(search_results_4) == 0
68 |
69 | # search by config (defaults to checkpoints across all namespaces)
70 | search_results_5 = [
71 | c async for c in saver.alist({"configurable": {"thread_id": "thread-2"}})
72 | ]
73 | assert len(search_results_5) == 2
74 | assert {
75 | search_results_5[0].config["configurable"]["checkpoint_ns"],
76 | search_results_5[1].config["configurable"]["checkpoint_ns"],
77 | } == {"", "inner"}
78 |
79 |
80 | async def test_null_chars(input_data: dict[str, Any]) -> None:
81 | """In MongoDB string *values* can be any valid UTF-8 including nulls.
82 | *Field names*, however, cannot contain nulls characters."""
83 | async with AsyncMongoDBSaver.from_conn_string(
84 | MONGODB_URI, DB_NAME, COLLECTION_NAME
85 | ) as saver:
86 | null_str = "\x00abc" # string containing null character
87 |
88 | # 1. null string in field *value*
89 | null_value_cfg = await saver.aput(
90 | input_data["config_1"],
91 | input_data["chkpnt_1"],
92 | {"my_key": null_str},
93 | {},
94 | )
95 | null_tuple = await saver.aget_tuple(null_value_cfg)
96 | assert null_tuple.metadata["my_key"] == null_str # type: ignore
97 | cps = [c async for c in saver.alist(None, filter={"my_key": null_str})]
98 | assert cps[0].metadata["my_key"] == null_str
99 |
100 | # 2. null string in field *name*
101 | with pytest.raises(InvalidDocument):
102 | await saver.aput(
103 | input_data["config_1"],
104 | input_data["chkpnt_1"],
105 | {null_str: "my_value"}, # type: ignore
106 | {},
107 | )
108 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ---
4 |
5 | ## Changes in version 0.0.1 (2025/05/09)
6 |
7 | - Initial release, added support for `MongoDBStore`.
8 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/README.md:
--------------------------------------------------------------------------------
1 | # langraph-store-mongodb
2 |
3 | LangGraph long-term memory using MongoDB.
4 |
5 | The implementation is feature complete (v0.0.1), but we may revise the API based on feedback and further usage patterns.
6 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/justfile:
--------------------------------------------------------------------------------
1 | set shell := ["bash", "-c"]
2 | set dotenv-load
3 | set dotenv-filename := "../../.local_atlas_uri"
4 |
5 | # Default target executed when no arguments are given.
6 | [private]
7 | default:
8 | @just --list
9 |
10 | install:
11 | uv sync --frozen
12 |
13 | [group('test')]
14 | integration_tests *args="":
15 | uv run pytest tests/integration_tests/ {{args}}
16 |
17 | [group('test')]
18 | unit_tests *args="":
19 | uv run pytest tests/unit_tests {{args}}
20 |
21 | [group('test')]
22 | tests *args="":
23 | uv run pytest {{args}}
24 |
25 | [group('test')]
26 | test_watch filename:
27 | uv run ptw --snapshot-update --now . -- -vv {{filename}}
28 |
29 | [group('lint')]
30 | lint:
31 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff --files
32 | git ls-files -- '*.py' | xargs uv run pre-commit run ruff-format --files
33 |
34 | [group('lint')]
35 | typing:
36 | uv run mypy .
37 |
38 | [group('lint')]
39 | codespell:
40 | git ls-files -- '*.py' | xargs uv run pre-commit run --hook-stage manual codespell --files
41 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/langgraph/store/mongodb/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import (
2 | MongoDBStore,
3 | VectorIndexConfig,
4 | create_vector_index_config,
5 | )
6 |
7 | __all__ = ["MongoDBStore", "VectorIndexConfig", "create_vector_index_config"]
8 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling>1.24"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "langgraph-store-mongodb"
7 | version = "0.0.1"
8 | description = "MongoDB implementation of the LangGraph long-term memory store."
9 | readme = "README.md"
10 | requires-python = ">=3.9"
11 | dependencies = [
12 | "langgraph-checkpoint>=2.0.23,<3.0.0",
13 | "langchain-mongodb>=0.6.1",
14 | ]
15 |
16 | [dependency-groups]
17 | dev = [
18 | "pytest-asyncio>=0.21.1",
19 | "pytest>=7.2.1",
20 | "pre-commit>=4.0",
21 | "mypy>=1.10.0",
22 | ]
23 |
24 | [tool.hatch.build.targets.wheel]
25 | packages = ["langgraph"]
26 |
27 | [tool.pytest.ini_options]
28 | addopts = "--strict-markers --strict-config --durations=5 -vv"
29 | markers = [
30 | "requires: mark tests as requiring a specific library",
31 | "compile: mark placeholder test used to compile integration tests without running them",
32 | ]
33 | asyncio_mode = "auto"
34 |
35 | [tool.ruff]
36 | lint.select = [
37 | "E", # pycodestyle
38 | "F", # Pyflakes
39 | "UP", # pyupgrade
40 | "B", # flake8-bugbear
41 | "I", # isort
42 | ]
43 | lint.ignore = ["E501", "B008", "UP007", "UP006"]
44 |
45 | [tool.mypy]
46 | # https://mypy.readthedocs.io/en/stable/config_file.html
47 | disallow_untyped_defs = true
48 | explicit_package_bases = true
49 | warn_no_return = false
50 | warn_unused_ignores = true
51 | warn_redundant_casts = true
52 | allow_redefinition = true
53 | disable_error_code = "typeddict-item, return-value"
54 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/tests/integration_tests/test_store_semantic.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections.abc import Generator
3 | from time import monotonic, sleep
4 | from typing import Callable
5 |
6 | import pytest
7 | from langchain_core.embeddings import Embeddings
8 | from pymongo import MongoClient
9 | from pymongo.collection import Collection
10 | from pymongo.errors import OperationFailure
11 |
12 | from langgraph.store.base import PutOp
13 | from langgraph.store.memory import InMemoryStore
14 | from langgraph.store.mongodb import (
15 | MongoDBStore,
16 | create_vector_index_config,
17 | )
18 |
19 | MONGODB_URI = os.environ.get(
20 | "MONGODB_URI", "mongodb://localhost:27017?directConnection=true"
21 | )
22 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
23 | COLLECTION_NAME = "semantic_search"
24 | INDEX_NAME = "vector_index"
25 | TIMEOUT, INTERVAL = 30, 1 # timeout to index new data
26 |
27 | DIMENSIONS = 5 # Dimensions of embedding model
28 |
29 |
30 | def wait(cond: Callable, timeout: int = 15, interval: int = 1) -> None:
31 | start = monotonic()
32 | while monotonic() - start < timeout:
33 | if cond():
34 | return
35 | else:
36 | sleep(interval)
37 | raise TimeoutError("timeout waiting for: ", cond)
38 |
39 |
40 | class StaticEmbeddings(Embeddings):
41 | """ANN Search is not tested here. That is done in langchain-mongodb."""
42 |
43 | def embed_documents(self, texts: list[str]) -> list[list[float]]:
44 | vectors = []
45 | for txt in texts:
46 | vectors.append(self.embed_query(txt))
47 | return vectors
48 |
49 | def embed_query(self, text: str) -> list[float]:
50 | if "pears" in text:
51 | return [1.0] + [0.5] * (DIMENSIONS - 1)
52 | else:
53 | return [0.5] * DIMENSIONS
54 |
55 |
56 | @pytest.fixture
57 | def collection() -> Generator[Collection, None, None]:
58 | client: MongoClient = MongoClient(MONGODB_URI)
59 | db = client[DB_NAME]
60 | db.drop_collection(COLLECTION_NAME)
61 | collection = db.create_collection(COLLECTION_NAME)
62 | wait(lambda: collection.count_documents({}) == 0, TIMEOUT, INTERVAL)
63 | try:
64 | collection.drop_search_index(INDEX_NAME)
65 | except OperationFailure:
66 | pass
67 | wait(
68 | lambda: len(collection.list_search_indexes().to_list()) == 0, TIMEOUT, INTERVAL
69 | )
70 |
71 | yield collection
72 |
73 | client.close()
74 |
75 |
76 | def test_filters(collection: Collection) -> None:
77 | """Test permutations of namespace_prefix in filter."""
78 |
79 | index_config = create_vector_index_config(
80 | name=INDEX_NAME,
81 | dims=DIMENSIONS,
82 | fields=["product"],
83 | embed=StaticEmbeddings(), # embedding
84 | filters=["metadata.available"],
85 | )
86 | store_mdb = MongoDBStore(
87 | collection, index_config=index_config, auto_index_timeout=TIMEOUT
88 | )
89 | store_in_mem = InMemoryStore(index=index_config)
90 |
91 | namespaces = [
92 | ("a",),
93 | ("a", "b", "c"),
94 | ("a", "b", "c", "d"),
95 | ]
96 |
97 | products = ["apples", "oranges", "pears"]
98 |
99 | # Add some indexed data
100 | put_ops = []
101 | for i, ns in enumerate(namespaces):
102 | put_ops.append(
103 | PutOp(
104 | namespace=ns,
105 | key=f"id_{i}",
106 | value={
107 | "product": products[i],
108 | "metadata": {"available": bool(i % 2), "grade": "A" * (i + 1)},
109 | },
110 | )
111 | )
112 |
113 | store_mdb.batch(put_ops)
114 | store_in_mem.batch(put_ops)
115 |
116 | query = "What is the grade of our pears?"
117 | # Case 1: fields is a string:
118 | namespace_prefix = ("a",) # filter ("a",) catches all docs
119 | wait(
120 | lambda: len(store_mdb.search(namespace_prefix, query=query)) == len(products),
121 | TIMEOUT,
122 | INTERVAL,
123 | )
124 |
125 | result_mdb = store_mdb.search(namespace_prefix, query=query)
126 | assert result_mdb[0].value["product"] == "pears" # test sorted by score
127 |
128 | result_mem = store_in_mem.search(namespace_prefix, query=query)
129 | assert len(result_mem) == len(products)
130 |
131 | # Case 2: filter on 2nd namespace in hierarchy
132 | namespace_prefix = ("a", "b")
133 | result_mem = store_in_mem.search(namespace_prefix, query=query)
134 | result_mdb = store_mdb.search(namespace_prefix, query=query)
135 | # filter ("a",) catches all docs
136 | assert len(result_mem) == len(result_mdb) == len(products) - 1
137 | assert result_mdb[0].value["product"] == "pears"
138 |
139 | # Case 3: Empty namespace_prefix
140 | namespace_prefix = ("",)
141 | result_mem = store_in_mem.search(namespace_prefix, query=query)
142 | result_mdb = store_mdb.search(namespace_prefix, query=query)
143 | assert len(result_mem) == len(result_mdb) == 0
144 |
145 | # Case 4: With filter on value (nested)
146 | namespace_prefix = ("a",)
147 | available = {"metadata.available": True}
148 | result_mdb = store_mdb.search(namespace_prefix, query=query, filter=available)
149 | assert result_mdb[0].value["product"] == "oranges"
150 | assert len(result_mdb) == 1
151 |
--------------------------------------------------------------------------------
/libs/langgraph-store-mongodb/tests/unit_tests/test_store_async.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections.abc import Generator
3 | from typing import Union
4 |
5 | import pytest
6 | from pymongo import MongoClient
7 |
8 | from langgraph.store.base import (
9 | GetOp,
10 | ListNamespacesOp,
11 | PutOp,
12 | SearchOp,
13 | TTLConfig,
14 | )
15 | from langgraph.store.mongodb import (
16 | MongoDBStore,
17 | )
18 |
19 | MONGODB_URI = os.environ.get(
20 | "MONGODB_URI", "mongodb://localhost:27017?directConnection=true"
21 | )
22 | DB_NAME = os.environ.get("DB_NAME", "langgraph-test")
23 | COLLECTION_NAME = "async_store"
24 |
25 |
26 | @pytest.fixture
27 | def store() -> Generator:
28 | """Create a simple store following that in base's test_list_namespaces_basic"""
29 | client: MongoClient = MongoClient(MONGODB_URI)
30 | collection = client[DB_NAME][COLLECTION_NAME]
31 | collection.delete_many({})
32 | collection.drop_indexes()
33 |
34 | yield MongoDBStore(
35 | collection,
36 | ttl_config=TTLConfig(default_ttl=3600, refresh_on_read=True),
37 | )
38 |
39 | if client:
40 | client.close()
41 |
42 |
43 | async def test_batch_async(store: MongoDBStore) -> None:
44 | N = 100
45 | M = 5
46 | ops: list[Union[PutOp, GetOp, ListNamespacesOp, SearchOp]] = []
47 | for m in range(M):
48 | for i in range(N):
49 | ops.append(
50 | PutOp(
51 | ("test", "foo", "bar", "baz", str(m % 2)),
52 | f"key{i}",
53 | value={"foo": "bar" + str(i)},
54 | )
55 | )
56 | ops.append(
57 | GetOp(
58 | ("test", "foo", "bar", "baz", str(m % 2)),
59 | f"key{i}",
60 | )
61 | )
62 | ops.append(
63 | ListNamespacesOp(
64 | match_conditions=None,
65 | max_depth=m + 1,
66 | )
67 | )
68 | ops.append(
69 | SearchOp(
70 | ("test",),
71 | )
72 | )
73 | ops.append(
74 | PutOp(
75 | ("test", "foo", "bar", "baz", str(m % 2)),
76 | f"key{i}",
77 | value={"foo": "bar" + str(i)},
78 | )
79 | )
80 | ops.append(
81 | PutOp(("test", "foo", "bar", "baz", str(m % 2)), f"key{i}", None)
82 | )
83 |
84 | results = await store.abatch(ops)
85 | assert len(results) == M * N * 6
86 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "langchain-mongodb-monorepo"
3 | version = "0.1.0"
4 | description = "LangChain MongoDB mono-repo"
5 | readme = "README.md"
6 | requires-python = ">=3.9"
7 | license = { file = "LICENSE" }
8 | dependencies = []
9 |
10 | [dependency-groups]
11 | dev = [
12 | "autodoc-pydantic>=2.2.0",
13 | "pydata-sphinx-theme>=0.16.1",
14 | "sphinx-design>=0.6.1",
15 | "sphinx>=7.4.7",
16 | "sphinx-copybutton>=0.5.2",
17 | "toml>=0.10.2",
18 | "langchain-core>=0.3.30",
19 | "sphinxcontrib-googleanalytics>=0.4",
20 | "langchain-mongodb",
21 | "langgraph-checkpoint-mongodb",
22 | "langgraph-store-mongodb",
23 | "langchain-community>=0.3.14",
24 | "myst-parser>=3.0.1",
25 | "ipython>=8.18.1",
26 | ]
27 |
28 | [tool.uv.sources]
29 | langchain-mongodb = { path = "libs/langchain-mongodb", editable = true }
30 | langgraph-checkpoint-mongodb = { path = "libs/langgraph-checkpoint-mongodb", editable = true }
31 | langgraph-store-mongodb = { path = "libs/langgraph-store-mongodb", editable = true }
32 |
--------------------------------------------------------------------------------
/scripts/setup_ollama.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eu
3 |
4 | ollama pull all-minilm:l6-v2
5 | ollama pull llama3:8b
6 |
--------------------------------------------------------------------------------
/scripts/start_local_atlas.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eu
3 |
4 | echo "Starting the container"
5 |
6 | IMAGE=mongodb/mongodb-atlas-local:latest
7 | DOCKER=$(which docker || which podman)
8 |
9 | $DOCKER pull $IMAGE
10 |
11 | $DOCKER kill mongodb_atlas_local || true
12 |
13 | CONTAINER_ID=$($DOCKER run --rm -d --name mongodb_atlas_local -P $IMAGE)
14 |
15 | function wait() {
16 | CONTAINER_ID=$1
17 | echo "waiting for container to become healthy..."
18 | $DOCKER logs mongodb_atlas_local
19 | }
20 |
21 | wait "$CONTAINER_ID"
22 |
23 | EXPOSED_PORT=$($DOCKER inspect --format='{{ (index (index .NetworkSettings.Ports "27017/tcp") 0).HostPort }}' "$CONTAINER_ID")
24 | export MONGODB_URI="mongodb://127.0.0.1:$EXPOSED_PORT/?directConnection=true"
25 | SCRIPT_DIR=$(realpath "$(dirname ${BASH_SOURCE[0]})")
26 | ROOT_DIR=$(dirname $SCRIPT_DIR)
27 | echo "MONGODB_URI=$MONGODB_URI" > $ROOT_DIR/.local_atlas_uri
28 |
29 | # Sleep for a bit to let all services start.
30 | sleep 5
31 |
--------------------------------------------------------------------------------
/scripts/update-locks.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eu
3 |
4 | python -m uv lock
5 |
6 | pushd libs/langgraph-checkpoint-mongodb
7 | python -m uv lock
8 | popd
9 |
10 | pushd libs/langchain-mongodb
11 | python -m uv lock
12 | popd
13 |
--------------------------------------------------------------------------------