├── .github
├── actions
│ └── uv_setup
│ │ └── action.yml
└── workflows
│ ├── _lint.yml
│ ├── _release.yml
│ ├── _test.yml
│ └── ci.yml
├── .gitignore
├── CONTRIBUTING.md
├── DEVELOPMENT.md
├── LICENSE
├── Makefile
├── README.md
├── docker-compose.yml
├── docs
└── v2_design_overview.md
├── examples
├── migrate_pgvector_to_pgvectorstore.ipynb
├── migrate_pgvector_to_pgvectorstore.md
├── pg_vectorstore.ipynb
├── pg_vectorstore_how_to.ipynb
└── vectorstore.ipynb
├── langchain_postgres
├── __init__.py
├── _utils.py
├── chat_message_histories.py
├── py.typed
├── translator.py
├── utils
│ └── pgvector_migrator.py
├── v2
│ ├── __init__.py
│ ├── async_vectorstore.py
│ ├── engine.py
│ ├── hybrid_search_config.py
│ ├── indexes.py
│ └── vectorstores.py
└── vectorstores.py
├── pyproject.toml
├── security.md
├── tests
├── __init__.py
├── unit_tests
│ ├── __init__.py
│ ├── fake_embeddings.py
│ ├── fixtures
│ │ ├── __init__.py
│ │ ├── filtering_test_cases.py
│ │ └── metadata_filtering_data.py
│ ├── query_constructors
│ │ ├── __init__.py
│ │ └── test_pgvector.py
│ ├── test_imports.py
│ ├── v1
│ │ ├── __init__.py
│ │ ├── test_chat_histories.py
│ │ ├── test_vectorstore.py
│ │ └── test_vectorstore_standard_tests.py
│ └── v2
│ │ ├── __init__.py
│ │ ├── test_async_pg_vectorstore.py
│ │ ├── test_async_pg_vectorstore_from_methods.py
│ │ ├── test_async_pg_vectorstore_index.py
│ │ ├── test_async_pg_vectorstore_search.py
│ │ ├── test_engine.py
│ │ ├── test_hybrid_search_config.py
│ │ ├── test_indexes.py
│ │ ├── test_pg_vectorstore.py
│ │ ├── test_pg_vectorstore_from_methods.py
│ │ ├── test_pg_vectorstore_index.py
│ │ ├── test_pg_vectorstore_search.py
│ │ └── test_pg_vectorstore_standard_suite.py
└── utils.py
└── uv.lock
/.github/actions/uv_setup/action.yml:
--------------------------------------------------------------------------------
1 | # TODO: https://docs.astral.sh/uv/guides/integration/github/#caching
2 |
3 | name: uv-install
4 | description: Set up Python and uv
5 |
6 | inputs:
7 | python-version:
8 | description: Python version, supporting MAJOR.MINOR only
9 | required: true
10 |
11 | runs:
12 | using: composite
13 | steps:
14 | - name: Install uv and set the python version
15 | uses: astral-sh/setup-uv@v5
16 | with:
17 | version: ${{ env.UV_VERSION }}
18 | python-version: ${{ inputs.python-version }}
19 |
--------------------------------------------------------------------------------
/.github/workflows/_lint.yml:
--------------------------------------------------------------------------------
1 | name: lint
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 | python-version:
11 | required: true
12 | type: string
13 | description: "Python version to use"
14 |
15 | env:
16 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }}
17 |
18 | # This env var allows us to get inline annotations when ruff has complaints.
19 | RUFF_OUTPUT_FORMAT: github
20 | UV_FROZEN: "true"
21 |
22 | jobs:
23 | build:
24 | name: "make lint #${{ inputs.python-version }}"
25 | runs-on: ubuntu-latest
26 | timeout-minutes: 20
27 | steps:
28 | - uses: actions/checkout@v4
29 |
30 | - name: Set up Python ${{ inputs.python-version }} + uv
31 | uses: "./.github/actions/uv_setup"
32 | with:
33 | python-version: ${{ inputs.python-version }}
34 |
35 | - name: Install dependencies
36 | working-directory: ${{ inputs.working-directory }}
37 | run: |
38 | uv sync --group test
39 |
40 | - name: Analysing the code with our lint
41 | working-directory: ${{ inputs.working-directory }}
42 | run: |
43 | make lint
44 |
--------------------------------------------------------------------------------
/.github/workflows/_release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }}
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 | workflow_dispatch:
11 | inputs:
12 | working-directory:
13 | description: "From which folder this pipeline executes"
14 | default: "libs/server"
15 | required: true
16 | type: choice
17 | options:
18 | - "."
19 | dangerous-nonmain-release:
20 | required: false
21 | type: boolean
22 | default: false
23 | description: "Release from a non-main branch (danger!)"
24 |
25 | env:
26 | PYTHON_VERSION: "3.11"
27 | UV_FROZEN: "true"
28 | UV_NO_SYNC: "true"
29 |
30 | jobs:
31 | build:
32 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmain-release
33 | environment: Scheduled testing
34 | runs-on: ubuntu-latest
35 |
36 | outputs:
37 | pkg-name: ${{ steps.check-version.outputs.pkg-name }}
38 | version: ${{ steps.check-version.outputs.version }}
39 |
40 | steps:
41 | - uses: actions/checkout@v4
42 |
43 | - name: Set up Python + uv
44 | uses: "./.github/actions/uv_setup"
45 | with:
46 | python-version: ${{ env.PYTHON_VERSION }}
47 |
48 | # We want to keep this build stage *separate* from the release stage,
49 | # so that there's no sharing of permissions between them.
50 | # The release stage has trusted publishing and GitHub repo contents write access,
51 | # and we want to keep the scope of that access limited just to the release job.
52 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency)
53 | # could get access to our GitHub or PyPI credentials.
54 | #
55 | # Per the trusted publishing GitHub Action:
56 | # > It is strongly advised to separate jobs for building [...]
57 | # > from the publish job.
58 | # https://github.com/pypa/gh-action-pypi-publish#non-goals
59 | - name: Build project for distribution
60 | run: uv build
61 | working-directory: ${{ inputs.working-directory }}
62 | - name: Upload build
63 | uses: actions/upload-artifact@v4
64 | with:
65 | name: dist
66 | path: ${{ inputs.working-directory }}/dist/
67 |
68 | - name: Check Version
69 | id: check-version
70 | shell: python
71 | working-directory: ${{ inputs.working-directory }}
72 | run: |
73 | import os
74 | import tomllib
75 | with open("pyproject.toml", "rb") as f:
76 | data = tomllib.load(f)
77 | pkg_name = data["project"]["name"]
78 | version = data["project"]["version"]
79 | with open(os.environ["GITHUB_OUTPUT"], "a") as f:
80 | f.write(f"pkg-name={pkg_name}\n")
81 | f.write(f"version={version}\n")
82 |
83 | publish:
84 | needs:
85 | - build
86 | runs-on: ubuntu-latest
87 | permissions:
88 | # This permission is used for trusted publishing:
89 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/
90 | #
91 | # Trusted publishing has to also be configured on PyPI for each package:
92 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/
93 | id-token: write
94 |
95 | defaults:
96 | run:
97 | working-directory: ${{ inputs.working-directory }}
98 |
99 | steps:
100 | - uses: actions/checkout@v4
101 |
102 | - name: Set up Python + uv
103 | uses: "./.github/actions/uv_setup"
104 | with:
105 | python-version: ${{ env.PYTHON_VERSION }}
106 |
107 | - uses: actions/download-artifact@v4
108 | with:
109 | name: dist
110 | path: ${{ inputs.working-directory }}/dist/
111 |
112 | - name: Publish package distributions to PyPI
113 | uses: pypa/gh-action-pypi-publish@release/v1
114 | with:
115 | packages-dir: ${{ inputs.working-directory }}/dist/
116 | verbose: true
117 | print-hash: true
118 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0
119 | attestations: false
120 |
121 | mark-release:
122 | needs:
123 | - build
124 | - publish
125 | runs-on: ubuntu-latest
126 | permissions:
127 | # This permission is needed by `ncipollo/release-action` to
128 | # create the GitHub release.
129 | contents: write
130 |
131 | defaults:
132 | run:
133 | working-directory: ${{ inputs.working-directory }}
134 |
135 | steps:
136 | - uses: actions/checkout@v4
137 |
138 | - name: Set up Python + uv
139 | uses: "./.github/actions/uv_setup"
140 | with:
141 | python-version: ${{ env.PYTHON_VERSION }}
142 |
143 | - uses: actions/download-artifact@v4
144 | with:
145 | name: dist
146 | path: ${{ inputs.working-directory }}/dist/
147 |
148 | - name: Create Tag
149 | uses: ncipollo/release-action@v1
150 | with:
151 | artifacts: "dist/*"
152 | token: ${{ secrets.GITHUB_TOKEN }}
153 | generateReleaseNotes: true
154 | tag: ${{needs.build.outputs.pkg-name}}==${{ needs.build.outputs.version }}
155 | body: ${{ needs.release-notes.outputs.release-body }}
156 | commit: main
157 | makeLatest: true
158 |
--------------------------------------------------------------------------------
/.github/workflows/_test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on:
4 | workflow_call:
5 | inputs:
6 | working-directory:
7 | required: true
8 | type: string
9 | description: "From which folder this pipeline executes"
10 |
11 | env:
12 | UV_FROZEN: "true"
13 | UV_NO_SYNC: "true"
14 |
15 | jobs:
16 | build:
17 | defaults:
18 | run:
19 | working-directory: ${{ inputs.working-directory }}
20 | runs-on: ubuntu-latest
21 | services:
22 | postgres:
23 | # ensure postgres version this stays in sync with prod database
24 | # and with postgres version used in docker compose
25 | # Testing with postgres that has the pg vector extension
26 | image: ankane/pgvector
27 | env:
28 | # optional (defaults to `postgres`)
29 | POSTGRES_DB: langchain_test
30 | # required
31 | POSTGRES_PASSWORD: langchain
32 | # optional (defaults to `5432`)
33 | POSTGRES_PORT: 5432
34 | # optional (defaults to `postgres`)
35 | POSTGRES_USER: langchain
36 | ports:
37 | # maps tcp port 5432 on service container to the host
38 | - 5432:5432
39 | # set health checks to wait until postgres has started
40 | options: >-
41 | --health-cmd pg_isready
42 | --health-interval 3s
43 | --health-timeout 5s
44 | --health-retries 10
45 | strategy:
46 | matrix:
47 | python-version:
48 | # - "3.9"
49 | # - "3.10"
50 | # - "3.11"
51 | - "3.12"
52 | name: Python ${{ matrix.python-version }}
53 | steps:
54 | - uses: actions/checkout@v4
55 | - name: Install postgresql-client
56 | run: |
57 | sudo apt-get update
58 | sudo apt-get install -y postgresql-client
59 | - name: Test database connection
60 | run: |
61 | # Test psql connection
62 | psql -h localhost -p 5432 -U langchain -d langchain_test -c "SELECT 1;"
63 |
64 | if [ $? -ne 0 ]; then
65 | echo "Postgres connection failed"
66 | exit 1
67 | else
68 | echo "Postgres connection successful"
69 | fi
70 | env:
71 | # postgress password is required; alternatively, you can run:
72 | # `PGPASSWORD=postgres_password psql ...`
73 | PGPASSWORD: langchain
74 | - name: Set up Python ${{ inputs.python-version }} + uv
75 | uses: "./.github/actions/uv_setup"
76 | id: setup-python
77 | with:
78 | python-version: ${{ inputs.python-version }}
79 | - name: Install dependencies
80 | shell: bash
81 | run: uv sync --group test
82 | - name: Run unit tests
83 | shell: bash
84 | run: |
85 | make test
86 | - name: Ensure the tests did not create any additional files
87 | shell: bash
88 | run: |
89 | set -eu
90 |
91 | STATUS="$(git status)"
92 | echo "$STATUS"
93 |
94 | # grep will exit non-zero if the target message isn't found,
95 | # and `set -e` above will cause the step to fail.
96 | echo "$STATUS" | grep 'nothing to commit, working tree clean'
97 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Run CI Tests
3 |
4 | on:
5 | push:
6 | branches: [ main ]
7 | pull_request:
8 | workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI
9 |
10 | # If another push to the same PR or branch happens while this workflow is still running,
11 | # cancel the earlier run in favor of the next run.
12 | #
13 | # There's no point in testing an outdated version of the code. GitHub only allows
14 | # a limited number of job runners to be active at the same time, so it's better to cancel
15 | # pointless jobs early so that more useful jobs can run sooner.
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.ref }}
18 | cancel-in-progress: true
19 |
20 | env:
21 | UV_FROZEN: "true"
22 | UV_NO_SYNC: "true"
23 | UV_VERSION: "0.5.25"
24 | WORKDIR: "."
25 |
26 | jobs:
27 | lint:
28 | strategy:
29 | matrix:
30 | # Only lint on the min and max supported Python versions.
31 | # It's extremely unlikely that there's a lint issue on any version in between
32 | # that doesn't show up on the min or max versions.
33 | #
34 | # GitHub rate-limits how many jobs can be running at any one time.
35 | # Starting new jobs is also relatively slow,
36 | # so linting on fewer versions makes CI faster.
37 | python-version:
38 | - "3.12"
39 | uses:
40 | ./.github/workflows/_lint.yml
41 | with:
42 | working-directory: "."
43 | python-version: ${{ matrix.python-version }}
44 | secrets: inherit
45 | test:
46 | uses:
47 | ./.github/workflows/_test.yml
48 | with:
49 | working-directory: "."
50 | secrets: inherit
51 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Python template
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 | .DS_Store
163 |
164 | # Pycharm
165 | .idea
166 |
167 | # pyenv virtualenv
168 | .python-version
169 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to langchain-postgres
2 |
3 | This guide is intended to help you get started contributing to langchain-postgres.
4 | As an open-source project in a rapidly developing field, we are extremely open
5 | to contributions, whether it be in the form of a new feature, improved infra, or better documentation.
6 |
7 | To contribute to this project, please follow the [fork and pull request](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
8 |
9 | ## Reporting bugs or suggesting improvements
10 |
11 | Our [GitHub issues](https://github.com/langchain-ai/langchain-postgres/issues) page is kept up to date
12 | with bugs, improvements, and feature requests. There is a taxonomy of labels to help
13 | with sorting and discovery of issues of interest. [See this page](https://github.com/langchain-ai/langchain-postgres/labels) for an overview of
14 | the system we use to tag our issues and pull requests.
15 |
16 | If you're looking for help with your code, consider posting a question on the
17 | [GitHub Discussions board](https://github.com/langchain-ai/langchain/discussions). Please
18 | understand that we won't be able to provide individual support via email. We
19 | also believe that help is much more valuable if it's **shared publicly**,
20 | so that more people can benefit from it.
21 |
22 | - **Describing your issue:** Try to provide as many details as possible. What
23 | exactly goes wrong? _How_ is it failing? Is there an error?
24 | "XY doesn't work" usually isn't that helpful for tracking down problems. Always
25 | remember to include the code you ran and if possible, extract only the relevant
26 | parts and don't just dump your entire script. This will make it easier for us to
27 | reproduce the error.
28 |
29 | - **Sharing long blocks of code or logs:** If you need to include long code,
30 | logs or tracebacks, you can wrap them in `` and ` `. This
31 | [collapses the content](https://developer.mozilla.org/en/docs/Web/HTML/Element/details)
32 | so it only becomes visible on click, making the issue easier to read and follow.
33 |
34 | ## Contributing code and documentation
35 |
36 | You can develop langchain-postgres locally and contribute to the Project!
37 |
38 | See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up and using a development environment.
39 |
40 | ## Opening a pull request
41 |
42 | Once you wrote and manually tested your change, you can start sending the patch to the main repository.
43 |
44 | - Open a new GitHub pull request with the patch against the `main` branch.
45 | - Ensure the PR title follows semantic commits conventions.
46 | - For example, `feat: add new feature`, `fix: correct issue with X`.
47 | - Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
48 |
--------------------------------------------------------------------------------
/DEVELOPMENT.md:
--------------------------------------------------------------------------------
1 | # Setting up a Development Environment
2 |
3 | This document details how to set up a local development environment that will
4 | allow you to contribute changes to the project.
5 |
6 | Acquire sources and create virtualenv.
7 | ```shell
8 | git clone https://github.com/langchain-ai/langchain-postgres
9 | cd langchain-postgres
10 | uv venv --python=3.13
11 | source .venv/bin/activate
12 | ```
13 |
14 | Install package in editable mode.
15 | ```shell
16 | poetry install --with dev,test,lint
17 | ```
18 |
19 | Start PostgreSQL/PGVector.
20 | ```shell
21 | docker run --rm -it --name pgvector-container \
22 | -e POSTGRES_USER=langchain \
23 | -e POSTGRES_PASSWORD=langchain \
24 | -e POSTGRES_DB=langchain \
25 | -p 6024:5432 pgvector/pgvector:pg16 \
26 | postgres -c log_statement=all
27 | ```
28 |
29 | Invoke test cases.
30 | ```shell
31 | pytest -vvv
32 | ```
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 LangChain, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: all lint format test help
2 |
3 | # Default target executed when no arguments are given to make.
4 | all: help
5 |
6 | ######################
7 | # TESTING AND COVERAGE
8 | ######################
9 |
10 | # Define a variable for the test file path.
11 | TEST_FILE ?= tests/unit_tests/
12 |
13 | test:
14 | uv run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
15 |
16 | test_watch:
17 | uv run ptw . -- $(TEST_FILE)
18 |
19 |
20 | ######################
21 | # LINTING AND FORMATTING
22 | ######################
23 |
24 | # Define a variable for Python and notebook files.
25 | lint format: PYTHON_FILES=.
26 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
27 |
28 | lint lint_diff:
29 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff
30 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff
31 | [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)
32 |
33 | format format_diff:
34 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES)
35 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES)
36 |
37 | spell_check:
38 | uv run codespell --toml pyproject.toml
39 |
40 | spell_fix:
41 | uv run codespell --toml pyproject.toml -w
42 |
43 | ######################
44 | # HELP
45 | ######################
46 |
47 | help:
48 | @echo '===================='
49 | @echo '-- LINTING --'
50 | @echo 'format - run code formatters'
51 | @echo 'lint - run linters'
52 | @echo 'spell_check - run codespell on the project'
53 | @echo 'spell_fix - run codespell on the project and fix the errors'
54 | @echo '-- TESTS --'
55 | @echo 'coverage - run unit tests and generate coverage report'
56 | @echo 'test - run unit tests'
57 | @echo 'test TEST_FILE= - run all tests in file'
58 | @echo '-- DOCUMENTATION tasks are from the top-level Makefile --'
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # langchain-postgres
2 |
3 | [](https://github.com/langchain-ai/langchain-postgres/releases)
4 | [](https://github.com/langchain-ai/langchain-postgres/actions/workflows/ci.yml)
5 | [](https://opensource.org/licenses/MIT)
6 | [](https://twitter.com/langchainai)
7 | [](https://discord.gg/6adMQxSpJS)
8 | [](https://github.com/langchain-ai/langchain-postgres/issues)
9 |
10 | The `langchain-postgres` package implementations of core LangChain abstractions using `Postgres`.
11 |
12 | The package is released under the MIT license.
13 |
14 | Feel free to use the abstraction as provided or else modify them / extend them as appropriate for your own application.
15 |
16 | ## Requirements
17 |
18 | The package supports the [asyncpg](https://github.com/MagicStack/asyncpg) and [psycopg3](https://www.psycopg.org/psycopg3/) drivers.
19 |
20 | ## Installation
21 |
22 | ```bash
23 | pip install -U langchain-postgres
24 | ```
25 |
26 | ## Vectorstore
27 |
28 | > [!WARNING]
29 | > In v0.0.14+, `PGVector` is deprecated. Please migrate to `PGVectorStore`
30 | > for improved performance and manageability.
31 | > See the [migration guide](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/migrate_pgvector_to_pgvectorstore.ipynb) for details on how to migrate from `PGVector` to `PGVectorStore`.
32 |
33 | ### Documentation
34 |
35 | * [Quickstart](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/pg_vectorstore.ipynb)
36 | * [How-to](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/pg_vectorstore_how_to.ipynb)
37 |
38 | ### Example
39 |
40 | ```python
41 | from langchain_core.documents import Document
42 | from langchain_core.embeddings import DeterministicFakeEmbedding
43 | from langchain_postgres import PGEngine, PGVectorStore
44 |
45 | # Replace the connection string with your own Postgres connection string
46 | CONNECTION_STRING = "postgresql+psycopg3://langchain:langchain@localhost:6024/langchain"
47 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
48 |
49 | # Replace the vector size with your own vector size
50 | VECTOR_SIZE = 768
51 | embedding = DeterministicFakeEmbedding(size=VECTOR_SIZE)
52 |
53 | TABLE_NAME = "my_doc_collection"
54 |
55 | engine.init_vectorstore_table(
56 | table_name=TABLE_NAME,
57 | vector_size=VECTOR_SIZE,
58 | )
59 |
60 | store = PGVectorStore.create_sync(
61 | engine=engine,
62 | table_name=TABLE_NAME,
63 | embedding_service=embedding,
64 | )
65 |
66 | docs = [
67 | Document(page_content="Apples and oranges"),
68 | Document(page_content="Cars and airplanes"),
69 | Document(page_content="Train")
70 | ]
71 |
72 | store.add_documents(docs)
73 |
74 | query = "I'd like a fruit."
75 | docs = store.similarity_search(query)
76 | print(docs)
77 | ```
78 |
79 | > [!TIP]
80 | > All synchronous functions have corresponding asynchronous functions
81 |
82 | ## ChatMessageHistory
83 |
84 | The chat message history abstraction helps to persist chat message history
85 | in a postgres table.
86 |
87 | PostgresChatMessageHistory is parameterized using a `table_name` and a `session_id`.
88 |
89 | The `table_name` is the name of the table in the database where
90 | the chat messages will be stored.
91 |
92 | The `session_id` is a unique identifier for the chat session. It can be assigned
93 | by the caller using `uuid.uuid4()`.
94 |
95 | ```python
96 | import uuid
97 |
98 | from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
99 | from langchain_postgres import PostgresChatMessageHistory
100 | import psycopg
101 |
102 | # Establish a synchronous connection to the database
103 | # (or use psycopg.AsyncConnection for async)
104 | conn_info = ... # Fill in with your connection info
105 | sync_connection = psycopg.connect(conn_info)
106 |
107 | # Create the table schema (only needs to be done once)
108 | table_name = "chat_history"
109 | PostgresChatMessageHistory.create_tables(sync_connection, table_name)
110 |
111 | session_id = str(uuid.uuid4())
112 |
113 | # Initialize the chat history manager
114 | chat_history = PostgresChatMessageHistory(
115 | table_name,
116 | session_id,
117 | sync_connection=sync_connection
118 | )
119 |
120 | # Add messages to the chat history
121 | chat_history.add_messages([
122 | SystemMessage(content="Meow"),
123 | AIMessage(content="woof"),
124 | HumanMessage(content="bark"),
125 | ])
126 |
127 | print(chat_history.messages)
128 | ```
129 |
130 | ## Google Cloud Integrations
131 |
132 | [Google Cloud](https://python.langchain.com/docs/integrations/providers/google/) provides Vector Store, Chat Message History, and Data Loader integrations for [AlloyDB](https://cloud.google.com/alloydb) and [Cloud SQL](https://cloud.google.com/sql) for PostgreSQL databases via the following PyPi packages:
133 |
134 | * [`langchain-google-alloydb-pg`](https://github.com/googleapis/langchain-google-alloydb-pg-python)
135 |
136 | * [`langchain-google-cloud-sql-pg`](https://github.com/googleapis/langchain-google-cloud-sql-pg-python)
137 |
138 | Using the Google Cloud integrations provides the following benefits:
139 |
140 | - **Enhanced Security**: Securely connect to Google Cloud databases utilizing IAM for authorization and database authentication without needing to manage SSL certificates, configure firewall rules, or enable authorized networks.
141 | - **Simplified and Secure Connections:** Connect to Google Cloud databases effortlessly using the instance name instead of complex connection strings. The integrations creates a secure connection pool that can be easily shared across your application using the `engine` object.
142 |
143 | | Vector Store | Metadata filtering | Async support | Schema Flexibility | Improved metadata handling | Hybrid Search |
144 | |--------------------------|--------------------|----------------|--------------------|----------------------------|---------------|
145 | | Google AlloyDB | ✓ | ✓ | ✓ | ✓ | ✗ |
146 | | Google Cloud SQL Postgres| ✓ | ✓ | ✓ | ✓ | ✗ |
147 |
148 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | name: langchain-postgres
2 |
3 | services:
4 | pgvector:
5 | # postgres with the pgvector extension
6 | image: pgvector/pgvector:pg16
7 | environment:
8 | POSTGRES_DB: langchain_test
9 | POSTGRES_USER: langchain
10 | POSTGRES_PASSWORD: langchain
11 | ports:
12 | - "5432:5432"
13 | command: |
14 | postgres -c log_statement=all
15 | healthcheck:
16 | test:
17 | [
18 | "CMD-SHELL",
19 | "psql postgresql://langchain:langchain@localhost/langchain --command 'SELECT 1;' || exit 1",
20 | ]
21 | interval: 5s
22 | retries: 60
23 | volumes:
24 | - postgres_data_pgvector_16:/var/lib/postgresql/data
25 |
26 | volumes:
27 | postgres_data_pgvector_16:
28 |
--------------------------------------------------------------------------------
/docs/v2_design_overview.md:
--------------------------------------------------------------------------------
1 | # Design Overview for `PGVectorStore`
2 |
3 | This document outlines the design choices behind the PGVectorStore integration for LangChain, focusing on how an async PostgreSQL driver can supports both synchronous and asynchronous usage.
4 |
5 | ## Motivation: Performance through Asynchronicity
6 |
7 | Database interactions are often I/O-bound, making asynchronous programming crucial for performance.
8 |
9 | - **Non-Blocking Operations:** Asynchronous code prevents the application from stalling while waiting for database responses, improving throughput and responsiveness.
10 | - **Asynchronous Foundation (`asyncio` and Drivers):** Built upon Python's `asyncio` library, the integration is designed to work with asynchronous PostgreSQL drivers to handle database operations efficiently. While compatible drivers are supported, the `asyncpg` driver is specifically recommended due to its high performance in concurrent scenarios. You can explore its benefits ([link](https://magic.io/blog/asyncpg-1m-rows-from-postgres-to-python/)) and performance benchmarks ([link](https://fernandoarteaga.dev/blog/psycopg-vs-asyncpg/)) for more details.
11 |
12 | This native async foundation ensures the core database interactions are fast and scalable.
13 |
14 | ## The Two-Class Approach: Enabling a Mixed Interface
15 |
16 | To cater to different application architectures while maintaining performance, we provide two classes:
17 |
18 | 1. **`AsyncPGVectorStore` (Core Asynchronous Implementation):**
19 | * This class contains the pure `async/await` logic for all database operations.
20 | * It's designed for **direct use within asynchronous applications**. Users working in an `asyncio` environment can `await` its methods for maximum efficiency and direct control within the event loop.
21 | * It represents the fundamental, non-blocking way of interacting with the database.
22 |
23 | 2. **`PGVectorStore` (Mixed Sync/Async API ):**
24 | * This class provides both asynchronous & synchronous APIs.
25 | * When one of its methods is called, it internally invokes the corresponding `async` method from `AsyncPGVectorStore`.
26 | * It **manages the execution of this underlying asynchronous logic**, handling the necessary `asyncio` event loop interactions (e.g., starting/running the coroutine) behind the scenes.
27 | * This allows users of synchronous codebases to leverage the performance benefits of the asynchronous core without needing to rewrite their application structure.
28 |
29 | ## Benefits of this Dual Interface Design
30 |
31 | This two-class structure provides significant advantages:
32 |
33 | - **Interface Flexibility:** Developers can **choose the interface that best fits their needs**:
34 | * Use `PGVectorStore` for easy integration into existing synchronous applications.
35 | * Use `AsyncPGVectorStore` for optimal performance and integration within `asyncio`-based applications.
36 | - **Ease of Use:** `PGVectorStore` offers a familiar synchronous programming model, hiding the complexity of managing async execution from the end-user.
37 | - **Robustness:** The clear separation helps prevent common errors associated with mixing synchronous and asynchronous code incorrectly, such as blocking the event loop from synchronous calls within an async context.
38 | - **Efficiency for Async Users:** `AsyncPGVectorStore` provides a direct path for async applications, avoiding any potential overhead from the sync-to-async bridging layer present in `PGVectorStore`.
39 |
--------------------------------------------------------------------------------
/examples/migrate_pgvector_to_pgvectorstore.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Migrate a `PGVector` vector store to `PGVectorStore`\n",
8 | "\n",
9 | "This guide shows how to migrate from the [`PGVector`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstores.py) vector store class to the [`PGVectorStore`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstore.py) class.\n",
10 | "\n",
11 | "## Why migrate?\n",
12 | "\n",
13 | "This guide explains how to migrate your vector data from a PGVector-style database (two tables) to an PGVectoStore-style database (one table per collection) for improved performance and manageability.\n",
14 | "\n",
15 | "Migrating to the PGVectorStore interface provides the following benefits:\n",
16 | "\n",
17 | "- **Simplified management**: A single table contains data corresponding to a single collection, making it easier to query, update, and maintain.\n",
18 | "- **Improved metadata handling**: It stores metadata in columns instead of JSON, resulting in significant performance improvements.\n",
19 | "- **Schema flexibility**: The interface allows users to add tables into any database schema.\n",
20 | "- **Improved performance**: The single-table schema can lead to faster query execution, especially for large collections.\n",
21 | "- **Clear separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows.\n",
22 | "- **Secure Connections:** The PGVectorStore interface creates a secure connection pool that can be easily shared across your application using the `engine` object.\n",
23 | "\n",
24 | "## Migration process\n",
25 | "\n",
26 | "> **_NOTE:_** The langchain-core library is installed to use the Fake embeddings service. To use a different embedding service, you'll need to install the appropriate library for your chosen provider. Choose embeddings services from [LangChain's Embedding models](https://python.langchain.com/v0.2/docs/integrations/text_embedding/)."
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "IR54BmgvdHT_"
33 | },
34 | "source": [
35 | "### Library Installation\n",
36 | "Install the integration library, `langchain-postgres`."
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "colab": {
44 | "base_uri": "https://localhost:8080/",
45 | "height": 1000
46 | },
47 | "id": "0ZITIDE160OD",
48 | "outputId": "e184bc0d-6541-4e0a-82d2-1e216db00a2d"
49 | },
50 | "outputs": [],
51 | "source": [
52 | "%pip install --upgrade --quiet langchain-postgres langchain-core SQLAlchemy"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "id": "f8f2830ee9ca1e01",
58 | "metadata": {
59 | "id": "f8f2830ee9ca1e01"
60 | },
61 | "source": [
62 | "## Data Migration"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "id": "OMvzMWRrR6n7",
68 | "metadata": {
69 | "id": "OMvzMWRrR6n7"
70 | },
71 | "source": [
72 | "### Set the postgres connection url\n",
73 | "\n",
74 | "`PGVectorStore` can be used with the `asyncpg` and `psycopg3` drivers."
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "irl7eMFnSPZr",
81 | "metadata": {
82 | "id": "irl7eMFnSPZr"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "# @title Set Your Values Here { display-mode: \"form\" }\n",
87 | "POSTGRES_USER = \"langchain\" # @param {type: \"string\"}\n",
88 | "POSTGRES_PASSWORD = \"langchain\" # @param {type: \"string\"}\n",
89 | "POSTGRES_HOST = \"localhost\" # @param {type: \"string\"}\n",
90 | "POSTGRES_PORT = \"6024\" # @param {type: \"string\"}\n",
91 | "POSTGRES_DB = \"langchain\" # @param {type: \"string\"}"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "id": "QuQigs4UoFQ2",
97 | "metadata": {
98 | "id": "QuQigs4UoFQ2"
99 | },
100 | "source": [
101 | "### PGEngine Connection Pool\n",
102 | "\n",
103 | "One of the requirements and arguments to establish PostgreSQL as a vector store is a `PGEngine` object. The `PGEngine` configures a shared connection pool to your Postgres database. This is an industry best practice to manage number of connections and to reduce latency through cached database connections.\n",
104 | "\n",
105 | "To create a `PGEngine` using `PGEngine.from_connection_string()` you need to provide:\n",
106 | "\n",
107 | "1. `url` : Connection string using the `postgresql+asyncpg` driver.\n"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "# See docker command above to launch a Postgres instance with pgvector enabled.\n",
124 | "CONNECTION_STRING = (\n",
125 | " f\"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}\"\n",
126 | " f\":{POSTGRES_PORT}/{POSTGRES_DB}\"\n",
127 | ")\n",
128 | "# To use psycopg3 driver, set your connection string to `postgresql+psycopg://`"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "from langchain_postgres import PGEngine\n",
138 | "\n",
139 | "engine = PGEngine.from_connection_string(url=CONNECTION_STRING)"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "To create a `PGEngine` using `PGEngine.from_engine()` you need to provide:\n",
147 | "\n",
148 | "1. `engine` : An object of `AsyncEngine`"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "from sqlalchemy.ext.asyncio import create_async_engine\n",
158 | "\n",
159 | "# Create an SQLAlchemy Async Engine\n",
160 | "pool = create_async_engine(\n",
161 | " CONNECTION_STRING,\n",
162 | ")\n",
163 | "\n",
164 | "engine = PGEngine.from_engine(engine=pool)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {},
170 | "source": [
171 | "### Get all collections\n",
172 | "\n",
173 | "This script migrates each collection to a new Vector Store table."
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "from langchain_postgres.utils.pgvector_migrator import alist_pgvector_collection_names\n",
183 | "\n",
184 | "all_collection_names = await alist_pgvector_collection_names(engine)\n",
185 | "print(all_collection_names)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {
191 | "id": "D9Xs2qhm6X56"
192 | },
193 | "source": [
194 | "### Create a new table(s) to migrate existing data\n",
195 | "The `PGVectorStore` class requires a database table. The `PGEngine` engine has a helper method `ainit_vectorstore_table()` that can be used to create a table with the proper schema for you."
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n",
203 | "\n",
204 | "```python\n",
205 | "SCHEMA_NAME=\"my_schema\"\n",
206 | "\n",
207 | "await engine.ainit_vectorstore_table(\n",
208 | " table_name=TABLE_NAME,\n",
209 | " vector_size=768,\n",
210 | " schema_name=SCHEMA_NAME, # Default: \"public\"\n",
211 | ")\n",
212 | "```\n",
213 | "\n",
214 | "When creating your vectorstore table, you have the flexibility to define custom metadata and ID columns. This is particularly useful for:\n",
215 | "\n",
216 | "- **Filtering**: Metadata columns allow you to easily filter your data within the vectorstore. For example, you might store the document source, date, or author as metadata for efficient retrieval.\n",
217 | "- **Non-UUID Identifiers**: By default, the id_column uses UUIDs. If you need to use a different type of ID (e.g., an integer or string), you can define a custom id_column.\n",
218 | "\n",
219 | "```python\n",
220 | "metadata_columns = [\n",
221 | " Column(f\"col_0_{collection_name}\", \"VARCHAR\"),\n",
222 | " Column(f\"col_1_{collection_name}\", \"VARCHAR\"),\n",
223 | "]\n",
224 | "engine.init_vectorstore_table(\n",
225 | " table_name=\"destination_table\",\n",
226 | " vector_size=VECTOR_SIZE,\n",
227 | " metadata_columns=metadata_columns,\n",
228 | " id_column=Column(\"langchain_id\", \"VARCHAR\"),\n",
229 | ")"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {
236 | "id": "avlyHEMn6gzU"
237 | },
238 | "outputs": [],
239 | "source": [
240 | "# Vertex AI embeddings uses a vector size of 768.\n",
241 | "# Adjust this according to your embeddings service.\n",
242 | "VECTOR_SIZE = 768\n",
243 | "for collection_name in all_collection_names:\n",
244 | " engine.init_vectorstore_table(\n",
245 | " table_name=collection_name,\n",
246 | " vector_size=VECTOR_SIZE,\n",
247 | " )"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "### Create a vector store and migrate data\n",
255 | "\n",
256 | "> **_NOTE:_** The `FakeEmbeddings` embedding service is only used to initialize a vector store object, not to generate any embeddings. The embeddings are copied directly from the PGVector table.\n",
257 | "\n",
258 | "If you have any customizations on the metadata or the id columns, add them to the vector store as follows:\n",
259 | "\n",
260 | "```python\n",
261 | "from langchain_postgres import PGVectorStore\n",
262 | "from langchain_core.embeddings import FakeEmbeddings\n",
263 | "\n",
264 | "destination_vector_store = PGVectorStore.create_sync(\n",
265 | " engine,\n",
266 | " embedding_service=FakeEmbeddings(size=VECTOR_SIZE),\n",
267 | " table_name=DESTINATION_TABLE_NAME,\n",
268 | " metadata_columns=[col.name for col in metadata_columns],\n",
269 | " id_column=\"langchain_id\",\n",
270 | ")\n",
271 | "```"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "metadata": {
278 | "id": "z-AZyzAQ7bsf"
279 | },
280 | "outputs": [],
281 | "source": [
282 | "from langchain_core.embeddings import FakeEmbeddings\n",
283 | "\n",
284 | "from langchain_postgres import PGVectorStore\n",
285 | "from langchain_postgres.utils.pgvector_migrator import amigrate_pgvector_collection\n",
286 | "\n",
287 | "for collection_name in all_collection_names:\n",
288 | " destination_vector_store = await PGVectorStore.create(\n",
289 | " engine,\n",
290 | " embedding_service=FakeEmbeddings(size=VECTOR_SIZE),\n",
291 | " table_name=collection_name,\n",
292 | " )\n",
293 | "\n",
294 | " await amigrate_pgvector_collection(\n",
295 | " engine,\n",
296 | " # Set collection name here\n",
297 | " collection_name=collection_name,\n",
298 | " vector_store=destination_vector_store,\n",
299 | " # This deletes data from the original table upon migration. You can choose to turn it off.\n",
300 | " # The data will only be deleted from the original table once all of it has been successfully copied to the destination table.\n",
301 | " delete_pg_collection=True,\n",
302 | " )"
303 | ]
304 | }
305 | ],
306 | "metadata": {
307 | "colab": {
308 | "provenance": [],
309 | "toc_visible": true
310 | },
311 | "kernelspec": {
312 | "display_name": "Python 3",
313 | "name": "python3"
314 | },
315 | "language_info": {
316 | "codemirror_mode": {
317 | "name": "ipython",
318 | "version": 3
319 | },
320 | "file_extension": ".py",
321 | "mimetype": "text/x-python",
322 | "name": "python",
323 | "nbconvert_exporter": "python",
324 | "pygments_lexer": "ipython3",
325 | "version": "3.12.3"
326 | }
327 | },
328 | "nbformat": 4,
329 | "nbformat_minor": 0
330 | }
331 |
--------------------------------------------------------------------------------
/examples/migrate_pgvector_to_pgvectorstore.md:
--------------------------------------------------------------------------------
1 | # Migrate a `PGVector` vector store to `PGVectorStore`
2 |
3 | This guide shows how to migrate from the [`PGVector`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstores.py) vector store class to the [`PGVectorStore`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstore.py) class.
4 |
5 | ## Why migrate?
6 |
7 | This guide explains how to migrate your vector data from a PGVector-style database (two tables) to an PGVectoStore-style database (one table per collection) for improved performance and manageability.
8 |
9 | Migrating to the PGVectorStore interface provides the following benefits:
10 |
11 | - **Simplified management**: A single table contains data corresponding to a single collection, making it easier to query, update, and maintain.
12 | - **Improved metadata handling**: It stores metadata in columns instead of JSON, resulting in significant performance improvements.
13 | - **Schema flexibility**: The interface allows users to add tables into any database schema.
14 | - **Improved performance**: The single-table schema can lead to faster query execution, especially for large collections.
15 | - **Clear separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows.
16 | - **Secure Connections:** The PGVectorStore interface creates a secure connection pool that can be easily shared across your application using the `engine` object.
17 |
18 | ## Migration process
19 |
20 | > **_NOTE:_** The langchain-core library is installed to use the Fake embeddings service. To use a different embedding service, you'll need to install the appropriate library for your chosen provider. Choose embeddings services from [LangChain's Embedding models](https://python.langchain.com/v0.2/docs/integrations/text_embedding/).
21 |
22 | While you can use the existing PGVector database, we **strongly recommend** migrating your data to the PGVectorStore-style schema to take full advantage of the performance benefits.
23 |
24 | ### (Recommended) Data migration
25 |
26 | 1. **Create a PG engine.**
27 |
28 | ```python
29 | from langchain_postgres import PGEngine
30 |
31 | # Replace these variable values
32 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
33 | ```
34 |
35 | > **_NOTE:_** All sync methods have corresponding async methods.
36 |
37 | 2. **Create a new table to migrate existing data.**
38 |
39 | ```python
40 | # Vertex AI embeddings uses a vector size of 768.
41 | # Adjust this according to your embeddings service.
42 | VECTOR_SIZE = 768
43 |
44 | engine.init_vectorstore_table(
45 | table_name="destination_table",
46 | vector_size=VECTOR_SIZE,
47 | )
48 | ```
49 |
50 | **(Optional) Customize your table.**
51 |
52 | When creating your vectorstore table, you have the flexibility to define custom metadata and ID columns. This is particularly useful for:
53 |
54 | - **Filtering**: Metadata columns allow you to easily filter your data within the vectorstore. For example, you might store the document source, date, or author as metadata for efficient retrieval.
55 | - **Non-UUID Identifiers**: By default, the id_column uses UUIDs. If you need to use a different type of ID (e.g., an integer or string), you can define a custom id_column.
56 |
57 | ```python
58 | metadata_columns = [
59 | Column(f"col_0_{collection_name}", "VARCHAR"),
60 | Column(f"col_1_{collection_name}", "VARCHAR"),
61 | ]
62 | engine.init_vectorstore_table(
63 | table_name="destination_table",
64 | vector_size=VECTOR_SIZE,
65 | metadata_columns=metadata_columns,
66 | id_column=Column("langchain_id", "VARCHAR"),
67 | )
68 | ```
69 |
70 | 3. **Create a vector store object to interact with the new data.**
71 |
72 | > **_NOTE:_** The `FakeEmbeddings` embedding service is only used to initialise a vector store object, not to generate any embeddings. The embeddings are copied directly from the PGVector table.
73 |
74 | ```python
75 | from langchain_postgres import PGVectorStore
76 | from langchain_core.embeddings import FakeEmbeddings
77 |
78 | destination_vector_store = PGVectorStore.create_sync(
79 | engine,
80 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE),
81 | table_name="destination_table",
82 | )
83 | ```
84 |
85 | If you have any customisations on the metadata or the id columns, add them to the vector store as follows:
86 |
87 | ```python
88 | from langchain_postgres import PGVectorStore
89 | from langchain_core.embeddings import FakeEmbeddings
90 |
91 | destination_vector_store = PGVectorStore.create_sync(
92 | engine,
93 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE),
94 | table_name="destination_table",
95 | metadata_columns=[col.name for col in metadata_columns],
96 | id_column="langchain_id",
97 | )
98 | ```
99 |
100 | 4. **Migrate the data to the new table.**
101 |
102 | ```python
103 | from langchain_postgres.utils.pgvector_migrator import amigrate_pgvector_collection
104 |
105 | migrate_pgvector_collection(
106 | engine,
107 | # Set collection name here
108 | collection_name="collection_name",
109 | vector_store=destination_vector_store,
110 | # This deletes data from the original table upon migration. You can choose to turn it off.
111 | delete_pg_collection=True,
112 | )
113 | ```
114 |
115 | The data will only be deleted from the original table once all of it has been successfully copied to the destination table.
116 |
117 | > **TIP:** If you would like to migrate multiple collections, you can use the `alist_pgvector_collection_names` method to get the names of all collections, allowing you to iterate through them.
118 | >
119 | > ```python
120 | > from langchain_postgres.utils.pgvector_migrator import alist_pgvector_collection_names
121 | >
122 | > all_collection_names = list_pgvector_collection_names(engine)
123 | > print(all_collection_names)
124 | > ```
125 |
126 | ### (Not Recommended) Use PGVectorStore interface on PGVector databases
127 |
128 | If you choose not to migrate your data, you can still use the PGVectorStore interface with your existing PGVector database. However, you won't benefit from the performance improvements of the PGVectorStore-style schema.
129 |
130 | 1. **Create an PGVectorStore engine.**
131 |
132 | ```python
133 | from langchain_postgres import PGEngine
134 |
135 | # Replace these variable values
136 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
137 | ```
138 |
139 | > **_NOTE:_** All sync methods have corresponding async methods.
140 |
141 | 2. **Create a vector store object to interact with the data.**
142 |
143 | Use the embeddings service used by your database. See [langchain docs](https://python.langchain.com/docs/integrations/text_embedding/) for reference.
144 |
145 | ```python
146 | from langchain_postgres import PGVectorStore
147 | from langchain_core.embeddings import FakeEmbeddings
148 |
149 | vector_store = PGVectorStore.create_sync(
150 | engine=engine,
151 | table_name="langchain_pg_embedding",
152 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE),
153 | content_column="document",
154 | metadata_json_column="cmetadata",
155 | metadata_columns=["collection_id"],
156 | id_column="id",
157 | )
158 | ```
159 |
160 | 3. **Perform similarity search.**
161 |
162 | Filter by collection id:
163 |
164 | ```python
165 | vector_store.similarity_search("query", k=5, filter=f"collection_id='{uuid}'")
166 | ```
167 |
168 | Filter by collection id and metadata:
169 |
170 | ```python
171 | vector_store.similarity_search(
172 | "query", k=5, filter=f"collection_id='{uuid}' and cmetadata->>'col_name' = 'value'"
173 | )
174 | ```
--------------------------------------------------------------------------------
/examples/pg_vectorstore.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# PGVectorStore\n",
8 | "\n",
9 | "`PGVectorStore` is a an implementation of the the LangChain vectorstore abstraction using `postgres` as the backend.\n",
10 | "\n",
11 | "## Requirements\n",
12 | "\n",
13 | "You'll need a PostgreSQL database with the `pgvector` extension enabled.\n",
14 | "\n",
15 | "\n",
16 | "For local development, you can use the following docker command to spin up the database:\n",
17 | "\n",
18 | "```shell\n",
19 | "docker run --name pgvector-container -e POSTGRES_USER=langchain -e POSTGRES_PASSWORD=langchain -e POSTGRES_DB=langchain -p 6024:5432 -d pgvector/pgvector:pg16\n",
20 | "```"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {
26 | "id": "IR54BmgvdHT_"
27 | },
28 | "source": [
29 | "## Install\n",
30 | "\n",
31 | "Install the `langchain-postgres` package."
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "colab": {
39 | "base_uri": "https://localhost:8080/",
40 | "height": 1000
41 | },
42 | "id": "0ZITIDE160OD",
43 | "outputId": "e184bc0d-6541-4e0a-82d2-1e216db00a2d",
44 | "tags": []
45 | },
46 | "outputs": [],
47 | "source": [
48 | "%pip install --upgrade --quiet langchain-postgres"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {
54 | "id": "QuQigs4UoFQ2"
55 | },
56 | "source": [
57 | "## Create an engine\n",
58 | "\n",
59 | "The first step is to create a `PGEngine` instance, which does the following:\n",
60 | "\n",
61 | "1. Allows you to create tables for storing documents and embeddings.\n",
62 | "2. Maintains a connection pool that manages connections to the database. This allows sharing of the connection pool and helps to reduce latency for database calls."
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {
69 | "tags": []
70 | },
71 | "outputs": [],
72 | "source": [
73 | "from langchain_postgres import PGEngine\n",
74 | "\n",
75 | "# See docker command above to launch a Postgres instance with pgvector enabled.\n",
76 | "# Replace these values with your own configuration.\n",
77 | "POSTGRES_USER = \"langchain\"\n",
78 | "POSTGRES_PASSWORD = \"langchain\"\n",
79 | "POSTGRES_HOST = \"localhost\"\n",
80 | "POSTGRES_PORT = \"6024\"\n",
81 | "POSTGRES_DB = \"langchain\"\n",
82 | "\n",
83 | "CONNECTION_STRING = (\n",
84 | " f\"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}\"\n",
85 | " f\":{POSTGRES_PORT}/{POSTGRES_DB}\"\n",
86 | ")\n",
87 | "\n",
88 | "pg_engine = PGEngine.from_connection_string(url=CONNECTION_STRING)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "metadata": {},
94 | "source": [
95 | "To use psycopg3 driver, set your connection string to `postgresql+psycopg://`"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {
101 | "id": "D9Xs2qhm6X56"
102 | },
103 | "source": [
104 | "## Create a document collection\n",
105 | "\n",
106 | "Use the `PGEngine.ainit_vectorstore_table()` method to create a database table to store the documents and embeddings. This table will be created with appropriate schema."
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {
113 | "tags": []
114 | },
115 | "outputs": [],
116 | "source": [
117 | "TABLE_NAME = \"vectorstore\"\n",
118 | "\n",
119 | "# The vector size (also called embedding size) is determined by the embedding model you use!\n",
120 | "VECTOR_SIZE = 1536"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "Use the `Column` class to customize the table schema. A Column is defined by a name and data type. Any Postgres [data type](https://www.postgresql.org/docs/current/datatype.html) can be used."
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {
134 | "id": "avlyHEMn6gzU",
135 | "tags": []
136 | },
137 | "outputs": [],
138 | "source": [
139 | "from sqlalchemy.exc import ProgrammingError\n",
140 | "\n",
141 | "from langchain_postgres import Column\n",
142 | "\n",
143 | "try:\n",
144 | " await pg_engine.ainit_vectorstore_table(\n",
145 | " table_name=TABLE_NAME,\n",
146 | " vector_size=VECTOR_SIZE,\n",
147 | " metadata_columns=[\n",
148 | " Column(\"likes\", \"INTEGER\"),\n",
149 | " Column(\"location\", \"TEXT\"),\n",
150 | " Column(\"topic\", \"TEXT\"),\n",
151 | " ],\n",
152 | " )\n",
153 | "except ProgrammingError:\n",
154 | " # Catching the exception here\n",
155 | " print(\"Table already exists. Skipping creation.\")"
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {},
161 | "source": [
162 | "### Configure an embeddings model\n",
163 | "\n",
164 | "You need to configure a vectorstore with an embedding model. The embedding model will be used automatically when adding documents and when searching.\n",
165 | "\n",
166 | "We'll use `langchain-openai` as the embedding more here, but you can use any [LangChain embeddings model](https://python.langchain.com/docs/integrations/text_embedding/)."
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {
173 | "tags": []
174 | },
175 | "outputs": [],
176 | "source": [
177 | "%pip install --upgrade --quiet langchain-openai"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "colab": {
185 | "base_uri": "https://localhost:8080/"
186 | },
187 | "id": "Vb2RJocV9_LQ",
188 | "outputId": "37f5dc74-2512-47b2-c135-f34c10afdcf4",
189 | "tags": []
190 | },
191 | "outputs": [],
192 | "source": [
193 | "from langchain_openai import OpenAIEmbeddings\n",
194 | "\n",
195 | "embedding = OpenAIEmbeddings(model=\"text-embedding-3-small\")"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {
201 | "id": "e1tl0aNx7SWy"
202 | },
203 | "source": [
204 | "## Initialize the vectorstore\n",
205 | "\n",
206 | "Once the schema for the document collection exists, you can initialize a vectorstore that uses the schema.\n",
207 | "\n",
208 | "You can use the vectorstore to do basic operations; including:\n",
209 | "\n",
210 | "1. Add documents\n",
211 | "2. Delete documents\n",
212 | "3. Search through the documents"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": null,
218 | "metadata": {
219 | "id": "z-AZyzAQ7bsf",
220 | "tags": []
221 | },
222 | "outputs": [],
223 | "source": [
224 | "from langchain_postgres import PGVectorStore\n",
225 | "\n",
226 | "vectorstore = await PGVectorStore.create(\n",
227 | " engine=pg_engine,\n",
228 | " table_name=TABLE_NAME,\n",
229 | " embedding_service=embedding,\n",
230 | " metadata_columns=[\"location\", \"topic\"],\n",
231 | ")"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "## Add documents\n",
239 | "\n",
240 | "\n",
241 | "You can add documents using the `aadd_documents` method. \n",
242 | "\n",
243 | "* Assign unique IDs to documents to avoid duplicated content in your database.\n",
244 | "* Adding a document by ID implements has `upsert` semantics (i.e., create if does not exist, update if exists)."
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {
251 | "tags": []
252 | },
253 | "outputs": [],
254 | "source": [
255 | "import uuid\n",
256 | "\n",
257 | "from langchain_core.documents import Document\n",
258 | "\n",
259 | "docs = [\n",
260 | " Document(\n",
261 | " id=uuid.uuid4(),\n",
262 | " page_content=\"there are cats in the pond\",\n",
263 | " metadata={\"likes\": 1, \"location\": \"pond\", \"topic\": \"animals\"},\n",
264 | " ),\n",
265 | " Document(\n",
266 | " id=uuid.uuid4(),\n",
267 | " page_content=\"ducks are also found in the pond\",\n",
268 | " metadata={\"likes\": 30, \"location\": \"pond\", \"topic\": \"animals\"},\n",
269 | " ),\n",
270 | " Document(\n",
271 | " id=uuid.uuid4(),\n",
272 | " page_content=\"fresh apples are available at the market\",\n",
273 | " metadata={\"likes\": 20, \"location\": \"market\", \"topic\": \"food\"},\n",
274 | " ),\n",
275 | " Document(\n",
276 | " id=uuid.uuid4(),\n",
277 | " page_content=\"the market also sells fresh oranges\",\n",
278 | " metadata={\"likes\": 5, \"location\": \"market\", \"topic\": \"food\"},\n",
279 | " ),\n",
280 | "]\n",
281 | "\n",
282 | "\n",
283 | "await vectorstore.aadd_documents(documents=docs)"
284 | ]
285 | },
286 | {
287 | "cell_type": "markdown",
288 | "metadata": {},
289 | "source": [
290 | "## Delete Documents\n",
291 | "\n",
292 | "Documents can be deleted by ID."
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {
299 | "tags": []
300 | },
301 | "outputs": [],
302 | "source": [
303 | "# We'll use the ID of the first doc to delete it\n",
304 | "ids = [docs[0].id]\n",
305 | "await vectorstore.adelete(ids)"
306 | ]
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "metadata": {},
311 | "source": [
312 | "## Search\n",
313 | "\n",
314 | "Search for similar documents using a natural language query."
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "metadata": {
321 | "tags": []
322 | },
323 | "outputs": [],
324 | "source": [
325 | "query = \"I'd like a fruit.\"\n",
326 | "docs = await vectorstore.asimilarity_search(query)\n",
327 | "for doc in docs:\n",
328 | " print(repr(doc))"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "metadata": {},
334 | "source": [
335 | "### Search by vector"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "metadata": {
342 | "tags": []
343 | },
344 | "outputs": [],
345 | "source": [
346 | "query_vector = embedding.embed_query(query)\n",
347 | "docs = await vectorstore.asimilarity_search_by_vector(query_vector, k=2)\n",
348 | "print(docs)"
349 | ]
350 | },
351 | {
352 | "cell_type": "markdown",
353 | "metadata": {},
354 | "source": [
355 | "## Filtering"
356 | ]
357 | },
358 | {
359 | "cell_type": "markdown",
360 | "metadata": {},
361 | "source": [
362 | "To enable search with filters, it is necessary to declare the columns that you want to filter on when creating the table. The vectorstore supports a set of filters that can be applied against the metadata fields of the documents.\n",
363 | "\n",
364 | "`PGVectorStore` currently supports the following operators.\n",
365 | "\n",
366 | "| Operator | Meaning/Category |\n",
367 | "|-----------|-------------------------|\n",
368 | "| \\$eq | Equality (==) |\n",
369 | "| \\$ne | Inequality (!=) |\n",
370 | "| \\$lt | Less than (<) |\n",
371 | "| \\$lte | Less than or equal (<=) |\n",
372 | "| \\$gt | Greater than (>) |\n",
373 | "| \\$gte | Greater than or equal (>=) |\n",
374 | "| \\$in | Special Cased (in) |\n",
375 | "| \\$nin | Special Cased (not in) |\n",
376 | "| \\$between | Special Cased (between) |\n",
377 | "| \\$exists | Special Cased (is null) |\n",
378 | "| \\$like | Text (like) |\n",
379 | "| \\$ilike | Text (case-insensitive like) |\n",
380 | "| \\$and | Logical (and) |\n",
381 | "| \\$or | Logical (or) |\n"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "metadata": {
388 | "tags": []
389 | },
390 | "outputs": [],
391 | "source": [
392 | "await vectorstore.asimilarity_search(\n",
393 | " \"birds\", filter={\"$or\": [{\"topic\": \"animals\"}, {\"location\": \"market\"}]}\n",
394 | ")"
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": null,
400 | "metadata": {
401 | "tags": []
402 | },
403 | "outputs": [],
404 | "source": [
405 | "await vectorstore.asimilarity_search(\"apple\", filter={\"topic\": \"food\"})"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": null,
411 | "metadata": {
412 | "tags": []
413 | },
414 | "outputs": [],
415 | "source": [
416 | "await vectorstore.asimilarity_search(\n",
417 | " \"apple\", filter={\"topic\": {\"$in\": [\"food\", \"animals\"]}}\n",
418 | ")"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": null,
424 | "metadata": {
425 | "tags": []
426 | },
427 | "outputs": [],
428 | "source": [
429 | "await vectorstore.asimilarity_search(\n",
430 | " \"sales of fruit\", filter={\"topic\": {\"$ne\": \"animals\"}}\n",
431 | ")"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {},
437 | "source": [
438 | "## Optimization\n",
439 | "\n",
440 | "Speed up vector search queries by adding appropriate indexes. Learn more about [vector indexes](https://cloud.google.com/blog/products/databases/faster-similarity-search-performance-with-pgvector-indexes)."
441 | ]
442 | },
443 | {
444 | "cell_type": "markdown",
445 | "metadata": {},
446 | "source": [
447 | "### Add an Index"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": null,
453 | "metadata": {
454 | "tags": []
455 | },
456 | "outputs": [],
457 | "source": [
458 | "from langchain_postgres.v2.indexes import IVFFlatIndex\n",
459 | "\n",
460 | "index = IVFFlatIndex() # Add an index using a default index name\n",
461 | "await vectorstore.aapply_vector_index(index)"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | "### Re-index\n",
469 | "\n",
470 | "Rebuild an index using the data stored in the index's table, replacing the old copy of the index. Some index types may require re-indexing after a considerable amount of new data is added."
471 | ]
472 | },
473 | {
474 | "cell_type": "code",
475 | "execution_count": null,
476 | "metadata": {
477 | "tags": []
478 | },
479 | "outputs": [],
480 | "source": [
481 | "await vectorstore.areindex() # Re-index using default index name"
482 | ]
483 | },
484 | {
485 | "cell_type": "markdown",
486 | "metadata": {},
487 | "source": [
488 | "### Drop an index\n",
489 | "\n",
490 | "You can delete indexes"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "metadata": {
497 | "tags": []
498 | },
499 | "outputs": [],
500 | "source": [
501 | "await vectorstore.adrop_vector_index() # Drop index using default name"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {},
507 | "source": [
508 | "## Clean up\n",
509 | "\n",
510 | "**⚠️ WARNING: this can not be undone**\n",
511 | "\n",
512 | "Drop the vector store table."
513 | ]
514 | },
515 | {
516 | "cell_type": "code",
517 | "execution_count": null,
518 | "metadata": {},
519 | "outputs": [],
520 | "source": [
521 | "await pg_engine.adrop_table(TABLE_NAME)"
522 | ]
523 | }
524 | ],
525 | "metadata": {
526 | "colab": {
527 | "provenance": [],
528 | "toc_visible": true
529 | },
530 | "kernelspec": {
531 | "display_name": "Python 3 (ipykernel)",
532 | "language": "python",
533 | "name": "python3"
534 | },
535 | "language_info": {
536 | "codemirror_mode": {
537 | "name": "ipython",
538 | "version": 3
539 | },
540 | "file_extension": ".py",
541 | "mimetype": "text/x-python",
542 | "name": "python",
543 | "nbconvert_exporter": "python",
544 | "pygments_lexer": "ipython3",
545 | "version": "3.11.4"
546 | }
547 | },
548 | "nbformat": 4,
549 | "nbformat_minor": 4
550 | }
551 |
--------------------------------------------------------------------------------
/langchain_postgres/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import metadata
2 |
3 | from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
4 | from langchain_postgres.translator import PGVectorTranslator
5 | from langchain_postgres.v2.engine import Column, ColumnDict, PGEngine
6 | from langchain_postgres.v2.vectorstores import PGVectorStore
7 | from langchain_postgres.vectorstores import PGVector
8 |
9 | try:
10 | __version__ = metadata.version(__package__)
11 | except metadata.PackageNotFoundError:
12 | # Case where package metadata is not available.
13 | __version__ = ""
14 |
15 | __all__ = [
16 | "__version__",
17 | "Column",
18 | "ColumnDict",
19 | "PGEngine",
20 | "PostgresChatMessageHistory",
21 | "PGVector",
22 | "PGVectorStore",
23 | "PGVectorTranslator",
24 | ]
25 |
--------------------------------------------------------------------------------
/langchain_postgres/_utils.py:
--------------------------------------------------------------------------------
1 | """Copied over from langchain_community.
2 |
3 | This code should be moved to langchain proper or removed entirely.
4 | """
5 |
6 | import logging
7 | from typing import List, Union
8 |
9 | import numpy as np
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
14 |
15 |
16 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
17 | """Row-wise cosine similarity between two equal-width matrices."""
18 | if len(X) == 0 or len(Y) == 0:
19 | return np.array([])
20 |
21 | X = np.array(X)
22 | Y = np.array(Y)
23 | if X.shape[1] != Y.shape[1]:
24 | raise ValueError(
25 | f"Number of columns in X and Y must be the same. X has shape {X.shape} "
26 | f"and Y has shape {Y.shape}."
27 | )
28 | try:
29 | import simsimd as simd # type: ignore
30 |
31 | X = np.array(X, dtype=np.float32)
32 | Y = np.array(Y, dtype=np.float32)
33 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine"))
34 | return Z
35 | except ImportError:
36 | logger.debug(
37 | "Unable to import simsimd, defaulting to NumPy implementation. If you want "
38 | "to use simsimd please install with `pip install simsimd`."
39 | )
40 | X_norm = np.linalg.norm(X, axis=1)
41 | Y_norm = np.linalg.norm(Y, axis=1)
42 | # Ignore divide by zero errors run time warnings as those are handled below.
43 | with np.errstate(divide="ignore", invalid="ignore"):
44 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
45 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
46 | return similarity
47 |
48 |
49 | def maximal_marginal_relevance(
50 | query_embedding: np.ndarray,
51 | embedding_list: list,
52 | lambda_mult: float = 0.5,
53 | k: int = 4,
54 | ) -> List[int]:
55 | """Calculate maximal marginal relevance."""
56 | if min(k, len(embedding_list)) <= 0:
57 | return []
58 | if query_embedding.ndim == 1:
59 | query_embedding = np.expand_dims(query_embedding, axis=0)
60 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
61 | most_similar = int(np.argmax(similarity_to_query))
62 | idxs = [most_similar]
63 | selected = np.array([embedding_list[most_similar]])
64 | while len(idxs) < min(k, len(embedding_list)):
65 | best_score = -np.inf
66 | idx_to_add = -1
67 | similarity_to_selected = cosine_similarity(embedding_list, selected)
68 | for i, query_score in enumerate(similarity_to_query):
69 | if i in idxs:
70 | continue
71 | redundant_score = max(similarity_to_selected[i])
72 | equation_score = (
73 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score
74 | )
75 | if equation_score > best_score:
76 | best_score = equation_score
77 | idx_to_add = i
78 | idxs.append(idx_to_add)
79 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
80 | return idxs
81 |
--------------------------------------------------------------------------------
/langchain_postgres/chat_message_histories.py:
--------------------------------------------------------------------------------
1 | """Client for persisting chat message history in a Postgres database.
2 |
3 | This client provides support for both sync and async via psycopg 3.
4 | """
5 |
6 | from __future__ import annotations
7 |
8 | import json
9 | import logging
10 | import re
11 | import uuid
12 | from typing import List, Optional, Sequence
13 |
14 | import psycopg
15 | from langchain_core.chat_history import BaseChatMessageHistory
16 | from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
17 | from psycopg import sql
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def _create_table_and_index(table_name: str) -> List[sql.Composed]:
23 | """Make a SQL query to create a table."""
24 | index_name = f"idx_{table_name}_session_id"
25 | statements = [
26 | sql.SQL(
27 | """
28 | CREATE TABLE IF NOT EXISTS {table_name} (
29 | id SERIAL PRIMARY KEY,
30 | session_id UUID NOT NULL,
31 | message JSONB NOT NULL,
32 | created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
33 | );
34 | """
35 | ).format(table_name=sql.Identifier(table_name)),
36 | sql.SQL(
37 | """
38 | CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
39 | """
40 | ).format(
41 | table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name)
42 | ),
43 | ]
44 | return statements
45 |
46 |
47 | def _get_messages_query(table_name: str) -> sql.Composed:
48 | """Make a SQL query to get messages for a given session."""
49 | return sql.SQL(
50 | "SELECT message "
51 | "FROM {table_name} "
52 | "WHERE session_id = %(session_id)s "
53 | "ORDER BY id;"
54 | ).format(table_name=sql.Identifier(table_name))
55 |
56 |
57 | def _delete_by_session_id_query(table_name: str) -> sql.Composed:
58 | """Make a SQL query to delete messages for a given session."""
59 | return sql.SQL(
60 | "DELETE FROM {table_name} WHERE session_id = %(session_id)s;"
61 | ).format(table_name=sql.Identifier(table_name))
62 |
63 |
64 | def _delete_table_query(table_name: str) -> sql.Composed:
65 | """Make a SQL query to delete a table."""
66 | return sql.SQL("DROP TABLE IF EXISTS {table_name};").format(
67 | table_name=sql.Identifier(table_name)
68 | )
69 |
70 |
71 | def _insert_message_query(table_name: str) -> sql.Composed:
72 | """Make a SQL query to insert a message."""
73 | return sql.SQL(
74 | "INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)"
75 | ).format(table_name=sql.Identifier(table_name))
76 |
77 |
78 | class PostgresChatMessageHistory(BaseChatMessageHistory):
79 | def __init__(
80 | self,
81 | table_name: str,
82 | session_id: str,
83 | /,
84 | *,
85 | sync_connection: Optional[psycopg.Connection] = None,
86 | async_connection: Optional[psycopg.AsyncConnection] = None,
87 | ) -> None:
88 | """Client for persisting chat message history in a Postgres database,
89 |
90 | This client provides support for both sync and async via psycopg >=3.
91 |
92 | The client can create schema in the database and provides methods to
93 | add messages, get messages, and clear the chat message history.
94 |
95 | The schema has the following columns:
96 |
97 | - id: A serial primary key.
98 | - session_id: The session ID for the chat message history.
99 | - message: The JSONB message content.
100 | - created_at: The timestamp of when the message was created.
101 |
102 | Messages are retrieved for a given session_id and are sorted by
103 | the id (which should be increasing monotonically), and correspond
104 | to the order in which the messages were added to the history.
105 |
106 | The "created_at" column is not returned by the interface, but
107 | has been added for the schema so the information is available in the database.
108 |
109 | A session_id can be used to separate different chat histories in the same table,
110 | the session_id should be provided when initializing the client.
111 |
112 | This chat history client takes in a psycopg connection object (either
113 | Connection or AsyncConnection) and uses it to interact with the database.
114 |
115 | This design allows to reuse the underlying connection object across
116 | multiple instantiations of this class, making instantiation fast.
117 |
118 | This chat history client is designed for prototyping applications that
119 | involve chat and are based on Postgres.
120 |
121 | As your application grows, you will likely need to extend the schema to
122 | handle more complex queries. For example, a chat application
123 | may involve multiple tables like a user table, a table for storing
124 | chat sessions / conversations, and this table for storing chat messages
125 | for a given session. The application will require access to additional
126 | endpoints like deleting messages by user id, listing conversations by
127 | user id or ordering them based on last message time, etc.
128 |
129 | Feel free to adapt this implementation to suit your application's needs.
130 |
131 | Args:
132 | session_id: The session ID to use for the chat message history
133 | table_name: The name of the database table to use
134 | sync_connection: An existing psycopg connection instance
135 | async_connection: An existing psycopg async connection instance
136 |
137 | Usage:
138 | - Use the create_tables or acreate_tables method to set up the table
139 | schema in the database.
140 | - Initialize the class with the appropriate session ID, table name,
141 | and database connection.
142 | - Add messages to the database using add_messages or aadd_messages.
143 | - Retrieve messages with get_messages or aget_messages.
144 | - Clear the session history with clear or aclear when needed.
145 |
146 | Note:
147 | - At least one of sync_connection or async_connection must be provided.
148 |
149 | Examples:
150 |
151 | .. code-block:: python
152 |
153 | import uuid
154 |
155 | from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
156 | from langchain_postgres import PostgresChatMessageHistory
157 | import psycopg
158 |
159 | # Establish a synchronous connection to the database
160 | # (or use psycopg.AsyncConnection for async)
161 | sync_connection = psycopg2.connect(conn_info)
162 |
163 | # Create the table schema (only needs to be done once)
164 | table_name = "chat_history"
165 | PostgresChatMessageHistory.create_tables(sync_connection, table_name)
166 |
167 | session_id = str(uuid.uuid4())
168 |
169 | # Initialize the chat history manager
170 | chat_history = PostgresChatMessageHistory(
171 | table_name,
172 | session_id,
173 | sync_connection=sync_connection
174 | )
175 |
176 | # Add messages to the chat history
177 | chat_history.add_messages([
178 | SystemMessage(content="Meow"),
179 | AIMessage(content="woof"),
180 | HumanMessage(content="bark"),
181 | ])
182 |
183 | print(chat_history.messages)
184 | """
185 | if not sync_connection and not async_connection:
186 | raise ValueError("Must provide sync_connection or async_connection")
187 |
188 | self._connection = sync_connection
189 | self._aconnection = async_connection
190 |
191 | # Validate that session id is a UUID
192 | try:
193 | uuid.UUID(session_id)
194 | except ValueError:
195 | raise ValueError(
196 | f"Invalid session id. Session id must be a valid UUID. Got {session_id}"
197 | )
198 |
199 | self._session_id = session_id
200 |
201 | if not re.match(r"^\w+$", table_name):
202 | raise ValueError(
203 | "Invalid table name. Table name must contain only alphanumeric "
204 | "characters and underscores."
205 | )
206 | self._table_name = table_name
207 |
208 | @staticmethod
209 | def create_tables(
210 | connection: psycopg.Connection,
211 | table_name: str,
212 | /,
213 | ) -> None:
214 | """Create the table schema in the database and create relevant indexes."""
215 | queries = _create_table_and_index(table_name)
216 | logger.info("Creating schema for table %s", table_name)
217 | with connection.cursor() as cursor:
218 | for query in queries:
219 | cursor.execute(query)
220 | connection.commit()
221 |
222 | @staticmethod
223 | async def acreate_tables(
224 | connection: psycopg.AsyncConnection, table_name: str, /
225 | ) -> None:
226 | """Create the table schema in the database and create relevant indexes."""
227 | queries = _create_table_and_index(table_name)
228 | logger.info("Creating schema for table %s", table_name)
229 | async with connection.cursor() as cur:
230 | for query in queries:
231 | await cur.execute(query)
232 | await connection.commit()
233 |
234 | @staticmethod
235 | def drop_table(connection: psycopg.Connection, table_name: str, /) -> None:
236 | """Delete the table schema in the database.
237 |
238 | WARNING:
239 | This will delete the given table from the database including
240 | all the database in the table and the schema of the table.
241 |
242 | Args:
243 | connection: The database connection.
244 | table_name: The name of the table to create.
245 | """
246 |
247 | query = _delete_table_query(table_name)
248 | logger.info("Dropping table %s", table_name)
249 | with connection.cursor() as cursor:
250 | cursor.execute(query)
251 | connection.commit()
252 |
253 | @staticmethod
254 | async def adrop_table(
255 | connection: psycopg.AsyncConnection, table_name: str, /
256 | ) -> None:
257 | """Delete the table schema in the database.
258 |
259 | WARNING:
260 | This will delete the given table from the database including
261 | all the database in the table and the schema of the table.
262 |
263 | Args:
264 | connection: Async database connection.
265 | table_name: The name of the table to create.
266 | """
267 | query = _delete_table_query(table_name)
268 | logger.info("Dropping table %s", table_name)
269 |
270 | async with connection.cursor() as acur:
271 | await acur.execute(query)
272 | await connection.commit()
273 |
274 | def add_messages(self, messages: Sequence[BaseMessage]) -> None:
275 | """Add messages to the chat message history."""
276 | if self._connection is None:
277 | raise ValueError(
278 | "Please initialize the PostgresChatMessageHistory "
279 | "with a sync connection or use the aadd_messages method instead."
280 | )
281 |
282 | values = [
283 | (self._session_id, json.dumps(message_to_dict(message)))
284 | for message in messages
285 | ]
286 |
287 | query = _insert_message_query(self._table_name)
288 |
289 | with self._connection.cursor() as cursor:
290 | cursor.executemany(query, values)
291 | self._connection.commit()
292 |
293 | async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
294 | """Add messages to the chat message history."""
295 | if self._aconnection is None:
296 | raise ValueError(
297 | "Please initialize the PostgresChatMessageHistory "
298 | "with an async connection or use the sync add_messages method instead."
299 | )
300 |
301 | values = [
302 | (self._session_id, json.dumps(message_to_dict(message)))
303 | for message in messages
304 | ]
305 |
306 | query = _insert_message_query(self._table_name)
307 | async with self._aconnection.cursor() as cursor:
308 | await cursor.executemany(query, values)
309 | await self._aconnection.commit()
310 |
311 | def get_messages(self) -> List[BaseMessage]:
312 | """Retrieve messages from the chat message history."""
313 | if self._connection is None:
314 | raise ValueError(
315 | "Please initialize the PostgresChatMessageHistory "
316 | "with a sync connection or use the async aget_messages method instead."
317 | )
318 |
319 | query = _get_messages_query(self._table_name)
320 |
321 | with self._connection.cursor() as cursor:
322 | cursor.execute(query, {"session_id": self._session_id})
323 | items = [record[0] for record in cursor.fetchall()]
324 |
325 | messages = messages_from_dict(items)
326 | return messages
327 |
328 | async def aget_messages(self) -> List[BaseMessage]:
329 | """Retrieve messages from the chat message history."""
330 | if self._aconnection is None:
331 | raise ValueError(
332 | "Please initialize the PostgresChatMessageHistory "
333 | "with an async connection or use the sync get_messages method instead."
334 | )
335 |
336 | query = _get_messages_query(self._table_name)
337 | async with self._aconnection.cursor() as cursor:
338 | await cursor.execute(query, {"session_id": self._session_id})
339 | items = [record[0] for record in await cursor.fetchall()]
340 |
341 | messages = messages_from_dict(items)
342 | return messages
343 |
344 | @property
345 | def messages(self) -> List[BaseMessage]:
346 | """The abstraction required a property."""
347 | return self.get_messages()
348 |
349 | @messages.setter
350 | def messages(self, value: list[BaseMessage]) -> None:
351 | """Clear the stored messages and appends a list of messages."""
352 | self.clear()
353 | self.add_messages(value)
354 |
355 | def clear(self) -> None:
356 | """Clear the chat message history for the GIVEN session."""
357 | if self._connection is None:
358 | raise ValueError(
359 | "Please initialize the PostgresChatMessageHistory "
360 | "with a sync connection or use the async clear method instead."
361 | )
362 |
363 | query = _delete_by_session_id_query(self._table_name)
364 | with self._connection.cursor() as cursor:
365 | cursor.execute(query, {"session_id": self._session_id})
366 | self._connection.commit()
367 |
368 | async def aclear(self) -> None:
369 | """Clear the chat message history for the GIVEN session."""
370 | if self._aconnection is None:
371 | raise ValueError(
372 | "Please initialize the PostgresChatMessageHistory "
373 | "with an async connection or use the sync clear method instead."
374 | )
375 |
376 | query = _delete_by_session_id_query(self._table_name)
377 | async with self._aconnection.cursor() as cursor:
378 | await cursor.execute(query, {"session_id": self._session_id})
379 | await self._aconnection.commit()
380 |
--------------------------------------------------------------------------------
/langchain_postgres/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/langchain_postgres/py.typed
--------------------------------------------------------------------------------
/langchain_postgres/translator.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Union
2 |
3 | from langchain_core.structured_query import (
4 | Comparator,
5 | Comparison,
6 | Operation,
7 | Operator,
8 | StructuredQuery,
9 | Visitor,
10 | )
11 |
12 |
13 | class PGVectorTranslator(Visitor):
14 | """Translate `PGVector` internal query language elements to valid filters."""
15 |
16 | allowed_operators = [Operator.AND, Operator.OR]
17 | """Subset of allowed logical operators."""
18 | allowed_comparators = [
19 | Comparator.EQ,
20 | Comparator.NE,
21 | Comparator.GT,
22 | Comparator.LT,
23 | Comparator.IN,
24 | Comparator.NIN,
25 | Comparator.CONTAIN,
26 | Comparator.LIKE,
27 | ]
28 | """Subset of allowed logical comparators."""
29 |
30 | def _format_func(self, func: Union[Operator, Comparator]) -> str:
31 | self._validate_func(func)
32 | return f"${func.value}"
33 |
34 | def visit_operation(self, operation: Operation) -> Dict:
35 | args = [arg.accept(self) for arg in operation.arguments]
36 | return {self._format_func(operation.operator): args}
37 |
38 | def visit_comparison(self, comparison: Comparison) -> Dict:
39 | return {
40 | comparison.attribute: {
41 | self._format_func(comparison.comparator): comparison.value
42 | }
43 | }
44 |
45 | def visit_structured_query(
46 | self, structured_query: StructuredQuery
47 | ) -> Tuple[str, dict]:
48 | if structured_query.filter is None:
49 | kwargs = {}
50 | else:
51 | kwargs = {"filter": structured_query.filter.accept(self)}
52 | return structured_query.query, kwargs
53 |
--------------------------------------------------------------------------------
/langchain_postgres/utils/pgvector_migrator.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import warnings
4 | from typing import Any, AsyncIterator, Iterator, Optional, Sequence, TypeVar
5 |
6 | from sqlalchemy import RowMapping, text
7 | from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
8 |
9 | from ..v2.engine import PGEngine
10 | from ..v2.vectorstores import PGVectorStore
11 |
12 | COLLECTIONS_TABLE = "langchain_pg_collection"
13 | EMBEDDINGS_TABLE = "langchain_pg_embedding"
14 |
15 | T = TypeVar("T")
16 |
17 |
18 | async def __aget_collection_uuid(
19 | engine: PGEngine,
20 | collection_name: str,
21 | ) -> str:
22 | """
23 | Get the collection uuid for a collection present in PGVector tables.
24 |
25 | Args:
26 | engine (PGEngine): The PG engine corresponding to the Database.
27 | collection_name (str): The name of the collection to get the uuid for.
28 | Returns:
29 | The uuid corresponding to the collection.
30 | """
31 | query = f"SELECT name, uuid FROM {COLLECTIONS_TABLE} WHERE name = :collection_name"
32 | async with engine._pool.connect() as conn:
33 | result = await conn.execute(
34 | text(query), parameters={"collection_name": collection_name}
35 | )
36 | result_map = result.mappings()
37 | result_fetch = result_map.fetchone()
38 | if result_fetch is None:
39 | raise ValueError(f"Collection, {collection_name} not found.")
40 | return result_fetch.uuid
41 |
42 |
43 | async def __aextract_pgvector_collection(
44 | engine: PGEngine,
45 | collection_name: str,
46 | batch_size: int = 1000,
47 | ) -> AsyncIterator[Sequence[RowMapping]]:
48 | """
49 | Extract all data belonging to a PGVector collection.
50 |
51 | Args:
52 | engine (PGEngine): The PG engine corresponding to the Database.
53 | collection_name (str): The name of the collection to get the data for.
54 | batch_size (int): The batch size for collection extraction.
55 | Default: 1000. Optional.
56 |
57 | Yields:
58 | The data present in the collection.
59 | """
60 | try:
61 | uuid_task = asyncio.create_task(__aget_collection_uuid(engine, collection_name))
62 | query = f"SELECT * FROM {EMBEDDINGS_TABLE} WHERE collection_id = :id"
63 | async with engine._pool.connect() as conn:
64 | uuid = await uuid_task
65 | result_proxy = await conn.execute(text(query), parameters={"id": uuid})
66 | while True:
67 | rows = result_proxy.fetchmany(size=batch_size)
68 | if not rows:
69 | break
70 | yield [row._mapping for row in rows]
71 | except ValueError:
72 | raise ValueError(f"Collection, {collection_name} does not exist.")
73 | except SQLAlchemyError as e:
74 | raise ProgrammingError(
75 | statement=f"Failed to extract data from collection '{collection_name}': {e}",
76 | params={"id": uuid},
77 | orig=e,
78 | ) from e
79 |
80 |
81 | async def __concurrent_batch_insert(
82 | data_batches: AsyncIterator[Sequence[RowMapping]],
83 | vector_store: PGVectorStore,
84 | max_concurrency: int = 100,
85 | ) -> None:
86 | pending: set[Any] = set()
87 | async for batch_data in data_batches:
88 | pending.add(
89 | asyncio.ensure_future(
90 | vector_store.aadd_embeddings(
91 | texts=[data.document for data in batch_data],
92 | embeddings=[json.loads(data.embedding) for data in batch_data],
93 | metadatas=[data.cmetadata for data in batch_data],
94 | ids=[data.id for data in batch_data],
95 | )
96 | )
97 | )
98 | if len(pending) >= max_concurrency:
99 | _, pending = await asyncio.wait(
100 | pending, return_when=asyncio.FIRST_COMPLETED
101 | )
102 | if pending:
103 | await asyncio.wait(pending)
104 |
105 |
106 | async def __amigrate_pgvector_collection(
107 | engine: PGEngine,
108 | collection_name: str,
109 | vector_store: PGVectorStore,
110 | delete_pg_collection: Optional[bool] = False,
111 | insert_batch_size: int = 1000,
112 | ) -> None:
113 | """
114 | Migrate all data present in a PGVector collection to use separate tables for each collection.
115 | The new data format is compatible with the PGVectoreStore interface.
116 |
117 | Args:
118 | engine (PGEngine): The PG engine corresponding to the Database.
119 | collection_name (str): The collection to migrate.
120 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table.
121 | delete_pg_collection (bool): An option to delete the original data upon migration.
122 | Default: False. Optional.
123 | insert_batch_size (int): Number of rows to insert at once in the table.
124 | Default: 1000.
125 | """
126 | destination_table = vector_store.get_table_name()
127 |
128 | # Get row count in PGVector collection
129 | uuid_task = asyncio.create_task(__aget_collection_uuid(engine, collection_name))
130 | query = (
131 | f"SELECT COUNT(*) FROM {EMBEDDINGS_TABLE} WHERE collection_id=:collection_id"
132 | )
133 | async with engine._pool.connect() as conn:
134 | uuid = await uuid_task
135 | result = await conn.execute(text(query), parameters={"collection_id": uuid})
136 | result_map = result.mappings()
137 | collection_data_len = result_map.fetchone()
138 | if collection_data_len is None:
139 | warnings.warn(f"Collection, {collection_name} contains no elements.")
140 | return
141 |
142 | # Extract data from the collection and batch insert into the new table
143 | data_batches = __aextract_pgvector_collection(
144 | engine, collection_name, batch_size=insert_batch_size
145 | )
146 | await __concurrent_batch_insert(data_batches, vector_store, max_concurrency=100)
147 |
148 | # Validate data migration
149 | query = f"SELECT COUNT(*) FROM {destination_table}"
150 | async with engine._pool.connect() as conn:
151 | result = await conn.execute(text(query))
152 | result_map = result.mappings()
153 | table_size = result_map.fetchone()
154 | if not table_size:
155 | raise ValueError(f"Table: {destination_table} does not exist.")
156 |
157 | if collection_data_len["count"] != table_size["count"]:
158 | raise ValueError(
159 | "All data not yet migrated.\n"
160 | f"Original row count: {collection_data_len['count']}\n"
161 | f"Collection table, {destination_table} row count: {table_size['count']}"
162 | )
163 | elif delete_pg_collection:
164 | # Delete PGVector data
165 | query = f"DELETE FROM {EMBEDDINGS_TABLE} WHERE collection_id=:collection_id"
166 | async with engine._pool.connect() as conn:
167 | await conn.execute(text(query), parameters={"collection_id": uuid})
168 | await conn.commit()
169 |
170 | query = f"DELETE FROM {COLLECTIONS_TABLE} WHERE name=:collection_name"
171 | async with engine._pool.connect() as conn:
172 | await conn.execute(
173 | text(query), parameters={"collection_name": collection_name}
174 | )
175 | await conn.commit()
176 | print(f"Successfully deleted PGVector collection, {collection_name}")
177 |
178 |
179 | async def __alist_pgvector_collection_names(
180 | engine: PGEngine,
181 | ) -> list[str]:
182 | """Lists all collection names present in PGVector table."""
183 | try:
184 | query = f"SELECT name from {COLLECTIONS_TABLE}"
185 | async with engine._pool.connect() as conn:
186 | result = await conn.execute(text(query))
187 | result_map = result.mappings()
188 | all_rows = result_map.fetchall()
189 | return [row["name"] for row in all_rows]
190 | except ProgrammingError as e:
191 | raise ValueError(
192 | "Please provide the correct collection table name: " + str(e)
193 | ) from e
194 |
195 |
196 | async def aextract_pgvector_collection(
197 | engine: PGEngine,
198 | collection_name: str,
199 | batch_size: int = 1000,
200 | ) -> AsyncIterator[Sequence[RowMapping]]:
201 | """
202 | Extract all data belonging to a PGVector collection.
203 |
204 | Args:
205 | engine (PGEngine): The PG engine corresponding to the Database.
206 | collection_name (str): The name of the collection to get the data for.
207 | batch_size (int): The batch size for collection extraction.
208 | Default: 1000. Optional.
209 |
210 | Yields:
211 | The data present in the collection.
212 | """
213 | iterator = __aextract_pgvector_collection(engine, collection_name, batch_size)
214 | while True:
215 | try:
216 | result = await engine._run_as_async(iterator.__anext__())
217 | yield result
218 | except StopAsyncIteration:
219 | break
220 |
221 |
222 | async def alist_pgvector_collection_names(
223 | engine: PGEngine,
224 | ) -> list[str]:
225 | """Lists all collection names present in PGVector table."""
226 | return await engine._run_as_async(__alist_pgvector_collection_names(engine))
227 |
228 |
229 | async def amigrate_pgvector_collection(
230 | engine: PGEngine,
231 | collection_name: str,
232 | vector_store: PGVectorStore,
233 | delete_pg_collection: Optional[bool] = False,
234 | insert_batch_size: int = 1000,
235 | ) -> None:
236 | """
237 | Migrate all data present in a PGVector collection to use separate tables for each collection.
238 | The new data format is compatible with the PGVectorStore interface.
239 |
240 | Args:
241 | engine (PGEngine): The PG engine corresponding to the Database.
242 | collection_name (str): The collection to migrate.
243 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table.
244 | use_json_metadata (bool): An option to keep the PGVector metadata as json in the new table.
245 | Default: False. Optional.
246 | delete_pg_collection (bool): An option to delete the original data upon migration.
247 | Default: False. Optional.
248 | insert_batch_size (int): Number of rows to insert at once in the table.
249 | Default: 1000.
250 | """
251 | await engine._run_as_async(
252 | __amigrate_pgvector_collection(
253 | engine,
254 | collection_name,
255 | vector_store,
256 | delete_pg_collection,
257 | insert_batch_size,
258 | )
259 | )
260 |
261 |
262 | def extract_pgvector_collection(
263 | engine: PGEngine,
264 | collection_name: str,
265 | batch_size: int = 1000,
266 | ) -> Iterator[Sequence[RowMapping]]:
267 | """
268 | Extract all data belonging to a PGVector collection.
269 |
270 | Args:
271 | engine (PGEngine): The PG engine corresponding to the Database.
272 | collection_name (str): The name of the collection to get the data for.
273 | batch_size (int): The batch size for collection extraction.
274 | Default: 1000. Optional.
275 |
276 | Yields:
277 | The data present in the collection.
278 | """
279 | iterator = __aextract_pgvector_collection(engine, collection_name, batch_size)
280 | while True:
281 | try:
282 | result = engine._run_as_sync(iterator.__anext__())
283 | yield result
284 | except StopAsyncIteration:
285 | break
286 |
287 |
288 | def list_pgvector_collection_names(engine: PGEngine) -> list[str]:
289 | """Lists all collection names present in PGVector table."""
290 | return engine._run_as_sync(__alist_pgvector_collection_names(engine))
291 |
292 |
293 | def migrate_pgvector_collection(
294 | engine: PGEngine,
295 | collection_name: str,
296 | vector_store: PGVectorStore,
297 | delete_pg_collection: Optional[bool] = False,
298 | insert_batch_size: int = 1000,
299 | ) -> None:
300 | """
301 | Migrate all data present in a PGVector collection to use separate tables for each collection.
302 | The new data format is compatible with the PGVectorStore interface.
303 |
304 | Args:
305 | engine (PGEngine): The PG engine corresponding to the Database.
306 | collection_name (str): The collection to migrate.
307 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table.
308 | delete_pg_collection (bool): An option to delete the original data upon migration.
309 | Default: False. Optional.
310 | insert_batch_size (int): Number of rows to insert at once in the table.
311 | Default: 1000.
312 | """
313 | engine._run_as_sync(
314 | __amigrate_pgvector_collection(
315 | engine,
316 | collection_name,
317 | vector_store,
318 | delete_pg_collection,
319 | insert_batch_size,
320 | )
321 | )
322 |
--------------------------------------------------------------------------------
/langchain_postgres/v2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/langchain_postgres/v2/__init__.py
--------------------------------------------------------------------------------
/langchain_postgres/v2/hybrid_search_config.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from dataclasses import dataclass, field
3 | from typing import Any, Callable, Optional, Sequence
4 |
5 | from sqlalchemy import RowMapping
6 |
7 |
8 | def weighted_sum_ranking(
9 | primary_search_results: Sequence[RowMapping],
10 | secondary_search_results: Sequence[RowMapping],
11 | primary_results_weight: float = 0.5,
12 | secondary_results_weight: float = 0.5,
13 | fetch_top_k: int = 4,
14 | ) -> Sequence[dict[str, Any]]:
15 | """
16 | Ranks documents using a weighted sum of scores from two sources.
17 |
18 | Args:
19 | primary_search_results: A list of (document, distance) tuples from
20 | the primary search.
21 | secondary_search_results: A list of (document, distance) tuples from
22 | the secondary search.
23 | primary_results_weight: The weight for the primary source's scores.
24 | Defaults to 0.5.
25 | secondary_results_weight: The weight for the secondary source's scores.
26 | Defaults to 0.5.
27 | fetch_top_k: The number of documents to fetch after merging the results.
28 | Defaults to 4.
29 |
30 | Returns:
31 | A list of (document, distance) tuples, sorted by weighted_score in
32 | descending order.
33 | """
34 |
35 | # stores computed metric with provided distance metric and weights
36 | weighted_scores: dict[str, dict[str, Any]] = {}
37 |
38 | # Process results from primary source
39 | for row in primary_search_results:
40 | values = list(row.values())
41 | doc_id = str(values[0]) # first value is doc_id
42 | distance = float(values[-1]) # type: ignore # last value is distance
43 | row_values = dict(row)
44 | row_values["distance"] = primary_results_weight * distance
45 | weighted_scores[doc_id] = row_values
46 |
47 | # Process results from secondary source,
48 | # adding to existing scores or creating new ones
49 | for row in secondary_search_results:
50 | values = list(row.values())
51 | doc_id = str(values[0]) # first value is doc_id
52 | distance = float(values[-1]) # type: ignore # last value is distance
53 | primary_score = (
54 | weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0
55 | )
56 | row_values = dict(row)
57 | row_values["distance"] = distance * secondary_results_weight + primary_score
58 | weighted_scores[doc_id] = row_values
59 |
60 | # Sort the results by weighted score in descending order
61 | ranked_results = sorted(
62 | weighted_scores.values(), key=lambda item: item["distance"], reverse=True
63 | )
64 | return ranked_results[:fetch_top_k]
65 |
66 |
67 | def reciprocal_rank_fusion(
68 | primary_search_results: Sequence[RowMapping],
69 | secondary_search_results: Sequence[RowMapping],
70 | rrf_k: float = 60,
71 | fetch_top_k: int = 4,
72 | ) -> Sequence[dict[str, Any]]:
73 | """
74 | Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources.
75 |
76 | Args:
77 | primary_search_results: A list of (document, distance) tuples from
78 | the primary search.
79 | secondary_search_results: A list of (document, distance) tuples from
80 | the secondary search.
81 | rrf_k: The RRF parameter k.
82 | Defaults to 60.
83 | fetch_top_k: The number of documents to fetch after merging the results.
84 | Defaults to 4.
85 |
86 | Returns:
87 | A list of (document_id, rrf_score) tuples, sorted by rrf_score
88 | in descending order.
89 | """
90 | rrf_scores: dict[str, dict[str, Any]] = {}
91 |
92 | # Process results from primary source
93 | for rank, row in enumerate(
94 | sorted(primary_search_results, key=lambda item: item["distance"], reverse=True)
95 | ):
96 | values = list(row.values())
97 | doc_id = str(values[0])
98 | row_values = dict(row)
99 | primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
100 | primary_score += 1.0 / (rank + rrf_k)
101 | row_values["distance"] = primary_score
102 | rrf_scores[doc_id] = row_values
103 |
104 | # Process results from secondary source
105 | for rank, row in enumerate(
106 | sorted(
107 | secondary_search_results, key=lambda item: item["distance"], reverse=True
108 | )
109 | ):
110 | values = list(row.values())
111 | doc_id = str(values[0])
112 | row_values = dict(row)
113 | secondary_score = (
114 | rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
115 | )
116 | secondary_score += 1.0 / (rank + rrf_k)
117 | row_values["distance"] = secondary_score
118 | rrf_scores[doc_id] = row_values
119 |
120 | # Sort the results by rrf score in descending order
121 | # Sort the results by weighted score in descending order
122 | ranked_results = sorted(
123 | rrf_scores.values(), key=lambda item: item["distance"], reverse=True
124 | )
125 | # Extract only the RowMapping for the top results
126 | return ranked_results[:fetch_top_k]
127 |
128 |
129 | @dataclass
130 | class HybridSearchConfig(ABC):
131 | """
132 | AlloyDB Vector Store Hybrid Search Config.
133 |
134 | Queries might be slow if the hybrid search column does not exist.
135 | For best hybrid search performance, consider creating a TSV column
136 | and adding GIN index.
137 | """
138 |
139 | tsv_column: Optional[str] = ""
140 | tsv_lang: Optional[str] = "pg_catalog.english"
141 | fts_query: Optional[str] = ""
142 | fusion_function: Callable[
143 | [Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any]
144 | ] = weighted_sum_ranking # Updated default
145 | fusion_function_parameters: dict[str, Any] = field(default_factory=dict)
146 | primary_top_k: int = 4
147 | secondary_top_k: int = 4
148 | index_name: str = "langchain_tsv_index"
149 | index_type: str = "GIN"
150 |
--------------------------------------------------------------------------------
/langchain_postgres/v2/indexes.py:
--------------------------------------------------------------------------------
1 | """Index class to add vector indexes on the PGVectorStore.
2 |
3 | Learn more about vector indexes at https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
4 | """
5 |
6 | import enum
7 | import re
8 | import warnings
9 | from abc import ABC, abstractmethod
10 | from dataclasses import dataclass, field
11 | from typing import Optional
12 |
13 |
14 | @dataclass
15 | class StrategyMixin:
16 | operator: str
17 | search_function: str
18 | index_function: str
19 |
20 |
21 | class DistanceStrategy(StrategyMixin, enum.Enum):
22 | """Enumerator of the Distance strategies."""
23 |
24 | EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops"
25 | COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops"
26 | INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops"
27 |
28 |
29 | DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE
30 | DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex"
31 |
32 |
33 | def validate_identifier(identifier: str) -> None:
34 | if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is None:
35 | raise ValueError(
36 | f"Invalid identifier: {identifier}. Identifiers must start with a letter or underscore, and subsequent characters can be letters, digits, or underscores."
37 | )
38 |
39 |
40 | @dataclass
41 | class BaseIndex(ABC):
42 | """
43 | Abstract base class for defining vector indexes.
44 |
45 | Attributes:
46 | name (Optional[str]): A human-readable name for the index. Defaults to None.
47 | index_type (str): A string identifying the type of index. Defaults to "base".
48 | distance_strategy (DistanceStrategy): The strategy used to calculate distances
49 | between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE.
50 | partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None.
51 | extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None.
52 | """
53 |
54 | name: Optional[str] = None
55 | index_type: str = "base"
56 | distance_strategy: DistanceStrategy = field(
57 | default_factory=lambda: DistanceStrategy.COSINE_DISTANCE
58 | )
59 | partial_indexes: Optional[list[str]] = None
60 | extension_name: Optional[str] = None
61 |
62 | @abstractmethod
63 | def index_options(self) -> str:
64 | """Set index query options for vector store initialization."""
65 | raise NotImplementedError(
66 | "index_options method must be implemented by subclass"
67 | )
68 |
69 | def get_index_function(self) -> str:
70 | return self.distance_strategy.index_function
71 |
72 | def __post_init__(self) -> None:
73 | """Check if initialization parameters are valid.
74 |
75 | Raises:
76 | ValueError: extension_name is a valid postgreSQL identifier
77 | """
78 |
79 | if self.extension_name:
80 | validate_identifier(self.extension_name)
81 | if self.index_type:
82 | validate_identifier(self.index_type)
83 |
84 |
85 | @dataclass
86 | class ExactNearestNeighbor(BaseIndex):
87 | index_type: str = "exactnearestneighbor"
88 |
89 |
90 | @dataclass
91 | class QueryOptions(ABC):
92 | @abstractmethod
93 | def to_parameter(self) -> list[str]:
94 | """Convert index attributes to list of configurations."""
95 | raise NotImplementedError("to_parameter method must be implemented by subclass")
96 |
97 | @abstractmethod
98 | def to_string(self) -> str:
99 | """Convert index attributes to string."""
100 | raise NotImplementedError("to_string method must be implemented by subclass")
101 |
102 |
103 | @dataclass
104 | class HNSWIndex(BaseIndex):
105 | index_type: str = "hnsw"
106 | m: int = 16
107 | ef_construction: int = 64
108 |
109 | def index_options(self) -> str:
110 | """Set index query options for vector store initialization."""
111 | return f"(m = {self.m}, ef_construction = {self.ef_construction})"
112 |
113 |
114 | @dataclass
115 | class HNSWQueryOptions(QueryOptions):
116 | ef_search: int = 40
117 |
118 | def to_parameter(self) -> list[str]:
119 | """Convert index attributes to list of configurations."""
120 | return [f"hnsw.ef_search = {self.ef_search}"]
121 |
122 | def to_string(self) -> str:
123 | """Convert index attributes to string."""
124 | warnings.warn(
125 | "to_string is deprecated, use to_parameter instead.",
126 | DeprecationWarning,
127 | )
128 | return f"hnsw.ef_search = {self.ef_search}"
129 |
130 |
131 | @dataclass
132 | class IVFFlatIndex(BaseIndex):
133 | index_type: str = "ivfflat"
134 | lists: int = 100
135 |
136 | def index_options(self) -> str:
137 | """Set index query options for vector store initialization."""
138 | return f"(lists = {self.lists})"
139 |
140 |
141 | @dataclass
142 | class IVFFlatQueryOptions(QueryOptions):
143 | probes: int = 1
144 |
145 | def to_parameter(self) -> list[str]:
146 | """Convert index attributes to list of configurations."""
147 | return [f"ivfflat.probes = {self.probes}"]
148 |
149 | def to_string(self) -> str:
150 | """Convert index attributes to string."""
151 | warnings.warn(
152 | "to_string is deprecated, use to_parameter instead.",
153 | DeprecationWarning,
154 | )
155 | return f"ivfflat.probes = {self.probes}"
156 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "langchain-postgres"
3 | version = "0.0.14"
4 | description = "An integration package connecting Postgres and LangChain"
5 | authors = []
6 | readme = "README.md"
7 | repository = "https://github.com/langchain-ai/langchain-postgres"
8 | requires-python = ">=3.9"
9 | license = "MIT"
10 | dependencies = [
11 | "asyncpg>=0.30.0",
12 | "langchain-core>=0.2.13,<0.4.0",
13 | "pgvector>=0.2.5,<0.4",
14 | "psycopg>=3,<4",
15 | "psycopg-pool>=3.2.1,<4",
16 | "sqlalchemy>=2,<3",
17 | "numpy>=1.21,<3",
18 | ]
19 |
20 | [tool.poetry.urls]
21 | "Source Code" = "https://github.com/langchain-ai/langchain-postgres/tree/master/langchain_postgres"
22 |
23 | [dependency-groups]
24 | test = [
25 | "langchain-tests==0.3.7",
26 | "mypy>=1.15.0",
27 | "pytest>=8.3.4",
28 | "pytest-asyncio>=0.25.3",
29 | "pytest-cov>=6.0.0",
30 | "pytest-mock>=3.14.0",
31 | "pytest-socket>=0.7.0",
32 | "pytest-timeout>=2.3.1",
33 | "ruff>=0.9.7",
34 | ]
35 |
36 | [tool.ruff.lint]
37 | select = [
38 | "E", # pycodestyle
39 | "F", # pyflakes
40 | "I", # isort
41 | "T201", # print
42 | ]
43 |
44 | [tool.mypy]
45 | disallow_untyped_defs = "True"
46 |
47 | [tool.coverage.run]
48 | omit = ["tests/*"]
49 |
50 | [build-system]
51 | requires = ["hatchling"]
52 | build-backend = "hatchling.build"
53 |
54 | [tool.pytest.ini_options]
55 | # --strict-markers will raise errors on unknown marks.
56 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
57 | #
58 | # https://docs.pytest.org/en/7.1.x/reference/reference.html
59 | # --strict-config any warnings encountered while parsing the `pytest`
60 | # section of the configuration file raise errors.
61 | #
62 | # https://github.com/tophat/syrupy
63 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
64 | addopts = "--strict-markers --strict-config --durations=5"
65 | # Global timeout for all tests. There should be a good reason for a test to
66 | # takemore than 30 seconds.
67 | timeout = 30
68 | # Registering custom markers.
69 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
70 | markers = []
71 | asyncio_mode = "auto"
72 |
73 |
74 | [tool.codespell]
75 | skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig'
76 | ignore-regex = '.*(Stati Uniti|Tense=Pres).*'
77 | ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin'
78 |
79 | [tool.ruff.lint.extend-per-file-ignores]
80 | "tests/unit_tests/v2/**/*.py" = [
81 | "E501",
82 | ]
83 |
84 | "langchain_postgres/v2/**/*.py" = [
85 | "E501",
86 | ]
87 |
88 | "langchain_postgres/utils/**/*.py" = [
89 | "E501",
90 | "T201", # Allow print
91 | ]
92 |
93 |
94 | "examples/**/*.ipynb" = [
95 | "E501",
96 | "T201", # Allow print
97 | ]
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/security.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Reporting OSS Vulnerabilities
4 |
5 | LangChain is partnered with [huntr by Protect AI](https://huntr.com/) to provide
6 | a bounty program for our open source projects.
7 |
8 | Please report security vulnerabilities associated with the LangChain
9 | open source projects by visiting the following link:
10 |
11 | [https://huntr.com/bounties/disclose/](https://huntr.com/bounties/disclose/?target=https%3A%2F%2Fgithub.com%2Flangchain-ai%2Flangchain&validSearch=true)
12 |
13 | Before reporting a vulnerability, please review:
14 |
15 | 1) In-Scope Targets and Out-of-Scope Targets below.
16 | 2) The [langchain-ai/langchain](https://python.langchain.com/docs/contributing/repo_structure) monorepo structure.
17 | 3) LangChain [security guidelines](https://python.langchain.com/docs/security) to
18 | understand what we consider to be a security vulnerability vs. developer
19 | responsibility.
20 |
21 | ### In-Scope Targets
22 |
23 | The following packages and repositories are eligible for bug bounties:
24 |
25 | - langchain-core
26 | - langchain (see exceptions)
27 | - langchain-community (see exceptions)
28 | - langgraph
29 | - langserve
30 |
31 | ### Out of Scope Targets
32 |
33 | All out of scope targets defined by huntr as well as:
34 |
35 | - **langchain-experimental**: This repository is for experimental code and is not
36 | eligible for bug bounties, bug reports to it will be marked as interesting or waste of
37 | time and published with no bounty attached.
38 | - **tools**: Tools in either langchain or langchain-community are not eligible for bug
39 | bounties. This includes the following directories
40 | - langchain/tools
41 | - langchain-community/tools
42 | - Please review our [security guidelines](https://python.langchain.com/docs/security)
43 | for more details, but generally tools interact with the real world. Developers are
44 | expected to understand the security implications of their code and are responsible
45 | for the security of their tools.
46 | - Code documented with security notices. This will be decided done on a case by
47 | case basis, but likely will not be eligible for a bounty as the code is already
48 | documented with guidelines for developers that should be followed for making their
49 | application secure.
50 | - Any LangSmith related repositories or APIs see below.
51 |
52 | ## Reporting LangSmith Vulnerabilities
53 |
54 | Please report security vulnerabilities associated with LangSmith by email to `security@langchain.dev`.
55 |
56 | - LangSmith site: https://smith.langchain.com
57 | - SDK client: https://github.com/langchain-ai/langsmith-sdk
58 |
59 | ### Other Security Concerns
60 |
61 | For any other security concerns, please contact us at `security@langchain.dev`.
62 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/fake_embeddings.py:
--------------------------------------------------------------------------------
1 | """Copied from community."""
2 |
3 | from typing import List
4 |
5 | from langchain_core.embeddings import Embeddings
6 |
7 | fake_texts = ["foo", "bar", "baz"]
8 |
9 |
10 | class FakeEmbeddings(Embeddings):
11 | """Fake embeddings functionality for testing."""
12 |
13 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
14 | """Return simple embeddings.
15 | Embeddings encode each text as its index."""
16 | return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
17 |
18 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
19 | return self.embed_documents(texts)
20 |
21 | def embed_query(self, text: str) -> List[float]:
22 | """Return constant query embeddings.
23 | Embeddings are identical to embed_documents(texts)[0].
24 | Distance to each text will be that text's index,
25 | as it was passed to embed_documents."""
26 | return [float(1.0)] * 9 + [float(0.0)]
27 |
28 | async def aembed_query(self, text: str) -> List[float]:
29 | return self.embed_query(text)
30 |
--------------------------------------------------------------------------------
/tests/unit_tests/fixtures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/fixtures/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/fixtures/filtering_test_cases.py:
--------------------------------------------------------------------------------
1 | """Module needs to move to a stasndalone package."""
2 |
3 | from langchain_core.documents import Document
4 |
5 | metadatas = [
6 | {
7 | "name": "adam",
8 | "date": "2021-01-01",
9 | "count": 1,
10 | "is_active": True,
11 | "tags": ["a", "b"],
12 | "location": [1.0, 2.0],
13 | "id": 1,
14 | "height": 10.0, # Float column
15 | "happiness": 0.9, # Float column
16 | "sadness": 0.1, # Float column
17 | },
18 | {
19 | "name": "bob",
20 | "date": "2021-01-02",
21 | "count": 2,
22 | "is_active": False,
23 | "tags": ["b", "c"],
24 | "location": [2.0, 3.0],
25 | "id": 2,
26 | "height": 5.7, # Float column
27 | "happiness": 0.8, # Float column
28 | "sadness": 0.1, # Float column
29 | },
30 | {
31 | "name": "jane",
32 | "date": "2021-01-01",
33 | "count": 3,
34 | "is_active": True,
35 | "tags": ["b", "d"],
36 | "location": [3.0, 4.0],
37 | "id": 3,
38 | "height": 2.4, # Float column
39 | "happiness": None,
40 | # Sadness missing intentionally
41 | },
42 | ]
43 | texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas]
44 |
45 | DOCUMENTS = [
46 | Document(page_content=text, metadata=metadata)
47 | for text, metadata in zip(texts, metadatas)
48 | ]
49 |
50 |
51 | TYPE_1_FILTERING_TEST_CASES = [
52 | # These tests only involve equality checks
53 | (
54 | {"id": 1},
55 | [1],
56 | ),
57 | # String field
58 | (
59 | # check name
60 | {"name": "adam"},
61 | [1],
62 | ),
63 | # Boolean fields
64 | (
65 | {"is_active": True},
66 | [1, 3],
67 | ),
68 | (
69 | {"is_active": False},
70 | [2],
71 | ),
72 | # And semantics for top level filtering
73 | (
74 | {"id": 1, "is_active": True},
75 | [1],
76 | ),
77 | (
78 | {"id": 1, "is_active": False},
79 | [],
80 | ),
81 | ]
82 |
83 | TYPE_2_FILTERING_TEST_CASES = [
84 | # These involve equality checks and other operators
85 | # like $ne, $gt, $gte, $lt, $lte
86 | (
87 | {"id": 1},
88 | [1],
89 | ),
90 | (
91 | {"id": {"$ne": 1}},
92 | [2, 3],
93 | ),
94 | (
95 | {"id": {"$gt": 1}},
96 | [2, 3],
97 | ),
98 | (
99 | {"id": {"$gte": 1}},
100 | [1, 2, 3],
101 | ),
102 | (
103 | {"id": {"$lt": 1}},
104 | [],
105 | ),
106 | (
107 | {"id": {"$lte": 1}},
108 | [1],
109 | ),
110 | # Repeat all the same tests with name (string column)
111 | (
112 | {"name": "adam"},
113 | [1],
114 | ),
115 | (
116 | {"name": "bob"},
117 | [2],
118 | ),
119 | (
120 | {"name": {"$eq": "adam"}},
121 | [1],
122 | ),
123 | (
124 | {"name": {"$ne": "adam"}},
125 | [2, 3],
126 | ),
127 | # And also gt, gte, lt, lte relying on lexicographical ordering
128 | (
129 | {"name": {"$gt": "jane"}},
130 | [],
131 | ),
132 | (
133 | {"name": {"$gte": "jane"}},
134 | [3],
135 | ),
136 | (
137 | {"name": {"$lt": "jane"}},
138 | [1, 2],
139 | ),
140 | (
141 | {"name": {"$lte": "jane"}},
142 | [1, 2, 3],
143 | ),
144 | (
145 | {"is_active": {"$eq": True}},
146 | [1, 3],
147 | ),
148 | (
149 | {"is_active": {"$ne": True}},
150 | [2],
151 | ),
152 | # Test float column.
153 | (
154 | {"height": {"$gt": 5.0}},
155 | [1, 2],
156 | ),
157 | (
158 | {"height": {"$gte": 5.0}},
159 | [1, 2],
160 | ),
161 | (
162 | {"height": {"$lt": 5.0}},
163 | [3],
164 | ),
165 | (
166 | {"height": {"$lte": 5.8}},
167 | [2, 3],
168 | ),
169 | ]
170 |
171 | TYPE_3_FILTERING_TEST_CASES = [
172 | # These involve usage of AND, OR and NOT operators
173 | (
174 | {"$or": [{"id": 1}, {"id": 2}]},
175 | [1, 2],
176 | ),
177 | (
178 | {"$or": [{"id": 1}, {"name": "bob"}]},
179 | [1, 2],
180 | ),
181 | (
182 | {"$and": [{"id": 1}, {"id": 2}]},
183 | [],
184 | ),
185 | (
186 | {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]},
187 | [1, 2, 3],
188 | ),
189 | # Test for $not operator
190 | (
191 | {"$not": {"id": 1}},
192 | [2, 3],
193 | ),
194 | (
195 | {"$not": [{"id": 1}]},
196 | [2, 3],
197 | ),
198 | (
199 | {"$not": {"name": "adam"}},
200 | [2, 3],
201 | ),
202 | (
203 | {"$not": [{"name": "adam"}]},
204 | [2, 3],
205 | ),
206 | (
207 | {"$not": {"is_active": True}},
208 | [2],
209 | ),
210 | (
211 | {"$not": [{"is_active": True}]},
212 | [2],
213 | ),
214 | (
215 | {"$not": {"height": {"$gt": 5.0}}},
216 | [3],
217 | ),
218 | (
219 | {"$not": [{"height": {"$gt": 5.0}}]},
220 | [3],
221 | ),
222 | ]
223 |
224 | TYPE_4_FILTERING_TEST_CASES = [
225 | # These involve special operators like $in, $nin, $between
226 | # Test between
227 | (
228 | {"id": {"$between": (1, 2)}},
229 | [1, 2],
230 | ),
231 | (
232 | {"id": {"$between": (1, 1)}},
233 | [1],
234 | ),
235 | # Test in
236 | (
237 | {"name": {"$in": ["adam", "bob"]}},
238 | [1, 2],
239 | ),
240 | # With numeric fields
241 | (
242 | {"id": {"$in": [1, 2]}},
243 | [1, 2],
244 | ),
245 | # Test nin
246 | (
247 | {"name": {"$nin": ["adam", "bob"]}},
248 | [3],
249 | ),
250 | ## with numeric fields
251 | (
252 | {"id": {"$nin": [1, 2]}},
253 | [3],
254 | ),
255 | ]
256 |
257 | TYPE_5_FILTERING_TEST_CASES = [
258 | # These involve special operators like $like, $ilike that
259 | # may be specified to certain databases.
260 | (
261 | {"name": {"$like": "a%"}},
262 | [1],
263 | ),
264 | (
265 | {"name": {"$like": "%a%"}}, # adam and jane
266 | [1, 3],
267 | ),
268 | ]
269 |
270 | TYPE_6_FILTERING_TEST_CASES = [
271 | # These involve the special operator $exists
272 | (
273 | {"happiness": {"$exists": False}},
274 | [],
275 | ),
276 | (
277 | {"happiness": {"$exists": True}},
278 | [1, 2, 3],
279 | ),
280 | (
281 | {"sadness": {"$exists": False}},
282 | [3],
283 | ),
284 | (
285 | {"sadness": {"$exists": True}},
286 | [1, 2],
287 | ),
288 | ]
289 |
--------------------------------------------------------------------------------
/tests/unit_tests/fixtures/metadata_filtering_data.py:
--------------------------------------------------------------------------------
1 | METADATAS = [
2 | {
3 | "name": "Wireless Headphones",
4 | "code": "WH001",
5 | "price": 149.99,
6 | "is_available": True,
7 | "release_date": "2023-10-26",
8 | "tags": ["audio", "wireless", "electronics"],
9 | "dimensions": [18.5, 7.2, 21.0],
10 | "inventory_location": [101, 102],
11 | "available_quantity": 50,
12 | },
13 | {
14 | "name": "Ergonomic Office Chair",
15 | "code": "EC002",
16 | "price": 299.00,
17 | "is_available": True,
18 | "release_date": "2023-08-15",
19 | "tags": ["furniture", "office", "ergonomic"],
20 | "dimensions": [65.0, 60.0, 110.0],
21 | "inventory_location": [201],
22 | "available_quantity": 10,
23 | },
24 | {
25 | "name": "Stainless Steel Water Bottle",
26 | "code": "WB003",
27 | "price": 25.50,
28 | "is_available": False,
29 | "release_date": "2024-01-05",
30 | "tags": ["hydration", "eco-friendly", "kitchen"],
31 | "dimensions": [7.5, 7.5, 25.0],
32 | "available_quantity": 0,
33 | },
34 | {
35 | "name": "Smart Fitness Tracker",
36 | "code": "FT004",
37 | "price": 79.95,
38 | "is_available": True,
39 | "release_date": "2023-11-12",
40 | "tags": ["fitness", "wearable", "technology"],
41 | "dimensions": [2.0, 1.0, 25.0],
42 | "inventory_location": [401],
43 | "available_quantity": 100,
44 | },
45 | ]
46 |
47 | FILTERING_TEST_CASES = [
48 | # These tests only involve equality checks
49 | (
50 | {"code": "FT004"},
51 | ["FT004"],
52 | ),
53 | # String field
54 | (
55 | # check name
56 | {"name": "Smart Fitness Tracker"},
57 | ["FT004"],
58 | ),
59 | # Boolean fields
60 | (
61 | {"is_available": True},
62 | ["WH001", "FT004", "EC002"],
63 | ),
64 | # And semantics for top level filtering
65 | (
66 | {"code": "WH001", "is_available": True},
67 | ["WH001"],
68 | ),
69 | # These involve equality checks and other operators
70 | # like $ne, $gt, $gte, $lt, $lte
71 | (
72 | {"available_quantity": {"$eq": 10}},
73 | ["EC002"],
74 | ),
75 | (
76 | {"available_quantity": {"$ne": 0}},
77 | ["WH001", "FT004", "EC002"],
78 | ),
79 | (
80 | {"available_quantity": {"$gt": 60}},
81 | ["FT004"],
82 | ),
83 | (
84 | {"available_quantity": {"$gte": 50}},
85 | ["WH001", "FT004"],
86 | ),
87 | (
88 | {"available_quantity": {"$lt": 5}},
89 | ["WB003"],
90 | ),
91 | (
92 | {"available_quantity": {"$lte": 10}},
93 | ["WB003", "EC002"],
94 | ),
95 | # Repeat all the same tests with name (string column)
96 | (
97 | {"code": {"$eq": "WH001"}},
98 | ["WH001"],
99 | ),
100 | (
101 | {"code": {"$ne": "WB003"}},
102 | ["WH001", "FT004", "EC002"],
103 | ),
104 | # And also gt, gte, lt, lte relying on lexicographical ordering
105 | (
106 | {"name": {"$gt": "Wireless Headphones"}},
107 | [],
108 | ),
109 | (
110 | {"name": {"$gte": "Wireless Headphones"}},
111 | ["WH001"],
112 | ),
113 | (
114 | {"name": {"$lt": "Smart Fitness Tracker"}},
115 | ["EC002"],
116 | ),
117 | (
118 | {"name": {"$lte": "Smart Fitness Tracker"}},
119 | ["FT004", "EC002"],
120 | ),
121 | (
122 | {"is_available": {"$eq": True}},
123 | ["WH001", "FT004", "EC002"],
124 | ),
125 | (
126 | {"is_available": {"$ne": True}},
127 | ["WB003"],
128 | ),
129 | # Test float column.
130 | (
131 | {"price": {"$gt": 200.0}},
132 | ["EC002"],
133 | ),
134 | (
135 | {"price": {"$gte": 149.99}},
136 | ["WH001", "EC002"],
137 | ),
138 | (
139 | {"price": {"$lt": 50.0}},
140 | ["WB003"],
141 | ),
142 | (
143 | {"price": {"$lte": 79.95}},
144 | ["FT004", "WB003"],
145 | ),
146 | # These involve usage of AND, OR and NOT operators
147 | (
148 | {"$or": [{"code": "WH001"}, {"code": "EC002"}]},
149 | ["WH001", "EC002"],
150 | ),
151 | (
152 | {"$or": [{"code": "WH001"}, {"available_quantity": 10}]},
153 | ["WH001", "EC002"],
154 | ),
155 | (
156 | {"$and": [{"code": "WH001"}, {"code": "EC002"}]},
157 | [],
158 | ),
159 | # Test for $not operator
160 | (
161 | {"$not": {"code": "WB003"}},
162 | ["WH001", "FT004", "EC002"],
163 | ),
164 | (
165 | {"$not": [{"code": "WB003"}]},
166 | ["WH001", "FT004", "EC002"],
167 | ),
168 | (
169 | {"$not": {"available_quantity": 0}},
170 | ["WH001", "FT004", "EC002"],
171 | ),
172 | (
173 | {"$not": [{"available_quantity": 0}]},
174 | ["WH001", "FT004", "EC002"],
175 | ),
176 | (
177 | {"$not": {"is_available": True}},
178 | ["WB003"],
179 | ),
180 | (
181 | {"$not": [{"is_available": True}]},
182 | ["WB003"],
183 | ),
184 | (
185 | {"$not": {"price": {"$gt": 150.0}}},
186 | ["WH001", "FT004", "WB003"],
187 | ),
188 | (
189 | {"$not": [{"price": {"$gt": 150.0}}]},
190 | ["WH001", "FT004", "WB003"],
191 | ),
192 | # These involve special operators like $in, $nin, $between
193 | # Test between
194 | (
195 | {"available_quantity": {"$between": (40, 60)}},
196 | ["WH001"],
197 | ),
198 | # Test in
199 | (
200 | {"name": {"$in": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}},
201 | ["FT004", "WB003"],
202 | ),
203 | # With numeric fields
204 | (
205 | {"available_quantity": {"$in": [0, 10]}},
206 | ["WB003", "EC002"],
207 | ),
208 | # Test nin
209 | (
210 | {"name": {"$nin": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}},
211 | ["WH001", "EC002"],
212 | ),
213 | ## with numeric fields
214 | (
215 | {"available_quantity": {"$nin": [50, 0, 10]}},
216 | ["FT004"],
217 | ),
218 | # These involve special operators like $like, $ilike that
219 | # may be specified to certain databases.
220 | (
221 | {"name": {"$like": "Wireless%"}},
222 | ["WH001"],
223 | ),
224 | (
225 | {"name": {"$like": "%less%"}}, # adam and jane
226 | ["WH001", "WB003"],
227 | ),
228 | # These involve the special operator $exists
229 | (
230 | {"tags": {"$exists": False}},
231 | [],
232 | ),
233 | (
234 | {"inventory_location": {"$exists": False}},
235 | ["WB003"],
236 | ),
237 | ]
238 |
239 | NEGATIVE_TEST_CASES = [
240 | {"$nor": [{"code": "WH001"}, {"code": "EC002"}]},
241 | {"$and": {"is_available": True}},
242 | {"is_available": {"$and": True}},
243 | {"is_available": {"name": "{Wireless Headphones", "code": "EC002"}},
244 | {"my column": {"$and": True}},
245 | {"is_available": {"code": "WH001"}},
246 | {"$and": {}},
247 | {"$and": []},
248 | {"$not": True},
249 | ]
250 |
--------------------------------------------------------------------------------
/tests/unit_tests/query_constructors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/query_constructors/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/query_constructors/test_pgvector.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import pytest as pytest
4 | from langchain_core.structured_query import (
5 | Comparator,
6 | Comparison,
7 | Operation,
8 | Operator,
9 | StructuredQuery,
10 | )
11 |
12 | from langchain_postgres import PGVectorTranslator
13 |
14 | DEFAULT_TRANSLATOR = PGVectorTranslator()
15 |
16 |
17 | def test_visit_comparison() -> None:
18 | comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
19 | expected = {"foo": {"$lt": 1}}
20 | actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
21 | assert expected == actual
22 |
23 |
24 | @pytest.mark.skip("Not implemented")
25 | def test_visit_operation() -> None:
26 | op = Operation(
27 | operator=Operator.AND,
28 | arguments=[
29 | Comparison(comparator=Comparator.LT, attribute="foo", value=2),
30 | Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
31 | Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
32 | ],
33 | )
34 | expected = {
35 | "foo": {"$lt": 2},
36 | "bar": {"$eq": "baz"},
37 | "abc": {"$gt": 2.0},
38 | }
39 | actual = DEFAULT_TRANSLATOR.visit_operation(op)
40 | assert expected == actual
41 |
42 |
43 | def test_visit_structured_query() -> None:
44 | query = "What is the capital of France?"
45 | structured_query = StructuredQuery(
46 | query=query,
47 | filter=None,
48 | )
49 | expected: Tuple[str, Dict] = (query, {})
50 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
51 | assert expected == actual
52 |
53 | comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
54 | structured_query = StructuredQuery(
55 | query=query,
56 | filter=comp,
57 | )
58 | expected = (query, {"filter": {"foo": {"$lt": 1}}})
59 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
60 | assert expected == actual
61 |
62 | op = Operation(
63 | operator=Operator.AND,
64 | arguments=[
65 | Comparison(comparator=Comparator.LT, attribute="foo", value=2),
66 | Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
67 | Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
68 | ],
69 | )
70 | structured_query = StructuredQuery(
71 | query=query,
72 | filter=op,
73 | )
74 | expected = (
75 | query,
76 | {
77 | "filter": {
78 | "$and": [
79 | {"foo": {"$lt": 2}},
80 | {"bar": {"$eq": "baz"}},
81 | {"abc": {"$gt": 2.0}},
82 | ]
83 | }
84 | },
85 | )
86 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
87 | assert expected == actual
88 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_imports.py:
--------------------------------------------------------------------------------
1 | from langchain_postgres import __all__
2 |
3 | EXPECTED_ALL = [
4 | "__version__",
5 | "Column",
6 | "ColumnDict",
7 | "PGEngine",
8 | "PGVector",
9 | "PGVectorStore",
10 | "PGVectorTranslator",
11 | "PostgresChatMessageHistory",
12 | ]
13 |
14 |
15 | def test_all_imports() -> None:
16 | """Test that __all__ is correctly defined."""
17 | assert sorted(EXPECTED_ALL) == sorted(__all__)
18 |
--------------------------------------------------------------------------------
/tests/unit_tests/v1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/v1/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/v1/test_chat_histories.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
4 |
5 | from langchain_postgres.chat_message_histories import PostgresChatMessageHistory
6 | from tests.utils import asyncpg_client, syncpg_client
7 |
8 |
9 | def test_sync_chat_history() -> None:
10 | table_name = "chat_history"
11 | session_id = str(uuid.UUID(int=123))
12 | with syncpg_client() as sync_connection:
13 | PostgresChatMessageHistory.drop_table(sync_connection, table_name)
14 | PostgresChatMessageHistory.create_tables(sync_connection, table_name)
15 |
16 | chat_history = PostgresChatMessageHistory(
17 | table_name, session_id, sync_connection=sync_connection
18 | )
19 |
20 | messages = chat_history.messages
21 | assert messages == []
22 |
23 | assert chat_history is not None
24 |
25 | # Get messages from the chat history
26 | messages = chat_history.messages
27 | assert messages == []
28 |
29 | chat_history.add_messages(
30 | [
31 | SystemMessage(content="Meow"),
32 | AIMessage(content="woof"),
33 | HumanMessage(content="bark"),
34 | ]
35 | )
36 |
37 | # Get messages from the chat history
38 | messages = chat_history.messages
39 | assert len(messages) == 3
40 | assert messages == [
41 | SystemMessage(content="Meow"),
42 | AIMessage(content="woof"),
43 | HumanMessage(content="bark"),
44 | ]
45 |
46 | chat_history.add_messages(
47 | [
48 | SystemMessage(content="Meow"),
49 | AIMessage(content="woof"),
50 | HumanMessage(content="bark"),
51 | ]
52 | )
53 |
54 | messages = chat_history.messages
55 | assert len(messages) == 6
56 | assert messages == [
57 | SystemMessage(content="Meow"),
58 | AIMessage(content="woof"),
59 | HumanMessage(content="bark"),
60 | SystemMessage(content="Meow"),
61 | AIMessage(content="woof"),
62 | HumanMessage(content="bark"),
63 | ]
64 |
65 | chat_history.clear()
66 | assert chat_history.messages == []
67 |
68 |
69 | async def test_async_chat_history() -> None:
70 | """Test the async chat history."""
71 | async with asyncpg_client() as async_connection:
72 | table_name = "chat_history"
73 | session_id = str(uuid.UUID(int=125))
74 | await PostgresChatMessageHistory.adrop_table(async_connection, table_name)
75 | await PostgresChatMessageHistory.acreate_tables(async_connection, table_name)
76 |
77 | chat_history = PostgresChatMessageHistory(
78 | table_name, session_id, async_connection=async_connection
79 | )
80 |
81 | messages = await chat_history.aget_messages()
82 | assert messages == []
83 |
84 | # Add messages
85 | await chat_history.aadd_messages(
86 | [
87 | SystemMessage(content="Meow"),
88 | AIMessage(content="woof"),
89 | HumanMessage(content="bark"),
90 | ]
91 | )
92 | # Get the messages
93 | messages = await chat_history.aget_messages()
94 | assert len(messages) == 3
95 | assert messages == [
96 | SystemMessage(content="Meow"),
97 | AIMessage(content="woof"),
98 | HumanMessage(content="bark"),
99 | ]
100 |
101 | # Add more messages
102 | await chat_history.aadd_messages(
103 | [
104 | SystemMessage(content="Meow"),
105 | AIMessage(content="woof"),
106 | HumanMessage(content="bark"),
107 | ]
108 | )
109 | # Get the messages
110 | messages = await chat_history.aget_messages()
111 | assert len(messages) == 6
112 | assert messages == [
113 | SystemMessage(content="Meow"),
114 | AIMessage(content="woof"),
115 | HumanMessage(content="bark"),
116 | SystemMessage(content="Meow"),
117 | AIMessage(content="woof"),
118 | HumanMessage(content="bark"),
119 | ]
120 |
121 | # clear
122 | await chat_history.aclear()
123 | assert await chat_history.aget_messages() == []
124 |
--------------------------------------------------------------------------------
/tests/unit_tests/v1/test_vectorstore_standard_tests.py:
--------------------------------------------------------------------------------
1 | from typing import AsyncGenerator, Generator
2 |
3 | import pytest
4 | from langchain_core.vectorstores import VectorStore
5 | from langchain_tests.integration_tests import VectorStoreIntegrationTests
6 |
7 | from tests.unit_tests.v1.test_vectorstore import aget_vectorstore, get_vectorstore
8 |
9 |
10 | class TestSync(VectorStoreIntegrationTests):
11 | @pytest.fixture()
12 | def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
13 | """Get an empty vectorstore for unit tests."""
14 | with get_vectorstore(embedding=self.get_embeddings()) as vstore:
15 | vstore.drop_tables()
16 | vstore.create_tables_if_not_exists()
17 | vstore.create_collection()
18 | yield vstore
19 |
20 | @property
21 | def has_async(self) -> bool:
22 | return False # Skip async tests for sync vector store
23 |
24 |
25 | class TestAsync(VectorStoreIntegrationTests):
26 | @pytest.fixture()
27 | async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
28 | """Get an empty vectorstore for unit tests."""
29 | async with aget_vectorstore(embedding=self.get_embeddings()) as vstore:
30 | await vstore.adrop_tables()
31 | await vstore.acreate_tables_if_not_exists()
32 | await vstore.acreate_collection()
33 | yield vstore
34 |
35 | @property
36 | def has_sync(self) -> bool:
37 | return False # Skip sync tests for async vector store
38 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/v2/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_async_pg_vectorstore.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from typing import AsyncIterator, Sequence
3 |
4 | import pytest
5 | import pytest_asyncio
6 | from langchain_core.documents import Document
7 | from langchain_core.embeddings import DeterministicFakeEmbedding
8 | from sqlalchemy import text
9 | from sqlalchemy.engine.row import RowMapping
10 |
11 | from langchain_postgres import Column, PGEngine
12 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore
13 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
14 |
15 | DEFAULT_TABLE = "default" + str(uuid.uuid4())
16 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4())
17 | CUSTOM_TABLE = "custom" + str(uuid.uuid4())
18 | VECTOR_SIZE = 768
19 |
20 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
21 |
22 | texts = ["foo", "bar", "baz"]
23 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
24 | docs = [
25 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
26 | ]
27 |
28 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))]
29 |
30 |
31 | async def aexecute(engine: PGEngine, query: str) -> None:
32 | async with engine._pool.connect() as conn:
33 | await conn.execute(text(query))
34 | await conn.commit()
35 |
36 |
37 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]:
38 | async with engine._pool.connect() as conn:
39 | result = await conn.execute(text(query))
40 | result_map = result.mappings()
41 | result_fetch = result_map.fetchall()
42 | return result_fetch
43 |
44 |
45 | @pytest.mark.enable_socket
46 | @pytest.mark.asyncio(scope="class")
47 | class TestVectorStore:
48 | @pytest_asyncio.fixture(scope="class")
49 | async def engine(self) -> AsyncIterator[PGEngine]:
50 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
51 |
52 | yield engine
53 | await engine.adrop_table(DEFAULT_TABLE)
54 | await engine.adrop_table(CUSTOM_TABLE)
55 | await engine.close()
56 |
57 | @pytest_asyncio.fixture(scope="class")
58 | async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
59 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
60 | vs = await AsyncPGVectorStore.create(
61 | engine,
62 | embedding_service=embeddings_service,
63 | table_name=DEFAULT_TABLE,
64 | )
65 | yield vs
66 |
67 | @pytest_asyncio.fixture(scope="class")
68 | async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
69 | await engine._ainit_vectorstore_table(
70 | CUSTOM_TABLE,
71 | VECTOR_SIZE,
72 | id_column="myid",
73 | content_column="mycontent",
74 | embedding_column="myembedding",
75 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
76 | metadata_json_column="mymeta",
77 | )
78 | vs = await AsyncPGVectorStore.create(
79 | engine,
80 | embedding_service=embeddings_service,
81 | table_name=CUSTOM_TABLE,
82 | id_column="myid",
83 | content_column="mycontent",
84 | embedding_column="myembedding",
85 | metadata_columns=["page", "source"],
86 | metadata_json_column="mymeta",
87 | )
88 | yield vs
89 |
90 | async def test_init_with_constructor(self, engine: PGEngine) -> None:
91 | with pytest.raises(Exception):
92 | AsyncPGVectorStore(
93 | key={},
94 | engine=engine._pool,
95 | embedding_service=embeddings_service,
96 | table_name=CUSTOM_TABLE,
97 | id_column="myid",
98 | content_column="noname",
99 | embedding_column="myembedding",
100 | metadata_columns=["page", "source"],
101 | metadata_json_column="mymeta",
102 | )
103 |
104 | async def test_post_init(self, engine: PGEngine) -> None:
105 | with pytest.raises(ValueError):
106 | await AsyncPGVectorStore.create(
107 | engine,
108 | embedding_service=embeddings_service,
109 | table_name=CUSTOM_TABLE,
110 | id_column="myid",
111 | content_column="noname",
112 | embedding_column="myembedding",
113 | metadata_columns=["page", "source"],
114 | metadata_json_column="mymeta",
115 | )
116 |
117 | async def test_aadd_texts(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None:
118 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
119 | await vs.aadd_texts(texts, ids=ids)
120 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
121 | assert len(results) == 3
122 |
123 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
124 | await vs.aadd_texts(texts, metadatas, ids)
125 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
126 | assert len(results) == 6
127 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
128 |
129 | async def test_aadd_texts_edge_cases(
130 | self, engine: PGEngine, vs: AsyncPGVectorStore
131 | ) -> None:
132 | texts = ["Taylor's", '"Swift"', "best-friend"]
133 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
134 | await vs.aadd_texts(texts, ids=ids)
135 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
136 | assert len(results) == 3
137 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
138 |
139 | async def test_aadd_docs(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None:
140 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
141 | await vs.aadd_documents(docs, ids=ids)
142 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
143 | assert len(results) == 3
144 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
145 |
146 | async def test_aadd_docs_no_ids(
147 | self, engine: PGEngine, vs: AsyncPGVectorStore
148 | ) -> None:
149 | await vs.aadd_documents(docs)
150 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
151 | assert len(results) == 3
152 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
153 |
154 | async def test_adelete(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None:
155 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
156 | await vs.aadd_texts(texts, ids=ids)
157 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
158 | assert len(results) == 3
159 | # delete an ID
160 | await vs.adelete([ids[0]])
161 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"')
162 | assert len(results) == 2
163 | # delete with no ids
164 | result = await vs.adelete()
165 | assert not result
166 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"')
167 |
168 | ##### Custom Vector Store #####
169 | async def test_aadd_embeddings(
170 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore
171 | ) -> None:
172 | await vs_custom.aadd_embeddings(
173 | texts=texts, embeddings=embeddings, metadatas=metadatas
174 | )
175 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
176 | assert len(results) == 3
177 | assert results[0]["mycontent"] == "foo"
178 | assert results[0]["myembedding"]
179 | assert results[0]["page"] == "0"
180 | assert results[0]["source"] == "postgres"
181 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"')
182 |
183 | async def test_aadd_texts_custom(
184 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore
185 | ) -> None:
186 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
187 | await vs_custom.aadd_texts(texts, ids=ids)
188 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
189 | assert len(results) == 3
190 | assert results[0]["mycontent"] == "foo"
191 | assert results[0]["myembedding"]
192 | assert results[0]["page"] is None
193 | assert results[0]["source"] is None
194 |
195 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
196 | await vs_custom.aadd_texts(texts, metadatas, ids)
197 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
198 | assert len(results) == 6
199 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"')
200 |
201 | async def test_aadd_docs_custom(
202 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore
203 | ) -> None:
204 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
205 | docs = [
206 | Document(
207 | page_content=texts[i],
208 | metadata={"page": str(i), "source": "postgres"},
209 | )
210 | for i in range(len(texts))
211 | ]
212 | await vs_custom.aadd_documents(docs, ids=ids)
213 |
214 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
215 | assert len(results) == 3
216 | assert results[0]["mycontent"] == "foo"
217 | assert results[0]["myembedding"]
218 | assert results[0]["page"] == "0"
219 | assert results[0]["source"] == "postgres"
220 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"')
221 |
222 | async def test_adelete_custom(
223 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore
224 | ) -> None:
225 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
226 | await vs_custom.aadd_texts(texts, ids=ids)
227 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
228 | content = [result["mycontent"] for result in results]
229 | assert len(results) == 3
230 | assert "foo" in content
231 | # delete an ID
232 | await vs_custom.adelete([ids[0]])
233 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"')
234 | content = [result["mycontent"] for result in results]
235 | assert len(results) == 2
236 | assert "foo" not in content
237 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"')
238 |
239 | async def test_ignore_metadata_columns(self, engine: PGEngine) -> None:
240 | column_to_ignore = "source"
241 | vs = await AsyncPGVectorStore.create(
242 | engine,
243 | embedding_service=embeddings_service,
244 | table_name=CUSTOM_TABLE,
245 | ignore_metadata_columns=[column_to_ignore],
246 | id_column="myid",
247 | content_column="mycontent",
248 | embedding_column="myembedding",
249 | metadata_json_column="mymeta",
250 | )
251 | assert column_to_ignore not in vs.metadata_columns
252 |
253 | async def test_create_vectorstore_with_invalid_parameters_1(
254 | self, engine: PGEngine
255 | ) -> None:
256 | with pytest.raises(ValueError):
257 | await AsyncPGVectorStore.create(
258 | engine,
259 | embedding_service=embeddings_service,
260 | table_name=CUSTOM_TABLE,
261 | id_column="myid",
262 | content_column="mycontent",
263 | embedding_column="myembedding",
264 | metadata_columns=["random_column"], # invalid metadata column
265 | )
266 |
267 | async def test_create_vectorstore_with_invalid_parameters_2(
268 | self, engine: PGEngine
269 | ) -> None:
270 | with pytest.raises(ValueError):
271 | await AsyncPGVectorStore.create(
272 | engine,
273 | embedding_service=embeddings_service,
274 | table_name=CUSTOM_TABLE,
275 | id_column="myid",
276 | content_column="langchain_id", # invalid content column type
277 | embedding_column="myembedding",
278 | metadata_columns=["random_column"],
279 | )
280 |
281 | async def test_create_vectorstore_with_invalid_parameters_3(
282 | self, engine: PGEngine
283 | ) -> None:
284 | with pytest.raises(ValueError):
285 | await AsyncPGVectorStore.create(
286 | engine,
287 | embedding_service=embeddings_service,
288 | table_name=CUSTOM_TABLE,
289 | id_column="myid",
290 | content_column="mycontent",
291 | embedding_column="random_column", # invalid embedding column
292 | metadata_columns=["random_column"],
293 | )
294 |
295 | async def test_create_vectorstore_with_invalid_parameters_4(
296 | self, engine: PGEngine
297 | ) -> None:
298 | with pytest.raises(ValueError):
299 | await AsyncPGVectorStore.create(
300 | engine,
301 | embedding_service=embeddings_service,
302 | table_name=CUSTOM_TABLE,
303 | id_column="myid",
304 | content_column="mycontent",
305 | embedding_column="langchain_id", # invalid embedding column data type
306 | metadata_columns=["random_column"],
307 | )
308 |
309 | async def test_create_vectorstore_with_invalid_parameters_5(
310 | self, engine: PGEngine
311 | ) -> None:
312 | with pytest.raises(ValueError):
313 | await AsyncPGVectorStore.create(
314 | engine,
315 | embedding_service=embeddings_service,
316 | table_name=CUSTOM_TABLE,
317 | id_column="myid",
318 | content_column="mycontent",
319 | embedding_column="langchain_id",
320 | metadata_columns=["random_column"],
321 | ignore_metadata_columns=[
322 | "one",
323 | "two",
324 | ], # invalid use of metadata_columns and ignore columns
325 | )
326 |
327 | async def test_create_vectorstore_with_init(self, engine: PGEngine) -> None:
328 | with pytest.raises(Exception):
329 | AsyncPGVectorStore(
330 | key={},
331 | engine=engine._pool,
332 | embedding_service=embeddings_service,
333 | table_name=CUSTOM_TABLE,
334 | id_column="myid",
335 | content_column="mycontent",
336 | embedding_column="myembedding",
337 | metadata_columns=["random_column"], # invalid metadata column
338 | )
339 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_async_pg_vectorstore_from_methods.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import AsyncIterator, Sequence
4 |
5 | import pytest
6 | import pytest_asyncio
7 | from langchain_core.documents import Document
8 | from langchain_core.embeddings import DeterministicFakeEmbedding
9 | from sqlalchemy import text
10 | from sqlalchemy.engine.row import RowMapping
11 |
12 | from langchain_postgres import Column, PGEngine
13 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore
14 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
15 |
16 | DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_")
17 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_")
18 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_")
19 | CUSTOM_TABLE_WITH_INT_ID = "custom_sync" + str(uuid.uuid4()).replace("-", "_")
20 | VECTOR_SIZE = 768
21 |
22 |
23 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
24 |
25 | texts = ["foo", "bar", "baz"]
26 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
27 | docs = [
28 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
29 | ]
30 |
31 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))]
32 |
33 |
34 | def get_env_var(key: str, desc: str) -> str:
35 | v = os.environ.get(key)
36 | if v is None:
37 | raise ValueError(f"Must set env var {key} to: {desc}")
38 | return v
39 |
40 |
41 | async def aexecute(engine: PGEngine, query: str) -> None:
42 | async with engine._pool.connect() as conn:
43 | await conn.execute(text(query))
44 | await conn.commit()
45 |
46 |
47 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]:
48 | async with engine._pool.connect() as conn:
49 | result = await conn.execute(text(query))
50 | result_map = result.mappings()
51 | result_fetch = result_map.fetchall()
52 | return result_fetch
53 |
54 |
55 | @pytest.mark.enable_socket
56 | @pytest.mark.asyncio
57 | class TestVectorStoreFromMethods:
58 | @pytest_asyncio.fixture
59 | async def engine(self) -> AsyncIterator[PGEngine]:
60 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
61 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
62 | await engine._ainit_vectorstore_table(
63 | CUSTOM_TABLE,
64 | VECTOR_SIZE,
65 | id_column="myid",
66 | content_column="mycontent",
67 | embedding_column="myembedding",
68 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
69 | store_metadata=False,
70 | )
71 | await engine._ainit_vectorstore_table(
72 | CUSTOM_TABLE_WITH_INT_ID,
73 | VECTOR_SIZE,
74 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False),
75 | content_column="mycontent",
76 | embedding_column="myembedding",
77 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
78 | store_metadata=False,
79 | )
80 | yield engine
81 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
82 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}")
83 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}")
84 | await engine.close()
85 |
86 | async def test_afrom_texts(self, engine: PGEngine) -> None:
87 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
88 | await AsyncPGVectorStore.afrom_texts(
89 | texts,
90 | embeddings_service,
91 | engine,
92 | DEFAULT_TABLE,
93 | metadatas=metadatas,
94 | ids=ids,
95 | )
96 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}")
97 | assert len(results) == 3
98 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}")
99 |
100 | async def test_afrom_docs(self, engine: PGEngine) -> None:
101 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
102 | await AsyncPGVectorStore.afrom_documents(
103 | docs,
104 | embeddings_service,
105 | engine,
106 | DEFAULT_TABLE,
107 | ids=ids,
108 | )
109 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}")
110 | assert len(results) == 3
111 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}")
112 |
113 | async def test_afrom_texts_custom(self, engine: PGEngine) -> None:
114 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
115 | await AsyncPGVectorStore.afrom_texts(
116 | texts,
117 | embeddings_service,
118 | engine,
119 | CUSTOM_TABLE,
120 | ids=ids,
121 | id_column="myid",
122 | content_column="mycontent",
123 | embedding_column="myembedding",
124 | metadata_columns=["page", "source"],
125 | )
126 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}")
127 | assert len(results) == 3
128 | assert results[0]["mycontent"] == "foo"
129 | assert results[0]["myembedding"]
130 | assert results[0]["page"] is None
131 | assert results[0]["source"] is None
132 |
133 | async def test_afrom_docs_custom(self, engine: PGEngine) -> None:
134 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
135 | docs = [
136 | Document(
137 | page_content=texts[i],
138 | metadata={"page": str(i), "source": "postgres"},
139 | )
140 | for i in range(len(texts))
141 | ]
142 | await AsyncPGVectorStore.afrom_documents(
143 | docs,
144 | embeddings_service,
145 | engine,
146 | CUSTOM_TABLE,
147 | ids=ids,
148 | id_column="myid",
149 | content_column="mycontent",
150 | embedding_column="myembedding",
151 | metadata_columns=["page", "source"],
152 | )
153 |
154 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}")
155 | assert len(results) == 3
156 | assert results[0]["mycontent"] == "foo"
157 | assert results[0]["myembedding"]
158 | assert results[0]["page"] == "0"
159 | assert results[0]["source"] == "postgres"
160 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}")
161 |
162 | async def test_afrom_docs_custom_with_int_id(self, engine: PGEngine) -> None:
163 | ids = [i for i in range(len(texts))]
164 | docs = [
165 | Document(
166 | page_content=texts[i],
167 | metadata={"page": str(i), "source": "postgres"},
168 | )
169 | for i in range(len(texts))
170 | ]
171 | await AsyncPGVectorStore.afrom_documents(
172 | docs,
173 | embeddings_service,
174 | engine,
175 | CUSTOM_TABLE_WITH_INT_ID,
176 | ids=ids,
177 | id_column="integer_id",
178 | content_column="mycontent",
179 | embedding_column="myembedding",
180 | metadata_columns=["page", "source"],
181 | )
182 |
183 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}")
184 | assert len(results) == 3
185 | for row in results:
186 | assert isinstance(row["integer_id"], int)
187 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}")
188 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_async_pg_vectorstore_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import AsyncIterator
4 |
5 | import pytest
6 | import pytest_asyncio
7 | from langchain_core.documents import Document
8 | from langchain_core.embeddings import DeterministicFakeEmbedding
9 | from sqlalchemy import text
10 |
11 | from langchain_postgres import PGEngine
12 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore
13 | from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig
14 | from langchain_postgres.v2.indexes import DistanceStrategy, HNSWIndex, IVFFlatIndex
15 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
16 |
17 | uuid_str = str(uuid.uuid4()).replace("-", "_")
18 | DEFAULT_TABLE = "default" + uuid_str
19 | DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str
20 | DEFAULT_INDEX_NAME = "index" + uuid_str
21 | VECTOR_SIZE = 768
22 | SIMPLE_TABLE = "default_table"
23 |
24 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
25 |
26 | texts = ["foo", "bar", "baz"]
27 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
28 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
29 | docs = [
30 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
31 | ]
32 |
33 | embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))]
34 |
35 |
36 | def get_env_var(key: str, desc: str) -> str:
37 | v = os.environ.get(key)
38 | if v is None:
39 | raise ValueError(f"Must set env var {key} to: {desc}")
40 | return v
41 |
42 |
43 | async def aexecute(engine: PGEngine, query: str) -> None:
44 | async with engine._pool.connect() as conn:
45 | await conn.execute(text(query))
46 | await conn.commit()
47 |
48 |
49 | @pytest.mark.enable_socket
50 | @pytest.mark.asyncio(scope="class")
51 | class TestIndex:
52 | @pytest_asyncio.fixture(scope="class")
53 | async def engine(self) -> AsyncIterator[PGEngine]:
54 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
55 | yield engine
56 |
57 | await engine._adrop_table(DEFAULT_TABLE)
58 | await engine._adrop_table(DEFAULT_HYBRID_TABLE)
59 | await engine._adrop_table(SIMPLE_TABLE)
60 | await engine.close()
61 |
62 | @pytest_asyncio.fixture(scope="class")
63 | async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]:
64 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
65 | vs = await AsyncPGVectorStore.create(
66 | engine,
67 | embedding_service=embeddings_service,
68 | table_name=DEFAULT_TABLE,
69 | )
70 |
71 | await vs.aadd_texts(texts, ids=ids)
72 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME)
73 | yield vs
74 |
75 | async def test_apply_default_name_vector_index(self, engine: PGEngine) -> None:
76 | await engine._ainit_vectorstore_table(
77 | SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True
78 | )
79 | vs = await AsyncPGVectorStore.create(
80 | engine,
81 | embedding_service=embeddings_service,
82 | table_name=SIMPLE_TABLE,
83 | )
84 | await vs.aadd_texts(texts, ids=ids)
85 | await vs.adrop_vector_index()
86 | index = HNSWIndex()
87 | await vs.aapply_vector_index(index)
88 | assert await vs.is_valid_index()
89 | await vs.adrop_vector_index()
90 |
91 | async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None:
92 | index = HNSWIndex(name=DEFAULT_INDEX_NAME)
93 | await vs.aapply_vector_index(index)
94 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
95 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME)
96 |
97 | async def test_aapply_vector_index_non_hybrid_search_vs(
98 | self, vs: AsyncPGVectorStore
99 | ) -> None:
100 | with pytest.raises(ValueError):
101 | await vs.aapply_hybrid_search_index()
102 |
103 | async def test_aapply_hybrid_search_index_table_without_tsv_column(
104 | self, engine: PGEngine, vs: AsyncPGVectorStore
105 | ) -> None:
106 | # overwriting vs to get a hybrid vs
107 | tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str
108 | vs = await AsyncPGVectorStore.create(
109 | engine,
110 | embedding_service=embeddings_service,
111 | table_name=DEFAULT_TABLE,
112 | hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name),
113 | )
114 | is_valid_index = await vs.is_valid_index(tsv_index_name)
115 | assert is_valid_index == False
116 | await vs.aapply_hybrid_search_index()
117 | assert await vs.is_valid_index(tsv_index_name)
118 | await vs.adrop_vector_index(tsv_index_name)
119 | is_valid_index = await vs.is_valid_index(tsv_index_name)
120 | assert is_valid_index == False
121 |
122 | async def test_aapply_hybrid_search_index_table_with_tsv_column(
123 | self, engine: PGEngine
124 | ) -> None:
125 | tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str
126 | config = HybridSearchConfig(
127 | tsv_column="tsv_column",
128 | tsv_lang="pg_catalog.english",
129 | index_name=tsv_index_name,
130 | )
131 | await engine._ainit_vectorstore_table(
132 | DEFAULT_HYBRID_TABLE,
133 | VECTOR_SIZE,
134 | hybrid_search_config=config,
135 | )
136 | vs = await AsyncPGVectorStore.create(
137 | engine,
138 | embedding_service=embeddings_service,
139 | table_name=DEFAULT_HYBRID_TABLE,
140 | hybrid_search_config=config,
141 | )
142 | is_valid_index = await vs.is_valid_index(tsv_index_name)
143 | assert is_valid_index == False
144 | await vs.aapply_hybrid_search_index()
145 | assert await vs.is_valid_index(tsv_index_name)
146 | await vs.areindex(tsv_index_name)
147 | assert await vs.is_valid_index(tsv_index_name)
148 | await vs.adrop_vector_index(tsv_index_name)
149 | is_valid_index = await vs.is_valid_index(tsv_index_name)
150 | assert is_valid_index == False
151 |
152 | async def test_areindex(self, vs: AsyncPGVectorStore) -> None:
153 | if not await vs.is_valid_index(DEFAULT_INDEX_NAME):
154 | index = HNSWIndex(name=DEFAULT_INDEX_NAME)
155 | await vs.aapply_vector_index(index)
156 | await vs.areindex(DEFAULT_INDEX_NAME)
157 | await vs.areindex(DEFAULT_INDEX_NAME)
158 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
159 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME)
160 |
161 | async def test_dropindex(self, vs: AsyncPGVectorStore) -> None:
162 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME)
163 | result = await vs.is_valid_index(DEFAULT_INDEX_NAME)
164 | assert not result
165 |
166 | async def test_aapply_vector_index_ivfflat(self, vs: AsyncPGVectorStore) -> None:
167 | index = IVFFlatIndex(
168 | name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN
169 | )
170 | await vs.aapply_vector_index(index, concurrently=True)
171 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
172 | index = IVFFlatIndex(
173 | name="secondindex",
174 | distance_strategy=DistanceStrategy.INNER_PRODUCT,
175 | )
176 | await vs.aapply_vector_index(index)
177 | assert await vs.is_valid_index("secondindex")
178 | await vs.adrop_vector_index("secondindex")
179 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME)
180 |
181 | async def test_is_valid_index(self, vs: AsyncPGVectorStore) -> None:
182 | is_valid = await vs.is_valid_index("invalid_index")
183 | assert not is_valid
184 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_hybrid_search_config.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from langchain_postgres.v2.hybrid_search_config import (
4 | reciprocal_rank_fusion,
5 | weighted_sum_ranking,
6 | )
7 |
8 |
9 | # Helper to create mock input items that mimic RowMapping for the fusion functions
10 | def get_row(doc_id: str, score: float, content: str = "content") -> dict:
11 | """
12 | Simulates a RowMapping-like dictionary.
13 | The fusion functions expect to extract doc_id as the first value and
14 | the initial score/distance as the last value when casting values from RowMapping.
15 | They then operate on dictionaries, using the 'distance' key for the fused score.
16 | """
17 | # Python dicts maintain insertion order (Python 3.7+).
18 | # This structure ensures list(row.values())[0] is doc_id and
19 | # list(row.values())[-1] is score.
20 | return {"id_val": doc_id, "content_field": content, "distance": score}
21 |
22 |
23 | class TestWeightedSumRanking:
24 | def test_empty_inputs(self) -> None:
25 | results = weighted_sum_ranking([], [])
26 | assert results == []
27 |
28 | def test_primary_only(self) -> None:
29 | primary = [get_row("p1", 0.8), get_row("p2", 0.6)]
30 | # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3
31 | results = weighted_sum_ranking( # type: ignore
32 | primary, # type: ignore
33 | [],
34 | primary_results_weight=0.5,
35 | secondary_results_weight=0.5,
36 | )
37 | assert len(results) == 2
38 | assert results[0]["id_val"] == "p1"
39 | assert results[0]["distance"] == pytest.approx(0.4)
40 | assert results[1]["id_val"] == "p2"
41 | assert results[1]["distance"] == pytest.approx(0.3)
42 |
43 | def test_secondary_only(self) -> None:
44 | secondary = [get_row("s1", 0.9), get_row("s2", 0.7)]
45 | # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35
46 | results = weighted_sum_ranking(
47 | [],
48 | secondary, # type: ignore
49 | primary_results_weight=0.5,
50 | secondary_results_weight=0.5,
51 | )
52 | assert len(results) == 2
53 | assert results[0]["id_val"] == "s1"
54 | assert results[0]["distance"] == pytest.approx(0.45)
55 | assert results[1]["id_val"] == "s2"
56 | assert results[1]["distance"] == pytest.approx(0.35)
57 |
58 | def test_mixed_results_default_weights(self) -> None:
59 | primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
60 | secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
61 | # Weights are 0.5, 0.5
62 | # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85
63 | # p_only_score = (0.7 * 0.5) = 0.35
64 | # s_only_score = (0.6 * 0.5) = 0.30
65 | # Order: common (0.85), p_only (0.35), s_only (0.30)
66 |
67 | results = weighted_sum_ranking(primary, secondary) # type: ignore
68 | assert len(results) == 3
69 | assert results[0]["id_val"] == "common"
70 | assert results[0]["distance"] == pytest.approx(0.85)
71 | assert results[1]["id_val"] == "p_only"
72 | assert results[1]["distance"] == pytest.approx(0.35)
73 | assert results[2]["id_val"] == "s_only"
74 | assert results[2]["distance"] == pytest.approx(0.30)
75 |
76 | def test_mixed_results_custom_weights(self) -> None:
77 | primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2
78 | secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4
79 | # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6
80 |
81 | results = weighted_sum_ranking(
82 | primary, # type: ignore
83 | secondary, # type: ignore
84 | primary_results_weight=0.2,
85 | secondary_results_weight=0.8,
86 | )
87 | assert len(results) == 1
88 | assert results[0]["id_val"] == "d1"
89 | assert results[0]["distance"] == pytest.approx(0.6)
90 |
91 | def test_fetch_top_k(self) -> None:
92 | primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
93 | # Scores: 1.0, 0.9, 0.8, 0.7, 0.6
94 | # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3
95 | results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore
96 | assert len(results) == 2
97 | assert results[0]["id_val"] == "p0"
98 | assert results[0]["distance"] == pytest.approx(0.5)
99 | assert results[1]["id_val"] == "p1"
100 | assert results[1]["distance"] == pytest.approx(0.45)
101 |
102 |
103 | class TestReciprocalRankFusion:
104 | def test_empty_inputs(self) -> None:
105 | results = reciprocal_rank_fusion([], [])
106 | assert results == []
107 |
108 | def test_primary_only(self) -> None:
109 | primary = [
110 | get_row("p1", 0.8),
111 | get_row("p2", 0.6),
112 | ] # p1 rank 0, p2 rank 1
113 | rrf_k = 60
114 | # p1_score = 1 / (0 + 60)
115 | # p2_score = 1 / (1 + 60)
116 | results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore
117 | assert len(results) == 2
118 | assert results[0]["id_val"] == "p1"
119 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
120 | assert results[1]["id_val"] == "p2"
121 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
122 |
123 | def test_secondary_only(self) -> None:
124 | secondary = [
125 | get_row("s1", 0.9),
126 | get_row("s2", 0.7),
127 | ] # s1 rank 0, s2 rank 1
128 | rrf_k = 60
129 | results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore
130 | assert len(results) == 2
131 | assert results[0]["id_val"] == "s1"
132 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
133 | assert results[1]["id_val"] == "s2"
134 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
135 |
136 | def test_mixed_results_default_k(self) -> None:
137 | primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
138 | secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
139 | rrf_k = 60
140 | # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k
141 | # p_only_score = (1/(1+k))_prim = 1/(k+1)
142 | # s_only_score = (1/(1+k))_sec = 1/(k+1)
143 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
144 | assert len(results) == 3
145 | assert results[0]["id_val"] == "common"
146 | assert results[0]["distance"] == pytest.approx(2.0 / rrf_k)
147 | # Check the next two elements, their order might vary due to tie in score
148 | next_ids = {results[1]["id_val"], results[2]["id_val"]}
149 | next_scores = {results[1]["distance"], results[2]["distance"]}
150 | assert next_ids == {"p_only", "s_only"}
151 | for score in next_scores:
152 | assert score == pytest.approx(1.0 / (1 + rrf_k))
153 |
154 | def test_fetch_top_k_rrf(self) -> None:
155 | primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
156 | rrf_k = 1
157 | results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore
158 | assert len(results) == 2
159 | assert results[0]["id_val"] == "p0"
160 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
161 | assert results[1]["id_val"] == "p1"
162 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
163 |
164 | def test_rrf_content_preservation(self) -> None:
165 | primary = [get_row("doc1", 0.9, content="Primary Content")]
166 | secondary = [get_row("doc1", 0.8, content="Secondary Content")]
167 | # RRF processes primary then secondary. If a doc is in both,
168 | # the content from the secondary list will overwrite primary's.
169 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore
170 | assert len(results) == 1
171 | assert results[0]["id_val"] == "doc1"
172 | assert results[0]["content_field"] == "Secondary Content"
173 |
174 | # If only in primary
175 | results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore
176 | assert results_prim_only[0]["content_field"] == "Primary Content"
177 |
178 | def test_reordering_from_inputs_rrf(self) -> None:
179 | """
180 | Tests that RRF fused ranking can be different from both primary and secondary
181 | input rankings.
182 | Primary Order: A, B, C
183 | Secondary Order: C, B, A
184 | Fused Order: (A, C) tied, then B
185 | """
186 | primary = [
187 | get_row("docA", 0.9),
188 | get_row("docB", 0.8),
189 | get_row("docC", 0.1),
190 | ]
191 | secondary = [
192 | get_row("docC", 0.9),
193 | get_row("docB", 0.5),
194 | get_row("docA", 0.2),
195 | ]
196 | rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation
197 | # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3
198 | # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1
199 | # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3
200 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
201 | assert len(results) == 3
202 | assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"}
203 | assert results[0]["distance"] == pytest.approx(4.0 / 3.0)
204 | assert results[1]["distance"] == pytest.approx(4.0 / 3.0)
205 | assert results[2]["id_val"] == "docB"
206 | assert results[2]["distance"] == pytest.approx(1.0)
207 |
208 | def test_reordering_from_inputs_weighted_sum(self) -> None:
209 | """
210 | Tests that the fused ranking can be different from both primary and secondary
211 | input rankings.
212 | Primary Order: A (0.9), B (0.7)
213 | Secondary Order: B (0.8), A (0.2)
214 | Fusion (0.5/0.5 weights):
215 | docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55
216 | docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75
217 | Expected Fused Order: docB (0.75), docA (0.55)
218 | This is different from Primary (A,B) and Secondary (B,A) in terms of
219 | original score, but the fusion logic changes the effective contribution).
220 | """
221 | primary = [get_row("docA", 0.9), get_row("docB", 0.7)]
222 | secondary = [get_row("docB", 0.8), get_row("docA", 0.2)]
223 |
224 | results = weighted_sum_ranking(primary, secondary) # type: ignore
225 | assert len(results) == 2
226 | assert results[0]["id_val"] == "docB"
227 | assert results[0]["distance"] == pytest.approx(0.75)
228 | assert results[1]["id_val"] == "docA"
229 | assert results[1]["distance"] == pytest.approx(0.55)
230 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_indexes.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import pytest
4 |
5 | from langchain_postgres.v2.indexes import (
6 | DistanceStrategy,
7 | HNSWIndex,
8 | HNSWQueryOptions,
9 | IVFFlatIndex,
10 | IVFFlatQueryOptions,
11 | )
12 |
13 |
14 | @pytest.mark.enable_socket
15 | class TestPGIndex:
16 | def test_distance_strategy(self) -> None:
17 | assert DistanceStrategy.EUCLIDEAN.operator == "<->"
18 | assert DistanceStrategy.EUCLIDEAN.search_function == "l2_distance"
19 | assert DistanceStrategy.EUCLIDEAN.index_function == "vector_l2_ops"
20 |
21 | assert DistanceStrategy.COSINE_DISTANCE.operator == "<=>"
22 | assert DistanceStrategy.COSINE_DISTANCE.search_function == "cosine_distance"
23 | assert DistanceStrategy.COSINE_DISTANCE.index_function == "vector_cosine_ops"
24 |
25 | assert DistanceStrategy.INNER_PRODUCT.operator == "<#>"
26 | assert DistanceStrategy.INNER_PRODUCT.search_function == "inner_product"
27 | assert DistanceStrategy.INNER_PRODUCT.index_function == "vector_ip_ops"
28 |
29 | def test_hnsw_index(self) -> None:
30 | index = HNSWIndex(name="test_index", m=32, ef_construction=128)
31 | assert index.index_type == "hnsw"
32 | assert index.m == 32
33 | assert index.ef_construction == 128
34 | assert index.index_options() == "(m = 32, ef_construction = 128)"
35 |
36 | def test_hnsw_query_options(self) -> None:
37 | options = HNSWQueryOptions(ef_search=80)
38 | assert options.to_parameter() == ["hnsw.ef_search = 80"]
39 |
40 | with warnings.catch_warnings(record=True) as w:
41 | options.to_string()
42 |
43 | assert len(w) == 1
44 | assert "to_string is deprecated, use to_parameter instead." in str(
45 | w[-1].message
46 | )
47 |
48 | def test_ivfflat_index(self) -> None:
49 | index = IVFFlatIndex(name="test_index", lists=200)
50 | assert index.index_type == "ivfflat"
51 | assert index.lists == 200
52 | assert index.index_options() == "(lists = 200)"
53 |
54 | def test_ivfflat_query_options(self) -> None:
55 | options = IVFFlatQueryOptions(probes=2)
56 | assert options.to_parameter() == ["ivfflat.probes = 2"]
57 |
58 | with warnings.catch_warnings(record=True) as w:
59 | options.to_string()
60 | assert len(w) == 1
61 | assert "to_string is deprecated, use to_parameter instead." in str(
62 | w[-1].message
63 | )
64 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_pg_vectorstore_from_methods.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import AsyncIterator, Sequence
4 |
5 | import pytest
6 | import pytest_asyncio
7 | from langchain_core.documents import Document
8 | from langchain_core.embeddings import DeterministicFakeEmbedding
9 | from sqlalchemy import text
10 | from sqlalchemy.engine.row import RowMapping
11 |
12 | from langchain_postgres import Column, PGEngine, PGVectorStore
13 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
14 |
15 | DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_")
16 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_")
17 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_")
18 | CUSTOM_TABLE_WITH_INT_ID = "custom_int_id" + str(uuid.uuid4()).replace("-", "_")
19 | CUSTOM_TABLE_WITH_INT_ID_SYNC = "custom_int_id_sync" + str(uuid.uuid4()).replace(
20 | "-", "_"
21 | )
22 | VECTOR_SIZE = 768
23 |
24 |
25 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
26 |
27 | texts = ["foo", "bar", "baz"]
28 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
29 | docs = [
30 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
31 | ]
32 |
33 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))]
34 |
35 |
36 | def get_env_var(key: str, desc: str) -> str:
37 | v = os.environ.get(key)
38 | if v is None:
39 | raise ValueError(f"Must set env var {key} to: {desc}")
40 | return v
41 |
42 |
43 | async def aexecute(
44 | engine: PGEngine,
45 | query: str,
46 | ) -> None:
47 | async def run(engine: PGEngine, query: str) -> None:
48 | async with engine._pool.connect() as conn:
49 | await conn.execute(text(query))
50 | await conn.commit()
51 |
52 | await engine._run_as_async(run(engine, query))
53 |
54 |
55 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]:
56 | async def run(engine: PGEngine, query: str) -> Sequence[RowMapping]:
57 | async with engine._pool.connect() as conn:
58 | result = await conn.execute(text(query))
59 | result_map = result.mappings()
60 | result_fetch = result_map.fetchall()
61 | return result_fetch
62 |
63 | return await engine._run_as_async(run(engine, query))
64 |
65 |
66 | @pytest.mark.enable_socket
67 | @pytest.mark.asyncio
68 | class TestVectorStoreFromMethods:
69 | @pytest_asyncio.fixture
70 | async def engine(self) -> AsyncIterator[PGEngine]:
71 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
72 | await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
73 | await engine.ainit_vectorstore_table(
74 | CUSTOM_TABLE,
75 | VECTOR_SIZE,
76 | id_column="myid",
77 | content_column="mycontent",
78 | embedding_column="myembedding",
79 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
80 | store_metadata=False,
81 | )
82 | await engine.ainit_vectorstore_table(
83 | CUSTOM_TABLE_WITH_INT_ID,
84 | VECTOR_SIZE,
85 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False),
86 | content_column="mycontent",
87 | embedding_column="myembedding",
88 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
89 | store_metadata=False,
90 | )
91 | yield engine
92 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
93 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}")
94 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}")
95 | await engine.close()
96 |
97 | @pytest_asyncio.fixture
98 | async def engine_sync(self) -> AsyncIterator[PGEngine]:
99 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
100 | engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE)
101 | engine.init_vectorstore_table(
102 | CUSTOM_TABLE_WITH_INT_ID_SYNC,
103 | VECTOR_SIZE,
104 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False),
105 | content_column="mycontent",
106 | embedding_column="myembedding",
107 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
108 | store_metadata=False,
109 | )
110 |
111 | yield engine
112 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}")
113 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID_SYNC}")
114 | await engine.close()
115 |
116 | async def test_afrom_texts(self, engine: PGEngine) -> None:
117 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
118 | await PGVectorStore.afrom_texts(
119 | texts,
120 | embeddings_service,
121 | engine,
122 | DEFAULT_TABLE,
123 | metadatas=metadatas,
124 | ids=ids,
125 | )
126 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}")
127 | assert len(results) == 3
128 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}")
129 |
130 | async def test_from_texts(self, engine_sync: PGEngine) -> None:
131 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
132 | PGVectorStore.from_texts(
133 | texts,
134 | embeddings_service,
135 | engine_sync,
136 | DEFAULT_TABLE_SYNC,
137 | metadatas=metadatas,
138 | ids=ids,
139 | )
140 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
141 | assert len(results) == 3
142 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}")
143 |
144 | async def test_afrom_docs(self, engine: PGEngine) -> None:
145 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
146 | await PGVectorStore.afrom_documents(
147 | docs,
148 | embeddings_service,
149 | engine,
150 | DEFAULT_TABLE,
151 | ids=ids,
152 | )
153 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}")
154 | assert len(results) == 3
155 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}")
156 |
157 | async def test_from_docs(self, engine_sync: PGEngine) -> None:
158 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
159 | PGVectorStore.from_documents(
160 | docs,
161 | embeddings_service,
162 | engine_sync,
163 | DEFAULT_TABLE_SYNC,
164 | ids=ids,
165 | )
166 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
167 | assert len(results) == 3
168 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}")
169 |
170 | async def test_afrom_docs_cross_env(self, engine_sync: PGEngine) -> None:
171 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
172 | await PGVectorStore.afrom_documents(
173 | docs,
174 | embeddings_service,
175 | engine_sync,
176 | DEFAULT_TABLE_SYNC,
177 | ids=ids,
178 | )
179 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
180 | assert len(results) == 3
181 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}")
182 |
183 | async def test_from_docs_cross_env(
184 | self, engine: PGEngine, engine_sync: PGEngine
185 | ) -> None:
186 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
187 | PGVectorStore.from_documents(
188 | docs,
189 | embeddings_service,
190 | engine,
191 | DEFAULT_TABLE_SYNC,
192 | ids=ids,
193 | )
194 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE_SYNC}")
195 | assert len(results) == 3
196 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}")
197 |
198 | async def test_afrom_texts_custom(self, engine: PGEngine) -> None:
199 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
200 | await PGVectorStore.afrom_texts(
201 | texts,
202 | embeddings_service,
203 | engine,
204 | CUSTOM_TABLE,
205 | ids=ids,
206 | id_column="myid",
207 | content_column="mycontent",
208 | embedding_column="myembedding",
209 | metadata_columns=["page", "source"],
210 | )
211 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}")
212 | assert len(results) == 3
213 | assert results[0]["mycontent"] == "foo"
214 | assert results[0]["myembedding"]
215 | assert results[0]["page"] is None
216 | assert results[0]["source"] is None
217 |
218 | async def test_afrom_docs_custom(self, engine: PGEngine) -> None:
219 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
220 | docs = [
221 | Document(
222 | page_content=texts[i],
223 | metadata={"page": str(i), "source": "postgres"},
224 | )
225 | for i in range(len(texts))
226 | ]
227 | await PGVectorStore.afrom_documents(
228 | docs,
229 | embeddings_service,
230 | engine,
231 | CUSTOM_TABLE,
232 | ids=ids,
233 | id_column="myid",
234 | content_column="mycontent",
235 | embedding_column="myembedding",
236 | metadata_columns=["page", "source"],
237 | )
238 |
239 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}")
240 | assert len(results) == 3
241 | assert results[0]["mycontent"] == "foo"
242 | assert results[0]["myembedding"]
243 | assert results[0]["page"] == "0"
244 | assert results[0]["source"] == "postgres"
245 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}")
246 |
247 | async def test_afrom_texts_custom_with_int_id(self, engine: PGEngine) -> None:
248 | ids = [i for i in range(len(texts))]
249 | await PGVectorStore.afrom_texts(
250 | texts,
251 | embeddings_service,
252 | engine,
253 | CUSTOM_TABLE_WITH_INT_ID,
254 | metadatas=metadatas,
255 | ids=ids,
256 | id_column="integer_id",
257 | content_column="mycontent",
258 | embedding_column="myembedding",
259 | metadata_columns=["page", "source"],
260 | )
261 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}")
262 | assert len(results) == 3
263 | for row in results:
264 | assert isinstance(row["integer_id"], int)
265 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}")
266 |
267 | async def test_from_texts_custom_with_int_id(self, engine_sync: PGEngine) -> None:
268 | ids = [i for i in range(len(texts))]
269 | PGVectorStore.from_texts(
270 | texts,
271 | embeddings_service,
272 | engine_sync,
273 | CUSTOM_TABLE_WITH_INT_ID_SYNC,
274 | ids=ids,
275 | id_column="integer_id",
276 | content_column="mycontent",
277 | embedding_column="myembedding",
278 | metadata_columns=["page", "source"],
279 | )
280 | results = await afetch(
281 | engine_sync, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID_SYNC}"
282 | )
283 | assert len(results) == 3
284 | for row in results:
285 | assert isinstance(row["integer_id"], int)
286 | await aexecute(engine_sync, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID_SYNC}")
287 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_pg_vectorstore_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import AsyncIterator
4 |
5 | import pytest
6 | import pytest_asyncio
7 | from langchain_core.documents import Document
8 | from langchain_core.embeddings import DeterministicFakeEmbedding
9 | from sqlalchemy import text
10 |
11 | from langchain_postgres import PGEngine, PGVectorStore
12 | from langchain_postgres.v2.indexes import (
13 | DistanceStrategy,
14 | HNSWIndex,
15 | IVFFlatIndex,
16 | )
17 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
18 |
19 | uuid_str = str(uuid.uuid4()).replace("-", "_")
20 | uuid_str_sync = str(uuid.uuid4()).replace("-", "_")
21 | DEFAULT_TABLE = "default" + uuid_str
22 | DEFAULT_TABLE_ASYNC = "default_sync" + uuid_str_sync
23 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_")
24 | DEFAULT_INDEX_NAME = "index" + uuid_str
25 | DEFAULT_INDEX_NAME_ASYNC = "index" + uuid_str_sync
26 | VECTOR_SIZE = 768
27 |
28 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
29 |
30 | texts = ["foo", "bar", "baz"]
31 | ids = [str(uuid.uuid4()) for i in range(len(texts))]
32 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))]
33 | docs = [
34 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts))
35 | ]
36 |
37 | embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))]
38 |
39 |
40 | def get_env_var(key: str, desc: str) -> str:
41 | v = os.environ.get(key)
42 | if v is None:
43 | raise ValueError(f"Must set env var {key} to: {desc}")
44 | return v
45 |
46 |
47 | async def aexecute(
48 | engine: PGEngine,
49 | query: str,
50 | ) -> None:
51 | async def run(engine: PGEngine, query: str) -> None:
52 | async with engine._pool.connect() as conn:
53 | await conn.execute(text(query))
54 | await conn.commit()
55 |
56 | await engine._run_as_async(run(engine, query))
57 |
58 |
59 | @pytest.mark.enable_socket
60 | @pytest.mark.asyncio(scope="class")
61 | class TestIndex:
62 | @pytest_asyncio.fixture(scope="class")
63 | async def engine(self) -> AsyncIterator[PGEngine]:
64 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
65 | yield engine
66 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}")
67 | await engine.close()
68 |
69 | @pytest_asyncio.fixture(scope="class")
70 | async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]:
71 | engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
72 | vs = PGVectorStore.create_sync(
73 | engine,
74 | embedding_service=embeddings_service,
75 | table_name=DEFAULT_TABLE,
76 | )
77 |
78 | vs.add_texts(texts, ids=ids)
79 | vs.drop_vector_index(DEFAULT_INDEX_NAME)
80 | yield vs
81 |
82 | async def test_aapply_vector_index(self, vs: PGVectorStore) -> None:
83 | index = HNSWIndex(name=DEFAULT_INDEX_NAME)
84 | vs.apply_vector_index(index)
85 | assert vs.is_valid_index(DEFAULT_INDEX_NAME)
86 | vs.drop_vector_index(DEFAULT_INDEX_NAME)
87 |
88 | async def test_areindex(self, vs: PGVectorStore) -> None:
89 | if not vs.is_valid_index(DEFAULT_INDEX_NAME):
90 | index = HNSWIndex(name=DEFAULT_INDEX_NAME)
91 | vs.apply_vector_index(index)
92 | vs.reindex(DEFAULT_INDEX_NAME)
93 | vs.reindex(DEFAULT_INDEX_NAME)
94 | assert vs.is_valid_index(DEFAULT_INDEX_NAME)
95 | vs.drop_vector_index(DEFAULT_INDEX_NAME)
96 |
97 | async def test_dropindex(self, vs: PGVectorStore) -> None:
98 | vs.drop_vector_index(DEFAULT_INDEX_NAME)
99 | result = vs.is_valid_index(DEFAULT_INDEX_NAME)
100 | assert not result
101 |
102 | async def test_aapply_vector_index_ivfflat(self, vs: PGVectorStore) -> None:
103 | index = IVFFlatIndex(
104 | name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN
105 | )
106 | vs.apply_vector_index(index, concurrently=True)
107 | assert vs.is_valid_index(DEFAULT_INDEX_NAME)
108 | index = IVFFlatIndex(
109 | name="secondindex",
110 | distance_strategy=DistanceStrategy.INNER_PRODUCT,
111 | )
112 | vs.apply_vector_index(index)
113 | assert vs.is_valid_index("secondindex")
114 | vs.drop_vector_index("secondindex")
115 | vs.drop_vector_index(DEFAULT_INDEX_NAME)
116 |
117 | async def test_is_valid_index(self, vs: PGVectorStore) -> None:
118 | is_valid = vs.is_valid_index("invalid_index")
119 | assert not is_valid
120 |
121 |
122 | @pytest.mark.enable_socket
123 | @pytest.mark.asyncio(scope="class")
124 | class TestAsyncIndex:
125 | @pytest_asyncio.fixture(scope="class")
126 | async def engine(self) -> AsyncIterator[PGEngine]:
127 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
128 | yield engine
129 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}")
130 | await engine.close()
131 |
132 | @pytest_asyncio.fixture(scope="class")
133 | async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]:
134 | await engine.ainit_vectorstore_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE)
135 | vs = await PGVectorStore.create(
136 | engine,
137 | embedding_service=embeddings_service,
138 | table_name=DEFAULT_TABLE_ASYNC,
139 | )
140 |
141 | await vs.aadd_texts(texts, ids=ids)
142 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC)
143 | yield vs
144 |
145 | async def test_aapply_vector_index(self, vs: PGVectorStore) -> None:
146 | index = HNSWIndex(name=DEFAULT_INDEX_NAME_ASYNC)
147 | await vs.aapply_vector_index(index)
148 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC)
149 |
150 | async def test_areindex(self, vs: PGVectorStore) -> None:
151 | if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC):
152 | index = HNSWIndex(name=DEFAULT_INDEX_NAME_ASYNC)
153 | await vs.aapply_vector_index(index)
154 | await vs.areindex(DEFAULT_INDEX_NAME_ASYNC)
155 | await vs.areindex(DEFAULT_INDEX_NAME_ASYNC)
156 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC)
157 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC)
158 |
159 | async def test_dropindex(self, vs: PGVectorStore) -> None:
160 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC)
161 | result = await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC)
162 | assert not result
163 |
164 | async def test_aapply_vector_index_ivfflat(self, vs: PGVectorStore) -> None:
165 | index = IVFFlatIndex(
166 | name=DEFAULT_INDEX_NAME_ASYNC, distance_strategy=DistanceStrategy.EUCLIDEAN
167 | )
168 | await vs.aapply_vector_index(index, concurrently=True)
169 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC)
170 | index = IVFFlatIndex(
171 | name="secondindex",
172 | distance_strategy=DistanceStrategy.INNER_PRODUCT,
173 | )
174 | await vs.aapply_vector_index(index)
175 | assert await vs.ais_valid_index("secondindex")
176 | await vs.adrop_vector_index("secondindex")
177 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC)
178 |
179 | async def test_is_valid_index(self, vs: PGVectorStore) -> None:
180 | is_valid = await vs.ais_valid_index("invalid_index")
181 | assert not is_valid
182 |
--------------------------------------------------------------------------------
/tests/unit_tests/v2/test_pg_vectorstore_standard_suite.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import AsyncGenerator, AsyncIterator
4 |
5 | import pytest
6 | import pytest_asyncio
7 | from langchain_tests.integration_tests import VectorStoreIntegrationTests
8 | from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE
9 | from sqlalchemy import text
10 |
11 | from langchain_postgres import Column, PGEngine, PGVectorStore
12 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING
13 |
14 | DEFAULT_TABLE = "standard" + str(uuid.uuid4())
15 | DEFAULT_TABLE_SYNC = "sync_standard" + str(uuid.uuid4())
16 |
17 |
18 | def get_env_var(key: str, desc: str) -> str:
19 | v = os.environ.get(key)
20 | if v is None:
21 | raise ValueError(f"Must set env var {key} to: {desc}")
22 | return v
23 |
24 |
25 | async def aexecute(
26 | engine: PGEngine,
27 | query: str,
28 | ) -> None:
29 | async def run(engine: PGEngine, query: str) -> None:
30 | async with engine._pool.connect() as conn:
31 | await conn.execute(text(query))
32 | await conn.commit()
33 |
34 | await engine._run_as_async(run(engine, query))
35 |
36 |
37 | @pytest.mark.enable_socket
38 | # @pytest.mark.filterwarnings("ignore")
39 | @pytest.mark.asyncio
40 | class TestStandardSuiteSync(VectorStoreIntegrationTests):
41 | @pytest_asyncio.fixture(scope="function")
42 | async def sync_engine(self) -> AsyncGenerator[PGEngine, None]:
43 | sync_engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
44 | yield sync_engine
45 | await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"')
46 | await sync_engine.close()
47 |
48 | @pytest.fixture(scope="function")
49 | def vectorstore(self, sync_engine: PGEngine) -> PGVectorStore: # type: ignore
50 | """Get an empty vectorstore for unit tests."""
51 | sync_engine.init_vectorstore_table(
52 | DEFAULT_TABLE_SYNC,
53 | EMBEDDING_SIZE,
54 | id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False),
55 | )
56 |
57 | vs = PGVectorStore.create_sync(
58 | sync_engine,
59 | embedding_service=self.get_embeddings(),
60 | table_name=DEFAULT_TABLE_SYNC,
61 | )
62 | yield vs
63 |
64 |
65 | @pytest.mark.enable_socket
66 | @pytest.mark.asyncio
67 | class TestStandardSuiteAsync(VectorStoreIntegrationTests):
68 | @pytest_asyncio.fixture(scope="function")
69 | async def async_engine(self) -> AsyncIterator[PGEngine]:
70 | async_engine = PGEngine.from_connection_string(url=CONNECTION_STRING)
71 | yield async_engine
72 | await aexecute(async_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
73 | await async_engine.close()
74 |
75 | @pytest_asyncio.fixture(scope="function")
76 | async def vectorstore( # type: ignore[override]
77 | self, async_engine: PGEngine
78 | ) -> AsyncGenerator[PGVectorStore, None]: # type: ignore
79 | """Get an empty vectorstore for unit tests."""
80 | await async_engine.ainit_vectorstore_table(
81 | DEFAULT_TABLE,
82 | EMBEDDING_SIZE,
83 | id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False),
84 | )
85 |
86 | vs = await PGVectorStore.create(
87 | async_engine,
88 | embedding_service=self.get_embeddings(),
89 | table_name=DEFAULT_TABLE,
90 | )
91 |
92 | yield vs
93 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | """Get fixtures for the database connection."""
2 |
3 | import os
4 | from contextlib import asynccontextmanager, contextmanager
5 |
6 | import psycopg
7 | from typing_extensions import AsyncGenerator, Generator
8 |
9 | # Env variables match the default settings in the docker-compose file
10 | # located in the root of the repository: [root]/docker-compose.yml
11 | # Non-standard ports are used to avoid conflicts with other local postgres
12 | # instances.
13 | # To spint up the postgres service for testing, run:
14 | # cd [root]/docker-compose.yml
15 | # docker-compose up pgvector
16 | POSTGRES_USER = os.environ.get("POSTGRES_USER", "langchain")
17 | POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost")
18 | POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "langchain")
19 | POSTGRES_DB = os.environ.get("POSTGRES_DB", "langchain_test")
20 |
21 | POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5432")
22 |
23 | DSN = (
24 | f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}"
25 | f":{POSTGRES_PORT}/{POSTGRES_DB}"
26 | )
27 |
28 | # Connection string used primarily by the vectorstores tests
29 | # it's written to work with SQLAlchemy (takes a driver name)
30 | # It is also running on a postgres instance that has the pgvector extension
31 | VECTORSTORE_CONNECTION_STRING = (
32 | f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}"
33 | f":{POSTGRES_PORT}/{POSTGRES_DB}"
34 | )
35 |
36 |
37 | @asynccontextmanager
38 | async def asyncpg_client() -> AsyncGenerator[psycopg.AsyncConnection, None]:
39 | # Establish a connection to your test database
40 | conn = await psycopg.AsyncConnection.connect(conninfo=DSN)
41 | try:
42 | yield conn
43 | finally:
44 | # Cleanup: close the connection after the test is done
45 | await conn.close()
46 |
47 |
48 | @contextmanager
49 | def syncpg_client() -> Generator[psycopg.Connection, None, None]:
50 | # Establish a connection to your test database
51 | conn = psycopg.connect(conninfo=DSN)
52 | try:
53 | yield conn
54 | finally:
55 | # Cleanup: close the connection after the test is done
56 | conn.close()
57 |
--------------------------------------------------------------------------------