├── .gitignore ├── LICENSE ├── firstbatch ├── __init__.py ├── algorithm │ ├── __init__.py │ ├── base.py │ ├── blueprint │ │ ├── __init__.py │ │ ├── base.py │ │ └── library.py │ ├── custom.py │ ├── factory.py │ ├── registry.py │ └── simple.py ├── async_core.py ├── client │ ├── __init__.py │ ├── async_client.py │ ├── base.py │ ├── schema.py │ ├── sync.py │ └── wrapper.py ├── constants.py ├── core.py ├── imports │ └── __init__.py ├── logger_conf.py ├── lossy │ ├── __init__.py │ ├── base.py │ ├── product.py │ └── scalar.py ├── utils.py └── vector_store │ ├── __init__.py │ ├── base.py │ ├── chroma.py │ ├── pinecone.py │ ├── qdrant.py │ ├── schema.py │ ├── supabase.py │ ├── typesense.py │ ├── utils.py │ └── weaviate.py ├── poetry.lock ├── pyproject.toml ├── readme.md └── tests ├── algorithms ├── test_algorithms.py ├── test_algorithms_async.py └── test_algorithms_vs.py ├── compression ├── test_lossy_product.py └── test_lossy_scalar.py ├── parser └── test_parser.py └── vector_store ├── test_chroma.py ├── test_pinecone.py ├── test_qdrant.py ├── test_supabase.py ├── test_typesense.py └── test_weaviate.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./DS_Store 2 | DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 FirstBatch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /firstbatch/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.core import FirstBatch 2 | from firstbatch.async_core import AsyncFirstBatch 3 | from firstbatch.algorithm import UserAction, Signal, BatchEnum, AlgorithmLabel 4 | from firstbatch.utils import Config 5 | from firstbatch.vector_store import Pinecone, Weaviate, Chroma, TypeSense, Supabase, Qdrant, DistanceMetric 6 | __all__ = ["FirstBatch", "AsyncFirstBatch", "Pinecone", "Weaviate", "Chroma", "TypeSense", "Supabase", "Qdrant" ,"Config", 7 | "UserAction", "Signal", "BatchEnum", "AlgorithmLabel", "DistanceMetric"] 8 | -------------------------------------------------------------------------------- /firstbatch/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.algorithm.base import BaseAlgorithm 2 | from firstbatch.algorithm.factory import FactoryAlgorithm 3 | from firstbatch.algorithm.simple import SimpleAlgorithm 4 | from firstbatch.algorithm.custom import CustomAlgorithm 5 | from firstbatch.algorithm.registry import AlgorithmLabel, AlgorithmRegistry 6 | from firstbatch.algorithm.blueprint import SignalObject, Blueprint, BatchEnum, BatchType, \ 7 | Params, Signal, SignalType, SessionObject, UserAction, Vertex, Edge, DFAParser 8 | 9 | __all__ = ["BaseAlgorithm", "FactoryAlgorithm", "SimpleAlgorithm", "CustomAlgorithm", "AlgorithmLabel", 10 | "AlgorithmRegistry","Blueprint", "BatchType", "UserAction", "Vertex", "Edge", "DFAParser", 11 | "Params", "BatchEnum", "Signal", "SignalObject", "SignalType", "SessionObject"] -------------------------------------------------------------------------------- /firstbatch/algorithm/base.py: -------------------------------------------------------------------------------- 1 | """BaseAlgorithm""" 2 | from __future__ import annotations 3 | from abc import ABC 4 | from typing import Optional, Any, List 5 | from firstbatch.algorithm.blueprint import Blueprint, UserAction 6 | from firstbatch.vector_store.utils import maximal_marginal_relevance 7 | from firstbatch.vector_store.schema import BatchQueryResult, BatchQuery, QueryMetadata 8 | from firstbatch.algorithm.registry import AlgorithmLabel, AlgorithmRegistry 9 | 10 | 11 | class BaseAlgorithm(ABC): 12 | 13 | batch_size: int 14 | is_custom: bool 15 | name: str 16 | 17 | def __init_subclass__(cls, label: Optional[AlgorithmLabel] = None, **kwargs): 18 | """Automatically registers subclasses with the AlgorithmRegistry.""" 19 | super().__init_subclass__(**kwargs) 20 | if label: 21 | AlgorithmRegistry.register_algorithm(label, cls) 22 | 23 | @property 24 | def _blueprint(self) -> Blueprint: 25 | return self.__blueprint 26 | 27 | @_blueprint.setter 28 | def _blueprint(self, value): 29 | self.__blueprint = value 30 | 31 | def blueprint_step(self, state: str, action: UserAction) -> Any: 32 | """Call the step method of the _blueprint.""" 33 | return self._blueprint.step(state, action) 34 | 35 | def random_batch(self, batch: BatchQueryResult, query: BatchQuery, **kwargs) -> Any: 36 | # We can't use apply threshold with random batches 37 | if "apply_threshold" in kwargs: 38 | del kwargs["apply_threshold"] 39 | if "apply_mmr" in kwargs: 40 | del kwargs["apply_mmr"] 41 | kwargs["shuffle"] = True 42 | ids, metadata = self._apply_params(batch, query, **kwargs) 43 | return ids[:batch.batch_size], metadata[:batch.batch_size] 44 | 45 | def biased_batch(self, batch: BatchQueryResult, query: BatchQuery, **kwargs) \ 46 | -> Any: 47 | kwargs["shuffle"] = True 48 | ids, metadata = self._apply_params(batch, query, **kwargs) 49 | return ids[:batch.batch_size], metadata[:batch.batch_size] 50 | 51 | def sampled_batch(self, batch: BatchQueryResult, query: BatchQuery, **kwargs: Any) -> Any: 52 | kwargs["shuffle"] = True 53 | ids, metadata = self._apply_params(batch, query, **kwargs) 54 | return ids[:batch.batch_size], metadata[:batch.batch_size] 55 | 56 | @staticmethod 57 | def _apply_params(batch: BatchQueryResult, query: BatchQuery, **kwargs: Any): 58 | 59 | if len(batch.results) != len(query.queries): 60 | raise ValueError("Number of results is not equal to number of queries!") 61 | 62 | if "apply_threshold" in kwargs: 63 | if isinstance(kwargs["apply_threshold"], list): 64 | if kwargs["apply_threshold"][0]: 65 | for i in range(len(batch.results)): 66 | batch.results[i] = batch.results[i].apply_threshold(kwargs["apply_threshold"][1]) 67 | else: 68 | if kwargs["apply_threshold"] > 0: 69 | for i in range(len(batch.results)): 70 | batch.results[i] = batch.results[i].apply_threshold(kwargs["apply_threshold"]) 71 | 72 | if "apply_mmr" in kwargs: 73 | if kwargs["apply_mmr"]: 74 | i = 0 75 | for q, embeddings in zip(query.queries, batch.results): 76 | if q.embedding is None: 77 | raise ValueError("Embedding cannot be None") 78 | batch.results[i] = maximal_marginal_relevance(q.embedding, embeddings, 0.5, q.top_k_mmr) 79 | i += 1 80 | 81 | if "remove_duplicates" in kwargs: 82 | if kwargs["remove_duplicates"]: 83 | batch.remove_duplicates() 84 | 85 | batch.sort() 86 | 87 | ids: List[str] = [] 88 | metadata: List[QueryMetadata] = [] 89 | 90 | for i, result in enumerate(batch.results): 91 | k = query.queries[i].top_k 92 | if result.ids is not None and result.metadata is not None: 93 | ids += result.ids[:k] 94 | metadata += result.metadata[:k] 95 | else: 96 | raise ValueError("Result ids or metadata is None") 97 | 98 | if "shuffle" in kwargs: 99 | if kwargs["shuffle"]: 100 | import random 101 | c = list(zip(ids, metadata)) 102 | random.shuffle(c) 103 | ids, metadata = (list(x) for x in zip(*c)) 104 | 105 | return ids, metadata 106 | 107 | def _reset(self, *args, **kwargs): 108 | pass 109 | -------------------------------------------------------------------------------- /firstbatch/algorithm/blueprint/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.algorithm.blueprint.base import Blueprint, BatchType, UserAction, Vertex, Edge, DFAParser, BatchEnum, \ 2 | Params, Signal, SignalObject, SignalType, SessionObject 3 | __all__ = ["Blueprint", "BatchType", "UserAction", "Vertex", "Edge", 4 | "DFAParser", "BatchEnum", "Params", "Signal", "SignalObject", "SignalType", "SessionObject"] -------------------------------------------------------------------------------- /firstbatch/algorithm/blueprint/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | from typing import List, Dict, Union, Optional, Any, Tuple 4 | from firstbatch.vector_store.schema import Vector 5 | from dataclasses import dataclass, field 6 | from dataclasses_json import DataClassJsonMixin 7 | from enum import Enum 8 | 9 | 10 | @dataclass 11 | class Params(DataClassJsonMixin): 12 | mu: float = 0.0 13 | alpha: float = 0.0 14 | r: float = 0.0 15 | last_n: int = 0 16 | n_topics: int = 0 17 | remove_duplicates: bool = True 18 | apply_threshold: Tuple[bool, float] = (False, 0.0) 19 | apply_mmr: bool = False 20 | 21 | 22 | @dataclass 23 | class SignalType: 24 | label: str 25 | weight: float 26 | 27 | def to_action(self): 28 | return self.label.upper() 29 | 30 | 31 | class Signal: 32 | """Signal class""" 33 | DEFAULT = SignalType(label="default", weight=1.0) 34 | ADD_TO_CART = SignalType("ADD_TO_CART", 16) 35 | ITEM_VIEW = SignalType("ITEM_VIEW", 10) 36 | APPLY = SignalType("APPLY", 18) 37 | PURCHASE = SignalType("PURCHASE", 20) 38 | HIGHLIGHT = SignalType("HIGHLIGHT", 8) 39 | GLANCE_VIEW = SignalType("GLANCE_VIEW", 14) 40 | CAMPAIGN_CLICK = SignalType("CAMPAIGN_CLICK", 6) 41 | CATEGORY_VISIT = SignalType("CATEGORY_VISIT", 10) 42 | SHARE = SignalType("SHARE", 10) 43 | MERCHANT_VIEW = SignalType("MERCHANT_VIEW", 10) 44 | REIMBURSED = SignalType("REIMBURSED", 20) 45 | APPROVED = SignalType("APPROVED", 18) 46 | REJECTED = SignalType("REJECTED", 18) 47 | SHARE_ARTICLE = SignalType("SHARE_ARTICLE", 10) 48 | COMMENT = SignalType("COMMENT", 12) 49 | PERSPECTIVES_SWITCH = SignalType("PERSPECTIVES_SWITCH", 8) 50 | REPOST = SignalType("REPOST", 20) 51 | SUBSCRIBE = SignalType("SUBSCRIBE", 18) 52 | SHARE_PROFILE = SignalType("SHARE_PROFILE", 10) 53 | PAID_SUBSCRIBE = SignalType("PAID_SUBSCRIBE", 20) 54 | SAVE = SignalType("SAVE", 8) 55 | FOLLOW_TOPIC = SignalType("FOLLOW_TOPIC", 10) 56 | WATCH = SignalType("WATCH", 20) 57 | CLICK_LINK = SignalType("CLICK_LINK", 6) 58 | RECOMMEND = SignalType("RECOMMEND", 12) 59 | FOLLOW = SignalType("FOLLOW", 10) 60 | VISIT_PROFILE = SignalType("VISIT_PROFILE", 12) 61 | AUTO_PLAY = SignalType("AUTO_PLAY", 4) 62 | SAVE_ARTICLE = SignalType("SAVE_ARTICLE", 8) 63 | REPLAY = SignalType("REPLAY", 20) 64 | READ = SignalType("READ", 14) 65 | LIKE = SignalType("LIKE", 8) 66 | CLICK_EMAIL_LINK = SignalType("CLICK_EMAIL_LINK", 6) 67 | ADD_TO_LIST = SignalType("ADD_TO_LIST", 12) 68 | FOLLOW_AUTHOR = SignalType("FOLLOW_AUTHOR", 10) 69 | SEARCH = SignalType("SEARCH", 15) 70 | CLICK_AD = SignalType("CLICK_AD", 6.0) 71 | 72 | @staticmethod 73 | def name(value: SignalType) -> str: 74 | return value.label.upper() 75 | 76 | @classmethod 77 | def add_signals(cls, signals: List[Dict[str, Any]]) -> None: 78 | for signal in signals: 79 | signal_type = SignalType(**signal) 80 | setattr(cls, signal['label'].upper(), signal_type) 81 | 82 | @classmethod 83 | def add_new_signals_from_json(cls, json_path: str) -> None: 84 | with open(json_path, "r") as f: 85 | signals = json.load(f) 86 | cls.add_signals(signals) 87 | 88 | @classmethod 89 | def add_new_signals_from_json_string(cls, json_string: str) -> None: 90 | signals = json.loads(json_string) 91 | cls.add_signals(signals) 92 | 93 | @classmethod 94 | def length(cls): 95 | return len([attr for attr in dir(cls) if not callable(getattr(cls, attr))]) 96 | 97 | 98 | @dataclass 99 | class SessionObject(DataClassJsonMixin): 100 | id: str 101 | is_persistent: bool 102 | 103 | 104 | @dataclass 105 | class SignalObject: 106 | action: SignalType 107 | vector: Vector 108 | cid: Optional[str] 109 | timestamp: Optional[int] 110 | 111 | 112 | class BatchEnum(str, Enum): 113 | BATCH = "batch" 114 | 115 | 116 | class BatchType(Enum): 117 | PERSONALIZED = "personalized" 118 | BIASED = "biased" 119 | SAMPLED = "sampled" 120 | RANDOM = "random" 121 | 122 | 123 | class UserAction: 124 | 125 | def __init__(self, action: Union[str, SignalType, BatchEnum]): 126 | 127 | self.action_type: Union[SignalType, BatchEnum] = BatchEnum.BATCH 128 | 129 | if isinstance(action, BatchEnum): 130 | self.action_type = action 131 | elif isinstance(action, SignalType): 132 | self.action_type = action 133 | else: 134 | self.action_type = self.parse_action(action) 135 | 136 | @staticmethod 137 | def parse_action(action: str) -> Union[SignalType, BatchEnum]: 138 | 139 | if action != "BATCH": 140 | return getattr(Signal, action) 141 | elif action == "BATCH": 142 | return BatchEnum.BATCH 143 | else: 144 | raise ValueError(f"Invalid action: {action}") 145 | 146 | 147 | @dataclass 148 | class Vertex: 149 | name: str 150 | batch_type: BatchType 151 | params: Params 152 | 153 | 154 | @dataclass 155 | class Edge: 156 | name: str 157 | edge_type: UserAction 158 | start: Vertex 159 | end: Vertex 160 | 161 | 162 | @dataclass 163 | class Blueprint: 164 | vertices: List[Vertex] = field(default_factory=list) 165 | edges: List[Edge] = field(default_factory=list) 166 | map: Dict[str, Vertex] = field(default_factory=dict) 167 | 168 | def add_vertex(self, vertex: Vertex): 169 | self.vertices.append(vertex) 170 | self.map[vertex.name] = vertex 171 | 172 | def add_edge(self, edge: Edge): 173 | self.edges.append(edge) 174 | 175 | def get_operation(self, state: str) -> Vertex: 176 | return self.map[state] 177 | 178 | def step(self, state: str, action: UserAction) -> Tuple[Vertex, BatchType, Params]: 179 | """Step function that takes a node name and a UserAction to determine the next vertex.""" 180 | try: 181 | if state == "0": 182 | vertex = self.vertices[0] 183 | else: 184 | vertex = self.map[state] 185 | except KeyError: 186 | raise ValueError(f"No vertex found for state: {state}") 187 | 188 | if not vertex: 189 | print(f"No vertex found with name: {state}") 190 | raise ValueError(f"No vertex found with name: {state}") 191 | 192 | edge = next((e for e in self.edges if e.start == vertex and e.edge_type.action_type == action.action_type), 193 | None) 194 | if edge: 195 | return edge.end, vertex.batch_type, vertex.params 196 | else: 197 | edge = next((e for e in self.edges if e.start == vertex and e.edge_type.action_type == Signal.DEFAULT), 198 | None) 199 | if edge: 200 | return edge.end, vertex.batch_type, vertex.params 201 | else: 202 | raise ValueError("No edge found for given conditions") 203 | 204 | 205 | class DFAParser: 206 | def __init__(self, data: Union[str, dict]): 207 | if isinstance(data, str): 208 | data = json.loads(data) 209 | self.data = data 210 | self.blueprint = Blueprint() 211 | 212 | def __validate_edges(self): 213 | """Validates that each vertex has at least one BatchEnum typed edge and one Signal typed edge. 214 | Also checks that the signals are covered. 215 | """ 216 | for node in self.blueprint.vertices: 217 | related_edges = [e for e in self.blueprint.edges if e.start == node] 218 | 219 | has_batch = any(isinstance(edge.edge_type.action_type, BatchEnum) for edge in related_edges) 220 | if not has_batch: 221 | raise ValueError(f"Node {node.name} is missing a BatchEnum typed edge") 222 | 223 | action_types = [edge.edge_type.action_type for edge in related_edges if 224 | not isinstance(edge.edge_type.action_type, BatchEnum)] 225 | # Check if the Signal is covered 226 | if Signal.DEFAULT not in action_types and len(action_types) != Signal.length(): 227 | raise ValueError(f"Node {node.name} does not have covered signals") 228 | 229 | def parse(self): 230 | 231 | # Parse Signals 232 | if "signals" in self.data: 233 | Signal.add_signals(self.data["signals"]) 234 | 235 | # Parse vertices 236 | for node_data in self.data["nodes"]: 237 | vertex = Vertex( 238 | name=node_data["name"], 239 | batch_type=BatchType(node_data["batch_type"]), 240 | params=Params(**node_data["params"]) 241 | ) 242 | self.blueprint.add_vertex(vertex) 243 | 244 | # Parse edges 245 | for edge_data in self.data["edges"]: 246 | edge = Edge( 247 | name=edge_data["name"], 248 | edge_type=UserAction(edge_data["edge_type"]), 249 | start=self.blueprint.map[edge_data["start"]], 250 | end=self.blueprint.map[edge_data["end"]] 251 | ) 252 | self.blueprint.add_edge(edge) 253 | 254 | self.__validate_edges() 255 | 256 | return self.blueprint 257 | -------------------------------------------------------------------------------- /firstbatch/algorithm/blueprint/library.py: -------------------------------------------------------------------------------- 1 | """Factory Algorithms for Blueprint""" 2 | # Navigable UX 3 | 4 | # Transform your user experience into a navigable journey by leveraging user interactions. 5 | # Every user action contributes to shaping their unique experience. 6 | 7 | # Target : Anonymous sessions to navigable experiences 8 | 9 | # Algo logic description : Trying to serve various content until receiving first signal from users. 10 | 11 | # Then crafting the experience sharply regarding signals. Gradually providing more space for exploration if user 12 | # keeps loosing interest in current topics. 13 | 14 | # Potential KPIs : Any engagement metric, Time spent on application or Bounce rate, 15 | # Conversion rate of anonymous sessions. 16 | 17 | Navigable_UX = '''{ 18 | "nodes": [ 19 | {"name": "Exploration", "batch_type": "random", "params": {"last_n":8}}, 20 | {"name": "Browsing", "batch_type": "sampled", "params": {"n_topics":12,"last_n":8 }}, 21 | {"name": "Discovery", "batch_type": "personalized", "params": {"r" : 0.2, "mu" : 0.5, "alpha" : 0.7, "apply_threshold": 0.3, "apply_mmr" :true}}, 22 | {"name": "Dedicated", "batch_type": "personalized", "params": {"r" : 0.1, "mu" : 0.2, "alpha" : 0.4, "apply_threshold": 0.5,"apply_mmr" :true, "last_n":5}}, 23 | {"name": "Focus", "batch_type": "personalized", "params": {"r" : 0.1, "mu" : 0.05, "alpha" : 0.1, "apply_threshold": 0.6, "apply_mmr" :false, "last_n":4}}, 24 | {"name": "Hyper_Focus", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0, "alpha" : 0, "apply_threshold": 0.7, "apply_mmr" :false, "last_n":2}} 25 | ], 26 | "edges": [ 27 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Exploration", "end": "Hyper_Focus"}, 28 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Browsing", "end": "Hyper_Focus"}, 29 | {"name": "edge3", "edge_type": "DEFAULT", "start": "Discovery", "end": "Hyper_Focus"}, 30 | {"name": "edge4", "edge_type": "DEFAULT", "start": "Dedicated", "end": "Hyper_Focus"}, 31 | {"name": "edge5", "edge_type": "DEFAULT", "start": "Focus", "end": "Hyper_Focus"}, 32 | {"name": "edge6", "edge_type": "DEFAULT", "start": "Hyper_Focus", "end": "Hyper_Focus"}, 33 | {"name": "edge7", "edge_type": "BATCH", "start": "Exploration", "end": "Browsing"}, 34 | {"name": "edge8", "edge_type": "BATCH", "start": "Browsing", "end": "Browsing"}, 35 | {"name": "edge9", "edge_type": "BATCH", "start": "Discovery", "end": "Discovery"}, 36 | {"name": "edge10", "edge_type": "BATCH", "start": "Dedicated", "end": "Discovery"}, 37 | {"name": "edge11", "edge_type": "BATCH", "start": "Focus", "end": "Dedicated"}, 38 | {"name": "edge12", "edge_type": "BATCH", "start": "Hyper_Focus", "end": "Focus"} 39 | ] 40 | }''' 41 | 42 | # Individually Crafted Recommendations 43 | 44 | # Offer users not only similar but also adjacent items in a personalized manner. 45 | 46 | # This approach allows users to discover new and relevant content on their own terms, 47 | # enhancing their exploration and satisfaction. 48 | 49 | # Target : Increase up-sell and help you to improve average order value. 50 | 51 | # Algo logic description : Making highly focused recommendations after first interaction. 52 | 53 | # But enable users to explore more items from a wider perspective to keep users within 54 | # recommendations space until they find something to add their cart. 55 | 56 | # Potential KPIs : Up-sell and cross-sell metrics. Average Order Value. Number of items per order. 57 | 58 | 59 | Individually_Crafted_Recommendations = '''{ 60 | "nodes": [ 61 | {"name": "Recommendation", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.0, "alpha" : 0, "apply_threshold": 0.7, "apply_mmr" :false, "last_n":1}}, 62 | {"name": "Expansion", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.05, "alpha" : 0.4, "apply_threshold": 0.7, "apply_mmr" :true, "last_n":2}}, 63 | {"name": "Discovery", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.1, "alpha" : 1, "apply_threshold": 0.6, "apply_mmr" :true, "last_n":4}} 64 | ], 65 | "edges": [ 66 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Discovery", "end": "Recommendation"}, 67 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Expansion", "end": "Recommendation"}, 68 | {"name": "edge3", "edge_type": "DEFAULT", "start": "Recommendation", "end": "Recommendation"}, 69 | {"name": "edge4", "edge_type": "BATCH", "start": "Recommendation", "end": "Expansion"}, 70 | {"name": "edge5", "edge_type": "BATCH", "start": "Expansion", "end": "Discovery"}, 71 | {"name": "edge6", "edge_type": "BATCH", "start": "Discovery", "end": "Discovery"} 72 | ] 73 | }''' 74 | 75 | # Unique Journeys 76 | 77 | # Enable users to access the right content from the very beginning by 78 | # tailoring their experience based on their starting point. 79 | 80 | # Target : Shape user journey from the very beginning. Might be best for 81 | # traffic source or seasonal campaigns based welcoming, and recurring visitor experiences. 82 | 83 | # Algo logic description : Providing focused content starting from first 84 | # load by utilizing user embeddings from previous sessions or adding seasonal effect to the experience. 85 | 86 | # For example adding summer collection as bias during summer. Then letting navigate themselves 87 | # just as we do in the Navigable_UX algorithm. 88 | 89 | # Potential KPIs : Up-Any engagement metric, Time spent before first interaction, 90 | # conversion rate of recurring visitors, Time spent on application or Bounce rate. 91 | 92 | Unique_Journeys = '''{ 93 | "nodes": [ 94 | {"name": "Welcome", "batch_type": "biased", "params": {"r" : 0.2, "mu" : 0.2, "alpha" : 0.4, "apply_threshold": 0.7, "apply_mmr" :false, "last_n":5}}, 95 | {"name": "Exploration", "batch_type": "personalized", "params": {"r" : 0.3, "mu" : 0.6, "alpha" : 0.7, "apply_threshold": 0.3, "apply_mmr" :true, "last_n":8}}, 96 | {"name": "Discovery", "batch_type": "personalized", "params": {"r" : 0.2, "mu" : 0.4, "alpha" : 0.5, "apply_threshold": 0.3, "apply_mmr" :true, "last_n":6}}, 97 | {"name": "Dedicated", "batch_type": "personalized", "params": {"r" : 0.1, "mu" : 0.2, "alpha" : 0.4, "apply_threshold": 0.5,"apply_mmr" :true, "last_n":5}}, 98 | {"name": "Focus", "batch_type": "personalized", "params": {"r" : 0.1, "mu" : 0.05, "alpha" : 0.1, "apply_threshold": 0.6, "apply_mmr" :false, "last_n":4}}, 99 | {"name": "Hyper_Focus", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0, "alpha" : 0, "apply_threshold": 0.7, "apply_mmr" :false, "last_n":2}} 100 | ], 101 | "edges": [ 102 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Welcome", "end": "Hyper_Focus"}, 103 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Exploration", "end": "Hyper_Focus"}, 104 | {"name": "edge3", "edge_type": "DEFAULT", "start": "Discovery", "end": "Hyper_Focus"}, 105 | {"name": "edge4", "edge_type": "DEFAULT", "start": "Dedicated", "end": "Hyper_Focus"}, 106 | {"name": "edge5", "edge_type": "DEFAULT", "start": "Focus", "end": "Hyper_Focus"}, 107 | {"name": "edge6", "edge_type": "DEFAULT", "start": "Hyper_Focus", "end": "Hyper_Focus"}, 108 | {"name": "edge7", "edge_type": "BATCH", "start": "Welcome", "end": "Exploration"}, 109 | {"name": "edge8", "edge_type": "BATCH", "start": "Exploration", "end": "Welcome"}, 110 | {"name": "edge9", "edge_type": "BATCH", "start": "Discovery", "end": "Exploration"}, 111 | {"name": "edge10", "edge_type": "BATCH", "start": "Dedicated", "end": "Discovery"}, 112 | {"name": "edge11", "edge_type": "BATCH", "start": "Focus", "end": "Dedicated"}, 113 | {"name": "edge12", "edge_type": "BATCH", "start": "Hyper_Focus", "end": "Focus"} 114 | ] 115 | }''' 116 | 117 | 118 | # Not User Targeting but User-Centric Promoted Content Curations 119 | 120 | # Shift away from conventional targeting techniques and embrace a user-centric approach to deliver promoted 121 | # items or ads in a captivating format. 122 | 123 | # This approach allows users to actively influence the curation of promoted content, 124 | # ensuring it aligns seamlessly with their preferences and resulting in a highly interactive and enjoyable experience. 125 | 126 | # Target : Not force your users to see irrelevant promoted content but provide an engaging campaign discovery. 127 | 128 | # People ignore ads because targeting only pollutes feeds. Therefore the aim is improveing campaign 129 | # CTR by providing true content for right users at the right time. 130 | 131 | # Algo logic description : Promoting contents in hyper-personalized manner by keeping curation 132 | # focused after first interaction to forever. 133 | 134 | # Potential KPIs : CTR 135 | 136 | User_Centric_Promoted_Content_Curations = '''{ 137 | "nodes": [ 138 | {"name": "Exploration", "batch_type": "sampled", "params": {"n_topics":8,"last_n": 3}}, 139 | {"name": "Curated", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.1, "alpha" : 0.3, "apply_threshold": 0.6, "apply_mmr" :false, "last_n":1}} 140 | ], 141 | "edges": [ 142 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Exploration", "end": "Curated"}, 143 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Curated", "end": "Curated"}, 144 | {"name": "edge3", "edge_type": "BATCH", "start": "Exploration", "end": "Exploration"}, 145 | {"name": "edge4", "edge_type": "BATCH", "start": "Curated", "end": "Curated"} 146 | ] 147 | }''' 148 | 149 | # User-Intent AI Agents 150 | # Empower your AI agents with real-time insights into user intentions, derived from their interactions. 151 | 152 | # This infusion of user intent brings intimacy to AI-driven experiences, 153 | # making users feel more connected and understood. 154 | 155 | # Target : Serving personal AI assistance that reflects user interactions that are not restricted with prompts. 156 | 157 | # Algo logic description : Not giving space for false navigation and keeping 158 | # the AI agent as much as closer to user intentions. 159 | 160 | # Because people are being demotivated by hallucinated conversations with AI too fast. 161 | 162 | # Potential KPIs : Time spent with AI agents, Chat Rating, Conversion rate through AI agents. 163 | 164 | User_Intent_AI_Agents = '''{ 165 | "nodes": [ 166 | {"name": "Welcome", "batch_type": "biased", "params": {"r" : 0, "mu" : 0.1, "alpha" : 0.4, "apply_threshold": 0.8, "apply_mmr" :false, "last_n":12}}, 167 | {"name": "Expansion", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.2, "alpha" : 0.6, "apply_threshold": 0.6, "apply_mmr" :true, "last_n":12}}, 168 | {"name": "Exploration", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0.5, "alpha" : 0.6, "apply_threshold": 0.5,"apply_mmr" :true, "last_n":12}}, 169 | {"name": "Focus", "batch_type": "personalized", "params": {"r" : 0, "mu" : 0, "alpha" : 0, "apply_threshold": 0.8, "apply_mmr" :false, "last_n":6}} 170 | ], 171 | "edges": [ 172 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Welcome", "end": "Hyper_Focus"}, 173 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Expansion", "end": "Hyper_Focus"}, 174 | {"name": "edge3", "edge_type": "DEFAULT", "start": "Exploration", "end": "Hyper_Focus"}, 175 | {"name": "edge4", "edge_type": "DEFAULT", "start": "Focus", "end": "Hyper_Focus"}, 176 | {"name": "edge1", "edge_type": "BATCH", "start": "Welcome", "end": "Welcome"}, 177 | {"name": "edge2", "edge_type": "BATCH", "start": "Expansion", "end": "Exploration"}, 178 | {"name": "edge3", "edge_type": "BATCH", "start": "Exploration", "end": "Exploration"}, 179 | {"name": "edge4", "edge_type": "BATCH", "start": "Focus", "end": "Expansion"} 180 | ] 181 | }''' 182 | 183 | lookup = { 184 | "Unique_Journeys".upper(): Unique_Journeys, 185 | "User_Centric_Promoted_Content_Curations".upper(): User_Centric_Promoted_Content_Curations, 186 | "User_Intent_AI_Agents".upper(): User_Intent_AI_Agents, 187 | "Individually_Crafted_Recommendations".upper(): Individually_Crafted_Recommendations, 188 | "Navigable_UX".upper(): Navigable_UX 189 | } -------------------------------------------------------------------------------- /firstbatch/algorithm/custom.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from firstbatch.algorithm.base import BaseAlgorithm 3 | from firstbatch.algorithm.registry import AlgorithmLabel 4 | from firstbatch.algorithm.blueprint import DFAParser 5 | from firstbatch.constants import DEFAULT_EMBEDDING_SIZE 6 | 7 | 8 | class CustomAlgorithm(BaseAlgorithm, label=AlgorithmLabel.CUSTOM): 9 | is_custom: bool = True 10 | name: str = "CUSTOM" 11 | 12 | def __init__(self, bp, batch_size: int, **kwargs): 13 | parser = DFAParser(bp) 14 | blueprint = parser.parse() 15 | self._blueprint = blueprint 16 | self.embedding_size = DEFAULT_EMBEDDING_SIZE 17 | self.batch_size = batch_size 18 | self._include_values = True 19 | if "embedding_size" in kwargs: 20 | self.embedding_size = kwargs["embedding_size"] 21 | if "include_values" in kwargs: 22 | self._include_values = kwargs["include_values"] 23 | -------------------------------------------------------------------------------- /firstbatch/algorithm/factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from firstbatch.algorithm.base import BaseAlgorithm 3 | from firstbatch.algorithm.registry import AlgorithmLabel 4 | from firstbatch.algorithm.blueprint import DFAParser 5 | from firstbatch.algorithm.blueprint.library import lookup 6 | from firstbatch.constants import DEFAULT_BATCH_SIZE, DEFAULT_EMBEDDING_SIZE 7 | 8 | 9 | class FactoryAlgorithm(BaseAlgorithm, label=AlgorithmLabel.RECOMMENDATIONS): 10 | batch_size: int = DEFAULT_BATCH_SIZE 11 | is_custom: bool = False 12 | name: str = "FACTORY" 13 | 14 | def __init__(self, label, batch_size: int, **kwargs): 15 | parser = DFAParser(lookup[label]) 16 | blueprint = parser.parse() 17 | self._blueprint = blueprint 18 | self.embedding_size = DEFAULT_EMBEDDING_SIZE 19 | self.batch_size = batch_size 20 | self._include_values = True 21 | if "embedding_size" in kwargs: 22 | self.embedding_size = kwargs["embedding_size"] 23 | if "include_values" in kwargs: 24 | self._include_values = kwargs["include_values"] 25 | -------------------------------------------------------------------------------- /firstbatch/algorithm/registry.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from enum import Enum 3 | from typing import Type, Dict, Union 4 | # from firstbatch.algorithm.base import BaseAlgorithm 5 | 6 | 7 | class AlgorithmLabel(str, Enum): 8 | SIMPLE = "SIMPLE" 9 | CUSTOM = "CUSTOM" 10 | UNIQUE_JOURNEYS = "Unique_Journeys".upper() 11 | CONTENT_CURATION = "User_Centric_Promoted_Content_Curations".upper() 12 | AI_AGENTS = "User_Intent_AI_Agents".upper() 13 | RECOMMENDATIONS = "Individually_Crafted_Recommendations".upper() 14 | NAVIGATION = "Navigable_UX".upper() 15 | 16 | 17 | class AlgorithmRegistry: 18 | _registry: Dict[str, Type["BaseAlgorithm"]] = {} # type: ignore 19 | 20 | @classmethod 21 | def register_algorithm(cls, label: Union[AlgorithmLabel, str], algo_class: Type["BaseAlgorithm"]) -> None: 22 | if isinstance(label, str): 23 | if label not in [AlgorithmLabel.SIMPLE, AlgorithmLabel.CUSTOM]: 24 | cls._registry["FACTORY"] = algo_class 25 | return 26 | label = AlgorithmLabel(label) 27 | cls._registry[label.name] = algo_class 28 | 29 | @classmethod 30 | def get_algorithm_by_label(cls, label: Union[AlgorithmLabel, str]) -> Type["BaseAlgorithm"]: 31 | """Retrieve a registered algorithm class by its label.""" 32 | if isinstance(label, AlgorithmLabel): 33 | label = label.name 34 | return cls._registry[label] 35 | -------------------------------------------------------------------------------- /firstbatch/algorithm/simple.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from firstbatch.algorithm.base import BaseAlgorithm 3 | from firstbatch.algorithm.registry import AlgorithmLabel 4 | from firstbatch.algorithm.blueprint import DFAParser 5 | from firstbatch.algorithm.blueprint.library import User_Centric_Promoted_Content_Curations as Simple 6 | from firstbatch.constants import DEFAULT_BATCH_SIZE, DEFAULT_EMBEDDING_SIZE 7 | 8 | 9 | class SimpleAlgorithm(BaseAlgorithm, label=AlgorithmLabel.SIMPLE): 10 | batch_size: int = DEFAULT_BATCH_SIZE 11 | is_custom: bool = False 12 | name: str = "SIMPLE" 13 | 14 | def __init__(self, batch_size: int, **kwargs): 15 | parser = DFAParser(Simple) 16 | blueprint = parser.parse() 17 | self._blueprint = blueprint 18 | self.embedding_size = DEFAULT_EMBEDDING_SIZE 19 | self.batch_size = batch_size 20 | self._include_values = True 21 | if "embedding_size" in kwargs: 22 | self.embedding_size = kwargs["embedding_size"] 23 | if "include_values" in kwargs: 24 | self._include_values = kwargs["include_values"] 25 | -------------------------------------------------------------------------------- /firstbatch/async_core.py: -------------------------------------------------------------------------------- 1 | """MasterClass for composition""" 2 | from __future__ import annotations 3 | from typing import Optional 4 | import time 5 | from collections import defaultdict 6 | import logging 7 | from firstbatch.algorithm import AlgorithmLabel, UserAction, BatchType, BatchEnum, \ 8 | BaseAlgorithm, AlgorithmRegistry, SignalType, SessionObject 9 | from firstbatch.client.schema import GetHistoryResponse 10 | from firstbatch.vector_store import ( 11 | VectorStore, 12 | FetchQuery, 13 | Query, Vector, 14 | MetadataFilter, 15 | adjust_weights, 16 | generate_batch 17 | ) 18 | from firstbatch.lossy import ScalarQuantizer 19 | from firstbatch.client import ( 20 | BatchQuery, 21 | BatchResponse, 22 | SignalObject, 23 | session_request, 24 | signal_request, 25 | history_request, 26 | random_batch_request, 27 | biased_batch_request, 28 | sampled_batch_request, 29 | update_state_request 30 | ) 31 | 32 | from firstbatch.constants import ( 33 | DEFAULT_VERBOSE, 34 | DEFAULT_HISTORY, 35 | DEFAULT_BATCH_SIZE, 36 | DEFAULT_QUANTIZER_TRAIN_SIZE, 37 | DEFAULT_QUANTIZER_TYPE, 38 | DEFAULT_TOPK_QUANT, 39 | DEFAULT_CONFIDENCE_INTERVAL_RATIO, 40 | MINIMUM_TOPK 41 | ) 42 | from firstbatch.utils import Config 43 | from firstbatch.client.async_client import AsyncFirstBatchClient 44 | from firstbatch.logger_conf import setup_logger 45 | 46 | 47 | class AsyncFirstBatch(AsyncFirstBatchClient): 48 | 49 | def __init__(self, api_key: str, config: Config): 50 | """ 51 | Initialize the FirstBatch class 52 | :param api_key: 53 | :param config: 54 | """ 55 | super().__init__(api_key) 56 | self.store: defaultdict[str, VectorStore] = defaultdict(VectorStore) 57 | self._batch_size = DEFAULT_BATCH_SIZE 58 | self._quantizer_train_size = DEFAULT_QUANTIZER_TRAIN_SIZE 59 | self._quantizer_type = DEFAULT_QUANTIZER_TYPE 60 | self._enable_history = DEFAULT_HISTORY 61 | self._verbose = DEFAULT_VERBOSE 62 | self.logger = setup_logger() 63 | self.logger.setLevel(logging.WARN) 64 | self._set_info() 65 | 66 | if config.verbose is not None: 67 | if config.verbose: 68 | self._verbose = config.verbose 69 | self.logger.setLevel(logging.INFO) 70 | else: 71 | self.logger.setLevel(logging.WARN) 72 | if config.batch_size is not None: 73 | self._batch_size = config.batch_size 74 | if config.quantizer_train_size is not None: 75 | self._quantizer_train_size = config.quantizer_train_size 76 | if config.quantizer_type is not None: 77 | self.logger.info("Product type quantizer is not supported yet.") 78 | # self._quantizer_type = kwargs["quantizer_type"] 79 | self._quantizer_type = "scalar" 80 | if config.enable_history is not None: 81 | self._enable_history = config.enable_history 82 | 83 | self.logger.info("Set mode to verbose") 84 | self.logger.info("Using: {}".format(self.url)) 85 | 86 | async def add_vdb(self, vdbid: str, vs: VectorStore): 87 | 88 | exists = await self._vdb_exists(vdbid) 89 | 90 | if not exists: 91 | if self._quantizer_type == "scalar": 92 | self.logger.info("VectorDB with id {} not found. Sketching new VectorDB".format(vdbid)) 93 | vs.quantizer = ScalarQuantizer(256) 94 | ts = min(int(self._quantizer_train_size/DEFAULT_TOPK_QUANT), 500) 95 | batch = generate_batch(ts, vs.embedding_size, top_k=DEFAULT_TOPK_QUANT, include_values=True) 96 | 97 | results = await vs.a_multi_search(batch) 98 | vs.train_quantizer(results.vectors()) 99 | 100 | quantized_vectors = [vs.quantize_vector(vector).vector for vector in results.vectors()] 101 | # This might 1-2 minutes with scalar quantizer 102 | await self._init_vectordb_scalar(self.api_key, vdbid, quantized_vectors, vs.quantizer.quantiles) 103 | 104 | self.store[vdbid] = vs 105 | elif self._quantizer_type == "product": 106 | raise NotImplementedError("Product quantizer not supported yet") 107 | 108 | else: 109 | raise ValueError(f"Invalid quantizer type: {self._quantizer_type}") 110 | else: 111 | self.store[vdbid] = vs 112 | 113 | async def user_embeddings(self, session: SessionObject): 114 | return await self._get_user_embeddings(session) 115 | 116 | async def session(self, algorithm: AlgorithmLabel, vdbid: str, session_id: Optional[str] = None, 117 | custom_id: Optional[str] = None): 118 | 119 | if session_id is None: 120 | if algorithm == AlgorithmLabel.SIMPLE: 121 | req = session_request(**{"algorithm": algorithm.value, "vdbid": vdbid}) 122 | 123 | elif algorithm == AlgorithmLabel.CUSTOM: 124 | req = session_request(**{"algorithm": algorithm.value, "vdbid": vdbid, "custom_id": custom_id}) 125 | 126 | else: 127 | req = session_request(**{"algorithm": "FACTORY", "vdbid": vdbid, "factory_id": algorithm.value}) 128 | else: 129 | if algorithm == AlgorithmLabel.SIMPLE: 130 | req = session_request(**{"id": session_id, "algorithm": algorithm.value, "vdbid": vdbid}) 131 | 132 | elif algorithm == AlgorithmLabel.CUSTOM: 133 | req = session_request(**{"id": session_id, "algorithm": algorithm.value, "vdbid": vdbid, 134 | "custom_id": custom_id}) 135 | 136 | else: 137 | req = session_request(**{"id": session_id, "algorithm": "FACTORY", "vdbid": vdbid, 138 | "factory_id": algorithm.value}) 139 | 140 | session = await self._create_session(req) 141 | return SessionObject(id=session.data, is_persistent=(session_id is not None)) 142 | 143 | async def add_signal(self, session: SessionObject, user_action: UserAction, cid: str): 144 | response = await self._get_session(session) 145 | vs = self.store[response.vdbid] 146 | 147 | if not isinstance(user_action.action_type, SignalType): 148 | raise ValueError(f"Invalid action type: {user_action.action_type}") 149 | 150 | fetch = FetchQuery(id=cid) 151 | result = vs.fetch(fetch) 152 | 153 | algo_instance = self.__get_algorithm(vs.embedding_size, self._batch_size, response.algorithm, 154 | response.factory_id, response.custom_id) 155 | 156 | # Create a signal object based on user action and content id 157 | signal_obj = SignalObject(vector=result.vector, action=user_action.action_type, cid=cid, timestamp=int(time.time())) 158 | 159 | # Call blueprint_step to calculate the next state 160 | (next_state, batch_type, params) = algo_instance.blueprint_step(response.state, user_action) 161 | # Send signal 162 | resp = await self._signal(signal_request(session, next_state.name, signal_obj)) 163 | 164 | if self._enable_history: 165 | await self._add_history(history_request(session, [result.metadata.data[vs.history_field]])) 166 | 167 | return resp.success 168 | 169 | async def batch(self, session: SessionObject, batch_size: Optional[int] = None, **kwargs): 170 | response = await self._get_session(session) 171 | vs = self.store[response.vdbid] 172 | 173 | self.logger.info("Session: {} {} {}".format(response.algorithm, response.factory_id, response.state)) 174 | if batch_size is None: 175 | batch_size = self._batch_size 176 | 177 | algo_instance = self.__get_algorithm(vs.embedding_size, batch_size, response.algorithm, response.factory_id, response.custom_id) 178 | user_action = UserAction(BatchEnum.BATCH) 179 | 180 | (next_state, batch_type, params) = algo_instance.blueprint_step(response.state, user_action) 181 | 182 | self.logger.info("{} {}".format(batch_type, params.to_dict())) 183 | 184 | history = self._mock_history() 185 | if self._enable_history: 186 | history = await self._get_history(session) 187 | 188 | if batch_type == BatchType.RANDOM: 189 | query = random_batch_request(algo_instance.batch_size, vs.embedding_size, **params.to_dict()) 190 | await self._update_state(update_state_request(session, next_state.name, "RANDOM")) 191 | batch_response = await vs.a_multi_search(query) 192 | ids, batch = algo_instance.random_batch(batch_response, query, **params.to_dict()) 193 | 194 | elif batch_type == BatchType.PERSONALIZED or batch_type == BatchType.BIASED: 195 | 196 | if batch_type == BatchType.BIASED and not ("bias_vectors" in kwargs and "bias_weights" in kwargs): 197 | self.logger.info("Bias vectors and weights must be provided for biased batch.") 198 | raise ValueError("no bias vectors provided") 199 | 200 | if batch_type == BatchType.PERSONALIZED and ("bias_vectors" in kwargs and "bias_weights" in kwargs): 201 | del kwargs["bias_vectors"] 202 | del kwargs["bias_weights"] 203 | 204 | if not response.has_embeddings and batch_type == BatchType.PERSONALIZED: 205 | self.logger.info("No embeddings found for personalized batch. Switching to random batch.") 206 | query = random_batch_request(algo_instance.batch_size, vs.embedding_size, **{"apply_mmr": True}) 207 | await self._update_state(update_state_request(session, next_state.name, batch_type="PERSONALIZED")) 208 | batch_response = await vs.a_multi_search(query) 209 | ids, batch = algo_instance.random_batch(batch_response, query, **params.to_dict()) 210 | 211 | else: 212 | batch_response_ = await self._biased_batch(biased_batch_request(session, response.vdbid, next_state.name, 213 | params.to_dict(), **kwargs)) 214 | query = self.__query_wrapper(response.vdbid, algo_instance.batch_size, batch_response_, history, **params.to_dict()) 215 | batch = await vs.a_multi_search(query) 216 | ids, batch = algo_instance.biased_batch(batch, query, **params.to_dict()) 217 | 218 | elif batch_type == BatchType.SAMPLED: 219 | batch_response_ = await self._sampled_batch(sampled_batch_request(session=session, vdbid=response.vdbid, 220 | state=next_state.name, n_topics=params.n_topics)) 221 | query = self.__query_wrapper(response.vdbid, algo_instance.batch_size, batch_response_, history, **params.to_dict()) 222 | batch = await vs.a_multi_search(query) 223 | ids, batch = algo_instance.sampled_batch(batch, query, **params.to_dict()) 224 | 225 | else: 226 | raise ValueError(f"Invalid batch type: {next_state.batch_type}") 227 | 228 | if self._enable_history: 229 | content = [b.data[vs.history_field] for b in batch[:algo_instance.batch_size]] 230 | await self._add_history(history_request(session, content)) 231 | 232 | return ids, batch 233 | 234 | def __query_wrapper(self, vdbid: str, batch_size: int, response: BatchResponse, 235 | history: Optional[GetHistoryResponse], **kwargs): 236 | """ 237 | Wrapper for the query method. It applies the parameters from the blueprint 238 | :param vdbid: VectorDB ID, str 239 | :param batch_size: Batch size, int 240 | :param response: response from the API, BatchResponse 241 | :param history: list of content ids, List[str] 242 | :param kwargs: 243 | :return: 244 | """ 245 | 246 | topks = adjust_weights(response.weights, batch_size, max((batch_size * DEFAULT_CONFIDENCE_INTERVAL_RATIO), 1)) 247 | m_filter = MetadataFilter(name="", filter={}) 248 | # We need the vector values to apply MMR or threshold 249 | include_values = ("apply_mmr" in kwargs or "apply_threshold" in kwargs) 250 | apply_mmr = "apply_mmr" in kwargs and (kwargs["apply_mmr"] is True or kwargs["apply_mmr"] == 1) 251 | 252 | if self._enable_history: 253 | if history is None: 254 | self.logger.info("History is None, No filter will be applied.") 255 | history = GetHistoryResponse(ids=[]) 256 | 257 | if "filter" in kwargs: 258 | m_filter = self.store[vdbid].history_filter(history.ids, kwargs["filter"]) 259 | else: 260 | m_filter = self.store[vdbid].history_filter(history.ids) 261 | 262 | if apply_mmr: 263 | # increase top_k for MMR to work better 264 | qs = [Query(Vector(vec, len(vec), ""), max(topks[i], MINIMUM_TOPK) * 2, filter=m_filter, include_values=include_values) for i, vec 265 | in 266 | enumerate(response.vectors)] 267 | else: 268 | qs = [Query(Vector(vec, len(vec), ""), max(topks[i], MINIMUM_TOPK), filter=m_filter, include_values=include_values) for i, vec in 269 | enumerate(response.vectors)] 270 | return BatchQuery(qs, batch_size) 271 | 272 | def __get_algorithm(self, embedding_size: int, batch_size: int, algorithm: str, factory_id: Optional[str] = None, 273 | custom_id: Optional[str] = None) -> BaseAlgorithm: 274 | 275 | if algorithm == "SIMPLE": 276 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 277 | algo_instance: BaseAlgorithm = algo_type(batch_size, **{"embedding_size": embedding_size}) 278 | elif algorithm == "CUSTOM": 279 | if custom_id is None: 280 | raise ValueError("Custom algorithm id is None") 281 | bp = self._get_blueprint(custom_id) 282 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 283 | algo_instance: BaseAlgorithm = algo_type(bp, batch_size, **{"embedding_size": embedding_size}) # type: ignore 284 | elif algorithm == "FACTORY": 285 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 286 | algo_instance: BaseAlgorithm = algo_type(factory_id, batch_size, **{"embedding_size": embedding_size}) # type: ignore 287 | else: 288 | raise ValueError(f"Invalid algorithm: {algorithm}") 289 | return algo_instance 290 | 291 | 292 | -------------------------------------------------------------------------------- /firstbatch/client/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.client.sync import FirstBatchClient 2 | from firstbatch.client.async_client import AsyncFirstBatchClient 3 | from firstbatch.client.schema import Session, SampledBatchRequest, BiasedBatchRequest, SignalRequest, InitRequest,\ 4 | CreateSessionRequest, AddHistoryRequest, UpdateStateRequest, BatchResponse 5 | from firstbatch.client.wrapper import history_request, session_request, update_state_request, signal_request, \ 6 | random_batch_request, biased_batch_request, sampled_batch_request, BatchQuery, SignalObject 7 | 8 | __all__ = ["FirstBatchClient","AsyncFirstBatchClient", "Session", "SampledBatchRequest", "BiasedBatchRequest", 9 | "SignalRequest", "InitRequest", "CreateSessionRequest", "AddHistoryRequest", "UpdateStateRequest", 10 | "session_request", "signal_request", "history_request", "random_batch_request", "biased_batch_request", 11 | "sampled_batch_request", "update_state_request", "BatchQuery", "SignalObject", "BatchResponse"] 12 | -------------------------------------------------------------------------------- /firstbatch/client/async_client.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | from typing import Dict, List, Union, Any, Optional, cast 3 | from pydantic import ValidationError 4 | import hashlib 5 | from firstbatch.constants import regions, REGION_URL 6 | from firstbatch.client.schema import (APIResponse, BatchResponse, GetSessionResponse, GetHistoryResponse, 7 | AddHistoryRequest, CreateSessionRequest, SignalRequest, BiasedBatchRequest, 8 | SampledBatchRequest, UpdateStateRequest, FirstBatchAPIError) 9 | from firstbatch.client.base import BaseClient 10 | from firstbatch.algorithm.blueprint.base import SessionObject 11 | 12 | 13 | class AsyncFirstBatchClient(BaseClient): 14 | def __init__(self, api_key: str, **kwargs): 15 | self.api_key = api_key 16 | self.url = "" 17 | self.region = "" 18 | self.team_id = "" 19 | self.headers = { 20 | "x-api-key": self.api_key, 21 | "Content-Type": "application/json" 22 | } 23 | 24 | async def _post_request(self, url: str, data: Dict) -> Dict: 25 | async with httpx.AsyncClient(timeout=10) as client: 26 | response = await client.post(url, headers=self.headers, json=data) 27 | return response.json() 28 | 29 | @staticmethod 30 | def __error_handling(response: Dict, func_name: str) -> None: 31 | if response["success"] is False: 32 | raise FirstBatchAPIError(f"FirstBatch API error with code {response['status_code']} " 33 | f"in {func_name} with reason: {response['message']}") 34 | 35 | def __id_wrapper(self, id: Union[str, None]): 36 | if isinstance(id, str): 37 | return self.team_id + "-" + id 38 | return id 39 | 40 | @staticmethod 41 | def __session_wrapper(session: SessionObject): 42 | return session.id 43 | 44 | # For example: 45 | async def _init_vectordb_scalar(self, key: str, vdbid: str, vecs: List[List[int]], quantiles: List[float]) -> Any: 46 | m = hashlib.md5() 47 | m.update(key.encode()) 48 | hash_value = m.hexdigest() 49 | data = { 50 | "key": hash_value, 51 | "vdbid": vdbid, 52 | "mode": "scalar", 53 | "region": self.region, 54 | "quantized_vecs": vecs, 55 | "quantiles": quantiles 56 | } 57 | response = await self._post_request(self.url + "embeddings/init_vdb", data) 58 | self.__error_handling(response, "init_vectordb_scalar") 59 | try: 60 | return APIResponse(**response) 61 | except ValidationError as e: 62 | raise e 63 | 64 | async def _add_history(self, req: AddHistoryRequest) -> Any: 65 | data = { 66 | "id": self.__session_wrapper(req.session), 67 | "ids": req.ids 68 | } 69 | response = await self._post_request(self.url + "embeddings/update_history", data) 70 | self.__error_handling(response, "add_history") 71 | try: 72 | return APIResponse(**response) 73 | except ValidationError as e: 74 | raise e 75 | 76 | async def _create_session(self, req: CreateSessionRequest) -> APIResponse: 77 | data = {"id": self.__id_wrapper(req.id), "algorithm": req.algorithm, "vdbid": req.vdbid, 78 | "custom_id": req.custom_id, "factory_id": req.factory_id, "has_embeddings": req.has_embeddings} 79 | response = await self._post_request(self.url + "embeddings/create_session", data) 80 | self.__error_handling(response, "create_session") 81 | try: 82 | return APIResponse(**response) 83 | except ValidationError as e: 84 | raise e 85 | 86 | async def _update_state(self, req: UpdateStateRequest) -> APIResponse: 87 | data = {"id": self.__session_wrapper(req.session), "state": req.state, "batch_type": req.batch_type} 88 | response = await self._post_request(self.url + "embeddings/update_state", data) 89 | self.__error_handling(response, "update_state") 90 | try: 91 | return APIResponse(**response) 92 | except ValidationError as e: 93 | raise e 94 | 95 | async def _signal(self, req: SignalRequest) -> APIResponse: 96 | data = { 97 | "id": self.__session_wrapper(req.session), 98 | "vector": req.vector, 99 | "signal": req.signal, 100 | "state": req.state, 101 | "signal_label": req.signal_label 102 | } 103 | response = await self._post_request(self.url + "embeddings/signal", data) 104 | self.__error_handling(response, "signal") 105 | try: 106 | return APIResponse(**response) 107 | except ValidationError as e: 108 | raise e 109 | 110 | async def _biased_batch(self, req: BiasedBatchRequest) -> BatchResponse: 111 | data = { 112 | "id": self.__session_wrapper(req.session), 113 | "vdbid": req.vdbid, 114 | "bias_vectors": req.bias_vectors, 115 | "bias_weights": req.bias_weights, 116 | "params": req.params, 117 | "state": req.state 118 | } 119 | response = await self._post_request(self.url + "embeddings/biased_batch", data) 120 | self.__error_handling(response, "biased_batch") 121 | try: 122 | api_response = APIResponse(**response).data 123 | if isinstance(api_response, dict): 124 | vectors = api_response.get('vectors') 125 | weights = api_response.get('weights') 126 | 127 | if vectors is not None and weights is not None: 128 | vectors_casted = cast(List[List[float]], vectors) 129 | weights_casted = cast(List[float], weights) 130 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 131 | else: 132 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 133 | else: 134 | raise TypeError("Expected a dictionary in APIResponse.data.") 135 | except ValidationError as e: 136 | raise e 137 | 138 | async def _sampled_batch(self, req: SampledBatchRequest) -> BatchResponse: 139 | data = { 140 | "id": self.__session_wrapper(req.session), 141 | "n": req.n, 142 | "vdbid": req.vdbid, 143 | "params": req.params, 144 | "state": req.state 145 | } 146 | response = await self._post_request(self.url + "embeddings/sampled_batch", data) 147 | self.__error_handling(response, "sampled_batch") 148 | try: 149 | api_response = APIResponse(**response).data 150 | if isinstance(api_response, dict): 151 | vectors = api_response.get('vectors') 152 | weights = api_response.get('weights') 153 | 154 | if vectors is not None and weights is not None: 155 | vectors_casted = cast(List[List[float]], vectors) 156 | weights_casted = cast(List[float], weights) 157 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 158 | else: 159 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 160 | else: 161 | raise TypeError("Expected a dictionary in APIResponse.data.") 162 | except ValidationError as e: 163 | raise e 164 | 165 | async def _get_session(self, session: SessionObject) -> GetSessionResponse: 166 | data = {"id": self.__session_wrapper(session)} 167 | response = await self._post_request(self.url + "embeddings/get_session", data) 168 | self.__error_handling(response, "get_session") 169 | try: 170 | api_response = APIResponse(**response).data 171 | if isinstance(api_response, dict): 172 | state = api_response.get("state", "") 173 | algorithm = api_response.get("algorithm", "") 174 | vdbid = api_response.get("vdbid", "") 175 | has_embeddings = api_response.get("has_embeddings", "") 176 | factory_id = api_response.get("factory_id") 177 | custom_id = api_response.get("custom_id") 178 | 179 | if state and algorithm and vdbid: 180 | return GetSessionResponse( 181 | state=cast(str, state), 182 | algorithm=cast(str, algorithm), 183 | vdbid=cast(str, vdbid), 184 | has_embeddings=cast(bool, has_embeddings), 185 | factory_id=cast(Optional[str], factory_id), 186 | custom_id=cast(Optional[str], custom_id) 187 | ) 188 | else: 189 | raise ValueError("Missing mandatory keys in API response.") 190 | else: 191 | raise TypeError("Expected a dictionary in APIResponse.data.") 192 | except ValidationError as e: 193 | raise e 194 | 195 | async def _get_history(self, session: SessionObject) -> GetHistoryResponse: 196 | data = {"id": self.__session_wrapper(session)} 197 | response = await self._post_request(self.url + "embeddings/get_history", data) 198 | self.__error_handling(response, "get_history") 199 | try: 200 | api_response = APIResponse(**response).data 201 | if isinstance(api_response, dict): 202 | ids = api_response.get('ids') 203 | if ids is not None: 204 | ids_casted = cast(List[str], ids) 205 | return GetHistoryResponse(ids=ids_casted) 206 | else: 207 | raise ValueError("Missing 'ids' in API response.") 208 | else: 209 | raise TypeError("Expected a dictionary in APIResponse.data.") 210 | except ValidationError as e: 211 | raise e 212 | 213 | async def _get_user_embeddings(self, session: SessionObject, last_n: Optional[int] = None) -> BatchResponse: 214 | 215 | data = {"id": self.__session_wrapper(session), "last_n": 50} 216 | if last_n is not None: 217 | data["last_n"] = last_n 218 | 219 | response = await self._post_request(self.url + "embeddings/get_embeddings", data) 220 | self.__error_handling(response, "get_user_embeddings") 221 | try: 222 | api_response = APIResponse(**response).data 223 | if isinstance(api_response, dict): 224 | vectors = api_response.get('vectors') 225 | weights = api_response.get('weights') 226 | 227 | if vectors is not None and weights is not None: 228 | vectors_casted = cast(List[List[float]], vectors) 229 | weights_casted = cast(List[float], weights) 230 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 231 | else: 232 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 233 | else: 234 | raise TypeError("Expected a dictionary in APIResponse.data.") 235 | except ValidationError as e: 236 | raise e 237 | 238 | async def _vdb_exists(self, vdbid: str) -> bool: 239 | data = {"vdbid": vdbid} 240 | response = await self._post_request(self.url + "embeddings/vdb_exists", data) 241 | self.__error_handling(response, "vdb_exists") 242 | try: 243 | return response["data"] 244 | except ValidationError as e: 245 | raise e 246 | 247 | async def _get_blueprint(self, custom_id: str) -> Any: 248 | data = {"id": custom_id} 249 | response = await self._post_request(self.url + "embeddings/get_blueprint", data) 250 | self.__error_handling(response, "get_blueprint") 251 | try: 252 | return response["data"] 253 | except ValidationError as e: 254 | raise e 255 | 256 | @staticmethod 257 | def _mock_history() -> GetHistoryResponse: 258 | return GetHistoryResponse(ids=[]) 259 | 260 | def _set_info(self) -> Any: 261 | import requests 262 | response = requests.get(REGION_URL, headers=self.headers) 263 | self.__error_handling(response.json(), "team_info") 264 | 265 | try: 266 | data = response.json()["data"] 267 | except ValidationError as e: 268 | raise e 269 | 270 | self.team_id = data["teamID"] 271 | region = data["region"] 272 | try: 273 | self.url = regions[region] 274 | self.region = region 275 | except ValueError: 276 | raise ValueError("There is no such region {}".format(region)) 277 | 278 | -------------------------------------------------------------------------------- /firstbatch/client/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Any 3 | from firstbatch.client.schema import AddHistoryRequest, CreateSessionRequest, SignalRequest, BiasedBatchRequest, \ 4 | SampledBatchRequest, UpdateStateRequest 5 | from firstbatch.algorithm.blueprint.base import SessionObject 6 | 7 | 8 | class BaseClient(ABC): 9 | 10 | @abstractmethod 11 | def _init_vectordb_scalar(self, key: str, vdbid:str, vecs:List[List[int]], quantiles: List[float]) -> Any: 12 | ... 13 | 14 | @abstractmethod 15 | def _add_history(self, req: AddHistoryRequest) -> Any: 16 | ... 17 | 18 | @abstractmethod 19 | def _create_session(self, req: CreateSessionRequest) -> Any: 20 | ... 21 | 22 | @abstractmethod 23 | def _update_state(self, req: UpdateStateRequest) -> Any: 24 | ... 25 | 26 | @abstractmethod 27 | def _signal(self, req: SignalRequest) -> Any: 28 | ... 29 | 30 | @abstractmethod 31 | def _biased_batch(self, req: BiasedBatchRequest) -> Any: 32 | ... 33 | 34 | @abstractmethod 35 | def _sampled_batch(self, req: SampledBatchRequest) -> Any: 36 | ... 37 | 38 | @abstractmethod 39 | def _get_session(self, session: SessionObject) -> Any: 40 | ... 41 | 42 | @abstractmethod 43 | def _get_history(self, session: SessionObject) -> Any: 44 | ... 45 | 46 | @abstractmethod 47 | def _get_user_embeddings(self, session: SessionObject) -> Any: 48 | ... 49 | 50 | @abstractmethod 51 | def _vdb_exists(self, vdbid: str) -> Any: 52 | ... 53 | 54 | @abstractmethod 55 | def _get_blueprint(self, custom_id: str) -> Any: 56 | ... 57 | 58 | @staticmethod 59 | @abstractmethod 60 | def _mock_history() -> Any: 61 | ... 62 | 63 | @abstractmethod 64 | def _set_info(self) -> Any: 65 | ... 66 | 67 | 68 | -------------------------------------------------------------------------------- /firstbatch/client/schema.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Optional, Dict, Union, Any 3 | from firstbatch.algorithm.registry import AlgorithmLabel 4 | from firstbatch.algorithm.blueprint.base import SessionObject 5 | from dataclasses_json import DataClassJsonMixin 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass 10 | class Session(DataClassJsonMixin): 11 | id: str 12 | algorithm: AlgorithmLabel 13 | state: int 14 | metadata: Optional[Dict[str, Any]] = None 15 | 16 | 17 | @dataclass 18 | class InitRequest: 19 | vdbid: str 20 | vecs: List[List[int]] 21 | quantiles: List[float] 22 | key: str 23 | 24 | 25 | @dataclass 26 | class AddHistoryRequest: 27 | session: SessionObject 28 | ids: List[str] 29 | 30 | 31 | @dataclass 32 | class CreateSessionRequest: 33 | algorithm: str 34 | vdbid: str 35 | has_embeddings: bool = False 36 | custom_id : Optional[str] = None 37 | factory_id: Optional[str] = None 38 | id: Optional[str] = None 39 | 40 | 41 | @dataclass 42 | class UpdateStateRequest: 43 | session: SessionObject 44 | state: str 45 | batch_type: Optional[str] = None 46 | 47 | 48 | @dataclass 49 | class SignalRequest: 50 | session: SessionObject 51 | vector: List[float] 52 | signal: float 53 | signal_label: str 54 | state: str 55 | 56 | 57 | @dataclass 58 | class BiasedBatchRequest: 59 | session: SessionObject 60 | vdbid: str 61 | state: str 62 | bias_vectors: Optional[List[List[float]]] = None 63 | bias_weights: Optional[List[float]] = None 64 | params: Optional[Dict[str, float]] = None 65 | 66 | 67 | @dataclass 68 | class SampledBatchRequest: 69 | session: SessionObject 70 | n: int 71 | vdbid: str 72 | state: str 73 | params: Optional[Dict[str, float]] = None 74 | 75 | 76 | class APIResponse(BaseModel): 77 | success: bool 78 | code: int 79 | data: Optional[Union[str, Dict[str, Union[str, int, List[str] ,List[float], List[List[float]]]]]] 80 | message: Optional[str] = None 81 | 82 | 83 | class GetHistoryResponse(BaseModel): 84 | ids: List[str] 85 | 86 | 87 | class GetSessionResponse(BaseModel): 88 | state: str 89 | algorithm: str 90 | vdbid: str 91 | has_embeddings: bool 92 | factory_id: Optional[str] = None 93 | custom_id: Optional[str] = None 94 | 95 | 96 | class SignalResponse(BaseModel): 97 | ... 98 | 99 | 100 | class BatchResponse(BaseModel): 101 | vectors: List[List[float]] 102 | weights: List[float] 103 | 104 | 105 | class FirstBatchAPIError(Exception): 106 | ... 107 | -------------------------------------------------------------------------------- /firstbatch/client/sync.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from typing import List, Dict, Union, Optional, Any, cast 3 | from pydantic import ValidationError 4 | import hashlib 5 | from firstbatch.constants import regions, REGION_URL 6 | from firstbatch.client.schema import APIResponse, BatchResponse, GetSessionResponse, GetHistoryResponse 7 | from firstbatch.client.schema import AddHistoryRequest, CreateSessionRequest, SignalRequest, BiasedBatchRequest,\ 8 | SampledBatchRequest, UpdateStateRequest, FirstBatchAPIError 9 | from firstbatch.client.base import BaseClient 10 | from firstbatch.algorithm.blueprint.base import SessionObject 11 | 12 | 13 | class FirstBatchClient(BaseClient): 14 | def __init__(self, api_key: str): 15 | self.api_key = api_key 16 | self.url: str = "" 17 | self.region: str = "" 18 | self.team_id: str = "" 19 | self.headers = { 20 | "x-api-key": self.api_key, 21 | "Content-Type": "application/json" 22 | } 23 | 24 | @staticmethod 25 | def __error_check(response: requests.Response, func_name: str) -> None: 26 | if response.status_code != 200: 27 | raise FirstBatchAPIError( 28 | f"FirstBatch API error with code {response.status_code} in {func_name} with reason: {response.reason}") 29 | 30 | def __id_wrapper(self, id: Union[str, None]): 31 | if isinstance(id, str): 32 | return self.team_id + "-" + id 33 | return id 34 | 35 | @staticmethod 36 | def __session_wrapper(session: SessionObject): 37 | return session.id 38 | 39 | def _init_vectordb_scalar(self, key: str, vdbid: str, vecs: List[List[int]], quantiles: List[float]) -> Any: 40 | 41 | m = hashlib.md5() 42 | m.update(key.encode()) 43 | hash_value = m.hexdigest() 44 | 45 | data = { 46 | "key": hash_value, 47 | "vdbid": vdbid, 48 | "mode": "scalar", 49 | "region": self.region, 50 | "quantized_vecs": vecs, 51 | "quantiles": quantiles 52 | } 53 | response = requests.post(self.url + "embeddings/init_vdb", headers=self.headers, json=data) 54 | self.__error_check(response, "init_vectordb_scalar") 55 | try: 56 | return APIResponse(**response.json()) 57 | except ValidationError as e: 58 | raise e 59 | 60 | def _add_history(self, req: AddHistoryRequest) -> Any: 61 | data = { 62 | "id": self.__session_wrapper(req.session), 63 | "ids": req.ids 64 | } 65 | response = requests.post(self.url + "embeddings/update_history", headers=self.headers, json=data) 66 | self.__error_check(response, "add_history") 67 | try: 68 | return APIResponse(**response.json()) 69 | except ValidationError as e: 70 | raise e 71 | 72 | def _create_session(self, req: CreateSessionRequest) -> APIResponse: 73 | data = {"id": self.__id_wrapper(req.id), "algorithm": req.algorithm, "vdbid": req.vdbid, 74 | "custom_id": req.custom_id, "factory_id": req.factory_id, "has_embeddings": req.has_embeddings} 75 | response = requests.post(self.url + "embeddings/create_session", headers=self.headers, json=data) 76 | self.__error_check(response, "create_session") 77 | try: 78 | return APIResponse(**response.json()) 79 | except ValidationError as e: 80 | # Handle parsing errors, maybe raise a custom exception or just re-raise 81 | raise e 82 | 83 | def _update_state(self, req: UpdateStateRequest) -> APIResponse: 84 | data = {"id": self.__session_wrapper(req.session), "state":req.state, "batch_type": req.batch_type} 85 | response = requests.post(self.url + "embeddings/update_state", headers=self.headers, json=data) 86 | self.__error_check(response, "update_state") 87 | try: 88 | return APIResponse(**response.json()) 89 | except ValidationError as e: 90 | # Handle parsing errors, maybe raise a custom exception or just re-raise 91 | raise e 92 | 93 | def _signal(self, req: SignalRequest) -> APIResponse: 94 | data = { 95 | "id": self.__session_wrapper(req.session), 96 | "vector": req.vector, 97 | "signal": req.signal, 98 | "state": req.state, 99 | "signal_label": req.signal_label 100 | } 101 | response = requests.post(self.url + "embeddings/signal", headers=self.headers, json=data) 102 | self.__error_check(response, "signal") 103 | try: 104 | return APIResponse(**response.json()) 105 | except ValidationError as e: 106 | raise e 107 | 108 | def _biased_batch(self, req: BiasedBatchRequest) -> BatchResponse: 109 | data = { 110 | "id": self.__session_wrapper(req.session), 111 | "vdbid": req.vdbid, 112 | "bias_vectors": req.bias_vectors, 113 | "bias_weights": req.bias_weights, 114 | "params": req.params, 115 | "state": req.state 116 | } 117 | response = requests.post(self.url + "embeddings/biased_batch", headers=self.headers, json=data) 118 | self.__error_check(response, "biased_batch") 119 | 120 | try: 121 | api_response = APIResponse(**response.json()).data 122 | if isinstance(api_response, dict): 123 | vectors = api_response.get('vectors') 124 | weights = api_response.get('weights') 125 | 126 | if vectors is not None and weights is not None: 127 | vectors_casted = cast(List[List[float]], vectors) 128 | weights_casted = cast(List[float], weights) 129 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 130 | else: 131 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 132 | else: 133 | raise TypeError("Expected a dictionary in APIResponse.data.") 134 | except ValidationError as e: 135 | raise e 136 | 137 | def _sampled_batch(self, req: SampledBatchRequest) -> BatchResponse: 138 | data = { 139 | "id": self.__session_wrapper(req.session), 140 | "n": req.n, 141 | "vdbid": req.vdbid, 142 | "params": req.params, 143 | "state": req.state 144 | } 145 | response = requests.post(self.url + "embeddings/sampled_batch", headers=self.headers, json=data) 146 | self.__error_check(response, "sampled_batch") 147 | try: 148 | api_response = APIResponse(**response.json()).data 149 | if isinstance(api_response, dict): 150 | vectors = api_response.get('vectors') 151 | weights = api_response.get('weights') 152 | 153 | if vectors is not None and weights is not None: 154 | vectors_casted = cast(List[List[float]], vectors) 155 | weights_casted = cast(List[float], weights) 156 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 157 | else: 158 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 159 | else: 160 | raise TypeError("Expected a dictionary in APIResponse.data.") 161 | except ValidationError as e: 162 | raise e 163 | 164 | def _get_session(self, session: SessionObject) -> GetSessionResponse: 165 | 166 | data = {"id": self.__session_wrapper(session)} 167 | response = requests.post(self.url + "embeddings/get_session", headers=self.headers, json=data) 168 | self.__error_check(response, "get_session") 169 | 170 | try: 171 | api_response = APIResponse(**response.json()).data 172 | if isinstance(api_response, dict): 173 | state = api_response.get("state", "") 174 | algorithm = api_response.get("algorithm", "") 175 | vdbid = api_response.get("vdbid", "") 176 | has_embeddings = api_response.get("has_embeddings", "") 177 | factory_id = api_response.get("factory_id") 178 | custom_id = api_response.get("custom_id") 179 | 180 | if state and algorithm and vdbid: 181 | return GetSessionResponse( 182 | state=cast(str, state), 183 | algorithm=cast(str, algorithm), 184 | vdbid=cast(str, vdbid), 185 | has_embeddings=cast(bool, has_embeddings), 186 | factory_id=cast(Optional[str], factory_id), 187 | custom_id=cast(Optional[str], custom_id) 188 | ) 189 | else: 190 | raise ValueError("Missing 'state', 'algorithm' or 'vdbid' in API response.") 191 | else: 192 | raise TypeError("Expected a dictionary in APIResponse.data.") 193 | except ValidationError as e: 194 | raise e 195 | 196 | def _get_history(self, session: SessionObject) -> GetHistoryResponse: 197 | data = {"id": self.__session_wrapper(session)} 198 | response = requests.post(self.url + "embeddings/get_history", headers=self.headers, json=data) 199 | self.__error_check(response, "get_history") 200 | 201 | try: 202 | api_response = APIResponse(**response.json()).data 203 | if isinstance(api_response, dict): 204 | history_data = cast(Dict[str, List[str]], api_response) 205 | return GetHistoryResponse(**history_data) 206 | else: 207 | raise TypeError("Expected a dictionary in APIResponse.data.") 208 | except ValidationError as e: 209 | raise e 210 | 211 | def _get_user_embeddings(self, session: SessionObject, last_n: Optional[int] = None) -> BatchResponse: 212 | 213 | data = {"id": self.__session_wrapper(session), "last_n": 50} 214 | if last_n is not None: 215 | data["last_n"] = last_n 216 | 217 | response = requests.post(self.url + "embeddings/get_embeddings", headers=self.headers, json=data) 218 | self.__error_check(response, "get_user_embeddings") 219 | try: 220 | api_response = APIResponse(**response.json()).data 221 | if isinstance(api_response, dict): 222 | vectors = api_response.get('vectors') 223 | weights = api_response.get('weights') 224 | 225 | if vectors is not None and weights is not None: 226 | vectors_casted = cast(List[List[float]], vectors) 227 | weights_casted = cast(List[float], weights) 228 | return BatchResponse(vectors=vectors_casted, weights=weights_casted) 229 | else: 230 | raise ValueError("Missing 'vectors' or 'weights' in API response.") 231 | else: 232 | raise TypeError("Expected a dictionary in APIResponse.data.") 233 | except ValidationError as e: 234 | raise e 235 | 236 | def _vdb_exists(self, vdbid: str) -> bool: 237 | data = {"vdbid": vdbid} 238 | response = requests.post(self.url + "embeddings/vdb_exists", headers=self.headers, json=data) 239 | self.__error_check(response, "vdb_exists") 240 | try: 241 | return response.json()["data"] 242 | except ValidationError as e: 243 | raise e 244 | 245 | def _get_blueprint(self, custom_id: str) -> Any: 246 | data = {"id": custom_id} 247 | response = requests.post(self.url + "embeddings/get_blueprint", headers=self.headers, json=data) 248 | self.__error_check(response, "get_blueprint") 249 | try: 250 | return response.json()["data"] 251 | except ValidationError as e: 252 | raise e 253 | 254 | @staticmethod 255 | def _mock_history() -> GetHistoryResponse: 256 | return GetHistoryResponse(ids=[]) 257 | 258 | def _set_info(self) -> Any: 259 | response = requests.get(REGION_URL, headers=self.headers) 260 | self.__error_check(response, "team_info") 261 | 262 | try: 263 | data = response.json()["data"] 264 | except ValidationError as e: 265 | raise e 266 | 267 | self.team_id = data["teamID"] 268 | region = data["region"] 269 | try: 270 | self.url = regions[region] 271 | self.region = region 272 | except ValueError: 273 | raise ValueError("There is no such region {}".format(region)) 274 | 275 | -------------------------------------------------------------------------------- /firstbatch/client/wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from firstbatch.client import BiasedBatchRequest, SignalRequest, SampledBatchRequest, CreateSessionRequest, \ 3 | AddHistoryRequest, UpdateStateRequest 4 | from firstbatch.vector_store.utils import generate_batch 5 | from firstbatch.vector_store.schema import BatchQuery 6 | from firstbatch.algorithm.blueprint import SignalObject, SessionObject 7 | from firstbatch.constants import MINIMUM_TOPK 8 | 9 | 10 | def history_request(session: SessionObject, ids: List[str]) -> AddHistoryRequest: 11 | return AddHistoryRequest(session, ids) 12 | 13 | 14 | def session_request(**kwargs) -> CreateSessionRequest: 15 | return CreateSessionRequest(**kwargs) 16 | 17 | 18 | def update_state_request(session: SessionObject, state: str, batch_type: str) -> UpdateStateRequest: 19 | return UpdateStateRequest(session=session, state=state, batch_type=batch_type) 20 | 21 | 22 | def signal_request(session: SessionObject, state: str, signal: SignalObject) -> SignalRequest: 23 | return SignalRequest(session=session, vector=signal.vector.vector, 24 | signal=signal.action.weight, state=state, signal_label=signal.action.label) 25 | 26 | 27 | def sampled_batch_request(session: SessionObject, vdbid: str, state: str, n_topics: int) -> SampledBatchRequest: 28 | return SampledBatchRequest(session=session, n=n_topics, vdbid=vdbid, state=state) 29 | 30 | 31 | def biased_batch_request(session: SessionObject, vdb_id: str, state: str, params: Dict[str, float], **kwargs) \ 32 | -> BiasedBatchRequest: 33 | if "bias_vectors" in kwargs and "bias_weights" in kwargs: 34 | return BiasedBatchRequest(session, vdb_id, state, 35 | bias_vectors=kwargs["bias_vectors"], bias_weights=kwargs["bias_weights"], 36 | params=params) 37 | else: 38 | return BiasedBatchRequest(session, vdb_id, state, params=params) 39 | 40 | 41 | def random_batch_request(batch_size: int, embedding_size: int, **kwargs) -> BatchQuery: 42 | return generate_batch(batch_size, embedding_size, top_k=MINIMUM_TOPK * 2, 43 | include_values=("apply_mmr" in kwargs or "apply_threshold" in kwargs)) 44 | -------------------------------------------------------------------------------- /firstbatch/constants.py: -------------------------------------------------------------------------------- 1 | """Constants for the firstbatch package.""" 2 | 3 | EU_CENTRAL_1 = "https://aws-eu-central-1.hollowdb.xyz/" 4 | US_WEST_1 = "https://aws-us-west-1.hollowdb.xyz/" 5 | US_EAST_1 = "https://aws-us-east-1.hollowdb.xyz/" 6 | ASIA_PACIFIC_1 = "https://aws-ap-southeast-1.hollowdb.xyz/" 7 | 8 | regions = {"us-east-1":US_EAST_1, "us-west-1":US_WEST_1, "eu-central-1":EU_CENTRAL_1, "ap-southeast-1": ASIA_PACIFIC_1} 9 | 10 | REGION_URL = "https://idp.firstbatch.xyz/v1/teams/team/get-team-information" 11 | 12 | DEFAULT_QUANTIZER_TRAIN_SIZE = 100 13 | DEFAULT_QUANTIZER_TYPE = "scalar" 14 | DEFAULT_EMBEDDING_SIZE = 1536 15 | DEFAULT_CONFIDENCE_INTERVAL_RATIO = 0.15 16 | DEFAULT_COLLECTION = "my_collection" 17 | DEFAULT_BATCH_SIZE = 10 18 | DEFAULT_KEY = "text" 19 | DEFAULT_TOPK_QUANT = 5 20 | MINIMUM_TOPK = 5 21 | DEFAULT_HISTORY = False 22 | DEFAULT_VERBOSE = False 23 | DEFAULT_HISTORY_FIELD = "id" 24 | 25 | -------------------------------------------------------------------------------- /firstbatch/core.py: -------------------------------------------------------------------------------- 1 | """MasterClass for composition""" 2 | from __future__ import annotations 3 | from typing import Optional 4 | import time 5 | from collections import defaultdict 6 | import logging 7 | from firstbatch.algorithm import AlgorithmLabel, UserAction, BatchType, BatchEnum, \ 8 | BaseAlgorithm, AlgorithmRegistry, SignalType, SessionObject 9 | from firstbatch.client.schema import GetHistoryResponse 10 | from firstbatch.vector_store import ( 11 | VectorStore, 12 | FetchQuery, 13 | Query, Vector, 14 | MetadataFilter, 15 | adjust_weights, 16 | generate_batch 17 | ) 18 | from firstbatch.lossy import ScalarQuantizer 19 | from firstbatch.client import ( 20 | BatchQuery, 21 | BatchResponse, 22 | SignalObject, 23 | session_request, 24 | signal_request, 25 | history_request, 26 | random_batch_request, 27 | biased_batch_request, 28 | sampled_batch_request, 29 | update_state_request 30 | ) 31 | from firstbatch.utils import Config 32 | from firstbatch.constants import ( 33 | DEFAULT_VERBOSE, 34 | DEFAULT_HISTORY, 35 | DEFAULT_BATCH_SIZE, 36 | DEFAULT_QUANTIZER_TRAIN_SIZE, 37 | DEFAULT_QUANTIZER_TYPE, 38 | DEFAULT_TOPK_QUANT, 39 | DEFAULT_CONFIDENCE_INTERVAL_RATIO, 40 | MINIMUM_TOPK 41 | ) 42 | 43 | from firstbatch.client.sync import FirstBatchClient 44 | from firstbatch.logger_conf import setup_logger 45 | 46 | 47 | class FirstBatch(FirstBatchClient): 48 | 49 | def __init__(self, api_key: str, config: Config): 50 | """ 51 | Initialize the FirstBatch class 52 | :param api_key: 53 | :param config: 54 | """ 55 | super().__init__(api_key) 56 | self.store: defaultdict[str, VectorStore] = defaultdict(VectorStore) 57 | self._batch_size = DEFAULT_BATCH_SIZE 58 | self._quantizer_train_size = DEFAULT_QUANTIZER_TRAIN_SIZE 59 | self._quantizer_type = DEFAULT_QUANTIZER_TYPE 60 | self._enable_history = DEFAULT_HISTORY 61 | self._verbose = DEFAULT_VERBOSE 62 | self.logger = setup_logger() 63 | self.logger.setLevel(logging.WARN) 64 | self._set_info() 65 | 66 | if config.verbose is not None: 67 | if config.verbose: 68 | self._verbose = config.verbose 69 | self.logger.setLevel(logging.DEBUG) 70 | else: 71 | self.logger.setLevel(logging.WARN) 72 | if config.batch_size is not None: 73 | self._batch_size = config.batch_size 74 | if config.quantizer_train_size is not None: 75 | self._quantizer_train_size = config.quantizer_train_size 76 | if config.quantizer_type is not None: 77 | self.logger.info("Product type quantizer is not supported yet.") 78 | # self._quantizer_type = kwargs["quantizer_type"] 79 | self._quantizer_type = "scalar" 80 | if config.enable_history is not None: 81 | self._enable_history = config.enable_history 82 | 83 | self.logger.info("Set mode to verbose") 84 | self.logger.info("Using: {}".format(self.url)) 85 | 86 | def add_vdb(self, vdbid: str, vs: VectorStore): 87 | 88 | exists = self._vdb_exists(vdbid) 89 | 90 | if not exists: 91 | self.logger.info("VectorDB with id {} not found. Sketching new VectorDB".format(vdbid)) 92 | if self._quantizer_type == "scalar": 93 | vs.quantizer = ScalarQuantizer(256) 94 | ts = min(int(self._quantizer_train_size/DEFAULT_TOPK_QUANT), 500) 95 | batch = generate_batch(ts, vs.embedding_size, top_k=DEFAULT_TOPK_QUANT, include_values=True) 96 | 97 | results = vs.multi_search(batch) 98 | vs.train_quantizer(results.vectors()) 99 | 100 | quantized_vectors = [vs.quantize_vector(vector).vector for vector in results.vectors()] 101 | # This might 1-2 minutes with scalar quantizer 102 | self._init_vectordb_scalar(self.api_key, vdbid, quantized_vectors, vs.quantizer.quantiles) 103 | 104 | self.store[vdbid] = vs 105 | elif self._quantizer_type == "product": 106 | raise NotImplementedError("Product quantizer not supported yet") 107 | 108 | else: 109 | raise ValueError(f"Invalid quantizer type: {self._quantizer_type}") 110 | else: 111 | self.store[vdbid] = vs 112 | 113 | def user_embeddings(self, session: SessionObject): 114 | return self._get_user_embeddings(session) 115 | 116 | def session(self, algorithm: AlgorithmLabel, vdbid: str, session_id: Optional[str] = None, 117 | custom_id: Optional[str] = None) -> SessionObject: 118 | 119 | if session_id is None: 120 | if algorithm == AlgorithmLabel.SIMPLE: 121 | req = session_request(**{"algorithm": algorithm.value, "vdbid": vdbid}) 122 | 123 | elif algorithm == AlgorithmLabel.CUSTOM: 124 | req = session_request(**{"algorithm": algorithm.value, "vdbid": vdbid, "custom_id": custom_id}) 125 | 126 | else: 127 | req = session_request(**{"algorithm": "FACTORY", "vdbid": vdbid, "factory_id": algorithm.value}) 128 | else: 129 | if algorithm == AlgorithmLabel.SIMPLE: 130 | req = session_request(**{"id": session_id, "algorithm": algorithm.value, "vdbid": vdbid}) 131 | 132 | elif algorithm == AlgorithmLabel.CUSTOM: 133 | req = session_request(**{"id": session_id, "algorithm": algorithm.value, "vdbid": vdbid, 134 | "custom_id": custom_id}) 135 | 136 | else: 137 | req = session_request(**{"id": session_id, "algorithm": "FACTORY", "vdbid": vdbid, 138 | "factory_id": algorithm.value}) 139 | 140 | return SessionObject(id=self._create_session(req).data, is_persistent=(session_id is not None)) 141 | 142 | def add_signal(self, session: SessionObject, user_action: UserAction, cid: str) -> None: 143 | response = self._get_session(session) 144 | 145 | vs = self.store[response.vdbid] 146 | 147 | if not isinstance(user_action.action_type, SignalType): 148 | raise ValueError(f"Invalid action type: {user_action.action_type}") 149 | 150 | fetch = FetchQuery(id=cid) 151 | result = vs.fetch(fetch) 152 | 153 | algo_instance = self.__get_algorithm(vs.embedding_size, self._batch_size, response.algorithm, response.factory_id, response.custom_id) 154 | 155 | # Create a signal object based on user action and content id 156 | signal_obj = SignalObject(vector=result.vector, action=user_action.action_type, cid=cid, timestamp=int(time.time())) 157 | 158 | # Call blueprint_step to calculate the next state 159 | (next_state, batch_type, params) = algo_instance.blueprint_step(response.state, user_action) 160 | # Send signal 161 | resp = self._signal(signal_request(session, next_state.name, signal_obj)) 162 | 163 | if resp.success: 164 | if self._enable_history: 165 | self._add_history(history_request(session, [result.metadata.data[vs.history_field]])) 166 | 167 | def batch(self, session: SessionObject, batch_size: Optional[int] = None, **kwargs): 168 | response = self._get_session(session) 169 | vs = self.store[response.vdbid] 170 | 171 | self.logger.info("Session: {} {} {}".format(response.algorithm, response.factory_id, response.state)) 172 | if batch_size is None: 173 | batch_size = self._batch_size 174 | 175 | algo_instance = self.__get_algorithm(vs.embedding_size, batch_size, response.algorithm, response.factory_id, response.custom_id) 176 | user_action = UserAction(BatchEnum.BATCH) 177 | 178 | (next_state, batch_type, params) = algo_instance.blueprint_step(response.state, user_action) 179 | 180 | self.logger.info("{} {}".format(batch_type, params.to_dict())) 181 | 182 | history = self._mock_history() 183 | if self._enable_history: 184 | history = self._get_history(session) 185 | 186 | if batch_type == BatchType.RANDOM: 187 | query = random_batch_request(algo_instance.batch_size, vs.embedding_size, **params.to_dict()) 188 | self._update_state(update_state_request(session, next_state.name, "RANDOM")) 189 | batch_response = vs.multi_search(query) 190 | ids, batch = algo_instance.random_batch(batch_response, query, **params.to_dict()) 191 | 192 | elif batch_type == BatchType.PERSONALIZED or batch_type == BatchType.BIASED: 193 | if batch_type == BatchType.BIASED and not ("bias_vectors" in kwargs and "bias_weights" in kwargs): 194 | self.logger.info("Bias vectors and weights must be provided for biased batch.") 195 | raise ValueError("no bias vectors provided") 196 | 197 | if batch_type == BatchType.PERSONALIZED and ("bias_vectors" in kwargs and "bias_weights" in kwargs): 198 | del kwargs["bias_vectors"] 199 | del kwargs["bias_weights"] 200 | 201 | if not response.has_embeddings and batch_type == BatchType.PERSONALIZED: 202 | self.logger.info("No embeddings found for personalized batch. Switching to random batch.") 203 | 204 | query = random_batch_request(algo_instance.batch_size, vs.embedding_size, **{"apply_mmr": True}) 205 | self._update_state(update_state_request(session, next_state.name, "PERSONALIZED")) 206 | batch_response = vs.multi_search(query) 207 | ids, batch = algo_instance.random_batch(batch_response, query, **params.to_dict()) 208 | 209 | else: 210 | batch_response_ = self._biased_batch(biased_batch_request(session, response.vdbid, next_state.name, 211 | params.to_dict(), **kwargs)) 212 | query = self.__query_wrapper(response.vdbid, algo_instance.batch_size, batch_response_, history, **params.to_dict()) 213 | batch = vs.multi_search(query) 214 | ids, batch = algo_instance.biased_batch(batch, query, **params.to_dict()) 215 | 216 | elif batch_type == BatchType.SAMPLED: 217 | batch_response_ = self._sampled_batch(sampled_batch_request(session=session, vdbid=response.vdbid, 218 | state=next_state.name, n_topics=params.n_topics)) 219 | query = self.__query_wrapper(response.vdbid, algo_instance.batch_size, batch_response_, history, **params.to_dict()) 220 | batch = vs.multi_search(query) 221 | ids, batch = algo_instance.sampled_batch(batch, query, **params.to_dict()) 222 | else: 223 | raise ValueError(f"Invalid batch type: {next_state.batch_type}") 224 | 225 | if self._enable_history: 226 | content = [b.data[vs.history_field] for b in batch[:algo_instance.batch_size]] 227 | self._add_history(history_request(session, content)) 228 | 229 | return ids, batch 230 | 231 | def __query_wrapper(self, vdbid: str, batch_size: int, response: BatchResponse, 232 | history: Optional[GetHistoryResponse], **kwargs): 233 | """ 234 | Wrapper for the query method. It applies the parameters from the blueprint 235 | :param vdbid: VectorDB ID, str 236 | :param batch_size: Batch size, int 237 | :param response: response from the API, BatchResponse 238 | :param history: list of content ids, GetHistoryResponse 239 | :param kwargs: 240 | :return: 241 | """ 242 | 243 | topks = adjust_weights(response.weights, batch_size, max((batch_size * DEFAULT_CONFIDENCE_INTERVAL_RATIO), 1)) 244 | m_filter = MetadataFilter(name="", filter={}) 245 | # We need the vector values to apply MMR or threshold 246 | include_values = ("apply_mmr" in kwargs or "apply_threshold" in kwargs) 247 | apply_mmr = "apply_mmr" in kwargs and (kwargs["apply_mmr"] is True or kwargs["apply_mmr"] == 1) 248 | 249 | if self._enable_history: 250 | if history is None: 251 | self.logger.info("History is None, No filter will be applied.") 252 | history = GetHistoryResponse(ids=[]) 253 | 254 | if "filter" in kwargs: 255 | m_filter = self.store[vdbid].history_filter(history.ids, kwargs["filter"]) 256 | else: 257 | m_filter = self.store[vdbid].history_filter(history.ids) 258 | 259 | qs = [Query(embedding=Vector(vec, len(vec), ""), top_k=max(topks[i], MINIMUM_TOPK), filter=m_filter, 260 | include_values=include_values) for i, vec in enumerate(response.vectors)] 261 | if apply_mmr: 262 | # increase top_k for MMR to work better 263 | for q in qs: 264 | q.top_k_mmr = q.top_k 265 | q.top_k *= 2 266 | 267 | return BatchQuery(qs, batch_size) 268 | 269 | def __get_algorithm(self, embedding_size: int, batch_size: int, algorithm: str, factory_id: Optional[str] = None, 270 | custom_id: Optional[str] = None) -> BaseAlgorithm: 271 | 272 | if algorithm == "SIMPLE": 273 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 274 | algo_instance: BaseAlgorithm = algo_type(batch_size, **{"embedding_size": embedding_size}) 275 | elif algorithm == "CUSTOM": 276 | if custom_id is None: 277 | raise ValueError("Custom algorithm id is None") 278 | bp = self._get_blueprint(custom_id) 279 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 280 | algo_instance: BaseAlgorithm = algo_type(bp, batch_size, **{"embedding_size": embedding_size}) # type: ignore 281 | elif algorithm == "FACTORY": 282 | algo_type = AlgorithmRegistry.get_algorithm_by_label(algorithm) 283 | algo_instance: BaseAlgorithm = algo_type(factory_id, batch_size, **{"embedding_size": embedding_size}) # type: ignore 284 | else: 285 | raise ValueError(f"Invalid algorithm: {algorithm}") 286 | return algo_instance 287 | 288 | 289 | -------------------------------------------------------------------------------- /firstbatch/imports/__init__.py: -------------------------------------------------------------------------------- 1 | from pydantic import ( 2 | BaseModel, 3 | Field, 4 | PrivateAttr, 5 | root_validator, 6 | validator, 7 | create_model, 8 | StrictFloat, 9 | StrictInt, 10 | StrictStr, 11 | ) 12 | from pydantic.fields import FieldInfo 13 | from pydantic import ValidationError 14 | 15 | 16 | __all__ = [ 17 | "BaseModel", 18 | "Field", 19 | "PrivateAttr", 20 | "root_validator", 21 | "validator", 22 | "create_model", 23 | "StrictFloat", 24 | "StrictInt", 25 | "StrictStr", 26 | "FieldInfo", 27 | "ValidationError", 28 | ] -------------------------------------------------------------------------------- /firstbatch/logger_conf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logger(): 5 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(message)s') 6 | logger = logging.getLogger("FirstBatchLogger") 7 | return logger 8 | -------------------------------------------------------------------------------- /firstbatch/lossy/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.lossy.scalar import ScalarQuantizer 2 | from firstbatch.lossy.product import ProductQuantizer 3 | from firstbatch.lossy.base import CompressedVector 4 | __all__ = ["ScalarQuantizer", "ProductQuantizer", "CompressedVector"] -------------------------------------------------------------------------------- /firstbatch/lossy/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, Optional 3 | from abc import ABC, abstractmethod 4 | from dataclasses_json import DataClassJsonMixin 5 | from dataclasses import dataclass 6 | from firstbatch.vector_store.schema import Vector 7 | 8 | 9 | @dataclass 10 | class CompressedVector(DataClassJsonMixin): 11 | vector: List[int] 12 | residual: Optional[List[int]] 13 | id: str 14 | 15 | 16 | class BaseLossy(ABC): 17 | """Base class for lossy compression algorithms.""" 18 | 19 | @abstractmethod 20 | def train(self, data: List[Vector]) -> None: 21 | """Train the algorithm. 22 | 23 | Args: 24 | data (Vector): Data to train the algorithm. 25 | """ 26 | ... 27 | 28 | @abstractmethod 29 | def compress(self, data: Vector) -> CompressedVector: 30 | """Compress data. 31 | 32 | Args: 33 | data (Any): Data to be compressed. 34 | 35 | Returns: 36 | Any: Compressed data. 37 | """ 38 | ... 39 | 40 | @abstractmethod 41 | def decompress(self, data: CompressedVector) -> Vector: 42 | """Decompress data. 43 | 44 | Args: 45 | data (Any): Data to be decompressed. 46 | 47 | Returns: 48 | Any: Decompressed data. 49 | """ 50 | ... -------------------------------------------------------------------------------- /firstbatch/lossy/product.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List, TYPE_CHECKING 3 | 4 | from firstbatch.lossy.base import BaseLossy, CompressedVector 5 | from firstbatch.vector_store.schema import Vector 6 | 7 | if TYPE_CHECKING: 8 | import nanopq 9 | 10 | 11 | class ProductQuantizer(BaseLossy): 12 | """Product quantizer algorithm.""" 13 | 14 | def __init__(self, cluster_size: int = 256, subquantizer_size: int = 32, verbose=False): 15 | try: 16 | import numpy as np 17 | import nanopq 18 | except ImportError: 19 | raise ImportError("Please install numpy && nanopq to use ProductQuantizer") 20 | 21 | self.data = None 22 | self.m = subquantizer_size 23 | self.ks = cluster_size 24 | self.__trained = False 25 | self.quantizer = nanopq.PQ(M=subquantizer_size, Ks=cluster_size, verbose=verbose) # 1024 26 | self.quantizer_residual = nanopq.PQ(M=subquantizer_size, Ks=cluster_size, verbose=verbose) 27 | 28 | def train(self, data: List[Vector]) -> None: 29 | # Encode data to PQ-codes 30 | if data[0].dim % self.m != 0: 31 | raise ValueError("input dimension must be dividable by M") 32 | 33 | if self.__trained: 34 | return None 35 | 36 | train_x = np.array([v.vector for v in data], dtype=np.float32) 37 | self.quantizer.fit(train_x) 38 | x_code = self.quantizer.encode(train_x) 39 | x = self.quantizer.decode(x_code) 40 | 41 | residuals = train_x - x 42 | self.quantizer_residual.fit(residuals) 43 | 44 | self.__trained = True 45 | 46 | # can only be used if train() has been called 47 | def compress(self, data: Vector) -> CompressedVector: 48 | if not self.__trained: 49 | raise ValueError("train() must be called before compress()") 50 | x = self.quantizer.encode(np.array(data.vector, dtype=np.float32).reshape(1, -1)) 51 | residual = self.quantizer_residual.encode(np.array(data.vector, dtype=np.float32) - self.quantizer.decode(x)) 52 | return CompressedVector(x.tolist()[0], residual.tolist()[0], data.id) 53 | 54 | def decompress(self, data: CompressedVector) -> Vector: 55 | if not self.__trained: 56 | raise ValueError("train() must be called before compress()") 57 | x = self.quantizer.decode(np.array(data.vector, dtype=np.uint16).reshape(1, -1)) 58 | residual = self.quantizer_residual.decode(np.array(data.residual, dtype=np.uint16).reshape(1, -1)) 59 | return Vector((x + residual).tolist(), len(data.vector), data.id) 60 | 61 | -------------------------------------------------------------------------------- /firstbatch/lossy/scalar.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List 3 | from firstbatch.lossy.base import BaseLossy, CompressedVector 4 | from firstbatch.vector_store.schema import Vector 5 | 6 | 7 | class ScalarQuantizer(BaseLossy): 8 | """Scalar quantizer algorithm.""" 9 | try: 10 | from tdigest import TDigest 11 | except ImportError: 12 | raise ImportError("Please install tdigest to use ScalarQuantizer") 13 | 14 | def __init__(self, levels=256): 15 | self.quantizer = self.TDigest() 16 | self.quantiles: List[float] = [] 17 | self.levels: int = levels 18 | 19 | def train(self, data: List[Vector]) -> None: 20 | scalars = Vector(vector=[], dim=0, id="") 21 | for vector in data: 22 | scalars = scalars.concat(vector) 23 | for scalar in scalars.vector: 24 | self.quantizer.update(scalar) 25 | self.quantiles = [self.quantizer.percentile(i * 100.0 / self.levels) for i in range(self.levels)] 26 | 27 | def __dequantize(self, qv): 28 | return [self.quantiles[val] for val in qv] 29 | 30 | def __quantize(self, v): 31 | return [self.__quantize_scalar(val) for val in v] 32 | 33 | def __quantize_scalar(self, scalar): 34 | return next((i for i, q in enumerate(self.quantiles) if scalar < q), self.levels - 1) 35 | 36 | def compress(self, data: Vector) -> CompressedVector: 37 | return CompressedVector(self.__quantize(data.vector), None, data.id) 38 | 39 | def decompress(self, data: CompressedVector) -> Vector: 40 | return Vector(self.__dequantize(data.vector), len(data.vector), data.id) 41 | -------------------------------------------------------------------------------- /firstbatch/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses_json import DataClassJsonMixin 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class Config(DataClassJsonMixin): 9 | """ 10 | Configuration for the vector store. 11 | """ 12 | batch_size: Optional[int] = None 13 | quantizer_train_size: Optional[int] = None 14 | quantizer_type: Optional[str] = None 15 | enable_history: Optional[bool] = None 16 | verbose: Optional[bool] = None 17 | -------------------------------------------------------------------------------- /firstbatch/vector_store/__init__.py: -------------------------------------------------------------------------------- 1 | from firstbatch.vector_store.base import VectorStore 2 | from firstbatch.vector_store.pinecone import Pinecone 3 | from firstbatch.vector_store.weaviate import Weaviate 4 | from firstbatch.vector_store.chroma import Chroma 5 | from firstbatch.vector_store.typesense import TypeSense 6 | from firstbatch.vector_store.supabase import Supabase 7 | from firstbatch.vector_store.qdrant import Qdrant 8 | from firstbatch.vector_store.schema import Query, QueryResult, SearchType, \ 9 | Vector, FetchQuery, Container, MetadataFilter, DistanceMetric 10 | from firstbatch.vector_store.utils import adjust_weights, generate_vectors, \ 11 | generate_batch, generate_query, maximal_marginal_relevance 12 | __all__ = ["Pinecone", "Weaviate", "Chroma", "TypeSense", "Supabase", "Qdrant", "VectorStore", "Query", "Container", 13 | "generate_vectors", "generate_query", "generate_batch", "maximal_marginal_relevance", "Vector", "FetchQuery", 14 | "QueryResult", "adjust_weights", "MetadataFilter", "SearchType", "DistanceMetric"] 15 | -------------------------------------------------------------------------------- /firstbatch/vector_store/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import ( 5 | Any, 6 | List, 7 | Dict, 8 | Optional, 9 | Union 10 | ) 11 | from firstbatch.lossy.base import CompressedVector 12 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult,\ 13 | QueryResult, BatchFetchQuery, FetchResult, BatchFetchResult, Vector, MetadataFilter 14 | 15 | 16 | class VectorStore(ABC): 17 | """Interface for vector store.""" 18 | 19 | @property 20 | @abstractmethod 21 | def quantizer(self): 22 | ... 23 | 24 | @quantizer.setter 25 | @abstractmethod 26 | def quantizer(self, value): 27 | ... 28 | 29 | @property 30 | @abstractmethod 31 | def embedding_size(self): 32 | ... 33 | 34 | @embedding_size.setter 35 | @abstractmethod 36 | def embedding_size(self, value): 37 | ... 38 | 39 | @property 40 | @abstractmethod 41 | def history_field(self): 42 | ... 43 | 44 | @abstractmethod 45 | def train_quantizer(self, vectors: List[Vector]): 46 | ... 47 | 48 | @abstractmethod 49 | def quantize_vector(self, vector: Vector) -> CompressedVector: 50 | ... 51 | 52 | @abstractmethod 53 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 54 | ... 55 | 56 | @abstractmethod 57 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 58 | """Return docs most similar to query using specified search type.""" 59 | ... 60 | 61 | @abstractmethod 62 | async def asearch( 63 | self, query: Query, **kwargs: Any 64 | ) -> QueryResult: 65 | """Return docs most similar to query using specified search type.""" 66 | ... 67 | 68 | @abstractmethod 69 | def fetch( 70 | self, query: FetchQuery, **kwargs: Any 71 | ) -> FetchResult: 72 | """Return docs most similar to query using specified search type.""" 73 | ... 74 | 75 | @abstractmethod 76 | async def afetch( 77 | self, query: FetchQuery, **kwargs: Any 78 | ) -> FetchResult: 79 | """Return docs most similar to query using specified search type.""" 80 | ... 81 | 82 | @abstractmethod 83 | def multi_search( 84 | self, batch_query: BatchQuery, **kwargs: Any 85 | ) -> BatchQueryResult: 86 | """Return docs most similar to query using specified search type.""" 87 | ... 88 | 89 | @abstractmethod 90 | def multi_fetch( 91 | self, batch_query: BatchFetchQuery, **kwargs: Any 92 | ) -> BatchFetchResult: 93 | """Return docs most similar to query using specified search type.""" 94 | ... 95 | 96 | @abstractmethod 97 | async def a_multi_search( 98 | self, batch_query: BatchQuery, **kwargs: Any 99 | ) -> BatchQueryResult: 100 | """Return docs most similar to query using specified search type.""" 101 | ... 102 | 103 | @abstractmethod 104 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 105 | """Return docs most similar to query using specified search type.""" 106 | ... 107 | 108 | 109 | -------------------------------------------------------------------------------- /firstbatch/vector_store/chroma.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Any, Optional, List, Dict, Union, cast 3 | from functools import partial 4 | import asyncio 5 | import logging 6 | from firstbatch.constants import DEFAULT_COLLECTION, DEFAULT_EMBEDDING_SIZE, DEFAULT_HISTORY_FIELD 7 | from firstbatch.vector_store.schema import MetadataFilter 8 | from firstbatch.vector_store.base import VectorStore 9 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult,\ 10 | QueryResult, QueryMetadata, BatchFetchQuery, BatchFetchResult, FetchResult, Vector, DistanceMetric 11 | from firstbatch.lossy.base import BaseLossy, CompressedVector 12 | 13 | if TYPE_CHECKING: 14 | import chromadb 15 | import chromadb.config 16 | from chromadb.api.types import Where 17 | 18 | logger = logging.getLogger("FirstBatchLogger") 19 | 20 | 21 | class Chroma(VectorStore): 22 | 23 | def __init__( 24 | self, 25 | collection_name: str = DEFAULT_COLLECTION, 26 | persist_directory: Optional[str] = None, 27 | client_settings: Optional[chromadb.config.Settings] = None, 28 | client: Optional[Any] = None, 29 | distance_metric: Optional[DistanceMetric] = None, 30 | history_field: Optional[str] = None, 31 | embedding_size: Optional[int] = None 32 | ) -> None: 33 | try: 34 | import chromadb 35 | import chromadb.config 36 | 37 | except ImportError: 38 | raise ImportError( 39 | "Could not import chromadb python package. " 40 | "Please install it with `pip install chromadb`." 41 | ) 42 | 43 | if client is not None: 44 | self._client_settings = client_settings 45 | self._client = client 46 | self._persist_directory = persist_directory 47 | else: 48 | raise ValueError("client should be an instance of chromadb.Client, got {type(client)}") 49 | 50 | self._collection: chromadb.Collection = self._client.get_collection(collection_name) 51 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 52 | self._history_field = DEFAULT_HISTORY_FIELD if history_field is None else history_field 53 | self._distance_metric: DistanceMetric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 54 | logger.debug("Chrome vector store initialized with collection: {}".format(collection_name)) 55 | 56 | @property 57 | def quantizer(self): 58 | return self._quantizer 59 | 60 | @quantizer.setter 61 | def quantizer(self, value): 62 | self._quantizer = value 63 | 64 | @property 65 | def embedding_size(self): 66 | return self._embedding_size 67 | 68 | @embedding_size.setter 69 | def embedding_size(self, value): 70 | self._embedding_size = value 71 | 72 | @property 73 | def history_field(self): 74 | return self._history_field 75 | 76 | def train_quantizer(self, vectors: List[Vector]): 77 | if isinstance(self._quantizer, BaseLossy): 78 | self._quantizer.train(vectors) 79 | else: 80 | raise ValueError("Quantizer is not initialized or of the wrong type") 81 | 82 | def quantize_vector(self, vector: Vector) -> CompressedVector: 83 | return self._quantizer.compress(vector) 84 | 85 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 86 | return self._quantizer.decompress(vector) 87 | 88 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 89 | from chromadb.api.types import Where 90 | 91 | if query.include_values: 92 | include = ["metadatas", "documents", "distances", "embeddings"] 93 | else: 94 | include = ["metadatas", "documents", "distances"] 95 | 96 | if query.embedding is None: 97 | raise ValueError("Query must have an embedding.") 98 | 99 | result = self._collection.query( 100 | query_embeddings=query.embedding.vector, 101 | n_results=query.top_k, 102 | where=cast(Optional[Where], query.filter.filter), 103 | include=include, # type: ignore 104 | **kwargs, 105 | ) 106 | 107 | metadatas = result.get("metadatas") 108 | ids = result.get("ids") 109 | distances = result.get("distances") 110 | if None in [metadatas, ids, distances]: 111 | raise ValueError("Query result does not contain metadatas, ids or distances.") 112 | 113 | if query.include_values: 114 | if result.get("embeddings") is None: 115 | raise ValueError("Query result does not contain embeddings.") 116 | vectors = [Vector(vec, len(query.embedding.vector), "") for vec in result["embeddings"][0]] # type: ignore 117 | else: 118 | vectors = [] 119 | 120 | metadata = [QueryMetadata(id=result["ids"][0][i], data=doc) 121 | for i, doc in enumerate(result["metadatas"][0])] # type: ignore 122 | 123 | return QueryResult(ids=result["ids"][0], scores=result["distances"][0], vectors=vectors, metadata=metadata, # type: ignore 124 | distance_metric=self._distance_metric) 125 | 126 | async def asearch( 127 | self, query: Query, **kwargs: Any 128 | ) -> QueryResult: 129 | func = partial( 130 | self.search, 131 | query=query, 132 | **kwargs, 133 | ) 134 | return await asyncio.get_event_loop().run_in_executor(None, func) 135 | 136 | def fetch(self, query: FetchQuery, **kwargs: Any) -> FetchResult: 137 | if query.id is None: 138 | raise ValueError("id must be provided for fetch query") 139 | 140 | result = self._collection.get(query.id, include=["metadatas", "documents", "embeddings"]) 141 | 142 | metadatas = result.get("metadatas") 143 | distances = result.get("embeddings") 144 | if None in [metadatas, distances]: 145 | raise ValueError("Query result does not contain metadatas, ids or distances.") 146 | 147 | m = QueryMetadata(id=query.id, data=result["metadatas"][0]) # type: ignore 148 | v = Vector(vector=result["embeddings"][0], id=query.id, dim=len(result["embeddings"][0])) # type: ignore 149 | return FetchResult(id=query.id, vector=v, metadata=m) 150 | 151 | async def afetch( 152 | self, query: FetchQuery, **kwargs: Any 153 | ) -> FetchResult: 154 | func = partial( 155 | self.fetch, 156 | query=query, 157 | **kwargs, 158 | ) 159 | return await asyncio.get_event_loop().run_in_executor(None, func) 160 | 161 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 162 | async def _async_multi_search(): 163 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 164 | results = await asyncio.gather(*coroutines) 165 | return BatchQueryResult(results, batch_query.batch_size) 166 | 167 | return asyncio.run(_async_multi_search()) 168 | 169 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 170 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 171 | results = await asyncio.gather(*coroutines) 172 | return BatchQueryResult(results, batch_query.batch_size) 173 | 174 | def multi_fetch( 175 | self, batch_query: BatchFetchQuery, **kwargs: Any 176 | ) -> BatchFetchResult: 177 | 178 | ids = [q.id for q in batch_query.fetches] 179 | result = self._collection.get(ids, include=["metadatas", "documents", "embeddings"]) 180 | fetches = [FetchResult(id=idx, vector=Vector(vector=result["embeddings"][i], dim=len(result["embeddings"][i])), # type: ignore 181 | metadata=QueryMetadata(id=idx, data=result["metadatas"][i])) for i, idx in enumerate(result["ids"])] # type: ignore 182 | return BatchFetchResult(batch_size=batch_query.batch_size, results=fetches) 183 | 184 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 185 | 186 | if isinstance(prev_filter, str): 187 | raise ValueError("prev_filter must be a dict for Chroma") 188 | 189 | if prev_filter is not None: 190 | if "$and" not in prev_filter: 191 | prev_filter["$and"] = [] 192 | for id in ids: 193 | prev_filter["$and"].append({self._history_field: {"$ne": id}}) 194 | 195 | return MetadataFilter(name="", filter=prev_filter) 196 | else: 197 | filter_: Dict = { 198 | "$and": [ 199 | ] 200 | } 201 | for id in ids: 202 | filter_["$and"].append({self._history_field: {"$ne": id}}) 203 | 204 | return MetadataFilter(name="", filter=filter_) -------------------------------------------------------------------------------- /firstbatch/vector_store/pinecone.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Any, Optional, List, Dict, Union 3 | import concurrent.futures 4 | from functools import partial 5 | import asyncio 6 | import logging 7 | from firstbatch.vector_store.base import VectorStore 8 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult, \ 9 | QueryResult, SearchType, QueryMetadata, BatchFetchQuery, BatchFetchResult, FetchResult, Vector, MetadataFilter, \ 10 | DistanceMetric 11 | from firstbatch.lossy.base import BaseLossy, CompressedVector 12 | from firstbatch.constants import DEFAULT_EMBEDDING_SIZE, DEFAULT_HISTORY_FIELD 13 | 14 | if TYPE_CHECKING: 15 | from pinecone import Index 16 | 17 | logger = logging.getLogger("FirstBatchLogger") 18 | 19 | 20 | class Pinecone(VectorStore): 21 | """`Pinecone` vector store.""" 22 | 23 | def __init__( 24 | self, 25 | index: Index, 26 | namespace: Optional[str] = None, 27 | distance_metric: Optional[DistanceMetric] = None, 28 | history_field: Optional[str] = None, 29 | embedding_size: Optional[int] = None 30 | ): 31 | """Initialize with Pinecone client.""" 32 | try: 33 | import pinecone 34 | except ImportError: 35 | raise ImportError( 36 | "Could not import pinecone python package. " 37 | "Please install it with `pip install pinecone-client`." 38 | ) 39 | if not isinstance(index, pinecone.Index): 40 | raise ValueError( 41 | f"client should be an instance of pinecone.index.Index, " 42 | f"got {type(index)}" 43 | ) 44 | self._index = index 45 | self._namespace = namespace 46 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 47 | self._history_field = DEFAULT_HISTORY_FIELD if history_field is None else history_field 48 | self._distance_metric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 49 | logger.debug("Pinecone vector store initialized with namespace: {}".format(namespace)) 50 | 51 | @property 52 | def quantizer(self): 53 | return self._quantizer 54 | 55 | @quantizer.setter 56 | def quantizer(self, value): 57 | self._quantizer = value 58 | 59 | @property 60 | def embedding_size(self): 61 | return self._embedding_size 62 | 63 | @property 64 | def history_field(self): 65 | return self._history_field 66 | 67 | @embedding_size.setter 68 | def embedding_size(self, value): 69 | self._embedding_size = value 70 | 71 | def train_quantizer(self, vectors: List[Vector]): 72 | if isinstance(self._quantizer, BaseLossy): 73 | self._quantizer.train(vectors) 74 | else: 75 | raise ValueError("Quantizer is not initialized or of the wrong type") 76 | 77 | def quantize_vector(self, vector: Vector) -> CompressedVector: 78 | return self._quantizer.compress(vector) 79 | 80 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 81 | return self._quantizer.decompress(vector) 82 | 83 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 84 | """Return docs most similar to query using specified search type.""" 85 | if query.search_type == SearchType.FETCH: 86 | raise ValueError("search_type must be 'default' or 'sparse' to use search method") 87 | elif query.search_type == SearchType.SPARSE: 88 | raise NotImplementedError("Sparse search is not implemented yet.") 89 | else: 90 | result = self._index.query(query.embedding.vector, 91 | top_k=query.top_k, 92 | filter=query.filter.filter, 93 | include_metadata=query.include_metadata, 94 | include_values=query.include_values, 95 | ) 96 | ids, scores, vectors, metadata = [], [], [], [] 97 | for r in result["matches"]: 98 | ids.append(r["id"]) 99 | scores.append(r["score"]) 100 | vectors.append(Vector(vector=r["values"], dim=len(r["values"]), id=r["id"])) 101 | metadata.append(QueryMetadata(id=r["id"], data=r["metadata"])) 102 | 103 | return QueryResult(ids=ids, scores=scores, vectors=vectors, metadata=metadata, 104 | distance_metric=self._distance_metric) 105 | 106 | async def asearch( 107 | self, query: Query, **kwargs: Any 108 | ) -> QueryResult: 109 | func = partial( 110 | self.search, 111 | query=query, 112 | **kwargs, 113 | ) 114 | return await asyncio.get_event_loop().run_in_executor(None, func) 115 | 116 | def fetch( 117 | self, query: FetchQuery, **kwargs: Any 118 | ) -> FetchResult: 119 | """Return docs most similar to query using specified search type.""" 120 | assert query.id is not None, "id must be provided for fetch query" 121 | result = self._index.fetch([query.id]) 122 | fetches = [] 123 | for k, v in result["vectors"].items(): 124 | fetches.append(FetchResult(id=k, vector=Vector(vector=v["values"], 125 | dim=len(v["values"]), id=k), 126 | metadata=QueryMetadata(id=k, data=v["metadata"]))) 127 | 128 | return fetches[0] 129 | 130 | async def afetch( 131 | self, query: FetchQuery, **kwargs: Any 132 | ) -> FetchResult: 133 | func = partial( 134 | self.fetch, 135 | query=query, 136 | **kwargs, 137 | ) 138 | return await asyncio.get_event_loop().run_in_executor(None, func) 139 | 140 | def multi_search_c( 141 | self, batch_query: BatchQuery, **kwargs: Any 142 | ) -> BatchQueryResult: 143 | with concurrent.futures.ThreadPoolExecutor() as executor: 144 | results = list(executor.map(self.search, batch_query.queries)) 145 | return BatchQueryResult(results, batch_query.batch_size) 146 | 147 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 148 | async def _async_multi_search(): 149 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 150 | results = await asyncio.gather(*coroutines) 151 | return BatchQueryResult(results, batch_query.batch_size) 152 | 153 | return asyncio.run(_async_multi_search()) 154 | 155 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 156 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 157 | results = await asyncio.gather(*coroutines) 158 | return BatchQueryResult(results, batch_query.batch_size) 159 | 160 | def multi_fetch( 161 | self, batch_query: BatchFetchQuery, **kwargs: Any 162 | ) -> BatchFetchResult: 163 | 164 | ids = [q.id for q in batch_query.fetches] 165 | result = self._index.fetch(ids) 166 | fetches = [FetchResult(id=k, vector=Vector(vector=v["values"], dim=len(v["values"])), 167 | metadata=QueryMetadata(id="", data=v["metadata"]) 168 | ) 169 | for k, v in result["vectors"].items()] 170 | return BatchFetchResult(batch_size=batch_query.batch_size, results=fetches) 171 | 172 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 173 | 174 | filter_ = { 175 | self._history_field: {"$nin": ids} 176 | } 177 | if prev_filter is not None: 178 | if isinstance(prev_filter, str): 179 | raise ValueError("prev_filter must be a dict for Pinecone") 180 | 181 | merged = prev_filter.copy() 182 | 183 | if self._history_field in prev_filter: 184 | merged[self._history_field]["$nin"] = list(set(prev_filter[self._history_field]["$nin"] + filter_[self._history_field]["$nin"])) 185 | else: 186 | merged[self._history_field] = filter_[self._history_field] 187 | 188 | for key, value in filter_.items(): 189 | if key != self._history_field and key not in merged: 190 | merged[key] = value 191 | 192 | return MetadataFilter(name="history", filter=merged) 193 | 194 | else: 195 | return MetadataFilter(name="History", filter=filter_) 196 | -------------------------------------------------------------------------------- /firstbatch/vector_store/qdrant.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Any, Optional, List, Dict, Union 3 | import concurrent.futures 4 | from functools import partial 5 | import asyncio 6 | import logging 7 | from firstbatch.vector_store.base import VectorStore 8 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult, \ 9 | QueryResult, QueryMetadata, BatchFetchQuery, BatchFetchResult, FetchResult, Vector, MetadataFilter, \ 10 | DistanceMetric 11 | from firstbatch.lossy.base import BaseLossy, CompressedVector 12 | from firstbatch.constants import DEFAULT_EMBEDDING_SIZE, DEFAULT_COLLECTION, DEFAULT_HISTORY_FIELD 13 | 14 | if TYPE_CHECKING: 15 | from qdrant_client import QdrantClient 16 | from qdrant_client.http.models import FieldCondition, MatchExcept, Filter 17 | 18 | logger = logging.getLogger("FirstBatchLogger") 19 | 20 | 21 | class Qdrant(VectorStore): 22 | """`Qdrant` vector store.""" 23 | 24 | def __init__( 25 | self, 26 | client: QdrantClient, 27 | collection_name: Optional[str] = None, 28 | distance_metric: Optional[DistanceMetric] = None, 29 | history_field: Optional[str] = None, 30 | embedding_size: Optional[int] = None 31 | ): 32 | """Initialize with Qdrant client.""" 33 | try: 34 | from qdrant_client import QdrantClient 35 | except ImportError: 36 | raise ImportError( 37 | "Could not import qdrant_client python package. " 38 | "Please install it with `pip install qdrant-client`." 39 | ) 40 | if not isinstance(client, QdrantClient): 41 | raise ValueError( 42 | f"client should be an instance of QdrantClient, " 43 | f"got {type(client)}" 44 | ) 45 | self._client = client 46 | self._collection_name = collection_name if collection_name is not None else DEFAULT_COLLECTION 47 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 48 | self._history_field = DEFAULT_HISTORY_FIELD if history_field is None else history_field 49 | self._distance_metric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 50 | logger.debug("Qdrant vector store initialized for collection {}".format(collection_name)) 51 | 52 | @property 53 | def quantizer(self): 54 | return self._quantizer 55 | 56 | @quantizer.setter 57 | def quantizer(self, value): 58 | self._quantizer = value 59 | 60 | @property 61 | def embedding_size(self): 62 | return self._embedding_size 63 | 64 | @embedding_size.setter 65 | def embedding_size(self, value): 66 | self._embedding_size = value 67 | 68 | @property 69 | def history_field(self): 70 | return self._history_field 71 | 72 | def train_quantizer(self, vectors: List[Vector]): 73 | if isinstance(self._quantizer, BaseLossy): 74 | self._quantizer.train(vectors) 75 | else: 76 | raise ValueError("Quantizer is not initialized or of the wrong type") 77 | 78 | def quantize_vector(self, vector: Vector) -> CompressedVector: 79 | return self._quantizer.compress(vector) 80 | 81 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 82 | return self._quantizer.decompress(vector) 83 | 84 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 85 | """Return docs most similar to query using specified search type.""" 86 | if query.filter.filter is None: 87 | query.filter.filter = {} 88 | 89 | result = self._client.search( 90 | collection_name=self._collection_name, 91 | query_vector=query.embedding.vector, 92 | limit=query.top_k, 93 | append_payload=query.include_metadata, 94 | with_vectors=query.include_values, 95 | query_filter=query.filter.filter 96 | ) 97 | 98 | ids, scores, vectors, metadata = [], [], [], [] 99 | for r in result: 100 | ids.append(str(r.id)) 101 | scores.append(r.score) 102 | if r.vector is not None: 103 | vectors.append(Vector(vector=r.vector, dim=len(r.vector), id=str(r.id))) 104 | if r.payload is not None: 105 | metadata.append(QueryMetadata(id=str(r.id), data=r.payload)) 106 | 107 | return QueryResult(ids=ids, scores=scores, vectors=vectors, metadata=metadata, 108 | distance_metric=self._distance_metric) 109 | 110 | async def asearch( 111 | self, query: Query, **kwargs: Any 112 | ) -> QueryResult: 113 | func = partial( 114 | self.search, 115 | query=query, 116 | **kwargs, 117 | ) 118 | return await asyncio.get_event_loop().run_in_executor(None, func) 119 | 120 | def fetch( 121 | self, query: FetchQuery, **kwargs: Any 122 | ) -> FetchResult: 123 | """Return docs most similar to query using specified search type.""" 124 | assert query.id is not None, "id must be provided for fetch query" 125 | id_: Union[str, int] 126 | try: 127 | id_ = int(query.id) 128 | except: 129 | id_ = query.id 130 | 131 | result = self._client.retrieve(collection_name=self._collection_name, ids=[id_], with_vectors=True) 132 | fetches = [] 133 | for r in result: 134 | fetches.append(FetchResult(id=str(r.id), vector=Vector(vector=r.vector, 135 | dim=len(r.vector), id=str(r.id)), 136 | metadata=QueryMetadata(id=str(r.id), data=r.payload))) 137 | 138 | return fetches[0] 139 | 140 | async def afetch( 141 | self, query: FetchQuery, **kwargs: Any 142 | ) -> FetchResult: 143 | func = partial( 144 | self.fetch, 145 | query=query, 146 | **kwargs, 147 | ) 148 | return await asyncio.get_event_loop().run_in_executor(None, func) 149 | 150 | def multi_search_c( 151 | self, batch_query: BatchQuery, **kwargs: Any 152 | ) -> BatchQueryResult: 153 | with concurrent.futures.ThreadPoolExecutor() as executor: 154 | results = list(executor.map(self.search, batch_query.queries)) 155 | return BatchQueryResult(results, batch_query.batch_size) 156 | 157 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 158 | async def _async_multi_search(): 159 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 160 | results = await asyncio.gather(*coroutines) 161 | return BatchQueryResult(results, batch_query.batch_size) 162 | 163 | return asyncio.run(_async_multi_search()) 164 | 165 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 166 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 167 | results = await asyncio.gather(*coroutines) 168 | return BatchQueryResult(results, batch_query.batch_size) 169 | 170 | def multi_fetch( 171 | self, batch_query: BatchFetchQuery, **kwargs: Any 172 | ) -> BatchFetchResult: 173 | 174 | ids:List[Union[str, int]] = [] 175 | for q in batch_query.fetches: 176 | try: 177 | ids.append(int(q.id)) 178 | except: 179 | ids.append(q.id) 180 | 181 | result = self._client.retrieve(collection_name=self._collection_name, ids=ids, with_vectors=True) 182 | fetches = [] 183 | for r in result: 184 | fetches.append(FetchResult(id=str(r.id), vector=Vector(vector=r.vector, 185 | dim=len(r.vector), id=str(r.id)), 186 | metadata=QueryMetadata(id=str(r.id), data=r.payload))) 187 | return BatchFetchResult(batch_size=batch_query.batch_size, results=fetches) 188 | 189 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 190 | 191 | from qdrant_client.http.models import FieldCondition, MatchExcept, Filter 192 | import json 193 | 194 | if prev_filter is not None: 195 | raise ValueError("Qdrant implementation currently does not support history filter with previous filter") 196 | 197 | filter_ = Filter( 198 | must= 199 | [ 200 | FieldCondition( 201 | key=self._history_field, 202 | match=MatchExcept(**{"except": ids})) 203 | ] 204 | ).model_dump_json() 205 | 206 | filter_ = json.loads(filter_) 207 | if "except_" in filter_["must"][0]["match"]: 208 | filter_["must"][0]["match"]["except"] = filter_["must"][0]["match"]["except_"] 209 | del filter_["must"][0]["match"]["except_"] 210 | 211 | return MetadataFilter(name="History", filter=filter_) 212 | -------------------------------------------------------------------------------- /firstbatch/vector_store/schema.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from firstbatch.imports import Field, BaseModel 3 | from dataclasses_json import DataClassJsonMixin 4 | from dataclasses import dataclass 5 | from typing import List, Optional, Union, Any, Dict 6 | from enum import Enum 7 | import numpy as np 8 | 9 | 10 | @dataclass 11 | class Vector(DataClassJsonMixin): 12 | vector: List[float] 13 | dim: int 14 | id: str = "" 15 | 16 | def concat(self, other: Vector) -> Vector: 17 | """Concatenates the current vector with another vector.""" 18 | if not isinstance(other, Vector): 19 | raise ValueError("The 'other' parameter must be an instance of Vector.") 20 | new_vector = self.vector + other.vector 21 | new_dim = self.dim + other.dim 22 | new_id = self.id + "_" + other.id # You can choose a different scheme for the new id if you want 23 | return Vector(vector=new_vector, dim=new_dim, id=new_id) 24 | 25 | 26 | @dataclass 27 | class Container: 28 | volume: Dict[str, Any] 29 | 30 | 31 | class DistanceMetric(Enum): 32 | COSINE_SIM = "cosine_sim" 33 | EUCLIDEAN_DIST = "euclidean_dist" 34 | DOT_PRODUCT = "dot_product" 35 | 36 | 37 | class SearchType(str, Enum): 38 | """Search types.""" 39 | DEFAULT = "default" 40 | SPARSE = "sparse" 41 | FETCH = "fetch" 42 | 43 | 44 | class MetadataFilter(BaseModel): 45 | """Interface for interacting with a document.""" 46 | 47 | name: str 48 | filter: Union[dict, str] = {} 49 | 50 | 51 | class QueryMetadata(BaseModel): 52 | """Interface for interacting with a document.""" 53 | id: str 54 | data: dict = Field(default_factory=dict) 55 | 56 | 57 | @dataclass 58 | class Query: 59 | """Vector store query.""" 60 | 61 | embedding: Vector 62 | top_k: int = 1 63 | top_k_mmr: int = top_k 64 | return_fields: Optional[List[str]] = None 65 | 66 | search_type: SearchType = SearchType.DEFAULT 67 | 68 | # metadata filters 69 | filter: MetadataFilter = MetadataFilter(name="") 70 | 71 | # include options 72 | include_metadata: bool = True 73 | include_values: bool = True 74 | 75 | # NOTE: currently only used by postgres hybrid search 76 | sparse_top_k: Optional[int] = None 77 | 78 | 79 | @dataclass 80 | class FetchQuery: 81 | """Vector store query.""" 82 | 83 | id: str 84 | search_type: SearchType = SearchType.FETCH 85 | return_fields: Optional[List[str]] = None 86 | 87 | 88 | @dataclass 89 | class BatchFetchQuery: 90 | """Vector store query.""" 91 | 92 | fetches: List[FetchQuery] 93 | batch_size: int = 1 94 | 95 | 96 | @dataclass 97 | class BatchQuery: 98 | """Vector store batch query.""" 99 | 100 | queries: List[Query] 101 | batch_size: int = 1 102 | 103 | def concat(self, other: BatchQuery) -> BatchQuery: 104 | """Concatenate two BatchQuery objects.""" 105 | if not isinstance(other, BatchQuery): 106 | raise ValueError("The 'other' parameter must be an instance of BatchQuery.") 107 | 108 | if not self.batch_size == other.batch_size: 109 | raise ValueError("The batch sizes must be equal.") 110 | 111 | # Extend the queries list 112 | new_queries = self.queries + other.queries 113 | 114 | return BatchQuery(batch_size=self.batch_size, queries=new_queries) 115 | 116 | 117 | @dataclass 118 | class FetchResult: 119 | """Vector store query result.""" 120 | 121 | metadata: QueryMetadata 122 | vector: Vector 123 | id: str 124 | 125 | 126 | @dataclass 127 | class QueryResult: 128 | """Vector store query result.""" 129 | 130 | ids: List[str] 131 | vectors: Optional[List[Vector]] = None 132 | metadata: Optional[List[QueryMetadata]] = None 133 | scores: Optional[List[float]] = None 134 | distance_metric: DistanceMetric = DistanceMetric.COSINE_SIM 135 | 136 | def to_ndarray(self) -> np.ndarray: 137 | 138 | if not self.scores: 139 | return np.array([]) 140 | 141 | if self.vectors is None: 142 | raise ValueError("Vectors must be provided to convert to ndarray.") 143 | 144 | matrix = [vec.vector for vec in self.vectors] 145 | return np.array(matrix) 146 | 147 | def apply_threshold(self, threshold: float) -> QueryResult: 148 | if not self.scores: 149 | return self 150 | 151 | avg: float = float(np.mean(np.array(self.scores).ravel()).item()) 152 | # TODO: Maybe remove this safety measure? 153 | if self.distance_metric == DistanceMetric.EUCLIDEAN_DIST: 154 | threshold = avg if threshold < avg else threshold 155 | indices_to_keep = [index for index, score in enumerate(self.scores) if score <= threshold] 156 | else: 157 | threshold = avg if threshold > avg else threshold 158 | indices_to_keep = [index for index, score in enumerate(self.scores) if score >= threshold] 159 | 160 | new_vectors = [self.vectors[i] for i in indices_to_keep] if self.vectors else None 161 | new_metadata = [self.metadata[i] for i in indices_to_keep] if self.metadata else None 162 | new_scores = [self.scores[i] for i in indices_to_keep] if self.scores else None 163 | new_ids = [self.ids[i] for i in indices_to_keep] if self.ids else None 164 | 165 | return QueryResult( 166 | vectors=new_vectors, 167 | metadata=new_metadata, 168 | scores=new_scores, 169 | ids=new_ids if new_ids else [] 170 | ) 171 | 172 | def remove_ids(self, ids: List[str]) -> QueryResult: 173 | if not self.ids: 174 | return self 175 | 176 | indices_to_keep = [index for index, id in enumerate(self.ids) if id not in ids] 177 | 178 | new_vectors = [self.vectors[i] for i in indices_to_keep] if self.vectors else None 179 | new_metadata = [self.metadata[i] for i in indices_to_keep] if self.metadata else None 180 | new_scores = [self.scores[i] for i in indices_to_keep] if self.scores else None 181 | new_ids = [self.ids[i] for i in indices_to_keep] if self.ids else None 182 | 183 | return QueryResult( 184 | vectors=new_vectors, 185 | metadata=new_metadata, 186 | scores=new_scores, 187 | ids= new_ids if new_ids else [] 188 | ) 189 | 190 | def concat(self, other: QueryResult) -> QueryResult: 191 | """Concatenate the current QueryResult with another QueryResult.""" 192 | if not isinstance(other, QueryResult): 193 | raise ValueError("The 'other' parameter must be an instance of QueryResult.") 194 | 195 | new_vectors = (self.vectors or []) + (other.vectors or []) 196 | new_metadata = (self.metadata or []) + (other.metadata or []) 197 | new_scores = (self.scores or []) + (other.scores or []) 198 | new_ids = (self.ids or []) + (other.ids or []) 199 | 200 | return QueryResult( 201 | vectors=new_vectors, 202 | metadata=new_metadata, 203 | scores=new_scores, 204 | ids=new_ids 205 | ) 206 | 207 | def non_unique_ids(self): 208 | from collections import Counter 209 | return [k for k, v in Counter(self.ids).items() if v > 1] 210 | 211 | 212 | @dataclass 213 | class BatchQueryResult: 214 | """Vector store query result.""" 215 | 216 | results: List[QueryResult] 217 | batch_size: int = 1 218 | 219 | def vectors(self) -> List[Vector]: 220 | return [v for r in self.results if r.vectors is not None for v in r.vectors] 221 | 222 | def remove_duplicates(self) -> None: 223 | """ Remove duplicate ids from the batch result """ 224 | if not self.results: 225 | return None 226 | 227 | flat = self.flatten() 228 | unique_ids = dict(zip(flat.ids, [0] * len(flat.ids))) 229 | # Initialize with the first result 230 | for result in self.results: 231 | selected_indices = [] 232 | for i, _id in enumerate(result.ids): 233 | if unique_ids[_id] == 0: 234 | unique_ids[_id] += 1 235 | selected_indices.append(i) 236 | result.ids = [result.ids[i] for i in selected_indices] 237 | result.metadata = [result.metadata[i] for i in selected_indices if result.metadata is not None] 238 | result.scores = [result.scores[i] for i in selected_indices if result.scores is not None] 239 | result.vectors = [result.vectors[i] for i in selected_indices if result.vectors is not None] 240 | 241 | def sort(self) -> None: 242 | if not self.results: 243 | return None 244 | 245 | for result in self.results: 246 | if result.scores is None: 247 | raise ValueError("Scores must be provided to sort.") 248 | sorted_indices = sorted(range(len(result.scores)), key=lambda k: result.scores[k], reverse=True) # type: ignore 249 | 250 | result.ids = [result.ids[i] for i in sorted_indices] 251 | result.metadata = [result.metadata[i] for i in sorted_indices if result.metadata is not None] 252 | result.scores = [result.scores[i] for i in sorted_indices] 253 | result.vectors = [result.vectors[i] for i in sorted_indices if result.vectors is not None] 254 | 255 | def flatten(self) -> QueryResult: 256 | """Flatten the batch results into a single QueryResult.""" 257 | if not self.results: 258 | return QueryResult(ids=[]) 259 | 260 | # Initialize with the first result 261 | flattened_result = self.results[0] 262 | 263 | # Iterate over the rest of the results and concatenate 264 | for result in self.results[1:]: 265 | flattened_result = flattened_result.concat(result) 266 | 267 | return flattened_result 268 | 269 | 270 | @dataclass 271 | class BatchFetchResult: 272 | """Vector store query result.""" 273 | 274 | results: List[FetchResult] 275 | batch_size: int = 1 276 | 277 | -------------------------------------------------------------------------------- /firstbatch/vector_store/supabase.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Any, Optional, List, Dict, Union 3 | from functools import partial 4 | import asyncio 5 | import logging 6 | from firstbatch.vector_store.schema import MetadataFilter 7 | from firstbatch.vector_store.base import VectorStore 8 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult,\ 9 | QueryResult, QueryMetadata, BatchFetchQuery, BatchFetchResult, FetchResult, Vector, DistanceMetric 10 | from firstbatch.lossy.base import BaseLossy, CompressedVector 11 | from firstbatch.constants import DEFAULT_COLLECTION, DEFAULT_EMBEDDING_SIZE, DEFAULT_HISTORY_FIELD 12 | 13 | if TYPE_CHECKING: 14 | from vecs import Client 15 | 16 | logger = logging.getLogger("FirstBatchLogger") 17 | 18 | 19 | class Supabase(VectorStore): 20 | 21 | def __init__( 22 | self, 23 | client: Client, 24 | collection_name: Optional[str] = None, 25 | query_name: Optional[str] = None, 26 | distance_metric: Optional[DistanceMetric] = None, 27 | history_field: Optional[str] = None, 28 | embedding_size: Optional[int] = None 29 | ) -> None: 30 | try: 31 | import vecs # noqa: F401 32 | except ImportError: 33 | raise ImportError( 34 | "Could not import supabase python package. " 35 | "Please install it with `pip install supabase`." 36 | ) 37 | 38 | self._client = client 39 | self.collection_name = collection_name or DEFAULT_COLLECTION 40 | self.query_name = query_name or "match_documents" 41 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 42 | self._collection = client.get_or_create_collection(name=self.collection_name, dimension=self._embedding_size) 43 | self._history_field = DEFAULT_HISTORY_FIELD if history_field is None else history_field 44 | self._distance_metric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 45 | logger.debug("Supabase/PGVector initialized with collection: {}".format(collection_name)) 46 | 47 | @property 48 | def quantizer(self): 49 | return self._quantizer 50 | 51 | @quantizer.setter 52 | def quantizer(self, value): 53 | self._quantizer = value 54 | 55 | @property 56 | def embedding_size(self): 57 | return self._embedding_size 58 | 59 | @embedding_size.setter 60 | def embedding_size(self, value): 61 | self._embedding_size = value 62 | 63 | @property 64 | def history_field(self): 65 | return self._history_field 66 | 67 | def train_quantizer(self, vectors: List[Vector]): 68 | if isinstance(self._quantizer, BaseLossy): 69 | self._quantizer.train(vectors) 70 | else: 71 | raise ValueError("Quantizer is not initialized or of the wrong type") 72 | 73 | def quantize_vector(self, vector: Vector) -> CompressedVector: 74 | return self._quantizer.compress(vector) 75 | 76 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 77 | return self._quantizer.decompress(vector) 78 | 79 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 80 | 81 | if query.include_values: 82 | result = self._collection.query(query.embedding.vector, query.top_k, query.filter.filter, 83 | include_value=True, include_metadata=False) 84 | 85 | ids, scores, vectors, metadata = [], [], [], [] 86 | ids_score = {r[0]: r[1] for r in result} 87 | fetches = self._collection.fetch(list(ids_score.keys())) 88 | for r in fetches: 89 | ids.append(r[0]) 90 | scores.append(ids_score[r[0]]) 91 | if query.include_values and query.include_metadata: 92 | vectors.append(Vector(vector=r[1], dim=len(r[1]), id=r[0])) 93 | metadata.append(QueryMetadata(id=r[0], data=r[2])) 94 | else: 95 | result = self._collection.query(query.embedding.vector, query.top_k, query.filter.filter, 96 | include_value=True, include_metadata=query.include_metadata) 97 | 98 | ids, scores, vectors, metadata = [], [], [], [] 99 | for r in result: 100 | ids.append(r[0]) 101 | scores.append(r[1]) 102 | if query.include_metadata: 103 | metadata.append(QueryMetadata(id=r[0], data=r[2])) 104 | vectors.append(Vector([], 0, r[0])) 105 | 106 | return QueryResult(ids=ids, scores=scores, vectors=vectors, metadata=metadata, 107 | distance_metric=self._distance_metric) 108 | 109 | async def asearch( 110 | self, query: Query, **kwargs: Any 111 | ) -> QueryResult: 112 | func = partial( 113 | self.search, 114 | query=query, 115 | **kwargs, 116 | ) 117 | return await asyncio.get_event_loop().run_in_executor(None, func) 118 | 119 | def fetch(self, query: FetchQuery, **kwargs: Any) -> FetchResult: 120 | assert query.id is not None, "id must be provided for fetch query" 121 | result = self._collection.fetch([query.id]) 122 | m = QueryMetadata(id=query.id, data=result[0][2]) 123 | v = Vector(vector=result[0][1], id=query.id, dim=len(result[0][1])) 124 | return FetchResult(id=query.id, vector=v, metadata=m) 125 | 126 | async def afetch( 127 | self, query: FetchQuery, **kwargs: Any 128 | ) -> FetchResult: 129 | func = partial( 130 | self.fetch, 131 | query=query, 132 | **kwargs, 133 | ) 134 | return await asyncio.get_event_loop().run_in_executor(None, func) 135 | 136 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 137 | async def _async_multi_search(): 138 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 139 | results = await asyncio.gather(*coroutines) 140 | return BatchQueryResult(results, batch_query.batch_size) 141 | 142 | return asyncio.run(_async_multi_search()) 143 | 144 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 145 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 146 | results = await asyncio.gather(*coroutines) 147 | return BatchQueryResult(results, batch_query.batch_size) 148 | 149 | def multi_fetch( 150 | self, batch_query: BatchFetchQuery, **kwargs: Any 151 | ) -> BatchFetchResult: 152 | 153 | ids = [q.id for q in batch_query.fetches] 154 | result = self._collection.fetch(ids) 155 | fetches = [FetchResult(id=idx[0], vector=Vector(vector=idx[1].tolist(), dim=len(idx[1].tolist())), 156 | metadata=QueryMetadata(id=idx[0], data=idx[2])) for i, idx in enumerate(result)] 157 | return BatchFetchResult(batch_size=batch_query.batch_size, results=fetches) 158 | 159 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 160 | 161 | if prev_filter is not None and not isinstance(prev_filter, str): 162 | if "$and" not in prev_filter: 163 | prev_filter["$and"] = [] 164 | for id in ids: 165 | prev_filter["$and"].append({self._history_field: {"$ne": id}}) 166 | 167 | return MetadataFilter(name="", filter=prev_filter) 168 | else: 169 | filter_: Dict = { 170 | "$and": [ 171 | ] 172 | } 173 | for id in ids: 174 | filter_["$and"].append({self._history_field: {"$ne": id}}) 175 | 176 | return MetadataFilter(name="History", filter=filter_) 177 | 178 | -------------------------------------------------------------------------------- /firstbatch/vector_store/typesense.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Optional, cast, List, Dict, Union 2 | from functools import partial 3 | import asyncio 4 | import logging 5 | from firstbatch.vector_store.schema import MetadataFilter 6 | from firstbatch.vector_store.base import VectorStore 7 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult, \ 8 | QueryResult, QueryMetadata, BatchFetchQuery, BatchFetchResult, FetchResult, Vector, DistanceMetric 9 | from firstbatch.constants import DEFAULT_COLLECTION, DEFAULT_EMBEDDING_SIZE 10 | from firstbatch.lossy.base import BaseLossy, CompressedVector 11 | 12 | 13 | if TYPE_CHECKING: 14 | from typesense.client import Client 15 | 16 | 17 | logger = logging.getLogger("FirstBatchLogger") 18 | 19 | 20 | class TypeSense(VectorStore): 21 | 22 | def __init__(self, 23 | client: "Client", 24 | collection_name: str = DEFAULT_COLLECTION, 25 | distance_metric: Optional[DistanceMetric] = None, 26 | history_field: Optional[str] = None, 27 | embedding_size: Optional[int] = None 28 | ) -> None: 29 | """Initialize params.""" 30 | import_err_msg = ( 31 | "`typesense` package not found, please run `pip install typesense`" 32 | ) 33 | try: 34 | import typesense # noqa: F401 35 | except ImportError: 36 | raise ImportError(import_err_msg) 37 | 38 | if client is not None: 39 | if not isinstance(client, typesense.Client): 40 | raise ValueError( 41 | f"client should be an instance of typesense.Client, " 42 | f"got {type(client)}" 43 | ) 44 | self._client = cast(typesense.Client, client) 45 | self._collection_name = collection_name 46 | self._collection = self._client.collections[self._collection_name] 47 | self._metadata_key = "metadata" 48 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 49 | self._history_field = "_id" if history_field is None else history_field 50 | self._distance_metric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 51 | logger.debug("TypeSense initialized with collection: {}".format(collection_name)) 52 | 53 | @property 54 | def quantizer(self): 55 | return self._quantizer 56 | 57 | @quantizer.setter 58 | def quantizer(self, value): 59 | self._quantizer = value 60 | 61 | @property 62 | def embedding_size(self): 63 | return self._embedding_size 64 | 65 | @embedding_size.setter 66 | def embedding_size(self, value): 67 | self._embedding_size = value 68 | 69 | @property 70 | def history_field(self): 71 | return self._history_field 72 | 73 | def train_quantizer(self, vectors: List[Vector]): 74 | if isinstance(self._quantizer, BaseLossy): 75 | self._quantizer.train(vectors) 76 | else: 77 | raise ValueError("Quantizer is not initialized or of the wrong type") 78 | 79 | def quantize_vector(self, vector: Vector) -> CompressedVector: 80 | return self._quantizer.compress(vector) 81 | 82 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 83 | return self._quantizer.decompress(vector) 84 | 85 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 86 | if query.filter.filter == {}: 87 | query_obj = { 88 | "q": "*", 89 | "vector_query": f'vec:({query.embedding.vector}, k:{query.top_k})', 90 | "collection": self._collection_name, 91 | } 92 | else: 93 | query_obj = { 94 | "q": "*", 95 | "vector_query": f'vec:({query.embedding.vector}, k:{query.top_k})', 96 | "collection": self._collection_name, 97 | "filter_by": str(query.filter.filter) 98 | } 99 | 100 | response = self._client.multi_search.perform( 101 | {"searches": [query_obj]}, {} 102 | ) 103 | q = QueryResult([], [], [], [], distance_metric=self._distance_metric) 104 | for hit in response["results"][0]["hits"]: 105 | document = hit["document"] 106 | metadata = {k: v for k, v in document.items() if k != 'vec'} 107 | 108 | q.metadata.append(QueryMetadata(id="", data=metadata)) # type: ignore 109 | q.vectors.append(Vector(vector=document["vec"], dim=len(document["vec"]), id=document["id"])) # type: ignore 110 | q.scores.append(hit["vector_distance"]) # type: ignore 111 | q.ids.append(document["id"]) 112 | return q 113 | 114 | async def asearch( 115 | self, query: Query, **kwargs: Any 116 | ) -> QueryResult: 117 | func = partial( 118 | self.search, 119 | query=query, 120 | **kwargs, 121 | ) 122 | return await asyncio.get_event_loop().run_in_executor(None, func) 123 | 124 | def fetch(self, query: FetchQuery, **kwargs: Any) -> FetchResult: 125 | res = self._client.collections[self._collection_name].documents[query.id].retrieve() 126 | metadata = {k: v for k, v in res.items() if k != 'vec'} 127 | vec = Vector(vector=res["vec"], id=res["id"], dim=len(res["vec"])) 128 | return FetchResult(vector=vec, metadata=QueryMetadata(id=res["id"], data=metadata), id=res["id"]) 129 | 130 | async def afetch( 131 | self, query: FetchQuery, **kwargs: Any 132 | ) -> FetchResult: 133 | func = partial( 134 | self.fetch, 135 | query=query, 136 | **kwargs, 137 | ) 138 | return await asyncio.get_event_loop().run_in_executor(None, func) 139 | 140 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 141 | async def _async_multi_search(): 142 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 143 | results = await asyncio.gather(*coroutines) 144 | return BatchQueryResult(results, batch_query.batch_size) 145 | 146 | return asyncio.run(_async_multi_search()) 147 | 148 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 149 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 150 | results = await asyncio.gather(*coroutines) 151 | return BatchQueryResult(results, batch_query.batch_size) 152 | 153 | def multi_fetch( 154 | self, batch_query: BatchFetchQuery, **kwargs: Any 155 | ) -> BatchFetchResult: 156 | async def _async_multi_fetch(): 157 | coroutines = [self.afetch(fetch, **kwargs) for fetch in batch_query.fetches] 158 | results = await asyncio.gather(*coroutines) 159 | return BatchFetchResult(results, batch_query.batch_size) 160 | 161 | return asyncio.run(_async_multi_fetch()) 162 | 163 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 164 | 165 | if self._history_field == "id": 166 | logger.debug("TypeSense doesn't allow filtering on id field. Try duplicating id in another field like _id.") 167 | raise ValueError("ID field error") 168 | 169 | filter_ = "{}:!=".format(self._history_field) + "[" + ",".join(ids) + "]" 170 | if prev_filter is not None and isinstance(prev_filter, str): 171 | filter_ += " && " + prev_filter 172 | return MetadataFilter(name="History", filter=filter_) 173 | -------------------------------------------------------------------------------- /firstbatch/vector_store/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for working with vectors and vectorstores.""" 2 | from math import ceil 3 | from typing import List, Union 4 | import numpy as np 5 | from firstbatch.vector_store.schema import Vector, Query, BatchQuery, QueryResult 6 | 7 | Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] 8 | 9 | 10 | def random_vec(dim): 11 | vec = np.random.random(dim) 12 | vec /= (np.linalg.norm(vec) + np.finfo(np.float64).eps) 13 | return vec 14 | 15 | 16 | def generate_vectors(dim, num_vectors): 17 | return [Vector(random_vec(dim).tolist(), dim, str(i)) for i in range(num_vectors)] 18 | 19 | 20 | def generate_query(num_vecs: int, dim: int, top_k: int, include_values: bool): 21 | for vec in generate_vectors(dim, num_vecs): 22 | yield Query(embedding=vec, top_k=top_k, top_k_mmr=int(top_k/2), include_values=include_values) 23 | 24 | 25 | def generate_batch(num_vecs: int, dim: int, top_k: int, include_values: bool): 26 | return BatchQuery(queries=list(generate_query(num_vecs, dim, top_k, include_values)), batch_size=num_vecs) 27 | 28 | 29 | def adjust_weights(weights: List[float], batch_size: float, c: float) -> List[int]: 30 | # Ensure the minimum weight is at least 1 31 | min_weight = min(weights) 32 | if min_weight < 1: 33 | diff = 1 - min_weight 34 | weights = [w + diff for w in weights] 35 | 36 | # Scale the weights so their sum is approximately batch_size 37 | current_sum = sum(weights) 38 | target_sum = batch_size 39 | if not (batch_size - c <= current_sum <= batch_size + c): 40 | scale_factor = target_sum / current_sum 41 | weights = [ceil(w * scale_factor) for w in weights] 42 | 43 | return [int(w) for w in weights] 44 | 45 | 46 | def maximal_marginal_relevance( 47 | query_embedding: Vector, 48 | batch: QueryResult, 49 | lambda_mult: float = 0.5, 50 | k: int = 4, 51 | ) -> QueryResult: 52 | 53 | embeddings = batch.to_ndarray() 54 | query = np.array(query_embedding.vector) 55 | 56 | if min(k, len(embeddings)) <= 0: 57 | return batch 58 | 59 | embeddings_norm = np.linalg.norm(embeddings, axis=1) 60 | query_norm = np.linalg.norm(query) 61 | dists = (embeddings @ query) / (embeddings_norm * query_norm) 62 | minval = np.argsort(dists) 63 | indices = [minval[0]] 64 | selected :np.ndarray = np.array([embeddings[minval[0]]]) 65 | while len(indices) < min(k, len(embeddings)): 66 | best_score = -np.inf 67 | idx_to_add = -1 68 | similarity_to_selected = cosine_similarity(embeddings, selected) 69 | 70 | for i, query_score in enumerate(dists): 71 | if i in indices: 72 | continue 73 | redundant_score = max(similarity_to_selected[i]) 74 | equation_score = ( 75 | lambda_mult * query_score - (1 - lambda_mult) * redundant_score 76 | ) 77 | if equation_score > best_score: 78 | best_score = equation_score 79 | idx_to_add = i 80 | indices.append(idx_to_add) 81 | selected = np.append(selected, [embeddings[idx_to_add]], axis=0) 82 | return QueryResult( 83 | ids=[batch.ids[i] for i in indices] if batch.ids is not None else [], 84 | metadata=[batch.metadata[i] for i in indices] if batch.metadata is not None else [], 85 | scores=[batch.scores[i] for i in indices] if batch.scores is not None else [], 86 | vectors=[batch.vectors[i] for i in indices] if batch.vectors is not None else [] 87 | ) 88 | 89 | 90 | def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: 91 | """Row-wise cosine similarity between two equal-width matrices.""" 92 | if len(X) == 0 or len(Y) == 0: 93 | return np.array([]) 94 | X = np.array(X) 95 | Y = np.array(Y) 96 | if X.shape[1] != Y.shape[1]: 97 | raise ValueError( 98 | f"Number of columns in X and Y must be the same. X has shape {X.shape} " 99 | f"and Y has shape {Y.shape}." 100 | ) 101 | 102 | X_norm = np.linalg.norm(X, axis=1) 103 | Y_norm = np.linalg.norm(Y, axis=1) 104 | # Ignore divide by zero errors run time warnings as those are handled below. 105 | with np.errstate(divide="ignore", invalid="ignore"): 106 | similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) 107 | similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 108 | return similarity 109 | 110 | 111 | -------------------------------------------------------------------------------- /firstbatch/vector_store/weaviate.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Optional, List, Dict, Union 4 | from functools import partial 5 | import asyncio 6 | import logging 7 | from firstbatch.vector_store.base import VectorStore 8 | from firstbatch.vector_store.schema import FetchQuery, Query, BatchQuery, BatchQueryResult, \ 9 | QueryResult, SearchType, QueryMetadata, BatchFetchQuery, Vector, FetchResult, BatchFetchResult, MetadataFilter, \ 10 | DistanceMetric 11 | from firstbatch.lossy.base import BaseLossy, CompressedVector 12 | from firstbatch.constants import DEFAULT_EMBEDDING_SIZE, DEFAULT_COLLECTION, DEFAULT_HISTORY_FIELD 13 | 14 | 15 | logger = logging.getLogger("FirstBatchLogger") 16 | 17 | 18 | class Weaviate(VectorStore): 19 | """`Weaviate` vector store.""" 20 | 21 | def __init__( 22 | self, 23 | client: Any, 24 | index_name: Optional[str] = None, 25 | output_fields: Optional[List[str]] = None, 26 | distance_metric: Optional[DistanceMetric] = None, 27 | history_field: Optional[str] = None, 28 | embedding_size: Optional[int] = None 29 | ): 30 | """Initialize with Weaviate client.""" 31 | 32 | if output_fields is None: 33 | output_fields = ["text"] 34 | try: 35 | import weaviate 36 | except ImportError: 37 | raise ImportError( 38 | "Could not import weaviate python package. " 39 | "Please install it with `pip install weaviate-client`." 40 | ) 41 | if not isinstance(client, weaviate.Client): 42 | raise ValueError( 43 | f"client should be an instance of weaviate.Client, got {type(client)}" 44 | ) 45 | self._client = client 46 | self._index_name = DEFAULT_COLLECTION if index_name is None else index_name 47 | self._output_fields = output_fields 48 | self._embedding_size = DEFAULT_EMBEDDING_SIZE if embedding_size is None else embedding_size 49 | self._history_field = DEFAULT_HISTORY_FIELD if history_field is None else history_field 50 | self._distance_metric = DistanceMetric.COSINE_SIM if distance_metric is None else distance_metric 51 | logger.debug("Weaviate initialized with index: {}".format(index_name)) 52 | 53 | @property 54 | def quantizer(self): 55 | return self._quantizer 56 | 57 | @quantizer.setter 58 | def quantizer(self, value): 59 | self._quantizer = value 60 | 61 | @property 62 | def embedding_size(self): 63 | return self._embedding_size 64 | 65 | @embedding_size.setter 66 | def embedding_size(self, value): 67 | self._embedding_size = value 68 | 69 | @property 70 | def history_field(self): 71 | return self._history_field 72 | 73 | def train_quantizer(self, vectors: List[Vector]): 74 | if isinstance(self._quantizer, BaseLossy): 75 | self._quantizer.train(vectors) 76 | else: 77 | raise ValueError("Quantizer is not initialized or of the wrong type") 78 | 79 | def quantize_vector(self, vector: Vector) -> CompressedVector: 80 | return self._quantizer.compress(vector) 81 | 82 | def dequantize_vector(self, vector: CompressedVector) -> Vector: 83 | return self._quantizer.decompress(vector) 84 | 85 | def search(self, query: Query, **kwargs: Any) -> QueryResult: 86 | """Return docs most similar to query using specified search type.""" 87 | if query.search_type == SearchType.FETCH: 88 | raise ValueError("search_type must be 'default' or 'sparse' to use search method") 89 | elif query.search_type == SearchType.SPARSE: 90 | raise NotImplementedError("Sparse search not implemented for Weaviate") 91 | else: 92 | vector = {"vector": query.embedding.vector} 93 | query_obj = self._client.query.get(self._index_name, self._output_fields) 94 | if query.filter.filter != {}: 95 | query_obj = query_obj.with_where(query.filter.filter) 96 | if kwargs.get("additional"): 97 | query_obj = query_obj.with_additional(kwargs.get("additional")) 98 | if query.include_values: 99 | query_obj = query_obj.with_additional(["vector", "distance", "id"]) 100 | else: 101 | query_obj = query_obj.with_additional(["distance", "id"]) 102 | 103 | result = query_obj.with_near_vector(vector).with_limit(query.top_k).do() 104 | if "errors" in result: 105 | raise ValueError(f"Error during query: {result['errors']}") 106 | ids, scores, vectors, metadata = [], [], [], [] 107 | for res in result["data"]["Get"][self._index_name.capitalize()]: 108 | _id = res["_additional"]["id"] 109 | ids.append(_id) 110 | m = {k: res.pop(k) for k in self._output_fields} 111 | metadata.append(QueryMetadata(id=_id, data=m)) 112 | scores.append(res["_additional"]["distance"]) 113 | if query.include_values: 114 | vectors.append(Vector(vector=res["_additional"]["vector"], 115 | id=_id, dim=len(res["_additional"]["vector"]))) 116 | else: 117 | vectors.append(Vector(vector=[], id=_id, dim=0)) 118 | 119 | return QueryResult(ids=ids, scores=scores, vectors=vectors, metadata=metadata, 120 | distance_metric=self._distance_metric) 121 | 122 | async def asearch( 123 | self, query: Query, **kwargs: Any 124 | ) -> QueryResult: 125 | func = partial( 126 | self.search, 127 | query=query, 128 | **kwargs, 129 | ) 130 | return await asyncio.get_event_loop().run_in_executor(None, func) 131 | 132 | def fetch( 133 | self, query: FetchQuery, **kwargs: Any 134 | ) -> FetchResult: 135 | """Return docs most similar to query using specified search type.""" 136 | # query_obj = self._client.query.get(self._index_name, self._search_properties) 137 | assert query.id is not None, "id must be provided for fetch query" 138 | data_object = self._client.data_object.get_by_id( 139 | query.id, 140 | class_name=self._index_name, 141 | with_vector=True 142 | ) 143 | m = QueryMetadata(id=query.id, data=data_object["properties"]) 144 | v = Vector(vector=data_object["vector"], id=query.id, dim=len(data_object["vector"])) 145 | return FetchResult(vector=v, metadata=m, id=query.id) 146 | 147 | async def afetch( 148 | self, query: FetchQuery, **kwargs: Any 149 | ) -> FetchResult: 150 | func = partial( 151 | self.fetch, 152 | query=query, 153 | **kwargs, 154 | ) 155 | return await asyncio.get_event_loop().run_in_executor(None, func) 156 | 157 | def multi_search(self, batch_query: BatchQuery, **kwargs: Any) -> BatchQueryResult: 158 | async def _async_multi_search(): 159 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 160 | results = await asyncio.gather(*coroutines) 161 | return BatchQueryResult(results, batch_query.batch_size) 162 | 163 | return asyncio.run(_async_multi_search()) 164 | 165 | async def a_multi_search(self, batch_query: BatchQuery, **kwargs: Any): 166 | coroutines = [self.asearch(query, **kwargs) for query in batch_query.queries] 167 | results = await asyncio.gather(*coroutines) 168 | return BatchQueryResult(results, batch_query.batch_size) 169 | 170 | def multi_fetch( 171 | self, batch_query: BatchFetchQuery, **kwargs: Any 172 | ) -> BatchFetchResult: 173 | async def _async_multi_fetch(): 174 | coroutines = [self.afetch(fetch, **kwargs) for fetch in batch_query.fetches] 175 | results = await asyncio.gather(*coroutines) 176 | return BatchFetchResult(results, batch_query.batch_size) 177 | 178 | return asyncio.run(_async_multi_fetch()) 179 | 180 | def history_filter(self, ids: List[str], prev_filter: Optional[Union[Dict, str]] = None) -> MetadataFilter: 181 | 182 | filter: Dict[str, Any] 183 | 184 | if prev_filter is not None: 185 | if isinstance(prev_filter, dict): 186 | filter = prev_filter 187 | else: 188 | raise TypeError("prev_filter must be a dictionary.") 189 | else: 190 | filter = { 191 | "operator": "And", 192 | "operands": [] 193 | } 194 | 195 | for id in ids: 196 | f = { 197 | "path": [self._history_field], 198 | "operator": "NotEqual", 199 | "valueText": id 200 | } 201 | filter["operands"].append(f) 202 | 203 | return MetadataFilter(name="History", filter=filter) 204 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "firstbatch" 3 | version = "0.1.73" 4 | description = "FirstBatch SDK for integrating user embeddings to your project. Add real-time personalization to your AI application without user data." 5 | authors = ["andthattoo "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9,<3.13" 11 | numpy = "^1.26.0" 12 | tdigest = "^0.5.2.2" 13 | nanopq = {version="^0.2.0", optional=true} 14 | chromadb = {version="^0.4.13", optional=true} 15 | typesense = {version="^0.17.0", optional=true} 16 | weaviate-client = {version="^3.24.2", optional=true} 17 | vecs = {version="^0.4.1", optional=true} 18 | requests = "^2.31.0" 19 | pydantic = "1.10.13" 20 | httpx = "^0.25.0" 21 | dataclasses-json = "^0.6.1" 22 | pinecone-client = {version="^2.2.4", optional=true} 23 | 24 | [tool.poetry.group.test.dependencies] 25 | pytest = "^7.3.0" 26 | pytest-cov = "^4.0.0" 27 | pytest-dotenv = "^0.5.2" 28 | responses = "^0.22.0" 29 | pytest-asyncio = "^0.20.3" 30 | 31 | [tool.poetry.extras] 32 | pinecone = ["pinecone-client"] 33 | supabase = ["vecs"] 34 | weaviate = ["weaviate-client"] 35 | typesense = ["typesense"] 36 | chromadb = ["chromadb"] 37 | product = ["nanopq"] 38 | 39 | [build-system] 40 | requires = ["poetry-core"] 41 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FirstBatch SDK 2 | 3 | The FirstBatch SDK provides an interface for integrating vector databases and powering personalized AI experiences in your application. 4 | 5 | ## Key Features 6 | 7 | - Seamlessly manage user sessions with persistent IDs or temporary sessions 8 | - Send signal actions like likes, clicks, etc. to update user embeddings in real-time 9 | - Fetch personalized batches of data tailored to each user's embeddings 10 | - Support for multiple vector database integrations: Pinecone, Weaviate, etc. 11 | - Built-in algorithms for common personalization use cases 12 | - Easy configuration with Python classes and environment variables 13 | 14 | ## Getting Started 15 | 16 | ### Prerequisites 17 | 18 | - Python 3.9+ 19 | - API keys for FirstBatch and your chosen vector database 20 | 21 | ### Installation 22 | 23 | ``` 24 | pip install firstbatch 25 | ``` 26 | 27 | ## Basic Usage 28 | 29 | 1. **Initialize VectorDB of your choice** 30 | ```python 31 | api_key = os.environ["PINECONE_API_KEY"] 32 | env = os.environ["PINECONE_ENV"] 33 | 34 | pinecone.init(api_key=api_key, environment=env) 35 | index = pinecone.Index("your_index_name") 36 | 37 | # Init FirstBatch 38 | config = Config(batch_size=20) 39 | personalized = FirstBatch(api_key=os.environ["FIRSTBATCH_API_KEY"], config=config) 40 | 41 | personalized.add_vdb("my_db", Pinecone(index, embedding_size=1536)) 42 | ``` 43 | 44 | ### Personalization 45 | 46 | 2. **Create a session with an Algorithm suiting your needs** 47 | ```python 48 | session = personalized.session(algorithm=AlgorithmLabel.AI_AGENTS, vdbid="my_db") 49 | ``` 50 | 51 | 3. **Make recommendations** 52 | ```python 53 | ids, batch = personalized.batch(session) 54 | ``` 55 | 4. **Let users add signals to shape their embeddings** 56 | ```python 57 | user_pick = 0 # User liked the first content from the previous batch. 58 | personalized.add_signal(session, UserAction(Signal.LIKE), ids[user_pick]) 59 | ``` 60 | 61 | ## Support 62 | 63 | For any issues or queries contact `support@firstbatch.xyz`. 64 | 65 | 66 | ## Resources 67 | 68 | - [User Embedding Guide](https://firstbatch.gitbook.io/user-embeddings/) 69 | - [SDK Documentation](https://firstbatch.gitbook.io/firstbatch-sdk/) 70 | 71 | Feel free to dive into the technicalities and leverage FirstBatch SDK for highly personalized user experiences. 72 | -------------------------------------------------------------------------------- /tests/algorithms/test_algorithms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from firstbatch import FirstBatch, Pinecone, Config, UserAction, Signal, AlgorithmLabel 3 | import pinecone 4 | import queue 5 | from firstbatch.vector_store.utils import generate_vectors 6 | import os 7 | 8 | 9 | @pytest.fixture 10 | def setup(): 11 | api_key = os.environ["PINECONE_API_KEY"] 12 | env = os.environ["PINECONE_ENV"] 13 | vdb_name = os.environ["VDB_NAME"] 14 | index_name = os.environ["INDEX_NAME"] 15 | embedding_size = int(os.environ["EMBEDDING_SIZE"]) 16 | 17 | pinecone.init(api_key=api_key, environment=env) 18 | pinecone.describe_index(index_name) 19 | index = pinecone.Index(index_name) 20 | 21 | cfg = Config(batch_size=20, quantizer_train_size=100, quantizer_type="scalar", 22 | enable_history=True, verbose=True) 23 | personalized = FirstBatch(api_key=os.environ["FIRSTBATCH_API_KEY"], config=cfg) 24 | personalized.add_vdb(vdb_name, Pinecone(index, embedding_size=embedding_size)) 25 | 26 | return personalized, vdb_name 27 | 28 | 29 | def test_simple(setup): 30 | actions = [("batch", 0), ("signal", 2), ("batch", 0)] 31 | action_queue = queue.Queue() 32 | for h in actions: 33 | action_queue.put(h) 34 | 35 | personalized, vdbid = setup 36 | session = personalized.session(algorithm=AlgorithmLabel.SIMPLE, vdbid=vdbid) 37 | ids, batch = [], [] 38 | 39 | while not action_queue.empty(): 40 | a = action_queue.get() 41 | if a[0] == "batch": 42 | ids, batch = personalized.batch(session) 43 | elif a[0] == "signal": 44 | cid = a[1] 45 | personalized.add_signal(session, UserAction(Signal.LIKE), ids[cid if cid < len(ids) else len(ids)-1]) 46 | 47 | 48 | def test_w_bias_vectors(setup): 49 | 50 | actions = [("batch", 0), ("signal", 2), ("batch", 0), ("signal", 4), ("batch", 0), 51 | ("batch", 0), ("signal", 1), ("signal", 2), ("signal", 3), ("batch", 0)] 52 | action_queue = queue.Queue() 53 | for h in actions: 54 | action_queue.put(h) 55 | 56 | starting_vectors = [vec.vector for vec in generate_vectors(int(os.environ["EMBEDDING_SIZE"]), 5)] 57 | starting_weights = [1.0] * 5 58 | data = {"bias_vectors": starting_vectors, "bias_weights": starting_weights} 59 | 60 | personalized, vdbid = setup 61 | session = personalized.session(algorithm=AlgorithmLabel.SIMPLE, vdbid=vdbid) 62 | ids, batch = [], [] 63 | 64 | while not action_queue.empty(): 65 | a = action_queue.get() 66 | if a[0] == "batch": 67 | ids, batch = personalized.batch(session, **data) 68 | elif a[0] == "signal": 69 | cid = a[1] 70 | personalized.add_signal(session, UserAction(Signal.LIKE), ids[cid if cid < len(ids) else len(ids)-1]) 71 | 72 | 73 | def test_factory(setup): 74 | 75 | actions = [("batch", 0), ("signal", 2), ("batch", 0), ("signal", 4), ("signal", 1), ("batch", 0), 76 | ("batch", 0), ("signal", 12), ("signal", 9)] 77 | action_queue = queue.Queue() 78 | for h in actions: 79 | action_queue.put(h) 80 | 81 | personalized, vdbid = setup 82 | session = personalized.session(algorithm=AlgorithmLabel.RECOMMENDATIONS, vdbid=vdbid) 83 | ids, batch = [], [] 84 | 85 | while not action_queue.empty(): 86 | a = action_queue.get() 87 | if a[0] == "batch": 88 | ids, batch = personalized.batch(session) 89 | elif a[0] == "signal": 90 | cid = a[1] 91 | personalized.add_signal(session, UserAction(Signal.ADD_TO_CART), ids[cid if cid < len(ids) else len(ids)-1]) 92 | 93 | 94 | def test_custom(setup): 95 | 96 | actions = [("batch", 0), ("signal", 2), ("batch", 0)] 97 | action_queue = queue.Queue() 98 | for h in actions: 99 | action_queue.put(h) 100 | 101 | personalized, vdbid = setup 102 | session = personalized.session(algorithm=AlgorithmLabel.CUSTOM, vdbid=vdbid, custom_id="f23a2cfe-5a38-4671-927d-0897c01a2d25") 103 | ids, batch = [], [] 104 | 105 | while not action_queue.empty(): 106 | a = action_queue.get() 107 | if a[0] == "batch": 108 | ids, batch = personalized.batch(session) 109 | elif a[0] == "signal": 110 | cid = a[1] 111 | personalized.add_signal(session, UserAction(Signal.ADD_TO_CART), ids[cid if cid < len(ids) else len(ids)-1]) 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /tests/algorithms/test_algorithms_async.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from firstbatch import AsyncFirstBatch, Pinecone, Config, UserAction, Signal, AlgorithmLabel 3 | import pinecone 4 | import queue 5 | import os 6 | 7 | 8 | @pytest.fixture 9 | def setup(): 10 | api_key = os.environ["PINECONE_API_KEY"] 11 | env = os.environ["PINECONE_ENV"] 12 | vdb_name = os.environ["VDB_NAME"] 13 | index_name = os.environ["INDEX_NAME"] 14 | embedding_size = int(os.environ["EMBEDDING_SIZE"]) 15 | 16 | pinecone.init(api_key=api_key, environment=env) 17 | pinecone.describe_index(index_name) 18 | index = pinecone.Index(index_name) 19 | 20 | config = Config(batch_size=20, quantizer_train_size=100, quantizer_type="scalar", 21 | enable_history=True, verbose=True) 22 | personalized = AsyncFirstBatch(api_key=os.environ["FIRSTBATCH_API_KEY"], config=config) 23 | return personalized, index, vdb_name, embedding_size 24 | 25 | 26 | @pytest.mark.asyncio 27 | async def test_async_simple(setup): 28 | actions = [("batch", 0), ("signal", 2), ("batch", 0)] 29 | action_queue = queue.Queue() 30 | for h in actions: 31 | action_queue.put(h) 32 | 33 | personalized, index, vdb, esize = setup 34 | await personalized.add_vdb(vdb, Pinecone(index, embedding_size=esize)) 35 | session = await personalized.session(algorithm=AlgorithmLabel.SIMPLE, vdbid=vdb) 36 | ids, batch = [], [] 37 | 38 | while not action_queue.empty(): 39 | a = action_queue.get() 40 | if a[0] == "batch": 41 | ids, batch = await personalized.batch(session) 42 | elif a[0] == "signal": 43 | cid = a[1] 44 | await personalized.add_signal(session, UserAction(Signal.LIKE), ids[cid if cid < len(ids) else len(ids)-1]) 45 | 46 | -------------------------------------------------------------------------------- /tests/algorithms/test_algorithms_vs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from firstbatch import FirstBatch, Qdrant, Config, UserAction, Signal, AlgorithmLabel 3 | from qdrant_client import QdrantClient 4 | import queue 5 | from firstbatch.vector_store.utils import generate_vectors 6 | import os 7 | 8 | 9 | @pytest.fixture 10 | def setup(): 11 | embedding_size = int(os.environ["EMBEDDING_SIZE"]) 12 | vdb_name = os.environ["VDB_NAME"] 13 | client = QdrantClient(os.environ["QDRANT_URL"]) 14 | cl = Qdrant(client=client, collection_name="farcaster", embedding_size=embedding_size) 15 | 16 | cfg = Config(batch_size=20, quantizer_train_size=100, quantizer_type="scalar", 17 | enable_history=True, verbose=True) 18 | personalized = FirstBatch(api_key=os.environ["FIRSTBATCH_API_KEY"], config=cfg) 19 | personalized.add_vdb(vdb_name, cl) 20 | 21 | return personalized, vdb_name 22 | 23 | 24 | def test_simple(setup): 25 | actions = [("batch", 0), ("signal", 2), ("batch", 0)] 26 | action_queue = queue.Queue() 27 | for h in actions: 28 | action_queue.put(h) 29 | 30 | personalized, vdbid = setup 31 | session = personalized.session(algorithm=AlgorithmLabel.SIMPLE, vdbid=vdbid) 32 | ids, batch = [], [] 33 | 34 | while not action_queue.empty(): 35 | a = action_queue.get() 36 | if a[0] == "batch": 37 | ids, batch = personalized.batch(session) 38 | elif a[0] == "signal": 39 | cid = a[1] 40 | personalized.add_signal(session, UserAction(Signal.LIKE), ids[cid if cid < len(ids) else len(ids)-1]) 41 | 42 | 43 | def test_w_bias_vectors(setup): 44 | 45 | actions = [("batch", 0), ("signal", 2), ("batch", 0), ("signal", 4), ("batch", 0), 46 | ("batch", 0), ("signal", 1), ("signal", 2), ("signal", 3), ("batch", 0)] 47 | action_queue = queue.Queue() 48 | for h in actions: 49 | action_queue.put(h) 50 | 51 | starting_vectors = [vec.vector for vec in generate_vectors(int(os.environ["EMBEDDING_SIZE"]), 5)] 52 | starting_weights = [1.0] * 5 53 | data = {"bias_vectors": starting_vectors, "bias_weights": starting_weights} 54 | 55 | personalized, vdbid = setup 56 | session = personalized.session(algorithm=AlgorithmLabel.SIMPLE, vdbid=vdbid) 57 | ids, batch = [], [] 58 | 59 | while not action_queue.empty(): 60 | a = action_queue.get() 61 | if a[0] == "batch": 62 | ids, batch = personalized.batch(session, **data) 63 | elif a[0] == "signal": 64 | cid = a[1] 65 | personalized.add_signal(session, UserAction(Signal.LIKE), ids[cid if cid < len(ids) else len(ids)-1]) 66 | 67 | def test_custom(setup): 68 | 69 | actions = [("batch", 0), ("signal", 2), ("batch", 0)] 70 | action_queue = queue.Queue() 71 | for h in actions: 72 | action_queue.put(h) 73 | 74 | personalized, vdbid = setup 75 | session = personalized.session(algorithm=AlgorithmLabel.CUSTOM, vdbid=vdbid, custom_id="f23a2cfe-5a38-4671-927d-0897c01a2d25") 76 | ids, batch = [], [] 77 | 78 | while not action_queue.empty(): 79 | a = action_queue.get() 80 | if a[0] == "batch": 81 | ids, batch = personalized.batch(session) 82 | elif a[0] == "signal": 83 | cid = a[1] 84 | personalized.add_signal(session, UserAction(Signal.ADD_TO_CART), ids[cid if cid < len(ids) else len(ids)-1]) 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /tests/compression/test_lossy_product.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import nanopq 4 | from firstbatch.lossy import ProductQuantizer, CompressedVector 5 | from firstbatch.vector_store.schema import Vector 6 | 7 | 8 | def random_vec(dim): 9 | vec = np.random.random(dim) 10 | vec /= (np.linalg.norm(vec) + np.finfo(np.float64).eps) 11 | return vec 12 | 13 | 14 | @pytest.fixture 15 | def setup_data(): 16 | data = [Vector(random_vec(1536).tolist(), 1536, str(i)) for i in range(10000)] 17 | pq = ProductQuantizer(512, 32) 18 | return data, pq 19 | 20 | 21 | def test_comp_decomp(setup_data): 22 | data, pq = setup_data 23 | pq.train(data) 24 | comp = pq.compress(data[0]) 25 | assert isinstance(comp, CompressedVector) 26 | 27 | decomp = pq.decompress(comp) 28 | assert isinstance(decomp, Vector) 29 | print("Error", np.sum(np.abs(np.array(decomp.vector) - np.array(data[0].vector)))) 30 | 31 | 32 | def test_reproduce(setup_data): 33 | data, pq = setup_data 34 | pq.train(data) 35 | 36 | new_pq = nanopq.PQ(32, 512) 37 | new_pq_res = nanopq.PQ(32, 512) 38 | 39 | new_pq.codewords = pq.quantizer.codewords 40 | new_pq_res.codewords = pq.quantizer_residual.codewords 41 | 42 | new_pq.Ds = pq.quantizer.Ds 43 | new_pq_res.Ds = pq.quantizer_residual.Ds 44 | 45 | comp = pq.quantizer.encode(np.array(data[0].vector, dtype=np.float32).reshape(1, -1)) 46 | comp2 = new_pq.encode(np.array(data[0].vector, dtype=np.float32).reshape(1, -1)) 47 | assert comp.all() == comp2.all() 48 | 49 | comp = pq.quantizer_residual.encode(np.array(data[0].vector, dtype=np.float32).reshape(1, -1)) 50 | comp2 = new_pq_res.encode(np.array(data[0].vector, dtype=np.float32).reshape(1, -1)) 51 | assert comp.all() == comp2.all() 52 | -------------------------------------------------------------------------------- /tests/compression/test_lossy_scalar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from firstbatch.lossy import ScalarQuantizer, CompressedVector 3 | from firstbatch.vector_store.schema import Vector 4 | import numpy as np 5 | 6 | 7 | def random_vec(dim): 8 | vec = np.random.random(dim) 9 | vec /= (np.linalg.norm(vec) + np.finfo(np.float64).eps) 10 | return vec 11 | 12 | 13 | @pytest.fixture 14 | def setup_data(): 15 | data = [Vector(random_vec(1536).tolist(), 1536, str(i)) for i in range(1000)] 16 | pq = ScalarQuantizer(256) 17 | return data, pq 18 | 19 | 20 | def test_comp_decomp(setup_data): 21 | data, pq = setup_data 22 | pq.train(data) 23 | comp = pq.compress(data[0]) 24 | assert isinstance(comp, CompressedVector) 25 | 26 | decomp = pq.decompress(comp) 27 | assert isinstance(decomp, Vector) 28 | 29 | error = np.sum(np.abs(np.array(decomp.vector) - np.array(data[0].vector))) 30 | print("Error", error) 31 | -------------------------------------------------------------------------------- /tests/parser/test_parser.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from firstbatch.algorithm import DFAParser, UserAction, AlgorithmRegistry, AlgorithmLabel 3 | from firstbatch.algorithm.blueprint import Signal 4 | from firstbatch.algorithm.blueprint.library import lookup 5 | 6 | 7 | @pytest.fixture 8 | def setup_data(): 9 | json_str = ''' 10 | { 11 | "signals": [ 12 | {"label": "NEW_SIGNAL", "weight": 1.5} 13 | ], 14 | "nodes": [ 15 | {"name": "1", "batch_type": "biased", "params": {"mu":0.1}}, 16 | {"name": "2", "batch_type": "random", "params": {"r":0.5}}, 17 | {"name": "3", "batch_type": "sampled", "params": {}} 18 | ], 19 | "edges": [ 20 | {"name": "edge5", "edge_type": "NEW_SIGNAL", "start": "3", "end": "3"}, 21 | {"name": "edge1", "edge_type": "LIKE", "start": "1", "end": "2"}, 22 | {"name": "edge2", "edge_type": "BATCH", "start": "1", "end": "3"}, 23 | {"name": "edge3", "edge_type": "DEFAULT", "start": "1", "end": "3"}, 24 | {"name": "edge4", "edge_type": "BATCH", "start": "2", "end": "3"}, 25 | {"name": "edge5", "edge_type": "DEFAULT", "start": "2", "end": "2"}, 26 | {"name": "edge5", "edge_type": "BATCH", "start": "3", "end": "2"}, 27 | {"name": "edge5", "edge_type": "DEFAULT", "start": "3", "end": "2"} 28 | ] 29 | }''' 30 | 31 | json_str2 = '''{ 32 | "nodes": [ 33 | {"name": "0", "batch_type": "random", "params": {}}, 34 | {"name": "1", "batch_type": "biased", "params": {"mu":0.0}}, 35 | {"name": "2", "batch_type": "biased", "params": {"mu":0.5}}, 36 | {"name": "3", "batch_type": "biased", "params": {"mu":1.0}} 37 | ], 38 | "edges": [ 39 | {"name": "edge1", "edge_type": "BATCH", "start": "0", "end": "0"}, 40 | {"name": "edge2", "edge_type": "DEFAULT", "start": "0", "end": "1"}, 41 | {"name": "edge3", "edge_type": "DEFAULT", "start": "1", "end": "1"}, 42 | {"name": "edge4", "edge_type": "BATCH", "start": "1", "end": "2"}, 43 | {"name": "edge5", "edge_type": "DEFAULT", "start": "2", "end": "1"}, 44 | {"name": "edge5", "edge_type": "BATCH", "start": "2", "end": "3"}, 45 | {"name": "edge5", "edge_type": "BATCH", "start": "3", "end": "0"}, 46 | {"name": "edge5", "edge_type": "DEFAULT", "start": "3", "end": "1"} 47 | ] 48 | }''' 49 | 50 | d = '''{ 51 | "nodes": [ 52 | {"name": "Initial_State", "batch_type": "random", "params": {}}, 53 | {"name": "Personalized_Recommendation", "batch_type": "biased", "params": {"mu": 0.8, "alpha": 0.7, "apply_mmr": 1, "last_n": 5}}, 54 | {"name": "Personalized_Exploratory", "batch_type": "biased", "params": {"mu": 0.6, "alpha": 0.5, "apply_mmr": 1, "last_n": 5, "r": 0.1}} 55 | ], 56 | "edges": [ 57 | {"name": "edge1", "edge_type": "DEFAULT", "start": "Initial_State", "end": "Personalized_Recommendation"}, 58 | {"name": "edge2", "edge_type": "DEFAULT", "start": "Personalized_Recommendation", "end": "Personalized_Exploratory"}, 59 | {"name": "edge3", "edge_type": "DEFAULT", "start": "Personalized_Exploratory", "end": "Initial_State"}, 60 | {"name": "edge4", "edge_type": "BATCH", "end": "Initial_State", "start": "Initial_State"}, 61 | {"name": "edge5", "edge_type": "BATCH", "end": "Personalized_Recommendation", "start": "Initial_State"}, 62 | {"name": "edge6", "edge_type": "BATCH", "end": "Personalized_Exploratory", "start": "Personalized_Recommendation"} 63 | ] 64 | }''' 65 | return json_str, json_str2, d 66 | 67 | 68 | def test_factory(): 69 | for k, v in lookup.items(): 70 | parser = DFAParser(v) 71 | try: 72 | blueprint = parser.parse() 73 | except Exception as e: 74 | pytest.fail(f"{e} error with {k}") 75 | 76 | 77 | def test_signal(setup_data): 78 | algo_type = AlgorithmRegistry.get_algorithm_by_label(AlgorithmLabel.UNIQUE_JOURNEYS) 79 | algo_instance = algo_type(factory_id, batch_size, **{"embedding_size": embedding_size}) # type: ignore 80 | 81 | blueprint = algo_instance._blueprint 82 | 83 | assert len(blueprint.vertices) == 7 84 | assert len(blueprint.edges) == 46 85 | 86 | current_vertex = 'Exploration' 87 | action = UserAction(Signal.REPOST) 88 | next_vertex, _, _ = blueprint.step(current_vertex, action) 89 | assert next_vertex.name == 'Dedicated_2' 90 | -------------------------------------------------------------------------------- /tests/vector_store/test_chroma.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import chromadb 3 | import warnings 4 | from firstbatch.vector_store import Chroma 5 | from firstbatch.vector_store.utils import generate_query, generate_batch 6 | from firstbatch.vector_store.schema import BatchQueryResult, QueryResult, FetchQuery, FetchResult, BatchFetchQuery, BatchFetchResult 7 | import os 8 | 9 | warnings.simplefilter("ignore", DeprecationWarning) 10 | warnings.simplefilter("ignore", ResourceWarning) 11 | 12 | 13 | @pytest.fixture 14 | def setup_chroma_client(): 15 | path = os.environ["CHROMA_CLIENT_PATH"] 16 | collection = "default" 17 | client = chromadb.PersistentClient(path=path) 18 | dim = 1536 19 | return Chroma(client=client, collection_name=collection, history_field="text"), dim 20 | 21 | 22 | def test_search(setup_chroma_client): 23 | chroma_client, dim = setup_chroma_client 24 | query = next(generate_query(1, dim, 10, True)) 25 | res = chroma_client.search(query) 26 | assert isinstance(res, QueryResult) 27 | 28 | 29 | def test_fetch(setup_chroma_client): 30 | chroma_client, dim = setup_chroma_client 31 | query = next(generate_query(1, dim, 10, False)) 32 | res = chroma_client.search(query) 33 | assert isinstance(res, QueryResult) 34 | fetch = FetchQuery(id=res.ids[0]) 35 | res = chroma_client.fetch(fetch) 36 | assert isinstance(res, FetchResult) 37 | 38 | 39 | def test_multi_search(setup_chroma_client): 40 | chroma_client, dim = setup_chroma_client 41 | batch = generate_batch(10, dim, 10, True) 42 | res = chroma_client.multi_search(batch) 43 | assert isinstance(res, BatchQueryResult) 44 | 45 | 46 | def test_multi_fetch(setup_chroma_client): 47 | chroma_client, dim = setup_chroma_client 48 | query = next(generate_query(1, dim, 10, False)) 49 | res = chroma_client.search(query) 50 | assert isinstance(res, QueryResult) 51 | ids = [id for id in res.ids] 52 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 53 | res = chroma_client.multi_fetch(bfq) 54 | assert isinstance(res, BatchFetchResult) 55 | 56 | 57 | def test_history(setup_chroma_client): 58 | """ 59 | Not implemented for Chroma 60 | """ 61 | chroma_client, dim = setup_chroma_client 62 | query = next(generate_query(1, dim, 10, False)) 63 | res = chroma_client.search(query) 64 | filt = chroma_client.history_filter([d.data[chroma_client.history_field] for d in res.metadata]) 65 | query.filter = filt 66 | res_ = chroma_client.search(query) 67 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 68 | -------------------------------------------------------------------------------- /tests/vector_store/test_pinecone.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pinecone 3 | import warnings 4 | from firstbatch.vector_store import Pinecone 5 | from firstbatch.vector_store.utils import generate_query, generate_batch 6 | from firstbatch.vector_store.schema import BatchQueryResult, QueryResult, FetchQuery, FetchResult, BatchFetchQuery, BatchFetchResult 7 | import os 8 | 9 | warnings.simplefilter("ignore", DeprecationWarning) 10 | warnings.simplefilter("ignore", ResourceWarning) 11 | 12 | 13 | @pytest.fixture 14 | def setup_pinecone_client(): 15 | api_key = os.environ["PINECONE_API_KEY"] 16 | env = os.environ["PINECONE_ENV"] 17 | index_name = os.environ["INDEX_NAME"] 18 | dim = 384 19 | pinecone.init(api_key=api_key, environment=env) 20 | pinecone.describe_index(index_name) 21 | index = pinecone.Index(index_name) 22 | return Pinecone(index=index, namespace=None), dim 23 | 24 | 25 | def test_search(setup_pinecone_client): 26 | pinecone_client, dim = setup_pinecone_client 27 | query = next(generate_query(1, dim, 10, True)) 28 | res = pinecone_client.search(query) 29 | assert isinstance(res, QueryResult) 30 | 31 | 32 | def test_fetch(setup_pinecone_client): 33 | pinecone_client, dim = setup_pinecone_client 34 | query = next(generate_query(1, dim, 10, False)) 35 | res = pinecone_client.search(query) 36 | assert isinstance(res, QueryResult) 37 | fetch = FetchQuery(id=res.ids[0]) 38 | res = pinecone_client.fetch(fetch) 39 | assert isinstance(res, FetchResult) 40 | 41 | 42 | def test_multi_search(setup_pinecone_client): 43 | pinecone_client, dim = setup_pinecone_client 44 | batch = generate_batch(10, dim, 10, True) 45 | res = pinecone_client.multi_search(batch) 46 | assert isinstance(res, BatchQueryResult) 47 | 48 | 49 | def test_multi_fetch(setup_pinecone_client): 50 | pinecone_client, dim = setup_pinecone_client 51 | query = next(generate_query(1, dim, 10, False)) 52 | res = pinecone_client.search(query) 53 | assert isinstance(res, QueryResult) 54 | ids = [id for id in res.ids] 55 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 56 | res = pinecone_client.multi_fetch(bfq) 57 | assert isinstance(res, BatchFetchResult) 58 | 59 | 60 | def test_history(setup_pinecone_client): 61 | pinecone_client, dim = setup_pinecone_client 62 | query = next(generate_query(1, dim, 10, False)) 63 | res = pinecone_client.search(query) 64 | filt = pinecone_client.history_filter([d.data[pinecone_client.history_field] for d in res.metadata]) 65 | query.filter = filt 66 | res_ = pinecone_client.search(query) 67 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 68 | -------------------------------------------------------------------------------- /tests/vector_store/test_qdrant.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from qdrant_client import QdrantClient 3 | from firstbatch.vector_store import Qdrant 4 | from firstbatch.vector_store.utils import generate_query, generate_batch 5 | from firstbatch.vector_store.schema import ( 6 | BatchQueryResult, QueryResult, FetchQuery, FetchResult, BatchFetchQuery, BatchFetchResult) 7 | import os 8 | 9 | @pytest.fixture 10 | def setup_typesense_client(): 11 | client = QdrantClient(os.environ["QDRANT_URL"]) 12 | return Qdrant(client=client, collection_name="farcaster") 13 | 14 | 15 | @pytest.fixture 16 | def dim(): 17 | return 1536 18 | 19 | 20 | def test_search(setup_typesense_client, dim): 21 | query = next(generate_query(1, dim, 10, True)) 22 | res = setup_typesense_client.search(query) 23 | assert isinstance(res, QueryResult) 24 | 25 | 26 | def test_fetch(setup_typesense_client, dim): 27 | query = next(generate_query(1, dim, 10, False)) 28 | res = setup_typesense_client.search(query) 29 | assert isinstance(res, QueryResult) 30 | fetch = FetchQuery(id=res.ids[0]) 31 | res = setup_typesense_client.fetch(fetch) 32 | assert isinstance(res, FetchResult) 33 | 34 | 35 | def test_multi_search(setup_typesense_client, dim): 36 | batch = generate_batch(10, dim, 10, True) 37 | res = setup_typesense_client.multi_search(batch) 38 | assert isinstance(res, BatchQueryResult) 39 | 40 | 41 | def test_multi_fetch(setup_typesense_client, dim): 42 | query = next(generate_query(1, dim, 10, False)) 43 | res = setup_typesense_client.search(query) 44 | assert isinstance(res, QueryResult) 45 | ids = [id for id in res.ids] 46 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 47 | res = setup_typesense_client.multi_fetch(bfq) 48 | assert isinstance(res, BatchFetchResult) 49 | 50 | 51 | def test_history(setup_typesense_client, dim): 52 | query = next(generate_query(1, dim, 10, False)) 53 | res = setup_typesense_client.search(query) 54 | filt = setup_typesense_client.history_filter([d.data[setup_typesense_client.history_field] for d in res.metadata]) 55 | query.filter = filt 56 | res_ = setup_typesense_client.search(query) 57 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 58 | -------------------------------------------------------------------------------- /tests/vector_store/test_supabase.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import vecs 3 | import warnings 4 | from firstbatch.vector_store import Supabase 5 | from firstbatch.vector_store.utils import generate_query, generate_batch 6 | from firstbatch.vector_store.schema import BatchQueryResult, QueryResult, FetchQuery, FetchResult, BatchFetchQuery, BatchFetchResult 7 | import os 8 | 9 | warnings.simplefilter("ignore", DeprecationWarning) 10 | warnings.simplefilter("ignore", ResourceWarning) 11 | 12 | 13 | @pytest.fixture 14 | def setup_supabase_client(): 15 | supabase_uri = os.environ["SUPABASE_URL"] 16 | client = vecs.create_client(supabase_uri) 17 | dim = 1536 18 | return Supabase(client=client, collection_name="new", query_name="match_documents"), dim 19 | 20 | 21 | def test_search(setup_supabase_client): 22 | supabase_client, dim = setup_supabase_client 23 | query = next(generate_query(1, dim, 10, True)) 24 | res = supabase_client.search(query) 25 | assert isinstance(res, QueryResult) 26 | 27 | 28 | def test_fetch(setup_supabase_client): 29 | supabase_client, dim = setup_supabase_client 30 | query = next(generate_query(1, dim, 10, False)) 31 | res = supabase_client.search(query) 32 | assert isinstance(res, QueryResult) 33 | fetch = FetchQuery(id=res.ids[0]) 34 | res = supabase_client.fetch(fetch) 35 | assert isinstance(res, FetchResult) 36 | 37 | 38 | def test_multi_search(setup_supabase_client): 39 | supabase_client, dim = setup_supabase_client 40 | batch = generate_batch(10, dim, 10, True) 41 | res = supabase_client.multi_search(batch) 42 | assert isinstance(res, BatchQueryResult) 43 | 44 | 45 | def test_multi_fetch(setup_supabase_client): 46 | supabase_client, dim = setup_supabase_client 47 | query = next(generate_query(1, dim, 10, False)) 48 | res = supabase_client.search(query) 49 | assert isinstance(res, QueryResult) 50 | ids = [id for id in res.ids] 51 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 52 | res = supabase_client.multi_fetch(bfq) 53 | assert isinstance(res, BatchFetchResult) 54 | 55 | 56 | def test_history(setup_supabase_client): 57 | supabase_client, dim = setup_supabase_client 58 | query = next(generate_query(1, dim, 10, False)) 59 | res = supabase_client.search(query) 60 | filt = supabase_client.history_filter([d.data[setup_supabase_client.history_field] for d in res.metadata]) 61 | query.filter = filt 62 | res_ = supabase_client.search(query) 63 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 64 | -------------------------------------------------------------------------------- /tests/vector_store/test_typesense.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import typesense 3 | from firstbatch.vector_store import TypeSense 4 | from firstbatch.vector_store.utils import generate_query, generate_batch 5 | from firstbatch.vector_store.schema import ( 6 | BatchQueryResult, QueryResult, FetchQuery, FetchResult, BatchFetchQuery, BatchFetchResult) 7 | import os 8 | 9 | @pytest.fixture 10 | def setup_typesense_client(): 11 | client = typesense.Client({ 12 | 'api_key': os.environ["TYPESENSE_API_KEY"], 13 | 'nodes': [{ 14 | 'host': os.environ["TYPESENSE_URL"], 15 | 'port': '443', 16 | 'protocol': 'https' # or http if working locally 17 | }], 18 | 'connection_timeout_seconds': 2 19 | }) 20 | return TypeSense(client=client, collection_name="tiktok") 21 | 22 | 23 | @pytest.fixture 24 | def dim(): 25 | return 384 26 | 27 | 28 | def test_search(setup_typesense_client, dim): 29 | query = next(generate_query(1, dim, 10, True)) 30 | res = setup_typesense_client.search(query) 31 | assert isinstance(res, QueryResult) 32 | 33 | 34 | def test_fetch(setup_typesense_client, dim): 35 | query = next(generate_query(1, dim, 10, False)) 36 | res = setup_typesense_client.search(query) 37 | assert isinstance(res, QueryResult) 38 | fetch = FetchQuery(id=res.ids[0]) 39 | res = setup_typesense_client.fetch(fetch) 40 | assert isinstance(res, FetchResult) 41 | 42 | 43 | def test_multi_search(setup_typesense_client, dim): 44 | batch = generate_batch(10, dim, 10, True) 45 | res = setup_typesense_client.multi_search(batch) 46 | assert isinstance(res, BatchQueryResult) 47 | 48 | 49 | def test_multi_fetch(setup_typesense_client, dim): 50 | query = next(generate_query(1, dim, 10, False)) 51 | res = setup_typesense_client.search(query) 52 | assert isinstance(res, QueryResult) 53 | ids = [id for id in res.ids] 54 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 55 | res = setup_typesense_client.multi_fetch(bfq) 56 | assert isinstance(res, BatchFetchResult) 57 | 58 | 59 | def test_history(setup_typesense_client, dim): 60 | query = next(generate_query(1, dim, 10, False)) 61 | res = setup_typesense_client.search(query) 62 | filt = setup_typesense_client.history_filter([d.data[setup_typesense_client.history_field] for d in res.metadata]) 63 | query.filter = filt 64 | res_ = setup_typesense_client.search(query) 65 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 66 | -------------------------------------------------------------------------------- /tests/vector_store/test_weaviate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import weaviate 3 | import warnings 4 | from firstbatch.vector_store import Weaviate 5 | from firstbatch.vector_store.utils import generate_query, generate_batch 6 | from firstbatch.vector_store.schema import BatchQueryResult, QueryResult, FetchQuery, BatchFetchQuery, BatchFetchResult, FetchResult 7 | import os 8 | 9 | warnings.simplefilter("ignore", DeprecationWarning) 10 | warnings.simplefilter("ignore", ResourceWarning) 11 | 12 | 13 | @pytest.fixture 14 | def setup_weaviate_client(): 15 | auth_config = weaviate.AuthApiKey(api_key=os.environ["WEAVIATE_API_KEY"]) 16 | client = weaviate.Client( 17 | url=os.environ["WEAVIATE_URL"], 18 | auth_client_secret=auth_config, 19 | ) 20 | index_name = "default" 21 | dim = 1536 22 | return Weaviate(client=client, index_name=index_name), dim 23 | 24 | 25 | def test_search(setup_weaviate_client): 26 | weaviate_client, dim = setup_weaviate_client 27 | query = next(generate_query(1, dim, 10, True)) 28 | res = weaviate_client.search(query) 29 | assert isinstance(res, QueryResult) 30 | 31 | 32 | def test_fetch(setup_weaviate_client): 33 | weaviate_client, dim = setup_weaviate_client 34 | query = next(generate_query(1, dim, 10, True)) 35 | res = weaviate_client.search(query) 36 | assert isinstance(res, QueryResult) 37 | fetch = FetchQuery(id=res.ids[0]) 38 | res = weaviate_client.fetch(fetch) 39 | assert isinstance(res, FetchResult) 40 | 41 | 42 | def test_multi_search(setup_weaviate_client): 43 | weaviate_client, dim = setup_weaviate_client 44 | batch = generate_batch(10, dim, 10, True) 45 | res = weaviate_client.multi_search(batch) 46 | assert isinstance(res, BatchQueryResult) 47 | 48 | 49 | def test_multi_fetch(setup_weaviate_client): 50 | weaviate_client, dim = setup_weaviate_client 51 | query = next(generate_query(1, dim, 10, True)) 52 | res = weaviate_client.search(query) 53 | assert isinstance(res, QueryResult) 54 | ids = [id for id in res.ids] 55 | bfq = BatchFetchQuery(batch_size=10, fetches=[FetchQuery(id=id) for id in ids]) 56 | res = weaviate_client.multi_fetch(bfq) 57 | assert isinstance(res, BatchFetchResult) 58 | 59 | 60 | def test_history(setup_weaviate_client): 61 | weaviate_client, dim = setup_weaviate_client 62 | query = next(generate_query(1, dim, 10, False)) 63 | res = weaviate_client.search(query) 64 | filt = weaviate_client.history_filter([d.data[weaviate_client.history_field] for d in res.metadata]) 65 | query.filter = filt 66 | res_ = weaviate_client.search(query) 67 | assert len(set(res.ids).intersection(set(res_.ids))) == 0 68 | --------------------------------------------------------------------------------