├── LICENSE
├── MRL.py
├── README.md
├── __init__.py
├── adanns
├── README.md
├── ablations
│ ├── centroid_recall.ipynb
│ ├── class_matches_per_cluster.ipynb
│ ├── cluster-distribution.ipynb
│ ├── plots.ipynb
│ └── relative_contrast.ipynb
├── adanns-ivf-optimized.ipynb
├── adanns-ivf-unoptimized.ipynb
├── compute_metrics.ipynb
├── diskann
│ ├── README.md
│ └── adanns-diskann.ipynb
├── dpr-nq
│ ├── README.md
│ └── adanns-nq.ipynb
├── generate_nn
│ ├── hnsw_exactl2.ipynb
│ ├── ivf-experiments.ipynb
│ └── ivfpq_opq_kmeans.ipynb
└── utils.py
├── generate_embeddings
├── pytorch_inference.py
└── run-inference.sh
├── images
├── accuracy-compute.png
├── adanns-opq.png
├── adanns-teaser.png
├── diskann-table.png
├── diskann-top1.png
├── encoders.png
├── flowchart.png
├── opq-1k.png
└── opq-nq.png
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 RAIVN Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MRL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Type, Any, Callable, Union, List, Optional
4 |
5 | '''
6 | Loss function for Matryoshka Representation Learning
7 | '''
8 |
9 | class Matryoshka_CE_Loss(nn.Module):
10 | def __init__(self, relative_importance=None, **kwargs):
11 | super(Matryoshka_CE_Loss, self).__init__()
12 | self.criterion = nn.CrossEntropyLoss(**kwargs)
13 | self.relative_importance= relative_importance
14 |
15 | def forward(self, output, target):
16 | loss=0
17 | N= len(output)
18 | for i in range(N):
19 | rel = 1. if self.relative_importance is None else self.relative_importance[i]
20 | loss+= rel*self.criterion(output[i], target)
21 | return loss
22 |
23 | class MRL_Linear_Layer(nn.Module):
24 | def __init__(self, nesting_list: List, num_classes=1000, efficient=False, **kwargs):
25 | super(MRL_Linear_Layer, self).__init__()
26 | self.nesting_list=nesting_list
27 | self.num_classes=num_classes # Number of classes for classification
28 | self.efficient = efficient
29 | if self.efficient:
30 | setattr(self, f"nesting_classifier_{0}", nn.Linear(nesting_list[-1], self.num_classes, **kwargs))
31 | else:
32 | for i, num_feat in enumerate(self.nesting_list):
33 | setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))
34 |
35 | def reset_parameters(self):
36 | if self.efficient:
37 | self.nesting_classifier_0.reset_parameters()
38 | else:
39 | for i in range(len(self.nesting_list)):
40 | getattr(self, f"nesting_classifier_{i}").reset_parameters()
41 |
42 |
43 | def forward(self, x):
44 | nesting_logits = ()
45 | for i, num_feat in enumerate(self.nesting_list):
46 | if self.efficient:
47 | if self.nesting_classifier_0.bias is None:
48 | nesting_logits+= (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()), )
49 | else:
50 | nesting_logits+= (torch.matmul(x[:, :num_feat], (self.nesting_classifier_0.weight[:, :num_feat]).t()) + self.nesting_classifier_0.bias, )
51 | else:
52 | nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
53 |
54 | return nesting_logits
55 |
56 |
57 | class FixedFeatureLayer(nn.Linear):
58 | '''
59 | For our fixed feature baseline, we just replace the classification layer with the following.
60 | It effectively just look at the first "in_features" for the classification.
61 | '''
62 |
63 | def __init__(self, in_features, out_features, **kwargs):
64 | super(FixedFeatureLayer, self).__init__(in_features, out_features, **kwargs)
65 |
66 | def forward(self, x):
67 | if not (self.bias is None):
68 | out = torch.matmul(x[:, :self.in_features], self.weight.t()) + self.bias
69 | else:
70 | out = torch.matmul(x[:, :self.in_features], self.weight.t())
71 | return out
72 |
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AdANNS: A Framework for Adaptive Semantic Search](https://arxiv.org/abs/2305.19435)
2 | _Aniket Rege, Aditya Kusupati, Sharan Ranjit S, Alan Fan, Qingqing Cao, Sham M. Kakade, Prateek Jain, Ali Farhadi_
3 |
4 | Learned representations are used in multiple downstream tasks like web-scale search & classification. However, they are flat & rigid—information is diffused across dimensions and cannot be adaptively deployed without large post-hoc overhead. We propose the use of adaptive representations to improve approximate nearest neighbour search (ANNS) and introduce a new paradigm, AdANNS, to achieve it at scale leveraging matryoshka representations (MRs). We compare AdANNS to ANNS structures built on independently trained rigid representations (RRs).
5 |
6 |
7 |
8 |
9 |
10 | This repository contains code for AdANNS construction and inference built on top of Matryoshka Representations (MRs). The training pipeline to generate MRs and RRs can be found [here](https://github.com/RAIVNLab/MRL). The repository is organized as follows:
11 |
12 | 1. Set up
13 | 2. Inference to generate MRs and RRs
14 | 3. AdANNS Experiments
15 |
16 |
17 | ## Set Up
18 | Pip install the requirements file in this directory. Note that a python3 distribution is required:
19 | ```
20 | pip3 install -r requirements.txt
21 | ```
22 |
23 | ## Inference on Trained Models
24 | We primarily utilize ResNet-50 MRL and Rigid encoders ("Fixed-Feature" in original MRL terminology) for a bulk of our experimentation. We also utilize trained MRL ResNet18/34/101 and ConvNeXT encoders as an ablation study. Inference on trained models to generate MR and RR embeddings used for downstream ANNS is provided in `inference/pytorch_inference.py`, and is explained in more detail in the [original MRL repository](https://github.com/RAIVNLab/MRL).
25 |
26 |
27 | ## AdANNS
28 | `cd adanns`
29 |
30 | We provide code showcasing AdANNS in action on a simple yet powerful search data structure – IVF (AdANNS-IVF) – and on industry-default quantization – OPQ (AdANNS-OPQ) – followed by its effectiveness on modern-day ANNS composite indices like IVFOPQ (AdANNS-IVFOPQ) and DiskANN (AdANNS-DiskANN).
31 |
32 | A more detailed walkthrough of AdANNS can be found in [`adanns/`](adanns/)
33 |
34 |
35 |
36 |
37 |
38 | ## Citation
39 | If you find this project useful in your research, please consider citing:
40 | ```
41 | @article{rege2023adanns,
42 | title={AdANNS: A Framework for Adaptive Semantic Search},
43 | author={Aniket Rege and Aditya Kusupati and Sharan Ranjit S and Alan Fan and Qingqing Cao and Sham Kakade and Prateek Jain and Ali Farhadi},
44 | year={2023},
45 | eprint={2305.19435},
46 | archivePrefix={arXiv},
47 | primaryClass={cs.LG}
48 | }
49 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/__init__.py
--------------------------------------------------------------------------------
/adanns/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 | The directory structure is organized as follows:
3 | 1. AdANNS-IVF (see Figure 2) in an [optimized](#optimized-adanns-ivf) and [unoptimized](#unoptimized-adanns-ivf) fashion
4 | 2. [Non Adaptive Pipeline](#non-adaptive-search) on ImageNet-1K
5 | 3. AdANNS with DiskANN, a graph based Memory-SSD hybrid ANNS index in [diskann](diskann/README.md)
6 | 4. AdANNS with Dense Passage Retriever (DPR) on [Natural Questions](https://ai.google.com/research/NaturalQuestions) (NQ) in [dpr-nq](dpr-nq/README.md)
7 |
8 |
9 |
10 |
11 |
12 | ## Optimized AdANNS-IVF
13 | [`adanns-ivf-optimized.ipynb`](adanns-ivf-optimized.ipynb) implements _AdANNS_-IVF in an optimized fashion (native Faiss) and is drawn from a [Faiss case study](https://gist.github.com/mdouze/8c5ab227c0f7d9d7c15cf92a391dcbe5#file-demo_independent_ivf_dimension-ipynb). During index construction, $d_c$ is used for coarse quantization (centroid assignment), and all database vectors are pre-assigned to their appropriate clusters with this coarse quantizer. We also typically learn a fine quantizer (OPQ) with $d_s$ for fast distance computation. During the search phase, for a given query, nprobe clusters are shortlisted with $d_c$, and linear scan is done via a PQ lookup table which is constructed from $d_s$ over all shortlisted clusters.
14 |
15 | Note that currently the optimized version supports only $d_s \ge d_c$ during index construction.
16 |
17 | ## Unoptimized AdANNS-IVF
18 | [`adanns-ivf-unoptimized.ipynb`](adanns-ivf-unoptimized.ipynb) implements a fully-decoupled ANNS pipeline; $d_c < d_s$ is now possible. This is carried out in a highly unoptimized fashion: during construction, we utilize pre-constructed *k-means* indices indexed over the database with $d_c$, and build an inverted file lookup that maps each database item to its closest cluster centroid. During search time, we iterate over each centroid and perform a batched search over every queryset vector corresponding to that centroid with $d_s$. This approach scales linearly in cost with the number of centroids $\mathcal{O}(N_C)$, and results in higher wall-clock times during search than the optimized version above.
19 |
20 | ## Non-Adaptive Search
21 |
22 |
23 |
24 |
25 | For all other experiments with Faiss (IVF-XX, MG-IVF-XX, OPQ, IVFOPQ, HNSWOPQ) on ImageNet-1K, we have a 3-stage pipeline:
26 | 1. Encoder inference to generate embeddings for database and queries: see [run-inference.sh](../generate_embeddings/run-inference.sh)
27 | 2. Index construction (database) and search (queries) on generated embeddings to generate k-NN arrays: see [generate_nn](generate_nn/)
28 | 3. Metric computation or Ablations on top of k-NN arrays: see [compute_metrics.ipynb](compute_metrics.ipynb) and [ablations](ablations/)
--------------------------------------------------------------------------------
/adanns/ablations/centroid_recall.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "21d1fdb4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import faiss\n",
12 | "import torch\n",
13 | "import sys\n",
14 | "sys.path.append('../')\n",
15 | "from utils.py import load_embeddings"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 2,
21 | "id": "cfc615b9",
22 | "metadata": {},
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "Loaded queries: (50000, 2048)\n"
29 | ]
30 | }
31 | ],
32 | "source": [
33 | "root = '../../../inference_array/resnet50/'\n",
34 | "model = 'mrl' # mrl, ff\n",
35 | "dataset = '1K' # 1K, 4K, V2\n",
36 | "index_type = 'kmeans'\n",
37 | "d = 2048 # cluster construction dim\n",
38 | "\n",
39 | "_, queryset, _, _, _, _ = load_embeddings(model, dataset, d)\n",
40 | "faiss.normalize_L2(queryset)\n",
41 | "print(\"Loaded queries:\", queryset.shape)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "id": "39d931c4",
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "name": "stdout",
52 | "output_type": "stream",
53 | "text": [
54 | "Clusters: 1024\n",
55 | "\n",
56 | "Number of probes: 1\n",
57 | "['top1', 'top2', 'top4', 'top5', 'top10']\n",
58 | "[0.7355 0.7355 0.7355 0.7355 0.7355]\n",
59 | "[0.84186 0.84186 0.84186 0.84186 0.84186]\n",
60 | "[0.89878 0.89878 0.89878 0.89878 0.89878]\n",
61 | "[0.93288 0.93288 0.93288 0.93288 0.93288]\n",
62 | "[0.95518 0.95518 0.95518 0.95518 0.95518]\n",
63 | "[0.96946 0.96946 0.96946 0.96946 0.96946]\n",
64 | "[0.9822 0.9822 0.9822 0.9822 0.9822]\n",
65 | "[0.99078 0.99078 0.99078 0.99078 0.99078]\n",
66 | "[1. 1. 1. 1. 1.]\n",
67 | "\n",
68 | "Number of probes: 2\n",
69 | "['top1', 'top2', 'top4', 'top5', 'top10']\n",
70 | "[0.7355 0.83386 0.83386 0.83386 0.83386]\n",
71 | "[0.84186 0.92572 0.92572 0.92572 0.92572]\n",
72 | "[0.89878 0.96662 0.96662 0.96662 0.96662]\n",
73 | "[0.93288 0.98402 0.98402 0.98402 0.98402]\n",
74 | "[0.95518 0.99246 0.99246 0.99246 0.99246]\n",
75 | "[0.96946 0.99674 0.99674 0.99674 0.99674]\n",
76 | "[0.9822 0.99894 0.99894 0.99894 0.99894]\n",
77 | "[0.99078 0.99982 0.99982 0.99982 0.99982]\n",
78 | "[1. 1. 1. 1. 1.]\n",
79 | "\n",
80 | "Number of probes: 5\n",
81 | "['top1', 'top2', 'top4', 'top5', 'top10']\n",
82 | "[0.7355 0.83386 0.89734 0.91348 0.91348]\n",
83 | "[0.84186 0.92572 0.96844 0.97612 0.97612]\n",
84 | "[0.89878 0.96662 0.99118 0.99438 0.99438]\n",
85 | "[0.93288 0.98402 0.99746 0.9988 0.9988 ]\n",
86 | "[0.95518 0.99246 0.99934 0.99978 0.99978]\n",
87 | "[0.96946 0.99674 0.99984 0.9999 0.9999 ]\n",
88 | "[0.9822 0.99894 0.99994 0.99998 0.99998]\n",
89 | "[0.99078 0.99982 1. 1. 1. ]\n",
90 | "[1. 1. 1. 1. 1.]\n",
91 | "\n",
92 | "Number of probes: 10\n",
93 | "['top1', 'top2', 'top4', 'top5', 'top10']\n",
94 | "[0.7355 0.83386 0.89734 0.91348 0.95232]\n",
95 | "[0.84186 0.92572 0.96844 0.97612 0.99114]\n",
96 | "[0.89878 0.96662 0.99118 0.99438 0.99898]\n",
97 | "[0.93288 0.98402 0.99746 0.9988 0.99986]\n",
98 | "[0.95518 0.99246 0.99934 0.99978 1. ]\n",
99 | "[0.96946 0.99674 0.99984 0.9999 1. ]\n",
100 | "[0.9822 0.99894 0.99994 0.99998 1. ]\n",
101 | "[0.99078 0.99982 1. 1. 1. ]\n",
102 | "[1. 1. 1. 1. 1.]\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "search_dim = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n",
108 | "nprobes = [1, 2, 5, 10]\n",
109 | "ncentroids = [1024]\n",
110 | "\n",
111 | "for centroid in ncentroids:\n",
112 | " print(\"Clusters: \", centroid)\n",
113 | " \n",
114 | " # Load kmeans index\n",
115 | " size = str(centroid)+'ncentroid_'+str(d)+'d'\n",
116 | " index_file = root+'index_files/'+model+dataset+'_'+index_type+'_'+size+\"_nbits8_nlist2048\"\n",
117 | " cpu_index = faiss.read_index(index_file+'.index')\n",
118 | " if torch.cuda.device_count() > 0:\n",
119 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n",
120 | " \n",
121 | " # Load and normalize centroids\n",
122 | " centroids_path = root+'kmeans/'+model+'ncentroids'+str(centroid)+\"_\"+str(d)+'d'\"_\"+dataset+'.npy'\n",
123 | " centroids = np.load(centroids_path)\n",
124 | " faiss.normalize_L2(centroids)\n",
125 | " gt = np.argsort(-queryset @ centroids.T, axis=1)\n",
126 | " \n",
127 | " topK = [1, 2, 4, 5, 10]\n",
128 | " \n",
129 | " for nprobe in nprobes:\n",
130 | " print(\"\\nNumber of probes:\", nprobe)\n",
131 | " print([f'top{k}' for k in topK])\n",
132 | " for dim in search_dim:\n",
133 | " q = np.ascontiguousarray(queryset[:, :dim])\n",
134 | " nqueries = q.shape[0]\n",
135 | " faiss.normalize_L2(q)\n",
136 | " c = np.ascontiguousarray(centroids[:, :dim])\n",
137 | " faiss.normalize_L2(c)\n",
138 | " low_d_clusters = np.argsort(-q @ c.T, axis=1)\n",
139 | " \n",
140 | " count = [0, 0, 0, 0, 0]\n",
141 | " \n",
142 | " # Iterate over all queries\n",
143 | " for i in range(nqueries):\n",
144 | " label = gt[i][0]\n",
145 | " target = low_d_clusters[i][:nprobe]\n",
146 | " for j in range(len(topK)):\n",
147 | " count[j] += label in target[:topK[j]] # increments count[j] if correct\n",
148 | "\n",
149 | " print(np.array(count) / nqueries)"
150 | ]
151 | }
152 | ],
153 | "metadata": {
154 | "kernelspec": {
155 | "display_name": "Python 3 (ipykernel)",
156 | "language": "python",
157 | "name": "python3"
158 | },
159 | "language_info": {
160 | "codemirror_mode": {
161 | "name": "ipython",
162 | "version": 3
163 | },
164 | "file_extension": ".py",
165 | "mimetype": "text/x-python",
166 | "name": "python",
167 | "nbconvert_exporter": "python",
168 | "pygments_lexer": "ipython3",
169 | "version": "3.11.3"
170 | },
171 | "vscode": {
172 | "interpreter": {
173 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc"
174 | }
175 | }
176 | },
177 | "nbformat": 4,
178 | "nbformat_minor": 5
179 | }
180 |
--------------------------------------------------------------------------------
/adanns/ablations/relative_contrast.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "1d746c2b",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import torch\n",
12 | "import faiss\n",
13 | "import sys\n",
14 | "sys.path.append('../')\n",
15 | "from utils import load_embeddings"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 9,
21 | "id": "3080d42d",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "root = '../../../inference_array/resnet50/'\n",
26 | "model = \"ff\" # mrl, ff\n",
27 | "dataset = '1K' # 1K, 4K, V2"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "id": "de07a2d2",
33 | "metadata": {},
34 | "source": [
35 | "### In the cell below, we use the relative contrast equation as defined in Equation (1) of [On the DIfficulty of Nearest Neighbor Search](https://www.ee.columbia.edu/ln/dvmm/pubs/files/ICML_RelativeContrast.pdf).
\n",
36 | "### $C_r = \\frac{D_{mean}}{D_{min}}$
\n",
37 | " where $C_r$ is the relative contrast of a dataset $X$, $D_{mean}$ is the expected distance of a random database sample from a query $q$, and $D_{min}$ is the expected distance to the nearest database sample from a query $q$.
"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "id": "d4f449b0",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "if torch.cuda.device_count() > 0:\n",
48 | " device = torch.device('cuda')\n",
49 | "else:\n",
50 | " raise Exception(\"Please use a GPU! This will take very very long otherwise.\")\n",
51 | "\n",
52 | "# dlist = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n",
53 | "dlist = [2048]\n",
54 | "batch_size = 4196\n",
55 | "\n",
56 | "for d in dlist:\n",
57 | " database, queryset, db_labels, query_labels, xb, xq = load_embeddings(model, dataset, d)\n",
58 | "\n",
59 | " qy = torch.as_tensor(xq).to(device)\n",
60 | " db = torch.as_tensor(xb)\n",
61 | " \n",
62 | " num_batches = int(database.shape[0] / batch_size)\n",
63 | " final_d_min = torch.ones((qy.shape[0])).to(device) * 1e10\n",
64 | " final_d_mean = []\n",
65 | "\n",
66 | " for i in range(num_batches):\n",
67 | " db_batch = db[(i)*batch_size:(i+1)*batch_size, :].to(device)\n",
68 | " distances = torch.cdist(qy, db_batch)\n",
69 | " sorted_dist = torch.sort(distances)\n",
70 | " current_d_min = sorted_dist.values[:, 0]\n",
71 | " \n",
72 | " final_d_min = torch.min(current_d_min, final_d_min)\n",
73 | " final_d_mean.append(torch.mean(distances, axis=1).cpu().numpy())\n",
74 | " \n",
75 | " C_r = np.mean(final_d_mean) / torch.mean(final_d_min).cpu().numpy()\n",
76 | " print(f'C_r(d={d})={C_r}')"
77 | ]
78 | }
79 | ],
80 | "metadata": {
81 | "kernelspec": {
82 | "display_name": "Python 3 (ipykernel)",
83 | "language": "python",
84 | "name": "python3"
85 | },
86 | "language_info": {
87 | "codemirror_mode": {
88 | "name": "ipython",
89 | "version": 3
90 | },
91 | "file_extension": ".py",
92 | "mimetype": "text/x-python",
93 | "name": "python",
94 | "nbconvert_exporter": "python",
95 | "pygments_lexer": "ipython3",
96 | "version": "3.11.3"
97 | },
98 | "vscode": {
99 | "interpreter": {
100 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc"
101 | }
102 | }
103 | },
104 | "nbformat": 4,
105 | "nbformat_minor": 5
106 | }
107 |
--------------------------------------------------------------------------------
/adanns/adanns-ivf-optimized.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "c7ade3d8",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import faiss\n",
11 | "import numpy as np\n",
12 | "import time\n",
13 | "import csv\n",
14 | "from os import path, makedirs\n",
15 | "\n",
16 | "import multiprocessing\n",
17 | "from multiprocessing.dummy import Pool as ThreadPool\n",
18 | "from functools import partial\n",
19 | "\n",
20 | "from faiss.contrib.ivf_tools import add_preassigned, search_preassigned"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "id": "43f2af97",
26 | "metadata": {},
27 | "source": [
28 | "## AdANNS-IVF"
29 | ]
30 | },
31 | {
32 | "attachments": {},
33 | "cell_type": "markdown",
34 | "id": "9744b124",
35 | "metadata": {},
36 | "source": [
37 | "### Notation\n",
38 | "1. $D$ = Embedding Dimensionality for IVF construction and search\n",
39 | "2. $M$ = number of OPQ subquantizers. Faiss requires $D$ % $M$ == $0$. \n",
40 | "3. For AdANNS, D is decomposed to $D_{construct}$ and $D_{search}$\n",
41 | "\n",
42 | "### Miscellaneous Notes\n",
43 | "1. Rigid representations (RR) are embedded through independently trained \"fixed feature\" (FF) encoders. RR and FF are thus used interchangeably in documentation and code and are essentially equivalent.\n",
44 | "2. In this notebook, the AdANNS-IVF coarse quantizer uses OPQ by default for cheap distance computation, but is optional.\n",
45 | "3. AdANNS-IVF is adapted from this [Faiss Case Study](https://gist.github.com/mdouze/8c5ab227c0f7d9d7c15cf92a391dcbe5#file-demo_independent_ivf_dimension-ipynb)\n",
46 | "4. Optimized AdANNS-IVF (with Faiss) has a restriction that $D_{construct}\\geq D_{search}$. This is because we slice centroids learnt from $D_{construct}$ to learn PQ codebooks with $D_{search}$ (this is possible because they are MRs)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 9,
52 | "id": "fc82acd3",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "D = 2048 # Max d for ResNet50\n",
57 | "n_cell = 1024 # number of IVF cells, default=1024 for ImageNet-1K\n",
58 | "\n",
59 | "embeddings_root = 'path/to/embeddings' # load embeddings\n",
60 | "adanns_root = 'path/to/adanns/indices/' # store adanns indices\n",
61 | "rigid_root = 'path/to/rigid/indices/' # store rigid indices\n",
62 | "config = 'rr' # mrl, rr\n",
63 | "\n",
64 | "if config == 'mrl':\n",
65 | " config_load = 'mrl1_e0_ff2048'\n",
66 | "elif config == 'rr':\n",
67 | " config_load = 'mrl0_e0_ff2048'\n",
68 | "else:\n",
69 | " raise Exception(f\"Unsupported config {config}!\")\n",
70 | "\n",
71 | "use_mrl = config.upper() # MRL, RR\n",
72 | "\n",
73 | "db_npy = '1K_train_' + config_load + '-X.npy'\n",
74 | "query_npy = '1K_val_' + config_load + '-X.npy'"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "045780fe",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "xb = np.load(embeddings_root + db_npy)\n",
85 | "assert np.count_nonzero(np.isnan(xb)) == 0\n",
86 | "xq = np.load(embeddings_root + query_npy)\n",
87 | "\n",
88 | "query_labels = np.load(embeddings_root + \"1K_val_\" + config_load + \"-y.npy\")\n",
89 | "db_labels = np.load(embeddings_root + \"1K_train_\" + config_load + \"-y.npy\")\n",
90 | "\n",
91 | "print(\"loaded DB %s : %s\" % (db_npy, xb.shape))\n",
92 | "print(\"loaded queries %s : %s\" % (query_npy, xq.shape))"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "a1cdc1ef",
98 | "metadata": {},
99 | "source": [
100 | "## RR2048 OPQ Dim Reduction Baseline"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 11,
106 | "id": "517fa2d4",
107 | "metadata": {},
108 | "outputs": [
109 | {
110 | "name": "stdout",
111 | "output_type": "stream",
112 | "text": [
113 | "(100000, 2048)\n"
114 | ]
115 | }
116 | ],
117 | "source": [
118 | "db_subsampled = xb[np.random.choice(xb.shape[0], 100000, replace=False)]\n",
119 | "print(db_subsampled.shape)\n",
120 | "dim_reduce = 128"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "id": "45fe9462",
126 | "metadata": {},
127 | "source": [
128 | "### SVD dim reduction + OPQ"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 12,
134 | "id": "4829ab90",
135 | "metadata": {},
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | "SVD projected Database: (1281167, 128)\n",
142 | "SVD projected Queries: (50000, 128)\n"
143 | ]
144 | }
145 | ],
146 | "source": [
147 | "def get_SVD_mat(db_subsampled, low_dim):\n",
148 | " mat = faiss.PCAMatrix(db_subsampled.shape[1], low_dim)\n",
149 | " mat.train(db_subsampled)\n",
150 | " assert mat.is_trained\n",
151 | " return mat\n",
152 | "\n",
153 | "svd_mat = get_SVD_mat(db_subsampled, dim_reduce)\n",
154 | "database_svd_lowdim = svd_mat.apply(xb)\n",
155 | "print(\"SVD projected Database: \", database_svd_lowdim.shape)\n",
156 | "query_svd_lowdim = svd_mat.apply(xq)\n",
157 | "print(\"SVD projected Queries: \", query_svd_lowdim.shape)\n",
158 | "\n",
159 | "faiss.normalize_L2(database_svd_lowdim)\n",
160 | "faiss.normalize_L2(query_svd_lowdim)"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 13,
166 | "id": "bd787adc",
167 | "metadata": {},
168 | "outputs": [
169 | {
170 | "name": "stdout",
171 | "output_type": "stream",
172 | "text": [
173 | "Building FF_D2048_SVD128_OPQ128.faiss\n",
174 | "Train+add time: 13411.728565454483\n",
175 | "[2048, 128, 128, 0.69224]\n"
176 | ]
177 | }
178 | ],
179 | "source": [
180 | "for M in [128]:\n",
181 | " if not path.exists(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss'):\n",
182 | " print(f'Building {use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')\n",
183 | " cpu_index = faiss.index_factory(dim_reduce, f'OPQ{M},PQ{M}')\n",
184 | " start = time.time()\n",
185 | " cpu_index.train(database_svd_lowdim)\n",
186 | " cpu_index.add(database_svd_lowdim)\n",
187 | " print(\"Train+add time: \", time.time() - start)\n",
188 | " faiss.write_index(cpu_index, f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')\n",
189 | " \n",
190 | " top1 = [xb.shape[1], dim_reduce, M]\n",
191 | " _, Ind = cpu_index.search(query_svd_lowdim, 100)\n",
192 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n",
193 | " print(top1)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": null,
199 | "id": "2b042b5b",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "name": "stdout",
204 | "output_type": "stream",
205 | "text": [
206 | "[2048, 128, 16, 0.69088]\n"
207 | ]
208 | }
209 | ],
210 | "source": [
211 | "for M in [128]:\n",
212 | " top1 = [xb.shape[1], dim_reduce, M]\n",
213 | " svd_opq_index = faiss.read_index(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}.faiss')\n",
214 | " _, Ind = svd_opq_index.search(query_svd_lowdim, 100)\n",
215 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n",
216 | " print(top1)"
217 | ]
218 | },
219 | {
220 | "cell_type": "markdown",
221 | "id": "e9e4b00a",
222 | "metadata": {},
223 | "source": [
224 | "## Rigid-IVF + OPQ"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 19,
230 | "id": "944b8f15",
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "Skipping build, index exists: FF_D2048+IVF1024,OPQ16.faiss\n"
238 | ]
239 | }
240 | ],
241 | "source": [
242 | "# Construct Rigid Index\n",
243 | "\n",
244 | "for M in [16]:\n",
245 | " for D in [2048]:\n",
246 | " database = np.ascontiguousarray(xb[:,:D], dtype=np.float32)\n",
247 | " faiss.normalize_L2(database)\n",
248 | " \n",
249 | " if M > D:\n",
250 | " continue\n",
251 | "\n",
252 | " if not path.exists(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss'):\n",
253 | " print(f'Building {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n",
254 | " start = time.time()\n",
255 | "\n",
256 | " index = faiss.index_factory(int(D), f'IVF{n_cell},PQ{M}')\n",
257 | "\n",
258 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')\n",
259 | " opq = opq_index_pretrained.chain.at(0)\n",
260 | "\n",
261 | " db = opq.apply(database)\n",
262 | "\n",
263 | " index.train(db)\n",
264 | " index.add(db)\n",
265 | "\n",
266 | " print(\"Time: \", time.time() - start)\n",
267 | " faiss.write_index(index, f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n",
268 | " print(f'Created IVF{n_cell},OPQ{M} index with D={D}')\n",
269 | "\n",
270 | " else:\n",
271 | " print(f'Skipping build, index exists: {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 23,
277 | "id": "baa631e4",
278 | "metadata": {},
279 | "outputs": [
280 | {
281 | "name": "stdout",
282 | "output_type": "stream",
283 | "text": [
284 | "[n_cell, D, M, top1]\n",
285 | "[1024, 2048, 8, 0.64966]\n",
286 | "[1024, 2048, 16, 0.6663]\n",
287 | "[1024, 2048, 32, 0.67724]\n",
288 | "[1024, 2048, 64, 0.68588]\n"
289 | ]
290 | }
291 | ],
292 | "source": [
293 | "# Search Rigid Index\n",
294 | "\n",
295 | "print('[n_cell, D, M, top1]')\n",
296 | "for D in [2048]:\n",
297 | " queryset = np.ascontiguousarray(xq[:,:D], dtype=np.float32)\n",
298 | " faiss.normalize_L2(queryset)\n",
299 | " for M in [8, 16, 32, 64]:\n",
300 | " if M > D:\n",
301 | " continue\n",
302 | " \n",
303 | " top1 = [n_cell, D, M]\n",
304 | " times = [n_cell, D, M]\n",
305 | "\n",
306 | " index = faiss.read_index(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')\n",
307 | " \n",
308 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')\n",
309 | " opq = opq_index_pretrained.chain.at(0)\n",
310 | "\n",
311 | " q = opq.apply(queryset)\n",
312 | "\n",
313 | " for nprobe in [1]:\n",
314 | " start = time.time()\n",
315 | " faiss.extract_index_ivf(index).nprobe = nprobe \n",
316 | " Dist, Ind = index.search(q, 100)\n",
317 | "\n",
318 | " top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])\n",
319 | " times.append(time.time() - start)\n",
320 | "\n",
321 | " print(top1)"
322 | ]
323 | },
324 | {
325 | "cell_type": "markdown",
326 | "id": "2b0c62f6",
327 | "metadata": {},
328 | "source": [
329 | "## AdANNS-IVF + OPQ"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 45,
335 | "id": "b9b187df",
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "def create_adanns_indices(D_search, D_construct, M, n_cell):\n",
340 | " index_search = faiss.index_factory(D_search, f'OPQ{M},IVF{n_cell},PQ{M}')\n",
341 | " index_construct = faiss.index_factory(D_construct, f'IVF{n_cell},Flat')\n",
342 | " \n",
343 | " database = np.ascontiguousarray(xb[:,:D_construct], dtype=np.float32)\n",
344 | " faiss.normalize_L2(database)\n",
345 | "\n",
346 | " # train the full-dimensional \"construct\" coarse quantizer. IVF centroid assignments are learnt with D_construct\n",
347 | " if not path.exists(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss'):\n",
348 | " index_construct.train(database)\n",
349 | " quantizer_construct = index_construct.quantizer\n",
350 | " faiss.write_index(quantizer_construct, adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n",
351 | " else:\n",
352 | " print(\"Index exists: \", adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n",
353 | "\n",
354 | " # prepare the \"search\" coarse quantizer. OPQ codebooks are learnt on D_search\n",
355 | " if not path.exists(adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss'):\n",
356 | " quantizer_construct = faiss.read_index(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')\n",
357 | " database_search = np.ascontiguousarray(xb[:, :D_search], dtype=np.float32)\n",
358 | " centroids_search = np.ascontiguousarray(quantizer_construct.reconstruct_n(0, quantizer_construct.ntotal)[:, :D_search], dtype=np.float32)\n",
359 | " \n",
360 | " # Apply OPQ to search DB and centroids\n",
361 | " opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D_search}_nbits8.index')\n",
362 | " print(f'Applying OPQ: 1K_opq_{M}m_d{D_search}')\n",
363 | " opq = opq_index_pretrained.chain.at(0)\n",
364 | " opq.apply(centroids_search)\n",
365 | " opq.apply(database_search)\n",
366 | " faiss.normalize_L2(database_search)\n",
367 | " \n",
368 | " index_ivf_search = faiss.downcast_index(faiss.extract_index_ivf(index_search))\n",
369 | " index_ivf_search.quantizer.add(centroids_search)\n",
370 | "\n",
371 | " index_ivf_search.train(database_search)\n",
372 | " index_search.is_trained = True\n",
373 | "\n",
374 | " # coarse quantization with the construct quantizer\n",
375 | " _, Ic = quantizer_construct.search(database, 1) # each database vector assigned to one of num_cell centroids\n",
376 | " # add operation \n",
377 | " add_preassigned(index_ivf_search, database_search, Ic.ravel())\n",
378 | "\n",
379 | " faiss.write_index(index_ivf_search, adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')\n",
380 | " else:\n",
381 | " print(\"Index exists: \", adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')\n",
382 | " \n",
383 | " print(f'Initialized construct quantizer D{D_construct}, search quantizer D{D_search}, M{M}, ncell{n_cell}')"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 46,
389 | "id": "0ad99c4a",
390 | "metadata": {},
391 | "outputs": [
392 | {
393 | "name": "stdout",
394 | "output_type": "stream",
395 | "text": [
396 | "Skipping (M, d_small, d_big): (64, 2048, 64)\n",
397 | "Skipping (M, d_small, d_big): (64, 2048, 128)\n",
398 | "Skipping (M, d_small, d_big): (64, 2048, 256)\n",
399 | "Skipping (M, d_small, d_big): (64, 2048, 512)\n",
400 | "Skipping (M, d_small, d_big): (64, 2048, 1024)\n",
401 | "Index exists: case_study_decoupled/MRL_D2048+IVF1024,PQ64_big_quantizer.faiss\n",
402 | "Applying OPQ: 1K_opq_64m_d2048\n",
403 | "Initialized big quantizer D2048, small quantizer D2048, M64, ncell1024\n"
404 | ]
405 | }
406 | ],
407 | "source": [
408 | "for D_construct in [64, 128, 256, 512, 1024, 2048]:\n",
409 | " for D_search in [2048]:\n",
410 | " for M in [64]:\n",
411 | " if M > D_search or D_search > D_construct:\n",
412 | " print(\"Skipping (M, d_search, d_construct): \", (M, D_search, D_construct))\n",
413 | " continue\n",
414 | " create_adanns_indices(D_search, D_construct, M, n_cell=1024)"
415 | ]
416 | },
417 | {
418 | "cell_type": "code",
419 | "execution_count": null,
420 | "id": "5d2081b7",
421 | "metadata": {},
422 | "outputs": [],
423 | "source": [
424 | "# Preassigned Search using multiple cores\n",
425 | "\n",
426 | "USE_MULTITHREAD_SEARCH = True\n",
427 | "num_cores = multiprocessing.cpu_count()\n",
428 | "thread_batch_size = 1000\n",
429 | "\n",
430 | "# Helper function to split search on multiple cores\n",
431 | "def multisearch_preassigned(index, queryset, Ic, batch_iter):\n",
432 | " _, I = search_preassigned(index, \n",
433 | " queryset[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], \n",
434 | " 100, # Shortlist length\n",
435 | " Ic[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], \n",
436 | " None)\n",
437 | " return I"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": 49,
443 | "id": "74680e75",
444 | "metadata": {},
445 | "outputs": [],
446 | "source": [
447 | "def search_adanns_indices(D_search, D_construct, n_cell, nprobes=[1]):\n",
448 | " queryset = np.ascontiguousarray(xq[:,:D_construct], dtype=np.float32)\n",
449 | " faiss.normalize_L2(queryset)\n",
450 | " \n",
451 | " queryset_small = np.ascontiguousarray(xq[:, :D_search], dtype=np.float32)\n",
452 | " faiss.normalize_L2(queryset_small)\n",
453 | " \n",
454 | " for M in [64]:\n",
455 | " top1 = [n_cell, D_construct, D_search, M]\n",
456 | " times = [n_cell, D_construct, D_search, M]\n",
457 | " if M > D_search or D_search > D_construct:\n",
458 | " continue\n",
459 | " \n",
460 | " # print(f'MRL IVF{n_cell},PQ{M}: D{D_search} search with D{D_construct} coarse quantization')\n",
461 | " quantizer_big = faiss.read_index(adanns_root + f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_big_quantizer.faiss')\n",
462 | " index_ivf_small = faiss.read_index(adanns_root + f'MRL_Dsmall{D_search}_Dbig{D_construct}+IVF{n_cell},OPQ{M}_small_quantizer.faiss')\n",
463 | " \n",
464 | " # disable precomputed tables, because the Dc is out of sync with the \n",
465 | " # small coarse quantizer\n",
466 | " index_ivf_small.use_precomputed_table = -1\n",
467 | " index_ivf_small.precompute_table()\n",
468 | "\n",
469 | " for nprobe in nprobes:\n",
470 | " start = time.time()\n",
471 | "\n",
472 | " # coarse quantization \n",
473 | " _, Ic = quantizer_big.search(queryset, nprobe) # Ic: (50K, nprobe)\n",
474 | "\n",
475 | " # actual search \n",
476 | " index_ivf_small.nprobe = nprobe\n",
477 | " \n",
478 | " if USE_MULTITHREAD_SEARCH:\n",
479 | " pool = ThreadPool(num_cores)\n",
480 | " partial_func = partial(multisearch_preassigned, index=index_ivf_small, queryset=queryset_small, Ic=Ic)\n",
481 | " I = pool.map(partial_func, range(queryset_small.shape[0] // thread_batch_size)) # 50K queries split to (num_batches, thread_batch_size) batches\n",
482 | " pool.close()\n",
483 | " pool.join()\n",
484 | " \n",
485 | " else:\n",
486 | " _, I = search_preassigned(index_ivf_small, queryset_small, 100, Ic, None) # I: (50K, 100)\n",
487 | "\n",
488 | " top1.append((np.sum(db_labels[I[:, 0]] == query_labels)) / query_labels.shape[0])\n",
489 | " times.append(time.time()-start)\n",
490 | " \n",
491 | " if (len(top1) > 4): # ignore continued cases\n",
492 | " with open('adanns-faiss-top1-opq.csv', 'a', encoding='UTF8', newline='') as f:\n",
493 | " writer = csv.writer(f)\n",
494 | " writer.writerow(top1)\n",
495 | " with open('adanns-faiss-timing-opq.csv', 'a', encoding='UTF8', newline='') as f:\n",
496 | " writer = csv.writer(f)\n",
497 | " writer.writerow(times)\n",
498 | " print(top1)\n",
499 | " # print(times)"
500 | ]
501 | },
502 | {
503 | "cell_type": "markdown",
504 | "id": "3baea8e1",
505 | "metadata": {},
506 | "source": [
507 | "## Metric Computation"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": 50,
513 | "id": "13a8b9b8",
514 | "metadata": {
515 | "scrolled": true
516 | },
517 | "outputs": [
518 | {
519 | "name": "stdout",
520 | "output_type": "stream",
521 | "text": [
522 | "['n_cell', 'D_big', 'D_small', 'M', '1probe', '4probe', '8probe']\n",
523 | "[1024, 64, 64, 64, 0.6942]\n",
524 | "[1024, 128, 64, 64, 0.69422]\n",
525 | "[1024, 128, 128, 64, 0.69584]\n",
526 | "[1024, 256, 64, 64, 0.69334]\n",
527 | "[1024, 256, 128, 64, 0.69604]\n",
528 | "[1024, 256, 256, 64, 0.69632]\n",
529 | "[1024, 512, 64, 64, 0.69418]\n",
530 | "[1024, 512, 128, 64, 0.69676]\n",
531 | "[1024, 512, 256, 64, 0.69568]\n",
532 | "[1024, 512, 512, 64, 0.6969]\n",
533 | "[1024, 1024, 64, 64, 0.69576]\n",
534 | "[1024, 1024, 128, 64, 0.69716]\n",
535 | "[1024, 1024, 256, 64, 0.69676]\n",
536 | "[1024, 1024, 512, 64, 0.69648]\n",
537 | "[1024, 1024, 1024, 64, 0.69412]\n",
538 | "[1024, 2048, 64, 64, 0.69444]\n",
539 | "[1024, 2048, 128, 64, 0.69608]\n",
540 | "[1024, 2048, 256, 64, 0.6973]\n",
541 | "[1024, 2048, 512, 64, 0.69628]\n",
542 | "[1024, 2048, 1024, 64, 0.69274]\n",
543 | "[1024, 2048, 2048, 64, 0.6899]\n"
544 | ]
545 | }
546 | ],
547 | "source": [
548 | "header = [\"n_cell\", \"D_construct\", \"D_search\", \"M\", \"1probe\", \"4probe\", \"8probe\"]\n",
549 | "print(header)\n",
550 | "\n",
551 | "with open('adanns-faiss-top1-opq.csv', 'w', encoding='UTF8', newline='') as f:\n",
552 | " writer = csv.writer(f)\n",
553 | " writer.writerow(header)\n",
554 | " \n",
555 | "with open('adanns-faiss-timing-opq.csv', 'w', encoding='UTF8', newline='') as f:\n",
556 | " writer = csv.writer(f)\n",
557 | " writer.writerow(header)\n",
558 | " \n",
559 | "for D_construct in [64, 128, 256, 512, 1024, 2048]:\n",
560 | " for D_search in [64, 128, 256, 512, 1024, 2048]:\n",
561 | " search_adanns_indices(D_search, D_construct, n_cell=1024, nprobes=[1])"
562 | ]
563 | }
564 | ],
565 | "metadata": {
566 | "kernelspec": {
567 | "display_name": "Python 3",
568 | "language": "python",
569 | "name": "python3"
570 | },
571 | "language_info": {
572 | "codemirror_mode": {
573 | "name": "ipython",
574 | "version": 3
575 | },
576 | "file_extension": ".py",
577 | "mimetype": "text/x-python",
578 | "name": "python",
579 | "nbconvert_exporter": "python",
580 | "pygments_lexer": "ipython3",
581 | "version": "3.11.0 (main, Oct 26 2022, 19:06:18) [Clang 14.0.0 (clang-1400.0.29.202)]"
582 | },
583 | "vscode": {
584 | "interpreter": {
585 | "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a"
586 | }
587 | }
588 | },
589 | "nbformat": 4,
590 | "nbformat_minor": 5
591 | }
592 |
--------------------------------------------------------------------------------
/adanns/adanns-ivf-unoptimized.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "04ace433",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import faiss\n",
12 | "import time\n",
13 | "import pandas as pd\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import csv\n",
16 | "from os import path, makedirs"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "id": "e8a1adec",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "root = 'path/to/embeddings/'\n",
27 | "\n",
28 | "D = 2048\n",
29 | "D_rr_search = 16 # to load high-D database and queryset for AR with rr models\n",
30 | "\n",
31 | "method = 'adanns' # adanns, mg-ivf-rr, mg-ivf-svd\n",
32 | "dataset = '1K' # 1K, 4K, V2"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 4,
38 | "id": "f98c143d",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "def load_data_helper(config, D_load=2048):\n",
43 | " db_csv = dataset + '_train_' + config + '-X.npy'\n",
44 | " query_csv = dataset + '_val_' + config + '-X.npy'\n",
45 | " db_label_csv = dataset + '_train_' + config + '-y.npy'\n",
46 | " query_label_csv = dataset + '_val_' + config + '-y.npy'\n",
47 | " \n",
48 | " if dataset == 'V2':\n",
49 | " db_csv = \"1K_train_\" + config + '-X.npy'\n",
50 | " db_label_csv = \"1K_train_\" + config + '-y.npy'\n",
51 | "\n",
52 | " db_load = np.ascontiguousarray(np.load(root+db_csv)[:, :D_load], dtype=np.float32)\n",
53 | " qy_load = np.ascontiguousarray(np.load(root+query_csv)[:, :D_load], dtype=np.float32)\n",
54 | " db_labels = np.load(root+db_label_csv)\n",
55 | " query_labels = np.load(root+query_label_csv)\n",
56 | "\n",
57 | " faiss.normalize_L2(db_load)\n",
58 | " faiss.normalize_L2(qy_load)\n",
59 | "\n",
60 | " return db_load, qy_load, db_labels, query_labels\n",
61 | "\n",
62 | "\n",
63 | "def load_construct_data(D_construct, D_rr_svd, ncentroids):\n",
64 | " if method == 'adanns':\n",
65 | " config = f'mrl1_e0_ff{D_construct}'\n",
66 | " elif method == 'mg-ivf-rr':\n",
67 | " config = f'mrl0_e0_ff{D_construct}'\n",
68 | " elif method == 'mg-ivf-svd':\n",
69 | " config = f'mrl0_e0_rr{D_construct}_svd{D_rr_svd}'\n",
70 | " else:\n",
71 | " raise Exception(\"Unsupported ANNS method.\")\n",
72 | " db_construct, qy_construct, db_labels, query_labels = load_data_helper(config, D_construct)\n",
73 | " \n",
74 | " print(\"Cluster Contruction DB: \", db_construct.shape)\n",
75 | " print(\"Cluster Construction queries:\", qy_construct.shape)\n",
76 | " \n",
77 | " # Load kmeans index and centroids with shape (centroid, D_construct)\n",
78 | " size = str(ncentroids)+'ncentroid_'+str(D_construct)+'Dc'\n",
79 | " if dataset == 'V2': # V2 is only a test set, change to 1K\n",
80 | " dataset = '1K'\n",
81 | " index_file = root+'index_files/'+method+dataset+'_kmeans_'+size\n",
82 | "\n",
83 | " centroids_path = root+'kmeans/'+method+'ncentroids'+str(ncentroids)+\"_\"+str(D_construct)+'Dc_'+dataset+'.npy'\n",
84 | " centroids = np.load(centroids_path)\n",
85 | " print(\"Loaded centroids: \", centroids.shape, centroids_path)\n",
86 | " \n",
87 | " return db_construct, qy_construct, db_labels, query_labels, centroids, index_file\n",
88 | "\n",
89 | "def load_search_data(D_search):\n",
90 | " if method == 'adanns':\n",
91 | " config = f'mrl1_e0_ff{D_search}'\n",
92 | " elif method in ['mg-ivf-rr', 'mg-ivf-svd']:\n",
93 | " config = f'mrl0_e0_ff{D_search}'\n",
94 | " else:\n",
95 | " raise Exception(\"Unsupported ANNS method.\")\n",
96 | " db_search, qy_search, _ , _ = load_data_helper(config, D_search)\n",
97 | "\n",
98 | " return db_search, qy_search\n"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 5,
104 | "id": "55ff5ca2",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "def eval_cluster(val_classes, db_classes, neighbors, k): \n",
109 | " APs, topk, recall = [], [], []\n",
110 | " cluster_size = neighbors.shape[0]\n",
111 | " for i in range(cluster_size):\n",
112 | " target = val_classes[i]\n",
113 | " indices = neighbors[i][:k] # k neighbor list for ith val vector\n",
114 | " labels = db_classes[indices]\n",
115 | " matches = (labels == target)\n",
116 | " \n",
117 | " # topk\n",
118 | " hits = np.sum(matches)\n",
119 | " if hits>0:\n",
120 | " topk.append(1)\n",
121 | " else:\n",
122 | " topk.append(0)\n",
123 | " \n",
124 | " # recall\n",
125 | " recall.append(np.sum(matches)/1300)\n",
126 | " \n",
127 | " # precision values\n",
128 | " tps = np.cumsum(matches)\n",
129 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n",
130 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n",
131 | " \n",
132 | " return np.mean(recall), np.mean(topk), np.mean(APs)"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 6,
138 | "id": "5a66d9b5",
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "def get_closest_centroids(queries, centroids, D_shortlist):\n",
143 | " centroid_index = faiss.IndexFlatL2(D_shortlist)\n",
144 | " xq_shortlist = np.ascontiguousarray(queries[:, :D_shortlist], dtype=np.float32)\n",
145 | " xc_shortlist = np.ascontiguousarray(centroids[:, :D_shortlist], dtype=np.float32)\n",
146 | " faiss.normalize_L2(xq_shortlist)\n",
147 | " faiss.normalize_L2(xc_shortlist)\n",
148 | " \n",
149 | " centroid_index.add(xc_shortlist)\n",
150 | " _, I = centroid_index.search(xq_shortlist, 1)\n",
151 | "\n",
152 | " return I"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 6,
158 | "id": "b69174c7",
159 | "metadata": {},
160 | "outputs": [
161 | {
162 | "name": "stdout",
163 | "output_type": "stream",
164 | "text": [
165 | "['d_c', 'd_s', 'd_shortlist', 'ncentroid', 'top1', 'recall@100', 'mAP@100', 'overlap']\n",
166 | "Cluster Contruction DB: (1281167, 2048)\n",
167 | "Cluster Construction queries: (10000, 2048)\n",
168 | "\n",
169 | "Loaded kmeans index: 1K_kmeans_1024ncentroid_2048d\n",
170 | "Loaded centroids: (1024, 2048) ../../inference_array/resnet50/kmeans/mrl/ncentroids1024_2048d_1K.npy\n",
171 | "Linear scan with d = 8\n",
172 | "[2048, 8, 2048, 1024, 0.5351, 0.05091163624126155, 0.605462860936279, 0]\n",
173 | "d_c:2048, d_s: 8, ncentroid: 1024\n",
174 | "Recall@100: 0.05091163624126155\n",
175 | "Top1: 0.5351\n",
176 | "Linear scan with d = 16\n",
177 | "[2048, 32, 2048, 1024, 0.5732, 0.051922105632729296, 0.6227463073287828, 0]\n",
178 | "d_c:2048, d_s: 32, ncentroid: 1024\n",
179 | "Recall@100: 0.051922105632729296\n",
180 | "Top1: 0.5732\n",
181 | "Linear scan with d = 64\n",
182 | "[2048, 64, 2048, 1024, 0.5785, 0.05198921759119408, 0.6243867377357402, 0]\n",
183 | "d_c:2048, d_s: 64, ncentroid: 1024\n",
184 | "Recall@100: 0.05198921759119408\n",
185 | "Top1: 0.5785\n",
186 | "Linear scan with d = 128\n",
187 | "[2048, 128, 2048, 1024, 0.5802, 0.05199322657824928, 0.6247377244191619, 0]\n",
188 | "d_c:2048, d_s: 128, ncentroid: 1024\n",
189 | "Recall@100: 0.05199322657824928\n",
190 | "Top1: 0.5802\n",
191 | "Linear scan with d = 256\n",
192 | "[2048, 256, 2048, 1024, 0.5801, 0.05199830116619949, 0.6250036507917335, 0]\n",
193 | "d_c:2048, d_s: 256, ncentroid: 1024\n",
194 | "Recall@100: 0.05199830116619949\n",
195 | "Top1: 0.5801\n",
196 | "Linear scan with d = 512\n",
197 | "[2048, 512, 2048, 1024, 0.5803, 0.05199979769465816, 0.6251013861127136, 0]\n",
198 | "d_c:2048, d_s: 512, ncentroid: 1024\n",
199 | "Recall@100: 0.05199979769465816\n",
200 | "Top1: 0.5803\n",
201 | "Linear scan with d = 1024\n",
202 | "[2048, 1024, 2048, 1024, 0.5766, 0.05198958841622709, 0.6249844689120537, 0]\n",
203 | "d_c:2048, d_s: 1024, ncentroid: 1024\n",
204 | "Recall@100: 0.05198958841622709\n",
205 | "Top1: 0.5766\n",
206 | "Total Time for 8 configs = 211.981683\n"
207 | ]
208 | }
209 | ],
210 | "source": [
211 | "D_search_list = [8, 16, 32, 64, 128, 256, 512, 1024]\n",
212 | "ncentroids = 1024\n",
213 | "\n",
214 | "for D in [2048]:\n",
215 | " k=100\n",
216 | " D_rr_svd = D\n",
217 | " D_construct_list = [D]\n",
218 | " D_shortlist_list = [D]\n",
219 | "\n",
220 | " header = ['d_construct', 'd_search', 'd_shortlist', 'ncentroid', 'top1', 'recall@'+str(k), 'mAP@'+str(k)]\n",
221 | " print(header)\n",
222 | " with open('kmeans_metrics.csv', 'w', encoding='UTF8', newline='') as f:\n",
223 | " writer = csv.writer(f)\n",
224 | " writer.writerow(header)\n",
225 | "\n",
226 | " start = time.time()\n",
227 | " for D_c in D_construct_list:\n",
228 | " # Load all construction data (database, queries, centroids)\n",
229 | " xb_construct, xq_construct, db_labels, query_labels, centroids, index_file = load_construct_data(D_c, D_rr_svd, ncentroids)\n",
230 | "\n",
231 | " cpu_index = faiss.read_index(index_file+'.index')\n",
232 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n",
233 | " print(\"\\nLoaded kmeans index:\", index_file.split(\"/\")[-1])\n",
234 | "\n",
235 | " # construct lookup table of centroid --> vectors, i.e. inverted lists\n",
236 | " _, I_db = index.search(xb_construct, 1)\n",
237 | " lut_db = {}\n",
238 | " for c in np.unique(I_db):\n",
239 | " lut_db[c] = np.argwhere(I_db==c)[:,0]\n",
240 | "\n",
241 | " for D_search in D_search_list:\n",
242 | " print(\"Linear scan with D_s = \", D_search)\n",
243 | " xb_search, xq_search = load_search_data(D_search)\n",
244 | "\n",
245 | " for D_shortlist in D_shortlist_list:\n",
246 | " # Currently, D_shortlist <= D_search is supported as we slice centroids for adanns\n",
247 | " I_q = get_closest_centroids(xq_search, centroids, D_shortlist)\n",
248 | " lut_q = {}\n",
249 | "\n",
250 | " start = time.time()\n",
251 | " recall, topk, mAP = [], [], []\n",
252 | "\n",
253 | " #Iterate over all centroids assigned to each\n",
254 | " for c in np.unique(I_q):\n",
255 | " lut_q[c] = np.argwhere(I_q==c)[:,0]\n",
256 | " exact_cpu_index = faiss.IndexFlatL2(D_search)\n",
257 | "\n",
258 | " # add cluster vectors to index and search only queries that map to that cluster\n",
259 | " exact = faiss.index_cpu_to_all_gpus(exact_cpu_index)\n",
260 | " cluster_db = np.ascontiguousarray(xb_search[lut_db[c]][:, :D_search], np.float32)\n",
261 | " cluster_query = np.ascontiguousarray(xq_search[lut_q[c]][:, :D_search], np.float32)\n",
262 | " faiss.normalize_L2(cluster_db)\n",
263 | " faiss.normalize_L2(cluster_query)\n",
264 | " exact.add(cluster_db)\n",
265 | " Dist, Ind = exact.search(cluster_query, k)\n",
266 | "\n",
267 | " # replace cluster-specific indices with original database indices for eval\n",
268 | " cluster_db_labels = db_labels[lut_db[c]]\n",
269 | " cluster_query_labels = query_labels[lut_q[c]]\n",
270 | "\n",
271 | " nn_1 = Ind[:, 0]\n",
272 | " pred_1 = cluster_db_labels[nn_1]\n",
273 | " hits = np.sum(pred_1 == cluster_query_labels)\n",
274 | " topk.append(hits)\n",
275 | "\n",
276 | " rl, tk, mp = eval_cluster(cluster_query_labels, cluster_db_labels, Ind, k)\n",
277 | " recall.append(rl)\n",
278 | " mAP.append(mp)\n",
279 | " row = [D_c, D_search, D_shortlist, ncentroids, np.sum(topk)/xq_search.shape[0], np.mean(recall), np.mean(mAP)]\n",
280 | " print(row)\n",
281 | "\n",
282 | " with open('kmeans_metrics.csv', 'a', encoding='UTF8', newline='') as f:\n",
283 | " writer = csv.writer(f)\n",
284 | " writer.writerow(row)\n",
285 | " print(\"d_c:%d, d_s: %d, ncentroid: %d\" %(D_c, D_search, ncentroids))\n",
286 | " print(\"Recall@100: \", np.mean(recall))\n",
287 | " print(\"Top1: \", np.sum(topk)/xq_search.shape[0])\n",
288 | "\n",
289 | " print(\"Total Time for %d configs = %f\" % (len(D_search_list) * len(ncentroids) * len(D_construct_list), time.time() - start))"
290 | ]
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "Python 3",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.11.0"
310 | },
311 | "vscode": {
312 | "interpreter": {
313 | "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a"
314 | }
315 | }
316 | },
317 | "nbformat": 4,
318 | "nbformat_minor": 5
319 | }
320 |
--------------------------------------------------------------------------------
/adanns/compute_metrics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "aa3cf273",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import pandas as pd\n",
11 | "import numpy as np\n",
12 | "import time\n",
13 | "import os\n",
14 | "import faiss\n",
15 | "import csv"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "f4fbeb4d",
21 | "metadata": {},
22 | "source": [
23 | "## Configuration Variables"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "6b093168",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "D = 2048 # vector dim\n",
34 | "ROOT_DIR = '../../inference_array/resnet50/'\n",
35 | "CONFIG = 'mrl/' # ['mrl/', 'rr/']\n",
36 | "NESTING = CONFIG == 'mrl/'\n",
37 | "SEARCH_INDEX = 'ivfpq' # ['exactl2', 'ivfpq', 'opq', 'hnsw32']\n",
38 | "DATASET = '1K' # 1K, V2, 4K\n",
39 | "\n",
40 | "# Quantization Variables\n",
41 | "nbits = 8 # nbits used to represent centroid id; total possible is k* = 2**nbits\n",
42 | "nlist = 1024 # how many Voronoi cells (must be >= k*)\n",
43 | "iterator = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] # vector dim D \n",
44 | "\n",
45 | "if SEARCH_INDEX in ['ivfpq', 'opq']:\n",
46 | " M = 32 # number of sub-quantizers, i.e. compression in bytes"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "id": "9c9351db",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "def compute_mAP_recall_at_k(val_classes, db_classes, neighbors, k):\n",
57 | " \"\"\"\n",
58 | " Computes the MAP@k on neighbors with val set by seeing if nearest neighbor\n",
59 | " is in the same class as the class of the val code. Let m be size of val set, and n in train.\n",
60 | "\n",
61 | " val: (m x d) All the truncated vector representations of images in val set\n",
62 | " val_classes: (m x 1) class index values for each vector in the val set\n",
63 | " db_classes: (n x 1) class index values for each vector in the train set\n",
64 | " neighbors: (m x k) indices in train set of top k neighbors for each vector in val set\n",
65 | " \"\"\"\n",
66 | "\n",
67 | " \"\"\"\n",
68 | " ImageNet-1K:\n",
69 | " shape of val is: (50000, dim)\n",
70 | " shape of val_classes is: (50000, 1)\n",
71 | " shape of db_classes is: (1281167, 1)\n",
72 | " shape of neighbors is: (50000, k)\n",
73 | " \"\"\"\n",
74 | " APs, precision, recall, topk, unique_cls = [], [], [], [], []\n",
75 | " \n",
76 | " for i in range(val_classes.shape[0]): # Compute precision for each vector's list of k-nn\n",
77 | " target = val_classes[i]\n",
78 | " indices = neighbors[i, :][:k] # k neighbor list for ith val vector\n",
79 | " labels = db_classes[indices]\n",
80 | " matches = (labels == target)\n",
81 | " \n",
82 | " # Number of unique classes\n",
83 | " unique_cls.append(len(np.unique(labels)))\n",
84 | " \n",
85 | " # topk\n",
86 | " hits = np.sum(matches)\n",
87 | " if hits > 0:\n",
88 | " topk.append(1)\n",
89 | " else:\n",
90 | " topk.append(0)\n",
91 | " \n",
92 | " # true positive counts\n",
93 | " tps = np.cumsum(matches)\n",
94 | "\n",
95 | " # recall\n",
96 | " recall.append(np.sum(matches)/1300)\n",
97 | " precision.append(np.sum(matches)/k)\n",
98 | "\n",
99 | " # precision values\n",
100 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n",
101 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n",
102 | "\n",
103 | " return np.mean(APs), np.mean(precision), np.mean(recall), np.mean(topk), np.mean(unique_cls)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "id": "ab7ca0bb",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "def get_k_recall_at_N(exact_gt, neighbors, k=40, N=2048):\n",
114 | " \"\"\"\n",
115 | " Computes k-Recall@N which denotes the recall of k true nearest neighbors (exact search) \n",
116 | " when N datapoints are retrieved with ANNS. Let q be size of query set.\n",
117 | " \n",
118 | " exact_gt: (q x k) True nearest-neighbors of query set computed with exact search\n",
119 | " neighbors: (q x N) Approximate nearest-neighbors of query set\n",
120 | " k: (1) Number of true nearest-neighbors\n",
121 | " N: (1) Number of approximate nearest-neighbors retrieved\n",
122 | " \"\"\"\n",
123 | " labels = exact_gt[:, :k] # Labels from true NN\n",
124 | " targets = neighbors\n",
125 | " num_queries = exact_gt.shape[0]\n",
126 | " count = 0\n",
127 | " for i in range(num_queries):\n",
128 | " label = labels[i]\n",
129 | " target = targets[i, :N]\n",
130 | " # Compute overlap between approximate and true nearest-neighbors\n",
131 | " count += len(list(set(label).intersection(target)))\n",
132 | " return count / (num_queries * k)"
133 | ]
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "id": "74005467",
138 | "metadata": {},
139 | "source": [
140 | "## Load database, query, and neighbor arrays and compute metrics"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 5,
146 | "id": "1245f078",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "def load_knn_array(dim, **kwargs):\n",
151 | " if SEARCH_INDEX in ['ivfpq', 'opq']:\n",
152 | " if (M > dim):\n",
153 | " return\n",
154 | " size = 'm'+str(M)+'_nlist'+str(nlist)+\"_nprobe\"+str(nprobe)+\"_\"\n",
155 | " elif SEARCH_INDEX == 'ivfsq':\n",
156 | " size = str(qtype)+'qtype_'\n",
157 | " elif SEARCH_INDEX == 'kmeans':\n",
158 | " size = str(nlist)+'ncentroid_'\n",
159 | " elif SEARCH_INDEX == 'ivf':\n",
160 | " size = 'nlist'+str(nlist)+\"_nprobe\"+str(nprobe)+\"_\"\n",
161 | " elif SEARCH_INDEX in ['hnsw32', 'hnswpq_M32_pq-m8','hnswpq_M32_pq-m16','hnswpq_M32_pq-m32','hnswpq_M32_pq-m64', 'hnswpq_M32_pq-m128']:\n",
162 | " size = 'efsearch'+str(nprobe)+\"_\"\n",
163 | " else:\n",
164 | " raise Exception(f\"Unsupported Search Index: {SEARCH_INDEX}\")\n",
165 | "\n",
166 | " # Load neighbors array and compute metrics\n",
167 | " neighbors_path = ROOT_DIR + \"neighbors/\" + CONFIG + SEARCH_INDEX+\"/\"+SEARCH_INDEX + \"_\" + size \\\n",
168 | " + \"2048shortlist_\" + DATASET + \"_d\"+str(dim)+\".csv\"\n",
169 | " \n",
170 | " if not os.path.exists(neighbors_path):\n",
171 | " print(neighbors_path.split(\"/\")[-1] + \" not found\")\n",
172 | " return\n",
173 | "\n",
174 | " return pd.read_csv(neighbors_path, header=None).to_numpy()\n",
175 | "\n",
176 | "\n",
177 | "def print_metrics(iterator, shortlist, metric, nprobe=1, N=2048):\n",
178 | " \"\"\"\n",
179 | " Computes and print retrieval metrics.\n",
180 | " \n",
181 | " iterator: (List) True nearest-neighbors of query set computed with exact search\n",
182 | " shortlist: (List) Number of data points retrieved (k)\n",
183 | " metric: Name of metric ['topk', 'mAP', 'precision', 'recall', 'unique_cls', 'k_recall_at_n']\n",
184 | " nprobe: Number of clusters probed during search (IVF) OR 'efSearch' for HNSW search quality\n",
185 | " N: Number of data points retrieved for k-recall@N\n",
186 | " \"\"\"\n",
187 | " # Load database and query set for nested models\n",
188 | " if NESTING:\n",
189 | " # Database: 1.2M x 1 for Imagenet-1K\n",
190 | " if DATASET == 'V2':\n",
191 | " db_labels = np.load(ROOT_DIR + \"1K_train_mrl1_e0_ff2048-y.npy\")\n",
192 | " else:\n",
193 | " db_labels = np.load(ROOT_DIR + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n",
194 | " \n",
195 | " # Query set: 50K x 1 for Imagenet-1K\n",
196 | " query_labels = np.load(ROOT_DIR + DATASET + \"_val_mrl1_e0_ff2048-y.npy\")\n",
197 | " \n",
198 | " for dim in iterator:\n",
199 | " # Load database and query set for fixed feature models\n",
200 | " if not NESTING:\n",
201 | " db_labels = np.load(ROOT_DIR + DATASET + \"_train_mrl0_e0_ff\"+str(dim)+\"-y.npy\")\n",
202 | " query_labels = np.load(ROOT_DIR + DATASET + \"_val_mrl0_e0_ff\"+str(D)+\"-y.npy\")\n",
203 | " \n",
204 | " neighbors = load_knn_array(dim, M=M, nlist=nlist, nprobe=nprobe)\n",
205 | " \n",
206 | " for k in shortlist:\n",
207 | " if metric == 'k_recall_at_n':\n",
208 | " # Use 40-NN from Exact Search with MRL as GT\n",
209 | " if NESTING:\n",
210 | " query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/mrl_exactl2_2048dim_{k}shortlist_1K.csv', header=None).to_numpy()\n",
211 | " else:\n",
212 | " query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/rr_exactl2_{dim}dim_{k}shortlist_1K.csv', header=None).to_numpy()\n",
213 | " \n",
214 | " k_recall = (get_k_recall_at_N(query_labels, neighbors, k, N))\n",
215 | " print(f'{k}-Recall@{N} = {k_recall}')\n",
216 | " \n",
217 | " else:\n",
218 | " mAP, precision, recall, topk, unique_cls = compute_mAP_recall_at_k(query_labels, db_labels, neighbors, k)\n",
219 | " if (metric == 'topk'): print(f'topk, {dim}, {M}, {nprobe}, {topk}')\n",
220 | " elif (metric == 'mAP'): print(f'mAP, {dim}, {M}, {nprobe}, {mAP}')\n",
221 | " elif (metric == 'precision'): print(f'precision, {dim}, {M}, {nprobe}, {precision}')\n",
222 | " elif (metric == 'recall') : print(f'recall, {dim}, {M}, {nprobe}, {recall}')\n",
223 | " elif (metric == 'unique_cls'): print(f'unique_cls, {dim}, {M}, {nprobe}, {unique_cls}')\n",
224 | " else: raise Exception(\"Unsupported metric!\")"
225 | ]
226 | },
227 | {
228 | "cell_type": "markdown",
229 | "id": "5f2ea629",
230 | "metadata": {},
231 | "source": [
232 | "## Example: Traditional Retrieval Metrics (Top-1, mAP, Recall)"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": 6,
238 | "id": "a336eb99",
239 | "metadata": {},
240 | "outputs": [
241 | {
242 | "name": "stdout",
243 | "output_type": "stream",
244 | "text": [
245 | "Index: ivfpq\n",
246 | "metric, D, M, nprobe, value\n",
247 | "topk, 16, 8, 1, 0.6775\n",
248 | "topk, 32, 8, 1, 0.6861\n",
249 | "mAP, 16, 8, 1, 0.6306868079365078\n",
250 | "mAP, 32, 8, 1, 0.6374524079365079\n",
251 | "recall, 16, 8, 1, 0.05151807692307692\n",
252 | "recall, 32, 8, 1, 0.051838800000000004\n"
253 | ]
254 | }
255 | ],
256 | "source": [
257 | "# Example evaluation for IVFPQ\n",
258 | "iterator = [16, 32]\n",
259 | "print(\"Index:\", SEARCH_INDEX)\n",
260 | "print(\"metric, D, M, nprobe, value\")\n",
261 | "for M in [8]:\n",
262 | " for nprobe in [1]:\n",
263 | " print_metrics(iterator, [1], 'topk', nprobe)\n",
264 | " print_metrics(iterator, [10], 'mAP', nprobe)\n",
265 | " print_metrics(iterator, [100], 'recall', nprobe)"
266 | ]
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "id": "8ce6be99",
271 | "metadata": {},
272 | "source": [
273 | "## Example: ANNS Metric: k-Recall@N"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": 64,
279 | "id": "4d5663b9",
280 | "metadata": {},
281 | "outputs": [
282 | {
283 | "name": "stdout",
284 | "output_type": "stream",
285 | "text": [
286 | "k-recall@N GT: (50000, 40)\n",
287 | "40-Recall@2048 = 0.2071915\n",
288 | "k-recall@N GT: (50000, 40)\n",
289 | "40-Recall@2048 = 0.311641\n",
290 | "k-recall@N GT: (50000, 40)\n",
291 | "40-Recall@2048 = 0.377283\n",
292 | "k-recall@N GT: (50000, 40)\n",
293 | "40-Recall@2048 = 0.4137225\n"
294 | ]
295 | }
296 | ],
297 | "source": [
298 | "USE_K_RECALL_AT_N = True\n",
299 | "SEARCH_INDEX = 'hnsw32'\n",
300 | "iterator = [8, 16, 32, 64]\n",
301 | "\n",
302 | "print_metrics(iterator, [40], 'krecall', nprobe=1, N=2048)"
303 | ]
304 | }
305 | ],
306 | "metadata": {
307 | "kernelspec": {
308 | "display_name": "Python 3 (ipykernel)",
309 | "language": "python",
310 | "name": "python3"
311 | },
312 | "language_info": {
313 | "codemirror_mode": {
314 | "name": "ipython",
315 | "version": 3
316 | },
317 | "file_extension": ".py",
318 | "mimetype": "text/x-python",
319 | "name": "python",
320 | "nbconvert_exporter": "python",
321 | "pygments_lexer": "ipython3",
322 | "version": "3.11.3"
323 | }
324 | },
325 | "nbformat": 4,
326 | "nbformat_minor": 5
327 | }
328 |
--------------------------------------------------------------------------------
/adanns/diskann/README.md:
--------------------------------------------------------------------------------
1 | # AdANNS-DiskANN
2 | AdANNS-DiskANN is a variant of [DiskANN](https://github.com/microsoft/DiskANN), a web-scale graph-based ANNS index capable of serving queries from both RAM and Disk (cheap SSDs).
3 |
4 |
5 |
6 |
7 |
8 |
9 | We provide a self-contained pipeline in [adanns-diskann.ipynb](adanns-diskann.ipynb) which requires a build of DiskANN provided in the [original codebase](https://github.com/microsoft/DiskANN) and summarized below:
10 |
11 | ```
12 | sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format \
13 | libboost-all-dev libmkl-full-dev
14 |
15 | mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
16 | ```
17 |
18 | ## Notebook Overview
19 | [adanns-diskann.ipynb](adanns-diskann.ipynb) is broadly organized as:
20 | 1. Data preprocessing: convert MR or RR embeddings (fp32 `np.ndarray`) to binary format
21 | 2. Generate exact-search "ground truth" used for k-recall@N
22 | 3. Build In-Memory or SSD DiskANN index on MR or RR
23 | 4. Search the built indices to generate k-NN arrays
24 | 5. Evaluate the k-NN arrays with and without reranking
--------------------------------------------------------------------------------
/adanns/diskann/adanns-diskann.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "650d31f4",
6 | "metadata": {},
7 | "source": [
8 | "## Dataset Preparation for DiskANN"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "d6dbab11",
15 | "metadata": {
16 | "scrolled": false
17 | },
18 | "outputs": [],
19 | "source": [
20 | "import numpy as np\n",
21 | "import sys\n",
22 | "import os\n",
23 | "\n",
24 | "embeddings_root = 'path/to/embeddings/'\n",
25 | "\n",
26 | "def generate_bin_data_from_ndarray(embedding_path, bin_out_path, embedding_dims):\n",
27 | " data_orig = np.load(embedding_path)\n",
28 | " for d in embedding_dims:\n",
29 | " data_sliced = data_orig[:, :d]\n",
30 | " outfile = bin_out_path+\"_d\"+str(d)+\".fbin\"\n",
31 | " print(outfile.split(\"/\")[-1])\n",
32 | " print(\"Array sliced: \", data_sliced.shape)\n",
33 | " data_sliced.astype('float32').tofile(\"temp\")\n",
34 | "\n",
35 | " num_points = data_sliced.shape[0].to_bytes(4, 'little')\n",
36 | " data_dim = data_sliced.shape[1].to_bytes(4, 'little')\n",
37 | "\n",
38 | " with open(\"temp\", \"rb\") as old, open(outfile, \"wb\") as new:\n",
39 | " new.write(num_points)\n",
40 | " new.write(data_dim)\n",
41 | " new.write(old.read())\n",
42 | " \n",
43 | " os.remove(\"temp\")"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "b69e102e",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "nesting_list = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n",
54 | "for d in nesting_list:\n",
55 | " generate_bin_data_from_ndarray(embeddings_root+\"1K_train_mrl0_e0_ff\"+str(d)+\"-X.npy\", \"../build/data/rr-resnet50/fbin/database\", [d])\n",
56 | " print()\n",
57 | " generate_bin_data_from_ndarray(embeddings_root+\"1K_val_mrl0_e0_ff\"+str(d)+\"-X.npy\", \"../build/data/rr-resnet50/fbin/queries\", [d])"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "id": "5db9a1eb",
63 | "metadata": {},
64 | "source": [
65 | "## Generate Exact Search ground truth from queries and database"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 2,
71 | "id": "45584765",
72 | "metadata": {
73 | "scrolled": true
74 | },
75 | "outputs": [],
76 | "source": [
77 | "%%bash\n",
78 | "use_mrl=mr # mr or rr \n",
79 | "\n",
80 | "for d in 8 16 32 64 128 256 512 1024 2048\n",
81 | "do\n",
82 | " ./../build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 \\\n",
83 | " --base_file ../build/data/{use_mrl}-resnet50/fbin/database_d$d.fbin \\\n",
84 | " --query_file ../build/data/{use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n",
85 | " --gt_file ../build/data/{use_mrl}-resnet50/exact_gt100/${use_mrl}_r50_queries_d$d\"\"_gt100 --K 100\n",
86 | "done"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "id": "a21fb4fd",
92 | "metadata": {},
93 | "source": [
94 | "## Build DiskANN In-Memory Index"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "d5ae6f7d",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "%%bash\n",
105 | "opq_bytes=32\n",
106 | "use_mrl=mrl # mr or rr\n",
107 | "\n",
108 | "for use_mrl in mrl\n",
109 | "do\n",
110 | " for d in 32 64 128 256 512 1024 2048\n",
111 | " do\n",
112 | " echo -e \"Building index ${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d\\n\"\n",
113 | " ./../build/tests/build_memory_index --data_type float --dist_fn l2 \\\n",
114 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n",
115 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/memory-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d \\\n",
116 | " -R 64 -L 100 --alpha 1.2 --build_PQ_bytes ${opq_bytes} --use_opq\n",
117 | " done\n",
118 | "done"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "95b04524",
124 | "metadata": {},
125 | "source": [
126 | "## Build DiskANN SSD Index"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "id": "4cd4d51b",
133 | "metadata": {
134 | "scrolled": true
135 | },
136 | "outputs": [],
137 | "source": [
138 | "%%bash\n",
139 | "opq_bytes=48\n",
140 | "use_mrl=rr\n",
141 | "reorder=disk-index-no-reorder\n",
142 | "\n",
143 | "# Disable post-hoc re-ranking by setting PQ_disk_bytes = build_PQ_bytes\n",
144 | "for opq_bytes in 32 48 64\n",
145 | "do\n",
146 | " for d in 1024\n",
147 | " do\n",
148 | " echo -e \"Building disk OPQ index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n",
149 | " ./../build/tests/build_disk_index --data_type float --dist_fn l2 \\\n",
150 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n",
151 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/disk-index-no-reorder/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n",
152 | " -R 64 -L 100 -B 0.3 -M 40 --PQ_disk_bytes $opq_bytes --build_PQ_bytes $opq_bytes --use_opq \n",
153 | " done\n",
154 | "done\n",
155 | "\n",
156 | "# Build index with implicit post-hoc full-precision reranking\n",
157 | "for opq_bytes in 32 48 64\n",
158 | "do\n",
159 | " for d in 128 1024\n",
160 | " do\n",
161 | " ./../build/tests/build_disk_index --data_type float --dist_fn l2 \\\n",
162 | " --data_path ../build/data/${use_mrl}-resnet50/fbin/database_d$d.fbin \\\n",
163 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/disk-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n",
164 | " -R 64 -L 100 -B 0.3 -M 40 --build_PQ_bytes $opq_bytes --use_opq \n",
165 | " echo -e \"Build index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n",
166 | " done\n",
167 | "done"
168 | ]
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "id": "cad290ff",
173 | "metadata": {},
174 | "source": [
175 | "## Search DiskANN Memory Index"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 1,
181 | "id": "2685a209",
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "%%bash\n",
186 | "opq_bytes=32\n",
187 | "\n",
188 | "for use_mrl in rr mr\n",
189 | "do\n",
190 | " for d in 32 64 128 256 512 1024 2048\n",
191 | " do\n",
192 | " ./../build/tests/search_memory_index --data_type float --dist_fn l2 \\\n",
193 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/memory-index/${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d \\\n",
194 | " --query_file ../build/data/${use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n",
195 | " --gt_file ../build/data/${use_mrl}-resnet50/exact_gt100/mrlr50_queries_d$d\"\"_gt100 \\\n",
196 | " -K 100 -L 100 --result_path ../build/data/${use_mrl}-resnet50/res/memory-index/d$d/opq${opq_bytes}\n",
197 | " echo -e \"Searched index ${use_mrl}1K_opq${opq_bytes}_R64_L100_A1.2_d$d\\n\"\n",
198 | " done\n",
199 | "done"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "id": "e7627d51",
205 | "metadata": {},
206 | "source": [
207 | "## Search DiskANN SSD Index"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "id": "35655236",
214 | "metadata": {
215 | "scrolled": true
216 | },
217 | "outputs": [],
218 | "source": [
219 | "%%bash\n",
220 | "opq_bytes=48\n",
221 | "use_mrl=mrl\n",
222 | "reorder=disk-index\n",
223 | "\n",
224 | "for d in 1024\n",
225 | "do\n",
226 | " for W in 2 8 16 32 # search quality\n",
227 | " do\n",
228 | " ./../build/tests/search_disk_index --data_type float --dist_fn l2 \\\n",
229 | " --index_path_prefix ../build/data/${use_mrl}-resnet50/${reorder}/${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d \\\n",
230 | " --query_file ../build/data/${use_mrl}-resnet50/fbin/queries_d$d.fbin \\\n",
231 | " --gt_file ../build/data/${use_mrl}-resnet50/exact_gt100/mrlr50_queries_d$d\"\"_gt100 \\\n",
232 | " -K 100 -L 100 -W ${W} --num_nodes_to_cache 100000 --result_path ../build/data/${use_mrl}-resnet50/res/${reorder}/d$d/opq${opq_bytes}_W$W\n",
233 | " echo -e \"Searched index ${use_mrl}1K_opq${opq_bytes}_R64_L100_B0.3_d$d\\n\"\n",
234 | " done\n",
235 | "done"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "id": "38563959",
241 | "metadata": {},
242 | "source": [
243 | "# DiskANN Eval"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 1,
249 | "id": "cf835299",
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "D = 2048\n",
254 | "CONFIG = 'mr' # ['mr', 'rr']\n",
255 | "NESTING = CONFIG == 'mr'\n",
256 | "DISKANN_INDEX = 'memory-index' # disk-index\n",
257 | "DATASET = '1K' # ['1K', '4K', 'V2']"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 2,
263 | "id": "e3dc60ac",
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "def compute_mAP_recall_at_k(val_classes, db_classes, neighbors, k):\n",
268 | " APs = list()\n",
269 | " precision, recall, topk, majvote, unique_cls = [], [], [], [], []\n",
270 | " \n",
271 | " for i in range(val_classes.shape[0]): # Compute precision for each vector's list of k-nn\n",
272 | " target = val_classes[i]\n",
273 | " indices = neighbors[i, :][:k] # k neighbor list for ith val vector\n",
274 | " labels = db_classes[indices]\n",
275 | " matches = (labels == target)\n",
276 | " \n",
277 | " # Number of unique classes\n",
278 | " unique_cls.append(len(np.unique(labels)))\n",
279 | " \n",
280 | " # topk\n",
281 | " hits = np.sum(matches)\n",
282 | " if hits>0:\n",
283 | " topk.append(1)\n",
284 | " else:\n",
285 | " topk.append(0)\n",
286 | " \n",
287 | " # true positive counts\n",
288 | " tps = np.cumsum(matches)\n",
289 | "\n",
290 | " # recall\n",
291 | " recall.append(np.sum(matches)/1300)\n",
292 | " precision.append(np.sum(matches)/k)\n",
293 | "\n",
294 | " # precision values\n",
295 | " precs = tps.astype(float) / np.arange(1, k + 1, 1)\n",
296 | " APs.append(np.sum(precs[matches.squeeze()]) / k)\n",
297 | "\n",
298 | " return np.mean(APs), np.mean(precision), np.mean(recall), np.mean(topk), majvote, np.mean(unique_cls)"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 3,
304 | "id": "1669e2da",
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "def print_metrics(CONFIG, nesting_list, shortlist, metric, nprobe=1):\n",
309 | " if NESTING:\n",
310 | " # Database: 1.2M x 1 for imagenet1k\n",
311 | " db_labels = np.load(embeddings_root + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n",
312 | " \n",
313 | " # Query set: 50K x 1 for imagenet1k\n",
314 | " query_labels = np.load(embeddings_root + DATASET + \"_val_mrl1_e0_ff2048-y.npy\")\n",
315 | " \n",
316 | " for dim in nesting_list:\n",
317 | " if opq > dim:\n",
318 | " continue\n",
319 | " # Load database and query set for fixed feature models\n",
320 | " if not NESTING:\n",
321 | " db_labels = np.load(embeddings_root + DATASET + \"_train_mrl1_e0_ff2048-y.npy\")\n",
322 | " query_labels = np.load(embeddings_root + DATASET + \"_val_mrl0_e0_ff\"+str(D)+\"-y.npy\")\n",
323 | " \n",
324 | " for W in [32]:\n",
325 | " row = [dim, opq, W]\n",
326 | " fileName = f'/home/jupyter/DiskANN/build/data/{CONFIG}-resnet50/res/{DISKANN_INDEX}/d{dim}/opq{opq}_100_idx_uint32.bin'\n",
327 | " print(fileName)\n",
328 | " with open(fileName, 'rb') as f:\n",
329 | " data = np.fromfile(f, dtype='
4 |
5 |
6 |
7 | We follow the setup on the [Dense Passage Retriever](https://github.com/facebookresearch/DPR) (DPR) repo. The Wikipedia corpus has 21 million passages and Natural Questions (NQ) dataset for open-domain QA settings. AdANNS with DPR on NQ is organized as:
8 | 1. Index Training for IP, L2, IVF, OPQ, IVFOPQ. Currently supported
9 | - In batches (RAM-constrained, ~2 hours build time)
10 | - One shot (~120G peak RAM usage, ~2 minutes build time)
11 | 2. DPR eval (available on request)
--------------------------------------------------------------------------------
/adanns/dpr-nq/adanns-nq.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "id": "4ef1bef5",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from loguru import logger\n",
11 | "import pyarrow as pa\n",
12 | "import faiss\n",
13 | "from tqdm import tqdm\n",
14 | "import numpy as np\n",
15 | "import time"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 5,
21 | "id": "80a238e4",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "# Specify These\n",
26 | "config = 'MR' # MR, RR\n",
27 | "index_type = 'IVFOPQ' # IP, L2, IVF, OPQ, IVFOPQ\n",
28 | "train_batches = False # Set to True if system has sufficient RAM\n",
29 | "DPR_root = '/mnt/disks/experiments/DPR/'"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 8,
35 | "id": "5861db59",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "if config == 'MR':\n",
40 | " config_name = 'dpr-nq-d768_384_192_96_48-wiki' # MR\n",
41 | "else:\n",
42 | " config_name = 'dpr-nq-d768-wiki' # RR-768\n",
43 | " \n",
44 | "embeddings_file = f'{DPR_root}results/embed/{config_name}.arrow'\n",
45 | "emb_data = pa.ipc.open_file(pa.memory_map(embeddings_file, \"rb\")).read_all()"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "4d1754f2",
51 | "metadata": {},
52 | "source": [
53 | "## Batched Index Training (RAM-constrained)\n",
54 | "Learn Exact Search Indices (with IP distance) in batches over 21M passages."
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 7,
60 | "id": "dd9cb0db",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "# Train exact index with database and queries with embedding size 'dim' and write to disk\n",
65 | "def batched_train(dim): \n",
66 | " index_file = f'results/embed/exact-index/{config_name}-dim{dim}_{index_type}_batched.faiss'\n",
67 | " \n",
68 | " sub_index = faiss.IndexFlatIP(dim)\n",
69 | " faiss_index = faiss.IndexIDMap2(sub_index)\n",
70 | "\n",
71 | " total = 0\n",
72 | " for batch in tqdm(emb_data.to_batches()):\n",
73 | " batch_data = batch.to_pydict()\n",
74 | " psg_ids = np.array(batch_data[\"id\"])\n",
75 | "\n",
76 | " token_emb = np.array(batch_data[\"embedding\"], dtype=np.float32)\n",
77 | " token_emb = np.ascontiguousarray(token_emb[:, :dim]) # Shape: (8192, dim)\n",
78 | " faiss_index.add_with_ids(token_emb, psg_ids)\n",
79 | "\n",
80 | " total += len(psg_ids)\n",
81 | " if total % 1000 == 0:\n",
82 | " logger.info(f\"indexed {total} passages\")\n",
83 | "\n",
84 | " faiss.write_index(faiss_index, str(index_file))\n",
85 | "\n",
86 | "if(train_batches):\n",
87 | " batched_train(dim=768)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "a46ce428",
93 | "metadata": {},
94 | "source": [
95 | "## Full Training (High peak RAM Usage ~120G)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 3,
101 | "id": "5ee63112",
102 | "metadata": {},
103 | "outputs": [
104 | {
105 | "name": "stdout",
106 | "output_type": "stream",
107 | "text": [
108 | "(21015324,)\n",
109 | "(21015324, 768) float32\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "if not train_batches:\n",
115 | " psg_ids = np.array(emb_data['id'])\n",
116 | " print(psg_ids.shape) # Passage IDs\n",
117 | "\n",
118 | " # Takes ~5 min on our system\n",
119 | " token_emb = np.array(emb_data[\"embedding\"])\n",
120 | "\n",
121 | " token_emb = np.hstack(token_emb)\n",
122 | "\n",
123 | " token_emb = token_emb.reshape(21015324, -1)\n",
124 | " print(token_emb.shape, token_emb.dtype) # Token Embeddings\n",
125 | "else:\n",
126 | " raise Exception(\"Insufficient RAM to train on entire data!\")"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 6,
132 | "id": "1adcd6c5",
133 | "metadata": {},
134 | "outputs": [
135 | {
136 | "name": "stdout",
137 | "output_type": "stream",
138 | "text": [
139 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M8.faiss\n",
140 | "Adding DB: (21015324, 768)\n",
141 | "Time to build index with d=768 : 1723.532558\n",
142 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss\n",
143 | "Adding DB: (21015324, 768)\n",
144 | "Time to build index with d=768 : 1926.746812\n",
145 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M32.faiss\n",
146 | "Adding DB: (21015324, 768)\n",
147 | "Time to build index with d=768 : 2302.668599\n",
148 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M48.faiss\n",
149 | "Adding DB: (21015324, 768)\n",
150 | "Time to build index with d=768 : 2746.178078\n",
151 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M64.faiss\n",
152 | "Adding DB: (21015324, 768)\n",
153 | "Time to build index with d=768 : 2214.350993\n",
154 | "Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M96.faiss\n",
155 | "Adding DB: (21015324, 768)\n",
156 | "Time to build index with d=768 : 2294.547853\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "ncell=10 # Number of IVF cells\n",
162 | "dims=[768] # Embedding dims to train indices over\n",
163 | "Ms=[8, 16, 32, 48, 64, 96] # Number of PQ sub-quantizers for IVF+OPQ\n",
164 | "\n",
165 | "for M in Ms:\n",
166 | " for dim in dims:\n",
167 | " if M > dim or dim%M!=0:\n",
168 | " print(\"Skipping (d,M) : (%d, %d)\" %(dim, M))\n",
169 | " continue\n",
170 | " \n",
171 | " token_emb_sliced = np.ascontiguousarray(token_emb[:, :dim])\n",
172 | " faiss.normalize_L2(token_emb_sliced)\n",
173 | " print(\"Adding DB: \", token_emb_sliced.shape)\n",
174 | " print(f'Generating {index_type} index on config: {config_name}')\n",
175 | " \n",
176 | " tic = time.time()\n",
177 | " # Flat L2 Index\n",
178 | " if index_type == 'IP':\n",
179 | " index_file = f'results/embed/IP/{config_name}-dim{dim}_IP.faiss'\n",
180 | " sub_index = faiss.IndexFlatIP(dim)\n",
181 | " faiss_index = faiss.IndexIDMap2(sub_index)\n",
182 | "\n",
183 | " elif index_type == 'L2':\n",
184 | " index_file = f'results/embed/L2/{config_name}-dim{dim}_L2.faiss'\n",
185 | " sub_index = faiss.IndexFlatL2(dim)\n",
186 | " faiss_index = faiss.IndexIDMap2(sub_index)\n",
187 | "\n",
188 | " elif index_type == 'IVF':\n",
189 | " index_file = f'results/embed/IVF/{config_name}-dim{dim}_IVF_ncell{ncell}.faiss'\n",
190 | " quantizer = faiss.IndexFlatL2(dim)\n",
191 | " faiss_index = faiss.IndexIVFFlat(quantizer, dim, ncell)\n",
192 | " faiss_index.train(token_emb_sliced)\n",
193 | " \n",
194 | " elif index_type == 'OPQ':\n",
195 | " index_file = f'results/embed/OPQ/{config_name}-dim{dim}_OPQ_M{M}_nbits8.faiss'\n",
196 | " opq_train_db_indices = np.random.choice(token_emb_sliced.shape[0], 500000, replace=False)\n",
197 | " opq_train_db = token_emb_sliced[opq_train_db_indices]\n",
198 | " sub_index = faiss.index_factory(dim, f\"OPQ{M},PQ{M}x{8}\")\n",
199 | " faiss_index = faiss.IndexIDMap2(sub_index)\n",
200 | " faiss_index.train(opq_train_db)\n",
201 | "\n",
202 | " elif index_type == 'IVFOPQ':\n",
203 | " index_file = f'results/embed/IVFOPQ/{config_name}-dim{dim}_IVFOPQ_cell{ncell}_M{M}_nbits8.faiss'\n",
204 | " sub_index = faiss.index_factory(dim, f\"OPQ{M},IVF{ncell},PQ{M}x{8}\")\n",
205 | " faiss_index = faiss.IndexIDMap2(sub_index)\n",
206 | " faiss_index.train(token_emb_sliced)\n",
207 | " \n",
208 | " faiss_index.add_with_ids(token_emb_sliced, psg_ids)\n",
209 | " faiss.write_index(faiss_index, str(index_file))\n",
210 | " toc = time.time()\n",
211 | " \n",
212 | " print(\"Generated \", index_file)\n",
213 | " print(\"Time to build index with d=%d : %f\" %(dim, toc-tic))"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "id": "daef008c",
219 | "metadata": {},
220 | "source": [
221 | "# Search (restart kernel for memory)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 2,
227 | "id": "dbb3919e",
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "2023-05-21 20:04:27.733 | INFO | __main__:batch_eval_dataset:94 - init Retriever from model_ckpt=ckpt/dpr-nq-d768_384_192_96_48\n",
235 | "2023-05-21 20:04:35.608 | INFO | __main__:batch_eval_dataset:100 - loading index_file=results/embed/IVFOPQ/dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss...\n",
236 | "2023-05-21 20:04:41.547 | INFO | __main__:batch_eval_dataset:107 - loading passage_db_file=data/psgs-w100.lmdb...\n",
237 | "2023-05-21 20:04:41.758 | INFO | __main__:batch_eval_dataset:114 - loading QA pairs from qas-data/nq-test.csv\n",
238 | "2023-05-21 20:04:41.803 | INFO | __main__:batch_eval_dataset:119 - computing query embeddings...\n",
239 | "2023-05-21 20:04:41.804 | INFO | __main__:batch_eval_dataset:121 - begin searching max(top_k)=200 passage for 3610 question...\n",
240 | "search 1668.9 queries/s, checking answers: 100%|██████████| 8/8 [10:52<00:00, 81.59s/it]\n",
241 | "2023-05-21 20:15:34.546 | INFO | __main__:batch_eval_dataset:154 - #total examples: 3610\n",
242 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@1:0.26204986149584486 correct_samples:946\n",
243 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@5:0.46204986149584487 correct_samples:1668\n",
244 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@20:0.5997229916897507 correct_samples:2165\n",
245 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@100:0.714404432132964 correct_samples:2579\n",
246 | "2023-05-21 20:15:34.567 | INFO | __main__:batch_eval_dataset:166 - precision@200:0.7484764542936289 correct_samples:2702\n",
247 | "Finished Processing!\n",
248 | "\n",
249 | "2023-05-21 20:15:40.640 | INFO | __main__:batch_eval_dataset:94 - init Retriever from model_ckpt=ckpt/dpr-nq-d768_384_192_96_48\n",
250 | "2023-05-21 20:15:43.152 | INFO | __main__:batch_eval_dataset:100 - loading index_file=results/embed/IVFOPQ/dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M32.faiss...\n",
251 | "2023-05-21 20:15:50.740 | INFO | __main__:batch_eval_dataset:107 - loading passage_db_file=data/psgs-w100.lmdb...\n",
252 | "2023-05-21 20:15:50.815 | INFO | __main__:batch_eval_dataset:114 - loading QA pairs from qas-data/nq-test.csv\n",
253 | "2023-05-21 20:15:50.851 | INFO | __main__:batch_eval_dataset:119 - computing query embeddings...\n",
254 | "2023-05-21 20:15:50.852 | INFO | __main__:batch_eval_dataset:121 - begin searching max(top_k)=200 passage for 3610 question...\n",
255 | "search 1811.9 queries/s, checking answers: 100%|██████████| 8/8 [06:27<00:00, 48.38s/it]\n",
256 | "2023-05-21 20:22:17.956 | INFO | __main__:batch_eval_dataset:154 - #total examples: 3610\n",
257 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@1:0.33407202216066484 correct_samples:1206\n",
258 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@5:0.531578947368421 correct_samples:1919\n",
259 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@20:0.6493074792243767 correct_samples:2344\n",
260 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@100:0.7401662049861496 correct_samples:2672\n",
261 | "2023-05-21 20:22:17.959 | INFO | __main__:batch_eval_dataset:166 - precision@200:0.7678670360110803 correct_samples:2772\n",
262 | "Finished Processing!\n",
263 | "\n"
264 | ]
265 | }
266 | ],
267 | "source": [
268 | "%%bash\n",
269 | "split=test\n",
270 | "ds=nq\n",
271 | "\n",
272 | "# Change these\n",
273 | "d=768\n",
274 | "index_type=IVFOPQ\n",
275 | "config_name=dpr-nq-d768_384_192_96_48-wiki\n",
276 | "\n",
277 | "# Modify index_file name to the one built above\n",
278 | "for M in 16 32\n",
279 | "do\n",
280 | " for d in 768\n",
281 | " do\n",
282 | " python rtr/cli/eval_retriever.py \\\n",
283 | " --passage_db_file data/psgs-w100.lmdb \\\n",
284 | " --model_ckpt ckpt/{config_name} \\\n",
285 | " --index_file results/embed/${index_type}/${config_name}-dim${d}_${index_type}_cell${ncell}_M${M}.faiss \\\n",
286 | " --dataset_file qas-data/${ds}-${split}.csv \\\n",
287 | " --save_file results/json/reader-${config_name}-${ds}-${split}-dim${d}.jsonl \\\n",
288 | " --batch_size 512 \\\n",
289 | " --max_question_len 200 \\\n",
290 | " --embedding_size ${d} \\\n",
291 | " --metrics_file results/metrics.json \\\n",
292 | " --binary False \\\n",
293 | " 2>&1 | tee results/logs/eval-${config_name}-${ds}-${split}-dim${d}.log\n",
294 | " echo -e \"Finished Processing!\\n\"\n",
295 | " done\n",
296 | "done"
297 | ]
298 | }
299 | ],
300 | "metadata": {
301 | "kernelspec": {
302 | "display_name": "Python 3 (ipykernel)",
303 | "language": "python",
304 | "name": "python3"
305 | },
306 | "language_info": {
307 | "codemirror_mode": {
308 | "name": "ipython",
309 | "version": 3
310 | },
311 | "file_extension": ".py",
312 | "mimetype": "text/x-python",
313 | "name": "python",
314 | "nbconvert_exporter": "python",
315 | "pygments_lexer": "ipython3",
316 | "version": "3.10.8"
317 | }
318 | },
319 | "nbformat": 4,
320 | "nbformat_minor": 5
321 | }
322 |
--------------------------------------------------------------------------------
/adanns/generate_nn/hnsw_exactl2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "cc5c9e14",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import faiss\n",
12 | "import time\n",
13 | "import pandas as pd\n",
14 | "from os import path, makedirs\n",
15 | "import torch\n",
16 | "import sys\n",
17 | "sys.path.append('../')\n",
18 | "from utils import load_embeddings"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "id": "20116c14",
24 | "metadata": {},
25 | "source": [
26 | "## Configuration Variables"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "id": "3be50514",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "root = '../../../inference_array/resnet50/' # path to database and queryset\n",
37 | "D = 2048 # embedding dim\n",
38 | "hnsw_max_neighbors = 32 # M for HNSW, default=32\n",
39 | "pq_num_subvectors = 32 # m for HNSW+PQ\n",
40 | "\n",
41 | "model = 'mrl' # mrl, rr\n",
42 | "dataset = '1K' # 1K, 4K, V2\n",
43 | "index_type = 'hnsw32' # exactl2, hnsw32, #'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(pq_num_subvectors)\n",
44 | "\n",
45 | "k = 2048 # shortlist length, default is set to the max supported by FAISS\n",
46 | "nesting_list = [8, 16, 32, 64] # embedding dim to loop over"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "f4401b8c",
52 | "metadata": {},
53 | "source": [
54 | "## FAISS Index Building and NN Search"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 3,
60 | "id": "5fd3e4f5",
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "if index_type == 'exactl2' and torch.cuda.device_count() > 0:\n",
65 | " use_gpu = True # GPU inference for exact search\n",
66 | "else:\n",
67 | " use_gpu = False # GPU inference for HNSW is currently not supported by FAISS"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 4,
73 | "id": "f862bf36",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def get_nn(index_type, nesting_list, m=8):\n",
78 | " for retrieval_dim in nesting_list:\n",
79 | " if retrieval_dim > D:\n",
80 | " continue\n",
81 | " \n",
82 | " if index_type == 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m) and retrieval_dim < m:\n",
83 | " continue\n",
84 | " \n",
85 | " if not path.isdir(root+'index_files/'+model+'/'):\n",
86 | " makedirs(root+'index_files/'+model+'/')\n",
87 | " index_file = root+'index_files/'+model+'/'+dataset+'_'+str(retrieval_dim)+'dim_'+index_type+'.index'\n",
88 | "\n",
89 | " _, _, _, _, xb, xq = load_embeddings(model, dataset, retrieval_dim)\n",
90 | "\n",
91 | " # Load or build index\n",
92 | " if path.exists(index_file): # Load index\n",
93 | " print(\"Loading index file: \" + index_file.split(\"/\")[-1])\n",
94 | " cpu_index = faiss.read_index(index_file)\n",
95 | " \n",
96 | " else: # Build index\n",
97 | " print(\"Generating index file: \" + index_file)\n",
98 | "\n",
99 | " d = xb.shape[1] # dimension\n",
100 | "\n",
101 | " start = time.time()\n",
102 | " if index_type == 'exactl2':\n",
103 | " print(\"Building Exact L2 Index\")\n",
104 | " cpu_index = faiss.IndexFlatL2(d) # build the index\n",
105 | " elif index_type == 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m):\n",
106 | " print(\"Building D%d + HNSW%d + PQ%d Index\" % (d, hnsw_max_neighbors, m))\n",
107 | " cpu_index = faiss.IndexHNSWPQ(d, m, hnsw_max_neighbors)\n",
108 | " cpu_index.train(xb)\n",
109 | " elif index_type == f'hnsw{hnsw_max_neighbors}':\n",
110 | " print(\"Building HNSW%d Index\" % hnsw_max_neighbors)\n",
111 | " cpu_index = faiss.IndexHNSWFlat(d, hnsw_max_neighbors)\n",
112 | " else:\n",
113 | " raise Exception(f\"Unsupported Index: {index_type}\")\n",
114 | " \n",
115 | " cpu_index.add(xb) # add vectors to the index\n",
116 | " faiss.write_index(cpu_index, index_file)\n",
117 | " print(\"GPU Index build time= %0.3f sec\" % (time.time() - start))\n",
118 | "\n",
119 | " if use_gpu:\n",
120 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n",
121 | " else:\n",
122 | " index = cpu_index\n",
123 | " \n",
124 | " # Iterate over efSearch (HNSW search probes)\n",
125 | " efsearchlist = [16]\n",
126 | " for efsearch in efsearchlist:\n",
127 | " start = time.time()\n",
128 | " if index_type in ['hnsw32', 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m)]:\n",
129 | " index.hnsw.efSearch = efsearch\n",
130 | " print(\"Searching with Efsearch =\", index.hnsw.efSearch)\n",
131 | " Dist, Ind = index.search(xq, k)\n",
132 | " # print(\"GPU %d-NN search time= %f sec\" % (k, time.time() - start))\n",
133 | " if not path.isdir(root+\"neighbors/\"+model+'/'+index_type):\n",
134 | " makedirs(root+\"neighbors/\"+model+'/'+index_type)\n",
135 | " nn_dir = root+\"neighbors/\"+model+'/'+index_type+\"/\"+index_type+'_efsearch'+str(efsearch)+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(retrieval_dim)+\".csv\"\n",
136 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n",
137 | " \n",
138 | " del index, Dist, Ind"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "id": "f5efe3d2",
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "name": "stdout",
149 | "output_type": "stream",
150 | "text": [
151 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_8dim_hnsw32.index\n",
152 | "Building HNSW32 Index\n",
153 | "GPU Index build time= 57.573 sec\n",
154 | "Searching with Efsearch = 16\n",
155 | "13010622\n",
156 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_16dim_hnsw32.index\n",
157 | "Building HNSW32 Index\n",
158 | "GPU Index build time= 71.297 sec\n",
159 | "Searching with Efsearch = 16\n",
160 | "15721346\n",
161 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_32dim_hnsw32.index\n",
162 | "Building HNSW32 Index\n",
163 | "GPU Index build time= 74.260 sec\n",
164 | "Searching with Efsearch = 16\n",
165 | "16834028\n",
166 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl1K_64dim_hnsw32.index\n",
167 | "Building HNSW32 Index\n",
168 | "GPU Index build time= 77.627 sec\n",
169 | "Searching with Efsearch = 16\n",
170 | "17557581\n"
171 | ]
172 | }
173 | ],
174 | "source": [
175 | "nesting_list = [8, 16, 32, 64]\n",
176 | "get_nn(index_type, nesting_list)"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 5,
182 | "id": "87ea3f22",
183 | "metadata": {
184 | "scrolled": false
185 | },
186 | "outputs": [
187 | {
188 | "name": "stdout",
189 | "output_type": "stream",
190 | "text": [
191 | "Loading index file: 1K_8dim_hnswpq_M32_pq-m8.index\n",
192 | "Searching with Efsearch = 16\n",
193 | "Loading index file: 1K_16dim_hnswpq_M32_pq-m8.index\n",
194 | "Searching with Efsearch = 16\n",
195 | "Loading index file: 1K_32dim_hnswpq_M32_pq-m8.index\n",
196 | "Searching with Efsearch = 16\n",
197 | "Loading index file: 1K_64dim_hnswpq_M32_pq-m8.index\n",
198 | "Searching with Efsearch = 16\n",
199 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_16dim_hnswpq_M32_pq-m16.index\n",
200 | "Building D16 + HNSW32 + PQ16 Index\n",
201 | "GPU Index build time= 180.430 sec\n",
202 | "Searching with Efsearch = 16\n",
203 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_32dim_hnswpq_M32_pq-m16.index\n",
204 | "Building D32 + HNSW32 + PQ16 Index\n",
205 | "GPU Index build time= 180.994 sec\n",
206 | "Searching with Efsearch = 16\n",
207 | "Generating index file: ../../../inference_array/resnet50/index_files/mrl/1K_64dim_hnswpq_M32_pq-m16.index\n",
208 | "Building D64 + HNSW32 + PQ16 Index\n",
209 | "GPU Index build time= 182.374 sec\n",
210 | "Searching with Efsearch = 16\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "# k = 40 to generate Exact Ground Truth for 40-Recall@2048\n",
216 | "# nesting_list = [D] # fixed embedding dimension for RR models\n",
217 | "pq_m_values = [8, 16] # loop over PQ m values\n",
218 | "\n",
219 | "for m in pq_m_values:\n",
220 | " index_type = 'hnswpq_M'+str(hnsw_max_neighbors)+'_pq-m'+str(m)\n",
221 | " get_nn(index_type, nesting_list, m)"
222 | ]
223 | }
224 | ],
225 | "metadata": {
226 | "jupytext": {
227 | "cell_metadata_filter": "-all",
228 | "notebook_metadata_filter": "-all"
229 | },
230 | "kernelspec": {
231 | "display_name": "Python 3 (ipykernel)",
232 | "language": "python",
233 | "name": "python3"
234 | },
235 | "language_info": {
236 | "codemirror_mode": {
237 | "name": "ipython",
238 | "version": 3
239 | },
240 | "file_extension": ".py",
241 | "mimetype": "text/x-python",
242 | "name": "python",
243 | "nbconvert_exporter": "python",
244 | "pygments_lexer": "ipython3",
245 | "version": "3.11.3"
246 | },
247 | "vscode": {
248 | "interpreter": {
249 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc"
250 | }
251 | }
252 | },
253 | "nbformat": 4,
254 | "nbformat_minor": 5
255 | }
256 |
--------------------------------------------------------------------------------
/adanns/generate_nn/ivf-experiments.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "18c0eff4",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import faiss\n",
12 | "import time\n",
13 | "import torch\n",
14 | "import pandas as pd\n",
15 | "from os import path, makedirs\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import sys\n",
18 | "sys.path.append('../')\n",
19 | "from utils import load_embeddings"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "44133407",
25 | "metadata": {},
26 | "source": [
27 | "## CONFIG"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 3,
33 | "id": "99918b2f",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "model = 'mrl' # mrl, rr\n",
38 | "arch = 'resnet50' # resnet18, resnet34, resnet50, resnet101, mobilenetv2\n",
39 | "root = f'../../../inference_array/{arch}'\n",
40 | "dataset = '1K' # 1K, 4K, V2, inat"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 4,
46 | "id": "6a04810b",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "index_type = 'ivf' # ivfpq, ivfsq, lsh, pq, ivf, opq\n",
51 | "\n",
52 | "hnsw_max_neighbors = 32 # 8, 32\n",
53 | "\n",
54 | "k = 2048 # shortlist length, default set to max supported by FAISS\n",
55 | "\n",
56 | "nesting_list_dict = {\n",
57 | " 'vgg19': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],\n",
58 | " 'resnet18': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],\n",
59 | " 'resnet34': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],\n",
60 | " 'resnet50': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048],\n",
61 | " 'resnet101': [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048],\n",
62 | " 'mobilenetv2': [10, 20, 40, 80, 160, 320, 640, 1280],\n",
63 | " 'convnext_tiny': [12, 24, 48, 96, 192, 384, 768]\n",
64 | "}\n",
65 | "\n",
66 | "nesting_list = nesting_list_dict[arch]\n",
67 | "\n",
68 | "ivf_configs = {\n",
69 | " 'C1': {'nlist': 2048},\n",
70 | " 'C2': {'nlist': 512},\n",
71 | " 'C3': {'nlist': 128},\n",
72 | " 'C4': {'nlist': 4096},\n",
73 | " 'C5': {'nlist': 8192},\n",
74 | " 'C6': {'nlist': 256},\n",
75 | " 'C7': {'nlist': 1024}\n",
76 | "}\n",
77 | "\n",
78 | "config_id = 'C7'\n",
79 | " \n",
80 | "# nbits = ivfpq_configs[config_id]['nbits'] # nbits used to represent centroid id; total possible is k* = 2**nbits\n",
81 | "\n",
82 | "nlist = ivf_configs[config_id]['nlist'] # how many Voronoi cells (must be >= k*)"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 5,
88 | "id": "91f822b8",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "if index_type == 'exactl2':\n",
93 | " use_gpu = 1 # GPU inference for exact search\n",
94 | "else:\n",
95 | " use_gpu = 0\n",
96 | "\n",
97 | "if use_gpu and faiss.get_num_gpus() == 0:\n",
98 | " raise Exception(\"GPU search is enabled but no GPU was found.\")"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "id": "67db9767",
104 | "metadata": {},
105 | "source": [
106 | "## IVF Experiments"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 7,
112 | "id": "4ca6a7d7",
113 | "metadata": {},
114 | "outputs": [
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "*************************************************\n",
120 | "FF-8 IVF \n",
121 | "\n",
122 | "*************************************************\n",
123 | "Loaded DB: 1K_train_mrl0_e0_ff8-X.npy\n",
124 | "DB size: 39.10 MB\n",
125 | "DB: (1281167, 8), Q: (50000, 8)\n",
126 | "Indexing database: (1281167, 8)\n",
127 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist256_d8\n",
128 | "Training IVF Index: d=8, nlist=256\n",
129 | "IVF Index train time= 0.141 sec\n",
130 | "NProbe: 1\n",
131 | "Searching queries: (50000, 8)\n",
132 | "GPU 2048-NN search time= 2.962470 sec\n",
133 | "\n",
134 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist256_nprobe1_2048shortlist_1K_d8.csv\n",
135 | "\n",
136 | "*************************************************\n",
137 | "Indexing database: (1281167, 8)\n",
138 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist512_d8\n",
139 | "Training IVF Index: d=8, nlist=512\n",
140 | "IVF Index train time= 0.272 sec\n",
141 | "NProbe: 1\n",
142 | "Searching queries: (50000, 8)\n",
143 | "GPU 2048-NN search time= 2.063317 sec\n",
144 | "\n",
145 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist512_nprobe1_2048shortlist_1K_d8.csv\n",
146 | "\n",
147 | "*************************************************\n",
148 | "Indexing database: (1281167, 8)\n",
149 | "Loading index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist1024_d8\n",
150 | "NProbe: 1\n",
151 | "Searching queries: (50000, 8)\n",
152 | "GPU 2048-NN search time= 0.369938 sec\n",
153 | "\n",
154 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist1024_nprobe1_2048shortlist_1K_d8.csv\n",
155 | "\n",
156 | "*************************************************\n",
157 | "Indexing database: (1281167, 8)\n",
158 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist2048_d8\n",
159 | "Training IVF Index: d=8, nlist=2048\n",
160 | "IVF Index train time= 6.959 sec\n",
161 | "NProbe: 1\n",
162 | "Searching queries: (50000, 8)\n",
163 | "GPU 2048-NN search time= 1.799566 sec\n",
164 | "\n",
165 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist2048_nprobe1_2048shortlist_1K_d8.csv\n",
166 | "\n",
167 | "*************************************************\n",
168 | "Indexing database: (1281167, 8)\n",
169 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist4096_d8\n",
170 | "Training IVF Index: d=8, nlist=4096\n",
171 | "IVF Index train time= 30.232 sec\n",
172 | "NProbe: 1\n",
173 | "Searching queries: (50000, 8)\n",
174 | "GPU 2048-NN search time= 1.340459 sec\n",
175 | "\n",
176 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist4096_nprobe1_2048shortlist_1K_d8.csv\n",
177 | "\n",
178 | "*************************************************\n",
179 | "Indexing database: (1281167, 8)\n",
180 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/8/1K_ivf_nlist8192_d8\n",
181 | "Training IVF Index: d=8, nlist=8192\n",
182 | "IVF Index train time= 74.361 sec\n",
183 | "NProbe: 1\n",
184 | "Searching queries: (50000, 8)\n",
185 | "GPU 2048-NN search time= 1.129385 sec\n",
186 | "\n",
187 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/8/ivf_nlist8192_nprobe1_2048shortlist_1K_d8.csv\n",
188 | "\n",
189 | "*************************************************\n",
190 | "*************************************************\n",
191 | "FF-16 IVF \n",
192 | "\n",
193 | "*************************************************\n",
194 | "Loaded DB: 1K_train_mrl0_e0_ff16-X.npy\n",
195 | "DB size: 78.20 MB\n",
196 | "DB: (1281167, 16), Q: (50000, 16)\n",
197 | "Indexing database: (1281167, 16)\n",
198 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist256_d16\n",
199 | "Training IVF Index: d=16, nlist=256\n",
200 | "IVF Index train time= 0.247 sec\n",
201 | "NProbe: 1\n",
202 | "Searching queries: (50000, 16)\n",
203 | "GPU 2048-NN search time= 6.624912 sec\n",
204 | "\n",
205 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist256_nprobe1_2048shortlist_1K_d16.csv\n",
206 | "\n",
207 | "*************************************************\n",
208 | "Indexing database: (1281167, 16)\n",
209 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist512_d16\n",
210 | "Training IVF Index: d=16, nlist=512\n",
211 | "IVF Index train time= 0.319 sec\n",
212 | "NProbe: 1\n",
213 | "Searching queries: (50000, 16)\n",
214 | "GPU 2048-NN search time= 2.174936 sec\n",
215 | "\n",
216 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist512_nprobe1_2048shortlist_1K_d16.csv\n",
217 | "\n",
218 | "*************************************************\n",
219 | "Indexing database: (1281167, 16)\n",
220 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist1024_d16\n",
221 | "Training IVF Index: d=16, nlist=1024\n",
222 | "IVF Index train time= 0.973 sec\n",
223 | "NProbe: 1\n",
224 | "Searching queries: (50000, 16)\n",
225 | "GPU 2048-NN search time= 1.328670 sec\n",
226 | "\n",
227 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n",
228 | "\n",
229 | "*************************************************\n",
230 | "Indexing database: (1281167, 16)\n",
231 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist2048_d16\n",
232 | "Training IVF Index: d=16, nlist=2048\n",
233 | "IVF Index train time= 3.577 sec\n",
234 | "NProbe: 1\n",
235 | "Searching queries: (50000, 16)\n",
236 | "GPU 2048-NN search time= 1.002080 sec\n",
237 | "\n",
238 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist2048_nprobe1_2048shortlist_1K_d16.csv\n",
239 | "\n",
240 | "*************************************************\n",
241 | "Indexing database: (1281167, 16)\n",
242 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist4096_d16\n",
243 | "Training IVF Index: d=16, nlist=4096\n",
244 | "IVF Index train time= 17.197 sec\n",
245 | "NProbe: 1\n",
246 | "Searching queries: (50000, 16)\n",
247 | "GPU 2048-NN search time= 0.797544 sec\n",
248 | "\n",
249 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist4096_nprobe1_2048shortlist_1K_d16.csv\n",
250 | "\n",
251 | "*************************************************\n",
252 | "Indexing database: (1281167, 16)\n",
253 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/16/1K_ivf_nlist8192_d16\n",
254 | "Training IVF Index: d=16, nlist=8192\n",
255 | "IVF Index train time= 35.940 sec\n",
256 | "NProbe: 1\n",
257 | "Searching queries: (50000, 16)\n",
258 | "GPU 2048-NN search time= 0.695007 sec\n",
259 | "\n",
260 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/16/ivf_nlist8192_nprobe1_2048shortlist_1K_d16.csv\n",
261 | "\n",
262 | "*************************************************\n",
263 | "*************************************************\n",
264 | "FF-512 IVF \n",
265 | "\n",
266 | "*************************************************\n",
267 | "Loaded DB: 1K_train_mrl0_e0_ff512-X.npy\n",
268 | "DB size: 2502.28 MB\n",
269 | "DB: (1281167, 512), Q: (50000, 512)\n",
270 | "Indexing database: (1281167, 512)\n",
271 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist256_d512\n",
272 | "Training IVF Index: d=512, nlist=256\n",
273 | "IVF Index train time= 1.085 sec\n",
274 | "NProbe: 1\n",
275 | "Searching queries: (50000, 512)\n",
276 | "GPU 2048-NN search time= 10.056288 sec\n",
277 | "\n",
278 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist256_nprobe1_2048shortlist_1K_d512.csv\n",
279 | "\n",
280 | "*************************************************\n",
281 | "Indexing database: (1281167, 512)\n",
282 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist512_d512\n",
283 | "Training IVF Index: d=512, nlist=512\n",
284 | "IVF Index train time= 2.630 sec\n",
285 | "NProbe: 1\n",
286 | "Searching queries: (50000, 512)\n",
287 | "GPU 2048-NN search time= 5.309681 sec\n",
288 | "\n",
289 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist512_nprobe1_2048shortlist_1K_d512.csv\n",
290 | "\n",
291 | "*************************************************\n",
292 | "Indexing database: (1281167, 512)\n",
293 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist1024_d512\n",
294 | "Training IVF Index: d=512, nlist=1024\n",
295 | "IVF Index train time= 8.684 sec\n",
296 | "NProbe: 1\n",
297 | "Searching queries: (50000, 512)\n",
298 | "GPU 2048-NN search time= 2.672392 sec\n",
299 | "\n",
300 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n",
301 | "\n",
302 | "*************************************************\n",
303 | "Indexing database: (1281167, 512)\n",
304 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist2048_d512\n",
305 | "Training IVF Index: d=512, nlist=2048\n",
306 | "IVF Index train time= 35.140 sec\n",
307 | "NProbe: 1\n",
308 | "Searching queries: (50000, 512)\n",
309 | "GPU 2048-NN search time= 1.662383 sec\n",
310 | "\n",
311 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist2048_nprobe1_2048shortlist_1K_d512.csv\n",
312 | "\n",
313 | "*************************************************\n",
314 | "Indexing database: (1281167, 512)\n",
315 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist4096_d512\n",
316 | "Training IVF Index: d=512, nlist=4096\n",
317 | "IVF Index train time= 130.656 sec\n",
318 | "NProbe: 1\n",
319 | "Searching queries: (50000, 512)\n",
320 | "GPU 2048-NN search time= 1.486253 sec\n",
321 | "\n",
322 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist4096_nprobe1_2048shortlist_1K_d512.csv\n",
323 | "\n",
324 | "*************************************************\n",
325 | "Indexing database: (1281167, 512)\n",
326 | "Generating index file: ../../../inference_array/resnet50/index_files/ff-ivf/512/1K_ivf_nlist8192_d512\n",
327 | "Training IVF Index: d=512, nlist=8192\n"
328 | ]
329 | },
330 | {
331 | "name": "stdout",
332 | "output_type": "stream",
333 | "text": [
334 | "IVF Index train time= 397.688 sec\n",
335 | "NProbe: 1\n",
336 | "Searching queries: (50000, 512)\n",
337 | "GPU 2048-NN search time= 2.997688 sec\n",
338 | "\n",
339 | "Writing NN csv to ../../../inference_array/resnet50/neighbors/ff-ivf/512/ivf_nlist8192_nprobe1_2048shortlist_1K_d512.csv\n",
340 | "\n",
341 | "*************************************************\n"
342 | ]
343 | }
344 | ],
345 | "source": [
346 | "Ds = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n",
347 | "num_cell_list = [256, 512, 1024, 2048, 4096, 8192] # IVF number of clusters\n",
348 | "nprobes = [1, 2, 4]\n",
349 | "\n",
350 | "for D in Ds:\n",
351 | " \n",
352 | " database, queryset, _, _, _, _ = load_embeddings(model, dataset, D, arch=arch)\n",
353 | "\n",
354 | " num_cell_list = [256, 512, 1024, 2048, 4096, 8192] # IVF num of clusters\n",
355 | " for num_cell in num_cell_list:\n",
356 | " if not path.isdir(root+'index_files/'+model+'/'):\n",
357 | " makedirs(root+'index_files/'+model+'/')\n",
358 | "\n",
359 | " print(\"Indexing database: \", database.shape)\n",
360 | " d = database.shape[1]\n",
361 | " index_file = root+'index_files/'+model+'/'+dataset+'_ivf'+'_nlist'+str(num_cell)+\"_d\"+str(d)\n",
362 | " \n",
363 | " # Load or build index\n",
364 | " if path.exists(index_file+'.index'):\n",
365 | " print(\"Loading index file: ” + index_file)\n",
366 | " cpu_index = faiss.read_index(index_file+'.index')\n",
367 | " else:\n",
368 | " if index_type == 'ivf':\n",
369 | " print(\"Generating index file: ” + index_file)\n",
370 | " # Generate IVF Index File\n",
371 | " quantizer = faiss.IndexFlatL2(d) # L2 quantizer to assign vectors to Voronoi cells\n",
372 | " cpu_index = faiss.IndexIVFFlat(quantizer, d, num_cell)\n",
373 | " print(f”Training IVF Index: d={d}, nlist={num_cell}\")\n",
374 | " ivf_start = time.time()\n",
375 | " cpu_index.train(database)\n",
376 | " ivf_end = time.time()\n",
377 | " print(\"IVF Index train time= %0.3f sec” % (ivf_end - ivf_start))\n",
378 | " elif index_type == 'exactl2':\n",
379 | " print(\"Building Exact L2 Index”)\n",
380 | " cpu_index = faiss.IndexFlatL2(d)\n",
381 | " cpu_index.add(database) # add vectors to the index\n",
382 | " faiss.write_index(cpu_index, index_file+'.index')\n",
383 | " if use_gpu:\n",
384 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n",
385 | " else:\n",
386 | " index = cpu_index\n",
387 | " for nprobe in nprobes: # nprobe for IVF\n",
388 | " print(f”NProbe: {nprobe}\")\n",
389 | " print(\"Searching queries: \", queryset.shape)\n",
390 | " start = time.time()\n",
391 | " index.nprobe = nprobe\n",
392 | " Dist, Ind = index.search(queryset, k) # k->shorlist length\n",
393 | " end = time.time() - start\n",
394 | " print(\"%d-NN search time= %f sec\\n” % (k, end))\n",
395 | " if not path.isdir(root+\"neighbors/\"+model):\n",
396 | " makedirs(root+\"neighbors/\"+model)\n",
397 | " nn_dir = root+\"neighbors/\"+model+index_type+\"_nlist”+str(num_cell)+f”_nprobe{nprobe}_\"+str(k)+\"shortlist_\"+dataset+f”_d{d}.csv”\n",
398 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n",
399 | " print(\"Writing NN csv to %s\\n” % (nn_dir))\n",
400 | " del index\n",
401 | " print(\"*************************************************\")"
402 | ]
403 | }
404 | ],
405 | "metadata": {
406 | "kernelspec": {
407 | "display_name": "Python 3 (ipykernel)",
408 | "language": "python",
409 | "name": "python3"
410 | },
411 | "language_info": {
412 | "codemirror_mode": {
413 | "name": "ipython",
414 | "version": 3
415 | },
416 | "file_extension": ".py",
417 | "mimetype": "text/x-python",
418 | "name": "python",
419 | "nbconvert_exporter": "python",
420 | "pygments_lexer": "ipython3",
421 | "version": "3.10.7"
422 | }
423 | },
424 | "nbformat": 4,
425 | "nbformat_minor": 5
426 | }
427 |
--------------------------------------------------------------------------------
/adanns/generate_nn/ivfpq_opq_kmeans.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "cc5c9e14",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "import faiss\n",
12 | "import time\n",
13 | "import pandas as pd\n",
14 | "from os import path, makedirs\n",
15 | "import sys\n",
16 | "\n",
17 | "sys.path.append('../')\n",
18 | "from utils import get_duplicate_dist"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "id": "20116c14",
24 | "metadata": {},
25 | "source": [
26 | "## Configuration Variables"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "id": "3be50514",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "root = '../../../inference_array/resnet50/'\n",
37 | "model = 'mrl/' # mrl/, rr/\n",
38 | "\n",
39 | "dataset = '1K' # 1K, 4K, V2\n",
40 | "index_type = 'ivfpq' # ivfpq, ivfsq, opq, kmeans, pq\n",
41 | "use_svd = False\n",
42 | " \n",
43 | "use_gpu = 0 # GPU inference for exact search. Disable for CPU search\n",
44 | "if use_gpu and faiss.get_num_gpus() == 0:\n",
45 | " raise Exception(\"GPU search is enabled but no GPU was found.\")"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "f4401b8c",
51 | "metadata": {},
52 | "source": [
53 | "## FAISS Index Building and NN Search"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 3,
59 | "id": "67862f80",
60 | "metadata": {},
61 | "outputs": [
62 | {
63 | "name": "stdout",
64 | "output_type": "stream",
65 | "text": [
66 | "Loading database: 1K_train_mrl1_e0_ff2048-X.npy\n",
67 | "Loading queries: 1K_val_mrl1_e0_ff2048-X.npy\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "if model == \"mrl/\":\n",
73 | " config = 'mrl1_e0_ff2048'\n",
74 | "elif model == 'rr/':\n",
75 | " config = 'mrl0_e0_ff'\n",
76 | "else: \n",
77 | " raise Exception(\"Unsupported pretrained model.\")\n",
78 | "\n",
79 | "db_npy = dataset + '_train_' + config + '-X.npy'\n",
80 | "query_npy = dataset + '_val_' + config + '-X.npy'\n",
81 | "\n",
82 | "# ImageNetv2 is only a test set; set database to ImageNet-1K\n",
83 | "if dataset == 'V2':\n",
84 | " db_npy = '1K_train_' + config + '-X.npy'\n",
85 | "\n",
86 | "print(\"Loading database: \", db_npy)\n",
87 | "print(\"Loading queries: \", query_npy)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "21ee5576",
93 | "metadata": {},
94 | "source": [
95 | "## SVD Ablation"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 4,
101 | "id": "cca2a7aa",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "def dim_reduce_pca(full_database, low_dim):\n",
106 | " subsampled_db = np.ascontiguousarray(full_database[0::10], dtype=np.float32)\n",
107 | " \n",
108 | " mat = faiss.PCAMatrix (subsampled_db.shape[1], low_dim)\n",
109 | " mat.train(subsampled_db)\n",
110 | " assert mat.is_trained\n",
111 | " \n",
112 | " return mat\n",
113 | "\n",
114 | "def get_svd_data(svd_low_dim):\n",
115 | " print(\"Using PCA for low-dim projection with d= \", svd_low_dim)\n",
116 | "\n",
117 | " database = np.load(root+db_npy)\n",
118 | " queryset = np.load(root+query_npy)\n",
119 | "\n",
120 | " print(\"Original Database: \", database.shape)\n",
121 | " svd_mat = dim_reduce_pca(database, svd_low_dim)\n",
122 | " database_svd_lowdim = svd_mat.apply(database)\n",
123 | " print(\"Low-d Database: \", database_svd_lowdim.shape)\n",
124 | " query_svd_lowdim = svd_mat.apply(queryset)\n",
125 | "\n",
126 | " faiss.normalize_L2(database_svd_lowdim)\n",
127 | " faiss.normalize_L2(query_svd_lowdim)\n",
128 | " \n",
129 | " return database_svd_lowdim, query_svd_lowdim\n",
130 | "\n",
131 | "if use_svd:\n",
132 | " svd_low_dim = 1024\n",
133 | " db_npy_svd = dataset+'_train_'+config+\"_svd\"+str(svd_low_dim)+\"-X.npy\"\n",
134 | " query_npy_svd = dataset+'_val_'+config+\"_svd\"+str(svd_low_dim)+\"-X.npy\"\n",
135 | " database_svd_lowdim, query_svd_lowdim = get_svd_data(svd_low_dim)\n",
136 | " \n",
137 | " if not path.exists(root+db_npy_svd):\n",
138 | " np.save(root+db_npy_svd, database_svd_lowdim)\n",
139 | " if not path.exists(root+query_npy_svd):\n",
140 | " np.save(root+query_npy_svd, query_svd_lowdim)"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "id": "5bdb1ac4",
146 | "metadata": {},
147 | "source": [
148 | "## Full Precision Embedding Loading (MRL)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 5,
154 | "id": "31543d33",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "if model == 'mrl/':\n",
159 | " if use_svd:\n",
160 | " xb = np.load(root+db_npy_svd)\n",
161 | " assert np.count_nonzero(np.isnan(xb)) == 0\n",
162 | " xq = np.load(root+query_npy_svd)\n",
163 | " else:\n",
164 | " xb = np.load(root+db_npy)\n",
165 | " assert np.count_nonzero(np.isnan(xb)) == 0\n",
166 | " xq = np.load(root+query_npy)"
167 | ]
168 | },
169 | {
170 | "cell_type": "markdown",
171 | "id": "2067c840",
172 | "metadata": {},
173 | "source": [
174 | "## Database Indexing and Search"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 6,
180 | "id": "60e69cb5",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "k = 2048 # nearest neighbors shortlist length, default = 2048 = max value supported by FAISS\n",
185 | "#for any pq index, iterator is m. For kmeans, iterator is number of centroids\n",
186 | "iterator = [8, 16, 32, 64] # M for PQ\n",
187 | "Ds = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] # embedding dimensionality\n",
188 | "\n",
189 | "# IVF index specific params\n",
190 | "nprobes = [1] # number of search probes (cells to search)\n",
191 | "nbits = 8 # nbits used to represent centroid id; total possible is k* = 2**nbits\n",
192 | "nlist = 1024 # number of cells (must be >= k*)"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 7,
198 | "id": "a798dca2",
199 | "metadata": {},
200 | "outputs": [],
201 | "source": [
202 | "# Return index file name of given dimensionality based on index type\n",
203 | "def get_index_file(dim):\n",
204 | " if index_type in ['ivfpq', 'opq', 'pq']:\n",
205 | " size = 'm'+str(dim)+\"_d\"+str(D)\n",
206 | " elif index_type == 'kmeans':\n",
207 | " size = str(dim)+'ncentroid_'+str(D)+'d'\n",
208 | " elif index_type == 'kmeans-pca':\n",
209 | " size = str(dim)+'ncentroid_'+str(svd_low_dim)+'pcadim'\n",
210 | " else:\n",
211 | " raise Exception(\"Unsupported Index!\")\n",
212 | "\n",
213 | " index_file = root+'index_files/'+model+index_type+\"/\"+dataset+'_'+index_type+'_'+size\n",
214 | " if index_type in ['ivfpq', 'ivfsq']:\n",
215 | " index_file += \"_nbits\"+str(nbits)+'_nlist'+str(nlist)+'.index'\n",
216 | " \n",
217 | " return index_file"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 8,
223 | "id": "87ea3f22",
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "name": "stdout",
228 | "output_type": "stream",
229 | "text": [
230 | "Indexing database (1281167, 8)\n",
231 | "Generating index file: 1K_ivfpq_m8_d8_nbits8_nlist1024.index\n",
232 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d8_nbits8_nlist1024.index\n",
233 | "Searching queries: (50000, 8)\n",
234 | "nprobe: 1\n",
235 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d8.csv\n",
236 | "\n",
237 | "Indexing database (1281167, 16)\n",
238 | "Generating index file: 1K_ivfpq_m8_d16_nbits8_nlist1024.index\n",
239 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d16_nbits8_nlist1024.index\n",
240 | "Searching queries: (50000, 16)\n",
241 | "nprobe: 1\n",
242 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n",
243 | "\n",
244 | "Indexing database (1281167, 16)\n",
245 | "Generating index file: 1K_ivfpq_m16_d16_nbits8_nlist1024.index\n",
246 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d16_nbits8_nlist1024.index\n",
247 | "Searching queries: (50000, 16)\n",
248 | "nprobe: 1\n",
249 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d16.csv\n",
250 | "\n",
251 | "Indexing database (1281167, 32)\n",
252 | "Generating index file: 1K_ivfpq_m8_d32_nbits8_nlist1024.index\n",
253 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d32_nbits8_nlist1024.index\n",
254 | "Searching queries: (50000, 32)\n",
255 | "nprobe: 1\n",
256 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n",
257 | "\n",
258 | "Indexing database (1281167, 32)\n",
259 | "Generating index file: 1K_ivfpq_m16_d32_nbits8_nlist1024.index\n",
260 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d32_nbits8_nlist1024.index\n",
261 | "Searching queries: (50000, 32)\n",
262 | "nprobe: 1\n",
263 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n",
264 | "\n",
265 | "Indexing database (1281167, 32)\n",
266 | "Generating index file: 1K_ivfpq_m32_d32_nbits8_nlist1024.index\n",
267 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d32_nbits8_nlist1024.index\n",
268 | "Searching queries: (50000, 32)\n",
269 | "nprobe: 1\n",
270 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d32.csv\n",
271 | "\n",
272 | "Indexing database (1281167, 64)\n",
273 | "Generating index file: 1K_ivfpq_m8_d64_nbits8_nlist1024.index\n",
274 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d64_nbits8_nlist1024.index\n",
275 | "Searching queries: (50000, 64)\n",
276 | "nprobe: 1\n",
277 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n",
278 | "\n",
279 | "Indexing database (1281167, 64)\n",
280 | "Generating index file: 1K_ivfpq_m16_d64_nbits8_nlist1024.index\n",
281 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d64_nbits8_nlist1024.index\n",
282 | "Searching queries: (50000, 64)\n",
283 | "nprobe: 1\n",
284 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n",
285 | "\n",
286 | "Indexing database (1281167, 64)\n",
287 | "Generating index file: 1K_ivfpq_m32_d64_nbits8_nlist1024.index\n",
288 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d64_nbits8_nlist1024.index\n",
289 | "Searching queries: (50000, 64)\n",
290 | "nprobe: 1\n",
291 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n",
292 | "\n",
293 | "Indexing database (1281167, 64)\n",
294 | "Generating index file: 1K_ivfpq_m64_d64_nbits8_nlist1024.index\n",
295 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d64_nbits8_nlist1024.index\n",
296 | "Searching queries: (50000, 64)\n",
297 | "nprobe: 1\n",
298 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d64.csv\n",
299 | "\n",
300 | "Indexing database (1281167, 128)\n",
301 | "Generating index file: 1K_ivfpq_m8_d128_nbits8_nlist1024.index\n",
302 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d128_nbits8_nlist1024.index\n",
303 | "Searching queries: (50000, 128)\n",
304 | "nprobe: 1\n",
305 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n",
306 | "\n",
307 | "Indexing database (1281167, 128)\n",
308 | "Generating index file: 1K_ivfpq_m16_d128_nbits8_nlist1024.index\n",
309 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d128_nbits8_nlist1024.index\n",
310 | "Searching queries: (50000, 128)\n",
311 | "nprobe: 1\n",
312 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n",
313 | "\n",
314 | "Indexing database (1281167, 128)\n",
315 | "Generating index file: 1K_ivfpq_m32_d128_nbits8_nlist1024.index\n",
316 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d128_nbits8_nlist1024.index\n",
317 | "Searching queries: (50000, 128)\n",
318 | "nprobe: 1\n",
319 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n",
320 | "\n",
321 | "Indexing database (1281167, 128)\n",
322 | "Generating index file: 1K_ivfpq_m64_d128_nbits8_nlist1024.index\n",
323 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d128_nbits8_nlist1024.index\n",
324 | "Searching queries: (50000, 128)\n",
325 | "nprobe: 1\n",
326 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d128.csv\n",
327 | "\n",
328 | "Indexing database (1281167, 256)\n",
329 | "Generating index file: 1K_ivfpq_m8_d256_nbits8_nlist1024.index\n",
330 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d256_nbits8_nlist1024.index\n",
331 | "Searching queries: (50000, 256)\n",
332 | "nprobe: 1\n",
333 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n",
334 | "\n",
335 | "Indexing database (1281167, 256)\n",
336 | "Generating index file: 1K_ivfpq_m16_d256_nbits8_nlist1024.index\n",
337 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d256_nbits8_nlist1024.index\n",
338 | "Searching queries: (50000, 256)\n",
339 | "nprobe: 1\n",
340 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n",
341 | "\n",
342 | "Indexing database (1281167, 256)\n",
343 | "Generating index file: 1K_ivfpq_m32_d256_nbits8_nlist1024.index\n",
344 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d256_nbits8_nlist1024.index\n",
345 | "Searching queries: (50000, 256)\n",
346 | "nprobe: 1\n",
347 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n",
348 | "\n",
349 | "Indexing database (1281167, 256)\n",
350 | "Generating index file: 1K_ivfpq_m64_d256_nbits8_nlist1024.index\n",
351 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d256_nbits8_nlist1024.index\n",
352 | "Searching queries: (50000, 256)\n",
353 | "nprobe: 1\n",
354 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d256.csv\n",
355 | "\n",
356 | "Indexing database (1281167, 512)\n",
357 | "Generating index file: 1K_ivfpq_m8_d512_nbits8_nlist1024.index\n",
358 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d512_nbits8_nlist1024.index\n",
359 | "Searching queries: (50000, 512)\n",
360 | "nprobe: 1\n",
361 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n",
362 | "\n",
363 | "Indexing database (1281167, 512)\n",
364 | "Generating index file: 1K_ivfpq_m16_d512_nbits8_nlist1024.index\n",
365 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d512_nbits8_nlist1024.index\n",
366 | "Searching queries: (50000, 512)\n",
367 | "nprobe: 1\n",
368 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n",
369 | "\n",
370 | "Indexing database (1281167, 512)\n",
371 | "Generating index file: 1K_ivfpq_m32_d512_nbits8_nlist1024.index\n",
372 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d512_nbits8_nlist1024.index\n",
373 | "Searching queries: (50000, 512)\n",
374 | "nprobe: 1\n",
375 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n",
376 | "\n",
377 | "Indexing database (1281167, 512)\n",
378 | "Generating index file: 1K_ivfpq_m64_d512_nbits8_nlist1024.index\n",
379 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d512_nbits8_nlist1024.index\n",
380 | "Searching queries: (50000, 512)\n",
381 | "nprobe: 1\n",
382 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d512.csv\n",
383 | "\n",
384 | "Indexing database (1281167, 1024)\n",
385 | "Generating index file: 1K_ivfpq_m8_d1024_nbits8_nlist1024.index\n",
386 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d1024_nbits8_nlist1024.index\n",
387 | "Searching queries: (50000, 1024)\n",
388 | "nprobe: 1\n",
389 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n",
390 | "\n",
391 | "Indexing database (1281167, 1024)\n",
392 | "Generating index file: 1K_ivfpq_m16_d1024_nbits8_nlist1024.index\n",
393 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d1024_nbits8_nlist1024.index\n",
394 | "Searching queries: (50000, 1024)\n",
395 | "nprobe: 1\n",
396 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n",
397 | "\n",
398 | "Indexing database (1281167, 1024)\n",
399 | "Generating index file: 1K_ivfpq_m32_d1024_nbits8_nlist1024.index\n",
400 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d1024_nbits8_nlist1024.index\n",
401 | "Searching queries: (50000, 1024)\n",
402 | "nprobe: 1\n",
403 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n",
404 | "\n",
405 | "Indexing database (1281167, 1024)\n",
406 | "Generating index file: 1K_ivfpq_m64_d1024_nbits8_nlist1024.index\n",
407 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d1024_nbits8_nlist1024.index\n",
408 | "Searching queries: (50000, 1024)\n",
409 | "nprobe: 1\n",
410 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d1024.csv\n",
411 | "\n",
412 | "Indexing database (1281167, 2048)\n",
413 | "Generating index file: 1K_ivfpq_m8_d2048_nbits8_nlist1024.index\n",
414 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m8_d2048_nbits8_nlist1024.index\n"
415 | ]
416 | },
417 | {
418 | "name": "stdout",
419 | "output_type": "stream",
420 | "text": [
421 | "Searching queries: (50000, 2048)\n",
422 | "nprobe: 1\n",
423 | "ivfpq_m8_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n",
424 | "\n",
425 | "Indexing database (1281167, 2048)\n",
426 | "Generating index file: 1K_ivfpq_m16_d2048_nbits8_nlist1024.index\n",
427 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m16_d2048_nbits8_nlist1024.index\n",
428 | "Searching queries: (50000, 2048)\n",
429 | "nprobe: 1\n",
430 | "ivfpq_m16_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n",
431 | "\n",
432 | "Indexing database (1281167, 2048)\n",
433 | "Generating index file: 1K_ivfpq_m32_d2048_nbits8_nlist1024.index\n",
434 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m32_d2048_nbits8_nlist1024.index\n",
435 | "Searching queries: (50000, 2048)\n",
436 | "nprobe: 1\n",
437 | "ivfpq_m32_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n",
438 | "\n",
439 | "Indexing database (1281167, 2048)\n",
440 | "Generating index file: 1K_ivfpq_m64_d2048_nbits8_nlist1024.index\n",
441 | "Loading index file: ../../../inference_array/resnet50/index_files/mrl/ivfpq/1K_ivfpq_m64_d2048_nbits8_nlist1024.index\n",
442 | "Searching queries: (50000, 2048)\n",
443 | "nprobe: 1\n",
444 | "ivfpq_m64_nlist1024_nprobe1_2048shortlist_1K_d2048.csv\n",
445 | "\n"
446 | ]
447 | }
448 | ],
449 | "source": [
450 | "# Iterate over embedding dimensionality of database and queries\n",
451 | "for D in Ds:\n",
452 | " # Iterate over subquantizers (PQ) or num_centroids (kmeans)\n",
453 | " for dim in iterator:\n",
454 | " if (index_type in ['pq', 'ivfpq', 'opq']):\n",
455 | " m = dim # number of sub-vectors to divide d into\n",
456 | " if D % m != 0 or m > D:\n",
457 | " continue\n",
458 | " \n",
459 | " if not path.isdir(root+'index_files/'+model+'/'):\n",
460 | " makedirs(root+'index_files/'+model+'/')\n",
461 | " \n",
462 | " index_file = get_index_file(dim) \n",
463 | " \n",
464 | " if model != 'mrl/':\n",
465 | " db_npy = dataset + '_train_' + config +str(D)+ '-X.npy'\n",
466 | " print(\"Loading database: \", db_npy)\n",
467 | " xb = np.load(root+db_npy)\n",
468 | " \n",
469 | " database = np.ascontiguousarray(xb[:,:D], dtype=np.float32)\n",
470 | " faiss.normalize_L2(database)\n",
471 | " print(\"Indexing database \", database.shape)\n",
472 | " print(\"Generating index file: \" + index_file.split(\"/\")[-1])\n",
473 | "\n",
474 | " # Load or build index\n",
475 | " if path.exists(index_file):\n",
476 | " print(\"Loading index file: \" + index_file)\n",
477 | " cpu_index = faiss.read_index(index_file)\n",
478 | "\n",
479 | " else:\n",
480 | " if (index_type == 'pq'):\n",
481 | " cpu_index = faiss.IndexPQ(D, dim, nbits)\n",
482 | " cpu_index.train(database)\n",
483 | " \n",
484 | " elif index_type in ['ivfpq', 'opq']:\n",
485 | " quantizer = faiss.IndexFlatL2(D) # L2 quantizer to assign vectors to Voronoi cells\n",
486 | " cpu_index = faiss.IndexIVFPQ(quantizer, D, nlist, m, nbits)\n",
487 | "\n",
488 | " # Learn a transformation of the embedding space with Optimized PQ\n",
489 | " if index_type == 'opq':\n",
490 | " opq_matrix = faiss.OPQMatrix(D, dim)\n",
491 | " cpu_index = faiss.IndexPreTransform (opq_matrix, cpu_index)\n",
492 | "\n",
493 | " print(\"Training %s Index: d=%d, m=%d, nbits=%d, nlist=%d\" %(index_type,D,m,nbits,nlist))\n",
494 | " ivfpq_start = time.time()\n",
495 | " cpu_index.train(database)\n",
496 | " ivfpq_end = time.time()\n",
497 | " print(\"%s Index train time= %0.3f sec\" % (index_type, ivfpq_end - ivfpq_start))\n",
498 | " \n",
499 | " elif index_type == 'ivfsq':\n",
500 | " quantizer = faiss.IndexFlatL2(D)\n",
501 | " cpu_index = faiss.IndexIVFScalarQuantizer(quantizer, D, nlist, qtype)\n",
502 | " print(\"Training IVFSQ Index: d=%d, qtype=%s, nlist=%d\" %(d,qtype,nlist))\n",
503 | " ivfsq_start = time.time()\n",
504 | " cpu_index.train(database)\n",
505 | " ivfsq_end = time.time()\n",
506 | " print(\"IVFSQ Index train time= %0.3f sec\" % (ivfsq_end - ivfsq_start))\n",
507 | " \n",
508 | "\n",
509 | " elif index_type in ['kmeans', 'kmeans-pca']:\n",
510 | " ncentroids = dim\n",
511 | " kmeans = faiss.Kmeans(D, ncentroids, verbose=True)\n",
512 | " if use_svd: \n",
513 | " database = np.load(root+db_npy_svd)\n",
514 | " \n",
515 | " print(\"Learning %s index with d=%d and k=%d\" %(index_type, D, ncentroids))\n",
516 | " kmeans_start = time.time()\n",
517 | " kmeans.train(database)\n",
518 | " kmeans_end = time.time()\n",
519 | " print(\"%s index train time= %0.3f sec\" % (index_type, kmeans_end - kmeans_start)) \n",
520 | " cpu_index = kmeans.index\n",
521 | " centroids_path = root+'kmeans/'+model+index_type+'_ncentroids'+str(ncentroids)+\"_\"+str(D)+'d'\"_\"+dataset+'.npy'\n",
522 | " print(\"Saving centroids: \", kmeans.centroids.shape)\n",
523 | " with open(centroids_path, 'wb') as f:\n",
524 | " np.save(f, kmeans.centroids)\n",
525 | "\n",
526 | " \n",
527 | " # add database embeddings to the index and save to disk \n",
528 | " cpu_index.add(database) \n",
529 | " faiss.write_index(cpu_index, index_file)\n",
530 | "\n",
531 | " if use_gpu:\n",
532 | " index = faiss.index_cpu_to_all_gpus(cpu_index)\n",
533 | " print(\"Moved to GPU\")\n",
534 | " else:\n",
535 | " index = cpu_index\n",
536 | "\n",
537 | " if model != 'mrl/':\n",
538 | " if use_svd:\n",
539 | " query_npy = query_npy_svd\n",
540 | " else:\n",
541 | " query_npy = dataset + '_val_' + config +str(D)+ '-X.npy'\n",
542 | " print(\"Loading queries: \", query_npy)\n",
543 | " xq = np.load(root+query_npy)\n",
544 | " queryset = np.ascontiguousarray(xq[:,:D], dtype=np.float32)\n",
545 | " faiss.normalize_L2(queryset)\n",
546 | " \n",
547 | " # kmeans indices are used downstream by adanns.ipynb \n",
548 | " if index_type not in ['kmeans', 'kmeans-pca']:\n",
549 | " print(\"Searching queries: \", queryset.shape)\n",
550 | " # Loop over search probes/ beam width\n",
551 | " for nprobe in nprobes:\n",
552 | " if index_type != 'pq':\n",
553 | " index.nprobe=nprobe\n",
554 | " print(\"nprobe: \", index.nprobe)\n",
555 | " \n",
556 | " Dist, Ind = index.search(queryset, k)\n",
557 | "\n",
558 | " if index_type == 'pq':\n",
559 | " nn_dir = root+\"neighbors/\"+model+index_type+\"/\"+index_type+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(D)+\".csv\"\n",
560 | " else:\n",
561 | " nn_dir = root+\"neighbors/\"+model+index_type+\"/\"+index_type+\"_m\"+str(dim)+'_nlist'+str(nlist)+'_nprobe'+str(nprobe)+\"_\"+str(k)+\"shortlist_\"+dataset+\"_d\"+str(D)+\".csv\"\n",
562 | " if not path.isdir(root+\"neighbors/\"+model+index_type+\"/\"):\n",
563 | " makedirs(root+\"neighbors/\"+model+index_type+\"/\")\n",
564 | " \n",
565 | " print(nn_dir.split(\"/\")[-1]+\"\\n\")\n",
566 | " pd.DataFrame(Ind).to_csv(nn_dir, header=None, index=None)\n",
567 | " \n",
568 | " # Optional test for duplicate distances in high M regimes\n",
569 | " #get_duplicate_dist(Ind, Dist)\n",
570 | " \n",
571 | " del Dist, Ind\n",
572 | " del index"
573 | ]
574 | }
575 | ],
576 | "metadata": {
577 | "jupytext": {
578 | "cell_metadata_filter": "-all",
579 | "notebook_metadata_filter": "-all"
580 | },
581 | "kernelspec": {
582 | "display_name": "Python 3 (ipykernel)",
583 | "language": "python",
584 | "name": "python3"
585 | },
586 | "language_info": {
587 | "codemirror_mode": {
588 | "name": "ipython",
589 | "version": 3
590 | },
591 | "file_extension": ".py",
592 | "mimetype": "text/x-python",
593 | "name": "python",
594 | "nbconvert_exporter": "python",
595 | "pygments_lexer": "ipython3",
596 | "version": "3.11.3"
597 | },
598 | "vscode": {
599 | "interpreter": {
600 | "hash": "51ae9d60c33a8ae5621576c9f7a44d174a8f6e30fb616100a36dfd42ed0f76dc"
601 | }
602 | }
603 | },
604 | "nbformat": 4,
605 | "nbformat_minor": 5
606 | }
607 |
--------------------------------------------------------------------------------
/adanns/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import faiss
3 | import torch
4 | import matplotlib.pyplot as plt
5 |
6 | valid_models = ['mrl', 'rr']
7 | valid_datasets = ['1K', '4K', 'A' ,'R' ,'O', 'V2'] # ImageNet versions
8 |
9 | def load_embeddings(model, dataset, embedding_dim, arch='resnet50', using_gcloud=True):
10 | if model == 'mrl':
11 | config = 'mrl1_e0_ff2048'
12 | elif model == 'rr': # using rigid representation
13 | config = f'mrl0_e0_ff{embedding_dim}'
14 | else:
15 | raise ValueError(f'Model must be in {valid_models}')
16 |
17 |
18 | if dataset not in valid_datasets:
19 | raise ValueError(f'Dataset must be in {valid_datasets}')
20 | if dataset == 'V2': # ImageNetv2 is only a test set; set database to ImageNet-1K
21 | dataset = '1K'
22 |
23 |
24 | if using_gcloud:
25 | root = f'../../../inference_array/{arch}/'
26 | else: # using local machine
27 | root = f'../../inference_array/{arch}/'
28 |
29 |
30 | db_npy = dataset + '_train_' + config + '-X.npy'
31 | query_npy = dataset + '_val_' + config + '-X.npy'
32 | db_label_npy = dataset + '_train_' + config + '-y.npy'
33 | query_label_npy = dataset + '_val_' + config + '-y.npy'
34 |
35 | database = np.load(root + db_npy)
36 | queryset = np.load(root + query_npy)
37 | db_labels = np.load(root + db_label_npy)
38 | query_labels = np.load(root + query_label_npy)
39 |
40 | faiss.normalize_L2(database)
41 | faiss.normalize_L2(queryset)
42 |
43 | xb = np.ascontiguousarray(database[:, :embedding_dim], dtype=np.float32)
44 | xq = np.ascontiguousarray(queryset[:, :embedding_dim], dtype=np.float32)
45 |
46 | faiss.normalize_L2(xb)
47 | faiss.normalize_L2(xq)
48 |
49 | return database, queryset, db_labels, query_labels, xb, xq
50 |
51 | # Find Duplicate distances for low M (M=1 is k-means) on searched faiss index
52 | def get_duplicate_dist(Ind, Dist):
53 | k = 100
54 | duplicates = []
55 | for i in range(50000):
56 | indices = Ind[i, :][:k]
57 | distances = Dist[i, :][:k]
58 | unique_distances = np.unique(distances, return_counts=1)[1]
59 | unique_distances = unique_distances[unique_distances != 1] # remove all 1s (i.e. unique distances)
60 | duplicates = np.append(duplicates, unique_distances)
61 |
62 | hist = plt.hist(duplicates, bins='auto')
63 | plt.title(model.split("/")[0]+", D="+str(D)+ ", M=" +str(dim))
64 | plt.xlabel("Number of 100-NN with same neighbor distance values")
65 | plt.show()
66 | print(duplicates[duplicates > 2].sum())
67 |
68 | # Normalize embeddings
69 | def normalize_embeddings(embeddings, dtype):
70 | if dtype == 'float32':
71 | print(np.linalg.norm(embeddings))
72 | faiss.normalize_L2(embeddings)
73 | print(np.linalg.norm(embeddings))
74 | return embeddings
75 | else:
76 | pass
--------------------------------------------------------------------------------
/generate_embeddings/pytorch_inference.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code to evaluate MRL models on different validation benchmarks.
3 | '''
4 | import sys
5 | sys.path.append("../") # adding root folder to the path
6 |
7 | import torch
8 | import torchvision
9 | from torchvision import transforms
10 | from torchvision.models import *
11 | from torchvision import datasets
12 | from tqdm import tqdm
13 |
14 | from MRL import *
15 | from imagenetv2_pytorch import ImageNetV2Dataset
16 | from argparse import ArgumentParser
17 | from utils import *
18 |
19 | # nesting list is by default from 8 to 2048 in powers of 2, can be modified from here.
20 | BATCH_SIZE = 256
21 | IMG_SIZE = 256
22 | CENTER_CROP_SIZE = 224
23 | #NESTING_LIST=[2**i for i in range(3, 12)]
24 | ROOT="../../IMAGENET/" # path to IN1K
25 | DATASET_ROOT="../../datasets/" #
26 |
27 | parser=ArgumentParser()
28 |
29 | # model args
30 | parser.add_argument('--efficient', action='store_true', help='Efficient Flag')
31 | parser.add_argument('--mrl', action='store_true', help='To use MRL')
32 | parser.add_argument('--rep_size', type=int, default=2048, help='Rep. size for fixed feature model')
33 | parser.add_argument('--path', type=str, required=True, help='Path to .pt model checkpoint')
34 | parser.add_argument('--old_ckpt', action='store_true', help='To use our trained checkpoints')
35 | parser.add_argument('--workers', type=int, default=12, help='num workers for dataloader')
36 | parser.add_argument('--model_arch', type=str, default='resnet50', help='Loaded model arch')
37 | # dataset/eval args
38 | parser.add_argument('--tta', action='store_true', help='Test Time Augmentation Flag')
39 | parser.add_argument('--dataset', type=str, default='V1', help='Benchmarks')
40 | parser.add_argument('--save_logits', action='store_true', help='To save logits for model analysis')
41 | parser.add_argument('--save_softmax', action='store_true', help='To save softmax_probs for model analysis')
42 | parser.add_argument('--save_gt', action='store_true', help='To save ground truth for model analysis')
43 | parser.add_argument('--save_predictions', action='store_true', help='To save predicted labels for model analysis')
44 | # retrieval args
45 | parser.add_argument('--retrieval', action='store_true', help='flag for image retrieval array dumps')
46 | parser.add_argument('--random_sample_dim', type=int, default=4202000, help='number of random samples to slice from retrieval database')
47 | parser.add_argument('--retrieval_array_path', default='', help='path to save database and query arrays for retrieval', type=str)
48 |
49 |
50 | args = parser.parse_args()
51 |
52 | if args.model_arch == 'convnext_tiny':
53 | convnext_tiny = create_model(
54 | 'convnext_tiny',
55 | pretrained=False,
56 | num_classes=1000,
57 | drop_path_rate=0.1,
58 | layer_scale_init_value=1e-6,
59 | head_init_scale=1.0
60 | )
61 |
62 | model_arch_dict = {
63 | 'vgg19': {'model': vgg19(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]},
64 | 'resnet18': {'model': resnet18(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512]},
65 | 'resnet34': {'model': resnet34(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512]},
66 | 'resnet50': {'model': resnet50(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048]},
67 | 'resnet101': {'model': resnet101(False), 'nest_list': [8, 16, 32, 64, 128, 256, 512, 1024, 2048]},
68 | 'mobilenetv2': {'model': mobilenet_v2(False), 'nest_list': [10, 20, 40, 80, 160, 320, 640, 1280]},
69 | 'convnext_tiny': {'model': convnext_tiny, 'nest_list': [12, 24, 48, 96, 192, 384, 768]},
70 | }
71 |
72 | model = model_arch_dict[args.model_arch]['model']
73 | print(model)
74 |
75 | if not args.old_ckpt:
76 | if args.mrl:
77 | if args.model_arch == 'mobilenetv2':
78 | model.classifier[1] = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], num_classes=1000, efficient=args.efficient)
79 | elif args.model_arch == 'convnext_tiny':
80 | model.head = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], num_classes=1000, efficient=args.efficient)
81 | else:
82 | model.fc = MRL_Linear_Layer(model_arch_dict[args.model_arch]['nest_list'], efficient=args.efficient)
83 | else:
84 | model.fc=FixedFeatureLayer(args.rep_size, 1000) # RR model
85 | else:
86 | if args.mrl:
87 | model = load_from_old_ckpt(model, args.efficient, model_arch_dict[args.model_arch]['nest_list'])
88 | else:
89 | model.fc=FixedFeatureLayer(args.rep_size, 1000)
90 |
91 | print(model.fc)
92 | apply_blurpool(model)
93 | model.load_state_dict(get_ckpt(args.path, args.model_arch)) # Since our models have a torch DDP wrapper, we modify keys to exclude first 7 chars (".module").
94 | model = model.cuda()
95 | model.eval()
96 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
97 | test_transform = transforms.Compose([
98 | transforms.Resize(IMG_SIZE),
99 | transforms.CenterCrop(CENTER_CROP_SIZE),
100 | transforms.ToTensor(),
101 | normalize])
102 |
103 | # Model Eval
104 | if not args.retrieval:
105 | if args.dataset == 'V2':
106 | print("Loading Robustness Dataset")
107 | dataset = ImageNetV2Dataset("matched-frequency", transform=test_transform)
108 | elif args.dataset == '4K':
109 | train_path = DATASET_ROOT+"imagenet-4k/train/"
110 | test_path = DATASET_ROOT+"imagenet-4k/test/"
111 | train_dataset = datasets.ImageFolder(train_path, transform=test_transform)
112 | test_dataset = datasets.ImageFolder(test_path, transform=test_transform)
113 | else:
114 | print("Loading Imagenet 1K val set")
115 | dataset = torchvision.datasets.ImageFolder(ROOT+'val/', transform=test_transform)
116 |
117 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False)
118 |
119 | if args.mrl:
120 | _, top1_acc, top5_acc, total_time, num_images, m_score_dict, softmax_probs, gt, logits = evaluate_model(
121 | model, dataloader, show_progress_bar=True, nesting_list=model_arch_dict[args.model_arch]['nest_list'], tta=args.tta, imagenetA=args.dataset == 'A', imagenetR=args.dataset == 'R')
122 | else:
123 | _, top1_acc, top5_acc, total_time, num_images, m_score_dict, softmax_probs, gt, logits = evaluate_model(
124 | model, dataloader, show_progress_bar=True, nesting_list=None, tta=args.tta, imagenetA=args.dataset == 'A', imagenetR=args.dataset == 'R')
125 |
126 | tqdm.write('Evaluated {} images'.format(num_images))
127 | confidence, predictions = torch.max(softmax_probs, dim=-1)
128 | if args.mrl:
129 | for i, nesting in enumerate(model_arch_dict[args.model_arch]['nest_list']):
130 | print("Rep. Size", "\t", nesting, "\n")
131 | tqdm.write(' Top-1 accuracy for {} : {:.2f}'.format(nesting, 100.0 * top1_acc[nesting]))
132 | tqdm.write(' Top-5 accuracy for {} : {:.2f}'.format(nesting, 100.0 * top5_acc[nesting]))
133 | tqdm.write(' Total time: {:.1f} (average time per image: {:.2f} ms)'.format(total_time, 1000.0 * total_time / num_images))
134 | else:
135 | print("Rep. Size", "\t", args.rep_size, "\n")
136 | tqdm.write(' Evaluated {} images'.format(num_images))
137 | tqdm.write(' Top-1 accuracy: {:.2f}%'.format(100.0 * top1_acc))
138 | tqdm.write(' Top-5 accuracy: {:.2f}%'.format(100.0 * top5_acc))
139 | tqdm.write(' Total time: {:.1f} (average time per image: {:.2f} ms)'.format(total_time, 1000.0 * total_time / num_images))
140 |
141 |
142 | # saving torch tensor for model analysis...
143 | if args.save_logits or args.save_softmax or args.save_predictions:
144 | save_string = f"mrl={args.mrl}_efficient={args.efficient}_dataset={args.dataset}_tta={args.tta}"
145 | if args.save_logits:
146 | torch.save(logits, save_string+"_logits.pth")
147 | if args.save_predictions:
148 | torch.save(predictions, save_string+"_predictions.pth")
149 | if args.save_softmax:
150 | torch.save(softmax_probs, save_string+"_softmax.pth")
151 |
152 | if args.save_gt:
153 | torch.save(gt, f"gt_dataset={args.dataset}.pth")
154 |
155 |
156 | # Image Retrieval Inference
157 | else:
158 | if args.dataset == '1K':
159 | train_dataset = datasets.ImageFolder(ROOT+"train/", transform=test_transform)
160 | test_dataset = datasets.ImageFolder(ROOT+"val/", transform=test_transform)
161 | elif args.dataset == 'V2':
162 | train_dataset = None # V2 has only a test set
163 | test_dataset = ImageNetV2Dataset("matched-frequency", transform=test_transform)
164 | elif args.dataset == '4K':
165 | train_path = DATASET_ROOT+"imagenet-4k/train/"
166 | test_path = DATASET_ROOT+"imagenet-4k/test/"
167 | train_dataset = datasets.ImageFolder(train_path, transform=test_transform)
168 | test_dataset = datasets.ImageFolder(test_path, transform=test_transform)
169 | else:
170 | print("Error: unsupported dataset!")
171 |
172 | database_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False)
173 | queryset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=args.workers, shuffle=False)
174 |
175 | config = args.model_arch+ "/" + args.dataset + "_val_mrl" + str(int(args.mrl)) + "_e" + str(int(args.efficient)) + "_rr" + str(int(args.rep_size))
176 | print("Retrieval Config: " + config)
177 | generate_retrieval_data(model, queryset_loader, config, args.random_sample_dim, args.rep_size, args.retrieval_array_path, args.model_arch)
178 |
179 | if train_dataset is not None:
180 | config = args.model_arch+ "/" + args.dataset + "_train_mrl" + str(int(args.mrl)) + "_e" + str(int(args.efficient)) + "_rr" + str(int(args.rep_size))
181 | print("Retrieval Config: " + config)
182 | generate_retrieval_data(model, database_loader, config, args.random_sample_dim, args.rep_size, args.retrieval_array_path, args.model_arch)
183 |
--------------------------------------------------------------------------------
/generate_embeddings/run-inference.sh:
--------------------------------------------------------------------------------
1 | config=RR # RR or MRL
2 |
3 | if config=RR
4 | then
5 | for dim in 8 16 32 64 128 256 512 1024 2048
6 | do
7 | echo "Generating embeddings on RR-$dim"
8 | python pytorch_inference.py --retrieval --path=path/to/rr/model/weights.pt \
9 | --model_arch='resnet50' --retrieval_array_path=rr_output_dir/ --dataset=1K --rep_size=$dim
10 | done
11 | else
12 | echo "Generating embeddings on MRL"
13 | python pytorch_inference.py --retrieval --path=path/to/mrl/model/weights.pt \
14 | --model_arch='resnet50' --retrieval_array_path=mrl_output_dir/ --dataset=1K --mrl
15 | fi
16 |
--------------------------------------------------------------------------------
/images/accuracy-compute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/accuracy-compute.png
--------------------------------------------------------------------------------
/images/adanns-opq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/adanns-opq.png
--------------------------------------------------------------------------------
/images/adanns-teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/adanns-teaser.png
--------------------------------------------------------------------------------
/images/diskann-table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/diskann-table.png
--------------------------------------------------------------------------------
/images/diskann-top1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/diskann-top1.png
--------------------------------------------------------------------------------
/images/encoders.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/encoders.png
--------------------------------------------------------------------------------
/images/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/flowchart.png
--------------------------------------------------------------------------------
/images/opq-1k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/opq-1k.png
--------------------------------------------------------------------------------
/images/opq-nq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/AdANNS/6bad9e4a4aa1345d484c49dd547797849941295a/images/opq-nq.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | terminaltables
2 | pytorch_pfn_extras
3 | fastargs
4 | matplotlib
5 | sklearn
6 | imgcat
7 | pandas
8 | assertpy
9 | tqdm
10 | psutil
11 | webdataset
12 | torchmetrics
13 | git+https://github.com/modestyachts/ImageNetV2_pytorch
14 | grad-cam
15 | scikit-image
16 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.cuda.amp import autocast
5 | from typing import Type, Any, Callable, Union, List, Optional
6 | from torchvision.models import *
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 |
11 | def get_ckpt(path, arch):
12 | ckpt=path
13 | ckpt = torch.load(ckpt, map_location='cpu')
14 | if arch == 'convnext_tiny':
15 | ckpt = ckpt['model']
16 | plain_ckpt={}
17 | for k in ckpt.keys():
18 | plain_ckpt[k[7:]] = ckpt[k] # remove the 'module' portion of key if model is Pytorch DDP
19 | return plain_ckpt
20 |
21 |
22 | class BlurPoolConv2d(torch.nn.Module):
23 | def __init__(self, conv):
24 | super().__init__()
25 | default_filter = torch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
26 | filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
27 | self.conv = conv
28 | self.register_buffer('blur_filter', filt)
29 |
30 | def forward(self, x):
31 | blurred = F.conv2d(x, self.blur_filter, stride=1, padding=(1, 1),
32 | groups=self.conv.in_channels, bias=None)
33 | return self.conv.forward(blurred)
34 |
35 |
36 | def apply_blurpool(mod: torch.nn.Module):
37 | for (name, child) in mod.named_children():
38 | if isinstance(child, torch.nn.Conv2d) and (np.max(child.stride) > 1 and child.in_channels >= 16):
39 | setattr(mod, name, BlurPoolConv2d(child))
40 | else: apply_blurpool(child)
41 |
42 |
43 | '''
44 | Retrieval utility methods.
45 | '''
46 | activation = {}
47 | fwd_pass_x_list = []
48 | fwd_pass_y_list = []
49 | path_list = []
50 |
51 | def get_activation(name):
52 | """
53 | Get the activation from an intermediate point in the network.
54 | :param name: layer whose activation is to be returned
55 | :return: activation of layer
56 | """
57 | def hook(model, input, output):
58 | activation[name] = output.detach()
59 | return hook
60 |
61 |
62 | def append_feature_vector_to_list(activation, label, rep_size, path):
63 | """
64 | Append the feature vector to a list to later write to disk.
65 | :param activation: image feature vector from network
66 | :param label: ground truth label
67 | :param rep_size: representation size to be stored
68 | """
69 | for i in range (activation.shape[0]):
70 | x = activation[i].cpu().detach().numpy()
71 | y = label[i].cpu().detach().numpy()
72 | fwd_pass_y_list.append(y)
73 | fwd_pass_x_list.append(x[:rep_size])
74 |
75 | def dump_feature_vector_array_lists(config_name, rep_size, random_sample_dim, output_path):
76 | """
77 | Save the database and query vector array lists to disk.
78 | :param config_name: config to specify during file write
79 | :param rep_size: representation size for fixed feature model
80 | :param random_sample_dim: to write a subset of database if required, e.g. to train an SVM on 100K samples
81 | :param output_path: path to dump database and query arrays after inference
82 | """
83 |
84 | # save X (n x 2048), y (n x 1) to disk, where n = num_samples
85 | X_fwd_pass = np.asarray(fwd_pass_x_list, dtype=np.float32)
86 | y_fwd_pass = np.asarray(fwd_pass_y_list, dtype=np.uint16).reshape(-1,1)
87 |
88 | if random_sample_dim < X_fwd_pass.shape[0]:
89 | random_indices = np.random.choice(X_fwd_pass.shape[0], size=random_sample_dim, replace=False)
90 | random_X = X_fwd_pass[random_indices, :]
91 | random_y = y_fwd_pass[random_indices, :]
92 | print("Writing random samples to disk with dim [%d x 2048] " % random_sample_dim)
93 | else:
94 | random_X = X_fwd_pass
95 | random_y = y_fwd_pass
96 | print("Writing %s to disk with dim [%d x %d]" % (str(config_name)+"_X", X_fwd_pass.shape[0], rep_size))
97 |
98 | print("Unique entries: ", len(np.unique(random_y)))
99 | np.save(output_path+str(config_name)+'-X.npy', random_X)
100 | np.save(output_path+str(config_name)+'-y.npy', random_y)
101 |
102 |
103 | def generate_retrieval_data(model, data_loader, config, random_sample_dim, rep_size, output_path, model_arch):
104 | """
105 | Iterate over data in dataloader, get feature vector from model inference, and save to array to dump to disk.
106 | :param model: ResNet50 model loaded from disk
107 | :param data_loader: loader for database or query set
108 | :param config: name of configuration for writing arrays to disk
109 | :param random_sample_dim: to write a subset of database if required, e.g. to train an SVM on 100K samples
110 | :param rep_size: representation size for fixed feature model
111 | :param output_path: path to dump database and query arrays after inference
112 | """
113 | model.eval()
114 | if model_arch == 'vgg19':
115 | model.classifier[5].register_forward_hook(get_activation('avgpool'))
116 | elif model_arch == 'mobilenetv2':
117 | model.classifier[0].register_forward_hook(get_activation('avgpool'))
118 | else:
119 | model.avgpool.register_forward_hook(get_activation('avgpool'))
120 | print("Dataloader len: ", len(data_loader))
121 |
122 | with torch.no_grad():
123 | with autocast():
124 | for i_batch, (images, target) in enumerate(data_loader):
125 | output = model(images.cuda())
126 | path = None
127 | append_feature_vector_to_list(activation['avgpool'].squeeze(), target.cuda(), rep_size, path)
128 | if (i_batch) % int(len(data_loader)/5) == 0:
129 | print("Finished processing: %f %%" % (i_batch / len(data_loader) * 100))
130 | dump_feature_vector_array_lists(config, rep_size, random_sample_dim, output_path)
131 |
132 | # re-initialize empty lists
133 | global fwd_pass_x_list
134 | global fwd_pass_y_list
135 | global path_list
136 | fwd_pass_x_list = []
137 | fwd_pass_y_list = []
138 | path_list = []
139 |
140 | '''
141 | Load pretrained models saved with old notation.
142 | '''
143 | class SingleHeadNestedLinear(nn.Linear):
144 | """
145 | Class for MRL-E model.
146 | """
147 |
148 | def __init__(self, nesting_list: List, num_classes=1000, **kwargs):
149 | super(SingleHeadNestedLinear, self).__init__(nesting_list[-1], num_classes, **kwargs)
150 | self.nesting_list=nesting_list
151 | self.num_classes=num_classes # Number of classes for classification
152 |
153 | def forward(self, x):
154 | nesting_logits = ()
155 | for i, num_feat in enumerate(self.nesting_list):
156 | if not (self.bias is None):
157 | logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t()) + self.bias
158 | else:
159 | logit = torch.matmul(x[:, :num_feat], (self.weight[:, :num_feat]).t())
160 | nesting_logits+= (logit,)
161 | return nesting_logits
162 |
163 | class MultiHeadNestedLinear(nn.Module):
164 | """
165 | Class for MRL model.
166 | """
167 | def __init__(self, nesting_list: List, num_classes=1000, **kwargs):
168 | super(MultiHeadNestedLinear, self).__init__()
169 | self.nesting_list=nesting_list
170 | self.num_classes=num_classes # Number of classes for classification
171 | for i, num_feat in enumerate(self.nesting_list):
172 | setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))
173 |
174 | def forward(self, x):
175 | nesting_logits = ()
176 | for i, num_feat in enumerate(self.nesting_list):
177 | nesting_logits += (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)
178 | return nesting_logits
179 |
180 | def load_from_old_ckpt(model, efficient, nesting_list):
181 | if efficient:
182 | model.fc=SingleHeadNestedLinear(nesting_list)
183 | else:
184 | model.fc=MultiHeadNestedLinear(nesting_list)
185 |
186 | return model
187 |
--------------------------------------------------------------------------------