├── tests ├── __init__.py ├── test_server_list_fail.txt ├── test_server_list.txt ├── test_index_config.json ├── test_index_state.py ├── test_client.py ├── test_rpc.py └── test_integration.py ├── distributed_faiss ├── __init__.py ├── index_state.py ├── index_cfg.py ├── rpc.py ├── client.py ├── server.py └── index.py ├── design_schema.png ├── detailed_design.png ├── scripts ├── idx_cfg.json ├── server_launcher.py └── load_data.py ├── .gitignore ├── setup.py ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /distributed_faiss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_server_list_fail.txt: -------------------------------------------------------------------------------- 1 | 4 2 | machine1234,8080 3 | machine1234,8081 4 | machine5678,8080 -------------------------------------------------------------------------------- /design_schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/distributed-faiss/HEAD/design_schema.png -------------------------------------------------------------------------------- /detailed_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/distributed-faiss/HEAD/detailed_design.png -------------------------------------------------------------------------------- /tests/test_server_list.txt: -------------------------------------------------------------------------------- 1 | 4 2 | machine1234,8080 3 | machine1234,8081 4 | machine5678,8080 5 | machine5678,8081 -------------------------------------------------------------------------------- /tests/test_index_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "index_storage_dir": "/tmp/save", 3 | "dim": "1024", 4 | "factory_type": "IVF{centroids},SQ8", 5 | "centroids": "1000" 6 | } 7 | -------------------------------------------------------------------------------- /scripts/idx_cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": "768", 3 | "faiss_factory": "flat", 4 | "centroids": "5000", 5 | "train_data_ratio": 1.0, 6 | "metric": "l2", 7 | "code_size": 64, 8 | "bits_per_vector": 8 9 | } 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Developer tools 6 | .vscode/ 7 | 8 | # Installation files 9 | *.egg-info/ 10 | eggs/ 11 | .eggs/ 12 | *.egg 13 | *.mmap 14 | -------------------------------------------------------------------------------- /tests/test_index_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import unittest 10 | from distributed_faiss.index_state import IndexState 11 | 12 | 13 | class TestIndexState(unittest.TestCase): 14 | def test_get_aggregated_states(self): 15 | states = [IndexState.NOT_TRAINED, IndexState.TRAINED, IndexState.TRAINING] 16 | self.assertEqual(IndexState.get_aggregated_states(states), IndexState.TRAINING) 17 | 18 | states = [IndexState.TRAINED, IndexState.TRAINED, IndexState.TRAINED] 19 | self.assertEqual(IndexState.get_aggregated_states(states), IndexState.TRAINED) 20 | 21 | states = [IndexState.NOT_TRAINED, IndexState.NOT_TRAINED, IndexState.TRAINED] 22 | self.assertEqual(IndexState.get_aggregated_states(states), IndexState.NOT_TRAINED) 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | from setuptools import setup 11 | 12 | with open("README.md") as f: 13 | readme = f.read() 14 | 15 | setup( 16 | name="distributed_faiss", 17 | version="0.0.1", 18 | description="Facebook AI Research Distributed Faiss", 19 | url="", # TODO 20 | classifiers=[ 21 | "Intended Audience :: Science/Research", 22 | "License :: CC-BY-NC", 23 | "Programming Language :: Python :: 3.7", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | ], 26 | long_description=readme, 27 | long_description_content_type="text/markdown", 28 | setup_requires=["setuptools>=18.0"], 29 | install_requires=[ 30 | "black", 31 | "cython", 32 | "faiss-cpu>=1.7.2", 33 | "filelock", 34 | "numpy", 35 | "regex", 36 | "submitit>=1.1.5", 37 | "torch>=1.2.0", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /distributed_faiss/index_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | from typing import List 9 | 10 | 11 | class IndexState(Enum): 12 | NOT_TRAINED = 1 13 | TRAINING = 2 14 | ADD = 3 15 | TRAINED = 4 16 | 17 | @staticmethod 18 | def get_aggregated_states(states: List["IndexState"]) -> "IndexState": 19 | states = set(states) 20 | assert len(states) > 0 21 | # check if all the states are consistent 22 | if len(states) == 1: 23 | return states.pop() 24 | # consider cluster to be still in training state 25 | # if at least one server is still training 26 | if IndexState.TRAINING in states: 27 | return IndexState.TRAINING 28 | # otherwise we are in a state where some servers 29 | # are trained and some are not trained. 30 | if IndexState.NOT_TRAINED in states: 31 | return IndexState.NOT_TRAINED 32 | # some nodes may be in the ADD-ing state 33 | if IndexState.ADD in states: 34 | return IndexState.ADD 35 | else: 36 | return IndexState.TRAINED 37 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import unittest 10 | import os 11 | from distributed_faiss.client import IndexClient 12 | 13 | 14 | TESTSERVERLIST_FILENAME = os.path.join(os.path.dirname(__file__), "test_server_list.txt") 15 | TESTSERVERLIST_FILENAME_FAIL = os.path.join(os.path.dirname(__file__), "test_server_list_fail.txt") 16 | 17 | 18 | class TestClient(unittest.TestCase): 19 | def test_read_server_list(self): 20 | servers = IndexClient.read_server_list(TESTSERVERLIST_FILENAME) 21 | self.assertEqual(len(servers), 4) 22 | self.assertEqual(servers[0][0], "machine1234") 23 | self.assertEqual(servers[0][1], 8080) 24 | self.assertEqual(servers[3][0], "machine5678") 25 | self.assertEqual(servers[3][1], 8081) 26 | 27 | with self.assertRaises(Exception) as context: 28 | IndexClient.read_server_list( 29 | TESTSERVERLIST_FILENAME_FAIL, 30 | total_max_timeout=0, 31 | ) 32 | self.assertEqual( 33 | ( 34 | f"4 != 3 in server list " 35 | f"{TESTSERVERLIST_FILENAME_FAIL}. Timed " 36 | "out after waiting 0.0 seconds" 37 | ), 38 | str(context.exception), 39 | ) 40 | 41 | 42 | if __name__ == "__main__": 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to distributed-faiss 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing to distributed-faiss, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /distributed_faiss/index_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import faiss 8 | import json 9 | 10 | 11 | class IndexCfg: 12 | def __init__( 13 | self, 14 | index_builder_type: str = None, 15 | faiss_factory: str = None, 16 | dim: int = 768, 17 | train_num: int = 0, 18 | train_ratio: int = 1.0, 19 | centroids: int = 0, 20 | metric: str = "dot", 21 | nprobe: int = 1, 22 | infer_centroids=False, 23 | buffer_bsz: int = 50000, 24 | save_interval_sec: int = -1, 25 | index_storage_dir: str = None, 26 | custom_meta_id_idx: int = 0, 27 | **kwargs, 28 | ): 29 | self.index_builder_type = index_builder_type 30 | self.faiss_factory = faiss_factory 31 | self.dim = int(dim) 32 | self.train_num = train_num 33 | self.train_ratio = train_ratio 34 | self.centroids = centroids 35 | self.metric = metric 36 | self.nprobe = nprobe 37 | self.infer_centroids = infer_centroids 38 | self.buffer_bsz = buffer_bsz 39 | self.save_interval_sec = save_interval_sec 40 | self.index_storage_dir = index_storage_dir 41 | self.custom_meta_id_idx = custom_meta_id_idx 42 | self.extra = kwargs 43 | 44 | def get_metric(self): 45 | metric = self.metric 46 | if metric == "dot": 47 | faiss_metric = faiss.METRIC_INNER_PRODUCT 48 | elif metric == "l2": 49 | faiss_metric = faiss.METRIC_L2 50 | else: 51 | raise RuntimeError("Only dot and l2 metrics are supported.") 52 | return faiss_metric 53 | 54 | @classmethod 55 | def from_json(cls, json_path): 56 | with open(json_path, "r") as f: 57 | kwargs = json.load(f) 58 | return cls(**kwargs) 59 | 60 | def to_json_string(self): 61 | return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) 62 | 63 | def __repr__(self) -> str: 64 | return f"" 65 | -------------------------------------------------------------------------------- /tests/test_rpc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | import _thread 9 | from multiprocessing import Pool as ProcessPool 10 | 11 | import time 12 | import torch 13 | 14 | from distributed_faiss.index_cfg import IndexCfg 15 | from distributed_faiss.index_state import IndexState 16 | from distributed_faiss.rpc import Client 17 | from distributed_faiss.server import IndexServer 18 | 19 | 20 | def add_train_data(c: Client, index_id): 21 | print("Call client ", c.id) 22 | c.add_index_data(index_id, torch.rand(100, 512).numpy(), None) 23 | 24 | 25 | def call_train(c: Client): 26 | c.async_train(0) 27 | 28 | 29 | def run_client(id): 30 | print("Run client ", id) 31 | c = Client(id, "localhost") 32 | cfg = IndexCfg(index_builder_type="flat", dim=512) 33 | idx_id = "idx:1" 34 | c.create_index(idx_id, cfg) 35 | for i in range(10): 36 | add_train_data(c, idx_id) 37 | 38 | c.async_train(idx_id) 39 | 40 | c.add_buffer_to_index(idx_id) 41 | 42 | while True: 43 | state = c.get_state(idx_id) 44 | print("Server state {}".format(state)) 45 | if state == IndexState.TRAINED: 46 | break 47 | time.sleep(2) 48 | 49 | for i in range(10): 50 | _result = c.search(idx_id, torch.rand(5, 512).numpy(), 5, True) 51 | c.close() 52 | 53 | 54 | class TestRPC(unittest.TestCase): 55 | save_dir = "/tmp/distributed_faiss/" 56 | 57 | def test_single_server_multiple_clients_threaded(self): 58 | server = IndexServer(0, index_storage_dir=self.save_dir) 59 | _thread.start_new_thread(server.start_blocking, ()) 60 | time.sleep(2) # let it start accepting clients 61 | processes = ProcessPool(processes=10) 62 | ids = list(range(10)) 63 | processes.map(run_client, ids) 64 | server.stop() 65 | 66 | @unittest.skip("Fails with ValueError: I/O operation on closed file.") 67 | def test_single_server_multiple_clients(self): 68 | server = IndexServer(0, index_storage_dir=self.save_dir) 69 | _thread.start_new_thread(server.start, ()) 70 | 71 | time.sleep(2) # let it start accepting clients 72 | 73 | clients = [] 74 | for i in range(10): 75 | c = Client(i, "localhost") 76 | clients.append(c) 77 | 78 | index_key = "lang_en" 79 | 80 | for i in range(10): 81 | [ 82 | c.add_train_data(index_key, torch.rand(100, 512).numpy(), None) 83 | # [("test_meta", c_id*100+i*10+j, None) for j in range(10)] 84 | for c_id, c in enumerate(clients) 85 | ] 86 | 87 | [c.async_train(0) for c in clients] 88 | 89 | while True: 90 | states = [c.get_state() for c in clients] 91 | print("Server states ", states) 92 | if all(s == IndexState.TRAINED for s in states): 93 | break 94 | time.sleep(2) 95 | 96 | # query 97 | results = [c.search(index_key, torch.rand(1, 512).numpy(), 5, True) for c in clients] 98 | 99 | print("Result 0", results[0]) 100 | 101 | 102 | if __name__ == "__main__": 103 | # test_single_server_multiple_clients() 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /distributed_faiss/rpc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Simplistic RPC implementation. 11 | Exposes all functions of a Server object. 12 | Uses pickle for serialization and the socket interface. 13 | 14 | Copied from https://github.com/facebookresearch/faiss/blob/master/benchs/distributed_ondisk/rpc.py 15 | """ 16 | 17 | import os 18 | import pickle 19 | import socket 20 | 21 | # default 22 | DEFAULT_PORT = 12032 23 | 24 | 25 | ######################################################################### 26 | # simple I/O functions 27 | 28 | 29 | def inline_send_handle(f, conn): 30 | st = os.fstat(f.fileno()) 31 | size = st.st_size 32 | pickle.dump(size, conn) 33 | conn.write(f.read(size)) 34 | 35 | 36 | def inline_send_string(s, conn): 37 | size = len(s) 38 | pickle.dump(size, conn) 39 | conn.write(s) 40 | 41 | 42 | class FileSock: 43 | """ 44 | wraps a socket so that it is usable by pickle/cPickle 45 | """ 46 | 47 | def __init__(self, sock): 48 | self.sock = sock 49 | self.nr = 0 50 | self.last_read_len = 0 51 | 52 | """ 53 | def write(self, buf): 54 | print("sending %d bytes ", len(buf), flush=True) 55 | self.sock.sendall(buf) 56 | """ 57 | 58 | def write(self, buf): 59 | # print("sending %d bytes"%len(buf)) 60 | # self.sock.sendall(buf) 61 | # print("...done") 62 | bs = 128 * 512 * 1024 63 | ns = 0 64 | while ns < len(buf): 65 | sent = self.sock.send(buf[ns : ns + bs]) 66 | ns += sent 67 | 68 | def read(self, bs=128 * 512 * 1024): 69 | self.nr += 1 70 | b = [] 71 | nb = 0 72 | while len(b) < bs: 73 | # print(' loop') 74 | rb = self.sock.recv(bs - nb) 75 | if not rb: 76 | break 77 | b.append(rb) 78 | nb += len(rb) 79 | 80 | # logger.info("read nb=%s", nb) 81 | 82 | self.last_read_len = nb 83 | return b"".join(b) 84 | 85 | def readline(self): 86 | # print("readline!") 87 | """may be optimized...""" 88 | s = bytes() 89 | while True: 90 | c = self.read(1) 91 | s += c 92 | if len(c) == 0 or chr(c[0]) == "\n": 93 | return s 94 | 95 | 96 | class ClientExit(Exception): 97 | pass 98 | 99 | 100 | class ServerException(Exception): 101 | pass 102 | 103 | 104 | class Client: 105 | """ 106 | Methods of the server object can be called transparently. Exceptions are 107 | re-raised. 108 | """ 109 | 110 | def __init__(self, id, HOST, port=DEFAULT_PORT, v6=False): 111 | self.id = id 112 | socktype = socket.AF_INET6 if v6 else socket.AF_INET 113 | 114 | sock = socket.socket(socktype, socket.SOCK_STREAM) 115 | print("connecting", HOST, port, socktype) 116 | sock.connect((HOST, port)) 117 | self.sock = sock 118 | self.fs = FileSock(sock) 119 | 120 | def generic_fun(self, fname, args): 121 | # int "gen fun",fname 122 | # logger.info('Client=%s, call fname=%s', self.id, fname) 123 | pickle.dump((fname, args), self.fs, protocol=4) 124 | return self.get_result() 125 | 126 | def get_result(self): 127 | (st, ret) = pickle.load(self.fs) 128 | if st != None: 129 | raise ServerException(st) 130 | else: 131 | return ret 132 | 133 | def close(self): 134 | self.sock.shutdown(socket.SHUT_RDWR) 135 | self.sock.close() 136 | 137 | def __getattr__(self, name): 138 | return lambda *x: self.generic_fun(name, x) 139 | -------------------------------------------------------------------------------- /scripts/server_launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import submitit 10 | import math 11 | import argparse 12 | import os 13 | import errno 14 | import time 15 | from distributed_faiss.server import IndexServer, DEFAULT_PORT 16 | 17 | """ 18 | fcntl.flock is broken for NFS. Using this workaround: 19 | https://stackoverflow.com/questions/37633951/python-locking-text-file-on-nfs 20 | """ 21 | 22 | 23 | def lockfile(target, link, timeout=300): 24 | global lock_owner 25 | poll_time = 10 26 | while timeout > 0: 27 | try: 28 | os.link(target, link) 29 | lock_owner = True 30 | break 31 | except OSError as err: 32 | if err.errno == errno.EEXIST: 33 | print("Lock unavailable. Waiting for 10 seconds...") 34 | time.sleep(poll_time) 35 | timeout -= poll_time 36 | else: 37 | raise err 38 | else: 39 | print("Timed out waiting for the lock.") 40 | 41 | 42 | def releaselock(link): 43 | try: 44 | if lock_owner: 45 | os.unlink(link) 46 | except OSError: 47 | print("Error:didn't possess lock.") 48 | 49 | 50 | def append_to_discovery_config_safe(discovery_config, msg): 51 | # tmp_link will be destroyed after unlink 52 | tmp_link = discovery_config + ".link" 53 | lockfile(discovery_config, tmp_link) 54 | with open(discovery_config, "a") as config: 55 | config.write(msg) 56 | releaselock(tmp_link) 57 | 58 | 59 | def run_server(discovery_config, base_port, index_storage_dir: str, load_index=False): 60 | job_env = submitit.JobEnvironment() 61 | # add local rank to avoid port conflict on the same machine 62 | port = base_port + job_env.local_rank 63 | 64 | append_to_discovery_config_safe(discovery_config, f"{job_env.hostname},{port}\r\n") 65 | 66 | server = IndexServer(job_env.global_rank, index_storage_dir) 67 | server.start_blocking(port, v6=False, load_index=load_index) 68 | return 69 | 70 | 71 | def main(): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--log-dir", required=True, help="log data dir") 74 | parser.add_argument( 75 | "--discovery-config", 76 | required=True, 77 | help="where to store info about running servers", 78 | ) 79 | parser.add_argument("--num-servers", type=int, required=True, help="how many servers to run") 80 | parser.add_argument("--save-dir", required=False, default="distributed_faiss_storage") 81 | parser.add_argument( 82 | "--num-servers-per-node", type=int, default=10, help="how many servers per node" 83 | ) 84 | parser.add_argument("--mem-gb", type=int, default=8, help="how much RAM per process") 85 | parser.add_argument("--cpus-per-node", type=int, default=32, help="how many cpus per node") 86 | parser.add_argument("--partition", required=True, help="slurm partition") 87 | parser.add_argument("--timeout-min", type=int, default=60, help="how to run for in min") 88 | parser.add_argument( 89 | "--base-port", 90 | type=int, 91 | default=DEFAULT_PORT, 92 | help="base-port + local-rank is going to be the final port", 93 | ) 94 | parser.add_argument( 95 | "--load-index", 96 | default=False, 97 | action="store_true", 98 | help="If true server will try to load index and meta from disk", 99 | ) 100 | parser.add_argument( 101 | "--comment", 102 | type=str, 103 | ) 104 | 105 | args = parser.parse_args() 106 | 107 | discovery_config = open(args.discovery_config, "w") 108 | discovery_config.write(f"{args.num_servers}\r\n") 109 | discovery_config.close() 110 | 111 | executor = submitit.AutoExecutor(folder=args.log_dir) 112 | num_nodes = math.ceil(args.num_servers / args.num_servers_per_node) 113 | cpus_per_server = math.floor(args.cpus_per_node / args.num_servers_per_node) 114 | executor.update_parameters( 115 | tasks_per_node=args.num_servers_per_node, 116 | nodes=num_nodes, 117 | timeout_min=args.timeout_min, 118 | slurm_partition=args.partition, 119 | mem_gb=args.mem_gb, 120 | cpus_per_task=cpus_per_server, 121 | comment=args.comment, 122 | ) 123 | job = executor.submit( 124 | run_server, 125 | args.discovery_config, 126 | args.base_port, 127 | args.save_dir, 128 | args.load_index, 129 | ) 130 | print(job.results()) 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /scripts/load_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import numpy as np 10 | import faiss 11 | import time 12 | from tqdm import tqdm 13 | from pathlib import Path 14 | 15 | import unittest 16 | import torch 17 | import random 18 | import string 19 | 20 | from distributed_faiss.rpc import Client 21 | from distributed_faiss.server import IndexServer, DEFAULT_PORT 22 | from distributed_faiss.client import IndexClient 23 | from distributed_faiss.index_state import IndexState 24 | from distributed_faiss.index_cfg import IndexCfg 25 | import time 26 | 27 | import json 28 | import logging 29 | 30 | logging.basicConfig(level=4) 31 | 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--mmap", type=str, help="memmap where keys and vals are stored") 36 | parser.add_argument( 37 | "--mmap-size", type=int, help="number of items saved in the datastore memmap" 38 | ) 39 | 40 | parser.add_argument("--dimension", type=int, default=1024, help="Size of each key") 41 | parser.add_argument("--cfg", type=str, default=None, help="path to index config json") 42 | parser.add_argument("--dstore-fp16", default=False, action="store_true") 43 | parser.add_argument( 44 | "--ncentroids", 45 | type=int, 46 | default=4096, 47 | help="number of centroids faiss should learn", 48 | ) 49 | parser.add_argument( 50 | "--bs", 51 | default=1000, 52 | type=int, 53 | help="can only load a certain amount of data to memory at a time.", 54 | ) 55 | parser.add_argument("--start", default=0, type=int, help="index to start adding keys at") 56 | parser.add_argument( 57 | "--discover", 58 | type=str, 59 | help="serverlist_path", 60 | ) 61 | parser.add_argument("--load_index", action="store_true") 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def array_to_memmap(array, filename): 67 | if os.path.exists(filename): 68 | fp = np.memmap(filename, mode="r", dtype=array.dtype, shape=array.shape) 69 | return fp 70 | 71 | fp = np.memmap(filename, mode="write", dtype=array.dtype, shape=array.shape) 72 | fp[:] = array[:] # copy 73 | fp.flush() 74 | del array 75 | return fp 76 | 77 | 78 | def save_random_mmap(path, nrow, ncol, chunk_size=100000): 79 | fp = np.memmap(path, mode="write", dtype=np.float16, shape=(nrow, ncol)) 80 | 81 | for i in tqdm(range(0, nrow, chunk_size), desc=f"saving random mmap to {path}"): 82 | end = min(nrow, i + chunk_size) 83 | fp[i:end] = np.random.rand(end - i, ncol).astype(fp.dtype) 84 | fp.flush() 85 | del fp 86 | 87 | 88 | """ 89 | python scripts/load_data.py --discover discover_val.txt \ 90 | --mmap random --mmap-size 111649041 --dimension 768 \ 91 | --cfg idx_cfg.json 92 | """ 93 | 94 | 95 | def main(): 96 | args = get_args() 97 | client = IndexClient(args.discover) 98 | cfg = IndexCfg.from_json(args.cfg) 99 | index_id = "lang_en" 100 | 101 | if args.load_index: 102 | client.load_index(index_id, cfg) 103 | else: 104 | client.create_index(index_id, cfg) 105 | if args.mmap == "random": 106 | rand_path = f"random_{args.mmap_size}_{args.dimension}_fp16.mmap" 107 | if os.path.exists(rand_path): 108 | print(f"Found random mmap at {rand_path}") 109 | save_random_mmap(rand_path, args.mmap_size, args.dimension) 110 | args.dstore_fp16 = True 111 | args.mmap = rand_path 112 | keys = np.memmap( 113 | args.mmap, 114 | dtype=np.float16 if args.dstore_fp16 else np.float32, 115 | mode="r", 116 | shape=(args.mmap_size, args.dimension), 117 | ) 118 | num_vec = keys.shape[0] 119 | since_save = 0 120 | ids = np.arange(args.start, args.mmap_size) 121 | for i in tqdm(list(range(0, num_vec, args.bs))): 122 | end = min(i + args.bs, num_vec) 123 | emb, id = keys[i:end].copy().astype(np.float32), ids[i:end] 124 | client.add_index_data(index_id, emb, id.tolist()) 125 | since_save += 1 126 | if (since_save / client.num_indexes) >= (1e7 / args.bs): 127 | since_save = 0 128 | client.save_index() 129 | 130 | if client.get_state(index_id) == IndexState.NOT_TRAINED: 131 | client.sync_train(index_id) 132 | while client.get_state(index_id) != IndexState.TRAINED: 133 | time.sleep(1) 134 | 135 | rand_vec = torch.rand((1, 768)).numpy() 136 | client.search(rand_vec, 4, index_id) 137 | client.save_index(index_id) 138 | print(f"ntotal: {client.get_ntotal(index_id)}") 139 | 140 | rand_vec = torch.rand( 141 | ( 142 | 1, 143 | args.dimension, 144 | ) 145 | ).numpy() 146 | client.search(rand_vec, 4, index_id) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | Distributed faiss index service. A lightweight library that lets you work with FAISS indexes which don't fit into a single server memory. It follows a simple concept of a set of index server processes runing in a complete isolation from each other. All the coordination is done at the client side. This siplified many-vs-many client-to-server relationship architecture is flexible and is specifically designed for research projects vs more complicated solutions that aims mostly at production usage and transactionality support. 3 | The data is sharded over several indexes on different servers in RAM. The search client aggregates results from different servers during retrieval. The service is model-independent and operates with supplied embeddings and metadatas. 4 | 5 | ### Features: 6 | * Multiple clients connect to all servers via RPC. 7 | * At indexing time: clients balance data across servers. The client sends the next available batch of embeddings to a server that is selected in a round-robin fashion. 8 | * The index client aggregates results from different servers during retrieval. It queries all the servers and uses a heap to find final results. 9 | * The API allows to send and store any additional metadata (e.g. raw bpe, language information, etc). 10 | * Launch servers with submitit. 11 | * Save/load the index/metadata periodically. Can restore from a stopped index state. 12 | * Supports several indexes at the same time (e.g. one index per language, or different versions of the same index). 13 | * The API is trying to optimize for network bandwidth. 14 | * Flexible index configuration. 15 | 16 | 17 | 18 | ### Installation 19 | `pip install -e .` 20 | 21 | 22 | ### Testing 23 | `python -m unittest discover tests` 24 | 25 | or 26 | ```bash 27 | pip install pytest 28 | pytest tests 29 | ``` 30 | 31 | ### Code formatting 32 | `black --line-length 100 .` 33 | 34 | # Usage 35 | ## Starting the index servers 36 | distributed-faiss consist of server and client parts which are supposed to be launched as separate services. 37 | The set of server processes can be launched either by using its API or the provided lauch tool that uses [`submitit`](https://github.com/facebookincubator/submitit) library that works on clusters with SLURM cluster management and job scheduling system 38 | 39 | 40 | 41 | ## Launching servers with submitit on SLURM managed clusters 42 | Example: 43 | 44 | ```bash 45 | python scripts/server_launcher.py \ 46 | --log-dir /logs/distr-faiss/ \ 47 | --discovery-config /tmp/discover_config.txt \ 48 | --save-dir $HOME/dfaiss_data \ 49 | --num-servers 64 \ 50 | --num-servers-per-node 32 \ 51 | --timeout-min 4320 \ 52 | --mem-gb 400 \ 53 | --base-port 12033 \ 54 | --partition dev & 55 | ``` 56 | Clients can now read `/tmp/discover_config.txt` to discover servers. 57 | 58 | Will launch a job running 64 servers in the background. 59 | To view logs (which are verbose but informative) run something like: 60 | `watch 'tail /logs/distr-faiss/34785924_0_log.err'` 61 | where the `34785924` will be the slurm job id you are allocated. 62 | 63 | 64 | ## Launching servers using API 65 | You can run each index server process indepentently using the following API: 66 | 67 | ```python 68 | server = IndexServer(global_rank, index_storage_dir) 69 | server.start_blocking(port, load_index=True) 70 | ``` 71 | 72 | The rank of the server node is needed for reading/writing its own part of the index from/to files. Index are dumped to files for persistent storage. The filesytem path convetion is that there is a shared folder for the entire logical index with each server node working on its own sub-folder inside it. 73 | index_storage_dir is the default parameter to store indexes. Can be overrided for each logic index by specifing this attribute in the index configuration object (see client code examples below) 74 | When you start a server node on a specific machine and port, you need to write the host, port line to a specific file which can later be used to start a client. 75 | 76 | 77 | ## Client API 78 | Each client process is supposed to work with all the server nodes and does all the data balancing among them. Client processes can be run independently of each other and work with the same set of server nodes simulateously. 79 | 80 | ```python 81 | index_client = IndexClient(discovery_config) 82 | ``` 83 | discovery_config is the path to the shared FS file which was used to start the set of servers and contains all (host, port) info to connect to all of them. 84 | 85 | ## Creating an index 86 | Each client & server nodes can work with multiple logical indexes (consider them as fully separate tables in an SQL database). 87 | Each logical index can have its own faiss-related configuration, FS location and other parameters which affect its creation logic. 88 | Example of creating a simle IVF index: 89 | 90 | ```python 91 | index_client = IndexClient(discovery_config) 92 | idx_cfg = IndexCfg( 93 | index_builder_type='ivf_simple', 94 | dim=128, 95 | train_num=10000, 96 | centroids=64, 97 | metric='dot', 98 | nprobe=12, 99 | index_storage_dir='path/to/your/index', 100 | ) 101 | index_id = 'your logic index str id' 102 | index_client.create_index(index_id, idx_cfg) 103 | ``` 104 | 105 | ## Index configuration 106 | 107 | `IndexCfg` has multiple attributes to set the FAISS index type. 108 | List of values for `index_builder_type` attribute: 109 | - `flat`, 110 | - `ivf_simple`, 111 | - `knnlm`, corresponds to `IndexIVFPQ`, 112 | - `hnswsq`, corresponds to `IndexHNSWSQ`, 113 | - `ivfsq`, corresponds to `IndexIVFScalarQuantizer`, 114 | - `ivf_gpu` is a gpu version of `IVF`. 115 | 116 | Alternatively, if `index_builder_type` is not specified, one can set `faiss_factory` just like in FAISS API factory call `faiss.index_factory(...)` 117 | 118 | The following attributes defined the way the index is created: 119 | - `train_num` - if specified, sets the number of samples are used for the index training. 120 | - `train_ratio` - the same as train_num but as a ratio of total data size. 121 | 122 | Data sent for indexing will be aggregated in memory until `train_num` threshold is exceeded. 123 | Please refer to the diagram below about the server and client side interactions and steps. 124 | 125 | 126 | 127 | ## Client side operations 128 | Once the index has been created, one can send batches of numpy arrays coupled with arbitrarily metadata (should be piackable) 129 | 130 | ```python 131 | index.add_index_data(index_id, vector_chunk, list_of_metadata) 132 | ``` 133 | The index training and creation are done asynchronously with the `add()` operation the index processing may take a lot of time after all the data are sent. 134 | In order to check if all server nodes have finished index building, it is recommended to use the following snippet: 135 | 136 | ```python 137 | while index.get_state(self.index_id) != IndexState.TRAINED: 138 | time.sleep(some_time) 139 | ``` 140 | 141 | Once the index is ready, one can query it: 142 | ```python 143 | scores, meta = index.search(query, topk=10, index_id, return_embeddings=False) 144 | ``` 145 | query is a query vector batch as a numpy array. return_embeddings enables to return the search result vectors in addition to metadata. If it is set to true, the result tuple will return vectors as the 3-rd element. 146 | 147 | ## Loading Data 148 | The following two commands load a medium sized mmap into distributed-faiss in about 1 minute: 149 | 150 | First launch 64 servers in the background 151 | ```bash 152 | python scripts/server_launcher.py \ 153 | --log-dir /logs/distr-faiss/ \ 154 | --discovery-config /tmp/discover_config.txt \ 155 | --save-dir $HOME/dfaiss_data \ 156 | --num-servers 64 \ 157 | --num-servers-per-node 32 \ 158 | --timeout-min 4320 \ 159 | --mem-gb 400 \ 160 | --base-port 12033 \ 161 | --partition dev & 162 | ``` 163 | Once you receive your allocation, load in the data with 164 | 165 | ```bash 166 | python scripts/load_data.py \ 167 | --discover /tmp/discover_config.txt \ 168 | --mmap $HOME/dfaiss_data/random_1000000000_768_fp16.mmap \ 169 | --mmap-size 1000000000 \ 170 | --dimension 768 \ 171 | --dstore-fp16 \ 172 | --cfg scripts/idx_cfg.json \ 173 | --dstore-fp16 174 | ``` 175 | 176 | modify `scripts/load_data.py` to load other data formats. 177 | 178 | # Reference 179 | Reference to cite when using `distributed-faiss` in a research paper: 180 | ``` 181 | @article{DBLP:journals/corr/abs-2112-09924, 182 | author = {Aleksandra Piktus and 183 | Fabio Petroni and 184 | Vladimir Karpukhin and 185 | Dmytro Okhonko and 186 | Samuel Broscheit and 187 | Gautier Izacard and 188 | Patrick Lewis and 189 | Barlas Oguz and 190 | Edouard Grave and 191 | Wen{-}tau Yih and 192 | Sebastian Riedel}, 193 | title = {The Web Is Your Oyster - Knowledge-Intensive {NLP} against a Very 194 | Large Web Corpus}, 195 | journal = {CoRR}, 196 | volume = {abs/2112.09924}, 197 | year = {2021}, 198 | url = {https://arxiv.org/abs/2112.09924}, 199 | eprinttype = {arXiv}, 200 | eprint = {2112.09924}, 201 | timestamp = {Tue, 04 Jan 2022 15:59:27 +0100}, 202 | biburl = {https://dblp.org/rec/journals/corr/abs-2112-09924.bib}, 203 | bibsource = {dblp computer science bibliography, https://dblp.org} 204 | } 205 | ``` 206 | 207 | You can access the paper [here](https://arxiv.org/abs/2112.09924). 208 | 209 | # License 210 | `distributed-faiss` is released under the CC-BY-NC 4.0 license. See the `LICENSE` file for details. 211 | -------------------------------------------------------------------------------- /distributed_faiss/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import itertools 10 | import logging 11 | import os 12 | import random 13 | import time 14 | from multiprocessing.dummy import Pool as ThreadPool 15 | from typing import List, Tuple 16 | 17 | import faiss 18 | import numpy as np 19 | 20 | from distributed_faiss.index_cfg import IndexCfg 21 | from . import rpc 22 | from .index_state import IndexState 23 | 24 | logger = logging.getLogger() 25 | 26 | from typing import Optional 27 | 28 | 29 | class ResultHeap: 30 | """Accumulate query results from a sliced dataset. The final result will 31 | be in self.D, self.I. Using faiss.float_maxheap_array_t()""" 32 | 33 | def __init__(self, nq, k): 34 | "nq: number of query vectors, k: number of results per query" 35 | self.I = np.zeros((nq, k), dtype="int64") 36 | self.D = np.zeros((nq, k), dtype="float32") 37 | self.nq, self.k = nq, k 38 | heaps = faiss.float_maxheap_array_t() 39 | 40 | heaps.k = k 41 | heaps.nh = nq 42 | heaps.val = faiss.swig_ptr(self.D) 43 | heaps.ids = faiss.swig_ptr(self.I) 44 | heaps.heapify() 45 | self.heaps = heaps 46 | 47 | def add_result(self, D, I): 48 | """D, I do not need to be in a particular order (heap or sorted)""" 49 | assert D.shape == (self.nq, self.k) 50 | assert I.shape == (self.nq, self.k) 51 | self.heaps.addn_with_ids(self.k, faiss.swig_ptr(D), faiss.swig_ptr(I), self.k) 52 | 53 | def finalize(self): 54 | self.heaps.reorder() 55 | 56 | 57 | class IndexClient: 58 | """ 59 | Manages a set of distance sub-indexes. The sub_indexes search a 60 | subset of the inverted lists. Searches are merged afterwards 61 | """ 62 | 63 | def __init__(self, server_list_path: str, cfg_path: Optional[str] = None): 64 | """connect to a series of (host, port) pairs""" 65 | machine_ports = IndexClient.read_server_list(server_list_path) 66 | self.sub_indexes = IndexClient.setup_connection(machine_ports) 67 | self.num_indexes = len(self.sub_indexes) 68 | 69 | # index_rank_to_id is a map between logical node id and sub-indexes list id 70 | # it is useful for debuggind 71 | # might be useful for a better load balancing upon index restart with failed/lagged behind nodes 72 | index_ranks = [idx.get_rank() for idx in self.sub_indexes] 73 | logger.info("index_ranks %s", index_ranks) 74 | self.index_rank_to_id = { 75 | index_rank: index_id for index_id, index_rank in enumerate(index_ranks) 76 | } 77 | logger.info("index_rank_to_id %s", self.index_rank_to_id) 78 | 79 | # pool of threads. Each thread manages one sub-index. 80 | self.pool = ThreadPool(self.num_indexes) 81 | self.verbose = False 82 | self.cur_server_ids = {} 83 | 84 | random.seed(time.time()) 85 | self.cfg = IndexCfg.from_json(cfg_path) if cfg_path is not None else None 86 | 87 | @staticmethod 88 | def read_server_list( 89 | server_list_path, 90 | initial_timeout=0.1, 91 | backoff_factor=1.5, 92 | total_max_timeout=7200, 93 | ) -> List[Tuple[str, int]]: 94 | time_waited = 0 95 | while True: 96 | with open(server_list_path) as f: 97 | res = [] # list of [(hostname, port), ...] 98 | for idx, line in enumerate(f): 99 | if idx == 0: 100 | num_servers = int(line) 101 | continue 102 | res.append((str(line.split(",")[0]), int(line.split(",")[1]))) 103 | 104 | msg = f"{num_servers} != {len(res)} in server list {server_list_path}." 105 | if num_servers != len(res): 106 | print( 107 | msg 108 | + f" Waiting {round(initial_timeout * 100) / 100} seconds for servers to load..." 109 | ) 110 | time.sleep(initial_timeout) 111 | 112 | if time_waited + initial_timeout >= total_max_timeout: 113 | break 114 | time_waited += initial_timeout 115 | initial_timeout *= backoff_factor 116 | 117 | assert num_servers == len(res), ( 118 | msg + f" Timed out after waiting {round(time_waited * 100) / 100} seconds" 119 | ) 120 | return res 121 | 122 | @staticmethod 123 | def setup_connection(machine_ports) -> List[rpc.Client]: 124 | sub_indexes = [] 125 | for idx, machine_port in enumerate(machine_ports): 126 | sub_indexes.append(rpc.Client(idx, machine_port[0], machine_port[1], False)) 127 | return sub_indexes 128 | 129 | def drop_index(self, index_id: str): 130 | self.pool.map(lambda idx: idx.drop_index(index_id), self.sub_indexes) 131 | 132 | def save_index(self, index_id: str): 133 | self.pool.map(lambda idx: idx.save_index(index_id), self.sub_indexes) 134 | 135 | def load_index( 136 | self, 137 | index_id: str, 138 | cfg: Optional[IndexCfg] = None, 139 | force_reload: bool = True, 140 | ) -> bool: 141 | def setup_cfg(cfg: Optional[IndexCfg]): 142 | if cfg is None: 143 | config_paths = self.pool.map( 144 | lambda idx: idx.get_config_path(index_id), self.sub_indexes 145 | ) 146 | if len(config_paths) > 0 and os.path.isfile(config_paths[0]): 147 | cfg = IndexCfg.from_json(config_paths[0]) 148 | else: 149 | cfg = IndexCfg() 150 | return cfg 151 | 152 | if force_reload: 153 | logger.info("Forced index reload") 154 | self.pool.map(lambda idx: idx.drop_index(index_id), self.sub_indexes) 155 | all_loaded = self.pool.map(lambda idx: idx.load_index(index_id, cfg), self.sub_indexes) 156 | # TODO: remove cfg as client instance attribute, it should be specific index only cfg 157 | self.cfg = setup_cfg(cfg) 158 | logger.info(f"Index cfg {self.cfg}") 159 | 160 | if all(all_loaded): 161 | logger.info(f"The index {index_id} is loaded.") 162 | return True 163 | if any(all_loaded): 164 | logger.warning(f"Some server nodes can't load index: {all_loaded}") 165 | return False 166 | 167 | def create_index(self, index_id: str, cfg: Optional[IndexCfg] = None): 168 | if cfg is not None: 169 | self.cfg = cfg 170 | if self.cfg is None: 171 | self.cfg = IndexCfg() 172 | return self.pool.map(lambda idx: idx.create_index(index_id, self.cfg), self.sub_indexes) 173 | 174 | def add_index_data( 175 | self, 176 | index_id: str, 177 | embeddings: np.array, 178 | metadata: Optional[List[object]] = None, 179 | train_async_if_triggered: bool = True, 180 | ) -> None: 181 | """ 182 | Randomly select index of server to write data to first time 183 | writing to it. later use round-robin to select index server 184 | to balance the load. 185 | """ 186 | if index_id not in self.cur_server_ids: 187 | self.cur_server_ids[index_id] = random.randint(0, self.num_indexes - 1) 188 | cur_server_id = self.cur_server_ids[index_id] 189 | self.sub_indexes[cur_server_id].add_index_data( 190 | index_id, embeddings, metadata, train_async_if_triggered 191 | ) 192 | self.cur_server_ids[index_id] = (self.cur_server_ids[index_id] + 1) % self.num_indexes 193 | 194 | def sync_train(self, index_id: str) -> None: 195 | self.pool.map(lambda idx: idx.sync_train(index_id), self.sub_indexes) 196 | 197 | def async_train(self, index_id: str): 198 | self.pool.map(lambda idx: idx.sync_train(index_id), self.sub_indexes) 199 | 200 | def search( 201 | self, query, topk: int, index_id: str, return_embeddings: bool = False 202 | ) -> Tuple[np.ndarray, List]: 203 | """Call idx.search on each rpc client and then aggregates results using heap.""" 204 | 205 | q_size = query.shape[0] 206 | maximize_metric: bool = self.cfg.metric == "dot" 207 | results = self.pool.imap( 208 | lambda idx: idx.search(index_id, query, topk, return_embeddings), self.sub_indexes 209 | ) 210 | return self._aggregate_results(results, topk, q_size, maximize_metric, return_embeddings) 211 | 212 | # TODO: make filter a generic custom function that takes meta and sample's score and return bool value 213 | def search_with_filter( 214 | self, 215 | query: np.array, 216 | top_k: int, 217 | index_id: str, 218 | filter_pos: int = -1, 219 | filter_value=None, 220 | ) -> Tuple[np.array, List[List[object]]]: 221 | 222 | # TODO: get from cfg ? 223 | filter_top_factor = 3 # search for x times more results 224 | actual_top_k = filter_top_factor * top_k if filter_pos >= 0 else top_k 225 | (scores, meta) = self.search(query, actual_top_k, index_id) 226 | 227 | if filter_pos < 0: 228 | return scores, meta 229 | 230 | def _do_filter(scores: np.array, results_meta: List[List[Tuple]]): 231 | re_query_ids = [] 232 | new_results = [] 233 | new_scores = [] 234 | 235 | for i, meta_list in enumerate(results_meta): 236 | sample_filtered_meta = [] 237 | sample_filtered_scores = [] 238 | for j, meta in enumerate(meta_list): 239 | if not meta: 240 | logger.warning("No meta for j=%d, score=%s", j, scores[i, j]) 241 | continue 242 | if len(meta) > filter_pos and meta[filter_pos] != filter_value: 243 | sample_filtered_meta.append(meta) 244 | sample_filtered_scores.append(scores[i, j]) 245 | if len(sample_filtered_meta) >= top_k: 246 | break 247 | 248 | if len(sample_filtered_meta) < top_k: 249 | re_query_ids.append(i) 250 | 251 | new_results.append(sample_filtered_meta) 252 | new_scores.append( 253 | np.concatenate([s.reshape(-1, 1) for s in sample_filtered_scores], axis=0) 254 | ) 255 | # TODO: filtered scores list may be of different lengths so we can't concatenate them here generally 256 | # new_scores = np.concatenate(new_scores, axis=0) 257 | return new_scores, new_results, re_query_ids 258 | 259 | new_scores, new_results_meta, re_query_ids = _do_filter(scores, meta) 260 | logger.info(f"{len(re_query_ids)} samples return less than {top_k} results after filtering") 261 | # TODO: search again with larger top_k for queries in re_query_ids 262 | 263 | return new_scores, new_results_meta 264 | 265 | @staticmethod 266 | def _aggregate_results( 267 | results: List[Tuple], 268 | topk: int, 269 | q_size: int, 270 | maximize_metric: bool, 271 | return_embeddings: bool, 272 | ): 273 | meta = [] 274 | embs = [] 275 | cur_idx = 0 276 | 277 | def to_matrix(l, n): 278 | return [l[i : i + n] for i in range(0, len(l), n)] 279 | 280 | res_heap = ResultHeap(q_size, topk) 281 | for DI, MetaI, e in results: 282 | # Two hacks: 283 | # 1) for DOT, search for -D because we want to find max dot product, not min 284 | # 2) create new indexed to later map metadata to the best indexes 285 | merged_meta = list(itertools.chain(*MetaI)) 286 | meta.extend(merged_meta) 287 | if return_embeddings: 288 | merged_embs = list(itertools.chain(*e)) 289 | embs.extend(merged_embs) 290 | Ii = np.reshape(np.arange(cur_idx, cur_idx + q_size * topk), (q_size, topk)) 291 | if maximize_metric: 292 | res_heap.add_result(-DI, Ii) 293 | else: 294 | res_heap.add_result(DI, Ii) 295 | cur_idx += q_size * topk 296 | res_heap.finalize() 297 | ids = np.reshape(res_heap.I, (-1,)).tolist() 298 | selected_meta = [meta[i] for i in ids] 299 | if return_embeddings: 300 | selected_embs = [embs[i] for i in ids] 301 | 302 | return ( 303 | ( 304 | res_heap.D, 305 | to_matrix(selected_meta, res_heap.D.shape[1]), 306 | to_matrix(selected_embs, res_heap.D.shape[1]), 307 | ) 308 | if return_embeddings 309 | else (res_heap.D, to_matrix(selected_meta, res_heap.D.shape[1])) 310 | ) 311 | 312 | def get_centroids(self, index_id: str): 313 | return self.pool.map(lambda idx: idx.get_centroids(index_id), self.sub_indexes) 314 | 315 | def set_nprobe(self, index_id: str, nprobe: int): 316 | return self.pool.map(lambda idx: idx.set_nprobe(index_id, nprobe), self.sub_indexes) 317 | 318 | def get_state(self, index_id: str) -> IndexState: 319 | states = self.pool.map(lambda idx: idx.get_state(index_id), self.sub_indexes) 320 | logging.info("Index nodes states %s", states) 321 | return IndexState.get_aggregated_states(states) 322 | 323 | def add_buffer_to_index( 324 | self, 325 | index_id: str, 326 | ): 327 | self.pool.map(lambda idx: idx.add_buffer_to_index(index_id), self.sub_indexes) 328 | 329 | def get_ntotal(self, index_id: str) -> None: 330 | return sum(self.pool.map(lambda idx: idx.get_ntotal(index_id), self.sub_indexes)) 331 | 332 | def get_ids(self, index_id: str) -> set: 333 | id_set_list = self.pool.map(lambda idx: idx.get_ids(index_id), self.sub_indexes) 334 | for i in range(len(id_set_list)): 335 | logging.info("ids i=%s len=%s", i, len(id_set_list[i])) 336 | return set().union(*id_set_list) 337 | 338 | def set_omp_num_threads(self, num_threads: int) -> None: 339 | self.pool.map(lambda idx: idx.set_omp_num_threads(num_threads), self.sub_indexes) 340 | 341 | def close(self): 342 | [index_conn.close() for index_conn in self.sub_indexes] 343 | 344 | def get_num_servers(self): 345 | return self.num_indexes 346 | -------------------------------------------------------------------------------- /distributed_faiss/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import _thread 8 | import argparse 9 | import logging 10 | import os 11 | import pathlib 12 | import pickle 13 | import selectors 14 | import socket 15 | import sys 16 | import threading 17 | import traceback 18 | import types 19 | from typing import List, Tuple, Optional 20 | 21 | import numpy as np 22 | 23 | from distributed_faiss.index import Index 24 | from distributed_faiss.index_cfg import IndexCfg 25 | from distributed_faiss.index_state import IndexState 26 | from distributed_faiss.rpc import FileSock, ClientExit, DEFAULT_PORT 27 | 28 | logger = logging.getLogger() 29 | logger.setLevel(logging.INFO) 30 | if logger.hasHandlers(): 31 | logger.handlers.clear() 32 | log_formatter = logging.Formatter("[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s") 33 | console = logging.StreamHandler() 34 | console.setFormatter(log_formatter) 35 | logger.addHandler(console) 36 | 37 | 38 | class IndexServer: 39 | def __init__(self, rank: int, index_storage_dir): 40 | self.indexes = {} 41 | self.indexes_lock = threading.Lock() 42 | self.rank = rank 43 | self.socket = None 44 | logger.info(f"IndexServer saving to {index_storage_dir} , rank={rank}") 45 | self.index_storage_dir = index_storage_dir 46 | 47 | def save_index(self, index_id: str): 48 | """Save index, metadata and buffer to {storage_dir}/{index_id}/{rank}/""" 49 | with self.indexes_lock: 50 | if index_id not in self.indexes: 51 | raise RuntimeError(f"Index with id={index_id} is not initialized") 52 | index = self.indexes[index_id] 53 | index.save() 54 | 55 | def load_index(self, index_id: str = "default", cfg: IndexCfg = None) -> bool: 56 | """Load index, metadata and buffer from {storage_dir}/{index_id}/{rank}/""" 57 | 58 | index_dir = self._get_storage_dir(index_id, cfg) 59 | if cfg: 60 | cfg.index_storage_dir = index_dir 61 | 62 | with self.indexes_lock: 63 | if index_id in self.indexes: 64 | logging.info("Index already exists: %s", index_id) 65 | if cfg: 66 | self.indexes[index_id].upd_cfg(cfg) 67 | return True 68 | else: 69 | logger.info(f"Loading index from {index_dir}") 70 | index = Index.from_storage_dir(index_dir, cfg) 71 | if index: 72 | self.indexes[index_id] = index 73 | logger.info(f"Index id={index_id} has been loaded") 74 | return True 75 | else: 76 | logger.info(f"Can't load index") 77 | return False 78 | 79 | def get_ids(self, index_id: str = "default") -> set: 80 | with self.indexes_lock: 81 | index = self.indexes[index_id] 82 | return index.get_ids() 83 | 84 | def get_rank(self) -> int: 85 | return self.rank 86 | 87 | def index_loaded(self, index_id: str) -> bool: 88 | """Check if an index with a given index_id exists and if it is trained""" 89 | with self.indexes_lock: 90 | return ( 91 | index_id in self.indexes 92 | and self.indexes[index_id].get_state() == IndexState.TRAINED 93 | ) 94 | 95 | def start_blocking(self, port=DEFAULT_PORT, v6=False, load_index=False): 96 | if load_index: 97 | self.load_index() 98 | HOST = "" # Symbolic name meaning the local host 99 | socktype = socket.AF_INET6 if v6 else socket.AF_INET 100 | s = socket.socket(socktype, socket.SOCK_STREAM) 101 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 102 | logger.info("bind %s:%d", HOST, port) 103 | s.bind((HOST, port)) 104 | s.listen(10) 105 | 106 | while True: 107 | try: 108 | conn, addr = s.accept() 109 | except socket.error as e: 110 | if e[1] == "Interrupted system call": 111 | continue 112 | raise 113 | 114 | logger.info("Connected by %s", addr) 115 | tid = _thread.start_new_thread(self.exec_loop_blocking, (conn,)) 116 | 117 | def exec_loop_blocking(self, socket): 118 | """main execution loop. Loops and handles exit states""" 119 | 120 | fs = FileSock(socket) 121 | logger.info("in exec_loop") 122 | try: 123 | while True: 124 | self.one_function_blocking(socket, fs) 125 | except ClientExit as e: 126 | logger.info("ClientExit %s", e) 127 | except socket.error as e: 128 | logger.error("socket error %s", e) 129 | except EOFError: 130 | logger.error("EOF during communication") 131 | except BaseException: 132 | # unexpected 133 | traceback.print_exc(50, sys.stderr) 134 | sys.exit(1) 135 | logger.info("exit server") 136 | 137 | def start(self, port=DEFAULT_PORT, v6=False): 138 | sel = selectors.DefaultSelector() 139 | HOST = "" 140 | socktype = socket.AF_INET6 if v6 else socket.AF_INET 141 | s = socket.socket(socktype, socket.SOCK_STREAM) 142 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 143 | 144 | logger.info("bind %s:%d", HOST, port) 145 | s.bind((HOST, port)) 146 | s.listen(10) 147 | s.setblocking(False) 148 | sel.register(s, selectors.EVENT_READ, data=None) 149 | self.socket = s 150 | while True: 151 | 152 | try: 153 | events = sel.select(timeout=None) 154 | for key, mask in events: 155 | if key.data is None: 156 | logger.info(" ") 157 | self._accept_wrapper(key.fileobj, sel) 158 | else: 159 | self._service_connection(key, mask, sel) 160 | except Exception as e: 161 | if hasattr(e, "message"): 162 | print(e.message) 163 | else: 164 | print(e) 165 | break 166 | 167 | logger.info("Server service loop ended") 168 | 169 | def one_function(self, socket: FileSock) -> int: 170 | """ 171 | - server sends result: (rid,st,ret) 172 | st = None, or exception if there was during execution 173 | ret = return value or None if st!=None 174 | """ 175 | fname = None 176 | args = None 177 | try: 178 | client_request = pickle.load(socket) 179 | read_bytes = socket.last_read_len 180 | # logger.info('read_bytes %s', read_bytes) 181 | if read_bytes > 0: 182 | (fname, args) = client_request 183 | logger.info("fname %s", fname) 184 | except EOFError: 185 | return 0 186 | 187 | st = None 188 | ret = None 189 | f = None 190 | 191 | try: 192 | f = getattr(self, fname) 193 | except AttributeError: 194 | st = AttributeError("unknown method " + fname) 195 | logger.error("unknown method %s", fname) 196 | try: 197 | # TODO: decide if a separate thread is needed 198 | if f == self.add_index_data: # same thread 199 | ret = f(*args) 200 | else: 201 | ret = f(*args) 202 | 203 | except Exception as e: 204 | st = "".join(traceback.format_tb(sys.exc_info()[2])) + str(e) 205 | logger.error("exception in method %s", f) 206 | logger.error("%s", st) 207 | 208 | try: 209 | pickle.dump((st, ret), socket, protocol=4) 210 | except EOFError: 211 | raise ClientExit("function return") 212 | 213 | return read_bytes 214 | 215 | def one_function_blocking(self, socket, fs: FileSock): 216 | try: 217 | (fname, args) = pickle.load(fs) 218 | except EOFError: 219 | raise ClientExit("read args") 220 | logger.info("executing method %s", fname) 221 | st = None 222 | ret = None 223 | try: 224 | f = getattr(self, fname) 225 | except AttributeError: 226 | st = "unknown method " + fname 227 | logger.error("unknown method %s", fname) 228 | try: 229 | ret = f(*args) 230 | except Exception as e: 231 | st = "".join(traceback.format_tb(sys.exc_info()[2])) + str(e) 232 | logger.error("exception in method: %s", traceback.print_exc(50)) 233 | 234 | logger.info("return") 235 | try: 236 | pickle.dump((st, ret), fs, protocol=4) 237 | except EOFError: 238 | raise ClientExit("function return") 239 | 240 | def create_index(self, index_id: str, cfg: IndexCfg): 241 | index_storage_dir = self._get_storage_dir(index_id, cfg) 242 | cfg.index_storage_dir = index_storage_dir 243 | pathlib.Path(index_storage_dir).mkdir(parents=True, exist_ok=True) 244 | logger.info(f"Set index save dir to: {index_storage_dir}") 245 | 246 | with self.indexes_lock: 247 | if index_id not in self.indexes: 248 | self.indexes[index_id] = Index(cfg) 249 | logger.info(f"Created new index {index_id}") 250 | logger.info(f"CFG: {str(cfg)}") 251 | return True 252 | logger.info(f"Index {index_id} already exists") 253 | return False 254 | 255 | def add_index_data( 256 | self, 257 | index_id: str, 258 | embeddings: np.array, 259 | metadata: Optional[List[object]] = None, 260 | train_async_if_triggered: bool = True, 261 | ): 262 | logger.info("adding embeddings idx=%s, embeddings=%s", index_id, embeddings.shape) 263 | 264 | with self.indexes_lock: 265 | index = self.indexes[index_id] 266 | index.add_batch(embeddings, metadata, train_async_if_triggered) 267 | 268 | def get_aggregated_ntotal(self, index_id: str) -> int: 269 | logger.info("getting current buffer data size idx=%s") 270 | with self.indexes_lock: 271 | index = self.indexes[index_id] 272 | return index.get_idx_data_num()[0] 273 | 274 | def stop(self): 275 | logger.info("Stopping server ...") 276 | if self.socket: 277 | self.socket.shutdown(socket.SHUT_RDWR) 278 | self.socket.close() 279 | self.socket = None 280 | 281 | for index_id in self.indexes: 282 | self.indexes[index_id].save() 283 | 284 | def get_ntotal(self, index_id: str) -> int: 285 | with self.indexes_lock: 286 | if index_id not in self.indexes: 287 | return 0 288 | index = self.indexes[index_id] 289 | index_data_num = index.get_idx_data_num() 290 | logger.info("Index id=%s data size idx=%s", index_id, index_data_num) 291 | return index_data_num[1] 292 | 293 | def drop_index(self, index_id: str): 294 | with self.indexes_lock: 295 | logger.info(f"Dropping index {index_id}") 296 | if index_id in self.indexes: 297 | del self.indexes[index_id] 298 | 299 | def sync_train(self, index_id: str): 300 | """ 301 | Warning, this method will block the main clients' service loop so all client requests will be stalled 302 | :return: 0 as of now, 303 | """ 304 | # TODO: don't block service loop, allocate a separate thread ? 305 | logger.info(f"Sync train started") 306 | self._train_index(index_id) 307 | 308 | def async_train(self, index_id: str): 309 | class IndexTrainer(threading.Thread): 310 | def __init__(self, server: IndexServer, *args, **kwargs): 311 | super(IndexTrainer, self).__init__(*args, **kwargs) 312 | self.server = server 313 | 314 | def run(self): 315 | self.server._train_index(index_id) 316 | 317 | t = IndexTrainer(self) 318 | t.run() 319 | 320 | def search( 321 | self, index_id: str, query_batch: np.array, top_k: int, return_embeddings: bool 322 | ) -> Tuple: 323 | logger.info(f"Query idx={index_id}, query={query_batch.shape}") 324 | index = self._get_index(index_id) 325 | assert not isinstance(index, set) 326 | r = index.search(query_batch, top_k=top_k, return_embeddings=return_embeddings) 327 | return r 328 | 329 | def get_centroids(self, index_id: str): 330 | index = self._get_index(index_id) 331 | return index.get_centroids() 332 | 333 | def set_nprobe(self, index_id: str, nprobe: int): 334 | index = self._get_index(index_id) 335 | return index.set_nprobe(nprobe) 336 | 337 | def get_state(self, index_id: str): 338 | index = self._get_index(index_id) 339 | return index.get_state() 340 | 341 | def add_buffer_to_index(self, index_id: str): 342 | index = self._get_index(index_id) 343 | return index.add_buffer_to_index() 344 | 345 | def get_config_path(self, index_id: str): 346 | cfg_path = os.path.join(self.index_storage_dir, index_id, str(self.rank), "cfg.json") 347 | return cfg_path 348 | 349 | def _get_index(self, index_id: str): 350 | with self.indexes_lock: 351 | if index_id not in self.indexes: 352 | raise RuntimeError("Server has no index with id={}".format(index_id)) 353 | return self.indexes[index_id] 354 | 355 | def _accept_wrapper(self, sock, sel): 356 | conn, addr = sock.accept() # Should be ready to read 357 | logger.info("accepted connection from %s", addr) 358 | conn.setblocking(False) 359 | data = types.SimpleNamespace(addr=addr, inb=b"", outb=b"") 360 | events = selectors.EVENT_READ | selectors.EVENT_WRITE # TODO: remove write? 361 | sel.register(conn, events, data=data) 362 | 363 | def _service_connection(self, key, mask, sel): 364 | sock = key.fileobj 365 | data = key.data 366 | 367 | fs = FileSock(sock) 368 | 369 | if mask & selectors.EVENT_READ: 370 | read_bytes = self.one_function(fs) 371 | 372 | if read_bytes == 0: 373 | logger.info("closing connection to %s", data.addr) 374 | sel.unregister(sock) 375 | sock.close() 376 | 377 | def _train_index(self, index_id: str): 378 | index = self._get_index(index_id) 379 | logger.info(f"Training index {index_id}") 380 | index.train() 381 | 382 | def _get_storage_dir(self, index_id: str, cfg: IndexCfg): 383 | index_storage_dir = cfg.index_storage_dir if cfg else None 384 | if not index_storage_dir: 385 | index_storage_dir = os.path.join(self.index_storage_dir, index_id, str(self.rank)) 386 | else: 387 | index_storage_dir = os.path.join(index_storage_dir, str(self.rank)) 388 | return index_storage_dir 389 | 390 | 391 | def main(): 392 | parser = argparse.ArgumentParser() 393 | 394 | # reader specific params 395 | parser.add_argument("--port", default=DEFAULT_PORT, type=int, help="TBD") 396 | parser.add_argument("--ipv4", default=False, action="store_true", help="force ipv4") 397 | args = parser.parse_args() 398 | logger.info("starting server ...") 399 | server = IndexServer() 400 | server.start(args.port, v6=not args.ipv4) 401 | 402 | 403 | if __name__ == "__main__": 404 | main() 405 | -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import unittest 9 | import time 10 | import torch 11 | import random 12 | import string 13 | import _thread 14 | import tempfile 15 | from pathlib import Path 16 | from typing import List 17 | 18 | from distributed_faiss.index_cfg import IndexCfg 19 | from distributed_faiss.server import IndexServer 20 | from distributed_faiss.client import IndexClient 21 | from distributed_faiss.index_state import IndexState 22 | 23 | import numpy as np 24 | 25 | REPO_HOME = Path(".").parent 26 | 27 | 28 | def get_rand_meta(ndocs, nchars) -> List[str]: 29 | """Return a random string of length n""" 30 | return [ 31 | "".join(random.choices(string.ascii_uppercase + string.digits, k=nchars)) 32 | for _ in range(ndocs) 33 | ] 34 | 35 | 36 | def load_batched_data( 37 | client, 38 | index_id="lang_en", 39 | max_num_docs_per_batch=100, 40 | embed_dim=512, 41 | meta_length=5, 42 | num_batches=10, 43 | ): 44 | for i in range(num_batches): 45 | num_docs_per_batch = random.randint(1, max_num_docs_per_batch) 46 | embeddings = torch.rand(num_docs_per_batch, embed_dim).numpy() 47 | meta = get_rand_meta(num_docs_per_batch, meta_length) 48 | client.add_index_data(index_id, embeddings, meta, train_async_if_triggered=False) 49 | 50 | 51 | class TestIntegration(unittest.TestCase): 52 | # TODO: add mock tests for all the async calls 53 | 54 | @classmethod 55 | def setUpClass(self): 56 | self.multi_server_save_dir = tempfile.TemporaryDirectory() 57 | self.single_server_save_dir = tempfile.TemporaryDirectory() 58 | self.embed_dim = 512 59 | self.index_id = 0 60 | 61 | num_servers = 4 62 | self.multi_servers = [] 63 | self.multi_ports = [1237, 1238, 1239, 1240] 64 | random.seed(0) 65 | for server_id, port in enumerate(self.multi_ports): 66 | server = IndexServer(server_id, index_storage_dir=self.multi_server_save_dir.name) 67 | _thread.start_new_thread(server.start_blocking, (port,)) 68 | self.multi_servers.append(server) 69 | 70 | self.single_server_port = 1241 71 | self.single_server = IndexServer(0, index_storage_dir=self.single_server_save_dir.name) 72 | _thread.start_new_thread(self.single_server.start_blocking, (self.single_server_port,)) 73 | print("Done setting up test") 74 | 75 | @classmethod 76 | def tearDownClass(self): 77 | [s.stop() for s in self.multi_servers] 78 | self.multi_server_save_dir.cleanup() 79 | self.single_server_save_dir.cleanup() 80 | print("Done tearing down test") 81 | 82 | def setUp(self): 83 | # get unique index_id for each test 84 | self.index_id = self._testMethodName 85 | 86 | def start_server(self, port): 87 | single_server = IndexServer(0, index_storage_dir="test_tmp") 88 | _thread.start_new_thread(single_server.start_blocking, (port,)) 89 | self.servers.append(single_server) 90 | # TODO: figure out how to close 91 | 92 | @staticmethod 93 | def make_client(server_port): 94 | with tempfile.NamedTemporaryFile() as fp: 95 | fp.write(f"{1}\n".encode()) 96 | fp.write(f"localhost,{server_port}\n".encode()) 97 | fp.seek(0) 98 | fp.read() 99 | single_client = IndexClient(fp.name) 100 | return single_client 101 | 102 | @staticmethod 103 | def make_clients(num_clients, server_ports): 104 | clients = [] 105 | num_servers = len(server_ports) 106 | with tempfile.NamedTemporaryFile() as fp: 107 | fp.write(f"{num_servers}\n".encode()) 108 | for i in range(num_servers): 109 | fp.write(f"localhost,{server_ports[i]}\n".encode()) 110 | fp.seek(0) 111 | fp.read() 112 | for i in range(num_clients): 113 | client = IndexClient(fp.name) 114 | clients.append(client) 115 | return clients 116 | 117 | def test_train_num_honored(self): 118 | train_num = 10 119 | not_trained = IndexState.NOT_TRAINED 120 | cfg = IndexCfg(index_builder_type="flat", dim=self.embed_dim, train_num=train_num) 121 | client = self.make_client(self.single_server_port) 122 | client.create_index(self.index_id, cfg) 123 | 124 | def add_data(ndoc): 125 | embeddings = torch.rand(ndoc, 512).numpy() 126 | meta = get_rand_meta(ndoc, 5) 127 | client.add_index_data(self.index_id, embeddings, meta, train_async_if_triggered=False) 128 | return client.get_state(self.index_id) 129 | 130 | state = add_data(train_num - 1) 131 | assert state == not_trained, f"{state} != {not_trained} after only 9 docs" 132 | state = add_data(1) 133 | assert state != not_trained, f"{state} == {not_trained} after train_num added" 134 | 135 | results = client.search(torch.rand(4, 512).numpy(), 4, self.index_id) 136 | assert results[0].shape == (4, 4) 137 | client.save_index(self.index_id) 138 | results = client.search(torch.rand(4, 512).numpy(), 4, self.index_id) 139 | assert results[0].shape == (4, 4) 140 | client.close() 141 | # Test save/load behavior 142 | client2 = self.make_client(self.single_server_port) 143 | client2.load_index(self.index_id, cfg) 144 | assert client2.get_state(self.index_id) == IndexState.TRAINED 145 | results = client2.search(torch.rand(4, 512).numpy(), 4, self.index_id) 146 | client2.close() 147 | 148 | def test_l2_dist(self): 149 | train_num = 10 150 | cfg = IndexCfg( 151 | index_builder_type="flat", 152 | dim=self.embed_dim, 153 | metric="l2", 154 | centroids=4, 155 | infer_centroids=False, 156 | ) 157 | 158 | client = self.make_client(self.single_server_port) 159 | client.create_index(self.index_id, cfg) 160 | 161 | def add_data(ndoc): 162 | embeddings = torch.rand(ndoc, 512).numpy() 163 | meta = get_rand_meta(ndoc, 5) 164 | client.add_index_data(self.index_id, embeddings, meta, train_async_if_triggered=False) 165 | client.sync_train(self.index_id) 166 | return client.get_state(self.index_id) 167 | 168 | state = add_data(train_num) 169 | while True: 170 | state = client.get_state(self.index_id) 171 | print("Server state ", state) 172 | if state == IndexState.TRAINED: 173 | break 174 | time.sleep(2) 175 | assert state == IndexState.TRAINED 176 | _ = client.search(torch.rand(4, 512).numpy(), 4, self.index_id) 177 | 178 | # TODO: check results differ if metric='dot' 179 | client.close() 180 | 181 | def test_result_aggregation(self): 182 | 183 | mock_results = [ 184 | ( 185 | np.array([[12.1, 13.2, 13.3, 14.3]], dtype=np.float32), 186 | [[1465, 1460, 443197, 1340]], 187 | None, 188 | ), 189 | ( 190 | np.array([[8.1, 12.6, 13.1, 17.4]], dtype=np.float32), 191 | [[0, 14, 3, 1]], 192 | None, 193 | ), 194 | ] 195 | D, i_minimize = IndexClient._aggregate_results(mock_results, 4, 1, False, False) 196 | _, i_maximize = IndexClient._aggregate_results(mock_results, 4, 1, True, False) 197 | assert i_maximize != i_minimize 198 | assert i_minimize[0][0] == 0 # The smallest distance 199 | assert D[0][0] < D[0][1] 200 | 201 | assert i_maximize[0][0] == 1 # the largest distance 202 | 203 | assert 0 in i_minimize[0] 204 | 205 | def test_search_quality_same_for_multiple_clients(self): 206 | # start single flat index as souce of truth 207 | embed_dim = 512 208 | num_docs_per_query = 16 209 | num_batches = 4 210 | topk_per_search = 5 211 | meta_length = 5 212 | cfg = IndexCfg(index_builder_type="flat", dim=embed_dim) 213 | 214 | single_client = self.make_client(self.single_server_port) 215 | single_client.create_index(self.index_id, cfg) 216 | 217 | clients = self.make_clients(4, self.multi_ports) 218 | 219 | self.assertEqual(single_client.get_state(self.index_id), IndexState.NOT_TRAINED) 220 | 221 | for client in clients: 222 | client.create_index(self.index_id, cfg) 223 | self.assertEqual(client.get_state(self.index_id), IndexState.NOT_TRAINED) 224 | num_batches = random.randint(1, 4) 225 | for _ in range(num_batches): 226 | num_docs_per_batch = random.randint(1, 12800) 227 | embeddings = torch.rand(num_docs_per_batch, embed_dim).numpy() 228 | meta = get_rand_meta(num_docs_per_batch, meta_length) 229 | client.add_index_data( 230 | self.index_id, embeddings, meta, train_async_if_triggered=False 231 | ) 232 | single_client.add_index_data( 233 | self.index_id, embeddings, meta, train_async_if_triggered=False 234 | ) 235 | self.assertEqual(client.get_state(self.index_id), IndexState.NOT_TRAINED) 236 | 237 | # we added training data but did not start training yet 238 | self.assertEqual(client.get_state(self.index_id), IndexState.NOT_TRAINED) 239 | clients[0].sync_train(self.index_id) 240 | single_client.sync_train(self.index_id) 241 | 242 | while True: 243 | state = clients[0].get_state(self.index_id) 244 | print("Server state ", state) 245 | if state == IndexState.TRAINED: 246 | break 247 | time.sleep(2) 248 | 249 | while True: 250 | state = single_client.get_state(self.index_id) 251 | print("Server state ", state) 252 | if state == IndexState.TRAINED: 253 | break 254 | time.sleep(2) 255 | 256 | self.assertEqual( 257 | clients[0].get_ntotal(self.index_id), single_client.get_ntotal(self.index_id) 258 | ) 259 | query = torch.rand(num_docs_per_query, embed_dim).numpy() 260 | scores_aggr, meta_aggr = clients[0].search(query, topk_per_search, self.index_id) 261 | scores_single, meta_single = single_client.search(query, topk_per_search, self.index_id) 262 | self.assertTrue((scores_aggr == scores_single).all()) 263 | self.assertEqual(len(meta_aggr), len(meta_single)) 264 | self.assertEqual(meta_aggr, meta_single) 265 | single_client.close() 266 | 267 | def test_index_client_multiple_server(self): 268 | embed_dim = 512 269 | num_docs_per_batch = 12800 270 | num_docs_per_query = 16 271 | num_batches = 4 272 | topk_per_search = 5 273 | meta_length = 5 274 | 275 | clients = self.make_clients(4, self.multi_ports) 276 | 277 | cfg = IndexCfg("flat", dim=embed_dim) 278 | assert cfg.index_builder_type == "flat" 279 | 280 | for client in clients: 281 | client.create_index(self.index_id, cfg) 282 | self.assertEqual(client.get_state(self.index_id), IndexState.NOT_TRAINED) 283 | for i in range(num_batches): 284 | embeddings = torch.rand(num_docs_per_batch, embed_dim).numpy() 285 | meta = get_rand_meta(num_docs_per_batch, meta_length) 286 | client.add_index_data( 287 | self.index_id, embeddings, meta, train_async_if_triggered=False 288 | ) 289 | # we added training data but did not start training yet 290 | self.assertEqual(client.get_state(self.index_id), IndexState.NOT_TRAINED) 291 | 292 | clients[0].sync_train(self.index_id) 293 | for client in clients: 294 | for i in range(num_batches): 295 | embeddings = torch.rand(num_docs_per_batch, embed_dim).numpy() 296 | meta = get_rand_meta(num_docs_per_batch, meta_length) 297 | client.add_index_data( 298 | self.index_id, embeddings, meta, train_async_if_triggered=False 299 | ) 300 | for client in clients: 301 | while True: 302 | state = client.get_state(self.index_id) 303 | print("Server state ", state) 304 | if state == IndexState.TRAINED: 305 | break 306 | time.sleep(2) 307 | 308 | for server in self.multi_servers: 309 | # Make sure that data is ballanced among servers 310 | self.assertEqual( 311 | server.get_ntotal(self.index_id), 312 | 2 * num_batches * num_docs_per_batch * len(clients) / len(self.multi_servers), 313 | ) 314 | for client in clients: 315 | self.assertEqual(client.get_state(self.index_id), IndexState.TRAINED) 316 | self.assertEqual( 317 | client.get_ntotal(self.index_id), 318 | 2 * num_batches * num_docs_per_batch * len(clients), 319 | ) 320 | self.assertEqual(client.get_ntotal("wrong_id"), 0) 321 | query = torch.rand(num_docs_per_query, embed_dim).numpy() 322 | scores, meta = client.search(query, topk_per_search, self.index_id) 323 | self.assertEqual((num_docs_per_query, topk_per_search), scores.shape) 324 | self.assertEqual(num_docs_per_query, len(meta)) 325 | self.assertEqual(topk_per_search, len(meta[0])) 326 | 327 | clients[0].save_index(self.index_id) 328 | clients[0].drop_index(self.index_id) 329 | for client in clients: 330 | self.assertEqual(client.get_ntotal(self.index_id), 0) 331 | 332 | def test_config_to_file(self): 333 | train_num = 13 334 | nprobe = 16 335 | cfg = IndexCfg( 336 | index_builder_type="flat", dim=self.embed_dim, train_num=train_num, nprobe=nprobe 337 | ) 338 | 339 | client = self.make_client(self.single_server_port) 340 | client.create_index(self.index_id, cfg) 341 | 342 | def add_data(ndoc): 343 | embeddings = torch.rand(ndoc, 512).numpy() 344 | meta = get_rand_meta(ndoc, 5) 345 | client.add_index_data(self.index_id, embeddings, meta, train_async_if_triggered=False) 346 | 347 | add_data(train_num) 348 | while True: 349 | state = client.get_state(self.index_id) 350 | print("Server state ", state) 351 | if state == IndexState.TRAINED: 352 | break 353 | time.sleep(2) 354 | 355 | client.save_index(self.index_id) 356 | cfg_path = os.path.join(self.single_server_save_dir.name, self.index_id, "0", "cfg.json") 357 | print("cfg_path", cfg_path) 358 | assert os.path.isfile(cfg_path) 359 | 360 | client.close() 361 | 362 | # Load config from file 363 | client2 = self.make_client(self.single_server_port) 364 | client2.load_index(self.index_id) 365 | 366 | assert client2.cfg.index_builder_type == "flat" 367 | assert client2.cfg.dim == self.embed_dim 368 | assert client2.cfg.train_num == train_num 369 | assert client2.cfg.nprobe == nprobe 370 | client2.close() 371 | 372 | # Overwrite config at startup 373 | cfg3 = IndexCfg( 374 | index_builder_type="flat", 375 | dim=self.embed_dim, 376 | train_num=train_num, 377 | nprobe=nprobe + 1, 378 | ) 379 | client3 = self.make_client(self.single_server_port) 380 | client3.load_index(self.index_id, cfg3) 381 | assert client3.cfg.index_builder_type == "flat" 382 | assert client3.cfg.dim == self.embed_dim 383 | assert client3.cfg.train_num == train_num 384 | assert client3.cfg.nprobe == nprobe + 1 385 | client3.close() 386 | 387 | def test_get_centroids(self): 388 | index_id = "ivf_test" 389 | train_num = 10 390 | centroids = 2 391 | cfg = IndexCfg( 392 | index_builder_type="ivf_simple", 393 | dim=self.embed_dim, 394 | train_num=train_num, 395 | centroids=centroids, 396 | ) 397 | client = self.make_client(self.single_server_port) 398 | client.create_index(index_id, cfg) 399 | 400 | def add_data(ndoc): 401 | embeddings = torch.rand(ndoc, self.embed_dim).numpy() 402 | meta = get_rand_meta(ndoc, 5) 403 | client.add_index_data(index_id, embeddings, meta, train_async_if_triggered=False) 404 | return client.get_state(index_id) 405 | 406 | state = add_data(train_num) 407 | 408 | while True: 409 | print("Server state ", state) 410 | if state == IndexState.TRAINED: 411 | break 412 | time.sleep(2) 413 | state = client.get_state(index_id) 414 | 415 | assert client.get_centroids(index_id)[0].shape == (centroids, self.embed_dim) 416 | client.close() 417 | 418 | 419 | class TestIndexCfg(unittest.TestCase): 420 | def test_from_index_cfg_from_json(self): 421 | IndexCfg.from_json(REPO_HOME.joinpath("tests/test_index_config.json")) 422 | 423 | 424 | if __name__ == "__main__": 425 | unittest.main() 426 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /distributed_faiss/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import _thread 8 | import logging 9 | import math 10 | import os 11 | import pickle 12 | import threading 13 | import time 14 | from typing import Tuple, Optional, List, Union 15 | 16 | import faiss 17 | import numpy as np 18 | 19 | from distributed_faiss.index_cfg import IndexCfg 20 | from distributed_faiss.index_state import IndexState 21 | 22 | logger = logging.getLogger() 23 | 24 | 25 | def get_quantizer(cfg: IndexCfg): 26 | metric = cfg.get_metric() 27 | if metric == faiss.METRIC_INNER_PRODUCT: 28 | quantizer = faiss.IndexFlatIP(cfg.dim) 29 | elif metric == faiss.METRIC_L2: 30 | quantizer = faiss.IndexFlatL2(cfg.dim) 31 | else: 32 | raise RuntimeError(f"Metric={metric} is not supported") 33 | return quantizer 34 | 35 | 36 | def init_faiss_ivf_simple(cfg: IndexCfg): 37 | metric = cfg.get_metric() 38 | index = faiss.IndexIVFFlat(get_quantizer(cfg), cfg.dim, cfg.centroids, metric) 39 | index.nprobe = cfg.nprobe 40 | return index 41 | 42 | 43 | def init_faiss_knnlm(cfg: IndexCfg): 44 | code_size = cfg.extra.get("code_size", 64) 45 | bits_per_vector = cfg.extra.get("bits_per_vector", 8) 46 | index = faiss.IndexIVFPQ(get_quantizer(cfg), cfg.dim, cfg.centroids, code_size, bits_per_vector) 47 | cfg.nprobe = index.nprobe 48 | return index 49 | 50 | 51 | def init_faiss_hnswsq(cfg: IndexCfg): 52 | assert cfg.get_metric() == faiss.METRIC_L2, "hnsw is supposed to work with L2 sim space" 53 | index = faiss.IndexHNSWSQ( 54 | cfg.dim, 55 | faiss.ScalarQuantizer.QT_8bit, 56 | cfg.extra.get("store_n", 128), 57 | ) 58 | index.hnsw.efSearch = cfg.nprobe 59 | index.hnsw.efConstruction = cfg.extra.get("ef_construction", 100) 60 | return index 61 | 62 | 63 | def init_faiss_ivf_scalar_qr(cfg: IndexCfg): 64 | index = faiss.IndexIVFScalarQuantizer( 65 | get_quantizer(cfg), cfg.dim, cfg.centroids, faiss.ScalarQuantizer.QT_fp16 66 | ) 67 | index.nprobe = cfg.nprobe 68 | return index 69 | 70 | 71 | def init_faiss_ivf_gpu(cfg: IndexCfg): 72 | # TODO: implement the logic of selecting specific gpus from cfg 73 | index = faiss.index_factory(cfg.dim, cfg.faiss_factory) 74 | metric = cfg.get_metric() 75 | if metric == faiss.METRIC_INNER_PRODUCT: 76 | quantizer = faiss.IndexFlatIP(cfg.dim) 77 | elif metric == faiss.METRIC_L2: 78 | quantizer = faiss.IndexFlatL2(cfg.dim) 79 | else: 80 | raise RuntimeError(f"Metric={metric} is not supported for ivf_gpu factory") 81 | 82 | index_ivf = faiss.extract_index_ivf(index) 83 | clustering_index = faiss.index_cpu_to_all_gpus(quantizer) 84 | index_ivf.clustering_index = clustering_index 85 | index.nprobe = cfg.nprobe 86 | return index 87 | 88 | 89 | def init_flat_index(cfg: IndexCfg): 90 | return get_quantizer(cfg) 91 | 92 | 93 | faiss_special_index_factories = { 94 | "flat": lambda cfg: faiss.IndexFlatIP(cfg.dim), 95 | "ivf_simple": init_faiss_ivf_simple, 96 | "knnlm": init_faiss_knnlm, 97 | "hnswsq": init_faiss_hnswsq, 98 | "ivfsq": init_faiss_ivf_scalar_qr, 99 | "ivf_gpu": init_faiss_ivf_gpu, 100 | } 101 | 102 | 103 | def get_index_files(index_storage_dir: str) -> Tuple[str, str, str, str]: 104 | index_file = os.path.join(index_storage_dir, "index.faiss") 105 | meta_file = os.path.join(index_storage_dir, "meta.pkl") 106 | buffer_file = os.path.join(index_storage_dir, "buffer.pkl") 107 | cfg_file = os.path.join(index_storage_dir, "cfg.json") 108 | return index_file, meta_file, buffer_file, cfg_file 109 | 110 | 111 | class Index: 112 | def __init__(self, cfg: IndexCfg): 113 | self.cfg = cfg 114 | self.embeddings_buffer = [] 115 | self.total_data = 0 116 | self.id_to_metadata = [] 117 | self.buffer_lock = threading.Lock() 118 | self.index_lock = threading.Lock() 119 | self.state = IndexState.NOT_TRAINED 120 | self.faiss_index = None 121 | 122 | self.index_save_time = time.time() 123 | self.index_saved_size = 0 124 | 125 | if cfg.save_interval_sec > 0: 126 | self._run_save_watcher() 127 | 128 | def drop_index(self): 129 | with self.buffer_lock: 130 | self.embeddings_buffer = [] 131 | self.total_data = 0 132 | self.id_to_metadata = [] 133 | 134 | with self.index_lock: 135 | self.faiss_index = None 136 | self.state = IndexState.NOT_TRAINED 137 | 138 | def add_batch( 139 | self, 140 | embeddings: np.array, 141 | metadata: Optional[List[object]], 142 | train_async_if_triggered: bool = True, 143 | ): 144 | embeddings_num = embeddings.shape[0] 145 | if not metadata: 146 | metadata = [None] * embeddings_num 147 | if embeddings_num != len(metadata): 148 | raise RuntimeError("metadata length should match the batch size of the embeddings") 149 | 150 | # TODO: check why HNSW-SQ doesn't work without it 151 | embeddings = embeddings.astype(np.float32) 152 | 153 | with self.buffer_lock: 154 | self.embeddings_buffer.append(embeddings) 155 | self.id_to_metadata.extend(metadata) 156 | self.total_data += embeddings_num 157 | total_data = self.total_data 158 | 159 | logger.info(f"The size of the buffer is {total_data}") 160 | 161 | state = self.get_state() 162 | if state == IndexState.TRAINED and total_data >= 0: 163 | self.add_buffer_to_index() # TODO: revise to avoid double state check here & inside add_buffer_to_index 164 | elif state == IndexState.NOT_TRAINED and 0 < self.cfg.train_num <= total_data: 165 | # trigger training 166 | logger.info(f"The size of the buffer is {total_data}, can start index training.") 167 | if state in [IndexState.TRAINING, IndexState.TRAINED, IndexState.ADD]: 168 | logger.info(f"Index state: {state}, skip training start") 169 | return 170 | 171 | if train_async_if_triggered: 172 | logger.info(f"Starting index training in a new thread") 173 | _thread.start_new_thread(self.train, ()) 174 | else: 175 | logger.info(f"Starting index training sync") 176 | self.train() 177 | 178 | def get_idx_data_num(self) -> Tuple[int, int]: 179 | with self.buffer_lock: 180 | buf_total = self.total_data 181 | index_total = 0 182 | with self.index_lock: 183 | if self.faiss_index: 184 | index_total = self.faiss_index.ntotal 185 | return buf_total, index_total 186 | 187 | def train(self) -> None: 188 | with self.index_lock: 189 | if self.state in [IndexState.TRAINING, IndexState.TRAINED, IndexState.ADD]: 190 | return 191 | self.state = IndexState.TRAINING 192 | cfg = self.cfg 193 | 194 | with self.buffer_lock: 195 | embeddings = self.embeddings_buffer 196 | dim = cfg.dim 197 | if dim == 0: # guess 198 | dim = embeddings[0].shape[1] 199 | cfg.dim = dim 200 | 201 | # get only part of the buffer sufficient for training 202 | if cfg.train_num > 0: 203 | train_num = cfg.train_num 204 | elif cfg.train_ratio >= 1.0: 205 | train_num = self.total_data 206 | else: 207 | train_num = int(cfg.train_ratio * self.total_data) 208 | all_data_as_np_array = np.concatenate(embeddings, axis=0) 209 | 210 | train_data = all_data_as_np_array[:train_num] 211 | np.random.shuffle(train_data) 212 | total_data_size = all_data_as_np_array.shape[0] 213 | index = self._init_faiss_index(total_data_size) 214 | 215 | logging.info(f"Created a faiss index of type {type(index)}") 216 | logger.info(f"Training index with array shaped {train_data.shape}") 217 | index.train(train_data) 218 | logger.info(f"Index trained") 219 | 220 | with self.index_lock: 221 | self.faiss_index = index 222 | self.state = IndexState.TRAINED 223 | self.add_buffer_to_index() 224 | 225 | def add_buffer_to_index(self) -> None: 226 | add_to_index = False 227 | with self.index_lock: 228 | if self.state == IndexState.TRAINED: 229 | add_to_index = True 230 | self.state = IndexState.ADD 231 | logging.info("Index is trained, adding data from buffer") 232 | else: 233 | logging.info("Index add is already in progress") 234 | 235 | if add_to_index: 236 | # new vectors addition happens in a separate thread to make this method non blocking and letting next 237 | # client's batches go to different server nodes without waiting for add_buffer_to_index completion 238 | _thread.start_new_thread(self._add_buffer_to_idx, ()) 239 | 240 | # TODO: overload to get faiss indexes back, not metadata 241 | def search( 242 | self, query_batch: np.array, top_k: int = 100, return_embeddings: bool = False 243 | ) -> Tuple[np.array, List[List[object]], Optional[np.array]]: 244 | logger.info(f"Searching index, queries {query_batch.shape}") 245 | 246 | with self.index_lock: 247 | if self.state != IndexState.TRAINED: 248 | raise RuntimeError(f"Server index is not trained. state: {self.state}") 249 | # Locking index for search operations 250 | # from https://github.com/facebookresearch/faiss/wiki/Threads-and-asynchronous-calls#performance-of-search: 251 | # "However it is very inefficient to call batches of queries from multiple threads, this will spawn more threads than cores." 252 | # TODO: avoid locking for single query calls (or small batches?) 253 | 254 | if return_embeddings: 255 | scores, indexes, embs = self.faiss_index.search_and_reconstruct(query_batch, top_k) 256 | else: 257 | scores, indexes = self.faiss_index.search(query_batch, top_k) 258 | embs = None 259 | 260 | top_k, n = indexes.shape 261 | with self.buffer_lock: 262 | results_meta = [ 263 | [ 264 | self.id_to_metadata[indexes[i, j]] if indexes[i, j] != -1 else None 265 | for j in range(n) 266 | ] 267 | for i in range(top_k) 268 | ] 269 | logger.info(f"Search completed, results_meta len: {len(results_meta)}") 270 | return scores, results_meta, embs 271 | 272 | def save(self) -> bool: 273 | state = self.get_state() 274 | if state == IndexState.TRAINED: 275 | return self._maybe_save(ignore_time=True) 276 | elif state == IndexState.ADD: 277 | # reset index_save_time to trigger save upon the completion of add batch 278 | logger.info("Index is in ADD state, clearing index_save_time") 279 | self.index_save_time = 0 280 | else: 281 | logger.info("Index is not trained, skip saving") 282 | return False 283 | 284 | @classmethod 285 | def from_storage_dir( 286 | cls, index_storage_dir: str, cfg: IndexCfg = None, ignore_buffer: bool = True 287 | ) -> Union[None, object]: 288 | logger.info("Deserializing index from %s", index_storage_dir) 289 | index_file, meta_file, buffer_file, cfg_file = get_index_files(index_storage_dir) 290 | 291 | if not os.path.exists(index_file): 292 | logger.info(f"No index found at {index_file}") 293 | return None 294 | else: 295 | logger.info(f"index_file={index_file} meta_file={meta_file} buffer_file={buffer_file}") 296 | 297 | index = faiss.read_index(index_file) 298 | logger.info("Loaded index of type %s and size %d", type(index), index.ntotal) 299 | 300 | if os.path.exists(meta_file): 301 | logger.info("Loading meta file from %s", meta_file) 302 | with open(meta_file, "rb") as reader: 303 | meta = pickle.load(reader) 304 | logger.info("metadata deserialized, size: %d", len(meta)) 305 | assert ( 306 | len(meta) >= index.ntotal 307 | ), "Deserialized meta list should be at least of faiss index size" 308 | else: 309 | raise RuntimeError("no meta file found. Can't use index.") 310 | 311 | buffer = [] 312 | if (not ignore_buffer) and buffer_file and os.path.exists(buffer_file): 313 | logger.info("Loading buffer from %s", buffer_file) 314 | with open(buffer_file, "rb") as reader: 315 | buffer = pickle.load(reader) 316 | 317 | if cfg is None: 318 | if os.path.isfile(cfg_file): 319 | cfg = IndexCfg.from_json(cfg_file) 320 | else: 321 | cfg = IndexCfg() 322 | 323 | logger.info(f"Meta length {len(meta)}") 324 | logger.info(f"Loaded index size {index.ntotal}") 325 | buffer_size = sum(v.shape[0] for v in buffer) 326 | logger.info(f"Buffer size {buffer_size}") 327 | 328 | result = cls(cfg) 329 | result.faiss_index = index 330 | result.state = IndexState.TRAINED 331 | result.upd_cfg(cfg) 332 | 333 | if len(meta) == index.ntotal + buffer_size: 334 | result.id_to_metadata = meta 335 | result.embeddings_buffer = buffer 336 | if buffer_size > 0: 337 | result.add_buffer_to_index() 338 | else: 339 | logger.warning( 340 | f"Metadata size doesn't match combined index+buffer size ({index.ntotal + buffer_size}): " 341 | f"ignoring buffer, reducing metadata to index size" 342 | ) 343 | result.id_to_metadata = meta[: index.ntotal] 344 | return result 345 | 346 | def get_centroids(self): 347 | with self.index_lock: 348 | if self.state != IndexState.TRAINED: 349 | raise RuntimeError("Server index is not trained") 350 | return self.faiss_index.quantizer.reconstruct_n(0, self.faiss_index.nlist) 351 | 352 | def set_nprobe(self, nprobe: int): 353 | self.cfg.nprobe = nprobe 354 | with self.index_lock: 355 | if self.faiss_index: 356 | self.faiss_index.nprobe = nprobe 357 | 358 | def get_state(self): 359 | with self.index_lock: 360 | return self.state 361 | 362 | def get_ids(self): 363 | id_idx = self.cfg.custom_meta_id_idx 364 | r = {meta[id_idx] for meta in self.id_to_metadata if meta} 365 | logger.info(f"Metadata Id set size {len(r)}") 366 | if self.faiss_index: 367 | with self.buffer_lock: 368 | with self.index_lock: 369 | buffer_size = sum(v.shape[0] for v in self.embeddings_buffer) 370 | if len(r) != self.faiss_index.ntotal + buffer_size: 371 | logger.warning( 372 | f"id set size mismatch, index+buffer={self.faiss_index.ntotal + buffer_size}" 373 | ) 374 | return r 375 | 376 | def upd_cfg(self, cfg: IndexCfg): 377 | self.cfg = cfg 378 | self._override_nprobe(cfg) 379 | 380 | def _init_faiss_index(self, total_data_size: int): 381 | cfg = self.cfg 382 | logger.info(f"Data size for indexing {total_data_size}") 383 | if cfg.index_builder_type: 384 | index = faiss_special_index_factories[cfg.index_builder_type](cfg) 385 | elif cfg.faiss_factory: 386 | cfg.centroids = int(cfg.centroids) 387 | if cfg.centroids == 0 or cfg.infer_centroids: 388 | cfg.centroids = self.infer_n_centroids(total_data_size) 389 | logger.info(f"Inferring cfg.centroids={cfg.centroids}") 390 | 391 | if "{centroids}" in cfg.faiss_factory: 392 | index_cfg_str = cfg.faiss_factory.format(centroids=cfg.centroids) 393 | else: 394 | index_cfg_str = cfg.faiss_factory 395 | logger.info(f"Using index factory: {index_cfg_str}") 396 | index = faiss.index_factory(cfg.dim, index_cfg_str, cfg.get_metric()) 397 | else: 398 | raise RuntimeError( 399 | "Either faiss_factory or valid index_builder_type should be specified to initialize index" 400 | ) 401 | return index 402 | 403 | def _add_buffer_to_idx(self): 404 | while True: 405 | bsz = self.cfg.buffer_bsz 406 | embeddings_to_add = [] 407 | embeddings_to_add_total = 0 408 | with self.buffer_lock: 409 | logger.info(f"Current buffer size: {self.total_data}") 410 | for i, e in enumerate(self.embeddings_buffer): 411 | embeddings_to_add.append(e) 412 | embeddings_to_add_total += e.shape[0] 413 | if embeddings_to_add_total >= bsz: 414 | break 415 | 416 | if embeddings_to_add_total == 0: 417 | logger.info(f"Buffer is empty") 418 | break 419 | else: 420 | self.embeddings_buffer = self.embeddings_buffer[len(embeddings_to_add) :] 421 | self.total_data -= embeddings_to_add_total 422 | add_data_as_np = np.concatenate(embeddings_to_add, axis=0) 423 | logger.info(f"Adding {add_data_as_np.shape[0]} to the trained index.") 424 | start_time = time.time() 425 | self.faiss_index.add(add_data_as_np) 426 | logger.info( 427 | f"Add completed in {time.time() - start_time}. ntotal={self.faiss_index.ntotal}" 428 | ) 429 | saved = self._maybe_save(ignore_time=False) 430 | logger.info(f"Index saved ={saved}") 431 | 432 | with self.index_lock: 433 | self.state = IndexState.TRAINED 434 | 435 | def _maybe_save(self, ignore_time: bool = False) -> bool: 436 | if not ignore_time: 437 | if self.cfg.save_interval_sec <= 0: # autosave disabled 438 | return False 439 | if time.time() - self.index_save_time < self.cfg.save_interval_sec: 440 | logger.info(f"Not enough time spent since the latest index save") 441 | return False 442 | 443 | # TODO: 444 | # - Use tmp files for initial save and then convert tmp files to permanent once they are ready 445 | # - the case when index is not trained/empty and data are only in the buffer 446 | # - part of the buffer may be lost if the process crashes during index ADD 447 | 448 | with self.buffer_lock, self.index_lock: 449 | if self.faiss_index.ntotal == self.index_saved_size: 450 | logger.info(f"Index hasn't changed since the latest save") 451 | return False 452 | 453 | index_storage_dir = self.cfg.index_storage_dir 454 | 455 | logger.info(f"Serializing index to {index_storage_dir}") 456 | index_file, meta_file, buffer_file, cfg_file = get_index_files(index_storage_dir) 457 | if os.path.exists(index_file): 458 | logger.info("Index file already exists. overwriting save") 459 | 460 | faiss.write_index(self.faiss_index, index_file) 461 | with open(meta_file, mode="wb") as f: 462 | pickle.dump(self.id_to_metadata, f) 463 | logger.info(f"Saved index & meta to: {index_file} | {meta_file}") 464 | 465 | with open(buffer_file, mode="wb") as f: 466 | pickle.dump(self.embeddings_buffer, f) 467 | logger.info(f"Saved buffer to {buffer_file}") 468 | 469 | with open(cfg_file, mode="w") as f: 470 | f.write(self.cfg.to_json_string() + "\n") 471 | 472 | self.index_saved_size = self.faiss_index.ntotal 473 | self.index_save_time = time.time() 474 | return True 475 | 476 | def _run_save_watcher(self): 477 | def _save(idx: Index): 478 | logger.info("Started save watcher thread") 479 | while True: 480 | time.sleep(idx.cfg.save_interval_sec) 481 | logger.info("autosave attempt") 482 | saved = idx._maybe_save(ignore_time=False) 483 | logger.info(f"Index saved ={saved}") 484 | 485 | _thread.start_new_thread(_save, (self,)) 486 | 487 | def _override_nprobe(self, cfg: IndexCfg): 488 | logger.info("Overriding nprobe from cfg=%s", cfg) 489 | if not self.faiss_index: 490 | logger.warning("Faiss Index is not initialized") 491 | if hasattr(self.faiss_index, "hnsw"): 492 | self.faiss_index.hnsw.efSearch = cfg.nprobe 493 | logger.info("hnsw.efSearch=%s", self.faiss_index.hnsw.efSearch) 494 | else: 495 | self.faiss_index.nprobe = cfg.nprobe 496 | 497 | @staticmethod 498 | def infer_n_centroids(total_data_size): 499 | if total_data_size < 10e5: 500 | centroids = int(2 * math.sqrt(total_data_size)) 501 | # TODO: get 4-16 factor from cfg 502 | elif total_data_size < 10e6: 503 | centroids = 65536 504 | elif total_data_size < 10e7: 505 | centroids = 262144 506 | else: 507 | centroids = 1048576 508 | return centroids 509 | --------------------------------------------------------------------------------