├── data ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── io.cpython-38.pyc │ │ └── __init__.cpython-38.pyc │ └── io.py ├── README.md ├── ivf.py └── rabitq.py ├── bin └── README.md ├── results └── README.md ├── technical_report.pdf ├── script ├── index.sh └── search.sh ├── src ├── index.cpp ├── search.cpp ├── space.h ├── fast_scan.h ├── utils.h ├── matrix.h └── ivf_rabitq.h ├── README.md └── LICENSE /data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/README.md: -------------------------------------------------------------------------------- 1 | This folder includes binary files. -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | The result of time-accuracy trade-off. -------------------------------------------------------------------------------- /technical_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoj0017/RaBitQ/HEAD/technical_report.pdf -------------------------------------------------------------------------------- /data/utils/__pycache__/io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoj0017/RaBitQ/HEAD/data/utils/__pycache__/io.cpython-38.pyc -------------------------------------------------------------------------------- /data/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaoj0017/RaBitQ/HEAD/data/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /script/index.sh: -------------------------------------------------------------------------------- 1 | 2 | C=4096 3 | data='sift' 4 | D=128 5 | B=128 6 | source='./data' 7 | 8 | g++ -o ./bin/index_${data} ./src/index.cpp -I ./src/ -O3 -march=core-avx2 -D BB=${B} -D DIM=${D} -D numC=${C} -D B_QUERY=4 -D SCAN 9 | 10 | ./bin/index_${data} -d $data -s "$source/$data/" -------------------------------------------------------------------------------- /script/search.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | source='./data' 5 | data='sift' 6 | C=4096 7 | B=128 8 | D=128 9 | k=100 10 | 11 | g++ -march=core-avx2 -Ofast -o ./bin/search_${data} ./src/search.cpp -I ./src/ -D BB=${B} -D DIM=${D} -D numC=${C} -D B_QUERY=4 -D FAST_SCAN 12 | 13 | result_path=./results 14 | mkdir ${result_path} 15 | 16 | res="${result_path}/${data}/" 17 | 18 | mkdir "$result_path/${data}/" 19 | 20 | ./bin/search_${data} -d ${data} -r ${res} -k ${k} -s "$source/$data/" 21 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Prerequisites 3 | * Python == 3.8, numpy == 1.20.3, faiss == 1.7.4, tqdm 4 | 5 | # Datasets 6 | 7 | The tested datasets are available at https://www.cse.cuhk.edu.hk/systems/hash/gqr/datasets.html. The SIFT dataset which contains the groundtruth file can be downloaded from ftp://ftp.irisa.fr/local/texmex/corpus/sift.tar.gz. 8 | 9 | ## Reproduction 10 | 1. Download a dataset. The data format of ``.fvecs'' can be found in http://corpus-texmex.irisa.fr/. 11 | 12 | 2. Unzip the dataset. 13 | 14 | 3. Build the IVF index for the dataset. 15 | 16 | ```shell 17 | python ivf.py 18 | ``` 19 | 20 | 4. Run the index phase of RaBitQ for the dataset. 21 | 22 | ```shell 23 | python rabitq.py 24 | ``` -------------------------------------------------------------------------------- /data/ivf.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import faiss 4 | import struct 5 | import os 6 | from utils.io import * 7 | 8 | source = './' 9 | 10 | if __name__ == '__main__': 11 | 12 | dataset = 'sift' 13 | print(f"Clustering - {dataset}") 14 | # path 15 | path = os.path.join(source, dataset) 16 | data_path = os.path.join(path, f'{dataset}_base.fvecs') 17 | X = read_fvecs(data_path) 18 | D = X.shape[1] 19 | K = 4096 20 | centroids_path = os.path.join(path, f'{dataset}_centroid_{K}.fvecs') 21 | dist_to_centroid_path = os.path.join(path, f'{dataset}_dist_to_centroid_{K}.fvecs') 22 | cluster_id_path = os.path.join(path, f'{dataset}_cluster_id_{K}.ivecs') 23 | 24 | # cluster data vectors 25 | index = faiss.index_factory(D, f"IVF{K},Flat") 26 | index.verbose = True 27 | index.train(X) 28 | centroids = index.quantizer.reconstruct_n(0, index.nlist) 29 | dist_to_centroid, cluster_id = index.quantizer.search(X, 1) 30 | dist_to_centroid = dist_to_centroid ** 0.5 31 | 32 | to_fvecs(dist_to_centroid_path, dist_to_centroid) 33 | to_ivecs(cluster_id_path, cluster_id) 34 | to_fvecs(centroids_path, centroids) 35 | -------------------------------------------------------------------------------- /data/utils/io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import struct 3 | from tqdm import tqdm 4 | 5 | def read_fvecs(filename, c_contiguous=True): 6 | print(f"Reading from {filename}.") 7 | fv = np.fromfile(filename, dtype=np.float32) 8 | if fv.size == 0: 9 | return np.zeros((0, 0)) 10 | dim = fv.view(np.int32)[0] 11 | assert dim > 0 12 | fv = fv.reshape(-1, 1 + dim) 13 | if not all(fv.view(np.int32)[:, 0] == dim): 14 | raise IOError("Non-uniform vector sizes in " + filename) 15 | fv = fv[:, 1:] 16 | if c_contiguous: 17 | fv = fv.copy() 18 | return fv 19 | 20 | def read_ivecs(filename, c_contiguous=True): 21 | fv = np.fromfile(filename, dtype=np.int32) 22 | if fv.size == 0: 23 | return np.zeros((0, 0)) 24 | dim = fv.view(np.int32)[0] 25 | assert dim > 0 26 | fv = fv.reshape(-1, 1 + dim) 27 | if not all(fv.view(np.int32)[:, 0] == dim): 28 | raise IOError("Non-uniform vector sizes in " + filename) 29 | fv = fv[:, 1:] 30 | if c_contiguous: 31 | fv = fv.copy() 32 | return fv 33 | 34 | def to_fvecs(filename, data): 35 | print(f"Writing File - {filename}") 36 | with open(filename, 'wb') as fp: 37 | for y in tqdm(data): 38 | d = struct.pack('I', len(y)) 39 | fp.write(d) 40 | for x in y: 41 | a = struct.pack('f', x) 42 | fp.write(a) 43 | 44 | def to_Ivecs(filename, data): 45 | print(f"Writing File - {filename}") 46 | with open(filename, 'wb') as fp: 47 | for y in tqdm(data): 48 | d = struct.pack('I', len(y)) 49 | fp.write(d) 50 | for x in y: 51 | a = struct.pack('Q', x) 52 | fp.write(a) 53 | 54 | def to_ivecs(filename, data): 55 | print(f"Writing File - {filename}") 56 | with open(filename, 'wb') as fp: 57 | for y in data: 58 | d = struct.pack('I', len(y)) 59 | fp.write(d) 60 | for x in y: 61 | a = struct.pack('I', x) 62 | fp.write(a) -------------------------------------------------------------------------------- /src/index.cpp: -------------------------------------------------------------------------------- 1 | #define EIGEN_DONT_PARALLELIZE 2 | #define USE_AVX2 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "matrix.h" 11 | #include "utils.h" 12 | #include "ivf_rabitq.h" 13 | 14 | using namespace std; 15 | 16 | int main(int argc, char * argv[]) { 17 | 18 | const struct option longopts[] ={ 19 | // General Parameter 20 | {"help", no_argument, 0, 'h'}, 21 | 22 | // Indexing Path 23 | {"dataset", required_argument, 0, 'd'}, 24 | {"source", required_argument, 0, 's'}, 25 | }; 26 | 27 | int ind; 28 | int iarg = 0; 29 | opterr = 1; //getopt error message (off: 0) 30 | 31 | char dataset[256]=""; 32 | char source[256]=""; 33 | 34 | while(iarg != -1){ 35 | iarg = getopt_long(argc, argv, "d:s:", longopts, &ind); 36 | switch (iarg){ 37 | case 'd': 38 | if(optarg){ 39 | strcpy(dataset, optarg); 40 | } 41 | break; 42 | case 's': 43 | if(optarg){ 44 | strcpy(source, optarg); 45 | } 46 | break; 47 | } 48 | } 49 | 50 | 51 | // ============================================================================================================== 52 | // Load Data 53 | char data_path[256] = ""; 54 | char index_path[256] = ""; 55 | char centroid_path[256] = ""; 56 | char x0_path[256] = ""; 57 | char dist_to_centroid_path[256] = ""; 58 | char cluster_id_path[256] = ""; 59 | char binary_path[256] = ""; 60 | 61 | sprintf(data_path, "%s%s_base.fvecs", source, dataset); 62 | Matrix X(data_path); 63 | 64 | sprintf(centroid_path, "%sRandCentroid_C%d_B%d.fvecs", source, numC, BB); 65 | Matrix C(centroid_path); 66 | 67 | sprintf(x0_path, "%sx0_C%d_B%d.fvecs", source, numC, BB); 68 | Matrix x0(x0_path); 69 | 70 | sprintf(dist_to_centroid_path, "%s%s_dist_to_centroid_%d.fvecs", source, dataset, numC); 71 | Matrix dist_to_centroid(dist_to_centroid_path); 72 | 73 | sprintf(cluster_id_path, "%s%s_cluster_id_%d.ivecs", source, dataset, numC); 74 | Matrix cluster_id(cluster_id_path); 75 | 76 | sprintf(binary_path, "%sRandNet_C%d_B%d.Ivecs", source, numC, BB); 77 | Matrix binary(binary_path); 78 | 79 | sprintf(index_path, "%sivfrabitq%d_B%d.index", source, numC, BB); 80 | std::cerr << "Loading Succeed!" << std::endl << std::endl; 81 | // ============================================================================================================== 82 | 83 | IVFRN ivf(X, C, dist_to_centroid, x0, cluster_id, binary); 84 | 85 | ivf.save(index_path); 86 | 87 | return 0; 88 | } 89 | -------------------------------------------------------------------------------- /data/rabitq.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import struct 4 | import time 5 | import os 6 | from utils.io import * 7 | from tqdm import tqdm 8 | 9 | source = './' 10 | datasets = ['sift'] 11 | 12 | def Orthogonal(D): 13 | G = np.random.randn(D, D).astype('float32') 14 | Q, _ = np.linalg.qr(G) 15 | return Q 16 | 17 | def GenerateBinaryCode(X, P): 18 | XP = np.dot(X, P) 19 | binary_XP = (XP > 0) 20 | X0 = np.sum(XP * (2 * binary_XP - 1) / D ** 0.5, axis=1, keepdims=True) / np.linalg.norm(XP, axis=1, keepdims=True) 21 | return binary_XP, X0 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | for dataset in datasets: 27 | # path 28 | path = os.path.join(source, dataset) 29 | data_path = os.path.join(path, f'{dataset}_base.fvecs') 30 | 31 | C = 4096 32 | centroids_path = os.path.join(path, f'{dataset}_centroid_{C}.fvecs') 33 | dist_to_centroid_path = os.path.join(path, f'{dataset}_dist_to_centroid_{C}.fvecs') 34 | cluster_id_path = os.path.join(path, f'{dataset}_cluster_id_{C}.ivecs') 35 | 36 | X = read_fvecs(data_path) 37 | centroids = read_fvecs(centroids_path) 38 | cluster_id = read_ivecs(cluster_id_path) 39 | 40 | D = X.shape[1] 41 | B = (D + 63) // 64 * 64 42 | MAX_BD = max(D, B) 43 | 44 | projection_path = os.path.join(path, f'P_C{C}_B{B}.fvecs') 45 | randomized_centroid_path = os.path.join(path, f'RandCentroid_C{C}_B{B}.fvecs') 46 | RN_path = os.path.join(path, f'RandNet_C{C}_B{B}.Ivecs') 47 | x0_path = os.path.join(path, f'x0_C{C}_B{B}.fvecs') 48 | 49 | X_pad = np.pad(X, ((0, 0), (0, MAX_BD-D)), 'constant') 50 | centroids_pad = np.pad(centroids, ((0, 0), (0, MAX_BD-D)), 'constant') 51 | np.random.seed(0) 52 | 53 | # The inverse of an orthogonal matrix equals to its transpose. 54 | P = Orthogonal(MAX_BD) 55 | P = P.T 56 | 57 | cluster_id=np.squeeze(cluster_id) 58 | XP = np.dot(X_pad, P) 59 | CP = np.dot(centroids_pad, P) 60 | XP = XP - CP[cluster_id] 61 | bin_XP = (XP > 0) 62 | 63 | # The inner product between the data vector and the quantized data vector, i.e., <\bar o, o>. 64 | x0 = np.sum(XP[ : , :B] * (2 * bin_XP[ : , :B] - 1) / B ** 0.5, axis=1, keepdims=True) / np.linalg.norm(XP, axis=1, keepdims=True) 65 | 66 | # To remove illy defined x0 67 | # np.linalg.norm(XP, axis=1, keepdims=True) = 0 indicates that its estimated distance based on our method has no error. 68 | # Thus, it should be good to set x0 as any finite non-zero number. 69 | x0[~np.isfinite(x0)] = 0.8 70 | 71 | bin_XP = bin_XP[:, :B].flatten() 72 | uint64_XP = np.packbits(bin_XP.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64) 73 | uint64_XP = uint64_XP.reshape(-1, B >> 6) 74 | 75 | # Output 76 | to_fvecs(randomized_centroid_path, CP) 77 | to_Ivecs(RN_path , uint64_XP) 78 | to_fvecs(x0_path , x0) 79 | to_fvecs(projection_path , P) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [SIGMOD 2024] RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search 2 | 3 | ## News and Updates 4 | 5 | * **A library with more practical implementation techniques about RaBitQ is released at the [RaBitQ-Library](https://github.com/VectorDB-NTU/RaBitQ-Library).** 6 | 7 | 8 | * A new blog - [Quantization in The Counterintuitive High-Dimensional Space](https://dev.to/gaoj0017/quantization-in-the-counterintuitive-high-dimensional-space-4feg) - to provide the key insights behind RaBitQ and its extension (corresponding to optimized approaches to binary quantization and scalar quantization respectively). 9 | 10 | --- 11 | 12 | We are open to address any questions regarding the RaBitQ project. Please feel free to drop us an email at *jianyang.gao [at] ntu.edu.sg* and *c.long [at] ntu.edu.sg*. 13 | 14 | ## Organization 15 | * The index phase of RaBitQ: `./data/rabitq.py`. 16 | * The query phase of RaBitQ: 17 | * `./src/ivf_rabitq.h` includes the general workflow. 18 | * `./src/space.h` includes the bitwise implementation of RaBitQ. 19 | * `./src/fast_scan.h` includes the SIMD-based implementation of RaBitQ. 20 | * Comments are provided in these files. 21 | 22 | 23 | ## Prerequisites 24 | * Eigen == 3.4.0 25 | 1. Download the Eigen library from https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz. 26 | 2. Unzip it and move the `Eigen` folder to `./src/`. 27 | 28 | --- 29 | ## Reproduction 30 | 31 | 1. Download and preprocess the datasets. Detailed instructions can be found in `./data/README.md`. 32 | 33 | 2. Index the datasets. 34 | ```sh 35 | ./script/index.sh 36 | ``` 37 | 3. Test the queries of the datasets. The results are generated in `./results/`. 38 | ```sh 39 | ./script/search.sh 40 | ``` 41 | 42 | --- 43 | 44 | Please cite our paper using the following bibtex if it is used in your research. 45 | 46 | ``` 47 | @article{10.1145/3654970, 48 | author = {Gao, Jianyang and Long, Cheng}, 49 | title = {RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search}, 50 | year = {2024}, 51 | issue_date = {June 2024}, 52 | publisher = {Association for Computing Machinery}, 53 | address = {New York, NY, USA}, 54 | volume = {2}, 55 | number = {3}, 56 | url = {https://doi.org/10.1145/3654970}, 57 | doi = {10.1145/3654970}, 58 | abstract = {Searching for approximate nearest neighbors (ANN) in the high-dimensional Euclidean space is a pivotal problem. Recently, with the help of fast SIMD-based implementations, Product Quantization (PQ) and its variants can often efficiently and accurately estimate the distances between the vectors and have achieved great success in the in-memory ANN search. Despite their empirical success, we note that these methods do not have a theoretical error bound and are observed to fail disastrously on some real-world datasets. Motivated by this, we propose a new randomized quantization method named RaBitQ, which quantizes D-dimensional vectors into D-bit strings. RaBitQ guarantees a sharp theoretical error bound and provides good empirical accuracy at the same time. In addition, we introduce efficient implementations of RaBitQ, supporting to estimate the distances with bitwise operations or SIMD-based operations. Extensive experiments on real-world datasets confirm that (1) our method outperforms PQ and its variants in terms of accuracy-efficiency trade-off by a clear margin and (2) its empirical performance is well-aligned with our theoretical analysis.}, 59 | journal = {Proc. ACM Manag. Data}, 60 | month = may, 61 | articleno = {167}, 62 | numpages = {27}, 63 | keywords = {Johnson-Lindenstrauss transformation, approximate nearest neighbor search, quantization} 64 | } 65 | ``` 66 | 67 | Please provide a reference of our paper if it helps in your system. 68 | 69 | ``` 70 | Jianyang Gao and Cheng Long. 2024. RaBitQ: Quantizing High-Dimensional Vectors with a Theoretical Error Bound for Approximate Nearest Neighbor Search. Proc. ACM Manag. Data 2, 3, Article 167 (June 2024), 27 pages. https://doi.org/10.1145/3654970 71 | ``` 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/search.cpp: -------------------------------------------------------------------------------- 1 | #define EIGEN_DONT_PARALLELIZE 2 | #define USE_AVX2 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace std; 14 | 15 | const int MAXK = 100; 16 | 17 | long double rotation_time=0; 18 | 19 | template 20 | void test(const Matrix &Q, const Matrix &RandQ, const Matrix &X, const Matrix &G, 21 | const IVFRN &ivf, int k){ 22 | float sys_t, usr_t, usr_t_sum = 0, total_time=0, search_time=0; 23 | struct rusage run_start, run_end; 24 | 25 | // ======================================================================== 26 | // Search Parameter 27 | vector nprobes; 28 | nprobes.push_back(300); 29 | // ======================================================================== 30 | 31 | for(auto nprobe:nprobes){ 32 | float total_time=0; 33 | float total_ratio=0; 34 | int correct = 0; 35 | 36 | for(int i=0;i Q(query_path); 114 | 115 | char data_path[256] = ""; 116 | sprintf(data_path, "%s%s_base.fvecs", source, dataset); 117 | Matrix X(data_path); 118 | 119 | char groundtruth_path[256] = ""; 120 | sprintf(groundtruth_path, "%s%s_groundtruth.ivecs", source, dataset); 121 | Matrix G(groundtruth_path); 122 | 123 | char transformation_path[256] = ""; 124 | sprintf(transformation_path, "%sP_C%d_B%d.fvecs", source, numC, BB); 125 | Matrix P(transformation_path); 126 | 127 | char index_path[256] = ""; 128 | sprintf(index_path, "%sivfrabitq%d_B%d.index", source, numC, BB); 129 | std::cerr << index_path << std::endl; 130 | #if defined(FAST_SCAN) 131 | char result_file_view[256] = ""; 132 | sprintf(result_file_view, "%s%s_ivfrabitq%d_B%d_fast_scan.log", result_path, dataset, numC, BB); 133 | #elif defined(SCAN) 134 | char result_file_view[256] = ""; 135 | sprintf(result_file_view, "%s%s_ivfrabitq%d_B%d_scan.log", result_path, dataset, numC, BB); 136 | #endif 137 | std::cerr << "Loading Succeed!" << std::endl; 138 | // ================================================================================================================================ 139 | 140 | 141 | freopen(result_file_view,"a",stdout); 142 | 143 | IVFRN ivf; 144 | ivf.load(index_path); 145 | 146 | float sys_t, usr_t, usr_t_sum = 0, total_time=0, search_time=0; 147 | struct rusage run_start, run_end; 148 | GetCurTime( &run_start); 149 | 150 | Matrix RandQ(Q.n, BB, Q); 151 | RandQ = mul(RandQ, P); 152 | 153 | GetCurTime( &run_end); 154 | GetTime( &run_start, &run_end, &usr_t, &sys_t); 155 | rotation_time = usr_t * 1e6 / Q.n; 156 | 157 | test(Q, RandQ, X, G, ivf, subk); 158 | 159 | return 0; 160 | } 161 | -------------------------------------------------------------------------------- /src/space.h: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 10 | #define PORTABLE_ALIGN64 __attribute__((aligned(64))) 11 | #include "matrix.h" 12 | #include "utils.h" 13 | #include 14 | 15 | 16 | template 17 | class Space{ 18 | public: 19 | // ================================================================================================ 20 | // ******************** 21 | // Binary Operation 22 | // ******************** 23 | inline static uint32_t popcount(u_int64_t *d); 24 | inline static uint32_t ip_bin_bin(uint64_t * q, uint64_t * d); 25 | inline static uint32_t ip_byte_bin(uint64_t *q, uint64_t *d); 26 | inline static void transpose_bin(uint8_t *q, uint64_t *tq); 27 | 28 | // ================================================================================================ 29 | inline static void range(float* q, float* c, float &vl, float &vr); 30 | inline static void quantize(uint8_t *result, float *q, float* c, float *u, float max_entry, float width, uint32_t & sum_q); 31 | inline static uint32_t sum(uint8_t* d); 32 | Space(){}; 33 | ~Space(){}; 34 | }; 35 | 36 | 37 | // ============================================================== 38 | // inner product between binary strings 39 | // ============================================================== 40 | template 41 | inline uint32_t Space::ip_bin_bin(uint64_t * q, uint64_t * d){ 42 | uint64_t ret = 0; 43 | for(int i = 0; i < B / 64; i ++){ 44 | ret += __builtin_popcountll((*d) & (*q)); 45 | q ++; 46 | d ++; 47 | } 48 | return ret; 49 | } 50 | 51 | // ============================================================== 52 | // popcount (a.k.a, bitcount) 53 | // ============================================================== 54 | template 55 | inline uint32_t Space::popcount(u_int64_t *d){ 56 | uint64_t ret = 0; 57 | for(int i = 0; i < B / 64; i ++){ 58 | ret += __builtin_popcountll((*d)); 59 | d ++; 60 | } 61 | return ret; 62 | } 63 | 64 | // ============================================================== 65 | // inner product between a decomposed byte string q 66 | // and a binary string d 67 | // ============================================================== 68 | template 69 | uint32_t Space::ip_byte_bin(uint64_t *q, uint64_t *d){ 70 | uint64_t ret = 0; 71 | for(int i = 0; i < B_QUERY; i++){ 72 | ret += (ip_bin_bin(q, d) << i); 73 | q += (B / 64); 74 | } 75 | return ret; 76 | } 77 | 78 | // ============================================================== 79 | // decompose the quantized query vector into B_q binary vector 80 | // ============================================================== 81 | template 82 | void Space::transpose_bin(uint8_t *q, uint64_t *tq){ 83 | for(int i=0;i(q)); 85 | v = _mm256_slli_epi32(v, (8-B_QUERY)); 86 | for(int j=0;j 100 | void Space::range(float* q, float* c, float& vl, float &vr){ 101 | vl = +1e20; 102 | vr = -1e20; 103 | for(int i=0;i vr)vr = tmp; 107 | q ++; 108 | c ++; 109 | } 110 | } 111 | 112 | // ============================================================== 113 | // quantize the query vector with uniform scalar quantization 114 | // ============================================================== 115 | template 116 | void Space::quantize(uint8_t *result, float *q, float* c, float * u, float vl, float width, uint32_t &sum_q){ 117 | float one_over_width = 1.0 / width; 118 | uint8_t *ptr_res = result; 119 | uint32_t sum = 0; 120 | for(int i=0;i 133 | inline float sqr_dist(float *d, float *q) { 134 | float PORTABLE_ALIGN32 TmpRes[8] = {0, 0, 0, 0, 0, 0, 0, 0}; 135 | constexpr uint32_t num_blk16 = L >> 4; 136 | constexpr uint32_t l = L & 0b1111; 137 | 138 | __m256 diff, v1, v2; 139 | __m256 sum = _mm256_set1_ps(0); 140 | for(int i=0;i 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #define lowbit(x) (x&(-x)) 13 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 14 | #define PORTABLE_ALIGN64 __attribute__((aligned(64))) 15 | 16 | using namespace std; 17 | 18 | // ============================================================== 19 | // look up the tables for a packed batch of 32 quantization codes 20 | // ============================================================== 21 | template 22 | inline void accumulate_one_block(uint8_t* codes, uint8_t* LUT, uint16_t* result){ 23 | __m256i low_mask = _mm256_set1_epi8(0xf); 24 | __m256i accu[4]; 25 | for(int i=0;i<4;i++){ 26 | accu[i] = _mm256_setzero_si256(); 27 | } 28 | 29 | constexpr uint32_t M = B / 4; 30 | 31 | for(int m=0;m 64 | inline void accumulate(uint32_t nblk, uint8_t* codes, uint8_t* LUT, uint16_t* result){ 65 | for(int i=0;i(codes, LUT, result); 67 | codes += 32 * B / 8; 68 | result += 32; 69 | } 70 | } 71 | 72 | 73 | // ============================================================== 74 | // prepare the look-up-table from the quantized query vector 75 | // ============================================================== 76 | template 77 | inline void pack_LUT(uint8_t* byte_query, uint8_t* LUT){ 78 | constexpr uint32_t M = B / 4; 79 | constexpr uint32_t pos[16]={ 80 | 3 /*0000*/, 3/*0001*/, 2/*0010*/, 3/*0011*/, 81 | 1 /*0100*/, 3/*0101*/, 2/*0110*/, 3/*0111*/, 82 | 0 /*1000*/, 3/*1001*/, 2/*1010*/, 3/*1011*/, 83 | 1 /*1100*/, 3/*1101*/, 2/*1110*/, 3/*1111*/, 84 | }; 85 | for(int i=0;i 97 | inline void get_matrix_column(T* src, size_t m, size_t n, int64_t i, int64_t j, TA& dest) { 98 | for (int64_t k = 0; k < dest.size(); k++) { 99 | if (k + i >= 0 && k + i < m) { 100 | dest[k] = src[(k + i) * n + j]; 101 | } 102 | else { 103 | dest[k] = 0; 104 | } 105 | } 106 | } 107 | 108 | // ============================================================== 109 | // pack 32 quantization codes in a batch from the quantization 110 | // codes represented by a sequence of uint8_t variables 111 | // ============================================================== 112 | template 113 | void pack_codes(const uint8_t* codes, uint32_t ncode, uint8_t* blocks){ 114 | 115 | uint32_t ncode_pad = (ncode + 31) / 32 * 32; 116 | constexpr uint32_t M = B / 4; 117 | const uint8_t bbs = 32; 118 | memset(blocks, 0, ncode_pad * M / 2); 119 | 120 | const uint8_t perm0[16] = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; 121 | uint8_t* codes2 = blocks; 122 | for(int blk=0;blk c, c0, c1; 126 | get_matrix_column(codes, ncode, M / 2, blk, m / 2, c); 127 | for (int j = 0; j < 32; j++) { 128 | c0[j] = c[j] & 15; 129 | c1[j] = c[j] >> 4; 130 | } 131 | for (int j = 0; j < 16; j++) { 132 | uint8_t d0, d1; 133 | d0 = c0[perm0[j]] | (c0[perm0[j] + 16] << 4); 134 | d1 = c1[perm0[j]] | (c1[perm0[j] + 16] << 4); 135 | codes2[j] = d0; 136 | codes2[j + 16] = d1; 137 | } 138 | codes2 += 32; 139 | } 140 | } 141 | } 142 | 143 | // ============================================================== 144 | // pack 32 quantization codes in a batch from the quantization 145 | // codes represented by a sequence of uint64_t variables 146 | // ============================================================== 147 | template 148 | void pack_codes(const uint64_t* binary_code, uint32_t ncode, uint8_t* blocks){ 149 | uint32_t ncode_pad = (ncode + 31) / 32 * 32; 150 | memset(blocks, 0, ncode_pad * sizeof(uint8_t)); 151 | 152 | uint8_t * binary_code_8bit = new uint8_t [ncode_pad * B / 8]; 153 | memcpy(binary_code_8bit, binary_code, ncode * B / 64 * sizeof(uint64_t)); 154 | 155 | for(int i=0;i> 4); 163 | uint8_t y = (v & 15); 164 | binary_code_8bit[i] = (y << 4 | x); 165 | } 166 | pack_codes(binary_code_8bit, ncode, blocks); 167 | delete [] binary_code_8bit; 168 | } 169 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifndef WIN32 8 | #include 9 | #endif 10 | 11 | typedef std::pair Result; 12 | typedef std::priority_queue ResultHeap; 13 | 14 | namespace Detail 15 | { 16 | double constexpr sqrtNewtonRaphson(double x, double curr, double prev) 17 | { 18 | return curr == prev 19 | ? curr 20 | : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr); 21 | } 22 | } 23 | 24 | /* 25 | * Constexpr version of the square root 26 | * Return value: 27 | * - For a finite and non-negative value of "x", returns an approximation for the square root of "x" 28 | * - Otherwise, returns NaN 29 | */ 30 | double constexpr const_sqrt(double x) 31 | { 32 | return x >= 0 && x < std::numeric_limits::infinity() 33 | ? Detail::sqrtNewtonRaphson(x, x, 0) 34 | : std::numeric_limits::quiet_NaN(); 35 | } 36 | 37 | void print_binary(uint64_t v){ 38 | for(int i=0;i<64;i++){ 39 | std::cerr << ((v >> (63 - i)) & 1); 40 | } 41 | } 42 | 43 | void print_binary(uint8_t v){ 44 | for(int i=0;i<8;i++){ 45 | std::cerr << ((v >> (7 - i)) & 1); 46 | } 47 | } 48 | 49 | inline uint32_t reverseBits(uint32_t n) { 50 | n = (n >> 1) & 0x55555555 | (n << 1) & 0xaaaaaaaa; 51 | n = (n >> 2) & 0x33333333 | (n << 2) & 0xcccccccc; 52 | n = (n >> 4) & 0x0f0f0f0f | (n << 4) & 0xf0f0f0f0; 53 | n = (n >> 8) & 0x00ff00ff | (n << 8) & 0xff00ff00; 54 | n = (n >> 16) & 0x0000ffff | (n << 16) & 0xffff0000; 55 | return n; 56 | } 57 | 58 | ResultHeap getGroundtruth(const Matrix & X, const Matrix & Q, size_t query, 59 | unsigned* groundtruth, size_t k){ 60 | ResultHeap ret; 61 | for(int i=0;i &Q, const Matrix &X, const Matrix &G, ResultHeap KNNs){ 69 | ResultHeap gt; 70 | int k = KNNs.size(); 71 | for(int i=0;i 1e-5){ 78 | ret += std::sqrt(KNNs.top().first / gt.top().first); 79 | valid_k ++; 80 | } 81 | gt.pop(); 82 | KNNs.pop(); 83 | } 84 | if(valid_k == 0) return 1.0 * k; 85 | return ret / valid_k * k; 86 | } 87 | 88 | int getRecall(ResultHeap & result, ResultHeap & gt){ 89 | int correct=0; 90 | 91 | std::unordered_set g; 92 | int ret = 0; 93 | 94 | while (gt.size()) { 95 | g.insert(gt.top().second); 96 | //std::cerr << "ID - " << gt.top().second << " dist - " << gt.top().first << std::endl; 97 | gt.pop(); 98 | } 99 | 100 | while (result.size()) { 101 | //std::cerr << "ID - " << result.top().second << " dist - " << result.top().first << std::endl; 102 | if (g.find(result.top().second) != g.end()) { 103 | ret++; 104 | } 105 | result.pop(); 106 | } 107 | 108 | return ret; 109 | } 110 | 111 | #ifndef WIN32 112 | void GetCurTime( rusage* curTime) 113 | { 114 | int ret = getrusage( RUSAGE_SELF, curTime); 115 | if( ret != 0) 116 | { 117 | fprintf( stderr, "The running time info couldn't be collected successfully.\n"); 118 | //FreeData( 2); 119 | exit( 0); 120 | } 121 | } 122 | 123 | /* 124 | * GetTime is used to get the 'float' format time from the start and end rusage structure. 125 | * 126 | * @Param timeStart, timeEnd indicate the two time points. 127 | * @Param userTime, sysTime get back the time information. 128 | * 129 | * @Return void. 130 | */ 131 | void GetTime( struct rusage* timeStart, struct rusage* timeEnd, float* userTime, float* sysTime) 132 | { 133 | (*userTime) = ((float)(timeEnd->ru_utime.tv_sec - timeStart->ru_utime.tv_sec)) + 134 | ((float)(timeEnd->ru_utime.tv_usec - timeStart->ru_utime.tv_usec)) * 1e-6; 135 | (*sysTime) = ((float)(timeEnd->ru_stime.tv_sec - timeStart->ru_stime.tv_sec)) + 136 | ((float)(timeEnd->ru_stime.tv_usec - timeStart->ru_stime.tv_usec)) * 1e-6; 137 | } 138 | 139 | #endif 140 | 141 | #if defined(_WIN32) 142 | #include 143 | #include 144 | 145 | #elif defined(__unix__) || defined(__unix) || defined(unix) || (defined(__APPLE__) && defined(__MACH__)) 146 | 147 | #include 148 | #include 149 | 150 | #if defined(__APPLE__) && defined(__MACH__) 151 | #include 152 | 153 | #elif (defined(_AIX) || defined(__TOS__AIX__)) || (defined(__sun__) || defined(__sun) || defined(sun) && (defined(__SVR4) || defined(__svr4__))) 154 | #include 155 | #include 156 | 157 | #elif defined(__linux__) || defined(__linux) || defined(linux) || defined(__gnu_linux__) 158 | 159 | #endif 160 | 161 | #else 162 | #error "Cannot define getPeakRSS( ) or getCurrentRSS( ) for an unknown OS." 163 | #endif 164 | 165 | 166 | /** 167 | * Returns the peak (maximum so far) resident set size (physical 168 | * memory use) measured in bytes, or zero if the value cannot be 169 | * determined on this OS. 170 | */ 171 | size_t getPeakRSS() { 172 | #if defined(_WIN32) 173 | /* Windows -------------------------------------------------- */ 174 | PROCESS_MEMORY_COUNTERS info; 175 | GetProcessMemoryInfo(GetCurrentProcess(), &info, sizeof(info)); 176 | return (size_t)info.PeakWorkingSetSize; 177 | 178 | #elif (defined(_AIX) || defined(__TOS__AIX__)) || (defined(__sun__) || defined(__sun) || defined(sun) && (defined(__SVR4) || defined(__svr4__))) 179 | /* AIX and Solaris ------------------------------------------ */ 180 | struct psinfo psinfo; 181 | int fd = -1; 182 | if ((fd = open("/proc/self/psinfo", O_RDONLY)) == -1) 183 | return (size_t)0L; /* Can't open? */ 184 | if (read(fd, &psinfo, sizeof(psinfo)) != sizeof(psinfo)) 185 | { 186 | close(fd); 187 | return (size_t)0L; /* Can't read? */ 188 | } 189 | close(fd); 190 | return (size_t)(psinfo.pr_rssize * 1024L); 191 | 192 | #elif defined(__unix__) || defined(__unix) || defined(unix) || (defined(__APPLE__) && defined(__MACH__)) 193 | /* BSD, Linux, and OSX -------------------------------------- */ 194 | struct rusage rusage; 195 | getrusage(RUSAGE_SELF, &rusage); 196 | #if defined(__APPLE__) && defined(__MACH__) 197 | return (size_t)rusage.ru_maxrss; 198 | #else 199 | return (size_t) (rusage.ru_maxrss * 1024L); 200 | #endif 201 | 202 | #else 203 | /* Unknown OS ----------------------------------------------- */ 204 | return (size_t)0L; /* Unsupported. */ 205 | #endif 206 | } 207 | 208 | 209 | /** 210 | * Returns the current resident set size (physical memory use) measured 211 | * in bytes, or zero if the value cannot be determined on this OS. 212 | */ 213 | size_t getCurrentRSS() { 214 | #if defined(_WIN32) 215 | /* Windows -------------------------------------------------- */ 216 | PROCESS_MEMORY_COUNTERS info; 217 | GetProcessMemoryInfo(GetCurrentProcess(), &info, sizeof(info)); 218 | return (size_t)info.WorkingSetSize; 219 | 220 | #elif defined(__APPLE__) && defined(__MACH__) 221 | /* OSX ------------------------------------------------------ */ 222 | struct mach_task_basic_info info; 223 | mach_msg_type_number_t infoCount = MACH_TASK_BASIC_INFO_COUNT; 224 | if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, 225 | (task_info_t)&info, &infoCount) != KERN_SUCCESS) 226 | return (size_t)0L; /* Can't access? */ 227 | return (size_t)info.resident_size; 228 | 229 | #elif defined(__linux__) || defined(__linux) || defined(linux) || defined(__gnu_linux__) 230 | /* Linux ---------------------------------------------------- */ 231 | long rss = 0L; 232 | FILE *fp = NULL; 233 | if ((fp = fopen("/proc/self/statm", "r")) == NULL) 234 | return (size_t) 0L; /* Can't open? */ 235 | if (fscanf(fp, "%*s%ld", &rss) != 1) { 236 | fclose(fp); 237 | return (size_t) 0L; /* Can't read? */ 238 | } 239 | fclose(fp); 240 | return (size_t) rss * (size_t) sysconf(_SC_PAGESIZE); 241 | 242 | #else 243 | /* AIX, BSD, Solaris, and Unknown OS ------------------------ */ 244 | return (size_t)0L; /* Unsupported. */ 245 | #endif 246 | } -------------------------------------------------------------------------------- /src/matrix.h: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | #ifndef MATRIX_HPP_ 4 | #define MATRIX_HPP_ 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | template 15 | class Matrix 16 | { 17 | private: 18 | 19 | public: 20 | T* data; 21 | size_t n; 22 | size_t d; 23 | 24 | // Construction 25 | Matrix(); // Default 26 | Matrix(size_t n, size_t d); // Fixed size 27 | Matrix(const Matrix &X); // Deep Copy 28 | Matrix(size_t n, size_t d, const Matrix &X); // Fixed size with a filling matrix. 29 | Matrix(size_t n, size_t d, const Matrix &X, size_t *id); // Submatrix with given row numbers. 30 | Matrix(const Matrix &X, const Matrix &ID); // Submatrix 31 | Matrix(const Matrix &X, const size_t id); // row 32 | Matrix(char * data_file_path); // IO 33 | Matrix(size_t n); // ID 34 | 35 | // Deconstruction 36 | ~Matrix(){ 37 | delete [] data; 38 | } 39 | 40 | // Serialization 41 | void serialize(FILE * fp); 42 | void deserialize(FILE * fp); 43 | 44 | Matrix & operator = (const Matrix &X){ 45 | delete [] data; 46 | n = X.n; 47 | d = X.d; 48 | data = new T [n*d]; 49 | memcpy(data, X.data, sizeof(T) * n * d); 50 | return *this; 51 | } 52 | 53 | // Linear Algebra 54 | void add(size_t a, const Matrix &B, size_t b); 55 | void div(size_t a, T c); 56 | void mul(const Matrix &A, Matrix &result) const; 57 | void copy(size_t r, const Matrix &A, size_t ra); // copy the ra th row of A to r th row 58 | float dist(const Matrix &A, size_t a, const Matrix &B, size_t b) const; 59 | float dist(size_t a, const Matrix &B, size_t b) const; 60 | 61 | size_t scalar(){ 62 | return data[0]; 63 | } 64 | 65 | bool empty(){ 66 | if(n == 0)return 1; 67 | return 0; 68 | } 69 | 70 | // Experiment and Debug 71 | void print(); 72 | void reset(); 73 | 74 | }; 75 | 76 | template 77 | Matrix::Matrix(){ 78 | n = 0; 79 | d = 0; 80 | data = NULL; 81 | } 82 | 83 | template 84 | Matrix::Matrix(const Matrix& X){ 85 | n = X.n; 86 | d = X.d; 87 | data = new T[n * d]; 88 | memcpy(data, X.data, sizeof(T) * n * d); 89 | } 90 | 91 | template 92 | Matrix::Matrix(size_t _n,size_t _d){ 93 | n = _n; 94 | d = _d; 95 | data = new T [n * d + 10]; 96 | memset(data, 0, (n * d + 10) * sizeof(T)); 97 | } 98 | 99 | template 100 | Matrix::Matrix(size_t _n,size_t _d, const Matrix &X){ 101 | n = _n; 102 | d = _d; 103 | data = new T [n * d + 10]; 104 | memset(data, 0, (n * d + 10) * sizeof(T)); 105 | for(int i=0;i 111 | Matrix::Matrix(size_t _n,size_t _d, const Matrix &X, size_t *id){ 112 | n = _n; 113 | d = _d; 114 | data = new T [n * d]; 115 | for(size_t i=0;i 121 | Matrix::Matrix(const Matrix &X, const Matrix &id){ 122 | n = id.n; 123 | d = X.d; 124 | data = new T [n * d]; 125 | for(size_t i=0;i 131 | Matrix::Matrix(const Matrix &X, const size_t id){ 132 | n = 1; 133 | d = X.d; 134 | data = new T [n * d]; 135 | for(size_t i=0;i 141 | Matrix::Matrix(char *data_file_path){ 142 | n = 0; 143 | d = 0; 144 | data = NULL; 145 | printf("%s\n",data_file_path); 146 | std::ifstream in(data_file_path, std::ios::binary); 147 | if (!in.is_open()) { 148 | std::cout << "open file error" << std::endl; 149 | exit(-1); 150 | } 151 | in.read((char*)&d, 4); 152 | 153 | std::cerr << "Dimensionality - " << d < 170 | Matrix::Matrix(size_t _n){ 171 | n = _n; 172 | d = 1; 173 | data = new T [n]; 174 | for(size_t i=0;i 178 | void Matrix::print(){ 179 | for(size_t i=0;i<2;i++){ 180 | std::cout << "("; 181 | for(size_t j=0;j 189 | void Matrix::reset(){ 190 | memset(data, 0, sizeof(T) * n * d); 191 | } 192 | 193 | template 194 | float Matrix::dist(const Matrix &A, size_t a, const Matrix &B, size_t b)const{ 195 | float dist = 0; 196 | float *ptra = A.data + a * d; 197 | float *ptrb = B.data + b * d; 198 | 199 | for(int i=0;i 210 | float Matrix::dist(size_t a, const Matrix &B, size_t b)const{ 211 | float dist = 0; 212 | for(size_t i=0;i 219 | void Matrix::add(size_t a, const Matrix &B, size_t b){ 220 | for(size_t i=0;i 226 | void Matrix::div(size_t a, T c){ 227 | for(size_t i=0;i 233 | void Matrix::mul(const Matrix &A, Matrix &result)const{ 234 | //result.reset(); 235 | result.n = n; 236 | result.d = A.d; 237 | for(size_t i=0;i 249 | Matrix mul(const Matrix &A, const Matrix &B){ 250 | 251 | std::cerr << "Matrix Multiplication - " << A.n << " " << A.d << " " << B.d << std::endl; 252 | Eigen::MatrixXd _A(A.n, A.d); 253 | Eigen::MatrixXd _B(B.n, B.d); 254 | Eigen::MatrixXd _C(A.n, B.d); 255 | 256 | for(int i=0;i result(A.n, B.d); 267 | 268 | for(int i=0;i 276 | void Matrix::serialize(FILE * fp){ 277 | fwrite(&n, sizeof(size_t), 1, fp); 278 | fwrite(&d, sizeof(size_t), 1, fp); 279 | size_t size = sizeof(T); 280 | fwrite(&size, sizeof(size_t), 1, fp); 281 | fwrite(data, size, n * d, fp); 282 | } 283 | 284 | template 285 | void Matrix::deserialize(FILE * fp){ 286 | fread(&n, sizeof(size_t), 1, fp); 287 | fread(&d, sizeof(size_t), 1, fp); 288 | //std::cerr << n << " " << d << std::endl; 289 | assert(n <= 1000000000); 290 | assert(d <= 2000); 291 | 292 | size_t size = sizeof(T); 293 | fread(&size, sizeof(size_t), 1, fp); 294 | data = new T [n * d]; 295 | fread(data, size, n * d, fp); 296 | } 297 | 298 | double normalize(float *x, unsigned D){ 299 | Eigen::VectorXd v(D); 300 | for(int i=0;i 10 | #include 11 | #include 12 | #include 13 | #include "matrix.h" 14 | #include "utils.h" 15 | #include "space.h" 16 | #include "fast_scan.h" 17 | 18 | template 19 | class IVFRN{ 20 | private: 21 | public: 22 | struct Factor{ 23 | float sqr_x; 24 | float error; 25 | float factor_ppc; 26 | float factor_ip; 27 | }; 28 | 29 | Factor * fac; 30 | static constexpr float fac_norm = const_sqrt(1.0 * B); 31 | static constexpr float max_x1 = 1.9 / const_sqrt(1.0 * B-1.0); 32 | 33 | static Space space; 34 | 35 | uint32_t N; // the number of data vectors 36 | uint32_t C; // the number of clusters 37 | 38 | uint32_t* start; // the start point of a cluster 39 | uint32_t* packed_start; // the start point of a cluster (packed with batch of 32) 40 | uint32_t* len; // the length of a cluster 41 | uint32_t* id; // N of size_t the ids of the objects in a cluster 42 | float * dist_to_c; // N of floats distance to the centroids (not the squared distance) 43 | float * u; // B of floats random numbers sampled from the uniform distribution [0,1] 44 | 45 | uint64_t * binary_code; // (B / 64) * N of 64-bit uint64_t 46 | uint8_t * packed_code; // packed code with the batch size of 32 vectors 47 | 48 | 49 | float * x0; // N of floats in the Random Net algorithm 50 | float * centroid; // N * B floats (not N * D), note that the centroids should be randomized 51 | float * data; // N * D floats, note that the datas are not randomized 52 | 53 | IVFRN(); 54 | IVFRN(const Matrix &X, const Matrix &_centroids, const Matrix &dist_to_centroid, 55 | const Matrix &_x0, const Matrix &cluster_id, const Matrix &binary); 56 | ~IVFRN(); 57 | 58 | ResultHeap search(float* query, float* rd_query, uint32_t k, uint32_t nprobe, float distK = std::numeric_limits::max()) const; 59 | 60 | static void scan(ResultHeap &KNNs, float &distK, uint32_t k, \ 61 | uint64_t *quant_query, uint64_t *ptr_binary_code, uint32_t len, Factor *ptr_fac, \ 62 | const float sqr_y, const float vl, const float width, const float sumq,\ 63 | float *query, float *data, uint32_t *id); 64 | 65 | static void fast_scan(ResultHeap &KNNs, float &distK, uint32_t k, \ 66 | uint8_t *LUT, uint8_t *packed_code, uint32_t len, Factor *ptr_fac, \ 67 | const float sqr_y, const float vl, const float width, const float sumq,\ 68 | float *query, float *data, uint32_t *id); 69 | 70 | void save(char* filename); 71 | void load(char* filename); 72 | }; 73 | 74 | // scan impl 75 | template 76 | void IVFRN::scan(ResultHeap &KNNs, float &distK, uint32_t k, \ 77 | uint64_t *quant_query, uint64_t *ptr_binary_code, uint32_t len, Factor *ptr_fac, \ 78 | const float sqr_y, const float vl, const float width, const float sumq, \ 79 | float *query, float *data, uint32_t *id){ 80 | 81 | constexpr int SIZE = 32; 82 | float y = std::sqrt(sqr_y); 83 | float res[SIZE]; 84 | float *ptr_res = &res[0]; 85 | int it = len / SIZE; 86 | 87 | for(int i=0;i sqr_x) + sqr_y + ptr_fac -> factor_ppc * vl + (space.ip_byte_bin(quant_query, ptr_binary_code) * 2 -sumq) * (ptr_fac -> factor_ip) * width; 91 | float error_bound = y * (ptr_fac -> error); 92 | *ptr_res = tmp_dist - error_bound; 93 | ptr_binary_code += B / 64; 94 | ptr_fac ++; 95 | ptr_res ++; 96 | } 97 | 98 | ptr_res = &res[0]; 99 | for(int j=0;j(query, data); 103 | if(gt_dist < distK){ 104 | KNNs.emplace(gt_dist, *id); 105 | if(KNNs.size() > k) KNNs.pop(); 106 | if(KNNs.size() == k)distK = KNNs.top().first; 107 | } 108 | } 109 | data += D; 110 | ptr_res++; 111 | id++; 112 | } 113 | } 114 | 115 | ptr_res = &res[0]; 116 | for(int i=it * SIZE;i sqr_x) + sqr_y + ptr_fac -> factor_ppc * vl + (space.ip_byte_bin(quant_query, ptr_binary_code) * 2 -sumq) * (ptr_fac -> factor_ip) * width; 118 | float error_bound = y * (ptr_fac -> error); 119 | *ptr_res = tmp_dist - error_bound; 120 | ptr_binary_code += B / 64; 121 | ptr_fac ++; 122 | ptr_res ++; 123 | } 124 | 125 | ptr_res = &res[0]; 126 | for(int i=it * SIZE;i(query, data); 129 | if(gt_dist < distK){ 130 | KNNs.emplace(gt_dist, *id); 131 | if(KNNs.size() > k) KNNs.pop(); 132 | if(KNNs.size() == k)distK = KNNs.top().first; 133 | } 134 | } 135 | data += D; 136 | ptr_res++; 137 | id++; 138 | } 139 | } 140 | 141 | template 142 | void IVFRN::fast_scan(ResultHeap &KNNs, float &distK, uint32_t k, \ 143 | uint8_t *LUT, uint8_t *packed_code, uint32_t len, Factor *ptr_fac, \ 144 | const float sqr_y, const float vl, const float width, const float sumq, \ 145 | float *query, float *data, uint32_t *id){ 146 | 147 | for(int i=0;i((SIZE / 32), packed_code, LUT, result); 161 | packed_code += SIZE * B / 8; 162 | 163 | for(int i=0;i sqr_x) + sqr_y + ptr_fac -> factor_ppc * vl + (result[i]-sumq) * (ptr_fac -> factor_ip) * width; 165 | float error_bound = y * (ptr_fac -> error); 166 | *ptr_low_dist = tmp_dist - error_bound; 167 | ptr_fac ++; 168 | ptr_low_dist ++; 169 | } 170 | ptr_low_dist = &low_dist[0]; 171 | for(int j=0;j(query, data); 175 | // cerr << *ptr_low_dist << " " << gt_dist << endl; 176 | if(gt_dist < distK){ 177 | KNNs.emplace(gt_dist, *id); 178 | if(KNNs.size() > k) KNNs.pop(); 179 | if(KNNs.size() == k)distK = KNNs.top().first; 180 | } 181 | } 182 | data += D; 183 | ptr_low_dist++; 184 | id++; 185 | } 186 | } 187 | 188 | { 189 | float low_dist[SIZE]; 190 | float *ptr_low_dist = &low_dist[0]; 191 | uint16_t PORTABLE_ALIGN32 result[SIZE]; 192 | accumulate(nblk_remain, packed_code, LUT, result); 193 | 194 | for(int i=0;i sqr_x) + sqr_y + ptr_fac -> factor_ppc * vl + (result[i] - sumq) * ptr_fac -> factor_ip * width; 196 | float error_bound = y * (ptr_fac -> error); 197 | 198 | // *********************************************************************************************** 199 | *ptr_low_dist = tmp_dist - error_bound; 200 | ptr_fac ++; 201 | ptr_low_dist ++; 202 | } 203 | ptr_low_dist = &low_dist[0]; 204 | for(int i=0;i(query, data); 208 | if(gt_dist < distK){ 209 | KNNs.emplace(gt_dist, *id); 210 | if(KNNs.size() > k) KNNs.pop(); 211 | if(KNNs.size() == k)distK = KNNs.top().first; 212 | } 213 | } 214 | data += D; 215 | ptr_low_dist++; 216 | id++; 217 | } 218 | } 219 | } 220 | 221 | // search impl 222 | template 223 | ResultHeap IVFRN::search(float* query, float* rd_query, uint32_t k, uint32_t nprobe, float distK) const{ 224 | // The default value of distK is +inf 225 | ResultHeap KNNs; 226 | // =========================================================================================================== 227 | // Find out the nearest N_{probe} centroids to the query vector. 228 | Result centroid_dist[numC]; 229 | float * ptr_c = centroid; 230 | for(int i=0;i(rd_query, ptr_c); 232 | centroid_dist[i].second = i; 233 | ptr_c += B; 234 | } 235 | std::partial_sort(centroid_dist, centroid_dist + nprobe, centroid_dist + numC); 236 | 237 | // =========================================================================================================== 238 | // Scan the first nprobe clusters. 239 | Result *ptr_centroid_dist = (¢roid_dist[0]); 240 | uint8_t PORTABLE_ALIGN64 byte_query[B]; 241 | 242 | for(int pb=0;pb second; 244 | float sqr_y = ptr_centroid_dist -> first; 245 | ptr_centroid_dist ++; 246 | 247 | // ======================================================================================================= 248 | // Preprocess the residual query and the quantized query 249 | float vl, vr; 250 | space.range(rd_query, centroid + c * B, vl, vr); 251 | float width = (vr - vl) / ((1 << B_QUERY) - 1); 252 | uint32_t sum_q = 0; 253 | space.quantize(byte_query, rd_query, centroid + c * B, u, vl, width, sum_q); 254 | 255 | #if defined(SCAN) // Binary String Representation 256 | uint64_t PORTABLE_ALIGN32 quant_query[B_QUERY * B / 64]; 257 | memset(quant_query, 0, sizeof(quant_query)); 258 | space.transpose_bin(byte_query, quant_query); 259 | #elif defined(FAST_SCAN) // Look-Up-Table Representation 260 | uint8_t PORTABLE_ALIGN32 LUT[B / 4 * 16]; 261 | pack_LUT(byte_query, LUT); 262 | #endif 263 | 264 | #if defined(SCAN) 265 | scan(KNNs, distK, k,\ 266 | quant_query, binary_code + 1ull * start[c] * (B / 64), len[c], fac + start[c], \ 267 | sqr_y, vl, width, sum_q,\ 268 | query, data + 1ull * start[c] * D, id + start[c]); 269 | #elif defined(FAST_SCAN) 270 | fast_scan(KNNs, distK, k, \ 271 | LUT, packed_code + packed_start[c], len[c], fac + start[c], \ 272 | sqr_y, vl, width, sum_q,\ 273 | query, data + start[c] * D, id + start[c]); 274 | #endif 275 | } 276 | return KNNs; 277 | } 278 | 279 | 280 | 281 | 282 | // ============================================================================================================================== 283 | // Save and Load Functions 284 | template 285 | void IVFRN::save(char * filename){ 286 | std::ofstream output(filename, std::ios::binary); 287 | 288 | uint32_t d = D; 289 | uint32_t b = B; 290 | output.write((char *) &N, sizeof(uint32_t)); 291 | output.write((char *) &d, sizeof(uint32_t)); 292 | output.write((char *) &C, sizeof(uint32_t)); 293 | output.write((char *) &b, sizeof(uint32_t)); 294 | 295 | output.write((char *) start , C * sizeof(uint32_t)); 296 | output.write((char *) len , C * sizeof(uint32_t)); 297 | output.write((char *) id , N * sizeof(uint32_t)); 298 | output.write((char *) dist_to_c , N * sizeof(float)); 299 | output.write((char *) x0 , N * sizeof(float)); 300 | 301 | output.write((char *) centroid, C * B * sizeof(float)); 302 | output.write((char *) data, 1ull * N * D * sizeof(float)); 303 | output.write((char *) binary_code, 1ull * N * B / 64 * sizeof(uint64_t)); 304 | 305 | output.close(); 306 | std::cerr << "Saved!" << std::endl; 307 | } 308 | 309 | // load impl 310 | template 311 | void IVFRN::load(char * filename){ 312 | std::ifstream input(filename, std::ios::binary); 313 | //std::cerr << filename << std::endl; 314 | 315 | if (!input.is_open()) 316 | throw std::runtime_error("Cannot open file"); 317 | 318 | uint32_t d; 319 | uint32_t b; 320 | input.read((char *) &N, sizeof(uint32_t)); 321 | input.read((char *) &d, sizeof(uint32_t)); 322 | input.read((char *) &C, sizeof(uint32_t)); 323 | input.read((char *) &b, sizeof(uint32_t)); 324 | 325 | std::cerr << d << std::endl; 326 | assert(d == D); 327 | assert(b == B); 328 | 329 | u = new float [B]; 330 | #if defined(RANDOM_QUERY_QUANTIZATION) 331 | std::random_device rd; 332 | std::mt19937 gen(rd()); 333 | std::uniform_real_distribution<> uniform(0.0, 1.0); 334 | for(int i=0;i(aligned_alloc(256, N * B / 64 * sizeof(uint64_t))); 343 | 344 | start = new uint32_t [C]; 345 | len = new uint32_t [C]; 346 | id = new uint32_t [N]; 347 | dist_to_c = new float [N]; 348 | x0 = new float [N]; 349 | 350 | fac = new Factor[N]; 351 | 352 | input.read((char *) start , C * sizeof(uint32_t)); 353 | input.read((char *) len , C * sizeof(uint32_t)); 354 | input.read((char *) id , N * sizeof(uint32_t)); 355 | input.read((char *) dist_to_c , N * sizeof(float)); 356 | input.read((char *) x0 , N * sizeof(float)); 357 | 358 | input.read((char *) centroid , C * B * sizeof(float)); 359 | input.read((char *) data, 1ull * N * D * sizeof(float)); 360 | input.read((char *) binary_code, 1ull * N * B / 64 * sizeof(uint64_t)); 361 | 362 | #if defined(FAST_SCAN) 363 | packed_start = new uint32_t [C]; 364 | int cur = 0; 365 | for(int i=0;i(aligned_alloc(32, cur * sizeof(uint8_t))); 370 | for(int i=0;i(binary_code + 1ull * start[i] * (B / 64), len[i], packed_code + 1ull * packed_start[i]); 372 | } 373 | #else 374 | packed_start = NULL; 375 | packed_code = NULL; 376 | #endif 377 | for(int i=0;i 391 | IVFRN::IVFRN(){ 392 | N = C = 0; 393 | start = len = id = NULL; 394 | x0 = dist_to_c = centroid = data = NULL; 395 | binary_code = NULL; 396 | fac = NULL; 397 | u = NULL; 398 | } 399 | 400 | template 401 | IVFRN::IVFRN(const Matrix &X, const Matrix &_centroids, const Matrix &dist_to_centroid, 402 | const Matrix &_x0, const Matrix &cluster_id, const Matrix &binary){ 403 | fac=NULL; 404 | u = NULL; 405 | 406 | N = X.n; 407 | C = _centroids.n; 408 | 409 | // check uint64_t 410 | assert(B % 64 == 0); 411 | assert(B >= D); 412 | 413 | start = new uint32_t [C]; 414 | len = new uint32_t [C]; 415 | id = new uint32_t [N]; 416 | dist_to_c = new float [N]; 417 | x0 = new float [N]; 418 | 419 | memset(len, 0, C * sizeof(uint32_t)); 420 | for(int i=0;i 454 | IVFRN::~IVFRN(){ 455 | if(id != NULL) delete [] id; 456 | if(dist_to_c != NULL) delete [] dist_to_c; 457 | if(len != NULL) delete [] len; 458 | if(start != NULL) delete [] start; 459 | if(x0 != NULL) delete [] x0; 460 | if(data != NULL) delete [] data; 461 | if(fac != NULL) delete [] fac; 462 | if(u != NULL) delete [] u; 463 | if(binary_code != NULL) std::free(binary_code); 464 | // if(pack_codes != NULL) std::free(pack_codes); 465 | if(centroid != NULL) std::free(centroid); 466 | } 467 | --------------------------------------------------------------------------------