├── .github ├── actions │ └── uv_setup │ │ └── action.yml └── workflows │ ├── _lint.yml │ ├── _release.yml │ ├── _test.yml │ └── ci.yml ├── .gitignore ├── CONTRIBUTING.md ├── DEVELOPMENT.md ├── LICENSE ├── Makefile ├── README.md ├── docker-compose.yml ├── docs └── v2_design_overview.md ├── examples ├── migrate_pgvector_to_pgvectorstore.ipynb ├── migrate_pgvector_to_pgvectorstore.md ├── pg_vectorstore.ipynb ├── pg_vectorstore_how_to.ipynb └── vectorstore.ipynb ├── langchain_postgres ├── __init__.py ├── _utils.py ├── chat_message_histories.py ├── py.typed ├── translator.py ├── utils │ └── pgvector_migrator.py ├── v2 │ ├── __init__.py │ ├── async_vectorstore.py │ ├── engine.py │ ├── hybrid_search_config.py │ ├── indexes.py │ └── vectorstores.py └── vectorstores.py ├── pyproject.toml ├── security.md ├── tests ├── __init__.py ├── unit_tests │ ├── __init__.py │ ├── fake_embeddings.py │ ├── fixtures │ │ ├── __init__.py │ │ ├── filtering_test_cases.py │ │ └── metadata_filtering_data.py │ ├── query_constructors │ │ ├── __init__.py │ │ └── test_pgvector.py │ ├── test_imports.py │ ├── v1 │ │ ├── __init__.py │ │ ├── test_chat_histories.py │ │ ├── test_vectorstore.py │ │ └── test_vectorstore_standard_tests.py │ └── v2 │ │ ├── __init__.py │ │ ├── test_async_pg_vectorstore.py │ │ ├── test_async_pg_vectorstore_from_methods.py │ │ ├── test_async_pg_vectorstore_index.py │ │ ├── test_async_pg_vectorstore_search.py │ │ ├── test_engine.py │ │ ├── test_hybrid_search_config.py │ │ ├── test_indexes.py │ │ ├── test_pg_vectorstore.py │ │ ├── test_pg_vectorstore_from_methods.py │ │ ├── test_pg_vectorstore_index.py │ │ ├── test_pg_vectorstore_search.py │ │ └── test_pg_vectorstore_standard_suite.py └── utils.py └── uv.lock /.github/actions/uv_setup/action.yml: -------------------------------------------------------------------------------- 1 | # TODO: https://docs.astral.sh/uv/guides/integration/github/#caching 2 | 3 | name: uv-install 4 | description: Set up Python and uv 5 | 6 | inputs: 7 | python-version: 8 | description: Python version, supporting MAJOR.MINOR only 9 | required: true 10 | 11 | runs: 12 | using: composite 13 | steps: 14 | - name: Install uv and set the python version 15 | uses: astral-sh/setup-uv@v5 16 | with: 17 | version: ${{ env.UV_VERSION }} 18 | python-version: ${{ inputs.python-version }} 19 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | python-version: 11 | required: true 12 | type: string 13 | description: "Python version to use" 14 | 15 | env: 16 | WORKDIR: ${{ inputs.working-directory == '' && '.' || inputs.working-directory }} 17 | 18 | # This env var allows us to get inline annotations when ruff has complaints. 19 | RUFF_OUTPUT_FORMAT: github 20 | UV_FROZEN: "true" 21 | 22 | jobs: 23 | build: 24 | name: "make lint #${{ inputs.python-version }}" 25 | runs-on: ubuntu-latest 26 | timeout-minutes: 20 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: Set up Python ${{ inputs.python-version }} + uv 31 | uses: "./.github/actions/uv_setup" 32 | with: 33 | python-version: ${{ inputs.python-version }} 34 | 35 | - name: Install dependencies 36 | working-directory: ${{ inputs.working-directory }} 37 | run: | 38 | uv sync --group test 39 | 40 | - name: Analysing the code with our lint 41 | working-directory: ${{ inputs.working-directory }} 42 | run: | 43 | make lint 44 | -------------------------------------------------------------------------------- /.github/workflows/_release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | run-name: Release ${{ inputs.working-directory }} by @${{ github.actor }} 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | workflow_dispatch: 11 | inputs: 12 | working-directory: 13 | description: "From which folder this pipeline executes" 14 | default: "libs/server" 15 | required: true 16 | type: choice 17 | options: 18 | - "." 19 | dangerous-nonmain-release: 20 | required: false 21 | type: boolean 22 | default: false 23 | description: "Release from a non-main branch (danger!)" 24 | 25 | env: 26 | PYTHON_VERSION: "3.11" 27 | UV_FROZEN: "true" 28 | UV_NO_SYNC: "true" 29 | 30 | jobs: 31 | build: 32 | if: github.ref == 'refs/heads/main' || inputs.dangerous-nonmain-release 33 | environment: Scheduled testing 34 | runs-on: ubuntu-latest 35 | 36 | outputs: 37 | pkg-name: ${{ steps.check-version.outputs.pkg-name }} 38 | version: ${{ steps.check-version.outputs.version }} 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | 43 | - name: Set up Python + uv 44 | uses: "./.github/actions/uv_setup" 45 | with: 46 | python-version: ${{ env.PYTHON_VERSION }} 47 | 48 | # We want to keep this build stage *separate* from the release stage, 49 | # so that there's no sharing of permissions between them. 50 | # The release stage has trusted publishing and GitHub repo contents write access, 51 | # and we want to keep the scope of that access limited just to the release job. 52 | # Otherwise, a malicious `build` step (e.g. via a compromised dependency) 53 | # could get access to our GitHub or PyPI credentials. 54 | # 55 | # Per the trusted publishing GitHub Action: 56 | # > It is strongly advised to separate jobs for building [...] 57 | # > from the publish job. 58 | # https://github.com/pypa/gh-action-pypi-publish#non-goals 59 | - name: Build project for distribution 60 | run: uv build 61 | working-directory: ${{ inputs.working-directory }} 62 | - name: Upload build 63 | uses: actions/upload-artifact@v4 64 | with: 65 | name: dist 66 | path: ${{ inputs.working-directory }}/dist/ 67 | 68 | - name: Check Version 69 | id: check-version 70 | shell: python 71 | working-directory: ${{ inputs.working-directory }} 72 | run: | 73 | import os 74 | import tomllib 75 | with open("pyproject.toml", "rb") as f: 76 | data = tomllib.load(f) 77 | pkg_name = data["project"]["name"] 78 | version = data["project"]["version"] 79 | with open(os.environ["GITHUB_OUTPUT"], "a") as f: 80 | f.write(f"pkg-name={pkg_name}\n") 81 | f.write(f"version={version}\n") 82 | 83 | publish: 84 | needs: 85 | - build 86 | runs-on: ubuntu-latest 87 | permissions: 88 | # This permission is used for trusted publishing: 89 | # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ 90 | # 91 | # Trusted publishing has to also be configured on PyPI for each package: 92 | # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ 93 | id-token: write 94 | 95 | defaults: 96 | run: 97 | working-directory: ${{ inputs.working-directory }} 98 | 99 | steps: 100 | - uses: actions/checkout@v4 101 | 102 | - name: Set up Python + uv 103 | uses: "./.github/actions/uv_setup" 104 | with: 105 | python-version: ${{ env.PYTHON_VERSION }} 106 | 107 | - uses: actions/download-artifact@v4 108 | with: 109 | name: dist 110 | path: ${{ inputs.working-directory }}/dist/ 111 | 112 | - name: Publish package distributions to PyPI 113 | uses: pypa/gh-action-pypi-publish@release/v1 114 | with: 115 | packages-dir: ${{ inputs.working-directory }}/dist/ 116 | verbose: true 117 | print-hash: true 118 | # Temp workaround since attestations are on by default as of gh-action-pypi-publish v1.11.0 119 | attestations: false 120 | 121 | mark-release: 122 | needs: 123 | - build 124 | - publish 125 | runs-on: ubuntu-latest 126 | permissions: 127 | # This permission is needed by `ncipollo/release-action` to 128 | # create the GitHub release. 129 | contents: write 130 | 131 | defaults: 132 | run: 133 | working-directory: ${{ inputs.working-directory }} 134 | 135 | steps: 136 | - uses: actions/checkout@v4 137 | 138 | - name: Set up Python + uv 139 | uses: "./.github/actions/uv_setup" 140 | with: 141 | python-version: ${{ env.PYTHON_VERSION }} 142 | 143 | - uses: actions/download-artifact@v4 144 | with: 145 | name: dist 146 | path: ${{ inputs.working-directory }}/dist/ 147 | 148 | - name: Create Tag 149 | uses: ncipollo/release-action@v1 150 | with: 151 | artifacts: "dist/*" 152 | token: ${{ secrets.GITHUB_TOKEN }} 153 | generateReleaseNotes: true 154 | tag: ${{needs.build.outputs.pkg-name}}==${{ needs.build.outputs.version }} 155 | body: ${{ needs.release-notes.outputs.release-body }} 156 | commit: main 157 | makeLatest: true 158 | -------------------------------------------------------------------------------- /.github/workflows/_test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | working-directory: 7 | required: true 8 | type: string 9 | description: "From which folder this pipeline executes" 10 | 11 | env: 12 | UV_FROZEN: "true" 13 | UV_NO_SYNC: "true" 14 | 15 | jobs: 16 | build: 17 | defaults: 18 | run: 19 | working-directory: ${{ inputs.working-directory }} 20 | runs-on: ubuntu-latest 21 | services: 22 | postgres: 23 | # ensure postgres version this stays in sync with prod database 24 | # and with postgres version used in docker compose 25 | # Testing with postgres that has the pg vector extension 26 | image: ankane/pgvector 27 | env: 28 | # optional (defaults to `postgres`) 29 | POSTGRES_DB: langchain_test 30 | # required 31 | POSTGRES_PASSWORD: langchain 32 | # optional (defaults to `5432`) 33 | POSTGRES_PORT: 5432 34 | # optional (defaults to `postgres`) 35 | POSTGRES_USER: langchain 36 | ports: 37 | # maps tcp port 5432 on service container to the host 38 | - 5432:5432 39 | # set health checks to wait until postgres has started 40 | options: >- 41 | --health-cmd pg_isready 42 | --health-interval 3s 43 | --health-timeout 5s 44 | --health-retries 10 45 | strategy: 46 | matrix: 47 | python-version: 48 | # - "3.9" 49 | # - "3.10" 50 | # - "3.11" 51 | - "3.12" 52 | name: Python ${{ matrix.python-version }} 53 | steps: 54 | - uses: actions/checkout@v4 55 | - name: Install postgresql-client 56 | run: | 57 | sudo apt-get update 58 | sudo apt-get install -y postgresql-client 59 | - name: Test database connection 60 | run: | 61 | # Test psql connection 62 | psql -h localhost -p 5432 -U langchain -d langchain_test -c "SELECT 1;" 63 | 64 | if [ $? -ne 0 ]; then 65 | echo "Postgres connection failed" 66 | exit 1 67 | else 68 | echo "Postgres connection successful" 69 | fi 70 | env: 71 | # postgress password is required; alternatively, you can run: 72 | # `PGPASSWORD=postgres_password psql ...` 73 | PGPASSWORD: langchain 74 | - name: Set up Python ${{ inputs.python-version }} + uv 75 | uses: "./.github/actions/uv_setup" 76 | id: setup-python 77 | with: 78 | python-version: ${{ inputs.python-version }} 79 | - name: Install dependencies 80 | shell: bash 81 | run: uv sync --group test 82 | - name: Run unit tests 83 | shell: bash 84 | run: | 85 | make test 86 | - name: Ensure the tests did not create any additional files 87 | shell: bash 88 | run: | 89 | set -eu 90 | 91 | STATUS="$(git status)" 92 | echo "$STATUS" 93 | 94 | # grep will exit non-zero if the target message isn't found, 95 | # and `set -e` above will cause the step to fail. 96 | echo "$STATUS" | grep 'nothing to commit, working tree clean' 97 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Run CI Tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | workflow_dispatch: # Allows to trigger the workflow manually in GitHub UI 9 | 10 | # If another push to the same PR or branch happens while this workflow is still running, 11 | # cancel the earlier run in favor of the next run. 12 | # 13 | # There's no point in testing an outdated version of the code. GitHub only allows 14 | # a limited number of job runners to be active at the same time, so it's better to cancel 15 | # pointless jobs early so that more useful jobs can run sooner. 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.ref }} 18 | cancel-in-progress: true 19 | 20 | env: 21 | UV_FROZEN: "true" 22 | UV_NO_SYNC: "true" 23 | UV_VERSION: "0.5.25" 24 | WORKDIR: "." 25 | 26 | jobs: 27 | lint: 28 | strategy: 29 | matrix: 30 | # Only lint on the min and max supported Python versions. 31 | # It's extremely unlikely that there's a lint issue on any version in between 32 | # that doesn't show up on the min or max versions. 33 | # 34 | # GitHub rate-limits how many jobs can be running at any one time. 35 | # Starting new jobs is also relatively slow, 36 | # so linting on fewer versions makes CI faster. 37 | python-version: 38 | - "3.12" 39 | uses: 40 | ./.github/workflows/_lint.yml 41 | with: 42 | working-directory: "." 43 | python-version: ${{ matrix.python-version }} 44 | secrets: inherit 45 | test: 46 | uses: 47 | ./.github/workflows/_test.yml 48 | with: 49 | working-directory: "." 50 | secrets: inherit 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | .DS_Store 163 | 164 | # Pycharm 165 | .idea 166 | 167 | # pyenv virtualenv 168 | .python-version 169 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to langchain-postgres 2 | 3 | This guide is intended to help you get started contributing to langchain-postgres. 4 | As an open-source project in a rapidly developing field, we are extremely open 5 | to contributions, whether it be in the form of a new feature, improved infra, or better documentation. 6 | 7 | To contribute to this project, please follow the [fork and pull request](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow. 8 | 9 | ## Reporting bugs or suggesting improvements 10 | 11 | Our [GitHub issues](https://github.com/langchain-ai/langchain-postgres/issues) page is kept up to date 12 | with bugs, improvements, and feature requests. There is a taxonomy of labels to help 13 | with sorting and discovery of issues of interest. [See this page](https://github.com/langchain-ai/langchain-postgres/labels) for an overview of 14 | the system we use to tag our issues and pull requests. 15 | 16 | If you're looking for help with your code, consider posting a question on the 17 | [GitHub Discussions board](https://github.com/langchain-ai/langchain/discussions). Please 18 | understand that we won't be able to provide individual support via email. We 19 | also believe that help is much more valuable if it's **shared publicly**, 20 | so that more people can benefit from it. 21 | 22 | - **Describing your issue:** Try to provide as many details as possible. What 23 | exactly goes wrong? _How_ is it failing? Is there an error? 24 | "XY doesn't work" usually isn't that helpful for tracking down problems. Always 25 | remember to include the code you ran and if possible, extract only the relevant 26 | parts and don't just dump your entire script. This will make it easier for us to 27 | reproduce the error. 28 | 29 | - **Sharing long blocks of code or logs:** If you need to include long code, 30 | logs or tracebacks, you can wrap them in `
` and `
`. This 31 | [collapses the content](https://developer.mozilla.org/en/docs/Web/HTML/Element/details) 32 | so it only becomes visible on click, making the issue easier to read and follow. 33 | 34 | ## Contributing code and documentation 35 | 36 | You can develop langchain-postgres locally and contribute to the Project! 37 | 38 | See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up and using a development environment. 39 | 40 | ## Opening a pull request 41 | 42 | Once you wrote and manually tested your change, you can start sending the patch to the main repository. 43 | 44 | - Open a new GitHub pull request with the patch against the `main` branch. 45 | - Ensure the PR title follows semantic commits conventions. 46 | - For example, `feat: add new feature`, `fix: correct issue with X`. 47 | - Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 48 | -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Setting up a Development Environment 2 | 3 | This document details how to set up a local development environment that will 4 | allow you to contribute changes to the project. 5 | 6 | Acquire sources and create virtualenv. 7 | ```shell 8 | git clone https://github.com/langchain-ai/langchain-postgres 9 | cd langchain-postgres 10 | uv venv --python=3.13 11 | source .venv/bin/activate 12 | ``` 13 | 14 | Install package in editable mode. 15 | ```shell 16 | poetry install --with dev,test,lint 17 | ``` 18 | 19 | Start PostgreSQL/PGVector. 20 | ```shell 21 | docker run --rm -it --name pgvector-container \ 22 | -e POSTGRES_USER=langchain \ 23 | -e POSTGRES_PASSWORD=langchain \ 24 | -e POSTGRES_DB=langchain \ 25 | -p 6024:5432 pgvector/pgvector:pg16 \ 26 | postgres -c log_statement=all 27 | ``` 28 | 29 | Invoke test cases. 30 | ```shell 31 | pytest -vvv 32 | ``` 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LangChain, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all lint format test help 2 | 3 | # Default target executed when no arguments are given to make. 4 | all: help 5 | 6 | ###################### 7 | # TESTING AND COVERAGE 8 | ###################### 9 | 10 | # Define a variable for the test file path. 11 | TEST_FILE ?= tests/unit_tests/ 12 | 13 | test: 14 | uv run pytest --disable-socket --allow-unix-socket $(TEST_FILE) 15 | 16 | test_watch: 17 | uv run ptw . -- $(TEST_FILE) 18 | 19 | 20 | ###################### 21 | # LINTING AND FORMATTING 22 | ###################### 23 | 24 | # Define a variable for Python and notebook files. 25 | lint format: PYTHON_FILES=. 26 | lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=. --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') 27 | 28 | lint lint_diff: 29 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff 30 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff 31 | [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES) 32 | 33 | format format_diff: 34 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) 35 | [ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES) 36 | 37 | spell_check: 38 | uv run codespell --toml pyproject.toml 39 | 40 | spell_fix: 41 | uv run codespell --toml pyproject.toml -w 42 | 43 | ###################### 44 | # HELP 45 | ###################### 46 | 47 | help: 48 | @echo '====================' 49 | @echo '-- LINTING --' 50 | @echo 'format - run code formatters' 51 | @echo 'lint - run linters' 52 | @echo 'spell_check - run codespell on the project' 53 | @echo 'spell_fix - run codespell on the project and fix the errors' 54 | @echo '-- TESTS --' 55 | @echo 'coverage - run unit tests and generate coverage report' 56 | @echo 'test - run unit tests' 57 | @echo 'test TEST_FILE= - run all tests in file' 58 | @echo '-- DOCUMENTATION tasks are from the top-level Makefile --' 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # langchain-postgres 2 | 3 | [![Release Notes](https://img.shields.io/github/release/langchain-ai/langchain-postgres)](https://github.com/langchain-ai/langchain-postgres/releases) 4 | [![CI](https://github.com/langchain-ai/langchain-postgres/actions/workflows/ci.yml/badge.svg)](https://github.com/langchain-ai/langchain-postgres/actions/workflows/ci.yml) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/langchainai.svg?style=social&label=Follow%20%40LangChainAI)](https://twitter.com/langchainai) 7 | [![](https://dcbadge.vercel.app/api/server/6adMQxSpJS?compact=true&style=flat)](https://discord.gg/6adMQxSpJS) 8 | [![Open Issues](https://img.shields.io/github/issues-raw/langchain-ai/langchain-postgres)](https://github.com/langchain-ai/langchain-postgres/issues) 9 | 10 | The `langchain-postgres` package implementations of core LangChain abstractions using `Postgres`. 11 | 12 | The package is released under the MIT license. 13 | 14 | Feel free to use the abstraction as provided or else modify them / extend them as appropriate for your own application. 15 | 16 | ## Requirements 17 | 18 | The package supports the [asyncpg](https://github.com/MagicStack/asyncpg) and [psycopg3](https://www.psycopg.org/psycopg3/) drivers. 19 | 20 | ## Installation 21 | 22 | ```bash 23 | pip install -U langchain-postgres 24 | ``` 25 | 26 | ## Vectorstore 27 | 28 | > [!WARNING] 29 | > In v0.0.14+, `PGVector` is deprecated. Please migrate to `PGVectorStore` 30 | > for improved performance and manageability. 31 | > See the [migration guide](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/migrate_pgvector_to_pgvectorstore.ipynb) for details on how to migrate from `PGVector` to `PGVectorStore`. 32 | 33 | ### Documentation 34 | 35 | * [Quickstart](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/pg_vectorstore.ipynb) 36 | * [How-to](https://github.com/langchain-ai/langchain-postgres/blob/main/examples/pg_vectorstore_how_to.ipynb) 37 | 38 | ### Example 39 | 40 | ```python 41 | from langchain_core.documents import Document 42 | from langchain_core.embeddings import DeterministicFakeEmbedding 43 | from langchain_postgres import PGEngine, PGVectorStore 44 | 45 | # Replace the connection string with your own Postgres connection string 46 | CONNECTION_STRING = "postgresql+psycopg3://langchain:langchain@localhost:6024/langchain" 47 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 48 | 49 | # Replace the vector size with your own vector size 50 | VECTOR_SIZE = 768 51 | embedding = DeterministicFakeEmbedding(size=VECTOR_SIZE) 52 | 53 | TABLE_NAME = "my_doc_collection" 54 | 55 | engine.init_vectorstore_table( 56 | table_name=TABLE_NAME, 57 | vector_size=VECTOR_SIZE, 58 | ) 59 | 60 | store = PGVectorStore.create_sync( 61 | engine=engine, 62 | table_name=TABLE_NAME, 63 | embedding_service=embedding, 64 | ) 65 | 66 | docs = [ 67 | Document(page_content="Apples and oranges"), 68 | Document(page_content="Cars and airplanes"), 69 | Document(page_content="Train") 70 | ] 71 | 72 | store.add_documents(docs) 73 | 74 | query = "I'd like a fruit." 75 | docs = store.similarity_search(query) 76 | print(docs) 77 | ``` 78 | 79 | > [!TIP] 80 | > All synchronous functions have corresponding asynchronous functions 81 | 82 | ## ChatMessageHistory 83 | 84 | The chat message history abstraction helps to persist chat message history 85 | in a postgres table. 86 | 87 | PostgresChatMessageHistory is parameterized using a `table_name` and a `session_id`. 88 | 89 | The `table_name` is the name of the table in the database where 90 | the chat messages will be stored. 91 | 92 | The `session_id` is a unique identifier for the chat session. It can be assigned 93 | by the caller using `uuid.uuid4()`. 94 | 95 | ```python 96 | import uuid 97 | 98 | from langchain_core.messages import SystemMessage, AIMessage, HumanMessage 99 | from langchain_postgres import PostgresChatMessageHistory 100 | import psycopg 101 | 102 | # Establish a synchronous connection to the database 103 | # (or use psycopg.AsyncConnection for async) 104 | conn_info = ... # Fill in with your connection info 105 | sync_connection = psycopg.connect(conn_info) 106 | 107 | # Create the table schema (only needs to be done once) 108 | table_name = "chat_history" 109 | PostgresChatMessageHistory.create_tables(sync_connection, table_name) 110 | 111 | session_id = str(uuid.uuid4()) 112 | 113 | # Initialize the chat history manager 114 | chat_history = PostgresChatMessageHistory( 115 | table_name, 116 | session_id, 117 | sync_connection=sync_connection 118 | ) 119 | 120 | # Add messages to the chat history 121 | chat_history.add_messages([ 122 | SystemMessage(content="Meow"), 123 | AIMessage(content="woof"), 124 | HumanMessage(content="bark"), 125 | ]) 126 | 127 | print(chat_history.messages) 128 | ``` 129 | 130 | ## Google Cloud Integrations 131 | 132 | [Google Cloud](https://python.langchain.com/docs/integrations/providers/google/) provides Vector Store, Chat Message History, and Data Loader integrations for [AlloyDB](https://cloud.google.com/alloydb) and [Cloud SQL](https://cloud.google.com/sql) for PostgreSQL databases via the following PyPi packages: 133 | 134 | * [`langchain-google-alloydb-pg`](https://github.com/googleapis/langchain-google-alloydb-pg-python) 135 | 136 | * [`langchain-google-cloud-sql-pg`](https://github.com/googleapis/langchain-google-cloud-sql-pg-python) 137 | 138 | Using the Google Cloud integrations provides the following benefits: 139 | 140 | - **Enhanced Security**: Securely connect to Google Cloud databases utilizing IAM for authorization and database authentication without needing to manage SSL certificates, configure firewall rules, or enable authorized networks. 141 | - **Simplified and Secure Connections:** Connect to Google Cloud databases effortlessly using the instance name instead of complex connection strings. The integrations creates a secure connection pool that can be easily shared across your application using the `engine` object. 142 | 143 | | Vector Store | Metadata filtering | Async support | Schema Flexibility | Improved metadata handling | Hybrid Search | 144 | |--------------------------|--------------------|----------------|--------------------|----------------------------|---------------| 145 | | Google AlloyDB | ✓ | ✓ | ✓ | ✓ | ✗ | 146 | | Google Cloud SQL Postgres| ✓ | ✓ | ✓ | ✓ | ✗ | 147 | 148 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | name: langchain-postgres 2 | 3 | services: 4 | pgvector: 5 | # postgres with the pgvector extension 6 | image: pgvector/pgvector:pg16 7 | environment: 8 | POSTGRES_DB: langchain_test 9 | POSTGRES_USER: langchain 10 | POSTGRES_PASSWORD: langchain 11 | ports: 12 | - "5432:5432" 13 | command: | 14 | postgres -c log_statement=all 15 | healthcheck: 16 | test: 17 | [ 18 | "CMD-SHELL", 19 | "psql postgresql://langchain:langchain@localhost/langchain --command 'SELECT 1;' || exit 1", 20 | ] 21 | interval: 5s 22 | retries: 60 23 | volumes: 24 | - postgres_data_pgvector_16:/var/lib/postgresql/data 25 | 26 | volumes: 27 | postgres_data_pgvector_16: 28 | -------------------------------------------------------------------------------- /docs/v2_design_overview.md: -------------------------------------------------------------------------------- 1 | # Design Overview for `PGVectorStore` 2 | 3 | This document outlines the design choices behind the PGVectorStore integration for LangChain, focusing on how an async PostgreSQL driver can supports both synchronous and asynchronous usage. 4 | 5 | ## Motivation: Performance through Asynchronicity 6 | 7 | Database interactions are often I/O-bound, making asynchronous programming crucial for performance. 8 | 9 | - **Non-Blocking Operations:** Asynchronous code prevents the application from stalling while waiting for database responses, improving throughput and responsiveness. 10 | - **Asynchronous Foundation (`asyncio` and Drivers):** Built upon Python's `asyncio` library, the integration is designed to work with asynchronous PostgreSQL drivers to handle database operations efficiently. While compatible drivers are supported, the `asyncpg` driver is specifically recommended due to its high performance in concurrent scenarios. You can explore its benefits ([link](https://magic.io/blog/asyncpg-1m-rows-from-postgres-to-python/)) and performance benchmarks ([link](https://fernandoarteaga.dev/blog/psycopg-vs-asyncpg/)) for more details. 11 | 12 | This native async foundation ensures the core database interactions are fast and scalable. 13 | 14 | ## The Two-Class Approach: Enabling a Mixed Interface 15 | 16 | To cater to different application architectures while maintaining performance, we provide two classes: 17 | 18 | 1. **`AsyncPGVectorStore` (Core Asynchronous Implementation):** 19 | * This class contains the pure `async/await` logic for all database operations. 20 | * It's designed for **direct use within asynchronous applications**. Users working in an `asyncio` environment can `await` its methods for maximum efficiency and direct control within the event loop. 21 | * It represents the fundamental, non-blocking way of interacting with the database. 22 | 23 | 2. **`PGVectorStore` (Mixed Sync/Async API ):** 24 | * This class provides both asynchronous & synchronous APIs. 25 | * When one of its methods is called, it internally invokes the corresponding `async` method from `AsyncPGVectorStore`. 26 | * It **manages the execution of this underlying asynchronous logic**, handling the necessary `asyncio` event loop interactions (e.g., starting/running the coroutine) behind the scenes. 27 | * This allows users of synchronous codebases to leverage the performance benefits of the asynchronous core without needing to rewrite their application structure. 28 | 29 | ## Benefits of this Dual Interface Design 30 | 31 | This two-class structure provides significant advantages: 32 | 33 | - **Interface Flexibility:** Developers can **choose the interface that best fits their needs**: 34 | * Use `PGVectorStore` for easy integration into existing synchronous applications. 35 | * Use `AsyncPGVectorStore` for optimal performance and integration within `asyncio`-based applications. 36 | - **Ease of Use:** `PGVectorStore` offers a familiar synchronous programming model, hiding the complexity of managing async execution from the end-user. 37 | - **Robustness:** The clear separation helps prevent common errors associated with mixing synchronous and asynchronous code incorrectly, such as blocking the event loop from synchronous calls within an async context. 38 | - **Efficiency for Async Users:** `AsyncPGVectorStore` provides a direct path for async applications, avoiding any potential overhead from the sync-to-async bridging layer present in `PGVectorStore`. 39 | -------------------------------------------------------------------------------- /examples/migrate_pgvector_to_pgvectorstore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Migrate a `PGVector` vector store to `PGVectorStore`\n", 8 | "\n", 9 | "This guide shows how to migrate from the [`PGVector`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstores.py) vector store class to the [`PGVectorStore`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstore.py) class.\n", 10 | "\n", 11 | "## Why migrate?\n", 12 | "\n", 13 | "This guide explains how to migrate your vector data from a PGVector-style database (two tables) to an PGVectoStore-style database (one table per collection) for improved performance and manageability.\n", 14 | "\n", 15 | "Migrating to the PGVectorStore interface provides the following benefits:\n", 16 | "\n", 17 | "- **Simplified management**: A single table contains data corresponding to a single collection, making it easier to query, update, and maintain.\n", 18 | "- **Improved metadata handling**: It stores metadata in columns instead of JSON, resulting in significant performance improvements.\n", 19 | "- **Schema flexibility**: The interface allows users to add tables into any database schema.\n", 20 | "- **Improved performance**: The single-table schema can lead to faster query execution, especially for large collections.\n", 21 | "- **Clear separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows.\n", 22 | "- **Secure Connections:** The PGVectorStore interface creates a secure connection pool that can be easily shared across your application using the `engine` object.\n", 23 | "\n", 24 | "## Migration process\n", 25 | "\n", 26 | "> **_NOTE:_** The langchain-core library is installed to use the Fake embeddings service. To use a different embedding service, you'll need to install the appropriate library for your chosen provider. Choose embeddings services from [LangChain's Embedding models](https://python.langchain.com/v0.2/docs/integrations/text_embedding/)." 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "IR54BmgvdHT_" 33 | }, 34 | "source": [ 35 | "### Library Installation\n", 36 | "Install the integration library, `langchain-postgres`." 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "colab": { 44 | "base_uri": "https://localhost:8080/", 45 | "height": 1000 46 | }, 47 | "id": "0ZITIDE160OD", 48 | "outputId": "e184bc0d-6541-4e0a-82d2-1e216db00a2d" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "%pip install --upgrade --quiet langchain-postgres langchain-core SQLAlchemy" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "f8f2830ee9ca1e01", 58 | "metadata": { 59 | "id": "f8f2830ee9ca1e01" 60 | }, 61 | "source": [ 62 | "## Data Migration" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "OMvzMWRrR6n7", 68 | "metadata": { 69 | "id": "OMvzMWRrR6n7" 70 | }, 71 | "source": [ 72 | "### Set the postgres connection url\n", 73 | "\n", 74 | "`PGVectorStore` can be used with the `asyncpg` and `psycopg3` drivers." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "irl7eMFnSPZr", 81 | "metadata": { 82 | "id": "irl7eMFnSPZr" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "# @title Set Your Values Here { display-mode: \"form\" }\n", 87 | "POSTGRES_USER = \"langchain\" # @param {type: \"string\"}\n", 88 | "POSTGRES_PASSWORD = \"langchain\" # @param {type: \"string\"}\n", 89 | "POSTGRES_HOST = \"localhost\" # @param {type: \"string\"}\n", 90 | "POSTGRES_PORT = \"6024\" # @param {type: \"string\"}\n", 91 | "POSTGRES_DB = \"langchain\" # @param {type: \"string\"}" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "QuQigs4UoFQ2", 97 | "metadata": { 98 | "id": "QuQigs4UoFQ2" 99 | }, 100 | "source": [ 101 | "### PGEngine Connection Pool\n", 102 | "\n", 103 | "One of the requirements and arguments to establish PostgreSQL as a vector store is a `PGEngine` object. The `PGEngine` configures a shared connection pool to your Postgres database. This is an industry best practice to manage number of connections and to reduce latency through cached database connections.\n", 104 | "\n", 105 | "To create a `PGEngine` using `PGEngine.from_connection_string()` you need to provide:\n", 106 | "\n", 107 | "1. `url` : Connection string using the `postgresql+asyncpg` driver.\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "**Note:** This tutorial demonstrates the async interface. All async methods have corresponding sync methods." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# See docker command above to launch a Postgres instance with pgvector enabled.\n", 124 | "CONNECTION_STRING = (\n", 125 | " f\"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}\"\n", 126 | " f\":{POSTGRES_PORT}/{POSTGRES_DB}\"\n", 127 | ")\n", 128 | "# To use psycopg3 driver, set your connection string to `postgresql+psycopg://`" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from langchain_postgres import PGEngine\n", 138 | "\n", 139 | "engine = PGEngine.from_connection_string(url=CONNECTION_STRING)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "To create a `PGEngine` using `PGEngine.from_engine()` you need to provide:\n", 147 | "\n", 148 | "1. `engine` : An object of `AsyncEngine`" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "from sqlalchemy.ext.asyncio import create_async_engine\n", 158 | "\n", 159 | "# Create an SQLAlchemy Async Engine\n", 160 | "pool = create_async_engine(\n", 161 | " CONNECTION_STRING,\n", 162 | ")\n", 163 | "\n", 164 | "engine = PGEngine.from_engine(engine=pool)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "### Get all collections\n", 172 | "\n", 173 | "This script migrates each collection to a new Vector Store table." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from langchain_postgres.utils.pgvector_migrator import alist_pgvector_collection_names\n", 183 | "\n", 184 | "all_collection_names = await alist_pgvector_collection_names(engine)\n", 185 | "print(all_collection_names)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "id": "D9Xs2qhm6X56" 192 | }, 193 | "source": [ 194 | "### Create a new table(s) to migrate existing data\n", 195 | "The `PGVectorStore` class requires a database table. The `PGEngine` engine has a helper method `ainit_vectorstore_table()` that can be used to create a table with the proper schema for you." 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "You can also specify a schema name by passing `schema_name` wherever you pass `table_name`. Eg:\n", 203 | "\n", 204 | "```python\n", 205 | "SCHEMA_NAME=\"my_schema\"\n", 206 | "\n", 207 | "await engine.ainit_vectorstore_table(\n", 208 | " table_name=TABLE_NAME,\n", 209 | " vector_size=768,\n", 210 | " schema_name=SCHEMA_NAME, # Default: \"public\"\n", 211 | ")\n", 212 | "```\n", 213 | "\n", 214 | "When creating your vectorstore table, you have the flexibility to define custom metadata and ID columns. This is particularly useful for:\n", 215 | "\n", 216 | "- **Filtering**: Metadata columns allow you to easily filter your data within the vectorstore. For example, you might store the document source, date, or author as metadata for efficient retrieval.\n", 217 | "- **Non-UUID Identifiers**: By default, the id_column uses UUIDs. If you need to use a different type of ID (e.g., an integer or string), you can define a custom id_column.\n", 218 | "\n", 219 | "```python\n", 220 | "metadata_columns = [\n", 221 | " Column(f\"col_0_{collection_name}\", \"VARCHAR\"),\n", 222 | " Column(f\"col_1_{collection_name}\", \"VARCHAR\"),\n", 223 | "]\n", 224 | "engine.init_vectorstore_table(\n", 225 | " table_name=\"destination_table\",\n", 226 | " vector_size=VECTOR_SIZE,\n", 227 | " metadata_columns=metadata_columns,\n", 228 | " id_column=Column(\"langchain_id\", \"VARCHAR\"),\n", 229 | ")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "id": "avlyHEMn6gzU" 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "# Vertex AI embeddings uses a vector size of 768.\n", 241 | "# Adjust this according to your embeddings service.\n", 242 | "VECTOR_SIZE = 768\n", 243 | "for collection_name in all_collection_names:\n", 244 | " engine.init_vectorstore_table(\n", 245 | " table_name=collection_name,\n", 246 | " vector_size=VECTOR_SIZE,\n", 247 | " )" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "### Create a vector store and migrate data\n", 255 | "\n", 256 | "> **_NOTE:_** The `FakeEmbeddings` embedding service is only used to initialize a vector store object, not to generate any embeddings. The embeddings are copied directly from the PGVector table.\n", 257 | "\n", 258 | "If you have any customizations on the metadata or the id columns, add them to the vector store as follows:\n", 259 | "\n", 260 | "```python\n", 261 | "from langchain_postgres import PGVectorStore\n", 262 | "from langchain_core.embeddings import FakeEmbeddings\n", 263 | "\n", 264 | "destination_vector_store = PGVectorStore.create_sync(\n", 265 | " engine,\n", 266 | " embedding_service=FakeEmbeddings(size=VECTOR_SIZE),\n", 267 | " table_name=DESTINATION_TABLE_NAME,\n", 268 | " metadata_columns=[col.name for col in metadata_columns],\n", 269 | " id_column=\"langchain_id\",\n", 270 | ")\n", 271 | "```" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": { 278 | "id": "z-AZyzAQ7bsf" 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "from langchain_core.embeddings import FakeEmbeddings\n", 283 | "\n", 284 | "from langchain_postgres import PGVectorStore\n", 285 | "from langchain_postgres.utils.pgvector_migrator import amigrate_pgvector_collection\n", 286 | "\n", 287 | "for collection_name in all_collection_names:\n", 288 | " destination_vector_store = await PGVectorStore.create(\n", 289 | " engine,\n", 290 | " embedding_service=FakeEmbeddings(size=VECTOR_SIZE),\n", 291 | " table_name=collection_name,\n", 292 | " )\n", 293 | "\n", 294 | " await amigrate_pgvector_collection(\n", 295 | " engine,\n", 296 | " # Set collection name here\n", 297 | " collection_name=collection_name,\n", 298 | " vector_store=destination_vector_store,\n", 299 | " # This deletes data from the original table upon migration. You can choose to turn it off.\n", 300 | " # The data will only be deleted from the original table once all of it has been successfully copied to the destination table.\n", 301 | " delete_pg_collection=True,\n", 302 | " )" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "colab": { 308 | "provenance": [], 309 | "toc_visible": true 310 | }, 311 | "kernelspec": { 312 | "display_name": "Python 3", 313 | "name": "python3" 314 | }, 315 | "language_info": { 316 | "codemirror_mode": { 317 | "name": "ipython", 318 | "version": 3 319 | }, 320 | "file_extension": ".py", 321 | "mimetype": "text/x-python", 322 | "name": "python", 323 | "nbconvert_exporter": "python", 324 | "pygments_lexer": "ipython3", 325 | "version": "3.12.3" 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 0 330 | } 331 | -------------------------------------------------------------------------------- /examples/migrate_pgvector_to_pgvectorstore.md: -------------------------------------------------------------------------------- 1 | # Migrate a `PGVector` vector store to `PGVectorStore` 2 | 3 | This guide shows how to migrate from the [`PGVector`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstores.py) vector store class to the [`PGVectorStore`](https://github.com/langchain-ai/langchain-postgres/blob/main/langchain_postgres/vectorstore.py) class. 4 | 5 | ## Why migrate? 6 | 7 | This guide explains how to migrate your vector data from a PGVector-style database (two tables) to an PGVectoStore-style database (one table per collection) for improved performance and manageability. 8 | 9 | Migrating to the PGVectorStore interface provides the following benefits: 10 | 11 | - **Simplified management**: A single table contains data corresponding to a single collection, making it easier to query, update, and maintain. 12 | - **Improved metadata handling**: It stores metadata in columns instead of JSON, resulting in significant performance improvements. 13 | - **Schema flexibility**: The interface allows users to add tables into any database schema. 14 | - **Improved performance**: The single-table schema can lead to faster query execution, especially for large collections. 15 | - **Clear separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows. 16 | - **Secure Connections:** The PGVectorStore interface creates a secure connection pool that can be easily shared across your application using the `engine` object. 17 | 18 | ## Migration process 19 | 20 | > **_NOTE:_** The langchain-core library is installed to use the Fake embeddings service. To use a different embedding service, you'll need to install the appropriate library for your chosen provider. Choose embeddings services from [LangChain's Embedding models](https://python.langchain.com/v0.2/docs/integrations/text_embedding/). 21 | 22 | While you can use the existing PGVector database, we **strongly recommend** migrating your data to the PGVectorStore-style schema to take full advantage of the performance benefits. 23 | 24 | ### (Recommended) Data migration 25 | 26 | 1. **Create a PG engine.** 27 | 28 | ```python 29 | from langchain_postgres import PGEngine 30 | 31 | # Replace these variable values 32 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 33 | ``` 34 | 35 | > **_NOTE:_** All sync methods have corresponding async methods. 36 | 37 | 2. **Create a new table to migrate existing data.** 38 | 39 | ```python 40 | # Vertex AI embeddings uses a vector size of 768. 41 | # Adjust this according to your embeddings service. 42 | VECTOR_SIZE = 768 43 | 44 | engine.init_vectorstore_table( 45 | table_name="destination_table", 46 | vector_size=VECTOR_SIZE, 47 | ) 48 | ``` 49 | 50 | **(Optional) Customize your table.** 51 | 52 | When creating your vectorstore table, you have the flexibility to define custom metadata and ID columns. This is particularly useful for: 53 | 54 | - **Filtering**: Metadata columns allow you to easily filter your data within the vectorstore. For example, you might store the document source, date, or author as metadata for efficient retrieval. 55 | - **Non-UUID Identifiers**: By default, the id_column uses UUIDs. If you need to use a different type of ID (e.g., an integer or string), you can define a custom id_column. 56 | 57 | ```python 58 | metadata_columns = [ 59 | Column(f"col_0_{collection_name}", "VARCHAR"), 60 | Column(f"col_1_{collection_name}", "VARCHAR"), 61 | ] 62 | engine.init_vectorstore_table( 63 | table_name="destination_table", 64 | vector_size=VECTOR_SIZE, 65 | metadata_columns=metadata_columns, 66 | id_column=Column("langchain_id", "VARCHAR"), 67 | ) 68 | ``` 69 | 70 | 3. **Create a vector store object to interact with the new data.** 71 | 72 | > **_NOTE:_** The `FakeEmbeddings` embedding service is only used to initialise a vector store object, not to generate any embeddings. The embeddings are copied directly from the PGVector table. 73 | 74 | ```python 75 | from langchain_postgres import PGVectorStore 76 | from langchain_core.embeddings import FakeEmbeddings 77 | 78 | destination_vector_store = PGVectorStore.create_sync( 79 | engine, 80 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE), 81 | table_name="destination_table", 82 | ) 83 | ``` 84 | 85 | If you have any customisations on the metadata or the id columns, add them to the vector store as follows: 86 | 87 | ```python 88 | from langchain_postgres import PGVectorStore 89 | from langchain_core.embeddings import FakeEmbeddings 90 | 91 | destination_vector_store = PGVectorStore.create_sync( 92 | engine, 93 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE), 94 | table_name="destination_table", 95 | metadata_columns=[col.name for col in metadata_columns], 96 | id_column="langchain_id", 97 | ) 98 | ``` 99 | 100 | 4. **Migrate the data to the new table.** 101 | 102 | ```python 103 | from langchain_postgres.utils.pgvector_migrator import amigrate_pgvector_collection 104 | 105 | migrate_pgvector_collection( 106 | engine, 107 | # Set collection name here 108 | collection_name="collection_name", 109 | vector_store=destination_vector_store, 110 | # This deletes data from the original table upon migration. You can choose to turn it off. 111 | delete_pg_collection=True, 112 | ) 113 | ``` 114 | 115 | The data will only be deleted from the original table once all of it has been successfully copied to the destination table. 116 | 117 | > **TIP:** If you would like to migrate multiple collections, you can use the `alist_pgvector_collection_names` method to get the names of all collections, allowing you to iterate through them. 118 | > 119 | > ```python 120 | > from langchain_postgres.utils.pgvector_migrator import alist_pgvector_collection_names 121 | > 122 | > all_collection_names = list_pgvector_collection_names(engine) 123 | > print(all_collection_names) 124 | > ``` 125 | 126 | ### (Not Recommended) Use PGVectorStore interface on PGVector databases 127 | 128 | If you choose not to migrate your data, you can still use the PGVectorStore interface with your existing PGVector database. However, you won't benefit from the performance improvements of the PGVectorStore-style schema. 129 | 130 | 1. **Create an PGVectorStore engine.** 131 | 132 | ```python 133 | from langchain_postgres import PGEngine 134 | 135 | # Replace these variable values 136 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 137 | ``` 138 | 139 | > **_NOTE:_** All sync methods have corresponding async methods. 140 | 141 | 2. **Create a vector store object to interact with the data.** 142 | 143 | Use the embeddings service used by your database. See [langchain docs](https://python.langchain.com/docs/integrations/text_embedding/) for reference. 144 | 145 | ```python 146 | from langchain_postgres import PGVectorStore 147 | from langchain_core.embeddings import FakeEmbeddings 148 | 149 | vector_store = PGVectorStore.create_sync( 150 | engine=engine, 151 | table_name="langchain_pg_embedding", 152 | embedding_service=FakeEmbeddings(size=VECTOR_SIZE), 153 | content_column="document", 154 | metadata_json_column="cmetadata", 155 | metadata_columns=["collection_id"], 156 | id_column="id", 157 | ) 158 | ``` 159 | 160 | 3. **Perform similarity search.** 161 | 162 | Filter by collection id: 163 | 164 | ```python 165 | vector_store.similarity_search("query", k=5, filter=f"collection_id='{uuid}'") 166 | ``` 167 | 168 | Filter by collection id and metadata: 169 | 170 | ```python 171 | vector_store.similarity_search( 172 | "query", k=5, filter=f"collection_id='{uuid}' and cmetadata->>'col_name' = 'value'" 173 | ) 174 | ``` -------------------------------------------------------------------------------- /examples/pg_vectorstore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# PGVectorStore\n", 8 | "\n", 9 | "`PGVectorStore` is a an implementation of the the LangChain vectorstore abstraction using `postgres` as the backend.\n", 10 | "\n", 11 | "## Requirements\n", 12 | "\n", 13 | "You'll need a PostgreSQL database with the `pgvector` extension enabled.\n", 14 | "\n", 15 | "\n", 16 | "For local development, you can use the following docker command to spin up the database:\n", 17 | "\n", 18 | "```shell\n", 19 | "docker run --name pgvector-container -e POSTGRES_USER=langchain -e POSTGRES_PASSWORD=langchain -e POSTGRES_DB=langchain -p 6024:5432 -d pgvector/pgvector:pg16\n", 20 | "```" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": { 26 | "id": "IR54BmgvdHT_" 27 | }, 28 | "source": [ 29 | "## Install\n", 30 | "\n", 31 | "Install the `langchain-postgres` package." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "colab": { 39 | "base_uri": "https://localhost:8080/", 40 | "height": 1000 41 | }, 42 | "id": "0ZITIDE160OD", 43 | "outputId": "e184bc0d-6541-4e0a-82d2-1e216db00a2d", 44 | "tags": [] 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "%pip install --upgrade --quiet langchain-postgres" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "QuQigs4UoFQ2" 55 | }, 56 | "source": [ 57 | "## Create an engine\n", 58 | "\n", 59 | "The first step is to create a `PGEngine` instance, which does the following:\n", 60 | "\n", 61 | "1. Allows you to create tables for storing documents and embeddings.\n", 62 | "2. Maintains a connection pool that manages connections to the database. This allows sharing of the connection pool and helps to reduce latency for database calls." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "tags": [] 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "from langchain_postgres import PGEngine\n", 74 | "\n", 75 | "# See docker command above to launch a Postgres instance with pgvector enabled.\n", 76 | "# Replace these values with your own configuration.\n", 77 | "POSTGRES_USER = \"langchain\"\n", 78 | "POSTGRES_PASSWORD = \"langchain\"\n", 79 | "POSTGRES_HOST = \"localhost\"\n", 80 | "POSTGRES_PORT = \"6024\"\n", 81 | "POSTGRES_DB = \"langchain\"\n", 82 | "\n", 83 | "CONNECTION_STRING = (\n", 84 | " f\"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}\"\n", 85 | " f\":{POSTGRES_PORT}/{POSTGRES_DB}\"\n", 86 | ")\n", 87 | "\n", 88 | "pg_engine = PGEngine.from_connection_string(url=CONNECTION_STRING)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "To use psycopg3 driver, set your connection string to `postgresql+psycopg://`" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "id": "D9Xs2qhm6X56" 102 | }, 103 | "source": [ 104 | "## Create a document collection\n", 105 | "\n", 106 | "Use the `PGEngine.ainit_vectorstore_table()` method to create a database table to store the documents and embeddings. This table will be created with appropriate schema." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "tags": [] 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "TABLE_NAME = \"vectorstore\"\n", 118 | "\n", 119 | "# The vector size (also called embedding size) is determined by the embedding model you use!\n", 120 | "VECTOR_SIZE = 1536" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Use the `Column` class to customize the table schema. A Column is defined by a name and data type. Any Postgres [data type](https://www.postgresql.org/docs/current/datatype.html) can be used." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": { 134 | "id": "avlyHEMn6gzU", 135 | "tags": [] 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "from sqlalchemy.exc import ProgrammingError\n", 140 | "\n", 141 | "from langchain_postgres import Column\n", 142 | "\n", 143 | "try:\n", 144 | " await pg_engine.ainit_vectorstore_table(\n", 145 | " table_name=TABLE_NAME,\n", 146 | " vector_size=VECTOR_SIZE,\n", 147 | " metadata_columns=[\n", 148 | " Column(\"likes\", \"INTEGER\"),\n", 149 | " Column(\"location\", \"TEXT\"),\n", 150 | " Column(\"topic\", \"TEXT\"),\n", 151 | " ],\n", 152 | " )\n", 153 | "except ProgrammingError:\n", 154 | " # Catching the exception here\n", 155 | " print(\"Table already exists. Skipping creation.\")" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Configure an embeddings model\n", 163 | "\n", 164 | "You need to configure a vectorstore with an embedding model. The embedding model will be used automatically when adding documents and when searching.\n", 165 | "\n", 166 | "We'll use `langchain-openai` as the embedding more here, but you can use any [LangChain embeddings model](https://python.langchain.com/docs/integrations/text_embedding/)." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "tags": [] 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "%pip install --upgrade --quiet langchain-openai" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "colab": { 185 | "base_uri": "https://localhost:8080/" 186 | }, 187 | "id": "Vb2RJocV9_LQ", 188 | "outputId": "37f5dc74-2512-47b2-c135-f34c10afdcf4", 189 | "tags": [] 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "from langchain_openai import OpenAIEmbeddings\n", 194 | "\n", 195 | "embedding = OpenAIEmbeddings(model=\"text-embedding-3-small\")" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "e1tl0aNx7SWy" 202 | }, 203 | "source": [ 204 | "## Initialize the vectorstore\n", 205 | "\n", 206 | "Once the schema for the document collection exists, you can initialize a vectorstore that uses the schema.\n", 207 | "\n", 208 | "You can use the vectorstore to do basic operations; including:\n", 209 | "\n", 210 | "1. Add documents\n", 211 | "2. Delete documents\n", 212 | "3. Search through the documents" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": { 219 | "id": "z-AZyzAQ7bsf", 220 | "tags": [] 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "from langchain_postgres import PGVectorStore\n", 225 | "\n", 226 | "vectorstore = await PGVectorStore.create(\n", 227 | " engine=pg_engine,\n", 228 | " table_name=TABLE_NAME,\n", 229 | " embedding_service=embedding,\n", 230 | " metadata_columns=[\"location\", \"topic\"],\n", 231 | ")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "## Add documents\n", 239 | "\n", 240 | "\n", 241 | "You can add documents using the `aadd_documents` method. \n", 242 | "\n", 243 | "* Assign unique IDs to documents to avoid duplicated content in your database.\n", 244 | "* Adding a document by ID implements has `upsert` semantics (i.e., create if does not exist, update if exists)." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": { 251 | "tags": [] 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "import uuid\n", 256 | "\n", 257 | "from langchain_core.documents import Document\n", 258 | "\n", 259 | "docs = [\n", 260 | " Document(\n", 261 | " id=uuid.uuid4(),\n", 262 | " page_content=\"there are cats in the pond\",\n", 263 | " metadata={\"likes\": 1, \"location\": \"pond\", \"topic\": \"animals\"},\n", 264 | " ),\n", 265 | " Document(\n", 266 | " id=uuid.uuid4(),\n", 267 | " page_content=\"ducks are also found in the pond\",\n", 268 | " metadata={\"likes\": 30, \"location\": \"pond\", \"topic\": \"animals\"},\n", 269 | " ),\n", 270 | " Document(\n", 271 | " id=uuid.uuid4(),\n", 272 | " page_content=\"fresh apples are available at the market\",\n", 273 | " metadata={\"likes\": 20, \"location\": \"market\", \"topic\": \"food\"},\n", 274 | " ),\n", 275 | " Document(\n", 276 | " id=uuid.uuid4(),\n", 277 | " page_content=\"the market also sells fresh oranges\",\n", 278 | " metadata={\"likes\": 5, \"location\": \"market\", \"topic\": \"food\"},\n", 279 | " ),\n", 280 | "]\n", 281 | "\n", 282 | "\n", 283 | "await vectorstore.aadd_documents(documents=docs)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "## Delete Documents\n", 291 | "\n", 292 | "Documents can be deleted by ID." 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": { 299 | "tags": [] 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "# We'll use the ID of the first doc to delete it\n", 304 | "ids = [docs[0].id]\n", 305 | "await vectorstore.adelete(ids)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "## Search\n", 313 | "\n", 314 | "Search for similar documents using a natural language query." 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": { 321 | "tags": [] 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "query = \"I'd like a fruit.\"\n", 326 | "docs = await vectorstore.asimilarity_search(query)\n", 327 | "for doc in docs:\n", 328 | " print(repr(doc))" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "### Search by vector" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": { 342 | "tags": [] 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "query_vector = embedding.embed_query(query)\n", 347 | "docs = await vectorstore.asimilarity_search_by_vector(query_vector, k=2)\n", 348 | "print(docs)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "## Filtering" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "To enable search with filters, it is necessary to declare the columns that you want to filter on when creating the table. The vectorstore supports a set of filters that can be applied against the metadata fields of the documents.\n", 363 | "\n", 364 | "`PGVectorStore` currently supports the following operators.\n", 365 | "\n", 366 | "| Operator | Meaning/Category |\n", 367 | "|-----------|-------------------------|\n", 368 | "| \\$eq | Equality (==) |\n", 369 | "| \\$ne | Inequality (!=) |\n", 370 | "| \\$lt | Less than (<) |\n", 371 | "| \\$lte | Less than or equal (<=) |\n", 372 | "| \\$gt | Greater than (>) |\n", 373 | "| \\$gte | Greater than or equal (>=) |\n", 374 | "| \\$in | Special Cased (in) |\n", 375 | "| \\$nin | Special Cased (not in) |\n", 376 | "| \\$between | Special Cased (between) |\n", 377 | "| \\$exists | Special Cased (is null) |\n", 378 | "| \\$like | Text (like) |\n", 379 | "| \\$ilike | Text (case-insensitive like) |\n", 380 | "| \\$and | Logical (and) |\n", 381 | "| \\$or | Logical (or) |\n" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": { 388 | "tags": [] 389 | }, 390 | "outputs": [], 391 | "source": [ 392 | "await vectorstore.asimilarity_search(\n", 393 | " \"birds\", filter={\"$or\": [{\"topic\": \"animals\"}, {\"location\": \"market\"}]}\n", 394 | ")" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": { 401 | "tags": [] 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "await vectorstore.asimilarity_search(\"apple\", filter={\"topic\": \"food\"})" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "tags": [] 413 | }, 414 | "outputs": [], 415 | "source": [ 416 | "await vectorstore.asimilarity_search(\n", 417 | " \"apple\", filter={\"topic\": {\"$in\": [\"food\", \"animals\"]}}\n", 418 | ")" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": { 425 | "tags": [] 426 | }, 427 | "outputs": [], 428 | "source": [ 429 | "await vectorstore.asimilarity_search(\n", 430 | " \"sales of fruit\", filter={\"topic\": {\"$ne\": \"animals\"}}\n", 431 | ")" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "## Optimization\n", 439 | "\n", 440 | "Speed up vector search queries by adding appropriate indexes. Learn more about [vector indexes](https://cloud.google.com/blog/products/databases/faster-similarity-search-performance-with-pgvector-indexes)." 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "### Add an Index" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": { 454 | "tags": [] 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "from langchain_postgres.v2.indexes import IVFFlatIndex\n", 459 | "\n", 460 | "index = IVFFlatIndex() # Add an index using a default index name\n", 461 | "await vectorstore.aapply_vector_index(index)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "### Re-index\n", 469 | "\n", 470 | "Rebuild an index using the data stored in the index's table, replacing the old copy of the index. Some index types may require re-indexing after a considerable amount of new data is added." 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "metadata": { 477 | "tags": [] 478 | }, 479 | "outputs": [], 480 | "source": [ 481 | "await vectorstore.areindex() # Re-index using default index name" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "### Drop an index\n", 489 | "\n", 490 | "You can delete indexes" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": { 497 | "tags": [] 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "await vectorstore.adrop_vector_index() # Drop index using default name" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "## Clean up\n", 509 | "\n", 510 | "**⚠️ WARNING: this can not be undone**\n", 511 | "\n", 512 | "Drop the vector store table." 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": {}, 519 | "outputs": [], 520 | "source": [ 521 | "await pg_engine.adrop_table(TABLE_NAME)" 522 | ] 523 | } 524 | ], 525 | "metadata": { 526 | "colab": { 527 | "provenance": [], 528 | "toc_visible": true 529 | }, 530 | "kernelspec": { 531 | "display_name": "Python 3 (ipykernel)", 532 | "language": "python", 533 | "name": "python3" 534 | }, 535 | "language_info": { 536 | "codemirror_mode": { 537 | "name": "ipython", 538 | "version": 3 539 | }, 540 | "file_extension": ".py", 541 | "mimetype": "text/x-python", 542 | "name": "python", 543 | "nbconvert_exporter": "python", 544 | "pygments_lexer": "ipython3", 545 | "version": "3.11.4" 546 | } 547 | }, 548 | "nbformat": 4, 549 | "nbformat_minor": 4 550 | } 551 | -------------------------------------------------------------------------------- /langchain_postgres/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | from langchain_postgres.chat_message_histories import PostgresChatMessageHistory 4 | from langchain_postgres.translator import PGVectorTranslator 5 | from langchain_postgres.v2.engine import Column, ColumnDict, PGEngine 6 | from langchain_postgres.v2.vectorstores import PGVectorStore 7 | from langchain_postgres.vectorstores import PGVector 8 | 9 | try: 10 | __version__ = metadata.version(__package__) 11 | except metadata.PackageNotFoundError: 12 | # Case where package metadata is not available. 13 | __version__ = "" 14 | 15 | __all__ = [ 16 | "__version__", 17 | "Column", 18 | "ColumnDict", 19 | "PGEngine", 20 | "PostgresChatMessageHistory", 21 | "PGVector", 22 | "PGVectorStore", 23 | "PGVectorTranslator", 24 | ] 25 | -------------------------------------------------------------------------------- /langchain_postgres/_utils.py: -------------------------------------------------------------------------------- 1 | """Copied over from langchain_community. 2 | 3 | This code should be moved to langchain proper or removed entirely. 4 | """ 5 | 6 | import logging 7 | from typing import List, Union 8 | 9 | import numpy as np 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 14 | 15 | 16 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 17 | """Row-wise cosine similarity between two equal-width matrices.""" 18 | if len(X) == 0 or len(Y) == 0: 19 | return np.array([]) 20 | 21 | X = np.array(X) 22 | Y = np.array(Y) 23 | if X.shape[1] != Y.shape[1]: 24 | raise ValueError( 25 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 26 | f"and Y has shape {Y.shape}." 27 | ) 28 | try: 29 | import simsimd as simd # type: ignore 30 | 31 | X = np.array(X, dtype=np.float32) 32 | Y = np.array(Y, dtype=np.float32) 33 | Z = 1 - np.array(simd.cdist(X, Y, metric="cosine")) 34 | return Z 35 | except ImportError: 36 | logger.debug( 37 | "Unable to import simsimd, defaulting to NumPy implementation. If you want " 38 | "to use simsimd please install with `pip install simsimd`." 39 | ) 40 | X_norm = np.linalg.norm(X, axis=1) 41 | Y_norm = np.linalg.norm(Y, axis=1) 42 | # Ignore divide by zero errors run time warnings as those are handled below. 43 | with np.errstate(divide="ignore", invalid="ignore"): 44 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 45 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 46 | return similarity 47 | 48 | 49 | def maximal_marginal_relevance( 50 | query_embedding: np.ndarray, 51 | embedding_list: list, 52 | lambda_mult: float = 0.5, 53 | k: int = 4, 54 | ) -> List[int]: 55 | """Calculate maximal marginal relevance.""" 56 | if min(k, len(embedding_list)) <= 0: 57 | return [] 58 | if query_embedding.ndim == 1: 59 | query_embedding = np.expand_dims(query_embedding, axis=0) 60 | similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] 61 | most_similar = int(np.argmax(similarity_to_query)) 62 | idxs = [most_similar] 63 | selected = np.array([embedding_list[most_similar]]) 64 | while len(idxs) < min(k, len(embedding_list)): 65 | best_score = -np.inf 66 | idx_to_add = -1 67 | similarity_to_selected = cosine_similarity(embedding_list, selected) 68 | for i, query_score in enumerate(similarity_to_query): 69 | if i in idxs: 70 | continue 71 | redundant_score = max(similarity_to_selected[i]) 72 | equation_score = ( 73 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 74 | ) 75 | if equation_score > best_score: 76 | best_score = equation_score 77 | idx_to_add = i 78 | idxs.append(idx_to_add) 79 | selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) 80 | return idxs 81 | -------------------------------------------------------------------------------- /langchain_postgres/chat_message_histories.py: -------------------------------------------------------------------------------- 1 | """Client for persisting chat message history in a Postgres database. 2 | 3 | This client provides support for both sync and async via psycopg 3. 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | import json 9 | import logging 10 | import re 11 | import uuid 12 | from typing import List, Optional, Sequence 13 | 14 | import psycopg 15 | from langchain_core.chat_history import BaseChatMessageHistory 16 | from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict 17 | from psycopg import sql 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def _create_table_and_index(table_name: str) -> List[sql.Composed]: 23 | """Make a SQL query to create a table.""" 24 | index_name = f"idx_{table_name}_session_id" 25 | statements = [ 26 | sql.SQL( 27 | """ 28 | CREATE TABLE IF NOT EXISTS {table_name} ( 29 | id SERIAL PRIMARY KEY, 30 | session_id UUID NOT NULL, 31 | message JSONB NOT NULL, 32 | created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() 33 | ); 34 | """ 35 | ).format(table_name=sql.Identifier(table_name)), 36 | sql.SQL( 37 | """ 38 | CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id); 39 | """ 40 | ).format( 41 | table_name=sql.Identifier(table_name), index_name=sql.Identifier(index_name) 42 | ), 43 | ] 44 | return statements 45 | 46 | 47 | def _get_messages_query(table_name: str) -> sql.Composed: 48 | """Make a SQL query to get messages for a given session.""" 49 | return sql.SQL( 50 | "SELECT message " 51 | "FROM {table_name} " 52 | "WHERE session_id = %(session_id)s " 53 | "ORDER BY id;" 54 | ).format(table_name=sql.Identifier(table_name)) 55 | 56 | 57 | def _delete_by_session_id_query(table_name: str) -> sql.Composed: 58 | """Make a SQL query to delete messages for a given session.""" 59 | return sql.SQL( 60 | "DELETE FROM {table_name} WHERE session_id = %(session_id)s;" 61 | ).format(table_name=sql.Identifier(table_name)) 62 | 63 | 64 | def _delete_table_query(table_name: str) -> sql.Composed: 65 | """Make a SQL query to delete a table.""" 66 | return sql.SQL("DROP TABLE IF EXISTS {table_name};").format( 67 | table_name=sql.Identifier(table_name) 68 | ) 69 | 70 | 71 | def _insert_message_query(table_name: str) -> sql.Composed: 72 | """Make a SQL query to insert a message.""" 73 | return sql.SQL( 74 | "INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)" 75 | ).format(table_name=sql.Identifier(table_name)) 76 | 77 | 78 | class PostgresChatMessageHistory(BaseChatMessageHistory): 79 | def __init__( 80 | self, 81 | table_name: str, 82 | session_id: str, 83 | /, 84 | *, 85 | sync_connection: Optional[psycopg.Connection] = None, 86 | async_connection: Optional[psycopg.AsyncConnection] = None, 87 | ) -> None: 88 | """Client for persisting chat message history in a Postgres database, 89 | 90 | This client provides support for both sync and async via psycopg >=3. 91 | 92 | The client can create schema in the database and provides methods to 93 | add messages, get messages, and clear the chat message history. 94 | 95 | The schema has the following columns: 96 | 97 | - id: A serial primary key. 98 | - session_id: The session ID for the chat message history. 99 | - message: The JSONB message content. 100 | - created_at: The timestamp of when the message was created. 101 | 102 | Messages are retrieved for a given session_id and are sorted by 103 | the id (which should be increasing monotonically), and correspond 104 | to the order in which the messages were added to the history. 105 | 106 | The "created_at" column is not returned by the interface, but 107 | has been added for the schema so the information is available in the database. 108 | 109 | A session_id can be used to separate different chat histories in the same table, 110 | the session_id should be provided when initializing the client. 111 | 112 | This chat history client takes in a psycopg connection object (either 113 | Connection or AsyncConnection) and uses it to interact with the database. 114 | 115 | This design allows to reuse the underlying connection object across 116 | multiple instantiations of this class, making instantiation fast. 117 | 118 | This chat history client is designed for prototyping applications that 119 | involve chat and are based on Postgres. 120 | 121 | As your application grows, you will likely need to extend the schema to 122 | handle more complex queries. For example, a chat application 123 | may involve multiple tables like a user table, a table for storing 124 | chat sessions / conversations, and this table for storing chat messages 125 | for a given session. The application will require access to additional 126 | endpoints like deleting messages by user id, listing conversations by 127 | user id or ordering them based on last message time, etc. 128 | 129 | Feel free to adapt this implementation to suit your application's needs. 130 | 131 | Args: 132 | session_id: The session ID to use for the chat message history 133 | table_name: The name of the database table to use 134 | sync_connection: An existing psycopg connection instance 135 | async_connection: An existing psycopg async connection instance 136 | 137 | Usage: 138 | - Use the create_tables or acreate_tables method to set up the table 139 | schema in the database. 140 | - Initialize the class with the appropriate session ID, table name, 141 | and database connection. 142 | - Add messages to the database using add_messages or aadd_messages. 143 | - Retrieve messages with get_messages or aget_messages. 144 | - Clear the session history with clear or aclear when needed. 145 | 146 | Note: 147 | - At least one of sync_connection or async_connection must be provided. 148 | 149 | Examples: 150 | 151 | .. code-block:: python 152 | 153 | import uuid 154 | 155 | from langchain_core.messages import SystemMessage, AIMessage, HumanMessage 156 | from langchain_postgres import PostgresChatMessageHistory 157 | import psycopg 158 | 159 | # Establish a synchronous connection to the database 160 | # (or use psycopg.AsyncConnection for async) 161 | sync_connection = psycopg2.connect(conn_info) 162 | 163 | # Create the table schema (only needs to be done once) 164 | table_name = "chat_history" 165 | PostgresChatMessageHistory.create_tables(sync_connection, table_name) 166 | 167 | session_id = str(uuid.uuid4()) 168 | 169 | # Initialize the chat history manager 170 | chat_history = PostgresChatMessageHistory( 171 | table_name, 172 | session_id, 173 | sync_connection=sync_connection 174 | ) 175 | 176 | # Add messages to the chat history 177 | chat_history.add_messages([ 178 | SystemMessage(content="Meow"), 179 | AIMessage(content="woof"), 180 | HumanMessage(content="bark"), 181 | ]) 182 | 183 | print(chat_history.messages) 184 | """ 185 | if not sync_connection and not async_connection: 186 | raise ValueError("Must provide sync_connection or async_connection") 187 | 188 | self._connection = sync_connection 189 | self._aconnection = async_connection 190 | 191 | # Validate that session id is a UUID 192 | try: 193 | uuid.UUID(session_id) 194 | except ValueError: 195 | raise ValueError( 196 | f"Invalid session id. Session id must be a valid UUID. Got {session_id}" 197 | ) 198 | 199 | self._session_id = session_id 200 | 201 | if not re.match(r"^\w+$", table_name): 202 | raise ValueError( 203 | "Invalid table name. Table name must contain only alphanumeric " 204 | "characters and underscores." 205 | ) 206 | self._table_name = table_name 207 | 208 | @staticmethod 209 | def create_tables( 210 | connection: psycopg.Connection, 211 | table_name: str, 212 | /, 213 | ) -> None: 214 | """Create the table schema in the database and create relevant indexes.""" 215 | queries = _create_table_and_index(table_name) 216 | logger.info("Creating schema for table %s", table_name) 217 | with connection.cursor() as cursor: 218 | for query in queries: 219 | cursor.execute(query) 220 | connection.commit() 221 | 222 | @staticmethod 223 | async def acreate_tables( 224 | connection: psycopg.AsyncConnection, table_name: str, / 225 | ) -> None: 226 | """Create the table schema in the database and create relevant indexes.""" 227 | queries = _create_table_and_index(table_name) 228 | logger.info("Creating schema for table %s", table_name) 229 | async with connection.cursor() as cur: 230 | for query in queries: 231 | await cur.execute(query) 232 | await connection.commit() 233 | 234 | @staticmethod 235 | def drop_table(connection: psycopg.Connection, table_name: str, /) -> None: 236 | """Delete the table schema in the database. 237 | 238 | WARNING: 239 | This will delete the given table from the database including 240 | all the database in the table and the schema of the table. 241 | 242 | Args: 243 | connection: The database connection. 244 | table_name: The name of the table to create. 245 | """ 246 | 247 | query = _delete_table_query(table_name) 248 | logger.info("Dropping table %s", table_name) 249 | with connection.cursor() as cursor: 250 | cursor.execute(query) 251 | connection.commit() 252 | 253 | @staticmethod 254 | async def adrop_table( 255 | connection: psycopg.AsyncConnection, table_name: str, / 256 | ) -> None: 257 | """Delete the table schema in the database. 258 | 259 | WARNING: 260 | This will delete the given table from the database including 261 | all the database in the table and the schema of the table. 262 | 263 | Args: 264 | connection: Async database connection. 265 | table_name: The name of the table to create. 266 | """ 267 | query = _delete_table_query(table_name) 268 | logger.info("Dropping table %s", table_name) 269 | 270 | async with connection.cursor() as acur: 271 | await acur.execute(query) 272 | await connection.commit() 273 | 274 | def add_messages(self, messages: Sequence[BaseMessage]) -> None: 275 | """Add messages to the chat message history.""" 276 | if self._connection is None: 277 | raise ValueError( 278 | "Please initialize the PostgresChatMessageHistory " 279 | "with a sync connection or use the aadd_messages method instead." 280 | ) 281 | 282 | values = [ 283 | (self._session_id, json.dumps(message_to_dict(message))) 284 | for message in messages 285 | ] 286 | 287 | query = _insert_message_query(self._table_name) 288 | 289 | with self._connection.cursor() as cursor: 290 | cursor.executemany(query, values) 291 | self._connection.commit() 292 | 293 | async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: 294 | """Add messages to the chat message history.""" 295 | if self._aconnection is None: 296 | raise ValueError( 297 | "Please initialize the PostgresChatMessageHistory " 298 | "with an async connection or use the sync add_messages method instead." 299 | ) 300 | 301 | values = [ 302 | (self._session_id, json.dumps(message_to_dict(message))) 303 | for message in messages 304 | ] 305 | 306 | query = _insert_message_query(self._table_name) 307 | async with self._aconnection.cursor() as cursor: 308 | await cursor.executemany(query, values) 309 | await self._aconnection.commit() 310 | 311 | def get_messages(self) -> List[BaseMessage]: 312 | """Retrieve messages from the chat message history.""" 313 | if self._connection is None: 314 | raise ValueError( 315 | "Please initialize the PostgresChatMessageHistory " 316 | "with a sync connection or use the async aget_messages method instead." 317 | ) 318 | 319 | query = _get_messages_query(self._table_name) 320 | 321 | with self._connection.cursor() as cursor: 322 | cursor.execute(query, {"session_id": self._session_id}) 323 | items = [record[0] for record in cursor.fetchall()] 324 | 325 | messages = messages_from_dict(items) 326 | return messages 327 | 328 | async def aget_messages(self) -> List[BaseMessage]: 329 | """Retrieve messages from the chat message history.""" 330 | if self._aconnection is None: 331 | raise ValueError( 332 | "Please initialize the PostgresChatMessageHistory " 333 | "with an async connection or use the sync get_messages method instead." 334 | ) 335 | 336 | query = _get_messages_query(self._table_name) 337 | async with self._aconnection.cursor() as cursor: 338 | await cursor.execute(query, {"session_id": self._session_id}) 339 | items = [record[0] for record in await cursor.fetchall()] 340 | 341 | messages = messages_from_dict(items) 342 | return messages 343 | 344 | @property 345 | def messages(self) -> List[BaseMessage]: 346 | """The abstraction required a property.""" 347 | return self.get_messages() 348 | 349 | @messages.setter 350 | def messages(self, value: list[BaseMessage]) -> None: 351 | """Clear the stored messages and appends a list of messages.""" 352 | self.clear() 353 | self.add_messages(value) 354 | 355 | def clear(self) -> None: 356 | """Clear the chat message history for the GIVEN session.""" 357 | if self._connection is None: 358 | raise ValueError( 359 | "Please initialize the PostgresChatMessageHistory " 360 | "with a sync connection or use the async clear method instead." 361 | ) 362 | 363 | query = _delete_by_session_id_query(self._table_name) 364 | with self._connection.cursor() as cursor: 365 | cursor.execute(query, {"session_id": self._session_id}) 366 | self._connection.commit() 367 | 368 | async def aclear(self) -> None: 369 | """Clear the chat message history for the GIVEN session.""" 370 | if self._aconnection is None: 371 | raise ValueError( 372 | "Please initialize the PostgresChatMessageHistory " 373 | "with an async connection or use the sync clear method instead." 374 | ) 375 | 376 | query = _delete_by_session_id_query(self._table_name) 377 | async with self._aconnection.cursor() as cursor: 378 | await cursor.execute(query, {"session_id": self._session_id}) 379 | await self._aconnection.commit() 380 | -------------------------------------------------------------------------------- /langchain_postgres/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/langchain_postgres/py.typed -------------------------------------------------------------------------------- /langchain_postgres/translator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Union 2 | 3 | from langchain_core.structured_query import ( 4 | Comparator, 5 | Comparison, 6 | Operation, 7 | Operator, 8 | StructuredQuery, 9 | Visitor, 10 | ) 11 | 12 | 13 | class PGVectorTranslator(Visitor): 14 | """Translate `PGVector` internal query language elements to valid filters.""" 15 | 16 | allowed_operators = [Operator.AND, Operator.OR] 17 | """Subset of allowed logical operators.""" 18 | allowed_comparators = [ 19 | Comparator.EQ, 20 | Comparator.NE, 21 | Comparator.GT, 22 | Comparator.LT, 23 | Comparator.IN, 24 | Comparator.NIN, 25 | Comparator.CONTAIN, 26 | Comparator.LIKE, 27 | ] 28 | """Subset of allowed logical comparators.""" 29 | 30 | def _format_func(self, func: Union[Operator, Comparator]) -> str: 31 | self._validate_func(func) 32 | return f"${func.value}" 33 | 34 | def visit_operation(self, operation: Operation) -> Dict: 35 | args = [arg.accept(self) for arg in operation.arguments] 36 | return {self._format_func(operation.operator): args} 37 | 38 | def visit_comparison(self, comparison: Comparison) -> Dict: 39 | return { 40 | comparison.attribute: { 41 | self._format_func(comparison.comparator): comparison.value 42 | } 43 | } 44 | 45 | def visit_structured_query( 46 | self, structured_query: StructuredQuery 47 | ) -> Tuple[str, dict]: 48 | if structured_query.filter is None: 49 | kwargs = {} 50 | else: 51 | kwargs = {"filter": structured_query.filter.accept(self)} 52 | return structured_query.query, kwargs 53 | -------------------------------------------------------------------------------- /langchain_postgres/utils/pgvector_migrator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import warnings 4 | from typing import Any, AsyncIterator, Iterator, Optional, Sequence, TypeVar 5 | 6 | from sqlalchemy import RowMapping, text 7 | from sqlalchemy.exc import ProgrammingError, SQLAlchemyError 8 | 9 | from ..v2.engine import PGEngine 10 | from ..v2.vectorstores import PGVectorStore 11 | 12 | COLLECTIONS_TABLE = "langchain_pg_collection" 13 | EMBEDDINGS_TABLE = "langchain_pg_embedding" 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | async def __aget_collection_uuid( 19 | engine: PGEngine, 20 | collection_name: str, 21 | ) -> str: 22 | """ 23 | Get the collection uuid for a collection present in PGVector tables. 24 | 25 | Args: 26 | engine (PGEngine): The PG engine corresponding to the Database. 27 | collection_name (str): The name of the collection to get the uuid for. 28 | Returns: 29 | The uuid corresponding to the collection. 30 | """ 31 | query = f"SELECT name, uuid FROM {COLLECTIONS_TABLE} WHERE name = :collection_name" 32 | async with engine._pool.connect() as conn: 33 | result = await conn.execute( 34 | text(query), parameters={"collection_name": collection_name} 35 | ) 36 | result_map = result.mappings() 37 | result_fetch = result_map.fetchone() 38 | if result_fetch is None: 39 | raise ValueError(f"Collection, {collection_name} not found.") 40 | return result_fetch.uuid 41 | 42 | 43 | async def __aextract_pgvector_collection( 44 | engine: PGEngine, 45 | collection_name: str, 46 | batch_size: int = 1000, 47 | ) -> AsyncIterator[Sequence[RowMapping]]: 48 | """ 49 | Extract all data belonging to a PGVector collection. 50 | 51 | Args: 52 | engine (PGEngine): The PG engine corresponding to the Database. 53 | collection_name (str): The name of the collection to get the data for. 54 | batch_size (int): The batch size for collection extraction. 55 | Default: 1000. Optional. 56 | 57 | Yields: 58 | The data present in the collection. 59 | """ 60 | try: 61 | uuid_task = asyncio.create_task(__aget_collection_uuid(engine, collection_name)) 62 | query = f"SELECT * FROM {EMBEDDINGS_TABLE} WHERE collection_id = :id" 63 | async with engine._pool.connect() as conn: 64 | uuid = await uuid_task 65 | result_proxy = await conn.execute(text(query), parameters={"id": uuid}) 66 | while True: 67 | rows = result_proxy.fetchmany(size=batch_size) 68 | if not rows: 69 | break 70 | yield [row._mapping for row in rows] 71 | except ValueError: 72 | raise ValueError(f"Collection, {collection_name} does not exist.") 73 | except SQLAlchemyError as e: 74 | raise ProgrammingError( 75 | statement=f"Failed to extract data from collection '{collection_name}': {e}", 76 | params={"id": uuid}, 77 | orig=e, 78 | ) from e 79 | 80 | 81 | async def __concurrent_batch_insert( 82 | data_batches: AsyncIterator[Sequence[RowMapping]], 83 | vector_store: PGVectorStore, 84 | max_concurrency: int = 100, 85 | ) -> None: 86 | pending: set[Any] = set() 87 | async for batch_data in data_batches: 88 | pending.add( 89 | asyncio.ensure_future( 90 | vector_store.aadd_embeddings( 91 | texts=[data.document for data in batch_data], 92 | embeddings=[json.loads(data.embedding) for data in batch_data], 93 | metadatas=[data.cmetadata for data in batch_data], 94 | ids=[data.id for data in batch_data], 95 | ) 96 | ) 97 | ) 98 | if len(pending) >= max_concurrency: 99 | _, pending = await asyncio.wait( 100 | pending, return_when=asyncio.FIRST_COMPLETED 101 | ) 102 | if pending: 103 | await asyncio.wait(pending) 104 | 105 | 106 | async def __amigrate_pgvector_collection( 107 | engine: PGEngine, 108 | collection_name: str, 109 | vector_store: PGVectorStore, 110 | delete_pg_collection: Optional[bool] = False, 111 | insert_batch_size: int = 1000, 112 | ) -> None: 113 | """ 114 | Migrate all data present in a PGVector collection to use separate tables for each collection. 115 | The new data format is compatible with the PGVectoreStore interface. 116 | 117 | Args: 118 | engine (PGEngine): The PG engine corresponding to the Database. 119 | collection_name (str): The collection to migrate. 120 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table. 121 | delete_pg_collection (bool): An option to delete the original data upon migration. 122 | Default: False. Optional. 123 | insert_batch_size (int): Number of rows to insert at once in the table. 124 | Default: 1000. 125 | """ 126 | destination_table = vector_store.get_table_name() 127 | 128 | # Get row count in PGVector collection 129 | uuid_task = asyncio.create_task(__aget_collection_uuid(engine, collection_name)) 130 | query = ( 131 | f"SELECT COUNT(*) FROM {EMBEDDINGS_TABLE} WHERE collection_id=:collection_id" 132 | ) 133 | async with engine._pool.connect() as conn: 134 | uuid = await uuid_task 135 | result = await conn.execute(text(query), parameters={"collection_id": uuid}) 136 | result_map = result.mappings() 137 | collection_data_len = result_map.fetchone() 138 | if collection_data_len is None: 139 | warnings.warn(f"Collection, {collection_name} contains no elements.") 140 | return 141 | 142 | # Extract data from the collection and batch insert into the new table 143 | data_batches = __aextract_pgvector_collection( 144 | engine, collection_name, batch_size=insert_batch_size 145 | ) 146 | await __concurrent_batch_insert(data_batches, vector_store, max_concurrency=100) 147 | 148 | # Validate data migration 149 | query = f"SELECT COUNT(*) FROM {destination_table}" 150 | async with engine._pool.connect() as conn: 151 | result = await conn.execute(text(query)) 152 | result_map = result.mappings() 153 | table_size = result_map.fetchone() 154 | if not table_size: 155 | raise ValueError(f"Table: {destination_table} does not exist.") 156 | 157 | if collection_data_len["count"] != table_size["count"]: 158 | raise ValueError( 159 | "All data not yet migrated.\n" 160 | f"Original row count: {collection_data_len['count']}\n" 161 | f"Collection table, {destination_table} row count: {table_size['count']}" 162 | ) 163 | elif delete_pg_collection: 164 | # Delete PGVector data 165 | query = f"DELETE FROM {EMBEDDINGS_TABLE} WHERE collection_id=:collection_id" 166 | async with engine._pool.connect() as conn: 167 | await conn.execute(text(query), parameters={"collection_id": uuid}) 168 | await conn.commit() 169 | 170 | query = f"DELETE FROM {COLLECTIONS_TABLE} WHERE name=:collection_name" 171 | async with engine._pool.connect() as conn: 172 | await conn.execute( 173 | text(query), parameters={"collection_name": collection_name} 174 | ) 175 | await conn.commit() 176 | print(f"Successfully deleted PGVector collection, {collection_name}") 177 | 178 | 179 | async def __alist_pgvector_collection_names( 180 | engine: PGEngine, 181 | ) -> list[str]: 182 | """Lists all collection names present in PGVector table.""" 183 | try: 184 | query = f"SELECT name from {COLLECTIONS_TABLE}" 185 | async with engine._pool.connect() as conn: 186 | result = await conn.execute(text(query)) 187 | result_map = result.mappings() 188 | all_rows = result_map.fetchall() 189 | return [row["name"] for row in all_rows] 190 | except ProgrammingError as e: 191 | raise ValueError( 192 | "Please provide the correct collection table name: " + str(e) 193 | ) from e 194 | 195 | 196 | async def aextract_pgvector_collection( 197 | engine: PGEngine, 198 | collection_name: str, 199 | batch_size: int = 1000, 200 | ) -> AsyncIterator[Sequence[RowMapping]]: 201 | """ 202 | Extract all data belonging to a PGVector collection. 203 | 204 | Args: 205 | engine (PGEngine): The PG engine corresponding to the Database. 206 | collection_name (str): The name of the collection to get the data for. 207 | batch_size (int): The batch size for collection extraction. 208 | Default: 1000. Optional. 209 | 210 | Yields: 211 | The data present in the collection. 212 | """ 213 | iterator = __aextract_pgvector_collection(engine, collection_name, batch_size) 214 | while True: 215 | try: 216 | result = await engine._run_as_async(iterator.__anext__()) 217 | yield result 218 | except StopAsyncIteration: 219 | break 220 | 221 | 222 | async def alist_pgvector_collection_names( 223 | engine: PGEngine, 224 | ) -> list[str]: 225 | """Lists all collection names present in PGVector table.""" 226 | return await engine._run_as_async(__alist_pgvector_collection_names(engine)) 227 | 228 | 229 | async def amigrate_pgvector_collection( 230 | engine: PGEngine, 231 | collection_name: str, 232 | vector_store: PGVectorStore, 233 | delete_pg_collection: Optional[bool] = False, 234 | insert_batch_size: int = 1000, 235 | ) -> None: 236 | """ 237 | Migrate all data present in a PGVector collection to use separate tables for each collection. 238 | The new data format is compatible with the PGVectorStore interface. 239 | 240 | Args: 241 | engine (PGEngine): The PG engine corresponding to the Database. 242 | collection_name (str): The collection to migrate. 243 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table. 244 | use_json_metadata (bool): An option to keep the PGVector metadata as json in the new table. 245 | Default: False. Optional. 246 | delete_pg_collection (bool): An option to delete the original data upon migration. 247 | Default: False. Optional. 248 | insert_batch_size (int): Number of rows to insert at once in the table. 249 | Default: 1000. 250 | """ 251 | await engine._run_as_async( 252 | __amigrate_pgvector_collection( 253 | engine, 254 | collection_name, 255 | vector_store, 256 | delete_pg_collection, 257 | insert_batch_size, 258 | ) 259 | ) 260 | 261 | 262 | def extract_pgvector_collection( 263 | engine: PGEngine, 264 | collection_name: str, 265 | batch_size: int = 1000, 266 | ) -> Iterator[Sequence[RowMapping]]: 267 | """ 268 | Extract all data belonging to a PGVector collection. 269 | 270 | Args: 271 | engine (PGEngine): The PG engine corresponding to the Database. 272 | collection_name (str): The name of the collection to get the data for. 273 | batch_size (int): The batch size for collection extraction. 274 | Default: 1000. Optional. 275 | 276 | Yields: 277 | The data present in the collection. 278 | """ 279 | iterator = __aextract_pgvector_collection(engine, collection_name, batch_size) 280 | while True: 281 | try: 282 | result = engine._run_as_sync(iterator.__anext__()) 283 | yield result 284 | except StopAsyncIteration: 285 | break 286 | 287 | 288 | def list_pgvector_collection_names(engine: PGEngine) -> list[str]: 289 | """Lists all collection names present in PGVector table.""" 290 | return engine._run_as_sync(__alist_pgvector_collection_names(engine)) 291 | 292 | 293 | def migrate_pgvector_collection( 294 | engine: PGEngine, 295 | collection_name: str, 296 | vector_store: PGVectorStore, 297 | delete_pg_collection: Optional[bool] = False, 298 | insert_batch_size: int = 1000, 299 | ) -> None: 300 | """ 301 | Migrate all data present in a PGVector collection to use separate tables for each collection. 302 | The new data format is compatible with the PGVectorStore interface. 303 | 304 | Args: 305 | engine (PGEngine): The PG engine corresponding to the Database. 306 | collection_name (str): The collection to migrate. 307 | vector_store (PGVectorStore): The PGVectorStore object corresponding to the new collection table. 308 | delete_pg_collection (bool): An option to delete the original data upon migration. 309 | Default: False. Optional. 310 | insert_batch_size (int): Number of rows to insert at once in the table. 311 | Default: 1000. 312 | """ 313 | engine._run_as_sync( 314 | __amigrate_pgvector_collection( 315 | engine, 316 | collection_name, 317 | vector_store, 318 | delete_pg_collection, 319 | insert_batch_size, 320 | ) 321 | ) 322 | -------------------------------------------------------------------------------- /langchain_postgres/v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/langchain_postgres/v2/__init__.py -------------------------------------------------------------------------------- /langchain_postgres/v2/hybrid_search_config.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from dataclasses import dataclass, field 3 | from typing import Any, Callable, Optional, Sequence 4 | 5 | from sqlalchemy import RowMapping 6 | 7 | 8 | def weighted_sum_ranking( 9 | primary_search_results: Sequence[RowMapping], 10 | secondary_search_results: Sequence[RowMapping], 11 | primary_results_weight: float = 0.5, 12 | secondary_results_weight: float = 0.5, 13 | fetch_top_k: int = 4, 14 | ) -> Sequence[dict[str, Any]]: 15 | """ 16 | Ranks documents using a weighted sum of scores from two sources. 17 | 18 | Args: 19 | primary_search_results: A list of (document, distance) tuples from 20 | the primary search. 21 | secondary_search_results: A list of (document, distance) tuples from 22 | the secondary search. 23 | primary_results_weight: The weight for the primary source's scores. 24 | Defaults to 0.5. 25 | secondary_results_weight: The weight for the secondary source's scores. 26 | Defaults to 0.5. 27 | fetch_top_k: The number of documents to fetch after merging the results. 28 | Defaults to 4. 29 | 30 | Returns: 31 | A list of (document, distance) tuples, sorted by weighted_score in 32 | descending order. 33 | """ 34 | 35 | # stores computed metric with provided distance metric and weights 36 | weighted_scores: dict[str, dict[str, Any]] = {} 37 | 38 | # Process results from primary source 39 | for row in primary_search_results: 40 | values = list(row.values()) 41 | doc_id = str(values[0]) # first value is doc_id 42 | distance = float(values[-1]) # type: ignore # last value is distance 43 | row_values = dict(row) 44 | row_values["distance"] = primary_results_weight * distance 45 | weighted_scores[doc_id] = row_values 46 | 47 | # Process results from secondary source, 48 | # adding to existing scores or creating new ones 49 | for row in secondary_search_results: 50 | values = list(row.values()) 51 | doc_id = str(values[0]) # first value is doc_id 52 | distance = float(values[-1]) # type: ignore # last value is distance 53 | primary_score = ( 54 | weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0 55 | ) 56 | row_values = dict(row) 57 | row_values["distance"] = distance * secondary_results_weight + primary_score 58 | weighted_scores[doc_id] = row_values 59 | 60 | # Sort the results by weighted score in descending order 61 | ranked_results = sorted( 62 | weighted_scores.values(), key=lambda item: item["distance"], reverse=True 63 | ) 64 | return ranked_results[:fetch_top_k] 65 | 66 | 67 | def reciprocal_rank_fusion( 68 | primary_search_results: Sequence[RowMapping], 69 | secondary_search_results: Sequence[RowMapping], 70 | rrf_k: float = 60, 71 | fetch_top_k: int = 4, 72 | ) -> Sequence[dict[str, Any]]: 73 | """ 74 | Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources. 75 | 76 | Args: 77 | primary_search_results: A list of (document, distance) tuples from 78 | the primary search. 79 | secondary_search_results: A list of (document, distance) tuples from 80 | the secondary search. 81 | rrf_k: The RRF parameter k. 82 | Defaults to 60. 83 | fetch_top_k: The number of documents to fetch after merging the results. 84 | Defaults to 4. 85 | 86 | Returns: 87 | A list of (document_id, rrf_score) tuples, sorted by rrf_score 88 | in descending order. 89 | """ 90 | rrf_scores: dict[str, dict[str, Any]] = {} 91 | 92 | # Process results from primary source 93 | for rank, row in enumerate( 94 | sorted(primary_search_results, key=lambda item: item["distance"], reverse=True) 95 | ): 96 | values = list(row.values()) 97 | doc_id = str(values[0]) 98 | row_values = dict(row) 99 | primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 100 | primary_score += 1.0 / (rank + rrf_k) 101 | row_values["distance"] = primary_score 102 | rrf_scores[doc_id] = row_values 103 | 104 | # Process results from secondary source 105 | for rank, row in enumerate( 106 | sorted( 107 | secondary_search_results, key=lambda item: item["distance"], reverse=True 108 | ) 109 | ): 110 | values = list(row.values()) 111 | doc_id = str(values[0]) 112 | row_values = dict(row) 113 | secondary_score = ( 114 | rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0 115 | ) 116 | secondary_score += 1.0 / (rank + rrf_k) 117 | row_values["distance"] = secondary_score 118 | rrf_scores[doc_id] = row_values 119 | 120 | # Sort the results by rrf score in descending order 121 | # Sort the results by weighted score in descending order 122 | ranked_results = sorted( 123 | rrf_scores.values(), key=lambda item: item["distance"], reverse=True 124 | ) 125 | # Extract only the RowMapping for the top results 126 | return ranked_results[:fetch_top_k] 127 | 128 | 129 | @dataclass 130 | class HybridSearchConfig(ABC): 131 | """ 132 | AlloyDB Vector Store Hybrid Search Config. 133 | 134 | Queries might be slow if the hybrid search column does not exist. 135 | For best hybrid search performance, consider creating a TSV column 136 | and adding GIN index. 137 | """ 138 | 139 | tsv_column: Optional[str] = "" 140 | tsv_lang: Optional[str] = "pg_catalog.english" 141 | fts_query: Optional[str] = "" 142 | fusion_function: Callable[ 143 | [Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any] 144 | ] = weighted_sum_ranking # Updated default 145 | fusion_function_parameters: dict[str, Any] = field(default_factory=dict) 146 | primary_top_k: int = 4 147 | secondary_top_k: int = 4 148 | index_name: str = "langchain_tsv_index" 149 | index_type: str = "GIN" 150 | -------------------------------------------------------------------------------- /langchain_postgres/v2/indexes.py: -------------------------------------------------------------------------------- 1 | """Index class to add vector indexes on the PGVectorStore. 2 | 3 | Learn more about vector indexes at https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing 4 | """ 5 | 6 | import enum 7 | import re 8 | import warnings 9 | from abc import ABC, abstractmethod 10 | from dataclasses import dataclass, field 11 | from typing import Optional 12 | 13 | 14 | @dataclass 15 | class StrategyMixin: 16 | operator: str 17 | search_function: str 18 | index_function: str 19 | 20 | 21 | class DistanceStrategy(StrategyMixin, enum.Enum): 22 | """Enumerator of the Distance strategies.""" 23 | 24 | EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops" 25 | COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops" 26 | INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops" 27 | 28 | 29 | DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE 30 | DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex" 31 | 32 | 33 | def validate_identifier(identifier: str) -> None: 34 | if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is None: 35 | raise ValueError( 36 | f"Invalid identifier: {identifier}. Identifiers must start with a letter or underscore, and subsequent characters can be letters, digits, or underscores." 37 | ) 38 | 39 | 40 | @dataclass 41 | class BaseIndex(ABC): 42 | """ 43 | Abstract base class for defining vector indexes. 44 | 45 | Attributes: 46 | name (Optional[str]): A human-readable name for the index. Defaults to None. 47 | index_type (str): A string identifying the type of index. Defaults to "base". 48 | distance_strategy (DistanceStrategy): The strategy used to calculate distances 49 | between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE. 50 | partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None. 51 | extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None. 52 | """ 53 | 54 | name: Optional[str] = None 55 | index_type: str = "base" 56 | distance_strategy: DistanceStrategy = field( 57 | default_factory=lambda: DistanceStrategy.COSINE_DISTANCE 58 | ) 59 | partial_indexes: Optional[list[str]] = None 60 | extension_name: Optional[str] = None 61 | 62 | @abstractmethod 63 | def index_options(self) -> str: 64 | """Set index query options for vector store initialization.""" 65 | raise NotImplementedError( 66 | "index_options method must be implemented by subclass" 67 | ) 68 | 69 | def get_index_function(self) -> str: 70 | return self.distance_strategy.index_function 71 | 72 | def __post_init__(self) -> None: 73 | """Check if initialization parameters are valid. 74 | 75 | Raises: 76 | ValueError: extension_name is a valid postgreSQL identifier 77 | """ 78 | 79 | if self.extension_name: 80 | validate_identifier(self.extension_name) 81 | if self.index_type: 82 | validate_identifier(self.index_type) 83 | 84 | 85 | @dataclass 86 | class ExactNearestNeighbor(BaseIndex): 87 | index_type: str = "exactnearestneighbor" 88 | 89 | 90 | @dataclass 91 | class QueryOptions(ABC): 92 | @abstractmethod 93 | def to_parameter(self) -> list[str]: 94 | """Convert index attributes to list of configurations.""" 95 | raise NotImplementedError("to_parameter method must be implemented by subclass") 96 | 97 | @abstractmethod 98 | def to_string(self) -> str: 99 | """Convert index attributes to string.""" 100 | raise NotImplementedError("to_string method must be implemented by subclass") 101 | 102 | 103 | @dataclass 104 | class HNSWIndex(BaseIndex): 105 | index_type: str = "hnsw" 106 | m: int = 16 107 | ef_construction: int = 64 108 | 109 | def index_options(self) -> str: 110 | """Set index query options for vector store initialization.""" 111 | return f"(m = {self.m}, ef_construction = {self.ef_construction})" 112 | 113 | 114 | @dataclass 115 | class HNSWQueryOptions(QueryOptions): 116 | ef_search: int = 40 117 | 118 | def to_parameter(self) -> list[str]: 119 | """Convert index attributes to list of configurations.""" 120 | return [f"hnsw.ef_search = {self.ef_search}"] 121 | 122 | def to_string(self) -> str: 123 | """Convert index attributes to string.""" 124 | warnings.warn( 125 | "to_string is deprecated, use to_parameter instead.", 126 | DeprecationWarning, 127 | ) 128 | return f"hnsw.ef_search = {self.ef_search}" 129 | 130 | 131 | @dataclass 132 | class IVFFlatIndex(BaseIndex): 133 | index_type: str = "ivfflat" 134 | lists: int = 100 135 | 136 | def index_options(self) -> str: 137 | """Set index query options for vector store initialization.""" 138 | return f"(lists = {self.lists})" 139 | 140 | 141 | @dataclass 142 | class IVFFlatQueryOptions(QueryOptions): 143 | probes: int = 1 144 | 145 | def to_parameter(self) -> list[str]: 146 | """Convert index attributes to list of configurations.""" 147 | return [f"ivfflat.probes = {self.probes}"] 148 | 149 | def to_string(self) -> str: 150 | """Convert index attributes to string.""" 151 | warnings.warn( 152 | "to_string is deprecated, use to_parameter instead.", 153 | DeprecationWarning, 154 | ) 155 | return f"ivfflat.probes = {self.probes}" 156 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "langchain-postgres" 3 | version = "0.0.14" 4 | description = "An integration package connecting Postgres and LangChain" 5 | authors = [] 6 | readme = "README.md" 7 | repository = "https://github.com/langchain-ai/langchain-postgres" 8 | requires-python = ">=3.9" 9 | license = "MIT" 10 | dependencies = [ 11 | "asyncpg>=0.30.0", 12 | "langchain-core>=0.2.13,<0.4.0", 13 | "pgvector>=0.2.5,<0.4", 14 | "psycopg>=3,<4", 15 | "psycopg-pool>=3.2.1,<4", 16 | "sqlalchemy>=2,<3", 17 | "numpy>=1.21,<3", 18 | ] 19 | 20 | [tool.poetry.urls] 21 | "Source Code" = "https://github.com/langchain-ai/langchain-postgres/tree/master/langchain_postgres" 22 | 23 | [dependency-groups] 24 | test = [ 25 | "langchain-tests==0.3.7", 26 | "mypy>=1.15.0", 27 | "pytest>=8.3.4", 28 | "pytest-asyncio>=0.25.3", 29 | "pytest-cov>=6.0.0", 30 | "pytest-mock>=3.14.0", 31 | "pytest-socket>=0.7.0", 32 | "pytest-timeout>=2.3.1", 33 | "ruff>=0.9.7", 34 | ] 35 | 36 | [tool.ruff.lint] 37 | select = [ 38 | "E", # pycodestyle 39 | "F", # pyflakes 40 | "I", # isort 41 | "T201", # print 42 | ] 43 | 44 | [tool.mypy] 45 | disallow_untyped_defs = "True" 46 | 47 | [tool.coverage.run] 48 | omit = ["tests/*"] 49 | 50 | [build-system] 51 | requires = ["hatchling"] 52 | build-backend = "hatchling.build" 53 | 54 | [tool.pytest.ini_options] 55 | # --strict-markers will raise errors on unknown marks. 56 | # https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks 57 | # 58 | # https://docs.pytest.org/en/7.1.x/reference/reference.html 59 | # --strict-config any warnings encountered while parsing the `pytest` 60 | # section of the configuration file raise errors. 61 | # 62 | # https://github.com/tophat/syrupy 63 | # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. 64 | addopts = "--strict-markers --strict-config --durations=5" 65 | # Global timeout for all tests. There should be a good reason for a test to 66 | # takemore than 30 seconds. 67 | timeout = 30 68 | # Registering custom markers. 69 | # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers 70 | markers = [] 71 | asyncio_mode = "auto" 72 | 73 | 74 | [tool.codespell] 75 | skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig' 76 | ignore-regex = '.*(Stati Uniti|Tense=Pres).*' 77 | ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' 78 | 79 | [tool.ruff.lint.extend-per-file-ignores] 80 | "tests/unit_tests/v2/**/*.py" = [ 81 | "E501", 82 | ] 83 | 84 | "langchain_postgres/v2/**/*.py" = [ 85 | "E501", 86 | ] 87 | 88 | "langchain_postgres/utils/**/*.py" = [ 89 | "E501", 90 | "T201", # Allow print 91 | ] 92 | 93 | 94 | "examples/**/*.ipynb" = [ 95 | "E501", 96 | "T201", # Allow print 97 | ] 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /security.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Reporting OSS Vulnerabilities 4 | 5 | LangChain is partnered with [huntr by Protect AI](https://huntr.com/) to provide 6 | a bounty program for our open source projects. 7 | 8 | Please report security vulnerabilities associated with the LangChain 9 | open source projects by visiting the following link: 10 | 11 | [https://huntr.com/bounties/disclose/](https://huntr.com/bounties/disclose/?target=https%3A%2F%2Fgithub.com%2Flangchain-ai%2Flangchain&validSearch=true) 12 | 13 | Before reporting a vulnerability, please review: 14 | 15 | 1) In-Scope Targets and Out-of-Scope Targets below. 16 | 2) The [langchain-ai/langchain](https://python.langchain.com/docs/contributing/repo_structure) monorepo structure. 17 | 3) LangChain [security guidelines](https://python.langchain.com/docs/security) to 18 | understand what we consider to be a security vulnerability vs. developer 19 | responsibility. 20 | 21 | ### In-Scope Targets 22 | 23 | The following packages and repositories are eligible for bug bounties: 24 | 25 | - langchain-core 26 | - langchain (see exceptions) 27 | - langchain-community (see exceptions) 28 | - langgraph 29 | - langserve 30 | 31 | ### Out of Scope Targets 32 | 33 | All out of scope targets defined by huntr as well as: 34 | 35 | - **langchain-experimental**: This repository is for experimental code and is not 36 | eligible for bug bounties, bug reports to it will be marked as interesting or waste of 37 | time and published with no bounty attached. 38 | - **tools**: Tools in either langchain or langchain-community are not eligible for bug 39 | bounties. This includes the following directories 40 | - langchain/tools 41 | - langchain-community/tools 42 | - Please review our [security guidelines](https://python.langchain.com/docs/security) 43 | for more details, but generally tools interact with the real world. Developers are 44 | expected to understand the security implications of their code and are responsible 45 | for the security of their tools. 46 | - Code documented with security notices. This will be decided done on a case by 47 | case basis, but likely will not be eligible for a bounty as the code is already 48 | documented with guidelines for developers that should be followed for making their 49 | application secure. 50 | - Any LangSmith related repositories or APIs see below. 51 | 52 | ## Reporting LangSmith Vulnerabilities 53 | 54 | Please report security vulnerabilities associated with LangSmith by email to `security@langchain.dev`. 55 | 56 | - LangSmith site: https://smith.langchain.com 57 | - SDK client: https://github.com/langchain-ai/langsmith-sdk 58 | 59 | ### Other Security Concerns 60 | 61 | For any other security concerns, please contact us at `security@langchain.dev`. 62 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/fake_embeddings.py: -------------------------------------------------------------------------------- 1 | """Copied from community.""" 2 | 3 | from typing import List 4 | 5 | from langchain_core.embeddings import Embeddings 6 | 7 | fake_texts = ["foo", "bar", "baz"] 8 | 9 | 10 | class FakeEmbeddings(Embeddings): 11 | """Fake embeddings functionality for testing.""" 12 | 13 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 14 | """Return simple embeddings. 15 | Embeddings encode each text as its index.""" 16 | return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] 17 | 18 | async def aembed_documents(self, texts: List[str]) -> List[List[float]]: 19 | return self.embed_documents(texts) 20 | 21 | def embed_query(self, text: str) -> List[float]: 22 | """Return constant query embeddings. 23 | Embeddings are identical to embed_documents(texts)[0]. 24 | Distance to each text will be that text's index, 25 | as it was passed to embed_documents.""" 26 | return [float(1.0)] * 9 + [float(0.0)] 27 | 28 | async def aembed_query(self, text: str) -> List[float]: 29 | return self.embed_query(text) 30 | -------------------------------------------------------------------------------- /tests/unit_tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/fixtures/filtering_test_cases.py: -------------------------------------------------------------------------------- 1 | """Module needs to move to a stasndalone package.""" 2 | 3 | from langchain_core.documents import Document 4 | 5 | metadatas = [ 6 | { 7 | "name": "adam", 8 | "date": "2021-01-01", 9 | "count": 1, 10 | "is_active": True, 11 | "tags": ["a", "b"], 12 | "location": [1.0, 2.0], 13 | "id": 1, 14 | "height": 10.0, # Float column 15 | "happiness": 0.9, # Float column 16 | "sadness": 0.1, # Float column 17 | }, 18 | { 19 | "name": "bob", 20 | "date": "2021-01-02", 21 | "count": 2, 22 | "is_active": False, 23 | "tags": ["b", "c"], 24 | "location": [2.0, 3.0], 25 | "id": 2, 26 | "height": 5.7, # Float column 27 | "happiness": 0.8, # Float column 28 | "sadness": 0.1, # Float column 29 | }, 30 | { 31 | "name": "jane", 32 | "date": "2021-01-01", 33 | "count": 3, 34 | "is_active": True, 35 | "tags": ["b", "d"], 36 | "location": [3.0, 4.0], 37 | "id": 3, 38 | "height": 2.4, # Float column 39 | "happiness": None, 40 | # Sadness missing intentionally 41 | }, 42 | ] 43 | texts = ["id {id}".format(id=metadata["id"]) for metadata in metadatas] 44 | 45 | DOCUMENTS = [ 46 | Document(page_content=text, metadata=metadata) 47 | for text, metadata in zip(texts, metadatas) 48 | ] 49 | 50 | 51 | TYPE_1_FILTERING_TEST_CASES = [ 52 | # These tests only involve equality checks 53 | ( 54 | {"id": 1}, 55 | [1], 56 | ), 57 | # String field 58 | ( 59 | # check name 60 | {"name": "adam"}, 61 | [1], 62 | ), 63 | # Boolean fields 64 | ( 65 | {"is_active": True}, 66 | [1, 3], 67 | ), 68 | ( 69 | {"is_active": False}, 70 | [2], 71 | ), 72 | # And semantics for top level filtering 73 | ( 74 | {"id": 1, "is_active": True}, 75 | [1], 76 | ), 77 | ( 78 | {"id": 1, "is_active": False}, 79 | [], 80 | ), 81 | ] 82 | 83 | TYPE_2_FILTERING_TEST_CASES = [ 84 | # These involve equality checks and other operators 85 | # like $ne, $gt, $gte, $lt, $lte 86 | ( 87 | {"id": 1}, 88 | [1], 89 | ), 90 | ( 91 | {"id": {"$ne": 1}}, 92 | [2, 3], 93 | ), 94 | ( 95 | {"id": {"$gt": 1}}, 96 | [2, 3], 97 | ), 98 | ( 99 | {"id": {"$gte": 1}}, 100 | [1, 2, 3], 101 | ), 102 | ( 103 | {"id": {"$lt": 1}}, 104 | [], 105 | ), 106 | ( 107 | {"id": {"$lte": 1}}, 108 | [1], 109 | ), 110 | # Repeat all the same tests with name (string column) 111 | ( 112 | {"name": "adam"}, 113 | [1], 114 | ), 115 | ( 116 | {"name": "bob"}, 117 | [2], 118 | ), 119 | ( 120 | {"name": {"$eq": "adam"}}, 121 | [1], 122 | ), 123 | ( 124 | {"name": {"$ne": "adam"}}, 125 | [2, 3], 126 | ), 127 | # And also gt, gte, lt, lte relying on lexicographical ordering 128 | ( 129 | {"name": {"$gt": "jane"}}, 130 | [], 131 | ), 132 | ( 133 | {"name": {"$gte": "jane"}}, 134 | [3], 135 | ), 136 | ( 137 | {"name": {"$lt": "jane"}}, 138 | [1, 2], 139 | ), 140 | ( 141 | {"name": {"$lte": "jane"}}, 142 | [1, 2, 3], 143 | ), 144 | ( 145 | {"is_active": {"$eq": True}}, 146 | [1, 3], 147 | ), 148 | ( 149 | {"is_active": {"$ne": True}}, 150 | [2], 151 | ), 152 | # Test float column. 153 | ( 154 | {"height": {"$gt": 5.0}}, 155 | [1, 2], 156 | ), 157 | ( 158 | {"height": {"$gte": 5.0}}, 159 | [1, 2], 160 | ), 161 | ( 162 | {"height": {"$lt": 5.0}}, 163 | [3], 164 | ), 165 | ( 166 | {"height": {"$lte": 5.8}}, 167 | [2, 3], 168 | ), 169 | ] 170 | 171 | TYPE_3_FILTERING_TEST_CASES = [ 172 | # These involve usage of AND, OR and NOT operators 173 | ( 174 | {"$or": [{"id": 1}, {"id": 2}]}, 175 | [1, 2], 176 | ), 177 | ( 178 | {"$or": [{"id": 1}, {"name": "bob"}]}, 179 | [1, 2], 180 | ), 181 | ( 182 | {"$and": [{"id": 1}, {"id": 2}]}, 183 | [], 184 | ), 185 | ( 186 | {"$or": [{"id": 1}, {"id": 2}, {"id": 3}]}, 187 | [1, 2, 3], 188 | ), 189 | # Test for $not operator 190 | ( 191 | {"$not": {"id": 1}}, 192 | [2, 3], 193 | ), 194 | ( 195 | {"$not": [{"id": 1}]}, 196 | [2, 3], 197 | ), 198 | ( 199 | {"$not": {"name": "adam"}}, 200 | [2, 3], 201 | ), 202 | ( 203 | {"$not": [{"name": "adam"}]}, 204 | [2, 3], 205 | ), 206 | ( 207 | {"$not": {"is_active": True}}, 208 | [2], 209 | ), 210 | ( 211 | {"$not": [{"is_active": True}]}, 212 | [2], 213 | ), 214 | ( 215 | {"$not": {"height": {"$gt": 5.0}}}, 216 | [3], 217 | ), 218 | ( 219 | {"$not": [{"height": {"$gt": 5.0}}]}, 220 | [3], 221 | ), 222 | ] 223 | 224 | TYPE_4_FILTERING_TEST_CASES = [ 225 | # These involve special operators like $in, $nin, $between 226 | # Test between 227 | ( 228 | {"id": {"$between": (1, 2)}}, 229 | [1, 2], 230 | ), 231 | ( 232 | {"id": {"$between": (1, 1)}}, 233 | [1], 234 | ), 235 | # Test in 236 | ( 237 | {"name": {"$in": ["adam", "bob"]}}, 238 | [1, 2], 239 | ), 240 | # With numeric fields 241 | ( 242 | {"id": {"$in": [1, 2]}}, 243 | [1, 2], 244 | ), 245 | # Test nin 246 | ( 247 | {"name": {"$nin": ["adam", "bob"]}}, 248 | [3], 249 | ), 250 | ## with numeric fields 251 | ( 252 | {"id": {"$nin": [1, 2]}}, 253 | [3], 254 | ), 255 | ] 256 | 257 | TYPE_5_FILTERING_TEST_CASES = [ 258 | # These involve special operators like $like, $ilike that 259 | # may be specified to certain databases. 260 | ( 261 | {"name": {"$like": "a%"}}, 262 | [1], 263 | ), 264 | ( 265 | {"name": {"$like": "%a%"}}, # adam and jane 266 | [1, 3], 267 | ), 268 | ] 269 | 270 | TYPE_6_FILTERING_TEST_CASES = [ 271 | # These involve the special operator $exists 272 | ( 273 | {"happiness": {"$exists": False}}, 274 | [], 275 | ), 276 | ( 277 | {"happiness": {"$exists": True}}, 278 | [1, 2, 3], 279 | ), 280 | ( 281 | {"sadness": {"$exists": False}}, 282 | [3], 283 | ), 284 | ( 285 | {"sadness": {"$exists": True}}, 286 | [1, 2], 287 | ), 288 | ] 289 | -------------------------------------------------------------------------------- /tests/unit_tests/fixtures/metadata_filtering_data.py: -------------------------------------------------------------------------------- 1 | METADATAS = [ 2 | { 3 | "name": "Wireless Headphones", 4 | "code": "WH001", 5 | "price": 149.99, 6 | "is_available": True, 7 | "release_date": "2023-10-26", 8 | "tags": ["audio", "wireless", "electronics"], 9 | "dimensions": [18.5, 7.2, 21.0], 10 | "inventory_location": [101, 102], 11 | "available_quantity": 50, 12 | }, 13 | { 14 | "name": "Ergonomic Office Chair", 15 | "code": "EC002", 16 | "price": 299.00, 17 | "is_available": True, 18 | "release_date": "2023-08-15", 19 | "tags": ["furniture", "office", "ergonomic"], 20 | "dimensions": [65.0, 60.0, 110.0], 21 | "inventory_location": [201], 22 | "available_quantity": 10, 23 | }, 24 | { 25 | "name": "Stainless Steel Water Bottle", 26 | "code": "WB003", 27 | "price": 25.50, 28 | "is_available": False, 29 | "release_date": "2024-01-05", 30 | "tags": ["hydration", "eco-friendly", "kitchen"], 31 | "dimensions": [7.5, 7.5, 25.0], 32 | "available_quantity": 0, 33 | }, 34 | { 35 | "name": "Smart Fitness Tracker", 36 | "code": "FT004", 37 | "price": 79.95, 38 | "is_available": True, 39 | "release_date": "2023-11-12", 40 | "tags": ["fitness", "wearable", "technology"], 41 | "dimensions": [2.0, 1.0, 25.0], 42 | "inventory_location": [401], 43 | "available_quantity": 100, 44 | }, 45 | ] 46 | 47 | FILTERING_TEST_CASES = [ 48 | # These tests only involve equality checks 49 | ( 50 | {"code": "FT004"}, 51 | ["FT004"], 52 | ), 53 | # String field 54 | ( 55 | # check name 56 | {"name": "Smart Fitness Tracker"}, 57 | ["FT004"], 58 | ), 59 | # Boolean fields 60 | ( 61 | {"is_available": True}, 62 | ["WH001", "FT004", "EC002"], 63 | ), 64 | # And semantics for top level filtering 65 | ( 66 | {"code": "WH001", "is_available": True}, 67 | ["WH001"], 68 | ), 69 | # These involve equality checks and other operators 70 | # like $ne, $gt, $gte, $lt, $lte 71 | ( 72 | {"available_quantity": {"$eq": 10}}, 73 | ["EC002"], 74 | ), 75 | ( 76 | {"available_quantity": {"$ne": 0}}, 77 | ["WH001", "FT004", "EC002"], 78 | ), 79 | ( 80 | {"available_quantity": {"$gt": 60}}, 81 | ["FT004"], 82 | ), 83 | ( 84 | {"available_quantity": {"$gte": 50}}, 85 | ["WH001", "FT004"], 86 | ), 87 | ( 88 | {"available_quantity": {"$lt": 5}}, 89 | ["WB003"], 90 | ), 91 | ( 92 | {"available_quantity": {"$lte": 10}}, 93 | ["WB003", "EC002"], 94 | ), 95 | # Repeat all the same tests with name (string column) 96 | ( 97 | {"code": {"$eq": "WH001"}}, 98 | ["WH001"], 99 | ), 100 | ( 101 | {"code": {"$ne": "WB003"}}, 102 | ["WH001", "FT004", "EC002"], 103 | ), 104 | # And also gt, gte, lt, lte relying on lexicographical ordering 105 | ( 106 | {"name": {"$gt": "Wireless Headphones"}}, 107 | [], 108 | ), 109 | ( 110 | {"name": {"$gte": "Wireless Headphones"}}, 111 | ["WH001"], 112 | ), 113 | ( 114 | {"name": {"$lt": "Smart Fitness Tracker"}}, 115 | ["EC002"], 116 | ), 117 | ( 118 | {"name": {"$lte": "Smart Fitness Tracker"}}, 119 | ["FT004", "EC002"], 120 | ), 121 | ( 122 | {"is_available": {"$eq": True}}, 123 | ["WH001", "FT004", "EC002"], 124 | ), 125 | ( 126 | {"is_available": {"$ne": True}}, 127 | ["WB003"], 128 | ), 129 | # Test float column. 130 | ( 131 | {"price": {"$gt": 200.0}}, 132 | ["EC002"], 133 | ), 134 | ( 135 | {"price": {"$gte": 149.99}}, 136 | ["WH001", "EC002"], 137 | ), 138 | ( 139 | {"price": {"$lt": 50.0}}, 140 | ["WB003"], 141 | ), 142 | ( 143 | {"price": {"$lte": 79.95}}, 144 | ["FT004", "WB003"], 145 | ), 146 | # These involve usage of AND, OR and NOT operators 147 | ( 148 | {"$or": [{"code": "WH001"}, {"code": "EC002"}]}, 149 | ["WH001", "EC002"], 150 | ), 151 | ( 152 | {"$or": [{"code": "WH001"}, {"available_quantity": 10}]}, 153 | ["WH001", "EC002"], 154 | ), 155 | ( 156 | {"$and": [{"code": "WH001"}, {"code": "EC002"}]}, 157 | [], 158 | ), 159 | # Test for $not operator 160 | ( 161 | {"$not": {"code": "WB003"}}, 162 | ["WH001", "FT004", "EC002"], 163 | ), 164 | ( 165 | {"$not": [{"code": "WB003"}]}, 166 | ["WH001", "FT004", "EC002"], 167 | ), 168 | ( 169 | {"$not": {"available_quantity": 0}}, 170 | ["WH001", "FT004", "EC002"], 171 | ), 172 | ( 173 | {"$not": [{"available_quantity": 0}]}, 174 | ["WH001", "FT004", "EC002"], 175 | ), 176 | ( 177 | {"$not": {"is_available": True}}, 178 | ["WB003"], 179 | ), 180 | ( 181 | {"$not": [{"is_available": True}]}, 182 | ["WB003"], 183 | ), 184 | ( 185 | {"$not": {"price": {"$gt": 150.0}}}, 186 | ["WH001", "FT004", "WB003"], 187 | ), 188 | ( 189 | {"$not": [{"price": {"$gt": 150.0}}]}, 190 | ["WH001", "FT004", "WB003"], 191 | ), 192 | # These involve special operators like $in, $nin, $between 193 | # Test between 194 | ( 195 | {"available_quantity": {"$between": (40, 60)}}, 196 | ["WH001"], 197 | ), 198 | # Test in 199 | ( 200 | {"name": {"$in": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, 201 | ["FT004", "WB003"], 202 | ), 203 | # With numeric fields 204 | ( 205 | {"available_quantity": {"$in": [0, 10]}}, 206 | ["WB003", "EC002"], 207 | ), 208 | # Test nin 209 | ( 210 | {"name": {"$nin": ["Smart Fitness Tracker", "Stainless Steel Water Bottle"]}}, 211 | ["WH001", "EC002"], 212 | ), 213 | ## with numeric fields 214 | ( 215 | {"available_quantity": {"$nin": [50, 0, 10]}}, 216 | ["FT004"], 217 | ), 218 | # These involve special operators like $like, $ilike that 219 | # may be specified to certain databases. 220 | ( 221 | {"name": {"$like": "Wireless%"}}, 222 | ["WH001"], 223 | ), 224 | ( 225 | {"name": {"$like": "%less%"}}, # adam and jane 226 | ["WH001", "WB003"], 227 | ), 228 | # These involve the special operator $exists 229 | ( 230 | {"tags": {"$exists": False}}, 231 | [], 232 | ), 233 | ( 234 | {"inventory_location": {"$exists": False}}, 235 | ["WB003"], 236 | ), 237 | ] 238 | 239 | NEGATIVE_TEST_CASES = [ 240 | {"$nor": [{"code": "WH001"}, {"code": "EC002"}]}, 241 | {"$and": {"is_available": True}}, 242 | {"is_available": {"$and": True}}, 243 | {"is_available": {"name": "{Wireless Headphones", "code": "EC002"}}, 244 | {"my column": {"$and": True}}, 245 | {"is_available": {"code": "WH001"}}, 246 | {"$and": {}}, 247 | {"$and": []}, 248 | {"$not": True}, 249 | ] 250 | -------------------------------------------------------------------------------- /tests/unit_tests/query_constructors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/query_constructors/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/query_constructors/test_pgvector.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import pytest as pytest 4 | from langchain_core.structured_query import ( 5 | Comparator, 6 | Comparison, 7 | Operation, 8 | Operator, 9 | StructuredQuery, 10 | ) 11 | 12 | from langchain_postgres import PGVectorTranslator 13 | 14 | DEFAULT_TRANSLATOR = PGVectorTranslator() 15 | 16 | 17 | def test_visit_comparison() -> None: 18 | comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1) 19 | expected = {"foo": {"$lt": 1}} 20 | actual = DEFAULT_TRANSLATOR.visit_comparison(comp) 21 | assert expected == actual 22 | 23 | 24 | @pytest.mark.skip("Not implemented") 25 | def test_visit_operation() -> None: 26 | op = Operation( 27 | operator=Operator.AND, 28 | arguments=[ 29 | Comparison(comparator=Comparator.LT, attribute="foo", value=2), 30 | Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), 31 | Comparison(comparator=Comparator.GT, attribute="abc", value=2.0), 32 | ], 33 | ) 34 | expected = { 35 | "foo": {"$lt": 2}, 36 | "bar": {"$eq": "baz"}, 37 | "abc": {"$gt": 2.0}, 38 | } 39 | actual = DEFAULT_TRANSLATOR.visit_operation(op) 40 | assert expected == actual 41 | 42 | 43 | def test_visit_structured_query() -> None: 44 | query = "What is the capital of France?" 45 | structured_query = StructuredQuery( 46 | query=query, 47 | filter=None, 48 | ) 49 | expected: Tuple[str, Dict] = (query, {}) 50 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) 51 | assert expected == actual 52 | 53 | comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1) 54 | structured_query = StructuredQuery( 55 | query=query, 56 | filter=comp, 57 | ) 58 | expected = (query, {"filter": {"foo": {"$lt": 1}}}) 59 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) 60 | assert expected == actual 61 | 62 | op = Operation( 63 | operator=Operator.AND, 64 | arguments=[ 65 | Comparison(comparator=Comparator.LT, attribute="foo", value=2), 66 | Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), 67 | Comparison(comparator=Comparator.GT, attribute="abc", value=2.0), 68 | ], 69 | ) 70 | structured_query = StructuredQuery( 71 | query=query, 72 | filter=op, 73 | ) 74 | expected = ( 75 | query, 76 | { 77 | "filter": { 78 | "$and": [ 79 | {"foo": {"$lt": 2}}, 80 | {"bar": {"$eq": "baz"}}, 81 | {"abc": {"$gt": 2.0}}, 82 | ] 83 | } 84 | }, 85 | ) 86 | actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) 87 | assert expected == actual 88 | -------------------------------------------------------------------------------- /tests/unit_tests/test_imports.py: -------------------------------------------------------------------------------- 1 | from langchain_postgres import __all__ 2 | 3 | EXPECTED_ALL = [ 4 | "__version__", 5 | "Column", 6 | "ColumnDict", 7 | "PGEngine", 8 | "PGVector", 9 | "PGVectorStore", 10 | "PGVectorTranslator", 11 | "PostgresChatMessageHistory", 12 | ] 13 | 14 | 15 | def test_all_imports() -> None: 16 | """Test that __all__ is correctly defined.""" 17 | assert sorted(EXPECTED_ALL) == sorted(__all__) 18 | -------------------------------------------------------------------------------- /tests/unit_tests/v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/v1/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/v1/test_chat_histories.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from langchain_core.messages import AIMessage, HumanMessage, SystemMessage 4 | 5 | from langchain_postgres.chat_message_histories import PostgresChatMessageHistory 6 | from tests.utils import asyncpg_client, syncpg_client 7 | 8 | 9 | def test_sync_chat_history() -> None: 10 | table_name = "chat_history" 11 | session_id = str(uuid.UUID(int=123)) 12 | with syncpg_client() as sync_connection: 13 | PostgresChatMessageHistory.drop_table(sync_connection, table_name) 14 | PostgresChatMessageHistory.create_tables(sync_connection, table_name) 15 | 16 | chat_history = PostgresChatMessageHistory( 17 | table_name, session_id, sync_connection=sync_connection 18 | ) 19 | 20 | messages = chat_history.messages 21 | assert messages == [] 22 | 23 | assert chat_history is not None 24 | 25 | # Get messages from the chat history 26 | messages = chat_history.messages 27 | assert messages == [] 28 | 29 | chat_history.add_messages( 30 | [ 31 | SystemMessage(content="Meow"), 32 | AIMessage(content="woof"), 33 | HumanMessage(content="bark"), 34 | ] 35 | ) 36 | 37 | # Get messages from the chat history 38 | messages = chat_history.messages 39 | assert len(messages) == 3 40 | assert messages == [ 41 | SystemMessage(content="Meow"), 42 | AIMessage(content="woof"), 43 | HumanMessage(content="bark"), 44 | ] 45 | 46 | chat_history.add_messages( 47 | [ 48 | SystemMessage(content="Meow"), 49 | AIMessage(content="woof"), 50 | HumanMessage(content="bark"), 51 | ] 52 | ) 53 | 54 | messages = chat_history.messages 55 | assert len(messages) == 6 56 | assert messages == [ 57 | SystemMessage(content="Meow"), 58 | AIMessage(content="woof"), 59 | HumanMessage(content="bark"), 60 | SystemMessage(content="Meow"), 61 | AIMessage(content="woof"), 62 | HumanMessage(content="bark"), 63 | ] 64 | 65 | chat_history.clear() 66 | assert chat_history.messages == [] 67 | 68 | 69 | async def test_async_chat_history() -> None: 70 | """Test the async chat history.""" 71 | async with asyncpg_client() as async_connection: 72 | table_name = "chat_history" 73 | session_id = str(uuid.UUID(int=125)) 74 | await PostgresChatMessageHistory.adrop_table(async_connection, table_name) 75 | await PostgresChatMessageHistory.acreate_tables(async_connection, table_name) 76 | 77 | chat_history = PostgresChatMessageHistory( 78 | table_name, session_id, async_connection=async_connection 79 | ) 80 | 81 | messages = await chat_history.aget_messages() 82 | assert messages == [] 83 | 84 | # Add messages 85 | await chat_history.aadd_messages( 86 | [ 87 | SystemMessage(content="Meow"), 88 | AIMessage(content="woof"), 89 | HumanMessage(content="bark"), 90 | ] 91 | ) 92 | # Get the messages 93 | messages = await chat_history.aget_messages() 94 | assert len(messages) == 3 95 | assert messages == [ 96 | SystemMessage(content="Meow"), 97 | AIMessage(content="woof"), 98 | HumanMessage(content="bark"), 99 | ] 100 | 101 | # Add more messages 102 | await chat_history.aadd_messages( 103 | [ 104 | SystemMessage(content="Meow"), 105 | AIMessage(content="woof"), 106 | HumanMessage(content="bark"), 107 | ] 108 | ) 109 | # Get the messages 110 | messages = await chat_history.aget_messages() 111 | assert len(messages) == 6 112 | assert messages == [ 113 | SystemMessage(content="Meow"), 114 | AIMessage(content="woof"), 115 | HumanMessage(content="bark"), 116 | SystemMessage(content="Meow"), 117 | AIMessage(content="woof"), 118 | HumanMessage(content="bark"), 119 | ] 120 | 121 | # clear 122 | await chat_history.aclear() 123 | assert await chat_history.aget_messages() == [] 124 | -------------------------------------------------------------------------------- /tests/unit_tests/v1/test_vectorstore_standard_tests.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator, Generator 2 | 3 | import pytest 4 | from langchain_core.vectorstores import VectorStore 5 | from langchain_tests.integration_tests import VectorStoreIntegrationTests 6 | 7 | from tests.unit_tests.v1.test_vectorstore import aget_vectorstore, get_vectorstore 8 | 9 | 10 | class TestSync(VectorStoreIntegrationTests): 11 | @pytest.fixture() 12 | def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore 13 | """Get an empty vectorstore for unit tests.""" 14 | with get_vectorstore(embedding=self.get_embeddings()) as vstore: 15 | vstore.drop_tables() 16 | vstore.create_tables_if_not_exists() 17 | vstore.create_collection() 18 | yield vstore 19 | 20 | @property 21 | def has_async(self) -> bool: 22 | return False # Skip async tests for sync vector store 23 | 24 | 25 | class TestAsync(VectorStoreIntegrationTests): 26 | @pytest.fixture() 27 | async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore 28 | """Get an empty vectorstore for unit tests.""" 29 | async with aget_vectorstore(embedding=self.get_embeddings()) as vstore: 30 | await vstore.adrop_tables() 31 | await vstore.acreate_tables_if_not_exists() 32 | await vstore.acreate_collection() 33 | yield vstore 34 | 35 | @property 36 | def has_sync(self) -> bool: 37 | return False # Skip sync tests for async vector store 38 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/langchain-ai/langchain-postgres/b342b29eff0ec986af128857716d56e7745be70b/tests/unit_tests/v2/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_async_pg_vectorstore.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import AsyncIterator, Sequence 3 | 4 | import pytest 5 | import pytest_asyncio 6 | from langchain_core.documents import Document 7 | from langchain_core.embeddings import DeterministicFakeEmbedding 8 | from sqlalchemy import text 9 | from sqlalchemy.engine.row import RowMapping 10 | 11 | from langchain_postgres import Column, PGEngine 12 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore 13 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 14 | 15 | DEFAULT_TABLE = "default" + str(uuid.uuid4()) 16 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()) 17 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()) 18 | VECTOR_SIZE = 768 19 | 20 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) 21 | 22 | texts = ["foo", "bar", "baz"] 23 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] 24 | docs = [ 25 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) 26 | ] 27 | 28 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] 29 | 30 | 31 | async def aexecute(engine: PGEngine, query: str) -> None: 32 | async with engine._pool.connect() as conn: 33 | await conn.execute(text(query)) 34 | await conn.commit() 35 | 36 | 37 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]: 38 | async with engine._pool.connect() as conn: 39 | result = await conn.execute(text(query)) 40 | result_map = result.mappings() 41 | result_fetch = result_map.fetchall() 42 | return result_fetch 43 | 44 | 45 | @pytest.mark.enable_socket 46 | @pytest.mark.asyncio(scope="class") 47 | class TestVectorStore: 48 | @pytest_asyncio.fixture(scope="class") 49 | async def engine(self) -> AsyncIterator[PGEngine]: 50 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 51 | 52 | yield engine 53 | await engine.adrop_table(DEFAULT_TABLE) 54 | await engine.adrop_table(CUSTOM_TABLE) 55 | await engine.close() 56 | 57 | @pytest_asyncio.fixture(scope="class") 58 | async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: 59 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) 60 | vs = await AsyncPGVectorStore.create( 61 | engine, 62 | embedding_service=embeddings_service, 63 | table_name=DEFAULT_TABLE, 64 | ) 65 | yield vs 66 | 67 | @pytest_asyncio.fixture(scope="class") 68 | async def vs_custom(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: 69 | await engine._ainit_vectorstore_table( 70 | CUSTOM_TABLE, 71 | VECTOR_SIZE, 72 | id_column="myid", 73 | content_column="mycontent", 74 | embedding_column="myembedding", 75 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 76 | metadata_json_column="mymeta", 77 | ) 78 | vs = await AsyncPGVectorStore.create( 79 | engine, 80 | embedding_service=embeddings_service, 81 | table_name=CUSTOM_TABLE, 82 | id_column="myid", 83 | content_column="mycontent", 84 | embedding_column="myembedding", 85 | metadata_columns=["page", "source"], 86 | metadata_json_column="mymeta", 87 | ) 88 | yield vs 89 | 90 | async def test_init_with_constructor(self, engine: PGEngine) -> None: 91 | with pytest.raises(Exception): 92 | AsyncPGVectorStore( 93 | key={}, 94 | engine=engine._pool, 95 | embedding_service=embeddings_service, 96 | table_name=CUSTOM_TABLE, 97 | id_column="myid", 98 | content_column="noname", 99 | embedding_column="myembedding", 100 | metadata_columns=["page", "source"], 101 | metadata_json_column="mymeta", 102 | ) 103 | 104 | async def test_post_init(self, engine: PGEngine) -> None: 105 | with pytest.raises(ValueError): 106 | await AsyncPGVectorStore.create( 107 | engine, 108 | embedding_service=embeddings_service, 109 | table_name=CUSTOM_TABLE, 110 | id_column="myid", 111 | content_column="noname", 112 | embedding_column="myembedding", 113 | metadata_columns=["page", "source"], 114 | metadata_json_column="mymeta", 115 | ) 116 | 117 | async def test_aadd_texts(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None: 118 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 119 | await vs.aadd_texts(texts, ids=ids) 120 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 121 | assert len(results) == 3 122 | 123 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 124 | await vs.aadd_texts(texts, metadatas, ids) 125 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 126 | assert len(results) == 6 127 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') 128 | 129 | async def test_aadd_texts_edge_cases( 130 | self, engine: PGEngine, vs: AsyncPGVectorStore 131 | ) -> None: 132 | texts = ["Taylor's", '"Swift"', "best-friend"] 133 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 134 | await vs.aadd_texts(texts, ids=ids) 135 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 136 | assert len(results) == 3 137 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') 138 | 139 | async def test_aadd_docs(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None: 140 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 141 | await vs.aadd_documents(docs, ids=ids) 142 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 143 | assert len(results) == 3 144 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') 145 | 146 | async def test_aadd_docs_no_ids( 147 | self, engine: PGEngine, vs: AsyncPGVectorStore 148 | ) -> None: 149 | await vs.aadd_documents(docs) 150 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 151 | assert len(results) == 3 152 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') 153 | 154 | async def test_adelete(self, engine: PGEngine, vs: AsyncPGVectorStore) -> None: 155 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 156 | await vs.aadd_texts(texts, ids=ids) 157 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 158 | assert len(results) == 3 159 | # delete an ID 160 | await vs.adelete([ids[0]]) 161 | results = await afetch(engine, f'SELECT * FROM "{DEFAULT_TABLE}"') 162 | assert len(results) == 2 163 | # delete with no ids 164 | result = await vs.adelete() 165 | assert not result 166 | await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE}"') 167 | 168 | ##### Custom Vector Store ##### 169 | async def test_aadd_embeddings( 170 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore 171 | ) -> None: 172 | await vs_custom.aadd_embeddings( 173 | texts=texts, embeddings=embeddings, metadatas=metadatas 174 | ) 175 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 176 | assert len(results) == 3 177 | assert results[0]["mycontent"] == "foo" 178 | assert results[0]["myembedding"] 179 | assert results[0]["page"] == "0" 180 | assert results[0]["source"] == "postgres" 181 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') 182 | 183 | async def test_aadd_texts_custom( 184 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore 185 | ) -> None: 186 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 187 | await vs_custom.aadd_texts(texts, ids=ids) 188 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 189 | assert len(results) == 3 190 | assert results[0]["mycontent"] == "foo" 191 | assert results[0]["myembedding"] 192 | assert results[0]["page"] is None 193 | assert results[0]["source"] is None 194 | 195 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 196 | await vs_custom.aadd_texts(texts, metadatas, ids) 197 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 198 | assert len(results) == 6 199 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') 200 | 201 | async def test_aadd_docs_custom( 202 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore 203 | ) -> None: 204 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 205 | docs = [ 206 | Document( 207 | page_content=texts[i], 208 | metadata={"page": str(i), "source": "postgres"}, 209 | ) 210 | for i in range(len(texts)) 211 | ] 212 | await vs_custom.aadd_documents(docs, ids=ids) 213 | 214 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 215 | assert len(results) == 3 216 | assert results[0]["mycontent"] == "foo" 217 | assert results[0]["myembedding"] 218 | assert results[0]["page"] == "0" 219 | assert results[0]["source"] == "postgres" 220 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') 221 | 222 | async def test_adelete_custom( 223 | self, engine: PGEngine, vs_custom: AsyncPGVectorStore 224 | ) -> None: 225 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 226 | await vs_custom.aadd_texts(texts, ids=ids) 227 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 228 | content = [result["mycontent"] for result in results] 229 | assert len(results) == 3 230 | assert "foo" in content 231 | # delete an ID 232 | await vs_custom.adelete([ids[0]]) 233 | results = await afetch(engine, f'SELECT * FROM "{CUSTOM_TABLE}"') 234 | content = [result["mycontent"] for result in results] 235 | assert len(results) == 2 236 | assert "foo" not in content 237 | await aexecute(engine, f'TRUNCATE TABLE "{CUSTOM_TABLE}"') 238 | 239 | async def test_ignore_metadata_columns(self, engine: PGEngine) -> None: 240 | column_to_ignore = "source" 241 | vs = await AsyncPGVectorStore.create( 242 | engine, 243 | embedding_service=embeddings_service, 244 | table_name=CUSTOM_TABLE, 245 | ignore_metadata_columns=[column_to_ignore], 246 | id_column="myid", 247 | content_column="mycontent", 248 | embedding_column="myembedding", 249 | metadata_json_column="mymeta", 250 | ) 251 | assert column_to_ignore not in vs.metadata_columns 252 | 253 | async def test_create_vectorstore_with_invalid_parameters_1( 254 | self, engine: PGEngine 255 | ) -> None: 256 | with pytest.raises(ValueError): 257 | await AsyncPGVectorStore.create( 258 | engine, 259 | embedding_service=embeddings_service, 260 | table_name=CUSTOM_TABLE, 261 | id_column="myid", 262 | content_column="mycontent", 263 | embedding_column="myembedding", 264 | metadata_columns=["random_column"], # invalid metadata column 265 | ) 266 | 267 | async def test_create_vectorstore_with_invalid_parameters_2( 268 | self, engine: PGEngine 269 | ) -> None: 270 | with pytest.raises(ValueError): 271 | await AsyncPGVectorStore.create( 272 | engine, 273 | embedding_service=embeddings_service, 274 | table_name=CUSTOM_TABLE, 275 | id_column="myid", 276 | content_column="langchain_id", # invalid content column type 277 | embedding_column="myembedding", 278 | metadata_columns=["random_column"], 279 | ) 280 | 281 | async def test_create_vectorstore_with_invalid_parameters_3( 282 | self, engine: PGEngine 283 | ) -> None: 284 | with pytest.raises(ValueError): 285 | await AsyncPGVectorStore.create( 286 | engine, 287 | embedding_service=embeddings_service, 288 | table_name=CUSTOM_TABLE, 289 | id_column="myid", 290 | content_column="mycontent", 291 | embedding_column="random_column", # invalid embedding column 292 | metadata_columns=["random_column"], 293 | ) 294 | 295 | async def test_create_vectorstore_with_invalid_parameters_4( 296 | self, engine: PGEngine 297 | ) -> None: 298 | with pytest.raises(ValueError): 299 | await AsyncPGVectorStore.create( 300 | engine, 301 | embedding_service=embeddings_service, 302 | table_name=CUSTOM_TABLE, 303 | id_column="myid", 304 | content_column="mycontent", 305 | embedding_column="langchain_id", # invalid embedding column data type 306 | metadata_columns=["random_column"], 307 | ) 308 | 309 | async def test_create_vectorstore_with_invalid_parameters_5( 310 | self, engine: PGEngine 311 | ) -> None: 312 | with pytest.raises(ValueError): 313 | await AsyncPGVectorStore.create( 314 | engine, 315 | embedding_service=embeddings_service, 316 | table_name=CUSTOM_TABLE, 317 | id_column="myid", 318 | content_column="mycontent", 319 | embedding_column="langchain_id", 320 | metadata_columns=["random_column"], 321 | ignore_metadata_columns=[ 322 | "one", 323 | "two", 324 | ], # invalid use of metadata_columns and ignore columns 325 | ) 326 | 327 | async def test_create_vectorstore_with_init(self, engine: PGEngine) -> None: 328 | with pytest.raises(Exception): 329 | AsyncPGVectorStore( 330 | key={}, 331 | engine=engine._pool, 332 | embedding_service=embeddings_service, 333 | table_name=CUSTOM_TABLE, 334 | id_column="myid", 335 | content_column="mycontent", 336 | embedding_column="myembedding", 337 | metadata_columns=["random_column"], # invalid metadata column 338 | ) 339 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_async_pg_vectorstore_from_methods.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import AsyncIterator, Sequence 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from langchain_core.documents import Document 8 | from langchain_core.embeddings import DeterministicFakeEmbedding 9 | from sqlalchemy import text 10 | from sqlalchemy.engine.row import RowMapping 11 | 12 | from langchain_postgres import Column, PGEngine 13 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore 14 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 15 | 16 | DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") 17 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") 18 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") 19 | CUSTOM_TABLE_WITH_INT_ID = "custom_sync" + str(uuid.uuid4()).replace("-", "_") 20 | VECTOR_SIZE = 768 21 | 22 | 23 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) 24 | 25 | texts = ["foo", "bar", "baz"] 26 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] 27 | docs = [ 28 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) 29 | ] 30 | 31 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] 32 | 33 | 34 | def get_env_var(key: str, desc: str) -> str: 35 | v = os.environ.get(key) 36 | if v is None: 37 | raise ValueError(f"Must set env var {key} to: {desc}") 38 | return v 39 | 40 | 41 | async def aexecute(engine: PGEngine, query: str) -> None: 42 | async with engine._pool.connect() as conn: 43 | await conn.execute(text(query)) 44 | await conn.commit() 45 | 46 | 47 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]: 48 | async with engine._pool.connect() as conn: 49 | result = await conn.execute(text(query)) 50 | result_map = result.mappings() 51 | result_fetch = result_map.fetchall() 52 | return result_fetch 53 | 54 | 55 | @pytest.mark.enable_socket 56 | @pytest.mark.asyncio 57 | class TestVectorStoreFromMethods: 58 | @pytest_asyncio.fixture 59 | async def engine(self) -> AsyncIterator[PGEngine]: 60 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 61 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) 62 | await engine._ainit_vectorstore_table( 63 | CUSTOM_TABLE, 64 | VECTOR_SIZE, 65 | id_column="myid", 66 | content_column="mycontent", 67 | embedding_column="myembedding", 68 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 69 | store_metadata=False, 70 | ) 71 | await engine._ainit_vectorstore_table( 72 | CUSTOM_TABLE_WITH_INT_ID, 73 | VECTOR_SIZE, 74 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False), 75 | content_column="mycontent", 76 | embedding_column="myembedding", 77 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 78 | store_metadata=False, 79 | ) 80 | yield engine 81 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") 82 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") 83 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}") 84 | await engine.close() 85 | 86 | async def test_afrom_texts(self, engine: PGEngine) -> None: 87 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 88 | await AsyncPGVectorStore.afrom_texts( 89 | texts, 90 | embeddings_service, 91 | engine, 92 | DEFAULT_TABLE, 93 | metadatas=metadatas, 94 | ids=ids, 95 | ) 96 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") 97 | assert len(results) == 3 98 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") 99 | 100 | async def test_afrom_docs(self, engine: PGEngine) -> None: 101 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 102 | await AsyncPGVectorStore.afrom_documents( 103 | docs, 104 | embeddings_service, 105 | engine, 106 | DEFAULT_TABLE, 107 | ids=ids, 108 | ) 109 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") 110 | assert len(results) == 3 111 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") 112 | 113 | async def test_afrom_texts_custom(self, engine: PGEngine) -> None: 114 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 115 | await AsyncPGVectorStore.afrom_texts( 116 | texts, 117 | embeddings_service, 118 | engine, 119 | CUSTOM_TABLE, 120 | ids=ids, 121 | id_column="myid", 122 | content_column="mycontent", 123 | embedding_column="myembedding", 124 | metadata_columns=["page", "source"], 125 | ) 126 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") 127 | assert len(results) == 3 128 | assert results[0]["mycontent"] == "foo" 129 | assert results[0]["myembedding"] 130 | assert results[0]["page"] is None 131 | assert results[0]["source"] is None 132 | 133 | async def test_afrom_docs_custom(self, engine: PGEngine) -> None: 134 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 135 | docs = [ 136 | Document( 137 | page_content=texts[i], 138 | metadata={"page": str(i), "source": "postgres"}, 139 | ) 140 | for i in range(len(texts)) 141 | ] 142 | await AsyncPGVectorStore.afrom_documents( 143 | docs, 144 | embeddings_service, 145 | engine, 146 | CUSTOM_TABLE, 147 | ids=ids, 148 | id_column="myid", 149 | content_column="mycontent", 150 | embedding_column="myembedding", 151 | metadata_columns=["page", "source"], 152 | ) 153 | 154 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") 155 | assert len(results) == 3 156 | assert results[0]["mycontent"] == "foo" 157 | assert results[0]["myembedding"] 158 | assert results[0]["page"] == "0" 159 | assert results[0]["source"] == "postgres" 160 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}") 161 | 162 | async def test_afrom_docs_custom_with_int_id(self, engine: PGEngine) -> None: 163 | ids = [i for i in range(len(texts))] 164 | docs = [ 165 | Document( 166 | page_content=texts[i], 167 | metadata={"page": str(i), "source": "postgres"}, 168 | ) 169 | for i in range(len(texts)) 170 | ] 171 | await AsyncPGVectorStore.afrom_documents( 172 | docs, 173 | embeddings_service, 174 | engine, 175 | CUSTOM_TABLE_WITH_INT_ID, 176 | ids=ids, 177 | id_column="integer_id", 178 | content_column="mycontent", 179 | embedding_column="myembedding", 180 | metadata_columns=["page", "source"], 181 | ) 182 | 183 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}") 184 | assert len(results) == 3 185 | for row in results: 186 | assert isinstance(row["integer_id"], int) 187 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}") 188 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_async_pg_vectorstore_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import AsyncIterator 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from langchain_core.documents import Document 8 | from langchain_core.embeddings import DeterministicFakeEmbedding 9 | from sqlalchemy import text 10 | 11 | from langchain_postgres import PGEngine 12 | from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore 13 | from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig 14 | from langchain_postgres.v2.indexes import DistanceStrategy, HNSWIndex, IVFFlatIndex 15 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 16 | 17 | uuid_str = str(uuid.uuid4()).replace("-", "_") 18 | DEFAULT_TABLE = "default" + uuid_str 19 | DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str 20 | DEFAULT_INDEX_NAME = "index" + uuid_str 21 | VECTOR_SIZE = 768 22 | SIMPLE_TABLE = "default_table" 23 | 24 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) 25 | 26 | texts = ["foo", "bar", "baz"] 27 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 28 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] 29 | docs = [ 30 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) 31 | ] 32 | 33 | embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] 34 | 35 | 36 | def get_env_var(key: str, desc: str) -> str: 37 | v = os.environ.get(key) 38 | if v is None: 39 | raise ValueError(f"Must set env var {key} to: {desc}") 40 | return v 41 | 42 | 43 | async def aexecute(engine: PGEngine, query: str) -> None: 44 | async with engine._pool.connect() as conn: 45 | await conn.execute(text(query)) 46 | await conn.commit() 47 | 48 | 49 | @pytest.mark.enable_socket 50 | @pytest.mark.asyncio(scope="class") 51 | class TestIndex: 52 | @pytest_asyncio.fixture(scope="class") 53 | async def engine(self) -> AsyncIterator[PGEngine]: 54 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 55 | yield engine 56 | 57 | await engine._adrop_table(DEFAULT_TABLE) 58 | await engine._adrop_table(DEFAULT_HYBRID_TABLE) 59 | await engine._adrop_table(SIMPLE_TABLE) 60 | await engine.close() 61 | 62 | @pytest_asyncio.fixture(scope="class") 63 | async def vs(self, engine: PGEngine) -> AsyncIterator[AsyncPGVectorStore]: 64 | await engine._ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) 65 | vs = await AsyncPGVectorStore.create( 66 | engine, 67 | embedding_service=embeddings_service, 68 | table_name=DEFAULT_TABLE, 69 | ) 70 | 71 | await vs.aadd_texts(texts, ids=ids) 72 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME) 73 | yield vs 74 | 75 | async def test_apply_default_name_vector_index(self, engine: PGEngine) -> None: 76 | await engine._ainit_vectorstore_table( 77 | SIMPLE_TABLE, VECTOR_SIZE, overwrite_existing=True 78 | ) 79 | vs = await AsyncPGVectorStore.create( 80 | engine, 81 | embedding_service=embeddings_service, 82 | table_name=SIMPLE_TABLE, 83 | ) 84 | await vs.aadd_texts(texts, ids=ids) 85 | await vs.adrop_vector_index() 86 | index = HNSWIndex() 87 | await vs.aapply_vector_index(index) 88 | assert await vs.is_valid_index() 89 | await vs.adrop_vector_index() 90 | 91 | async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None: 92 | index = HNSWIndex(name=DEFAULT_INDEX_NAME) 93 | await vs.aapply_vector_index(index) 94 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME) 95 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME) 96 | 97 | async def test_aapply_vector_index_non_hybrid_search_vs( 98 | self, vs: AsyncPGVectorStore 99 | ) -> None: 100 | with pytest.raises(ValueError): 101 | await vs.aapply_hybrid_search_index() 102 | 103 | async def test_aapply_hybrid_search_index_table_without_tsv_column( 104 | self, engine: PGEngine, vs: AsyncPGVectorStore 105 | ) -> None: 106 | # overwriting vs to get a hybrid vs 107 | tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str 108 | vs = await AsyncPGVectorStore.create( 109 | engine, 110 | embedding_service=embeddings_service, 111 | table_name=DEFAULT_TABLE, 112 | hybrid_search_config=HybridSearchConfig(index_name=tsv_index_name), 113 | ) 114 | is_valid_index = await vs.is_valid_index(tsv_index_name) 115 | assert is_valid_index == False 116 | await vs.aapply_hybrid_search_index() 117 | assert await vs.is_valid_index(tsv_index_name) 118 | await vs.adrop_vector_index(tsv_index_name) 119 | is_valid_index = await vs.is_valid_index(tsv_index_name) 120 | assert is_valid_index == False 121 | 122 | async def test_aapply_hybrid_search_index_table_with_tsv_column( 123 | self, engine: PGEngine 124 | ) -> None: 125 | tsv_index_name = "tsv_index_on_table_without_tsv_column_" + uuid_str 126 | config = HybridSearchConfig( 127 | tsv_column="tsv_column", 128 | tsv_lang="pg_catalog.english", 129 | index_name=tsv_index_name, 130 | ) 131 | await engine._ainit_vectorstore_table( 132 | DEFAULT_HYBRID_TABLE, 133 | VECTOR_SIZE, 134 | hybrid_search_config=config, 135 | ) 136 | vs = await AsyncPGVectorStore.create( 137 | engine, 138 | embedding_service=embeddings_service, 139 | table_name=DEFAULT_HYBRID_TABLE, 140 | hybrid_search_config=config, 141 | ) 142 | is_valid_index = await vs.is_valid_index(tsv_index_name) 143 | assert is_valid_index == False 144 | await vs.aapply_hybrid_search_index() 145 | assert await vs.is_valid_index(tsv_index_name) 146 | await vs.areindex(tsv_index_name) 147 | assert await vs.is_valid_index(tsv_index_name) 148 | await vs.adrop_vector_index(tsv_index_name) 149 | is_valid_index = await vs.is_valid_index(tsv_index_name) 150 | assert is_valid_index == False 151 | 152 | async def test_areindex(self, vs: AsyncPGVectorStore) -> None: 153 | if not await vs.is_valid_index(DEFAULT_INDEX_NAME): 154 | index = HNSWIndex(name=DEFAULT_INDEX_NAME) 155 | await vs.aapply_vector_index(index) 156 | await vs.areindex(DEFAULT_INDEX_NAME) 157 | await vs.areindex(DEFAULT_INDEX_NAME) 158 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME) 159 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME) 160 | 161 | async def test_dropindex(self, vs: AsyncPGVectorStore) -> None: 162 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME) 163 | result = await vs.is_valid_index(DEFAULT_INDEX_NAME) 164 | assert not result 165 | 166 | async def test_aapply_vector_index_ivfflat(self, vs: AsyncPGVectorStore) -> None: 167 | index = IVFFlatIndex( 168 | name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN 169 | ) 170 | await vs.aapply_vector_index(index, concurrently=True) 171 | assert await vs.is_valid_index(DEFAULT_INDEX_NAME) 172 | index = IVFFlatIndex( 173 | name="secondindex", 174 | distance_strategy=DistanceStrategy.INNER_PRODUCT, 175 | ) 176 | await vs.aapply_vector_index(index) 177 | assert await vs.is_valid_index("secondindex") 178 | await vs.adrop_vector_index("secondindex") 179 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME) 180 | 181 | async def test_is_valid_index(self, vs: AsyncPGVectorStore) -> None: 182 | is_valid = await vs.is_valid_index("invalid_index") 183 | assert not is_valid 184 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_hybrid_search_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from langchain_postgres.v2.hybrid_search_config import ( 4 | reciprocal_rank_fusion, 5 | weighted_sum_ranking, 6 | ) 7 | 8 | 9 | # Helper to create mock input items that mimic RowMapping for the fusion functions 10 | def get_row(doc_id: str, score: float, content: str = "content") -> dict: 11 | """ 12 | Simulates a RowMapping-like dictionary. 13 | The fusion functions expect to extract doc_id as the first value and 14 | the initial score/distance as the last value when casting values from RowMapping. 15 | They then operate on dictionaries, using the 'distance' key for the fused score. 16 | """ 17 | # Python dicts maintain insertion order (Python 3.7+). 18 | # This structure ensures list(row.values())[0] is doc_id and 19 | # list(row.values())[-1] is score. 20 | return {"id_val": doc_id, "content_field": content, "distance": score} 21 | 22 | 23 | class TestWeightedSumRanking: 24 | def test_empty_inputs(self) -> None: 25 | results = weighted_sum_ranking([], []) 26 | assert results == [] 27 | 28 | def test_primary_only(self) -> None: 29 | primary = [get_row("p1", 0.8), get_row("p2", 0.6)] 30 | # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 31 | results = weighted_sum_ranking( # type: ignore 32 | primary, # type: ignore 33 | [], 34 | primary_results_weight=0.5, 35 | secondary_results_weight=0.5, 36 | ) 37 | assert len(results) == 2 38 | assert results[0]["id_val"] == "p1" 39 | assert results[0]["distance"] == pytest.approx(0.4) 40 | assert results[1]["id_val"] == "p2" 41 | assert results[1]["distance"] == pytest.approx(0.3) 42 | 43 | def test_secondary_only(self) -> None: 44 | secondary = [get_row("s1", 0.9), get_row("s2", 0.7)] 45 | # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 46 | results = weighted_sum_ranking( 47 | [], 48 | secondary, # type: ignore 49 | primary_results_weight=0.5, 50 | secondary_results_weight=0.5, 51 | ) 52 | assert len(results) == 2 53 | assert results[0]["id_val"] == "s1" 54 | assert results[0]["distance"] == pytest.approx(0.45) 55 | assert results[1]["id_val"] == "s2" 56 | assert results[1]["distance"] == pytest.approx(0.35) 57 | 58 | def test_mixed_results_default_weights(self) -> None: 59 | primary = [get_row("common", 0.8), get_row("p_only", 0.7)] 60 | secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] 61 | # Weights are 0.5, 0.5 62 | # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 63 | # p_only_score = (0.7 * 0.5) = 0.35 64 | # s_only_score = (0.6 * 0.5) = 0.30 65 | # Order: common (0.85), p_only (0.35), s_only (0.30) 66 | 67 | results = weighted_sum_ranking(primary, secondary) # type: ignore 68 | assert len(results) == 3 69 | assert results[0]["id_val"] == "common" 70 | assert results[0]["distance"] == pytest.approx(0.85) 71 | assert results[1]["id_val"] == "p_only" 72 | assert results[1]["distance"] == pytest.approx(0.35) 73 | assert results[2]["id_val"] == "s_only" 74 | assert results[2]["distance"] == pytest.approx(0.30) 75 | 76 | def test_mixed_results_custom_weights(self) -> None: 77 | primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2 78 | secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4 79 | # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 80 | 81 | results = weighted_sum_ranking( 82 | primary, # type: ignore 83 | secondary, # type: ignore 84 | primary_results_weight=0.2, 85 | secondary_results_weight=0.8, 86 | ) 87 | assert len(results) == 1 88 | assert results[0]["id_val"] == "d1" 89 | assert results[0]["distance"] == pytest.approx(0.6) 90 | 91 | def test_fetch_top_k(self) -> None: 92 | primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] 93 | # Scores: 1.0, 0.9, 0.8, 0.7, 0.6 94 | # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 95 | results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore 96 | assert len(results) == 2 97 | assert results[0]["id_val"] == "p0" 98 | assert results[0]["distance"] == pytest.approx(0.5) 99 | assert results[1]["id_val"] == "p1" 100 | assert results[1]["distance"] == pytest.approx(0.45) 101 | 102 | 103 | class TestReciprocalRankFusion: 104 | def test_empty_inputs(self) -> None: 105 | results = reciprocal_rank_fusion([], []) 106 | assert results == [] 107 | 108 | def test_primary_only(self) -> None: 109 | primary = [ 110 | get_row("p1", 0.8), 111 | get_row("p2", 0.6), 112 | ] # p1 rank 0, p2 rank 1 113 | rrf_k = 60 114 | # p1_score = 1 / (0 + 60) 115 | # p2_score = 1 / (1 + 60) 116 | results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore 117 | assert len(results) == 2 118 | assert results[0]["id_val"] == "p1" 119 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) 120 | assert results[1]["id_val"] == "p2" 121 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) 122 | 123 | def test_secondary_only(self) -> None: 124 | secondary = [ 125 | get_row("s1", 0.9), 126 | get_row("s2", 0.7), 127 | ] # s1 rank 0, s2 rank 1 128 | rrf_k = 60 129 | results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore 130 | assert len(results) == 2 131 | assert results[0]["id_val"] == "s1" 132 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) 133 | assert results[1]["id_val"] == "s2" 134 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) 135 | 136 | def test_mixed_results_default_k(self) -> None: 137 | primary = [get_row("common", 0.8), get_row("p_only", 0.7)] 138 | secondary = [get_row("common", 0.9), get_row("s_only", 0.6)] 139 | rrf_k = 60 140 | # common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k 141 | # p_only_score = (1/(1+k))_prim = 1/(k+1) 142 | # s_only_score = (1/(1+k))_sec = 1/(k+1) 143 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore 144 | assert len(results) == 3 145 | assert results[0]["id_val"] == "common" 146 | assert results[0]["distance"] == pytest.approx(2.0 / rrf_k) 147 | # Check the next two elements, their order might vary due to tie in score 148 | next_ids = {results[1]["id_val"], results[2]["id_val"]} 149 | next_scores = {results[1]["distance"], results[2]["distance"]} 150 | assert next_ids == {"p_only", "s_only"} 151 | for score in next_scores: 152 | assert score == pytest.approx(1.0 / (1 + rrf_k)) 153 | 154 | def test_fetch_top_k_rrf(self) -> None: 155 | primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)] 156 | rrf_k = 1 157 | results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore 158 | assert len(results) == 2 159 | assert results[0]["id_val"] == "p0" 160 | assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k)) 161 | assert results[1]["id_val"] == "p1" 162 | assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k)) 163 | 164 | def test_rrf_content_preservation(self) -> None: 165 | primary = [get_row("doc1", 0.9, content="Primary Content")] 166 | secondary = [get_row("doc1", 0.8, content="Secondary Content")] 167 | # RRF processes primary then secondary. If a doc is in both, 168 | # the content from the secondary list will overwrite primary's. 169 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore 170 | assert len(results) == 1 171 | assert results[0]["id_val"] == "doc1" 172 | assert results[0]["content_field"] == "Secondary Content" 173 | 174 | # If only in primary 175 | results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore 176 | assert results_prim_only[0]["content_field"] == "Primary Content" 177 | 178 | def test_reordering_from_inputs_rrf(self) -> None: 179 | """ 180 | Tests that RRF fused ranking can be different from both primary and secondary 181 | input rankings. 182 | Primary Order: A, B, C 183 | Secondary Order: C, B, A 184 | Fused Order: (A, C) tied, then B 185 | """ 186 | primary = [ 187 | get_row("docA", 0.9), 188 | get_row("docB", 0.8), 189 | get_row("docC", 0.1), 190 | ] 191 | secondary = [ 192 | get_row("docC", 0.9), 193 | get_row("docB", 0.5), 194 | get_row("docA", 0.2), 195 | ] 196 | rrf_k = 1.0 # Using 1.0 for k to simplify rank score calculation 197 | # docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3 198 | # docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1 199 | # docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3 200 | results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore 201 | assert len(results) == 3 202 | assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"} 203 | assert results[0]["distance"] == pytest.approx(4.0 / 3.0) 204 | assert results[1]["distance"] == pytest.approx(4.0 / 3.0) 205 | assert results[2]["id_val"] == "docB" 206 | assert results[2]["distance"] == pytest.approx(1.0) 207 | 208 | def test_reordering_from_inputs_weighted_sum(self) -> None: 209 | """ 210 | Tests that the fused ranking can be different from both primary and secondary 211 | input rankings. 212 | Primary Order: A (0.9), B (0.7) 213 | Secondary Order: B (0.8), A (0.2) 214 | Fusion (0.5/0.5 weights): 215 | docA_score = (0.9 * 0.5) + (0.2 * 0.5) = 0.45 + 0.10 = 0.55 216 | docB_score = (0.7 * 0.5) + (0.8 * 0.5) = 0.35 + 0.40 = 0.75 217 | Expected Fused Order: docB (0.75), docA (0.55) 218 | This is different from Primary (A,B) and Secondary (B,A) in terms of 219 | original score, but the fusion logic changes the effective contribution). 220 | """ 221 | primary = [get_row("docA", 0.9), get_row("docB", 0.7)] 222 | secondary = [get_row("docB", 0.8), get_row("docA", 0.2)] 223 | 224 | results = weighted_sum_ranking(primary, secondary) # type: ignore 225 | assert len(results) == 2 226 | assert results[0]["id_val"] == "docB" 227 | assert results[0]["distance"] == pytest.approx(0.75) 228 | assert results[1]["id_val"] == "docA" 229 | assert results[1]["distance"] == pytest.approx(0.55) 230 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_indexes.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | 5 | from langchain_postgres.v2.indexes import ( 6 | DistanceStrategy, 7 | HNSWIndex, 8 | HNSWQueryOptions, 9 | IVFFlatIndex, 10 | IVFFlatQueryOptions, 11 | ) 12 | 13 | 14 | @pytest.mark.enable_socket 15 | class TestPGIndex: 16 | def test_distance_strategy(self) -> None: 17 | assert DistanceStrategy.EUCLIDEAN.operator == "<->" 18 | assert DistanceStrategy.EUCLIDEAN.search_function == "l2_distance" 19 | assert DistanceStrategy.EUCLIDEAN.index_function == "vector_l2_ops" 20 | 21 | assert DistanceStrategy.COSINE_DISTANCE.operator == "<=>" 22 | assert DistanceStrategy.COSINE_DISTANCE.search_function == "cosine_distance" 23 | assert DistanceStrategy.COSINE_DISTANCE.index_function == "vector_cosine_ops" 24 | 25 | assert DistanceStrategy.INNER_PRODUCT.operator == "<#>" 26 | assert DistanceStrategy.INNER_PRODUCT.search_function == "inner_product" 27 | assert DistanceStrategy.INNER_PRODUCT.index_function == "vector_ip_ops" 28 | 29 | def test_hnsw_index(self) -> None: 30 | index = HNSWIndex(name="test_index", m=32, ef_construction=128) 31 | assert index.index_type == "hnsw" 32 | assert index.m == 32 33 | assert index.ef_construction == 128 34 | assert index.index_options() == "(m = 32, ef_construction = 128)" 35 | 36 | def test_hnsw_query_options(self) -> None: 37 | options = HNSWQueryOptions(ef_search=80) 38 | assert options.to_parameter() == ["hnsw.ef_search = 80"] 39 | 40 | with warnings.catch_warnings(record=True) as w: 41 | options.to_string() 42 | 43 | assert len(w) == 1 44 | assert "to_string is deprecated, use to_parameter instead." in str( 45 | w[-1].message 46 | ) 47 | 48 | def test_ivfflat_index(self) -> None: 49 | index = IVFFlatIndex(name="test_index", lists=200) 50 | assert index.index_type == "ivfflat" 51 | assert index.lists == 200 52 | assert index.index_options() == "(lists = 200)" 53 | 54 | def test_ivfflat_query_options(self) -> None: 55 | options = IVFFlatQueryOptions(probes=2) 56 | assert options.to_parameter() == ["ivfflat.probes = 2"] 57 | 58 | with warnings.catch_warnings(record=True) as w: 59 | options.to_string() 60 | assert len(w) == 1 61 | assert "to_string is deprecated, use to_parameter instead." in str( 62 | w[-1].message 63 | ) 64 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_pg_vectorstore_from_methods.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import AsyncIterator, Sequence 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from langchain_core.documents import Document 8 | from langchain_core.embeddings import DeterministicFakeEmbedding 9 | from sqlalchemy import text 10 | from sqlalchemy.engine.row import RowMapping 11 | 12 | from langchain_postgres import Column, PGEngine, PGVectorStore 13 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 14 | 15 | DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_") 16 | DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_") 17 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") 18 | CUSTOM_TABLE_WITH_INT_ID = "custom_int_id" + str(uuid.uuid4()).replace("-", "_") 19 | CUSTOM_TABLE_WITH_INT_ID_SYNC = "custom_int_id_sync" + str(uuid.uuid4()).replace( 20 | "-", "_" 21 | ) 22 | VECTOR_SIZE = 768 23 | 24 | 25 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) 26 | 27 | texts = ["foo", "bar", "baz"] 28 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] 29 | docs = [ 30 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) 31 | ] 32 | 33 | embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] 34 | 35 | 36 | def get_env_var(key: str, desc: str) -> str: 37 | v = os.environ.get(key) 38 | if v is None: 39 | raise ValueError(f"Must set env var {key} to: {desc}") 40 | return v 41 | 42 | 43 | async def aexecute( 44 | engine: PGEngine, 45 | query: str, 46 | ) -> None: 47 | async def run(engine: PGEngine, query: str) -> None: 48 | async with engine._pool.connect() as conn: 49 | await conn.execute(text(query)) 50 | await conn.commit() 51 | 52 | await engine._run_as_async(run(engine, query)) 53 | 54 | 55 | async def afetch(engine: PGEngine, query: str) -> Sequence[RowMapping]: 56 | async def run(engine: PGEngine, query: str) -> Sequence[RowMapping]: 57 | async with engine._pool.connect() as conn: 58 | result = await conn.execute(text(query)) 59 | result_map = result.mappings() 60 | result_fetch = result_map.fetchall() 61 | return result_fetch 62 | 63 | return await engine._run_as_async(run(engine, query)) 64 | 65 | 66 | @pytest.mark.enable_socket 67 | @pytest.mark.asyncio 68 | class TestVectorStoreFromMethods: 69 | @pytest_asyncio.fixture 70 | async def engine(self) -> AsyncIterator[PGEngine]: 71 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 72 | await engine.ainit_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) 73 | await engine.ainit_vectorstore_table( 74 | CUSTOM_TABLE, 75 | VECTOR_SIZE, 76 | id_column="myid", 77 | content_column="mycontent", 78 | embedding_column="myembedding", 79 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 80 | store_metadata=False, 81 | ) 82 | await engine.ainit_vectorstore_table( 83 | CUSTOM_TABLE_WITH_INT_ID, 84 | VECTOR_SIZE, 85 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False), 86 | content_column="mycontent", 87 | embedding_column="myembedding", 88 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 89 | store_metadata=False, 90 | ) 91 | yield engine 92 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") 93 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE}") 94 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID}") 95 | await engine.close() 96 | 97 | @pytest_asyncio.fixture 98 | async def engine_sync(self) -> AsyncIterator[PGEngine]: 99 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 100 | engine.init_vectorstore_table(DEFAULT_TABLE_SYNC, VECTOR_SIZE) 101 | engine.init_vectorstore_table( 102 | CUSTOM_TABLE_WITH_INT_ID_SYNC, 103 | VECTOR_SIZE, 104 | id_column=Column(name="integer_id", data_type="INTEGER", nullable=False), 105 | content_column="mycontent", 106 | embedding_column="myembedding", 107 | metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], 108 | store_metadata=False, 109 | ) 110 | 111 | yield engine 112 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_SYNC}") 113 | await aexecute(engine, f"DROP TABLE IF EXISTS {CUSTOM_TABLE_WITH_INT_ID_SYNC}") 114 | await engine.close() 115 | 116 | async def test_afrom_texts(self, engine: PGEngine) -> None: 117 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 118 | await PGVectorStore.afrom_texts( 119 | texts, 120 | embeddings_service, 121 | engine, 122 | DEFAULT_TABLE, 123 | metadatas=metadatas, 124 | ids=ids, 125 | ) 126 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") 127 | assert len(results) == 3 128 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") 129 | 130 | async def test_from_texts(self, engine_sync: PGEngine) -> None: 131 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 132 | PGVectorStore.from_texts( 133 | texts, 134 | embeddings_service, 135 | engine_sync, 136 | DEFAULT_TABLE_SYNC, 137 | metadatas=metadatas, 138 | ids=ids, 139 | ) 140 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") 141 | assert len(results) == 3 142 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") 143 | 144 | async def test_afrom_docs(self, engine: PGEngine) -> None: 145 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 146 | await PGVectorStore.afrom_documents( 147 | docs, 148 | embeddings_service, 149 | engine, 150 | DEFAULT_TABLE, 151 | ids=ids, 152 | ) 153 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE}") 154 | assert len(results) == 3 155 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE}") 156 | 157 | async def test_from_docs(self, engine_sync: PGEngine) -> None: 158 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 159 | PGVectorStore.from_documents( 160 | docs, 161 | embeddings_service, 162 | engine_sync, 163 | DEFAULT_TABLE_SYNC, 164 | ids=ids, 165 | ) 166 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") 167 | assert len(results) == 3 168 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") 169 | 170 | async def test_afrom_docs_cross_env(self, engine_sync: PGEngine) -> None: 171 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 172 | await PGVectorStore.afrom_documents( 173 | docs, 174 | embeddings_service, 175 | engine_sync, 176 | DEFAULT_TABLE_SYNC, 177 | ids=ids, 178 | ) 179 | results = await afetch(engine_sync, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") 180 | assert len(results) == 3 181 | await aexecute(engine_sync, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") 182 | 183 | async def test_from_docs_cross_env( 184 | self, engine: PGEngine, engine_sync: PGEngine 185 | ) -> None: 186 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 187 | PGVectorStore.from_documents( 188 | docs, 189 | embeddings_service, 190 | engine, 191 | DEFAULT_TABLE_SYNC, 192 | ids=ids, 193 | ) 194 | results = await afetch(engine, f"SELECT * FROM {DEFAULT_TABLE_SYNC}") 195 | assert len(results) == 3 196 | await aexecute(engine, f"TRUNCATE TABLE {DEFAULT_TABLE_SYNC}") 197 | 198 | async def test_afrom_texts_custom(self, engine: PGEngine) -> None: 199 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 200 | await PGVectorStore.afrom_texts( 201 | texts, 202 | embeddings_service, 203 | engine, 204 | CUSTOM_TABLE, 205 | ids=ids, 206 | id_column="myid", 207 | content_column="mycontent", 208 | embedding_column="myembedding", 209 | metadata_columns=["page", "source"], 210 | ) 211 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") 212 | assert len(results) == 3 213 | assert results[0]["mycontent"] == "foo" 214 | assert results[0]["myembedding"] 215 | assert results[0]["page"] is None 216 | assert results[0]["source"] is None 217 | 218 | async def test_afrom_docs_custom(self, engine: PGEngine) -> None: 219 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 220 | docs = [ 221 | Document( 222 | page_content=texts[i], 223 | metadata={"page": str(i), "source": "postgres"}, 224 | ) 225 | for i in range(len(texts)) 226 | ] 227 | await PGVectorStore.afrom_documents( 228 | docs, 229 | embeddings_service, 230 | engine, 231 | CUSTOM_TABLE, 232 | ids=ids, 233 | id_column="myid", 234 | content_column="mycontent", 235 | embedding_column="myembedding", 236 | metadata_columns=["page", "source"], 237 | ) 238 | 239 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE}") 240 | assert len(results) == 3 241 | assert results[0]["mycontent"] == "foo" 242 | assert results[0]["myembedding"] 243 | assert results[0]["page"] == "0" 244 | assert results[0]["source"] == "postgres" 245 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE}") 246 | 247 | async def test_afrom_texts_custom_with_int_id(self, engine: PGEngine) -> None: 248 | ids = [i for i in range(len(texts))] 249 | await PGVectorStore.afrom_texts( 250 | texts, 251 | embeddings_service, 252 | engine, 253 | CUSTOM_TABLE_WITH_INT_ID, 254 | metadatas=metadatas, 255 | ids=ids, 256 | id_column="integer_id", 257 | content_column="mycontent", 258 | embedding_column="myembedding", 259 | metadata_columns=["page", "source"], 260 | ) 261 | results = await afetch(engine, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID}") 262 | assert len(results) == 3 263 | for row in results: 264 | assert isinstance(row["integer_id"], int) 265 | await aexecute(engine, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID}") 266 | 267 | async def test_from_texts_custom_with_int_id(self, engine_sync: PGEngine) -> None: 268 | ids = [i for i in range(len(texts))] 269 | PGVectorStore.from_texts( 270 | texts, 271 | embeddings_service, 272 | engine_sync, 273 | CUSTOM_TABLE_WITH_INT_ID_SYNC, 274 | ids=ids, 275 | id_column="integer_id", 276 | content_column="mycontent", 277 | embedding_column="myembedding", 278 | metadata_columns=["page", "source"], 279 | ) 280 | results = await afetch( 281 | engine_sync, f"SELECT * FROM {CUSTOM_TABLE_WITH_INT_ID_SYNC}" 282 | ) 283 | assert len(results) == 3 284 | for row in results: 285 | assert isinstance(row["integer_id"], int) 286 | await aexecute(engine_sync, f"TRUNCATE TABLE {CUSTOM_TABLE_WITH_INT_ID_SYNC}") 287 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_pg_vectorstore_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import AsyncIterator 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from langchain_core.documents import Document 8 | from langchain_core.embeddings import DeterministicFakeEmbedding 9 | from sqlalchemy import text 10 | 11 | from langchain_postgres import PGEngine, PGVectorStore 12 | from langchain_postgres.v2.indexes import ( 13 | DistanceStrategy, 14 | HNSWIndex, 15 | IVFFlatIndex, 16 | ) 17 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 18 | 19 | uuid_str = str(uuid.uuid4()).replace("-", "_") 20 | uuid_str_sync = str(uuid.uuid4()).replace("-", "_") 21 | DEFAULT_TABLE = "default" + uuid_str 22 | DEFAULT_TABLE_ASYNC = "default_sync" + uuid_str_sync 23 | CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_") 24 | DEFAULT_INDEX_NAME = "index" + uuid_str 25 | DEFAULT_INDEX_NAME_ASYNC = "index" + uuid_str_sync 26 | VECTOR_SIZE = 768 27 | 28 | embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) 29 | 30 | texts = ["foo", "bar", "baz"] 31 | ids = [str(uuid.uuid4()) for i in range(len(texts))] 32 | metadatas = [{"page": str(i), "source": "postgres"} for i in range(len(texts))] 33 | docs = [ 34 | Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) 35 | ] 36 | 37 | embeddings = [embeddings_service.embed_query("foo") for i in range(len(texts))] 38 | 39 | 40 | def get_env_var(key: str, desc: str) -> str: 41 | v = os.environ.get(key) 42 | if v is None: 43 | raise ValueError(f"Must set env var {key} to: {desc}") 44 | return v 45 | 46 | 47 | async def aexecute( 48 | engine: PGEngine, 49 | query: str, 50 | ) -> None: 51 | async def run(engine: PGEngine, query: str) -> None: 52 | async with engine._pool.connect() as conn: 53 | await conn.execute(text(query)) 54 | await conn.commit() 55 | 56 | await engine._run_as_async(run(engine, query)) 57 | 58 | 59 | @pytest.mark.enable_socket 60 | @pytest.mark.asyncio(scope="class") 61 | class TestIndex: 62 | @pytest_asyncio.fixture(scope="class") 63 | async def engine(self) -> AsyncIterator[PGEngine]: 64 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 65 | yield engine 66 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE}") 67 | await engine.close() 68 | 69 | @pytest_asyncio.fixture(scope="class") 70 | async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]: 71 | engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) 72 | vs = PGVectorStore.create_sync( 73 | engine, 74 | embedding_service=embeddings_service, 75 | table_name=DEFAULT_TABLE, 76 | ) 77 | 78 | vs.add_texts(texts, ids=ids) 79 | vs.drop_vector_index(DEFAULT_INDEX_NAME) 80 | yield vs 81 | 82 | async def test_aapply_vector_index(self, vs: PGVectorStore) -> None: 83 | index = HNSWIndex(name=DEFAULT_INDEX_NAME) 84 | vs.apply_vector_index(index) 85 | assert vs.is_valid_index(DEFAULT_INDEX_NAME) 86 | vs.drop_vector_index(DEFAULT_INDEX_NAME) 87 | 88 | async def test_areindex(self, vs: PGVectorStore) -> None: 89 | if not vs.is_valid_index(DEFAULT_INDEX_NAME): 90 | index = HNSWIndex(name=DEFAULT_INDEX_NAME) 91 | vs.apply_vector_index(index) 92 | vs.reindex(DEFAULT_INDEX_NAME) 93 | vs.reindex(DEFAULT_INDEX_NAME) 94 | assert vs.is_valid_index(DEFAULT_INDEX_NAME) 95 | vs.drop_vector_index(DEFAULT_INDEX_NAME) 96 | 97 | async def test_dropindex(self, vs: PGVectorStore) -> None: 98 | vs.drop_vector_index(DEFAULT_INDEX_NAME) 99 | result = vs.is_valid_index(DEFAULT_INDEX_NAME) 100 | assert not result 101 | 102 | async def test_aapply_vector_index_ivfflat(self, vs: PGVectorStore) -> None: 103 | index = IVFFlatIndex( 104 | name=DEFAULT_INDEX_NAME, distance_strategy=DistanceStrategy.EUCLIDEAN 105 | ) 106 | vs.apply_vector_index(index, concurrently=True) 107 | assert vs.is_valid_index(DEFAULT_INDEX_NAME) 108 | index = IVFFlatIndex( 109 | name="secondindex", 110 | distance_strategy=DistanceStrategy.INNER_PRODUCT, 111 | ) 112 | vs.apply_vector_index(index) 113 | assert vs.is_valid_index("secondindex") 114 | vs.drop_vector_index("secondindex") 115 | vs.drop_vector_index(DEFAULT_INDEX_NAME) 116 | 117 | async def test_is_valid_index(self, vs: PGVectorStore) -> None: 118 | is_valid = vs.is_valid_index("invalid_index") 119 | assert not is_valid 120 | 121 | 122 | @pytest.mark.enable_socket 123 | @pytest.mark.asyncio(scope="class") 124 | class TestAsyncIndex: 125 | @pytest_asyncio.fixture(scope="class") 126 | async def engine(self) -> AsyncIterator[PGEngine]: 127 | engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 128 | yield engine 129 | await aexecute(engine, f"DROP TABLE IF EXISTS {DEFAULT_TABLE_ASYNC}") 130 | await engine.close() 131 | 132 | @pytest_asyncio.fixture(scope="class") 133 | async def vs(self, engine: PGEngine) -> AsyncIterator[PGVectorStore]: 134 | await engine.ainit_vectorstore_table(DEFAULT_TABLE_ASYNC, VECTOR_SIZE) 135 | vs = await PGVectorStore.create( 136 | engine, 137 | embedding_service=embeddings_service, 138 | table_name=DEFAULT_TABLE_ASYNC, 139 | ) 140 | 141 | await vs.aadd_texts(texts, ids=ids) 142 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) 143 | yield vs 144 | 145 | async def test_aapply_vector_index(self, vs: PGVectorStore) -> None: 146 | index = HNSWIndex(name=DEFAULT_INDEX_NAME_ASYNC) 147 | await vs.aapply_vector_index(index) 148 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) 149 | 150 | async def test_areindex(self, vs: PGVectorStore) -> None: 151 | if not await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC): 152 | index = HNSWIndex(name=DEFAULT_INDEX_NAME_ASYNC) 153 | await vs.aapply_vector_index(index) 154 | await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) 155 | await vs.areindex(DEFAULT_INDEX_NAME_ASYNC) 156 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) 157 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) 158 | 159 | async def test_dropindex(self, vs: PGVectorStore) -> None: 160 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) 161 | result = await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) 162 | assert not result 163 | 164 | async def test_aapply_vector_index_ivfflat(self, vs: PGVectorStore) -> None: 165 | index = IVFFlatIndex( 166 | name=DEFAULT_INDEX_NAME_ASYNC, distance_strategy=DistanceStrategy.EUCLIDEAN 167 | ) 168 | await vs.aapply_vector_index(index, concurrently=True) 169 | assert await vs.ais_valid_index(DEFAULT_INDEX_NAME_ASYNC) 170 | index = IVFFlatIndex( 171 | name="secondindex", 172 | distance_strategy=DistanceStrategy.INNER_PRODUCT, 173 | ) 174 | await vs.aapply_vector_index(index) 175 | assert await vs.ais_valid_index("secondindex") 176 | await vs.adrop_vector_index("secondindex") 177 | await vs.adrop_vector_index(DEFAULT_INDEX_NAME_ASYNC) 178 | 179 | async def test_is_valid_index(self, vs: PGVectorStore) -> None: 180 | is_valid = await vs.ais_valid_index("invalid_index") 181 | assert not is_valid 182 | -------------------------------------------------------------------------------- /tests/unit_tests/v2/test_pg_vectorstore_standard_suite.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import AsyncGenerator, AsyncIterator 4 | 5 | import pytest 6 | import pytest_asyncio 7 | from langchain_tests.integration_tests import VectorStoreIntegrationTests 8 | from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE 9 | from sqlalchemy import text 10 | 11 | from langchain_postgres import Column, PGEngine, PGVectorStore 12 | from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING 13 | 14 | DEFAULT_TABLE = "standard" + str(uuid.uuid4()) 15 | DEFAULT_TABLE_SYNC = "sync_standard" + str(uuid.uuid4()) 16 | 17 | 18 | def get_env_var(key: str, desc: str) -> str: 19 | v = os.environ.get(key) 20 | if v is None: 21 | raise ValueError(f"Must set env var {key} to: {desc}") 22 | return v 23 | 24 | 25 | async def aexecute( 26 | engine: PGEngine, 27 | query: str, 28 | ) -> None: 29 | async def run(engine: PGEngine, query: str) -> None: 30 | async with engine._pool.connect() as conn: 31 | await conn.execute(text(query)) 32 | await conn.commit() 33 | 34 | await engine._run_as_async(run(engine, query)) 35 | 36 | 37 | @pytest.mark.enable_socket 38 | # @pytest.mark.filterwarnings("ignore") 39 | @pytest.mark.asyncio 40 | class TestStandardSuiteSync(VectorStoreIntegrationTests): 41 | @pytest_asyncio.fixture(scope="function") 42 | async def sync_engine(self) -> AsyncGenerator[PGEngine, None]: 43 | sync_engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 44 | yield sync_engine 45 | await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"') 46 | await sync_engine.close() 47 | 48 | @pytest.fixture(scope="function") 49 | def vectorstore(self, sync_engine: PGEngine) -> PGVectorStore: # type: ignore 50 | """Get an empty vectorstore for unit tests.""" 51 | sync_engine.init_vectorstore_table( 52 | DEFAULT_TABLE_SYNC, 53 | EMBEDDING_SIZE, 54 | id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False), 55 | ) 56 | 57 | vs = PGVectorStore.create_sync( 58 | sync_engine, 59 | embedding_service=self.get_embeddings(), 60 | table_name=DEFAULT_TABLE_SYNC, 61 | ) 62 | yield vs 63 | 64 | 65 | @pytest.mark.enable_socket 66 | @pytest.mark.asyncio 67 | class TestStandardSuiteAsync(VectorStoreIntegrationTests): 68 | @pytest_asyncio.fixture(scope="function") 69 | async def async_engine(self) -> AsyncIterator[PGEngine]: 70 | async_engine = PGEngine.from_connection_string(url=CONNECTION_STRING) 71 | yield async_engine 72 | await aexecute(async_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"') 73 | await async_engine.close() 74 | 75 | @pytest_asyncio.fixture(scope="function") 76 | async def vectorstore( # type: ignore[override] 77 | self, async_engine: PGEngine 78 | ) -> AsyncGenerator[PGVectorStore, None]: # type: ignore 79 | """Get an empty vectorstore for unit tests.""" 80 | await async_engine.ainit_vectorstore_table( 81 | DEFAULT_TABLE, 82 | EMBEDDING_SIZE, 83 | id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False), 84 | ) 85 | 86 | vs = await PGVectorStore.create( 87 | async_engine, 88 | embedding_service=self.get_embeddings(), 89 | table_name=DEFAULT_TABLE, 90 | ) 91 | 92 | yield vs 93 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | """Get fixtures for the database connection.""" 2 | 3 | import os 4 | from contextlib import asynccontextmanager, contextmanager 5 | 6 | import psycopg 7 | from typing_extensions import AsyncGenerator, Generator 8 | 9 | # Env variables match the default settings in the docker-compose file 10 | # located in the root of the repository: [root]/docker-compose.yml 11 | # Non-standard ports are used to avoid conflicts with other local postgres 12 | # instances. 13 | # To spint up the postgres service for testing, run: 14 | # cd [root]/docker-compose.yml 15 | # docker-compose up pgvector 16 | POSTGRES_USER = os.environ.get("POSTGRES_USER", "langchain") 17 | POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost") 18 | POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "langchain") 19 | POSTGRES_DB = os.environ.get("POSTGRES_DB", "langchain_test") 20 | 21 | POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5432") 22 | 23 | DSN = ( 24 | f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}" 25 | f":{POSTGRES_PORT}/{POSTGRES_DB}" 26 | ) 27 | 28 | # Connection string used primarily by the vectorstores tests 29 | # it's written to work with SQLAlchemy (takes a driver name) 30 | # It is also running on a postgres instance that has the pgvector extension 31 | VECTORSTORE_CONNECTION_STRING = ( 32 | f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}" 33 | f":{POSTGRES_PORT}/{POSTGRES_DB}" 34 | ) 35 | 36 | 37 | @asynccontextmanager 38 | async def asyncpg_client() -> AsyncGenerator[psycopg.AsyncConnection, None]: 39 | # Establish a connection to your test database 40 | conn = await psycopg.AsyncConnection.connect(conninfo=DSN) 41 | try: 42 | yield conn 43 | finally: 44 | # Cleanup: close the connection after the test is done 45 | await conn.close() 46 | 47 | 48 | @contextmanager 49 | def syncpg_client() -> Generator[psycopg.Connection, None, None]: 50 | # Establish a connection to your test database 51 | conn = psycopg.connect(conninfo=DSN) 52 | try: 53 | yield conn 54 | finally: 55 | # Cleanup: close the connection after the test is done 56 | conn.close() 57 | --------------------------------------------------------------------------------