├── LICENSE ├── MRL.py ├── README.md ├── __init__.py ├── adanns ├── README.md ├── ablations │ ├── centroid_recall.ipynb │ ├── class_matches_per_cluster.ipynb │ ├── cluster-distribution.ipynb │ ├── plots.ipynb │ └── relative_contrast.ipynb ├── adanns-ivf-optimized.ipynb ├── adanns-ivf-unoptimized.ipynb ├── compute_metrics.ipynb ├── diskann │ ├── README.md │ └── adanns-diskann.ipynb ├── dpr-nq │ ├── README.md │ └── adanns-nq.ipynb ├── generate_nn │ ├── hnsw_exactl2.ipynb │ ├── ivf-experiments.ipynb │ └── ivfpq_opq_kmeans.ipynb └── utils.py ├── generate_embeddings ├── pytorch_inference.py └── run-inference.sh ├── images ├── accuracy-compute.png ├── adanns-opq.png ├── adanns-teaser.png ├── diskann-table.png ├── diskann-top1.png ├── encoders.png ├── flowchart.png ├── opq-1k.png └── opq-nq.png ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 RAIVN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MRL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Type, Any, Callable, Union, List, Optional 4 | 5 | ''' 6 | Loss function for Matryoshka Representation Learning 7 | ''' 8 | 9 | class Matryoshka_CE_Loss(nn.Module): 10 | def __init__(self, relative_importance=None, **kwargs): 11 | super(Matryoshka_CE_Loss, self).__init__() 12 | self.criterion = nn.CrossEntropyLoss(**kwargs) 13 | self.relative_importance= relative_importance 14 | 15 | def forward(self, output, target): 16 | loss=0 17 | N= len(output) 18 | for i in range(N): 19 | rel = 1. if self.relative_importance is None else self.relative_importance[i] 20 | loss+= rel*self.criterion(output[i], target) 21 | return loss 22 | 23 | class MRL_Linear_Layer(nn.Module): 24 | def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs): 25 | super(MRL_Linear_Layer, self).__init__() 26 | self.nesting_list=nesting_list 27 | self.num_classes=num_classes # Number of classes for classification 28 | self.efficient = efficient 29 | if self.efficient: 30 | setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs)) 31 | else: 32 | for i, num_feat in enumerate(self.nesting_list): 33 | setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs)) 34 | 35 | def reset_parameters(self): 36 | if self.efficient: 37 | self.nesting_classifier_0.reset_parameters() 38 | else: 39 | for i in range(len(self.nesting_list)): 40 | getattr(self, f"nesting_classifier_{i}").reset_parameters() 41 | 42 | 43 | def forward(self, x): 44 | nesting_logits = () 45 | for i, num_feat in enumerate(self.nesting_list): 46 | if self.efficient: 47 | if self.nesting_classifier_0.bias is None: 48 | nesting_logits+= (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), ) 49 | else: 50 | nesting_logits+= (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, ) 51 | else: 52 | nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),) 53 | 54 | return nesting_logits 55 | 56 | 57 | class FixedFeatureLayer(nn.Linear): 58 | ''' 59 | For our fixed feature baseline, we just replace the classification layer with the following. 60 | It effectively just look at the first "in_features" for the classification. 61 | ''' 62 | 63 | def __init__(self, in_features, out_features, **kwargs): 64 | super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs) 65 | 66 | def forward(self, x): 67 | if not (self.bias is None): 68 | out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias 69 | else: 70 | out = torch.matmul(x[:, :self.in_features], self.weight.t()) 71 | return out 72 | 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AdANNS: A Framework for Adaptive Semantic Search](https://arxiv.org/abs/2305.19435) 2 | _Aniket Rege, Aditya Kusupati, Sharan Ranjit S, Alan Fan, Qingqing Cao, Sham M. Kakade, Prateek Jain, Ali Farhadi_ 3 | 4 | Learned representations are used in multiple downstream tasks like web-scale search & classification. However, they are flat & rigid—information is diffused across dimensions and cannot be adaptively deployed without large post-hoc overhead. We propose the use of adaptive representations to improve approximate nearest neighbour search (ANNS) and introduce a new paradigm, AdANNS, to achieve it at scale leveraging matryoshka representations (MRs). We compare AdANNS to ANNS structures built on independently trained rigid representations (RRs). 5 | 6 |

7 | 8 |

9 | 10 | This repository contains code for AdANNS construction and inference built on top of Matryoshka Representations (MRs). The training pipeline to generate MRs and RRs can be found [here](https://github.com/RAIVNLab/MRL). The repository is organized as follows: 11 | 12 | 1. Set up 13 | 2. Inference to generate MRs and RRs 14 | 3. AdANNS Experiments 15 | 16 | 17 | ## Set Up 18 | Pip install the requirements file in this directory. Note that a python3 distribution is required: 19 | ``` 20 | pip3 install -r requirements.txt 21 | ``` 22 | 23 | ## Inference on Trained Models 24 | We primarily utilize ResNet-50 MRL and Rigid encoders ("Fixed-Feature" in original MRL terminology) for a bulk of our experimentation. We also utilize trained MRL ResNet18/34/101 and ConvNeXT encoders as an ablation study. Inference on trained models to generate MR and RR embeddings used for downstream ANNS is provided in `inference/pytorch_inference.py`, and is explained in more detail in the [original MRL repository](https://github.com/RAIVNLab/MRL). 25 | 26 | 27 | ## AdANNS 28 | `cd adanns` 29 | 30 | We provide code showcasing AdANNS in action on a simple yet powerful search data structure – IVF (AdANNS-IVF) – and on industry-default quantization – OPQ (AdANNS-OPQ) – followed by its effectiveness on modern-day ANNS composite indices like IVFOPQ (AdANNS-IVFOPQ) and DiskANN (AdANNS-DiskANN). 31 | 32 | A more detailed walkthrough of AdANNS can be found in [`adanns/`](adanns/) 33 | 34 |

35 | 36 |

37 | 38 | ## Citation 39 | If you find this project useful in your research, please consider citing: 40 | ``` 41 | @article{rege2023adanns, 42 | title={AdANNS: A Framework for Adaptive Semantic Search}, 43 | author={Aniket Rege and Aditya Kusupati and Sharan Ranjit S and Alan Fan and Qingqing Cao and Sham Kakade and Prateek Jain and Ali Farhadi}, 44 | year={2023}, 45 | eprint={2305.19435}, 46 | archivePrefix={arXiv}, 47 | primaryClass={cs.LG} 48 | } 49 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/__init__.py -------------------------------------------------------------------------------- /adanns/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | The directory structure is organized as follows: 3 | 1. AdANNS-IVF (see Figure 2) in an [optimized](#optimized-adanns-ivf) and [unoptimized](#unoptimized-adanns-ivf) fashion 4 | 2. [Non Adaptive Pipeline](#non-adaptive-search) on ImageNet-1K 5 | 3. AdANNS with DiskANN, a graph based Memory-SSD hybrid ANNS index in [diskann](diskann/README.md) 6 | 4. AdANNS with Dense Passage Retriever (DPR) on [Natural Questions](https://ai.google.com/research/NaturalQuestions) (NQ) in [dpr-nq](dpr-nq/README.md) 7 | 8 |

9 | 10 |

11 | 12 | ## Optimized AdANNS-IVF 13 | [`adanns-ivf-optimized.ipynb`](adanns-ivf-optimized.ipynb) implements _AdANNS_-IVF in an optimized fashion (native Faiss) and is drawn from a [Faiss case study](https://gist.github.com/mdouze/8c5ab227c0f7d9d7c15cf92a391dcbe5#file-demo_independent_ivf_dimension-ipynb). During index construction, $d_c$ is used for coarse quantization (centroid assignment), and all database vectors are pre-assigned to their appropriate clusters with this coarse quantizer. We also typically learn a fine quantizer (OPQ) with $d_s$ for fast distance computation. During the search phase, for a given query, nprobe clusters are shortlisted with $d_c$, and linear scan is done via a PQ lookup table which is constructed from $d_s$ over all shortlisted clusters. 14 | 15 | Note that currently the optimized version supports only $d_s \ge d_c$ during index construction. 16 | 17 | ## Unoptimized AdANNS-IVF 18 | [`adanns-ivf-unoptimized.ipynb`](adanns-ivf-unoptimized.ipynb) implements a fully-decoupled ANNS pipeline; $d_c < d_s$ is now possible. This is carried out in a highly unoptimized fashion: during construction, we utilize pre-constructed *k-means* indices indexed over the database with $d_c$, and build an inverted file lookup that maps each database item to its closest cluster centroid. During search time, we iterate over each centroid and perform a batched search over every queryset vector corresponding to that centroid with $d_s$. This approach scales linearly in cost with the number of centroids $\mathcal{O}(N_C)$, and results in higher wall-clock times during search than the optimized version above. 19 | 20 | ## Non-Adaptive Search 21 |

22 | 23 |

24 | 25 | For all other experiments with Faiss (IVF-XX, MG-IVF-XX, OPQ, IVFOPQ, HNSWOPQ) on ImageNet-1K, we have a 3-stage pipeline: 26 | 1. Encoder inference to generate embeddings for database and queries: see [run-inference.sh](../generate_embeddings/run-inference.sh) 27 | 2. Index construction (database) and search (queries) on generated embeddings to generate k-NN arrays: see [generate_nn](generate_nn/) 28 | 3. Metric computation or Ablations on top of k-NN arrays: see [compute_metrics.ipynb](compute_metrics.ipynb) and [ablations](ablations/) -------------------------------------------------------------------------------- /adanns/ablations/centroid_recall.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "21d1fdb4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import faiss\n", 12 | "import torch\n", 13 | "import sys\n", 14 | "sys.path.append('../')\n", 15 | "from utils.py import load_embeddings" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "id": "cfc615b9", 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "Loaded queries: (50000, 2048)\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "root = '../../../inference_array/resnet50/'\n", 34 | "model = 'mrl' # mrl, ff\n", 35 | "dataset = '1K' # 1K, 4K, V2\n", 36 | "index_type = 'kmeans'\n", 37 | "d = 2048 # cluster construction dim\n", 38 | "\n", 39 | "_, queryset, _, _, _, _ = load_embeddings(model, dataset, d)\n", 40 | "faiss.normalize_L2(queryset)\n", 41 | "print(\"Loaded queries:\", queryset.shape)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "39d931c4", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Clusters: 1024\n", 55 | "\n", 56 | "Number of probes: 1\n", 57 | "['top1', 'top2', 'top4', 'top5', 'top10']\n", 58 | "[0.7355 0.7355 0.7355 0.7355 0.7355]\n", 59 | "[0.84186 0.84186 0.84186 0.84186 0.84186]\n", 60 | "[0.89878 0.89878 0.89878 0.89878 0.89878]\n", 61 | "[0.93288 0.93288 0.93288 0.93288 0.93288]\n", 62 | "[0.95518 0.95518 0.95518 0.95518 0.95518]\n", 63 | "[0.96946 0.96946 0.96946 0.96946 0.96946]\n", 64 | "[0.9822 0.9822 0.9822 0.9822 0.9822]\n", 65 | "[0.99078 0.99078 0.99078 0.99078 0.99078]\n", 66 | "[1. 1. 1. 1. 1.]\n", 67 | "\n", 68 | "Number of probes: 2\n", 69 | "['top1', 'top2', 'top4', 'top5', 'top10']\n", 70 | "[0.7355 0.83386 0.83386 0.83386 0.83386]\n", 71 | "[0.84186 0.92572 0.92572 0.92572 0.92572]\n", 72 | "[0.89878 0.96662 0.96662 0.96662 0.96662]\n", 73 | "[0.93288 0.98402 0.98402 0.98402 0.98402]\n", 74 | "[0.95518 0.99246 0.99246 0.99246 0.99246]\n", 75 | "[0.96946 0.99674 0.99674 0.99674 0.99674]\n", 76 | "[0.9822 0.99894 0.99894 0.99894 0.99894]\n", 77 | "[0.99078 0.99982 0.99982 0.99982 0.99982]\n", 78 | "[1. 1. 1. 1. 1.]\n", 79 | "\n", 80 | "Number of probes: 5\n", 81 | "['top1', 'top2', 'top4', 'top5', 'top10']\n", 82 | "[0.7355 0.83386 0.89734 0.91348 0.91348]\n", 83 | "[0.84186 0.92572 0.96844 0.97612 0.97612]\n", 84 | "[0.89878 0.96662 0.99118 0.99438 0.99438]\n", 85 | "[0.93288 0.98402 0.99746 0.9988 0.9988 ]\n", 86 | "[0.95518 0.99246 0.99934 0.99978 0.99978]\n", 87 | "[0.96946 0.99674 0.99984 0.9999 0.9999 ]\n", 88 | "[0.9822 0.99894 0.99994 0.99998 0.99998]\n", 89 | "[0.99078 0.99982 1. 1. 1. ]\n", 90 | "[1. 1. 1. 1. 1.]\n", 91 | "\n", 92 | "Number of probes: 10\n", 93 | "['top1', 'top2', 'top4', 'top5', 'top10']\n", 94 | "[0.7355 0.83386 0.89734 0.91348 0.95232]\n", 95 | "[0.84186 0.92572 0.96844 0.97612 0.99114]\n", 96 | "[0.89878 0.96662 0.99118 0.99438 0.99898]\n", 97 | "[0.93288 0.98402 0.99746 0.9988 0.99986]\n", 98 | "[0.95518 0.99246 0.99934 0.99978 1. ]\n", 99 | "[0.96946 0.99674 0.99984 0.9999 1. ]\n", 100 | "[0.9822 0.99894 0.99994 0.99998 1. ]\n", 101 | "[0.99078 0.99982 1. 1. 1. ]\n", 102 | "[1. 1. 1. 1. 1.]\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "search_dim = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n", 108 | "nprobes = [1, 2, 5, 10]\n", 109 | "ncentroids = [1024]\n", 110 | "\n", 111 | "for centroid in ncentroids:\n", 112 | " print(\"Clusters: \", centroid)\n", 113 | " \n", 114 | " # Load kmeans index\n", 115 | " size = str(centroid)+'ncentroid_'+str(d)+'d'\n", 116 | " index_file = root+'index_files/'+model+dataset+'_'+index_type+'_'+size+\"_nbits8_nlist2048\"\n", 117 | " cpu_index = faiss.read_index(index_file+'.index')\n", 118 | " if torch.cuda.device_count() > 0:\n", 119 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n", 120 | " \n", 121 | " # Load and normalize centroids\n", 122 | " centroids_path = root+'kmeans/'+model+'ncentroids'+str(centroid)+\"_\"+str(d)+'d'\"_\"+dataset+'.npy'\n", 123 | " centroids = np.load(centroids_path)\n", 124 | " faiss.normalize_L2(centroids)\n", 125 | " gt = np.argsort(-queryset @ centroids.T, axis=1)\n", 126 | " \n", 127 | " topK = [1, 2, 4, 5, 10]\n", 128 | " \n", 129 | " for nprobe in nprobes:\n", 130 | " print(\"\\nNumber of probes:\", nprobe)\n", 131 | " print([f'top{k}' for k in topK])\n", 132 | " for dim in search_dim:\n", 133 | " q = np.ascontiguousarray(queryset[:, :dim])\n", 134 | " nqueries = q.shape[0]\n", 135 | " faiss.normalize_L2(q)\n", 136 | " c = np.ascontiguousarray(centroids[:, :dim])\n", 137 | " faiss.normalize_L2(c)\n", 138 | " low_d_clusters = np.argsort(-q @ c.T, axis=1)\n", 139 | " \n", 140 | " count = [0, 0, 0, 0, 0]\n", 141 | " \n", 142 | " # Iterate over all queries\n", 143 | " for i in range(nqueries):\n", 144 | " label = gt[i][0]\n", 145 | " target = low_d_clusters[i][:nprobe]\n", 146 | " for j in range(len(topK)):\n", 147 | " count[j] += label in target[:topK[j]] # increments count[j] if correct\n", 148 | "\n", 149 | " print(np.array(count) / nqueries)" 150 | ] 151 | } 152 | ], 153 | "metadata": { 154 | "kernelspec": { 155 | "display_name": "Python 3 (ipykernel)", 156 | "language": "python", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.11.3" 170 | }, 171 | "vscode": { 172 | "interpreter": { 173 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc" 174 | } 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 5 179 | } 180 | -------------------------------------------------------------------------------- /adanns/ablations/relative_contrast.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1d746c2b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import torch\n", 12 | "import faiss\n", 13 | "import sys\n", 14 | "sys.path.append('../')\n", 15 | "from utils import load_embeddings" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 9, 21 | "id": "3080d42d", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "root = '../../../inference_array/resnet50/'\n", 26 | "model = \"ff\" # mrl, ff\n", 27 | "dataset = '1K' # 1K, 4K, V2" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "id": "de07a2d2", 33 | "metadata": {}, 34 | "source": [ 35 | "### In the cell below, we use the relative contrast equation as defined in Equation (1) of [On the DIfficulty of Nearest Neighbor Search](https://www.ee.columbia.edu/ln/dvmm/pubs/files/ICML_RelativeContrast.pdf).
\n", 36 | "### $C_r = \\frac{D_{mean}}{D_{min}}$
\n", 37 | "

where $C_r$ is the relative contrast of a dataset $X$, $D_{mean}$ is the expected distance of a random database sample from a query $q$, and $D_{min}$ is the expected distance to the nearest database sample from a query $q$.

" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "d4f449b0", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "if torch.cuda.device_count() > 0:\n", 48 | " device = torch.device('cuda')\n", 49 | "else:\n", 50 | " raise Exception(\"Please use a GPU! This will take very very long otherwise.\")\n", 51 | "\n", 52 | "# dlist = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n", 53 | "dlist = [2048]\n", 54 | "batch_size = 4196\n", 55 | "\n", 56 | "for d in dlist:\n", 57 | " database, queryset, db_labels, query_labels, xb, xq = load_embeddings(model, dataset, d)\n", 58 | "\n", 59 | " qy = torch.as_tensor(xq).to(device)\n", 60 | " db = torch.as_tensor(xb)\n", 61 | " \n", 62 | " num_batches = int(database.shape[0] / batch_size)\n", 63 | " final_d_min = torch.ones((qy.shape[0])).to(device) * 1e10\n", 64 | " final_d_mean = []\n", 65 | "\n", 66 | " for i in range(num_batches):\n", 67 | " db_batch = db[(i)*batch_size:(i+1)*batch_size, :].to(device)\n", 68 | " distances = torch.cdist(qy, db_batch)\n", 69 | " sorted_dist = torch.sort(distances)\n", 70 | " current_d_min = sorted_dist.values[:, 0]\n", 71 | " \n", 72 | " final_d_min = torch.min(current_d_min, final_d_min)\n", 73 | " final_d_mean.append(torch.mean(distances, axis=1).cpu().numpy())\n", 74 | " \n", 75 | " C_r = np.mean(final_d_mean) / torch.mean(final_d_min).cpu().numpy()\n", 76 | " print(f'C_r(d={d})={C_r}')" 77 | ] 78 | } 79 | ], 80 | "metadata": { 81 | "kernelspec": { 82 | "display_name": "Python 3 (ipykernel)", 83 | "language": "python", 84 | "name": "python3" 85 | }, 86 | "language_info": { 87 | "codemirror_mode": { 88 | "name": "ipython", 89 | "version": 3 90 | }, 91 | "file_extension": ".py", 92 | "mimetype": "text/x-python", 93 | "name": "python", 94 | "nbconvert_exporter": "python", 95 | "pygments_lexer": "ipython3", 96 | "version": "3.11.3" 97 | }, 98 | "vscode": { 99 | "interpreter": { 100 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc" 101 | } 102 | } 103 | }, 104 | "nbformat": 4, 105 | "nbformat_minor": 5 106 | } 107 | -------------------------------------------------------------------------------- /adanns/adanns-ivf-optimized.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c7ade3d8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import faiss\n", 11 | "import numpy as np\n", 12 | "import time\n", 13 | "import csv\n", 14 | "from os import path, makedirs\n", 15 | "\n", 16 | "import multiprocessing\n", 17 | "from multiprocessing.dummy import Pool as ThreadPool\n", 18 | "from functools import partial\n", 19 | "\n", 20 | "from faiss.contrib.ivf_tools import add_preassigned, search_preassigned" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "43f2af97", 26 | "metadata": {}, 27 | "source": [ 28 | "## AdANNS-IVF" 29 | ] 30 | }, 31 | { 32 | "attachments": {}, 33 | "cell_type": "markdown", 34 | "id": "9744b124", 35 | "metadata": {}, 36 | "source": [ 37 | "### Notation\n", 38 | "1. $D$ = Embedding Dimensionality for IVF construction and search\n", 39 | "2. $M$ = number of OPQ subquantizers. Faiss requires $D$ % $M$ == $0$. \n", 40 | "3. For AdANNS, D is decomposed to $D_{construct}$ and $D_{search}$\n", 41 | "\n", 42 | "### Miscellaneous Notes\n", 43 | "1. Rigid representations (RR) are embedded through independently trained \"fixed feature\" (FF) encoders. RR and FF are thus used interchangeably in documentation and code and are essentially equivalent.\n", 44 | "2. In this notebook, the AdANNS-IVF coarse quantizer uses OPQ by default for cheap distance computation, but is optional.\n", 45 | "3. AdANNS-IVF is adapted from this [Faiss Case Study](https://gist.github.com/mdouze/8c5ab227c0f7d9d7c15cf92a391dcbe5#file-demo_independent_ivf_dimension-ipynb)\n", 46 | "4. Optimized AdANNS-IVF (with Faiss) has a restriction that $D_{construct}\\geq D_{search}$. This is because we slice centroids learnt from $D_{construct}$ to learn PQ codebooks with $D_{search}$ (this is possible because they are MRs)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 9, 52 | "id": "fc82acd3", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "D = 2048 # Max d for ResNet50\n", 57 | "n_cell = 1024 # number of IVF cells, default=1024 for ImageNet-1K\n", 58 | "\n", 59 | "embeddings_root = 'path/to/embeddings' # load embeddings\n", 60 | "adanns_root = 'path/to/adanns/indices/' # store adanns indices\n", 61 | "rigid_root = 'path/to/rigid/indices/' # store rigid indices\n", 62 | "config = 'rr' # mrl, rr\n", 63 | "\n", 64 | "if config == 'mrl':\n", 65 | " config_load = 'mrl1_e0_ff2048'\n", 66 | "elif config == 'rr':\n", 67 | " config_load = 'mrl0_e0_ff2048'\n", 68 | "else:\n", 69 | " raise Exception(f\"Unsupported config {config}!\")\n", 70 | "\n", 71 | "use_mrl = config.upper() # MRL, RR\n", 72 | "\n", 73 | "db_npy = '1K_train_' + config_load + '-X.npy'\n", 74 | "query_npy = '1K_val_' + config_load + '-X.npy'" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "045780fe", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "xb = np.load(embeddings_root + db_npy)\n", 85 | "assert np.count_nonzero(np.isnan(xb)) == 0\n", 86 | "xq = np.load(embeddings_root + query_npy)\n", 87 | "\n", 88 | "query_labels = np.load(embeddings_root + \"1K_val_\" + config_load + \"-y.npy\")\n", 89 | "db_labels = np.load(embeddings_root + \"1K_train_\" + config_load + \"-y.npy\")\n", 90 | "\n", 91 | "print(\"loaded DB %s : %s\" % (db_npy, xb.shape))\n", 92 | "print(\"loaded queries %s : %s\" % (query_npy, xq.shape))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "a1cdc1ef", 98 | "metadata": {}, 99 | "source": [ 100 | "## RR2048 OPQ Dim Reduction Baseline" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 11, 106 | "id": "517fa2d4", 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "(100000, 2048)\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "db_subsampled = xb[np.random.choice(xb.shape[0], 100000, replace=False)]\n", 119 | "print(db_subsampled.shape)\n", 120 | "dim_reduce = 128" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "45fe9462", 126 | "metadata": {}, 127 | "source": [ 128 | "### SVD dim reduction + OPQ" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 12, 134 | "id": "4829ab90", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "SVD projected Database: (1281167, 128)\n", 142 | "SVD projected Queries: (50000, 128)\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "def get_SVD_mat(db_subsampled, low_dim):\n", 148 | " mat = faiss.PCAMatrix(db_subsampled.shape[1], low_dim)\n", 149 | " mat.train(db_subsampled)\n", 150 | " assert mat.is_trained\n", 151 | " return mat\n", 152 | "\n", 153 | "svd_mat = get_SVD_mat(db_subsampled, dim_reduce)\n", 154 | "database_svd_lowdim = svd_mat.apply(xb)\n", 155 | "print(\"SVD projected Database: \", database_svd_lowdim.shape)\n", 156 | "query_svd_lowdim = svd_mat.apply(xq)\n", 157 | "print(\"SVD projected Queries: \", query_svd_lowdim.shape)\n", 158 | "\n", 159 | "faiss.normalize_L2(database_svd_lowdim)\n", 160 | "faiss.normalize_L2(query_svd_lowdim)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 13, 166 | "id": "bd787adc", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "Building FF_D2048_SVD128_OPQ128.faiss\n", 174 | "Train+add time: 13411.728565454483\n", 175 | "[2048, 128, 128, 0.69224]\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "for M in [128]:\n", 181 | " if not path.exists(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss'):\n", 182 | " print(f'Building {use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')\n", 183 | " cpu_index = faiss.index_factory(dim_reduce, f'OPQ{M},PQ{M}')\n", 184 | " start = time.time()\n", 185 | " cpu_index.train(database_svd_lowdim)\n", 186 | " cpu_index.add(database_svd_lowdim)\n", 187 | " print(\"Train+add time: \", time.time() - start)\n", 188 | " faiss.write_index(cpu_index, f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')\n", 189 | " \n", 190 | " top1 = [xb.shape[1], dim_reduce, M]\n", 191 | " _, Ind = cpu_index.search(query_svd_lowdim, 100)\n", 192 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n", 193 | " print(top1)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "2b042b5b", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "[2048, 128, 16, 0.69088]\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "for M in [128]:\n", 212 | " top1 = [xb.shape[1], dim_reduce, M]\n", 213 | " svd_opq_index = faiss.read_index(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}.faiss')\n", 214 | " _, Ind = svd_opq_index.search(query_svd_lowdim, 100)\n", 215 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n", 216 | " print(top1)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "e9e4b00a", 222 | "metadata": {}, 223 | "source": [ 224 | "## Rigid-IVF + OPQ" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 19, 230 | "id": "944b8f15", 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Skipping build, index exists: FF_D2048+IVF1024,OPQ16.faiss\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "# Construct Rigid Index\n", 243 | "\n", 244 | "for M in [16]:\n", 245 | " for D in [2048]:\n", 246 | " database = np.ascontiguousarray(xb[:,:D], dtype=np.float32)\n", 247 | " faiss.normalize_L2(database)\n", 248 | " \n", 249 | " if M > D:\n", 250 | " continue\n", 251 | "\n", 252 | " if not path.exists(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss'):\n", 253 | " print(f'Building {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n", 254 | " start = time.time()\n", 255 | "\n", 256 | " index = faiss.index_factory(int(D), f'IVF{n_cell},PQ{M}')\n", 257 | "\n", 258 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')\n", 259 | " opq = opq_index_pretrained.chain.at(0)\n", 260 | "\n", 261 | " db = opq.apply(database)\n", 262 | "\n", 263 | " index.train(db)\n", 264 | " index.add(db)\n", 265 | "\n", 266 | " print(\"Time: \", time.time() - start)\n", 267 | " faiss.write_index(index, f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n", 268 | " print(f'Created IVF{n_cell},OPQ{M} index with D={D}')\n", 269 | "\n", 270 | " else:\n", 271 | " print(f'Skipping build, index exists: {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 23, 277 | "id": "baa631e4", 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "[n_cell, D, M, top1]\n", 285 | "[1024, 2048, 8, 0.64966]\n", 286 | "[1024, 2048, 16, 0.6663]\n", 287 | "[1024, 2048, 32, 0.67724]\n", 288 | "[1024, 2048, 64, 0.68588]\n" 289 | ] 290 | } 291 | ], 292 | "source": [ 293 | "# Search Rigid Index\n", 294 | "\n", 295 | "print('[n_cell, D, M, top1]')\n", 296 | "for D in [2048]:\n", 297 | " queryset = np.ascontiguousarray(xq[:,:D], dtype=np.float32)\n", 298 | " faiss.normalize_L2(queryset)\n", 299 | " for M in [8, 16, 32, 64]:\n", 300 | " if M > D:\n", 301 | " continue\n", 302 | " \n", 303 | " top1 = [n_cell, D, M]\n", 304 | " times = [n_cell, D, M]\n", 305 | "\n", 306 | " index = faiss.read_index(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n", 307 | " \n", 308 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')\n", 309 | " opq = opq_index_pretrained.chain.at(0)\n", 310 | "\n", 311 | " q = opq.apply(queryset)\n", 312 | "\n", 313 | " for nprobe in [1]:\n", 314 | " start = time.time()\n", 315 | " faiss.extract_index_ivf(index).nprobe = nprobe \n", 316 | " Dist, Ind = index.search(q, 100)\n", 317 | "\n", 318 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n", 319 | " times.append(time.time() - start)\n", 320 | "\n", 321 | " print(top1)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "2b0c62f6", 327 | "metadata": {}, 328 | "source": [ 329 | "## AdANNS-IVF + OPQ" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 45, 335 | "id": "b9b187df", 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "def create_adanns_indices(D_search, D_construct, M, n_cell):\n", 340 | " index_search = faiss.index_factory(D_search, f'OPQ{M},IVF{n_cell},PQ{M}')\n", 341 | " index_construct = faiss.index_factory(D_construct, f'IVF{n_cell},Flat')\n", 342 | " \n", 343 | " database = np.ascontiguousarray(xb[:,:D_construct], dtype=np.float32)\n", 344 | " faiss.normalize_L2(database)\n", 345 | "\n", 346 | " # train the full-dimensional \"construct\" coarse quantizer. IVF centroid assignments are learnt with D_construct\n", 347 | " if not path.exists(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss'):\n", 348 | " index_construct.train(database)\n", 349 | " quantizer_construct = index_construct.quantizer\n", 350 | " faiss.write_index(quantizer_construct, adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n", 351 | " else:\n", 352 | " print(\"Index exists: \", adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n", 353 | "\n", 354 | " # prepare the \"search\" coarse quantizer. OPQ codebooks are learnt on D_search\n", 355 | " if not path.exists(adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss'):\n", 356 | " quantizer_construct = faiss.read_index(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n", 357 | " database_search = np.ascontiguousarray(xb[:, :D_search], dtype=np.float32)\n", 358 | " centroids_search = np.ascontiguousarray(quantizer_construct.reconstruct_n(0, quantizer_construct.ntotal)[:, :D_search], dtype=np.float32)\n", 359 | " \n", 360 | " # Apply OPQ to search DB and centroids\n", 361 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D_search}_nbits8.index')\n", 362 | " print(f'Applying OPQ: 1K_opq_{M}m_d{D_search}')\n", 363 | " opq = opq_index_pretrained.chain.at(0)\n", 364 | " opq.apply(centroids_search)\n", 365 | " opq.apply(database_search)\n", 366 | " faiss.normalize_L2(database_search)\n", 367 | " \n", 368 | " index_ivf_search = faiss.downcast_index(faiss.extract_index_ivf(index_search))\n", 369 | " index_ivf_search.quantizer.add(centroids_search)\n", 370 | "\n", 371 | " index_ivf_search.train(database_search)\n", 372 | " index_search.is_trained = True\n", 373 | "\n", 374 | " # coarse quantization with the construct quantizer\n", 375 | " _, Ic = quantizer_construct.search(database, 1) # each database vector assigned to one of num_cell centroids\n", 376 | " # add operation \n", 377 | " add_preassigned(index_ivf_search, database_search, Ic.ravel())\n", 378 | "\n", 379 | " faiss.write_index(index_ivf_search, adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')\n", 380 | " else:\n", 381 | " print(\"Index exists: \", adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')\n", 382 | " \n", 383 | " print(f'Initialized construct quantizer D{D_construct}, search quantizer D{D_search}, M{M}, ncell{n_cell}')" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 46, 389 | "id": "0ad99c4a", 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "Skipping (M, d_small, d_big): (64, 2048, 64)\n", 397 | "Skipping (M, d_small, d_big): (64, 2048, 128)\n", 398 | "Skipping (M, d_small, d_big): (64, 2048, 256)\n", 399 | "Skipping (M, d_small, d_big): (64, 2048, 512)\n", 400 | "Skipping (M, d_small, d_big): (64, 2048, 1024)\n", 401 | "Index exists: case_study_decoupled/MRL_D2048+IVF1024,PQ64_big_quantizer.faiss\n", 402 | "Applying OPQ: 1K_opq_64m_d2048\n", 403 | "Initialized big quantizer D2048, small quantizer D2048, M64, ncell1024\n" 404 | ] 405 | } 406 | ], 407 | "source": [ 408 | "for D_construct in [64, 128, 256, 512, 1024, 2048]:\n", 409 | " for D_search in [2048]:\n", 410 | " for M in [64]:\n", 411 | " if M > D_search or D_search > D_construct:\n", 412 | " print(\"Skipping (M, d_search, d_construct): \", (M, D_search, D_construct))\n", 413 | " continue\n", 414 | " create_adanns_indices(D_search, D_construct, M, n_cell=1024)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "id": "5d2081b7", 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "# Preassigned Search using multiple cores\n", 425 | "\n", 426 | "USE_MULTITHREAD_SEARCH = True\n", 427 | "num_cores = multiprocessing.cpu_count()\n", 428 | "thread_batch_size = 1000\n", 429 | "\n", 430 | "# Helper function to split search on multiple cores\n", 431 | "def multisearch_preassigned(index, queryset, Ic, batch_iter):\n", 432 | " _, I = search_preassigned(index, \n", 433 | " queryset[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], \n", 434 | " 100, # Shortlist length\n", 435 | " Ic[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], \n", 436 | " None)\n", 437 | " return I" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 49, 443 | "id": "74680e75", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "def search_adanns_indices(D_search, D_construct, n_cell, nprobes=[1]):\n", 448 | " queryset = np.ascontiguousarray(xq[:,:D_construct], dtype=np.float32)\n", 449 | " faiss.normalize_L2(queryset)\n", 450 | " \n", 451 | " queryset_small = np.ascontiguousarray(xq[:, :D_search], dtype=np.float32)\n", 452 | " faiss.normalize_L2(queryset_small)\n", 453 | " \n", 454 | " for M in [64]:\n", 455 | " top1 = [n_cell, D_construct, D_search, M]\n", 456 | " times = [n_cell, D_construct, D_search, M]\n", 457 | " if M > D_search or D_search > D_construct:\n", 458 | " continue\n", 459 | " \n", 460 | " # print(f'MRL IVF{n_cell},PQ{M}: D{D_search} search with D{D_construct} coarse quantization')\n", 461 | " quantizer_big = faiss.read_index(adanns_root + f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_big_quantizer.faiss')\n", 462 | " index_ivf_small = faiss.read_index(adanns_root + f'MRL_Dsmall{D_search}_Dbig{D_construct}+IVF{n_cell},OPQ{M}_small_quantizer.faiss')\n", 463 | " \n", 464 | " # disable precomputed tables, because the Dc is out of sync with the \n", 465 | " # small coarse quantizer\n", 466 | " index_ivf_small.use_precomputed_table = -1\n", 467 | " index_ivf_small.precompute_table()\n", 468 | "\n", 469 | " for nprobe in nprobes:\n", 470 | " start = time.time()\n", 471 | "\n", 472 | " # coarse quantization \n", 473 | " _, Ic = quantizer_big.search(queryset, nprobe) # Ic: (50K, nprobe)\n", 474 | "\n", 475 | " # actual search \n", 476 | " index_ivf_small.nprobe = nprobe\n", 477 | " \n", 478 | " if USE_MULTITHREAD_SEARCH:\n", 479 | " pool = ThreadPool(num_cores)\n", 480 | " partial_func = partial(multisearch_preassigned, index=index_ivf_small, queryset=queryset_small, Ic=Ic)\n", 481 | " I = pool.map(partial_func, range(queryset_small.shape[0] // thread_batch_size)) # 50K queries split to (num_batches, thread_batch_size) batches\n", 482 | " pool.close()\n", 483 | " pool.join()\n", 484 | " \n", 485 | " else:\n", 486 | " _, I = search_preassigned(index_ivf_small, queryset_small, 100, Ic, None) # I: (50K, 100)\n", 487 | "\n", 488 | " top1.append((np.sum(db_labels[I[:, 0]] == query_labels)) / query_labels.shape[0])\n", 489 | " times.append(time.time()-start)\n", 490 | " \n", 491 | " if (len(top1) > 4): # ignore continued cases\n", 492 | " with open('adanns-faiss-top1-opq.csv', 'a', encoding='UTF8', newline='') as f:\n", 493 | " writer = csv.writer(f)\n", 494 | " writer.writerow(top1)\n", 495 | " with open('adanns-faiss-timing-opq.csv', 'a', encoding='UTF8', newline='') as f:\n", 496 | " writer = csv.writer(f)\n", 497 | " writer.writerow(times)\n", 498 | " print(top1)\n", 499 | " # print(times)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "markdown", 504 | "id": "3baea8e1", 505 | "metadata": {}, 506 | "source": [ 507 | "## Metric Computation" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 50, 513 | "id": "13a8b9b8", 514 | "metadata": { 515 | "scrolled": true 516 | }, 517 | "outputs": [ 518 | { 519 | "name": "stdout", 520 | "output_type": "stream", 521 | "text": [ 522 | "['n_cell', 'D_big', 'D_small', 'M', '1probe', '4probe', '8probe']\n", 523 | "[1024, 64, 64, 64, 0.6942]\n", 524 | "[1024, 128, 64, 64, 0.69422]\n", 525 | "[1024, 128, 128, 64, 0.69584]\n", 526 | "[1024, 256, 64, 64, 0.69334]\n", 527 | "[1024, 256, 128, 64, 0.69604]\n", 528 | "[1024, 256, 256, 64, 0.69632]\n", 529 | "[1024, 512, 64, 64, 0.69418]\n", 530 | "[1024, 512, 128, 64, 0.69676]\n", 531 | "[1024, 512, 256, 64, 0.69568]\n", 532 | "[1024, 512, 512, 64, 0.6969]\n", 533 | "[1024, 1024, 64, 64, 0.69576]\n", 534 | "[1024, 1024, 128, 64, 0.69716]\n", 535 | "[1024, 1024, 256, 64, 0.69676]\n", 536 | "[1024, 1024, 512, 64, 0.69648]\n", 537 | "[1024, 1024, 1024, 64, 0.69412]\n", 538 | "[1024, 2048, 64, 64, 0.69444]\n", 539 | "[1024, 2048, 128, 64, 0.69608]\n", 540 | "[1024, 2048, 256, 64, 0.6973]\n", 541 | "[1024, 2048, 512, 64, 0.69628]\n", 542 | "[1024, 2048, 1024, 64, 0.69274]\n", 543 | "[1024, 2048, 2048, 64, 0.6899]\n" 544 | ] 545 | } 546 | ], 547 | "source": [ 548 | "header = [\"n_cell\", \"D_construct\", \"D_search\", \"M\", \"1probe\", \"4probe\", \"8probe\"]\n", 549 | "print(header)\n", 550 | "\n", 551 | "with open('adanns-faiss-top1-opq.csv', 'w', encoding='UTF8', newline='') as f:\n", 552 | " writer = csv.writer(f)\n", 553 | " writer.writerow(header)\n", 554 | " \n", 555 | "with open('adanns-faiss-timing-opq.csv', 'w', encoding='UTF8', newline='') as f:\n", 556 | " writer = csv.writer(f)\n", 557 | " writer.writerow(header)\n", 558 | " \n", 559 | "for D_construct in [64, 128, 256, 512, 1024, 2048]:\n", 560 | " for D_search in [64, 128, 256, 512, 1024, 2048]:\n", 561 | " search_adanns_indices(D_search, D_construct, n_cell=1024, nprobes=[1])" 562 | ] 563 | } 564 | ], 565 | "metadata": { 566 | "kernelspec": { 567 | "display_name": "Python 3", 568 | "language": "python", 569 | "name": "python3" 570 | }, 571 | "language_info": { 572 | "codemirror_mode": { 573 | "name": "ipython", 574 | "version": 3 575 | }, 576 | "file_extension": ".py", 577 | "mimetype": "text/x-python", 578 | "name": "python", 579 | "nbconvert_exporter": "python", 580 | "pygments_lexer": "ipython3", 581 | "version": "3.11.0 (main, Oct 26 2022, 19:06:18) [Clang 14.0.0 (clang-1400.0.29.202)]" 582 | }, 583 | "vscode": { 584 | "interpreter": { 585 | "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" 586 | } 587 | } 588 | }, 589 | "nbformat": 4, 590 | "nbformat_minor": 5 591 | } 592 | -------------------------------------------------------------------------------- /adanns/adanns-ivf-unoptimized.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "04ace433", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import faiss\n", 12 | "import time\n", 13 | "import pandas as pd\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import csv\n", 16 | "from os import path, makedirs" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "id": "e8a1adec", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "root = 'path/to/embeddings/'\n", 27 | "\n", 28 | "D = 2048\n", 29 | "D_rr_search = 16 # to load high-D database and queryset for AR with rr models\n", 30 | "\n", 31 | "method = 'adanns' # adanns, mg-ivf-rr, mg-ivf-svd\n", 32 | "dataset = '1K' # 1K, 4K, V2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "id": "f98c143d", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def load_data_helper(config, D_load=2048):\n", 43 | " db_csv = dataset + '_train_' + config + '-X.npy'\n", 44 | " query_csv = dataset + '_val_' + config + '-X.npy'\n", 45 | " db_label_csv = dataset + '_train_' + config + '-y.npy'\n", 46 | " query_label_csv = dataset + '_val_' + config + '-y.npy'\n", 47 | " \n", 48 | " if dataset == 'V2':\n", 49 | " db_csv = \"1K_train_\" + config + '-X.npy'\n", 50 | " db_label_csv = \"1K_train_\" + config + '-y.npy'\n", 51 | "\n", 52 | " db_load = np.ascontiguousarray(np.load(root+db_csv)[:, :D_load], dtype=np.float32)\n", 53 | " qy_load = np.ascontiguousarray(np.load(root+query_csv)[:, :D_load], dtype=np.float32)\n", 54 | " db_labels = np.load(root+db_label_csv)\n", 55 | " query_labels = np.load(root+query_label_csv)\n", 56 | "\n", 57 | " faiss.normalize_L2(db_load)\n", 58 | " faiss.normalize_L2(qy_load)\n", 59 | "\n", 60 | " return db_load, qy_load, db_labels, query_labels\n", 61 | "\n", 62 | "\n", 63 | "def load_construct_data(D_construct, D_rr_svd, ncentroids):\n", 64 | " if method == 'adanns':\n", 65 | " config = f'mrl1_e0_ff{D_construct}'\n", 66 | " elif method == 'mg-ivf-rr':\n", 67 | " config = f'mrl0_e0_ff{D_construct}'\n", 68 | " elif method == 'mg-ivf-svd':\n", 69 | " config = f'mrl0_e0_rr{D_construct}_svd{D_rr_svd}'\n", 70 | " else:\n", 71 | " raise Exception(\"Unsupported ANNS method.\")\n", 72 | " db_construct, qy_construct, db_labels, query_labels = load_data_helper(config, D_construct)\n", 73 | " \n", 74 | " print(\"Cluster Contruction DB: \", db_construct.shape)\n", 75 | " print(\"Cluster Construction queries:\", qy_construct.shape)\n", 76 | " \n", 77 | " # Load kmeans index and centroids with shape (centroid, D_construct)\n", 78 | " size = str(ncentroids)+'ncentroid_'+str(D_construct)+'Dc'\n", 79 | " if dataset == 'V2': # V2 is only a test set, change to 1K\n", 80 | " dataset = '1K'\n", 81 | " index_file = root+'index_files/'+method+dataset+'_kmeans_'+size\n", 82 | "\n", 83 | " centroids_path = root+'kmeans/'+method+'ncentroids'+str(ncentroids)+\"_\"+str(D_construct)+'Dc_'+dataset+'.npy'\n", 84 | " centroids = np.load(centroids_path)\n", 85 | " print(\"Loaded centroids: \", centroids.shape, centroids_path)\n", 86 | " \n", 87 | " return db_construct, qy_construct, db_labels, query_labels, centroids, index_file\n", 88 | "\n", 89 | "def load_search_data(D_search):\n", 90 | " if method == 'adanns':\n", 91 | " config = f'mrl1_e0_ff{D_search}'\n", 92 | " elif method in ['mg-ivf-rr', 'mg-ivf-svd']:\n", 93 | " config = f'mrl0_e0_ff{D_search}'\n", 94 | " else:\n", 95 | " raise Exception(\"Unsupported ANNS method.\")\n", 96 | " db_search, qy_search, _ , _ = load_data_helper(config, D_search)\n", 97 | "\n", 98 | " return db_search, qy_search\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "id": "55ff5ca2", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "def eval_cluster(val_classes, db_classes, neighbors, k): \n", 109 | " APs, topk, recall = [], [], []\n", 110 | " cluster_size = neighbors.shape[0]\n", 111 | " for i in range(cluster_size):\n", 112 | " target = val_classes[i]\n", 113 | " indices = neighbors[i][:k] # k neighbor list for ith val vector\n", 114 | " labels = db_classes[indices]\n", 115 | " matches = (labels == target)\n", 116 | " \n", 117 | " # topk\n", 118 | " hits = np.sum(matches)\n", 119 | " if hits>0:\n", 120 | " topk.append(1)\n", 121 | " else:\n", 122 | " topk.append(0)\n", 123 | " \n", 124 | " # recall\n", 125 | " recall.append(np.sum(matches)/1300)\n", 126 | " \n", 127 | " # precision values\n", 128 | " tps = np.cumsum(matches)\n", 129 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n", 130 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n", 131 | " \n", 132 | " return np.mean(recall), np.mean(topk), np.mean(APs)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 6, 138 | "id": "5a66d9b5", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def get_closest_centroids(queries, centroids, D_shortlist):\n", 143 | " centroid_index = faiss.IndexFlatL2(D_shortlist)\n", 144 | " xq_shortlist = np.ascontiguousarray(queries[:, :D_shortlist], dtype=np.float32)\n", 145 | " xc_shortlist = np.ascontiguousarray(centroids[:, :D_shortlist], dtype=np.float32)\n", 146 | " faiss.normalize_L2(xq_shortlist)\n", 147 | " faiss.normalize_L2(xc_shortlist)\n", 148 | " \n", 149 | " centroid_index.add(xc_shortlist)\n", 150 | " _, I = centroid_index.search(xq_shortlist, 1)\n", 151 | "\n", 152 | " return I" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "id": "b69174c7", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "['d_c', 'd_s', 'd_shortlist', 'ncentroid', 'top1', 'recall@100', 'mAP@100', 'overlap']\n", 166 | "Cluster Contruction DB: (1281167, 2048)\n", 167 | "Cluster Construction queries: (10000, 2048)\n", 168 | "\n", 169 | "Loaded kmeans index: 1K_kmeans_1024ncentroid_2048d\n", 170 | "Loaded centroids: (1024, 2048) ../../inference_array/resnet50/kmeans/mrl/ncentroids1024_2048d_1K.npy\n", 171 | "Linear scan with d = 8\n", 172 | "[2048, 8, 2048, 1024, 0.5351, 0.05091163624126155, 0.605462860936279, 0]\n", 173 | "d_c:2048, d_s: 8, ncentroid: 1024\n", 174 | "Recall@100: 0.05091163624126155\n", 175 | "Top1: 0.5351\n", 176 | "Linear scan with d = 16\n", 177 | "[2048, 32, 2048, 1024, 0.5732, 0.051922105632729296, 0.6227463073287828, 0]\n", 178 | "d_c:2048, d_s: 32, ncentroid: 1024\n", 179 | "Recall@100: 0.051922105632729296\n", 180 | "Top1: 0.5732\n", 181 | "Linear scan with d = 64\n", 182 | "[2048, 64, 2048, 1024, 0.5785, 0.05198921759119408, 0.6243867377357402, 0]\n", 183 | "d_c:2048, d_s: 64, ncentroid: 1024\n", 184 | "Recall@100: 0.05198921759119408\n", 185 | "Top1: 0.5785\n", 186 | "Linear scan with d = 128\n", 187 | "[2048, 128, 2048, 1024, 0.5802, 0.05199322657824928, 0.6247377244191619, 0]\n", 188 | "d_c:2048, d_s: 128, ncentroid: 1024\n", 189 | "Recall@100: 0.05199322657824928\n", 190 | "Top1: 0.5802\n", 191 | "Linear scan with d = 256\n", 192 | "[2048, 256, 2048, 1024, 0.5801, 0.05199830116619949, 0.6250036507917335, 0]\n", 193 | "d_c:2048, d_s: 256, ncentroid: 1024\n", 194 | "Recall@100: 0.05199830116619949\n", 195 | "Top1: 0.5801\n", 196 | "Linear scan with d = 512\n", 197 | "[2048, 512, 2048, 1024, 0.5803, 0.05199979769465816, 0.6251013861127136, 0]\n", 198 | "d_c:2048, d_s: 512, ncentroid: 1024\n", 199 | "Recall@100: 0.05199979769465816\n", 200 | "Top1: 0.5803\n", 201 | "Linear scan with d = 1024\n", 202 | "[2048, 1024, 2048, 1024, 0.5766, 0.05198958841622709, 0.6249844689120537, 0]\n", 203 | "d_c:2048, d_s: 1024, ncentroid: 1024\n", 204 | "Recall@100: 0.05198958841622709\n", 205 | "Top1: 0.5766\n", 206 | "Total Time for 8 configs = 211.981683\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "D_search_list = [8, 16, 32, 64, 128, 256, 512, 1024]\n", 212 | "ncentroids = 1024\n", 213 | "\n", 214 | "for D in [2048]:\n", 215 | " k=100\n", 216 | " D_rr_svd = D\n", 217 | " D_construct_list = [D]\n", 218 | " D_shortlist_list = [D]\n", 219 | "\n", 220 | " header = ['d_construct', 'd_search', 'd_shortlist', 'ncentroid', 'top1', 'recall@'+str(k), 'mAP@'+str(k)]\n", 221 | " print(header)\n", 222 | " with open('kmeans_metrics.csv', 'w', encoding='UTF8', newline='') as f:\n", 223 | " writer = csv.writer(f)\n", 224 | " writer.writerow(header)\n", 225 | "\n", 226 | " start = time.time()\n", 227 | " for D_c in D_construct_list:\n", 228 | " # Load all construction data (database, queries, centroids)\n", 229 | " xb_construct, xq_construct, db_labels, query_labels, centroids, index_file = load_construct_data(D_c, D_rr_svd, ncentroids)\n", 230 | "\n", 231 | " cpu_index = faiss.read_index(index_file+'.index')\n", 232 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n", 233 | " print(\"\\nLoaded kmeans index:\", index_file.split(\"/\")[-1])\n", 234 | "\n", 235 | " # construct lookup table of centroid --> vectors, i.e. inverted lists\n", 236 | " _, I_db = index.search(xb_construct, 1)\n", 237 | " lut_db = {}\n", 238 | " for c in np.unique(I_db):\n", 239 | " lut_db[c] = np.argwhere(I_db==c)[:,0]\n", 240 | "\n", 241 | " for D_search in D_search_list:\n", 242 | " print(\"Linear scan with D_s = \", D_search)\n", 243 | " xb_search, xq_search = load_search_data(D_search)\n", 244 | "\n", 245 | " for D_shortlist in D_shortlist_list:\n", 246 | " # Currently, D_shortlist <= D_search is supported as we slice centroids for adanns\n", 247 | " I_q = get_closest_centroids(xq_search, centroids, D_shortlist)\n", 248 | " lut_q = {}\n", 249 | "\n", 250 | " start = time.time()\n", 251 | " recall, topk, mAP = [], [], []\n", 252 | "\n", 253 | " #Iterate over all centroids assigned to each\n", 254 | " for c in np.unique(I_q):\n", 255 | " lut_q[c] = np.argwhere(I_q==c)[:,0]\n", 256 | " exact_cpu_index = faiss.IndexFlatL2(D_search)\n", 257 | "\n", 258 | " # add cluster vectors to index and search only queries that map to that cluster\n", 259 | " exact = faiss.index_cpu_to_all_gpus(exact_cpu_index)\n", 260 | " cluster_db = np.ascontiguousarray(xb_search[lut_db[c]][:, :D_search], np.float32)\n", 261 | " cluster_query = np.ascontiguousarray(xq_search[lut_q[c]][:, :D_search], np.float32)\n", 262 | " faiss.normalize_L2(cluster_db)\n", 263 | " faiss.normalize_L2(cluster_query)\n", 264 | " exact.add(cluster_db)\n", 265 | " Dist, Ind = exact.search(cluster_query, k)\n", 266 | "\n", 267 | " # replace cluster-specific indices with original database indices for eval\n", 268 | " cluster_db_labels = db_labels[lut_db[c]]\n", 269 | " cluster_query_labels = query_labels[lut_q[c]]\n", 270 | "\n", 271 | " nn_1 = Ind[:, 0]\n", 272 | " pred_1 = cluster_db_labels[nn_1]\n", 273 | " hits = np.sum(pred_1 == cluster_query_labels)\n", 274 | " topk.append(hits)\n", 275 | "\n", 276 | " rl, tk, mp = eval_cluster(cluster_query_labels, cluster_db_labels, Ind, k)\n", 277 | " recall.append(rl)\n", 278 | " mAP.append(mp)\n", 279 | " row = [D_c, D_search, D_shortlist, ncentroids, np.sum(topk)/xq_search.shape[0], np.mean(recall), np.mean(mAP)]\n", 280 | " print(row)\n", 281 | "\n", 282 | " with open('kmeans_metrics.csv', 'a', encoding='UTF8', newline='') as f:\n", 283 | " writer = csv.writer(f)\n", 284 | " writer.writerow(row)\n", 285 | " print(\"d_c:%d, d_s: %d, ncentroid: %d\" %(D_c, D_search, ncentroids))\n", 286 | " print(\"Recall@100: \", np.mean(recall))\n", 287 | " print(\"Top1: \", np.sum(topk)/xq_search.shape[0])\n", 288 | "\n", 289 | " print(\"Total Time for %d configs = %f\" % (len(D_search_list) * len(ncentroids) * len(D_construct_list), time.time() - start))" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.11.0" 310 | }, 311 | "vscode": { 312 | "interpreter": { 313 | "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" 314 | } 315 | } 316 | }, 317 | "nbformat": 4, 318 | "nbformat_minor": 5 319 | } 320 | -------------------------------------------------------------------------------- /adanns/compute_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "aa3cf273", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd\n", 11 | "import numpy as np\n", 12 | "import time\n", 13 | "import os\n", 14 | "import faiss\n", 15 | "import csv" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "f4fbeb4d", 21 | "metadata": {}, 22 | "source": [ 23 | "## Configuration Variables" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "6b093168", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "D = 2048 # vector dim\n", 34 | "ROOT_DIR = '../../inference_array/resnet50/'\n", 35 | "CONFIG = 'mrl/' # ['mrl/', 'rr/']\n", 36 | "NESTING = CONFIG == 'mrl/'\n", 37 | "SEARCH_INDEX = 'ivfpq' # ['exactl2', 'ivfpq', 'opq', 'hnsw32']\n", 38 | "DATASET = '1K' # 1K, V2, 4K\n", 39 | "\n", 40 | "# Quantization Variables\n", 41 | "nbits = 8 # nbits used to represent centroid id; total possible is k* = 2**nbits\n", 42 | "nlist = 1024 # how many Voronoi cells (must be >= k*)\n", 43 | "iterator = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] # vector dim D \n", 44 | "\n", 45 | "if SEARCH_INDEX in ['ivfpq', 'opq']:\n", 46 | " M = 32 # number of sub-quantizers, i.e. compression in bytes" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "id": "9c9351db", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "def compute_mAP_recall_at_k(val_classes, db_classes, neighbors, k):\n", 57 | " \"\"\"\n", 58 | " Computes the MAP@k on neighbors with val set by seeing if nearest neighbor\n", 59 | " is in the same class as the class of the val code. Let m be size of val set, and n in train.\n", 60 | "\n", 61 | " val: (m x d) All the truncated vector representations of images in val set\n", 62 | " val_classes: (m x 1) class index values for each vector in the val set\n", 63 | " db_classes: (n x 1) class index values for each vector in the train set\n", 64 | " neighbors: (m x k) indices in train set of top k neighbors for each vector in val set\n", 65 | " \"\"\"\n", 66 | "\n", 67 | " \"\"\"\n", 68 | " ImageNet-1K:\n", 69 | " shape of val is: (50000, dim)\n", 70 | " shape of val_classes is: (50000, 1)\n", 71 | " shape of db_classes is: (1281167, 1)\n", 72 | " shape of neighbors is: (50000, k)\n", 73 | " \"\"\"\n", 74 | " APs, precision, recall, topk, unique_cls = [], [], [], [], []\n", 75 | " \n", 76 | " for i in range(val_classes.shape[0]): # Compute precision for each vector's list of k-nn\n", 77 | " target = val_classes[i]\n", 78 | " indices = neighbors[i, :][:k] # k neighbor list for ith val vector\n", 79 | " labels = db_classes[indices]\n", 80 | " matches = (labels == target)\n", 81 | " \n", 82 | " # Number of unique classes\n", 83 | " unique_cls.append(len(np.unique(labels)))\n", 84 | " \n", 85 | " # topk\n", 86 | " hits = np.sum(matches)\n", 87 | " if hits > 0:\n", 88 | " topk.append(1)\n", 89 | " else:\n", 90 | " topk.append(0)\n", 91 | " \n", 92 | " # true positive counts\n", 93 | " tps = np.cumsum(matches)\n", 94 | "\n", 95 | " # recall\n", 96 | " recall.append(np.sum(matches)/1300)\n", 97 | " precision.append(np.sum(matches)/k)\n", 98 | "\n", 99 | " # precision values\n", 100 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n", 101 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n", 102 | "\n", 103 | " return np.mean(APs), np.mean(precision), np.mean(recall), np.mean(topk), np.mean(unique_cls)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "ab7ca0bb", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def get_k_recall_at_N(exact_gt, neighbors, k=40, N=2048):\n", 114 | " \"\"\"\n", 115 | " Computes k-Recall@N which denotes the recall of k true nearest neighbors (exact search) \n", 116 | " when N datapoints are retrieved with ANNS. Let q be size of query set.\n", 117 | " \n", 118 | " exact_gt: (q x k) True nearest-neighbors of query set computed with exact search\n", 119 | " neighbors: (q x N) Approximate nearest-neighbors of query set\n", 120 | " k: (1) Number of true nearest-neighbors\n", 121 | " N: (1) Number of approximate nearest-neighbors retrieved\n", 122 | " \"\"\"\n", 123 | " labels = exact_gt[:, :k] # Labels from true NN\n", 124 | " targets = neighbors\n", 125 | " num_queries = exact_gt.shape[0]\n", 126 | " count = 0\n", 127 | " for i in range(num_queries):\n", 128 | " label = labels[i]\n", 129 | " target = targets[i, :N]\n", 130 | " # Compute overlap between approximate and true nearest-neighbors\n", 131 | " count += len(list(set(label).intersection(target)))\n", 132 | " return count / (num_queries * k)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "id": "74005467", 138 | "metadata": {}, 139 | "source": [ 140 | "## Load database, query, and neighbor arrays and compute metrics" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 5, 146 | "id": "1245f078", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "def load_knn_array(dim, **kwargs):\n", 151 | " if SEARCH_INDEX in ['ivfpq', 'opq']:\n", 152 | " if (M > dim):\n", 153 | " return\n", 154 | " size = 'm'+str(M)+'_nlist'+str(nlist)+\"_nprobe\"+str(nprobe)+\"_\"\n", 155 | " elif SEARCH_INDEX == 'ivfsq':\n", 156 | " size = str(qtype)+'qtype_'\n", 157 | " elif SEARCH_INDEX == 'kmeans':\n", 158 | " size = str(nlist)+'ncentroid_'\n", 159 | " elif SEARCH_INDEX == 'ivf':\n", 160 | " size = 'nlist'+str(nlist)+\"_nprobe\"+str(nprobe)+\"_\"\n", 161 | " elif SEARCH_INDEX in ['hnsw32', 'hnswpq_M32_pq-m8','hnswpq_M32_pq-m16','hnswpq_M32_pq-m32','hnswpq_M32_pq-m64', 'hnswpq_M32_pq-m128']:\n", 162 | " size = 'efsearch'+str(nprobe)+\"_\"\n", 163 | " else:\n", 164 | " raise Exception(f\"Unsupported Search Index: {SEARCH_INDEX}\")\n", 165 | "\n", 166 | " # Load neighbors array and compute metrics\n", 167 | " neighbors_path = ROOT_DIR + \"neighbors/\" + CONFIG + SEARCH_INDEX+\"/\"+SEARCH_INDEX + \"_\" + size \\\n", 168 | " + \"2048shortlist_\" + DATASET + \"_d\"+str(dim)+\".csv\"\n", 169 | " \n", 170 | " if not os.path.exists(neighbors_path):\n", 171 | " print(neighbors_path.split(\"/\")[-1] + \" not found\")\n", 172 | " return\n", 173 | "\n", 174 | " return pd.read_csv(neighbors_path, header=None).to_numpy()\n", 175 | "\n", 176 | "\n", 177 | "def print_metrics(iterator, shortlist, metric, nprobe=1, N=2048):\n", 178 | " \"\"\"\n", 179 | " Computes and print retrieval metrics.\n", 180 | " \n", 181 | " iterator: (List) True nearest-neighbors of query set computed with exact search\n", 182 | " shortlist: (List) Number of data points retrieved (k)\n", 183 | " metric: Name of metric ['topk', 'mAP', 'precision', 'recall', 'unique_cls', 'k_recall_at_n']\n", 184 | " nprobe: Number of clusters probed during search (IVF) OR 'efSearch' for HNSW search quality\n", 185 | " N: Number of data points retrieved for k-recall@N\n", 186 | " \"\"\"\n", 187 | " # Load database and query set for nested models\n", 188 | " if NESTING:\n", 189 | " # Database: 1.2M x 1 for Imagenet-1K\n", 190 | " if DATASET == 'V2':\n", 191 | " db_labels = np.load(ROOT_DIR + \"1K_train_mrl1_e0_ff2048-y.npy\")\n", 192 | " else:\n", 193 | " db_labels = np.load(ROOT_DIR + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n", 194 | " \n", 195 | " # Query set: 50K x 1 for Imagenet-1K\n", 196 | " query_labels = np.load(ROOT_DIR + DATASET + \"_val_mrl1_e0_ff2048-y.npy\")\n", 197 | " \n", 198 | " for dim in iterator:\n", 199 | " # Load database and query set for fixed feature models\n", 200 | " if not NESTING:\n", 201 | " db_labels = np.load(ROOT_DIR + DATASET + \"_train_mrl0_e0_ff\"+str(dim)+\"-y.npy\")\n", 202 | " query_labels = np.load(ROOT_DIR + DATASET + \"_val_mrl0_e0_ff\"+str(D)+\"-y.npy\")\n", 203 | " \n", 204 | " neighbors = load_knn_array(dim, M=M, nlist=nlist, nprobe=nprobe)\n", 205 | " \n", 206 | " for k in shortlist:\n", 207 | " if metric == 'k_recall_at_n':\n", 208 | " # Use 40-NN from Exact Search with MRL as GT\n", 209 | " if NESTING:\n", 210 | " query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/mrl_exactl2_2048dim_{k}shortlist_1K.csv', header=None).to_numpy()\n", 211 | " else:\n", 212 | " query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/rr_exactl2_{dim}dim_{k}shortlist_1K.csv', header=None).to_numpy()\n", 213 | " \n", 214 | " k_recall = (get_k_recall_at_N(query_labels, neighbors, k, N))\n", 215 | " print(f'{k}-Recall@{N} = {k_recall}')\n", 216 | " \n", 217 | " else:\n", 218 | " mAP, precision, recall, topk, unique_cls = compute_mAP_recall_at_k(query_labels, db_labels, neighbors, k)\n", 219 | " if (metric == 'topk'): print(f'topk, {dim}, {M}, {nprobe}, {topk}')\n", 220 | " elif (metric == 'mAP'): print(f'mAP, {dim}, {M}, {nprobe}, {mAP}')\n", 221 | " elif (metric == 'precision'): print(f'precision, {dim}, {M}, {nprobe}, {precision}')\n", 222 | " elif (metric == 'recall') : print(f'recall, {dim}, {M}, {nprobe}, {recall}')\n", 223 | " elif (metric == 'unique_cls'): print(f'unique_cls, {dim}, {M}, {nprobe}, {unique_cls}')\n", 224 | " else: raise Exception(\"Unsupported metric!\")" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "5f2ea629", 230 | "metadata": {}, 231 | "source": [ 232 | "## Example: Traditional Retrieval Metrics (Top-1, mAP, Recall)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 6, 238 | "id": "a336eb99", 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "Index: ivfpq\n", 246 | "metric, D, M, nprobe, value\n", 247 | "topk, 16, 8, 1, 0.6775\n", 248 | "topk, 32, 8, 1, 0.6861\n", 249 | "mAP, 16, 8, 1, 0.6306868079365078\n", 250 | "mAP, 32, 8, 1, 0.6374524079365079\n", 251 | "recall, 16, 8, 1, 0.05151807692307692\n", 252 | "recall, 32, 8, 1, 0.051838800000000004\n" 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "# Example evaluation for IVFPQ\n", 258 | "iterator = [16, 32]\n", 259 | "print(\"Index:\", SEARCH_INDEX)\n", 260 | "print(\"metric, D, M, nprobe, value\")\n", 261 | "for M in [8]:\n", 262 | " for nprobe in [1]:\n", 263 | " print_metrics(iterator, [1], 'topk', nprobe)\n", 264 | " print_metrics(iterator, [10], 'mAP', nprobe)\n", 265 | " print_metrics(iterator, [100], 'recall', nprobe)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "id": "8ce6be99", 271 | "metadata": {}, 272 | "source": [ 273 | "## Example: ANNS Metric: k-Recall@N" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 64, 279 | "id": "4d5663b9", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "k-recall@N GT: (50000, 40)\n", 287 | "40-Recall@2048 = 0.2071915\n", 288 | "k-recall@N GT: (50000, 40)\n", 289 | "40-Recall@2048 = 0.311641\n", 290 | "k-recall@N GT: (50000, 40)\n", 291 | "40-Recall@2048 = 0.377283\n", 292 | "k-recall@N GT: (50000, 40)\n", 293 | "40-Recall@2048 = 0.4137225\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "USE_K_RECALL_AT_N = True\n", 299 | "SEARCH_INDEX = 'hnsw32'\n", 300 | "iterator = [8, 16, 32, 64]\n", 301 | "\n", 302 | "print_metrics(iterator, [40], 'krecall', nprobe=1, N=2048)" 303 | ] 304 | } 305 | ], 306 | "metadata": { 307 | "kernelspec": { 308 | "display_name": "Python 3 (ipykernel)", 309 | "language": "python", 310 | "name": "python3" 311 | }, 312 | "language_info": { 313 | "codemirror_mode": { 314 | "name": "ipython", 315 | "version": 3 316 | }, 317 | "file_extension": ".py", 318 | "mimetype": "text/x-python", 319 | "name": "python", 320 | "nbconvert_exporter": "python", 321 | "pygments_lexer": "ipython3", 322 | "version": "3.11.3" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 5 327 | } 328 | -------------------------------------------------------------------------------- /adanns/diskann/README.md: -------------------------------------------------------------------------------- 1 | # AdANNS-DiskANN 2 | AdANNS-DiskANN is a variant of [DiskANN](https://github.com/microsoft/DiskANN), a web-scale graph-based ANNS index capable of serving queries from both RAM and Disk (cheap SSDs). 3 | 4 |

5 | 6 |

7 | 8 | 9 | We provide a self-contained pipeline in [adanns-diskann.ipynb](adanns-diskann.ipynb) which requires a build of DiskANN provided in the [original codebase](https://github.com/microsoft/DiskANN) and summarized below: 10 | 11 | ``` 12 | sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format \ 13 | libboost-all-dev libmkl-full-dev 14 | 15 | mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j 16 | ``` 17 | 18 | ## Notebook Overview 19 | [adanns-diskann.ipynb](adanns-diskann.ipynb) is broadly organized as: 20 | 1. Data preprocessing: convert MR or RR embeddings (fp32 `np.ndarray`) to binary format 21 | 2. Generate exact-search "ground truth" used for k-recall@N 22 | 3. Build In-Memory or SSD DiskANN index on MR or RR 23 | 4. Search the built indices to generate k-NN arrays 24 | 5. Evaluate the k-NN arrays with and without reranking -------------------------------------------------------------------------------- /adanns/diskann/adanns-diskann.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "650d31f4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Dataset Preparation for DiskANN" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "d6dbab11", 15 | "metadata": { 16 | "scrolled": false 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import sys\n", 22 | "import os\n", 23 | "\n", 24 | "embeddings_root = 'path/to/embeddings/'\n", 25 | "\n", 26 | "def generate_bin_data_from_ndarray(embedding_path, bin_out_path, embedding_dims):\n", 27 | " data_orig = np.load(embedding_path)\n", 28 | " for d in embedding_dims:\n", 29 | " data_sliced = data_orig[:, :d]\n", 30 | " outfile = bin_out_path+\"_d\"+str(d)+\".fbin\"\n", 31 | " print(outfile.split(\"/\")[-1])\n", 32 | " print(\"Array sliced: \", data_sliced.shape)\n", 33 | " data_sliced.astype('float32').tofile(\"temp\")\n", 34 | "\n", 35 | " num_points = data_sliced.shape[0].to_bytes(4, 'little')\n", 36 | " data_dim = data_sliced.shape[1].to_bytes(4, 'little')\n", 37 | "\n", 38 | " with open(\"temp\", \"rb\") as old, open(outfile, \"wb\") as new:\n", 39 | " new.write(num_points)\n", 40 | " new.write(data_dim)\n", 41 | " new.write(old.read())\n", 42 | " \n", 43 | " os.remove(\"temp\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "b69e102e", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n", 54 | "for d in nesting_list:\n", 55 | " generate_bin_data_from_ndarray(embeddings_root+\"1K_train_mrl0_e0_ff\"+str(d)+\"-X.npy\", \"../build/data/rr-resnet50/fbin/database\", [d])\n", 56 | " print()\n", 57 | " generate_bin_data_from_ndarray(embeddings_root+\"1K_val_mrl0_e0_ff\"+str(d)+\"-X.npy\", \"../build/data/rr-resnet50/fbin/queries\", [d])" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "5db9a1eb", 63 | "metadata": {}, 64 | "source": [ 65 | "## Generate Exact Search ground truth from queries and database" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 2, 71 | "id": "45584765", 72 | "metadata": { 73 | "scrolled": true 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "%%bash\n", 78 | "use_mrl=mr # mr or rr \n", 79 | "\n", 80 | "for d in 8 16 32 64 128 256 512 1024 2048\n", 81 | "do\n", 82 | " ./../build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 \\\n", 83 | " --base_file ../build/data/{use_mrl}-resnet50/fbin/database_d$d.fbin \\\n", 84 | " --query_file ../build/data/{use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n", 85 | " --gt_file ../build/data/{use_mrl}-resnet50/exact_gt100/${use_mrl}_r50_queries_d$d\"\"_gt100 --K 100\n", 86 | "done" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "a21fb4fd", 92 | "metadata": {}, 93 | "source": [ 94 | "## Build DiskANN In-Memory Index" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "d5ae6f7d", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "%%bash\n", 105 | "opq_bytes=32\n", 106 | "use_mrl=mrl # mr or rr\n", 107 | "\n", 108 | "for use_mrl in mrl\n", 109 | "do\n", 110 | " for d in 32 64 128 256 512 1024 2048\n", 111 | " do\n", 112 | " echo -e \"Building index ${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d\\n\"\n", 113 | " ./../build/tests/build_memory_index --data_type float --dist_fn l2 \\\n", 114 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n", 115 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/memory-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d \\\n", 116 | " -R 64 -L 100 --alpha 1.2 --build_PQ_bytes ${opq_bytes} --use_opq\n", 117 | " done\n", 118 | "done" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "95b04524", 124 | "metadata": {}, 125 | "source": [ 126 | "## Build DiskANN SSD Index" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "4cd4d51b", 133 | "metadata": { 134 | "scrolled": true 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "%%bash\n", 139 | "opq_bytes=48\n", 140 | "use_mrl=rr\n", 141 | "reorder=disk-index-no-reorder\n", 142 | "\n", 143 | "# Disable post-hoc re-ranking by setting PQ_disk_bytes = build_PQ_bytes\n", 144 | "for opq_bytes in 32 48 64\n", 145 | "do\n", 146 | " for d in 1024\n", 147 | " do\n", 148 | " echo -e \"Building disk OPQ index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n", 149 | " ./../build/tests/build_disk_index --data_type float --dist_fn l2 \\\n", 150 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n", 151 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/disk-index-no-reorder/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n", 152 | " -R 64 -L 100 -B 0.3 -M 40 --PQ_disk_bytes $opq_bytes --build_PQ_bytes $opq_bytes --use_opq \n", 153 | " done\n", 154 | "done\n", 155 | "\n", 156 | "# Build index with implicit post-hoc full-precision reranking\n", 157 | "for opq_bytes in 32 48 64\n", 158 | "do\n", 159 | " for d in 128 1024\n", 160 | " do\n", 161 | " ./../build/tests/build_disk_index --data_type float --dist_fn l2 \\\n", 162 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n", 163 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/disk-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n", 164 | " -R 64 -L 100 -B 0.3 -M 40 --build_PQ_bytes $opq_bytes --use_opq \n", 165 | " echo -e \"Build index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n", 166 | " done\n", 167 | "done" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "cad290ff", 173 | "metadata": {}, 174 | "source": [ 175 | "## Search DiskANN Memory Index" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 1, 181 | "id": "2685a209", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "%%bash\n", 186 | "opq_bytes=32\n", 187 | "\n", 188 | "for use_mrl in rr mr\n", 189 | "do\n", 190 | " for d in 32 64 128 256 512 1024 2048\n", 191 | " do\n", 192 | " ./../build/tests/search_memory_index --data_type float --dist_fn l2 \\\n", 193 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/memory-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d \\\n", 194 | " --query_file ../build/data/${use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n", 195 | " --gt_file ../build/data/${use_mrl}-resnet50/exact_gt100/mrlr50_queries_d$d\"\"_gt100 \\\n", 196 | " -K 100 -L 100 --result_path ../build/data/${use_mrl}-resnet50/res/memory-index/d$d/opq${opq_bytes}\n", 197 | " echo -e \"Searched index ${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d\\n\"\n", 198 | " done\n", 199 | "done" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "e7627d51", 205 | "metadata": {}, 206 | "source": [ 207 | "## Search DiskANN SSD Index" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "35655236", 214 | "metadata": { 215 | "scrolled": true 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "%%bash\n", 220 | "opq_bytes=48\n", 221 | "use_mrl=mrl\n", 222 | "reorder=disk-index\n", 223 | "\n", 224 | "for d in 1024\n", 225 | "do\n", 226 | " for W in 2 8 16 32 # search quality\n", 227 | " do\n", 228 | " ./../build/tests/search_disk_index --data_type float --dist_fn l2 \\\n", 229 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/${reorder}/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n", 230 | " --query_file ../build/data/${use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n", 231 | " --gt_file ../build/data/${use_mrl}-resnet50/exact_gt100/mrlr50_queries_d$d\"\"_gt100 \\\n", 232 | " -K 100 -L 100 -W ${W} --num_nodes_to_cache 100000 --result_path ../build/data/${use_mrl}-resnet50/res/${reorder}/d$d/opq${opq_bytes}_W$W\n", 233 | " echo -e \"Searched index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n", 234 | " done\n", 235 | "done" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "38563959", 241 | "metadata": {}, 242 | "source": [ 243 | "# DiskANN Eval" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 1, 249 | "id": "cf835299", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "D = 2048\n", 254 | "CONFIG = 'mr' # ['mr', 'rr']\n", 255 | "NESTING = CONFIG == 'mr'\n", 256 | "DISKANN_INDEX = 'memory-index' # disk-index\n", 257 | "DATASET = '1K' # ['1K', '4K', 'V2']" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 2, 263 | "id": "e3dc60ac", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "def compute_mAP_recall_at_k(val_classes, db_classes, neighbors, k):\n", 268 | " APs = list()\n", 269 | " precision, recall, topk, majvote, unique_cls = [], [], [], [], []\n", 270 | " \n", 271 | " for i in range(val_classes.shape[0]): # Compute precision for each vector's list of k-nn\n", 272 | " target = val_classes[i]\n", 273 | " indices = neighbors[i, :][:k] # k neighbor list for ith val vector\n", 274 | " labels = db_classes[indices]\n", 275 | " matches = (labels == target)\n", 276 | " \n", 277 | " # Number of unique classes\n", 278 | " unique_cls.append(len(np.unique(labels)))\n", 279 | " \n", 280 | " # topk\n", 281 | " hits = np.sum(matches)\n", 282 | " if hits>0:\n", 283 | " topk.append(1)\n", 284 | " else:\n", 285 | " topk.append(0)\n", 286 | " \n", 287 | " # true positive counts\n", 288 | " tps = np.cumsum(matches)\n", 289 | "\n", 290 | " # recall\n", 291 | " recall.append(np.sum(matches)/1300)\n", 292 | " precision.append(np.sum(matches)/k)\n", 293 | "\n", 294 | " # precision values\n", 295 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n", 296 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n", 297 | "\n", 298 | " return np.mean(APs), np.mean(precision), np.mean(recall), np.mean(topk), majvote, np.mean(unique_cls)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 3, 304 | "id": "1669e2da", 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "def print_metrics(CONFIG, nesting_list, shortlist, metric, nprobe=1):\n", 309 | " if NESTING:\n", 310 | " # Database: 1.2M x 1 for imagenet1k\n", 311 | " db_labels = np.load(embeddings_root + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n", 312 | " \n", 313 | " # Query set: 50K x 1 for imagenet1k\n", 314 | " query_labels = np.load(embeddings_root + DATASET + \"_val_mrl1_e0_ff2048-y.npy\")\n", 315 | " \n", 316 | " for dim in nesting_list:\n", 317 | " if opq > dim:\n", 318 | " continue\n", 319 | " # Load database and query set for fixed feature models\n", 320 | " if not NESTING:\n", 321 | " db_labels = np.load(embeddings_root + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n", 322 | " query_labels = np.load(embeddings_root + DATASET + \"_val_mrl0_e0_ff\"+str(D)+\"-y.npy\")\n", 323 | " \n", 324 | " for W in [32]:\n", 325 | " row = [dim, opq, W]\n", 326 | " fileName = f'/home/jupyter/DiskANN/build/data/{CONFIG}-resnet50/res/{DISKANN_INDEX}/d{dim}/opq{opq}_100_idx_uint32.bin'\n", 327 | " print(fileName)\n", 328 | " with open(fileName, 'rb') as f:\n", 329 | " data = np.fromfile(f, dtype=' 4 | 5 |

6 | 7 | We follow the setup on the [Dense Passage Retriever](https://github.com/facebookresearch/DPR) (DPR) repo. The Wikipedia corpus has 21 million passages and Natural Questions (NQ) dataset for open-domain QA settings. AdANNS with DPR on NQ is organized as: 8 | 1. Index Training for IP, L2, IVF, OPQ, IVFOPQ. Currently supported 9 | - In batches (RAM-constrained, ~2 hours build time) 10 | - One shot (~120G peak RAM usage, ~2 minutes build time) 11 | 2. DPR eval (available on request) -------------------------------------------------------------------------------- /adanns/dpr-nq/adanns-nq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "4ef1bef5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from loguru import logger\n", 11 | "import pyarrow as pa\n", 12 | "import faiss\n", 13 | "from tqdm import tqdm\n", 14 | "import numpy as np\n", 15 | "import time" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 5, 21 | "id": "80a238e4", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# Specify These\n", 26 | "config = 'MR' # MR, RR\n", 27 | "index_type = 'IVFOPQ' # IP, L2, IVF, OPQ, IVFOPQ\n", 28 | "train_batches = False # Set to True if system has sufficient RAM\n", 29 | "DPR_root = '/mnt/disks/experiments/DPR/'" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 8, 35 | "id": "5861db59", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if config == 'MR':\n", 40 | " config_name = 'dpr-nq-d768_384_192_96_48-wiki' # MR\n", 41 | "else:\n", 42 | " config_name = 'dpr-nq-d768-wiki' # RR-768\n", 43 | " \n", 44 | "embeddings_file = f'{DPR_root}results/embed/{config_name}.arrow'\n", 45 | "emb_data = pa.ipc.open_file(pa.memory_map(embeddings_file, \"rb\")).read_all()" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "4d1754f2", 51 | "metadata": {}, 52 | "source": [ 53 | "## Batched Index Training (RAM-constrained)\n", 54 | "Learn Exact Search Indices (with IP distance) in batches over 21M passages." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 7, 60 | "id": "dd9cb0db", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "# Train exact index with database and queries with embedding size 'dim' and write to disk\n", 65 | "def batched_train(dim): \n", 66 | " index_file = f'results/embed/exact-index/{config_name}-dim{dim}_{index_type}_batched.faiss'\n", 67 | " \n", 68 | " sub_index = faiss.IndexFlatIP(dim)\n", 69 | " faiss_index = faiss.IndexIDMap2(sub_index)\n", 70 | "\n", 71 | " total = 0\n", 72 | " for batch in tqdm(emb_data.to_batches()):\n", 73 | " batch_data = batch.to_pydict()\n", 74 | " psg_ids = np.array(batch_data[\"id\"])\n", 75 | "\n", 76 | " token_emb = np.array(batch_data[\"embedding\"], dtype=np.float32)\n", 77 | " token_emb = np.ascontiguousarray(token_emb[:, :dim]) # Shape: (8192, dim)\n", 78 | " faiss_index.add_with_ids(token_emb, psg_ids)\n", 79 | "\n", 80 | " total += len(psg_ids)\n", 81 | " if total % 1000 == 0:\n", 82 | " logger.info(f\"indexed {total} passages\")\n", 83 | "\n", 84 | " faiss.write_index(faiss_index, str(index_file))\n", 85 | "\n", 86 | "if(train_batches):\n", 87 | " batched_train(dim=768)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "a46ce428", 93 | "metadata": {}, 94 | "source": [ 95 | "## Full Training (High peak RAM Usage ~120G)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "id": "5ee63112", 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "(21015324,)\n", 109 | "(21015324, 768) float32\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "if not train_batches:\n", 115 | " psg_ids = np.array(emb_data['id'])\n", 116 | " print(psg_ids.shape) # Passage IDs\n", 117 | "\n", 118 | " # Takes ~5 min on our system\n", 119 | " token_emb = np.array(emb_data[\"embedding\"])\n", 120 | "\n", 121 | " token_emb = np.hstack(token_emb)\n", 122 | "\n", 123 | " token_emb = token_emb.reshape(21015324, -1)\n", 124 | " print(token_emb.shape, token_emb.dtype) # Token Embeddings\n", 125 | "else:\n", 126 | " raise Exception(\"Insufficient RAM to train on entire data!\")" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "id": "1adcd6c5", 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M8.faiss\n", 140 | "Adding DB: (21015324, 768)\n", 141 | "Time to build index with d=768 : 1723.532558\n", 142 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss\n", 143 | "Adding DB: (21015324, 768)\n", 144 | "Time to build index with d=768 : 1926.746812\n", 145 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M32.faiss\n", 146 | "Adding DB: (21015324, 768)\n", 147 | "Time to build index with d=768 : 2302.668599\n", 148 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M48.faiss\n", 149 | "Adding DB: (21015324, 768)\n", 150 | "Time to build index with d=768 : 2746.178078\n", 151 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M64.faiss\n", 152 | "Adding DB: (21015324, 768)\n", 153 | "Time to build index with d=768 : 2214.350993\n", 154 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M96.faiss\n", 155 | "Adding DB: (21015324, 768)\n", 156 | "Time to build index with d=768 : 2294.547853\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "ncell=10 # Number of IVF cells\n", 162 | "dims=[768] # Embedding dims to train indices over\n", 163 | "Ms=[8, 16, 32, 48, 64, 96] # Number of PQ sub-quantizers for IVF+OPQ\n", 164 | "\n", 165 | "for M in Ms:\n", 166 | " for dim in dims:\n", 167 | " if M > dim or dim%M!=0:\n", 168 | " print(\"Skipping (d,M) : (%d, %d)\" %(dim, M))\n", 169 | " continue\n", 170 | " \n", 171 | " token_emb_sliced = np.ascontiguousarray(token_emb[:, :dim])\n", 172 | " faiss.normalize_L2(token_emb_sliced)\n", 173 | " print(\"Adding DB: \", token_emb_sliced.shape)\n", 174 | " print(f'Generating {index_type} index on config: {config_name}')\n", 175 | " \n", 176 | " tic = time.time()\n", 177 | " # Flat L2 Index\n", 178 | " if index_type == 'IP':\n", 179 | " index_file = f'results/embed/IP/{config_name}-dim{dim}_IP.faiss'\n", 180 | " sub_index = faiss.IndexFlatIP(dim)\n", 181 | " faiss_index = faiss.IndexIDMap2(sub_index)\n", 182 | "\n", 183 | " elif index_type == 'L2':\n", 184 | " index_file = f'results/embed/L2/{config_name}-dim{dim}_L2.faiss'\n", 185 | " sub_index = faiss.IndexFlatL2(dim)\n", 186 | " faiss_index = faiss.IndexIDMap2(sub_index)\n", 187 | "\n", 188 | " elif index_type == 'IVF':\n", 189 | " index_file = f'results/embed/IVF/{config_name}-dim{dim}_IVF_ncell{ncell}.faiss'\n", 190 | " quantizer = faiss.IndexFlatL2(dim)\n", 191 | " faiss_index = faiss.IndexIVFFlat(quantizer, dim, ncell)\n", 192 | " faiss_index.train(token_emb_sliced)\n", 193 | " \n", 194 | " elif index_type == 'OPQ':\n", 195 | " index_file = f'results/embed/OPQ/{config_name}-dim{dim}_OPQ_M{M}_nbits8.faiss'\n", 196 | " opq_train_db_indices = np.random.choice(token_emb_sliced.shape[0], 500000, replace=False)\n", 197 | " opq_train_db = token_emb_sliced[opq_train_db_indices]\n", 198 | " sub_index = faiss.index_factory(dim, f\"OPQ{M},PQ{M}x{8}\")\n", 199 | " faiss_index = faiss.IndexIDMap2(sub_index)\n", 200 | " faiss_index.train(opq_train_db)\n", 201 | "\n", 202 | " elif index_type == 'IVFOPQ':\n", 203 | " index_file = f'results/embed/IVFOPQ/{config_name}-dim{dim}_IVFOPQ_cell{ncell}_M{M}_nbits8.faiss'\n", 204 | " sub_index = faiss.index_factory(dim, f\"OPQ{M},IVF{ncell},PQ{M}x{8}\")\n", 205 | " faiss_index = faiss.IndexIDMap2(sub_index)\n", 206 | " faiss_index.train(token_emb_sliced)\n", 207 | " \n", 208 | " faiss_index.add_with_ids(token_emb_sliced, psg_ids)\n", 209 | " faiss.write_index(faiss_index, str(index_file))\n", 210 | " toc = time.time()\n", 211 | " \n", 212 | " print(\"Generated \", index_file)\n", 213 | " print(\"Time to build index with d=%d : %f\" %(dim, toc-tic))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "daef008c", 219 | "metadata": {}, 220 | "source": [ 221 | "# Search (restart kernel for memory)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 2, 227 | "id": "dbb3919e", 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "2023-05-21 20:04:27.733 | INFO | __main__:batch_eval_dataset:94 - init Retriever from model_ckpt=ckpt/dpr-nq-d768_384_192_96_48\n", 235 | "2023-05-21 20:04:35.608 | INFO | __main__:batch_eval_dataset:100 - loading index_file=results/embed/IVFOPQ/dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss...\n", 236 | "2023-05-21 20:04:41.547 | INFO | __main__:batch_eval_dataset:107 - loading passage_db_file=data/psgs-w100.lmdb...\n", 237 | "2023-05-21 20:04:41.758 | INFO | __main__:batch_eval_dataset:114 - loading QA pairs from qas-data/nq-test.csv\n", 238 | "2023-05-21 20:04:41.803 | INFO | __main__:batch_eval_dataset:119 - computing query embeddings...\n", 239 | "2023-05-21 20:04:41.804 | INFO | __main__:batch_eval_dataset:121 - begin searching max(top_k)=200 passage for 3610 question...\n", 240 | "search 1668.9 queries/s, checking answers: 100%|██████████| 8/8 [10:52<00:00, 81.59s/it]\n", 241 | "2023-05-21 20:15:34.546 | INFO | __main__:batch_eval_dataset:154 - #total examples: 3610\n", 242 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@1:0.26204986149584486 correct_samples:946\n", 243 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@5:0.46204986149584487 correct_samples:1668\n", 244 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@20:0.5997229916897507 correct_samples:2165\n", 245 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@100:0.714404432132964 correct_samples:2579\n", 246 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@200:0.7484764542936289 correct_samples:2702\n", 247 | "Finished Processing!\n", 248 | "\n", 249 | "2023-05-21 20:15:40.640 | INFO | __main__:batch_eval_dataset:94 - init Retriever from model_ckpt=ckpt/dpr-nq-d768_384_192_96_48\n", 250 | "2023-05-21 20:15:43.152 | INFO | __main__:batch_eval_dataset:100 - loading index_file=results/embed/IVFOPQ/dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M32.faiss...\n", 251 | "2023-05-21 20:15:50.740 | INFO | __main__:batch_eval_dataset:107 - loading passage_db_file=data/psgs-w100.lmdb...\n", 252 | "2023-05-21 20:15:50.815 | INFO | __main__:batch_eval_dataset:114 - loading QA pairs from qas-data/nq-test.csv\n", 253 | "2023-05-21 20:15:50.851 | INFO | __main__:batch_eval_dataset:119 - computing query embeddings...\n", 254 | "2023-05-21 20:15:50.852 | INFO | __main__:batch_eval_dataset:121 - begin searching max(top_k)=200 passage for 3610 question...\n", 255 | "search 1811.9 queries/s, checking answers: 100%|██████████| 8/8 [06:27<00:00, 48.38s/it]\n", 256 | "2023-05-21 20:22:17.956 | INFO | __main__:batch_eval_dataset:154 - #total examples: 3610\n", 257 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@1:0.33407202216066484 correct_samples:1206\n", 258 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@5:0.531578947368421 correct_samples:1919\n", 259 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@20:0.6493074792243767 correct_samples:2344\n", 260 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@100:0.7401662049861496 correct_samples:2672\n", 261 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@200:0.7678670360110803 correct_samples:2772\n", 262 | "Finished Processing!\n", 263 | "\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "%%bash\n", 269 | "split=test\n", 270 | "ds=nq\n", 271 | "\n", 272 | "# Change these\n", 273 | "d=768\n", 274 | "index_type=IVFOPQ\n", 275 | "config_name=dpr-nq-d768_384_192_96_48-wiki\n", 276 | "\n", 277 | "# Modify index_file name to the one built above\n", 278 | "for M in 16 32\n", 279 | "do\n", 280 | " for d in 768\n", 281 | " do\n", 282 | " python rtr/cli/eval_retriever.py \\\n", 283 | " --passage_db_file data/psgs-w100.lmdb \\\n", 284 | " --model_ckpt ckpt/{config_name} \\\n", 285 | " --index_file results/embed/${index_type}/${config_name}-dim${d}_${index_type}_cell${ncell}_M${M}.faiss \\\n", 286 | " --dataset_file qas-data/${ds}-${split}.csv \\\n", 287 | " --save_file results/json/reader-${config_name}-${ds}-${split}-dim${d}.jsonl \\\n", 288 | " --batch_size 512 \\\n", 289 | " --max_question_len 200 \\\n", 290 | " --embedding_size ${d} \\\n", 291 | " --metrics_file results/metrics.json \\\n", 292 | " --binary False \\\n", 293 | " 2>&1 | tee results/logs/eval-${config_name}-${ds}-${split}-dim${d}.log\n", 294 | " echo -e \"Finished Processing!\\n\"\n", 295 | " done\n", 296 | "done" 297 | ] 298 | } 299 | ], 300 | "metadata": { 301 | "kernelspec": { 302 | "display_name": "Python 3 (ipykernel)", 303 | "language": "python", 304 | "name": "python3" 305 | }, 306 | "language_info": { 307 | "codemirror_mode": { 308 | "name": "ipython", 309 | "version": 3 310 | }, 311 | "file_extension": ".py", 312 | "mimetype": "text/x-python", 313 | "name": "python", 314 | "nbconvert_exporter": "python", 315 | "pygments_lexer": "ipython3", 316 | "version": "3.10.8" 317 | } 318 | }, 319 | "nbformat": 4, 320 | "nbformat_minor": 5 321 | } 322 | -------------------------------------------------------------------------------- /adanns/generate_nn/hnsw_exactl2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "cc5c9e14", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import faiss\n", 12 | "import time\n", 13 | "import pandas as pd\n", 14 | "from os import path, makedirs\n", 15 | "import torch\n", 16 | "import sys\n", 17 | "sys.path.append('../')\n", 18 | "from utils import load_embeddings" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "20116c14", 24 | "metadata": {}, 25 | "source": [ 26 | "## Configuration Variables" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "3be50514", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "root = '../../../inference_array/resnet50/' # path to database and queryset\n", 37 | "D = 2048 # embedding dim\n", 38 | "hnsw_max_neighbors = 32 # M for HNSW, default=32\n", 39 | "pq_num_subvectors = 32 # m for HNSW+PQ\n", 40 | "\n", 41 | "model = 'mrl' # mrl, rr\n", 42 | "dataset = '1K' # 1K, 4K, V2\n", 43 | "index_type = 'hnsw32' # exactl2, hnsw32, #'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(pq_num_subvectors)\n", 44 | "\n", 45 | "k = 2048 # shortlist length, default is set to the max supported by FAISS\n", 46 | "nesting_list = [8, 16, 32, 64] # embedding dim to loop over" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "f4401b8c", 52 | "metadata": {}, 53 | "source": [ 54 | "## FAISS Index Building and NN Search" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "5fd3e4f5", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "if index_type == 'exactl2' and torch.cuda.device_count() > 0:\n", 65 | " use_gpu = True # GPU inference for exact search\n", 66 | "else:\n", 67 | " use_gpu = False # GPU inference for HNSW is currently not supported by FAISS" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "f862bf36", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def get_nn(index_type, nesting_list, m=8):\n", 78 | " for retrieval_dim in nesting_list:\n", 79 | " if retrieval_dim > D:\n", 80 | " continue\n", 81 | " \n", 82 | " if index_type == 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m) and retrieval_dim < m:\n", 83 | " continue\n", 84 | " \n", 85 | " if not path.isdir(root+'index_files/'+model+'/'):\n", 86 | " makedirs(root+'index_files/'+model+'/')\n", 87 | " index_file = root+'index_files/'+model+'/'+dataset+'_'+str(retrieval_dim)+'dim_'+index_type+'.index'\n", 88 | "\n", 89 | " _, _, _, _, xb, xq = load_embeddings(model, dataset, retrieval_dim)\n", 90 | "\n", 91 | " # Load or build index\n", 92 | " if path.exists(index_file): # Load index\n", 93 | " print(\"Loading index file: \" + index_file.split(\"/\")[-1])\n", 94 | " cpu_index = faiss.read_index(index_file)\n", 95 | " \n", 96 | " else: # Build index\n", 97 | " print(\"Generating index file: \" + index_file)\n", 98 | "\n", 99 | " d = xb.shape[1] # dimension\n", 100 | "\n", 101 | " start = time.time()\n", 102 | " if index_type == 'exactl2':\n", 103 | " print(\"Building Exact L2 Index\")\n", 104 | " cpu_index = faiss.IndexFlatL2(d) # build the index\n", 105 | " elif index_type == 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m):\n", 106 | " print(\"Building D%d + HNSW%d + PQ%d Index\" % (d, hnsw_max_neighbors, m))\n", 107 | " cpu_index = faiss.IndexHNSWPQ(d, m, hnsw_max_neighbors)\n", 108 | " cpu_index.train(xb)\n", 109 | " elif index_type == f'hnsw{hnsw_max_neighbors}':\n", 110 | " print(\"Building HNSW%d Index\" % hnsw_max_neighbors)\n", 111 | " cpu_index = faiss.IndexHNSWFlat(d, hnsw_max_neighbors)\n", 112 | " else:\n", 113 | " raise Exception(f\"Unsupported Index: {index_type}\")\n", 114 | " \n", 115 | " cpu_index.add(xb) # add vectors to the index\n", 116 | " faiss.write_index(cpu_index, index_file)\n", 117 | " print(\"GPU Index build time= %0.3f sec\" % (time.time() - start))\n", 118 | "\n", 119 | " if use_gpu:\n", 120 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n", 121 | " else:\n", 122 | " index = cpu_index\n", 123 | " \n", 124 | " # Iterate over efSearch (HNSW search probes)\n", 125 | " efsearchlist = [16]\n", 126 | " for efsearch in efsearchlist:\n", 127 | " start = time.time()\n", 128 | " if index_type in ['hnsw32', 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m)]:\n", 129 | " index.hnsw.efSearch = efsearch\n", 130 | " print(\"Searching with Efsearch =\", index.hnsw.efSearch)\n", 131 | " Dist, Ind = index.search(xq, k)\n", 132 | " # print(\"GPU %d-NN search time= %f sec\" % (k, time.time() - start))\n", 133 | " if not path.isdir(root+\"neighbors/\"+model+'/'+index_type):\n", 134 | " makedirs(root+\"neighbors/\"+model+'/'+index_type)\n", 135 | " nn_dir = root+\"neighbors/\"+model+'/'+index_type+\"/\"+index_type+'_efsearch'+str(efsearch)+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(retrieval_dim)+\".csv\"\n", 136 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n", 137 | " \n", 138 | " del index, Dist, Ind" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "id": "f5efe3d2", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_8dim_hnsw32.index\n", 152 | "Building HNSW32 Index\n", 153 | "GPU Index build time= 57.573 sec\n", 154 | "Searching with Efsearch = 16\n", 155 | "13010622\n", 156 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_16dim_hnsw32.index\n", 157 | "Building HNSW32 Index\n", 158 | "GPU Index build time= 71.297 sec\n", 159 | "Searching with Efsearch = 16\n", 160 | "15721346\n", 161 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_32dim_hnsw32.index\n", 162 | "Building HNSW32 Index\n", 163 | "GPU Index build time= 74.260 sec\n", 164 | "Searching with Efsearch = 16\n", 165 | "16834028\n", 166 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_64dim_hnsw32.index\n", 167 | "Building HNSW32 Index\n", 168 | "GPU Index build time= 77.627 sec\n", 169 | "Searching with Efsearch = 16\n", 170 | "17557581\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "nesting_list = [8, 16, 32, 64]\n", 176 | "get_nn(index_type, nesting_list)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "id": "87ea3f22", 183 | "metadata": { 184 | "scrolled": false 185 | }, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "Loading index file: 1K_8dim_hnswpq_M32_pq-m8.index\n", 192 | "Searching with Efsearch = 16\n", 193 | "Loading index file: 1K_16dim_hnswpq_M32_pq-m8.index\n", 194 | "Searching with Efsearch = 16\n", 195 | "Loading index file: 1K_32dim_hnswpq_M32_pq-m8.index\n", 196 | "Searching with Efsearch = 16\n", 197 | "Loading index file: 1K_64dim_hnswpq_M32_pq-m8.index\n", 198 | "Searching with Efsearch = 16\n", 199 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_16dim_hnswpq_M32_pq-m16.index\n", 200 | "Building D16 + HNSW32 + PQ16 Index\n", 201 | "GPU Index build time= 180.430 sec\n", 202 | "Searching with Efsearch = 16\n", 203 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_32dim_hnswpq_M32_pq-m16.index\n", 204 | "Building D32 + HNSW32 + PQ16 Index\n", 205 | "GPU Index build time= 180.994 sec\n", 206 | "Searching with Efsearch = 16\n", 207 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_64dim_hnswpq_M32_pq-m16.index\n", 208 | "Building D64 + HNSW32 + PQ16 Index\n", 209 | "GPU Index build time= 182.374 sec\n", 210 | "Searching with Efsearch = 16\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "# k = 40 to generate Exact Ground Truth for 40-Recall@2048\n", 216 | "# nesting_list = [D] # fixed embedding dimension for RR models\n", 217 | "pq_m_values = [8, 16] # loop over PQ m values\n", 218 | "\n", 219 | "for m in pq_m_values:\n", 220 | " index_type = 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m)\n", 221 | " get_nn(index_type, nesting_list, m)" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "jupytext": { 227 | "cell_metadata_filter": "-all", 228 | "notebook_metadata_filter": "-all" 229 | }, 230 | "kernelspec": { 231 | "display_name": "Python 3 (ipykernel)", 232 | "language": "python", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.11.3" 246 | }, 247 | "vscode": { 248 | "interpreter": { 249 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc" 250 | } 251 | } 252 | }, 253 | "nbformat": 4, 254 | "nbformat_minor": 5 255 | } 256 | -------------------------------------------------------------------------------- /adanns/generate_nn/ivf-experiments.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "18c0eff4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import faiss\n", 12 | "import time\n", 13 | "import torch\n", 14 | "import pandas as pd\n", 15 | "from os import path, makedirs\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import sys\n", 18 | "sys.path.append('../')\n", 19 | "from utils import load_embeddings" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "44133407", 25 | "metadata": {}, 26 | "source": [ 27 | "## CONFIG" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "id": "99918b2f", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "model = 'mrl' # mrl, rr\n", 38 | "arch = 'resnet50' # resnet18, resnet34, resnet50, resnet101, mobilenetv2\n", 39 | "root = f'../../../inference_array/{arch}'\n", 40 | "dataset = '1K' # 1K, 4K, V2, inat" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "id": "6a04810b", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "index_type = 'ivf' # ivfpq, ivfsq, lsh, pq, ivf, opq\n", 51 | "\n", 52 | "hnsw_max_neighbors = 32 # 8, 32\n", 53 | "\n", 54 | "k = 2048 # shortlist length, default set to max supported by FAISS\n", 55 | "\n", 56 | "nesting_list_dict = {\n", 57 | " 'vgg19': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],\n", 58 | " 'resnet18': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],\n", 59 | " 'resnet34': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],\n", 60 | " 'resnet50': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048],\n", 61 | " 'resnet101': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048],\n", 62 | " 'mobilenetv2': [10, 20, 40, 80, 160, 320, 640, 1280],\n", 63 | " 'convnext_tiny': [12, 24, 48, 96, 192, 384, 768]\n", 64 | "}\n", 65 | "\n", 66 | "nesting_list = nesting_list_dict[arch]\n", 67 | "\n", 68 | "ivf_configs = {\n", 69 | " 'C1': {'nlist': 2048},\n", 70 | " 'C2': {'nlist': 512},\n", 71 | " 'C3': {'nlist': 128},\n", 72 | " 'C4': {'nlist': 4096},\n", 73 | " 'C5': {'nlist': 8192},\n", 74 | " 'C6': {'nlist': 256},\n", 75 | " 'C7': {'nlist': 1024}\n", 76 | "}\n", 77 | "\n", 78 | "config_id = 'C7'\n", 79 | " \n", 80 | "# nbits = ivfpq_configs[config_id]['nbits'] # nbits used to represent centroid id; total possible is k* = 2**nbits\n", 81 | "\n", 82 | "nlist = ivf_configs[config_id]['nlist'] # how many Voronoi cells (must be >= k*)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "id": "91f822b8", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "if index_type == 'exactl2':\n", 93 | " use_gpu = 1 # GPU inference for exact search\n", 94 | "else:\n", 95 | " use_gpu = 0\n", 96 | "\n", 97 | "if use_gpu and faiss.get_num_gpus() == 0:\n", 98 | " raise Exception(\"GPU search is enabled but no GPU was found.\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "67db9767", 104 | "metadata": {}, 105 | "source": [ 106 | "## IVF Experiments" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 7, 112 | "id": "4ca6a7d7", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "*************************************************\n", 120 | "FF-8 IVF \n", 121 | "\n", 122 | "*************************************************\n", 123 | "Loaded DB: 1K_train_mrl0_e0_ff8-X.npy\n", 124 | "DB size: 39.10 MB\n", 125 | "DB: (1281167, 8), Q: (50000, 8)\n", 126 | "Indexing database: (1281167, 8)\n", 127 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist256_d8\n", 128 | "Training IVF Index: d=8, nlist=256\n", 129 | "IVF Index train time= 0.141 sec\n", 130 | "NProbe: 1\n", 131 | "Searching queries: (50000, 8)\n", 132 | "GPU 2048-NN search time= 2.962470 sec\n", 133 | "\n", 134 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist256_nprobe1_2048shortlist_1K_d8.csv\n", 135 | "\n", 136 | "*************************************************\n", 137 | "Indexing database: (1281167, 8)\n", 138 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist512_d8\n", 139 | "Training IVF Index: d=8, nlist=512\n", 140 | "IVF Index train time= 0.272 sec\n", 141 | "NProbe: 1\n", 142 | "Searching queries: (50000, 8)\n", 143 | "GPU 2048-NN search time= 2.063317 sec\n", 144 | "\n", 145 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist512_nprobe1_2048shortlist_1K_d8.csv\n", 146 | "\n", 147 | "*************************************************\n", 148 | "Indexing database: (1281167, 8)\n", 149 | "Loading index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist1024_d8\n", 150 | "NProbe: 1\n", 151 | "Searching queries: (50000, 8)\n", 152 | "GPU 2048-NN search time= 0.369938 sec\n", 153 | "\n", 154 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist1024_nprobe1_2048shortlist_1K_d8.csv\n", 155 | "\n", 156 | "*************************************************\n", 157 | "Indexing database: (1281167, 8)\n", 158 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist2048_d8\n", 159 | "Training IVF Index: d=8, nlist=2048\n", 160 | "IVF Index train time= 6.959 sec\n", 161 | "NProbe: 1\n", 162 | "Searching queries: (50000, 8)\n", 163 | "GPU 2048-NN search time= 1.799566 sec\n", 164 | "\n", 165 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist2048_nprobe1_2048shortlist_1K_d8.csv\n", 166 | "\n", 167 | "*************************************************\n", 168 | "Indexing database: (1281167, 8)\n", 169 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist4096_d8\n", 170 | "Training IVF Index: d=8, nlist=4096\n", 171 | "IVF Index train time= 30.232 sec\n", 172 | "NProbe: 1\n", 173 | "Searching queries: (50000, 8)\n", 174 | "GPU 2048-NN search time= 1.340459 sec\n", 175 | "\n", 176 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist4096_nprobe1_2048shortlist_1K_d8.csv\n", 177 | "\n", 178 | "*************************************************\n", 179 | "Indexing database: (1281167, 8)\n", 180 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist8192_d8\n", 181 | "Training IVF Index: d=8, nlist=8192\n", 182 | "IVF Index train time= 74.361 sec\n", 183 | "NProbe: 1\n", 184 | "Searching queries: (50000, 8)\n", 185 | "GPU 2048-NN search time= 1.129385 sec\n", 186 | "\n", 187 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist8192_nprobe1_2048shortlist_1K_d8.csv\n", 188 | "\n", 189 | "*************************************************\n", 190 | "*************************************************\n", 191 | "FF-16 IVF \n", 192 | "\n", 193 | "*************************************************\n", 194 | "Loaded DB: 1K_train_mrl0_e0_ff16-X.npy\n", 195 | "DB size: 78.20 MB\n", 196 | "DB: (1281167, 16), Q: (50000, 16)\n", 197 | "Indexing database: (1281167, 16)\n", 198 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist256_d16\n", 199 | "Training IVF Index: d=16, nlist=256\n", 200 | "IVF Index train time= 0.247 sec\n", 201 | "NProbe: 1\n", 202 | "Searching queries: (50000, 16)\n", 203 | "GPU 2048-NN search time= 6.624912 sec\n", 204 | "\n", 205 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist256_nprobe1_2048shortlist_1K_d16.csv\n", 206 | "\n", 207 | "*************************************************\n", 208 | "Indexing database: (1281167, 16)\n", 209 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist512_d16\n", 210 | "Training IVF Index: d=16, nlist=512\n", 211 | "IVF Index train time= 0.319 sec\n", 212 | "NProbe: 1\n", 213 | "Searching queries: (50000, 16)\n", 214 | "GPU 2048-NN search time= 2.174936 sec\n", 215 | "\n", 216 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist512_nprobe1_2048shortlist_1K_d16.csv\n", 217 | "\n", 218 | "*************************************************\n", 219 | "Indexing database: (1281167, 16)\n", 220 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist1024_d16\n", 221 | "Training IVF Index: d=16, nlist=1024\n", 222 | "IVF Index train time= 0.973 sec\n", 223 | "NProbe: 1\n", 224 | "Searching queries: (50000, 16)\n", 225 | "GPU 2048-NN search time= 1.328670 sec\n", 226 | "\n", 227 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n", 228 | "\n", 229 | "*************************************************\n", 230 | "Indexing database: (1281167, 16)\n", 231 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist2048_d16\n", 232 | "Training IVF Index: d=16, nlist=2048\n", 233 | "IVF Index train time= 3.577 sec\n", 234 | "NProbe: 1\n", 235 | "Searching queries: (50000, 16)\n", 236 | "GPU 2048-NN search time= 1.002080 sec\n", 237 | "\n", 238 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist2048_nprobe1_2048shortlist_1K_d16.csv\n", 239 | "\n", 240 | "*************************************************\n", 241 | "Indexing database: (1281167, 16)\n", 242 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist4096_d16\n", 243 | "Training IVF Index: d=16, nlist=4096\n", 244 | "IVF Index train time= 17.197 sec\n", 245 | "NProbe: 1\n", 246 | "Searching queries: (50000, 16)\n", 247 | "GPU 2048-NN search time= 0.797544 sec\n", 248 | "\n", 249 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist4096_nprobe1_2048shortlist_1K_d16.csv\n", 250 | "\n", 251 | "*************************************************\n", 252 | "Indexing database: (1281167, 16)\n", 253 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist8192_d16\n", 254 | "Training IVF Index: d=16, nlist=8192\n", 255 | "IVF Index train time= 35.940 sec\n", 256 | "NProbe: 1\n", 257 | "Searching queries: (50000, 16)\n", 258 | "GPU 2048-NN search time= 0.695007 sec\n", 259 | "\n", 260 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist8192_nprobe1_2048shortlist_1K_d16.csv\n", 261 | "\n", 262 | "*************************************************\n", 263 | "*************************************************\n", 264 | "FF-512 IVF \n", 265 | "\n", 266 | "*************************************************\n", 267 | "Loaded DB: 1K_train_mrl0_e0_ff512-X.npy\n", 268 | "DB size: 2502.28 MB\n", 269 | "DB: (1281167, 512), Q: (50000, 512)\n", 270 | "Indexing database: (1281167, 512)\n", 271 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist256_d512\n", 272 | "Training IVF Index: d=512, nlist=256\n", 273 | "IVF Index train time= 1.085 sec\n", 274 | "NProbe: 1\n", 275 | "Searching queries: (50000, 512)\n", 276 | "GPU 2048-NN search time= 10.056288 sec\n", 277 | "\n", 278 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist256_nprobe1_2048shortlist_1K_d512.csv\n", 279 | "\n", 280 | "*************************************************\n", 281 | "Indexing database: (1281167, 512)\n", 282 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist512_d512\n", 283 | "Training IVF Index: d=512, nlist=512\n", 284 | "IVF Index train time= 2.630 sec\n", 285 | "NProbe: 1\n", 286 | "Searching queries: (50000, 512)\n", 287 | "GPU 2048-NN search time= 5.309681 sec\n", 288 | "\n", 289 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist512_nprobe1_2048shortlist_1K_d512.csv\n", 290 | "\n", 291 | "*************************************************\n", 292 | "Indexing database: (1281167, 512)\n", 293 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist1024_d512\n", 294 | "Training IVF Index: d=512, nlist=1024\n", 295 | "IVF Index train time= 8.684 sec\n", 296 | "NProbe: 1\n", 297 | "Searching queries: (50000, 512)\n", 298 | "GPU 2048-NN search time= 2.672392 sec\n", 299 | "\n", 300 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n", 301 | "\n", 302 | "*************************************************\n", 303 | "Indexing database: (1281167, 512)\n", 304 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist2048_d512\n", 305 | "Training IVF Index: d=512, nlist=2048\n", 306 | "IVF Index train time= 35.140 sec\n", 307 | "NProbe: 1\n", 308 | "Searching queries: (50000, 512)\n", 309 | "GPU 2048-NN search time= 1.662383 sec\n", 310 | "\n", 311 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist2048_nprobe1_2048shortlist_1K_d512.csv\n", 312 | "\n", 313 | "*************************************************\n", 314 | "Indexing database: (1281167, 512)\n", 315 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist4096_d512\n", 316 | "Training IVF Index: d=512, nlist=4096\n", 317 | "IVF Index train time= 130.656 sec\n", 318 | "NProbe: 1\n", 319 | "Searching queries: (50000, 512)\n", 320 | "GPU 2048-NN search time= 1.486253 sec\n", 321 | "\n", 322 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist4096_nprobe1_2048shortlist_1K_d512.csv\n", 323 | "\n", 324 | "*************************************************\n", 325 | "Indexing database: (1281167, 512)\n", 326 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist8192_d512\n", 327 | "Training IVF Index: d=512, nlist=8192\n" 328 | ] 329 | }, 330 | { 331 | "name": "stdout", 332 | "output_type": "stream", 333 | "text": [ 334 | "IVF Index train time= 397.688 sec\n", 335 | "NProbe: 1\n", 336 | "Searching queries: (50000, 512)\n", 337 | "GPU 2048-NN search time= 2.997688 sec\n", 338 | "\n", 339 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist8192_nprobe1_2048shortlist_1K_d512.csv\n", 340 | "\n", 341 | "*************************************************\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "Ds = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n", 347 | "num_cell_list = [256, 512, 1024, 2048, 4096, 8192] # IVF number of clusters\n", 348 | "nprobes = [1, 2, 4]\n", 349 | "\n", 350 | "for D in Ds:\n", 351 | " \n", 352 | " database, queryset, _, _, _, _ = load_embeddings(model, dataset, D, arch=arch)\n", 353 | "\n", 354 | " num_cell_list = [256, 512, 1024, 2048, 4096, 8192] # IVF num of clusters\n", 355 | " for num_cell in num_cell_list:\n", 356 | " if not path.isdir(root+'index_files/'+model+'/'):\n", 357 | " makedirs(root+'index_files/'+model+'/')\n", 358 | "\n", 359 | " print(\"Indexing database: \", database.shape)\n", 360 | " d = database.shape[1]\n", 361 | " index_file = root+'index_files/'+model+'/'+dataset+'_ivf'+'_nlist'+str(num_cell)+\"_d\"+str(d)\n", 362 | " \n", 363 | " # Load or build index\n", 364 | " if path.exists(index_file+'.index'):\n", 365 | " print(\"Loading index file: ” + index_file)\n", 366 | " cpu_index = faiss.read_index(index_file+'.index')\n", 367 | " else:\n", 368 | " if index_type == 'ivf':\n", 369 | " print(\"Generating index file: ” + index_file)\n", 370 | " # Generate IVF Index File\n", 371 | " quantizer = faiss.IndexFlatL2(d) # L2 quantizer to assign vectors to Voronoi cells\n", 372 | " cpu_index = faiss.IndexIVFFlat(quantizer, d, num_cell)\n", 373 | " print(f”Training IVF Index: d={d}, nlist={num_cell}\")\n", 374 | " ivf_start = time.time()\n", 375 | " cpu_index.train(database)\n", 376 | " ivf_end = time.time()\n", 377 | " print(\"IVF Index train time= %0.3f sec” % (ivf_end - ivf_start))\n", 378 | " elif index_type == 'exactl2':\n", 379 | " print(\"Building Exact L2 Index”)\n", 380 | " cpu_index = faiss.IndexFlatL2(d)\n", 381 | " cpu_index.add(database) # add vectors to the index\n", 382 | " faiss.write_index(cpu_index, index_file+'.index')\n", 383 | " if use_gpu:\n", 384 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n", 385 | " else:\n", 386 | " index = cpu_index\n", 387 | " for nprobe in nprobes: # nprobe for IVF\n", 388 | " print(f”NProbe: {nprobe}\")\n", 389 | " print(\"Searching queries: \", queryset.shape)\n", 390 | " start = time.time()\n", 391 | " index.nprobe = nprobe\n", 392 | " Dist, Ind = index.search(queryset, k) # k->shorlist length\n", 393 | " end = time.time() - start\n", 394 | " print(\"%d-NN search time= %f sec\\n” % (k, end))\n", 395 | " if not path.isdir(root+\"neighbors/\"+model):\n", 396 | " makedirs(root+\"neighbors/\"+model)\n", 397 | " nn_dir = root+\"neighbors/\"+model+index_type+\"_nlist”+str(num_cell)+f”_nprobe{nprobe}_\"+str(k)+\"shortlist_\"+dataset+f”_d{d}.csv”\n", 398 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n", 399 | " print(\"Writing NN csv to %s\\n” % (nn_dir))\n", 400 | " del index\n", 401 | " print(\"*************************************************\")" 402 | ] 403 | } 404 | ], 405 | "metadata": { 406 | "kernelspec": { 407 | "display_name": "Python 3 (ipykernel)", 408 | "language": "python", 409 | "name": "python3" 410 | }, 411 | "language_info": { 412 | "codemirror_mode": { 413 | "name": "ipython", 414 | "version": 3 415 | }, 416 | "file_extension": ".py", 417 | "mimetype": "text/x-python", 418 | "name": "python", 419 | "nbconvert_exporter": "python", 420 | "pygments_lexer": "ipython3", 421 | "version": "3.10.7" 422 | } 423 | }, 424 | "nbformat": 4, 425 | "nbformat_minor": 5 426 | } 427 | -------------------------------------------------------------------------------- /adanns/generate_nn/ivfpq_opq_kmeans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "cc5c9e14", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import faiss\n", 12 | "import time\n", 13 | "import pandas as pd\n", 14 | "from os import path, makedirs\n", 15 | "import sys\n", 16 | "\n", 17 | "sys.path.append('../')\n", 18 | "from utils import get_duplicate_dist" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "20116c14", 24 | "metadata": {}, 25 | "source": [ 26 | "## Configuration Variables" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "id": "3be50514", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "root = '../../../inference_array/resnet50/'\n", 37 | "model = 'mrl/' # mrl/, rr/\n", 38 | "\n", 39 | "dataset = '1K' # 1K, 4K, V2\n", 40 | "index_type = 'ivfpq' # ivfpq, ivfsq, opq, kmeans, pq\n", 41 | "use_svd = False\n", 42 | " \n", 43 | "use_gpu = 0 # GPU inference for exact search. Disable for CPU search\n", 44 | "if use_gpu and faiss.get_num_gpus() == 0:\n", 45 | " raise Exception(\"GPU search is enabled but no GPU was found.\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "f4401b8c", 51 | "metadata": {}, 52 | "source": [ 53 | "## FAISS Index Building and NN Search" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "id": "67862f80", 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Loading database: 1K_train_mrl1_e0_ff2048-X.npy\n", 67 | "Loading queries: 1K_val_mrl1_e0_ff2048-X.npy\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "if model == \"mrl/\":\n", 73 | " config = 'mrl1_e0_ff2048'\n", 74 | "elif model == 'rr/':\n", 75 | " config = 'mrl0_e0_ff'\n", 76 | "else: \n", 77 | " raise Exception(\"Unsupported pretrained model.\")\n", 78 | "\n", 79 | "db_npy = dataset + '_train_' + config + '-X.npy'\n", 80 | "query_npy = dataset + '_val_' + config + '-X.npy'\n", 81 | "\n", 82 | "# ImageNetv2 is only a test set; set database to ImageNet-1K\n", 83 | "if dataset == 'V2':\n", 84 | " db_npy = '1K_train_' + config + '-X.npy'\n", 85 | "\n", 86 | "print(\"Loading database: \", db_npy)\n", 87 | "print(\"Loading queries: \", query_npy)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "21ee5576", 93 | "metadata": {}, 94 | "source": [ 95 | "## SVD Ablation" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "id": "cca2a7aa", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "def dim_reduce_pca(full_database, low_dim):\n", 106 | " subsampled_db = np.ascontiguousarray(full_database[0::10], dtype=np.float32)\n", 107 | " \n", 108 | " mat = faiss.PCAMatrix (subsampled_db.shape[1], low_dim)\n", 109 | " mat.train(subsampled_db)\n", 110 | " assert mat.is_trained\n", 111 | " \n", 112 | " return mat\n", 113 | "\n", 114 | "def get_svd_data(svd_low_dim):\n", 115 | " print(\"Using PCA for low-dim projection with d= \", svd_low_dim)\n", 116 | "\n", 117 | " database = np.load(root+db_npy)\n", 118 | " queryset = np.load(root+query_npy)\n", 119 | "\n", 120 | " print(\"Original Database: \", database.shape)\n", 121 | " svd_mat = dim_reduce_pca(database, svd_low_dim)\n", 122 | " database_svd_lowdim = svd_mat.apply(database)\n", 123 | " print(\"Low-d Database: \", database_svd_lowdim.shape)\n", 124 | " query_svd_lowdim = svd_mat.apply(queryset)\n", 125 | "\n", 126 | " faiss.normalize_L2(database_svd_lowdim)\n", 127 | " faiss.normalize_L2(query_svd_lowdim)\n", 128 | " \n", 129 | " return database_svd_lowdim, query_svd_lowdim\n", 130 | "\n", 131 | "if use_svd:\n", 132 | " svd_low_dim = 1024\n", 133 | " db_npy_svd = dataset+'_train_'+config+\"_svd\"+str(svd_low_dim)+\"-X.npy\"\n", 134 | " query_npy_svd = dataset+'_val_'+config+\"_svd\"+str(svd_low_dim)+\"-X.npy\"\n", 135 | " database_svd_lowdim, query_svd_lowdim = get_svd_data(svd_low_dim)\n", 136 | " \n", 137 | " if not path.exists(root+db_npy_svd):\n", 138 | " np.save(root+db_npy_svd, database_svd_lowdim)\n", 139 | " if not path.exists(root+query_npy_svd):\n", 140 | " np.save(root+query_npy_svd, query_svd_lowdim)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "id": "5bdb1ac4", 146 | "metadata": {}, 147 | "source": [ 148 | "## Full Precision Embedding Loading (MRL)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "id": "31543d33", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "if model == 'mrl/':\n", 159 | " if use_svd:\n", 160 | " xb = np.load(root+db_npy_svd)\n", 161 | " assert np.count_nonzero(np.isnan(xb)) == 0\n", 162 | " xq = np.load(root+query_npy_svd)\n", 163 | " else:\n", 164 | " xb = np.load(root+db_npy)\n", 165 | " assert np.count_nonzero(np.isnan(xb)) == 0\n", 166 | " xq = np.load(root+query_npy)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "id": "2067c840", 172 | "metadata": {}, 173 | "source": [ 174 | "## Database Indexing and Search" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "id": "60e69cb5", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "k = 2048 # nearest neighbors shortlist length, default = 2048 = max value supported by FAISS\n", 185 | "#for any pq index, iterator is m. For kmeans, iterator is number of centroids\n", 186 | "iterator = [8, 16, 32, 64] # M for PQ\n", 187 | "Ds = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] # embedding dimensionality\n", 188 | "\n", 189 | "# IVF index specific params\n", 190 | "nprobes = [1] # number of search probes (cells to search)\n", 191 | "nbits = 8 # nbits used to represent centroid id; total possible is k* = 2**nbits\n", 192 | "nlist = 1024 # number of cells (must be >= k*)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "id": "a798dca2", 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "# Return index file name of given dimensionality based on index type\n", 203 | "def get_index_file(dim):\n", 204 | " if index_type in ['ivfpq', 'opq', 'pq']:\n", 205 | " size = 'm'+str(dim)+\"_d\"+str(D)\n", 206 | " elif index_type == 'kmeans':\n", 207 | " size = str(dim)+'ncentroid_'+str(D)+'d'\n", 208 | " elif index_type == 'kmeans-pca':\n", 209 | " size = str(dim)+'ncentroid_'+str(svd_low_dim)+'pcadim'\n", 210 | " else:\n", 211 | " raise Exception(\"Unsupported Index!\")\n", 212 | "\n", 213 | " index_file = root+'index_files/'+model+index_type+\"/\"+dataset+'_'+index_type+'_'+size\n", 214 | " if index_type in ['ivfpq', 'ivfsq']:\n", 215 | " index_file += \"_nbits\"+str(nbits)+'_nlist'+str(nlist)+'.index'\n", 216 | " \n", 217 | " return index_file" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 8, 223 | "id": "87ea3f22", 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "Indexing database (1281167, 8)\n", 231 | "Generating index file: 1K_ivfpq_m8_d8_nbits8_nlist1024.index\n", 232 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d8_nbits8_nlist1024.index\n", 233 | "Searching queries: (50000, 8)\n", 234 | "nprobe: 1\n", 235 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d8.csv\n", 236 | "\n", 237 | "Indexing database (1281167, 16)\n", 238 | "Generating index file: 1K_ivfpq_m8_d16_nbits8_nlist1024.index\n", 239 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d16_nbits8_nlist1024.index\n", 240 | "Searching queries: (50000, 16)\n", 241 | "nprobe: 1\n", 242 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n", 243 | "\n", 244 | "Indexing database (1281167, 16)\n", 245 | "Generating index file: 1K_ivfpq_m16_d16_nbits8_nlist1024.index\n", 246 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d16_nbits8_nlist1024.index\n", 247 | "Searching queries: (50000, 16)\n", 248 | "nprobe: 1\n", 249 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n", 250 | "\n", 251 | "Indexing database (1281167, 32)\n", 252 | "Generating index file: 1K_ivfpq_m8_d32_nbits8_nlist1024.index\n", 253 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d32_nbits8_nlist1024.index\n", 254 | "Searching queries: (50000, 32)\n", 255 | "nprobe: 1\n", 256 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n", 257 | "\n", 258 | "Indexing database (1281167, 32)\n", 259 | "Generating index file: 1K_ivfpq_m16_d32_nbits8_nlist1024.index\n", 260 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d32_nbits8_nlist1024.index\n", 261 | "Searching queries: (50000, 32)\n", 262 | "nprobe: 1\n", 263 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n", 264 | "\n", 265 | "Indexing database (1281167, 32)\n", 266 | "Generating index file: 1K_ivfpq_m32_d32_nbits8_nlist1024.index\n", 267 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d32_nbits8_nlist1024.index\n", 268 | "Searching queries: (50000, 32)\n", 269 | "nprobe: 1\n", 270 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n", 271 | "\n", 272 | "Indexing database (1281167, 64)\n", 273 | "Generating index file: 1K_ivfpq_m8_d64_nbits8_nlist1024.index\n", 274 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d64_nbits8_nlist1024.index\n", 275 | "Searching queries: (50000, 64)\n", 276 | "nprobe: 1\n", 277 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n", 278 | "\n", 279 | "Indexing database (1281167, 64)\n", 280 | "Generating index file: 1K_ivfpq_m16_d64_nbits8_nlist1024.index\n", 281 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d64_nbits8_nlist1024.index\n", 282 | "Searching queries: (50000, 64)\n", 283 | "nprobe: 1\n", 284 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n", 285 | "\n", 286 | "Indexing database (1281167, 64)\n", 287 | "Generating index file: 1K_ivfpq_m32_d64_nbits8_nlist1024.index\n", 288 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d64_nbits8_nlist1024.index\n", 289 | "Searching queries: (50000, 64)\n", 290 | "nprobe: 1\n", 291 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n", 292 | "\n", 293 | "Indexing database (1281167, 64)\n", 294 | "Generating index file: 1K_ivfpq_m64_d64_nbits8_nlist1024.index\n", 295 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d64_nbits8_nlist1024.index\n", 296 | "Searching queries: (50000, 64)\n", 297 | "nprobe: 1\n", 298 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n", 299 | "\n", 300 | "Indexing database (1281167, 128)\n", 301 | "Generating index file: 1K_ivfpq_m8_d128_nbits8_nlist1024.index\n", 302 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d128_nbits8_nlist1024.index\n", 303 | "Searching queries: (50000, 128)\n", 304 | "nprobe: 1\n", 305 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n", 306 | "\n", 307 | "Indexing database (1281167, 128)\n", 308 | "Generating index file: 1K_ivfpq_m16_d128_nbits8_nlist1024.index\n", 309 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d128_nbits8_nlist1024.index\n", 310 | "Searching queries: (50000, 128)\n", 311 | "nprobe: 1\n", 312 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n", 313 | "\n", 314 | "Indexing database (1281167, 128)\n", 315 | "Generating index file: 1K_ivfpq_m32_d128_nbits8_nlist1024.index\n", 316 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d128_nbits8_nlist1024.index\n", 317 | "Searching queries: (50000, 128)\n", 318 | "nprobe: 1\n", 319 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n", 320 | "\n", 321 | "Indexing database (1281167, 128)\n", 322 | "Generating index file: 1K_ivfpq_m64_d128_nbits8_nlist1024.index\n", 323 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d128_nbits8_nlist1024.index\n", 324 | "Searching queries: (50000, 128)\n", 325 | "nprobe: 1\n", 326 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n", 327 | "\n", 328 | "Indexing database (1281167, 256)\n", 329 | "Generating index file: 1K_ivfpq_m8_d256_nbits8_nlist1024.index\n", 330 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d256_nbits8_nlist1024.index\n", 331 | "Searching queries: (50000, 256)\n", 332 | "nprobe: 1\n", 333 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n", 334 | "\n", 335 | "Indexing database (1281167, 256)\n", 336 | "Generating index file: 1K_ivfpq_m16_d256_nbits8_nlist1024.index\n", 337 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d256_nbits8_nlist1024.index\n", 338 | "Searching queries: (50000, 256)\n", 339 | "nprobe: 1\n", 340 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n", 341 | "\n", 342 | "Indexing database (1281167, 256)\n", 343 | "Generating index file: 1K_ivfpq_m32_d256_nbits8_nlist1024.index\n", 344 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d256_nbits8_nlist1024.index\n", 345 | "Searching queries: (50000, 256)\n", 346 | "nprobe: 1\n", 347 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n", 348 | "\n", 349 | "Indexing database (1281167, 256)\n", 350 | "Generating index file: 1K_ivfpq_m64_d256_nbits8_nlist1024.index\n", 351 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d256_nbits8_nlist1024.index\n", 352 | "Searching queries: (50000, 256)\n", 353 | "nprobe: 1\n", 354 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n", 355 | "\n", 356 | "Indexing database (1281167, 512)\n", 357 | "Generating index file: 1K_ivfpq_m8_d512_nbits8_nlist1024.index\n", 358 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d512_nbits8_nlist1024.index\n", 359 | "Searching queries: (50000, 512)\n", 360 | "nprobe: 1\n", 361 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n", 362 | "\n", 363 | "Indexing database (1281167, 512)\n", 364 | "Generating index file: 1K_ivfpq_m16_d512_nbits8_nlist1024.index\n", 365 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d512_nbits8_nlist1024.index\n", 366 | "Searching queries: (50000, 512)\n", 367 | "nprobe: 1\n", 368 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n", 369 | "\n", 370 | "Indexing database (1281167, 512)\n", 371 | "Generating index file: 1K_ivfpq_m32_d512_nbits8_nlist1024.index\n", 372 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d512_nbits8_nlist1024.index\n", 373 | "Searching queries: (50000, 512)\n", 374 | "nprobe: 1\n", 375 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n", 376 | "\n", 377 | "Indexing database (1281167, 512)\n", 378 | "Generating index file: 1K_ivfpq_m64_d512_nbits8_nlist1024.index\n", 379 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d512_nbits8_nlist1024.index\n", 380 | "Searching queries: (50000, 512)\n", 381 | "nprobe: 1\n", 382 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n", 383 | "\n", 384 | "Indexing database (1281167, 1024)\n", 385 | "Generating index file: 1K_ivfpq_m8_d1024_nbits8_nlist1024.index\n", 386 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d1024_nbits8_nlist1024.index\n", 387 | "Searching queries: (50000, 1024)\n", 388 | "nprobe: 1\n", 389 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n", 390 | "\n", 391 | "Indexing database (1281167, 1024)\n", 392 | "Generating index file: 1K_ivfpq_m16_d1024_nbits8_nlist1024.index\n", 393 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d1024_nbits8_nlist1024.index\n", 394 | "Searching queries: (50000, 1024)\n", 395 | "nprobe: 1\n", 396 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n", 397 | "\n", 398 | "Indexing database (1281167, 1024)\n", 399 | "Generating index file: 1K_ivfpq_m32_d1024_nbits8_nlist1024.index\n", 400 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d1024_nbits8_nlist1024.index\n", 401 | "Searching queries: (50000, 1024)\n", 402 | "nprobe: 1\n", 403 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n", 404 | "\n", 405 | "Indexing database (1281167, 1024)\n", 406 | "Generating index file: 1K_ivfpq_m64_d1024_nbits8_nlist1024.index\n", 407 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d1024_nbits8_nlist1024.index\n", 408 | "Searching queries: (50000, 1024)\n", 409 | "nprobe: 1\n", 410 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n", 411 | "\n", 412 | "Indexing database (1281167, 2048)\n", 413 | "Generating index file: 1K_ivfpq_m8_d2048_nbits8_nlist1024.index\n", 414 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d2048_nbits8_nlist1024.index\n" 415 | ] 416 | }, 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "Searching queries: (50000, 2048)\n", 422 | "nprobe: 1\n", 423 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n", 424 | "\n", 425 | "Indexing database (1281167, 2048)\n", 426 | "Generating index file: 1K_ivfpq_m16_d2048_nbits8_nlist1024.index\n", 427 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d2048_nbits8_nlist1024.index\n", 428 | "Searching queries: (50000, 2048)\n", 429 | "nprobe: 1\n", 430 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n", 431 | "\n", 432 | "Indexing database (1281167, 2048)\n", 433 | "Generating index file: 1K_ivfpq_m32_d2048_nbits8_nlist1024.index\n", 434 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d2048_nbits8_nlist1024.index\n", 435 | "Searching queries: (50000, 2048)\n", 436 | "nprobe: 1\n", 437 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n", 438 | "\n", 439 | "Indexing database (1281167, 2048)\n", 440 | "Generating index file: 1K_ivfpq_m64_d2048_nbits8_nlist1024.index\n", 441 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d2048_nbits8_nlist1024.index\n", 442 | "Searching queries: (50000, 2048)\n", 443 | "nprobe: 1\n", 444 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n", 445 | "\n" 446 | ] 447 | } 448 | ], 449 | "source": [ 450 | "# Iterate over embedding dimensionality of database and queries\n", 451 | "for D in Ds:\n", 452 | " # Iterate over subquantizers (PQ) or num_centroids (kmeans)\n", 453 | " for dim in iterator:\n", 454 | " if (index_type in ['pq', 'ivfpq', 'opq']):\n", 455 | " m = dim # number of sub-vectors to divide d into\n", 456 | " if D % m != 0 or m > D:\n", 457 | " continue\n", 458 | " \n", 459 | " if not path.isdir(root+'index_files/'+model+'/'):\n", 460 | " makedirs(root+'index_files/'+model+'/')\n", 461 | " \n", 462 | " index_file = get_index_file(dim) \n", 463 | " \n", 464 | " if model != 'mrl/':\n", 465 | " db_npy = dataset + '_train_' + config +str(D)+ '-X.npy'\n", 466 | " print(\"Loading database: \", db_npy)\n", 467 | " xb = np.load(root+db_npy)\n", 468 | " \n", 469 | " database = np.ascontiguousarray(xb[:,:D], dtype=np.float32)\n", 470 | " faiss.normalize_L2(database)\n", 471 | " print(\"Indexing database \", database.shape)\n", 472 | " print(\"Generating index file: \" + index_file.split(\"/\")[-1])\n", 473 | "\n", 474 | " # Load or build index\n", 475 | " if path.exists(index_file):\n", 476 | " print(\"Loading index file: \" + index_file)\n", 477 | " cpu_index = faiss.read_index(index_file)\n", 478 | "\n", 479 | " else:\n", 480 | " if (index_type == 'pq'):\n", 481 | " cpu_index = faiss.IndexPQ(D, dim, nbits)\n", 482 | " cpu_index.train(database)\n", 483 | " \n", 484 | " elif index_type in ['ivfpq', 'opq']:\n", 485 | " quantizer = faiss.IndexFlatL2(D) # L2 quantizer to assign vectors to Voronoi cells\n", 486 | " cpu_index = faiss.IndexIVFPQ(quantizer, D, nlist, m, nbits)\n", 487 | "\n", 488 | " # Learn a transformation of the embedding space with Optimized PQ\n", 489 | " if index_type == 'opq':\n", 490 | " opq_matrix = faiss.OPQMatrix(D, dim)\n", 491 | " cpu_index = faiss.IndexPreTransform (opq_matrix, cpu_index)\n", 492 | "\n", 493 | " print(\"Training %s Index: d=%d, m=%d, nbits=%d, nlist=%d\" %(index_type,D,m,nbits,nlist))\n", 494 | " ivfpq_start = time.time()\n", 495 | " cpu_index.train(database)\n", 496 | " ivfpq_end = time.time()\n", 497 | " print(\"%s Index train time= %0.3f sec\" % (index_type, ivfpq_end - ivfpq_start))\n", 498 | " \n", 499 | " elif index_type == 'ivfsq':\n", 500 | " quantizer = faiss.IndexFlatL2(D)\n", 501 | " cpu_index = faiss.IndexIVFScalarQuantizer(quantizer, D, nlist, qtype)\n", 502 | " print(\"Training IVFSQ Index: d=%d, qtype=%s, nlist=%d\" %(d,qtype,nlist))\n", 503 | " ivfsq_start = time.time()\n", 504 | " cpu_index.train(database)\n", 505 | " ivfsq_end = time.time()\n", 506 | " print(\"IVFSQ Index train time= %0.3f sec\" % (ivfsq_end - ivfsq_start))\n", 507 | " \n", 508 | "\n", 509 | " elif index_type in ['kmeans', 'kmeans-pca']:\n", 510 | " ncentroids = dim\n", 511 | " kmeans = faiss.Kmeans(D, ncentroids, verbose=True)\n", 512 | " if use_svd: \n", 513 | " database = np.load(root+db_npy_svd)\n", 514 | " \n", 515 | " print(\"Learning %s index with d=%d and k=%d\" %(index_type, D, ncentroids))\n", 516 | " kmeans_start = time.time()\n", 517 | " kmeans.train(database)\n", 518 | " kmeans_end = time.time()\n", 519 | " print(\"%s index train time= %0.3f sec\" % (index_type, kmeans_end - kmeans_start)) \n", 520 | " cpu_index = kmeans.index\n", 521 | " centroids_path = root+'kmeans/'+model+index_type+'_ncentroids'+str(ncentroids)+\"_\"+str(D)+'d'\"_\"+dataset+'.npy'\n", 522 | " print(\"Saving centroids: \", kmeans.centroids.shape)\n", 523 | " with open(centroids_path, 'wb') as f:\n", 524 | " np.save(f, kmeans.centroids)\n", 525 | "\n", 526 | " \n", 527 | " # add database embeddings to the index and save to disk \n", 528 | " cpu_index.add(database) \n", 529 | " faiss.write_index(cpu_index, index_file)\n", 530 | "\n", 531 | " if use_gpu:\n", 532 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n", 533 | " print(\"Moved to GPU\")\n", 534 | " else:\n", 535 | " index = cpu_index\n", 536 | "\n", 537 | " if model != 'mrl/':\n", 538 | " if use_svd:\n", 539 | " query_npy = query_npy_svd\n", 540 | " else:\n", 541 | " query_npy = dataset + '_val_' + config +str(D)+ '-X.npy'\n", 542 | " print(\"Loading queries: \", query_npy)\n", 543 | " xq = np.load(root+query_npy)\n", 544 | " queryset = np.ascontiguousarray(xq[:,:D], dtype=np.float32)\n", 545 | " faiss.normalize_L2(queryset)\n", 546 | " \n", 547 | " # kmeans indices are used downstream by adanns.ipynb \n", 548 | " if index_type not in ['kmeans', 'kmeans-pca']:\n", 549 | " print(\"Searching queries: \", queryset.shape)\n", 550 | " # Loop over search probes/ beam width\n", 551 | " for nprobe in nprobes:\n", 552 | " if index_type != 'pq':\n", 553 | " index.nprobe=nprobe\n", 554 | " print(\"nprobe: \", index.nprobe)\n", 555 | " \n", 556 | " Dist, Ind = index.search(queryset, k)\n", 557 | "\n", 558 | " if index_type == 'pq':\n", 559 | " nn_dir = root+\"neighbors/\"+model+index_type+\"/\"+index_type+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(D)+\".csv\"\n", 560 | " else:\n", 561 | " nn_dir = root+\"neighbors/\"+model+index_type+\"/\"+index_type+\"_m\"+str(dim)+'_nlist'+str(nlist)+'_nprobe'+str(nprobe)+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(D)+\".csv\"\n", 562 | " if not path.isdir(root+\"neighbors/\"+model+index_type+\"/\"):\n", 563 | " makedirs(root+\"neighbors/\"+model+index_type+\"/\")\n", 564 | " \n", 565 | " print(nn_dir.split(\"/\")[-1]+\"\\n\")\n", 566 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n", 567 | " \n", 568 | " # Optional test for duplicate distances in high M regimes\n", 569 | " #get_duplicate_dist(Ind, Dist)\n", 570 | " \n", 571 | " del Dist, Ind\n", 572 | " del index" 573 | ] 574 | } 575 | ], 576 | "metadata": { 577 | "jupytext": { 578 | "cell_metadata_filter": "-all", 579 | "notebook_metadata_filter": "-all" 580 | }, 581 | "kernelspec": { 582 | "display_name": "Python 3 (ipykernel)", 583 | "language": "python", 584 | "name": "python3" 585 | }, 586 | "language_info": { 587 | "codemirror_mode": { 588 | "name": "ipython", 589 | "version": 3 590 | }, 591 | "file_extension": ".py", 592 | "mimetype": "text/x-python", 593 | "name": "python", 594 | "nbconvert_exporter": "python", 595 | "pygments_lexer": "ipython3", 596 | "version": "3.11.3" 597 | }, 598 | "vscode": { 599 | "interpreter": { 600 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc" 601 | } 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 5 606 | } 607 | -------------------------------------------------------------------------------- /adanns/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import faiss 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | valid_models = ['mrl', 'rr'] 7 | valid_datasets = ['1K', '4K', 'A' ,'R' ,'O', 'V2'] # ImageNet versions 8 | 9 | def load_embeddings(model, dataset, embedding_dim, arch='resnet50', using_gcloud=True): 10 | if model == 'mrl': 11 | config = 'mrl1_e0_ff2048' 12 | elif model == 'rr': # using rigid representation 13 | config = f'mrl0_e0_ff{embedding_dim}' 14 | else: 15 | raise ValueError(f'Model must be in {valid_models}') 16 | 17 | 18 | if dataset not in valid_datasets: 19 | raise ValueError(f'Dataset must be in {valid_datasets}') 20 | if dataset == 'V2': # ImageNetv2 is only a test set; set database to ImageNet-1K 21 | dataset = '1K' 22 | 23 | 24 | if using_gcloud: 25 | root = f'../../../inference_array/{arch}/' 26 | else: # using local machine 27 | root = f'../../inference_array/{arch}/' 28 | 29 | 30 | db_npy = dataset + '_train_' + config + '-X.npy' 31 | query_npy = dataset + '_val_' + config + '-X.npy' 32 | db_label_npy = dataset + '_train_' + config + '-y.npy' 33 | query_label_npy = dataset + '_val_' + config + '-y.npy' 34 | 35 | database = np.load(root + db_npy) 36 | queryset = np.load(root + query_npy) 37 | db_labels = np.load(root + db_label_npy) 38 | query_labels = np.load(root + query_label_npy) 39 | 40 | faiss.normalize_L2(database) 41 | faiss.normalize_L2(queryset) 42 | 43 | xb = np.ascontiguousarray(database[:, :embedding_dim], dtype=np.float32) 44 | xq = np.ascontiguousarray(queryset[:, :embedding_dim], dtype=np.float32) 45 | 46 | faiss.normalize_L2(xb) 47 | faiss.normalize_L2(xq) 48 | 49 | return database, queryset, db_labels, query_labels, xb, xq 50 | 51 | # Find Duplicate distances for low M (M=1 is k-means) on searched faiss index 52 | def get_duplicate_dist(Ind, Dist): 53 | k = 100 54 | duplicates = [] 55 | for i in range(50000): 56 | indices = Ind[i, :][:k] 57 | distances = Dist[i, :][:k] 58 | unique_distances = np.unique(distances, return_counts=1)[1] 59 | unique_distances = unique_distances[unique_distances != 1] # remove all 1s (i.e. unique distances) 60 | duplicates = np.append(duplicates, unique_distances) 61 | 62 | hist = plt.hist(duplicates, bins='auto') 63 | plt.title(model.split("/")[0]+", D="+str(D)+ ", M=" +str(dim)) 64 | plt.xlabel("Number of 100-NN with same neighbor distance values") 65 | plt.show() 66 | print(duplicates[duplicates > 2].sum()) 67 | 68 | # Normalize embeddings 69 | def normalize_embeddings(embeddings, dtype): 70 | if dtype == 'float32': 71 | print(np.linalg.norm(embeddings)) 72 | faiss.normalize_L2(embeddings) 73 | print(np.linalg.norm(embeddings)) 74 | return embeddings 75 | else: 76 | pass -------------------------------------------------------------------------------- /generate_embeddings/pytorch_inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code to evaluate MRL models on different validation benchmarks. 3 | ''' 4 | import sys 5 | sys.path.append("../") # adding root folder to the path 6 | 7 | import torch 8 | import torchvision 9 | from torchvision import transforms 10 | from torchvision.models import * 11 | from torchvision import datasets 12 | from tqdm import tqdm 13 | 14 | from MRL import * 15 | from imagenetv2_pytorch import ImageNetV2Dataset 16 | from argparse import ArgumentParser 17 | from utils import * 18 | 19 | # nesting list is by default from 8 to 2048 in powers of 2, can be modified from here. 20 | BATCH_SIZE = 256 21 | IMG_SIZE = 256 22 | CENTER_CROP_SIZE = 224 23 | #NESTING_LIST=[2**i for i in range(3, 12)] 24 | ROOT="../../IMAGENET/" # path to IN1K 25 | DATASET_ROOT="../../datasets/" # 26 | 27 | parser=ArgumentParser() 28 | 29 | # model args 30 | parser.add_argument('--efficient', action='store_true', help='Efficient Flag') 31 | parser.add_argument('--mrl', action='store_true', help='To use MRL') 32 | parser.add_argument('--rep_size', type=int, default=2048, help='Rep. size for fixed feature model') 33 | parser.add_argument('--path', type=str, required=True, help='Path to .pt model checkpoint') 34 | parser.add_argument('--old_ckpt', action='store_true', help='To use our trained checkpoints') 35 | parser.add_argument('--workers', type=int, default=12, help='num workers for dataloader') 36 | parser.add_argument('--model_arch', type=str, default='resnet50', help='Loaded model arch') 37 | # dataset/eval args 38 | parser.add_argument('--tta', action='store_true', help='Test Time Augmentation Flag') 39 | parser.add_argument('--dataset', type=str, default='V1', help='Benchmarks') 40 | parser.add_argument('--save_logits', action='store_true', help='To save logits for model analysis') 41 | parser.add_argument('--save_softmax', action='store_true', help='To save softmax_probs for model analysis') 42 | parser.add_argument('--save_gt', action='store_true', help='To save ground truth for model analysis') 43 | parser.add_argument('--save_predictions', action='store_true', help='To save predicted labels for model analysis') 44 | # retrieval args 45 | parser.add_argument('--retrieval', action='store_true', help='flag for image retrieval array dumps') 46 | parser.add_argument('--random_sample_dim', type=int, default=4202000, help='number of random samples to slice from retrieval database') 47 | parser.add_argument('--retrieval_array_path', default='', help='path to save database and query arrays for retrieval', type=str) 48 | 49 | 50 | args = parser.parse_args() 51 | 52 | if args.model_arch == 'convnext_tiny': 53 | convnext_tiny = create_model( 54 | 'convnext_tiny', 55 | pretrained=False, 56 | num_classes=1000, 57 | drop_path_rate=0.1, 58 | layer_scale_init_value=1e-6, 59 | head_init_scale=1.0 60 | ) 61 | 62 | model_arch_dict = { 63 | 'vgg19': {'model': vgg19(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]}, 64 | 'resnet18': {'model': resnet18(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512]}, 65 | 'resnet34': {'model': resnet34(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512]}, 66 | 'resnet50': {'model': resnet50(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048]}, 67 | 'resnet101': {'model': resnet101(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048]}, 68 | 'mobilenetv2': {'model': mobilenet_v2(False), 'nest_list': [10, 20, 40, 80, 160, 320, 640, 1280]}, 69 | 'convnext_tiny': {'model': convnext_tiny, 'nest_list': [12, 24, 48, 96, 192, 384, 768]}, 70 | } 71 | 72 | model = model_arch_dict[args.model_arch]['model'] 73 | print(model) 74 | 75 | if not args.old_ckpt: 76 | if args.mrl: 77 | if args.model_arch == 'mobilenetv2': 78 | model.classifier[1] = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], num_classes=1000, efficient=args.efficient) 79 | elif args.model_arch == 'convnext_tiny': 80 | model.head = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], num_classes=1000, efficient=args.efficient) 81 | else: 82 | model.fc = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], efficient=args.efficient) 83 | else: 84 | model.fc=FixedFeatureLayer(args.rep_size, 1000) # RR model 85 | else: 86 | if args.mrl: 87 | model = load_from_old_ckpt(model, args.efficient, model_arch_dict[args.model_arch]['nest_list']) 88 | else: 89 | model.fc=FixedFeatureLayer(args.rep_size, 1000) 90 | 91 | print(model.fc) 92 | apply_blurpool(model) 93 | model.load_state_dict(get_ckpt(args.path, args.model_arch)) # Since our models have a torch DDP wrapper, we modify keys to exclude first 7 chars (".module"). 94 | model = model.cuda() 95 | model.eval() 96 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 97 | test_transform = transforms.Compose([ 98 | transforms.Resize(IMG_SIZE), 99 | transforms.CenterCrop(CENTER_CROP_SIZE), 100 | transforms.ToTensor(), 101 | normalize]) 102 | 103 | # Model Eval 104 | if not args.retrieval: 105 | if args.dataset == 'V2': 106 | print("Loading Robustness Dataset") 107 | dataset = ImageNetV2Dataset("matched-frequency", transform=test_transform) 108 | elif args.dataset == '4K': 109 | train_path = DATASET_ROOT+"imagenet-4k/train/" 110 | test_path = DATASET_ROOT+"imagenet-4k/test/" 111 | train_dataset = datasets.ImageFolder(train_path, transform=test_transform) 112 | test_dataset = datasets.ImageFolder(test_path, transform=test_transform) 113 | else: 114 | print("Loading Imagenet 1K val set") 115 | dataset = torchvision.datasets.ImageFolder(ROOT+'val/', transform=test_transform) 116 | 117 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False) 118 | 119 | if args.mrl: 120 | _, top1_acc, top5_acc, total_time, num_images, m_score_dict, softmax_probs, gt, logits = evaluate_model( 121 | model, dataloader, show_progress_bar=True, nesting_list=model_arch_dict[args.model_arch]['nest_list'], tta=args.tta, imagenetA=args.dataset == 'A', imagenetR=args.dataset == 'R') 122 | else: 123 | _, top1_acc, top5_acc, total_time, num_images, m_score_dict, softmax_probs, gt, logits = evaluate_model( 124 | model, dataloader, show_progress_bar=True, nesting_list=None, tta=args.tta, imagenetA=args.dataset == 'A', imagenetR=args.dataset == 'R') 125 | 126 | tqdm.write('Evaluated {} images'.format(num_images)) 127 | confidence, predictions = torch.max(softmax_probs, dim=-1) 128 | if args.mrl: 129 | for i, nesting in enumerate(model_arch_dict[args.model_arch]['nest_list']): 130 | print("Rep. Size", "\t", nesting, "\n") 131 | tqdm.write(' Top-1 accuracy for {} : {:.2f}'.format(nesting, 100.0 * top1_acc[nesting])) 132 | tqdm.write(' Top-5 accuracy for {} : {:.2f}'.format(nesting, 100.0 * top5_acc[nesting])) 133 | tqdm.write(' Total time: {:.1f} (average time per image: {:.2f} ms)'.format(total_time, 1000.0 * total_time / num_images)) 134 | else: 135 | print("Rep. Size", "\t", args.rep_size, "\n") 136 | tqdm.write(' Evaluated {} images'.format(num_images)) 137 | tqdm.write(' Top-1 accuracy: {:.2f}%'.format(100.0 * top1_acc)) 138 | tqdm.write(' Top-5 accuracy: {:.2f}%'.format(100.0 * top5_acc)) 139 | tqdm.write(' Total time: {:.1f} (average time per image: {:.2f} ms)'.format(total_time, 1000.0 * total_time / num_images)) 140 | 141 | 142 | # saving torch tensor for model analysis... 143 | if args.save_logits or args.save_softmax or args.save_predictions: 144 | save_string = f"mrl={args.mrl}_efficient={args.efficient}_dataset={args.dataset}_tta={args.tta}" 145 | if args.save_logits: 146 | torch.save(logits, save_string+"_logits.pth") 147 | if args.save_predictions: 148 | torch.save(predictions, save_string+"_predictions.pth") 149 | if args.save_softmax: 150 | torch.save(softmax_probs, save_string+"_softmax.pth") 151 | 152 | if args.save_gt: 153 | torch.save(gt, f"gt_dataset={args.dataset}.pth") 154 | 155 | 156 | # Image Retrieval Inference 157 | else: 158 | if args.dataset == '1K': 159 | train_dataset = datasets.ImageFolder(ROOT+"train/", transform=test_transform) 160 | test_dataset = datasets.ImageFolder(ROOT+"val/", transform=test_transform) 161 | elif args.dataset == 'V2': 162 | train_dataset = None # V2 has only a test set 163 | test_dataset = ImageNetV2Dataset("matched-frequency", transform=test_transform) 164 | elif args.dataset == '4K': 165 | train_path = DATASET_ROOT+"imagenet-4k/train/" 166 | test_path = DATASET_ROOT+"imagenet-4k/test/" 167 | train_dataset = datasets.ImageFolder(train_path, transform=test_transform) 168 | test_dataset = datasets.ImageFolder(test_path, transform=test_transform) 169 | else: 170 | print("Error: unsupported dataset!") 171 | 172 | database_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False) 173 | queryset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False) 174 | 175 | config = args.model_arch+ "/" + args.dataset + "_val_mrl" + str(int(args.mrl)) + "_e" + str(int(args.efficient)) + "_rr" + str(int(args.rep_size)) 176 | print("Retrieval Config: " + config) 177 | generate_retrieval_data(model, queryset_loader, config, args.random_sample_dim, args.rep_size, args.retrieval_array_path, args.model_arch) 178 | 179 | if train_dataset is not None: 180 | config = args.model_arch+ "/" + args.dataset + "_train_mrl" + str(int(args.mrl)) + "_e" + str(int(args.efficient)) + "_rr" + str(int(args.rep_size)) 181 | print("Retrieval Config: " + config) 182 | generate_retrieval_data(model, database_loader, config, args.random_sample_dim, args.rep_size, args.retrieval_array_path, args.model_arch) 183 | -------------------------------------------------------------------------------- /generate_embeddings/run-inference.sh: -------------------------------------------------------------------------------- 1 | config=RR # RR or MRL 2 | 3 | if config=RR 4 | then 5 | for dim in 8 16 32 64 128 256 512 1024 2048 6 | do 7 | echo "Generating embeddings on RR-$dim" 8 | python pytorch_inference.py --retrieval --path=path/to/rr/model/weights.pt \ 9 | --model_arch='resnet50' --retrieval_array_path=rr_output_dir/ --dataset=1K --rep_size=$dim 10 | done 11 | else 12 | echo "Generating embeddings on MRL" 13 | python pytorch_inference.py --retrieval --path=path/to/mrl/model/weights.pt \ 14 | --model_arch='resnet50' --retrieval_array_path=mrl_output_dir/ --dataset=1K --mrl 15 | fi 16 | -------------------------------------------------------------------------------- /images/accuracy-compute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/accuracy-compute.png -------------------------------------------------------------------------------- /images/adanns-opq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/adanns-opq.png -------------------------------------------------------------------------------- /images/adanns-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/adanns-teaser.png -------------------------------------------------------------------------------- /images/diskann-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/diskann-table.png -------------------------------------------------------------------------------- /images/diskann-top1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/diskann-top1.png -------------------------------------------------------------------------------- /images/encoders.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/encoders.png -------------------------------------------------------------------------------- /images/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/flowchart.png -------------------------------------------------------------------------------- /images/opq-1k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/opq-1k.png -------------------------------------------------------------------------------- /images/opq-nq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/opq-nq.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | terminaltables 2 | pytorch_pfn_extras 3 | fastargs 4 | matplotlib 5 | sklearn 6 | imgcat 7 | pandas 8 | assertpy 9 | tqdm 10 | psutil 11 | webdataset 12 | torchmetrics 13 | git+https://github.com/modestyachts/ImageNetV2_pytorch 14 | grad-cam 15 | scikit-image 16 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.cuda.amp import autocast 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | from torchvision.models import * 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | def get_ckpt(path, arch): 12 | ckpt=path 13 | ckpt = torch.load(ckpt, map_location='cpu') 14 | if arch == 'convnext_tiny': 15 | ckpt = ckpt['model'] 16 | plain_ckpt={} 17 | for k in ckpt.keys(): 18 | plain_ckpt[k[7:]] = ckpt[k] # remove the 'module' portion of key if model is Pytorch DDP 19 | return plain_ckpt 20 | 21 | 22 | class BlurPoolConv2d(torch.nn.Module): 23 | def __init__(self, conv): 24 | super().__init__() 25 | default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0 26 | filt = default_filter.repeat(conv.in_channels, 1, 1, 1) 27 | self.conv = conv 28 | self.register_buffer('blur_filter', filt) 29 | 30 | def forward(self, x): 31 | blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1), 32 | groups=self.conv.in_channels, bias=None) 33 | return self.conv.forward(blurred) 34 | 35 | 36 | def apply_blurpool(mod: torch.nn.Module): 37 | for (name, child) in mod.named_children(): 38 | if isinstance(child, torch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16): 39 | setattr(mod, name, BlurPoolConv2d(child)) 40 | else: apply_blurpool(child) 41 | 42 | 43 | ''' 44 | Retrieval utility methods. 45 | ''' 46 | activation = {} 47 | fwd_pass_x_list = [] 48 | fwd_pass_y_list = [] 49 | path_list = [] 50 | 51 | def get_activation(name): 52 | """ 53 | Get the activation from an intermediate point in the network. 54 | :param name: layer whose activation is to be returned 55 | :return: activation of layer 56 | """ 57 | def hook(model, input, output): 58 | activation[name] = output.detach() 59 | return hook 60 | 61 | 62 | def append_feature_vector_to_list(activation, label, rep_size, path): 63 | """ 64 | Append the feature vector to a list to later write to disk. 65 | :param activation: image feature vector from network 66 | :param label: ground truth label 67 | :param rep_size: representation size to be stored 68 | """ 69 | for i in range (activation.shape[0]): 70 | x = activation[i].cpu().detach().numpy() 71 | y = label[i].cpu().detach().numpy() 72 | fwd_pass_y_list.append(y) 73 | fwd_pass_x_list.append(x[:rep_size]) 74 | 75 | def dump_feature_vector_array_lists(config_name, rep_size, random_sample_dim, output_path): 76 | """ 77 | Save the database and query vector array lists to disk. 78 | :param config_name: config to specify during file write 79 | :param rep_size: representation size for fixed feature model 80 | :param random_sample_dim: to write a subset of database if required, e.g. to train an SVM on 100K samples 81 | :param output_path: path to dump database and query arrays after inference 82 | """ 83 | 84 | # save X (n x 2048), y (n x 1) to disk, where n = num_samples 85 | X_fwd_pass = np.asarray(fwd_pass_x_list, dtype=np.float32) 86 | y_fwd_pass = np.asarray(fwd_pass_y_list, dtype=np.uint16).reshape(-1,1) 87 | 88 | if random_sample_dim < X_fwd_pass.shape[0]: 89 | random_indices = np.random.choice(X_fwd_pass.shape[0], size=random_sample_dim, replace=False) 90 | random_X = X_fwd_pass[random_indices, :] 91 | random_y = y_fwd_pass[random_indices, :] 92 | print("Writing random samples to disk with dim [%d x 2048] " % random_sample_dim) 93 | else: 94 | random_X = X_fwd_pass 95 | random_y = y_fwd_pass 96 | print("Writing %s to disk with dim [%d x %d]" % (str(config_name)+"_X", X_fwd_pass.shape[0], rep_size)) 97 | 98 | print("Unique entries: ", len(np.unique(random_y))) 99 | np.save(output_path+str(config_name)+'-X.npy', random_X) 100 | np.save(output_path+str(config_name)+'-y.npy', random_y) 101 | 102 | 103 | def generate_retrieval_data(model, data_loader, config, random_sample_dim, rep_size, output_path, model_arch): 104 | """ 105 | Iterate over data in dataloader, get feature vector from model inference, and save to array to dump to disk. 106 | :param model: ResNet50 model loaded from disk 107 | :param data_loader: loader for database or query set 108 | :param config: name of configuration for writing arrays to disk 109 | :param random_sample_dim: to write a subset of database if required, e.g. to train an SVM on 100K samples 110 | :param rep_size: representation size for fixed feature model 111 | :param output_path: path to dump database and query arrays after inference 112 | """ 113 | model.eval() 114 | if model_arch == 'vgg19': 115 | model.classifier[5].register_forward_hook(get_activation('avgpool')) 116 | elif model_arch == 'mobilenetv2': 117 | model.classifier[0].register_forward_hook(get_activation('avgpool')) 118 | else: 119 | model.avgpool.register_forward_hook(get_activation('avgpool')) 120 | print("Dataloader len: ", len(data_loader)) 121 | 122 | with torch.no_grad(): 123 | with autocast(): 124 | for i_batch, (images, target) in enumerate(data_loader): 125 | output = model(images.cuda()) 126 | path = None 127 | append_feature_vector_to_list(activation['avgpool'].squeeze(), target.cuda(), rep_size, path) 128 | if (i_batch) % int(len(data_loader)/5) == 0: 129 | print("Finished processing: %f %%" % (i_batch / len(data_loader) * 100)) 130 | dump_feature_vector_array_lists(config, rep_size, random_sample_dim, output_path) 131 | 132 | # re-initialize empty lists 133 | global fwd_pass_x_list 134 | global fwd_pass_y_list 135 | global path_list 136 | fwd_pass_x_list = [] 137 | fwd_pass_y_list = [] 138 | path_list = [] 139 | 140 | ''' 141 | Load pretrained models saved with old notation. 142 | ''' 143 | class SingleHeadNestedLinear(nn.Linear): 144 | """ 145 | Class for MRL-E model. 146 | """ 147 | 148 | def __init__(self, nesting_list: List, num_classes=1000, **kwargs): 149 | super(SingleHeadNestedLinear, self).__init__(nesting_list[-1], num_classes, **kwargs) 150 | self.nesting_list=nesting_list 151 | self.num_classes=num_classes # Number of classes for classification 152 | 153 | def forward(self, x): 154 | nesting_logits = () 155 | for i, num_feat in enumerate(self.nesting_list): 156 | if not (self.bias is None): 157 | logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t()) + self.bias 158 | else: 159 | logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t()) 160 | nesting_logits+= (logit,) 161 | return nesting_logits 162 | 163 | class MultiHeadNestedLinear(nn.Module): 164 | """ 165 | Class for MRL model. 166 | """ 167 | def __init__(self, nesting_list: List, num_classes=1000, **kwargs): 168 | super(MultiHeadNestedLinear, self).__init__() 169 | self.nesting_list=nesting_list 170 | self.num_classes=num_classes # Number of classes for classification 171 | for i, num_feat in enumerate(self.nesting_list): 172 | setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs)) 173 | 174 | def forward(self, x): 175 | nesting_logits = () 176 | for i, num_feat in enumerate(self.nesting_list): 177 | nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),) 178 | return nesting_logits 179 | 180 | def load_from_old_ckpt(model, efficient, nesting_list): 181 | if efficient: 182 | model.fc=SingleHeadNestedLinear(nesting_list) 183 | else: 184 | model.fc=MultiHeadNestedLinear(nesting_list) 185 | 186 | return model 187 | --------------------------------------------------------------------------------