├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .env.example ├── .github └── workflows │ ├── publish_package_on_release.yml │ └── pull_request.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── OWNERS ├── README.md ├── fig ├── custom_case_run_test.png └── custom_dataset.png ├── install.py ├── install └── requirements_py3.11.txt ├── pyproject.toml ├── tests ├── conftest.py ├── pytest.ini ├── test_bench_runner.py ├── test_chroma.py ├── test_data_source.py ├── test_dataset.py ├── test_elasticsearch_cloud.py ├── test_models.py ├── test_rate_runner.py ├── test_redis.py ├── test_utils.py └── ut_cases.py └── vectordb_bench ├── __init__.py ├── __main__.py ├── backend ├── __init__.py ├── assembler.py ├── cases.py ├── clients │ ├── __init__.py │ ├── aliyun_elasticsearch │ │ ├── aliyun_elasticsearch.py │ │ └── config.py │ ├── aliyun_opensearch │ │ ├── aliyun_opensearch.py │ │ └── config.py │ ├── alloydb │ │ ├── alloydb.py │ │ ├── cli.py │ │ └── config.py │ ├── api.py │ ├── aws_opensearch │ │ ├── aws_opensearch.py │ │ ├── cli.py │ │ ├── config.py │ │ └── run.py │ ├── chroma │ │ ├── chroma.py │ │ └── config.py │ ├── clickhouse │ │ ├── cli.py │ │ ├── clickhouse.py │ │ └── config.py │ ├── elastic_cloud │ │ ├── config.py │ │ └── elastic_cloud.py │ ├── lancedb │ │ ├── cli.py │ │ ├── config.py │ │ └── lancedb.py │ ├── mariadb │ │ ├── cli.py │ │ ├── config.py │ │ └── mariadb.py │ ├── memorydb │ │ ├── cli.py │ │ ├── config.py │ │ └── memorydb.py │ ├── milvus │ │ ├── cli.py │ │ ├── config.py │ │ └── milvus.py │ ├── mongodb │ │ ├── config.py │ │ └── mongodb.py │ ├── pgdiskann │ │ ├── cli.py │ │ ├── config.py │ │ └── pgdiskann.py │ ├── pgvecto_rs │ │ ├── cli.py │ │ ├── config.py │ │ └── pgvecto_rs.py │ ├── pgvector │ │ ├── cli.py │ │ ├── config.py │ │ └── pgvector.py │ ├── pgvectorscale │ │ ├── cli.py │ │ ├── config.py │ │ └── pgvectorscale.py │ ├── pinecone │ │ ├── config.py │ │ └── pinecone.py │ ├── qdrant_cloud │ │ ├── cli.py │ │ ├── config.py │ │ └── qdrant_cloud.py │ ├── redis │ │ ├── cli.py │ │ ├── config.py │ │ └── redis.py │ ├── test │ │ ├── cli.py │ │ ├── config.py │ │ └── test.py │ ├── tidb │ │ ├── cli.py │ │ ├── config.py │ │ └── tidb.py │ ├── vespa │ │ ├── cli.py │ │ ├── config.py │ │ ├── util.py │ │ └── vespa.py │ ├── weaviate_cloud │ │ ├── cli.py │ │ ├── config.py │ │ └── weaviate_cloud.py │ └── zilliz_cloud │ │ ├── cli.py │ │ ├── config.py │ │ └── zilliz_cloud.py ├── data_source.py ├── dataset.py ├── result_collector.py ├── runner │ ├── __init__.py │ ├── mp_runner.py │ ├── rate_runner.py │ ├── read_write_runner.py │ ├── serial_runner.py │ └── util.py ├── task_runner.py └── utils.py ├── base.py ├── cli ├── __init__.py ├── cli.py └── vectordbbench.py ├── config-files └── sample_config.yml ├── custom └── custom_case.json ├── frontend ├── components │ ├── check_results │ │ ├── charts.py │ │ ├── data.py │ │ ├── expanderStyle.py │ │ ├── filters.py │ │ ├── footer.py │ │ ├── headerIcon.py │ │ ├── nav.py │ │ ├── priceTable.py │ │ └── stPageConfig.py │ ├── concurrent │ │ └── charts.py │ ├── custom │ │ ├── displayCustomCase.py │ │ ├── displaypPrams.py │ │ ├── getCustomConfig.py │ │ └── initStyle.py │ ├── get_results │ │ └── saveAsImage.py │ ├── run_test │ │ ├── autoRefresh.py │ │ ├── caseSelector.py │ │ ├── dbConfigSetting.py │ │ ├── dbSelector.py │ │ ├── generateTasks.py │ │ ├── hideSidebar.py │ │ ├── initStyle.py │ │ └── submitTask.py │ └── tables │ │ └── data.py ├── config │ ├── dbCaseConfigs.py │ ├── dbPrices.py │ └── styles.py ├── pages │ ├── concurrent.py │ ├── custom.py │ ├── quries_per_dollar.py │ ├── run_test.py │ └── tables.py ├── utils.py └── vdb_benchmark.py ├── interface.py ├── log_util.py ├── metric.py ├── models.py └── results ├── ElasticCloud ├── result_20230727_standard_elasticcloud.json └── result_20230808_standard_elasticcloud.json ├── Milvus ├── result_20230727_standard_milvus.json └── result_20230808_standard_milvus.json ├── PgVector ├── result_20230727_standard_pgvector.json └── result_20230808_standard_pgvector.json ├── Pinecone ├── result_20230727_standard_pinecone.json └── result_20230808_standard_pinecone.json ├── QdrantCloud ├── result_20230727_standard_qdrantcloud.json └── result_20230808_standard_qdrantcloud.json ├── WeaviateCloud ├── result_20230727_standard_weaviatecloud.json └── result_20230808_standard_weaviatecloud.json ├── ZillizCloud ├── result_20230727_standard_zillizcloud.json ├── result_20230808_standard_zillizcloud.json └── result_20240105_standard_202401_zillizcloud.json ├── dbPrices.json ├── getLeaderboardData.py └── leaderboard.json /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-buster as builder-image 2 | 3 | RUN apt-get update 4 | 5 | COPY ../install/requirements_py3.11.txt . 6 | RUN pip3 install -U pip 7 | RUN pip3 install --no-cache-dir -r requirements_py3.11.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | 9 | WORKDIR /opt/code 10 | ENV PYTHONPATH /opt/code/VectorDBBench 11 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile 3 | { 4 | "name": "VectorDBBench dev container", 5 | "build": { 6 | // Sets the run context to one level up instead of the .devcontainer folder. 7 | "context": "..", 8 | // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. 9 | "dockerfile": "./Dockerfile" 10 | }, 11 | "runArgs": [ 12 | "--privileged", 13 | "--cap-add=SYS_PTRACE" 14 | ], 15 | "mounts": [ 16 | // You have to make sure source directory is avaliable on your host file system. 17 | "source=${localEnv:HOME}/vectordb_bench/dataset,target=/tmp/vectordb_bench/dataset,type=bind,consistency=cached" 18 | ], 19 | "workspaceMount": "source=${localWorkspaceFolder},target=/opt/code/VectorDBBench,type=bind,consistency=cached", 20 | "workspaceFolder": "/opt/code/VectorDBBench", 21 | 22 | // Features to add to the dev container. More info: https://containers.dev/features. 23 | // "features": {}, 24 | 25 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 26 | "forwardPorts": [ 27 | 8501 28 | ], 29 | 30 | // Uncomment the next line to run commands after the container is created. 31 | // "postCreateCommand": "cat /etc/os-release", 32 | 33 | // Configure tool-specific properties. 34 | "customizations": { 35 | "vscode": { 36 | "extensions": [ 37 | "eamodio.gitlens", 38 | "ms-python.python", 39 | "ms-python.debugpy", 40 | "ms-azuretools.vscode-docker" 41 | ] 42 | } 43 | } 44 | 45 | // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. 46 | // "remoteUser": "devcontainer" 47 | } 48 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # LOG_LEVEL= 2 | # LOG_PATH= 3 | # LOG_NAME= 4 | # TIMEZONE= 5 | 6 | # NUM_PER_BATCH= 7 | # DEFAULT_DATASET_URL= 8 | 9 | DATASET_LOCAL_DIR="/tmp/vectordb_bench/dataset" 10 | 11 | # DROP_OLD = True 12 | -------------------------------------------------------------------------------- /.github/workflows/publish_package_on_release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to TestPyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build-n-publish: 9 | name: Build and publish Python 🐍 distributions 📦 to PyPI 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Check out from Git 14 | uses: actions/checkout@v3 15 | - name: Get history and tags for SCM versioning 16 | run: | 17 | git fetch --prune --unshallow 18 | git fetch --depth=1 origin +refs/tags/*:refs/tags/* 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.11 23 | - name: Install pypa/build 24 | run: >- 25 | python -m 26 | pip install 27 | build 28 | --user 29 | - name: Build a binary wheel and a source tarball 30 | run: >- 31 | python -m 32 | build 33 | --sdist 34 | --wheel 35 | --outdir dist/ 36 | . 37 | - name: Publish distribution 📦 to Test PyPI 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | with: 40 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 41 | repository-url: https://test.pypi.org/legacy/ 42 | - name: Publish distribution 📦 to PyPI 43 | if: startsWith(github.ref, 'refs/tags') 44 | uses: pypa/gh-action-pypi-publish@release/v1 45 | with: 46 | password: ${{ secrets.PYPI_API_TOKEN }} 47 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | name: Test on pull request 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | name: Run Python Tests 11 | strategy: 12 | matrix: 13 | python-version: [3.11, 3.12] 14 | os: [ubuntu-latest, windows-latest] 15 | runs-on: ${{ matrix.os }} 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | - name: Setup Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Fetch tags 26 | run: | 27 | git fetch --prune --unshallow --tags 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install -e ".[test]" 33 | 34 | - name: Run coding checks 35 | run: | 36 | make lint 37 | 38 | - name: Test with pytest 39 | run: | 40 | make unittest 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw[op] 2 | *.egg-info 3 | dist/ 4 | __pycache__ 5 | .env 6 | .data/ 7 | __MACOSX 8 | .DS_Store 9 | build/ 10 | venv/ 11 | .venv/ 12 | .idea/ 13 | results/ 14 | logs/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-buster as builder-image 2 | 3 | RUN apt-get update 4 | 5 | COPY install/requirements_py3.11.txt . 6 | RUN pip3 install -U pip 7 | RUN pip3 install --no-cache-dir -r requirements_py3.11.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 8 | 9 | FROM python:3.11-slim-buster 10 | 11 | COPY --from=builder-image /usr/local/bin /usr/local/bin 12 | COPY --from=builder-image /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages 13 | 14 | WORKDIR /opt/code 15 | COPY . . 16 | ENV PYTHONPATH /opt/code 17 | 18 | ENTRYPOINT ["python3", "-m", "vectordb_bench"] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zilliztech 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | unittest: 2 | PYTHONPATH=`pwd` python3 -m pytest tests/test_dataset.py::TestDataSet::test_download_small -svv 3 | 4 | format: 5 | PYTHONPATH=`pwd` python3 -m black vectordb_bench 6 | PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench --fix 7 | 8 | lint: 9 | PYTHONPATH=`pwd` python3 -m black vectordb_bench --check 10 | PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench 11 | -------------------------------------------------------------------------------- /OWNERS: -------------------------------------------------------------------------------- 1 | filters: 2 | ".*": 3 | approvers: 4 | - XuanYang-cn 5 | reviewers: 6 | - XuanYang-cn 7 | -------------------------------------------------------------------------------- /fig/custom_case_run_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zilliztech/VectorDBBench/571de32b1cae8210f0ce5426981347add2d5c61d/fig/custom_case_run_test.png -------------------------------------------------------------------------------- /fig/custom_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zilliztech/VectorDBBench/571de32b1cae8210f0ce5426981347add2d5c61d/fig/custom_dataset.png -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | 6 | def docker_tag_base(): 7 | return 'vdbbench' 8 | 9 | def dockerfile_path_base(): 10 | return os.path.join('vectordb_bench/', '../Dockerfile') 11 | 12 | def docker_tag(track, algo): 13 | return docker_tag_base() + '-' + track + '-' + algo 14 | 15 | 16 | def build(tag, args, dockerfile): 17 | print('Building %s...' % tag) 18 | if args is not None and len(args) != 0: 19 | q = " ".join(["--build-arg " + x.replace(" ", "\\ ") for x in args]) 20 | else: 21 | q = "" 22 | 23 | try: 24 | command = 'docker build %s --rm -t %s -f' \ 25 | % (q, tag) 26 | command += ' %s .' % dockerfile 27 | print(command) 28 | subprocess.check_call(command, shell=True) 29 | return {tag: 'success'} 30 | except subprocess.CalledProcessError: 31 | return {tag: 'fail'} 32 | 33 | def build_multiprocess(args): 34 | return build(*args) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser( 39 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 40 | parser.add_argument( 41 | "--proc", 42 | default=1, 43 | type=int, 44 | help="the number of process to build docker images") 45 | parser.add_argument( 46 | '--track', 47 | choices=['none'], 48 | default='none' 49 | ) 50 | parser.add_argument( 51 | '--algorithm', 52 | metavar='NAME', 53 | help='build only the named algorithm image', 54 | default=None) 55 | parser.add_argument( 56 | '--dockerfile', 57 | metavar='PATH', 58 | help='build only the image from a Dockerfile path', 59 | default=None) 60 | parser.add_argument( 61 | '--build-arg', 62 | help='pass given args to all docker builds', 63 | nargs="+") 64 | args = parser.parse_args() 65 | 66 | print('Building base image...') 67 | 68 | subprocess.check_call( 69 | 'docker build \ 70 | --rm -t %s -f %s .' % (docker_tag_base(), dockerfile_path_base()), shell=True) 71 | 72 | print('Building end.') 73 | 74 | -------------------------------------------------------------------------------- /install/requirements_py3.11.txt: -------------------------------------------------------------------------------- 1 | grpcio==1.53.2 2 | grpcio-tools==1.53.0 3 | qdrant-client 4 | pinecone-client 5 | weaviate-client 6 | elasticsearch 7 | pgvector 8 | pgvecto_rs[psycopg3]>=0.2.1 9 | sqlalchemy 10 | redis 11 | chromadb 12 | pytz 13 | streamlit-autorefresh 14 | streamlit>=1.23.0 15 | streamlit_extras 16 | tqdm 17 | s3fs 18 | psutil 19 | polars 20 | plotly 21 | environs 22 | pydantic=int} 84 | with chrma.init(): 85 | filter_value = int(count * filter_value) 86 | test_id = np.random.randint(filter_value, count) 87 | q = embeddings[test_id] 88 | 89 | 90 | res = chrma.search_embedding( 91 | query=q, k=100, filters={"id": filter_value} 92 | ) 93 | assert ( 94 | res[0] == int(test_id) 95 | ), f"the most nearest neighbor ({res[0]}) id is not test_id ({test_id})" 96 | isFilter = True 97 | id_list = [] 98 | for id in res: 99 | id_list.append(id) 100 | if int(id) < filter_value: 101 | isFilter = False 102 | break 103 | assert isFilter, f"Filter not working, id_list: {id_list}" 104 | 105 | -------------------------------------------------------------------------------- /tests/test_data_source.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | from vectordb_bench.backend.data_source import DatasetSource 4 | from vectordb_bench.backend.cases import type2case 5 | 6 | log = logging.getLogger("vectordb_bench") 7 | 8 | class TestReader: 9 | @pytest.mark.parametrize("type_case", [ 10 | (k, v) for k, v in type2case.items() 11 | ]) 12 | def test_type_cases(self, type_case): 13 | self.per_case_test(type_case) 14 | 15 | 16 | def per_case_test(self, type_case): 17 | t, ca_cls = type_case 18 | ca = ca_cls() 19 | log.info(f"test case: {t.name}, {ca.name}") 20 | 21 | filters = ca.filter_rate 22 | ca.dataset.prepare(source=DatasetSource.AliyunOSS, filters=filters) 23 | ali_trains = ca.dataset.train_files 24 | 25 | ca.dataset.prepare(filters=filters) 26 | s3_trains = ca.dataset.train_files 27 | 28 | assert ali_trains == s3_trains 29 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.backend.dataset import Dataset 2 | import logging 3 | import pytest 4 | from pydantic import ValidationError 5 | from vectordb_bench.backend.data_source import DatasetSource 6 | 7 | 8 | log = logging.getLogger("vectordb_bench") 9 | 10 | class TestDataSet: 11 | def test_iter_dataset(self): 12 | for ds in Dataset: 13 | log.info(ds) 14 | 15 | def test_cohere(self): 16 | cohere = Dataset.COHERE.get(100_000) 17 | log.info(cohere) 18 | assert cohere.name == "Cohere" 19 | assert cohere.size == 100_000 20 | assert cohere.label == "SMALL" 21 | assert cohere.dim == 768 22 | 23 | def test_cohere_error(self): 24 | with pytest.raises(ValidationError): 25 | Dataset.COHERE.get(9999) 26 | 27 | def test_iter_cohere(self): 28 | cohere_10m = Dataset.COHERE.manager(10_000_000) 29 | cohere_10m.prepare() 30 | 31 | import time 32 | before = time.time() 33 | for i in cohere_10m: 34 | log.debug(i.head(1)) 35 | 36 | dur_iter = time.time() - before 37 | log.warning(f"iter through cohere_10m cost={dur_iter/60}min") 38 | 39 | # pytest -sv tests/test_dataset.py::TestDataSet::test_iter_laion 40 | def test_iter_laion(self): 41 | laion_100m = Dataset.LAION.manager(100_000_000) 42 | from vectordb_bench.backend.data_source import DatasetSource 43 | laion_100m.prepare(source=DatasetSource.AliyunOSS) 44 | 45 | import time 46 | before = time.time() 47 | for i in laion_100m: 48 | log.debug(i.head(1)) 49 | 50 | dur_iter = time.time() - before 51 | log.warning(f"iter through laion_100m cost={dur_iter/60}min") 52 | 53 | def test_download_small(self): 54 | openai_50k = Dataset.OPENAI.manager(50_000) 55 | files = [ 56 | "test.parquet", 57 | "neighbors.parquet", 58 | "neighbors_head_1p.parquet", 59 | "neighbors_tail_1p.parquet", 60 | ] 61 | 62 | file_path = openai_50k.data_dir.joinpath("test.parquet") 63 | import os 64 | 65 | DatasetSource.S3.reader().read( 66 | openai_50k.data.dir_name.lower(), 67 | files=files, 68 | local_ds_root=openai_50k.data_dir, 69 | ) 70 | 71 | os.remove(file_path) 72 | DatasetSource.AliyunOSS.reader().read( 73 | openai_50k.data.dir_name.lower(), 74 | files=files, 75 | local_ds_root=openai_50k.data_dir, 76 | ) 77 | 78 | -------------------------------------------------------------------------------- /tests/test_elasticsearch_cloud.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from vectordb_bench.models import ( 3 | DB, 4 | MetricType, 5 | ElasticsearchConfig, 6 | ) 7 | import numpy as np 8 | 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | cloud_id = "" 13 | password = "" 14 | 15 | 16 | class TestModels: 17 | def test_insert_and_search(self): 18 | assert DB.ElasticCloud.value == "Elasticsearch" 19 | assert DB.ElasticCloud.config == ElasticsearchConfig 20 | 21 | dbcls = DB.ElasticCloud.init_cls 22 | dbConfig = DB.ElasticCloud.config_cls(cloud_id=cloud_id, password=password) 23 | dbCaseConfig = DB.ElasticCloud.case_config_cls()( 24 | metric_type=MetricType.L2, efConstruction=64, M=16, num_candidates=100 25 | ) 26 | 27 | dim = 16 28 | es = dbcls( 29 | dim=dim, 30 | db_config=dbConfig.to_dict(), 31 | db_case_config=dbCaseConfig, 32 | indice="test_es_cloud", 33 | drop_old=True, 34 | ) 35 | 36 | count = 10_000 37 | filter_rate = 0.9 38 | embeddings = [[np.random.random() for _ in range(dim)] for _ in range(count)] 39 | 40 | # insert 41 | with es.init(): 42 | res = es.insert_embeddings(embeddings=embeddings, metadata=range(count)) 43 | # bulk_insert return 44 | assert ( 45 | res == count 46 | ), f"the return count of bulk insert ({res}) is not equal to count ({count})" 47 | 48 | # indice_count return 49 | es.client.indices.refresh() 50 | esCountRes = es.client.count(index=es.indice) 51 | countResCount = esCountRes.raw["count"] 52 | assert ( 53 | countResCount == count 54 | ), f"the return count of es client ({countResCount}) is not equal to count ({count})" 55 | 56 | # search 57 | with es.init(): 58 | test_id = np.random.randint(count) 59 | log.info(f"test_id: {test_id}") 60 | q = embeddings[test_id] 61 | 62 | res = es.search_embedding(query=q, k=100) 63 | log.info(f"search_results_id: {res}") 64 | assert ( 65 | res[0] == test_id 66 | ), f"the most nearest neighbor ({res[0]}) id is not test_id ({test_id})" 67 | 68 | # search with filters 69 | with es.init(): 70 | test_id = np.random.randint(count * filter_rate, count) 71 | log.info(f"test_id: {test_id}") 72 | q = embeddings[test_id] 73 | 74 | res = es.search_embedding( 75 | query=q, k=100, filters={"id": count * filter_rate} 76 | ) 77 | log.info(f"search_results_id: {res}") 78 | assert ( 79 | res[0] == test_id 80 | ), f"the most nearest neighbor ({res[0]}) id is not test_id ({test_id})" 81 | isFilter = True 82 | for id in res: 83 | if id < count * filter_rate: 84 | isFilter = False 85 | break 86 | assert isFilter, f"filters failed" 87 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | from vectordb_bench.models import ( 4 | TaskConfig, CaseConfig, 5 | CaseResult, TestResult, 6 | Metric, CaseType 7 | ) 8 | from vectordb_bench.backend.clients import ( 9 | DB, 10 | IndexType 11 | ) 12 | 13 | from vectordb_bench import config 14 | 15 | 16 | log = logging.getLogger("vectordb_bench") 17 | 18 | 19 | class TestModels: 20 | @pytest.mark.skip("runs locally") 21 | def test_test_result(self): 22 | result = CaseResult( 23 | task_config=TaskConfig( 24 | db=DB.Milvus, 25 | db_config=DB.Milvus.config(), 26 | db_case_config=DB.Milvus.case_config_cls(index=IndexType.Flat)(), 27 | case_config=CaseConfig(case_id=CaseType.Performance10M), 28 | ), 29 | metrics=Metric(), 30 | ) 31 | 32 | test_result = TestResult(run_id=10000, results=[result]) 33 | test_result.flush() 34 | 35 | with pytest.raises(ValueError): 36 | result = TestResult.read_file('nosuchfile.json') 37 | 38 | def test_test_result_read_write(self): 39 | result_dir = config.RESULTS_LOCAL_DIR 40 | for json_file in result_dir.rglob("result*.json"): 41 | res = TestResult.read_file(json_file) 42 | res.flush() 43 | 44 | def test_test_result_merge(self): 45 | result_dir = config.RESULTS_LOCAL_DIR 46 | all_results = [] 47 | 48 | first_result = None 49 | for json_file in result_dir.glob("*.json"): 50 | res = TestResult.read_file(json_file) 51 | 52 | for cr in res.results: 53 | all_results.append(cr) 54 | 55 | if not first_result: 56 | first_result = res 57 | 58 | tr = TestResult( 59 | run_id=first_result.run_id, 60 | task_label="standard", 61 | results=all_results, 62 | ) 63 | tr.flush() 64 | 65 | def test_test_result_display(self): 66 | result_dir = config.RESULTS_LOCAL_DIR 67 | for json_file in result_dir.rglob("result*.json"): 68 | log.info(json_file) 69 | res = TestResult.read_file(json_file) 70 | res.display() 71 | -------------------------------------------------------------------------------- /tests/test_rate_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import argparse 3 | from vectordb_bench.backend.dataset import Dataset, DatasetSource 4 | from vectordb_bench.backend.runner.rate_runner import RatedMultiThreadingInsertRunner 5 | from vectordb_bench.backend.runner.read_write_runner import ReadWriteRunner 6 | from vectordb_bench.backend.clients import DB, VectorDB 7 | from vectordb_bench.backend.clients.milvus.config import FLATConfig 8 | from vectordb_bench.backend.clients.zilliz_cloud.config import AutoIndexConfig 9 | 10 | import logging 11 | 12 | log = logging.getLogger("vectordb_bench") 13 | log.setLevel(logging.DEBUG) 14 | 15 | def get_rate_runner(db): 16 | cohere = Dataset.COHERE.manager(100_000) 17 | prepared = cohere.prepare(DatasetSource.AliyunOSS) 18 | assert prepared 19 | runner = RatedMultiThreadingInsertRunner( 20 | rate = 10, 21 | db = db, 22 | dataset = cohere, 23 | ) 24 | 25 | return runner 26 | 27 | def test_rate_runner(db, insert_rate): 28 | runner = get_rate_runner(db) 29 | 30 | _, t = runner.run_with_rate() 31 | log.info(f"insert run done, time={t}") 32 | 33 | def test_read_write_runner(db, insert_rate, conc: list, search_stage: Iterable[float], read_dur_after_write: int, local: bool=False): 34 | cohere = Dataset.COHERE.manager(1_000_000) 35 | if local is True: 36 | source = DatasetSource.AliyunOSS 37 | else: 38 | source = DatasetSource.S3 39 | prepared = cohere.prepare(source) 40 | assert prepared 41 | 42 | rw_runner = ReadWriteRunner( 43 | db=db, 44 | dataset=cohere, 45 | insert_rate=insert_rate, 46 | search_stage=search_stage, 47 | read_dur_after_write=read_dur_after_write, 48 | concurrencies=conc 49 | ) 50 | rw_runner.run_read_write() 51 | 52 | 53 | def get_db(db: str, config: dict) -> VectorDB: 54 | if db == DB.Milvus.name: 55 | return DB.Milvus.init_cls(dim=768, db_config=config, db_case_config=FLATConfig(metric_type="COSINE"), drop_old=True) 56 | elif db == DB.ZillizCloud.name: 57 | return DB.ZillizCloud.init_cls(dim=768, db_config=config, db_case_config=AutoIndexConfig(metric_type="COSINE"), drop_old=True) 58 | else: 59 | raise ValueError(f"unknown db: {db}") 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("-r", "--insert_rate", type=int, default="1000", help="insert entity row count per seconds, cps") 65 | parser.add_argument("-d", "--db", type=str, default=DB.Milvus.name, help="db name") 66 | parser.add_argument("-t", "--duration", type=int, default=300, help="stage search duration in seconds") 67 | parser.add_argument("--use_s3", action='store_true', help="whether to use S3 dataset") 68 | 69 | flags = parser.parse_args() 70 | 71 | # TODO read uri, user, password from .env 72 | config = { 73 | "uri": "http://localhost:19530", 74 | "user": "", 75 | "password": "", 76 | } 77 | 78 | conc = (1, 15, 50) 79 | search_stage = (0.5, 0.6, 0.7, 0.8, 0.9) 80 | 81 | db = get_db(flags.db, config) 82 | test_read_write_runner( 83 | db=db, 84 | insert_rate=flags.insert_rate, 85 | conc=conc, 86 | search_stage=search_stage, 87 | read_dur_after_write=flags.duration, 88 | local=flags.use_s3) 89 | -------------------------------------------------------------------------------- /tests/test_redis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from vectordb_bench.models import ( 3 | DB, 4 | ) 5 | from vectordb_bench.backend.clients.redis.config import RedisConfig 6 | import numpy as np 7 | 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | # Tests for Redis, assumes Redis is running on localhost:6379, can be modified by changing the dict below 12 | dict = {} 13 | dict['name'] = "redis" 14 | dict['host'] = "localhost" 15 | dict['port'] = 6379 16 | dict['password'] = "redis" 17 | 18 | 19 | 20 | class TestRedis: 21 | def test_insert_and_search(self): 22 | assert DB.Redis.value == "Redis" 23 | dbcls = DB.Redis.init_cls 24 | dbConfig = dbcls.config_cls() 25 | 26 | 27 | dim = 16 28 | rdb = dbcls( 29 | dim=dim, 30 | db_config=dict, 31 | db_case_config=None, 32 | indice="test_redis", 33 | drop_old=True, 34 | ) 35 | 36 | count = 10_000 37 | filter_value = 0.9 38 | embeddings = [[np.random.random() for _ in range(dim)] for _ in range(count)] 39 | 40 | 41 | # insert 42 | with rdb.init(): 43 | assert (rdb.conn.ping() == True), "redis client is not connected" 44 | res = rdb.insert_embeddings(embeddings=embeddings, metadata=range(count)) 45 | # bulk_insert return 46 | assert ( 47 | res[0] == count 48 | ), f"the return count of bulk insert ({res}) is not equal to count ({count})" 49 | 50 | # count entries in redis database 51 | countRes = rdb.conn.dbsize() 52 | 53 | assert ( 54 | countRes == count 55 | ), f"the return count of redis client ({countRes}) is not equal to count ({count})" 56 | 57 | # search 58 | with rdb.init(): 59 | test_id = np.random.randint(count) 60 | #log.info(f"test_id: {test_id}") 61 | q = embeddings[test_id] 62 | 63 | res = rdb.search_embedding(query=q, k=100) 64 | #log.info(f"search_results_id: {res}") 65 | print(res) 66 | # res of format [2757, 2944, 8893, 6695, 5571, 608, 455, 3464, 1584, 1807, 8452, 4311...] 67 | assert ( 68 | res[0] == int(test_id) 69 | ), f"the most nearest neighbor ({res[0]}) id is not test_id ({str(test_id)}" 70 | 71 | # search with filters 72 | with rdb.init(): 73 | filter_value = int(count * filter_value) 74 | test_id = np.random.randint(filter_value, count) 75 | q = embeddings[test_id] 76 | 77 | 78 | res = rdb.search_embedding( 79 | query=q, k=100, filters={"metadata": filter_value} 80 | ) 81 | assert ( 82 | res[0] == int(test_id) 83 | ), f"the most nearest neighbor ({res[0]}) id is not test_id ({test_id})" 84 | isFilter = True 85 | id_list = [] 86 | for id in res: 87 | id_list.append(id) 88 | if int(id) < filter_value: 89 | isFilter = False 90 | break 91 | assert isFilter, f"filters failed, got: ({id}), expected less than ({filter_value})" 92 | 93 | #Test id filter for exact match 94 | res = rdb.search_embedding( 95 | query=q, k=100, filters={"id": 9999} 96 | ) 97 | assert ( 98 | res[0] == 9999 99 | ) 100 | 101 | #Test two filters, id and metadata 102 | res = rdb.search_embedding( 103 | query=q, k=100, filters={"metadata": filter_value, "id": 9999} 104 | ) 105 | assert ( 106 | res[0] == 9999 and len(res) == 1, f"filters failed, got: ({res[0]}), expected ({9999})" 107 | ) -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | 4 | from vectordb_bench.backend import utils 5 | from vectordb_bench.metric import calc_recall 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | class TestUtils: 10 | @pytest.mark.parametrize("testcases", [ 11 | (1, '1'), 12 | (10, '10'), 13 | (100, '100'), 14 | (1000, '1K'), 15 | (2000, '2K'), 16 | (30_000, '30K'), 17 | (400_000, '400K'), 18 | (5_000_000, '5M'), 19 | (60_000_000, '60M'), 20 | (1_000_000_000, '1B'), 21 | (1_000_000_000_000, '1000B'), 22 | ]) 23 | def test_numerize(self, testcases): 24 | t_in, expected = testcases 25 | assert expected == utils.numerize(t_in) 26 | 27 | @pytest.mark.parametrize("got_expected", [ 28 | ([1, 3, 5, 7, 9, 10], 1.0), 29 | ([11, 12, 13, 14, 15, 16], 0.0), 30 | ([1, 3, 5, 11, 12, 13], 0.5), 31 | ([1, 3, 5], 0.5), 32 | ]) 33 | def test_recall(self, got_expected): 34 | got, expected = got_expected 35 | ground_truth = [1, 3, 5, 7, 9, 10] 36 | res = calc_recall(6, ground_truth, got) 37 | log.info(f"recall: {res}, expected: {expected}") 38 | assert res == expected 39 | 40 | 41 | class TestGetFiles: 42 | @pytest.mark.parametrize("train_count", [ 43 | 1, 44 | 10, 45 | 50, 46 | 100, 47 | ]) 48 | def test_train_count(self, train_count): 49 | files = utils.compose_train_files(train_count, True) 50 | log.info(files) 51 | 52 | assert len(files) == train_count 53 | 54 | @pytest.mark.parametrize("use_shuffled", [True, False]) 55 | def test_use_shuffled(self, use_shuffled): 56 | files = utils.compose_train_files(1, use_shuffled) 57 | log.info(files) 58 | 59 | trains = [f for f in files if "train" in f] 60 | if use_shuffled: 61 | for t in trains: 62 | assert "shuffle_train" in t 63 | else: 64 | for t in trains: 65 | assert "shuffle" not in t 66 | assert "train" in t 67 | -------------------------------------------------------------------------------- /tests/ut_cases.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.backend.cases import ( 2 | PerformanceCase, 3 | CaseType, 4 | ) 5 | 6 | from vectordb_bench.backend.datase import Dataset, DatasetManager 7 | 8 | 9 | class Performance100K99p(PerformanceCase): 10 | case_id: CaseType = 100 11 | filter_rate: float | int | None = 0.99 12 | dataset: DatasetManager = Dataset.COHERE.manager(100_000) 13 | name: str = "Filtering Search Performance Test (100K Dataset, 768 Dim, Filter 99%)" 14 | description: str = """This case tests the search performance of a vector database with a small dataset (Cohere 100K vectors, 768 dimensions) under a high filtering rate (99% vectors), at varying parallel levels. 15 | Results will show index building time, recall, and maximum QPS.""" 16 | 17 | class Performance100K1p(PerformanceCase): 18 | case_id: CaseType = 100 19 | filter_rate: float | int | None = 0.01 20 | dataset: DatasetManager = Dataset.COHERE.manager(100_000) 21 | name: str = "Filtering Search Performance Test (100K Dataset, 768 Dim, Filter 1%)" 22 | description: str = ( 23 | """This case tests the search performance of a vector database with a small dataset (Cohere 100K vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel levels. 24 | Results will show index building time, recall, and maximum QPS.""", 25 | ) 26 | 27 | 28 | class Performance100K(PerformanceCase): 29 | case_id: CaseType = 100 30 | dataset: DatasetManager = Dataset.COHERE.manager(100_000) 31 | name: str = "Search Performance Test (100K Dataset, 768 Dim)" 32 | description: str = """This case tests the search performance of a vector database with a small dataset (Cohere 100K vectors, 768 dimensions) at varying parallel levels. 33 | Results will show index building time, recall, and maximum QPS.""" 34 | -------------------------------------------------------------------------------- /vectordb_bench/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pathlib 3 | 4 | import environs 5 | 6 | from . import log_util 7 | 8 | env = environs.Env() 9 | env.read_env(path=".env", recurse=False) 10 | 11 | 12 | class config: 13 | ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/" 14 | AWS_S3_URL = "assets.zilliz.com/benchmark/" 15 | 16 | LOG_LEVEL = env.str("LOG_LEVEL", "INFO") 17 | 18 | DEFAULT_DATASET_URL = env.str("DEFAULT_DATASET_URL", AWS_S3_URL) 19 | DATASET_LOCAL_DIR = env.path("DATASET_LOCAL_DIR", "/tmp/vectordb_bench/dataset") 20 | NUM_PER_BATCH = env.int("NUM_PER_BATCH", 100) 21 | 22 | DROP_OLD = env.bool("DROP_OLD", True) 23 | USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True) 24 | 25 | NUM_CONCURRENCY = env.list( 26 | "NUM_CONCURRENCY", 27 | [ 28 | 1, 29 | 5, 30 | 10, 31 | 15, 32 | 20, 33 | 25, 34 | 30, 35 | 35, 36 | 40, 37 | 45, 38 | 50, 39 | 55, 40 | 60, 41 | 65, 42 | 70, 43 | 75, 44 | 80, 45 | 85, 46 | 90, 47 | 95, 48 | 100, 49 | ], 50 | subcast=int, 51 | ) 52 | 53 | CONCURRENCY_DURATION = 30 54 | 55 | CONCURRENCY_TIMEOUT = 3600 56 | 57 | RESULTS_LOCAL_DIR = env.path( 58 | "RESULTS_LOCAL_DIR", 59 | pathlib.Path(__file__).parent.joinpath("results"), 60 | ) 61 | CONFIG_LOCAL_DIR = env.path( 62 | "CONFIG_LOCAL_DIR", 63 | pathlib.Path(__file__).parent.joinpath("config-files"), 64 | ) 65 | 66 | K_DEFAULT = 100 # default return top k nearest neighbors during search 67 | CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json") 68 | 69 | CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h 70 | LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h 71 | LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h 72 | LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d 73 | LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d 74 | 75 | LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h 76 | LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d 77 | 78 | OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h 79 | OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h 80 | OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d 81 | OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d 82 | 83 | OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h 84 | OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d 85 | 86 | def display(self) -> str: 87 | return [ 88 | i 89 | for i in inspect.getmembers(self) 90 | if not inspect.ismethod(i[1]) and not i[0].startswith("_") and "TIMEOUT" not in i[0] 91 | ] 92 | 93 | 94 | log_util.init(config.LOG_LEVEL) 95 | -------------------------------------------------------------------------------- /vectordb_bench/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import subprocess 4 | import traceback 5 | 6 | from . import config 7 | 8 | log = logging.getLogger("vectordb_bench") 9 | 10 | 11 | def main(): 12 | log.info(f"all configs: {config().display()}") 13 | run_streamlit() 14 | 15 | 16 | def run_streamlit(): 17 | cmd = [ 18 | "streamlit", 19 | "run", 20 | f"{pathlib.Path(__file__).parent}/frontend/vdb_benchmark.py", 21 | "--logger.level", 22 | "info", 23 | "--theme.base", 24 | "light", 25 | "--theme.primaryColor", 26 | "#3670F2", 27 | "--theme.secondaryBackgroundColor", 28 | "#F0F2F6", 29 | ] 30 | log.debug(f"cmd: {cmd}") 31 | try: 32 | subprocess.run(cmd, check=True) 33 | except KeyboardInterrupt: 34 | log.info("exit streamlit...") 35 | except Exception as e: 36 | log.warning(f"exit, err={e}\nstack trace={traceback.format_exc(chain=True)}") 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /vectordb_bench/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zilliztech/VectorDBBench/571de32b1cae8210f0ce5426981347add2d5c61d/vectordb_bench/backend/__init__.py -------------------------------------------------------------------------------- /vectordb_bench/backend/assembler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from vectordb_bench.backend.clients import EmptyDBCaseConfig 4 | from vectordb_bench.backend.data_source import DatasetSource 5 | from vectordb_bench.models import TaskConfig 6 | 7 | from .cases import CaseLabel 8 | from .task_runner import CaseRunner, RunningStatus, TaskRunner 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | class Assembler: 14 | @classmethod 15 | def assemble(cls, run_id: str, task: TaskConfig, source: DatasetSource) -> CaseRunner: 16 | c_cls = task.case_config.case_id.case_cls 17 | 18 | c = c_cls(task.case_config.custom_case) 19 | if type(task.db_case_config) is not EmptyDBCaseConfig: 20 | task.db_case_config.metric_type = c.dataset.data.metric_type 21 | 22 | return CaseRunner( 23 | run_id=run_id, 24 | config=task, 25 | ca=c, 26 | status=RunningStatus.PENDING, 27 | dataset_source=source, 28 | ) 29 | 30 | @classmethod 31 | def assemble_all( 32 | cls, 33 | run_id: str, 34 | task_label: str, 35 | tasks: list[TaskConfig], 36 | source: DatasetSource, 37 | ) -> TaskRunner: 38 | """group by case type, db, and case dataset""" 39 | runners = [cls.assemble(run_id, task, source) for task in tasks] 40 | load_runners = [r for r in runners if r.ca.label == CaseLabel.Load] 41 | perf_runners = [r for r in runners if r.ca.label == CaseLabel.Performance] 42 | 43 | # group by db 44 | db2runner = {} 45 | for r in perf_runners: 46 | db = r.config.db 47 | if db not in db2runner: 48 | db2runner[db] = [] 49 | db2runner[db].append(r) 50 | 51 | # check dbclient installed 52 | for k in db2runner: 53 | _ = k.init_cls 54 | 55 | # sort by dataset size 56 | for _, runner in db2runner.items(): 57 | runner.sort(key=lambda x: x.ca.dataset.data.size) 58 | 59 | all_runners = [] 60 | all_runners.extend(load_runners) 61 | for v in db2runner.values(): 62 | all_runners.extend(v) 63 | 64 | return TaskRunner( 65 | run_id=run_id, 66 | task_label=task_label, 67 | case_runners=all_runners, 68 | ) 69 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py: -------------------------------------------------------------------------------- 1 | from ..elastic_cloud.config import ElasticCloudIndexConfig 2 | from ..elastic_cloud.elastic_cloud import ElasticCloud 3 | 4 | 5 | class AliyunElasticsearch(ElasticCloud): 6 | def __init__( 7 | self, 8 | dim: int, 9 | db_config: dict, 10 | db_case_config: ElasticCloudIndexConfig, 11 | indice: str = "vdb_bench_indice", # must be lowercase 12 | id_col_name: str = "id", 13 | vector_col_name: str = "vector", 14 | drop_old: bool = False, 15 | **kwargs, 16 | ): 17 | super().__init__( 18 | dim=dim, 19 | db_config=db_config, 20 | db_case_config=db_case_config, 21 | indice=indice, 22 | id_col_name=id_col_name, 23 | vector_col_name=vector_col_name, 24 | drop_old=drop_old, 25 | **kwargs, 26 | ) 27 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/aliyun_elasticsearch/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBConfig 4 | 5 | 6 | class AliyunElasticsearchConfig(DBConfig, BaseModel): 7 | #: Protocol in use to connect to the node 8 | scheme: str = "http" 9 | host: str = "" 10 | port: int = 9200 11 | user: str = "elastic" 12 | password: SecretStr 13 | 14 | def to_dict(self) -> dict: 15 | return { 16 | "hosts": [{"scheme": self.scheme, "host": self.host, "port": self.port}], 17 | "basic_auth": (self.user, self.password.get_secret_value()), 18 | } 19 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/aliyun_opensearch/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import BaseModel, SecretStr 4 | 5 | from ..api import DBCaseConfig, DBConfig, MetricType 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | class AliyunOpenSearchConfig(DBConfig, BaseModel): 11 | host: str = "" 12 | user: str = "" 13 | password: SecretStr = "" 14 | 15 | ak: str = "" 16 | sk: SecretStr = "" 17 | control_host: str = "searchengine.cn-hangzhou.aliyuncs.com" 18 | 19 | def to_dict(self) -> dict: 20 | return { 21 | "host": self.host, 22 | "user": self.user, 23 | "password": self.password.get_secret_value(), 24 | "ak": self.ak, 25 | "sk": self.sk.get_secret_value(), 26 | "control_host": self.control_host, 27 | } 28 | 29 | 30 | class AliyunOpenSearchIndexConfig(BaseModel, DBCaseConfig): 31 | metric_type: MetricType = MetricType.L2 32 | ef_construction: int = 500 33 | M: int = 100 34 | ef_search: int = 40 35 | 36 | def distance_type(self) -> str: 37 | if self.metric_type == MetricType.L2: 38 | return "SquaredEuclidean" 39 | if self.metric_type in (MetricType.IP, MetricType.COSINE): 40 | return "InnerProduct" 41 | return "SquaredEuclidean" 42 | 43 | def index_param(self) -> dict: 44 | return {} 45 | 46 | def search_param(self) -> dict: 47 | return {} 48 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/aws_opensearch/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | HNSWFlavor2, 9 | cli, 10 | click_parameter_decorators_from_typed_dict, 11 | run, 12 | ) 13 | from .. import DB 14 | 15 | 16 | class AWSOpenSearchTypedDict(TypedDict): 17 | host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] 18 | port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")] 19 | user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")] 20 | password: Annotated[str, click.option("--password", type=str, help="Db password")] 21 | number_of_shards: Annotated[ 22 | int, 23 | click.option("--number-of-shards", type=int, help="Number of primary shards for the index", default=1), 24 | ] 25 | number_of_replicas: Annotated[ 26 | int, 27 | click.option( 28 | "--number-of-replicas", type=int, help="Number of replica copies for each primary shard", default=1 29 | ), 30 | ] 31 | index_thread_qty: Annotated[ 32 | int, 33 | click.option( 34 | "--index-thread-qty", 35 | type=int, 36 | help="Thread count for native engine indexing", 37 | default=4, 38 | ), 39 | ] 40 | 41 | index_thread_qty_during_force_merge: Annotated[ 42 | int, 43 | click.option( 44 | "--index-thread-qty-during-force-merge", 45 | type=int, 46 | help="Thread count during force merge operations", 47 | default=4, 48 | ), 49 | ] 50 | 51 | number_of_indexing_clients: Annotated[ 52 | int, 53 | click.option( 54 | "--number-of-indexing-clients", 55 | type=int, 56 | help="Number of concurrent indexing clients", 57 | default=1, 58 | ), 59 | ] 60 | 61 | number_of_segments: Annotated[ 62 | int, 63 | click.option("--number-of-segments", type=int, help="Target number of segments after merging", default=1), 64 | ] 65 | 66 | refresh_interval: Annotated[ 67 | int, 68 | click.option( 69 | "--refresh-interval", type=str, help="How often to make new data available for search", default="60s" 70 | ), 71 | ] 72 | 73 | force_merge_enabled: Annotated[ 74 | int, 75 | click.option("--force-merge-enabled", type=bool, help="Whether to perform force merge operation", default=True), 76 | ] 77 | 78 | flush_threshold_size: Annotated[ 79 | int, 80 | click.option( 81 | "--flush-threshold-size", type=str, help="Size threshold for flushing the transaction log", default="5120mb" 82 | ), 83 | ] 84 | 85 | cb_threshold: Annotated[ 86 | int, 87 | click.option( 88 | "--cb-threshold", 89 | type=str, 90 | help="k-NN Memory circuit breaker threshold", 91 | default="50%", 92 | ), 93 | ] 94 | 95 | 96 | class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ... 97 | 98 | 99 | @cli.command() 100 | @click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict) 101 | def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): 102 | from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig 103 | 104 | run( 105 | db=DB.AWSOpenSearch, 106 | db_config=AWSOpenSearchConfig( 107 | host=parameters["host"], 108 | port=parameters["port"], 109 | user=parameters["user"], 110 | password=SecretStr(parameters["password"]), 111 | ), 112 | db_case_config=AWSOpenSearchIndexConfig( 113 | number_of_shards=parameters["number_of_shards"], 114 | number_of_replicas=parameters["number_of_replicas"], 115 | index_thread_qty=parameters["index_thread_qty"], 116 | number_of_segments=parameters["number_of_segments"], 117 | refresh_interval=parameters["refresh_interval"], 118 | force_merge_enabled=parameters["force_merge_enabled"], 119 | flush_threshold_size=parameters["flush_threshold_size"], 120 | number_of_indexing_clients=parameters["number_of_indexing_clients"], 121 | index_thread_qty_during_force_merge=parameters["index_thread_qty_during_force_merge"], 122 | cb_threshold=parameters["cb_threshold"], 123 | ), 124 | **parameters, 125 | ) 126 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/aws_opensearch/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import Enum 3 | 4 | from pydantic import BaseModel, SecretStr 5 | 6 | from ..api import DBCaseConfig, DBConfig, MetricType 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | class AWSOpenSearchConfig(DBConfig, BaseModel): 12 | host: str = "" 13 | port: int = 443 14 | user: str = "" 15 | password: SecretStr = "" 16 | 17 | def to_dict(self) -> dict: 18 | return { 19 | "hosts": [{"host": self.host, "port": self.port}], 20 | "http_auth": (self.user, self.password.get_secret_value()), 21 | "use_ssl": True, 22 | "http_compress": True, 23 | "verify_certs": True, 24 | "ssl_assert_hostname": False, 25 | "ssl_show_warn": False, 26 | "timeout": 600, 27 | } 28 | 29 | 30 | class AWSOS_Engine(Enum): 31 | nmslib = "nmslib" 32 | faiss = "faiss" 33 | lucene = "Lucene" 34 | 35 | 36 | class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig): 37 | metric_type: MetricType = MetricType.L2 38 | engine: AWSOS_Engine = AWSOS_Engine.faiss 39 | efConstruction: int = 256 40 | efSearch: int = 256 41 | M: int = 16 42 | index_thread_qty: int | None = 4 43 | number_of_shards: int | None = 1 44 | number_of_replicas: int | None = 0 45 | number_of_segments: int | None = 1 46 | refresh_interval: str | None = "60s" 47 | force_merge_enabled: bool | None = True 48 | flush_threshold_size: str | None = "5120mb" 49 | number_of_indexing_clients: int | None = 1 50 | index_thread_qty_during_force_merge: int 51 | cb_threshold: str | None = "50%" 52 | 53 | def parse_metric(self) -> str: 54 | if self.metric_type == MetricType.IP: 55 | return "innerproduct" 56 | if self.metric_type == MetricType.COSINE: 57 | if self.engine == AWSOS_Engine.faiss: 58 | log.info( 59 | "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch", 60 | ) 61 | return "innerproduct" 62 | return "cosinesimil" 63 | return "l2" 64 | 65 | def index_param(self) -> dict: 66 | return { 67 | "name": "hnsw", 68 | "space_type": self.parse_metric(), 69 | "engine": self.engine.value, 70 | "parameters": { 71 | "ef_construction": self.efConstruction, 72 | "m": self.M, 73 | "ef_search": self.efSearch, 74 | }, 75 | } 76 | 77 | def search_param(self) -> dict: 78 | return {} 79 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/chroma/chroma.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from typing import Any 4 | 5 | import chromadb 6 | 7 | from ..api import DBCaseConfig, VectorDB 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | 12 | class ChromaClient(VectorDB): 13 | """Chroma client for VectorDB. 14 | To set up Chroma in docker, see https://docs.trychroma.com/usage-guide 15 | or the instructions in tests/test_chroma.py 16 | 17 | To change to running in process, modify the HttpClient() in __init__() and init(). 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dim: int, 23 | db_config: dict, 24 | db_case_config: DBCaseConfig, 25 | drop_old: bool = False, 26 | **kwargs, 27 | ): 28 | self.db_config = db_config 29 | self.case_config = db_case_config 30 | self.collection_name = "example2" 31 | 32 | client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) 33 | assert client.heartbeat() is not None 34 | if drop_old: 35 | try: 36 | client.reset() # Reset the database 37 | except Exception: 38 | drop_old = False 39 | log.info(f"Chroma client drop_old collection: {self.collection_name}") 40 | 41 | @contextmanager 42 | def init(self) -> None: 43 | """create and destory connections to database. 44 | 45 | Examples: 46 | >>> with self.init(): 47 | >>> self.insert_embeddings() 48 | """ 49 | # create connection 50 | self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) 51 | 52 | self.collection = self.client.get_or_create_collection("example2") 53 | yield 54 | self.client = None 55 | self.collection = None 56 | 57 | def ready_to_search(self) -> bool: 58 | pass 59 | 60 | def optimize(self, data_size: int | None = None): 61 | pass 62 | 63 | def insert_embeddings( 64 | self, 65 | embeddings: list[list[float]], 66 | metadata: list[int], 67 | **kwargs: Any, 68 | ) -> tuple[int, Exception]: 69 | """Insert embeddings into the database. 70 | 71 | Args: 72 | embeddings(list[list[float]]): list of embeddings 73 | metadata(list[int]): list of metadata 74 | kwargs: other arguments 75 | 76 | Returns: 77 | tuple[int, Exception]: number of embeddings inserted and exception if any 78 | """ 79 | ids = [str(i) for i in metadata] 80 | metadata = [{"id": int(i)} for i in metadata] 81 | if len(embeddings) > 0: 82 | self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata) 83 | return len(embeddings), None 84 | 85 | def search_embedding( 86 | self, 87 | query: list[float], 88 | k: int = 100, 89 | filters: dict | None = None, 90 | timeout: int | None = None, 91 | **kwargs: Any, 92 | ) -> dict: 93 | """Search embeddings from the database. 94 | Args: 95 | embedding(list[float]): embedding to search 96 | k(int): number of results to return 97 | kwargs: other arguments 98 | 99 | Returns: 100 | Dict {ids: list[list[int]], 101 | embedding: list[list[float]] 102 | distance: list[list[float]]} 103 | """ 104 | if filters: 105 | # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} 106 | id_value = filters.get("id") 107 | results = self.collection.query( 108 | query_embeddings=query, 109 | n_results=k, 110 | where={"id": {"$gt": id_value}}, 111 | ) 112 | # return list of id's in results 113 | return [int(i) for i in results.get("ids")[0]] 114 | results = self.collection.query(query_embeddings=query, n_results=k) 115 | return [int(i) for i in results.get("ids")[0]] 116 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/chroma/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import SecretStr 2 | 3 | from ..api import DBConfig 4 | 5 | 6 | class ChromaConfig(DBConfig): 7 | password: SecretStr 8 | host: SecretStr 9 | port: int 10 | 11 | def to_dict(self) -> dict: 12 | return { 13 | "host": self.host.get_secret_value(), 14 | "port": self.port, 15 | "password": self.password.get_secret_value(), 16 | } 17 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/clickhouse/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | HNSWFlavor2, 9 | cli, 10 | click_parameter_decorators_from_typed_dict, 11 | run, 12 | ) 13 | from .. import DB 14 | from .config import ClickhouseHNSWConfig 15 | 16 | 17 | class ClickhouseTypedDict(TypedDict): 18 | password: Annotated[str, click.option("--password", type=str, help="DB password")] 19 | host: Annotated[str, click.option("--host", type=str, help="DB host", required=True)] 20 | port: Annotated[int, click.option("--port", type=int, default=8123, help="DB Port")] 21 | user: Annotated[int, click.option("--user", type=str, default="clickhouse", help="DB user")] 22 | ssl: Annotated[ 23 | bool, 24 | click.option( 25 | "--ssl/--no-ssl", 26 | is_flag=True, 27 | show_default=True, 28 | default=True, 29 | help="Enable or disable SSL for Clickhouse", 30 | ), 31 | ] 32 | ssl_ca_certs: Annotated[ 33 | str, 34 | click.option( 35 | "--ssl-ca-certs", 36 | show_default=True, 37 | help="Path to certificate authority file to use for SSL", 38 | ), 39 | ] 40 | 41 | 42 | class ClickhouseHNSWTypedDict(CommonTypedDict, ClickhouseTypedDict, HNSWFlavor2): ... 43 | 44 | 45 | @cli.command() 46 | @click_parameter_decorators_from_typed_dict(ClickhouseHNSWTypedDict) 47 | def Clickhouse(**parameters: Unpack[ClickhouseHNSWTypedDict]): 48 | from .config import ClickhouseConfig 49 | 50 | run( 51 | db=DB.Clickhouse, 52 | db_config=ClickhouseConfig( 53 | db_label=parameters["db_label"], 54 | user=parameters["user"], 55 | password=SecretStr(parameters["password"]) if parameters["password"] else None, 56 | host=parameters["host"], 57 | port=parameters["port"], 58 | ssl=parameters["ssl"], 59 | ssl_ca_certs=parameters["ssl_ca_certs"], 60 | ), 61 | db_case_config=ClickhouseHNSWConfig( 62 | M=parameters["m"], 63 | efConstruction=parameters["ef_construction"], 64 | ef=parameters["ef_runtime"], 65 | ), 66 | **parameters, 67 | ) 68 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/clickhouse/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TypedDict 3 | 4 | from pydantic import BaseModel, SecretStr 5 | 6 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 7 | 8 | 9 | class ClickhouseConfigDict(TypedDict): 10 | user: str 11 | password: str 12 | host: str 13 | port: int 14 | database: str 15 | secure: bool 16 | 17 | 18 | class ClickhouseConfig(DBConfig): 19 | user: str = "clickhouse" 20 | password: SecretStr 21 | host: str = "localhost" 22 | port: int = 8123 23 | db_name: str = "default" 24 | secure: bool = False 25 | 26 | def to_dict(self) -> ClickhouseConfigDict: 27 | pwd_str = self.password.get_secret_value() 28 | return { 29 | "host": self.host, 30 | "port": self.port, 31 | "database": self.db_name, 32 | "user": self.user, 33 | "password": pwd_str, 34 | "secure": self.secure, 35 | } 36 | 37 | 38 | class ClickhouseIndexConfig(BaseModel, DBCaseConfig): 39 | 40 | metric_type: MetricType | None = None 41 | vector_data_type: str | None = "Float32" # Data type of vectors. Can be Float32 or Float64 or BFloat16 42 | create_index_before_load: bool = True 43 | create_index_after_load: bool = False 44 | 45 | def parse_metric(self) -> str: 46 | if not self.metric_type: 47 | return "" 48 | return self.metric_type.value 49 | 50 | def parse_metric_str(self) -> str: 51 | if self.metric_type == MetricType.L2: 52 | return "L2Distance" 53 | if self.metric_type == MetricType.COSINE: 54 | return "cosineDistance" 55 | return "cosineDistance" 56 | 57 | @abstractmethod 58 | def session_param(self): 59 | pass 60 | 61 | 62 | class ClickhouseHNSWConfig(ClickhouseIndexConfig): 63 | M: int | None # Default in clickhouse in 32 64 | efConstruction: int | None # Default in clickhouse in 128 65 | ef: int | None = None 66 | index: IndexType = IndexType.HNSW 67 | quantization: str | None = "bf16" # Default is bf16. Possible values are f64, f32, f16, bf16, or i8 68 | granularity: int | None = 10_000_000 # Size of the index granules. By default, in CH it's equal 10.000.000 69 | 70 | def index_param(self) -> dict: 71 | return { 72 | "vector_data_type": self.vector_data_type, 73 | "metric_type": self.parse_metric_str(), 74 | "index_type": self.index.value, 75 | "quantization": self.quantization, 76 | "granularity": self.granularity, 77 | "params": {"M": self.M, "efConstruction": self.efConstruction}, 78 | } 79 | 80 | def search_param(self) -> dict: 81 | return { 82 | "metric_type": self.parse_metric_str(), 83 | "params": {"ef": self.ef}, 84 | } 85 | 86 | def session_param(self) -> dict: 87 | return { 88 | "allow_experimental_vector_similarity_index": 1, 89 | } 90 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/elastic_cloud/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel, SecretStr 4 | 5 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 6 | 7 | 8 | class ElasticCloudConfig(DBConfig, BaseModel): 9 | cloud_id: SecretStr 10 | password: SecretStr 11 | 12 | def to_dict(self) -> dict: 13 | return { 14 | "cloud_id": self.cloud_id.get_secret_value(), 15 | "basic_auth": ("elastic", self.password.get_secret_value()), 16 | } 17 | 18 | 19 | class ESElementType(str, Enum): 20 | float = "float" # 4 byte 21 | byte = "byte" # 1 byte, -128 to 127 22 | 23 | 24 | class ElasticCloudIndexConfig(BaseModel, DBCaseConfig): 25 | element_type: ESElementType = ESElementType.float 26 | index: IndexType = IndexType.ES_HNSW # ES only support 'hnsw' 27 | 28 | metric_type: MetricType | None = None 29 | efConstruction: int | None = None 30 | M: int | None = None 31 | num_candidates: int | None = None 32 | 33 | def parse_metric(self) -> str: 34 | if self.metric_type == MetricType.L2: 35 | return "l2_norm" 36 | if self.metric_type == MetricType.IP: 37 | return "dot_product" 38 | return "cosine" 39 | 40 | def index_param(self) -> dict: 41 | return { 42 | "type": "dense_vector", 43 | "index": True, 44 | "element_type": self.element_type.value, 45 | "similarity": self.parse_metric(), 46 | "index_options": { 47 | "type": self.index.value, 48 | "m": self.M, 49 | "ef_construction": self.efConstruction, 50 | }, 51 | } 52 | 53 | def search_param(self) -> dict: 54 | return { 55 | "num_candidates": self.num_candidates, 56 | } 57 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/lancedb/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | cli, 9 | click_parameter_decorators_from_typed_dict, 10 | run, 11 | ) 12 | from .. import DB 13 | from ..api import IndexType 14 | 15 | 16 | class LanceDBTypedDict(CommonTypedDict): 17 | uri: Annotated[ 18 | str, 19 | click.option("--uri", type=str, help="URI connection string", required=True), 20 | ] 21 | token: Annotated[ 22 | str | None, 23 | click.option("--token", type=str, help="Authentication token", required=False), 24 | ] 25 | 26 | 27 | @cli.command() 28 | @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) 29 | def LanceDB(**parameters: Unpack[LanceDBTypedDict]): 30 | from .config import LanceDBConfig, _lancedb_case_config 31 | 32 | run( 33 | db=DB.LanceDB, 34 | db_config=LanceDBConfig( 35 | db_label=parameters["db_label"], 36 | uri=parameters["uri"], 37 | token=SecretStr(parameters["token"]) if parameters.get("token") else None, 38 | ), 39 | db_case_config=_lancedb_case_config.get("NONE")(), 40 | **parameters, 41 | ) 42 | 43 | 44 | @cli.command() 45 | @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) 46 | def LanceDBAutoIndex(**parameters: Unpack[LanceDBTypedDict]): 47 | from .config import LanceDBConfig, _lancedb_case_config 48 | 49 | run( 50 | db=DB.LanceDB, 51 | db_config=LanceDBConfig( 52 | db_label=parameters["db_label"], 53 | uri=parameters["uri"], 54 | token=SecretStr(parameters["token"]) if parameters.get("token") else None, 55 | ), 56 | db_case_config=_lancedb_case_config.get(IndexType.AUTOINDEX)(), 57 | **parameters, 58 | ) 59 | 60 | 61 | @cli.command() 62 | @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) 63 | def LanceDBIVFPQ(**parameters: Unpack[LanceDBTypedDict]): 64 | from .config import LanceDBConfig, _lancedb_case_config 65 | 66 | run( 67 | db=DB.LanceDB, 68 | db_config=LanceDBConfig( 69 | db_label=parameters["db_label"], 70 | uri=parameters["uri"], 71 | token=SecretStr(parameters["token"]) if parameters.get("token") else None, 72 | ), 73 | db_case_config=_lancedb_case_config.get(IndexType.IVFPQ)(), 74 | **parameters, 75 | ) 76 | 77 | 78 | @cli.command() 79 | @click_parameter_decorators_from_typed_dict(LanceDBTypedDict) 80 | def LanceDBHNSW(**parameters: Unpack[LanceDBTypedDict]): 81 | from .config import LanceDBConfig, _lancedb_case_config 82 | 83 | run( 84 | db=DB.LanceDB, 85 | db_config=LanceDBConfig( 86 | db_label=parameters["db_label"], 87 | uri=parameters["uri"], 88 | token=SecretStr(parameters["token"]) if parameters.get("token") else None, 89 | ), 90 | db_case_config=_lancedb_case_config.get(IndexType.HNSW)(), 91 | **parameters, 92 | ) 93 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/lancedb/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 4 | 5 | 6 | class LanceDBConfig(DBConfig): 7 | """LanceDB connection configuration.""" 8 | 9 | db_label: str 10 | uri: str 11 | token: SecretStr | None = None 12 | 13 | def to_dict(self) -> dict: 14 | return { 15 | "uri": self.uri, 16 | "token": self.token.get_secret_value() if self.token else None, 17 | } 18 | 19 | 20 | class LanceDBIndexConfig(BaseModel, DBCaseConfig): 21 | index: IndexType = IndexType.IVFPQ 22 | metric_type: MetricType = MetricType.L2 23 | num_partitions: int = 0 24 | num_sub_vectors: int = 0 25 | nbits: int = 8 # Must be 4 or 8 26 | sample_rate: int = 256 27 | max_iterations: int = 50 28 | 29 | def index_param(self) -> dict: 30 | if self.index not in [ 31 | IndexType.IVFPQ, 32 | IndexType.HNSW, 33 | IndexType.AUTOINDEX, 34 | IndexType.NONE, 35 | ]: 36 | msg = f"Index type {self.index} is not supported for LanceDB!" 37 | raise ValueError(msg) 38 | 39 | # See https://lancedb.github.io/lancedb/python/python/#lancedb.table.Table.create_index 40 | params = { 41 | "metric": self.parse_metric(), 42 | "num_bits": self.nbits, 43 | "sample_rate": self.sample_rate, 44 | "max_iterations": self.max_iterations, 45 | } 46 | 47 | if self.num_partitions > 0: 48 | params["num_partitions"] = self.num_partitions 49 | if self.num_sub_vectors > 0: 50 | params["num_sub_vectors"] = self.num_sub_vectors 51 | 52 | return params 53 | 54 | def search_param(self) -> dict: 55 | pass 56 | 57 | def parse_metric(self) -> str: 58 | if self.metric_type in [MetricType.L2, MetricType.COSINE]: 59 | return self.metric_type.value.lower() 60 | if self.metric_type in [MetricType.IP, MetricType.DP]: 61 | return "dot" 62 | msg = f"Metric type {self.metric_type} is not supported for LanceDB!" 63 | raise ValueError(msg) 64 | 65 | 66 | class LanceDBNoIndexConfig(LanceDBIndexConfig): 67 | index: IndexType = IndexType.NONE 68 | 69 | def index_param(self) -> dict: 70 | return {} 71 | 72 | 73 | class LanceDBAutoIndexConfig(LanceDBIndexConfig): 74 | index: IndexType = IndexType.AUTOINDEX 75 | 76 | def index_param(self) -> dict: 77 | return {} 78 | 79 | 80 | class LanceDBHNSWIndexConfig(LanceDBIndexConfig): 81 | index: IndexType = IndexType.HNSW 82 | m: int = 0 83 | ef_construction: int = 0 84 | 85 | def index_param(self) -> dict: 86 | params = LanceDBIndexConfig.index_param(self) 87 | 88 | # See https://lancedb.github.io/lancedb/python/python/#lancedb.index.HnswSq 89 | params["index_type"] = "IVF_HNSW_SQ" 90 | if self.m > 0: 91 | params["m"] = self.m 92 | if self.ef_construction > 0: 93 | params["ef_construction"] = self.ef_construction 94 | 95 | return params 96 | 97 | 98 | _lancedb_case_config = { 99 | IndexType.IVFPQ: LanceDBIndexConfig, 100 | IndexType.AUTOINDEX: LanceDBAutoIndexConfig, 101 | IndexType.HNSW: LanceDBHNSWIndexConfig, 102 | IndexType.NONE: LanceDBNoIndexConfig, 103 | } 104 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/lancedb/lancedb.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | 4 | import lancedb 5 | import pyarrow as pa 6 | from lancedb.pydantic import LanceModel 7 | 8 | from ..api import IndexType, VectorDB 9 | from .config import LanceDBConfig, LanceDBIndexConfig 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | class VectorModel(LanceModel): 15 | id: int 16 | vector: list[float] 17 | 18 | 19 | class LanceDB(VectorDB): 20 | def __init__( 21 | self, 22 | dim: int, 23 | db_config: LanceDBConfig, 24 | db_case_config: LanceDBIndexConfig, 25 | collection_name: str = "vector_bench_test", 26 | drop_old: bool = False, 27 | **kwargs, 28 | ): 29 | self.name = "LanceDB" 30 | self.db_config = db_config 31 | self.case_config = db_case_config 32 | self.table_name = collection_name 33 | self.dim = dim 34 | self.uri = db_config["uri"] 35 | 36 | db = lancedb.connect(self.uri) 37 | 38 | if drop_old: 39 | try: 40 | db.drop_table(self.table_name) 41 | except Exception as e: 42 | log.warning(f"Failed to drop table {self.table_name}: {e}") 43 | 44 | try: 45 | db.open_table(self.table_name) 46 | except Exception: 47 | schema = pa.schema( 48 | [pa.field("id", pa.int64()), pa.field("vector", pa.list_(pa.float64(), list_size=self.dim))] 49 | ) 50 | db.create_table(self.table_name, schema=schema, mode="overwrite") 51 | 52 | @contextmanager 53 | def init(self): 54 | self.db = lancedb.connect(self.uri) 55 | self.table = self.db.open_table(self.table_name) 56 | yield 57 | self.db = None 58 | self.table = None 59 | 60 | def insert_embeddings( 61 | self, 62 | embeddings: list[list[float]], 63 | metadata: list[int], 64 | ) -> tuple[int, Exception | None]: 65 | try: 66 | data = [{"id": meta, "vector": emb} for meta, emb in zip(metadata, embeddings, strict=False)] 67 | self.table.add(data) 68 | return len(metadata), None 69 | except Exception as e: 70 | log.warning(f"Failed to insert data into LanceDB table ({self.table_name}), error: {e}") 71 | return 0, e 72 | 73 | def search_embedding( 74 | self, 75 | query: list[float], 76 | k: int = 100, 77 | filters: dict | None = None, 78 | ) -> list[int]: 79 | if filters: 80 | results = ( 81 | self.table.search(query) 82 | .select(["id"]) 83 | .where(f"id >= {filters['id']}", prefilter=True) 84 | .limit(k) 85 | .to_list() 86 | ) 87 | else: 88 | results = self.table.search(query).select(["id"]).limit(k).to_list() 89 | return [int(result["id"]) for result in results] 90 | 91 | def optimize(self, data_size: int | None = None): 92 | if self.table and hasattr(self, "case_config") and self.case_config.index != IndexType.NONE: 93 | log.info(f"Creating index for LanceDB table ({self.table_name})") 94 | self.table.create_index(**self.case_config.index_param()) 95 | # Better recall with IVF_PQ (though still bad) but breaks HNSW: https://github.com/lancedb/lancedb/issues/2369 96 | if self.case_config.index in (IndexType.IVFPQ, IndexType.AUTOINDEX): 97 | self.table.optimize() 98 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/mariadb/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from vectordb_bench.backend.clients import DB 7 | 8 | from ....cli.cli import ( 9 | CommonTypedDict, 10 | cli, 11 | click_parameter_decorators_from_typed_dict, 12 | run, 13 | ) 14 | 15 | 16 | class MariaDBTypedDict(CommonTypedDict): 17 | user_name: Annotated[ 18 | str, 19 | click.option( 20 | "--username", 21 | type=str, 22 | help="Username", 23 | required=True, 24 | ), 25 | ] 26 | password: Annotated[ 27 | str, 28 | click.option( 29 | "--password", 30 | type=str, 31 | help="Password", 32 | required=True, 33 | ), 34 | ] 35 | 36 | host: Annotated[ 37 | str, 38 | click.option( 39 | "--host", 40 | type=str, 41 | help="Db host", 42 | default="127.0.0.1", 43 | ), 44 | ] 45 | 46 | port: Annotated[ 47 | int, 48 | click.option( 49 | "--port", 50 | type=int, 51 | default=3306, 52 | help="Db Port", 53 | ), 54 | ] 55 | 56 | storage_engine: Annotated[ 57 | int, 58 | click.option( 59 | "--storage-engine", 60 | type=click.Choice(["InnoDB", "MyISAM"]), 61 | help="DB storage engine", 62 | required=True, 63 | ), 64 | ] 65 | 66 | 67 | class MariaDBHNSWTypedDict(MariaDBTypedDict): 68 | m: Annotated[ 69 | int | None, 70 | click.option( 71 | "--m", 72 | type=int, 73 | help="M parameter in MHNSW vector indexing", 74 | required=False, 75 | ), 76 | ] 77 | 78 | ef_search: Annotated[ 79 | int | None, 80 | click.option( 81 | "--ef-search", 82 | type=int, 83 | help="MariaDB system variable mhnsw_min_limit", 84 | required=False, 85 | ), 86 | ] 87 | 88 | max_cache_size: Annotated[ 89 | int | None, 90 | click.option( 91 | "--max-cache-size", 92 | type=int, 93 | help="MariaDB system variable mhnsw_max_cache_size", 94 | required=False, 95 | ), 96 | ] 97 | 98 | 99 | @cli.command() 100 | @click_parameter_decorators_from_typed_dict(MariaDBHNSWTypedDict) 101 | def MariaDBHNSW( 102 | **parameters: Unpack[MariaDBHNSWTypedDict], 103 | ): 104 | from .config import MariaDBConfig, MariaDBHNSWConfig 105 | 106 | run( 107 | db=DB.MariaDB, 108 | db_config=MariaDBConfig( 109 | db_label=parameters["db_label"], 110 | user_name=parameters["username"], 111 | password=SecretStr(parameters["password"]), 112 | host=parameters["host"], 113 | port=parameters["port"], 114 | ), 115 | db_case_config=MariaDBHNSWConfig( 116 | M=parameters["m"], 117 | ef_search=parameters["ef_search"], 118 | storage_engine=parameters["storage_engine"], 119 | max_cache_size=parameters["max_cache_size"], 120 | ), 121 | **parameters, 122 | ) 123 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/mariadb/config.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict 2 | 3 | from pydantic import BaseModel, SecretStr 4 | 5 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 6 | 7 | 8 | class MariaDBConfigDict(TypedDict): 9 | """These keys will be directly used as kwargs in mariadb connection string, 10 | so the names must match exactly mariadb API""" 11 | 12 | user: str 13 | password: str 14 | host: str 15 | port: int 16 | 17 | 18 | class MariaDBConfig(DBConfig): 19 | user_name: str = "root" 20 | password: SecretStr 21 | host: str = "127.0.0.1" 22 | port: int = 3306 23 | 24 | def to_dict(self) -> MariaDBConfigDict: 25 | pwd_str = self.password.get_secret_value() 26 | return { 27 | "host": self.host, 28 | "port": self.port, 29 | "user": self.user_name, 30 | "password": pwd_str, 31 | } 32 | 33 | 34 | class MariaDBIndexConfig(BaseModel): 35 | """Base config for MariaDB""" 36 | 37 | metric_type: MetricType | None = None 38 | 39 | def parse_metric(self) -> str: 40 | if self.metric_type == MetricType.L2: 41 | return "euclidean" 42 | if self.metric_type == MetricType.COSINE: 43 | return "cosine" 44 | msg = f"Metric type {self.metric_type} is not supported!" 45 | raise ValueError(msg) 46 | 47 | 48 | class MariaDBHNSWConfig(MariaDBIndexConfig, DBCaseConfig): 49 | M: int | None 50 | ef_search: int | None 51 | index: IndexType = IndexType.HNSW 52 | storage_engine: str = "InnoDB" 53 | max_cache_size: int | None 54 | 55 | def index_param(self) -> dict: 56 | return { 57 | "storage_engine": self.storage_engine, 58 | "metric_type": self.parse_metric(), 59 | "index_type": self.index.value, 60 | "M": self.M, 61 | "max_cache_size": self.max_cache_size, 62 | } 63 | 64 | def search_param(self) -> dict: 65 | return { 66 | "metric_type": self.parse_metric(), 67 | "ef_search": self.ef_search, 68 | } 69 | 70 | 71 | _mariadb_case_config = { 72 | IndexType.HNSW: MariaDBHNSWConfig, 73 | } 74 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/memorydb/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | HNSWFlavor2, 9 | cli, 10 | click_parameter_decorators_from_typed_dict, 11 | run, 12 | ) 13 | from .. import DB 14 | 15 | 16 | class MemoryDBTypedDict(TypedDict): 17 | host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] 18 | password: Annotated[str, click.option("--password", type=str, help="Db password")] 19 | port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] 20 | ssl: Annotated[ 21 | bool, 22 | click.option( 23 | "--ssl/--no-ssl", 24 | is_flag=True, 25 | show_default=True, 26 | default=True, 27 | help="Enable or disable SSL for MemoryDB", 28 | ), 29 | ] 30 | ssl_ca_certs: Annotated[ 31 | str, 32 | click.option( 33 | "--ssl-ca-certs", 34 | show_default=True, 35 | help="Path to certificate authority file to use for SSL", 36 | ), 37 | ] 38 | cmd: Annotated[ 39 | bool, 40 | click.option( 41 | "--cmd", 42 | is_flag=True, 43 | show_default=True, 44 | default=False, 45 | help=( 46 | "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance." 47 | " In production, MemoryDB only supports cluster mode (CME)" 48 | ), 49 | ), 50 | ] 51 | insert_batch_size: Annotated[ 52 | int, 53 | click.option( 54 | "--insert-batch-size", 55 | type=int, 56 | default=10, 57 | help="Batch size for inserting data. Adjust this as needed, but don't make it too big", 58 | ), 59 | ] 60 | 61 | 62 | class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ... 63 | 64 | 65 | @cli.command() 66 | @click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict) 67 | def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]): 68 | from .config import MemoryDBConfig, MemoryDBHNSWConfig 69 | 70 | run( 71 | db=DB.MemoryDB, 72 | db_config=MemoryDBConfig( 73 | db_label=parameters["db_label"], 74 | password=SecretStr(parameters["password"]) if parameters["password"] else None, 75 | host=SecretStr(parameters["host"]), 76 | port=parameters["port"], 77 | ssl=parameters["ssl"], 78 | ssl_ca_certs=parameters["ssl_ca_certs"], 79 | cmd=parameters["cmd"], 80 | ), 81 | db_case_config=MemoryDBHNSWConfig( 82 | M=parameters["m"], 83 | ef_construction=parameters["ef_construction"], 84 | ef_runtime=parameters["ef_runtime"], 85 | insert_batch_size=parameters["insert_batch_size"], 86 | ), 87 | **parameters, 88 | ) 89 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/memorydb/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 4 | 5 | 6 | class MemoryDBConfig(DBConfig): 7 | host: SecretStr 8 | password: SecretStr | None = None 9 | port: int | None = None 10 | ssl: bool | None = None 11 | cmd: bool | None = None 12 | ssl_ca_certs: str | None = None 13 | 14 | def to_dict(self) -> dict: 15 | return { 16 | "host": self.host.get_secret_value(), 17 | "port": self.port, 18 | "password": self.password.get_secret_value() if self.password else None, 19 | "ssl": self.ssl, 20 | "cmd": self.cmd, 21 | "ssl_ca_certs": self.ssl_ca_certs, 22 | } 23 | 24 | 25 | class MemoryDBIndexConfig(BaseModel, DBCaseConfig): 26 | metric_type: MetricType | None = None 27 | insert_batch_size: int | None = None 28 | 29 | def parse_metric(self) -> str: 30 | if self.metric_type == MetricType.L2: 31 | return "l2" 32 | if self.metric_type == MetricType.IP: 33 | return "ip" 34 | return "cosine" 35 | 36 | 37 | class MemoryDBHNSWConfig(MemoryDBIndexConfig): 38 | M: int | None = 16 39 | ef_construction: int | None = 64 40 | ef_runtime: int | None = 10 41 | index: IndexType = IndexType.HNSW 42 | 43 | def index_param(self) -> dict: 44 | return { 45 | "metric": self.parse_metric(), 46 | "index_type": self.index.value, 47 | "m": self.M, 48 | "ef_construction": self.ef_construction, 49 | } 50 | 51 | def search_param(self) -> dict: 52 | return { 53 | "ef_runtime": self.ef_runtime, 54 | } 55 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/mongodb/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel, SecretStr 4 | 5 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 6 | 7 | 8 | class QuantizationType(Enum): 9 | NONE = "none" 10 | BINARY = "binary" 11 | SCALAR = "scalar" 12 | 13 | 14 | class MongoDBConfig(DBConfig, BaseModel): 15 | connection_string: SecretStr = "mongodb+srv://:@.heatl.mongodb.net" 16 | database: str = "vdb_bench" 17 | 18 | def to_dict(self) -> dict: 19 | return { 20 | "connection_string": self.connection_string.get_secret_value(), 21 | "database": self.database, 22 | } 23 | 24 | 25 | class MongoDBIndexConfig(BaseModel, DBCaseConfig): 26 | index: IndexType = IndexType.HNSW # MongoDB uses HNSW for vector search 27 | metric_type: MetricType = MetricType.COSINE 28 | num_candidates_ratio: int = 10 # Default numCandidates ratio for vector search 29 | quantization: QuantizationType = QuantizationType.NONE # Quantization type if applicable 30 | 31 | def parse_metric(self) -> str: 32 | if self.metric_type == MetricType.L2: 33 | return "euclidean" 34 | if self.metric_type == MetricType.IP: 35 | return "dotProduct" 36 | return "cosine" # Default to cosine similarity 37 | 38 | def index_param(self) -> dict: 39 | return { 40 | "type": "vectorSearch", 41 | "fields": [ 42 | { 43 | "type": "vector", 44 | "similarity": self.parse_metric(), 45 | "numDimensions": None, # Will be set in MongoDB class 46 | "path": "vector", # Vector field name 47 | "quantization": self.quantization.value, 48 | } 49 | ], 50 | } 51 | 52 | def search_param(self) -> dict: 53 | return {"num_candidates_ratio": self.num_candidates_ratio} 54 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pgdiskann/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Annotated, Unpack 3 | 4 | import click 5 | from pydantic import SecretStr 6 | 7 | from vectordb_bench.backend.clients import DB 8 | 9 | from ....cli.cli import ( 10 | CommonTypedDict, 11 | cli, 12 | click_parameter_decorators_from_typed_dict, 13 | run, 14 | ) 15 | 16 | 17 | class PgDiskAnnTypedDict(CommonTypedDict): 18 | user_name: Annotated[ 19 | str, 20 | click.option("--user-name", type=str, help="Db username", required=True), 21 | ] 22 | password: Annotated[ 23 | str, 24 | click.option( 25 | "--password", 26 | type=str, 27 | help="Postgres database password", 28 | default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), 29 | show_default="$POSTGRES_PASSWORD", 30 | ), 31 | ] 32 | 33 | host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] 34 | db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] 35 | max_neighbors: Annotated[ 36 | int, 37 | click.option( 38 | "--max-neighbors", 39 | type=int, 40 | help="PgDiskAnn max neighbors", 41 | ), 42 | ] 43 | l_value_ib: Annotated[ 44 | int, 45 | click.option( 46 | "--l-value-ib", 47 | type=int, 48 | help="PgDiskAnn l_value_ib", 49 | ), 50 | ] 51 | l_value_is: Annotated[ 52 | float, 53 | click.option( 54 | "--l-value-is", 55 | type=float, 56 | help="PgDiskAnn l_value_is", 57 | ), 58 | ] 59 | maintenance_work_mem: Annotated[ 60 | str | None, 61 | click.option( 62 | "--maintenance-work-mem", 63 | type=str, 64 | help="Sets the maximum memory to be used for maintenance operations (index creation). " 65 | "Can be entered as string with unit like '64GB' or as an integer number of KB." 66 | "This will set the parameters: max_parallel_maintenance_workers," 67 | " max_parallel_workers & table(parallel_workers)", 68 | required=False, 69 | ), 70 | ] 71 | max_parallel_workers: Annotated[ 72 | int | None, 73 | click.option( 74 | "--max-parallel-workers", 75 | type=int, 76 | help="Sets the maximum number of parallel processes per maintenance operation (index creation)", 77 | required=False, 78 | ), 79 | ] 80 | 81 | 82 | @cli.command() 83 | @click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict) 84 | def PgDiskAnn( 85 | **parameters: Unpack[PgDiskAnnTypedDict], 86 | ): 87 | from .config import PgDiskANNConfig, PgDiskANNImplConfig 88 | 89 | run( 90 | db=DB.PgDiskANN, 91 | db_config=PgDiskANNConfig( 92 | db_label=parameters["db_label"], 93 | user_name=SecretStr(parameters["user_name"]), 94 | password=SecretStr(parameters["password"]), 95 | host=parameters["host"], 96 | db_name=parameters["db_name"], 97 | ), 98 | db_case_config=PgDiskANNImplConfig( 99 | max_neighbors=parameters["max_neighbors"], 100 | l_value_ib=parameters["l_value_ib"], 101 | l_value_is=parameters["l_value_is"], 102 | max_parallel_workers=parameters["max_parallel_workers"], 103 | maintenance_work_mem=parameters["maintenance_work_mem"], 104 | ), 105 | **parameters, 106 | ) 107 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pgdiskann/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections.abc import Mapping, Sequence 3 | from typing import Any, LiteralString, TypedDict 4 | 5 | from pydantic import BaseModel, SecretStr 6 | 7 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 8 | 9 | POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" 10 | 11 | 12 | class PgDiskANNConfigDict(TypedDict): 13 | """These keys will be directly used as kwargs in psycopg connection string, 14 | so the names must match exactly psycopg API""" 15 | 16 | user: str 17 | password: str 18 | host: str 19 | port: int 20 | dbname: str 21 | 22 | 23 | class PgDiskANNConfig(DBConfig): 24 | user_name: SecretStr = SecretStr("postgres") 25 | password: SecretStr 26 | host: str = "localhost" 27 | port: int = 5432 28 | db_name: str 29 | 30 | def to_dict(self) -> PgDiskANNConfigDict: 31 | user_str = self.user_name.get_secret_value() 32 | pwd_str = self.password.get_secret_value() 33 | return { 34 | "host": self.host, 35 | "port": self.port, 36 | "dbname": self.db_name, 37 | "user": user_str, 38 | "password": pwd_str, 39 | } 40 | 41 | 42 | class PgDiskANNIndexConfig(BaseModel, DBCaseConfig): 43 | metric_type: MetricType | None = None 44 | create_index_before_load: bool = False 45 | create_index_after_load: bool = True 46 | maintenance_work_mem: str | None 47 | max_parallel_workers: int | None 48 | 49 | def parse_metric(self) -> str: 50 | if self.metric_type == MetricType.L2: 51 | return "vector_l2_ops" 52 | if self.metric_type == MetricType.IP: 53 | return "vector_ip_ops" 54 | return "vector_cosine_ops" 55 | 56 | def parse_metric_fun_op(self) -> LiteralString: 57 | if self.metric_type == MetricType.L2: 58 | return "<->" 59 | if self.metric_type == MetricType.IP: 60 | return "<#>" 61 | return "<=>" 62 | 63 | def parse_metric_fun_str(self) -> str: 64 | if self.metric_type == MetricType.L2: 65 | return "l2_distance" 66 | if self.metric_type == MetricType.IP: 67 | return "max_inner_product" 68 | return "cosine_distance" 69 | 70 | @abstractmethod 71 | def index_param(self) -> dict: ... 72 | 73 | @abstractmethod 74 | def search_param(self) -> dict: ... 75 | 76 | @abstractmethod 77 | def session_param(self) -> dict: ... 78 | 79 | @staticmethod 80 | def _optionally_build_with_options( 81 | with_options: Mapping[str, Any], 82 | ) -> Sequence[dict[str, Any]]: 83 | """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" 84 | options = [] 85 | for option_name, value in with_options.items(): 86 | if value is not None: 87 | options.append( 88 | { 89 | "option_name": option_name, 90 | "val": str(value), 91 | }, 92 | ) 93 | return options 94 | 95 | @staticmethod 96 | def _optionally_build_set_options( 97 | set_mapping: Mapping[str, Any], 98 | ) -> Sequence[dict[str, Any]]: 99 | """Walk through options, creating 'SET 'key1 = "value1";' list""" 100 | session_options = [] 101 | for setting_name, value in set_mapping.items(): 102 | if value: 103 | session_options.append( 104 | { 105 | "parameter": { 106 | "setting_name": setting_name, 107 | "val": str(value), 108 | }, 109 | }, 110 | ) 111 | return session_options 112 | 113 | 114 | class PgDiskANNImplConfig(PgDiskANNIndexConfig): 115 | index: IndexType = IndexType.DISKANN 116 | max_neighbors: int | None 117 | l_value_ib: int | None 118 | l_value_is: float | None 119 | maintenance_work_mem: str | None = None 120 | max_parallel_workers: int | None = None 121 | 122 | def index_param(self) -> dict: 123 | return { 124 | "metric": self.parse_metric(), 125 | "index_type": self.index.value, 126 | "options": { 127 | "max_neighbors": self.max_neighbors, 128 | "l_value_ib": self.l_value_ib, 129 | }, 130 | "maintenance_work_mem": self.maintenance_work_mem, 131 | "max_parallel_workers": self.max_parallel_workers, 132 | } 133 | 134 | def search_param(self) -> dict: 135 | return { 136 | "metric": self.parse_metric(), 137 | "metric_fun_op": self.parse_metric_fun_op(), 138 | } 139 | 140 | def session_param(self) -> dict: 141 | return { 142 | "diskann.l_value_is": self.l_value_is, 143 | } 144 | 145 | 146 | _pgdiskann_case_config = { 147 | IndexType.DISKANN: PgDiskANNImplConfig, 148 | } 149 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pgvectorscale/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Annotated, Unpack 3 | 4 | import click 5 | from pydantic import SecretStr 6 | 7 | from vectordb_bench.backend.clients import DB 8 | 9 | from ....cli.cli import ( 10 | CommonTypedDict, 11 | cli, 12 | click_parameter_decorators_from_typed_dict, 13 | run, 14 | ) 15 | 16 | 17 | class PgVectorScaleTypedDict(CommonTypedDict): 18 | user_name: Annotated[ 19 | str, 20 | click.option("--user-name", type=str, help="Db username", required=True), 21 | ] 22 | password: Annotated[ 23 | str, 24 | click.option( 25 | "--password", 26 | type=str, 27 | help="Postgres database password", 28 | default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), 29 | show_default="$POSTGRES_PASSWORD", 30 | ), 31 | ] 32 | 33 | host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] 34 | db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] 35 | 36 | 37 | class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict): 38 | storage_layout: Annotated[ 39 | str, 40 | click.option( 41 | "--storage-layout", 42 | type=str, 43 | help="Streaming DiskANN storage layout", 44 | ), 45 | ] 46 | num_neighbors: Annotated[ 47 | int, 48 | click.option( 49 | "--num-neighbors", 50 | type=int, 51 | help="Streaming DiskANN num neighbors", 52 | ), 53 | ] 54 | search_list_size: Annotated[ 55 | int, 56 | click.option( 57 | "--search-list-size", 58 | type=int, 59 | help="Streaming DiskANN search list size", 60 | ), 61 | ] 62 | max_alpha: Annotated[ 63 | float, 64 | click.option( 65 | "--max-alpha", 66 | type=float, 67 | help="Streaming DiskANN max alpha", 68 | ), 69 | ] 70 | num_dimensions: Annotated[ 71 | int, 72 | click.option( 73 | "--num-dimensions", 74 | type=int, 75 | help="Streaming DiskANN num dimensions", 76 | ), 77 | ] 78 | query_search_list_size: Annotated[ 79 | int, 80 | click.option( 81 | "--query-search-list-size", 82 | type=int, 83 | help="Streaming DiskANN query search list size", 84 | ), 85 | ] 86 | query_rescore: Annotated[ 87 | int, 88 | click.option( 89 | "--query-rescore", 90 | type=int, 91 | help="Streaming DiskANN query rescore", 92 | ), 93 | ] 94 | 95 | 96 | @cli.command() 97 | @click_parameter_decorators_from_typed_dict(PgVectorScaleDiskAnnTypedDict) 98 | def PgVectorScaleDiskAnn( 99 | **parameters: Unpack[PgVectorScaleDiskAnnTypedDict], 100 | ): 101 | from .config import PgVectorScaleConfig, PgVectorScaleStreamingDiskANNConfig 102 | 103 | run( 104 | db=DB.PgVectorScale, 105 | db_config=PgVectorScaleConfig( 106 | db_label=parameters["db_label"], 107 | user_name=SecretStr(parameters["user_name"]), 108 | password=SecretStr(parameters["password"]), 109 | host=parameters["host"], 110 | db_name=parameters["db_name"], 111 | ), 112 | db_case_config=PgVectorScaleStreamingDiskANNConfig( 113 | storage_layout=parameters["storage_layout"], 114 | num_neighbors=parameters["num_neighbors"], 115 | search_list_size=parameters["search_list_size"], 116 | max_alpha=parameters["max_alpha"], 117 | num_dimensions=parameters["num_dimensions"], 118 | query_search_list_size=parameters["query_search_list_size"], 119 | query_rescore=parameters["query_rescore"], 120 | ), 121 | **parameters, 122 | ) 123 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pgvectorscale/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import LiteralString, TypedDict 3 | 4 | from pydantic import BaseModel, SecretStr 5 | 6 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 7 | 8 | POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" 9 | 10 | 11 | class PgVectorScaleConfigDict(TypedDict): 12 | """These keys will be directly used as kwargs in psycopg connection string, 13 | so the names must match exactly psycopg API""" 14 | 15 | user: str 16 | password: str 17 | host: str 18 | port: int 19 | dbname: str 20 | 21 | 22 | class PgVectorScaleConfig(DBConfig): 23 | user_name: SecretStr = SecretStr("postgres") 24 | password: SecretStr 25 | host: str = "localhost" 26 | port: int = 5432 27 | db_name: str 28 | 29 | def to_dict(self) -> PgVectorScaleConfigDict: 30 | user_str = self.user_name.get_secret_value() 31 | pwd_str = self.password.get_secret_value() 32 | return { 33 | "host": self.host, 34 | "port": self.port, 35 | "dbname": self.db_name, 36 | "user": user_str, 37 | "password": pwd_str, 38 | } 39 | 40 | 41 | class PgVectorScaleIndexConfig(BaseModel, DBCaseConfig): 42 | metric_type: MetricType | None = None 43 | create_index_before_load: bool = False 44 | create_index_after_load: bool = True 45 | 46 | def parse_metric(self) -> str: 47 | if self.metric_type == MetricType.COSINE: 48 | return "vector_cosine_ops" 49 | return "" 50 | 51 | def parse_metric_fun_op(self) -> LiteralString: 52 | if self.metric_type == MetricType.COSINE: 53 | return "<=>" 54 | return "" 55 | 56 | def parse_metric_fun_str(self) -> str: 57 | if self.metric_type == MetricType.COSINE: 58 | return "cosine_distance" 59 | return "" 60 | 61 | @abstractmethod 62 | def index_param(self) -> dict: ... 63 | 64 | @abstractmethod 65 | def search_param(self) -> dict: ... 66 | 67 | @abstractmethod 68 | def session_param(self) -> dict: ... 69 | 70 | 71 | class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig): 72 | index: IndexType = IndexType.STREAMING_DISKANN 73 | storage_layout: str | None 74 | num_neighbors: int | None 75 | search_list_size: int | None 76 | max_alpha: float | None 77 | num_dimensions: int | None 78 | num_bits_per_dimension: int | None 79 | query_search_list_size: int | None 80 | query_rescore: int | None 81 | 82 | def index_param(self) -> dict: 83 | return { 84 | "metric": self.parse_metric(), 85 | "index_type": self.index.value, 86 | "options": { 87 | "storage_layout": self.storage_layout, 88 | "num_neighbors": self.num_neighbors, 89 | "search_list_size": self.search_list_size, 90 | "max_alpha": self.max_alpha, 91 | "num_dimensions": self.num_dimensions, 92 | }, 93 | } 94 | 95 | def search_param(self) -> dict: 96 | return { 97 | "metric": self.parse_metric(), 98 | "metric_fun_op": self.parse_metric_fun_op(), 99 | } 100 | 101 | def session_param(self) -> dict: 102 | return { 103 | "diskann.query_search_list_size": self.query_search_list_size, 104 | "diskann.query_rescore": self.query_rescore, 105 | } 106 | 107 | 108 | _pgvectorscale_case_config = { 109 | IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig, 110 | } 111 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pinecone/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import SecretStr 2 | 3 | from ..api import DBConfig 4 | 5 | 6 | class PineconeConfig(DBConfig): 7 | api_key: SecretStr 8 | index_name: str 9 | 10 | def to_dict(self) -> dict: 11 | return { 12 | "api_key": self.api_key.get_secret_value(), 13 | "index_name": self.index_name, 14 | } 15 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/pinecone/pinecone.py: -------------------------------------------------------------------------------- 1 | """Wrapper around the Pinecone vector database over VectorDB""" 2 | 3 | import logging 4 | from contextlib import contextmanager 5 | 6 | import pinecone 7 | 8 | from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB 9 | from .config import PineconeConfig 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | PINECONE_MAX_NUM_PER_BATCH = 1000 14 | PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB 15 | 16 | 17 | class Pinecone(VectorDB): 18 | def __init__( 19 | self, 20 | dim: int, 21 | db_config: dict, 22 | db_case_config: DBCaseConfig, 23 | drop_old: bool = False, 24 | **kwargs, 25 | ): 26 | """Initialize wrapper around the milvus vector database.""" 27 | self.index_name = db_config.get("index_name", "") 28 | self.api_key = db_config.get("api_key", "") 29 | self.batch_size = int( 30 | min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH), 31 | ) 32 | 33 | pc = pinecone.Pinecone(api_key=self.api_key) 34 | index = pc.Index(self.index_name) 35 | 36 | if drop_old: 37 | index_stats = index.describe_index_stats() 38 | index_dim = index_stats["dimension"] 39 | if index_dim != dim: 40 | msg = f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}" 41 | raise ValueError(msg) 42 | for namespace in index_stats["namespaces"]: 43 | log.info(f"Pinecone index delete namespace: {namespace}") 44 | index.delete(delete_all=True, namespace=namespace) 45 | 46 | self._metadata_key = "meta" 47 | 48 | @classmethod 49 | def config_cls(cls) -> type[DBConfig]: 50 | return PineconeConfig 51 | 52 | @classmethod 53 | def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]: 54 | return EmptyDBCaseConfig 55 | 56 | @contextmanager 57 | def init(self): 58 | pc = pinecone.Pinecone(api_key=self.api_key) 59 | self.index = pc.Index(self.index_name) 60 | yield 61 | 62 | def optimize(self, data_size: int | None = None): 63 | pass 64 | 65 | def insert_embeddings( 66 | self, 67 | embeddings: list[list[float]], 68 | metadata: list[int], 69 | **kwargs, 70 | ) -> tuple[int, Exception]: 71 | assert len(embeddings) == len(metadata) 72 | insert_count = 0 73 | try: 74 | for batch_start_offset in range(0, len(embeddings), self.batch_size): 75 | batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) 76 | insert_datas = [] 77 | for i in range(batch_start_offset, batch_end_offset): 78 | insert_data = ( 79 | str(metadata[i]), 80 | embeddings[i], 81 | {self._metadata_key: metadata[i]}, 82 | ) 83 | insert_datas.append(insert_data) 84 | self.index.upsert(insert_datas) 85 | insert_count += batch_end_offset - batch_start_offset 86 | except Exception as e: 87 | return (insert_count, e) 88 | return (len(embeddings), None) 89 | 90 | def search_embedding( 91 | self, 92 | query: list[float], 93 | k: int = 100, 94 | filters: dict | None = None, 95 | timeout: int | None = None, 96 | ) -> list[int]: 97 | pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}} 98 | try: 99 | res = self.index.query( 100 | top_k=k, 101 | vector=query, 102 | filter=pinecone_filters, 103 | )["matches"] 104 | except Exception as e: 105 | log.warning(f"Error querying index: {e}") 106 | raise e from e 107 | return [int(one_res["id"]) for one_res in res] 108 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/qdrant_cloud/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | cli, 9 | click_parameter_decorators_from_typed_dict, 10 | run, 11 | ) 12 | from .. import DB 13 | 14 | 15 | class QdrantTypedDict(CommonTypedDict): 16 | url: Annotated[ 17 | str, 18 | click.option("--url", type=str, help="URL connection string", required=True), 19 | ] 20 | api_key: Annotated[ 21 | str | None, 22 | click.option("--api-key", type=str, help="API key for authentication", required=False), 23 | ] 24 | 25 | 26 | @cli.command() 27 | @click_parameter_decorators_from_typed_dict(QdrantTypedDict) 28 | def QdrantCloud(**parameters: Unpack[QdrantTypedDict]): 29 | from .config import QdrantConfig, QdrantIndexConfig 30 | 31 | config_params = { 32 | "db_label": parameters["db_label"], 33 | "url": SecretStr(parameters["url"]), 34 | } 35 | 36 | config_params["api_key"] = SecretStr(parameters["api_key"]) if parameters["api_key"] else None 37 | 38 | run( 39 | db=DB.QdrantCloud, 40 | db_config=QdrantConfig(**config_params), 41 | db_case_config=QdrantIndexConfig(), 42 | **parameters, 43 | ) 44 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/qdrant_cloud/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, MetricType 4 | 5 | 6 | # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant. 7 | class QdrantConfig(DBConfig): 8 | url: SecretStr 9 | api_key: SecretStr | None = None 10 | 11 | def to_dict(self) -> dict: 12 | api_key_value = self.api_key.get_secret_value() if self.api_key else None 13 | if api_key_value: 14 | return { 15 | "url": self.url.get_secret_value(), 16 | "api_key": api_key_value, 17 | "prefer_grpc": True, 18 | } 19 | return { 20 | "url": self.url.get_secret_value(), 21 | } 22 | 23 | 24 | class QdrantIndexConfig(BaseModel, DBCaseConfig): 25 | metric_type: MetricType | None = None 26 | 27 | def parse_metric(self) -> str: 28 | if self.metric_type == MetricType.L2: 29 | return "Euclid" 30 | 31 | if self.metric_type == MetricType.IP: 32 | return "Dot" 33 | 34 | return "Cosine" 35 | 36 | def index_param(self) -> dict: 37 | return {"distance": self.parse_metric()} 38 | 39 | def search_param(self) -> dict: 40 | return {} 41 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/redis/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | HNSWFlavor2, 9 | cli, 10 | click_parameter_decorators_from_typed_dict, 11 | run, 12 | ) 13 | from .. import DB 14 | from .config import RedisHNSWConfig 15 | 16 | 17 | class RedisTypedDict(TypedDict): 18 | host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] 19 | password: Annotated[str, click.option("--password", type=str, help="Db password")] 20 | port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] 21 | ssl: Annotated[ 22 | bool, 23 | click.option( 24 | "--ssl/--no-ssl", 25 | is_flag=True, 26 | show_default=True, 27 | default=True, 28 | help="Enable or disable SSL for Redis", 29 | ), 30 | ] 31 | ssl_ca_certs: Annotated[ 32 | str, 33 | click.option( 34 | "--ssl-ca-certs", 35 | show_default=True, 36 | help="Path to certificate authority file to use for SSL", 37 | ), 38 | ] 39 | cmd: Annotated[ 40 | bool, 41 | click.option( 42 | "--cmd", 43 | is_flag=True, 44 | show_default=True, 45 | default=False, 46 | help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn", 47 | ), 48 | ] 49 | 50 | 51 | class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): ... 52 | 53 | 54 | @cli.command() 55 | @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict) 56 | def Redis(**parameters: Unpack[RedisHNSWTypedDict]): 57 | from .config import RedisConfig 58 | 59 | run( 60 | db=DB.Redis, 61 | db_config=RedisConfig( 62 | db_label=parameters["db_label"], 63 | password=SecretStr(parameters["password"]) if parameters["password"] else None, 64 | host=SecretStr(parameters["host"]), 65 | port=parameters["port"], 66 | ssl=parameters["ssl"], 67 | ssl_ca_certs=parameters["ssl_ca_certs"], 68 | cmd=parameters["cmd"], 69 | ), 70 | db_case_config=RedisHNSWConfig( 71 | M=parameters["m"], 72 | efConstruction=parameters["ef_construction"], 73 | ef=parameters["ef_runtime"], 74 | ), 75 | **parameters, 76 | ) 77 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/redis/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, IndexType, MetricType 4 | 5 | 6 | class RedisConfig(DBConfig): 7 | password: SecretStr | None = None 8 | host: SecretStr 9 | port: int | None = None 10 | 11 | def to_dict(self) -> dict: 12 | return { 13 | "host": self.host.get_secret_value(), 14 | "port": self.port, 15 | "password": self.password.get_secret_value() if self.password is not None else None, 16 | } 17 | 18 | 19 | class RedisIndexConfig(BaseModel): 20 | """Base config for milvus""" 21 | 22 | metric_type: MetricType | None = None 23 | 24 | def parse_metric(self) -> str: 25 | if not self.metric_type: 26 | return "" 27 | return self.metric_type.value 28 | 29 | 30 | class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig): 31 | M: int 32 | efConstruction: int 33 | ef: int | None = None 34 | index: IndexType = IndexType.HNSW 35 | 36 | def index_param(self) -> dict: 37 | return { 38 | "metric_type": self.parse_metric(), 39 | "index_type": self.index.value, 40 | "params": {"M": self.M, "efConstruction": self.efConstruction}, 41 | } 42 | 43 | def search_param(self) -> dict: 44 | return { 45 | "metric_type": self.parse_metric(), 46 | "params": {"ef": self.ef}, 47 | } 48 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/test/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Unpack 2 | 3 | from ....cli.cli import ( 4 | CommonTypedDict, 5 | cli, 6 | click_parameter_decorators_from_typed_dict, 7 | run, 8 | ) 9 | from .. import DB 10 | from ..test.config import TestConfig, TestIndexConfig 11 | 12 | 13 | class TestTypedDict(CommonTypedDict): ... 14 | 15 | 16 | @cli.command() 17 | @click_parameter_decorators_from_typed_dict(TestTypedDict) 18 | def Test(**parameters: Unpack[TestTypedDict]): 19 | run( 20 | db=DB.Test, 21 | db_config=TestConfig(db_label=parameters["db_label"]), 22 | db_case_config=TestIndexConfig(), 23 | **parameters, 24 | ) 25 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/test/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from ..api import DBCaseConfig, DBConfig, MetricType 4 | 5 | 6 | class TestConfig(DBConfig): 7 | def to_dict(self) -> dict: 8 | return {"db_label": self.db_label} 9 | 10 | 11 | class TestIndexConfig(BaseModel, DBCaseConfig): 12 | metric_type: MetricType | None = None 13 | 14 | def index_param(self) -> dict: 15 | return {} 16 | 17 | def search_param(self) -> dict: 18 | return {} 19 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/test/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections.abc import Generator 3 | from contextlib import contextmanager 4 | from typing import Any 5 | 6 | from ..api import DBCaseConfig, VectorDB 7 | 8 | log = logging.getLogger(__name__) 9 | 10 | 11 | class Test(VectorDB): 12 | def __init__( 13 | self, 14 | dim: int, 15 | db_config: dict, 16 | db_case_config: DBCaseConfig, 17 | drop_old: bool = False, 18 | **kwargs, 19 | ): 20 | self.db_config = db_config 21 | self.case_config = db_case_config 22 | 23 | log.info("Starting Test DB") 24 | 25 | @contextmanager 26 | def init(self) -> Generator[None, None, None]: 27 | """create and destroy connections to database. 28 | 29 | Examples: 30 | >>> with self.init(): 31 | >>> self.insert_embeddings() 32 | """ 33 | 34 | yield 35 | 36 | def optimize(self, data_size: int | None = None): 37 | pass 38 | 39 | def insert_embeddings( 40 | self, 41 | embeddings: list[list[float]], 42 | metadata: list[int], 43 | **kwargs: Any, 44 | ) -> tuple[int, Exception | None]: 45 | """Insert embeddings into the database. 46 | Should call self.init() first. 47 | """ 48 | return len(metadata), None 49 | 50 | def search_embedding( 51 | self, 52 | query: list[float], 53 | k: int = 100, 54 | filters: dict | None = None, 55 | timeout: int | None = None, 56 | **kwargs: Any, 57 | ) -> list[int]: 58 | return list(range(k)) 59 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/tidb/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from vectordb_bench.backend.clients import DB 7 | 8 | from ....cli.cli import CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run 9 | 10 | 11 | class TiDBTypedDict(CommonTypedDict): 12 | user_name: Annotated[ 13 | str, 14 | click.option( 15 | "--username", 16 | type=str, 17 | help="Username", 18 | default="root", 19 | show_default=True, 20 | required=True, 21 | ), 22 | ] 23 | password: Annotated[ 24 | str, 25 | click.option( 26 | "--password", 27 | type=str, 28 | default="", 29 | show_default=True, 30 | help="Password", 31 | ), 32 | ] 33 | host: Annotated[ 34 | str, 35 | click.option( 36 | "--host", 37 | type=str, 38 | default="127.0.0.1", 39 | show_default=True, 40 | required=True, 41 | help="Db host", 42 | ), 43 | ] 44 | port: Annotated[ 45 | int, 46 | click.option( 47 | "--port", 48 | type=int, 49 | default=4000, 50 | show_default=True, 51 | required=True, 52 | help="Db Port", 53 | ), 54 | ] 55 | db_name: Annotated[ 56 | str, 57 | click.option( 58 | "--db-name", 59 | type=str, 60 | default="test", 61 | show_default=True, 62 | required=True, 63 | help="Db name", 64 | ), 65 | ] 66 | ssl: Annotated[ 67 | bool, 68 | click.option( 69 | "--ssl/--no-ssl", 70 | default=False, 71 | show_default=True, 72 | is_flag=True, 73 | help="Enable or disable SSL, for TiDB Serverless SSL must be enabled", 74 | ), 75 | ] 76 | 77 | 78 | @cli.command() 79 | @click_parameter_decorators_from_typed_dict(TiDBTypedDict) 80 | def TiDB( 81 | **parameters: Unpack[TiDBTypedDict], 82 | ): 83 | from .config import TiDBConfig, TiDBIndexConfig 84 | 85 | run( 86 | db=DB.TiDB, 87 | db_config=TiDBConfig( 88 | db_label=parameters["db_label"], 89 | user_name=parameters["username"], 90 | password=SecretStr(parameters["password"]), 91 | host=parameters["host"], 92 | port=parameters["port"], 93 | db_name=parameters["db_name"], 94 | ssl=parameters["ssl"], 95 | ), 96 | db_case_config=TiDBIndexConfig(), 97 | **parameters, 98 | ) 99 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/tidb/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, MetricType 4 | 5 | 6 | class TiDBConfig(DBConfig): 7 | user_name: str = "root" 8 | password: SecretStr 9 | host: str = "127.0.0.1" 10 | port: int = 4000 11 | db_name: str = "test" 12 | ssl: bool = False 13 | 14 | def to_dict(self) -> dict: 15 | pwd_str = self.password.get_secret_value() 16 | return { 17 | "host": self.host, 18 | "port": self.port, 19 | "user": self.user_name, 20 | "password": pwd_str, 21 | "database": self.db_name, 22 | "ssl_verify_cert": self.ssl, 23 | "ssl_verify_identity": self.ssl, 24 | } 25 | 26 | 27 | class TiDBIndexConfig(BaseModel, DBCaseConfig): 28 | metric_type: MetricType | None = None 29 | 30 | def get_metric_fn(self) -> str: 31 | if self.metric_type == MetricType.L2: 32 | return "vec_l2_distance" 33 | if self.metric_type == MetricType.COSINE: 34 | return "vec_cosine_distance" 35 | msg = f"Unsupported metric type: {self.metric_type}" 36 | raise ValueError(msg) 37 | 38 | def index_param(self) -> dict: 39 | return { 40 | "metric_fn": self.get_metric_fn(), 41 | } 42 | 43 | def search_param(self) -> dict: 44 | return { 45 | "metric_fn": self.get_metric_fn(), 46 | } 47 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/vespa/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from vectordb_bench.backend.clients import DB 7 | from vectordb_bench.cli.cli import ( 8 | CommonTypedDict, 9 | HNSWFlavor1, 10 | cli, 11 | click_parameter_decorators_from_typed_dict, 12 | run, 13 | ) 14 | 15 | 16 | class VespaTypedDict(CommonTypedDict, HNSWFlavor1): 17 | uri: Annotated[ 18 | str, 19 | click.option("--uri", "-u", type=str, help="uri connection string", default="http://127.0.0.1"), 20 | ] 21 | port: Annotated[ 22 | int, 23 | click.option("--port", "-p", type=int, help="connection port", default=8080), 24 | ] 25 | quantization: Annotated[ 26 | str, click.option("--quantization", type=click.Choice(["none", "binary"], case_sensitive=False), default="none") 27 | ] 28 | 29 | 30 | @cli.command() 31 | @click_parameter_decorators_from_typed_dict(VespaTypedDict) 32 | def Vespa(**params: Unpack[VespaTypedDict]): 33 | from .config import VespaConfig, VespaHNSWConfig 34 | 35 | case_params = { 36 | "quantization_type": params["quantization"], 37 | "M": params["m"], 38 | "efConstruction": params["ef_construction"], 39 | "ef": params["ef_search"], 40 | } 41 | 42 | run( 43 | db=DB.Vespa, 44 | db_config=VespaConfig(url=SecretStr(params["uri"]), port=params["port"]), 45 | db_case_config=VespaHNSWConfig(**{k: v for k, v in case_params.items() if v}), 46 | **params, 47 | ) 48 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/vespa/config.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, TypeAlias 2 | 3 | from pydantic import BaseModel, SecretStr 4 | 5 | from ..api import DBCaseConfig, DBConfig, MetricType 6 | 7 | VespaMetric: TypeAlias = Literal["euclidean", "angular", "dotproduct", "prenormalized-angular", "hamming", "geodegrees"] 8 | 9 | VespaQuantizationType: TypeAlias = Literal["none", "binary"] 10 | 11 | 12 | class VespaConfig(DBConfig): 13 | url: SecretStr = "http://127.0.0.1" 14 | port: int = 8080 15 | 16 | def to_dict(self): 17 | return { 18 | "url": self.url.get_secret_value(), 19 | "port": self.port, 20 | } 21 | 22 | 23 | class VespaHNSWConfig(BaseModel, DBCaseConfig): 24 | metric_type: MetricType = MetricType.COSINE 25 | quantization_type: VespaQuantizationType = "none" 26 | M: int = 16 27 | efConstruction: int = 200 28 | ef: int = 100 29 | 30 | def index_param(self) -> dict: 31 | return { 32 | "distance_metric": self.parse_metric(self.metric_type), 33 | "max_links_per_node": self.M, 34 | "neighbors_to_explore_at_insert": self.efConstruction, 35 | } 36 | 37 | def search_param(self) -> dict: 38 | return {} 39 | 40 | def parse_metric(self, metric_type: MetricType) -> VespaMetric: 41 | match metric_type: 42 | case MetricType.COSINE: 43 | return "angular" 44 | case MetricType.L2: 45 | return "euclidean" 46 | case MetricType.DP | MetricType.IP: 47 | return "dotproduct" 48 | case MetricType.HAMMING: 49 | return "hamming" 50 | case _: 51 | raise NotImplementedError 52 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/vespa/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for supporting binary quantization 2 | 3 | From https://docs.vespa.ai/en/binarizing-vectors.html#appendix-conversion-to-int8 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | def binarize_tensor(tensor: list[float]) -> list[int]: 10 | """ 11 | Binarize a floating-point list by thresholding at zero 12 | and packing the bits into bytes. 13 | """ 14 | tensor = np.array(tensor) 15 | return np.packbits(np.where(tensor > 0, 1, 0), axis=0).astype(np.int8).tolist() 16 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/weaviate_cloud/cli.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Unpack 2 | 3 | import click 4 | from pydantic import SecretStr 5 | 6 | from ....cli.cli import ( 7 | CommonTypedDict, 8 | cli, 9 | click_parameter_decorators_from_typed_dict, 10 | run, 11 | ) 12 | from .. import DB 13 | 14 | 15 | class WeaviateTypedDict(CommonTypedDict): 16 | api_key: Annotated[ 17 | str, 18 | click.option("--api-key", type=str, help="Weaviate api key", required=True), 19 | ] 20 | url: Annotated[ 21 | str, 22 | click.option("--url", type=str, help="Weaviate url", required=True), 23 | ] 24 | 25 | 26 | @cli.command() 27 | @click_parameter_decorators_from_typed_dict(WeaviateTypedDict) 28 | def Weaviate(**parameters: Unpack[WeaviateTypedDict]): 29 | from .config import WeaviateConfig, WeaviateIndexConfig 30 | 31 | run( 32 | db=DB.WeaviateCloud, 33 | db_config=WeaviateConfig( 34 | db_label=parameters["db_label"], 35 | api_key=SecretStr(parameters["api_key"]), 36 | url=SecretStr(parameters["url"]), 37 | ), 38 | db_case_config=WeaviateIndexConfig(ef=256, efConstruction=256, maxConnections=16), 39 | **parameters, 40 | ) 41 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/weaviate_cloud/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig, MetricType 4 | 5 | 6 | class WeaviateConfig(DBConfig): 7 | url: SecretStr 8 | api_key: SecretStr 9 | 10 | def to_dict(self) -> dict: 11 | return { 12 | "url": self.url.get_secret_value(), 13 | "auth_client_secret": self.api_key.get_secret_value(), 14 | } 15 | 16 | 17 | class WeaviateIndexConfig(BaseModel, DBCaseConfig): 18 | metric_type: MetricType | None = None 19 | ef: int | None = -1 20 | efConstruction: int | None = None 21 | maxConnections: int | None = None 22 | 23 | def parse_metric(self) -> str: 24 | if self.metric_type == MetricType.L2: 25 | return "l2-squared" 26 | if self.metric_type == MetricType.IP: 27 | return "dot" 28 | return "cosine" 29 | 30 | def index_param(self) -> dict: 31 | if self.maxConnections is not None and self.efConstruction is not None: 32 | params = { 33 | "distance": self.parse_metric(), 34 | "maxConnections": self.maxConnections, 35 | "efConstruction": self.efConstruction, 36 | } 37 | else: 38 | params = {"distance": self.parse_metric()} 39 | return params 40 | 41 | def search_param(self) -> dict: 42 | return { 43 | "ef": self.ef, 44 | } 45 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/zilliz_cloud/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Annotated, Unpack 3 | 4 | import click 5 | from pydantic import SecretStr 6 | 7 | from vectordb_bench.backend.clients import DB 8 | from vectordb_bench.cli.cli import ( 9 | CommonTypedDict, 10 | cli, 11 | click_parameter_decorators_from_typed_dict, 12 | run, 13 | ) 14 | 15 | 16 | class ZillizTypedDict(CommonTypedDict): 17 | uri: Annotated[ 18 | str, 19 | click.option("--uri", type=str, help="uri connection string", required=True), 20 | ] 21 | user_name: Annotated[ 22 | str, 23 | click.option("--user-name", type=str, help="Db username", required=True), 24 | ] 25 | password: Annotated[ 26 | str, 27 | click.option( 28 | "--password", 29 | type=str, 30 | help="Zilliz password", 31 | default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""), 32 | show_default="$ZILLIZ_PASSWORD", 33 | ), 34 | ] 35 | level: Annotated[ 36 | str, 37 | click.option("--level", type=str, help="Zilliz index level", required=False), 38 | ] 39 | 40 | 41 | @cli.command() 42 | @click_parameter_decorators_from_typed_dict(ZillizTypedDict) 43 | def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]): 44 | from .config import AutoIndexConfig, ZillizCloudConfig 45 | 46 | run( 47 | db=DB.ZillizCloud, 48 | db_config=ZillizCloudConfig( 49 | db_label=parameters["db_label"], 50 | uri=SecretStr(parameters["uri"]), 51 | user=parameters["user_name"], 52 | password=SecretStr(parameters["password"]), 53 | ), 54 | db_case_config=AutoIndexConfig( 55 | params={parameters["level"]}, 56 | ), 57 | **parameters, 58 | ) 59 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/zilliz_cloud/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import SecretStr 2 | 3 | from ..api import DBCaseConfig, DBConfig 4 | from ..milvus.config import IndexType, MilvusIndexConfig 5 | 6 | 7 | class ZillizCloudConfig(DBConfig): 8 | uri: SecretStr 9 | user: str 10 | password: SecretStr 11 | 12 | def to_dict(self) -> dict: 13 | return { 14 | "uri": self.uri.get_secret_value(), 15 | "user": self.user, 16 | "password": self.password.get_secret_value(), 17 | } 18 | 19 | 20 | class AutoIndexConfig(MilvusIndexConfig, DBCaseConfig): 21 | index: IndexType = IndexType.AUTOINDEX 22 | level: int = 1 23 | 24 | def index_param(self) -> dict: 25 | return { 26 | "metric_type": self.parse_metric(), 27 | "index_type": self.index.value, 28 | "params": {}, 29 | } 30 | 31 | def search_param(self) -> dict: 32 | return { 33 | "metric_type": self.parse_metric(), 34 | "params": { 35 | "level": self.level, 36 | }, 37 | } 38 | -------------------------------------------------------------------------------- /vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py: -------------------------------------------------------------------------------- 1 | """Wrapper around the ZillizCloud vector database over VectorDB""" 2 | 3 | from ..api import DBCaseConfig 4 | from ..milvus.milvus import Milvus 5 | 6 | 7 | class ZillizCloud(Milvus): 8 | def __init__( 9 | self, 10 | dim: int, 11 | db_config: dict, 12 | db_case_config: DBCaseConfig, 13 | collection_name: str = "ZillizCloudVectorDBBench", 14 | drop_old: bool = False, 15 | name: str = "ZillizCloud", 16 | **kwargs, 17 | ): 18 | super().__init__( 19 | dim=dim, 20 | db_config=db_config, 21 | db_case_config=db_case_config, 22 | collection_name=collection_name, 23 | drop_old=drop_old, 24 | name=name, 25 | **kwargs, 26 | ) 27 | -------------------------------------------------------------------------------- /vectordb_bench/backend/result_collector.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | from vectordb_bench.models import TestResult 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | class ResultCollector: 10 | @classmethod 11 | def collect(cls, result_dir: pathlib.Path) -> list[TestResult]: 12 | reg = "result_*.json" 13 | results_d = {} 14 | if not result_dir.exists() or len(list(result_dir.rglob(reg))) == 0: 15 | return [] 16 | 17 | for json_file in result_dir.rglob(reg): 18 | file_result = TestResult.read_file(json_file, trans_unit=True) 19 | 20 | # Group result files of the same run_id into one TestResult 21 | if file_result.run_id in results_d: 22 | results_d[file_result.run_id].results.extend(file_result.results) 23 | else: 24 | results_d[file_result.run_id] = file_result 25 | 26 | return list(results_d.values()) 27 | -------------------------------------------------------------------------------- /vectordb_bench/backend/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .mp_runner import ( 2 | MultiProcessingSearchRunner, 3 | ) 4 | from .serial_runner import SerialInsertRunner, SerialSearchRunner 5 | 6 | __all__ = [ 7 | "MultiProcessingSearchRunner", 8 | "SerialInsertRunner", 9 | "SerialSearchRunner", 10 | ] 11 | -------------------------------------------------------------------------------- /vectordb_bench/backend/runner/rate_runner.py: -------------------------------------------------------------------------------- 1 | import concurrent 2 | import logging 3 | import multiprocessing as mp 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from vectordb_bench import config 8 | from vectordb_bench.backend.clients import api 9 | from vectordb_bench.backend.dataset import DataSetIterator 10 | from vectordb_bench.backend.utils import time_it 11 | 12 | from .util import get_data 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | class RatedMultiThreadingInsertRunner: 18 | def __init__( 19 | self, 20 | rate: int, # numRows per second 21 | db: api.VectorDB, 22 | dataset_iter: DataSetIterator, 23 | normalize: bool = False, 24 | timeout: float | None = None, 25 | ): 26 | self.timeout = timeout if isinstance(timeout, int | float) else None 27 | self.dataset = dataset_iter 28 | self.db = db 29 | self.normalize = normalize 30 | self.insert_rate = rate 31 | self.batch_rate = rate // config.NUM_PER_BATCH 32 | 33 | def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]): 34 | db.insert_embeddings(emb, metadata) 35 | 36 | @time_it 37 | def run_with_rate(self, q: mp.Queue): 38 | with ThreadPoolExecutor(max_workers=mp.cpu_count()) as executor: 39 | executing_futures = [] 40 | 41 | @time_it 42 | def submit_by_rate() -> bool: 43 | rate = self.batch_rate 44 | for data in self.dataset: 45 | emb, metadata = get_data(data, self.normalize) 46 | executing_futures.append( 47 | executor.submit(self.send_insert_task, self.db, emb, metadata), 48 | ) 49 | rate -= 1 50 | 51 | if rate == 0: 52 | return False 53 | return rate == self.batch_rate 54 | 55 | with self.db.init(): 56 | while True: 57 | start_time = time.perf_counter() 58 | finished, elapsed_time = submit_by_rate() 59 | if finished is True: 60 | q.put(True, block=True) 61 | log.info(f"End of dataset, left unfinished={len(executing_futures)}") 62 | break 63 | 64 | q.put(False, block=False) 65 | wait_interval = 1 - elapsed_time if elapsed_time < 1 else 0.001 66 | 67 | try: 68 | done, not_done = concurrent.futures.wait( 69 | executing_futures, 70 | timeout=wait_interval, 71 | return_when=concurrent.futures.FIRST_EXCEPTION, 72 | ) 73 | 74 | if len(not_done) > 0: 75 | log.warning( 76 | f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] " 77 | f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round" 78 | ) 79 | executing_futures = list(not_done) 80 | else: 81 | log.debug( 82 | f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} " 83 | f"task in 1s, wait_interval={wait_interval:.2f}" 84 | ) 85 | executing_futures = [] 86 | except Exception as e: 87 | log.warning(f"task error, terminating, err={e}") 88 | q.put(None, block=True) 89 | executor.shutdown(wait=True, cancel_futures=True) 90 | raise e from e 91 | 92 | dur = time.perf_counter() - start_time 93 | if dur < 1: 94 | time.sleep(1 - dur) 95 | 96 | # wait for all tasks in executing_futures to complete 97 | if len(executing_futures) > 0: 98 | try: 99 | done, _ = concurrent.futures.wait( 100 | executing_futures, 101 | return_when=concurrent.futures.FIRST_EXCEPTION, 102 | ) 103 | except Exception as e: 104 | log.warning(f"task error, terminating, err={e}") 105 | q.put(None, block=True) 106 | executor.shutdown(wait=True, cancel_futures=True) 107 | raise e from e 108 | -------------------------------------------------------------------------------- /vectordb_bench/backend/runner/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from pandas import DataFrame 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], list[str]]: 10 | all_metadata = data_df["id"].tolist() 11 | emb_np = np.stack(data_df["emb"]) 12 | if normalize: 13 | log.debug("normalize the 100k train data") 14 | all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() 15 | else: 16 | all_embeddings = emb_np.tolist() 17 | return all_embeddings, all_metadata 18 | -------------------------------------------------------------------------------- /vectordb_bench/backend/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | 4 | 5 | def numerize(n: int) -> str: 6 | """display positive number n for readability 7 | 8 | Examples: 9 | >>> numerize(1_000) 10 | '1K' 11 | >>> numerize(1_000_000_000) 12 | '1B' 13 | """ 14 | sufix2upbound = { 15 | "EMPTY": 1e3, 16 | "K": 1e6, 17 | "M": 1e9, 18 | "B": 1e12, 19 | "END": float("inf"), 20 | } 21 | 22 | display_n, sufix = n, "" 23 | for s, base in sufix2upbound.items(): 24 | # number >= 1000B will alway have sufix 'B' 25 | if s == "END": 26 | display_n = int(n / 1e9) 27 | sufix = "B" 28 | break 29 | 30 | if n < base: 31 | sufix = "" if s == "EMPTY" else s 32 | display_n = int(n / (base / 1e3)) 33 | break 34 | return f"{display_n}{sufix}" 35 | 36 | 37 | def time_it(func: any): 38 | """returns result and elapsed time""" 39 | 40 | @wraps(func) 41 | def inner(*args, **kwargs): 42 | pref = time.perf_counter() 43 | result = func(*args, **kwargs) 44 | delta = time.perf_counter() - pref 45 | return result, delta 46 | 47 | return inner 48 | 49 | 50 | def compose_train_files(train_count: int, use_shuffled: bool) -> list[str]: 51 | prefix = "shuffle_train" if use_shuffled else "train" 52 | middle = f"of-{train_count}" 53 | surfix = "parquet" 54 | 55 | train_files = [] 56 | if train_count > 1: 57 | just_size = 2 58 | for i in range(train_count): 59 | sub_file = f"{prefix}-{str(i).rjust(just_size, '0')}-{middle}.{surfix}" 60 | train_files.append(sub_file) 61 | else: 62 | train_files.append(f"{prefix}.{surfix}") 63 | 64 | return train_files 65 | 66 | 67 | ONE_PERCENT = 0.01 68 | NINETY_NINE_PERCENT = 0.99 69 | 70 | 71 | def compose_gt_file(filters: float | str | None = None) -> str: 72 | if filters is None: 73 | return "neighbors.parquet" 74 | 75 | if filters == ONE_PERCENT: 76 | return "neighbors_head_1p.parquet" 77 | 78 | if filters == NINETY_NINE_PERCENT: 79 | return "neighbors_tail_1p.parquet" 80 | 81 | msg = f"Filters not supported: {filters}" 82 | raise ValueError(msg) 83 | -------------------------------------------------------------------------------- /vectordb_bench/base.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel as PydanticBaseModel 2 | 3 | 4 | class BaseModel(PydanticBaseModel, arbitrary_types_allowed=True): 5 | pass 6 | -------------------------------------------------------------------------------- /vectordb_bench/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zilliztech/VectorDBBench/571de32b1cae8210f0ce5426981347add2d5c61d/vectordb_bench/cli/__init__.py -------------------------------------------------------------------------------- /vectordb_bench/cli/vectordbbench.py: -------------------------------------------------------------------------------- 1 | from ..backend.clients.alloydb.cli import AlloyDBScaNN 2 | from ..backend.clients.aws_opensearch.cli import AWSOpenSearch 3 | from ..backend.clients.clickhouse.cli import Clickhouse 4 | from ..backend.clients.lancedb.cli import LanceDB 5 | from ..backend.clients.mariadb.cli import MariaDBHNSW 6 | from ..backend.clients.memorydb.cli import MemoryDB 7 | from ..backend.clients.milvus.cli import MilvusAutoIndex 8 | from ..backend.clients.pgdiskann.cli import PgDiskAnn 9 | from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat 10 | from ..backend.clients.pgvector.cli import PgVectorHNSW 11 | from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn 12 | from ..backend.clients.qdrant_cloud.cli import QdrantCloud 13 | from ..backend.clients.redis.cli import Redis 14 | from ..backend.clients.test.cli import Test 15 | from ..backend.clients.tidb.cli import TiDB 16 | from ..backend.clients.vespa.cli import Vespa 17 | from ..backend.clients.weaviate_cloud.cli import Weaviate 18 | from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex 19 | from .cli import cli 20 | 21 | cli.add_command(PgVectorHNSW) 22 | cli.add_command(PgVectoRSHNSW) 23 | cli.add_command(PgVectoRSIVFFlat) 24 | cli.add_command(Redis) 25 | cli.add_command(MemoryDB) 26 | cli.add_command(Weaviate) 27 | cli.add_command(Test) 28 | cli.add_command(ZillizAutoIndex) 29 | cli.add_command(MilvusAutoIndex) 30 | cli.add_command(AWSOpenSearch) 31 | cli.add_command(PgVectorScaleDiskAnn) 32 | cli.add_command(PgDiskAnn) 33 | cli.add_command(AlloyDBScaNN) 34 | cli.add_command(MariaDBHNSW) 35 | cli.add_command(TiDB) 36 | cli.add_command(Clickhouse) 37 | cli.add_command(Vespa) 38 | cli.add_command(LanceDB) 39 | cli.add_command(QdrantCloud) 40 | 41 | 42 | if __name__ == "__main__": 43 | cli() 44 | -------------------------------------------------------------------------------- /vectordb_bench/config-files/sample_config.yml: -------------------------------------------------------------------------------- 1 | pgvectorhnsw: 2 | db_label: pgConfigTest 3 | user_name: vectordbbench 4 | db_name: vectordbbench 5 | host: localhost 6 | m: 16 7 | ef_construction: 128 8 | ef_search: 128 9 | milvushnsw: 10 | skip_search_serial: True 11 | case_type: Performance1536D50K 12 | uri: http://localhost:19530 13 | m: 16 14 | ef_construction: 128 15 | ef_search: 128 16 | drop_old: False 17 | load: False 18 | -------------------------------------------------------------------------------- /vectordb_bench/custom/custom_case.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "My Dataset (Performace Case)", 4 | "description": "this is a customized dataset.", 5 | "load_timeout": 36000, 6 | "optimize_timeout": 36000, 7 | "dataset_config": { 8 | "name": "My Dataset", 9 | "dir": "/my_dataset_path", 10 | "size": 1000000, 11 | "dim": 1024, 12 | "metric_type": "L2", 13 | "file_count": 1, 14 | "use_shuffled": false, 15 | "with_gt": true 16 | } 17 | } 18 | ] -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/check_results/data.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import asdict 3 | from vectordb_bench.metric import isLowerIsBetterMetric 4 | from vectordb_bench.models import CaseResult, ResultLabel 5 | 6 | 7 | def getChartData( 8 | tasks: list[CaseResult], 9 | dbNames: list[str], 10 | caseNames: list[str], 11 | ): 12 | filterTasks = getFilterTasks(tasks, dbNames, caseNames) 13 | mergedTasks, failedTasks = mergeTasks(filterTasks) 14 | return mergedTasks, failedTasks 15 | 16 | 17 | def getFilterTasks( 18 | tasks: list[CaseResult], 19 | dbNames: list[str], 20 | caseNames: list[str], 21 | ) -> list[CaseResult]: 22 | filterTasks = [ 23 | task 24 | for task in tasks 25 | if task.task_config.db_name in dbNames 26 | and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames 27 | ] 28 | return filterTasks 29 | 30 | 31 | def mergeTasks(tasks: list[CaseResult]): 32 | dbCaseMetricsMap = defaultdict(lambda: defaultdict(dict)) 33 | for task in tasks: 34 | db_name = task.task_config.db_name 35 | db = task.task_config.db.value 36 | db_label = task.task_config.db_config.db_label or "" 37 | version = task.task_config.db_config.version or "" 38 | case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case) 39 | dbCaseMetricsMap[db_name][case.name] = { 40 | "db": db, 41 | "db_label": db_label, 42 | "version": version, 43 | "metrics": mergeMetrics( 44 | dbCaseMetricsMap[db_name][case.name].get("metrics", {}), 45 | asdict(task.metrics), 46 | ), 47 | "label": getBetterLabel( 48 | dbCaseMetricsMap[db_name][case.name].get("label", ResultLabel.FAILED), 49 | task.label, 50 | ), 51 | } 52 | 53 | mergedTasks = [] 54 | failedTasks = defaultdict(lambda: defaultdict(str)) 55 | for db_name, caseMetricsMap in dbCaseMetricsMap.items(): 56 | for case_name, metricInfo in caseMetricsMap.items(): 57 | metrics = metricInfo["metrics"] 58 | db = metricInfo["db"] 59 | db_label = metricInfo["db_label"] 60 | version = metricInfo["version"] 61 | label = metricInfo["label"] 62 | if label == ResultLabel.NORMAL: 63 | mergedTasks.append( 64 | { 65 | "db_name": db_name, 66 | "db": db, 67 | "db_label": db_label, 68 | "version": version, 69 | "case_name": case_name, 70 | "metricsSet": set(metrics.keys()), 71 | **metrics, 72 | } 73 | ) 74 | else: 75 | failedTasks[case_name][db_name] = label 76 | 77 | return mergedTasks, failedTasks 78 | 79 | 80 | def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict: 81 | metrics = {**metrics_1} 82 | for key, value in metrics_2.items(): 83 | metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value 84 | 85 | return metrics 86 | 87 | 88 | def getBetterMetric(metric, value_1, value_2): 89 | try: 90 | if value_1 < 1e-7: 91 | return value_2 92 | if value_2 < 1e-7: 93 | return value_1 94 | return min(value_1, value_2) if isLowerIsBetterMetric(metric) else max(value_1, value_2) 95 | except Exception: 96 | return value_1 97 | 98 | 99 | def getBetterLabel(label_1: ResultLabel, label_2: ResultLabel): 100 | return label_2 if label_1 != ResultLabel.NORMAL else label_1 101 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/check_results/expanderStyle.py: -------------------------------------------------------------------------------- 1 | def initMainExpanderStyle(st): 2 | st.markdown( 3 | """""", 11 | unsafe_allow_html=True, 12 | ) 13 | 14 | 15 | def initSidebarExanderStyle(st): 16 | st.markdown( 17 | """", 15 | unsafe_allow_html=True, 16 | ) 17 | 18 | st.header("Filters") 19 | 20 | shownResults = getshownResults(results, st) 21 | showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st) 22 | 23 | shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames) 24 | 25 | return shownData, failedTasks, showCaseNames 26 | 27 | 28 | def getshownResults(results: list[TestResult], st) -> list[CaseResult]: 29 | resultSelectOptions = [ 30 | result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results 31 | ] 32 | if len(resultSelectOptions) == 0: 33 | st.write("There are no results to display. Please wait for the task to complete or run a new task.") 34 | return [] 35 | 36 | selectedResultSelectedOptions = st.multiselect( 37 | "Select the task results you need to analyze.", 38 | resultSelectOptions, 39 | # label_visibility="hidden", 40 | default=resultSelectOptions, 41 | ) 42 | selectedResult: list[CaseResult] = [] 43 | for option in selectedResultSelectedOptions: 44 | result = results[resultSelectOptions.index(option)].results 45 | selectedResult += result 46 | 47 | return selectedResult 48 | 49 | 50 | def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[str]]: 51 | initSidebarExanderStyle(st) 52 | allDbNames = list(set({res.task_config.db_name for res in result})) 53 | allDbNames.sort() 54 | allCases: list[Case] = [ 55 | res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result 56 | ] 57 | allCaseNameSet = set({case.name for case in allCases}) 58 | allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [ 59 | case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER 60 | ] 61 | 62 | # DB Filter 63 | dbFilterContainer = st.container() 64 | showDBNames = filterView( 65 | dbFilterContainer, 66 | "DB Filter", 67 | allDbNames, 68 | col=1, 69 | ) 70 | 71 | # Case Filter 72 | caseFilterContainer = st.container() 73 | showCaseNames = filterView( 74 | caseFilterContainer, 75 | "Case Filter", 76 | [caseName for caseName in allCaseNames], 77 | col=1, 78 | ) 79 | 80 | return showDBNames, showCaseNames 81 | 82 | 83 | def filterView(container, header, options, col, optionLables=None): 84 | selectAllState = f"{header}-select-all-state" 85 | if selectAllState not in st.session_state: 86 | st.session_state[selectAllState] = True 87 | 88 | countKeyState = f"{header}-select-all-count-key" 89 | if countKeyState not in st.session_state: 90 | st.session_state[countKeyState] = 0 91 | 92 | expander = container.expander(header, True) 93 | selectAllColumns = expander.columns(SIDEBAR_CONTROL_COLUMNS, gap="small") 94 | selectAllButton = selectAllColumns[SIDEBAR_CONTROL_COLUMNS - 2].button( 95 | "select all", 96 | key=f"{header}-select-all-button", 97 | # type="primary", 98 | ) 99 | clearAllButton = selectAllColumns[SIDEBAR_CONTROL_COLUMNS - 1].button( 100 | "clear all", 101 | key=f"{header}-clear-all-button", 102 | # type="primary", 103 | ) 104 | if selectAllButton: 105 | st.session_state[selectAllState] = True 106 | st.session_state[countKeyState] += 1 107 | if clearAllButton: 108 | st.session_state[selectAllState] = False 109 | st.session_state[countKeyState] += 1 110 | columns = expander.columns( 111 | col, 112 | gap="small", 113 | ) 114 | if optionLables is None: 115 | optionLables = options 116 | isActive = {option: st.session_state[selectAllState] for option in optionLables} 117 | for i, option in enumerate(optionLables): 118 | isActive[option] = columns[i % col].checkbox( 119 | optionLables[i], 120 | value=isActive[option], 121 | key=f"{optionLables[i]}-{st.session_state[countKeyState]}", 122 | ) 123 | 124 | return [options[i] for i, option in enumerate(optionLables) if isActive[option]] 125 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/check_results/footer.py: -------------------------------------------------------------------------------- 1 | def footer(st): 2 | text = "* All test results are from community contributors. If there is any ambiguity, feel free to raise an issue or make amendments on our GitHub page." 3 | st.markdown( 4 | f""" 5 |
{text}
8 | 9 | ") 8 | if navClick: 9 | switch_page("run test") 10 | 11 | 12 | def NavToQuriesPerDollar(st): 13 | st.subheader("Compare qps with price.") 14 | navClick = st.button("QP$ (Quries per Dollar)   >") 15 | if navClick: 16 | switch_page("quries_per_dollar") 17 | 18 | 19 | def NavToResults(st, key="nav-to-results"): 20 | navClick = st.button("<   Back to Results", key=key) 21 | if navClick: 22 | switch_page("vdb benchmark") 23 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/check_results/priceTable.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.backend.clients import DB 2 | import pandas as pd 3 | from collections import defaultdict 4 | import streamlit as st 5 | 6 | from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE 7 | 8 | 9 | def priceTable(container, data): 10 | dbAndLabelSet = {(d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value} 11 | 12 | dbAndLabelList = list(dbAndLabelSet) 13 | dbAndLabelList.sort() 14 | 15 | table = pd.DataFrame( 16 | [ 17 | { 18 | "DB": db, 19 | "Label": db_label, 20 | "Price per hour": DB_DBLABEL_TO_PRICE.get(db, {}).get(db_label, 0), 21 | } 22 | for db, db_label in dbAndLabelList 23 | ] 24 | ) 25 | height = len(table) * 35 + 38 26 | 27 | expander = container.expander("Price List (Editable).") 28 | editTable = expander.data_editor( 29 | table, 30 | use_container_width=True, 31 | hide_index=True, 32 | height=height, 33 | disabled=("DB", "Label"), 34 | column_config={ 35 | "Price per hour": st.column_config.NumberColumn( 36 | min_value=0, 37 | format="$ %f", 38 | ) 39 | }, 40 | ) 41 | 42 | priceMap = defaultdict(lambda: defaultdict(float)) 43 | for _, row in editTable.iterrows(): 44 | db, db_label, price = row 45 | priceMap[db][db_label] = price 46 | 47 | return priceMap 48 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/check_results/stPageConfig.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.frontend.config.styles import * 2 | 3 | 4 | def initResultsPageConfig(st): 5 | st.set_page_config( 6 | page_title=PAGE_TITLE, 7 | page_icon=FAVICON, 8 | # layout="wide", 9 | # initial_sidebar_state="collapsed", 10 | ) 11 | 12 | 13 | def initRunTestPageConfig(st): 14 | st.set_page_config( 15 | page_title=PAGE_TITLE, 16 | page_icon=FAVICON, 17 | # layout="wide", 18 | initial_sidebar_state="collapsed", 19 | ) 20 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/concurrent/charts.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.frontend.components.check_results.expanderStyle import ( 2 | initMainExpanderStyle, 3 | ) 4 | import plotly.express as px 5 | 6 | from vectordb_bench.frontend.config.styles import COLOR_MAP 7 | 8 | 9 | def drawChartsByCase(allData, showCaseNames: list[str], st, latency_type: str): 10 | initMainExpanderStyle(st) 11 | for caseName in showCaseNames: 12 | chartContainer = st.expander(caseName, True) 13 | caseDataList = [data for data in allData if data["case_name"] == caseName] 14 | data = [ 15 | { 16 | "conc_num": caseData["conc_num_list"][i], 17 | "qps": (caseData["conc_qps_list"][i] if 0 <= i < len(caseData["conc_qps_list"]) else 0), 18 | "latency_p99": ( 19 | caseData["conc_latency_p99_list"][i] * 1000 20 | if 0 <= i < len(caseData["conc_latency_p99_list"]) 21 | else 0 22 | ), 23 | "latency_avg": ( 24 | caseData["conc_latency_avg_list"][i] * 1000 25 | if 0 <= i < len(caseData["conc_latency_avg_list"]) 26 | else 0 27 | ), 28 | "db_name": caseData["db_name"], 29 | "db": caseData["db"], 30 | } 31 | for caseData in caseDataList 32 | for i in range(len(caseData["conc_num_list"])) 33 | ] 34 | drawChart(data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type) 35 | 36 | 37 | def getRange(metric, data, padding_multipliers): 38 | minV = min([d.get(metric, 0) for d in data]) 39 | maxV = max([d.get(metric, 0) for d in data]) 40 | padding = maxV - minV 41 | rangeV = [ 42 | minV - padding * padding_multipliers[0], 43 | maxV + padding * padding_multipliers[1], 44 | ] 45 | return rangeV 46 | 47 | 48 | def gen_title(s: str) -> str: 49 | if "latency" in s: 50 | return f'{s.replace("_", " ").title()} (ms)' 51 | else: 52 | return s.upper() 53 | 54 | 55 | def drawChart(data, st, key: str, x_metric: str = "latency_p99", y_metric: str = "qps"): 56 | if len(data) == 0: 57 | return 58 | 59 | x = x_metric 60 | xrange = getRange(x, data, [0.05, 0.1]) 61 | 62 | y = y_metric 63 | yrange = getRange(y, data, [0.2, 0.1]) 64 | 65 | color = "db" 66 | color_discrete_map = COLOR_MAP 67 | color = "db_name" 68 | color_discrete_map = None 69 | line_group = "db_name" 70 | text = "conc_num" 71 | 72 | data.sort(key=lambda a: a["conc_num"]) 73 | 74 | fig = px.line( 75 | data, 76 | x=x, 77 | y=y, 78 | color=color, 79 | color_discrete_map=color_discrete_map, 80 | line_group=line_group, 81 | text=text, 82 | markers=True, 83 | hover_data={ 84 | "conc_num": True, 85 | }, 86 | height=720, 87 | ) 88 | fig.update_xaxes(range=xrange, title_text=gen_title(x_metric)) 89 | fig.update_yaxes(range=yrange, title_text=gen_title(y_metric)) 90 | fig.update_traces(textposition="bottom right", texttemplate="conc-%{text:,.4~r}") 91 | 92 | st.plotly_chart(fig, use_container_width=True, key=key) 93 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/custom/displayCustomCase.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig 2 | 3 | 4 | def displayCustomCase(customCase: CustomCaseConfig, st, key): 5 | 6 | columns = st.columns([1, 2]) 7 | customCase.dataset_config.name = columns[0].text_input( 8 | "Name", key=f"{key}_name", value=customCase.dataset_config.name 9 | ) 10 | customCase.name = f"{customCase.dataset_config.name} (Performace Case)" 11 | customCase.dataset_config.dir = columns[1].text_input( 12 | "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir 13 | ) 14 | 15 | columns = st.columns(4) 16 | customCase.dataset_config.dim = columns[0].number_input( 17 | "dim", key=f"{key}_dim", value=customCase.dataset_config.dim 18 | ) 19 | customCase.dataset_config.size = columns[1].number_input( 20 | "size", key=f"{key}_size", value=customCase.dataset_config.size 21 | ) 22 | customCase.dataset_config.metric_type = columns[2].selectbox( 23 | "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"] 24 | ) 25 | customCase.dataset_config.file_count = columns[3].number_input( 26 | "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count 27 | ) 28 | 29 | columns = st.columns(4) 30 | customCase.dataset_config.use_shuffled = columns[0].checkbox( 31 | "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled 32 | ) 33 | customCase.dataset_config.with_gt = columns[1].checkbox( 34 | "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt 35 | ) 36 | 37 | customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description) 38 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/custom/displaypPrams.py: -------------------------------------------------------------------------------- 1 | def displayParams(st): 2 | st.markdown( 3 | """ 4 | - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. 5 | - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. 6 | - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. 7 | - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. 8 | 9 | - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. 10 | 11 | - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order. 12 | """ 13 | ) 14 | st.caption( 15 | """We recommend limiting the number of test query vectors, like 1,000.""", 16 | help=""" 17 | When conducting concurrent query tests, Vdbbench creates a large number of processes. 18 | To minimize additional communication overhead during testing, 19 | we prepare a complete set of test queries for each process, allowing them to run independently.\n 20 | However, this means that as the number of concurrent processes increases, 21 | the number of copied query vectors also increases significantly, 22 | which can place substantial pressure on memory resources. 23 | """, 24 | ) 25 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/custom/getCustomConfig.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from pydantic import BaseModel 4 | 5 | from vectordb_bench import config 6 | 7 | 8 | class CustomDatasetConfig(BaseModel): 9 | name: str = "custom_dataset" 10 | dir: str = "" 11 | size: int = 0 12 | dim: int = 0 13 | metric_type: str = "L2" 14 | file_count: int = 1 15 | use_shuffled: bool = False 16 | with_gt: bool = True 17 | 18 | 19 | class CustomCaseConfig(BaseModel): 20 | name: str = "custom_dataset (Performace Case)" 21 | description: str = "" 22 | load_timeout: int = 36000 23 | optimize_timeout: int = 36000 24 | dataset_config: CustomDatasetConfig = CustomDatasetConfig() 25 | 26 | 27 | def get_custom_configs(): 28 | with open(config.CUSTOM_CONFIG_DIR, "r") as f: 29 | custom_configs = json.load(f) 30 | return [CustomCaseConfig(**custom_config) for custom_config in custom_configs] 31 | 32 | 33 | def save_custom_configs(custom_configs: list[CustomDatasetConfig]): 34 | with open(config.CUSTOM_CONFIG_DIR, "w") as f: 35 | json.dump([custom_config.dict() for custom_config in custom_configs], f, indent=4) 36 | 37 | 38 | def generate_custom_case(): 39 | return CustomCaseConfig() 40 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/custom/initStyle.py: -------------------------------------------------------------------------------- 1 | def initStyle(st): 2 | st.markdown( 3 | """""", 14 | unsafe_allow_html=True, 15 | ) 16 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/get_results/saveAsImage.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import streamlit as st 3 | import streamlit.components.v1 as components 4 | 5 | HTML_2_CANVAS_URL = "https://unpkg.com/html2canvas@1.4.1/dist/html2canvas.js" 6 | 7 | 8 | @st.cache_data 9 | def load_unpkg(src: str) -> str: 10 | return requests.get(src).text 11 | 12 | 13 | def getResults(container, pageName="vectordb_bench"): 14 | container.subheader("Get results") 15 | saveAsImage(container, pageName) 16 | 17 | 18 | def saveAsImage(container, pageName): 19 | html2canvasJS = load_unpkg(HTML_2_CANVAS_URL) 20 | container.write() 21 | buttonText = "Save as Image" 22 | savePDFButton = container.button(buttonText) 23 | if savePDFButton: 24 | components.html( 25 | f""" 26 | 27 | 28 | """, 50 | height=0, 51 | width=0, 52 | ) 53 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/autoRefresh.py: -------------------------------------------------------------------------------- 1 | from streamlit_autorefresh import st_autorefresh 2 | from vectordb_bench.frontend.config.styles import * 3 | 4 | 5 | def autoRefresh(): 6 | auto_refresh_count = st_autorefresh( 7 | interval=MAX_AUTO_REFRESH_INTERVAL, 8 | limit=MAX_AUTO_REFRESH_COUNT, 9 | key="streamlit-auto-refresh", 10 | ) 11 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/dbConfigSetting.py: -------------------------------------------------------------------------------- 1 | from pydantic import ValidationError 2 | from vectordb_bench.backend.clients import DB 3 | from vectordb_bench.frontend.config.styles import DB_CONFIG_SETTING_COLUMNS 4 | from vectordb_bench.frontend.utils import inputIsPassword 5 | 6 | 7 | def dbConfigSettings(st, activedDbList: list[DB]): 8 | expander = st.expander("Configurations for the selected databases", True) 9 | 10 | dbConfigs = {} 11 | isAllValid = True 12 | for activeDb in activedDbList: 13 | dbConfigSettingItemContainer = expander.container() 14 | dbConfig = dbConfigSettingItem(dbConfigSettingItemContainer, activeDb) 15 | try: 16 | dbConfigs[activeDb] = activeDb.config_cls(**dbConfig) 17 | except ValidationError as e: 18 | isAllValid = False 19 | errTexts = [] 20 | for err in e.raw_errors: 21 | errLocs = err.loc_tuple() 22 | errInfo = err.exc 23 | errText = f"{', '.join(errLocs)} - {errInfo}" 24 | errTexts.append(errText) 25 | 26 | dbConfigSettingItemContainer.error(f"{'; '.join(errTexts)}") 27 | 28 | return dbConfigs, isAllValid 29 | 30 | 31 | def dbConfigSettingItem(st, activeDb: DB): 32 | st.markdown( 33 | f"
{activeDb.value}
", 34 | unsafe_allow_html=True, 35 | ) 36 | columns = st.columns(DB_CONFIG_SETTING_COLUMNS) 37 | 38 | dbConfigClass = activeDb.config_cls 39 | schema = dbConfigClass.schema() 40 | property_items = schema.get("properties").items() 41 | required_fields = set(schema.get("required", [])) 42 | dbConfig = {} 43 | idx = 0 44 | 45 | # db config (unique) 46 | for key, property in property_items: 47 | if key not in dbConfigClass.common_short_configs() and key not in dbConfigClass.common_long_configs(): 48 | column = columns[idx % DB_CONFIG_SETTING_COLUMNS] 49 | idx += 1 50 | input_value = column.text_input( 51 | key, 52 | key=f"{activeDb.name}-{key}", 53 | value=property.get("default", ""), 54 | type="password" if inputIsPassword(key) else "default", 55 | placeholder="optional" if key not in required_fields else None, 56 | ) 57 | if key in required_fields or input_value: 58 | dbConfig[key] = input_value 59 | 60 | # db config (common short labels) 61 | for key in dbConfigClass.common_short_configs(): 62 | column = columns[idx % DB_CONFIG_SETTING_COLUMNS] 63 | idx += 1 64 | dbConfig[key] = column.text_input( 65 | key, 66 | key="%s-%s" % (activeDb.name, key), 67 | value="", 68 | type="default", 69 | placeholder="optional, for labeling results", 70 | ) 71 | 72 | # db config (common long text_input) 73 | for key in dbConfigClass.common_long_configs(): 74 | dbConfig[key] = st.text_area( 75 | key, 76 | key="%s-%s" % (activeDb.name, key), 77 | value="", 78 | placeholder="optional", 79 | ) 80 | return dbConfig 81 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/dbSelector.py: -------------------------------------------------------------------------------- 1 | from streamlit.runtime.media_file_storage import MediaFileStorageError 2 | from vectordb_bench.frontend.config.styles import DB_SELECTOR_COLUMNS, DB_TO_ICON 3 | from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST 4 | 5 | 6 | def dbSelector(st): 7 | st.markdown( 8 | "
", 9 | unsafe_allow_html=True, 10 | ) 11 | st.subheader("STEP 1: Select the database(s)") 12 | st.markdown( 13 | "
Choose at least one case you want to run the test for.
", 14 | unsafe_allow_html=True, 15 | ) 16 | 17 | dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small") 18 | dbIsActived = {db: False for db in DB_LIST} 19 | 20 | for i, db in enumerate(DB_LIST): 21 | column = dbContainerColumns[i % DB_SELECTOR_COLUMNS] 22 | dbIsActived[db] = column.checkbox(db.name) 23 | try: 24 | column.image(DB_TO_ICON.get(db, "")) 25 | except MediaFileStorageError: 26 | column.warning(f"{db.name} image not available") 27 | pass 28 | activedDbList = [db for db in DB_LIST if dbIsActived[db]] 29 | 30 | return activedDbList 31 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/generateTasks.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.backend.clients import DB 2 | from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig 3 | 4 | 5 | def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs): 6 | tasks = [] 7 | for db in activedDbList: 8 | for case in activedCaseList: 9 | task = TaskConfig( 10 | db=db.value, 11 | db_config=dbConfigs[db], 12 | case_config=case, 13 | db_case_config=db.case_config_cls(allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None))( 14 | **{key.value: value for key, value in allCaseConfigs[db][case].items()} 15 | ), 16 | ) 17 | tasks.append(task) 18 | 19 | return tasks 20 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/hideSidebar.py: -------------------------------------------------------------------------------- 1 | def hideSidebar(st): 2 | st.markdown( 3 | """""", 7 | unsafe_allow_html=True, 8 | ) 9 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/initStyle.py: -------------------------------------------------------------------------------- 1 | def initStyle(st): 2 | st.markdown( 3 | """""", 15 | unsafe_allow_html=True, 16 | ) 17 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/run_test/submitTask.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from vectordb_bench import config 3 | from vectordb_bench.frontend.config import styles 4 | from vectordb_bench.interface import benchmark_runner 5 | from vectordb_bench.models import TaskConfig 6 | 7 | 8 | def submitTask(st, tasks, isAllValid): 9 | st.markdown( 10 | "
", 11 | unsafe_allow_html=True, 12 | ) 13 | st.subheader("STEP 3: Task Label") 14 | st.markdown( 15 | "
This description is used to mark the result.
", 16 | unsafe_allow_html=True, 17 | ) 18 | 19 | taskLabel = taskLabelInput(st) 20 | 21 | st.markdown( 22 | "
", 23 | unsafe_allow_html=True, 24 | ) 25 | 26 | controlPanelContainer = st.container() 27 | controlPanel(controlPanelContainer, tasks, taskLabel, isAllValid) 28 | 29 | 30 | def taskLabelInput(st): 31 | defaultTaskLabel = datetime.now().strftime("%Y%m%d%H") 32 | columns = st.columns(styles.TASK_LABEL_INPUT_COLUMNS) 33 | taskLabel = columns[0].text_input("task_label", defaultTaskLabel, label_visibility="collapsed") 34 | return taskLabel 35 | 36 | 37 | def advancedSettings(st): 38 | container = st.columns([1, 2]) 39 | index_already_exists = container[0].checkbox("Index already exists", value=False) 40 | container[1].caption("if selected, inserting and building will be skipped.") 41 | 42 | container = st.columns([1, 2]) 43 | use_aliyun = container[0].checkbox("Dataset from Aliyun (Shanghai)", value=False) 44 | container[1].caption( 45 | "if selected, the dataset will be downloaded from Aliyun OSS shanghai, default AWS S3 aws-us-west." 46 | ) 47 | 48 | container = st.columns([1, 2]) 49 | k = container[0].number_input("k", min_value=1, value=100, label_visibility="collapsed") 50 | container[1].caption("K value for number of nearest neighbors to search") 51 | 52 | container = st.columns([1, 2]) 53 | defaultconcurrentInput = ",".join(map(str, config.NUM_CONCURRENCY)) 54 | concurrentInput = container[0].text_input( 55 | "Concurrent Input", value=defaultconcurrentInput, label_visibility="collapsed" 56 | ) 57 | container[1].caption("num of concurrencies for search tests to get max-qps") 58 | return index_already_exists, use_aliyun, k, concurrentInput 59 | 60 | 61 | def controlPanel(st, tasks: list[TaskConfig], taskLabel, isAllValid): 62 | index_already_exists, use_aliyun, k, concurrentInput = advancedSettings(st) 63 | 64 | def runHandler(): 65 | benchmark_runner.set_drop_old(not index_already_exists) 66 | 67 | try: 68 | concurrentInput_list = [int(item.strip()) for item in concurrentInput.split(",")] 69 | except ValueError: 70 | st.write("please input correct number") 71 | return None 72 | 73 | for task in tasks: 74 | task.case_config.k = k 75 | task.case_config.concurrency_search_config.num_concurrency = concurrentInput_list 76 | 77 | benchmark_runner.set_download_address(use_aliyun) 78 | benchmark_runner.run(tasks, taskLabel) 79 | 80 | def stopHandler(): 81 | benchmark_runner.stop_running() 82 | 83 | isRunning = benchmark_runner.has_running() 84 | 85 | if isRunning: 86 | currentTaskId = benchmark_runner.get_current_task_id() 87 | tasksCount = benchmark_runner.get_tasks_count() 88 | text = f":running: Running Task {currentTaskId} / {tasksCount}" 89 | st.progress(currentTaskId / tasksCount, text=text) 90 | 91 | columns = st.columns(6) 92 | columns[0].button( 93 | "Run Your Test", 94 | disabled=True, 95 | on_click=runHandler, 96 | type="primary", 97 | ) 98 | columns[1].button( 99 | "Stop", 100 | on_click=stopHandler, 101 | type="primary", 102 | ) 103 | 104 | else: 105 | errorText = benchmark_runner.latest_error or "" 106 | if len(errorText) > 0: 107 | st.error(errorText) 108 | disabled = True if len(tasks) == 0 or not isAllValid else False 109 | if not isAllValid: 110 | st.error("Make sure all config is valid.") 111 | elif len(tasks) == 0: 112 | st.warning("No tests to run.") 113 | st.button( 114 | "Run Your Test", 115 | disabled=disabled, 116 | on_click=runHandler, 117 | type="primary", 118 | ) 119 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/components/tables/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from vectordb_bench.interface import benchmark_runner 3 | from vectordb_bench.models import CaseResult, ResultLabel 4 | import pandas as pd 5 | 6 | 7 | def getNewResults(): 8 | allResults = benchmark_runner.get_results() 9 | newResults: list[CaseResult] = [] 10 | 11 | for res in allResults: 12 | results = res.results 13 | for result in results: 14 | if result.label == ResultLabel.NORMAL: 15 | newResults.append(result) 16 | 17 | df = pd.DataFrame(formatData(newResults)) 18 | return df 19 | 20 | 21 | def formatData(caseResults: list[CaseResult]): 22 | data = [] 23 | for caseResult in caseResults: 24 | db = caseResult.task_config.db.value 25 | db_label = caseResult.task_config.db_config.db_label 26 | case_config = caseResult.task_config.case_config 27 | case = case_config.case_id.case_cls() 28 | filter_rate = case.filter_rate 29 | dataset = case.dataset.data.name 30 | metrics = asdict(caseResult.metrics) 31 | data.append( 32 | { 33 | "db": db, 34 | "db_label": db_label, 35 | "case_name": case.name, 36 | "dataset": dataset, 37 | "filter_rate": filter_rate, 38 | **metrics, 39 | } 40 | ) 41 | return data 42 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/config/dbPrices.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench import config 2 | import ujson 3 | import pathlib 4 | 5 | with open(pathlib.Path(config.RESULTS_LOCAL_DIR, "dbPrices.json")) as f: 6 | DB_DBLABEL_TO_PRICE = ujson.load(f) 7 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/config/styles.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench.models import DB 2 | 3 | # style const 4 | DB_SELECTOR_COLUMNS = 6 5 | DB_CONFIG_SETTING_COLUMNS = 3 6 | CASE_CONFIG_SETTING_COLUMNS = 4 7 | CHECKBOX_INDENT = 30 8 | TASK_LABEL_INPUT_COLUMNS = 2 9 | CHECKBOX_MAX_COLUMNS = 4 10 | DB_CONFIG_INPUT_MAX_COLUMNS = 2 11 | CASE_CONFIG_INPUT_MAX_COLUMNS = 3 12 | DB_CONFIG_INPUT_WIDTH_RADIO = 2 13 | CASE_CONFIG_INPUT_WIDTH_RADIO = 0.98 14 | CASE_INTRO_RATIO = 3 15 | SIDEBAR_CONTROL_COLUMNS = 3 16 | LEGEND_RECT_WIDTH = 24 17 | LEGEND_RECT_HEIGHT = 16 18 | LEGEND_TEXT_FONT_SIZE = 14 19 | 20 | PATTERN_SHAPES = ["", "+", "\\", "x", ".", "|", "/", "-"] 21 | 22 | 23 | def getPatternShape(i): 24 | return PATTERN_SHAPES[i % len(PATTERN_SHAPES)] 25 | 26 | 27 | # run_test page auto-refresh config 28 | MAX_AUTO_REFRESH_COUNT = 999999 29 | MAX_AUTO_REFRESH_INTERVAL = 5000 # 5s 30 | 31 | PAGE_TITLE = "VectorDB Benchmark" 32 | FAVICON = "https://assets.zilliz.com/favicon_f7f922fe27.png" 33 | HEADER_ICON = "https://assets.zilliz.com/vdb_benchmark_db790b5387.png" 34 | 35 | # RedisCloud icon: https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png 36 | # Elasticsearch icon: https://assets.zilliz.com/elasticsearch_beffeadc29.png 37 | # Chroma icon: https://assets.zilliz.com/chroma_ceb3f06ed7.png 38 | DB_TO_ICON = { 39 | DB.Milvus: "https://assets.zilliz.com/milvus_c30b0d1994.png", 40 | DB.ZillizCloud: "https://assets.zilliz.com/zilliz_5f4cc9b050.png", 41 | DB.ElasticCloud: "https://assets.zilliz.com/Elatic_Cloud_dad8d6a3a3.png", 42 | DB.Pinecone: "https://assets.zilliz.com/pinecone_94d8154979.png", 43 | DB.QdrantCloud: "https://assets.zilliz.com/qdrant_b691674fcd.png", 44 | DB.WeaviateCloud: "https://assets.zilliz.com/weaviate_4f6f171ebe.png", 45 | DB.PgVector: "https://assets.zilliz.com/PG_Vector_d464f2ef5f.png", 46 | DB.PgVectoRS: "https://assets.zilliz.com/PG_Vector_d464f2ef5f.png", 47 | DB.Redis: "https://assets.zilliz.com/Redis_Cloud_74b8bfef39.png", 48 | DB.Chroma: "https://assets.zilliz.com/chroma_ceb3f06ed7.png", 49 | DB.AWSOpenSearch: "https://assets.zilliz.com/opensearch_1eee37584e.jpeg", 50 | DB.TiDB: "https://img2.pingcap.com/forms/3/d/3d7fd5f9767323d6f037795704211ac44b4923d6.png", 51 | DB.Vespa: "https://vespa.ai/vespa-content/uploads/2025/01/Vespa-symbol-green-rgb.png.webp", 52 | DB.LanceDB: "https://raw.githubusercontent.com/lancedb/lancedb/main/docs/src/assets/logo.png", 53 | } 54 | 55 | # RedisCloud color: #0D6EFD 56 | # Chroma color: #FFC107 57 | COLOR_MAP = { 58 | DB.Milvus.value: "#0DCAF0", 59 | DB.ZillizCloud.value: "#0D6EFD", 60 | DB.ElasticCloud.value: "#04D6C8", 61 | DB.Pinecone.value: "#6610F2", 62 | DB.QdrantCloud.value: "#D91AD9", 63 | DB.WeaviateCloud.value: "#20C997", 64 | DB.PgVector.value: "#4C779A", 65 | DB.Redis.value: "#0D6EFD", 66 | DB.AWSOpenSearch.value: "#0DCAF0", 67 | DB.TiDB.value: "#0D6EFD", 68 | DB.Vespa.value: "#61d790", 69 | } 70 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/pages/concurrent.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from vectordb_bench.frontend.components.check_results.footer import footer 3 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 4 | from vectordb_bench.frontend.components.check_results.nav import ( 5 | NavToResults, 6 | NavToRunTest, 7 | ) 8 | from vectordb_bench.frontend.components.check_results.filters import getshownData 9 | from vectordb_bench.frontend.components.concurrent.charts import drawChartsByCase 10 | from vectordb_bench.frontend.components.get_results.saveAsImage import getResults 11 | from vectordb_bench.frontend.config.styles import FAVICON 12 | from vectordb_bench.interface import benchmark_runner 13 | from vectordb_bench.models import TestResult 14 | 15 | 16 | def main(): 17 | # set page config 18 | st.set_page_config( 19 | page_title="VDBBench Conc Perf", 20 | page_icon=FAVICON, 21 | layout="wide", 22 | # initial_sidebar_state="collapsed", 23 | ) 24 | 25 | # header 26 | drawHeaderIcon(st) 27 | 28 | allResults = benchmark_runner.get_results() 29 | 30 | def check_conc_data(res: TestResult): 31 | case_results = res.results 32 | count = 0 33 | for case_result in case_results: 34 | if len(case_result.metrics.conc_num_list) > 0: 35 | count += 1 36 | 37 | return count > 0 38 | 39 | checkedResults = [res for res in allResults if check_conc_data(res)] 40 | 41 | st.title("VectorDB Benchmark (Concurrent Performance)") 42 | 43 | # results selector 44 | resultSelectorContainer = st.sidebar.container() 45 | shownData, _, showCaseNames = getshownData(checkedResults, resultSelectorContainer) 46 | 47 | resultSelectorContainer.divider() 48 | 49 | # nav 50 | navContainer = st.sidebar.container() 51 | NavToRunTest(navContainer) 52 | NavToResults(navContainer) 53 | 54 | # save or share 55 | resultesContainer = st.sidebar.container() 56 | getResults(resultesContainer, "vectordb_bench_concurrent") 57 | 58 | # main 59 | latency_type = st.radio("Latency Type", options=["latency_p99", "latency_avg"]) 60 | drawChartsByCase(shownData, showCaseNames, st.container(), latency_type=latency_type) 61 | 62 | # footer 63 | footer(st.container()) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/pages/custom.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import streamlit as st 3 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 4 | from vectordb_bench.frontend.components.custom.displayCustomCase import ( 5 | displayCustomCase, 6 | ) 7 | from vectordb_bench.frontend.components.custom.displaypPrams import displayParams 8 | from vectordb_bench.frontend.components.custom.getCustomConfig import ( 9 | CustomCaseConfig, 10 | generate_custom_case, 11 | get_custom_configs, 12 | save_custom_configs, 13 | ) 14 | from vectordb_bench.frontend.components.custom.initStyle import initStyle 15 | from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE 16 | 17 | 18 | class CustomCaseManager: 19 | customCaseItems: list[CustomCaseConfig] 20 | 21 | def __init__(self): 22 | self.customCaseItems = get_custom_configs() 23 | 24 | def addCase(self): 25 | new_custom_case = generate_custom_case() 26 | new_custom_case.dataset_config.name = f"{new_custom_case.dataset_config.name} {len(self.customCaseItems)}" 27 | self.customCaseItems += [new_custom_case] 28 | self.save() 29 | 30 | def deleteCase(self, idx: int): 31 | self.customCaseItems.pop(idx) 32 | self.save() 33 | 34 | def save(self): 35 | save_custom_configs(self.customCaseItems) 36 | 37 | 38 | def main(): 39 | st.set_page_config( 40 | page_title=PAGE_TITLE, 41 | page_icon=FAVICON, 42 | # layout="wide", 43 | # initial_sidebar_state="collapsed", 44 | ) 45 | 46 | # header 47 | drawHeaderIcon(st) 48 | 49 | # init style 50 | initStyle(st) 51 | 52 | st.title("Custom Dataset") 53 | displayParams(st) 54 | customCaseManager = CustomCaseManager() 55 | 56 | for idx, customCase in enumerate(customCaseManager.customCaseItems): 57 | expander = st.expander(customCase.dataset_config.name, expanded=True) 58 | key = f"custom_case_{idx}" 59 | displayCustomCase(customCase, expander, key=key) 60 | 61 | columns = expander.columns(8) 62 | columns[0].button( 63 | "Save", 64 | key=f"{key}_", 65 | type="secondary", 66 | on_click=lambda: customCaseManager.save(), 67 | ) 68 | columns[1].button( 69 | ":red[Delete]", 70 | key=f"{key}_delete", 71 | type="secondary", 72 | # B023 73 | on_click=partial(lambda idx: customCaseManager.deleteCase(idx), idx=idx), 74 | ) 75 | 76 | st.button( 77 | "\+ New Dataset", 78 | key="add_custom_configs", 79 | type="primary", 80 | on_click=lambda: customCaseManager.addCase(), 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/pages/quries_per_dollar.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from vectordb_bench.frontend.components.check_results.footer import footer 3 | from vectordb_bench.frontend.components.check_results.expanderStyle import ( 4 | initMainExpanderStyle, 5 | ) 6 | from vectordb_bench.frontend.components.check_results.priceTable import priceTable 7 | from vectordb_bench.frontend.components.check_results.stPageConfig import ( 8 | initResultsPageConfig, 9 | ) 10 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 11 | from vectordb_bench.frontend.components.check_results.nav import ( 12 | NavToResults, 13 | NavToRunTest, 14 | ) 15 | from vectordb_bench.frontend.components.check_results.charts import drawMetricChart 16 | from vectordb_bench.frontend.components.check_results.filters import getshownData 17 | from vectordb_bench.frontend.components.get_results.saveAsImage import getResults 18 | 19 | from vectordb_bench.interface import benchmark_runner 20 | from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC 21 | 22 | 23 | def main(): 24 | # set page config 25 | initResultsPageConfig(st) 26 | 27 | # header 28 | drawHeaderIcon(st) 29 | 30 | allResults = benchmark_runner.get_results() 31 | 32 | st.title("Vector DB Benchmark (QP$)") 33 | 34 | # results selector 35 | resultSelectorContainer = st.sidebar.container() 36 | shownData, _, showCaseNames = getshownData(allResults, resultSelectorContainer) 37 | 38 | resultSelectorContainer.divider() 39 | 40 | # nav 41 | navContainer = st.sidebar.container() 42 | NavToRunTest(navContainer) 43 | NavToResults(navContainer) 44 | 45 | # save or share 46 | resultesContainer = st.sidebar.container() 47 | getResults(resultesContainer, "vectordb_bench_qp$") 48 | 49 | # price table 50 | initMainExpanderStyle(st) 51 | priceTableContainer = st.container() 52 | priceMap = priceTable(priceTableContainer, shownData) 53 | 54 | # charts 55 | for caseName in showCaseNames: 56 | data = [data for data in shownData if data["case_name"] == caseName] 57 | dataWithMetric = [] 58 | metric = QURIES_PER_DOLLAR_METRIC 59 | for d in data: 60 | qps = d.get("qps", 0) 61 | price = priceMap.get(d["db"], {}).get(d["db_label"], 0) 62 | if qps > 0 and price > 0: 63 | d[metric] = d["qps"] / price * 3.6 64 | dataWithMetric.append(d) 65 | if len(dataWithMetric) > 0: 66 | chartContainer = st.expander(caseName, True) 67 | key = f"{caseName}-{metric}" 68 | drawMetricChart(data, metric, chartContainer, key=key) 69 | 70 | # footer 71 | footer(st.container()) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/pages/run_test.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from vectordb_bench.frontend.components.run_test.autoRefresh import autoRefresh 3 | from vectordb_bench.frontend.components.run_test.caseSelector import caseSelector 4 | from vectordb_bench.frontend.components.run_test.dbConfigSetting import dbConfigSettings 5 | from vectordb_bench.frontend.components.run_test.dbSelector import dbSelector 6 | from vectordb_bench.frontend.components.run_test.generateTasks import generate_tasks 7 | from vectordb_bench.frontend.components.run_test.hideSidebar import hideSidebar 8 | from vectordb_bench.frontend.components.run_test.initStyle import initStyle 9 | from vectordb_bench.frontend.components.run_test.submitTask import submitTask 10 | from vectordb_bench.frontend.components.check_results.nav import NavToResults 11 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 12 | from vectordb_bench.frontend.components.check_results.stPageConfig import initRunTestPageConfig 13 | 14 | 15 | def main(): 16 | # set page config 17 | initRunTestPageConfig(st) 18 | 19 | # init style 20 | initStyle(st) 21 | 22 | # header 23 | drawHeaderIcon(st) 24 | 25 | # hide sidebar 26 | hideSidebar(st) 27 | 28 | # nav to results 29 | NavToResults(st) 30 | 31 | # header 32 | st.title("Run Your Test") 33 | # st.write("description [todo]") 34 | 35 | # select db 36 | dbSelectorContainer = st.container() 37 | activedDbList = dbSelector(dbSelectorContainer) 38 | 39 | # db config setting 40 | dbConfigs = {} 41 | isAllValid = True 42 | if len(activedDbList) > 0: 43 | dbConfigContainer = st.container() 44 | dbConfigs, isAllValid = dbConfigSettings(dbConfigContainer, activedDbList) 45 | 46 | # select case and set db_case_config 47 | caseSelectorContainer = st.container() 48 | activedCaseList, allCaseConfigs = caseSelector(caseSelectorContainer, activedDbList) 49 | 50 | # generate tasks 51 | tasks = generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs) if isAllValid else [] 52 | 53 | # submit 54 | submitContainer = st.container() 55 | submitTask(submitContainer, tasks, isAllValid) 56 | 57 | # nav to results 58 | NavToResults(st, key="footer-nav-to-results") 59 | 60 | # autofresh 61 | autoRefresh() 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/pages/tables.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 3 | from vectordb_bench.frontend.components.tables.data import getNewResults 4 | from vectordb_bench.frontend.config.styles import FAVICON 5 | 6 | 7 | def main(): 8 | # set page config 9 | st.set_page_config( 10 | page_title="Table", 11 | page_icon=FAVICON, 12 | layout="wide", 13 | # initial_sidebar_state="collapsed", 14 | ) 15 | 16 | # header 17 | drawHeaderIcon(st) 18 | 19 | df = getNewResults() 20 | st.dataframe(df, height=800) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | 4 | 5 | passwordKeys = ["password", "api_key"] 6 | 7 | 8 | def inputIsPassword(key: str) -> bool: 9 | return key.lower() in passwordKeys 10 | 11 | 12 | def addHorizontalLine(st): 13 | st.markdown( 14 | "
", 15 | unsafe_allow_html=True, 16 | ) 17 | 18 | 19 | def generate_random_string(length): 20 | letters = string.ascii_letters + string.digits 21 | result = "".join(random.choice(letters) for _ in range(length)) 22 | return result 23 | -------------------------------------------------------------------------------- /vectordb_bench/frontend/vdb_benchmark.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from vectordb_bench.frontend.components.check_results.footer import footer 3 | from vectordb_bench.frontend.components.check_results.stPageConfig import ( 4 | initResultsPageConfig, 5 | ) 6 | from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon 7 | from vectordb_bench.frontend.components.check_results.nav import ( 8 | NavToQuriesPerDollar, 9 | NavToRunTest, 10 | ) 11 | from vectordb_bench.frontend.components.check_results.charts import drawCharts 12 | from vectordb_bench.frontend.components.check_results.filters import getshownData 13 | from vectordb_bench.frontend.components.get_results.saveAsImage import getResults 14 | 15 | from vectordb_bench.interface import benchmark_runner 16 | 17 | 18 | def main(): 19 | # set page config 20 | initResultsPageConfig(st) 21 | 22 | # header 23 | drawHeaderIcon(st) 24 | 25 | allResults = benchmark_runner.get_results() 26 | 27 | st.title("Vector Database Benchmark") 28 | st.caption( 29 | "Except for zillizcloud-v2024.1, which was tested in _January 2024_, all other tests were completed before _August 2023_." 30 | ) 31 | st.caption("All tested milvus are in _standalone_ mode.") 32 | 33 | # results selector and filter 34 | resultSelectorContainer = st.sidebar.container() 35 | shownData, failedTasks, showCaseNames = getshownData(allResults, resultSelectorContainer) 36 | 37 | resultSelectorContainer.divider() 38 | 39 | # nav 40 | navContainer = st.sidebar.container() 41 | NavToRunTest(navContainer) 42 | NavToQuriesPerDollar(navContainer) 43 | 44 | # save or share 45 | resultesContainer = st.sidebar.container() 46 | getResults(resultesContainer, "vectordb_bench") 47 | 48 | # charts 49 | drawCharts(st, shownData, failedTasks, showCaseNames) 50 | 51 | # footer 52 | footer(st.container()) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /vectordb_bench/log_util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import config 3 | from pathlib import Path 4 | 5 | 6 | def init(log_level: str): 7 | # Create logs directory if it doesn't exist 8 | log_dir = Path("logs") 9 | log_dir.mkdir(exist_ok=True) 10 | 11 | log_config = { 12 | "version": 1, 13 | "disable_existing_loggers": False, 14 | "formatters": { 15 | "default": { 16 | "format": "%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)", 17 | }, 18 | "colorful_console": { 19 | "format": "%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)", 20 | "()": ColorfulFormatter, 21 | }, 22 | }, 23 | "handlers": { 24 | "console": { 25 | "class": "logging.StreamHandler", 26 | "formatter": "colorful_console", 27 | }, 28 | "no_color_console": { 29 | "class": "logging.StreamHandler", 30 | "formatter": "default", 31 | }, 32 | "file": { 33 | "class": "logging.handlers.RotatingFileHandler", 34 | "formatter": "default", 35 | "filename": "logs/vectordb_bench.log", 36 | "maxBytes": 10485760, # 10MB 37 | "backupCount": 5, 38 | "encoding": "utf8", 39 | }, 40 | }, 41 | "loggers": { 42 | "vectordb_bench": { 43 | "handlers": ["console", "file"], 44 | "level": log_level, 45 | "propagate": False, 46 | }, 47 | "no_color": { 48 | "handlers": ["no_color_console", "file"], 49 | "level": log_level, 50 | "propagate": False, 51 | }, 52 | }, 53 | "propagate": False, 54 | } 55 | 56 | config.dictConfig(log_config) 57 | 58 | 59 | class colors: 60 | HEADER = "\033[95m" 61 | INFO = "\033[92m" 62 | DEBUG = "\033[94m" 63 | WARNING = "\033[93m" 64 | ERROR = "\033[95m" 65 | CRITICAL = "\033[91m" 66 | ENDC = "\033[0m" 67 | 68 | 69 | COLORS = { 70 | "INFO": colors.INFO, 71 | "INFOM": colors.INFO, 72 | "DEBUG": colors.DEBUG, 73 | "DEBUGM": colors.DEBUG, 74 | "WARNING": colors.WARNING, 75 | "WARNINGM": colors.WARNING, 76 | "CRITICAL": colors.CRITICAL, 77 | "CRITICALM": colors.CRITICAL, 78 | "ERROR": colors.ERROR, 79 | "ERRORM": colors.ERROR, 80 | "ENDC": colors.ENDC, 81 | } 82 | 83 | 84 | class ColorFulFormatColMixin: 85 | def format_col(self, message: str, level_name: str): 86 | if level_name in COLORS: 87 | message = COLORS[level_name] + message + COLORS["ENDC"] 88 | return message 89 | 90 | 91 | class ColorfulLogRecordProxy(logging.LogRecord): 92 | def __init__(self, record: any): 93 | self._record = record 94 | msg_level = record.levelname + "M" 95 | self.msg = f"{COLORS[msg_level]}{record.msg}{COLORS['ENDC']}" 96 | self.filename = record.filename 97 | self.lineno = f"{record.lineno}" 98 | self.process = f"{record.process}" 99 | self.levelname = f"{COLORS[record.levelname]}{record.levelname}{COLORS['ENDC']}" 100 | 101 | def __getattr__(self, attr: any): 102 | if attr not in self.__dict__: 103 | return getattr(self._record, attr) 104 | return getattr(self, attr) 105 | 106 | 107 | class ColorfulFormatter(ColorFulFormatColMixin, logging.Formatter): 108 | def format(self, record: any): 109 | proxy = ColorfulLogRecordProxy(record) 110 | return super().format(proxy) 111 | -------------------------------------------------------------------------------- /vectordb_bench/metric.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass, field 3 | 4 | import numpy as np 5 | 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | @dataclass 10 | class Metric: 11 | """result metrics""" 12 | 13 | # for load cases 14 | max_load_count: int = 0 15 | 16 | # for performance cases 17 | load_duration: float = 0.0 # duration to load all dataset into DB 18 | qps: float = 0.0 19 | serial_latency_p99: float = 0.0 20 | recall: float = 0.0 21 | ndcg: float = 0.0 22 | conc_num_list: list[int] = field(default_factory=list) 23 | conc_qps_list: list[float] = field(default_factory=list) 24 | conc_latency_p99_list: list[float] = field(default_factory=list) 25 | conc_latency_avg_list: list[float] = field(default_factory=list) 26 | 27 | 28 | QURIES_PER_DOLLAR_METRIC = "QP$ (Quries per Dollar)" 29 | LOAD_DURATION_METRIC = "load_duration" 30 | SERIAL_LATENCY_P99_METRIC = "serial_latency_p99" 31 | MAX_LOAD_COUNT_METRIC = "max_load_count" 32 | QPS_METRIC = "qps" 33 | RECALL_METRIC = "recall" 34 | 35 | metric_unit_map = { 36 | LOAD_DURATION_METRIC: "s", 37 | SERIAL_LATENCY_P99_METRIC: "ms", 38 | MAX_LOAD_COUNT_METRIC: "K", 39 | QURIES_PER_DOLLAR_METRIC: "K", 40 | } 41 | 42 | lower_is_better_metrics = [ 43 | LOAD_DURATION_METRIC, 44 | SERIAL_LATENCY_P99_METRIC, 45 | ] 46 | 47 | metric_order = [ 48 | QPS_METRIC, 49 | RECALL_METRIC, 50 | LOAD_DURATION_METRIC, 51 | SERIAL_LATENCY_P99_METRIC, 52 | MAX_LOAD_COUNT_METRIC, 53 | ] 54 | 55 | 56 | def isLowerIsBetterMetric(metric: str) -> bool: 57 | return metric in lower_is_better_metrics 58 | 59 | 60 | def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float: 61 | recalls = np.zeros(count) 62 | for i, result in enumerate(got): 63 | if result in ground_truth: 64 | recalls[i] = 1 65 | 66 | return np.mean(recalls) 67 | 68 | 69 | def get_ideal_dcg(k: int): 70 | ideal_dcg = 0 71 | for i in range(k): 72 | ideal_dcg += 1 / np.log2(i + 2) 73 | 74 | return ideal_dcg 75 | 76 | 77 | def calc_ndcg(ground_truth: list[int], got: list[int], ideal_dcg: float) -> float: 78 | dcg = 0 79 | ground_truth = list(ground_truth) 80 | for got_id in set(got): 81 | if got_id in ground_truth: 82 | idx = ground_truth.index(got_id) 83 | dcg += 1 / np.log2(idx + 2) 84 | return dcg / ideal_dcg 85 | -------------------------------------------------------------------------------- /vectordb_bench/results/ElasticCloud/result_20230808_standard_elasticcloud.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_id": "5c1e8bd468224ffda1b39b08cdc342c3", 3 | "task_label": "standard", 4 | "results": [ 5 | { 6 | "metrics": { 7 | "max_load_count": 0, 8 | "load_duration": 8671.2705, 9 | "qps": 11.2945, 10 | "serial_latency_p99": 3.6112, 11 | "recall": 0.996 12 | }, 13 | "task_config": { 14 | "db": "ElasticCloud", 15 | "db_config": { 16 | "db_label": "upTo2.5c8g", 17 | "cloud_id": "**********", 18 | "password": "**********" 19 | }, 20 | "db_case_config": { 21 | "element_type": "float", 22 | "index": "hnsw", 23 | "metric_type": "COSINE", 24 | "efConstruction": 360, 25 | "M": 30, 26 | "num_candidates": null 27 | }, 28 | "case_config": { 29 | "case_id": 10, 30 | "custom_case": {} 31 | } 32 | }, 33 | "label": ":)" 34 | }, 35 | { 36 | "metrics": { 37 | "max_load_count": 0, 38 | "load_duration": 0.0, 39 | "qps": 0.0, 40 | "serial_latency_p99": 0.0, 41 | "recall": 0.0 42 | }, 43 | "task_config": { 44 | "db": "ElasticCloud", 45 | "db_config": { 46 | "db_label": "upTo2.5c8g", 47 | "cloud_id": "**********", 48 | "password": "**********" 49 | }, 50 | "db_case_config": { 51 | "element_type": "float", 52 | "index": "hnsw", 53 | "metric_type": "COSINE", 54 | "efConstruction": 360, 55 | "M": 30, 56 | "num_candidates": null 57 | }, 58 | "case_config": { 59 | "case_id": 11, 60 | "custom_case": {} 61 | } 62 | }, 63 | "label": "?" 64 | }, 65 | { 66 | "metrics": { 67 | "max_load_count": 0, 68 | "load_duration": 8671.2705, 69 | "qps": 17.3271, 70 | "serial_latency_p99": 3.7748, 71 | "recall": 0.9961 72 | }, 73 | "task_config": { 74 | "db": "ElasticCloud", 75 | "db_config": { 76 | "db_label": "upTo2.5c8g", 77 | "cloud_id": "**********", 78 | "password": "**********" 79 | }, 80 | "db_case_config": { 81 | "element_type": "float", 82 | "index": "hnsw", 83 | "metric_type": "COSINE", 84 | "efConstruction": 360, 85 | "M": 30, 86 | "num_candidates": null 87 | }, 88 | "case_config": { 89 | "case_id": 12, 90 | "custom_case": {} 91 | } 92 | }, 93 | "label": ":)" 94 | }, 95 | { 96 | "metrics": { 97 | "max_load_count": 0, 98 | "load_duration": 0.0, 99 | "qps": 0.0, 100 | "serial_latency_p99": 0.0, 101 | "recall": 0.0 102 | }, 103 | "task_config": { 104 | "db": "ElasticCloud", 105 | "db_config": { 106 | "db_label": "upTo2.5c8g", 107 | "cloud_id": "**********", 108 | "password": "**********" 109 | }, 110 | "db_case_config": { 111 | "element_type": "float", 112 | "index": "hnsw", 113 | "metric_type": "COSINE", 114 | "efConstruction": 360, 115 | "M": 30, 116 | "num_candidates": null 117 | }, 118 | "case_config": { 119 | "case_id": 13, 120 | "custom_case": {} 121 | } 122 | }, 123 | "label": "?" 124 | }, 125 | { 126 | "metrics": { 127 | "max_load_count": 0, 128 | "load_duration": 8671.2705, 129 | "qps": 26.26, 130 | "serial_latency_p99": 0.5561, 131 | "recall": 0.9999 132 | }, 133 | "task_config": { 134 | "db": "ElasticCloud", 135 | "db_config": { 136 | "db_label": "upTo2.5c8g", 137 | "cloud_id": "**********", 138 | "password": "**********" 139 | }, 140 | "db_case_config": { 141 | "element_type": "float", 142 | "index": "hnsw", 143 | "metric_type": "COSINE", 144 | "efConstruction": 360, 145 | "M": 30, 146 | "num_candidates": null 147 | }, 148 | "case_config": { 149 | "case_id": 14, 150 | "custom_case": {} 151 | } 152 | }, 153 | "label": ":)" 154 | }, 155 | { 156 | "metrics": { 157 | "max_load_count": 0, 158 | "load_duration": 0.0, 159 | "qps": 0.0, 160 | "serial_latency_p99": 0.0, 161 | "recall": 0.0 162 | }, 163 | "task_config": { 164 | "db": "ElasticCloud", 165 | "db_config": { 166 | "db_label": "upTo2.5c8g", 167 | "cloud_id": "**********", 168 | "password": "**********" 169 | }, 170 | "db_case_config": { 171 | "element_type": "float", 172 | "index": "hnsw", 173 | "metric_type": "COSINE", 174 | "efConstruction": 360, 175 | "M": 30, 176 | "num_candidates": null 177 | }, 178 | "case_config": { 179 | "case_id": 15, 180 | "custom_case": {} 181 | } 182 | }, 183 | "label": "?" 184 | } 185 | ], 186 | "file_fmt": "result_{}_{}_{}.json" 187 | } -------------------------------------------------------------------------------- /vectordb_bench/results/PgVector/result_20230808_standard_pgvector.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_id": "5c1e8bd468224ffda1b39b08cdc342c3", 3 | "task_label": "standard", 4 | "results": [ 5 | { 6 | "metrics": { 7 | "max_load_count": 0, 8 | "load_duration": 1380.9471, 9 | "qps": 0.8836, 10 | "serial_latency_p99": 2.523, 11 | "recall": 0.8528 12 | }, 13 | "task_config": { 14 | "db": "PgVector", 15 | "db_config": { 16 | "db_label": "2c8g", 17 | "user_name": "**********", 18 | "password": "**********", 19 | "url": "**********", 20 | "db_name": "**********" 21 | }, 22 | "db_case_config": { 23 | "index": "IVF_FLAT", 24 | "metric_type": "L2", 25 | "lists": 10, 26 | "probes": 2 27 | }, 28 | "case_config": { 29 | "case_id": 10, 30 | "custom_case": {} 31 | } 32 | }, 33 | "label": ":)" 34 | }, 35 | { 36 | "metrics": { 37 | "max_load_count": 0, 38 | "load_duration": 0.0, 39 | "qps": 0.0, 40 | "serial_latency_p99": 0.0, 41 | "recall": 0.0 42 | }, 43 | "task_config": { 44 | "db": "PgVector", 45 | "db_config": { 46 | "db_label": "2c8g", 47 | "user_name": "**********", 48 | "password": "**********", 49 | "url": "**********", 50 | "db_name": "**********" 51 | }, 52 | "db_case_config": { 53 | "metric_type": "L2", 54 | "lists": 10, 55 | "probes": 2, 56 | "index": "IVF_FLAT" 57 | }, 58 | "case_config": { 59 | "case_id": 11, 60 | "custom_case": {} 61 | } 62 | }, 63 | "label": "?" 64 | }, 65 | { 66 | "metrics": { 67 | "max_load_count": 0, 68 | "load_duration": 1380.9471, 69 | "qps": 0.8937, 70 | "serial_latency_p99": 3.7202, 71 | "recall": 0.8525 72 | }, 73 | "task_config": { 74 | "db": "PgVector", 75 | "db_config": { 76 | "db_label": "2c8g", 77 | "user_name": "**********", 78 | "password": "**********", 79 | "url": "**********", 80 | "db_name": "**********" 81 | }, 82 | "db_case_config": { 83 | "metric_type": "L2", 84 | "lists": 10, 85 | "probes": 2, 86 | "index": "IVF_FLAT" 87 | }, 88 | "case_config": { 89 | "case_id": 12, 90 | "custom_case": {} 91 | } 92 | }, 93 | "label": ":)" 94 | }, 95 | { 96 | "metrics": { 97 | "max_load_count": 0, 98 | "load_duration": 0.0, 99 | "qps": 0.0, 100 | "serial_latency_p99": 0.0, 101 | "recall": 0.0 102 | }, 103 | "task_config": { 104 | "db": "PgVector", 105 | "db_config": { 106 | "db_label": "2c8g", 107 | "user_name": "**********", 108 | "password": "**********", 109 | "url": "**********", 110 | "db_name": "**********" 111 | }, 112 | "db_case_config": { 113 | "index": "IVF_FLAT", 114 | "metric_type": "L2", 115 | "lists": 10, 116 | "probes": 2 117 | }, 118 | "case_config": { 119 | "case_id": 13, 120 | "custom_case": {} 121 | } 122 | }, 123 | "label": "?" 124 | }, 125 | { 126 | "metrics": { 127 | "max_load_count": 0, 128 | "load_duration": 1372.7522, 129 | "qps": 1.2145, 130 | "serial_latency_p99": 3.6224, 131 | "recall": 0.7487 132 | }, 133 | "task_config": { 134 | "db": "PgVector", 135 | "db_config": { 136 | "db_label": "2c8g", 137 | "user_name": "**********", 138 | "password": "**********", 139 | "url": "**********", 140 | "db_name": "**********" 141 | }, 142 | "db_case_config": { 143 | "index": "IVF_FLAT", 144 | "metric_type": "L2", 145 | "lists": 10, 146 | "probes": 2 147 | }, 148 | "case_config": { 149 | "case_id": 14, 150 | "custom_case": {} 151 | } 152 | }, 153 | "label": ":)" 154 | }, 155 | { 156 | "metrics": { 157 | "max_load_count": 0, 158 | "load_duration": 0.0, 159 | "qps": 0.0, 160 | "serial_latency_p99": 0.0, 161 | "recall": 0.0 162 | }, 163 | "task_config": { 164 | "db": "PgVector", 165 | "db_config": { 166 | "db_label": "2c8g", 167 | "user_name": "**********", 168 | "password": "**********", 169 | "url": "**********", 170 | "db_name": "**********" 171 | }, 172 | "db_case_config": { 173 | "index": "IVF_FLAT", 174 | "metric_type": "L2", 175 | "lists": 10, 176 | "probes": 2 177 | }, 178 | "case_config": { 179 | "case_id": 15, 180 | "custom_case": {} 181 | } 182 | }, 183 | "label": "?" 184 | } 185 | ], 186 | "file_fmt": "result_{}_{}_{}.json" 187 | } 188 | -------------------------------------------------------------------------------- /vectordb_bench/results/dbPrices.json: -------------------------------------------------------------------------------- 1 | { 2 | "Milvus": {}, 3 | "ZillizCloud": { 4 | "1cu-perf": 0.159, 5 | "8cu-perf": 1.272, 6 | "1cu-cap": 0.159, 7 | "2cu-cap": 0.318 8 | }, 9 | "WeaviateCloud": { 10 | "standard": 10.1, 11 | "bus_crit": 32.6 12 | }, 13 | "ElasticCloud": { 14 | "upTo2.5c8g": 0.4793 15 | }, 16 | "QdrantCloud": { 17 | "0.5c4g-1node": 0.052, 18 | "2c8g-1node": 0.166, 19 | "4c16g-1node": 0.2852, 20 | "4c16g-5node": 1.426 21 | }, 22 | "Pinecone": { 23 | "s1.x1": 0.0973, 24 | "s1.x2": 0.194, 25 | "p1.x1": 0.0973, 26 | "p2.x1": 0.146, 27 | "p2.x1-8node": 1.168, 28 | "p1.x1-8node": 0.779, 29 | "s1.x1-2node": 0.195 30 | }, 31 | "PgVector": {} 32 | } -------------------------------------------------------------------------------- /vectordb_bench/results/getLeaderboardData.py: -------------------------------------------------------------------------------- 1 | from vectordb_bench import config 2 | import ujson 3 | import pathlib 4 | from vectordb_bench.backend.cases import CaseType 5 | from vectordb_bench.backend.clients import DB 6 | from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE 7 | from vectordb_bench.interface import benchMarkRunner 8 | from vectordb_bench.models import ResultLabel, TestResult 9 | from datetime import datetime 10 | 11 | taskLabelToCode = { 12 | ResultLabel.FAILED: -1, 13 | ResultLabel.OUTOFRANGE: -2, 14 | ResultLabel.NORMAL: 1, 15 | } 16 | 17 | 18 | def format_time(ts: float) -> str: 19 | default_standard_test_time = datetime(2023, 8, 1) 20 | t = datetime.fromtimestamp(ts) 21 | if t < default_standard_test_time: 22 | t = default_standard_test_time 23 | return t.strftime("%Y-%m") 24 | 25 | 26 | def main(): 27 | allResults: list[TestResult] = benchMarkRunner.get_results() 28 | 29 | if allResults is not None: 30 | data = [ 31 | { 32 | "db": d.task_config.db.value, 33 | "db_label": d.task_config.db_config.db_label, 34 | "db_name": d.task_config.db_name, 35 | "case": d.task_config.case_config.case_id.case_name(), 36 | "qps": d.metrics.qps, 37 | "latency": d.metrics.serial_latency_p99, 38 | "recall": d.metrics.recall, 39 | "label": taskLabelToCode[d.label], 40 | "note": d.task_config.db_config.note, 41 | "version": d.task_config.db_config.version, 42 | "test_time": format_time(test_result.timestamp), 43 | } 44 | for test_result in allResults 45 | if "standard" in test_result.task_label 46 | for d in test_result.results 47 | if d.task_config.case_config.case_id != CaseType.CapacityDim128 48 | and d.task_config.case_config.case_id != CaseType.CapacityDim960 49 | if d.task_config.db != DB.ZillizCloud 50 | or test_result.timestamp >= datetime(2024, 1, 1).timestamp() 51 | ] 52 | 53 | # compute qp$ 54 | for d in data: 55 | db = d["db"] 56 | db_label = d["db_label"] 57 | qps = d["qps"] 58 | price = DB_DBLABEL_TO_PRICE.get(db, {}).get(db_label, 0) 59 | d["qp$"] = (qps / price * 3600) if price > 0 else 0.0 60 | 61 | with open(pathlib.Path(config.RESULTS_LOCAL_DIR, "leaderboard.json"), "w") as f: 62 | ujson.dump(data, f) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | --------------------------------------------------------------------------------