├── .gitignore ├── executor ├── __init__.py ├── commons.py ├── postgres_indexer.py ├── hnswlib_searcher.py ├── hnswpsql.py └── postgreshandler.py ├── tests ├── docker_args.txt ├── docker-compose.yml ├── conftest.py ├── unit │ └── test_basic.py └── integration │ ├── test_sync.py │ └── test_hnsw_psql.py ├── config.yml ├── requirements.txt ├── .pre-commit-config.yaml ├── manifest.yml ├── .github └── workflows │ ├── ci.yml │ └── cd.yml ├── Dockerfile └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | .pytest_cache/ 3 | .jina/ 4 | -------------------------------------------------------------------------------- /executor/__init__.py: -------------------------------------------------------------------------------- 1 | from .hnswpsql import HNSWPostgresIndexer 2 | -------------------------------------------------------------------------------- /tests/docker_args.txt: -------------------------------------------------------------------------------- 1 | --uses-with dry_run=True startup_sync=False -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | jtype: HNSWPostgresIndexer 2 | metas: 3 | py_modules: 4 | - executor/__init__.py 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | psycopg2-binary>=2.8.6 2 | pytest-mock 3 | hnswlib>=0.6.0 4 | bidict 5 | pre-commit 6 | backports.zoneinfo 7 | tqdm 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | types: [python] 7 | args: 8 | - -S -------------------------------------------------------------------------------- /manifest.yml: -------------------------------------------------------------------------------- 1 | manifest_version: 1 2 | name: HNSWPostgresIndexer 3 | description: A complete indexer based on HNSW and PostgreSQL 4 | url: https://github.com/jina-ai/executor-hnsw-postgres 5 | keywords: [hnswlib, postgres, indexer, ann, hnsw] 6 | -------------------------------------------------------------------------------- /tests/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | psql: 4 | image: postgres:13.2 5 | ports: 6 | - "5432:5432" 7 | expose: 8 | - 10000-60000 9 | environment: 10 | - POSTGRES_PASSWORD=123456 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | linting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: install black 10 | run: pip install black==22.3.0 11 | - name: Black check formatting 12 | run: black --check -S . 13 | call-external: 14 | uses: jina-ai/workflows-executors/.github/workflows/ci.yml@master -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | release: 8 | types: 9 | - created 10 | workflow_dispatch: 11 | # pull_request: 12 | # uncomment the above to test CD in a PR 13 | 14 | jobs: 15 | call-external: 16 | uses: jina-ai/workflows-executors/.github/workflows/cd.yml@master 17 | with: 18 | event_name: ${{ github.event_name }} 19 | secrets: 20 | jinahub_token: ${{ secrets.JINAHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # DO NOT DELETE THIS DOCKERFILE 2 | # required because hnsw requires a specific gcc 3 | 4 | FROM jinaai/jina:3-py37-perf 5 | 6 | COPY . /workspace 7 | WORKDIR /workspace 8 | 9 | # install GCC compiler 10 | RUN apt-get update && apt-get install --no-install-recommends -y build-essential git \ 11 | && rm -rf /var/lib/apt/lists/* 12 | 13 | # install the third-party requirements 14 | RUN pip install --compile --no-cache-dir \ 15 | -r requirements.txt 16 | 17 | ENTRYPOINT ["jina", "executor", "--uses", "config.yml"] 18 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import pytest 6 | from jina import Document 7 | 8 | 9 | @pytest.fixture() 10 | def docker_compose(request): 11 | os.system( 12 | f'docker-compose -f {request.param} --project-directory . up --build -d ' 13 | f'--remove-orphans' 14 | ) 15 | time.sleep(5) 16 | yield 17 | os.system( 18 | f'docker-compose -f {request.param} --project-directory . down ' 19 | f'--remove-orphans' 20 | ) 21 | 22 | 23 | @pytest.fixture() 24 | def get_documents(): 25 | def get_documents_inner(nr=10, index_start=0, emb_size=7): 26 | random_batch = np.random.random([nr, emb_size]).astype(np.float32) 27 | for i in range(index_start, nr + index_start): 28 | d = Document() 29 | d.id = f'aa{i}' # to test it supports non-int ids 30 | d.embedding = random_batch[i - index_start] 31 | yield d 32 | 33 | return get_documents_inner 34 | 35 | 36 | @pytest.fixture() 37 | def runtime_args(): 38 | return {'shard_id': 0, 'replica_id': 0, 'shards': 1} 39 | -------------------------------------------------------------------------------- /tests/unit/test_basic.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os.path 3 | 4 | import pytest 5 | from executor.hnswpsql import HNSWPostgresIndexer, HnswlibSearcher, PostgreSQLStorage 6 | from jina import DocumentArray 7 | 8 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 9 | compose_yml = os.path.abspath(os.path.join( 10 | cur_dir, '..', 'docker-compose.yml')) 11 | 12 | 13 | def test_basic(runtime_args): 14 | indexer = HNSWPostgresIndexer( 15 | dry_run=True, startup_sync=False, runtime_args=runtime_args 16 | ) 17 | assert isinstance(indexer._vec_indexer, HnswlibSearcher) 18 | assert isinstance(indexer._kv_indexer, PostgreSQLStorage) 19 | assert indexer._init_kwargs is not None 20 | status = indexer.status() 21 | assert status['psql_docs'] is None 22 | assert status['hnsw_docs'] == 0.0 # protobuf converts ints to floats 23 | assert datetime.datetime.fromisoformat( 24 | status['last_sync'] 25 | ) == datetime.datetime.fromtimestamp(0, datetime.timezone.utc) 26 | 27 | 28 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 29 | def test_docker(docker_compose, get_documents, runtime_args): 30 | emb_size = 10 31 | 32 | docs = DocumentArray(get_documents(emb_size=emb_size)) 33 | 34 | indexer = HNSWPostgresIndexer(dim=emb_size, runtime_args=runtime_args) 35 | assert isinstance(indexer._vec_indexer, HnswlibSearcher) 36 | assert isinstance(indexer._kv_indexer, PostgreSQLStorage) 37 | assert indexer._init_kwargs is not None 38 | # test for empty sync from psql 39 | indexer.sync({}) 40 | 41 | indexer.index(docs, {}) 42 | 43 | search_docs = DocumentArray(get_documents( 44 | index_start=len(docs), emb_size=emb_size)) 45 | indexer.search(search_docs, {}) 46 | assert len(search_docs[0].matches) == 0 47 | 48 | indexer.sync({}) 49 | indexer.search(search_docs, {}) 50 | assert len(search_docs[0].matches) > 0 51 | 52 | indexer.clear() 53 | status = indexer.status() 54 | assert status['psql_docs'] == 0 55 | assert status['hnsw_docs'] == 0 56 | -------------------------------------------------------------------------------- /tests/integration/test_sync.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os.path 3 | import time 4 | 5 | import pytest 6 | from executor.hnswpsql import HNSWPostgresIndexer 7 | from jina import DocumentArray, Flow 8 | 9 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 10 | compose_yml = os.path.abspath(os.path.join( 11 | cur_dir, '..', 'docker-compose.yml')) 12 | 13 | METRIC = 'cosine' 14 | 15 | SHARDS = 2 16 | 17 | 18 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 19 | def test_sync(docker_compose, get_documents): 20 | def verify_status(f, expected_size_min): 21 | results = f.post('/status', None, return_responses=True) 22 | status_results = results[0].parameters["__results__"] 23 | assert len(status_results.values()) == SHARDS 24 | nr_hnsw_docs = sum(v['hnsw_docs'] for v in status_results.values()) 25 | for status in status_results.values(): 26 | psql_docs = int(status['psql_docs']) 27 | assert psql_docs >= expected_size_min 28 | assert int(nr_hnsw_docs) >= expected_size_min 29 | last_sync_raw = list(status_results.values())[0]['last_sync'] 30 | last_sync_timestamp = datetime.datetime.fromisoformat(last_sync_raw) 31 | return nr_hnsw_docs, last_sync_timestamp 32 | 33 | emb_size = 10 34 | nr_docs_batch = 3 35 | nr_runs = 4 36 | 37 | uses_with = {'dim': emb_size, 'sync_interval': 5} 38 | search_docs = DocumentArray( 39 | get_documents(index_start=nr_docs_batch * 40 | (nr_runs + 1), emb_size=emb_size) 41 | ) 42 | 43 | f = Flow().add( 44 | uses=HNSWPostgresIndexer, 45 | uses_with=uses_with, 46 | shards=SHARDS, 47 | polling='all' 48 | ) 49 | 50 | with f: 51 | nr_indexed_docs, last_sync_timestamp = verify_status(f, 0) 52 | 53 | for i in range(nr_runs): 54 | docs = get_documents( 55 | nr=nr_docs_batch, index_start=i * nr_docs_batch, emb_size=emb_size 56 | ) 57 | 58 | f.post('/index', docs) 59 | 60 | got_updated_docs = False 61 | for _ in range(50): 62 | search_docs = f.post( 63 | '/search', search_docs) 64 | assert len(search_docs[0].matches) >= nr_indexed_docs 65 | nr_indexed_docs, last_sync_timestamp = verify_status( 66 | f, nr_indexed_docs) 67 | if nr_indexed_docs == (i + 1) * nr_docs_batch: 68 | got_updated_docs = True 69 | break 70 | time.sleep(0.2) 71 | assert got_updated_docs 72 | -------------------------------------------------------------------------------- /executor/commons.py: -------------------------------------------------------------------------------- 1 | from typing import Generator,Tuple,Union,Optional,BinaryIO, TextIO 2 | import numpy as np 3 | import os 4 | import sys 5 | from tqdm import tqdm 6 | 7 | GENERATOR_TYPE = Generator[ 8 | Tuple[str, Union[np.ndarray, bytes], Optional[bytes]], None, None 9 | ] 10 | 11 | EMPTY_BYTES = b'' 12 | 13 | BYTE_PADDING = 4 14 | DUMP_DTYPE = np.float64 15 | 16 | 17 | def export_dump_streaming( 18 | path: str, 19 | shards: int, 20 | size: int, 21 | data: GENERATOR_TYPE, 22 | logger, 23 | ): 24 | """Export the data to a path, based on sharding, 25 | :param path: path to dump 26 | :param shards: the nr of shards this pea is part of 27 | :param size: total amount of entries 28 | :param data: the generator of the data (ids, vectors, metadata) 29 | """ 30 | logger.info(f'Dumping {size} docs to {path} for {shards} shards') 31 | _handle_dump(data, path, shards, size) 32 | 33 | 34 | def _handle_dump( 35 | data: GENERATOR_TYPE, 36 | path: str, 37 | shards: int, 38 | size: int, 39 | ): 40 | if not os.path.exists(path): 41 | os.makedirs(path) 42 | 43 | # directory must be empty to be safe 44 | if not os.listdir(path): 45 | size_per_shard = size // shards 46 | extra = size % shards 47 | shard_range = list(range(shards)) 48 | for shard_id in shard_range: 49 | if shard_id == shard_range[-1]: 50 | size_this_shard = size_per_shard + extra 51 | else: 52 | size_this_shard = size_per_shard 53 | _write_shard_data(data, path, shard_id, size_this_shard) 54 | else: 55 | raise Exception( 56 | f'path for dump {path} contains data. Please empty. Not dumping...' 57 | ) 58 | 59 | 60 | def _write_shard_data( 61 | data: GENERATOR_TYPE, 62 | path: str, 63 | shard_id: int, 64 | size_this_shard: int, 65 | ): 66 | shard_path = os.path.join(path, str(shard_id)) 67 | shard_docs_written = 0 68 | os.makedirs(shard_path) 69 | vectors_fp, metas_fp, ids_fp = _get_file_paths(shard_path) 70 | with open(vectors_fp, 'wb') as vectors_fh, open(metas_fp, 'wb') as metas_fh, open( 71 | ids_fp, 'w' 72 | ) as ids_fh: 73 | progress = tqdm(total=size_this_shard) 74 | while shard_docs_written < size_this_shard: 75 | _write_shard_files(data, ids_fh, metas_fh, vectors_fh) 76 | shard_docs_written += 1 77 | progress.update(1) 78 | progress.close() 79 | 80 | 81 | def _write_shard_files( 82 | data: GENERATOR_TYPE, 83 | ids_fh: TextIO, 84 | metas_fh: BinaryIO, 85 | vectors_fh: BinaryIO, 86 | ): 87 | id_, vec, meta = next(data) 88 | # need to ensure compatibility to read time 89 | if vec is None: 90 | vec = EMPTY_BYTES 91 | if isinstance(vec, np.ndarray): 92 | if vec.dtype != DUMP_DTYPE: 93 | vec = vec.astype(DUMP_DTYPE) 94 | vec = vec.tobytes() 95 | vectors_fh.write(len(vec).to_bytes(BYTE_PADDING, sys.byteorder) + vec) 96 | if meta is None: 97 | meta = EMPTY_BYTES 98 | metas_fh.write(len(meta).to_bytes(BYTE_PADDING, sys.byteorder) + meta) 99 | ids_fh.write(id_ + '\n') 100 | 101 | def _get_file_paths(shard_path: str): 102 | vectors_fp = os.path.join(shard_path, 'vectors') 103 | metas_fp = os.path.join(shard_path, 'metas') 104 | ids_fp = os.path.join(shard_path, 'ids') 105 | return vectors_fp, metas_fp, ids_fp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌟 HNSW + PostgreSQL Indexer 2 | 3 | HNSWPostgreSQLIndexer is a production-ready, scalable Indexer for the Jina neural search framework. 4 | 5 | It combines the reliability of PostgreSQL with the speed and efficiency of the HNSWlib nearest neighbor library. 6 | 7 | It thus provides all the CRUD operations expected of a database system, while also offering fast and reliable vector lookup. 8 | 9 | **Requires** a running PostgreSQL database service. For quick testing, you can run a containerized version locally with: 10 | 11 | `docker run -e POSTGRES_PASSWORD=123456 -p 127.0.0.1:5432:5432/tcp postgres:13.2` 12 | 13 | ## Syncing between PSQL and HNSW 14 | 15 | By default, all data is stored in a PSQL database (as defined in the arguments). 16 | In order to add data to / build a HNSW index with your data, you need to manually call the `/sync` endpoint. 17 | This iterates through the data you have stored, and adds it to the HNSW index. 18 | By default, this is done incrementally, on top of whatever data the HNSW index already has. 19 | If you want to completely rebuild the index, use the parameter `rebuild`, like so: 20 | 21 | ```python 22 | flow.post(on='/sync', parameters={'rebuild': True}) 23 | ``` 24 | 25 | At start-up time, the data from PSQL is synced into HNSW automatically. 26 | You can disable this with: 27 | 28 | ```python 29 | Flow().add( 30 | uses='jinahub://HNSWPostgresIndexer', 31 | uses_with={'startup_sync': False} 32 | ) 33 | ``` 34 | 35 | ### Automatic background syncing 36 | 37 | **⚠ WARNING: Experimental feature** 38 | 39 | Optionally, you can enable the option for automatic background syncing of the data into HNSW. 40 | This creates a thread in the background of the main operations, that will regularly perform the synchronization. 41 | This can be done with the `sync_interval` constructor argument, like so: 42 | 43 | ```python 44 | Flow().add( 45 | uses='jinahub://HNSWPostgresIndexer', 46 | uses_with={'sync_interval': 5} 47 | ) 48 | ``` 49 | 50 | `sync_interval` argument accepts an integer that represents the amount of seconds to wait between synchronization attempts. 51 | This should be adjusted based on your specific data amounts. 52 | For the duration of the background sync, the HNSW index will be locked to avoid invalid state, so searching will be queued. 53 | The same applies during search operations: the index is locked and indexing will be queued. 54 | 55 | ## CRUD operations 56 | 57 | You can perform all the usual operations on the respective endpoints 58 | 59 | - `/index`. Add new data to PostgreSQL 60 | - `/search`. Query the HNSW index with your Documents. 61 | - `/update`. Update documents in PostgreSQL 62 | - `/delete`. Delete documents in PostgreSQL. 63 | 64 | **Note**. This only performs soft-deletion by default. 65 | This is done in order to not break the look-up of the Document id after doing a search. 66 | For a hard delete, add `'soft_delete': False'` to `parameters` of the delete request. 67 | You might also perform a cleanup after a full rebuild of the HNSW index, by calling `/cleanup`. 68 | 69 | ## Status endpoint 70 | 71 | You can also get the information about the status of your data via the `/status` endpoint. 72 | This returns a `dict` whose tags contain the relevant information. 73 | The information can be accessed via the following keys in the `parameters.__results__` of a full flow response: 74 | 75 | - `'psql_docs'`: number of Documents stored in the PSQL database (includes entries that have been "soft-deleted") 76 | - `'hnsw_docs'`: the number of Documents indexed in the HNSW index 77 | - `'last_sync'`: the time of the last synchronization of PSQL into HNSW 78 | - `'pea_id'`: the shard number 79 | 80 | In a sharded environment (`parallel>1`) you will get one `dict` from each shard. 81 | Each shard will have its own `'hnsw_docs'`, `'last_sync'`, `'pea_id'`, but they will all report the same `'psql_docs'` 82 | (The PSQL database is available to all your shards). 83 | You need to sum the `'hnsw_docs'` across these dictionaries, like so 84 | 85 | ```python 86 | results = f.post('/status', None, return_responses=True) 87 | status_results = results[0].parameters["__results__"] 88 | total_hnsw_docs = sum(v['hnsw_docs'] for v in status_results.values()) 89 | ``` 90 | -------------------------------------------------------------------------------- /tests/integration/test_hnsw_psql.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os.path 3 | from collections import OrderedDict 4 | from typing import Dict 5 | 6 | import pytest 7 | from executor.hnswpsql import HNSWPostgresIndexer 8 | from jina import DocumentArray, Flow, Executor, requests 9 | from jina.logging.profile import TimeContext 10 | 11 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 12 | compose_yml = os.path.abspath(os.path.join( 13 | cur_dir, '..', 'docker-compose.yml')) 14 | 15 | METRIC = 'cosine' 16 | 17 | 18 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 19 | def test_basic_integration(docker_compose, get_documents): 20 | emb_size = 10 21 | nr_docs = 299 22 | 23 | docs = DocumentArray(get_documents(nr=nr_docs, emb_size=emb_size)) 24 | 25 | f = Flow().add( 26 | uses=HNSWPostgresIndexer, 27 | uses_with={ 28 | 'dim': emb_size, 29 | }, 30 | shards=1, 31 | # this will lead to warnings on PSQL for clashing ids 32 | # but required in order for the query request is sent 33 | # to all the shards 34 | polling='all', 35 | ) 36 | 37 | with f: 38 | # test for empty sync from psql 39 | f.post( 40 | '/sync', 41 | ) 42 | results = f.post('/status', None, return_responses=True) 43 | status_results = results[0].parameters["__results__"] 44 | first_hnsw_docs = sum(v['hnsw_docs'] for v in status_results.values()) 45 | for a_status in status_results.values(): 46 | assert int(a_status['psql_docs']) == 0 47 | assert int(first_hnsw_docs) == 0 48 | 49 | status = list(status_results.values())[0]['last_sync'] 50 | last_sync_timestamp = datetime.datetime.fromisoformat(status) 51 | 52 | f.post('/index', docs) 53 | 54 | search_docs = DocumentArray( 55 | get_documents(index_start=len(docs), emb_size=emb_size) 56 | ) 57 | 58 | search_docs = f.post('/search', search_docs) 59 | assert len(search_docs[0].matches) == 0 60 | 61 | f.post( 62 | '/sync', 63 | ) 64 | search_docs = f.post('/search', search_docs) 65 | assert len(search_docs[0].matches) > 0 66 | 67 | results2 = f.post('/status', None, return_responses=True) 68 | status_results2 = results2[0].parameters["__results__"] 69 | new_hnsw_docs = sum(v['hnsw_docs'] for v in status_results2.values()) 70 | assert new_hnsw_docs > first_hnsw_docs 71 | assert int(new_hnsw_docs) == len(docs) 72 | for a_status in status_results2.values(): 73 | assert int(a_status['psql_docs']) == len(docs) 74 | status2 = list(status_results2.values())[0]['last_sync'] 75 | last_sync = datetime.datetime.fromisoformat(status2) 76 | assert last_sync > last_sync_timestamp 77 | 78 | 79 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 80 | @pytest.mark.parametrize('nr_docs', [100]) 81 | @pytest.mark.parametrize('nr_search_docs', [10]) 82 | @pytest.mark.parametrize('emb_size', [10]) 83 | def test_replicas_integration( 84 | docker_compose, get_documents, nr_docs, nr_search_docs, emb_size, benchmark=False 85 | ): 86 | LIMIT = 10 87 | NR_SHARDS = 2 88 | # FIXME: rolling_update is deprecated in latest jina core, then replicas > 1 cannot pass the test. 89 | NR_REPLICAS = 1 90 | docs = get_documents(nr=nr_docs, emb_size=emb_size) 91 | 92 | uses_with = {'dim': emb_size, 'limit': LIMIT, 'mute_unique_warnings': True} 93 | 94 | f = Flow().add( 95 | name='indexer', 96 | uses=HNSWPostgresIndexer, 97 | uses_with=uses_with, 98 | shards=NR_SHARDS, 99 | replicas=NR_REPLICAS, 100 | # this will lead to warnings on PSQL for clashing ids 101 | # but required in order for the query request is sent 102 | # to all the shards 103 | polling='all', 104 | timeout_ready=-1, 105 | ) 106 | 107 | with f: 108 | results = f.post('/status', return_responses=True) 109 | status_results = results[0].parameters["__results__"] 110 | for a_status in status_results.values(): 111 | assert int(a_status['psql_docs']) == 0 112 | hnsw_docs = sum(v['hnsw_docs'] for v in status_results.values()) 113 | assert int(hnsw_docs) == 0 114 | 115 | request_size = 100 116 | if benchmark: 117 | request_size = 1000 118 | 119 | with TimeContext(f'indexing {nr_docs}'): 120 | f.post('/index', docs, request_size=request_size) 121 | 122 | status = f.post( 123 | '/status', return_responses=True)[0].parameters["__results__"] 124 | for a_status in status.values(): 125 | assert int(a_status['psql_docs']) == nr_docs 126 | assert int(a_status['hnsw_docs']) == 0 127 | 128 | search_docs = DocumentArray( 129 | get_documents(index_start=nr_docs, 130 | nr=nr_search_docs, emb_size=emb_size) 131 | ) 132 | 133 | if not benchmark: 134 | search_docs = f.post('/search', search_docs) 135 | assert len(search_docs[0].matches) == 0 136 | 137 | # NOTE: "rolling_update" is remove, refer to https://github.com/jina-ai/jina/pull/4517 138 | # with TimeContext(f'rolling update {NR_REPLICAS} replicas x {NR_SHARDS} shards'): 139 | # f.rolling_update(deployment_name='indexer', uses_with=uses_with) 140 | 141 | f.post( 142 | '/sync', 143 | ) 144 | 145 | results2 = f.post('/status', return_responses=True) 146 | status_results2 = results2[0].parameters["__results__"] 147 | for a_status in status_results2.values(): 148 | assert int(a_status['psql_docs']) == nr_docs 149 | hnsw_docs2 = sum(v['hnsw_docs'] for v in status_results2.values()) 150 | assert int(hnsw_docs2) == nr_docs 151 | 152 | with TimeContext(f'search with {nr_search_docs}'): 153 | search_docs = f.post('/search', search_docs) 154 | assert len(search_docs[0].matches) == NR_SHARDS * LIMIT 155 | # FIXME(core): see https://github.com/jina-ai/executor-hnsw-postgres/pull/7 156 | if benchmark: 157 | f.post('/clear') 158 | 159 | 160 | def in_docker(): 161 | """Returns: True if running in a Docker container, else False""" 162 | try: 163 | with open('/proc/1/cgroup', 'rt') as ifh: 164 | if 'docker' in ifh.read(): 165 | print('in docker, skipping benchmark') 166 | return True 167 | return False 168 | except: 169 | return False 170 | 171 | 172 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 173 | def test_benchmark_basic(docker_compose, get_documents): 174 | docs = [1_000, 10_000, 100_000, 1_000_000] 175 | if in_docker() or ('GITHUB_WORKFLOW' in os.environ): 176 | docs.pop() 177 | for nr_docs in docs: 178 | test_replicas_integration( 179 | docker_compose=docker_compose, 180 | nr_docs=nr_docs, 181 | get_documents=get_documents, 182 | nr_search_docs=10, 183 | emb_size=128, 184 | benchmark=True, 185 | ) 186 | 187 | 188 | @pytest.mark.parametrize('docker_compose', [compose_yml], indirect=['docker_compose']) 189 | def test_integration_cleanup(docker_compose, get_documents): 190 | emb_size = 10 191 | docs = DocumentArray(get_documents(nr=100, emb_size=emb_size)) 192 | 193 | uses_with = { 194 | 'dim': emb_size, 195 | } 196 | 197 | f = Flow().add( 198 | name='indexer', 199 | uses=HNSWPostgresIndexer, 200 | uses_with=uses_with, 201 | ) 202 | 203 | with f: 204 | f.post('/index', docs) 205 | results = f.post('/status', None, return_responses=True) 206 | status_results = results[0].parameters["__results__"] 207 | for a_status in status_results.values(): 208 | assert int(a_status['psql_docs']) == len(docs) 209 | 210 | # default to soft delete 211 | f.delete(docs) 212 | results2 = f.post('/status', None, return_responses=True) 213 | status_results2 = results2[0].parameters["__results__"] 214 | for a_status in status_results2.values(): 215 | assert int(a_status['psql_docs']) == len(docs) 216 | 217 | f.post(on='/cleanup') 218 | results3 = f.post('/status', None, return_responses=True) 219 | status_results3 = results3[0].parameters["__results__"] 220 | for a_status in status_results3.values(): 221 | assert int(a_status['psql_docs']) == 0 222 | 223 | 224 | # TODO test with update. same ids, diff embeddings, assert embeddings in match has 225 | # changed 226 | -------------------------------------------------------------------------------- /executor/postgres_indexer.py: -------------------------------------------------------------------------------- 1 | __copyright__ = 'Copyright (c) 2021 Jina AI Limited. All rights reserved.' 2 | __license__ = 'Apache-2.0' 3 | 4 | import datetime 5 | from datetime import timezone 6 | from typing import Dict 7 | 8 | try: 9 | from zoneinfo import ZoneInfo 10 | except ImportError: 11 | from backports.zoneinfo import ZoneInfo 12 | 13 | import numpy as np 14 | from jina import Document, DocumentArray 15 | from jina.logging.logger import JinaLogger 16 | 17 | from .commons import export_dump_streaming # this is for local testing 18 | from .postgreshandler import PostgreSQLHandler 19 | 20 | 21 | def doc_without_embedding(d: Document): 22 | new_doc = Document(d, copy=True) 23 | new_doc.embedding = None 24 | return new_doc.to_bytes() 25 | 26 | 27 | class PostgreSQLStorage: 28 | """:class:`PostgreSQLStorage` PostgreSQL-based Storage Indexer.""" 29 | 30 | def __init__( 31 | self, 32 | hostname: str = '127.0.0.1', 33 | port: int = 5432, 34 | username: str = 'postgres', 35 | password: str = '123456', 36 | database: str = 'postgres', 37 | table: str = 'default_table', 38 | max_connections=5, 39 | traversal_paths: str = '@r', 40 | return_embeddings: bool = True, 41 | dry_run: bool = False, 42 | partitions: int = 128, 43 | dump_dtype: type = np.float64, 44 | mute_unique_warnings: bool = False, 45 | *args, 46 | **kwargs, 47 | ): 48 | """ 49 | Initialize the PostgreSQLStorage. 50 | 51 | :param hostname: hostname of the machine 52 | :param port: the port 53 | :param username: the username to authenticate 54 | :param password: the password to authenticate 55 | :param database: the database name 56 | :param table: the table name to use 57 | :param return_embeddings: whether to return embeddings on search or 58 | not 59 | :param dry_run: If True, no database connection will be build. 60 | :param partitions: the number of shards to distribute 61 | the data (used when rolling update on Searcher side) 62 | :param mute_unique_warnings: whether to mute warnings about unique 63 | ids constraint failing (useful when indexing with shards and 64 | polling = 'all') 65 | """ 66 | self.default_traversal_paths = traversal_paths 67 | self.hostname = hostname 68 | self.port = port 69 | self.username = username 70 | self.password = password 71 | self.database = database 72 | self.table = table 73 | self.logger = JinaLogger('psql_indexer') 74 | self.partitions = partitions 75 | self.handler = PostgreSQLHandler( 76 | hostname=self.hostname, 77 | port=self.port, 78 | username=self.username, 79 | password=self.password, 80 | database=self.database, 81 | table=self.table, 82 | max_connections=max_connections, 83 | dry_run=dry_run, 84 | partitions=partitions, 85 | dump_dtype=dump_dtype, 86 | mute_unique_warnings=mute_unique_warnings, 87 | ) 88 | self.default_return_embeddings = return_embeddings 89 | 90 | @property 91 | def dump_dtype(self): 92 | return self.handler.dump_dtype 93 | 94 | @property 95 | def size(self): 96 | """Obtain the size of the table 97 | 98 | .. # noqa: DAR201 99 | """ 100 | return self.handler.get_size() 101 | 102 | @property 103 | def snapshot_size(self): 104 | """Obtain the size of the table 105 | 106 | .. # noqa: DAR201 107 | """ 108 | return self.handler.get_snapshot_size() 109 | 110 | def add(self, docs: DocumentArray, parameters: Dict, **kwargs): 111 | """Add Documents to Postgres 112 | 113 | :param docs: list of Documents 114 | :param parameters: parameters to the request 115 | """ 116 | if docs is None: 117 | return 118 | traversal_paths = parameters.get( 119 | 'traversal_paths', self.default_traversal_paths 120 | ) 121 | self.handler.add(docs[traversal_paths]) 122 | 123 | def update(self, docs: DocumentArray, parameters: Dict, **kwargs): 124 | """Updated document from the database. 125 | 126 | :param docs: list of Documents 127 | :param parameters: parameters to the request 128 | """ 129 | if docs is None: 130 | return 131 | traversal_paths = parameters.get( 132 | 'traversal_paths', self.default_traversal_paths 133 | ) 134 | self.handler.update(docs[traversal_paths]) 135 | 136 | def cleanup(self, **kwargs): 137 | """ 138 | Full deletion of the entries that 139 | have been marked for soft-deletion 140 | """ 141 | self.handler.cleanup() 142 | 143 | def delete(self, docs: DocumentArray, parameters: Dict, **kwargs): 144 | """Delete document from the database. 145 | 146 | NOTE: This is a soft-deletion, required by the snapshotting 147 | mechanism in the PSQLFaissCompound 148 | 149 | For a real delete, use the /cleanup endpoint 150 | 151 | :param docs: list of Documents 152 | :param parameters: parameters to the request 153 | """ 154 | if docs is None: 155 | return 156 | traversal_paths = parameters.get( 157 | 'traversal_paths', self.default_traversal_paths 158 | ) 159 | soft_delete = parameters.get('soft_delete', False) 160 | self.handler.delete(docs[traversal_paths], soft_delete) 161 | 162 | def dump(self, parameters: Dict, **kwargs): 163 | """Dump the index 164 | 165 | :param parameters: a dictionary containing the parameters for the dump 166 | """ 167 | path = parameters.get('dump_path') 168 | if path is None: 169 | self.logger.error(f'No "dump_path" provided for {self}') 170 | 171 | shards = int(parameters.get('shards')) 172 | if shards is None: 173 | self.logger.error(f'No "shards" provided for {self}') 174 | 175 | include_metas = parameters.get('include_metas', True) 176 | 177 | export_dump_streaming( 178 | path, 179 | shards=shards, 180 | size=self.size, 181 | data=self.handler.get_generator(include_metas=include_metas), 182 | logger=self.logger, 183 | ) 184 | 185 | def close(self) -> None: 186 | """ 187 | Close the connections in the connection pool 188 | """ 189 | # TODO perhaps store next_shard_to_use? 190 | self.handler.close() 191 | 192 | def search(self, docs: DocumentArray, parameters: Dict, **kwargs): 193 | """Get the Documents by the ids of the docs in the DocArray 194 | 195 | :param docs: the DocumentArray to search 196 | with (they only need to have the `.id` set) 197 | :param parameters: the parameters to this request 198 | """ 199 | if docs is None: 200 | return 201 | traversal_paths = parameters.get( 202 | 'traversal_paths', self.default_traversal_paths 203 | ) 204 | 205 | self.handler.search( 206 | docs[traversal_paths], 207 | return_embeddings=parameters.get( 208 | 'return_embeddings', self.default_return_embeddings 209 | ), 210 | ) 211 | 212 | def snapshot(self, **kwargs): 213 | """ 214 | Create a snapshot duplicate of the current table 215 | """ 216 | # TODO argument with table name, database location 217 | # maybe send to another PSQL instance to avoid perf hit? 218 | self.handler.snapshot() 219 | 220 | def get_snapshot(self, shard_id: int, total_shards: int): 221 | """Get the data meant out of the snapshot, distributed 222 | to this shard id, out of X total shards, based on the virtual 223 | shards allocated. 224 | """ 225 | if self.snapshot_size > 0: 226 | shards_to_get = self._vshards_to_get( 227 | shard_id, total_shards, self.partitions 228 | ) 229 | 230 | return self.handler.get_snapshot(shards_to_get) 231 | else: 232 | self.logger.warning('Not data in PSQL db snapshot. Nothing to export...') 233 | return None 234 | 235 | @staticmethod 236 | def _vshards_to_get(shard_id, total_shards, virtual_shards): 237 | if shard_id > total_shards - 1: 238 | raise ValueError('shard_id should be 0-indexed out of range(total_shards)') 239 | vshards = list(range(virtual_shards)) 240 | vshard_part = ( 241 | virtual_shards // total_shards 242 | ) # nr of virtual shards given to one shard 243 | vshard_remainder = virtual_shards % total_shards 244 | if shard_id == total_shards - 1: 245 | shards_to_get = vshards[ 246 | shard_id 247 | * vshard_part : ((shard_id + 1) * vshard_part + vshard_remainder) 248 | ] 249 | else: 250 | shards_to_get = vshards[ 251 | shard_id * vshard_part : (shard_id + 1) * vshard_part 252 | ] 253 | return [str(shard_id) for shard_id in shards_to_get] 254 | 255 | def _get_delta(self, shard_id, total_shards, timestamp: datetime.datetime): 256 | """ 257 | Get the rows that have changed since the last timestamp, per shard 258 | """ 259 | # we assume all db timestamps are UTC +00 260 | try: 261 | timestamp = timestamp.astimezone(ZoneInfo('UTC')) 262 | except ValueError: 263 | pass # year 0 if timestamp is min 264 | if self.size > 0: 265 | shards_to_get = self._vshards_to_get( 266 | shard_id, total_shards, self.partitions 267 | ) 268 | 269 | return self.handler._get_delta(shards_to_get, timestamp) 270 | else: 271 | self.logger.warning('No data in PSQL to sync into HNSW. Skipping') 272 | return None 273 | 274 | @property 275 | def last_snapshot_timestamp(self) -> datetime.datetime: 276 | """ 277 | Get the timestamp of the snapshot 278 | """ 279 | return next(self.handler._get_snapshot_timestamp()) 280 | 281 | @property 282 | def last_timestamp(self) -> datetime.datetime: 283 | """ 284 | Get the latest timestamp of the data 285 | """ 286 | return next(self.handler._get_data_timestamp()) 287 | 288 | def clear(self, **kwargs): 289 | """ 290 | Full deletion of the entries (hard-delete) 291 | :param kwargs: 292 | :return: 293 | """ 294 | self.handler.clear() 295 | 296 | @property 297 | def initialized(self, **kwargs): 298 | """ 299 | Whether the PSQL connection has been initialized 300 | """ 301 | return hasattr(self.handler, 'postgreSQL_pool') 302 | -------------------------------------------------------------------------------- /executor/hnswlib_searcher.py: -------------------------------------------------------------------------------- 1 | __copyright__ = "Copyright (c) 2021 Jina AI Limited. All rights reserved." 2 | __license__ = "Apache-2.0" 3 | 4 | import json 5 | from datetime import datetime, timezone 6 | from typing import Dict, Iterable, Optional, Generator, Tuple, List 7 | 8 | import hnswlib 9 | import numpy as np 10 | from bidict import bidict 11 | from jina import DocumentArray, Document 12 | from jina.logging.logger import JinaLogger 13 | 14 | GENERATOR_DELTA = Generator[ 15 | Tuple[str, Optional[np.ndarray], Optional[datetime]], None, None 16 | ] 17 | 18 | HNSW_TYPE = np.float32 19 | DEFAULT_METRIC = 'cosine' 20 | 21 | 22 | class HnswlibSearcher: 23 | """Hnswlib powered vector indexer. 24 | 25 | This indexer uses the HNSW algorithm to index and search for vectors. It does not 26 | require training, and can be built up incrementally. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | limit: int = 10, 32 | metric: str = DEFAULT_METRIC, 33 | dim: int = 0, 34 | max_elements: int = 1_000_000, 35 | ef_construction: int = 200, 36 | ef_query: int = 50, 37 | max_connection: int = 16, 38 | dump_path: Optional[str] = None, 39 | traversal_paths: str = '@r', 40 | is_distance: bool = True, 41 | last_timestamp: datetime = datetime.fromtimestamp(0, timezone.utc), 42 | num_threads: int = -1, 43 | *args, 44 | **kwargs, 45 | ): 46 | """ 47 | :param limit: Number of results to get for each query document in search 48 | :param metric: Distance metric type, can be 'euclidean', 'inner_product', 49 | or 'cosine' 50 | :param dim: The dimensionality of vectors to index 51 | :param max_elements: Maximum number of elements (vectors) to index 52 | :param ef_construction: The construction time/accuracy trade-off 53 | :param ef_query: The query time accuracy/speed trade-off. High is more 54 | accurate but slower 55 | :param max_connection: The maximum number of outgoing connections in the 56 | graph (the "M" parameter) 57 | :param dump_path: The path to the directory from where to load, and where to 58 | save the index state 59 | :param traversal_paths: The default traversal path on docs (used for 60 | indexing, search and update), e.g. '@r', '@c', '@r,c' 61 | :param is_distance: Boolean flag that describes if distance metric need to 62 | be reinterpreted as similarities. 63 | :param last_timestamp: the last time we synced into this HNSW index 64 | :param num_threads: nr of threads to use during indexing. -1 is default 65 | """ 66 | self.limit = limit 67 | self.metric = metric 68 | self.dim = dim 69 | self.max_elements = max_elements 70 | self.traversal_paths = traversal_paths 71 | self.ef_construction = ef_construction 72 | self.ef_query = ef_query 73 | self.max_connection = max_connection 74 | self.dump_path = dump_path 75 | self.is_distance = is_distance 76 | self.last_timestamp = last_timestamp 77 | self.num_threads = num_threads 78 | 79 | self.logger = JinaLogger(self.__class__.__name__) 80 | self._index = hnswlib.Index(space=self.metric_type, dim=self.dim) 81 | 82 | # TODO(Cristian): decide whether to keep this for eventual dump loading 83 | dump_path = self.dump_path or kwargs.get('runtime_args', {}).get( 84 | 'dump_path', None 85 | ) 86 | if dump_path is not None: 87 | self.logger.info('Starting to build HnswlibSearcher from dump data') 88 | 89 | self._index.load_index( 90 | f'{self.dump_path}/index.bin', max_elements=self.max_elements 91 | ) 92 | with open(f'{self.dump_path}/ids.json', 'r') as f: 93 | self._ids_to_inds = bidict(json.load(f)) 94 | 95 | else: 96 | self._init_empty_index() 97 | 98 | self._index.set_ef(self.ef_query) 99 | 100 | def _init_empty_index(self): 101 | self._index.init_index( 102 | max_elements=self.max_elements, 103 | ef_construction=self.ef_construction, 104 | M=self.max_connection, 105 | ) 106 | self._ids_to_inds = bidict() 107 | 108 | def search( 109 | self, docs: Optional[DocumentArray] = None, parameters: Dict = {}, **kwargs 110 | ): 111 | """Attach matches to the Documents in `docs`, each match containing only the 112 | `id` of the matched document and the `score`. 113 | 114 | :param docs: An array of `Documents` that should have the `embedding` property 115 | of the same dimension as vectors in the index 116 | :param parameters: Dictionary with optional parameters that can be used to 117 | override the parameters set at initialization. Supported keys are 118 | `traversal_paths`, `limit` and `ef_query`. 119 | """ 120 | if docs is None: 121 | return 122 | 123 | traversal_paths = parameters.get('traversal_paths', self.traversal_paths) 124 | docs_search = docs[traversal_paths] 125 | if len(docs_search) == 0: 126 | return 127 | 128 | ef_query = parameters.get('ef_query', self.ef_query) 129 | limit = int(parameters.get('limit', self.limit)) 130 | 131 | self._index.set_ef(ef_query) 132 | 133 | if limit > len(self._ids_to_inds): 134 | limit = len(self._ids_to_inds) 135 | 136 | embeddings_search = docs_search.embeddings 137 | if embeddings_search.shape[1] != self.dim: 138 | raise ValueError( 139 | 'Query documents have embeddings with dimension' 140 | f' {embeddings_search.shape[1]}, which does not match the dimension ' 141 | f'of' 142 | f' the index ({self.dim})' 143 | ) 144 | 145 | indices, dists = self._index.knn_query(docs_search.embeddings, k=limit) 146 | 147 | for i, (indices_i, dists_i) in enumerate(zip(indices, dists)): 148 | for idx, dist in zip(indices_i, dists_i): 149 | match = Document(id=self._ids_to_inds.inverse[idx]) 150 | if self.is_distance: 151 | match.scores[self.metric].value = dist 152 | elif self.metric in ["inner_product", "cosine"]: 153 | match.scores[self.metric].value = 1 - dist 154 | elif self.metric == 'euclidean': 155 | match.scores[self.metric].value = 1 / (1 + dist) 156 | else: 157 | match.scores[self.metric].value = dist 158 | docs_search[i].matches.append(match) 159 | 160 | def index( 161 | self, docs: Optional[DocumentArray] = None, parameters: Dict = {}, **kwargs 162 | ): 163 | """Index the Documents' embeddings. If the document is already in index, it 164 | will be updated. 165 | 166 | :param docs: Documents whose `embedding` to index. 167 | :param parameters: Dictionary with optional parameters that can be used to 168 | override the parameters set at initialization. The only supported key is 169 | `traversal_paths`. 170 | """ 171 | traversal_paths = parameters.get('traversal_paths', self.traversal_paths) 172 | if docs is None: 173 | return 174 | 175 | docs_to_index = docs[traversal_paths] 176 | if len(docs_to_index) == 0: 177 | return 178 | 179 | embeddings = docs_to_index.embeddings 180 | if embeddings.shape[-1] != self.dim: 181 | raise ValueError( 182 | f'Attempted to index vectors with dimension' 183 | f' {embeddings.shape[-1]}, but dimension of index is {self.dim}' 184 | ) 185 | 186 | ids = docs_to_index[:, 'id'] 187 | index_size = self._index.element_count 188 | docs_inds = [] 189 | for id in ids: 190 | if id not in self._ids_to_inds: 191 | docs_inds.append(index_size) 192 | index_size += 1 193 | else: 194 | self.logger.info(f'Document with id {id} already in index, updating.') 195 | docs_inds.append(self._ids_to_inds[id]) 196 | self._add(embeddings, ids, docs_inds) 197 | 198 | def _add(self, embeddings, ids, docs_inds: Optional[List[int]] = None): 199 | if docs_inds is None: 200 | docs_inds = list( 201 | range(self._index.element_count, self._index.element_count + len(ids)) 202 | ) 203 | self._index.add_items(embeddings, ids=docs_inds, num_threads=self.num_threads) 204 | self._ids_to_inds.update({_id: ind for _id, ind in zip(ids, docs_inds)}) 205 | 206 | def update( 207 | self, docs: Optional[DocumentArray] = None, parameters: Dict = {}, **kwargs 208 | ): 209 | """Update the Documents' embeddings. If a Document is not already present in 210 | the index, it will get ignored, and a warning will be raised. 211 | 212 | :param docs: Documents whose `embedding` to update. 213 | :param parameters: Dictionary with optional parameters that can be used to 214 | override the parameters set at initialization. The only supported key is 215 | `traversal_paths`. 216 | """ 217 | traversal_paths = parameters.get('traversal_paths', self.traversal_paths) 218 | if docs is None: 219 | return 220 | 221 | docs_to_update = docs[traversal_paths] 222 | if len(docs_to_update) == 0: 223 | return 224 | 225 | # TODO(Cristian): don't recreate DA if the ids all exist. 226 | # we are punishing everyone the same 227 | # instead of rewarding people that send updated with existing ids 228 | doc_inds, docs_filtered = [], [] 229 | for doc in docs_to_update: 230 | if doc.id not in self._ids_to_inds: 231 | self.logger.warning( 232 | f'Attempting to update document with id {doc.id} which is not' 233 | ' indexed, skipping. To add documents to index, use the /index' 234 | ' endpoint' 235 | ) 236 | else: 237 | docs_filtered.append(doc) 238 | doc_inds.append(self._ids_to_inds[doc.id]) 239 | docs_filtered = DocumentArray(docs_filtered) 240 | 241 | embeddings = docs_filtered.embeddings 242 | if embeddings.shape[-1] != self.dim: 243 | raise ValueError( 244 | f'Attempted to update vectors with dimension' 245 | f' {embeddings.shape[-1]}, but dimension of index is {self.dim}' 246 | ) 247 | 248 | self._index.add_items(embeddings, ids=doc_inds, num_threads=self.num_threads) 249 | 250 | def delete(self, parameters: Dict, **kwargs): 251 | """Delete entries from the index by id 252 | 253 | :param parameters: parameters to the request. Should contain the list of ids 254 | of entries (Documents) to delete under the `ids` key 255 | """ 256 | deleted_ids = parameters.get('ids', []) 257 | 258 | for _id in set(deleted_ids).intersection(self._ids_to_inds.keys()): 259 | ind = self._ids_to_inds[_id] 260 | self._index.mark_deleted(ind) 261 | del self._ids_to_inds[_id] 262 | 263 | def dump(self, parameters: Dict = {}, **kwargs): 264 | """Save the index and document ids. 265 | 266 | The index and ids will be saved separately for each shard. 267 | 268 | :param parameters: Dictionary with optional parameters that can be used to 269 | override the parameters set at initialization. The only supported key is 270 | `dump_path`. 271 | """ 272 | 273 | dump_path = parameters.get('dump_path', self.dump_path) 274 | if dump_path is None: 275 | raise ValueError( 276 | 'The `dump_path` must be provided to save the indexer state.' 277 | ) 278 | 279 | self._index.save_index(f'{dump_path}/index.bin') 280 | with open(f'{dump_path}/ids.json', 'w') as f: 281 | json.dump(dict(self._ids_to_inds), f) 282 | 283 | def clear(self, **kwargs): 284 | """Clear the index of all entries.""" 285 | self._index = hnswlib.Index(space=self.metric_type, dim=self.dim) 286 | self._init_empty_index() 287 | self._index.set_ef(self.ef_query) 288 | 289 | def status(self) -> Dict: 290 | """Return the status information about the indexer. 291 | 292 | The status will contain information on the total number of indexed and deleted 293 | documents, and on the number of (searchable) documents currently in the index. 294 | """ 295 | 296 | status = { 297 | 'count_deleted': self._index.element_count - len(self._ids_to_inds), 298 | 'count_indexed': self._index.element_count, 299 | 'count_active': self.size, 300 | } 301 | return status 302 | 303 | @property 304 | def size(self): 305 | return len(self._ids_to_inds) 306 | 307 | @property 308 | def metric_type(self): 309 | if self.metric == 'euclidean': 310 | metric_type = 'l2' 311 | elif self.metric == 'cosine': 312 | metric_type = 'cosine' 313 | elif self.metric == 'inner_product': 314 | metric_type = 'ip' 315 | 316 | if self.metric not in ['euclidean', 'cosine', 'inner_product']: 317 | self.logger.warning( 318 | f'Invalid distance metric {self.metric} for HNSW index construction! ' 319 | 'Default to euclidean distance' 320 | ) 321 | metric_type = DEFAULT_METRIC 322 | 323 | return metric_type 324 | 325 | def sync(self, delta: GENERATOR_DELTA): 326 | if delta is None: 327 | self.logger.warning('No data received in HNSW.sync. Skipping...') 328 | return 329 | 330 | for doc_id, vec_array, doc_timestamp in delta: 331 | idx = self._ids_to_inds.get(doc_id) 332 | 333 | if (vec_array is not None) and np.isnan(vec_array).any(): 334 | self.logger.error( 335 | f'NaN value contained in the embedding of doc {doc_id}' 336 | ) 337 | continue 338 | 339 | # TODO: performance improvements possible 340 | # instead of creating new Ds and DAs individually 341 | # we can can batch 342 | if idx is None: 343 | if vec_array is None: 344 | continue 345 | vec = vec_array.astype(HNSW_TYPE) 346 | 347 | self._add([vec], [doc_id]) 348 | elif vec_array is None: 349 | self.delete({'ids': [doc_id]}) 350 | else: 351 | vec = vec_array.reshape(1, -1).astype(HNSW_TYPE) 352 | da = DocumentArray(Document(id=doc_id, embedding=vec)) 353 | self.update(da) 354 | 355 | if doc_timestamp > self.last_timestamp: 356 | self.last_timestamp = doc_timestamp 357 | 358 | def index_sync(self, iterator: GENERATOR_DELTA, batch_size=100) -> None: 359 | # there might be new operations on PSQL in the meantime 360 | timestamp = datetime.now(timezone.utc) 361 | if iterator is None: 362 | self.logger.warning('No data received in HNSW.sync. Skipping...') 363 | return 364 | 365 | this_batch_size = 0 366 | # batching 367 | this_batch_embeds = np.zeros((batch_size, self.dim), dtype=HNSW_TYPE) 368 | this_batch_ids = [] 369 | 370 | while True: 371 | try: 372 | doc_id, vec_array, _ = next(iterator) 373 | if vec_array is None or np.isnan(vec_array).any(): 374 | self.logger.error( 375 | f'The doc {doc_id} does not contain embedding, or NaN values are found.' 376 | ) 377 | continue 378 | 379 | vec = vec_array.astype(HNSW_TYPE) 380 | this_batch_embeds[this_batch_size] = vec 381 | this_batch_ids.append(doc_id) 382 | this_batch_size += 1 383 | 384 | if this_batch_size == batch_size: 385 | # do it 386 | # we don't send the 0s 387 | self._add(this_batch_embeds[:this_batch_size], this_batch_ids) 388 | this_batch_size = 0 389 | this_batch_ids = [] 390 | except StopIteration: 391 | if this_batch_size > 0: 392 | self._add(this_batch_embeds[:this_batch_size], this_batch_ids) 393 | break 394 | 395 | self.last_timestamp = timestamp 396 | -------------------------------------------------------------------------------- /executor/hnswpsql.py: -------------------------------------------------------------------------------- 1 | __copyright__ = "Copyright (c) 2021 Jina AI Limited. All rights reserved." 2 | __license__ = "Apache-2.0" 3 | 4 | import copy 5 | import inspect 6 | import threading 7 | import time 8 | import traceback 9 | from contextlib import nullcontext 10 | from datetime import datetime, timezone 11 | from threading import Thread 12 | from typing import Optional, Tuple, Dict, Union 13 | 14 | import numpy as np 15 | from jina import Executor, requests, DocumentArray, Document 16 | from jina.logging.logger import JinaLogger 17 | 18 | from .hnswlib_searcher import HnswlibSearcher, DEFAULT_METRIC 19 | from .postgres_indexer import PostgreSQLStorage 20 | 21 | 22 | def _get_method_args(): 23 | frame = inspect.currentframe().f_back 24 | keys, _, _, values = inspect.getargvalues(frame) 25 | kwargs = {} 26 | for key in keys: 27 | if key != 'self': 28 | kwargs[key] = values[key] 29 | return kwargs 30 | 31 | 32 | class HNSWPostgresIndexer(Executor): 33 | """ 34 | Production-ready, scalable Indexer for the Jina neural search framework. 35 | 36 | Combines the reliability of PostgreSQL with the speed and efficiency of the 37 | HNSWlib nearest neighbor library. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | total_shards: Optional[int] = None, 43 | startup_sync: bool = True, 44 | sync_interval: Optional[int] = None, 45 | limit: int = 10, 46 | metric: str = DEFAULT_METRIC, 47 | dim: int = 0, 48 | max_elements: int = 1_000_000, 49 | ef_construction: int = 400, 50 | ef_query: int = 50, 51 | max_connection: int = 64, 52 | is_distance: bool = True, 53 | num_threads: int = -1, 54 | traversal_paths: str = '@r', 55 | hostname: str = '127.0.0.1', 56 | port: int = 5432, 57 | username: str = 'postgres', 58 | password: str = '123456', 59 | database: str = 'postgres', 60 | table: str = 'default_table', 61 | return_embeddings: bool = True, 62 | dry_run: bool = False, 63 | partitions: int = 128, 64 | mute_unique_warnings: bool = False, 65 | **kwargs, 66 | ): 67 | """ 68 | :param startup_sync: whether to sync from PSQL into HNSW on start-up 69 | :param total_shards: the total nr of shards that this shard is part of. 70 | :param sync_interval: the interval between automatic background PSQL-HNSW 71 | syncs 72 | (if None, sync will be turned off) 73 | :param limit: (HNSW) Number of results to get for each query document in 74 | search 75 | :param metric: (HNSW) Distance metric type. Can be 'euclidean', 76 | 'inner_product', or 'cosine' 77 | :param dim: (HNSW) dimensionality of vectors to index 78 | :param max_elements: (HNSW) maximum number of elements (vectors) to index 79 | :param ef_construction: (HNSW) construction time/accuracy trade-off 80 | :param ef_query: (HNSW) query time accuracy/speed trade-off. High is more 81 | accurate but slower 82 | :param max_connection: (HNSW) The maximum number of outgoing connections in 83 | the graph (the "M" parameter) 84 | :param is_distance: (HNSW) if distance metric needs to be reinterpreted as 85 | similarity 86 | :param last_timestamp: (HNSW) the last time we synced into this HNSW index 87 | :param num_threads: (HNSW) nr of threads to use during indexing. -1 is default 88 | :param traversal_paths: (PSQL) default traversal paths on docs 89 | (used for indexing, delete and update), e.g. '@r', '@c', '@r,c' 90 | :param hostname: (PSQL) hostname of the machine 91 | :param port: (PSQL) the port 92 | :param username: (PSQL) the username to authenticate 93 | :param password: (PSQL) the password to authenticate 94 | :param database: (PSQL) the database 95 | :param table: (PSQL) the table to use 96 | :param return_embeddings: (PSQL) whether to return embeddings on 97 | search 98 | :param dry_run: (PSQL) If True, no database connection will be built 99 | :param partitions: (PSQL) the number of shards to distribute 100 | the data (used when syncing into HNSW) 101 | :param mute_unique_warnings: (PSQL) whether to mute warnings about unique 102 | ids constraint failing (useful when indexing with shards and polling = 'all') 103 | 104 | NOTE: 105 | 106 | - `total_shards` is REQUIRED in k8s, since there 107 | `runtime_args.parallel` is always 1 108 | - some arguments are passed to the inner classes. They are documented 109 | here for easier reference 110 | """ 111 | super().__init__(**kwargs) 112 | self.logger = JinaLogger( 113 | getattr(self.metas, 'name', self.__class__.__name__)) 114 | 115 | # TODO is there a way to improve this? 116 | # done because we want to have the args exposed in hub website 117 | # but we want to avoid having to manually pass every arg to the classes 118 | self._init_kwargs = _get_method_args() 119 | self._init_kwargs.update(kwargs) 120 | self.sync_interval = sync_interval 121 | self.lock = nullcontext() 122 | 123 | self._prepare_shards(total_shards) 124 | 125 | self._kv_indexer: Optional[PostgreSQLStorage] = None 126 | self._vec_indexer: Optional[HnswlibSearcher] = None 127 | 128 | ( 129 | self._kv_indexer, 130 | self._vec_indexer, 131 | ) = self._init_executors(self._init_kwargs) 132 | if startup_sync: 133 | self._sync() 134 | 135 | if self.sync_interval: 136 | self.lock = threading.Lock() 137 | self.stop_sync_thread = False 138 | self._start_auto_sync() 139 | 140 | def _prepare_shards(self, total_shards): 141 | warning_issued = False 142 | if total_shards is None: 143 | self.total_shards = getattr(self.runtime_args, 'shards', None) 144 | else: 145 | self.total_shards = total_shards 146 | if self.total_shards is None: 147 | self.logger.warning( 148 | 'total_shards was None. ' 149 | 'Setting it to 1 to allow non-sharded syncing. ' 150 | 'This can happen when running Executor outside a Flow or on k8s' 151 | ) 152 | self.total_shards = 1 153 | warning_issued = True 154 | 155 | if not hasattr(self.runtime_args, 'shard_id'): 156 | self.runtime_args.shard_id = 0 157 | if not warning_issued: 158 | self.logger.warning( 159 | 'shard_id was None. ' 160 | 'setting it to 1 to allow non-sharded syncing. ' 161 | 'This can happen when running the Executor outside a Flow' 162 | ) 163 | 164 | else: 165 | # shards is passed as str from Flow.add in yaml 166 | self.total_shards = int(self.total_shards) 167 | 168 | @requests(on='/sync') 169 | def sync(self, parameters: Dict, **kwargs): 170 | """ 171 | Perform a sync between PSQL and HNSW 172 | 173 | :param parameters: dictionary with options for sync 174 | 175 | Keys accepted: 176 | 177 | - 'rebuild' (bool): whether to rebuild HNSW or do 178 | incremental syncing 179 | - 'timestamp' (str): ISO-formatted timestamp string. Time 180 | from which to get data for syncing into HNSW 181 | - 'batch_size' (int): The batch size for indexing in HNSW 182 | """ 183 | self._sync(**parameters) 184 | 185 | def _sync( 186 | self, 187 | rebuild: bool = False, 188 | timestamp: str = None, 189 | batch_size: int = 100, 190 | **kwargs, 191 | ): 192 | timestamp: Optional[datetime] = self._compute_timestamp_for_sync( 193 | timestamp, rebuild 194 | ) 195 | if timestamp is None: 196 | return 197 | 198 | iterator = self._kv_indexer._get_delta( 199 | shard_id=self.runtime_args.shard_id, 200 | total_shards=self.total_shards, 201 | timestamp=timestamp, 202 | ) 203 | 204 | # we prevent race conditions with search 205 | with self.lock: 206 | if rebuild or self._vec_indexer.size == 0: 207 | # call with just indexing 208 | self._vec_indexer = HnswlibSearcher(**self._init_kwargs) 209 | self._vec_indexer.index_sync(iterator, batch_size) 210 | self.logger.info( 211 | f'Rebuilt HNSW index with {self._vec_indexer.size} docs' 212 | ) 213 | 214 | else: 215 | prev_size = self._vec_indexer.size 216 | self._vec_indexer.sync(iterator) 217 | if prev_size != self._vec_indexer.size: 218 | self.logger.info( 219 | f'Synced HNSW index from {prev_size} docs to ' 220 | f'{self._vec_indexer.size}' 221 | ) 222 | else: 223 | self.logger.info( 224 | f'Performed empty sync. HNSW index size is still {prev_size}' 225 | ) 226 | 227 | def _init_executors( 228 | self, _init_kwargs 229 | ) -> Tuple[PostgreSQLStorage, HnswlibSearcher]: 230 | kv_indexer = PostgreSQLStorage(dump_dtype=np.float32, **_init_kwargs) 231 | vec_indexer = HnswlibSearcher(**_init_kwargs) 232 | return kv_indexer, vec_indexer 233 | 234 | @requests(on='/index') 235 | def index(self, docs: DocumentArray, parameters: Dict, **kwargs): 236 | """Index new documents 237 | 238 | NOTE: PSQL has a uniqueness constraint on ID 239 | 240 | :param docs: the Documents to index 241 | :param parameters: dictionary with options for indexing 242 | 243 | Keys accepted: 244 | 245 | - 'traversal_paths' (str): traversal path for the docs 246 | """ 247 | self._kv_indexer.add(docs, parameters, **kwargs) 248 | 249 | @requests(on='/update') 250 | def update(self, docs: DocumentArray, parameters: Dict, **kwargs): 251 | """Update existing documents 252 | 253 | :param docs: the Documents to update 254 | :param parameters: dictionary with options for updating 255 | 256 | Keys accepted: 257 | 258 | - 'traversal_paths' (str): traversal path for the docs 259 | """ 260 | self._kv_indexer.update(docs, parameters, **kwargs) 261 | 262 | @requests(on='/delete') 263 | def delete(self, docs: DocumentArray, parameters: Dict, **kwargs): 264 | """Delete existing documents, by id 265 | 266 | :param docs: the Documents to delete 267 | :param parameters: dictionary with options for deleting 268 | 269 | Keys accepted: 270 | 271 | - 'traversal_paths' (str): traversal path for the docs 272 | - 'soft_delete' (bool, default `True`): whether to perform soft delete 273 | (doc is marked as empty but still exists in db, for retrieval purposes) 274 | """ 275 | if 'soft_delete' not in parameters: 276 | parameters['soft_delete'] = True 277 | 278 | self._kv_indexer.delete(docs, parameters, **kwargs) 279 | 280 | @requests(on='/clear') 281 | def clear(self, **kwargs): 282 | """ 283 | Delete all data from PSQL and HNSW 284 | 285 | """ 286 | if self._kv_indexer.initialized: 287 | self._kv_indexer.clear() 288 | self._vec_indexer = HnswlibSearcher(**self._init_kwargs) 289 | self._vec_indexer.clear() 290 | assert self._kv_indexer.size == 0 291 | assert self._vec_indexer.size == 0 292 | 293 | @requests(on='/status') 294 | def status(self, **kwargs): 295 | """ 296 | Get information on status of this Indexer inside a dictionary. 297 | In the flow return it will be exposed under the `parameters.__results__` with 298 | a key for each executor. 299 | 300 | e.g. {'__results__': {'executor0/shard-0/rep-0': {'hnsw_docs': 0.0, 301 | 'last_sync': '1970-01-01T00:00:00+00:00', 302 | 'psql_docs': 0.0, 303 | 'shard_id': 0.0}, 304 | 'executor0/shard-1/rep-0': {'hnsw_docs': 0.0, 305 | 'last_sync': '1970-01-01T00:00:00+00:00', 306 | 'psql_docs': 0.0, 307 | 'shard_id': 1.0}}} 308 | 309 | :return: Dictionary with keys 'psql_docs', 'hnsw_docs', 310 | 'last_sync', 'shard_id' 311 | """ 312 | psql_docs = None 313 | hnsw_docs = None 314 | last_sync = None 315 | 316 | if self._kv_indexer and self._kv_indexer.initialized: 317 | psql_docs = self._kv_indexer.size 318 | else: 319 | self.logger.warning(f'PSQL connection has not been initialized') 320 | 321 | if self._vec_indexer: 322 | hnsw_docs = self._vec_indexer.size 323 | last_sync = self._vec_indexer.last_timestamp 324 | last_sync = last_sync.isoformat() 325 | else: 326 | self.logger.warning(f'HNSW index has not been initialized') 327 | 328 | status = { 329 | 'psql_docs': psql_docs, 330 | 'hnsw_docs': hnsw_docs, 331 | 'last_sync': last_sync, 332 | 'shard_id': self.runtime_args.shard_id, 333 | } 334 | return status 335 | 336 | @requests(on='/search') 337 | def search(self, docs: 'DocumentArray', parameters: Dict = None, **kwargs): 338 | """Search the vec embeddings in HNSW and then lookup the metadata in PSQL 339 | 340 | The `HNSWSearcher` attaches matches to the `Documents` sent as 341 | inputs with the id of the match, and its embedding. 342 | Then, the `PostgreSQLStorage` retrieves the full metadata 343 | (original text or image blob) and attaches 344 | those to the Document. You receive back the full Documents as matches 345 | to your search Documents. 346 | 347 | :param docs: `Document` with `.embedding` the same shape as the 348 | `Documents` stored in the `HNSW` index. The ids of the `Documents` 349 | stored in `HNSW` need to exist in the PSQL. 350 | Otherwise you will not get back the original metadata. 351 | :param parameters: dictionary for parameters for the search operation 352 | 353 | 354 | - 'traversal_paths' (str): traversal paths for the docs 355 | - 'limit' (int): nr of matches to get per Document 356 | - 'ef_query' (int): query time accuracy/speed trade-off. High is more 357 | accurate but slower 358 | """ 359 | if self._kv_indexer and self._vec_indexer: 360 | # we prevent race conditions with sync 361 | with self.lock: 362 | self._vec_indexer.search(docs, parameters) 363 | 364 | kv_parameters = copy.deepcopy(parameters) 365 | kv_parameters['traversal_paths'] = ','.join( 366 | [ 367 | path + 'm' 368 | for path in kv_parameters.get('traversal_paths', '@r').split(',') 369 | ] 370 | ) 371 | self._kv_indexer.search(docs, kv_parameters) 372 | else: 373 | self.logger.warning( 374 | 'Indexers have not been initialized. Empty results') 375 | return 376 | 377 | @requests(on='/cleanup') 378 | def cleanup(self, **kwargs): 379 | """ 380 | Completely remove the entries in PSQL that have been 381 | soft-deleted (via the /delete endpoint) 382 | """ 383 | if self._kv_indexer: 384 | self._kv_indexer.cleanup() 385 | else: 386 | self.logger.warning(f'PSQL has not been initialized') 387 | 388 | def _start_auto_sync(self): 389 | self.sync_thread = Thread(target=self._sync_loop, daemon=False) 390 | self.sync_thread.start() 391 | 392 | def close(self) -> None: 393 | if hasattr(self, 'sync_thread'): 394 | # wait for sync thread to finish 395 | self.stop_sync_thread = True 396 | try: 397 | self.sync_thread.join() 398 | except Exception as e: 399 | self.logger.warning(f'Error when stopping sync thread: {e}') 400 | 401 | def _sync_loop(self): 402 | try: 403 | self.logger.warning(f'started sync thread') 404 | while True: 405 | self._sync(rebuild=False, timestamp=None, batch_size=100) 406 | self.logger.info(f'sync thread: Completed sync') 407 | time.sleep(self.sync_interval) 408 | if self.stop_sync_thread: 409 | self.logger.info(f'Exiting sync thread') 410 | return 411 | except Exception as e: 412 | self.logger.error(f'Sync thread failed: {e}') 413 | self.logger.error(traceback.format_exc()) 414 | 415 | def _compute_timestamp_for_sync( 416 | self, timestamp: Union[datetime, str], rebuild: bool 417 | ) -> Optional[datetime]: 418 | if timestamp is None: 419 | if rebuild: 420 | # we assume all db timestamps are UTC +00 421 | timestamp = datetime.fromtimestamp(0, timezone.utc) 422 | elif self._vec_indexer.last_timestamp: 423 | timestamp = self._vec_indexer.last_timestamp 424 | else: 425 | self.logger.error( 426 | f'No timestamp provided in parameters: ' 427 | f'and vec_indexer.last_timestamp' 428 | f'was None. Cannot do sync' 429 | ) 430 | return None 431 | else: 432 | timestamp = datetime.fromisoformat(timestamp) 433 | 434 | return timestamp 435 | -------------------------------------------------------------------------------- /executor/postgreshandler.py: -------------------------------------------------------------------------------- 1 | __copyright__ = "Copyright (c) 2021 Jina AI Limited. All rights reserved." 2 | __license__ = "Apache-2.0" 3 | 4 | import datetime 5 | import hashlib 6 | import copy as cp 7 | from contextlib import contextmanager 8 | from typing import Generator, List, Optional, Tuple 9 | 10 | import numpy as np 11 | import psycopg2 12 | import psycopg2.extras 13 | from jina import Document, DocumentArray 14 | from jina.logging.logger import JinaLogger 15 | from psycopg2 import pool # noqa: F401 16 | 17 | 18 | def doc_without_embedding(d: Document): 19 | new_doc = Document(d, copy=True) 20 | new_doc.embedding = None 21 | return new_doc.to_bytes() 22 | 23 | 24 | SCHEMA_VERSION = 2 25 | SCHEMA_VERSIONS_TABLE_NAME = 'schema_versions' 26 | 27 | 28 | class PostgreSQLHandler: 29 | """ 30 | Postgres Handler to connect to the database and 31 | can apply add, update, delete and query. 32 | 33 | :param hostname: hostname of the machine 34 | :param port: the port 35 | :param username: the username to authenticate 36 | :param password: the password to authenticate 37 | :param database: the database name 38 | :param collection: the collection name 39 | :param dry_run: If True, no database connection will be build 40 | :param partitions: the number of shards to 41 | distribute the data (used when rolling update on Searcher side) 42 | :param mute_unique_warnings: whether to mute warnings about unique 43 | ids constraint failing (useful when indexing with shards and 44 | polling = 'all') 45 | :param args: other arguments 46 | :param kwargs: other keyword arguments 47 | """ 48 | 49 | def __init__( 50 | self, 51 | hostname: str = '127.0.0.1', 52 | port: int = 5432, 53 | username: str = 'default_name', 54 | password: str = 'default_pwd', 55 | database: str = 'postgres', 56 | table: Optional[str] = 'default_table', 57 | max_connections: int = 10, 58 | dump_dtype: type = np.float64, 59 | dry_run: bool = False, 60 | partitions: int = 128, 61 | mute_unique_warnings: bool = False, 62 | *args, 63 | **kwargs, 64 | ): 65 | super().__init__(*args, **kwargs) 66 | self.logger = JinaLogger('psq_handler') 67 | self.table = table 68 | self.dump_dtype = dump_dtype 69 | self.partitions = partitions 70 | self.snapshot_table = 'snapshot' 71 | self.mute_unique_warnings = mute_unique_warnings 72 | 73 | if not dry_run: 74 | self.postgreSQL_pool = psycopg2.pool.SimpleConnectionPool( 75 | 1, 76 | max_connections, 77 | user=username, 78 | password=password, 79 | database=database, 80 | host=hostname, 81 | port=port, 82 | ) 83 | self._init_table() 84 | else: 85 | self.logger.info( 86 | 'PSQL started in dry run mode. Will not connect to ' 87 | 'PSQL service. Needs to be restarted to connect ' 88 | 'again, with `dry_run=False`' 89 | ) 90 | 91 | def _init_table(self): 92 | """ 93 | Use table if exists or create one if it doesn't. 94 | 95 | Create table if needed with id, vecs and metas. 96 | """ 97 | self._create_schema_version() 98 | 99 | if self._table_exists(): 100 | self._assert_table_schema_version() 101 | self.logger.info('Using existing table') 102 | else: 103 | self._create_table() 104 | 105 | def _execute_sql_gracefully(self, statement, data=tuple()): 106 | with self.get_connection() as connection: 107 | records = None 108 | try: 109 | cursor = connection.cursor() 110 | if data: 111 | cursor.execute(statement, data) 112 | else: 113 | cursor.execute(statement) 114 | if cursor.rowcount: 115 | try: 116 | records = cursor.fetchall() 117 | except psycopg2.ProgrammingError: 118 | # some queries will not have results but still have rowcount 119 | pass 120 | except psycopg2.errors.UniqueViolation as error: 121 | self.logger.debug(f'Error while executing {statement}: {error}.') 122 | 123 | connection.commit() 124 | 125 | return records 126 | 127 | def _create_schema_version(self): 128 | self._execute_sql_gracefully( 129 | f'''CREATE TABLE IF NOT EXISTS {SCHEMA_VERSIONS_TABLE_NAME} ( 130 | table_name varchar, 131 | schema_version integer 132 | );''' 133 | ) 134 | 135 | def _create_table(self): 136 | self._execute_sql_gracefully( 137 | f'''CREATE TABLE IF NOT EXISTS {self.table} ( 138 | doc_id VARCHAR PRIMARY KEY, 139 | embedding BYTEA, 140 | doc BYTEA, 141 | shard int, 142 | last_updated timestamp with time zone default current_timestamp 143 | ); 144 | INSERT INTO {SCHEMA_VERSIONS_TABLE_NAME} VALUES (%s, %s);''', 145 | (self.table, SCHEMA_VERSION), 146 | ) 147 | 148 | def _table_exists(self): 149 | return self._execute_sql_gracefully( 150 | 'SELECT EXISTS' 151 | '(' 152 | 'SELECT * FROM information_schema.tables ' 153 | 'WHERE table_name=%s' 154 | ')', 155 | (self.table,), 156 | )[0][0] 157 | 158 | def _assert_table_schema_version(self): 159 | with self.get_connection() as connection: 160 | cursor = connection.cursor() 161 | cursor.execute( 162 | f'SELECT schema_version FROM ' 163 | f'{SCHEMA_VERSIONS_TABLE_NAME} ' 164 | f'WHERE table_name=%s;', 165 | (self.table,), 166 | ) 167 | result = cursor.fetchone() 168 | if result: 169 | if result[0] != SCHEMA_VERSION: 170 | raise RuntimeError( 171 | f'The schema versions of the database ' 172 | f'(version {result[0]}) and the Executor ' 173 | f'(version {SCHEMA_VERSION}) do not match. ' 174 | f'Please migrate your data to the latest ' 175 | f'version or use an Executor version with a ' 176 | f'matching schema version.' 177 | ) 178 | else: 179 | raise RuntimeError( 180 | f'The schema versions of the database ' 181 | f'(NO version number) and the Executor ' 182 | f'(version {SCHEMA_VERSION}) do not match.' 183 | f'Please migrate your data to the latest version.' 184 | ) 185 | 186 | def add(self, docs: DocumentArray, *args, **kwargs): 187 | """Insert the documents into the database. 188 | 189 | :param docs: list of Documents 190 | :param args: other arguments 191 | :param kwargs: other keyword arguments 192 | :param args: other arguments 193 | :param kwargs: other keyword arguments 194 | :return record: List of Document's id added 195 | """ 196 | with self.get_connection() as connection: 197 | cursor = connection.cursor() 198 | try: 199 | psycopg2.extras.execute_batch( 200 | cursor, 201 | f'INSERT INTO {self.table} ' 202 | f'(doc_id, embedding, doc, shard, last_updated) ' 203 | f'VALUES (%s, %s, %s, %s, current_timestamp)', 204 | [ 205 | ( 206 | doc.id, 207 | doc.embedding.astype(self.dump_dtype).tobytes() 208 | if doc.embedding is not None 209 | else None, 210 | doc_without_embedding(doc), 211 | self._get_next_shard(doc.id), 212 | ) 213 | for doc in docs 214 | ], 215 | ) 216 | except psycopg2.errors.UniqueViolation as e: 217 | if not self.mute_unique_warnings: 218 | self.logger.warning( 219 | f'Document already exists in PSQL database.' 220 | f' {e}. Skipping entire transaction...' 221 | ) 222 | connection.rollback() 223 | connection.commit() 224 | 225 | def update(self, docs: DocumentArray, *args, **kwargs): 226 | """Updated documents from the database. 227 | 228 | :param docs: list of Documents 229 | :param args: other arguments 230 | :param kwargs: other keyword arguments 231 | :return record: List of Document's id after update 232 | """ 233 | with self.get_connection() as connection: 234 | cursor = connection.cursor() 235 | psycopg2.extras.execute_batch( 236 | cursor, 237 | f'UPDATE {self.table}\ 238 | SET embedding = %s,\ 239 | doc = %s,\ 240 | last_updated = current_timestamp \ 241 | WHERE doc_id = %s', 242 | [ 243 | ( 244 | doc.embedding.astype(self.dump_dtype).tobytes(), 245 | doc_without_embedding(doc), 246 | doc.id, 247 | ) 248 | for doc in docs 249 | ], 250 | ) 251 | connection.commit() 252 | 253 | def cleanup(self): 254 | """ 255 | Full deletion of the entries that 256 | have been marked for soft-deletion 257 | """ 258 | with self.get_connection() as connection: 259 | cursor = connection.cursor() 260 | cursor.execute( 261 | f'DELETE FROM {self.table} WHERE doc is NULL', 262 | ) 263 | connection.commit() 264 | 265 | def delete(self, docs: DocumentArray, soft_delete=False, *args, **kwargs): 266 | """Delete document from the database. 267 | 268 | NOTE: This can be a soft-deletion, required by the snapshotting 269 | mechanism in the PSQLFaissCompound 270 | 271 | For a real delete, use the /cleanup endpoint 272 | 273 | :param docs: list of Documents 274 | :param args: other arguments 275 | :param soft_delete: 276 | :param kwargs: other keyword arguments 277 | :return record: List of Document's id after deletion 278 | """ 279 | with self.get_connection() as connection: 280 | cursor = connection.cursor() 281 | if soft_delete: 282 | self.logger.warning( 283 | 'Performing soft-delete. Use /cleanup or a hard ' 284 | 'delete to delete the records' 285 | ) 286 | psycopg2.extras.execute_batch( 287 | cursor, 288 | f'UPDATE {self.table} ' 289 | f'SET embedding = NULL, ' 290 | f'doc = NULL, ' 291 | f'last_updated = current_timestamp ' 292 | f'WHERE doc_id = %s;', 293 | [(doc.id,) for doc in docs], 294 | ) 295 | else: 296 | psycopg2.extras.execute_batch( 297 | cursor, 298 | f'DELETE FROM {self.table} WHERE doc_id = %s;', 299 | [(doc.id,) for doc in docs], 300 | ) 301 | connection.commit() 302 | 303 | def close(self): 304 | self.postgreSQL_pool.closeall() 305 | 306 | @contextmanager 307 | def get_connection(self): 308 | """A ContextManager for quickly getting a cursor""" 309 | try: 310 | conn = self._get_connection() 311 | with conn: # ensure commit or rollback 312 | yield conn 313 | except: 314 | raise 315 | finally: 316 | self._close_connection(conn) 317 | 318 | def search(self, docs: DocumentArray, return_embeddings: bool = True, **kwargs): 319 | """Use the Postgres db as a key-value engine, 320 | returning the metadata of a document id""" 321 | if return_embeddings: 322 | embeddings_field = ', embedding ' 323 | else: 324 | embeddings_field = '' 325 | with self.get_connection() as connection: 326 | cursor = connection.cursor() 327 | for doc in docs: 328 | # retrieve metadata 329 | cursor.execute( 330 | f'SELECT doc {embeddings_field} FROM {self.table} WHERE doc_id = %s;', 331 | (doc.id,), 332 | ) 333 | result = cursor.fetchone() 334 | 335 | scores = cp.deepcopy(doc.scores) 336 | 337 | data = bytes(result[0]) 338 | retrieved_doc = Document.from_bytes(data) 339 | if return_embeddings and result[1] is not None: 340 | embedding = np.frombuffer(result[1], dtype=self.dump_dtype) 341 | retrieved_doc.embedding = embedding 342 | 343 | # update meta data 344 | doc._data = retrieved_doc._data 345 | # update scores 346 | doc.scores = scores 347 | 348 | def _close_connection(self, connection): 349 | # restore it to the pool 350 | self.postgreSQL_pool.putconn(connection, close=False) 351 | 352 | def _get_connection(self): 353 | # by default psycopg2 is not auto-committing 354 | # this means we can have rollbacks 355 | # and maintain ACID-ity 356 | connection = self.postgreSQL_pool.getconn() 357 | connection.autocommit = False 358 | return connection 359 | 360 | def get_size(self): 361 | with self.get_connection() as connection: 362 | cursor = connection.cursor() 363 | cursor.execute(f'SELECT COUNT(*) FROM {self.table}') 364 | records = cursor.fetchall() 365 | return records[0][0] 366 | 367 | def _get_next_shard(self, doc_id: str): 368 | sha = hashlib.sha256() 369 | sha.update(bytes(doc_id, 'utf-8')) 370 | return int(sha.hexdigest(), 16) % self.partitions 371 | 372 | def snapshot(self): 373 | """ 374 | Saves the state of the data table in a new table 375 | 376 | Required to be done in two steps because 377 | 1. create table like ... doesn't include data 378 | 2. insert into .. (select ...) doesn't include primary key definitions 379 | """ 380 | with self.get_connection() as connection: 381 | try: 382 | cursor = connection.cursor() 383 | cursor.execute( 384 | f'drop table if exists {self.snapshot_table}; ' 385 | f'create table {self.snapshot_table} ' 386 | f'(like {self.table} including all);' 387 | ) 388 | connection.commit() 389 | cursor = connection.cursor() 390 | cursor.execute( 391 | f'insert into {self.snapshot_table} (select * from {self.table});' 392 | ) 393 | connection.commit() 394 | self.logger.info('Successfully created snapshot') 395 | except (Exception, psycopg2.Error) as error: 396 | self.logger.error(f'Error snapshotting: {error}') 397 | connection.rollback() 398 | 399 | def get_snapshot(self, shards_to_get: List[str]): 400 | """ 401 | Get the data from the snapshot, for a specific range of virtual shards 402 | """ 403 | shards_quoted = tuple(int(shard) for shard in shards_to_get) 404 | with self.get_connection() as connection: 405 | try: 406 | cursor = connection.cursor('snapshot') 407 | cursor.itersize = 10000 408 | cursor.execute( 409 | f'SELECT doc_id, embedding from {self.snapshot_table} ' 410 | f'WHERE shard in %s ' 411 | f'ORDER BY doc_id', 412 | (shards_quoted,), 413 | ) 414 | for rec in cursor: 415 | vec = ( 416 | np.frombuffer(rec[1], dtype=self.dump_dtype) 417 | if rec[1] is not None 418 | else None 419 | ) 420 | yield rec[0], vec 421 | except (Exception, psycopg2.Error) as error: 422 | self.logger.error(f'Error importing snapshot: {error}') 423 | connection.rollback() 424 | connection.commit() 425 | 426 | def get_generator( 427 | self, include_metas=True 428 | ) -> Generator[Tuple[str, bytes, Optional[bytes]], None, None]: 429 | with self.get_connection() as connection: 430 | cursor = connection.cursor('generator') # server-side cursor 431 | cursor.itersize = 10000 432 | if include_metas: 433 | cursor.execute( 434 | f'SELECT doc_id, embedding, doc FROM {self.table} ORDER BY doc_id' 435 | ) 436 | for rec in cursor: 437 | yield rec[0], np.frombuffer(rec[1]) if rec[ 438 | 1 439 | ] is not None else None, rec[2] 440 | else: 441 | cursor.execute( 442 | f'SELECT doc_id, embedding FROM {self.table} ORDER BY doc_id' 443 | ) 444 | for rec in cursor: 445 | yield rec[0], np.frombuffer(rec[1]) if rec[ 446 | 1 447 | ] is not None else None, None 448 | 449 | def _get_snapshot_timestamp(self): 450 | """Get the timestamp of the snapshot""" 451 | with self.get_connection() as connection: 452 | cursor = connection.cursor() 453 | try: 454 | cursor.execute(f'SELECT MAX(last_updated) FROM {self.snapshot_table}') 455 | for rec in cursor: 456 | yield rec[0] 457 | except Exception as e: 458 | self.logger.error(f'Could not obtain timestamp from snapshot: {e}') 459 | 460 | def _get_data_timestamp(self): 461 | """Get the timestamp of the data""" 462 | with self.get_connection() as connection: 463 | cursor = connection.cursor() 464 | try: 465 | cursor.execute(f'SELECT MAX(last_updated) FROM {self.table}') 466 | for rec in cursor: 467 | yield rec[0] 468 | except Exception as e: 469 | self.logger.error(f'Could not obtain timestamp from data: {e}') 470 | 471 | def _get_delta( 472 | self, shards_to_get, timestamp 473 | ) -> Generator[Tuple[str, bytes, datetime.datetime], None, None]: 474 | with self.get_connection() as connection: 475 | cursor = connection.cursor('generator') # server-side cursor 476 | cursor.itersize = 10000 477 | shards_quoted = tuple(int(shard) for shard in shards_to_get) 478 | cursor.execute( 479 | f'SELECT doc_id, embedding, last_updated ' 480 | f'from {self.table} ' 481 | f'WHERE shard in %s ' 482 | f'and last_updated > %s ' 483 | f'ORDER BY doc_id', 484 | (shards_quoted, timestamp), 485 | ) 486 | for rec in cursor: 487 | second_val = ( 488 | np.frombuffer(rec[1], dtype=self.dump_dtype) 489 | if rec[1] is not None 490 | else None 491 | ) 492 | 493 | yield rec[0], second_val, rec[2] 494 | 495 | def get_snapshot_size(self): 496 | """ 497 | Get the size of the snapshot, if it exists. 498 | else 0 499 | """ 500 | with self.get_connection() as connection: 501 | try: 502 | cursor = connection.cursor() 503 | cursor.execute(f'SELECT COUNT(*) FROM {self.snapshot_table}') 504 | records = cursor.fetchall() 505 | return records[0][0] 506 | except Exception as e: 507 | self.logger.warning(f'Could not get size of snapshot: {e}') 508 | 509 | return 0 510 | 511 | def clear(self): 512 | """ 513 | Full hard-deletion of the entries 514 | :return: 515 | """ 516 | with self.get_connection() as connection: 517 | cursor = connection.cursor() 518 | cursor.execute(f'DELETE FROM {self.table}') 519 | connection.commit() 520 | 521 | @property 522 | def initialized(self, **kwargs): 523 | """ 524 | Whether the PSQL connection has been initialized 525 | """ 526 | return hasattr(self, 'postgreSQL_pool') 527 | --------------------------------------------------------------------------------