├── README.md └── mxbai_binary_quantization.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # binary-embeddings 2 | Showcase how mxbai-embed-large-v1 can be used to produce binary embedding. Binary embeddings enabled 32x storage savings and 40x faster retrieval. 3 | -------------------------------------------------------------------------------- /mxbai_binary_quantization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Binary embeddings with [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)\n", 8 | "\n", 9 | "Our model was trained to have a non-'clunky' embeddings space. This allows for quantizing the embeddings with low performance loss compared to techniques like Matryoshka. With binary embeddings, we can use the Hamming distance, which is well optimized for CPUs.\n", 10 | "\n", 11 | "In general, the approach is divided into 2 steps:\n", 12 | "\n", 13 | "1. Retrieve candidates based on Hamming distance.\n", 14 | "2. Rescore the candidates based on the dot product between the binary embedding and the floating embedding of the query.\n", 15 | "\n", 16 | "We find that we can retain ~96-99% of the performance, achieve ~40x faster retrieval, and realize 32x storage savings." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Requirement already satisfied: sentence_transformers in /home/aamir/.local/lib/python3.9/site-packages (2.6.0.dev0)\n", 29 | "Requirement already satisfied: datasets in /home/aamir/.local/lib/python3.9/site-packages (2.18.0)\n", 30 | "Requirement already satisfied: beir in /home/aamir/.local/lib/python3.9/site-packages (2.0.0)\n", 31 | "Requirement already satisfied: faiss-cpu in /home/aamir/.local/lib/python3.9/site-packages (1.8.0)\n", 32 | "Requirement already satisfied: pytrec-eval in /home/aamir/.local/lib/python3.9/site-packages (from beir) (0.5)\n", 33 | "Requirement already satisfied: elasticsearch==7.9.1 in /home/aamir/.local/lib/python3.9/site-packages (from beir) (7.9.1)\n", 34 | "Requirement already satisfied: certifi in /usr/lib/python3/dist-packages (from elasticsearch==7.9.1->beir) (2020.6.20)\n", 35 | "Requirement already satisfied: urllib3>=1.21.1 in /home/aamir/.local/lib/python3.9/site-packages (from elasticsearch==7.9.1->beir) (2.2.1)\n", 36 | "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (0.3.8)\n", 37 | "Requirement already satisfied: huggingface-hub>=0.19.4 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (0.21.4)\n", 38 | "Requirement already satisfied: tqdm>=4.62.1 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (4.66.2)\n", 39 | "Requirement already satisfied: filelock in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (3.13.1)\n", 40 | "Requirement already satisfied: pyyaml>=5.1 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (6.0.1)\n", 41 | "Requirement already satisfied: pyarrow-hotfix in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (0.6)\n", 42 | "Requirement already satisfied: pyarrow>=12.0.0 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (15.0.1)\n", 43 | "Requirement already satisfied: requests>=2.19.0 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (2.31.0)\n", 44 | "Requirement already satisfied: xxhash in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (3.4.1)\n", 45 | "Requirement already satisfied: fsspec[http]<=2024.2.0,>=2023.1.0 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (2024.2.0)\n", 46 | "Requirement already satisfied: packaging in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (23.2)\n", 47 | "Requirement already satisfied: pandas in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (2.2.1)\n", 48 | "Requirement already satisfied: aiohttp in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (3.9.1)\n", 49 | "Requirement already satisfied: multiprocess in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (0.70.16)\n", 50 | "Requirement already satisfied: numpy>=1.17 in /home/aamir/.local/lib/python3.9/site-packages (from datasets) (1.26.4)\n", 51 | "Requirement already satisfied: frozenlist>=1.1.1 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.4.1)\n", 52 | "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.3)\n", 53 | "Requirement already satisfied: multidict<7.0,>=4.5 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.5)\n", 54 | "Requirement already satisfied: yarl<2.0,>=1.0 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.9.4)\n", 55 | "Requirement already satisfied: aiosignal>=1.1.2 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1)\n", 56 | "Requirement already satisfied: attrs>=17.3.0 in /home/aamir/.local/lib/python3.9/site-packages (from aiohttp->datasets) (23.2.0)\n", 57 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/aamir/.local/lib/python3.9/site-packages (from huggingface-hub>=0.19.4->datasets) (4.10.0)\n", 58 | "Requirement already satisfied: charset-normalizer<4,>=2 in /home/aamir/.local/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (3.3.2)\n", 59 | "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets) (2.10)\n", 60 | "Requirement already satisfied: transformers<5.0.0,>=4.32.0 in /home/aamir/.local/lib/python3.9/site-packages (from sentence_transformers) (4.38.2)\n", 61 | "Requirement already satisfied: scipy in /home/aamir/.local/lib/python3.9/site-packages (from sentence_transformers) (1.12.0)\n", 62 | "Requirement already satisfied: torch>=1.11.0 in /home/aamir/.local/lib/python3.9/site-packages (from sentence_transformers) (2.2.1)\n", 63 | "Requirement already satisfied: scikit-learn in /home/aamir/.local/lib/python3.9/site-packages (from sentence_transformers) (1.4.1.post1)\n", 64 | "Requirement already satisfied: Pillow in /home/aamir/.local/lib/python3.9/site-packages (from sentence_transformers) (10.2.0)\n", 65 | "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.3.1)\n", 66 | "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (8.9.2.26)\n", 67 | "Requirement already satisfied: networkx in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (3.2.1)\n", 68 | "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (11.4.5.107)\n", 69 | "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n", 70 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n", 71 | "Requirement already satisfied: sympy in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (1.12)\n", 72 | "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (10.3.2.106)\n", 73 | "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (2.19.3)\n", 74 | "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.0.106)\n", 75 | "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (11.0.2.54)\n", 76 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n", 77 | "Requirement already satisfied: triton==2.2.0 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (2.2.0)\n", 78 | "Requirement already satisfied: jinja2 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (3.1.3)\n", 79 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/aamir/.local/lib/python3.9/site-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n", 80 | "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/aamir/.local/lib/python3.9/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.11.0->sentence_transformers) (12.4.99)\n", 81 | "Requirement already satisfied: tokenizers<0.19,>=0.14 in /home/aamir/.local/lib/python3.9/site-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (0.15.2)\n", 82 | "Requirement already satisfied: regex!=2019.12.17 in /home/aamir/.local/lib/python3.9/site-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (2023.12.25)\n", 83 | "Requirement already satisfied: safetensors>=0.4.1 in /home/aamir/.local/lib/python3.9/site-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (0.4.2)\n", 84 | "Requirement already satisfied: MarkupSafe>=2.0 in /home/aamir/.local/lib/python3.9/site-packages (from jinja2->torch>=1.11.0->sentence_transformers) (2.1.5)\n", 85 | "Requirement already satisfied: python-dateutil>=2.8.2 in /home/aamir/.local/lib/python3.9/site-packages (from pandas->datasets) (2.9.0.post0)\n", 86 | "Requirement already satisfied: pytz>=2020.1 in /home/aamir/.local/lib/python3.9/site-packages (from pandas->datasets) (2024.1)\n", 87 | "Requirement already satisfied: tzdata>=2022.7 in /home/aamir/.local/lib/python3.9/site-packages (from pandas->datasets) (2024.1)\n", 88 | "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", 89 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/aamir/.local/lib/python3.9/site-packages (from scikit-learn->sentence_transformers) (3.3.0)\n", 90 | "Requirement already satisfied: joblib>=1.2.0 in /home/aamir/.local/lib/python3.9/site-packages (from scikit-learn->sentence_transformers) (1.3.2)\n", 91 | "Requirement already satisfied: mpmath>=0.19 in /home/aamir/.local/lib/python3.9/site-packages (from sympy->torch>=1.11.0->sentence_transformers) (1.3.0)\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "!pip install sentence_transformers datasets beir faiss-cpu" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 2, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stderr", 106 | "output_type": "stream", 107 | "text": [ 108 | "/home/aamir/.local/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 109 | " from .autonotebook import tqdm as notebook_tqdm\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "from sentence_transformers import SentenceTransformer\n", 115 | "from sentence_transformers import models\n", 116 | "from datasets import load_dataset\n", 117 | "from beir.retrieval.evaluation import EvaluateRetrieval\n", 118 | "from mteb import MTEB\n", 119 | "import numpy as np\n", 120 | "import faiss\n", 121 | "import time\n", 122 | "import os" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "Let's use the worlds best model xD" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 3, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "model = SentenceTransformer(\"mixedbread-ai/mxbai-embed-large-v1\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "TrecCovid is a nice benchmark, not too large, not too small, also pretty difficult." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 4, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "task = \"mteb/trec-covid\"\n", 155 | "dataset = load_dataset(task, \"corpus\")\n", 156 | "docs_ids = dataset[\"corpus\"][\"_id\"]\n", 157 | "features = [d[\"title\"] + \" \" + d[\"text\"] for d in dataset[\"corpus\"]]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "Let's speedup the calculation by using fp16" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 5, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "_ = model.half()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "On 4xA100 it should take ~2min" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 6, 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "Embeddings computed. Shape: (171332, 1024)\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "pool = model.start_multi_process_pool()\n", 198 | "\n", 199 | "# normalize_embeddings=True will normalize the embeddings to unit length before indexing so the dot product is equal to the cosine similarity\n", 200 | "emb = model.encode_multi_process(features, pool, normalize_embeddings=True)\n", 201 | "print(\"Embeddings computed. Shape:\", emb.shape)\n", 202 | "\n", 203 | "model.stop_multi_process_pool(pool)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "FP32 Index" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "index = faiss.IndexFlatIP(emb.shape[1])\n", 220 | "index.add(emb)\n", 221 | "faiss.write_index(index, \"index_fp32.faiss\")" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": {}, 227 | "source": [ 228 | "Binary Index, convert embeddings using simple thresholding" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 8, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "Binary embeddings computed. Shape: (171332, 128)\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "bemb = np.packbits(emb > 0).reshape(emb.shape[0], -1)\n", 246 | "print(\"Binary embeddings computed. Shape:\", bemb.shape)\n", 247 | "num_dim = emb.shape[1]\n", 248 | "bindex = faiss.IndexBinaryFlat(num_dim)\n", 249 | "bindex.add(bemb)\n", 250 | "faiss.write_index_binary(bindex, \"index_binary.faiss\")" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "Compression size is ~32 as expected" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 9, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "File size of index_fp32.faiss: 701775917\n", 270 | "File size of index_binary.faiss: 21930529\n", 271 | "Compression ratio: 31.999953899880847\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "# check file size\n", 277 | "fp32_index_size = os.path.getsize(\"index_fp32.faiss\")\n", 278 | "binary_index_size = os.path.getsize(\"index_binary.faiss\")\n", 279 | "print(\"File size of index_fp32.faiss:\", fp32_index_size)\n", 280 | "print(\"File size of index_binary.faiss:\", binary_index_size)\n", 281 | "print(\"Compression ratio:\", fp32_index_size / binary_index_size)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "Some BEIR stuff for the eval later" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 10, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "qrels_df = load_dataset(task)[\"test\"]\n", 298 | "qrels = {}\n", 299 | "for row in qrels_df:\n", 300 | " qid = row['query-id']\n", 301 | " cid = row['corpus-id']\n", 302 | " \n", 303 | " if row['score'] > 0:\n", 304 | " if qid not in qrels:\n", 305 | " qrels[qid] = {}\n", 306 | " qrels[qid][cid] = int(row['score'])" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 11, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "queries = load_dataset(task, \"queries\")\n", 316 | "queries = queries.filter(lambda x: x['_id'] in qrels)\n", 317 | "\n", 318 | "query_ids = queries[\"queries\"][\"_id\"]\n", 319 | "queries = [\"Represent this sentence for searching relevant passages: \" + d[\"text\"] for d in queries[\"queries\"]]" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 12, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "model.float()\n", 329 | "query_emb = model.encode(queries, convert_to_numpy=True, normalize_embeddings=True)\n", 330 | "query_bemb = np.where(query_emb < 0, 0, 1).astype(np.uint8) # binarize\n", 331 | "query_bemb = np.packbits(query_bemb).reshape(query_bemb.shape[0], -1)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "`score_multiply` is used for helping the evaluation. It is expecting that the relevant elements have a higher score, while hamming distance gives us a lower score." 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 13, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "def faiss_search(index, queries_emb, k=[10, 100], float_embed = None, score_multiply = 1, oversample = 1):\n", 348 | " start_time = time.time()\n", 349 | " faiss_scores, faiss_doc_ids = index.search(queries_emb, max(k) * oversample)\n", 350 | " print(f\"Search took {(time.time()-start_time):.4f} sec\")\n", 351 | " \n", 352 | " query2id = {idx: qid for idx, qid in enumerate(query_ids)}\n", 353 | " doc2id = {idx: cid for idx, cid in enumerate(docs_ids)}\n", 354 | " id2doc = {cid: idx for idx, cid in enumerate(docs_ids)}\n", 355 | "\n", 356 | " faiss_results = {}\n", 357 | " for idx in range(0, len(faiss_scores)):\n", 358 | " qid = query2id[idx]\n", 359 | " doc_scores = {doc2id[doc_id]: score.item() * score_multiply for doc_id, score in zip(faiss_doc_ids[idx], faiss_scores[idx])}\n", 360 | " \n", 361 | " # Rescore\n", 362 | " if float_embed is not None:\n", 363 | " bin_doc_emb = np.asarray([index.reconstruct(id2doc[doc_id]) for doc_id in doc_scores])\n", 364 | " bin_doc_emb_unpacked = np.unpackbits(bin_doc_emb, axis=-1).astype(\"int\")\n", 365 | " \n", 366 | " scores_cont = (float_embed[idx] @ bin_doc_emb_unpacked.T)\n", 367 | " doc_scores = {doc_id: score_cont for doc_id, score_cont in zip(doc_scores, scores_cont)}\n", 368 | "\n", 369 | " faiss_results[qid] = doc_scores\n", 370 | "\n", 371 | " \n", 372 | " ndcg, map_score, recall, precision = EvaluateRetrieval.evaluate(qrels, faiss_results, k)\n", 373 | " acc = EvaluateRetrieval.evaluate_custom(qrels, faiss_results, [3, 5, 10], metric=\"acc\")\n", 374 | " print(ndcg)\n", 375 | " print(recall)\n", 376 | " print(acc)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "### Baseline: Normal exact search\n", 384 | "We mostly care about NDCG@10 here." 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 14, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "Search took 0.4522 sec\n", 397 | "{'NDCG@10': 0.75558, 'NDCG@100': 0.56317}\n", 398 | "{'Recall@10': 0.02136, 'Recall@100': 0.13842}\n", 399 | "{'Accuracy@3': 0.98, 'Accuracy@5': 1.0, 'Accuracy@10': 1.0}\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "faiss_search(index, query_emb)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "### W/O Rescoring\n", 412 | "\n", 413 | "We loose around 53% of the performance. But its pretty fast ~30-40x faster." 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 15, 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "name": "stdout", 423 | "output_type": "stream", 424 | "text": [ 425 | "Search took 0.0146 sec\n", 426 | "{'NDCG@10': 0.72723, 'NDCG@100': 0.50933}\n", 427 | "{'Recall@10': 0.02007, 'Recall@100': 0.12191}\n", 428 | "{'Accuracy@3': 1.0, 'Accuracy@5': 1.0, 'Accuracy@10': 1.0}\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "faiss_search(bindex, query_bemb, score_multiply=-1, oversample=10)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "markdown", 438 | "metadata": {}, 439 | "source": [ 440 | "### With Rescoring\n", 441 | "\n", 442 | "Still extremely fast, with the difference that we retain 99% of the performance. We verified similar behavior for SciFact and ArguAna. Accuracy was also boosted." 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 16, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "Search took 0.0143 sec\n" 455 | ] 456 | }, 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "{'NDCG@10': 0.75496, 'NDCG@100': 0.53638}\n", 462 | "{'Recall@10': 0.02128, 'Recall@100': 0.13022}\n", 463 | "{'Accuracy@3': 1.0, 'Accuracy@5': 1.0, 'Accuracy@10': 1.0}\n" 464 | ] 465 | } 466 | ], 467 | "source": [ 468 | "faiss_search(bindex, query_bemb, float_embed=query_emb, oversample=10)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "## Conclusion\n", 476 | "\n", 477 | "Binary embedding enables extremely fast retrieval and low storage usage, at the expense of a slight performance loss, which can be mitigated by using a reranker. This has cool applications for on-device usage, large-scale developments, etc. We should also explore its potential for other tasks, such as clustering and deduplication at scale." 478 | ] 479 | } 480 | ], 481 | "metadata": { 482 | "kernelspec": { 483 | "display_name": "base", 484 | "language": "python", 485 | "name": "python3" 486 | }, 487 | "language_info": { 488 | "codemirror_mode": { 489 | "name": "ipython", 490 | "version": 3 491 | }, 492 | "file_extension": ".py", 493 | "mimetype": "text/x-python", 494 | "name": "python", 495 | "nbconvert_exporter": "python", 496 | "pygments_lexer": "ipython3", 497 | "version": "3.9.2" 498 | } 499 | }, 500 | "nbformat": 4, 501 | "nbformat_minor": 2 502 | } 503 | --------------------------------------------------------------------------------