├── docs ├── index.md ├── changelog.md ├── reference │ ├── errors.md │ ├── neo4j_client.md │ ├── neo4j_retriever.md │ ├── neo4j_store.md │ ├── serialization │ │ ├── types.md │ │ └── query_parameters_marshaller.md │ ├── metadata_filter │ │ ├── parser.md │ │ └── neo4j_query_converter.md │ ├── neo4j_query_reader.md │ └── neo4j_query_writer.md ├── license.md ├── assets │ └── panda-logo.png └── stylesheets │ └── mkdocstrings.css ├── src └── neo4j_haystack │ ├── __about__.py │ ├── document_stores │ ├── __init__.py │ └── utils.py │ ├── client │ └── __init__.py │ ├── serialization │ ├── types.py │ ├── __init__.py │ └── query_parameters_marshaller.py │ ├── components │ ├── __init__.py │ ├── neo4j_query_writer.py │ ├── neo4j_query_reader.py │ └── neo4j_retriever.py │ ├── metadata_filter │ ├── __init__.py │ └── parser.py │ ├── __init__.py │ └── errors.py ├── tests ├── __init__.py └── neo4j_haystack │ ├── client │ └── test_client_config.py │ ├── components │ ├── test_neo4j_retriever.py │ ├── test_neo4j_query_reader.py │ ├── test_neo4j_dynamic_retriever.py │ └── test_neo4j_query_writer.py │ ├── document_stores │ ├── test_ds_filters.py │ ├── test_ds_extended.py │ └── test_ds_protocol.py │ ├── conftest.py │ ├── metadata_filter │ └── test_filters.py │ └── samples │ └── movies.json ├── .gitignore ├── scripts ├── setup.sh └── load_movies.py ├── .github └── workflows │ ├── release.yml │ ├── api-docs.yml │ ├── ci.yml │ └── test.yml ├── LICENSE.txt ├── examples ├── indexing_pipeline.py ├── rag_pipeline.py ├── rag-parent-child │ ├── rag-pipeline.py │ └── indexing-pipeline.py └── rag_pipeline_cypher.py ├── mkdocs.yml ├── CONTRIBUTING.md ├── pyproject.toml └── CHANGELOG.md /docs/index.md: -------------------------------------------------------------------------------- 1 | --8<-- "README.md" -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | --8<-- "CHANGELOG.md" -------------------------------------------------------------------------------- /docs/reference/errors.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.errors 2 | -------------------------------------------------------------------------------- /src/neo4j_haystack/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.1.0" 2 | -------------------------------------------------------------------------------- /docs/reference/neo4j_client.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.client.neo4j_client 2 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | # License 2 | 3 | ``` 4 | --8<-- "LICENSE.txt" 5 | ``` 6 | -------------------------------------------------------------------------------- /docs/reference/neo4j_retriever.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.components.neo4j_retriever 2 | -------------------------------------------------------------------------------- /docs/reference/neo4j_store.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.document_stores.neo4j_store 2 | -------------------------------------------------------------------------------- /docs/reference/serialization/types.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.serialization.types 2 | -------------------------------------------------------------------------------- /docs/reference/metadata_filter/parser.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.metadata_filter.parser 2 | -------------------------------------------------------------------------------- /docs/reference/neo4j_query_reader.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.components.neo4j_query_reader 2 | -------------------------------------------------------------------------------- /docs/reference/neo4j_query_writer.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.components.neo4j_query_writer 2 | -------------------------------------------------------------------------------- /docs/assets/panda-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sidagarwal04/neo4j-haystack/main/docs/assets/panda-logo.png -------------------------------------------------------------------------------- /docs/reference/metadata_filter/neo4j_query_converter.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.metadata_filter.neo4j_query_converter 2 | -------------------------------------------------------------------------------- /docs/reference/serialization/query_parameters_marshaller.md: -------------------------------------------------------------------------------- 1 | ::: neo4j_haystack.serialization.query_parameters_marshaller 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2023-present Sergey Bondarenco 2 | # 3 | # SPDX-License-Identifier: MIT 4 | -------------------------------------------------------------------------------- /src/neo4j_haystack/document_stores/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 2 | 3 | __all__ = ("Neo4jDocumentStore",) 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | __pycache__/ 4 | *.py[cod] 5 | dist/ 6 | *.egg-info/ 7 | build/ 8 | htmlcov/ 9 | site/ 10 | .coverage* 11 | __pypackages__/ 12 | pip-wheel-metadata/ 13 | .pytest_cache/ 14 | .mypy_cache/ 15 | .ruff_cache/ 16 | .python-version 17 | .venv/ 18 | .cache/ 19 | .env 20 | -------------------------------------------------------------------------------- /src/neo4j_haystack/client/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.client.neo4j_client import ( 2 | Neo4jClient, 3 | Neo4jClientConfig, 4 | Neo4jRecord, 5 | VectorStoreIndexInfo, 6 | ) 7 | 8 | __all__ = ( 9 | "Neo4jClient", 10 | "Neo4jClientConfig", 11 | "Neo4jRecord", 12 | "VectorStoreIndexInfo", 13 | ) 14 | -------------------------------------------------------------------------------- /src/neo4j_haystack/serialization/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol 2 | 3 | 4 | class QueryParametersMarshaller(Protocol): 5 | """ 6 | A Protocol to be used by marshaller implementations which convert Neo4j query parameters to appropriate types. 7 | """ 8 | 9 | def supports(self, obj: Any) -> bool: 10 | pass 11 | 12 | def marshal(self, obj: Any) -> Any: 13 | pass 14 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Installs `hatch` globally using `pipx` 5 | 6 | if ! command -v hatch &>/dev/null; then 7 | if ! command -v pipx &>/dev/null; then 8 | echo "Installing 'pipx'..." 9 | python3 -m pip install --user pipx 10 | fi 11 | echo "Installing 'hatch'..." 12 | pipx install hatch 13 | else 14 | echo "'hatch' is already installed: $(which hatch)" 15 | fi 16 | 17 | -------------------------------------------------------------------------------- /src/neo4j_haystack/components/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.components.neo4j_query_reader import Neo4jQueryReader 2 | from neo4j_haystack.components.neo4j_query_writer import Neo4jQueryWriter 3 | from neo4j_haystack.components.neo4j_retriever import ( 4 | Neo4jDynamicDocumentRetriever, 5 | Neo4jEmbeddingRetriever, 6 | ) 7 | 8 | __all__ = ( 9 | "Neo4jEmbeddingRetriever", 10 | "Neo4jDynamicDocumentRetriever", 11 | "Neo4jQueryWriter", 12 | "Neo4jQueryReader", 13 | ) 14 | -------------------------------------------------------------------------------- /src/neo4j_haystack/metadata_filter/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.metadata_filter.neo4j_query_converter import Neo4jQueryConverter 2 | from neo4j_haystack.metadata_filter.parser import ( 3 | AST, 4 | COMPARISON_OPS, 5 | LOGICAL_OPS, 6 | FilterParser, 7 | FilterType, 8 | OperatorAST, 9 | ) 10 | 11 | __all__ = ( 12 | "Neo4jQueryConverter", 13 | "FilterParser", 14 | "AST", 15 | "FilterType", 16 | "LOGICAL_OPS", 17 | "COMPARISON_OPS", 18 | "OperatorAST", 19 | ) 20 | -------------------------------------------------------------------------------- /src/neo4j_haystack/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.serialization.query_parameters_marshaller import ( 2 | DataclassQueryParametersMarshaller, 3 | DocumentQueryParametersMarshaller, 4 | Neo4jQueryParametersMarshaller, 5 | ) 6 | from neo4j_haystack.serialization.types import QueryParametersMarshaller 7 | 8 | __all__ = ( 9 | "QueryParametersMarshaller", 10 | "Neo4jQueryParametersMarshaller", 11 | "DataclassQueryParametersMarshaller", 12 | "DocumentQueryParametersMarshaller", 13 | ) 14 | -------------------------------------------------------------------------------- /src/neo4j_haystack/__init__.py: -------------------------------------------------------------------------------- 1 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig, VectorStoreIndexInfo 2 | from neo4j_haystack.components import ( 3 | Neo4jDynamicDocumentRetriever, 4 | Neo4jEmbeddingRetriever, 5 | Neo4jQueryWriter, 6 | ) 7 | from neo4j_haystack.document_stores import Neo4jDocumentStore 8 | 9 | __all__ = ( 10 | "Neo4jDocumentStore", 11 | "Neo4jClient", 12 | "Neo4jClientConfig", 13 | "VectorStoreIndexInfo", 14 | "Neo4jEmbeddingRetriever", 15 | "Neo4jDynamicDocumentRetriever", 16 | "Neo4jQueryWriter", 17 | ) 18 | -------------------------------------------------------------------------------- /src/neo4j_haystack/errors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from haystack.document_stores.errors import DocumentStoreError 4 | 5 | 6 | class Neo4jDocumentStoreError(DocumentStoreError): 7 | """Error for issues that occur in a Neo4j Document Store""" 8 | 9 | def __init__(self, message: Optional[str] = None): 10 | super().__init__(message=message) 11 | 12 | 13 | class Neo4jClientError(DocumentStoreError): 14 | """Error for issues that occur in a Neo4j client""" 15 | 16 | def __init__(self, message: Optional[str] = None): 17 | super().__init__(message=message) 18 | 19 | 20 | class Neo4jFilterParserError(Exception): 21 | """Error is raised when metadata filters are failing to parse""" 22 | 23 | pass 24 | -------------------------------------------------------------------------------- /src/neo4j_haystack/document_stores/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import MutableMapping 2 | from itertools import islice 3 | 4 | 5 | def flatten_dict(dictionary: MutableMapping, parent_key="", separator="."): 6 | items = [] 7 | for key, value in dictionary.items(): 8 | new_key = parent_key + separator + key if parent_key else key 9 | if isinstance(value, MutableMapping): 10 | items.extend(flatten_dict(value, new_key, separator=separator).items()) 11 | else: 12 | items.append((new_key, value)) 13 | return dict(items) 14 | 15 | 16 | def get_batches_from_generator(iterable, n): 17 | """ 18 | Batch elements of an iterable into fixed-length chunks or blocks. 19 | """ 20 | it = iter(iterable) 21 | x = tuple(islice(it, n)) 22 | while x: 23 | yield x 24 | x = tuple(islice(it, n)) 25 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: pypi-release 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | tags: 7 | - "v[0-9].[0-9]+.[0-9]+*" 8 | 9 | concurrency: 10 | group: pypi-release-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | deploy: 18 | runs-on: ubuntu-latest 19 | 20 | permissions: 21 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 22 | 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v3 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: 3.x 31 | 32 | - name: Install dependencies 33 | run: | 34 | pip install --upgrade pip 35 | pip install hatch 36 | 37 | - name: Build package 38 | run: hatch build 39 | 40 | - name: Publish package distributions to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present Sergey Bondarenco 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /.github/workflows/api-docs.yml: -------------------------------------------------------------------------------- 1 | name: api-docs 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | tags: 7 | - "v[0-9].[0-9]+.[0-9]+*" 8 | 9 | permissions: 10 | contents: write 11 | 12 | concurrency: 13 | group: api-docs-${{ github.head_ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | publish-api-docs: 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout 22 | uses: actions/checkout@v3 23 | 24 | - name: Setup Python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: 3.x 28 | 29 | - name: Install Hatch 30 | run: pip install --upgrade hatch 31 | 32 | - name: Configure Git Credentials 33 | run: | 34 | git config user.name github-actions[bot] 35 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 36 | 37 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 38 | - uses: actions/cache@v3 39 | with: 40 | key: mkdocs-material-${{ env.cache_id }} 41 | path: .cache 42 | restore-keys: | 43 | mkdocs-material- 44 | 45 | - name: Publish Docs 46 | run: hatch run docs:publish 47 | -------------------------------------------------------------------------------- /docs/stylesheets/mkdocstrings.css: -------------------------------------------------------------------------------- 1 | /* Indentation. */ 2 | div.doc-contents:not(.first) { 3 | padding-left: 25px; 4 | border-left: .05rem solid var(--md-typeset-table-color); 5 | } 6 | 7 | /* Mark external links as such. */ 8 | a.external::after, 9 | a.autorefs-external::after { 10 | /* https://primer.style/octicons/arrow-up-right-24 */ 11 | mask-image: url('data:image/svg+xml,'); 12 | -webkit-mask-image: url('data:image/svg+xml,'); 13 | content: ' '; 14 | 15 | display: inline-block; 16 | vertical-align: middle; 17 | position: relative; 18 | 19 | height: 1em; 20 | width: 1em; 21 | background-color: var(--md-typeset-a-color); 22 | } 23 | 24 | a.external:hover::after, 25 | a.autorefs-external:hover::after { 26 | background-color: var(--md-accent-fg-color); 27 | } -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | concurrency: 10 | group: ci-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | quality: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v3 20 | 21 | - name: Setup Python 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.8" 25 | 26 | - name: Install Hatch 27 | run: | 28 | pip install --upgrade pip 29 | pip install --upgrade hatch 30 | 31 | - name: Code quality checks 32 | run: hatch run lint:all 33 | 34 | test: 35 | strategy: 36 | matrix: 37 | python-version: 38 | - "3.8.x" 39 | - "3.9.x" 40 | - "3.10.x" 41 | os: 42 | - ubuntu-20.04 43 | # - windows-latest 44 | # - macos-latest 45 | 46 | runs-on: ${{ matrix.os }} 47 | 48 | name: Python ${{ matrix.python-version }} test on ${{ matrix.os }} 49 | 50 | steps: 51 | - name: Check out repository 52 | uses: actions/checkout@v3 53 | 54 | - name: Set up python 55 | id: setup-python 56 | uses: actions/setup-python@v4 57 | with: 58 | python-version: ${{ matrix.python-version }} 59 | 60 | - name: Install Hatch 61 | run: pip install --upgrade hatch 62 | 63 | - name: Run tests 64 | run: | 65 | hatch run test 66 | -------------------------------------------------------------------------------- /scripts/load_movies.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, Set 3 | 4 | from datasets import load_dataset 5 | 6 | NUM_ROWS_TO_SAVE = 1000 7 | DATA_FILE = "data/movies.json" 8 | 9 | 10 | def valid_examples(example: Dict[str, Any]) -> bool: 11 | return ( 12 | example["original_language"] == "en" 13 | and example["genres"] 14 | and example["overview"] 15 | and example["title"] 16 | and example["release_date"] 17 | ) 18 | 19 | 20 | unique_ids: Set[Any] = set() 21 | 22 | 23 | def is_unique(example: Dict[str, Any]) -> bool: 24 | if example["id"] in unique_ids: 25 | return False 26 | else: 27 | unique_ids.add(example["id"]) 28 | return True 29 | 30 | 31 | def convert_to_document(example: Dict[str, Any]) -> Dict[str, Any]: 32 | return { 33 | "id": str(example["id"]), 34 | "content": example["overview"], 35 | "meta": { 36 | "title": example["title"], 37 | "runtime": example["runtime"], 38 | "vote_average": example["vote_average"], 39 | "release_date": example["release_date"], 40 | "genres": example["genres"].split("-"), 41 | }, 42 | } 43 | 44 | 45 | movies_dataset = ( 46 | load_dataset("wykonos/movies", split="train") 47 | .filter(valid_examples) 48 | .filter(is_unique) 49 | .shuffle(seed=42) 50 | .select(range(NUM_ROWS_TO_SAVE)) 51 | ) 52 | 53 | movie_documents = [convert_to_document(movie) for movie in movies_dataset] 54 | 55 | with open(DATA_FILE, "w") as outfile: 56 | json.dump(movie_documents, outfile) 57 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | PYTHON_VERSION: 7 | description: "List of python versions" 8 | required: true 9 | default: '["3.8.x", "3.9.x", "3.10.x"]' 10 | OS_NAME: 11 | description: "List of OS for test execution" 12 | required: true 13 | default: '["ubuntu-20.04"]' 14 | NEO4J_VERSION: 15 | description: "List of Neo4j docker images" 16 | required: true 17 | default: '["neo4j:5.13.0", "neo4j:5.14.0", "neo4j:5.15.0", "neo4j:5.16.0", "neo4j:5.17.0"]' 18 | 19 | concurrency: 20 | group: test-${{ github.head_ref }} 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | test: 25 | strategy: 26 | matrix: 27 | python-version: ${{ fromJSON(github.event.inputs.PYTHON_VERSION) }} 28 | os: ${{ fromJSON(github.event.inputs.OS_NAME) }} 29 | neo4j-version: ${{ fromJSON(github.event.inputs.NEO4J_VERSION) }} 30 | 31 | runs-on: ${{ matrix.os }} 32 | 33 | name: Python ${{ matrix.python-version }} test on Neo4j:${{ matrix.neo4j-version }} OS:${{ matrix.os }} 34 | 35 | steps: 36 | - name: Check out repository 37 | uses: actions/checkout@v3 38 | 39 | - name: Set up python 40 | id: setup-python 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | 45 | - name: Install Hatch 46 | run: pip install --upgrade hatch 47 | 48 | - name: Run tests 49 | env: 50 | NEO4J_VERSION: ${{ matrix.neo4j-version }} 51 | run: | 52 | hatch run test 53 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/client/test_client_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from neo4j import basic_auth 3 | 4 | from neo4j_haystack.client import Neo4jClientConfig 5 | 6 | 7 | @pytest.mark.unit 8 | def test_url_is_mandatory(): 9 | with pytest.raises(ValueError): 10 | Neo4jClientConfig(url=None) 11 | 12 | 13 | @pytest.mark.unit 14 | def test_auth_is_mandatory(): 15 | with pytest.raises(ValueError): 16 | Neo4jClientConfig(username=None) 17 | 18 | with pytest.raises(ValueError): 19 | Neo4jClientConfig(password=None) 20 | 21 | # should not raise as `auth` credentials provided 22 | Neo4jClientConfig(username=None, password=None, auth=basic_auth("username", "password")) 23 | 24 | 25 | @pytest.mark.unit 26 | def test_client_config_to_dict(): 27 | config = Neo4jClientConfig( 28 | url="bolt://localhost:7687", 29 | database="neo4j", 30 | username="username", 31 | password="password", 32 | driver_config={"connection_timeout": 60.0, "keep_alive": True}, 33 | session_config={"fetch_size": 1000}, 34 | transaction_config={"timeout": 10}, 35 | use_env=False, 36 | ) 37 | data = config.to_dict() 38 | 39 | assert data == { 40 | "type": "neo4j_haystack.client.neo4j_client.Neo4jClientConfig", 41 | "init_parameters": { 42 | "url": "bolt://localhost:7687", 43 | "database": "neo4j", 44 | "username": "username", 45 | "password": "password", 46 | "driver_config": {"connection_timeout": 60.0, "keep_alive": True}, 47 | "session_config": {"fetch_size": 1000}, 48 | "transaction_config": {"timeout": 10}, 49 | "use_env": False, 50 | }, 51 | } 52 | 53 | 54 | @pytest.mark.unit 55 | def test_client_config_from_dict(): 56 | data = { 57 | "type": "neo4j_haystack.client.neo4j_client.Neo4jClientConfig", 58 | "init_parameters": { 59 | "url": "bolt://localhost:7687", 60 | "database": "neo4j", 61 | "username": "username", 62 | "password": "password", 63 | "driver_config": {"connection_timeout": 60.0, "keep_alive": True}, 64 | "session_config": {"fetch_size": 1000}, 65 | "transaction_config": {"timeout": 10}, 66 | }, 67 | } 68 | 69 | config = Neo4jClientConfig.from_dict(data) 70 | 71 | assert config.url == "bolt://localhost:7687" 72 | assert config.database == "neo4j" 73 | assert config.username == "username" 74 | assert config.password == "password" 75 | assert config.driver_config == {"connection_timeout": 60.0, "keep_alive": True} 76 | assert config.session_config == {"fetch_size": 1000} 77 | assert config.transaction_config == {"timeout": 10} 78 | -------------------------------------------------------------------------------- /examples/indexing_pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import zipfile 4 | from io import BytesIO 5 | from pathlib import Path 6 | 7 | import requests 8 | from haystack import Pipeline 9 | from haystack.components.converters import TextFileToDocument 10 | from haystack.components.embedders import SentenceTransformersDocumentEmbedder 11 | from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter 12 | from haystack.components.writers import DocumentWriter 13 | 14 | from neo4j_haystack import Neo4jDocumentStore 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def fetch_archive_from_http(url: str, output_dir: str): 20 | if Path(output_dir).is_dir(): 21 | logger.warning(f"'{output_dir}' directory already exists. Skipping data download") 22 | return 23 | 24 | with requests.get(url, timeout=10, stream=True) as response: 25 | with zipfile.ZipFile(BytesIO(response.content)) as zip_ref: 26 | zip_ref.extractall(output_dir) 27 | 28 | 29 | # Let's first get some files that we want to use 30 | docs_dir = "data/docs" 31 | fetch_archive_from_http( 32 | url="https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/wiki_gameofthrones_txt6.zip", 33 | output_dir=docs_dir, 34 | ) 35 | 36 | # Make sure you have a running Neo4j database, e.g. with Docker: 37 | # docker run \ 38 | # --restart always \ 39 | # --publish=7474:7474 --publish=7687:7687 \ 40 | # --env NEO4J_AUTH=neo4j/passw0rd \ 41 | # neo4j:5.15.0 42 | 43 | document_store = Neo4jDocumentStore( 44 | url="bolt://localhost:7687", 45 | username="neo4j", 46 | password="passw0rd", 47 | database="neo4j", 48 | embedding_dim=384, 49 | similarity="cosine", 50 | recreate_index=True, 51 | ) 52 | 53 | # Create components and an indexing pipeline that converts txt to documents, cleans and splits them, and 54 | # indexes them for dense retrieval. 55 | p = Pipeline() 56 | p.add_component("text_file_converter", TextFileToDocument()) 57 | p.add_component("cleaner", DocumentCleaner()) 58 | p.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30)) 59 | p.add_component("embedder", SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")) 60 | p.add_component("writer", DocumentWriter(document_store=document_store)) 61 | 62 | p.connect("text_file_converter.documents", "cleaner.documents") 63 | p.connect("cleaner.documents", "splitter.documents") 64 | p.connect("splitter.documents", "embedder.documents") 65 | p.connect("embedder.documents", "writer.documents") 66 | 67 | # Take the docs data directory as input and run the pipeline 68 | file_paths = [docs_dir / Path(name) for name in os.listdir(docs_dir)] 69 | result = p.run({"text_file_converter": {"sources": file_paths}}) 70 | 71 | # Assuming you have a docker container running navigate to http://localhost:7474 to open Neo4j Browser 72 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "neo4j-haystack" 2 | site_description: "Haystack Document Store for Neo4j." 3 | site_url: "https://prosto.github.io/neo4j-haystack/" 4 | repo_url: "https://github.com/prosto/neo4j-haystack" 5 | repo_name: "prosto/neo4j-haystack" 6 | site_dir: "site" 7 | watch: 8 | - README.md 9 | - README.md 10 | - src 11 | 12 | nav: 13 | - Home: 14 | - Overview: index.md 15 | - Changelog: changelog.md 16 | - License: license.md 17 | - Code Reference: 18 | - Neo4jDocumentStore: reference/neo4j_store.md 19 | - Neo4jRetriever: reference/neo4j_retriever.md 20 | - Neo4jQueryWriter: reference/neo4j_query_writer.md 21 | - Neo4jQueryReader: reference/neo4j_query_reader.md 22 | - Neo4jClient: reference/neo4j_client.md 23 | - MetadataFilter: 24 | - FilterParser: reference/metadata_filter/parser.md 25 | - Neo4jQueryConverter: reference/metadata_filter/neo4j_query_converter.md 26 | - Serialization: 27 | - Neo4jQueryParametersMarshaller: reference/serialization/query_parameters_marshaller.md 28 | - types: reference/serialization/types.md 29 | - errors: reference/errors.md 30 | - Haystack Documentation: https://docs.haystack.deepset.ai/v2.0/docs/intro 31 | 32 | theme: 33 | name: material 34 | logo: assets/panda-logo.png 35 | features: 36 | - navigation.tabs 37 | - navigation.tabs.sticky 38 | palette: 39 | - media: "(prefers-color-scheme: dark)" 40 | scheme: slate 41 | primary: black 42 | accent: lime 43 | toggle: 44 | icon: material/weather-night 45 | name: Switch to light mode 46 | - media: "(prefers-color-scheme: light)" 47 | scheme: default 48 | primary: teal 49 | accent: purple 50 | toggle: 51 | icon: material/weather-sunny 52 | name: Switch to dark mode 53 | 54 | extra_css: 55 | - stylesheets/mkdocstrings.css 56 | 57 | markdown_extensions: 58 | - admonition 59 | - pymdownx.details 60 | - pymdownx.superfences 61 | - pymdownx.highlight: 62 | anchor_linenums: true 63 | line_spans: __span 64 | pygments_lang_class: true 65 | - pymdownx.inlinehilite 66 | - pymdownx.snippets 67 | - attr_list 68 | - toc: 69 | permalink: "¤" 70 | 71 | plugins: 72 | - search 73 | - section-index 74 | - mkdocstrings: 75 | handlers: 76 | python: 77 | paths: [src] 78 | import: 79 | - https://docs.python.org/3/objects.inv 80 | - https://mkdocstrings.github.io/objects.inv 81 | - https://mkdocstrings.github.io/griffe/objects.inv 82 | - https://neo4j.com/docs/api/python-driver/current/objects.inv 83 | options: 84 | docstring_options: 85 | ignore_init_summary: yes 86 | merge_init_into_class: no 87 | separate_signature: true 88 | show_signature_annotations: true 89 | line_length: 80 90 | show_source: yes 91 | show_root_full_path: no 92 | docstring_section_style: list 93 | annotations_path: brief 94 | members_order: source 95 | filters: [] # show private methods 96 | 97 | extra: 98 | social: 99 | - icon: fontawesome/brands/github 100 | link: https://github.com/prosto 101 | 102 | -------------------------------------------------------------------------------- /examples/rag_pipeline.py: -------------------------------------------------------------------------------- 1 | from haystack import GeneratedAnswer, Pipeline 2 | from haystack.components.builders.answer_builder import AnswerBuilder 3 | from haystack.components.builders.prompt_builder import PromptBuilder 4 | from haystack.components.embedders import SentenceTransformersTextEmbedder 5 | from haystack.components.generators import HuggingFaceAPIGenerator 6 | from haystack.utils import Secret 7 | 8 | from neo4j_haystack import Neo4jDocumentStore, Neo4jEmbeddingRetriever 9 | 10 | # Load HF Token from environment variables. 11 | HF_TOKEN = Secret.from_env_var("HF_API_TOKEN") 12 | 13 | # Make sure you have a running Neo4j database with indexed documents available (see `indexing_pipeline.py`), 14 | # e.g. with Docker: 15 | # docker run \ 16 | # --restart always \ 17 | # --publish=7474:7474 --publish=7687:7687 \ 18 | # --env NEO4J_AUTH=neo4j/passw0rd \ 19 | # neo4j:5.15.0 20 | 21 | document_store = Neo4jDocumentStore( 22 | url="bolt://localhost:7687", 23 | username="neo4j", 24 | password="passw0rd", 25 | database="neo4j", 26 | embedding_dim=384, 27 | similarity="cosine", 28 | recreate_index=False, # Do not delete index as it was created by indexer pipeline 29 | create_index_if_missing=False, 30 | ) 31 | 32 | # Build a RAG pipeline with a Retriever to get relevant documents to the query and a HuggingFaceTGIGenerator 33 | # interacting with LLMs using a custom prompt. 34 | prompt_template = """ 35 | Given these documents, answer the question.\nDocuments: 36 | {% for doc in documents %} 37 | {{ doc.content }} 38 | {% endfor %} 39 | 40 | \nQuestion: {{question}} 41 | \nAnswer: 42 | """ 43 | rag_pipeline = Pipeline() 44 | rag_pipeline.add_component( 45 | "query_embedder", 46 | SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2", progress_bar=False), 47 | ) 48 | rag_pipeline.add_component("retriever", Neo4jEmbeddingRetriever(document_store=document_store)) 49 | rag_pipeline.add_component("prompt_builder", PromptBuilder(template=prompt_template)) 50 | rag_pipeline.add_component( 51 | "llm", 52 | HuggingFaceAPIGenerator( 53 | api_type="serverless_inference_api", api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"}, token=HF_TOKEN 54 | ), 55 | ) 56 | rag_pipeline.add_component("answer_builder", AnswerBuilder()) 57 | 58 | rag_pipeline.connect("query_embedder", "retriever.query_embedding") 59 | rag_pipeline.connect("retriever.documents", "prompt_builder.documents") 60 | rag_pipeline.connect("prompt_builder.prompt", "llm.prompt") 61 | rag_pipeline.connect("llm.replies", "answer_builder.replies") 62 | rag_pipeline.connect("llm.meta", "answer_builder.meta") 63 | rag_pipeline.connect("retriever", "answer_builder.documents") 64 | 65 | # Ask a question on the data you just added. 66 | question = "Who created the Dothraki vocabulary?" 67 | result = rag_pipeline.run( 68 | { 69 | "query_embedder": {"text": question}, 70 | "retriever": {"top_k": 3}, 71 | "prompt_builder": {"question": question}, 72 | "answer_builder": {"query": question}, 73 | } 74 | ) 75 | 76 | # For details, like which documents were used to generate the answer, look into the GeneratedAnswer object 77 | answer: GeneratedAnswer = result["answer_builder"]["answers"][0] 78 | 79 | # ruff: noqa: T201 80 | print("Query: ", answer.query) 81 | print("Answer: ", answer.data) 82 | print("== Sources:") 83 | for doc in answer.documents: 84 | print("-> ", doc.meta["file_path"]) 85 | -------------------------------------------------------------------------------- /examples/rag-parent-child/rag-pipeline.py: -------------------------------------------------------------------------------- 1 | from haystack import GeneratedAnswer, Pipeline 2 | from haystack.components.builders.answer_builder import AnswerBuilder 3 | from haystack.components.builders.prompt_builder import PromptBuilder 4 | from haystack.components.embedders import SentenceTransformersTextEmbedder 5 | from haystack.components.generators import HuggingFaceAPIGenerator 6 | from haystack.utils.auth import Secret 7 | 8 | from neo4j_haystack import Neo4jClientConfig, Neo4jDynamicDocumentRetriever 9 | 10 | # Load HF Token from environment variables. 11 | HF_TOKEN = Secret.from_env_var("HF_API_TOKEN") 12 | 13 | # Make sure you have a running Neo4j database with indexed documents available (run `indexing_pipeline.py` first 14 | # and keep docker running) 15 | client_config = Neo4jClientConfig( 16 | url="bolt://localhost:7687", 17 | username="neo4j", 18 | password="passw0rd", 19 | database="neo4j", 20 | ) 21 | 22 | cypher_query = """ 23 | // Query Child documents by $query_embedding 24 | CALL db.index.vector.queryNodes($index, $top_k, $query_embedding) 25 | YIELD node as chunk, score 26 | 27 | // Find Parent document for previously retrieved child (e.g. extend RAG context) 28 | MATCH (chunk)-[:HAS_PARENT]->(parent:Document) 29 | WITH parent, max(score) AS score // deduplicate parents 30 | RETURN parent{.*, score} 31 | """ 32 | 33 | prompt_template = """ 34 | Given these documents, answer the question.\nDocuments: 35 | {% for doc in documents %} 36 | {{ doc.content }} 37 | {% endfor %} 38 | 39 | \nQuestion: {{question}} 40 | \nAnswer: 41 | """ 42 | rag_pipeline = Pipeline() 43 | rag_pipeline.add_component( 44 | "query_embedder", 45 | SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2", progress_bar=False), 46 | ) 47 | rag_pipeline.add_component( 48 | "retriever", 49 | Neo4jDynamicDocumentRetriever( 50 | client_config=client_config, 51 | runtime_parameters=["query_embedding"], 52 | doc_node_name="parent", 53 | verify_connectivity=True, 54 | ), 55 | ) 56 | rag_pipeline.add_component("prompt_builder", PromptBuilder(template=prompt_template)) 57 | rag_pipeline.add_component( 58 | "llm", 59 | HuggingFaceAPIGenerator( 60 | api_type="serverless_inference_api", 61 | api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"}, 62 | generation_kwargs={"max_new_tokens": 120, "stop_sequences": ["."]}, 63 | token=HF_TOKEN, 64 | ), 65 | ) 66 | rag_pipeline.add_component("answer_builder", AnswerBuilder()) 67 | 68 | rag_pipeline.connect("query_embedder", "retriever.query_embedding") 69 | rag_pipeline.connect("retriever.documents", "prompt_builder.documents") 70 | rag_pipeline.connect("prompt_builder.prompt", "llm.prompt") 71 | rag_pipeline.connect("llm.replies", "answer_builder.replies") 72 | rag_pipeline.connect("llm.meta", "answer_builder.meta") 73 | rag_pipeline.connect("retriever", "answer_builder.documents") 74 | 75 | question = "Why did author suppressed technology in the Dune universe?" 76 | result = rag_pipeline.run( 77 | { 78 | "query_embedder": {"text": question}, 79 | "retriever": { 80 | "query": cypher_query, 81 | "parameters": {"index": "chunk-embeddings", "top_k": 5}, 82 | }, 83 | "prompt_builder": {"question": question}, 84 | "answer_builder": {"query": question}, 85 | } 86 | ) 87 | 88 | # For details, like which documents were used to generate the answer, look into the GeneratedAnswer object 89 | answer: GeneratedAnswer = result["answer_builder"]["answers"][0] 90 | 91 | # ruff: noqa: T201 92 | print("Query: ", answer.query) 93 | print("Answer: ", answer.data) 94 | print("== Sources:") 95 | for doc in answer.documents: 96 | print("-> Score: ", doc.score, "Content: ", doc.content[:200] + "...") 97 | -------------------------------------------------------------------------------- /examples/rag_pipeline_cypher.py: -------------------------------------------------------------------------------- 1 | from haystack import GeneratedAnswer, Pipeline 2 | from haystack.components.builders.answer_builder import AnswerBuilder 3 | from haystack.components.builders.prompt_builder import PromptBuilder 4 | from haystack.components.embedders import SentenceTransformersTextEmbedder 5 | from haystack.components.generators import HuggingFaceAPIGenerator 6 | from haystack.utils import Secret 7 | 8 | from neo4j_haystack import Neo4jClientConfig, Neo4jDynamicDocumentRetriever 9 | 10 | # Load HF Token from environment variables. 11 | HF_TOKEN = Secret.from_env_var("HF_API_TOKEN") 12 | 13 | # Make sure you have a running Neo4j database with indexed documents available (see `indexing_pipeline.py`), 14 | # e.g. with Docker: 15 | # docker run \ 16 | # --restart always \ 17 | # --publish=7474:7474 --publish=7687:7687 \ 18 | # --env NEO4J_AUTH=neo4j/passw0rd \ 19 | # neo4j:5.15.0 20 | 21 | client_config = Neo4jClientConfig( 22 | url="bolt://localhost:7687", 23 | username="neo4j", 24 | password="passw0rd", 25 | database="neo4j", 26 | ) 27 | 28 | cypher_query = """ 29 | CALL db.index.vector.queryNodes($index, $top_k, $query_embedding) 30 | YIELD node as doc, score 31 | MATCH (doc) 32 | RETURN doc{.*, score}, score 33 | ORDER BY score DESC LIMIT $top_k 34 | """ 35 | 36 | # Build a RAG pipeline with a Retriever to get relevant documents to the query and a HuggingFaceTGIGenerator 37 | # interacting with LLMs using a custom prompt. 38 | prompt_template = """ 39 | Given these documents, answer the question.\nDocuments: 40 | {% for doc in documents %} 41 | {{ doc.content }} 42 | {% endfor %} 43 | 44 | \nQuestion: {{question}} 45 | \nAnswer: 46 | """ 47 | rag_pipeline = Pipeline() 48 | rag_pipeline.add_component( 49 | "query_embedder", 50 | SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2", progress_bar=False), 51 | ) 52 | rag_pipeline.add_component( 53 | "retriever", 54 | Neo4jDynamicDocumentRetriever( 55 | client_config=client_config, 56 | runtime_parameters=["query_embedding"], 57 | doc_node_name="doc", 58 | verify_connectivity=True, 59 | ), 60 | ) 61 | rag_pipeline.add_component("prompt_builder", PromptBuilder(template=prompt_template)) 62 | rag_pipeline.add_component( 63 | "llm", 64 | HuggingFaceAPIGenerator( 65 | api_type="serverless_inference_api", api_params={"model": "mistralai/Mistral-7B-Instruct-v0.3"}, token=HF_TOKEN 66 | ), 67 | ) 68 | rag_pipeline.add_component("answer_builder", AnswerBuilder()) 69 | 70 | rag_pipeline.connect("query_embedder", "retriever.query_embedding") 71 | rag_pipeline.connect("retriever.documents", "prompt_builder.documents") 72 | rag_pipeline.connect("prompt_builder.prompt", "llm.prompt") 73 | rag_pipeline.connect("llm.replies", "answer_builder.replies") 74 | rag_pipeline.connect("llm.meta", "answer_builder.meta") 75 | rag_pipeline.connect("retriever", "answer_builder.documents") 76 | 77 | # Ask a question on the data you just added. 78 | question = "Who created the Dothraki vocabulary?" 79 | result = rag_pipeline.run( 80 | { 81 | "query_embedder": {"text": question}, 82 | "retriever": { 83 | "query": cypher_query, 84 | "parameters": {"index": "document-embeddings", "top_k": 3}, 85 | }, 86 | "prompt_builder": {"question": question}, 87 | "answer_builder": {"query": question}, 88 | } 89 | ) 90 | 91 | # For details, like which documents were used to generate the answer, look into the GeneratedAnswer object 92 | answer: GeneratedAnswer = result["answer_builder"]["answers"][0] 93 | 94 | # ruff: noqa: T201 95 | print("Query: ", answer.query) 96 | print("Answer: ", answer.data) 97 | print("== Sources:") 98 | for doc in answer.documents: 99 | print("-> ", doc.meta["file_path"]) 100 | -------------------------------------------------------------------------------- /src/neo4j_haystack/serialization/query_parameters_marshaller.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, is_dataclass 2 | from typing import Any, Dict 3 | 4 | from haystack import Document 5 | 6 | from neo4j_haystack.serialization.types import QueryParametersMarshaller 7 | 8 | 9 | class DocumentQueryParametersMarshaller(QueryParametersMarshaller): 10 | """ 11 | A marshaller which converts `haystack.Document` to a dictionary when it is used as query parameter in Cypher query 12 | 13 | Haystack's native `to_dict` method is called to produce a flattened dictionary of Document data along with its meta 14 | fields. 15 | """ 16 | 17 | def supports(self, obj: Any) -> bool: 18 | """ 19 | Checks if given object is `haystack.Document` instance 20 | """ 21 | return isinstance(obj, Document) 22 | 23 | def marshal(self, obj: Any) -> Dict[str, Any]: 24 | """ 25 | Converts `haystack.Document` to dictionary so it could be used as Cypher query parameter 26 | """ 27 | return obj.to_dict(flatten=True) 28 | 29 | 30 | class DataclassQueryParametersMarshaller(QueryParametersMarshaller): 31 | """ 32 | A marshaller which converts a `dataclass` to a dictionary when encountered as Cypher query parameter. 33 | """ 34 | 35 | def supports(self, obj: Any) -> bool: 36 | """ 37 | Checks if given object is a python `dataclass` instance 38 | """ 39 | return is_dataclass(obj) and not isinstance(obj, type) 40 | 41 | def marshal(self, obj: Any) -> Dict[str, Any]: 42 | """ 43 | Converts `dataclass` to dictionary so it could be used as Cypher query parameter. 44 | """ 45 | return asdict(obj) 46 | 47 | 48 | class Neo4jQueryParametersMarshaller(QueryParametersMarshaller): 49 | """ 50 | The marshaller converts Cypher query parameters to types which can be consumed in Neo4j query execution. In some 51 | cases query parameters contain complex data types (e.g. dataclasses) which can be converted to `dict` types and 52 | thus become eligible for running a Cypher query: 53 | 54 | ```py title="Example: Running a query with a dataclass parameter" 55 | from dataclasses import dataclass 56 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig 57 | from neo4j_haystack.serialization import Neo4jQueryParametersMarshaller 58 | 59 | @dataclass 60 | class YearInfo: 61 | year: int = 2024 62 | 63 | neo4j_config = Neo4jClientConfig("bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd") 64 | neo4j_client = Neo4jClient(client_config=neo4j_config) 65 | 66 | marshaller = Neo4jQueryParametersMarshaller() 67 | query_params = marshaller.marshall({"year_info": YearInfo()}) 68 | 69 | neo4j_client.execute_read( 70 | "MATCH (doc:`Document`) WHERE doc.year=$year_info.year RETURN doc", 71 | parameters=query_params 72 | ) 73 | ``` 74 | The above example would fail without marshaller. With marshaller `query_params` will become 75 | `:::py {"year_info": {"year": 2024}}` which is eligible input for Neo4j Driver. 76 | """ 77 | 78 | def __init__(self): 79 | self._marshallers = [DocumentQueryParametersMarshaller(), DataclassQueryParametersMarshaller()] 80 | 81 | def supports(self, _obj: Any) -> bool: 82 | """ 83 | Supports conversion of any type 84 | """ 85 | return True 86 | 87 | def marshal(self, data) -> Any: 88 | if isinstance(data, dict): 89 | return {key: self.marshal(value) for key, value in data.items()} 90 | elif isinstance(data, (list, tuple, set)): 91 | return [self.marshal(value) for value in data] 92 | 93 | # Find a marshaller which supports given type of parameter otherwise return parameter as is 94 | marshaller = next((m for m in self._marshallers if m.supports(data)), None) 95 | 96 | return self.marshal(marshaller.marshal(data)) if marshaller else data 97 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/components/test_neo4j_retriever.py: -------------------------------------------------------------------------------- 1 | from operator import attrgetter 2 | from typing import Callable, List 3 | from unittest import mock 4 | 5 | import pytest 6 | from haystack import Document 7 | 8 | from neo4j_haystack.components.neo4j_retriever import Neo4jEmbeddingRetriever 9 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 10 | 11 | 12 | @pytest.fixture 13 | def movie_document_store( 14 | doc_store_factory: Callable[..., Neo4jDocumentStore], 15 | movie_documents_with_embeddings: List[Document], 16 | ) -> Neo4jDocumentStore: 17 | doc_store = doc_store_factory(embedding_dim=384, similarity="cosine", recreate_index=True) 18 | doc_store.write_documents(movie_documents_with_embeddings) 19 | 20 | return doc_store 21 | 22 | 23 | def test_retrieve_documents(movie_document_store: Neo4jDocumentStore, text_embedder: Callable[[str], List[float]]): 24 | retriever = Neo4jEmbeddingRetriever(document_store=movie_document_store) 25 | 26 | query_embedding = text_embedder( 27 | "A young fella pretending to be a good citizen but actually planning to commit a crime" 28 | ) 29 | 30 | documents = retriever.run( 31 | query_embedding, 32 | top_k=5, 33 | )["documents"] 34 | assert len(documents) == 5 35 | 36 | expected_content = "A film student robs a bank under the guise of shooting a short film about a bank robbery." 37 | retrieved_contents = list(map(attrgetter("content"), documents)) 38 | assert expected_content in retrieved_contents 39 | 40 | 41 | @pytest.mark.integration 42 | def test_retrieve_documents_with_filters( 43 | movie_document_store: Neo4jDocumentStore, text_embedder: Callable[[str], List[float]] 44 | ): 45 | retriever = Neo4jEmbeddingRetriever(document_store=movie_document_store) 46 | 47 | query_embedding = text_embedder( 48 | "A young fella pretending to be a good citizen but actually planning to commit a crime" 49 | ) 50 | 51 | documents = retriever.run( 52 | query_embedding, 53 | top_k=5, 54 | filters={"field": "release_date", "operator": "==", "value": "2018-12-09"}, 55 | )["documents"] 56 | assert len(documents) == 1 57 | 58 | 59 | @pytest.mark.integration 60 | def test_retriever_to_dict(): 61 | doc_store = mock.create_autospec(Neo4jDocumentStore) 62 | doc_store.to_dict.return_value = {"ds": "yes"} 63 | 64 | retriever = Neo4jEmbeddingRetriever( 65 | document_store=doc_store, 66 | filters={"field": "num", "operator": ">", "value": 10}, 67 | top_k=11, 68 | scale_score=False, 69 | return_embedding=True, 70 | ) 71 | data = retriever.to_dict() 72 | 73 | assert data == { 74 | "type": "neo4j_haystack.components.neo4j_retriever.Neo4jEmbeddingRetriever", 75 | "init_parameters": { 76 | "document_store": {"ds": "yes"}, 77 | "filters": {"field": "num", "operator": ">", "value": 10}, 78 | "top_k": 11, 79 | "scale_score": False, 80 | "return_embedding": True, 81 | }, 82 | } 83 | 84 | 85 | @pytest.mark.integration 86 | @mock.patch.object(Neo4jDocumentStore, "from_dict") 87 | def test_retriever_from_dict(from_dict_mock): 88 | data = { 89 | "type": "neo4j_haystack.components.neo4j_retriever.Neo4jEmbeddingRetriever", 90 | "init_parameters": { 91 | "document_store": {"ds": "yes"}, 92 | "filters": {"field": "num", "operator": ">", "value": 10}, 93 | "top_k": 11, 94 | "scale_score": False, 95 | "return_embedding": True, 96 | }, 97 | } 98 | doc_store = mock.create_autospec(Neo4jDocumentStore) 99 | from_dict_mock.return_value = doc_store 100 | 101 | retriever = Neo4jEmbeddingRetriever.from_dict(data) 102 | 103 | assert retriever._document_store == doc_store 104 | assert retriever._filters == {"field": "num", "operator": ">", "value": 10} 105 | assert retriever._top_k == 11 106 | assert retriever._scale_score is False 107 | assert retriever._return_embedding is True 108 | -------------------------------------------------------------------------------- /examples/rag-parent-child/indexing-pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from haystack import Pipeline 4 | from haystack.components.converters import TextFileToDocument 5 | from haystack.components.embedders import SentenceTransformersDocumentEmbedder 6 | from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter 7 | 8 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 9 | from neo4j_haystack.components.neo4j_query_writer import Neo4jQueryWriter 10 | 11 | # Make sure you have a running Neo4j database, e.g. with Docker: 12 | # docker run \ 13 | # --restart always \ 14 | # --publish=7474:7474 --publish=7687:7687 \ 15 | # --env NEO4J_AUTH=neo4j/passw0rd \ 16 | # neo4j:5.16.0 17 | 18 | client_config = Neo4jClientConfig( 19 | url="bolt://localhost:7687", 20 | username="neo4j", 21 | password="passw0rd", 22 | database="neo4j", 23 | ) 24 | 25 | pipe = Pipeline() 26 | pipe.add_component("text_file_converter", TextFileToDocument()) 27 | pipe.add_component("cleaner", DocumentCleaner()) 28 | pipe.add_component("parent_splitter", DocumentSplitter(split_by="word", split_length=512, split_overlap=30)) 29 | pipe.add_component("child_splitter", DocumentSplitter(split_by="word", split_length=100, split_overlap=24)) 30 | pipe.add_component("embedder", SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")) 31 | pipe.add_component( 32 | "neo4j_writer", 33 | Neo4jQueryWriter(client_config=client_config, runtime_parameters=["child_documents", "parent_documents"]), 34 | ) 35 | pipe.add_component( 36 | "neo4j_vector_index", Neo4jQueryWriter(client_config=client_config, runtime_parameters=["query_status"]) 37 | ) 38 | 39 | pipe.connect("text_file_converter.documents", "cleaner.documents") 40 | pipe.connect("cleaner.documents", "parent_splitter.documents") 41 | pipe.connect("parent_splitter.documents", "child_splitter.documents") 42 | pipe.connect("child_splitter.documents", "embedder.documents") 43 | pipe.connect("embedder.documents", "neo4j_writer.child_documents") 44 | pipe.connect("parent_splitter.documents", "neo4j_writer.parent_documents") 45 | pipe.connect("neo4j_writer.query_status", "neo4j_vector_index.query_status") 46 | 47 | # Take the docs data directory as input and run the pipeline 48 | file_paths = [Path(__file__).resolve().parent / "dune.txt"] 49 | 50 | cypher_query_create_documents = """ 51 | // Creating Parent documents 52 | UNWIND $parent_documents AS parent_doc 53 | MERGE (parent:`Document` {id: parent_doc.id}) 54 | SET parent += parent_doc{.*, embedding: null, _split_overlap: null} // _split_overlap is a dictionary and is skipped 55 | 56 | // Creating Child documents for a given 'parent' document 57 | WITH parent 58 | UNWIND $child_documents AS child_doc 59 | WITH parent, child_doc WHERE child_doc.source_id = parent.id 60 | MERGE (child:`Chunk` {id: child_doc.id})-[:HAS_PARENT]->(parent) 61 | SET child += child_doc{.*, embedding: null, _split_overlap: null} // _split_overlap is a dictionary and is skipped 62 | WITH child, child_doc 63 | CALL { WITH child, child_doc 64 | MATCH(child:`Chunk` {id: child_doc.id}) WHERE child_doc.embedding IS NOT NULL 65 | CALL db.create.setNodeVectorProperty(child, 'embedding', child_doc.embedding) 66 | } 67 | """ 68 | 69 | cypher_query_create_index = """ 70 | CREATE VECTOR INDEX `chunk-embeddings` IF NOT EXISTS 71 | FOR (child_doc:Chunk) 72 | ON (child_doc.embedding) 73 | OPTIONS {indexConfig: { 74 | `vector.dimensions`: $vector_dimensions, 75 | `vector.similarity_function`: $vector_similarity_function 76 | }} 77 | """ 78 | 79 | result = pipe.run( 80 | { 81 | "text_file_converter": {"sources": file_paths}, 82 | "neo4j_writer": { 83 | "query": cypher_query_create_documents, 84 | "parameters": {"parent_label": "Parent", "child_label": "Chunk"}, 85 | }, 86 | "neo4j_vector_index": { 87 | "query": cypher_query_create_index, 88 | "parameters": { 89 | "vector_dimensions": 384, 90 | "vector_similarity_function": "cosine", 91 | }, 92 | }, 93 | } 94 | ) 95 | 96 | # Assuming you have a docker container running navigate to http://localhost:7474 to open Neo4j Browser 97 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Environment setup 4 | 5 | Fork and clone the repository, then: 6 | 7 | ```bash 8 | cd scripts 9 | ./setup.sh 10 | ``` 11 | 12 | > NOTE: 13 | > If it fails for some reason, 14 | > you'll need to install 15 | > [Hatch](https://hatch.pypa.io/latest/) 16 | > manually. 17 | > 18 | > You can install it with: 19 | > 20 | > ```bash 21 | > python3 -m pip install --user pipx 22 | > pipx install hatch 23 | > ``` 24 | > 25 | > Now you can try running `hatch env show` to see 26 | > all available environments and provided scripts. 27 | 28 | Hatch automatically manages dependencies for given environments. Default environment can be used for running unit/integration tests. 29 | 30 | For local development it is suggested to configure Hatch to create virtual environments inside project's folder. First find the location of the configuration file: 31 | 32 | ```bash 33 | hatch config find # see where configuration file is located 34 | ``` 35 | 36 | Update configuration with the following: 37 | 38 | ```toml 39 | [dirs.env] 40 | virtual = ".venv" 41 | ``` 42 | 43 | It is suggested to develop locally with minimal supported python version (see `requires-python` in `pyproject.toml`). 44 | To make sure hatch is creating environments with python version you need set `HATCH_PYTHON` as follows: 45 | 46 | ```bash 47 | export HATCH_PYTHON=~/.pyenv/versions/3.8.16/bin/python # You might have a different version 48 | hatch run python --version # Checks python version by creating default hatch environment 49 | ``` 50 | 51 | If you run `hatch env create dev` you should see `.venv/dev` directory created with dependencies you need for development. 52 | You could point your IDE to this particular python environment/interpreter. 53 | 54 | ## Development 55 | 56 | 1. create a new branch: `git switch -c feature-or-bugfix-name` 57 | 1. edit the code and/or the documentation 58 | 59 | **Before committing:** 60 | 61 | 1. run `hatch run lint:all` for code quality checks and formatting 62 | 1. run `hatch run test` to run the tests (fix any issue) 63 | 1. if you updated the documentation or the project dependencies: 64 | 1. run `hatch run docs:serve` 65 | 1. go to http://localhost:8000/neo4j-haystack/ and check that everything looks good 66 | 1. follow [commit message convention](#commit-message-convention) 67 | 68 | Don't bother updating the changelog, we will take care of this. 69 | 70 | ## Commit message convention 71 | 72 | Commit messages must follow the convention based on the 73 | [Angular style](https://gist.github.com/stephenparish/9941e89d80e2bc58a153#format-of-the-commit-message) 74 | or the [Karma convention](https://karma-runner.github.io/4.0/dev/git-commit-msg.html): 75 | 76 | ```text 77 | [(scope)]: Subject 78 | 79 | [Body] 80 | ``` 81 | 82 | **Subject and body must be valid Markdown.** 83 | Subject must have proper casing (uppercase for first letter 84 | if it makes sense), but no dot at the end, and no punctuation 85 | in general. 86 | 87 | Scope and body are optional. Type can be: 88 | 89 | - `build`: About packaging, building wheels, etc. 90 | - `chore`: About packaging or repo/files management. 91 | - `ci`: About Continuous Integration. 92 | - `deps`: Dependencies update. 93 | - `docs`: About documentation. 94 | - `feat`: New feature. 95 | - `fix`: Bug fix. 96 | - `perf`: About performance. 97 | - `refactor`: Changes that are not features or bug fixes. 98 | - `style`: A change in code style/format. 99 | - `tests`: About tests. 100 | 101 | If you write a body, please add trailers at the end 102 | (for example issues and PR references, or co-authors), 103 | without relying on GitHub's flavored Markdown: 104 | 105 | ```text 106 | Body. 107 | 108 | Issue #10: https://github.com///issues/10 109 | Related to PR namespace/other-project#15: https://github.com///pull/15 110 | ``` 111 | 112 | These "trailers" must appear at the end of the body, 113 | without any blank lines between them. The trailer title 114 | can contain any character except colons `:`. 115 | We expect a full URI for each trailer, not just GitHub autolinks 116 | (for example, full GitHub URLs for commits and issues, 117 | not the hash or the #issue-number). 118 | 119 | We do not enforce a line length on commit messages summary and body, 120 | but please avoid very long summaries, and very long lines in the body, 121 | unless they are part of code blocks that must not be wrapped. 122 | 123 | ## Pull requests guidelines 124 | 125 | Link to any related issue in the Pull Request message. 126 | 127 | During the review, we recommend using fixups: 128 | 129 | ```bash 130 | # SHA is the SHA of the commit you want to fix 131 | git commit --fixup=SHA 132 | ``` 133 | 134 | Once all the changes are approved, you can squash your commits: 135 | 136 | ```bash 137 | git rebase -i --autosquash main 138 | ``` 139 | 140 | And force-push: 141 | 142 | ```bash 143 | git push -f 144 | ``` 145 | 146 | If this seems all too complicated, you can push or force-push each new commit, 147 | and we will squash them ourselves if needed, before merging. 148 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/document_stores/test_ds_filters.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | from haystack import Document 5 | 6 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 7 | 8 | 9 | @pytest.mark.integration 10 | def test_eq_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 11 | doc_store.write_documents(documents) 12 | result = doc_store.filter_documents(filters={"field": "year", "operator": "==", "value": 2020}) 13 | assert len(result) == 3 14 | 15 | 16 | @pytest.mark.integration 17 | def test_ne_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 18 | doc_store.write_documents(documents) 19 | 20 | result = doc_store.filter_documents(filters={"field": "year", "operator": "!=", "value": 2020}) 21 | 22 | # Neo4j teats null values (absent properties) as incomparable in logical expressions, thus documents 23 | # not having `year` property are ignored 24 | assert len(result) == 3 25 | 26 | 27 | @pytest.mark.integration 28 | def test_in_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 29 | doc_store.write_documents(documents) 30 | result = doc_store.filter_documents(filters={"field": "year", "operator": "in", "value": [2020, 2021, "n.a."]}) 31 | assert len(result) == 6 32 | 33 | 34 | @pytest.mark.integration 35 | def test_not_in_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 36 | """ 37 | Neo4j does not consider properties with null values during filtering, e.g. "year NOT IN [2020, 2026]" will ignore 38 | documents where ``year`` is absent (or null which is equivalent in Neo4j) 39 | """ 40 | doc_store.write_documents(documents) 41 | result = doc_store.filter_documents(filters={"field": "year", "operator": "not in", "value": [2020, 2026]}) 42 | assert len(result) == 3 43 | 44 | result = doc_store.filter_documents(filters={"field": "year", "operator": "not in", "value": [2020, 2021]}) 45 | assert len(result) == 0 46 | 47 | 48 | @pytest.mark.integration 49 | def test_comparison_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 50 | doc_store.write_documents(documents) 51 | 52 | result = doc_store.filter_documents(filters={"field": "year", "operator": ">", "value": 2020}) 53 | assert len(result) == 3 54 | 55 | result = doc_store.filter_documents(filters={"field": "year", "operator": ">=", "value": 2020}) 56 | assert len(result) == 6 57 | 58 | result = doc_store.filter_documents(filters={"field": "year", "operator": "<", "value": 2021}) 59 | assert len(result) == 3 60 | 61 | result = doc_store.filter_documents(filters={"field": "year", "operator": "<=", "value": 2021}) 62 | assert len(result) == 6 63 | 64 | 65 | @pytest.mark.integration 66 | def test_compound_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 67 | doc_store.write_documents(documents) 68 | 69 | result = doc_store.filter_documents( 70 | filters={ 71 | "operator": "AND", 72 | "conditions": [ 73 | {"field": "year", "operator": ">=", "value": 2020}, 74 | {"field": "year", "operator": "<=", "value": 2021}, 75 | ], 76 | } 77 | ) 78 | assert len(result) == 6 79 | 80 | 81 | @pytest.mark.integration 82 | def test_nested_condition_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 83 | doc_store.write_documents(documents) 84 | 85 | filters = { 86 | "operator": "AND", 87 | "conditions": [ 88 | {"field": "year", "operator": ">=", "value": 2020}, 89 | {"field": "year", "operator": "<=", "value": 2021}, 90 | { 91 | "operator": "OR", 92 | "conditions": [ 93 | {"field": "name", "operator": "in", "value": ["name_0", "name_1"]}, 94 | {"field": "numbers", "operator": "<", "value": 5.0}, 95 | ], 96 | }, 97 | ], 98 | } 99 | result = doc_store.filter_documents(filters=filters) 100 | assert len(result) == 4 101 | 102 | filters = { 103 | "operator": "AND", 104 | "conditions": [ 105 | {"field": "year", "operator": ">=", "value": 2020}, 106 | {"field": "year", "operator": "<=", "value": 2021}, 107 | { 108 | "operator": "OR", 109 | "conditions": [ 110 | {"field": "name", "operator": "in", "value": ["name_0", "name_1"]}, 111 | { 112 | "operator": "AND", 113 | "conditions": [ 114 | {"field": "name", "operator": "==", "value": "name_2"}, 115 | { 116 | "operator": "NOT", 117 | "conditions": [ 118 | {"field": "month", "operator": "==", "value": "01"}, 119 | ], 120 | }, 121 | ], 122 | }, 123 | ], 124 | }, 125 | ], 126 | } 127 | result = doc_store.filter_documents(filters=filters) 128 | assert len(result) == 5 129 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/components/test_neo4j_query_reader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from unittest import mock 3 | 4 | import pytest 5 | from haystack import Document 6 | 7 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 8 | from neo4j_haystack.components.neo4j_query_reader import Neo4jQueryReader 9 | from neo4j_haystack.components.neo4j_retriever import Neo4jClient 10 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 11 | 12 | 13 | @pytest.fixture 14 | def client_config( 15 | neo4j_database: Neo4jClientConfig, 16 | doc_store: Neo4jDocumentStore, 17 | documents: List[Document], 18 | ) -> Neo4jClientConfig: 19 | doc_store.write_documents(documents) 20 | return neo4j_database 21 | 22 | 23 | @pytest.mark.integration 24 | def test_query_reader(client_config: Neo4jClientConfig): 25 | query = "MATCH (doc:`Document`) WHERE doc.year=$year RETURN doc.name as name, doc.year as year" 26 | reader = Neo4jQueryReader( 27 | client_config=client_config, 28 | query=query, 29 | verify_connectivity=True, 30 | runtime_parameters=["year"], 31 | ) 32 | 33 | result = reader.run(year=2020) 34 | 35 | assert result["records"] == [ 36 | {"name": "name_0", "year": 2020}, 37 | {"name": "name_1", "year": 2020}, 38 | {"name": "name_2", "year": 2020}, 39 | ] 40 | 41 | 42 | @pytest.mark.integration 43 | def test_query_reader_single_record(client_config: Neo4jClientConfig): 44 | reader = Neo4jQueryReader(client_config=client_config, verify_connectivity=True) 45 | 46 | result = reader.run( 47 | query=("MATCH (doc:`Document` {name: $name}) RETURN doc.name as name, doc.year as year"), 48 | parameters={"name": "name_1"}, 49 | ) 50 | 51 | assert result["first_record"] == {"name": "name_1", "year": 2020} 52 | 53 | 54 | @pytest.mark.integration 55 | def test_query_reader_error_result(client_config: Neo4jClientConfig): 56 | reader = Neo4jQueryReader(client_config=client_config, raise_on_failure=False) 57 | 58 | result = reader.run( 59 | query=("MATCH (doc:`Document` {name: $name}) RETURN_ doc.name as name, doc.year as year"), 60 | parameters={"name": "name_1"}, 61 | ) 62 | 63 | assert "Invalid input 'RETURN_'" in result["error_message"] 64 | 65 | 66 | @pytest.mark.integration 67 | def test_query_reader_raises_error(client_config: Neo4jClientConfig): 68 | reader = Neo4jQueryReader(client_config=client_config, raise_on_failure=True) 69 | 70 | with pytest.raises(Exception): # noqa: B017 71 | reader.run( 72 | query=("MATCH (doc:`Document` {name: $name}) RETURN_ doc.name as name, doc.year as year"), 73 | parameters={"name": "name_1"}, 74 | ) 75 | 76 | 77 | @pytest.mark.unit 78 | @mock.patch("neo4j_haystack.components.neo4j_query_reader.Neo4jClientConfig", spec=Neo4jClientConfig) 79 | @mock.patch("neo4j_haystack.components.neo4j_query_reader.Neo4jClient", spec=Neo4jClient) 80 | def test_neo4j_query_reader_to_dict(neo4j_client_mock, client_config_mock): 81 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jQueryReader 82 | client_config_mock.configure_mock(**{"to_dict.return_value": {"mock": "mock"}}) 83 | 84 | reader = Neo4jQueryReader( 85 | client_config=client_config_mock, 86 | query="cypher", 87 | runtime_parameters=["year"], 88 | verify_connectivity=True, 89 | raise_on_failure=False, 90 | ) 91 | 92 | data = reader.to_dict() 93 | 94 | assert data == { 95 | "type": "neo4j_haystack.components.neo4j_query_reader.Neo4jQueryReader", 96 | "init_parameters": { 97 | "query": "cypher", 98 | "runtime_parameters": ["year"], 99 | "verify_connectivity": True, 100 | "raise_on_failure": False, 101 | "client_config": {"mock": "mock"}, 102 | }, 103 | } 104 | neo4j_client.verify_connectivity.assert_called_once() 105 | 106 | 107 | @pytest.mark.unit 108 | @mock.patch.object(Neo4jClientConfig, "from_dict") 109 | @mock.patch("neo4j_haystack.components.neo4j_query_reader.Neo4jClient", spec=Neo4jClient) 110 | def test_neo4j_query_reader_from_dict(neo4j_client_mock, from_dict_mock): 111 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jQueryReader 112 | expected_client_config = mock.Mock(spec=Neo4jClientConfig) 113 | from_dict_mock.return_value = expected_client_config 114 | 115 | data = { 116 | "type": "neo4j_haystack.components.neo4j_query_reader.Neo4jQueryReader", 117 | "init_parameters": { 118 | "query": "cypher", 119 | "runtime_parameters": ["year"], 120 | "verify_connectivity": True, 121 | "raise_on_failure": False, 122 | "client_config": {"mock": "mock"}, 123 | }, 124 | } 125 | 126 | reader = Neo4jQueryReader.from_dict(data) 127 | 128 | assert reader._query == "cypher" 129 | assert reader._client_config == expected_client_config 130 | assert reader._runtime_parameters == ["year"] 131 | assert reader._verify_connectivity is True 132 | assert reader._raise_on_failure is False 133 | 134 | neo4j_client.verify_connectivity.assert_called_once() 135 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling>=1.8.0", "hatch-regex-commit"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "neo4j-haystack" 7 | dynamic = ["version"] 8 | description = 'Integration of Neo4j graph database with Haystack' 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = "MIT" 12 | keywords = ["Haystack", "semantic-search", "DocumentStore", "neo4j"] 13 | authors = [ 14 | { name = "Sergey Bondarenco", email = "sergey.bondarenco@outlook.com" }, 15 | ] 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "Operating System :: OS Independent", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3.8", 21 | "Programming Language :: Python :: 3.9", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: Implementation :: CPython", 24 | "Programming Language :: Python :: Implementation :: PyPy", 25 | ] 26 | dependencies = ["haystack-ai", "neo4j>=5.8.1"] 27 | 28 | [project.optional-dependencies] 29 | docker = ["docker"] 30 | tests = [ 31 | "coverage[toml]>=7.6.1", 32 | "pytest", 33 | "neo4j-haystack[docker]", 34 | "sentence-transformers>=3.1.0", 35 | ] 36 | docs = [ 37 | "mkdocs>=1.6.1", 38 | "mkdocstrings[python]", 39 | "mkdocs-material>=9.5.34", 40 | "mkdocs-coverage>=1.1.0", 41 | "mkdocs-section-index>=0.3.9", 42 | "black>=24.8.0", 43 | ] 44 | maintain = ["git-changelog>=2.5.2"] 45 | quality = ["black>=24.8.0", "ruff>=0.6.5"] 46 | typing = ["mypy>=1.11.2"] 47 | examples = ["datasets"] 48 | all = ["neo4j-haystack[docker,tests,docs,maintain,quality,typing,examples]"] 49 | 50 | [project.urls] 51 | Documentation = "https://prosto.github.io/neo4j-haystack" 52 | Issues = "https://github.com/prosto/neo4j-haystack/issues" 53 | Source = "https://github.com/prosto/neo4j-haystack" 54 | 55 | [tool.hatch.version] 56 | source = "regex_commit" 57 | path = "src/neo4j_haystack/__about__.py" 58 | tag_sign = false 59 | tag_name = "v{new_version}" 60 | commit_message = "🚀 Version {new_version}" 61 | commit_extra_args = ["-a"] # make sure all other updates get into single commit 62 | check_dirty = false 63 | commit = true 64 | 65 | [tool.hatch.envs.default] 66 | features = ["tests"] 67 | template = "default" 68 | 69 | [tool.hatch.envs.default.scripts] 70 | test = "pytest {args:tests}" 71 | test-cov = "coverage run -m pytest {args:tests}" 72 | cov-report = ["- coverage combine", "coverage html"] 73 | cov = ["test-cov", "cov-report"] 74 | 75 | [tool.hatch.envs.dev] 76 | features = ["all"] 77 | description = """ 78 | Environment can be used locally for development as it installs all dependencies. Usefull when working in IDE. 79 | Run `hatch env create dev` in order to create the environment. Works well when you configure hatch to install 80 | virtual environments into project speicifc dir (e.g. `neo4j_haystack/.venv`). See `dirs.env` setting by running `hatch config show`. 81 | """ 82 | 83 | [tool.hatch.envs.data] 84 | features = ["examples"] 85 | template = "data" 86 | 87 | [tool.hatch.envs.data.scripts] 88 | load-movies = "python scripts/load_movies.py" 89 | 90 | [tool.pytest.ini_options] 91 | minversion = "6.0" 92 | addopts = "--strict-markers" 93 | markers = [ 94 | "unit: unit tests", 95 | "integration: integration tests", 96 | "neo4j: requires Neo4j container", 97 | ] 98 | log_cli = true 99 | log_cli_level = "INFO" 100 | 101 | 102 | [tool.hatch.envs.docs] 103 | template = "docs" 104 | features = ["docs", "typing"] 105 | 106 | [tool.hatch.envs.docs.scripts] 107 | build = "mkdocs build --clean --strict" 108 | serve = "mkdocs serve --dev-addr localhost:8000" 109 | publish = "mkdocs gh-deploy --force" 110 | 111 | [tool.hatch.envs.maintain] 112 | template = "maintain" 113 | features = ["maintain"] 114 | 115 | [tool.hatch.envs.maintain.scripts] 116 | changelog = "git-changelog --bump 'v{args:.}'" 117 | bump = [ 118 | "changelog", # this will update CHANGELOG.md 119 | "hatch version {args:.}", # this will bump version in __about__.py and commit CHANGELOG.md with new tag 120 | ] 121 | 122 | [tool.git-changelog] 123 | bump = "none" 124 | convention = "conventional" 125 | in-place = true 126 | output = "CHANGELOG.md" 127 | parse-refs = false 128 | parse-trailers = true 129 | sections = ["build", "deps", "feat", "fix", "refactor"] 130 | template = "keepachangelog" 131 | 132 | [tool.hatch.envs.lint] 133 | template = "lint" 134 | features = ["quality", "typing"] 135 | 136 | [tool.hatch.envs.lint.scripts] 137 | typing = "mypy --install-types --non-interactive {args:src/neo4j_haystack tests examples scripts}" 138 | style = ["ruff check {args:.}", "black --check --diff {args:.}"] 139 | fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] 140 | all = ["style", "typing"] 141 | 142 | [tool.black] 143 | target-version = ["py38"] 144 | line-length = 120 145 | skip-string-normalization = true 146 | 147 | [tool.ruff] 148 | target-version = "py38" 149 | line-length = 120 150 | 151 | lint.select = [ 152 | "A", 153 | "ARG", 154 | "B", 155 | "C", 156 | "DTZ", 157 | "E", 158 | "EM", 159 | "F", 160 | "FBT", 161 | "I", 162 | "ICN", 163 | "ISC", 164 | "N", 165 | "PLC", 166 | "PLE", 167 | "PLR", 168 | "PLW", 169 | "Q", 170 | "RUF", 171 | "S", 172 | "T", 173 | "TID", 174 | "UP", 175 | "W", 176 | "YTT", 177 | ] 178 | lint.ignore = [ 179 | # Allow function arguments shadowing a Python builtin (e.g. `id`) 180 | "A002", 181 | "A003", 182 | # Allow non-abstract empty methods in abstract base classes 183 | "B027", 184 | # Allow boolean-typed positional argument in function definition 185 | "FBT001", 186 | # Allow boolean default positional argument in function definition 187 | "FBT002", 188 | # Allow boolean positional values in function calls, like `dict.get(... True)` 189 | "FBT003", 190 | # Ignore checks for possible passwords 191 | "S105", 192 | "S106", 193 | "S107", 194 | # Ignore complexity 195 | "C901", 196 | "PLR0911", 197 | "PLR0912", 198 | "PLR0913", 199 | "PLR0915", 200 | 201 | "EM101", 202 | "EM102", 203 | ] 204 | lint.unfixable = [ 205 | # Don't touch unused imports 206 | "F401", 207 | ] 208 | 209 | [tool.mypy] 210 | ignore_missing_imports = true 211 | 212 | [tool.ruff.lint.isort] 213 | known-first-party = ["neo4j_haystack"] 214 | 215 | [tool.ruff.lint.flake8-tidy-imports] 216 | ban-relative-imports = "all" 217 | 218 | [tool.ruff.lint.per-file-ignores] 219 | # Tests can use magic values, assertions, and relative imports 220 | "tests/**/*" = ["PLR2004", "S101", "TID252"] 221 | 222 | [tool.coverage.run] 223 | source_pkgs = ["neo4j_haystack", "tests"] 224 | branch = true 225 | parallel = true 226 | omit = ["src/neo4j_haystack/__about__.py"] 227 | 228 | [tool.coverage.paths] 229 | neo4j_haystack = ["src/neo4j_haystack", "*/neo4j-haystack/src/neo4j_haystack"] 230 | tests = ["tests", "*/neo4j-haystack/tests"] 231 | 232 | [tool.coverage.report] 233 | exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] 234 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/components/test_neo4j_dynamic_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from unittest import mock 3 | 4 | import pytest 5 | from haystack import Document, Pipeline, component 6 | 7 | from neo4j_haystack.components.neo4j_retriever import ( 8 | Neo4jClient, 9 | Neo4jClientConfig, 10 | Neo4jDynamicDocumentRetriever, 11 | ) 12 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 13 | 14 | 15 | @pytest.fixture 16 | def client_config( 17 | neo4j_database: Neo4jClientConfig, 18 | doc_store: Neo4jDocumentStore, 19 | documents: List[Document], 20 | ) -> Neo4jClientConfig: 21 | doc_store.write_documents(documents) 22 | return neo4j_database 23 | 24 | 25 | @pytest.mark.integration 26 | def test_retrieve_all_documents(client_config: Neo4jClientConfig): 27 | retriever = Neo4jDynamicDocumentRetriever( 28 | client_config=client_config, doc_node_name="doc", verify_connectivity=True 29 | ) 30 | 31 | result = retriever.run(query="MATCH (doc:Document) RETURN doc") 32 | documents: List[Document] = result["documents"] 33 | 34 | assert len(documents) == 9 35 | 36 | 37 | @pytest.mark.integration 38 | def test_compose_docs_from_result(client_config: Neo4jClientConfig): 39 | query = """MATCH (doc:Document) 40 | WITH doc, 10 as score 41 | RETURN doc.id as id, doc.content as content, score""" 42 | 43 | retriever = Neo4jDynamicDocumentRetriever( 44 | client_config=client_config, 45 | query=query, 46 | compose_doc_from_result=True, 47 | verify_connectivity=True, 48 | ) 49 | 50 | result = retriever.run() 51 | documents: List[Document] = result["documents"] 52 | 53 | assert len(documents) == 9 54 | for doc in documents: 55 | assert doc.id 56 | assert doc.content 57 | assert doc.score == 10 58 | assert not doc.meta 59 | 60 | 61 | @pytest.mark.integration 62 | def test_filter_query(client_config: Neo4jClientConfig): 63 | retriever = Neo4jDynamicDocumentRetriever( 64 | client_config=client_config, doc_node_name="doc", verify_connectivity=True 65 | ) 66 | 67 | result = retriever.run( 68 | query="""MATCH (doc:Document) 69 | WHERE doc.year > 2020 OR doc.year is NULL 70 | RETURN doc""" 71 | ) 72 | documents: List[Document] = result["documents"] 73 | assert len(documents) == 6 74 | 75 | for doc in documents: 76 | assert doc.meta.get("year", -1) != 2020 77 | 78 | 79 | @pytest.mark.integration 80 | def test_pipeline_execution(client_config: Neo4jClientConfig): 81 | @component 82 | class YearProvider: 83 | @component.output_types(year_start=int, year_end=int) 84 | def run(self, year_start: int, year_end: int): 85 | return {"year_start": year_start, "year_end": year_end} 86 | 87 | retriever = Neo4jDynamicDocumentRetriever( 88 | client_config=client_config, 89 | runtime_parameters=["year_start", "year_end"], 90 | doc_node_name="doc", 91 | verify_connectivity=True, 92 | ) 93 | 94 | query = """MATCH (doc:Document) 95 | WHERE (doc.year >= $year_start and doc.year <= $year_end) AND doc.month = $month 96 | RETURN doc LIMIT $num_return""" 97 | 98 | pipeline = Pipeline() 99 | pipeline.add_component("year_provider", YearProvider()) 100 | pipeline.add_component("retriever", retriever) 101 | pipeline.connect("year_provider.year_start", "retriever.year_start") 102 | pipeline.connect("year_provider.year_end", "retriever.year_end") 103 | 104 | result = pipeline.run( 105 | data={ 106 | "year_provider": {"year_start": 2020, "year_end": 2021}, 107 | "retriever": { 108 | "query": query, 109 | "parameters": { 110 | "month": "02", 111 | "num_return": 2, 112 | }, 113 | }, 114 | } 115 | ) 116 | 117 | documents = result["retriever"]["documents"] 118 | assert len(documents) == 2 119 | 120 | for doc in documents: 121 | assert doc.meta.get("year") == 2021 122 | 123 | 124 | @pytest.mark.unit 125 | @mock.patch("neo4j_haystack.components.neo4j_retriever.Neo4jClientConfig", spec=Neo4jClientConfig) 126 | @mock.patch("neo4j_haystack.components.neo4j_retriever.Neo4jClient", spec=Neo4jClient) 127 | def test_document_retriever_to_dict(neo4j_client_mock, client_config_mock): 128 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jDynamicDocumentRetriever 129 | client_config_mock.configure_mock(**{"to_dict.return_value": {"mock": "mock"}}) 130 | 131 | retriever = Neo4jDynamicDocumentRetriever( 132 | client_config=client_config_mock, 133 | query="cypher", 134 | runtime_parameters=["year"], 135 | doc_node_name="doc", 136 | compose_doc_from_result=True, 137 | verify_connectivity=True, 138 | ) 139 | 140 | data = retriever.to_dict() 141 | 142 | assert data == { 143 | "type": "neo4j_haystack.components.neo4j_retriever.Neo4jDynamicDocumentRetriever", 144 | "init_parameters": { 145 | "query": "cypher", 146 | "runtime_parameters": ["year"], 147 | "doc_node_name": "doc", 148 | "compose_doc_from_result": True, 149 | "verify_connectivity": True, 150 | "client_config": {"mock": "mock"}, 151 | }, 152 | } 153 | neo4j_client.verify_connectivity.assert_called_once() 154 | 155 | 156 | @pytest.mark.unit 157 | @mock.patch.object(Neo4jClientConfig, "from_dict") 158 | @mock.patch("neo4j_haystack.components.neo4j_retriever.Neo4jClient", spec=Neo4jClient) 159 | def test_document_retriever_from_dict(neo4j_client_mock, from_dict_mock): 160 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jDynamicDocumentRetriever 161 | expected_client_config = mock.Mock(spec=Neo4jClientConfig) 162 | from_dict_mock.return_value = expected_client_config 163 | 164 | data = { 165 | "type": "neo4j_haystack.components.neo4j_retriever.Neo4jDynamicDocumentRetriever", 166 | "init_parameters": { 167 | "query": "cypher", 168 | "runtime_parameters": ["year"], 169 | "doc_node_name": "doc", 170 | "compose_doc_from_result": True, 171 | "verify_connectivity": True, 172 | "client_config": {"mock": "mock"}, 173 | }, 174 | } 175 | 176 | retriever = Neo4jDynamicDocumentRetriever.from_dict(data) 177 | 178 | assert retriever._query == "cypher" 179 | assert retriever._client_config == expected_client_config 180 | assert retriever._runtime_parameters == ["year"] 181 | assert retriever._doc_node_name == "doc" 182 | assert retriever._compose_doc_from_result is True 183 | assert retriever._verify_connectivity is True 184 | 185 | neo4j_client.verify_connectivity.assert_called_once() 186 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to 6 | [Semantic Versioning](http://semver.org/spec/v2.0.0.html). 7 | 8 | 9 | ## [v2.1.0](https://github.com/prosto/neo4j-haystack/releases/tag/v2.1.0) - 2024-09-17 10 | 11 | [Compare with v2.0.3](https://github.com/prosto/neo4j-haystack/compare/v2.0.3...v2.1.0) 12 | 13 | ### Dependencies 14 | 15 | - Update versions of dependencies ([b422aca](https://github.com/prosto/neo4j-haystack/commit/b422aca73c2a449213c1d87128418d0d77d37cb5) by Sergey Bondarenco). 16 | 17 | ### Features 18 | 19 | - Allow Cypher query to be provided during creation of Neo4jDynamicDocumentRetriever component ([ddfbd2e](https://github.com/prosto/neo4j-haystack/commit/ddfbd2e277999d05c7f580c1cb61e0341b91783d) by Sergey Bondarenco). 20 | - Allow a custom marshaller for converting Document to Neo4j record ([9ca023e](https://github.com/prosto/neo4j-haystack/commit/9ca023e059bfd6535eb0666e8c7518453f49ec46) by Sergey Bondarenco). 21 | - Neo4jQueryReader component to run arbitrary Cypher queries to extract data from Neo4j ([bc23597](https://github.com/prosto/neo4j-haystack/commit/bc23597b66342e447a90fb12e9c8874894c9ccf0) by Sergey Bondarenco). 22 | - A module responsible of serialization of complex types in Cypher query parameters ([99ff860](https://github.com/prosto/neo4j-haystack/commit/99ff86009f20adecab1bd38351632b47bf52a031) by Sergey Bondarenco). 23 | - Example RAG pipeline for Parent-Child document setup ([525d166](https://github.com/prosto/neo4j-haystack/commit/525d1665ad43383d1abdea6d3395505f72d21153) by Sergey Bondarenco). 24 | - Add raise_on_failure setting for Neo4jQueryWriter component for better error handling control ([51c819c](https://github.com/prosto/neo4j-haystack/commit/51c819c347d9633d59c404f63c04f5bdec74241e) by Sergey Bondarenco). 25 | - Neo4jQueryWriter component to run Cypher queries which write data to Neo4j ([ceb569a](https://github.com/prosto/neo4j-haystack/commit/ceb569aded92e5657a054fa4fa0fa975ac9fa571) by Sergey Bondarenco). 26 | 27 | ### Bug Fixes 28 | 29 | - Use HuggingFaceAPIGenerator in example scripts as HuggingFaceTGIGenerator is no longer available ([42ed60d](https://github.com/prosto/neo4j-haystack/commit/42ed60d3b873cb7306a4d0be9b5de682c533d8a0) by Sergey Bondarenco). 30 | 31 | ### Code Refactoring 32 | 33 | - Adjust component outputs and parameter serialization for Neo4jQueryWriter ([0d93b21](https://github.com/prosto/neo4j-haystack/commit/0d93b2102c6b677739cb878316c711ddd4a890d2) by Sergey Bondarenco). 34 | 35 | ## [v2.0.3](https://github.com/prosto/neo4j-haystack/releases/tag/v2.0.3) - 2024-02-08 36 | 37 | [Compare with v2.0.2](https://github.com/prosto/neo4j-haystack/compare/v2.0.2...v2.0.3) 38 | 39 | ### Build 40 | 41 | - Update settings as per latest ruff requirements in pyproject.toml ([652d1f1](https://github.com/prosto/neo4j-haystack/commit/652d1f1ac6666d508edde825ed78c93d87ed6c4b) by Sergey Bondarenco). 42 | 43 | ### Features 44 | 45 | - Introducing `execute_write` method in Neo4jClient to run arbitrary Cypher queries which modify data ([88e89bb](https://github.com/prosto/neo4j-haystack/commit/88e89bbe405a72e9185cf56de18aaabcebe71219) by Sergey Bondarenco). 46 | 47 | ### Bug Fixes 48 | 49 | - Return number of written documents from `write_documents` as per Protocol for document stores ([f421ed5](https://github.com/prosto/neo4j-haystack/commit/f421ed54c671c14cabc0fb1a00d5b68c156dda6c) by Sergey Bondarenco). 50 | 51 | ### Code Refactoring 52 | 53 | - Read token values using Secret util ([f75fa32](https://github.com/prosto/neo4j-haystack/commit/f75fa3258a6a53a610c7b7356a891a6ee63f2f08) by Sergey Bondarenco). 54 | 55 | ## [v2.0.2](https://github.com/prosto/neo4j-haystack/releases/tag/v2.0.2) - 2024-01-19 56 | 57 | [Compare with v2.0.1](https://github.com/prosto/neo4j-haystack/compare/v2.0.1...v2.0.2) 58 | 59 | ### Bug Fixes 60 | 61 | - Change imports for DuplicateDocumentError and DuplicatePolicy as per latest changes in haystack ([7a1f053](https://github.com/prosto/neo4j-haystack/commit/7a1f0535b143ef3b4a3e558174e369630079a824) by Sergey Bondarenco). 62 | - Rename 'model_name_or_path' to 'model' as per latest changes in haystack ([0131059](https://github.com/prosto/neo4j-haystack/commit/0131059df8f9966568fea8716d3ba1910801542c) by Sergey Bondarenco). 63 | 64 | ## [v2.0.1](https://github.com/prosto/neo4j-haystack/releases/tag/v2.0.1) - 2024-01-15 65 | 66 | [Compare with v2.0.0](https://github.com/prosto/neo4j-haystack/compare/v2.0.0...v2.0.1) 67 | 68 | ### Bug Fixes 69 | 70 | - Rename metadata slot to meta in example pipelines ([1831d40](https://github.com/prosto/neo4j-haystack/commit/1831d4071bacd1cff4cd99f186cf7a7a1a4d1edc) by Sergey Bondarenco). 71 | 72 | ## [v2.0.0](https://github.com/prosto/neo4j-haystack/releases/tag/v2.0.0) - 2024-01-15 73 | 74 | [Compare with v1.0.0](https://github.com/prosto/neo4j-haystack/compare/v1.0.0...v2.0.0) 75 | 76 | ### Build 77 | 78 | - Update haystack dependency to 2.0 ([bd3e925](https://github.com/prosto/neo4j-haystack/commit/bd3e92543674ab4f3dd8f988a3bc882bbd00042a) by Sergey Bondarenco). 79 | 80 | ### Features 81 | 82 | - Update examples based on haystack 2.0 pipelines ([a7e3bf1](https://github.com/prosto/neo4j-haystack/commit/a7e3bf1788ac9f6b87e82497740feea056386f87) by Sergey Bondarenco). 83 | - Retriever component for documents stored in Neo4j ([b411ebc](https://github.com/prosto/neo4j-haystack/commit/b411ebc5f850272e0050307f03cc6157b7bc6e26) by Sergey Bondarenco). 84 | - Update DocumentStore protocol implementation to match haystack 2.0 requirements ([9748e7d](https://github.com/prosto/neo4j-haystack/commit/9748e7d4f27087b80c8f028b8612f76ed1daf8a8) by Sergey Bondarenco). 85 | - Update metadata filter parser for document store ([6ce780c](https://github.com/prosto/neo4j-haystack/commit/6ce780c846576d690b7216e37793532841a54dc3) by Sergey Bondarenco). 86 | 87 | ### Code Refactoring 88 | 89 | - Organize modules into packages for better separation of concerns ([6a101e8](https://github.com/prosto/neo4j-haystack/commit/6a101e8047bcd2dac2b49598701f7233390bae88) by Sergey Bondarenco). 90 | - Change name of retriever component as per documented naming convention ([f79a952](https://github.com/prosto/neo4j-haystack/commit/f79a952fbe59be0d1d5d13e03ae58401f6403ce9) by Sergey Bondarenco). 91 | 92 | ## [v1.0.0](https://github.com/prosto/neo4j-haystack/releases/tag/v1.0.0) - 2023-12-19 93 | 94 | [Compare with first commit](https://github.com/prosto/neo4j-haystack/compare/f801a10c8cf6eb7d784c77d8b72005cf5985dffc...v1.0.0) 95 | 96 | ### Build 97 | 98 | - Script to bump a new release by tagging commit and generating changelog ([84a923d](https://github.com/prosto/neo4j-haystack/commit/84a923dc5d8b1f5ff8602fbdf4f86ff5c682e565) by Sergey Bondarenco). 99 | 100 | ### Features 101 | 102 | - Example of Neo4jDocumentStore used in a question answering pipeline ([e7628c6](https://github.com/prosto/neo4j-haystack/commit/e7628c672489f609c14d539859d110e8facda848) by Sergey Bondarenco). 103 | - Sample script to download movie data from HF datasets ([ed44127](https://github.com/prosto/neo4j-haystack/commit/ed44127329454b555e906e1b5463fa8b9f4e8fe7) by Sergey Bondarenco). 104 | - Project setup and initial implementation of Neo4jDataStore ([f801a10](https://github.com/prosto/neo4j-haystack/commit/f801a10c8cf6eb7d784c77d8b72005cf5985dffc) by Sergey Bondarenco). 105 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import socket 5 | import time 6 | from typing import Any, Callable, Generator, List 7 | 8 | import docker 9 | import numpy as np 10 | import pytest 11 | from haystack import Document 12 | from haystack.components.embedders import ( 13 | SentenceTransformersDocumentEmbedder, 14 | SentenceTransformersTextEmbedder, 15 | ) 16 | from neo4j import Driver, GraphDatabase 17 | 18 | from neo4j_haystack.client import Neo4jClientConfig 19 | from neo4j_haystack.document_stores import Neo4jDocumentStore 20 | 21 | NEO4J_PORT = 7689 22 | EMBEDDING_DIM = 768 23 | 24 | logger = logging.getLogger("conftest") 25 | 26 | 27 | @pytest.fixture 28 | def documents() -> List[Document]: 29 | documents = [] 30 | for i in range(3): 31 | documents.append( 32 | Document( 33 | content=f"A Foo Document {i}", 34 | meta={"name": f"name_{i}", "year": 2020, "month": "01", "numbers": [2, 4]}, 35 | embedding=np.random.rand(EMBEDDING_DIM).astype(np.float32), 36 | ) 37 | ) 38 | 39 | documents.append( 40 | Document( 41 | content=f"A Bar Document {i}", 42 | meta={"name": f"name_{i}", "year": 2021, "month": "02", "numbers": [-2, -4]}, 43 | embedding=np.random.rand(EMBEDDING_DIM).astype(np.float32), 44 | ) 45 | ) 46 | 47 | documents.append( 48 | Document( 49 | content=f"Document {i} without embeddings", 50 | meta={"name": f"name_{i}", "no_embedding": True, "month": "03"}, 51 | ) 52 | ) 53 | 54 | return documents 55 | 56 | 57 | @pytest.fixture 58 | def movie_documents() -> List[Document]: 59 | current_test_dir = os.path.dirname(__file__) 60 | 61 | with open(os.path.join(current_test_dir, "./samples/movies.json")) as movies_json: 62 | file_contents = movies_json.read() 63 | docs_json = json.loads(file_contents) 64 | documents = [Document.from_dict(doc_json) for doc_json in docs_json] 65 | 66 | return documents 67 | 68 | 69 | def _get_free_tcp_port(): 70 | tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 71 | tcp.bind(("", 0)) 72 | addr, port = tcp.getsockname() 73 | tcp.close() 74 | return port 75 | 76 | 77 | def _connection_established(db_driver: Driver) -> bool: 78 | """ 79 | Periodically check neo4j database connectivity and return connection status (:const:`True` if has been established) 80 | """ 81 | timeout = 120 82 | stop_time = 3 83 | elapsed_time = 0 84 | connection_established = False 85 | while not connection_established and elapsed_time < timeout: 86 | try: 87 | db_driver.verify_connectivity() 88 | connection_established = True 89 | except Exception: 90 | time.sleep(stop_time) 91 | elapsed_time += stop_time 92 | return connection_established 93 | 94 | 95 | @pytest.fixture(scope="module") 96 | def neo4j_database(): 97 | """ 98 | Starts Neo4j docker container and waits until Neo4j database is ready. 99 | Returns Neo4j client configuration which represents the database in the docker container. 100 | Container is removed after test suite execution. The `scope` is set to ``session`` to keep only one docker 101 | container instance fof the whole duration of tests execution to speedup the process. 102 | """ 103 | neo4j_port = _get_free_tcp_port() 104 | neo4j_version = os.environ.get("NEO4J_VERSION", "neo4j:5.13.0") 105 | neo4j_container = f"test_neo4j_haystack-{neo4j_port}" 106 | 107 | config = Neo4jClientConfig( 108 | f"bolt://localhost:{neo4j_port}", database="neo4j", username="neo4j", password="passw0rd" 109 | ) 110 | 111 | client = docker.from_env() 112 | container = client.containers.run( 113 | image=neo4j_version, 114 | auto_remove=True, 115 | environment={ 116 | "NEO4J_AUTH": f"{config.username}/{config.password}", 117 | }, 118 | name=neo4j_container, 119 | ports={"7687/tcp": ("127.0.0.1", neo4j_port)}, 120 | detach=True, 121 | remove=True, 122 | ) 123 | 124 | db_driver = GraphDatabase.driver(config.url, database=config.database, auth=config.auth) 125 | 126 | if not _connection_established(db_driver): 127 | pytest.exit("Could not startup neo4j docker container and establish connection with database") 128 | 129 | logger.info(f"Started neo4j docker container: {neo4j_container}, image: {neo4j_version}, port: {neo4j_port}") 130 | 131 | yield config 132 | 133 | logger.info(f"Stopping neo4j docker container: {neo4j_container}") 134 | 135 | db_driver.close() 136 | container.stop() 137 | 138 | 139 | @pytest.fixture 140 | def doc_store_factory(neo4j_database: Neo4jClientConfig) -> Callable[..., Neo4jDocumentStore]: 141 | """ 142 | A factory function to create `Neo4jDocumentStore`. It depends on the `neo4j_database` fixture (a running 143 | docker container with neo4j database). Can be used to construct different flavours of `Neo4jDocumentStore` 144 | by providing necessary initialization params through keyword arguments (see `doc_store_params`). 145 | """ 146 | 147 | def _doc_store(**doc_store_params) -> Neo4jDocumentStore: 148 | return Neo4jDocumentStore( 149 | **dict( 150 | { 151 | "url": neo4j_database.url, 152 | "username": neo4j_database.username, 153 | "password": neo4j_database.password, 154 | "database": neo4j_database.database, 155 | "embedding_dim": EMBEDDING_DIM, 156 | "embedding_field": "embedding", 157 | "index": "document-embeddings", 158 | "node_label": "Document", 159 | }, 160 | **doc_store_params, 161 | ) 162 | ) 163 | 164 | return _doc_store 165 | 166 | 167 | @pytest.fixture 168 | def doc_store(doc_store_factory: Callable[..., Neo4jDocumentStore]) -> Generator[Neo4jDocumentStore, Any, Any]: 169 | """ 170 | A default instance of the document store to be used in tests. 171 | Please notice data (and index) is removed after each test execution to provide a clean state for the next test. 172 | """ 173 | ds = doc_store_factory() 174 | 175 | yield ds 176 | 177 | # Remove all data from DB to start a new test with clean state 178 | ds.delete_index() 179 | 180 | 181 | @pytest.fixture(scope="session") 182 | def text_embedder() -> Callable[[str], List[float]]: 183 | text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") 184 | text_embedder.warm_up() 185 | 186 | def _text_embedder(text: str) -> List[float]: 187 | return text_embedder.run(text)["embedding"] 188 | 189 | return _text_embedder 190 | 191 | 192 | @pytest.fixture(scope="session") 193 | def doc_embedder() -> Callable[[List[Document]], List[Document]]: 194 | doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2") 195 | doc_embedder.warm_up() 196 | 197 | def _doc_embedder(documents: List[Document]) -> List[Document]: 198 | return doc_embedder.run(documents)["documents"] 199 | 200 | return _doc_embedder 201 | 202 | 203 | @pytest.fixture 204 | def movie_documents_with_embeddings( 205 | movie_documents: List[Document], 206 | doc_embedder: Callable[[List[Document]], List[Document]], 207 | ) -> List[Document]: 208 | documents_copy = [Document.from_dict(doc.to_dict()) for doc in movie_documents] 209 | return doc_embedder(documents_copy) 210 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/document_stores/test_ds_extended.py: -------------------------------------------------------------------------------- 1 | from operator import attrgetter 2 | from typing import Callable, List 3 | 4 | import numpy as np 5 | import pytest 6 | from haystack import Document 7 | 8 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 9 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 10 | 11 | 12 | @pytest.mark.unit 13 | def test_document_store_connection_parameters(): 14 | skip_db_interactions_for_test = { 15 | "create_index_if_missing": False, 16 | "verify_connectivity": False, 17 | "recreate_index": False, 18 | } 19 | 20 | connection_config = { 21 | "url": "bolt://db:7687", 22 | "username": "username", 23 | "password": "password", 24 | "database": "database", 25 | } 26 | 27 | doc_store = Neo4jDocumentStore(**connection_config, **skip_db_interactions_for_test) 28 | 29 | assert doc_store.client_config.url == "bolt://db:7687" 30 | assert doc_store.client_config.database == "database" 31 | 32 | client_config = Neo4jClientConfig(**connection_config) 33 | doc_store = Neo4jDocumentStore(client_config=client_config, **skip_db_interactions_for_test) 34 | 35 | assert doc_store.client_config == client_config 36 | 37 | 38 | @pytest.mark.integration 39 | def test_get_all_documents_generator(doc_store: Neo4jDocumentStore, documents: List[Document]): 40 | doc_store.write_documents(documents) 41 | assert len(list(doc_store.get_all_documents_generator(batch_size=2))) == 9 42 | 43 | 44 | @pytest.mark.integration 45 | def test_get_all_documents_without_embeddings(doc_store: Neo4jDocumentStore, documents: List[Document]): 46 | doc_store.write_documents(documents) 47 | out = doc_store.get_all_documents_generator(return_embedding=False) 48 | for doc in out: 49 | assert doc.embedding is None 50 | 51 | 52 | @pytest.mark.integration 53 | def test_get_document_by_id(doc_store: Neo4jDocumentStore, documents: List[Document]): 54 | doc_store.write_documents(documents) 55 | doc = doc_store.get_document_by_id(documents[0].id) 56 | 57 | assert doc 58 | assert doc.id == documents[0].id 59 | assert doc.content == documents[0].content 60 | 61 | 62 | @pytest.mark.integration 63 | def test_get_documents_by_id(doc_store: Neo4jDocumentStore, documents: List[Document]): 64 | doc_store.write_documents(documents) 65 | ids = [doc.id for doc in documents] 66 | result = {doc.id for doc in doc_store.get_documents_by_id(ids, batch_size=2)} 67 | assert set(ids) == result 68 | 69 | 70 | @pytest.mark.integration 71 | def test_count_documents_with_filter(doc_store: Neo4jDocumentStore, documents: List[Document]): 72 | doc_store.write_documents(documents) 73 | assert doc_store.count_documents_with_filter() == len(documents) 74 | assert doc_store.count_documents_with_filter({"field": "year", "operator": "in", "value": [2020]}) == 3 75 | assert doc_store.count_documents_with_filter({"field": "month", "operator": "==", "value": "02"}) == 3 76 | 77 | 78 | @pytest.mark.integration 79 | def test_get_embedding_count(doc_store: Neo4jDocumentStore, documents: List[Document]): 80 | """ 81 | We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain 82 | embeddings. 83 | """ 84 | doc_store.write_documents(documents) 85 | assert doc_store.count_documents_with_filter({"field": "embedding", "operator": "exists", "value": True}) == 6 86 | 87 | 88 | @pytest.mark.integration 89 | def test_delete_all_documents(doc_store: Neo4jDocumentStore, documents: List[Document]): 90 | doc_store.write_documents(documents) 91 | doc_store.delete_all_documents() 92 | assert doc_store.count_documents() == 0 93 | 94 | 95 | @pytest.mark.integration 96 | def test_delete_documents_with_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 97 | doc_store.write_documents(documents) 98 | doc_store.delete_all_documents(filters={"field": "year", "operator": "in", "value": [2020, 2021]}) 99 | documents = doc_store.filter_documents() 100 | assert doc_store.count_documents() == 3 101 | 102 | 103 | @pytest.mark.integration 104 | def test_delete_documents_by_id(doc_store: Neo4jDocumentStore, documents: List[Document]): 105 | doc_store.write_documents(documents) 106 | docs_to_delete = doc_store.filter_documents(filters={"field": "year", "operator": "==", "value": 2020}) 107 | doc_store.delete_all_documents(document_ids=[doc.id for doc in docs_to_delete]) 108 | assert doc_store.count_documents() == 6 109 | 110 | 111 | @pytest.mark.integration 112 | def test_delete_documents_by_id_with_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 113 | doc_store.write_documents(documents) 114 | docs_to_delete = doc_store.filter_documents(filters={"field": "year", "operator": "==", "value": 2020}) 115 | # this should delete only 1 document out of the 3 ids passed 116 | doc_store.delete_all_documents( 117 | document_ids=[doc.id for doc in docs_to_delete], filters={"field": "name", "operator": "==", "value": "name_0"} 118 | ) 119 | assert doc_store.count_documents() == 8 120 | 121 | 122 | @pytest.mark.integration 123 | def test_delete_index(doc_store: Neo4jDocumentStore, documents: List[Document]): 124 | doc_store.write_documents(documents) 125 | assert doc_store.count_documents() == len(documents) 126 | 127 | doc_store.delete_index() 128 | assert doc_store.count_documents() == 0 129 | assert ( 130 | doc_store.neo4j_client.retrieve_vector_index(doc_store.index, doc_store.node_label, doc_store.embedding_field) 131 | is None 132 | ) 133 | 134 | 135 | @pytest.mark.integration 136 | def test_delete_index_does_not_raise_if_not_exists(doc_store_factory: Callable[..., Neo4jDocumentStore]): 137 | """By default Neo4j will trigger DatabaseError (server) if trying to remove index which does not exist""" 138 | doc_store = doc_store_factory(index="document-embeddings", create_index_if_missing=False) 139 | doc_store.delete_index() 140 | 141 | 142 | @pytest.mark.integration 143 | def test_get_all_documents_large_quantities(doc_store: Neo4jDocumentStore): 144 | docs_to_write = [ 145 | Document.from_dict( 146 | {"meta": {"name": f"name_{i}"}, "content": f"text_{i}", "embedding": np.random.rand(768).astype(np.float32)} 147 | ) 148 | for i in range(1000) 149 | ] 150 | doc_store.write_documents(docs_to_write) 151 | documents = doc_store.filter_documents() 152 | 153 | assert len(documents) == len(docs_to_write) 154 | 155 | 156 | @pytest.mark.integration 157 | def test_update_embeddings( 158 | doc_store_factory: Callable[..., Neo4jDocumentStore], 159 | movie_documents: List[Document], 160 | movie_documents_with_embeddings: List[Document], 161 | ): 162 | doc_store = doc_store_factory(embedding_dim=384, similarity="cosine", recreate_index=True) 163 | doc_store.write_documents(movie_documents) # no embeddings at this point 164 | 165 | documents = doc_store.filter_documents() 166 | assert all(not doc.embedding for doc in documents) 167 | 168 | doc_store.update_embeddings(documents=movie_documents_with_embeddings) 169 | 170 | documents = doc_store.filter_documents() 171 | assert all(doc.embedding for doc in documents) 172 | 173 | 174 | @pytest.mark.integration 175 | def test_query_embeddings( 176 | doc_store_factory: Callable[..., Neo4jDocumentStore], 177 | movie_documents_with_embeddings: List[Document], 178 | text_embedder: Callable[..., List[float]], 179 | ): 180 | doc_store = doc_store_factory(embedding_dim=384, similarity="cosine", recreate_index=True) 181 | doc_store.write_documents(documents=movie_documents_with_embeddings) 182 | 183 | query_embedding = text_embedder( 184 | "A young fella pretending to be a good citizen but actually planning to commit a crime" 185 | ) 186 | 187 | documents = doc_store.query_by_embedding(query_embedding, top_k=5) 188 | assert len(documents) == 5 189 | 190 | expected_content = "A film student robs a bank under the guise of shooting a short film about a bank robbery." 191 | retrieved_contents = list(map(attrgetter("content"), documents)) 192 | assert expected_content in retrieved_contents 193 | 194 | documents = doc_store.query_by_embedding( 195 | query_embedding, 196 | top_k=5, 197 | filters={"field": "release_date", "operator": "==", "value": "2018-12-09"}, 198 | ) 199 | assert len(documents) == 1 200 | 201 | 202 | @pytest.mark.integration 203 | def test_update_meta(doc_store: Neo4jDocumentStore, documents: List[Document]): 204 | doc_store.write_documents(documents) 205 | doc = documents[0] 206 | doc_store.update_document_meta(doc.id, meta={"year": 2099, "month": "12"}) 207 | doc = doc_store.get_document_by_id(doc.id) 208 | assert doc.meta["year"] == 2099 209 | assert doc.meta["month"] == "12" 210 | -------------------------------------------------------------------------------- /src/neo4j_haystack/components/neo4j_query_writer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Literal, Optional 3 | 4 | from haystack import component, default_from_dict, default_to_dict 5 | from typing_extensions import NotRequired, TypedDict 6 | 7 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig 8 | from neo4j_haystack.serialization import ( 9 | Neo4jQueryParametersMarshaller, 10 | QueryParametersMarshaller, 11 | ) 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class QueryResult(TypedDict): 17 | query_status: Literal["success", "error"] 18 | error_message: NotRequired[str] 19 | error: NotRequired[Exception] 20 | 21 | 22 | @component 23 | class Neo4jQueryWriter: 24 | """ 25 | A component for writing arbitrary data to Neo4j database using plain Cypher query. 26 | 27 | This component gives flexible way to write data to Neo4j by running arbitrary Cypher query with 28 | parameters. Query parameters can be supplied in a pipeline from other components (or pipeline data). 29 | You could use such queries to write Documents with additional graph nodes for a more complex RAG scenarios. 30 | The difference between [DocumentWriter](https://docs.haystack.deepset.ai/docs/documentwriter) and `Neo4jQueryWriter` 31 | is that the latter can write any data to Neo4j - not just Documents. 32 | 33 | Note: 34 | Please consider [data types mappings](https://neo4j.com/docs/api/python-driver/current/api.html#data-types) in \ 35 | Cypher query when working with query parameters. Neo4j Python Driver handles type conversions/mappings. 36 | Specifically you can figure out in the documentation of the driver how to work with temporal types. 37 | 38 | ```py title="Example: Creating a Document node with Neo4jQueryWriter" 39 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 40 | from neo4j_haystack.components.neo4j_query_writer import Neo4jQueryWriter 41 | 42 | client_config = Neo4jClientConfig("bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd") 43 | 44 | doc_meta = {"year": 2020, "source_url": "https://www.deepset.ai/blog"} 45 | 46 | writer = Neo4jQueryWriter(client_config=client_config, verify_connectivity=True, runtime_parameters=["doc_meta"]) 47 | 48 | result = writer.run( 49 | query=( 50 | "MERGE (doc:`Document` {id: $doc_id})" 51 | "SET doc += {id: $doc_id, content: $content, year: $doc_meta.year, source_url: $doc_meta.source_url}" 52 | ), 53 | parameters={"doc_id": "123", "content": "beautiful graph"}, 54 | doc_meta=doc_meta 55 | ) 56 | ``` 57 | 58 | Output: 59 | `>>> {'query_status': 'success'}` 60 | 61 | In case query execution results in error and `raise_on_failure=False` the output will contain the error, e.g.: 62 | 63 | Output: 64 | `>>> {'query_status': 'error', 'error_message': 'Invalid cypher syntax', error: }` 65 | 66 | In RAG pipeline runtime parameters could be connected from other components. Make sure during component creation to 67 | specify which `runtime_parameters` are expected to become as input slots for the component. In the example above 68 | `doc_meta` can be connected , e.g. `pipeline.connect("other_component.output", "writer.doc_meta")`. 69 | 70 | Important: 71 | At the moment parameters support simple data types, dictionaries (see `doc_meta` in the example above) and 72 | python dataclasses (which can be converted to `dict`). For example `haystack.Document` or `haystack.ChatMessage` 73 | instances are valid query parameter inputs. However, currently Neo4j Python Driver does not convert dataclasses 74 | to dictionaries automatically for us. By default \ 75 | [Neo4jQueryParametersMarshaller][neo4j_haystack.serialization.query_parameters_marshaller.Neo4jQueryParametersMarshaller] 76 | is used to handle such conversions. You can change this logic by creating your own marshaller (see the 77 | `query_parameters_marshaller` attribute) 78 | """ 79 | 80 | def __init__( 81 | self, 82 | client_config: Neo4jClientConfig, 83 | query: Optional[str] = None, 84 | runtime_parameters: Optional[List[str]] = None, 85 | verify_connectivity: Optional[bool] = False, 86 | raise_on_failure: bool = True, 87 | query_parameters_marshaller: Optional[QueryParametersMarshaller] = None, 88 | ): 89 | """ 90 | Create a Neo4jDocumentWriter component. 91 | 92 | Args: 93 | client_config: Neo4j client configuration to connect to database (e.g. credentials and connection settings). 94 | query: Optional Cypher query for document retrieval. If `None` should be provided as component input. 95 | runtime_parameters: list of input parameters/slots for connecting components in a pipeline. 96 | verify_connectivity: If `True` will verify connectivity with Neo4j database configured by `client_config`. 97 | raise_on_failure: If `True` raises an exception if it fails to execute given Cypher query. 98 | query_parameters_marshaller: Marshaller responsible for converting query parameters which can be used in 99 | Cypher query, e.g. python dataclasses to be converted to dictionary. `Neo4jQueryParametersMarshaller` 100 | is the default marshaller implementation. 101 | """ 102 | self._client_config = client_config 103 | self._query = query 104 | self._runtime_parameters = runtime_parameters or [] 105 | self._verify_connectivity = verify_connectivity 106 | self._raise_on_failure = raise_on_failure 107 | 108 | self._neo4j_client = Neo4jClient(client_config) 109 | self._query_parameters_marshaller = query_parameters_marshaller or Neo4jQueryParametersMarshaller() 110 | 111 | # setup inputs 112 | run_input_slots = {"query": str, "parameters": Optional[Dict[str, Any]]} 113 | kwargs_input_slots = {param: Optional[Any] for param in self._runtime_parameters} 114 | component.set_input_types(self, **run_input_slots, **kwargs_input_slots) 115 | 116 | # setup outputs 117 | component.set_output_types(self, query_status=str, error=Optional[Exception], error_message=Optional[str]) 118 | 119 | if verify_connectivity: 120 | self._neo4j_client.verify_connectivity() 121 | 122 | def to_dict(self) -> Dict[str, Any]: 123 | """ 124 | Serialize this component to a dictionary. 125 | """ 126 | data = default_to_dict( 127 | self, 128 | query=self._query, 129 | runtime_parameters=self._runtime_parameters, 130 | verify_connectivity=self._verify_connectivity, 131 | raise_on_failure=self._raise_on_failure, 132 | ) 133 | 134 | data["init_parameters"]["client_config"] = self._client_config.to_dict() 135 | 136 | return data 137 | 138 | @classmethod 139 | def from_dict(cls, data: Dict[str, Any]) -> "Neo4jQueryWriter": 140 | """ 141 | Deserialize this component from a dictionary. 142 | """ 143 | client_config = Neo4jClientConfig.from_dict(data["init_parameters"]["client_config"]) 144 | data["init_parameters"]["client_config"] = client_config 145 | return default_from_dict(cls, data) 146 | 147 | def run(self, query: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, **kwargs) -> QueryResult: 148 | """ 149 | Runs the arbitrary Cypher `query` with `parameters` to write data to Neo4j. 150 | 151 | Once data is written to Neo4j the component returns back some execution stats. 152 | 153 | Args: 154 | query: Cypher query to run. 155 | parameters: Cypher query parameters which can be used as placeholders in the `query`. 156 | kwargs: Runtime parameters from connected components in a pipeline, e.g. 157 | `pipeline.connect("year_provider.year_start", "writer.year_start")`, where `year_start` will be part 158 | of `kwargs`. 159 | 160 | Returns: 161 | Output: Query execution stats. 162 | 163 | Example: `:::py {'query_status': 'success'}` 164 | """ 165 | query = query or self._query 166 | if query is None: 167 | raise ValueError( 168 | "`query` is mandatory input and should be provided either in component's constructor, pipeline input or" 169 | "connection" 170 | ) 171 | kwargs = kwargs or {} 172 | parameters = parameters or {} 173 | parameters_combined = {**kwargs, **parameters} 174 | 175 | try: 176 | self._neo4j_client.execute_write( 177 | query, 178 | parameters=self._serialize_parameters(parameters_combined), 179 | ) 180 | 181 | return {"query_status": "success"} 182 | except Exception as ex: 183 | if self._raise_on_failure: 184 | logger.error("Couldn't execute Neo4j write query %s", ex) 185 | raise ex 186 | 187 | return {"query_status": "error", "error_message": str(ex), "error": ex} 188 | 189 | def _serialize_parameters(self, parameters: Any) -> Any: 190 | """ 191 | Serializes `parameters` into a data structure which can be accepted by Neo4j Python Driver (and a Cypher query 192 | respectively). See \ 193 | [Neo4jQueryParametersMarshaller][neo4j_haystack.serialization.query_parameters_marshaller.Neo4jQueryParametersMarshaller] 194 | for more details. 195 | """ 196 | return self._query_parameters_marshaller.marshal(parameters) 197 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/document_stores/test_ds_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from unittest import mock 3 | 4 | import pytest 5 | from haystack import Document 6 | from haystack.document_stores.errors import DuplicateDocumentError 7 | from haystack.document_stores.types import DuplicatePolicy 8 | 9 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig 10 | from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore 11 | 12 | 13 | @pytest.mark.integration 14 | def test_write_documents(doc_store: Neo4jDocumentStore, documents: List[Document]): 15 | doc_store.write_documents(documents) 16 | 17 | docs = doc_store.filter_documents() 18 | assert len(docs) == len(documents) 19 | assert all(isinstance(doc, Document) for doc in docs) 20 | 21 | expected_ids = {doc.id for doc in documents} 22 | ids = {doc.id for doc in docs} 23 | assert ids == expected_ids 24 | 25 | 26 | @pytest.mark.integration 27 | def test_write_with_duplicate_doc_ids(doc_store: Neo4jDocumentStore): 28 | duplicate_documents = [ 29 | Document(content="Doc1", meta={"key1": "value1"}), 30 | Document(content="Doc1", meta={"key1": "value1"}), 31 | ] 32 | doc_store.write_documents(duplicate_documents, policy=DuplicatePolicy.SKIP) 33 | results = doc_store.filter_documents() 34 | 35 | assert len(results) == 1 36 | assert results[0] == duplicate_documents[0] 37 | with pytest.raises(DuplicateDocumentError): 38 | doc_store.write_documents(duplicate_documents, policy=DuplicatePolicy.FAIL) 39 | 40 | 41 | @pytest.mark.integration 42 | def test_count_documents(doc_store: Neo4jDocumentStore, documents: List[Document]): 43 | doc_store.write_documents(documents) 44 | 45 | assert doc_store.count_documents() == len(documents) 46 | 47 | 48 | @pytest.mark.integration 49 | def test_filter_documents_duplicate_text_value(doc_store: Neo4jDocumentStore): 50 | documents = [ 51 | Document(content="duplicated", meta={"meta_field": "0"}), 52 | Document(content="duplicated", meta={"meta_field": "1", "name": "file.txt"}), 53 | Document(content="Doc2", meta={"name": "file_2.txt"}), 54 | ] 55 | doc_store.write_documents(documents) 56 | 57 | documents = doc_store.filter_documents(filters={"field": "meta_field", "operator": "in", "value": ["1"]}) 58 | assert len(documents) == 1 59 | assert documents[0].content == "duplicated" 60 | assert documents[0].meta["name"] == "file.txt" 61 | 62 | documents = doc_store.filter_documents(filters={"field": "meta_field", "operator": "in", "value": ["0"]}) 63 | assert len(documents) == 1 64 | assert documents[0].content == "duplicated" 65 | assert documents[0].meta.get("name") is None 66 | 67 | documents = doc_store.filter_documents(filters={"field": "name", "operator": "==", "value": "file_2.txt"}) 68 | assert len(documents) == 1 69 | assert documents[0].content == "Doc2" 70 | assert documents[0].meta.get("meta_field") is None 71 | 72 | 73 | @pytest.mark.integration 74 | def test_filter_documents_with_correct_filters(doc_store: Neo4jDocumentStore, documents: List[Document]): 75 | doc_store.write_documents(documents) 76 | 77 | result = doc_store.filter_documents(filters={"field": "year", "operator": "in", "value": [2020]}) 78 | assert len(result) == 3 79 | 80 | documents = doc_store.filter_documents(filters={"field": "year", "operator": "in", "value": [2020, 2021]}) 81 | assert len(documents) == 6 82 | 83 | 84 | @pytest.mark.integration 85 | def test_filter_documents_with_incorrect_filter_name(doc_store: Neo4jDocumentStore, documents: List[Document]): 86 | doc_store.write_documents(documents) 87 | 88 | result = doc_store.filter_documents( 89 | filters={"field": "non_existing_meta_field", "operator": "in", "value": ["whatever"]} 90 | ) 91 | assert len(result) == 0 92 | 93 | 94 | @pytest.mark.integration 95 | def test_filter_documents_with_incorrect_filter_value(doc_store: Neo4jDocumentStore, documents: List[Document]): 96 | doc_store.write_documents(documents) 97 | result = doc_store.filter_documents(filters={"field": "year", "operator": "==", "value": "nope"}) 98 | assert len(result) == 0 99 | 100 | 101 | @pytest.mark.integration 102 | def test_duplicate_documents_skip(doc_store: Neo4jDocumentStore, documents: List[Document]): 103 | doc_store.write_documents(documents) 104 | 105 | updated_docs = [] 106 | for doc in documents: 107 | updated_d = Document.from_dict(doc.to_dict()) 108 | updated_d.meta["name"] = "Updated" 109 | updated_docs.append(updated_d) 110 | 111 | doc_store.write_documents(updated_docs, policy=DuplicatePolicy.SKIP) 112 | for doc in doc_store.filter_documents(): 113 | assert doc.meta.get("name") != "Updated" 114 | 115 | 116 | @pytest.mark.integration 117 | def test_duplicate_documents_overwrite(doc_store: Neo4jDocumentStore, documents: List[Document]): 118 | doc_store.write_documents(documents) 119 | 120 | updated_docs = [] 121 | for doc in documents: 122 | updated_d = Document.from_dict(doc.to_dict()) 123 | updated_d.meta["name"] = "Updated" 124 | updated_docs.append(updated_d) 125 | 126 | doc_store.write_documents(updated_docs, policy=DuplicatePolicy.OVERWRITE) 127 | for doc in doc_store.filter_documents(): 128 | assert doc.meta["name"] == "Updated" 129 | 130 | 131 | @pytest.mark.integration 132 | def test_duplicate_documents_fail(doc_store: Neo4jDocumentStore, documents: List[Document]): 133 | doc_store.write_documents(documents) 134 | 135 | updated_docs = [] 136 | for doc in documents: 137 | updated_d = Document.from_dict(doc.to_dict()) 138 | updated_d.meta["name"] = "Updated" 139 | updated_docs.append(updated_d) 140 | 141 | with pytest.raises(DuplicateDocumentError): 142 | doc_store.write_documents(updated_docs, policy=DuplicatePolicy.FAIL) 143 | 144 | 145 | @pytest.mark.integration 146 | def test_write_document_meta(doc_store: Neo4jDocumentStore): 147 | doc_store.write_documents( 148 | [ 149 | Document.from_dict({"content": "dict_without_meta", "id": "1"}), 150 | Document.from_dict({"content": "dict_with_meta", "meta_field": "test2", "id": "2"}), 151 | Document(content="document_object_without_meta", id="3"), 152 | Document(content="document_object_with_meta", meta={"meta_field": "test4"}, id="4"), 153 | ] 154 | ) 155 | 156 | doc1, doc2, doc3, doc4 = doc_store.filter_documents( 157 | filters={"field": "id", "operator": "in", "value": ["1", "2", "3", "4"]} 158 | ) 159 | 160 | assert not doc1.meta 161 | assert doc2.meta["meta_field"] == "test2" 162 | assert not doc3.meta 163 | assert doc4.meta["meta_field"] == "test4" 164 | 165 | 166 | @pytest.mark.integration 167 | def test_delete_documents(doc_store: Neo4jDocumentStore, documents: List[Document]): 168 | doc_store.write_documents(documents) 169 | 170 | docs_to_delete = doc_store.filter_documents(filters={"field": "year", "operator": "==", "value": 2020}) 171 | doc_store.delete_documents(document_ids=[doc.id for doc in docs_to_delete]) 172 | 173 | assert doc_store.count_documents() == 6 174 | 175 | 176 | @mock.patch.object(Neo4jClientConfig, "to_dict") 177 | @mock.patch.object(Neo4jClient, "verify_connectivity") 178 | def test_document_store_to_dict(verify_connectivity_mock, to_dict_mock): 179 | to_dict_mock.return_value = {"mock": "mock"} 180 | 181 | doc_store = Neo4jDocumentStore( 182 | url="bolt://localhost:7687", 183 | database="database", 184 | username="neo4j", 185 | password="passw0rd", 186 | index="document-embeddings", 187 | node_label="Document", 188 | embedding_dim=384, 189 | embedding_field="embedding", 190 | similarity="cosine", 191 | progress_bar=False, 192 | create_index_if_missing=False, 193 | recreate_index=False, 194 | write_batch_size=111, 195 | verify_connectivity=True, 196 | ) 197 | data = doc_store.to_dict() 198 | 199 | assert data == { 200 | "type": "neo4j_haystack.document_stores.neo4j_store.Neo4jDocumentStore", 201 | "init_parameters": { 202 | "index": "document-embeddings", 203 | "node_label": "Document", 204 | "embedding_dim": 384, 205 | "embedding_field": "embedding", 206 | "similarity": "cosine", 207 | "progress_bar": False, 208 | "create_index_if_missing": False, 209 | "recreate_index": False, 210 | "write_batch_size": 111, 211 | "verify_connectivity": True, 212 | "client_config": {"mock": "mock"}, 213 | }, 214 | } 215 | verify_connectivity_mock.assert_called_once() 216 | 217 | 218 | @pytest.mark.integration 219 | @mock.patch.object(Neo4jClientConfig, "from_dict") 220 | @mock.patch.object(Neo4jClient, "verify_connectivity") 221 | def test_document_store_from_dict(verify_connectivity_mock, from_dict_mock): 222 | data = { 223 | "type": "neo4j_haystack.document_stores.neo4j_store.Neo4jDocumentStore", 224 | "init_parameters": { 225 | "index": "document-embeddings", 226 | "node_label": "Document", 227 | "embedding_dim": 384, 228 | "embedding_field": "embedding", 229 | "similarity": "cosine", 230 | "progress_bar": False, 231 | "create_index_if_missing": False, 232 | "recreate_index": False, 233 | "write_batch_size": 111, 234 | "client_config": {"url": "bolt://localhost:7687"}, 235 | }, 236 | } 237 | expected_client_config = Neo4jClientConfig(url="bolt://localhost:7687") 238 | from_dict_mock.return_value = expected_client_config 239 | 240 | doc_store = Neo4jDocumentStore.from_dict(data) 241 | 242 | assert doc_store.client_config == expected_client_config 243 | assert doc_store.index == "document-embeddings" 244 | assert doc_store.node_label == "Document" 245 | assert doc_store.embedding_dim == 384 246 | assert doc_store.embedding_field == "embedding" 247 | assert doc_store.similarity == "cosine" 248 | assert doc_store.progress_bar is False 249 | assert doc_store.create_index_if_missing is False 250 | assert doc_store.recreate_index is False 251 | assert doc_store.write_batch_size == 111 252 | 253 | verify_connectivity_mock.assert_called_once() 254 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/metadata_filter/test_filters.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | from dataclasses import dataclass, field 3 | from typing import Any, ContextManager, Dict, Literal 4 | 5 | import pytest 6 | 7 | from neo4j_haystack.errors import Neo4jFilterParserError 8 | from neo4j_haystack.metadata_filter import FilterParser, FilterType, Neo4jQueryConverter 9 | 10 | 11 | @dataclass 12 | class TestCase: 13 | id: str 14 | filters: FilterType 15 | cypher_query: str = field(default="") 16 | params: Dict[str, Any] = field(default_factory=dict) 17 | expectation: ContextManager = field(default_factory=does_not_raise) 18 | 19 | def as_param(self): 20 | return pytest.param(self.filters, self.cypher_query, self.params, self.expectation, id=self.id) 21 | 22 | 23 | def in_comparison_cypher(field_name: str, op: Literal["IN", "NOT IN"], field_param_name: str, prefix="doc"): 24 | """A helper function to generate "IN / NOT IN" Cypher statement""" 25 | 26 | not_expr = "NOT " if op == "NOT IN" else "" 27 | return ( 28 | f"CASE valueType({prefix}.{field_name}) STARTS WITH 'LIST' " 29 | f"WHEN true THEN any(val IN {prefix}.{field_name} WHERE {not_expr}val IN {field_param_name}) " 30 | f"ELSE {not_expr}{prefix}.{field_name} IN {field_param_name} END" 31 | ) 32 | 33 | 34 | TEST_CASES = [ 35 | TestCase( 36 | "== str", 37 | filters={"field": "name", "operator": "==", "value": "test"}, 38 | cypher_query="doc.name = $fv_name", 39 | params={"fv_name": "test"}, 40 | ), 41 | TestCase( 42 | "== num", 43 | filters={"field": "num", "operator": "==", "value": 11}, 44 | cypher_query="doc.num = $fv_num", 45 | params={"fv_num": 11}, 46 | ), 47 | TestCase( 48 | "== bool", 49 | filters={"field": "flag", "operator": "==", "value": True}, 50 | cypher_query="doc.flag = $fv_flag", 51 | params={"fv_flag": True}, 52 | ), 53 | TestCase( 54 | "!= num", 55 | filters={"field": "num", "operator": "!=", "value": 10}, 56 | cypher_query="doc.num <> $fv_num", 57 | params={"fv_num": 10}, 58 | ), 59 | TestCase( 60 | "!= str", 61 | filters={"field": "name", "operator": "!=", "value": "test"}, 62 | cypher_query="doc.name <> $fv_name", 63 | params={"fv_name": "test"}, 64 | ), 65 | TestCase( 66 | "!= bool", 67 | filters={"field": "flag", "operator": "!=", "value": False}, 68 | cypher_query="doc.flag <> $fv_flag", 69 | params={"fv_flag": False}, 70 | ), 71 | TestCase( 72 | "> num", 73 | filters={"field": "num", "operator": ">", "value": 10}, 74 | cypher_query="doc.num > $fv_num", 75 | params={"fv_num": 10}, 76 | ), 77 | TestCase( 78 | ">= num", 79 | filters={"field": "num", "operator": ">=", "value": 10}, 80 | cypher_query="doc.num >= $fv_num", 81 | params={"fv_num": 10}, 82 | ), 83 | TestCase( 84 | "< num", 85 | filters={"field": "num", "operator": "<", "value": 10}, 86 | cypher_query="doc.num < $fv_num", 87 | params={"fv_num": 10}, 88 | ), 89 | TestCase( 90 | "<= num", 91 | filters={"field": "num", "operator": "<=", "value": 10}, 92 | cypher_query="doc.num <= $fv_num", 93 | params={"fv_num": 10}, 94 | ), 95 | TestCase( 96 | "in str_list", 97 | filters={"field": "name", "operator": "in", "value": ["test1", "test2"]}, 98 | cypher_query=in_comparison_cypher("name", "IN", "$fv_name"), 99 | params={"fv_name": ["test1", "test2"]}, 100 | ), 101 | TestCase( 102 | "in number_list", 103 | filters={"field": "num", "operator": "in", "value": [1, 2]}, 104 | cypher_query=in_comparison_cypher("num", "IN", "$fv_num"), 105 | params={"fv_num": [1, 2]}, 106 | ), 107 | TestCase( 108 | "in number_set", 109 | filters={"field": "num", "operator": "in", "value": {2, 1}}, 110 | cypher_query=in_comparison_cypher("num", "IN", "$fv_num"), 111 | params={"fv_num": [1, 2]}, 112 | ), 113 | TestCase( 114 | "not in str_list", 115 | filters={"field": "name", "operator": "not in", "value": ["test1", "test2"]}, 116 | cypher_query=in_comparison_cypher("name", "NOT IN", "$fv_name"), 117 | params={"fv_name": ["test1", "test2"]}, 118 | ), 119 | TestCase( 120 | "not in num_set", 121 | filters={"field": "num", "operator": "not in", "value": {5, 2}}, 122 | cypher_query=in_comparison_cypher("num", "NOT IN", "$fv_num"), 123 | params={"fv_num": [2, 5]}, 124 | ), 125 | TestCase( 126 | "not in num_tuple", 127 | filters={"field": "num", "operator": "not in", "value": (1, 2)}, 128 | cypher_query=in_comparison_cypher("num", "NOT IN", "$fv_num"), 129 | params={"fv_num": (1, 2)}, 130 | ), 131 | TestCase( 132 | "exists true", 133 | filters={"field": "name", "operator": "exists", "value": True}, 134 | cypher_query="doc.name IS NOT NULL", 135 | ), 136 | TestCase( 137 | "exists false", filters={"field": "name", "operator": "exists", "value": False}, cypher_query="doc.name IS NULL" 138 | ), 139 | TestCase( 140 | "AND many_operands", 141 | filters={ 142 | "operator": "AND", 143 | "conditions": [ 144 | {"field": "type", "operator": "==", "value": "article"}, 145 | {"field": "rating", "operator": ">=", "value": 3}, 146 | ], 147 | }, 148 | cypher_query="(doc.type = $fv_type AND doc.rating >= $fv_rating)", 149 | params={"fv_type": "article", "fv_rating": 3}, 150 | ), 151 | TestCase( 152 | "AND single_operand", 153 | filters={ 154 | "operator": "AND", 155 | "conditions": [ 156 | {"field": "rating", "operator": "==", "value": 3}, 157 | ], 158 | }, 159 | cypher_query="doc.rating = $fv_rating", 160 | params={"fv_rating": 3}, 161 | ), 162 | TestCase( 163 | "OR many_operands", 164 | filters={ 165 | "operator": "OR", 166 | "conditions": [ 167 | {"field": "type", "operator": "==", "value": "article"}, 168 | {"field": "rating", "operator": ">=", "value": 3}, 169 | ], 170 | }, 171 | cypher_query="(doc.type = $fv_type OR doc.rating >= $fv_rating)", 172 | params={"fv_type": "article", "fv_rating": 3}, 173 | ), 174 | TestCase( 175 | "OR single_operand", 176 | filters={ 177 | "operator": "OR", 178 | "conditions": [ 179 | {"field": "type", "operator": "==", "value": "article"}, 180 | ], 181 | }, 182 | cypher_query="doc.type = $fv_type", 183 | params={"fv_type": "article"}, 184 | ), 185 | TestCase( 186 | "flatten_field_name_with_underscores", 187 | filters={ 188 | "operator": "AND", 189 | "conditions": [ 190 | {"field": "meta.type", "operator": "==", "value": "article"}, 191 | {"field": "meta.rating", "operator": ">=", "value": 3}, 192 | {"field": "likes", "operator": ">", "value": 100}, 193 | ], 194 | }, 195 | cypher_query="(doc.meta_type = $fv_meta_type AND doc.meta_rating >= $fv_meta_rating AND doc.likes > $fv_likes)", 196 | params={"fv_meta_type": "article", "fv_meta_rating": 3, "fv_likes": 100}, 197 | ), 198 | TestCase( 199 | "complex_query", 200 | filters={ 201 | "operator": "AND", 202 | "conditions": [ 203 | {"field": "type", "operator": "==", "value": "article"}, 204 | {"field": "year", "operator": ">=", "value": 2015}, 205 | {"field": "year", "operator": "<", "value": 2021}, 206 | {"field": "rating", "operator": ">=", "value": 3}, 207 | { 208 | "operator": "OR", 209 | "conditions": [ 210 | {"field": "genre", "operator": "in", "value": ["economy", "politics"]}, 211 | {"field": "publisher", "operator": "==", "value": "nytimes"}, 212 | ], 213 | }, 214 | ], 215 | }, 216 | cypher_query=( 217 | "(doc.type = $fv_type AND doc.year >= $fv_year AND doc.year < $fv_year_1 AND doc.rating >= $fv_rating AND" 218 | f" ({in_comparison_cypher('genre', 'IN', '$fv_genre')} OR doc.publisher = $fv_publisher))" 219 | ), 220 | params={ 221 | "fv_type": "article", 222 | "fv_year": 2015, 223 | "fv_year_1": 2021, 224 | "fv_rating": 3, 225 | "fv_genre": ["economy", "politics"], 226 | "fv_publisher": "nytimes", 227 | }, 228 | ), 229 | TestCase( 230 | "logical_operators_same_level", 231 | filters={ 232 | "operator": "OR", 233 | "conditions": [ 234 | { 235 | "operator": "AND", 236 | "conditions": [ 237 | {"field": "type", "operator": "==", "value": "news"}, 238 | {"field": "likes", "operator": "<", "value": 100}, 239 | ], 240 | }, 241 | { 242 | "operator": "AND", 243 | "conditions": [ 244 | {"field": "type", "operator": "==", "value": "blog"}, 245 | {"field": "likes", "operator": ">=", "value": 500}, 246 | ], 247 | }, 248 | ], 249 | }, 250 | cypher_query=( 251 | "((doc.type = $fv_type AND doc.likes < $fv_likes) OR (doc.type = $fv_type_1 AND doc.likes >= $fv_likes_1))" 252 | ), 253 | params={"fv_type": "news", "fv_likes": 100, "fv_type_1": "blog", "fv_likes_1": 500}, 254 | ), 255 | TestCase( 256 | "empty_filter", 257 | filters={}, 258 | expectation=pytest.raises(Neo4jFilterParserError), 259 | ), 260 | TestCase( 261 | "unsupported_logical_operator", 262 | filters={ 263 | "operator": "XOR", 264 | "conditions": [ 265 | {"field": "type", "operator": "==", "value": "blog"}, 266 | {"field": "likes", "operator": ">=", "value": 500}, 267 | ], 268 | }, 269 | expectation=pytest.raises(Neo4jFilterParserError), 270 | ), 271 | TestCase( 272 | "unsupported_comparison_operator", 273 | filters={ 274 | "operator": "XOR", 275 | "conditions": [ 276 | {"field": "type", "operator": "==", "value": "blog"}, 277 | {"field": "likes", "operator": "mod", "value": 500}, # here 278 | ], 279 | }, 280 | expectation=pytest.raises(Neo4jFilterParserError), 281 | ), 282 | TestCase( 283 | "empty_conditions", 284 | filters={"operator": "AND", "conditions": []}, 285 | expectation=pytest.raises(Neo4jFilterParserError), 286 | ), 287 | TestCase( 288 | "wrong_conditions_type", 289 | filters={"operator": "AND", "conditions": 2}, 290 | expectation=pytest.raises(Neo4jFilterParserError), 291 | ), 292 | TestCase( 293 | "no_logical_operator_provided", 294 | filters={"conditions": [{"field": "type", "operator": "==", "value": "blog"}]}, 295 | expectation=pytest.raises(Neo4jFilterParserError), 296 | ), 297 | TestCase( 298 | "no_comparison_operator_provided", 299 | filters={"operator": "AND", "conditions": [{"field": "type", "value": "blog"}]}, 300 | expectation=pytest.raises(Neo4jFilterParserError), 301 | ), 302 | ] 303 | 304 | 305 | @pytest.mark.unit 306 | @pytest.mark.parametrize("filters,expected_cypher,expected_params,expectation", [tc.as_param() for tc in TEST_CASES]) 307 | def test_filter_converter(filters, expected_cypher, expected_params, expectation): 308 | parser = FilterParser() 309 | converter = Neo4jQueryConverter("doc") 310 | 311 | with expectation: 312 | filter_ast = parser.parse(filters) 313 | cypher_query, params = converter.convert(filter_ast) 314 | 315 | assert cypher_query == expected_cypher 316 | assert params == expected_params 317 | -------------------------------------------------------------------------------- /src/neo4j_haystack/components/neo4j_query_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Optional 3 | 4 | from haystack import component, default_from_dict, default_to_dict 5 | from typing_extensions import NotRequired, TypedDict 6 | 7 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig 8 | from neo4j_haystack.serialization import ( 9 | Neo4jQueryParametersMarshaller, 10 | QueryParametersMarshaller, 11 | ) 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class QueryResult(TypedDict): 17 | """ 18 | Query execution outputs for the `Neo4jQueryReader` component. 19 | """ 20 | 21 | records: NotRequired[List[Dict[str, Any]]] 22 | first_record: NotRequired[Optional[Dict[str, Any]]] 23 | error: NotRequired[Exception] 24 | error_message: NotRequired[str] 25 | 26 | 27 | @component 28 | class Neo4jQueryReader: 29 | """ 30 | A component for reading arbitrary data from Neo4j database using plain Cypher query. 31 | 32 | This component gives flexible way to read data from Neo4j by running custom Cypher query along with query 33 | parameters. Query parameters can be supplied in a pipeline from other components (or pipeline inputs). 34 | You could use such queries to read data from Neo4j to enhance your RAG pipelines. For example a 35 | prompt to LLM can produce Cypher query based on given context and then `Neo4jQueryReader` can be used to run the 36 | query and extract results. [OutputAdapter](https://docs.haystack.deepset.ai/docs/outputadapter) component might 37 | become handy in such scenarios - it can be used as a connection from the `Neo4jQueryReader` to convert (transform) 38 | results accordingly. 39 | 40 | Note: 41 | Please consider [data types mappings](https://neo4j.com/docs/api/python-driver/current/api.html#data-types) in \ 42 | Cypher query when working with query parameters. Neo4j Python Driver handles type conversions/mappings. 43 | Specifically you can figure out in the documentation of the driver how to work with temporal types. 44 | 45 | ```py title="Example: Find a Document node with Neo4jQueryReader and extract data" 46 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 47 | from neo4j_haystack.components.neo4j_query_reader import Neo4jQueryReader 48 | 49 | client_config = Neo4jClientConfig("bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd") 50 | 51 | reader = Neo4jQueryReader(client_config=client_config, runtime_parameters=["year"]) 52 | 53 | # Get all documents with "year"=2020 and return "name" and "embedding" attributes for each found record 54 | result = reader.run( 55 | query=("MATCH (doc:`Document`) WHERE doc.year=$year RETURN doc.name as name, doc.embedding as embedding"), 56 | year=2020, 57 | ) 58 | ``` 59 | 60 | Output: 61 | `>>> {'records': [{'name': 'name_0', 'embedding': [...]}, {'name': 'name_1', 'embedding': [...]}, \ 62 | {'name': 'name_2', 'embedding': [...]}], 'first_record': {'name': 'name_0', 'embedding': [...]}}` 63 | 64 | The above result contains the following output: 65 | 66 | - `records` - A list of dictionaries, will have all the records returned by Cypher query. You can control record 67 | outputs as per your needs. For example an aggregation function could be used to return a single result. 68 | In such case there will be one record in the `records` list. 69 | - `first_record` - In case the `records` contains just one item, `first_record` will have the first record 70 | from the list (put simply, first_record=records[0]). It was introduced as a syntax convenience. 71 | 72 | If your Cypher query produces an error (e.g. invalid syntax) you could use that in `Loop-Based Auto-Correction` 73 | pipelines to ask LLM to auto correct the query based on the error message, afterwards run the query again. 74 | 75 | ```py title="Example: Output error with Neo4jQueryReader" 76 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 77 | from neo4j_haystack.components.neo4j_query_reader import Neo4jQueryReader 78 | 79 | client_config = Neo4jClientConfig("bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd") 80 | 81 | reader = Neo4jQueryReader(client_config=client_config, raise_on_failure=False) 82 | 83 | # Intentionally introduce error in Cypher query (see "RETURN_") 84 | result = reader.run( 85 | query=("MATCH (doc:`Document` {name: $name}) RETURN_ doc.name as name, doc.year as year"), 86 | parameters={"name": "name_1"}, 87 | ) 88 | ``` 89 | 90 | Output: 91 | `>>> {'error_message': 'Invalid input \'RETURN_\'...', 'error': }` 92 | 93 | The `error_message` output can be used in your pipeline to deal with Cypher query error (e.g. auto correction) 94 | 95 | When configuring [Query parameters](https://neo4j.com/docs/python-manual/current/query-simple/#query-parameters) \ 96 | for `Neo4jQueryReader` component, consider the following: 97 | 98 | - Parameters can be provided at the component creation time, see `parameters` 99 | - In RAG pipeline runtime parameters could be connected from other components. 100 | Make sure during creation time to specify which `runtime_parameters` are expected. 101 | 102 | Important: 103 | At the moment parameters support simple data types, dictionaries and python dataclasses (which can be converted 104 | to `dict`). For example `haystack.ChatMessage` instance is a valid query parameter input. If you supply custom 105 | classes as query parameters, e.g. \ 106 | `Neo4jQueryReader(client_config=client_config).run(parameters={"obj": })` it will 107 | result in error. In such rare cases `query_parameters_marshaller` attribute can be used to provide a 108 | custom marshaller implementation for the type being used as query parameter value. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | client_config: Neo4jClientConfig, 114 | query: Optional[str] = None, 115 | runtime_parameters: Optional[List[str]] = None, 116 | verify_connectivity: Optional[bool] = False, 117 | raise_on_failure: bool = False, 118 | query_parameters_marshaller: Optional[QueryParametersMarshaller] = None, 119 | ): 120 | """ 121 | Creates a Neo4jDocumentReader component. 122 | 123 | Args: 124 | client_config: Neo4j client configuration to connect to database (e.g. credentials and connection settings). 125 | query: Optional Cypher query if known at component creation time. If `None` should be provided as component 126 | input. 127 | runtime_parameters: list of input parameters/slots for connecting components in a pipeline. 128 | verify_connectivity: If `True` will verify connectivity with Neo4j database configured by `client_config`. 129 | raise_on_failure: If `True` raises an exception if it fails to execute given Cypher query. 130 | query_parameters_marshaller: Marshaller responsible for converting query parameters which can be used in 131 | Cypher query, e.g. python dataclasses to be converted to dictionary. `Neo4jQueryParametersMarshaller` 132 | is the default marshaller implementation. 133 | """ 134 | self._client_config = client_config 135 | self._query = query 136 | self._runtime_parameters = runtime_parameters or [] 137 | self._verify_connectivity = verify_connectivity 138 | self._raise_on_failure = raise_on_failure 139 | 140 | self._neo4j_client = Neo4jClient(client_config) 141 | self._query_parameters_marshaller = query_parameters_marshaller or Neo4jQueryParametersMarshaller() 142 | 143 | # setup inputs 144 | run_input_slots = {"query": str, "parameters": Optional[Dict[str, Any]]} 145 | kwargs_input_slots = {param: Optional[Any] for param in self._runtime_parameters} 146 | component.set_input_types(self, **run_input_slots, **kwargs_input_slots) 147 | 148 | # setup outputs 149 | component.set_output_types( 150 | self, records=List[Dict[str, Any]], expanded_record=Optional[Dict[str, Any]], error=Optional[str] 151 | ) 152 | 153 | if verify_connectivity: 154 | self._neo4j_client.verify_connectivity() 155 | 156 | def to_dict(self) -> Dict[str, Any]: 157 | """ 158 | Serialize this component to a dictionary. 159 | """ 160 | data = default_to_dict( 161 | self, 162 | query=self._query, 163 | runtime_parameters=self._runtime_parameters, 164 | verify_connectivity=self._verify_connectivity, 165 | raise_on_failure=self._raise_on_failure, 166 | ) 167 | 168 | data["init_parameters"]["client_config"] = self._client_config.to_dict() 169 | 170 | return data 171 | 172 | @classmethod 173 | def from_dict(cls, data: Dict[str, Any]) -> "Neo4jQueryReader": 174 | """ 175 | Deserialize this component from a dictionary. 176 | """ 177 | client_config = Neo4jClientConfig.from_dict(data["init_parameters"]["client_config"]) 178 | data["init_parameters"]["client_config"] = client_config 179 | return default_from_dict(cls, data) 180 | 181 | def run(self, query: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, **kwargs) -> QueryResult: 182 | """ 183 | Runs the arbitrary Cypher `query` with `parameters` to read data from Neo4j. 184 | 185 | Args: 186 | query: Cypher query to run. 187 | parameters: Cypher query parameters which can be used as placeholders in the `query`. 188 | kwargs: Arbitrary parameters supplied in a pipeline execution from other component's output slots, e.g. 189 | `pipeline.connect("year_provider.year_start", "reader.year_start")`, where `year_start` will be part 190 | of `kwargs`. 191 | 192 | Returns: 193 | Output: Records returned from Cypher query in case request was successful or error message if there was an 194 | error during Cypher query execution (`raise_on_failure` should be `False`). 195 | 196 | ```py title="Example: Output with records" 197 | {'records': [{...}, {...}], 'first_record': {...}} 198 | ``` 199 | 200 | where: 201 | 202 | - `records` - List of records returned (e.g. using `RETURN` statement) by Cypher query 203 | - `first_record` - First record from the `records` list if any 204 | 205 | ```py title="Example: Output with error" 206 | {'error_message': 'Invalid Cypher syntax...', 'error': } 207 | ``` 208 | 209 | where: 210 | 211 | - `error_message` - Error message returned by Neo4j in case Cypher query is invalid 212 | - `error` - Original Exception which was triggered by Neo4j (containing the `error_message`) 213 | """ 214 | query = query or self._query 215 | if query is None: 216 | raise ValueError( 217 | "`query` is mandatory input and should be provided either in component's constructor, pipeline input or" 218 | "connection" 219 | ) 220 | kwargs = kwargs or {} 221 | parameters = parameters or {} 222 | parameters_combined = {**kwargs, **parameters} 223 | 224 | try: 225 | _, records = self._neo4j_client.execute_read( 226 | query, 227 | parameters=self._serialize_parameters(parameters_combined), 228 | ) 229 | 230 | return {"records": records, "first_record": records[0] if len(records) > 0 else None} 231 | except Exception as ex: 232 | if self._raise_on_failure: 233 | logger.error("Couldn't execute Neo4j read query %s", ex) 234 | raise ex 235 | 236 | return { 237 | "error": ex, 238 | "error_message": str(ex), 239 | } 240 | 241 | def _serialize_parameters(self, parameters: Any) -> Any: 242 | """ 243 | Serializes `parameters` into a data structure which can be accepted by Neo4j Python Driver (and a Cypher query 244 | respectively). See \ 245 | [Neo4jQueryParametersMarshaller][neo4j_haystack.serialization.query_parameters_marshaller.Neo4jQueryParametersMarshaller] 246 | for more details. 247 | """ 248 | return self._query_parameters_marshaller.marshal(parameters) 249 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/components/test_neo4j_query_writer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict 3 | from unittest import mock 4 | 5 | import pytest 6 | from haystack import Document, Pipeline, component 7 | from haystack.dataclasses.chat_message import ChatMessage 8 | 9 | from neo4j_haystack.client.neo4j_client import Neo4jClientConfig 10 | from neo4j_haystack.components.neo4j_query_writer import Neo4jQueryWriter 11 | from neo4j_haystack.components.neo4j_retriever import Neo4jClient 12 | 13 | 14 | @pytest.fixture 15 | def client_config(neo4j_database: Neo4jClientConfig) -> Neo4jClientConfig: 16 | return neo4j_database 17 | 18 | 19 | @pytest.fixture 20 | def neo4j_client(client_config: Neo4jClientConfig) -> Neo4jClient: 21 | return Neo4jClient(config=client_config) 22 | 23 | 24 | @pytest.mark.integration 25 | def test_query_writer(client_config: Neo4jClientConfig, neo4j_client: Neo4jClient): 26 | doc_id = "t1_123" 27 | doc_meta = {"year": 2024, "source_url": "https://www.deepset.ai/blog"} 28 | 29 | writer = Neo4jQueryWriter(client_config=client_config, verify_connectivity=True, runtime_parameters=["doc_meta"]) 30 | 31 | result = writer.run( 32 | query=( 33 | "MERGE (doc:`Document` {id: $doc_id})" 34 | "SET doc += {id: $doc_id, content: $content, year: $doc_meta.year, source_url: $doc_meta.source_url}" 35 | ), 36 | parameters={"doc_id": doc_id, "content": "beautiful graph"}, 37 | doc_meta=doc_meta, 38 | ) 39 | 40 | assert result["query_status"] == "success" 41 | 42 | records = list( 43 | neo4j_client.query_nodes("MATCH (doc:Document {id:$doc_id}) RETURN doc", parameters={"doc_id": doc_id}) 44 | ) 45 | assert len(records) == 1 46 | 47 | doc_in_neo4j = records[0].data().get("doc") 48 | assert doc_in_neo4j == { 49 | "id": doc_id, 50 | "content": "beautiful graph", 51 | "year": 2024, 52 | "source_url": "https://www.deepset.ai/blog", 53 | } 54 | 55 | 56 | @pytest.mark.integration 57 | def test_query_writer_with_document_input(client_config: Neo4jClientConfig, neo4j_client: Neo4jClient): 58 | doc_id = "t2_123" 59 | document = Document(id=doc_id, content="beautiful graph", meta={"year": 2024}) 60 | 61 | writer = Neo4jQueryWriter(client_config=client_config, verify_connectivity=True, runtime_parameters=["document"]) 62 | 63 | runtime_params = {"document": document} 64 | result = writer.run( 65 | query=( 66 | "MERGE (doc:`Document` {id: $document.id})" 67 | "SET doc += {id: $document.id, content: $document.content, year: $document.year}" 68 | ), 69 | **runtime_params, # Document will be serialized with `Document.to_dict` 70 | ) 71 | 72 | assert result["query_status"] == "success" 73 | 74 | records = list( 75 | neo4j_client.query_nodes("MATCH (doc:Document {id:$doc_id}) RETURN doc", parameters={"doc_id": doc_id}) 76 | ) 77 | assert len(records) == 1 78 | 79 | doc_in_neo4j = records[0].data().get("doc") 80 | assert Document.from_dict(doc_in_neo4j) == document 81 | 82 | 83 | @pytest.mark.integration 84 | def test_query_writer_with_complex_input(client_config: Neo4jClientConfig, neo4j_client: Neo4jClient): 85 | @dataclass 86 | class Coordinates: 87 | x: float 88 | y: float 89 | 90 | @dataclass 91 | class Location: 92 | area: str 93 | coordinates: Coordinates 94 | 95 | query = """ 96 | MERGE (p:`Person` {name: $name}) 97 | SET p += { 98 | age: $age, 99 | milestones: $milestones, 100 | `first_friend.name`: $friends[0].name, 101 | `first_friend.age`: $friends[0].age, 102 | `location.area`: $location.area, 103 | `location.coordinates.x`: $location.coordinates.x, 104 | `location.coordinates.y`: $location.coordinates.y 105 | } 106 | """ 107 | 108 | writer = Neo4jQueryWriter(client_config=client_config, query=query) 109 | 110 | runtime_params = { 111 | "name": "emily", 112 | "age": 30, 113 | "milestones": ["first", "second"], 114 | "friends": [{"name": "john", "age": 29}], 115 | "interests": {"summer": "swimming", "winter": "skiing"}, 116 | "location": Location("Nice", Coordinates(51.4934, 0.0098)), 117 | } 118 | result = writer.run(parameters=runtime_params) 119 | 120 | assert result["query_status"] == "success" 121 | 122 | records = list( 123 | neo4j_client.query_nodes("MATCH (person:Person {name:$name}) RETURN person", parameters={"name": "emily"}) 124 | ) 125 | assert len(records) == 1 126 | 127 | person_in_neo4j: Dict[str, Any] = records[0].data().get("person", {}) 128 | assert person_in_neo4j == { 129 | "name": "emily", 130 | "age": 30, 131 | "milestones": ["first", "second"], 132 | "first_friend.name": "john", 133 | "first_friend.age": 29, 134 | "location.area": "Nice", 135 | "location.coordinates.x": 51.4934, 136 | "location.coordinates.y": 0.0098, 137 | } 138 | 139 | 140 | @pytest.mark.integration 141 | def test_query_writer_with_dataclass_input(client_config: Neo4jClientConfig, neo4j_client: Neo4jClient): 142 | session_id = "123" 143 | chat_messages = [ 144 | ChatMessage.from_system("You are are helpful chatbot"), 145 | ChatMessage.from_assistant("How can I help you?"), 146 | ChatMessage.from_user("Tell me your secret"), 147 | ] 148 | 149 | writer = Neo4jQueryWriter( 150 | client_config=client_config, verify_connectivity=True, runtime_parameters=["chat_messages"] 151 | ) 152 | 153 | query = """ 154 | // Create session node with a related first chat message from the list 155 | MERGE (session:`Session` {id: $session_id}) 156 | WITH session 157 | CREATE (session)-[:FIRST_MESSAGE]->(msg_node:`ChatMessage`) 158 | WITH msg_node 159 | MATCH (msg_node) SET msg_node += { role: $chat_messages[0].role, content: $chat_messages[0].content } 160 | WITH msg_node as first_message 161 | 162 | // Create remaining chat message nodes 163 | UNWIND tail($chat_messages) as tail_message 164 | CREATE (msg_node:`ChatMessage`) 165 | WITH msg_node, tail_message, first_message 166 | MATCH (msg_node) SET msg_node += { role: tail_message.role, content: tail_message.content } 167 | 168 | // Connect chat messages with :NEXT_MESSAGE relationship 169 | WITH collect(msg_node) AS message_nodes, first_message 170 | WITH [first_message] + message_nodes as all_nodes 171 | FOREACH (i IN range(0, size(all_nodes)-2) 172 | | FOREACH (n1 IN [all_nodes[i]] 173 | | FOREACH (n2 IN [all_nodes[i+1]] 174 | | CREATE (n1)-[:NEXT_MESSAGE]->(n2)))) 175 | """ 176 | 177 | runtime_params = {"chat_messages": chat_messages} 178 | result = writer.run( 179 | query=query, 180 | parameters={"session_id": session_id}, 181 | **runtime_params, # Will be serialized to list of dictionaries 182 | ) 183 | 184 | assert result["query_status"] == "success" 185 | 186 | cypher_query_get_messages = """ 187 | MATCH (session:Session {id:$session_id})-[:FIRST_MESSAGE|NEXT_MESSAGE*]->(message) 188 | RETURN session, collect(message) as messages 189 | """ 190 | 191 | records = list( 192 | neo4j_client.query_nodes( 193 | cypher_query_get_messages, 194 | parameters={"session_id": session_id}, 195 | ) 196 | ) 197 | assert len(records) == 1 198 | 199 | session_in_neo4j = records[0].data().get("session") 200 | assert session_in_neo4j == {"id": session_id} 201 | 202 | messages_in_neo4j = records[0].data().get("messages") 203 | assert messages_in_neo4j == [ 204 | {"role": "system", "content": "You are are helpful chatbot"}, 205 | {"role": "assistant", "content": "How can I help you?"}, 206 | {"role": "user", "content": "Tell me your secret"}, 207 | ] 208 | 209 | 210 | @pytest.mark.integration 211 | def test_query_writer_with_non_supported_parameters(client_config: Neo4jClientConfig): 212 | class Location: 213 | def __init__(self, area: str): 214 | self.area = area 215 | 216 | writer = Neo4jQueryWriter(client_config=client_config, verify_connectivity=True, runtime_parameters=["data"]) 217 | 218 | runtime_params = { 219 | "location": Location("Nice"), 220 | } 221 | 222 | with pytest.raises(Exception): # noqa: B017 223 | writer.run( 224 | query=("MERGE (p:`Person` {name: $name}) SET p += { area: $location.area }"), 225 | parameters=runtime_params, 226 | ) 227 | 228 | 229 | @pytest.mark.integration 230 | def test_pipeline_execution(client_config: Neo4jClientConfig, neo4j_client: Neo4jClient): 231 | @component 232 | class YearProvider: 233 | @component.output_types(year_start=int, year_end=int) 234 | def run(self, year_start: int, year_end: int): 235 | return {"year_start": year_start, "year_end": year_end} 236 | 237 | doc_id = "001" 238 | 239 | query = "CREATE (doc:Document {id: $id, content: $content, year_start: $year_start, year_end: $year_end})" 240 | writer = Neo4jQueryWriter( 241 | client_config=client_config, verify_connectivity=True, runtime_parameters=["year_start", "year_end"] 242 | ) 243 | 244 | pipeline = Pipeline() 245 | pipeline.add_component("year_provider", YearProvider()) 246 | pipeline.add_component("writer", writer) 247 | pipeline.connect("year_provider.year_start", "writer.year_start") 248 | pipeline.connect("year_provider.year_end", "writer.year_end") 249 | 250 | pipeline.run( 251 | data={ 252 | "year_provider": {"year_start": 2017, "year_end": 2024}, 253 | "writer": { 254 | "query": query, 255 | "parameters": { 256 | "id": doc_id, 257 | "content": "Hello? Is it transformer?", 258 | }, 259 | }, 260 | } 261 | ) 262 | 263 | records = list( 264 | neo4j_client.query_nodes("MATCH (doc:Document {id:$doc_id}) RETURN doc", parameters={"doc_id": doc_id}) 265 | ) 266 | assert len(records) == 1 267 | 268 | doc_in_neo4j = records[0].data().get("doc") 269 | assert doc_in_neo4j == {"id": doc_id, "content": "Hello? Is it transformer?", "year_start": 2017, "year_end": 2024} 270 | 271 | 272 | @pytest.mark.integration 273 | def test_query_writer_raises_on_failure(client_config: Neo4jClientConfig): 274 | writer = Neo4jQueryWriter( 275 | client_config=client_config, 276 | verify_connectivity=False, 277 | raise_on_failure=True, 278 | ) 279 | 280 | invalid_query = "CREATE (n:Universe {id=`42`})" # should be ":" instead of "=" 281 | 282 | with pytest.raises(Exception): # noqa: B017 283 | writer.run(invalid_query) 284 | 285 | 286 | @pytest.mark.integration 287 | def test_query_writer_return_on_failure(client_config: Neo4jClientConfig): 288 | writer = Neo4jQueryWriter(client_config=client_config, raise_on_failure=False) 289 | 290 | invalid_query = "CREATE (n:Universe {id=`42`})" # should be ":" instead of "=" 291 | result = writer.run(invalid_query) 292 | 293 | assert result["query_status"] == "error" 294 | assert "Invalid input" in result["error_message"] 295 | 296 | 297 | @pytest.mark.unit 298 | @mock.patch("neo4j_haystack.components.neo4j_query_writer.Neo4jClientConfig", spec=Neo4jClientConfig) 299 | @mock.patch("neo4j_haystack.components.neo4j_query_writer.Neo4jClient", spec=Neo4jClient) 300 | def test_neo4j_query_writer_to_dict(neo4j_client_mock, client_config_mock): 301 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jQueryWriter 302 | client_config_mock.configure_mock(**{"to_dict.return_value": {"mock": "mock"}}) 303 | 304 | writer = Neo4jQueryWriter( 305 | client_config=client_config_mock, 306 | query="cypher", 307 | runtime_parameters=["year"], 308 | verify_connectivity=True, 309 | raise_on_failure=False, 310 | ) 311 | 312 | data = writer.to_dict() 313 | 314 | assert data == { 315 | "type": "neo4j_haystack.components.neo4j_query_writer.Neo4jQueryWriter", 316 | "init_parameters": { 317 | "query": "cypher", 318 | "runtime_parameters": ["year"], 319 | "verify_connectivity": True, 320 | "raise_on_failure": False, 321 | "client_config": {"mock": "mock"}, 322 | }, 323 | } 324 | neo4j_client.verify_connectivity.assert_called_once() 325 | 326 | 327 | @pytest.mark.unit 328 | @mock.patch.object(Neo4jClientConfig, "from_dict") 329 | @mock.patch("neo4j_haystack.components.neo4j_query_writer.Neo4jClient", spec=Neo4jClient) 330 | def test_neo4j_query_writer_from_dict(neo4j_client_mock, from_dict_mock): 331 | neo4j_client = neo4j_client_mock.return_value # capturing instance created in Neo4jQueryWriter 332 | expected_client_config = mock.Mock(spec=Neo4jClientConfig) 333 | from_dict_mock.return_value = expected_client_config 334 | 335 | data = { 336 | "type": "neo4j_haystack.components.neo4j_query_writer.Neo4jQueryWriter", 337 | "init_parameters": { 338 | "query": "cypher", 339 | "runtime_parameters": ["year"], 340 | "verify_connectivity": True, 341 | "raise_on_failure": False, 342 | "client_config": {"mock": "mock"}, 343 | }, 344 | } 345 | 346 | writer = Neo4jQueryWriter.from_dict(data) 347 | 348 | assert writer._query == "cypher" 349 | assert writer._client_config == expected_client_config 350 | assert writer._runtime_parameters == ["year"] 351 | assert writer._verify_connectivity is True 352 | assert writer._raise_on_failure is False 353 | 354 | neo4j_client.verify_connectivity.assert_called_once() 355 | -------------------------------------------------------------------------------- /src/neo4j_haystack/metadata_filter/parser.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from enum import Enum 3 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Union 4 | 5 | from neo4j_haystack.errors import Neo4jFilterParserError 6 | 7 | FilterType = Dict[str, Any] 8 | 9 | FieldValuePrimitive = Union[int, float, str, bool] 10 | FieldValueType = Union[FieldValuePrimitive, Iterable[FieldValuePrimitive]] 11 | 12 | LogicalOpsTuple = namedtuple("LogicalOpsTuple", "OP_AND, OP_OR, OP_NOT") 13 | LOGICAL_OPS = LogicalOpsTuple("AND", "OR", "NOT") 14 | 15 | ComparisonOpsTuple = namedtuple( 16 | "ComparisonOpsTuple", "OP_EQ, OP_NEQ, OP_IN, OP_NIN, OP_GT, OP_GTE, OP_LT, OP_LTE, OP_EXISTS" 17 | ) 18 | COMPARISON_OPS = ComparisonOpsTuple("==", "!=", "in", "not in", ">", ">=", "<", "<=", "exists") 19 | 20 | 21 | class OpType(str, Enum): 22 | LOGICAL = "logical" 23 | COMPARISON = "comparison" 24 | UNKNOWN = "unknown" 25 | 26 | @staticmethod 27 | def from_op(op: Optional[str]) -> "OpType": 28 | if op in LOGICAL_OPS: 29 | return OpType.LOGICAL 30 | elif op in COMPARISON_OPS: 31 | return OpType.COMPARISON 32 | 33 | return OpType.UNKNOWN 34 | 35 | 36 | class AST: 37 | """ 38 | A base class for nodes of an abstract syntax tree for metadata filters. Its `descriptor` property provides a python 39 | friendly name (e.g. potentially used as a method name) 40 | """ 41 | 42 | @property 43 | def descriptor(self) -> str: 44 | return type(self).__name__.lower() 45 | 46 | 47 | class LogicalOp(AST): 48 | """ 49 | `AST` node which represents a logical operator e.g. "AND", "OR", "NOT". 50 | Logical operator is comprised of an operator (e.g. "AND") and respective operands - expressions participating in 51 | the logical evaluation syntax. For example one could read it the following way: ``"operand1 AND operand2"``, where 52 | ``"AND"`` is the operator and ``"operand1", "operand2"`` are `AST` nodes which might evaluate into: 53 | 54 | * simple expressions like ``"field1 = 1 AND field2 >= 2"`` 55 | * more complex expressions like ``"(field1 = 1 OR field2 < 4) AND field3 > 2"`` 56 | 57 | Please notice the actual representation of expressions is managed by a separate component which knows how to parse 58 | the syntax tree and translate its nodes (`AST`) into DocumentStore's specific filtering syntax. 59 | """ 60 | 61 | def __init__(self, operands: Sequence[AST], op: str): 62 | if OpType.from_op(op) != OpType.LOGICAL: 63 | raise Neo4jFilterParserError( 64 | f"The '{op}' logical operator is not supported. Consider using one of: {LOGICAL_OPS}." 65 | ) 66 | 67 | self.operands = operands 68 | self.op = op 69 | 70 | @property 71 | def descriptor(self) -> str: 72 | return "logical_op" 73 | 74 | def __repr__(self): 75 | return f"" 76 | 77 | 78 | class ComparisonOp(AST): 79 | """ 80 | This `AST` node represents a comparison operator in filters syntax tree (e.g. "==", "in", "<=" etc). 81 | Comparison operator is comprised of an operator, a field name and field value. For example one could 82 | read it in the following way: ``"age == 20"``, where 83 | 84 | - ``"=="`` - is the operator 85 | - ``"age"`` - is a field name 86 | - "20" - is a comparison value. 87 | 88 | Please notice the actual representation of comparison expressions is managed by a separate component which knows how 89 | to translate the `ComparisonOp` into DocumentStore's specific filtering syntax. 90 | """ 91 | 92 | def __init__(self, field_name: str, op: str, field_value: FieldValueType): 93 | if OpType.from_op(op) != OpType.COMPARISON: 94 | raise Neo4jFilterParserError( 95 | f"The '{op}' comparison operator is not supported. Consider using one of: {COMPARISON_OPS}." 96 | ) 97 | 98 | self.field_name = field_name 99 | self.op = op 100 | self.field_value = field_value 101 | 102 | @property 103 | def descriptor(self) -> str: 104 | return "comparison_op" 105 | 106 | def __repr__(self): 107 | return f"" 108 | 109 | 110 | OperatorAST = Union[ComparisonOp, LogicalOp] 111 | 112 | 113 | class FilterParser: 114 | """ 115 | The implementation of metadata filter parser into an abstract syntax tree comprised of respective 116 | `AST` nodes. The tree structure has a single root node and, depending on actual filters provided, can result 117 | into a number of Logical nodes as well as Comparison (leaf) nodes. The parsing logic takes into consideration rules 118 | documented in the following document [Metadata Filtering](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering). 119 | 120 | `FilterParser` does not depend on actual DocumentStore implementation. Its single purpose is to parse filters into a 121 | tree of `AST` nodes by applying metadata filtering rules (including default operators behavior). 122 | 123 | With a given example of a metadata filter parsing below: 124 | 125 | ```py 126 | filters = { 127 | "operator": "OR", 128 | "conditions": [ 129 | { 130 | "operator": "AND", 131 | "conditions": [ 132 | {"field": "meta.type", "operator": "==", "value": "news"}, 133 | {"field": "meta.likes", "operator": "!=", "value": 100}, 134 | ], 135 | }, 136 | { 137 | "operator": "AND", 138 | "conditions": [ 139 | {"field": "meta.type", "operator": "==", "value": "blog"}, 140 | {"field": "meta.likes", "operator": ">=", "value": 500}, 141 | ], 142 | }, 143 | ], 144 | } 145 | op_tree = FilterParser().parse(filters) 146 | ``` 147 | 148 | We should expect the following tree structure (`op_tree`) after parsing: 149 | 150 | ```console 151 | +-------------+ 152 | | | 153 | | op: "OR" | 154 | +-------------+ 155 | +-------------+ operands +----------------+ 156 | | +-------------+ | 157 | | | 158 | +-----+-------+ +-----+-------+ 159 | | | | | 160 | | op: "AND" | | op: "AND" | 161 | +-------------+ +-------------+ 162 | +--------+ operands +----+ +-------+ operands +-------+ 163 | | +-------------+ | | +-------------+ | 164 | | | | | 165 | +----------+---------+ +--------------+------+ +------+-------------+ +-----------+---------+ 166 | | | | | | | | | 167 | | | | | | | | | 168 | | field_name: "type" | | field_name: "likes" | | field_name: "type" | | field_name: "likes" | 169 | | op: "==" | | op: ">=" | | op: "==" | | op: "!=" | 170 | | field_value: "blog"| | field_value: 500 | | field_value: "news"| | field_value: 100 | 171 | +--------------------+ +---------------------+ +--------------------+ +---------------------+ 172 | ``` 173 | 174 | Having such a tree DocumentStore should be able to traverse it and interpret into a Document Store specific syntax. 175 | """ 176 | 177 | def __init__(self, flatten_field_name=True) -> None: 178 | self.flatten_field_name = flatten_field_name 179 | 180 | def comparison_op(self, field_name: str, op: str, field_value: FieldValueType) -> ComparisonOp: 181 | return ComparisonOp(field_name, op, field_value) 182 | 183 | def logical_op(self, op: str, operands: Sequence[AST]) -> LogicalOp: 184 | return LogicalOp(operands, op) 185 | 186 | def combine(self, *operands: Optional[OperatorAST], default_op: str = LOGICAL_OPS.OP_AND) -> Optional[OperatorAST]: 187 | """ 188 | Combines several operands (standalone filter values) into a logical operator if number of operands is greater 189 | than one. Operands with `None` value are skipped (e.g. empty strings). 190 | 191 | Args: 192 | default_op: Default operator to be used to construct the `LogicalOp`, defaults to `OP_AND` 193 | 194 | Returns: 195 | Logical operator with `operands` or the operand itself if it is the only provided in arguments. 196 | """ 197 | valid_operands = [op for op in operands if op] 198 | 199 | if len(valid_operands) == 0: 200 | return None 201 | 202 | return valid_operands[0] if len(valid_operands) == 1 else self.logical_op(default_op, valid_operands) 203 | 204 | def _parse_comparison_op(self, filters: FilterType) -> ComparisonOp: 205 | """ 206 | Parsing a comparison operator dictionary. 207 | 208 | Args: 209 | filters: Comparison filter dictionary with `field`, `operator` and `value` keys expected. 210 | 211 | Raises: 212 | Neo4jFilterParserError: If required `field` or `value` comparison dictionary keys are missing. 213 | 214 | Returns: 215 | `AST` node representing the comparison expression in abstract syntax tree. 216 | """ 217 | 218 | operator = filters["operator"] 219 | field_name = filters.get("field") 220 | filter_value = filters.get("value") 221 | 222 | if not field_name: 223 | raise Neo4jFilterParserError(f"`field` is mandatory in comparison filter dictionary: `{filters}`.") 224 | 225 | if filter_value is None: 226 | raise Neo4jFilterParserError(f"`value` is mandatory in comparison filter dictionary: `{filters}`.") 227 | 228 | return self.comparison_op(field_name, operator, filter_value) 229 | 230 | def _parse_logical_op(self, filters: FilterType) -> LogicalOp: 231 | """ 232 | This method is responsible of parsing logical operators. It returns `LogicalOp` with specific operator 233 | (e.g. "OR") and its operands (list of `AST` nodes). Operands are parsed by calling `:::py self._parse_tree` 234 | which might result in further recursive parsing flow. 235 | 236 | Args: 237 | filters: Logical filter with conditions. 238 | 239 | Raises: 240 | Neo4jFilterParserError: If required `conditions` logic dictionary key is missing or is not a `list` with at 241 | least one item. 242 | 243 | Returns: 244 | The instance of `LogicalOp` with the list parsed operands. 245 | """ 246 | 247 | op = filters["operator"] 248 | conditions = filters.get("conditions") 249 | 250 | if not conditions or not isinstance(conditions, list) or len(conditions) < 1: 251 | raise Neo4jFilterParserError("Can not parse logical operator with empty or absent conditions") 252 | 253 | operands: List[AST] = [self._parse_tree(condition) for condition in conditions] 254 | 255 | return self.logical_op(op, operands) 256 | 257 | def _parse_tree(self, filters: FilterType) -> OperatorAST: 258 | """ 259 | This parses filters dictionary and identifies operations based on its "operator" type, 260 | e.g. "AND" would be resolved to a logical operator type (`OpType.LOGICAL`). Once recognized the parsing 261 | of the operator and its filter will be delegated to a respective method (e.g. `self._parse_logical_op` or 262 | `self._parse_comparison_op`). 263 | 264 | Args: 265 | filters: Metadata filters dictionary. Could be the full filter value or a smaller filters chunk. 266 | 267 | Raises: 268 | Neo4jFilterParserError: If required `operator` dictionary key is missing. 269 | Neo4jFilterParserError: If `filters` is not a dictionary. 270 | Neo4jFilterParserError: If operator value is unknown. 271 | 272 | Returns: 273 | A root `AST` node of parsed filter. 274 | """ 275 | 276 | if not isinstance(filters, dict): 277 | raise Neo4jFilterParserError("Filter must be a dictionary.") 278 | 279 | if "operator" not in filters: 280 | raise Neo4jFilterParserError("`operator` must ne present in both comparison and logic dictionaries.") 281 | 282 | operator = filters["operator"] 283 | op_type = OpType.from_op(operator) 284 | 285 | if op_type == OpType.LOGICAL: 286 | return self._parse_logical_op(filters) 287 | elif op_type == OpType.COMPARISON: 288 | return self._parse_comparison_op(filters) 289 | else: 290 | raise Neo4jFilterParserError( 291 | f"Unknown operator({operator}) in filter dictionary. Should be either " 292 | f"comparison({COMPARISON_OPS}) or logical({LOGICAL_OPS})" 293 | ) 294 | 295 | def parse(self, filters: FilterType) -> OperatorAST: 296 | """ 297 | This is the entry point to parse a given metadata filter into an abstract syntax tree. The implementation 298 | delegates the parsing logic to the private `self._parse_tree` method. 299 | 300 | Args: 301 | filters: Metadata filters to be parsed. 302 | 303 | Raises: 304 | Neo4jFilterParserError: In case parsing results in empty `AST` tree or more than one root node has been 305 | created after parsing. 306 | 307 | Returns: 308 | Abstract syntax tree representing `filters`. You should expect a single root operator returned, which 309 | could used to traverse the whole tree. 310 | """ 311 | 312 | return self._parse_tree(filters) 313 | -------------------------------------------------------------------------------- /src/neo4j_haystack/components/neo4j_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, cast 2 | 3 | from haystack import ( 4 | ComponentError, 5 | Document, 6 | component, 7 | default_from_dict, 8 | default_to_dict, 9 | ) 10 | 11 | from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig 12 | from neo4j_haystack.document_stores import Neo4jDocumentStore 13 | 14 | 15 | @component 16 | class Neo4jEmbeddingRetriever: 17 | """ 18 | A component for retrieving documents from Neo4jDocumentStore. 19 | 20 | ```py title="Retrieving documents assuming documents have been previously indexed" 21 | from haystack import Document, Pipeline 22 | from haystack.components.embedders import SentenceTransformersTextEmbedder 23 | 24 | from neo4j_haystack import Neo4jDocumentStore, Neo4jEmbeddingRetriever 25 | 26 | model_name = "sentence-transformers/all-MiniLM-L6-v2" 27 | 28 | # Document store with default credentials 29 | document_store = Neo4jDocumentStore( 30 | url="bolt://localhost:7687", 31 | embedding_dim=384, # same as the embedding model 32 | ) 33 | 34 | pipeline = Pipeline() 35 | pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model=model_name)) 36 | pipeline.add_component("retriever", Neo4jEmbeddingRetriever(document_store=document_store)) 37 | pipeline.connect("text_embedder.embedding", "retriever.query_embedding") 38 | 39 | result = pipeline.run( 40 | data={ 41 | "text_embedder": {"text": "Query to be embedded"}, 42 | "retriever": { 43 | "top_k": 5, 44 | "filters": {"field": "release_date", "operator": "==", "value": "2018-12-09"}, 45 | }, 46 | } 47 | ) 48 | 49 | # Obtain retrieved documents from pipeline execution 50 | documents: List[Document] = result["retriever"]["documents"] 51 | ``` 52 | """ 53 | 54 | def __init__( 55 | self, 56 | document_store: Neo4jDocumentStore, 57 | filters: Optional[Dict[str, Any]] = None, 58 | top_k: int = 10, 59 | scale_score: bool = True, 60 | return_embedding: bool = False, 61 | ): 62 | """ 63 | Create a Neo4jEmbeddingRetriever component. 64 | 65 | Args: 66 | document_store: An instance of `Neo4jDocumentStore`. 67 | filters: A dictionary with filters to narrow down the search space. 68 | top_k: The maximum number of documents to retrieve. 69 | scale_score: Whether to scale the scores of the retrieved documents or not. 70 | return_embedding: Whether to return the embedding of the retrieved Documents. 71 | 72 | Raises: 73 | ValueError: If `document_store` is not an instance of `Neo4jDocumentStore`. 74 | """ 75 | 76 | if not isinstance(document_store, Neo4jDocumentStore): 77 | msg = "document_store must be an instance of Neo4jDocumentStore" 78 | raise ValueError(msg) 79 | 80 | self._document_store = document_store 81 | 82 | self._filters = filters 83 | self._top_k = top_k 84 | self._scale_score = scale_score 85 | self._return_embedding = return_embedding 86 | 87 | def to_dict(self) -> Dict[str, Any]: 88 | """ 89 | Serialize this component to a dictionary. 90 | """ 91 | data = default_to_dict( 92 | self, 93 | document_store=self._document_store, 94 | filters=self._filters, 95 | top_k=self._top_k, 96 | scale_score=self._scale_score, 97 | return_embedding=self._return_embedding, 98 | ) 99 | data["init_parameters"]["document_store"] = self._document_store.to_dict() 100 | 101 | return data 102 | 103 | @classmethod 104 | def from_dict(cls, data: Dict[str, Any]) -> "Neo4jEmbeddingRetriever": 105 | """ 106 | Deserialize this component from a dictionary. 107 | """ 108 | document_store = Neo4jDocumentStore.from_dict(data["init_parameters"]["document_store"]) 109 | data["init_parameters"]["document_store"] = document_store 110 | return default_from_dict(cls, data) 111 | 112 | @component.output_types(documents=List[Document]) 113 | def run( 114 | self, 115 | query_embedding: List[float], 116 | filters: Optional[Dict[str, Any]] = None, 117 | top_k: Optional[int] = None, 118 | scale_score: Optional[bool] = None, 119 | return_embedding: Optional[bool] = None, 120 | ): 121 | """ 122 | Run the Embedding Retriever on the given input data. 123 | 124 | Args: 125 | query_embedding: Embedding of the query. 126 | filters: A dictionary with filters to narrow down the search space. 127 | top_k: The maximum number of documents to return. 128 | scale_score: Whether to scale the scores of the retrieved documents or not. 129 | return_embedding: Whether to return the embedding of the retrieved Documents. 130 | 131 | Returns: 132 | The retrieved documents. 133 | """ 134 | docs = self._document_store.query_by_embedding( 135 | query_embedding=query_embedding, 136 | filters=filters or self._filters, 137 | top_k=top_k or self._top_k, 138 | scale_score=scale_score or self._scale_score, 139 | return_embedding=return_embedding or self._return_embedding, 140 | ) 141 | 142 | return {"documents": docs} 143 | 144 | 145 | @component 146 | class Neo4jDynamicDocumentRetriever: 147 | """ 148 | A component for retrieving Documents from Neo4j database using plain Cypher query. 149 | 150 | This component gives flexible way to retrieve data from Neo4j by running arbitrary Cypher query along with query 151 | parameters. Query parameters can be supplied in a pipeline from other components (or pipeline data). 152 | 153 | See the following documentation on how to compose Cypher queries with parameters: 154 | 155 | - [Overview of Cypher query syntax](https://neo4j.com/docs/cypher-manual/current/queries/) 156 | - [Cypher Query Parameters](https://neo4j.com/docs/cypher-manual/current/syntax/parameters/) 157 | 158 | Above are resources which will help understand better Cypher query syntax and parameterization. Under the hood 159 | [Neo4j Python Driver](https://neo4j.com/docs/python-manual/current/) is used to query database and fetch results. 160 | You might be interested in the following documentation: 161 | 162 | - [Query the database](https://neo4j.com/docs/python-manual/current/query-simple/) 163 | - [Query parameters](https://neo4j.com/docs/python-manual/current/query-simple/#query-parameters) 164 | - [Data types and mapping to Cypher types](https://neo4j.com/docs/python-manual/current/data-types/) 165 | 166 | Note: 167 | Please consider data types mappings in Cypher query when working with parameters. Neo4j Python Driver handles 168 | type conversions/mappings. Specifically you can figure out in the documentation of the driver how to work with 169 | temporal types (e.g. `DateTime`). 170 | 171 | Query execution results will be mapped/converted to `haystack.Document` type. See more details in the 172 | [RETURN clause](https://neo4j.com/docs/cypher-manual/current/clauses/return/) documentation. There are two 173 | ways how Documents are being composed from query results. 174 | 175 | (1) Converting documents from [nodes](https://neo4j.com/docs/cypher-manual/current/clauses/return/#return-nodes) 176 | 177 | ```py title="Convert Neo4j `node` to `haystack.Document`" 178 | client_config = Neo4jClientConfig( 179 | "bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd" 180 | ) 181 | 182 | retriever = Neo4jDynamicDocumentRetriever( 183 | client_config=client_config, doc_node_name="doc", verify_connectivity=True 184 | ) 185 | 186 | result = retriever.run( 187 | query="MATCH (doc:Document) WHERE doc.year > $year OR doc.year is NULL RETURN doc", 188 | parameters={"year": 2020} 189 | ) 190 | documents: List[Document] = result["documents"] 191 | ``` 192 | 193 | Please notice how `doc_node_name` attribute assumes `"doc"` node is going to be returned from the query. 194 | `Neo4jDynamicDocumentRetriever` will convert properties of the node (e.g. `id`, `content` etc) to 195 | `haystack.Document` type. 196 | 197 | (2) Converting documents from query output keys (e.g. [column aliases](https://neo4j.com/docs/cypher-manual/current/clauses/return/#return-column-alias)) 198 | 199 | You might want to run a complex query which aggregates information from multiple sources (nodes) in Neo4j. In such 200 | case you can compose final Document from several dta points. 201 | 202 | ```py title="Convert Neo4j `node` to `haystack.Document`" 203 | # Configuration with default settings 204 | client_config=Neo4jClientConfig() 205 | 206 | retriever = Neo4jDynamicDocumentRetriever(client_config=client_config, compose_doc_from_result=True) 207 | 208 | result = retriever.run( 209 | query=( 210 | "MATCH (doc:Document) " 211 | "WHERE doc.year > $year OR doc.year is NULL " 212 | "RETURN doc.id as id, doc.content as content, doc.year as year" 213 | ), 214 | parameters={"year": 2020}, 215 | ) 216 | documents: List[Document] = result["documents"] 217 | ``` 218 | 219 | The above will produce Documents with `id`, `content` and `year`(meta) fields. Please notice 220 | `compose_doc_from_result` is set to `True` to enable such Document construction behavior. 221 | 222 | Below is an example of a pipeline which explores all ways how parameters could be supplied to the 223 | `Neo4jDynamicDocumentRetriever` component in the pipeline. 224 | 225 | ```py 226 | @component 227 | class YearProvider: 228 | @component.output_types(year_start=int, year_end=int) 229 | def run(self, year_start: int, year_end: int): 230 | return {"year_start": year_start, "year_end": year_end} 231 | 232 | # Configuration with default settings 233 | client_config=Neo4jClientConfig() 234 | 235 | retriever = Neo4jDynamicDocumentRetriever( 236 | client_config=client_config, 237 | runtime_parameters=["year_start", "year_end"], 238 | ) 239 | 240 | query = ( 241 | "MATCH (doc:Document) " 242 | "WHERE (doc.year >= $year_start and doc.year <= $year_end) AND doc.month = $month" 243 | "RETURN doc LIMIT $num_return" 244 | ) 245 | 246 | pipeline = Pipeline() 247 | pipeline.add_component("year_provider", YearProvider()) 248 | pipeline.add_component("retriever", retriever) 249 | pipeline.connect("year_provider.year_start", "retriever.year_start") 250 | pipeline.connect("year_provider.year_end", "retriever.year_end") 251 | 252 | result = pipeline.run( 253 | data={ 254 | "year_provider": {"year_start": 2020, "year_end": 2021}, 255 | "retriever": { 256 | "query": query, 257 | "parameters": { 258 | "month": "02", 259 | "num_return": 2, 260 | }, 261 | }, 262 | } 263 | ) 264 | 265 | documents = result["retriever"]["documents"] 266 | ``` 267 | 268 | Please notice the following from the example above: 269 | 270 | - `runtime_parameters` is a list of parameter names which are going to be input slots when connecting components 271 | in a pipeline. In our case `year_start` and `year_end` parameters flow from the `year_provider` component into 272 | `retriever`. The `query` uses those parameters in the `WHERE` clause. 273 | - `pipeline.run` specifies additional parameters to the `retriever` component which can be referenced in the 274 | `query`. If parameter names clash those provided in the pipeline's data take precedence. 275 | """ 276 | 277 | def __init__( 278 | self, 279 | client_config: Neo4jClientConfig, 280 | query: Optional[str] = None, 281 | runtime_parameters: Optional[List[str]] = None, 282 | doc_node_name: Optional[str] = "doc", 283 | compose_doc_from_result: Optional[bool] = False, 284 | verify_connectivity: Optional[bool] = False, 285 | ): 286 | """ 287 | Create a Neo4jDynamicDocumentRetriever component. 288 | 289 | Args: 290 | client_config: Neo4j client configuration to connect to database (e.g. credentials and connection settings). 291 | query: Optional Cypher query for document retrieval. If `None` should be provided as component input. 292 | runtime_parameters: list of input parameters/slots for connecting components in a pipeline. 293 | doc_node_name: the name of the variable which is returned from Cypher query which contains Document 294 | attributes (e.g. `id`, `content`, `meta` fields). 295 | compose_doc_from_result: If `True` Document attributes will be constructed from Cypher query outputs (keys). 296 | `doc_node_name` setting will be ignored in this case. 297 | verify_connectivity: If `True` will verify connectivity with Neo4j database configured by `client_config`. 298 | 299 | Raises: 300 | ComponentError: In case neither `compose_doc_from_result` nor `doc_node_name` are defined. 301 | """ 302 | if not compose_doc_from_result and not doc_node_name: 303 | raise ComponentError( 304 | "Please specify how Document is being composed out of Neo4j query response. " 305 | "With `compose_doc_from_result` set to `True` documents will be created out of properties/keys " 306 | "returned by the query." 307 | ) 308 | 309 | self._client_config = client_config 310 | self._query = query 311 | self._runtime_parameters = runtime_parameters or [] 312 | self._doc_node_name = doc_node_name 313 | self._compose_doc_from_result = compose_doc_from_result 314 | self._verify_connectivity = verify_connectivity 315 | 316 | self._neo4j_client = Neo4jClient(client_config) 317 | 318 | # setup inputs 319 | run_input_slots = {"query": str, "parameters": Optional[Dict[str, Any]]} 320 | kwargs_input_slots = {param: Optional[Any] for param in self._runtime_parameters} 321 | component.set_input_types(self, **run_input_slots, **kwargs_input_slots) 322 | 323 | # setup outputs 324 | component.set_output_types(self, documents=List[Document]) 325 | 326 | if verify_connectivity: 327 | self._neo4j_client.verify_connectivity() 328 | 329 | def to_dict(self) -> Dict[str, Any]: 330 | """ 331 | Serialize this component to a dictionary. 332 | """ 333 | data = default_to_dict( 334 | self, 335 | query=self._query, 336 | runtime_parameters=self._runtime_parameters, 337 | doc_node_name=self._doc_node_name, 338 | compose_doc_from_result=self._compose_doc_from_result, 339 | verify_connectivity=self._verify_connectivity, 340 | ) 341 | 342 | data["init_parameters"]["client_config"] = self._client_config.to_dict() 343 | 344 | return data 345 | 346 | @classmethod 347 | def from_dict(cls, data: Dict[str, Any]) -> "Neo4jDynamicDocumentRetriever": 348 | """ 349 | Deserialize this component from a dictionary. 350 | """ 351 | client_config = Neo4jClientConfig.from_dict(data["init_parameters"]["client_config"]) 352 | data["init_parameters"]["client_config"] = client_config 353 | return default_from_dict(cls, data) 354 | 355 | def run(self, query: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, **kwargs: Dict[str, Any]): 356 | """ 357 | Runs the arbitrary Cypher `query` with `parameters` and returns Documents. 358 | 359 | Args: 360 | query: Cypher query to run. 361 | parameters: Cypher query parameters which can be used as placeholders in the `query`. 362 | kwargs: Arbitrary parameters supplied in a pipeline execution from other component's output slots, e.g. 363 | `pipeline.connect("year_provider.year_start", "retriever.year_start")`, where `year_start` will be part 364 | of `kwargs`. 365 | 366 | Returns: 367 | Retrieved documents. 368 | """ 369 | query = query or self._query 370 | if query is None: 371 | raise ValueError( 372 | "`query` is mandatory input and should be provided either in component's constructor, pipeline input or" 373 | "connection" 374 | ) 375 | kwargs = kwargs or {} 376 | parameters = parameters or {} 377 | parameters_combined = {**kwargs, **parameters} 378 | 379 | documents: List[Document] = [] 380 | neo4j_query_result = self._neo4j_client.query_nodes(query, parameters_combined) 381 | 382 | for record in neo4j_query_result: 383 | data = record.data() 384 | document_dict = data if self._compose_doc_from_result else data.get(cast(str, self._doc_node_name)) 385 | documents.append(Document.from_dict(document_dict)) 386 | 387 | return {"documents": documents} 388 | -------------------------------------------------------------------------------- /tests/neo4j_haystack/samples/movies.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "670470", 4 | "content": "When a young man moves across the hall from his high school crush\u00a0basic lust turns to obsession and voyeurism will not be enough to quench his desire for her!", 5 | "meta": { 6 | "title": "She's Mine", 7 | "runtime": 78.0, 8 | "vote_average": 10.0, 9 | "release_date": "2018-01-01", 10 | "genres": ["Drama", "Horror", "Romance"] 11 | } 12 | }, 13 | { 14 | "id": "484543", 15 | "content": "The making of Jerzy Kosinski. The BBC documentary on the life and art of enigmatic novelist Jerzy Kosinski. Through interviews with his second wife Kiki von Fraunhofer-Kosinski friends and fellow authors and Polish villagers who knew Kosinski when he was a child hiding from the scourge of Nazism this program attempts to assess the verity of Kosinski's \"autobiographical\" fiction the need for him to maintain a nebulous mystique about his early life and to understand his obsession with S&M sex clubs in Manhattan during the 1970s and 1980s.", 16 | "meta": { 17 | "title": "Sex, Lies and Jerzy Kosinski", 18 | "runtime": 59.0, 19 | "vote_average": 0.0, 20 | "release_date": "1995-04-11", 21 | "genres": ["Documentary"] 22 | } 23 | }, 24 | { 25 | "id": "253965", 26 | "content": "The cartoon chronicles Hitler's birth brief childhood and eventual rise to power.", 27 | "meta": { 28 | "title": "Bury the Axis", 29 | "runtime": 6.0, 30 | "vote_average": 7.0, 31 | "release_date": "1943-02-03", 32 | "genres": ["Animation"] 33 | } 34 | }, 35 | { 36 | "id": "776813", 37 | "content": "I lived off Lombard Street in a private terrace apartment overlooking the Golden Gate Bridge for free. I had free rent because I was the caretaker for the owner of the building who lived upstairs. Her name was Stephanie Bauer. She was about 80 years old. It was a very ideal situation except for the fact that my neighbor was a handsome French man who had loud sex most nights.", 38 | "meta": { 39 | "title": "What a Boring and Disappointing Life (brown)", 40 | "runtime": 21.0, 41 | "vote_average": 0.0, 42 | "release_date": "2017-01-01", 43 | "genres": ["Documentary"] 44 | } 45 | }, 46 | { 47 | "id": "741789", 48 | "content": "Set to Leonard Cohen's 1988 song \"I'm Your Man\" this animated short offers a playful meditation on romance and the clich\u00e9s that go with it.", 49 | "meta": { 50 | "title": "I'm Your Man", 51 | "runtime": 4.0, 52 | "vote_average": 0.0, 53 | "release_date": "1996-10-14", 54 | "genres": ["Animation", "Music"] 55 | } 56 | }, 57 | { 58 | "id": "392408", 59 | "content": "1969-\u201cEstirpe\u201d was the first Spanish graphic novel by a troubled youngster under a pseudonym. He mysteriously disappeared after that. 44 years later a production company that wanted to adapt it to the big screen hired Mar\u00eda a young lawyer to find the author and get the rights from him.", 60 | "meta": { 61 | "title": "Estirpe", 62 | "runtime": 0.0, 63 | "vote_average": 0.0, 64 | "release_date": "2016-04-23", 65 | "genres": ["Science Fiction"] 66 | } 67 | }, 68 | { 69 | "id": "420543", 70 | "content": "A group of workers decide to join the army in the Great War. The indulge themselves in the side benefits to being soldiers and one of them marries a French waitress.", 71 | "meta": { 72 | "title": "Carry on, Sergeant!", 73 | "runtime": 107.0, 74 | "vote_average": 5.4, 75 | "release_date": "1928-11-10", 76 | "genres": ["Drama", "War"] 77 | } 78 | }, 79 | { 80 | "id": "2280", 81 | "content": "When a young boy makes a wish at a carnival machine to be big\u2014he wakes up the following morning to find that it has been granted and his body has grown older overnight. But he is still the same 13-year-old boy inside. Now he must learn how to cope with the unfamiliar world of grown-ups including getting a job and having his first romantic encounter with a woman.", 82 | "meta": { 83 | "title": "Big", 84 | "runtime": 104.0, 85 | "vote_average": 7.148, 86 | "release_date": "1988-06-03", 87 | "genres": ["Fantasy", "Drama", "Comedy", "Romance", "Family"] 88 | } 89 | }, 90 | { 91 | "id": "686860", 92 | "content": "Russ a forty-something acting teacher advises his students of the impending arrival of casting director Lolly. In preparation the remaining classes provide an opportunity to rehearse their selected audition scenes from cult and classic movies.", 93 | "meta": { 94 | "title": "Class", 95 | "runtime": 10.0, 96 | "vote_average": 0.0, 97 | "release_date": "2020-01-18", 98 | "genres": ["Comedy"] 99 | } 100 | }, 101 | { 102 | "id": "744092", 103 | "content": "The primacy of culture. Culture over race. \u201cculture leap (non-linear)\u201d is composed of found home movies from South Africa and discarded 16mm film from the anthropology department at the University of Cape Town salvaged by filmmaker Roger Horn while completing his PhD in Anthropology.", 104 | "meta": { 105 | "title": "culture leap (non-linear)", 106 | "runtime": 3.0, 107 | "vote_average": 0.0, 108 | "release_date": "2021-09-09", 109 | "genres": ["Documentary"] 110 | } 111 | }, 112 | { 113 | "id": "93753", 114 | "content": "Lila a college student discovers a latent ability to see into the past. When a professor hires her for some psychological tests she becomes the target of the killer of the woman she sees in her dreams.", 115 | "meta": { 116 | "title": "Sensation", 117 | "runtime": 102.0, 118 | "vote_average": 3.4, 119 | "release_date": "1994-10-06", 120 | "genres": ["Thriller", "Mystery"] 121 | } 122 | }, 123 | { 124 | "id": "704911", 125 | "content": "On the 18th May 2015 Peter Hook & the Light preformed a special Joy Division show live at Christ Church Ian Curtis\u2019s boyhood church in Macclesfield England. This concert premiered 40 years after Curtis' death to celebrate the life of Curtis. The set list contains the entire Joy Division/Warsaw discography played in full.", 126 | "meta": { 127 | "title": "Peter Hook & The Light: So This Is Permanent", 128 | "runtime": 180.0, 129 | "vote_average": 10.0, 130 | "release_date": "2020-05-18", 131 | "genres": ["Music"] 132 | } 133 | }, 134 | { 135 | "id": "451999", 136 | "content": "A 1916 film directed by Chester M. Franklin.", 137 | "meta": { 138 | "title": "Martha's Vindication", 139 | "runtime": 50.0, 140 | "vote_average": 0.0, 141 | "release_date": "1916-02-20", 142 | "genres": ["Drama"] 143 | } 144 | }, 145 | { 146 | "id": "922412", 147 | "content": "Jimmy Jump appears as a female impersonator in an amateur show who is forced to wear his costume home. He meets with varied experiences and is such an attractive looking \"woman\" that the men all want to make dates with him.", 148 | "meta": { 149 | "title": "The Perfect Lady", 150 | "runtime": 0.0, 151 | "vote_average": 0.0, 152 | "release_date": "1924-02-24", 153 | "genres": ["Comedy"] 154 | } 155 | }, 156 | { 157 | "id": "847547", 158 | "content": "Debra is the new intern assigned to the mail room of a large corporation. Megan and Brittney are the veteran mail room employees assigned to train her. As Debra learns to sort mail Megan and Brittney answer her occasional question but spend most of their time drinking coffee and sharing erotic tales about various employees in the different divisions of their corporation. It seems those two girls have all the dirt on who's doing who at the office.", 159 | "meta": { 160 | "title": "Intimate Touch", 161 | "runtime": 76.0, 162 | "vote_average": 1.0, 163 | "release_date": "2001-01-01", 164 | "genres": ["Drama"] 165 | } 166 | }, 167 | { 168 | "id": "613153", 169 | "content": "This short super 8 experimental film is composed of three elements: footage taken by LaBruce in a mosh pit at a Toronto hardcore punk show featuring the bands Scream The Mr. T. Experience and MDC; found footage of porn star Al Parker having penetrative anal sex; and a soundtrack by The Carpenters.", 170 | "meta": { 171 | "title": "Slam!", 172 | "runtime": 12.0, 173 | "vote_average": 0.0, 174 | "release_date": "1992-11-04", 175 | "genres": ["Documentary"] 176 | } 177 | }, 178 | { 179 | "id": "624341", 180 | "content": "Apocalypse at the Little Bighorn Custer's final battle.", 181 | "meta": { 182 | "title": "Contested Ground", 183 | "runtime": 47.0, 184 | "vote_average": 0.0, 185 | "release_date": "2016-01-27", 186 | "genres": ["Documentary", "History"] 187 | } 188 | }, 189 | { 190 | "id": "853959", 191 | "content": "A young lady baby sits for a mysterious woman and wakes up in a hospital with vivid memories of an occurrence she can't prove happened. It's left to an equally mysterious chain smoking detective to find out the truth. What happened? Did it happen?", 192 | "meta": { 193 | "title": "Eje", 194 | "runtime": 15.0, 195 | "vote_average": 0.0, 196 | "release_date": "2021-07-22", 197 | "genres": ["Horror", "Fantasy"] 198 | } 199 | }, 200 | { 201 | "id": "776697", 202 | "content": "After his father gets into a fight at a bowling alley Darious begins to investigate the limitations of his own manhood.", 203 | "meta": { 204 | "title": "Bruiser", 205 | "runtime": 10.0, 206 | "vote_average": 0.0, 207 | "release_date": "2021-03-16", 208 | "genres": ["Drama"] 209 | } 210 | }, 211 | { 212 | "id": "357370", 213 | "content": "The 23rd issue of the long running industry cinemagazine. Features the articles: 'Safety First' 'Paying For It' and ' A Star Drops In'.", 214 | "meta": { 215 | "title": "Mining Review 2nd Year No. 11", 216 | "runtime": 10.0, 217 | "vote_average": 0.0, 218 | "release_date": "1949-07-01", 219 | "genres": ["Documentary"] 220 | } 221 | }, 222 | { 223 | "id": "541018", 224 | "content": "From instilling fear in asylum inmates with brutal treatments to the modern application of restraint senseless drugging and electroshock psychiatrists have a long and hidden history of force intimidation and outright terror. This is where Germany enters the picture for it was here that psychiatry was born here where psychiatry was nurtured and grew and here where psychiatry would commit one of the world's most horrific atrocities.", 225 | "meta": { 226 | "title": "Age of Fear: Psychiatry's Reign of Terror", 227 | "runtime": 0.0, 228 | "vote_average": 0.0, 229 | "release_date": "2012-01-01", 230 | "genres": ["Documentary"] 231 | } 232 | }, 233 | { 234 | "id": "812977", 235 | "content": "With lush visual effects and evocative soundscape this movement tape is a poetic exploration of one man's fall from grace as he becomes disillusioned with gay ideals of beauty.", 236 | "meta": { 237 | "title": "Angel", 238 | "runtime": 5.0, 239 | "vote_average": 0.0, 240 | "release_date": "1998-08-09", 241 | "genres": ["Music"] 242 | } 243 | }, 244 | { 245 | "id": "467915", 246 | "content": "In Dublin two couples (Jim and Danielle; Yvonne and Chris) are seemingly living in marital bliss. However when Chris's behaviour begins to change Yvonne seeks solace in the arms of Jim and before long they are in the midst of an affair. When a life-changing secret is later revealed all four are forced to re-evaluate their lives their marriages and their friendships - but can anything be salvaged from the wreckage?", 247 | "meta": { 248 | "title": "The Delinquent Season", 249 | "runtime": 103.0, 250 | "vote_average": 6.6, 251 | "release_date": "2018-04-27", 252 | "genres": ["Drama", "Romance"] 253 | } 254 | }, 255 | { 256 | "id": "292767", 257 | "content": "An ex-convict tries to make an honest living and take care of his girlfriend and her mentally slow brother.", 258 | "meta": { 259 | "title": "The Clean and Narrow", 260 | "runtime": 79.0, 261 | "vote_average": 0.0, 262 | "release_date": "2000-01-21", 263 | "genres": ["Drama"] 264 | } 265 | }, 266 | { 267 | "id": "578681", 268 | "content": "Not everything that happens on a school playground is play. Short film about LGBT effects in school.", 269 | "meta": { 270 | "title": "They Say", 271 | "runtime": 19.0, 272 | "vote_average": 7.0, 273 | "release_date": "2011-03-12", 274 | "genres": ["Drama"] 275 | } 276 | }, 277 | { 278 | "id": "924525", 279 | "content": "On her birthday Carmela a 12-year-old girl is forced to meet her father in a family meeting centre due to the case of gender violence that he has brought up against her mother.", 280 | "meta": { 281 | "title": "Fed Up", 282 | "runtime": 23.0, 283 | "vote_average": 0.0, 284 | "release_date": "2021-12-03", 285 | "genres": ["Drama"] 286 | } 287 | }, 288 | { 289 | "id": "714270", 290 | "content": "'Chrono' is a short film dipped in red manipulating the concepts of time and isolation.", 291 | "meta": { 292 | "title": "Chrono", 293 | "runtime": 2.0, 294 | "vote_average": 0.0, 295 | "release_date": "2020-06-10", 296 | "genres": ["Science Fiction", "Thriller", "Horror"] 297 | } 298 | }, 299 | { 300 | "id": "363372", 301 | "content": "Based on the Buster Brown comic by R.F. Outcault.", 302 | "meta": { 303 | "title": "At the Dentist", 304 | "runtime": 3.0, 305 | "vote_average": 5.0, 306 | "release_date": "1918-02-25", 307 | "genres": ["Animation", "Comedy"] 308 | } 309 | }, 310 | { 311 | "id": "356229", 312 | "content": "Senza Peso a mini Opera created by world-class musicians and artists unveils a beautifully dark world of lost souls and redemption.", 313 | "meta": { 314 | "title": "Senza Peso", 315 | "runtime": 10.0, 316 | "vote_average": 0.0, 317 | "release_date": "2014-06-03", 318 | "genres": ["Drama", "Music"] 319 | } 320 | }, 321 | { 322 | "id": "939639", 323 | "content": "After a stock goes bust (cheese stock that\u2019s \u201cnot supposed to break it\u2019s cheese\u201d) Bob Pauley and Robert Russo must \u2018light up\u2019 and reconcile with their past each other and the situation with their wives.", 324 | "meta": { 325 | "title": "A Couple Of Regular Fellas", 326 | "runtime": 140.0, 327 | "vote_average": 10.0, 328 | "release_date": "2022-02-13", 329 | "genres": ["Drama", "Comedy"] 330 | } 331 | }, 332 | { 333 | "id": "942206", 334 | "content": "The solution to a murder hinges on two witnesses: a deaf woman and a blind man.", 335 | "meta": { 336 | "title": "A Voice in the Dark", 337 | "runtime": 50.0, 338 | "vote_average": 0.0, 339 | "release_date": "1921-06-05", 340 | "genres": ["Mystery"] 341 | } 342 | }, 343 | { 344 | "id": "498535", 345 | "content": "The repercussions are immediate when activist history professor Arthur Collins - out of frustration - assigns his class to design a plan to assassinate the President of the United States. Life becomes particularly tough for the professor's teaching assistant T.K. whose reluctance to choose a side plants the young man right in the middle.", 346 | "meta": { 347 | "title": "The Paradigm Shift", 348 | "runtime": 26.0, 349 | "vote_average": 0.0, 350 | "release_date": "2008-02-15", 351 | "genres": ["Comedy", "Drama"] 352 | } 353 | }, 354 | { 355 | "id": "807484", 356 | "content": "A gang member a hustler and a small-time dealer discover that walking out of the prison gates is just the beginning of the journey to go straight.", 357 | "meta": { 358 | "title": "A Hard Straight", 359 | "runtime": 75.0, 360 | "vote_average": 0.0, 361 | "release_date": "2004-03-18", 362 | "genres": ["Documentary"] 363 | } 364 | }, 365 | { 366 | "id": "589366", 367 | "content": "Director Jeanie Finlay charts a transgender man's path to parenthood after he decides to carry his child himself. The pregnancy prompts an unexpected and profound reckoning with conventions of masculinity self-definition and biology", 368 | "meta": { 369 | "title": "Seahorse: The Dad Who Gave Birth", 370 | "runtime": 91.0, 371 | "vote_average": 6.5, 372 | "release_date": "2020-01-28", 373 | "genres": ["Documentary"] 374 | } 375 | }, 376 | { 377 | "id": "624669", 378 | "content": "A film student robs a bank under the guise of shooting a short film about a bank robbery.", 379 | "meta": { 380 | "title": "Bank Robbery Student Film", 381 | "runtime": 10.0, 382 | "vote_average": 0.0, 383 | "release_date": "2018-12-09", 384 | "genres": ["Comedy", "Crime"] 385 | } 386 | }, 387 | { 388 | "id": "926665", 389 | "content": "Two hikers find themselves in a reluctant cuddle.", 390 | "meta": { 391 | "title": "A Disappointing Excursion", 392 | "runtime": 4.0, 393 | "vote_average": 0.0, 394 | "release_date": "2017-01-01", 395 | "genres": ["Animation"] 396 | } 397 | }, 398 | { 399 | "id": "156994", 400 | "content": "A soldier is forced to take estrogen and wear lingerie when he's blackmailed by a violent transvestite.", 401 | "meta": { 402 | "title": "She-Man: A Story of Fixation", 403 | "runtime": 68.0, 404 | "vote_average": 5.1, 405 | "release_date": "1967-09-01", 406 | "genres": ["Comedy", "Drama"] 407 | } 408 | }, 409 | { 410 | "id": "769419", 411 | "content": "An ex-marine in search of some defining life direction unexpectedly faces a hurdle when his jeep breaks down in a sparsely populated Wyoming town. While waiting for parts to repair the vehicle he takes a temporary job as a rest area attendant and moves onto a ranch run by a one-legged widow with a mentally handicapped son. When an abandoned Vietnamese girl joins the group the scene is ready for confrontation and tragedy.", 412 | "meta": { 413 | "title": "Follow Your Heart", 414 | "runtime": 95.0, 415 | "vote_average": 0.0, 416 | "release_date": "1990-04-02", 417 | "genres": ["Drama", "TV Movie"] 418 | } 419 | }, 420 | { 421 | "id": "665502", 422 | "content": "2018 will be a long remember year in the World Rally Championship. Not only because the new generation WRCs are so spectacular but also because it was exciting until the last round in Australia. All four manufacturers could win at least one of the rounds and with Neuville Ogier and T\u00e4nak three drivers on three different cars had a chance to take the drivers\u2019 title when the championship arrived in Australia. Not to forget the sensational part-time comeback of S\u00e9bastien Loeb and Turkey hosting a new. Visually amazing event. This official film shows the highlights from all 13 rounds of the World Rally Championship 2018.", 423 | "meta": { 424 | "title": "WRC 2018 - FIA World Rally Championship", 425 | "runtime": 476.0, 426 | "vote_average": 0.0, 427 | "release_date": "2018-12-31", 428 | "genres": ["Documentary"] 429 | } 430 | }, 431 | { 432 | "id": "333130", 433 | "content": "Marcus Chris and Evan find themselves stuck in any normal day when without warning zombies begin to infest the city. Not knowing exactly what to do the three attempt to seek shelter at Marcus' house but first they have to get there. Facing the undead Marcus Chris and Evan try to survive the zombies and each other.", 434 | "meta": { 435 | "title": "Better Off Undead", 436 | "runtime": 29.0, 437 | "vote_average": 0.0, 438 | "release_date": "2007-04-17", 439 | "genres": ["Horror", "Comedy"] 440 | } 441 | } 442 | ] 443 | --------------------------------------------------------------------------------