├── tests ├── __init__.py ├── unit_test │ ├── __init__.py │ └── test_factory.py └── integration_test │ ├── __init__.py │ ├── test_pinecone.py │ └── test_qdrant.py ├── vectordbs ├── __init__.py ├── factory.py ├── types.py └── providers │ ├── zilliz_datastore.py │ ├── pinecone_datastore.py │ ├── qdrant_datastore.py │ ├── weaviate_datastore.py │ ├── redis_datastore.py │ └── milvus_datastore.py ├── .gitignore ├── poetry.toml ├── pyproject.toml ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vectordbs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit_test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration_test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | **/__pycache__ 3 | .vscode -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | -------------------------------------------------------------------------------- /tests/unit_test/test_factory.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from vectordbs import factory 3 | 4 | class TestFactory(unittest.TestCase): 5 | def test_factory(self): 6 | pinecone = factory.get_datastore("pinecone") 7 | assert pinecone -------------------------------------------------------------------------------- /tests/integration_test/test_pinecone.py: -------------------------------------------------------------------------------- 1 | from vectordbs import factory 2 | import unittest 3 | 4 | class TestPineconeIntegration(unittest.TestCase): 5 | def test_pinecone_simple(): 6 | pinecone_db = factory.get_datastore("pinecone") 7 | 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "vectordbs" 3 | version = "0.1.0" 4 | description = "A python library that supports all vector databases specifically for LLM apps and frameworks" 5 | authors = ["Timothy Chen "] 6 | license = "Apache 2.0" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | weaviate-client = "^3.15.5" 12 | pinecone-client = "^2.2.1" 13 | chromadb = "^0.3.21" 14 | pymilvus = "^2.2.4" 15 | loguru = "^0.6.0" 16 | tenacity = "^8.2.2" 17 | envclasses = "^0.3.1" 18 | 19 | 20 | [build-system] 21 | requires = ["poetry-core"] 22 | build-backend = "poetry.core.masonry.api" 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vectordbs 2 | 3 | Vectordbs is a python library that supports all vector databases specifically for LLM apps and frameworks. 4 | 5 | Currently all LLM frameworks and apps (Llama-Index, Open AI, Langchain, Baby AGI, etc) are all having to reimplement support for each vector database with a lot of redundant work, and this library is hoping to existing and new frameworks don't have to reinvent the wheel. 6 | 7 | ## Credits 8 | 9 | The initial starting point of the project are based on existing work that happens in Langchain, Llama Index, Chatgpt-retrieval-plugin and Babyagi. Which is also the point of this project, which is help existing and future projects to not having to repeat the same work. 10 | 11 | ## List of Supported Vector databases 12 | 13 | 14 | | Task | Supported | 15 | |----------------|-----------| 16 | | Chroma | :x: | 17 | | Pinecone | :x: | 18 | | Weaviate | :x: | 19 | | Qdrant | :x: | 20 | | Milvus | :x: | 21 | | Redis | :x: | 22 | | Opensearch | :x: | 23 | | Postgres | :x: | 24 | 25 | 26 | ## Install 27 | 28 | ``` 29 | pip install vectordbs 30 | ``` 31 | -------------------------------------------------------------------------------- /vectordbs/factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from vectordbs.types import VectorStore 3 | 4 | def get_datastore(datastore: str) -> VectorStore: 5 | assert datastore is not None 6 | 7 | match datastore: 8 | case "pinecone": 9 | from vectordbs.providers.pinecone_datastore import PineconeDataStore 10 | 11 | return PineconeDataStore() 12 | case "weaviate": 13 | from vectordbs.providers.weaviate_datastore import WeaviateDataStore 14 | 15 | return WeaviateDataStore() 16 | case "milvus": 17 | from vectordbs.providers.milvus_datastore import MilvusDataStore 18 | 19 | return MilvusDataStore() 20 | case "zilliz": 21 | from vectordbs.providers.zilliz_datastore import ZillizDataStore 22 | 23 | return ZillizDataStore() 24 | case "redis": 25 | from vectordbs.providers.redis_datastore import RedisDataStore 26 | 27 | return RedisDataStore.init() 28 | case "qdrant": 29 | from vectordbs.providers.qdrant_datastore import QdrantDataStore 30 | 31 | return QdrantDataStore() 32 | case _: 33 | raise ValueError(f"Unsupported vector database: {datastore}") 34 | -------------------------------------------------------------------------------- /vectordbs/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Protocol, runtime_checkable 3 | from enum import Enum 4 | from abc import ABC, abstractmethod 5 | from pydantic import BaseModel 6 | 7 | @dataclass 8 | class DocumentChunk: 9 | document_id: str 10 | text: str 11 | vector: List[float] 12 | 13 | @dataclass 14 | class DocumentMetadataFilter: 15 | field_name: str 16 | gte: Optional[int] = None 17 | lte: Optional[int] = None 18 | 19 | @dataclass 20 | class DocumentChunkWithScore(DocumentChunk): 21 | score: float 22 | 23 | @dataclass 24 | class QueryResult: 25 | data: Optional[List[Any]] = None 26 | similarities: Optional[List[float]] = None 27 | ids: Optional[List[str]] = None 28 | 29 | @dataclass 30 | class QueryWithEmbedding: 31 | text: str 32 | vector: List[float] 33 | 34 | @dataclass 35 | class VectorStoreData: 36 | id: str 37 | data: dict 38 | embedding: List[float] 39 | 40 | class VectorStoreQueryMode(str, Enum): 41 | """Vector store query mode.""" 42 | 43 | DEFAULT = "default" 44 | SPARSE = "sparse" 45 | HYBRID = "hybrid" 46 | 47 | @dataclass 48 | class VectorStoreQuery: 49 | """Vector store query.""" 50 | 51 | # dense embedding 52 | query_embedding: Optional[List[float]] = None 53 | similarity_top_k: int = 1 54 | ids: Optional[List[str]] = None 55 | query_str: Optional[str] = None 56 | 57 | # NOTE: current mode 58 | mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT 59 | 60 | # NOTE: only for hybrid search (0 for bm25, 1 for vector search) 61 | alpha: Optional[float] = None 62 | 63 | class VectorStore(ABC): 64 | """Abstract vector store class.""" 65 | 66 | @abstractmethod 67 | def add( 68 | self, 69 | datas: List[VectorStoreData], 70 | ) -> List[str]: 71 | """Add embedding results to vector store.""" 72 | ... 73 | 74 | @abstractmethod 75 | def delete(self, ids: List[str]) -> None: 76 | """Delete doc.""" 77 | ... 78 | 79 | @abstractmethod 80 | def query( 81 | self, 82 | query: VectorStoreQuery, 83 | ) -> QueryResult: 84 | """Query vector store.""" 85 | ... 86 | -------------------------------------------------------------------------------- /tests/integration_test/test_qdrant.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Optional 3 | 4 | import pytest 5 | 6 | from vectordbs.types import ( 7 | DocumentChunk, 8 | DocumentMetadataFilter, 9 | QueryWithEmbedding, 10 | DocumentMetadata, 11 | ) 12 | from vectordbs.providers.qdrant_datastore import QdrantDataStore, QdrantOptions 13 | 14 | def create_document_chunk(id: Optional[str] = None) -> DocumentChunk: 15 | return DocumentChunk( 16 | id=id or str(uuid.uuid4()), 17 | text="sample text", 18 | metadata=DocumentMetadata( 19 | document_id="1", 20 | source="source", 21 | source_id="source_id", 22 | author="author", 23 | ), 24 | embedding=[0.1] * 1536, 25 | ) 26 | 27 | @pytest.fixture 28 | def qdrant_datastore(): 29 | options = QdrantOptions() 30 | datastore = QdrantDataStore(options, recreate_collection=True) 31 | yield datastore 32 | datastore.delete(filter=None) 33 | 34 | 35 | def test_upsert(qdrant_datastore): 36 | document_chunk = create_document_chunk() 37 | result = qdrant_datastore.upsert({document_chunk.id: [document_chunk]}) 38 | assert len(result) == 1 39 | assert result[0] == document_chunk.id 40 | 41 | 42 | def test_query(qdrant_datastore): 43 | document_chunk = create_document_chunk() 44 | qdrant_datastore.upsert({document_chunk.id: [document_chunk]}) 45 | 46 | query = QueryWithEmbedding( 47 | query="test query", 48 | embedding=[0.1] * 1536, 49 | filter=DocumentMetadataFilter(document_id=document_chunk.metadata.document_id), 50 | top_k=5, 51 | ) 52 | 53 | results = qdrant_datastore.query([query]) 54 | assert len(results) == 1 55 | assert len(results[0].results) >= 1 56 | assert results[0].results[0].id == document_chunk.id 57 | 58 | 59 | def test_delete(qdrant_datastore): 60 | document_chunk = create_document_chunk() 61 | qdrant_datastore.upsert({document_chunk.id: [document_chunk]}) 62 | 63 | deleted = qdrant_datastore.delete(ids=[document_chunk.id]) 64 | assert deleted 65 | 66 | query = QueryWithEmbedding( 67 | query="test query", 68 | embedding=[0.1] * 1536, 69 | filter=DocumentMetadataFilter(document_id=document_chunk.metadata.document_id), 70 | top_k=5, 71 | ) 72 | 73 | results = qdrant_datastore.query([query]) 74 | assert len(results) == 1 75 | assert len(results[0].results) == 0 76 | -------------------------------------------------------------------------------- /vectordbs/providers/zilliz_datastore.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional 4 | from pymilvus import ( 5 | connections, 6 | ) 7 | from uuid import uuid4 8 | 9 | from datastore.providers.milvus_datastore import ( 10 | MilvusDataStore, 11 | ) 12 | 13 | 14 | ZILLIZ_COLLECTION = os.environ.get("ZILLIZ_COLLECTION") or "c" + uuid4().hex 15 | ZILLIZ_URI = os.environ.get("ZILLIZ_URI") 16 | ZILLIZ_USER = os.environ.get("ZILLIZ_USER") 17 | ZILLIZ_PASSWORD = os.environ.get("ZILLIZ_PASSWORD") 18 | ZILLIZ_USE_SECURITY = False if ZILLIZ_PASSWORD is None else True 19 | 20 | ZILLIZ_CONSISTENCY_LEVEL = os.environ.get("ZILLIZ_CONSISTENCY_LEVEL") 21 | 22 | class ZillizDataStore(MilvusDataStore): 23 | def __init__(self, create_new: Optional[bool] = False): 24 | """Create a Zilliz DataStore. 25 | 26 | The Zilliz Datastore allows for storing your indexes and metadata within a Zilliz Cloud instance. 27 | 28 | Args: 29 | create_new (Optional[bool], optional): Whether to overwrite if collection already exists. Defaults to True. 30 | """ 31 | # Overwrite the default consistency level by MILVUS_CONSISTENCY_LEVEL 32 | self._consistency_level = ZILLIZ_CONSISTENCY_LEVEL or "Bounded" 33 | self._create_connection() 34 | 35 | self._create_collection(ZILLIZ_COLLECTION, create_new) # type: ignore 36 | self._create_index() 37 | 38 | def _create_connection(self): 39 | # Check if the connection already exists 40 | try: 41 | i = [ 42 | connections.get_connection_addr(x[0]) 43 | for x in connections.list_connections() 44 | ].index({"address": ZILLIZ_URI, "user": ZILLIZ_USER}) 45 | self.alias = connections.list_connections()[i][0] 46 | except ValueError: 47 | # Connect to the Zilliz instance using the passed in Environment variables 48 | self.alias = uuid4().hex 49 | connections.connect(alias=self.alias, uri=ZILLIZ_URI, user=ZILLIZ_USER, password=ZILLIZ_PASSWORD, secure=ZILLIZ_USE_SECURITY) # type: ignore 50 | self._print_info("Connect to zilliz cloud server") 51 | 52 | def _create_index(self): 53 | try: 54 | # If no index on the collection, create one 55 | if len(self.col.indexes) == 0: 56 | self.index_params = {"metric_type": "IP", "index_type": "AUTOINDEX", "params": {}} 57 | self.col.create_index("embedding", index_params=self.index_params) 58 | 59 | self.col.load() 60 | self.search_params = {"metric_type": "IP", "params": {}} 61 | except Exception as e: 62 | self._print_err("Failed to create index, error: {}".format(e)) 63 | 64 | 65 | -------------------------------------------------------------------------------- /vectordbs/providers/pinecone_datastore.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional 3 | import pinecone 4 | import asyncio 5 | from pydantic import BaseSettings, Field 6 | 7 | from vectordbs.types import VectorStore, VectorStoreData, VectorStoreQuery, VectorStoreQueryResult 8 | 9 | # Set the batch size for upserting vectors to Pinecone 10 | UPSERT_BATCH_SIZE = 100 11 | 12 | class PineconeOptions(BaseSettings): 13 | api_key: str = Field(..., env="PINECONE_API_KEY") 14 | environment: str = Field(..., env="PINECONE_ENVIRONMENT") 15 | index: str = Field(..., env="PINECONE_INDEX") 16 | 17 | class PineconeDataStore(VectorStore): 18 | def __init__(self, options: PineconeOptions): 19 | 20 | # Initialize Pinecone with the API key and environment 21 | # NOTE: Do we need a singleton to make sure we only init once? 22 | pinecone.init(api_key=options.api_key, environment=options.environment) 23 | 24 | # Will raise if index doesn't exist 25 | self.index = pinecone.Index(options.index) 26 | 27 | async def _upsert(self, datas: List[VectorStoreData]) -> List[str]: 28 | """ 29 | Takes in a dict from document id to list of document chunks and inserts them into the index. 30 | Return a list of document ids. 31 | """ 32 | # Initialize a list of ids to return 33 | doc_ids: List[str] = [] 34 | # Initialize a list of vectors to upsert 35 | vectors = [] 36 | # Loop through the dict items 37 | for data in datas: 38 | # Append the id to the ids list 39 | doc_ids.append(data.id) 40 | vector = (chunk.id, chunk.embedding, chunk.metadata) 41 | vectors.append(vector) 42 | 43 | # Split the vectors list into batches of the specified size 44 | batches = [ 45 | vectors[i : i + UPSERT_BATCH_SIZE] 46 | for i in range(0, len(vectors), UPSERT_BATCH_SIZE) 47 | ] 48 | # Upsert each batch to Pinecone 49 | for batch in batches: 50 | try: 51 | print(f"Upserting batch of size {len(batch)}") 52 | self.index.upsert(vectors=batch) 53 | print(f"Upserted batch successfully") 54 | except Exception as e: 55 | print(f"Error upserting batch: {e}") 56 | raise e 57 | 58 | return doc_ids 59 | 60 | async def _query( 61 | self, 62 | queries: List[VectorStoreQuery], 63 | ) -> List[VectorStoreQueryResult()]: 64 | """ 65 | Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. 66 | """ 67 | 68 | # Define a helper coroutine that performs a single query and returns a QueryResult 69 | async def _single_query(query: VectorStoreQuery) -> VectorStoreQueryResult(): 70 | #print(f"Query: {query.query}") 71 | 72 | # Convert the metadata filter object to a dict with pinecone filter expressions 73 | pinecone_filter = self._get_pinecone_filter(query.filter) 74 | 75 | try: 76 | # Query the index with the query embedding, filter, and top_k 77 | query_response = self.index.query( 78 | # namespace=namespace, 79 | top_k=query.top_k, 80 | vector=query.embedding, 81 | filter=pinecone_filter, 82 | include_metadata=True, 83 | ) 84 | except Exception as e: 85 | print(f"Error querying index: {e}") 86 | raise e 87 | 88 | query_results: List[DocumentChunkWithScore] = [] 89 | for result in query_response.matches: 90 | score = result.score 91 | metadata = result.metadata 92 | # Remove document id and text from metadata and store it in a new variable 93 | metadata_without_text = ( 94 | {key: value for key, value in metadata.items() if key != "text"} 95 | if metadata 96 | else None 97 | ) 98 | 99 | # If the source is not a valid Source in the Source enum, set it to None 100 | if ( 101 | metadata_without_text 102 | and "source" in metadata_without_text 103 | and metadata_without_text["source"] not in Source.__members__ 104 | ): 105 | metadata_without_text["source"] = None 106 | 107 | # Create a document chunk with score object with the result data 108 | result = DocumentChunkWithScore( 109 | id=result.id, 110 | score=score, 111 | text=metadata["text"] if metadata and "text" in metadata else None, 112 | metadata=metadata_without_text, 113 | ) 114 | query_results.append(result) 115 | return QueryResult(query=query.query, results=query_results) 116 | 117 | # Use asyncio.gather to run multiple _single_query coroutines concurrently and collect their results 118 | results: List[QueryResult] = await asyncio.gather( 119 | *[_single_query(query) for query in queries] 120 | ) 121 | 122 | return results 123 | 124 | async def delete( 125 | self, ids: List[str]) -> bool: 126 | """ 127 | Removes vectors by ids. 128 | """ 129 | try: 130 | print(f"Deleting vectors with ids {ids}") 131 | self.index.delete(ids=ids) # type: ignore 132 | print(f"Deleted vectors with ids successfully") 133 | except Exception as e: 134 | print(f"Error deleting vectors with ids: {e}") 135 | raise e 136 | 137 | return True -------------------------------------------------------------------------------- /vectordbs/providers/qdrant_datastore.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import Dict, List, Optional 4 | 5 | from grpc._channel import _InactiveRpcError 6 | from qdrant_client.http.exceptions import UnexpectedResponse 7 | from qdrant_client.http.models import PayloadSchemaType 8 | 9 | from vectordbs.datastore import DataStore 10 | from vectordbs.types import ( 11 | DocumentChunk, 12 | DocumentMetadataFilter, 13 | QueryResult, 14 | QueryWithEmbedding, 15 | DocumentChunkWithScore, 16 | ) 17 | from qdrant_client.http import models as rest 18 | 19 | import qdrant_client 20 | 21 | from pydantic import BaseSettings, Field 22 | from services.date import to_unix_timestamp 23 | 24 | 25 | class QdrantOptions(BaseSettings): 26 | url: str = Field(..., env="QDRANT_URL", default="http://localhost") 27 | port: int = Field(..., env="QDRANT_PORT", default=6333) 28 | grpc_port: int = Field(..., env="QDRANT_GRPC_PORT", default=6334) 29 | collection: str = Field(..., env="QDRANT_COLLECTION", default="document_chunks") 30 | api_key: str = Field(..., env="QDRANT_API_KEY") 31 | vector_size: int = Field(1536) 32 | distance: str = Field("Cosine") 33 | 34 | class QdrantDataStore(DataStore): 35 | UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") 36 | 37 | def __init__( 38 | self, 39 | options: QdrantOptions, 40 | recreate_collection: bool = False 41 | ): 42 | """ 43 | Args: 44 | collection_name: Name of the collection to be used 45 | vector_size: Size of the embedding stored in a collection 46 | distance: 47 | Any of "Cosine" / "Euclid" / "Dot". Distance function to measure 48 | similarity 49 | """ 50 | self.client = qdrant_client.QdrantClient( 51 | url=options.url, 52 | port=options.port, 53 | grpc_port=options.grpc_port, 54 | api_key=options.api_key, 55 | prefer_grpc=True, 56 | timeout=10, 57 | ) 58 | self.collection_name = options.collection 59 | 60 | # Set up the collection so the points might be inserted or queried 61 | self._set_up_collection(vector_size, distance, recreate_collection) 62 | 63 | async def upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: 64 | """ 65 | Takes in a list of document chunks and inserts them into the database. 66 | Return a list of document ids. 67 | """ 68 | points = [ 69 | self._convert_document_chunk_to_point(chunk) 70 | for _, chunks in chunks.items() 71 | for chunk in chunks 72 | ] 73 | self.client.upsert( 74 | collection_name=self.collection_name, 75 | points=points, # type: ignore 76 | wait=True, 77 | ) 78 | return list(chunks.keys()) 79 | 80 | async def query( 81 | self, 82 | queries: List[QueryWithEmbedding], 83 | ) -> List[QueryResult]: 84 | """ 85 | Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. 86 | """ 87 | search_requests = [ 88 | self._convert_query_to_search_request(query) for query in queries 89 | ] 90 | results = self.client.search_batch( 91 | collection_name=self.collection_name, 92 | requests=search_requests, 93 | ) 94 | return [ 95 | QueryResult( 96 | query=query.query, 97 | results=[ 98 | self._convert_scored_point_to_document_chunk_with_score(point) 99 | for point in result 100 | ], 101 | ) 102 | for query, result in zip(queries, results) 103 | ] 104 | 105 | async def delete( 106 | self, 107 | ids: Optional[List[str]] = None, 108 | filter: Optional[DocumentMetadataFilter] = None, 109 | ) -> bool: 110 | """ 111 | Removes vectors by ids, filter, or everything in the datastore. 112 | Returns whether the operation was successful. 113 | """ 114 | if ids is None and filter is None: 115 | raise ValueError("Please provide one of the parameters: ids or filter.") 116 | 117 | points_selector = self._convert_metadata_filter_to_qdrant_filter(filter, ids) 118 | 119 | response = self.client.delete( 120 | collection_name=self.collection_name, 121 | points_selector=points_selector, # type: ignore 122 | ) 123 | return "COMPLETED" == response.status 124 | 125 | 126 | def _convert_document_chunk_to_point( 127 | self, document_chunk: DocumentChunk 128 | ) -> rest.PointStruct: 129 | created_at = ( 130 | to_unix_timestamp(document_chunk.metadata.created_at) 131 | if document_chunk.metadata.created_at is not None 132 | else None 133 | ) 134 | return rest.PointStruct( 135 | id=self._create_document_chunk_id(document_chunk.id), 136 | vector=document_chunk.embedding, # type: ignore 137 | payload={ 138 | "id": document_chunk.id, 139 | "text": document_chunk.text, 140 | "metadata": document_chunk.metadata.dict(), 141 | "created_at": created_at, 142 | }, 143 | ) 144 | 145 | def _create_document_chunk_id(self, external_id: Optional[str]) -> str: 146 | if external_id is None: 147 | return uuid.uuid4().hex 148 | return uuid.uuid5(self.UUID_NAMESPACE, external_id).hex 149 | 150 | def _convert_query_to_search_request( 151 | self, query: QueryWithEmbedding 152 | ) -> rest.SearchRequest: 153 | return rest.SearchRequest( 154 | vector=query.embedding, 155 | filter=self._convert_metadata_filter_to_qdrant_filter(query.filter), 156 | limit=query.top_k, # type: ignore 157 | with_payload=True, 158 | with_vector=False, 159 | ) 160 | 161 | def _convert_metadata_filter_to_qdrant_filter( 162 | self, 163 | metadata_filter: Optional[DocumentMetadataFilter] = None, 164 | ids: Optional[List[str]] = None, 165 | ) -> Optional[rest.Filter]: 166 | if metadata_filter is None and ids is None: 167 | return None 168 | 169 | must_conditions, should_conditions = [], [] 170 | 171 | # Filtering by document ids 172 | if ids and len(ids) > 0: 173 | for document_id in ids: 174 | should_conditions.append( 175 | rest.FieldCondition( 176 | key="metadata.document_id", 177 | match=rest.MatchValue(value=document_id), 178 | ) 179 | ) 180 | 181 | # Equality filters for the payload attributes 182 | if metadata_filter: 183 | meta_attributes_keys = { 184 | "document_id": "metadata.document_id", 185 | "source": "metadata.source", 186 | "source_id": "metadata.source_id", 187 | "author": "metadata.author", 188 | } 189 | 190 | for meta_attr_name, payload_key in meta_attributes_keys.items(): 191 | attr_value = getattr(metadata_filter, meta_attr_name) 192 | if attr_value is None: 193 | continue 194 | 195 | must_conditions.append( 196 | rest.FieldCondition( 197 | key=payload_key, match=rest.MatchValue(value=attr_value) 198 | ) 199 | ) 200 | 201 | # Date filters use range filtering 202 | start_date = metadata_filter.start_date 203 | end_date = metadata_filter.end_date 204 | if start_date or end_date: 205 | gte_filter = ( 206 | to_unix_timestamp(start_date) if start_date is not None else None 207 | ) 208 | lte_filter = ( 209 | to_unix_timestamp(end_date) if end_date is not None else None 210 | ) 211 | must_conditions.append( 212 | rest.FieldCondition( 213 | key="created_at", 214 | range=rest.Range( 215 | gte=gte_filter, 216 | lte=lte_filter, 217 | ), 218 | ) 219 | ) 220 | if 0 == len(must_conditions) and 0 == len(should_conditions): 221 | return None 222 | return rest.Filter(must=must_conditions, should=should_conditions) 223 | 224 | def _convert_scored_point_to_document_chunk_with_score( 225 | self, scored_point: rest.ScoredPoint 226 | ) -> DocumentChunkWithScore: 227 | payload = scored_point.payload or {} 228 | return DocumentChunkWithScore( 229 | id=payload.get("id"), 230 | text=scored_point.payload.get("text"), # type: ignore 231 | metadata=scored_point.payload.get("metadata"), # type: ignore 232 | embedding=scored_point.vector, # type: ignore 233 | score=scored_point.score, 234 | ) 235 | 236 | def _set_up_collection( 237 | self, vector_size: int, distance: str, recreate_collection: bool 238 | ): 239 | distance = rest.Distance[distance.upper()] 240 | 241 | if recreate_collection: 242 | self._recreate_collection(distance, vector_size) 243 | 244 | try: 245 | collection_info = self.client.get_collection(self.collection_name) 246 | current_distance = collection_info.config.params.vectors.distance # type: ignore 247 | current_vector_size = collection_info.config.params.vectors.size # type: ignore 248 | 249 | if current_distance != distance: 250 | raise ValueError( 251 | f"Collection '{self.collection_name}' already exists in Qdrant, " 252 | f"but it is configured with a similarity '{current_distance.name}'. " 253 | f"If you want to use that collection, but with a different " 254 | f"similarity, please set `recreate_collection=True` argument." 255 | ) 256 | 257 | if current_vector_size != vector_size: 258 | raise ValueError( 259 | f"Collection '{self.collection_name}' already exists in Qdrant, " 260 | f"but it is configured with a vector size '{current_vector_size}'. " 261 | f"If you want to use that collection, but with a different " 262 | f"vector size, please set `recreate_collection=True` argument." 263 | ) 264 | except (UnexpectedResponse, _InactiveRpcError): 265 | self._recreate_collection(distance, vector_size) 266 | 267 | def _recreate_collection(self, distance: rest.Distance, vector_size: int): 268 | self.client.recreate_collection( 269 | self.collection_name, 270 | vectors_config=rest.VectorParams( 271 | size=vector_size, 272 | distance=distance, 273 | ), 274 | ) 275 | 276 | # Create the payload index for the document_id metadata attribute, as it is 277 | # used to delete the document related entries 278 | self.client.create_payload_index( 279 | self.collection_name, 280 | field_name="metadata.document_id", 281 | field_type=PayloadSchemaType.KEYWORD, 282 | ) 283 | self.client.create_payload_index( 284 | self.collection_name, 285 | field_name="created_at", 286 | field_schema=PayloadSchemaType.INTEGER, 287 | ) 288 | 289 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /vectordbs/providers/weaviate_datastore.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | import asyncio 3 | from typing import Dict, List, Optional 4 | from loguru import logger 5 | from weaviate import Client 6 | import weaviate 7 | import os 8 | import uuid 9 | 10 | from weaviate.util import generate_uuid5 11 | 12 | from datastore.datastore import DataStore 13 | from models.models import ( 14 | DocumentChunk, 15 | DocumentChunkMetadata, 16 | DocumentMetadataFilter, 17 | QueryResult, 18 | QueryWithEmbedding, 19 | DocumentChunkWithScore, 20 | Source, 21 | ) 22 | 23 | 24 | WEAVIATE_HOST = os.environ.get("WEAVIATE_HOST", "http://127.0.0.1") 25 | WEAVIATE_PORT = os.environ.get("WEAVIATE_PORT", "8080") 26 | WEAVIATE_USERNAME = os.environ.get("WEAVIATE_USERNAME", None) 27 | WEAVIATE_PASSWORD = os.environ.get("WEAVIATE_PASSWORD", None) 28 | WEAVIATE_SCOPES = os.environ.get("WEAVIATE_SCOPES", None) 29 | WEAVIATE_INDEX = os.environ.get("WEAVIATE_INDEX", "OpenAIDocument") 30 | 31 | WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20)) 32 | WEAVIATE_BATCH_DYNAMIC = os.environ.get("WEAVIATE_BATCH_DYNAMIC", False) 33 | WEAVIATE_BATCH_TIMEOUT_RETRIES = int(os.environ.get("WEAVIATE_TIMEOUT_RETRIES", 3)) 34 | WEAVIATE_BATCH_NUM_WORKERS = int(os.environ.get("WEAVIATE_BATCH_NUM_WORKERS", 1)) 35 | 36 | SCHEMA = { 37 | "class": WEAVIATE_INDEX, 38 | "description": "The main class", 39 | "properties": [ 40 | { 41 | "name": "chunk_id", 42 | "dataType": ["string"], 43 | "description": "The chunk id", 44 | }, 45 | { 46 | "name": "document_id", 47 | "dataType": ["string"], 48 | "description": "The document id", 49 | }, 50 | { 51 | "name": "text", 52 | "dataType": ["text"], 53 | "description": "The chunk's text", 54 | }, 55 | { 56 | "name": "source", 57 | "dataType": ["string"], 58 | "description": "The source of the data", 59 | }, 60 | { 61 | "name": "source_id", 62 | "dataType": ["string"], 63 | "description": "The source id", 64 | }, 65 | { 66 | "name": "url", 67 | "dataType": ["string"], 68 | "description": "The source url", 69 | }, 70 | { 71 | "name": "created_at", 72 | "dataType": ["date"], 73 | "description": "Creation date of document", 74 | }, 75 | { 76 | "name": "author", 77 | "dataType": ["string"], 78 | "description": "Document author", 79 | }, 80 | ], 81 | } 82 | 83 | 84 | def extract_schema_properties(schema): 85 | properties = schema["properties"] 86 | 87 | return {property["name"] for property in properties} 88 | 89 | 90 | class WeaviateDataStore(DataStore): 91 | def handle_errors(self, results: Optional[List[dict]]) -> List[str]: 92 | if not self or not results: 93 | return [] 94 | 95 | error_messages = [] 96 | for result in results: 97 | if ( 98 | "result" not in result 99 | or "errors" not in result["result"] 100 | or "error" not in result["result"]["errors"] 101 | ): 102 | continue 103 | for message in result["result"]["errors"]["error"]: 104 | error_messages.append(message["message"]) 105 | logger.exception(message["message"]) 106 | 107 | return error_messages 108 | 109 | def __init__(self): 110 | auth_credentials = self._build_auth_credentials() 111 | 112 | url = f"{WEAVIATE_HOST}:{WEAVIATE_PORT}" 113 | 114 | logger.debug( 115 | f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}" 116 | ) 117 | self.client = Client(url, auth_client_secret=auth_credentials) 118 | self.client.batch.configure( 119 | batch_size=WEAVIATE_BATCH_SIZE, 120 | dynamic=WEAVIATE_BATCH_DYNAMIC, # type: ignore 121 | callback=self.handle_errors, # type: ignore 122 | timeout_retries=WEAVIATE_BATCH_TIMEOUT_RETRIES, 123 | num_workers=WEAVIATE_BATCH_NUM_WORKERS, 124 | ) 125 | 126 | if self.client.schema.contains(SCHEMA): 127 | current_schema = self.client.schema.get(WEAVIATE_INDEX) 128 | current_schema_properties = extract_schema_properties(current_schema) 129 | 130 | logger.debug( 131 | f"Found index {WEAVIATE_INDEX} with properties {current_schema_properties}" 132 | ) 133 | logger.debug("Will reuse this schema") 134 | else: 135 | new_schema_properties = extract_schema_properties(SCHEMA) 136 | logger.debug( 137 | f"Creating index {WEAVIATE_INDEX} with properties {new_schema_properties}" 138 | ) 139 | self.client.schema.create_class(SCHEMA) 140 | 141 | @staticmethod 142 | def _build_auth_credentials(): 143 | if WEAVIATE_USERNAME and WEAVIATE_PASSWORD: 144 | return weaviate.auth.AuthClientPassword( 145 | WEAVIATE_USERNAME, WEAVIATE_PASSWORD, WEAVIATE_SCOPES 146 | ) 147 | else: 148 | return None 149 | 150 | async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: 151 | """ 152 | Takes in a list of list of document chunks and inserts them into the database. 153 | Return a list of document ids. 154 | """ 155 | doc_ids = [] 156 | 157 | with self.client.batch as batch: 158 | for doc_id, doc_chunks in chunks.items(): 159 | logger.debug(f"Upserting {doc_id} with {len(doc_chunks)} chunks") 160 | for doc_chunk in doc_chunks: 161 | # we generate a uuid regardless of the format of the document_id because 162 | # weaviate needs a uuid to store each document chunk and 163 | # a document chunk cannot share the same uuid 164 | doc_uuid = generate_uuid5(doc_chunk, WEAVIATE_INDEX) 165 | metadata = doc_chunk.metadata 166 | doc_chunk_dict = doc_chunk.dict() 167 | doc_chunk_dict.pop("metadata") 168 | for key, value in metadata.dict().items(): 169 | doc_chunk_dict[key] = value 170 | doc_chunk_dict["chunk_id"] = doc_chunk_dict.pop("id") 171 | doc_chunk_dict["source"] = ( 172 | doc_chunk_dict.pop("source").value 173 | if doc_chunk_dict["source"] 174 | else None 175 | ) 176 | embedding = doc_chunk_dict.pop("embedding") 177 | 178 | batch.add_data_object( 179 | uuid=doc_uuid, 180 | data_object=doc_chunk_dict, 181 | class_name=WEAVIATE_INDEX, 182 | vector=embedding, 183 | ) 184 | 185 | doc_ids.append(doc_id) 186 | batch.flush() 187 | return doc_ids 188 | 189 | async def _query( 190 | self, 191 | queries: List[QueryWithEmbedding], 192 | ) -> List[QueryResult]: 193 | """ 194 | Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. 195 | """ 196 | 197 | async def _single_query(query: QueryWithEmbedding) -> QueryResult: 198 | logger.debug(f"Query: {query.query}") 199 | if not hasattr(query, "filter") or not query.filter: 200 | result = ( 201 | self.client.query.get( 202 | WEAVIATE_INDEX, 203 | [ 204 | "chunk_id", 205 | "document_id", 206 | "text", 207 | "source", 208 | "source_id", 209 | "url", 210 | "created_at", 211 | "author", 212 | ], 213 | ) 214 | .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) 215 | .with_limit(query.top_k) # type: ignore 216 | .with_additional(["score", "vector"]) 217 | .do() 218 | ) 219 | else: 220 | filters_ = self.build_filters(query.filter) 221 | result = ( 222 | self.client.query.get( 223 | WEAVIATE_INDEX, 224 | [ 225 | "chunk_id", 226 | "document_id", 227 | "text", 228 | "source", 229 | "source_id", 230 | "url", 231 | "created_at", 232 | "author", 233 | ], 234 | ) 235 | .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) 236 | .with_where(filters_) 237 | .with_limit(query.top_k) # type: ignore 238 | .with_additional(["score", "vector"]) 239 | .do() 240 | ) 241 | 242 | query_results: List[DocumentChunkWithScore] = [] 243 | response = result["data"]["Get"][WEAVIATE_INDEX] 244 | 245 | for resp in response: 246 | result = DocumentChunkWithScore( 247 | id=resp["chunk_id"], 248 | text=resp["text"], 249 | embedding=resp["_additional"]["vector"], 250 | score=resp["_additional"]["score"], 251 | metadata=DocumentChunkMetadata( 252 | document_id=resp["document_id"] if resp["document_id"] else "", 253 | source=Source(resp["source"]), 254 | source_id=resp["source_id"], 255 | url=resp["url"], 256 | created_at=resp["created_at"], 257 | author=resp["author"], 258 | ), 259 | ) 260 | query_results.append(result) 261 | return QueryResult(query=query.query, results=query_results) 262 | 263 | return await asyncio.gather(*[_single_query(query) for query in queries]) 264 | 265 | async def delete( 266 | self, 267 | ids: Optional[List[str]] = None, 268 | filter: Optional[DocumentMetadataFilter] = None, 269 | delete_all: Optional[bool] = None, 270 | ) -> bool: 271 | # TODO 272 | """ 273 | Removes vectors by ids, filter, or everything in the datastore. 274 | Returns whether the operation was successful. 275 | """ 276 | if delete_all: 277 | logger.debug(f"Deleting all vectors in index {WEAVIATE_INDEX}") 278 | self.client.schema.delete_all() 279 | return True 280 | 281 | if ids: 282 | operands = [ 283 | {"path": ["document_id"], "operator": "Equal", "valueString": id} 284 | for id in ids 285 | ] 286 | 287 | where_clause = {"operator": "Or", "operands": operands} 288 | 289 | logger.debug(f"Deleting vectors from index {WEAVIATE_INDEX} with ids {ids}") 290 | result = self.client.batch.delete_objects( 291 | class_name=WEAVIATE_INDEX, where=where_clause, output="verbose" 292 | ) 293 | 294 | if not bool(result["results"]["successful"]): 295 | logger.debug( 296 | f"Failed to delete the following objects: {result['results']['objects']}" 297 | ) 298 | 299 | if filter: 300 | where_clause = self.build_filters(filter) 301 | 302 | logger.debug( 303 | f"Deleting vectors from index {WEAVIATE_INDEX} with filter {where_clause}" 304 | ) 305 | result = self.client.batch.delete_objects( 306 | class_name=WEAVIATE_INDEX, where=where_clause 307 | ) 308 | 309 | if not bool(result["results"]["successful"]): 310 | logger.debug( 311 | f"Failed to delete the following objects: {result['results']['objects']}" 312 | ) 313 | 314 | return True 315 | 316 | @staticmethod 317 | def build_filters(filter): 318 | if filter.source: 319 | filter.source = filter.source.value 320 | 321 | operands = [] 322 | filter_conditions = { 323 | "source": { 324 | "operator": "Equal", 325 | "value": "query.filter.source.value", 326 | "value_key": "valueString", 327 | }, 328 | "start_date": {"operator": "GreaterThanEqual", "value_key": "valueDate"}, 329 | "end_date": {"operator": "LessThanEqual", "value_key": "valueDate"}, 330 | "default": {"operator": "Equal", "value_key": "valueString"}, 331 | } 332 | 333 | for attr, value in filter.__dict__.items(): 334 | if value is not None: 335 | filter_condition = filter_conditions.get( 336 | attr, filter_conditions["default"] 337 | ) 338 | value_key = filter_condition["value_key"] 339 | 340 | operand = { 341 | "path": [ 342 | attr 343 | if not (attr == "start_date" or attr == "end_date") 344 | else "created_at" 345 | ], 346 | "operator": filter_condition["operator"], 347 | value_key: value, 348 | } 349 | 350 | operands.append(operand) 351 | 352 | return {"operator": "And", "operands": operands} 353 | 354 | @staticmethod 355 | def _is_valid_weaviate_id(candidate_id: str) -> bool: 356 | """ 357 | Check if candidate_id is a valid UUID for weaviate's use 358 | 359 | Weaviate supports UUIDs of version 3, 4 and 5. This function checks if the candidate_id is a valid UUID of one of these versions. 360 | See https://weaviate.io/developers/weaviate/more-resources/faq#q-are-there-restrictions-on-uuid-formatting-do-i-have-to-adhere-to-any-standards 361 | for more information. 362 | """ 363 | acceptable_version = [3, 4, 5] 364 | 365 | try: 366 | result = uuid.UUID(candidate_id) 367 | if result.version not in acceptable_version: 368 | return False 369 | else: 370 | return True 371 | except ValueError: 372 | return False 373 | -------------------------------------------------------------------------------- /vectordbs/providers/redis_datastore.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | import re 5 | import json 6 | import redis.asyncio as redis 7 | import numpy as np 8 | 9 | from redis.commands.search.query import Query as RediSearchQuery 10 | from redis.commands.search.indexDefinition import IndexDefinition, IndexType 11 | from redis.commands.search.field import ( 12 | TagField, 13 | TextField, 14 | NumericField, 15 | VectorField, 16 | ) 17 | from typing import Dict, List, Optional 18 | from datastore.datastore import DataStore 19 | from models.models import ( 20 | DocumentChunk, 21 | DocumentMetadataFilter, 22 | DocumentChunkWithScore, 23 | DocumentMetadataFilter, 24 | QueryResult, 25 | QueryWithEmbedding, 26 | ) 27 | from services.date import to_unix_timestamp 28 | 29 | # Read environment variables for Redis 30 | REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") 31 | REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) 32 | REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") 33 | REDIS_INDEX_NAME = os.environ.get("REDIS_INDEX_NAME", "index") 34 | REDIS_DOC_PREFIX = os.environ.get("REDIS_DOC_PREFIX", "doc") 35 | REDIS_DISTANCE_METRIC = os.environ.get("REDIS_DISTANCE_METRIC", "COSINE") 36 | REDIS_INDEX_TYPE = os.environ.get("REDIS_INDEX_TYPE", "FLAT") 37 | assert REDIS_INDEX_TYPE in ("FLAT", "HNSW") 38 | 39 | # OpenAI Ada Embeddings Dimension 40 | VECTOR_DIMENSION = 1536 41 | 42 | # RediSearch constants 43 | REDIS_REQUIRED_MODULES = [ 44 | {"name": "search", "ver": 20600}, 45 | {"name": "ReJSON", "ver": 20404} 46 | ] 47 | REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]") 48 | REDIS_SEARCH_SCHEMA = { 49 | "document_id": TagField("$.document_id", as_name="document_id"), 50 | "metadata": { 51 | # "source_id": TagField("$.metadata.source_id", as_name="source_id"), 52 | "source": TagField("$.metadata.source", as_name="source"), 53 | # "author": TextField("$.metadata.author", as_name="author"), 54 | # "created_at": NumericField("$.metadata.created_at", as_name="created_at"), 55 | }, 56 | "embedding": VectorField( 57 | "$.embedding", 58 | REDIS_INDEX_TYPE, 59 | { 60 | "TYPE": "FLOAT64", 61 | "DIM": VECTOR_DIMENSION, 62 | "DISTANCE_METRIC": REDIS_DISTANCE_METRIC, 63 | }, 64 | as_name="embedding", 65 | ), 66 | } 67 | 68 | # Helper functions 69 | def unpack_schema(d: dict): 70 | for v in d.values(): 71 | if isinstance(v, dict): 72 | yield from unpack_schema(v) 73 | else: 74 | yield v 75 | 76 | async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]): 77 | 78 | installed_modules = (await client.info()).get("modules", []) 79 | installed_modules = {module["name"]: module for module in installed_modules} 80 | for module in modules: 81 | if module["name"] not in installed_modules or int(installed_modules[module["name"]]["ver"]) < int(module["ver"]): 82 | error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \ 83 | "Please refer to Redis Stack docs: https://redis.io/docs/stack/" 84 | logging.error(error_message) 85 | raise ValueError(error_message) 86 | 87 | 88 | 89 | class RedisDataStore(DataStore): 90 | def __init__(self, client: redis.Redis): 91 | self.client = client 92 | # Init default metadata with sentinel values in case the document written has no metadata 93 | self._default_metadata = { 94 | field: "_null_" for field in REDIS_SEARCH_SCHEMA["metadata"] 95 | } 96 | 97 | ### Redis Helper Methods ### 98 | 99 | @classmethod 100 | async def init(cls): 101 | """ 102 | Setup the index if it does not exist. 103 | """ 104 | try: 105 | # Connect to the Redis Client 106 | logging.info("Connecting to Redis") 107 | client = redis.Redis( 108 | host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD 109 | ) 110 | except Exception as e: 111 | logging.error(f"Error setting up Redis: {e}") 112 | raise e 113 | 114 | await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES) 115 | 116 | try: 117 | # Check for existence of RediSearch Index 118 | await client.ft(REDIS_INDEX_NAME).info() 119 | logging.info(f"RediSearch index {REDIS_INDEX_NAME} already exists") 120 | except: 121 | # Create the RediSearch Index 122 | logging.info(f"Creating new RediSearch index {REDIS_INDEX_NAME}") 123 | definition = IndexDefinition( 124 | prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON 125 | ) 126 | fields = list(unpack_schema(REDIS_SEARCH_SCHEMA)) 127 | await client.ft(REDIS_INDEX_NAME).create_index( 128 | fields=fields, definition=definition 129 | ) 130 | return cls(client) 131 | 132 | @staticmethod 133 | def _redis_key(document_id: str, chunk_id: str) -> str: 134 | """ 135 | Create the JSON key for document chunks in Redis. 136 | 137 | Args: 138 | document_id (str): Document Identifier 139 | chunk_id (str): Chunk Identifier 140 | 141 | Returns: 142 | str: JSON key string. 143 | """ 144 | return f"doc:{document_id}:chunk:{chunk_id}" 145 | 146 | @staticmethod 147 | def _escape(value: str) -> str: 148 | """ 149 | Escape filter value. 150 | 151 | Args: 152 | value (str): Value to escape. 153 | 154 | Returns: 155 | str: Escaped filter value for RediSearch. 156 | """ 157 | 158 | def escape_symbol(match) -> str: 159 | value = match.group(0) 160 | return f"\\{value}" 161 | 162 | return REDIS_DEFAULT_ESCAPED_CHARS.sub(escape_symbol, value) 163 | 164 | def _get_redis_chunk(self, chunk: DocumentChunk) -> dict: 165 | """ 166 | Convert DocumentChunk into a JSON object for storage 167 | in Redis. 168 | 169 | Args: 170 | chunk (DocumentChunk): Chunk of a Document. 171 | 172 | Returns: 173 | dict: JSON object for storage in Redis. 174 | """ 175 | # Convert chunk -> dict 176 | data = chunk.__dict__ 177 | metadata = chunk.metadata.__dict__ 178 | data["chunk_id"] = data.pop("id") 179 | 180 | # Prep Redis Metadata 181 | redis_metadata = dict(self._default_metadata) 182 | if metadata: 183 | for field, value in metadata.items(): 184 | if value: 185 | if field == "created_at": 186 | redis_metadata[field] = to_unix_timestamp(value) # type: ignore 187 | else: 188 | redis_metadata[field] = value 189 | data["metadata"] = redis_metadata 190 | return data 191 | 192 | def _get_redis_query(self, query: QueryWithEmbedding) -> RediSearchQuery: 193 | """ 194 | Convert a QueryWithEmbedding into a RediSearchQuery. 195 | 196 | Args: 197 | query (QueryWithEmbedding): Search query. 198 | 199 | Returns: 200 | RediSearchQuery: Query for RediSearch. 201 | """ 202 | filter_str: str = "" 203 | 204 | # RediSearch field type to query string 205 | def _typ_to_str(typ, field, value) -> str: # type: ignore 206 | if isinstance(typ, TagField): 207 | return f"@{field}:{{{self._escape(value)}}} " 208 | elif isinstance(typ, TextField): 209 | return f"@{field}:{self._escape(value)} " 210 | elif isinstance(typ, NumericField): 211 | num = to_unix_timestamp(value) 212 | match field: 213 | case "start_date": 214 | return f"@{field}:[{num} +inf] " 215 | case "end_date": 216 | return f"@{field}:[-inf {num}] " 217 | 218 | # Build filter 219 | if query.filter: 220 | for field, value in query.filter.__dict__.items(): 221 | if not value: 222 | continue 223 | if field in REDIS_SEARCH_SCHEMA: 224 | filter_str += _typ_to_str(REDIS_SEARCH_SCHEMA[field], field, value) 225 | elif field in REDIS_SEARCH_SCHEMA["metadata"]: 226 | if field == "source": # handle the enum 227 | value = value.value 228 | filter_str += _typ_to_str( 229 | REDIS_SEARCH_SCHEMA["metadata"][field], field, value 230 | ) 231 | elif field in ["start_date", "end_date"]: 232 | filter_str += _typ_to_str( 233 | REDIS_SEARCH_SCHEMA["metadata"]["created_at"], field, value 234 | ) 235 | 236 | # Postprocess filter string 237 | filter_str = filter_str.strip() 238 | filter_str = filter_str if filter_str else "*" 239 | 240 | # Prepare query string 241 | query_str = ( 242 | f"({filter_str})=>[KNN {query.top_k} @embedding $embedding as score]" 243 | ) 244 | return ( 245 | RediSearchQuery(query_str) 246 | .sort_by("score") 247 | .paging(0, query.top_k) 248 | .dialect(2) 249 | ) 250 | 251 | async def _redis_delete(self, keys: List[str]): 252 | """ 253 | Delete a list of keys from Redis. 254 | 255 | Args: 256 | keys (List[str]): List of keys to delete. 257 | """ 258 | # Delete the keys 259 | await asyncio.gather(*[self.client.delete(key) for key in keys]) 260 | 261 | ####### 262 | 263 | async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: 264 | """ 265 | Takes in a list of list of document chunks and inserts them into the database. 266 | Return a list of document ids. 267 | """ 268 | # Initialize a list of ids to return 269 | doc_ids: List[str] = [] 270 | 271 | # Loop through the dict items 272 | for doc_id, chunk_list in chunks.items(): 273 | 274 | # Append the id to the ids list 275 | doc_ids.append(doc_id) 276 | 277 | # Write chunks in a pipelines 278 | async with self.client.pipeline(transaction=False) as pipe: 279 | for chunk in chunk_list: 280 | key = self._redis_key(doc_id, chunk.id) 281 | data = self._get_redis_chunk(chunk) 282 | await pipe.json().set(key, "$", data) 283 | await pipe.execute() 284 | 285 | return doc_ids 286 | 287 | async def _query( 288 | self, 289 | queries: List[QueryWithEmbedding], 290 | ) -> List[QueryResult]: 291 | """ 292 | Takes in a list of queries with embeddings and filters and 293 | returns a list of query results with matching document chunks and scores. 294 | """ 295 | # Prepare query responses and results object 296 | results: List[QueryResult] = [] 297 | 298 | # Gather query results in a pipeline 299 | logging.info(f"Gathering {len(queries)} query results", flush=True) 300 | for query in queries: 301 | 302 | logging.info(f"Query: {query.query}") 303 | query_results: List[DocumentChunkWithScore] = [] 304 | 305 | # Extract Redis query 306 | redis_query: RediSearchQuery = self._get_redis_query(query) 307 | embedding = np.array(query.embedding, dtype=np.float64).tobytes() 308 | 309 | # Perform vector search 310 | query_response = await self.client.ft(REDIS_INDEX_NAME).search( 311 | redis_query, {"embedding": embedding} 312 | ) 313 | 314 | # Iterate through the most similar documents 315 | for doc in query_response.docs: 316 | # Load JSON data 317 | doc_json = json.loads(doc.json) 318 | # Create document chunk object with score 319 | result = DocumentChunkWithScore( 320 | id=doc_json["metadata"]["document_id"], 321 | score=doc.score, 322 | text=doc_json["text"], 323 | metadata=doc_json["metadata"] 324 | ) 325 | query_results.append(result) 326 | 327 | # Add to overall results 328 | results.append(QueryResult(query=query.query, results=query_results)) 329 | 330 | return results 331 | 332 | async def _find_keys(self, pattern: str) -> List[str]: 333 | return [key async for key in self.client.scan_iter(pattern)] 334 | 335 | async def delete( 336 | self, 337 | ids: Optional[List[str]] = None, 338 | filter: Optional[DocumentMetadataFilter] = None, 339 | delete_all: Optional[bool] = None, 340 | ) -> bool: 341 | """ 342 | Removes vectors by ids, filter, or everything in the datastore. 343 | Returns whether the operation was successful. 344 | """ 345 | # Delete all vectors from the index if delete_all is True 346 | if delete_all: 347 | try: 348 | logging.info(f"Deleting all documents from index") 349 | await self.client.ft(REDIS_INDEX_NAME).dropindex(True) 350 | logging.info(f"Deleted all documents successfully") 351 | return True 352 | except Exception as e: 353 | logging.info(f"Error deleting all documents: {e}") 354 | raise e 355 | 356 | # Delete by filter 357 | if filter: 358 | # TODO - extend this to work with other metadata filters? 359 | if filter.document_id: 360 | try: 361 | keys = await self._find_keys( 362 | f"{REDIS_DOC_PREFIX}:{filter.document_id}:*" 363 | ) 364 | await self._redis_delete(keys) 365 | logging.info(f"Deleted document {filter.document_id} successfully") 366 | except Exception as e: 367 | logging.info(f"Error deleting document {filter.document_id}: {e}") 368 | raise e 369 | 370 | # Delete by explicit ids (Redis keys) 371 | if ids: 372 | try: 373 | logging.info(f"Deleting document ids {ids}") 374 | keys = [] 375 | # find all keys associated with the document ids 376 | for document_id in ids: 377 | doc_keys = await self._find_keys( 378 | pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*" 379 | ) 380 | keys.extend(doc_keys) 381 | # delete all keys 382 | logging.info(f"Deleting {len(keys)} keys from Redis") 383 | await self._redis_delete(keys) 384 | except Exception as e: 385 | logging.info(f"Error deleting ids: {e}") 386 | raise e 387 | 388 | return True 389 | -------------------------------------------------------------------------------- /vectordbs/providers/milvus_datastore.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import asyncio 4 | 5 | from typing import Dict, List, Optional 6 | from pymilvus import ( 7 | Collection, 8 | connections, 9 | utility, 10 | FieldSchema, 11 | DataType, 12 | CollectionSchema, 13 | MilvusException, 14 | ) 15 | from uuid import uuid4 16 | 17 | 18 | from services.date import to_unix_timestamp 19 | from datastore.datastore import DataStore 20 | from models.models import ( 21 | DocumentChunk, 22 | DocumentChunkMetadata, 23 | Source, 24 | DocumentMetadataFilter, 25 | QueryResult, 26 | QueryWithEmbedding, 27 | DocumentChunkWithScore, 28 | ) 29 | 30 | MILVUS_COLLECTION = os.environ.get("MILVUS_COLLECTION") or "c" + uuid4().hex 31 | MILVUS_HOST = os.environ.get("MILVUS_HOST") or "localhost" 32 | MILVUS_PORT = os.environ.get("MILVUS_PORT") or 19530 33 | MILVUS_USER = os.environ.get("MILVUS_USER") 34 | MILVUS_PASSWORD = os.environ.get("MILVUS_PASSWORD") 35 | MILVUS_USE_SECURITY = False if MILVUS_PASSWORD is None else True 36 | 37 | MILVUS_INDEX_PARAMS = os.environ.get("MILVUS_INDEX_PARAMS") 38 | MILVUS_SEARCH_PARAMS = os.environ.get("MILVUS_SEARCH_PARAMS") 39 | MILVUS_CONSISTENCY_LEVEL = os.environ.get("MILVUS_CONSISTENCY_LEVEL") 40 | 41 | UPSERT_BATCH_SIZE = 100 42 | OUTPUT_DIM = 1536 43 | EMBEDDING_FIELD = "embedding" 44 | 45 | 46 | class Required: 47 | pass 48 | 49 | # The fields names that we are going to be storing within Milvus, the field declaration for schema creation, and the default value 50 | SCHEMA_V1 = [ 51 | ( 52 | "pk", 53 | FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), 54 | Required, 55 | ), 56 | ( 57 | EMBEDDING_FIELD, 58 | FieldSchema(name=EMBEDDING_FIELD, dtype=DataType.FLOAT_VECTOR, dim=OUTPUT_DIM), 59 | Required, 60 | ), 61 | ( 62 | "text", 63 | FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), 64 | Required, 65 | ), 66 | ( 67 | "document_id", 68 | FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=65535), 69 | "", 70 | ), 71 | ( 72 | "source_id", 73 | FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=65535), 74 | "", 75 | ), 76 | ( 77 | "id", 78 | FieldSchema( 79 | name="id", 80 | dtype=DataType.VARCHAR, 81 | max_length=65535, 82 | ), 83 | "", 84 | ), 85 | ( 86 | "source", 87 | FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=65535), 88 | "", 89 | ), 90 | ("url", FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=65535), ""), 91 | ("created_at", FieldSchema(name="created_at", dtype=DataType.INT64), -1), 92 | ( 93 | "author", 94 | FieldSchema(name="author", dtype=DataType.VARCHAR, max_length=65535), 95 | "", 96 | ), 97 | ] 98 | 99 | # V2 schema, remomve the "pk" field 100 | SCHEMA_V2 = SCHEMA_V1[1:] 101 | SCHEMA_V2[4][1].is_primary = True 102 | 103 | 104 | class MilvusDataStore(DataStore): 105 | def __init__( 106 | self, 107 | create_new: Optional[bool] = False, 108 | consistency_level: str = "Bounded", 109 | ): 110 | """Create a Milvus DataStore. 111 | 112 | The Milvus Datastore allows for storing your indexes and metadata within a Milvus instance. 113 | 114 | Args: 115 | create_new (Optional[bool], optional): Whether to overwrite if collection already exists. Defaults to True. 116 | consistency_level(str, optional): Specify the collection consistency level. 117 | Defaults to "Bounded" for search performance. 118 | Set to "Strong" in test cases for result validation. 119 | """ 120 | # Overwrite the default consistency level by MILVUS_CONSISTENCY_LEVEL 121 | self._consistency_level = MILVUS_CONSISTENCY_LEVEL or consistency_level 122 | self._create_connection() 123 | 124 | self._create_collection(MILVUS_COLLECTION, create_new) # type: ignore 125 | self._create_index() 126 | 127 | def _print_info(self, msg): 128 | # TODO: logger 129 | print(msg) 130 | 131 | def _print_err(self, msg): 132 | # TODO: logger 133 | print(msg) 134 | 135 | def _get_schema(self): 136 | return SCHEMA_V1 if self._schema_ver == "V1" else SCHEMA_V2 137 | 138 | def _create_connection(self): 139 | try: 140 | self.alias = "" 141 | # Check if the connection already exists 142 | for x in connections.list_connections(): 143 | addr = connections.get_connection_addr(x[0]) 144 | if x[1] and ('address' in addr) and (addr['address'] == "{}:{}".format(MILVUS_HOST, MILVUS_PORT)): 145 | self.alias = x[0] 146 | self._print_info("Reuse connection to Milvus server '{}:{}' with alias '{:s}'" 147 | .format(MILVUS_HOST, MILVUS_PORT, self.alias)) 148 | break 149 | 150 | # Connect to the Milvus instance using the passed in Environment variables 151 | if len(self.alias) == 0: 152 | self.alias = uuid4().hex 153 | connections.connect( 154 | alias=self.alias, 155 | host=MILVUS_HOST, 156 | port=MILVUS_PORT, 157 | user=MILVUS_USER, # type: ignore 158 | password=MILVUS_PASSWORD, # type: ignore 159 | secure=MILVUS_USE_SECURITY, 160 | ) 161 | self._print_info("Create connection to Milvus server '{}:{}' with alias '{:s}'" 162 | .format(MILVUS_HOST, MILVUS_PORT, self.alias)) 163 | except Exception as e: 164 | self._print_err("Failed to create connection to Milvus server '{}:{}', error: {}" 165 | .format(MILVUS_HOST, MILVUS_PORT, e)) 166 | 167 | def _create_collection(self, collection_name, create_new: bool) -> None: 168 | """Create a collection based on environment and passed in variables. 169 | 170 | Args: 171 | create_new (bool): Whether to overwrite if collection already exists. 172 | """ 173 | try: 174 | self._schema_ver = "V1" 175 | # If the collection exists and create_new is True, drop the existing collection 176 | if utility.has_collection(collection_name, using=self.alias) and create_new: 177 | utility.drop_collection(collection_name, using=self.alias) 178 | 179 | # Check if the collection doesnt exist 180 | if utility.has_collection(collection_name, using=self.alias) is False: 181 | # If it doesnt exist use the field params from init to create a new schem 182 | schema = [field[1] for field in SCHEMA_V2] 183 | schema = CollectionSchema(schema) 184 | # Use the schema to create a new collection 185 | self.col = Collection( 186 | collection_name, 187 | schema=schema, 188 | using=self.alias, 189 | consistency_level=self._consistency_level, 190 | ) 191 | self._schema_ver = "V2" 192 | self._print_info("Create Milvus collection '{}' with schema {} and consistency level {}" 193 | .format(collection_name, self._schema_ver, self._consistency_level)) 194 | else: 195 | # If the collection exists, point to it 196 | self.col = Collection( 197 | collection_name, using=self.alias 198 | ) # type: ignore 199 | # Which sechma is used 200 | for field in self.col.schema.fields: 201 | if field.name == "id" and field.is_primary: 202 | self._schema_ver = "V2" 203 | break 204 | self._print_info("Milvus collection '{}' already exists with schema {}" 205 | .format(collection_name, self._schema_ver)) 206 | except Exception as e: 207 | self._print_err("Failed to create collection '{}', error: {}".format(collection_name, e)) 208 | 209 | def _create_index(self): 210 | # TODO: verify index/search params passed by os.environ 211 | self.index_params = MILVUS_INDEX_PARAMS or None 212 | self.search_params = MILVUS_SEARCH_PARAMS or None 213 | try: 214 | # If no index on the collection, create one 215 | if len(self.col.indexes) == 0: 216 | if self.index_params is not None: 217 | # Convert the string format to JSON format parameters passed by MILVUS_INDEX_PARAMS 218 | self.index_params = json.loads(self.index_params) 219 | self._print_info("Create Milvus index: {}".format(self.index_params)) 220 | # Create an index on the 'embedding' field with the index params found in init 221 | self.col.create_index(EMBEDDING_FIELD, index_params=self.index_params) 222 | else: 223 | # If no index param supplied, to first create an HNSW index for Milvus 224 | try: 225 | i_p = { 226 | "metric_type": "IP", 227 | "index_type": "HNSW", 228 | "params": {"M": 8, "efConstruction": 64}, 229 | } 230 | self._print_info("Attempting creation of Milvus '{}' index".format(i_p["index_type"])) 231 | self.col.create_index(EMBEDDING_FIELD, index_params=i_p) 232 | self.index_params = i_p 233 | self._print_info("Creation of Milvus '{}' index successful".format(i_p["index_type"])) 234 | # If create fails, most likely due to being Zilliz Cloud instance, try to create an AutoIndex 235 | except MilvusException: 236 | self._print_info("Attempting creation of Milvus default index") 237 | i_p = {"metric_type": "IP", "index_type": "AUTOINDEX", "params": {}} 238 | self.col.create_index(EMBEDDING_FIELD, index_params=i_p) 239 | self.index_params = i_p 240 | self._print_info("Creation of Milvus default index successful") 241 | # If an index already exists, grab its params 242 | else: 243 | # How about if the first index is not vector index? 244 | for index in self.col.indexes: 245 | idx = index.to_dict() 246 | if idx["field"] == EMBEDDING_FIELD: 247 | self._print_info("Index already exists: {}".format(idx)) 248 | self.index_params = idx['index_param'] 249 | break 250 | 251 | self.col.load() 252 | 253 | if self.search_params is not None: 254 | # Convert the string format to JSON format parameters passed by MILVUS_SEARCH_PARAMS 255 | self.search_params = json.loads(self.search_params) 256 | else: 257 | # The default search params 258 | metric_type = "IP" 259 | if "metric_type" in self.index_params: 260 | metric_type = self.index_params["metric_type"] 261 | default_search_params = { 262 | "IVF_FLAT": {"metric_type": metric_type, "params": {"nprobe": 10}}, 263 | "IVF_SQ8": {"metric_type": metric_type, "params": {"nprobe": 10}}, 264 | "IVF_PQ": {"metric_type": metric_type, "params": {"nprobe": 10}}, 265 | "HNSW": {"metric_type": metric_type, "params": {"ef": 10}}, 266 | "RHNSW_FLAT": {"metric_type": metric_type, "params": {"ef": 10}}, 267 | "RHNSW_SQ": {"metric_type": metric_type, "params": {"ef": 10}}, 268 | "RHNSW_PQ": {"metric_type": metric_type, "params": {"ef": 10}}, 269 | "IVF_HNSW": {"metric_type": metric_type, "params": {"nprobe": 10, "ef": 10}}, 270 | "ANNOY": {"metric_type": metric_type, "params": {"search_k": 10}}, 271 | "AUTOINDEX": {"metric_type": metric_type, "params": {}}, 272 | } 273 | # Set the search params 274 | self.search_params = default_search_params[self.index_params["index_type"]] 275 | self._print_info("Milvus search parameters: {}".format(self.search_params)) 276 | except Exception as e: 277 | self._print_err("Failed to create index, error: {}".format(e)) 278 | 279 | async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: 280 | """Upsert chunks into the datastore. 281 | 282 | Args: 283 | chunks (Dict[str, List[DocumentChunk]]): A list of DocumentChunks to insert 284 | 285 | Raises: 286 | e: Error in upserting data. 287 | 288 | Returns: 289 | List[str]: The document_id's that were inserted. 290 | """ 291 | try: 292 | # The doc id's to return for the upsert 293 | doc_ids: List[str] = [] 294 | # List to collect all the insert data, skip the "pk" for schema V1 295 | offset = 1 if self._schema_ver == "V1" else 0 296 | insert_data = [[] for _ in range(len(self._get_schema()) - offset)] 297 | 298 | # Go through each document chunklist and grab the data 299 | for doc_id, chunk_list in chunks.items(): 300 | # Append the doc_id to the list we are returning 301 | doc_ids.append(doc_id) 302 | # Examine each chunk in the chunklist 303 | for chunk in chunk_list: 304 | # Extract data from the chunk 305 | list_of_data = self._get_values(chunk) 306 | # Check if the data is valid 307 | if list_of_data is not None: 308 | # Append each field to the insert_data 309 | for x in range(len(insert_data)): 310 | insert_data[x].append(list_of_data[x]) 311 | # Slice up our insert data into batches 312 | batches = [ 313 | insert_data[i : i + UPSERT_BATCH_SIZE] 314 | for i in range(0, len(insert_data), UPSERT_BATCH_SIZE) 315 | ] 316 | 317 | # Attempt to insert each batch into our collection 318 | # batch data can work with both V1 and V2 schema 319 | for batch in batches: 320 | if len(batch[0]) != 0: 321 | try: 322 | self._print_info(f"Upserting batch of size {len(batch[0])}") 323 | self.col.insert(batch) 324 | self._print_info(f"Upserted batch successfully") 325 | except Exception as e: 326 | self._print_err(f"Failed to insert batch records, error: {e}") 327 | raise e 328 | 329 | # This setting perfoms flushes after insert. Small insert == bad to use 330 | # self.col.flush() 331 | return doc_ids 332 | except Exception as e: 333 | self._print_err("Failed to insert records, error: {}".format(e)) 334 | return [] 335 | 336 | 337 | def _get_values(self, chunk: DocumentChunk) -> List[any] | None: # type: ignore 338 | """Convert the chunk into a list of values to insert whose indexes align with fields. 339 | 340 | Args: 341 | chunk (DocumentChunk): The chunk to convert. 342 | 343 | Returns: 344 | List (any): The values to insert. 345 | """ 346 | # Convert DocumentChunk and its sub models to dict 347 | values = chunk.dict() 348 | # Unpack the metadata into the same dict 349 | meta = values.pop("metadata") 350 | values.update(meta) 351 | 352 | # Convert date to int timestamp form 353 | if values["created_at"]: 354 | values["created_at"] = to_unix_timestamp(values["created_at"]) 355 | 356 | # If source exists, change from Source object to the string value it holds 357 | if values["source"]: 358 | values["source"] = values["source"].value 359 | # List to collect data we will return 360 | ret = [] 361 | # Grab data responding to each field, excluding the hidden auto pk field for schema V1 362 | offset = 1 if self._schema_ver == "V1" else 0 363 | for key, _, default in self._get_schema()[offset:]: 364 | # Grab the data at the key and default to our defaults set in init 365 | x = values.get(key) or default 366 | # If one of our required fields is missing, ignore the entire entry 367 | if x is Required: 368 | self._print_info("Chunk " + values["id"] + " missing " + key + " skipping") 369 | return None 370 | # Add the corresponding value if it passes the tests 371 | ret.append(x) 372 | return ret 373 | 374 | async def _query( 375 | self, 376 | queries: List[QueryWithEmbedding], 377 | ) -> List[QueryResult]: 378 | """Query the QueryWithEmbedding against the MilvusDocumentSearch 379 | 380 | Search the embedding and its filter in the collection. 381 | 382 | Args: 383 | queries (List[QueryWithEmbedding]): The list of searches to perform. 384 | 385 | Returns: 386 | List[QueryResult]: Results for each search. 387 | """ 388 | # Async to perform the query, adapted from pinecone implementation 389 | async def _single_query(query: QueryWithEmbedding) -> QueryResult: 390 | try: 391 | filter = None 392 | # Set the filter to expression that is valid for Milvus 393 | if query.filter is not None: 394 | # Either a valid filter or None will be returned 395 | filter = self._get_filter(query.filter) 396 | 397 | # Perform our search 398 | return_from = 2 if self._schema_ver == "V1" else 1 399 | res = self.col.search( 400 | data=[query.embedding], 401 | anns_field=EMBEDDING_FIELD, 402 | param=self.search_params, 403 | limit=query.top_k, 404 | expr=filter, 405 | output_fields=[ 406 | field[0] for field in self._get_schema()[return_from:] 407 | ], # Ignoring pk, embedding 408 | ) 409 | # Results that will hold our DocumentChunkWithScores 410 | results = [] 411 | # Parse every result for our search 412 | for hit in res[0]: # type: ignore 413 | # The distance score for the search result, falls under DocumentChunkWithScore 414 | score = hit.score 415 | # Our metadata info, falls under DocumentChunkMetadata 416 | metadata = {} 417 | # Grab the values that correspond to our fields, ignore pk and embedding. 418 | for x in [field[0] for field in self._get_schema()[return_from:]]: 419 | metadata[x] = hit.entity.get(x) 420 | # If the source isn't valid, convert to None 421 | if metadata["source"] not in Source.__members__: 422 | metadata["source"] = None 423 | # Text falls under the DocumentChunk 424 | text = metadata.pop("text") 425 | # Id falls under the DocumentChunk 426 | ids = metadata.pop("id") 427 | chunk = DocumentChunkWithScore( 428 | id=ids, 429 | score=score, 430 | text=text, 431 | metadata=DocumentChunkMetadata(**metadata), 432 | ) 433 | results.append(chunk) 434 | 435 | # TODO: decide on doing queries to grab the embedding itself, slows down performance as double query occurs 436 | 437 | return QueryResult(query=query.query, results=results) 438 | except Exception as e: 439 | self._print_err("Failed to query, error: {}".format(e)) 440 | return QueryResult(query=query.query, results=[]) 441 | 442 | results: List[QueryResult] = await asyncio.gather( 443 | *[_single_query(query) for query in queries] 444 | ) 445 | return results 446 | 447 | async def delete( 448 | self, 449 | ids: Optional[List[str]] = None, 450 | filter: Optional[DocumentMetadataFilter] = None, 451 | delete_all: Optional[bool] = None, 452 | ) -> bool: 453 | """Delete the entities based either on the chunk_id of the vector, 454 | 455 | Args: 456 | ids (Optional[List[str]], optional): The document_ids to delete. Defaults to None. 457 | filter (Optional[DocumentMetadataFilter], optional): The filter to delete by. Defaults to None. 458 | delete_all (Optional[bool], optional): Whether to drop the collection and recreate it. Defaults to None. 459 | """ 460 | # If deleting all, drop and create the new collection 461 | if delete_all: 462 | coll_name = self.col.name 463 | self._print_info("Delete the entire collection {} and create new one".format(coll_name)) 464 | # Release the collection from memory 465 | self.col.release() 466 | # Drop the collection 467 | self.col.drop() 468 | # Recreate the new collection 469 | self._create_collection(coll_name, True) 470 | self._create_index() 471 | return True 472 | 473 | # Keep track of how many we have deleted for later printing 474 | delete_count = 0 475 | batch_size = 100 476 | pk_name = "pk" if self._schema_ver == "V1" else "id" 477 | try: 478 | # According to the api design, the ids is a list of document_id, 479 | # document_id is not primary key, use query+delete to workaround, 480 | # in future version we can delete by expression 481 | if (ids is not None) and len(ids) > 0: 482 | # Add quotation marks around the string format id 483 | ids = ['"' + str(id) + '"' for id in ids] 484 | # Query for the pk's of entries that match id's 485 | ids = self.col.query(f"document_id in [{','.join(ids)}]") 486 | # Convert to list of pks 487 | pks = [str(entry[pk_name]) for entry in ids] # type: ignore 488 | # for schema V2, the "id" is varchar, rewrite the expression 489 | if self._schema_ver != "V1": 490 | pks = ['"' + pk + '"' for pk in pks] 491 | 492 | # Delete by ids batch by batch(avoid too long expression) 493 | self._print_info("Apply {:d} deletions to schema {:s}".format(len(pks), self._schema_ver)) 494 | while len(pks) > 0: 495 | batch_pks = pks[:batch_size] 496 | pks = pks[batch_size:] 497 | # Delete the entries batch by batch 498 | res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") 499 | # Increment our deleted count 500 | delete_count += int(res.delete_count) # type: ignore 501 | except Exception as e: 502 | self._print_err("Failed to delete by ids, error: {}".format(e)) 503 | 504 | try: 505 | # Check if empty filter 506 | if filter is not None: 507 | # Convert filter to milvus expression 508 | filter = self._get_filter(filter) # type: ignore 509 | # Check if there is anything to filter 510 | if len(filter) != 0: # type: ignore 511 | # Query for the pk's of entries that match filter 512 | res = self.col.query(filter) # type: ignore 513 | # Convert to list of pks 514 | pks = [str(entry[pk_name]) for entry in res] # type: ignore 515 | # for schema V2, the "id" is varchar, rewrite the expression 516 | if self._schema_ver != "V1": 517 | pks = ['"' + pk + '"' for pk in pks] 518 | # Check to see if there are valid pk's to delete, delete batch by batch(avoid too long expression) 519 | while len(pks) > 0: # type: ignore 520 | batch_pks = pks[:batch_size] 521 | pks = pks[batch_size:] 522 | # Delete the entries batch by batch 523 | res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") # type: ignore 524 | # Increment our delete count 525 | delete_count += int(res.delete_count) # type: ignore 526 | except Exception as e: 527 | self._print_err("Failed to delete by filter, error: {}".format(e)) 528 | 529 | self._print_info("{:d} records deleted".format(delete_count)) 530 | 531 | # This setting performs flushes after delete. Small delete == bad to use 532 | # self.col.flush() 533 | 534 | return True 535 | 536 | def _get_filter(self, filter: DocumentMetadataFilter) -> Optional[str]: 537 | """Converts a DocumentMetdataFilter to the expression that Milvus takes. 538 | 539 | Args: 540 | filter (DocumentMetadataFilter): The Filter to convert to Milvus expression. 541 | 542 | Returns: 543 | Optional[str]: The filter if valid, otherwise None. 544 | """ 545 | filters = [] 546 | # Go through all the fields and their values 547 | for field, value in filter.dict().items(): 548 | # Check if the Value is empty 549 | if value is not None: 550 | # Convert start_date to int and add greater than or equal logic 551 | if field == "start_date": 552 | filters.append( 553 | "(created_at >= " + str(to_unix_timestamp(value)) + ")" 554 | ) 555 | # Convert end_date to int and add less than or equal logic 556 | elif field == "end_date": 557 | filters.append( 558 | "(created_at <= " + str(to_unix_timestamp(value)) + ")" 559 | ) 560 | # Convert Source to its string value and check equivalency 561 | elif field == "source": 562 | filters.append("(" + field + ' == "' + str(value.value) + '")') 563 | # Check equivalency of rest of string fields 564 | else: 565 | filters.append("(" + field + ' == "' + str(value) + '")') 566 | # Join all our expressions with `and`` 567 | return " and ".join(filters) 568 | --------------------------------------------------------------------------------