├── .gitmodules ├── LICENSE.txt ├── NOTICE.txt ├── README.md ├── benchmark ├── algorithms │ └── ann_elasticsearch.py ├── algos.yaml └── install │ └── Dockerfile.ann-elasticsearch ├── build.gradle ├── examples ├── image-search │ ├── README.md │ ├── images │ │ ├── 101502.jpg │ │ ├── 101503.jpg │ │ ├── 101504.jpg │ │ ├── 126200.jpg │ │ ├── 126201.jpg │ │ └── 143700.jpg │ ├── models.py │ ├── prepare.sh │ └── search_example.py ├── lib │ └── common.py └── question-answering │ ├── README.md │ ├── dataset.py │ ├── loss.py │ ├── models.py │ ├── prepare.sh │ ├── search_example.py │ └── train.py ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat └── src ├── main ├── java │ └── org │ │ └── elasticsearch │ │ ├── analysis │ │ ├── CodeAttribute.java │ │ ├── CodeAttributeImpl.java │ │ ├── IvfpqAnalyzer.java │ │ ├── IvfpqAnalyzerProvider.java │ │ └── IvfpqTokenizer.java │ │ ├── ann │ │ ├── AlgebraicOps.java │ │ ├── ArrayUtils.java │ │ ├── ExactSearch.java │ │ └── ProductQuantizer.java │ │ ├── mapper │ │ └── IvfpqFieldMapper.java │ │ ├── plugin │ │ └── AnnPlugin.java │ │ └── search │ │ ├── IvfpqQuery.java │ │ └── IvfpqQueryBuilder.java └── plugin-metadata │ └── plugin-security.policy └── test └── java └── org └── elasticsearch ├── ann ├── ExactSearchTests.java └── ProductQuantizerTests.java └── plugin └── AnnPluginIT.java /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "examples/question-answering/dataset"] 2 | path = examples/question-answering/dataset 3 | url = https://github.com/shuzi/insuranceQA.git 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Source code in this repository is variously licensed under the Apache License 2 | Version 2.0, an Apache compatible license, or the Elastic License. Outside of 3 | the "x-pack" folder, source code in a given file is licensed under the Apache 4 | License Version 2.0, unless otherwise noted at the beginning of the file or a 5 | LICENSE file present in the directory subtree declares a separate license. 6 | Within the "x-pack" folder, source code in a given file is licensed under the 7 | Elastic License, unless otherwise noted at the beginning of the file or a 8 | LICENSE file present in the directory subtree declares a separate license. 9 | 10 | The build produces two sets of binaries - one set that falls under the Elastic 11 | License and another set that falls under Apache License Version 2.0. The 12 | binaries that contain `-oss` in the artifact name are licensed under the Apache 13 | License Version 2.0. 14 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/NOTICE.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Elasticsearch Approximate Nearest Neighbor plugin 2 | 3 | This repository provides product quantization based approximate nearest neighbor elasticsearch plugin for searching high-dimensional dense vectors. 4 | 5 | It can be used for various purposes with neural network frameworks (like tensorflow, pytorch, etc). 6 | For example, 7 | 1. Similar images search 8 | 2. Question answering 9 | 3. Recommendation or learning to rank 10 | 11 | (See [examples](./examples)) 12 | 13 | Product quantization implementation is based on paper "Product quantization for nearest neighbor search" - Herve Jegou, Matthijs Douze, Cordelia Schmid. 14 | 15 | ## Installation 16 | 17 | See the full list of prebuilt versions. If you don"t see a version available, see the link below for building or file a request via issues. 18 | 19 | To install, you"d run a command such as: 20 | 21 | ./bin/elasticsearch-plugin install http://${release_zip_file} 22 | 23 | After installation, you need to restart running Elasticsearch. 24 | 25 | 26 | ## Usage 27 | 28 | This plugin provides custom mapper type, analyzer, and search query. 29 | 30 | ### Create mapping 31 | 32 | Before you create mapping, product quantizer parameters has to be trained. 33 | See [example](./examples/lib/common.py#L4-L10) code. 34 | 35 | PUT sample_images 36 | 37 | ```json 38 | { 39 | "settings": { 40 | "analysis": { 41 | "analyzer": { 42 | "image_analyzer": { 43 | "type": "ivfpq_analyzer", 44 | "d": 256, 45 | "m": 128, 46 | "ksub": 64, 47 | "coarseCentroids": "0.007750325836241245,0.0010391526157036424,...,0.031184080988168716", 48 | "pqCentroids": "0.00041024317033588886,0.022187601774930954,...,0.001461795181967318", 49 | } 50 | } 51 | } 52 | }, 53 | "mappings": { 54 | "vector": { 55 | "_source": { 56 | "excludes": [ 57 | "feature" 58 | ] 59 | }, 60 | "properties": { 61 | "name": { 62 | "type": "keyword" 63 | }, 64 | "feature": { 65 | "type": "ivfpq", 66 | "analyzer": "image_analyzer" 67 | } 68 | } 69 | } 70 | } 71 | } 72 | ``` 73 | 74 | 75 | ### Index vector data 76 | 77 | The following example adds or updates a vector data in a specific index, making it searchable. 78 | 79 | POST sample_images/image/1 80 | 81 | ```json 82 | { 83 | "name": "1.jpg", 84 | "feature": "0.018046028912067413,0.0010425627697259188,...,0.0012223172234371305" 85 | } 86 | ``` 87 | 88 | ### Search similar vectors 89 | 90 | The ivfpq_query within the search request body could be used with other elasticsearch queries. 91 | 92 | POST sample_images/_search 93 | 94 | ```json 95 | { 96 | "query": { 97 | "ivfpq_query": { 98 | "query": "0.02125333994626999,0.000217707478441298,...,0.001304438104853034", 99 | "fields": ["feature"] 100 | } 101 | }, 102 | "sort": [ 103 | {"_score": {"order": "asc"}} 104 | ] 105 | } 106 | ``` 107 | 108 | ## Development 109 | 110 | If you want to build for a new elasticsearch version which is not released, you could build by the following way. 111 | 112 | ### 1. Build with Gradle Wrapper 113 | 114 | ./gradlew build 115 | 116 | This command generates a Elasticsearch plugin zip file. 117 | 118 | 119 | ### 2. Install with ./bin/elasticsearch-plugin 120 | 121 | ./bin/elasticsearch-plugin install file:///${path_to_generated_zip_file} 122 | -------------------------------------------------------------------------------- /benchmark/algorithms/ann_elasticsearch.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import subprocess 4 | import sys 5 | import time 6 | 7 | import elasticsearch 8 | import elasticsearch.helpers 9 | import sympy 10 | from tqdm import tqdm 11 | 12 | sys.path.append("install/lib-faiss") 13 | import numpy 14 | import sklearn.preprocessing 15 | import faiss 16 | from ann_benchmarks.algorithms.base import BaseANN 17 | 18 | 19 | class ANNElasticsearch(BaseANN): 20 | 21 | def __init__(self, metric, n_list): 22 | self.n_list = n_list 23 | self.n_probe = None 24 | self.metric = metric 25 | self.run_server() 26 | self.client = elasticsearch.Elasticsearch() 27 | 28 | def query(self, v, n): 29 | if self.metric == 'angular': 30 | v /= numpy.linalg.norm(v) 31 | response = self.client.search('vectors', body={ 32 | 'query': { 33 | 'ivfpq_query': { 34 | 'query': ','.join(map(str, v)), 35 | 'fields': ['vector'], 36 | 'nprobe': self.n_probe 37 | } 38 | }, 39 | "sort": [ 40 | {"_score": {"order": "asc"}}, 41 | ], 42 | 'size': n 43 | }) 44 | return [int(hit['_id']) for hit in response['hits']['hits']] 45 | 46 | def fit(self, X): 47 | if self.metric == 'angular': 48 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 49 | if X.dtype != numpy.float32: 50 | X = X.astype(numpy.float32) 51 | d = X.shape[1] 52 | m = self.get_subvector_size(d) 53 | quantizer = faiss.IndexFlatL2(d) 54 | index = faiss.IndexIVFPQ(quantizer, d, self.n_list, m, 8) 55 | index.train(X) 56 | coarse_centroids = [quantizer.xb.at(i) for i in range(quantizer.xb.size())] 57 | pq_centroids = [index.pq.centroids.at(i) for i in range(index.pq.centroids.size())] 58 | self.create_mapping(d, m, index.pq.ksub, coarse_centroids, pq_centroids) 59 | self.add_vectors(X) 60 | 61 | def set_query_arguments(self, n_probe): 62 | self.n_probe = n_probe 63 | 64 | def run_server(self): 65 | subprocess.call(['bash', './elasticsearch/bin/start.sh']) 66 | time.sleep(10) 67 | 68 | def get_subvector_size(self, d, k=64): 69 | dsub = 1 70 | factors = sympy.ntheory.factorint(d) 71 | for p, n in sorted(factors.items(), key=lambda x: -x[0]): 72 | for i in range(n): 73 | dsub *= p 74 | if dsub >= k: 75 | break 76 | return d // dsub 77 | 78 | def create_mapping(self, d, m, ksub, coarse_centroids, pq_centroids): 79 | self.client.indices.create('vectors', { 80 | 'settings': { 81 | 'analysis': { 82 | 'analyzer': { 83 | 'ann_analyzer': { 84 | 'type': 'ivfpq_analyzer', 85 | 'd': d, 86 | 'm': m, 87 | 'ksub': ksub, 88 | 'coarseCentroids': ','.join(map(str, coarse_centroids)), 89 | 'pqCentroids': ','.join(map(str, pq_centroids)) 90 | } 91 | } 92 | } 93 | }, 94 | 'mappings': { 95 | 'vector': { 96 | "_source": { 97 | "enabled": False 98 | }, 99 | 'properties': { 100 | 'vector': { 101 | 'type': 'ivfpq', 102 | 'analyzer': 'ann_analyzer' 103 | } 104 | } 105 | } 106 | } 107 | }) 108 | 109 | def add_vectors(self, vectors, batch_size=1024): 110 | for i in tqdm(range(0, vectors.shape[0], batch_size)): 111 | actions = [{ 112 | '_index': 'vectors', 113 | '_type': 'vector', 114 | '_id': i + j, 115 | '_source': { 116 | 'vector': ','.join(map(str, vector)) 117 | } 118 | } for j, vector in enumerate(vectors[i:i+batch_size])] 119 | elasticsearch.helpers.bulk(self.client, actions) 120 | print('create index: finish:', self.client.count(index='vectors', doc_type='vector')) 121 | 122 | def __str__(self): 123 | return 'ANN-Elasticsearch(n_list=%d, n_probe=%d)' % (self.n_list, self.n_probe) 124 | -------------------------------------------------------------------------------- /benchmark/algos.yaml: -------------------------------------------------------------------------------- 1 | float: 2 | any: 3 | ann-elasticsearch: 4 | docker-tag: ann-benchmarks-ann-elasticsearch 5 | module: ann_benchmarks.algorithms.ann_elasticsearch 6 | constructor: ANNElasticsearch 7 | base-args: ["@metric"] 8 | run-groups: 9 | base: 10 | args: [[32,64,128,256,512,1024,2048,4096]] 11 | query-args: [[1, 5, 10, 50, 100, 200]] 12 | -------------------------------------------------------------------------------- /benchmark/install/Dockerfile.ann-elasticsearch: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks-faiss 2 | 3 | ENV WORK_DIR=/home/app 4 | 5 | # install dependent packages 6 | RUN apt install -y curl 7 | 8 | # install openjdk11 9 | RUN curl -s -O https://download.java.net/java/GA/jdk11/13/GPL/openjdk-11.0.1_linux-x64_bin.tar.gz && \ 10 | tar -zxvf openjdk-11.0.1_linux-x64_bin.tar.gz && \ 11 | rm openjdk-11.0.1_linux-x64_bin.tar.gz 12 | ENV JAVA_HOME=${WORK_DIR}/jdk-11.0.1 13 | 14 | # install elasticsearch 15 | ENV VERSION=6.5.4 16 | RUN curl -s -O https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-6.5.4.tar.gz && \ 17 | tar -zxvf elasticsearch-${VERSION}.tar.gz && \ 18 | rm elasticsearch-${VERSION}.tar.gz && \ 19 | mv elasticsearch-${VERSION} elasticsearch 20 | RUN useradd elasticsearch 21 | RUN echo 'su elasticsearch -c "/home/app/elasticsearch/bin/elasticsearch -d"' > elasticsearch/bin/start.sh 22 | RUN chown -R elasticsearch.elasticsearch elasticsearch 23 | 24 | # install elasticsearch-ann plugin 25 | RUN git clone https://github.com/rixwew/elasticsearch-approximate-nearest-neighbor.git 26 | RUN cd elasticsearch-approximate-nearest-neighbor && \ 27 | ./gradlew bundlePlugin && cd .. && \ 28 | ./elasticsearch/bin/elasticsearch-plugin install -b \ 29 | file:///${WORK_DIR}/elasticsearch-approximate-nearest-neighbor/build/distributions/ann-${VERSION}-1.0.zip && \ 30 | rm -rf elasticsearch-approximate-nearest-neighbor 31 | 32 | # install dependent packages 33 | RUN pip install sympy elasticsearch tqdm 34 | 35 | ENTRYPOINT ["python", "run_algorithm.py"] 36 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | ext { 3 | elasticsearchVersion = "6.5.4" 4 | } 5 | repositories { 6 | mavenCentral() 7 | jcenter() 8 | } 9 | dependencies { 10 | classpath "org.elasticsearch.gradle:build-tools:$elasticsearchVersion" 11 | } 12 | } 13 | 14 | group = 'org.elasticsearch' 15 | version = "$elasticsearchVersion-1.0" 16 | 17 | apply plugin: 'java' 18 | apply plugin: 'idea' 19 | apply plugin: 'elasticsearch.esplugin' 20 | 21 | licenseFile = rootProject.file('LICENSE.txt') 22 | noticeFile = rootProject.file('NOTICE.txt') 23 | esplugin { 24 | name 'ann' 25 | description 'Approximate Nearest Neighbor Elasticsearch Plugin' 26 | classname 'org.elasticsearch.plugin.AnnPlugin' 27 | licenseFile rootProject.file('LICENSE.txt') 28 | noticeFile rootProject.file('NOTICE.txt') 29 | } 30 | 31 | -------------------------------------------------------------------------------- /examples/image-search/README.md: -------------------------------------------------------------------------------- 1 | # Similar Image Search Example 2 | 3 | ## Requirement 4 | 5 | * pytorch 1.0 6 | * torchvision 7 | * numpy 8 | * cv2 9 | * elasticsearch 10 | 11 | ## Download INRIA Holidays dataset 12 | 13 | ```bash 14 | bash prepare.sh 15 | ``` 16 | 17 | ## Search similar images using elasticsearch plugin 18 | 19 | ```bash 20 | export PYTHONPATH=$PATH_TO_SCRIPT_DIR/lib:$PYTHONPATH 21 | python search_example.py --query "./images/143700.jpg" 22 | --result_size 5 23 | ``` 24 | 25 | ```json 26 | { 27 | "took": 8, 28 | "timed_out": false, 29 | "_shards": { 30 | "total": 5, 31 | "successful": 5, 32 | "skipped": 0, 33 | "failed": 0 34 | }, 35 | "hits": { 36 | "total": 287, 37 | "max_score": null, 38 | "hits": [ 39 | { 40 | "_index": "images", 41 | "_type": "image", 42 | "_id": "plcrb2gBaJEWlukYw7kH", 43 | "_score": 0.24182819, 44 | "_source": { 45 | "description": "dataset/jpg/126201.jpg" 46 | }, 47 | "sort": [ 48 | 0.24182819 49 | ] 50 | }, 51 | { 52 | "_index": "images", 53 | "_type": "image", 54 | "_id": "v1crb2gBaJEWlukYhbhJ", 55 | "_score": 0.2570988, 56 | "_source": { 57 | "description": "dataset/jpg/126200.jpg" 58 | }, 59 | "sort": [ 60 | 0.2570988 61 | ] 62 | }, 63 | { 64 | "_index": "images", 65 | "_type": "image", 66 | "_id": "uVcsb2gBaJEWlukYD7pf", 67 | "_score": 0.25885358, 68 | "_source": { 69 | "description": "dataset/jpg/101504.jpg" 70 | }, 71 | "sort": [ 72 | 0.25885358 73 | ] 74 | }, 75 | { 76 | "_index": "images", 77 | "_type": "image", 78 | "_id": "e1crb2gBaJEWlukYt7mo", 79 | "_score": 0.27128482, 80 | "_source": { 81 | "description": "dataset/jpg/101502.jpg" 82 | }, 83 | "sort": [ 84 | 0.27128482 85 | ] 86 | }, 87 | { 88 | "_index": "images", 89 | "_type": "image", 90 | "_id": "WVcrb2gBaJEWlukYrrlh", 91 | "_score": 0.28078148, 92 | "_source": { 93 | "description": "dataset/jpg/101503.jpg" 94 | }, 95 | "sort": [ 96 | 0.28078148 97 | ] 98 | } 99 | ] 100 | } 101 | } 102 | ``` 103 | 104 | ### Sample Image Query 105 | 106 | 107 | 108 | ### Search Result 109 | 110 | 111 |

112 | 113 | 114 | 115 | 116 | 117 |

118 | -------------------------------------------------------------------------------- /examples/image-search/images/101502.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/101502.jpg -------------------------------------------------------------------------------- /examples/image-search/images/101503.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/101503.jpg -------------------------------------------------------------------------------- /examples/image-search/images/101504.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/101504.jpg -------------------------------------------------------------------------------- /examples/image-search/images/126200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/126200.jpg -------------------------------------------------------------------------------- /examples/image-search/images/126201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/126201.jpg -------------------------------------------------------------------------------- /examples/image-search/images/143700.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/examples/image-search/images/143700.jpg -------------------------------------------------------------------------------- /examples/image-search/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | class ImageEncoder(torch.nn.Module): 6 | 7 | def __init__(self): 8 | super(ImageEncoder, self).__init__() 9 | model = torchvision.models.alexnet(True) 10 | model.classifier = torch.nn.Sequential( 11 | *list(model.classifier.children())[:-1]) 12 | self.encoder = model 13 | 14 | def forward(self, x): 15 | x = self.encoder(x) 16 | norm = x.norm(p=2, dim=1, keepdim=True) 17 | x = x.div(norm.expand_as(x)) 18 | return x 19 | -------------------------------------------------------------------------------- /examples/image-search/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download INRIA Holidays dataset 4 | mkdir -p dataset && cd dataset 5 | curl -O ftp://ftp.inrialpes.fr/pub/lear/douze/data/jpg1.tar.gz 6 | tar -xzvf jpg1.tar.gz 7 | -------------------------------------------------------------------------------- /examples/image-search/search_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import cv2 5 | import elasticsearch 6 | import numpy as np 7 | import torch 8 | from common import SearchClient, fit_pq_params 9 | 10 | from models import ImageEncoder 11 | 12 | 13 | def get_features_batch(encoder, images, use_cuda): 14 | with torch.no_grad(): 15 | x = torch.FloatTensor(images).transpose(1, 3) / 255 16 | if use_cuda: 17 | x = x.cuda() 18 | return encoder(x).cpu().numpy() 19 | 20 | 21 | def get_features(image_encoder, image_path_iter, use_cuda, batch_size=64): 22 | feats, names, images = list(), list(), list() 23 | for image_path in image_path_iter: 24 | name = str(image_path) 25 | names.append(name) 26 | image = cv2.imread(name) 27 | image = cv2.resize(image, (224, 224)) 28 | images.append(image) 29 | if len(images) == batch_size: 30 | feats.append(get_features_batch(image_encoder, images, use_cuda)) 31 | images = list() 32 | if len(images) > 0: 33 | feats.append(get_features_batch(image_encoder, images, use_cuda)) 34 | return names, np.concatenate(feats, axis=0) 35 | 36 | 37 | def main(query, result_size, dataset_path, nlist, m, use_cuda): 38 | image_encoder = ImageEncoder().eval() 39 | if use_cuda: 40 | image_encoder = image_encoder.cuda() 41 | es = elasticsearch.Elasticsearch() 42 | client = SearchClient(es, index_name='images', type_name='image') 43 | names, feats = get_features(image_encoder, Path(dataset_path).iterdir(), use_cuda) 44 | coarse_centroids, pq_centroids, ksub, dsub = fit_pq_params(feats, feats.shape[1], nlist, m) 45 | client.create_mapping(feats.shape[1], m, ksub, coarse_centroids, pq_centroids) 46 | client.add_vectors(names, feats) 47 | _, encoded_query = get_features(image_encoder, [query], use_cuda, batch_size=1) 48 | result = client.query(encoded_query[0], result_size) 49 | print(json.dumps(result, indent=2)) 50 | 51 | 52 | if __name__ == '__main__': 53 | import argparse 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--query', required=True) 57 | parser.add_argument('--result_size', type=int, default=5) 58 | parser.add_argument('--dataset', default='dataset/jpg') 59 | parser.add_argument('--nlist', type=int, default=8) 60 | parser.add_argument('--m', type=int, default=64) 61 | parser.add_argument('--use_cuda', type=bool, default=True) 62 | args = parser.parse_args() 63 | 64 | main(query=args.query, 65 | result_size=args.result_size, 66 | dataset_path=args.dataset, 67 | nlist=args.nlist, 68 | m=args.m, 69 | use_cuda=args.use_cuda) 70 | -------------------------------------------------------------------------------- /examples/lib/common.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | 3 | 4 | def fit_pq_params(xb, d, nlist, m): 5 | quantizer = faiss.IndexFlatL2(d) 6 | index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 4) 7 | index.train(xb) 8 | coarse_centroids = [quantizer.xb.at(i) for i in range(quantizer.xb.size())] 9 | pq_centroids = [index.pq.centroids.at(i) for i in range(index.pq.centroids.size())] 10 | return coarse_centroids, pq_centroids, index.pq.ksub, index.pq.dsub 11 | 12 | 13 | class SearchClient(object): 14 | 15 | def __init__(self, client, index_name, type_name): 16 | self.client = client 17 | self.index_name = index_name 18 | self.type_name = type_name 19 | 20 | def add_vectors(self, descriptions, feats): 21 | for i in range(len(descriptions)): 22 | description = descriptions[i] 23 | feat = feats[i].tolist() 24 | doc = { 25 | 'description': description, 26 | 'vector': ','.join(map(str, feat)) 27 | } 28 | self.client.index(index=self.index_name, 29 | doc_type=self.type_name, 30 | body=doc) 31 | 32 | def query(self, feat, result_size=10): 33 | query = { 34 | 'query': { 35 | 'ivfpq_query': { 36 | 'query': ','.join(map(str, feat)), 37 | 'fields': ['vector'] 38 | } 39 | }, 40 | 'sort': {'_score': {'order': 'asc'}}, 41 | 'size': result_size 42 | } 43 | response = self.client.search(self.index_name, body=query) 44 | return response 45 | 46 | def create_mapping(self, d, m, ksub, coarse_centroids, pq_centroids): 47 | body = { 48 | 'settings': { 49 | 'analysis': { 50 | 'analyzer': { 51 | 'ann_analyzer': { 52 | 'type': 'ivfpq_analyzer', 53 | 'd': d, 54 | 'm': m, 55 | 'ksub': ksub, 56 | 'coarseCentroids': ','.join(map(str, coarse_centroids)), 57 | 'pqCentroids': ','.join(map(str, pq_centroids)) 58 | } 59 | } 60 | } 61 | }, 62 | 'mappings': { 63 | self.type_name: { 64 | "_source": { 65 | "excludes": [ 66 | "vector" 67 | ] 68 | }, 69 | 'properties': { 70 | 'description': { 71 | 'type': 'keyword' 72 | }, 73 | 'vector': { 74 | 'type': 'ivfpq', 75 | 'analyzer': 'ann_analyzer' 76 | } 77 | } 78 | } 79 | } 80 | } 81 | self.client.indices.create(self.index_name, body) 82 | -------------------------------------------------------------------------------- /examples/question-answering/README.md: -------------------------------------------------------------------------------- 1 | # Question answering Example 2 | 3 | Question answering implementation is based on paper LSTM-based Deep Learning Models 4 | for Non-factoid Answer Selection - Tan, dos Santos, Xiang and Zhou. 5 | 6 | ## Requirement 7 | 8 | * pytorch 1.0 9 | * numpy 10 | * gensim 11 | * elasticsearch 12 | 13 | ## Download insurance qa data and train model 14 | 15 | ```bash 16 | bash prepare.sh 17 | python train.py 18 | ``` 19 | 20 | InsuranceQA Version1 top1 precision result 21 | 22 | | Model | Validation | Test1 | Test2 | 23 | |:---------------------------------|-----------:|------:|------:| 24 | | QA-LSTM basic-model, max pooling(100 epoch) | 62.2 | 63.8 | 58.8 | 25 | | QA-LSTM basic-model, max pooling(paper) | 64.3 | 63.1 | 58.0 | 26 | 27 | 28 | ## Search answers using elasticsearch plugin 29 | 30 | ```bash 31 | export PYTHONPATH=$PATH_TO_SCRIPT_DIR/lib:$PYTHONPATH 32 | python search_example.py --question "Can a Non us citizen get Life Insurance" 33 | --result_size 5 34 | ``` 35 | 36 | ```json 37 | { 38 | "took": 36, 39 | "timed_out": false, 40 | "_shards": { 41 | "total": 5, 42 | "successful": 5, 43 | "skipped": 0, 44 | "failed": 0 45 | }, 46 | "hits": { 47 | "total": 870, 48 | "max_score": null, 49 | "hits": [ 50 | { 51 | "_index": "answers", 52 | "_type": "answer", 53 | "_id": "o1i1f2gBaJEWlukYG7sK", 54 | "_score": 0.5443098, 55 | "_source": { 56 | "description": "a non citizen can get life insurance with most company if they have a green card or an H-1b work visa some company do require the applicant be a US citizen before allow them get a life insurance policy and some will only allow green card but not work visa contact an agent find out which company will work for your situation" 57 | }, 58 | "sort": [ 59 | 0.5443098 60 | ] 61 | }, 62 | { 63 | "_index": "answers", 64 | "_type": "answer", 65 | "_id": "81i2f2gBaJEWlukYacAb", 66 | "_score": 0.7198508, 67 | "_source": { 68 | "description": "yes there be absolutely no requirement a person be a citizen buy life insurance each company make its own decision on requirement but citizenship be not 1 them so long as you be in the country legally you can buy life insurance different ID be require different carrier but rest assure if your age and health warrant it you can buy life insurance on yourself here in the USA love help thank you Gary Lane" 69 | }, 70 | "sort": [ 71 | 0.7198508 72 | ] 73 | }, 74 | { 75 | "_index": "answers", 76 | "_type": "answer", 77 | "_id": "0Fiyf2gBaJEWlukYdLDC", 78 | "_score": 0.75013983, 79 | "_source": { 80 | "description": "you do not have be a citizen obtain life insurance US life insurer require the propose insured must be a permanent resident of the US that mean a US citizen or a non US citizen who be a lawful permanent US resident ( green card or on certain visa type the applicant will also need have the means pay premium and have a demonstrable life insurance need i.e. generate earn income or asset protect here some insurer have develop foreign national program that can also work in situation where established US interest and tie exist plus meet some additional criterion citizen of some country may not be eligible it can be a complex area of field underwriting so much so that our firm have develop a special questionnaire help shop for coverage be sure work with a life insurance professional with experience in this area" 81 | }, 82 | "sort": [ 83 | 0.75013983 84 | ] 85 | }, 86 | { 87 | "_index": "answers", 88 | "_type": "answer", 89 | "_id": "9Fi2f2gBaJEWlukYacBr", 90 | "_score": 0.75358534, 91 | "_source": { 92 | "description": "yes a non US citizen can get life insurance with many American company it be up to the discretion of each company as to what type of citizenship or residency they will accept a green card be usually ok and many company will accept a work visa as qualification for apply for life insurance in the US get life insurance in the US as a non US citizen however almost always require have a residence in the United States" 93 | }, 94 | "sort": [ 95 | 0.75358534 96 | ] 97 | }, 98 | { 99 | "_index": "answers", 100 | "_type": "answer", 101 | "_id": "Nliwf2gBaJEWlukYIKcx", 102 | "_score": 0.7816857, 103 | "_source": { 104 | "description": "almost anyone can get life insurance the only people who can not get life insurance those who have serious health problem who fall outside the age guideline guarantee issue those who do not have any income at all even they may able to get a policy with a cap on the face amount in the us those who do not have citizenship a green card work visa" 105 | }, 106 | "sort": [ 107 | 0.7816857 108 | ] 109 | } 110 | ] 111 | } 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /examples/question-answering/dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy 4 | import torch.utils.data 5 | 6 | 7 | class Vocab(object): 8 | 9 | def __init__(self, vocab_path, lexicon, unk_surf='', thresh=5): 10 | self.vid2surf = dict() 11 | lexicon = {vocab_id for vocab_id, count in lexicon.items() if count >= thresh} 12 | with open(vocab_path, encoding='utf-8') as f: 13 | for _line in f: 14 | vocab_id, surf = _line.rstrip().split('\t') 15 | if vocab_id in lexicon: 16 | self.vid2surf[vocab_id] = surf 17 | self.vid2wid = {vocab_id: i + 1 for i, vocab_id in enumerate(self.vid2surf)} 18 | self.wid2surf = {self.vid2wid.get(vid): surf for vid, surf in self.vid2surf.items()} 19 | self.unk_surf = unk_surf 20 | self.unk_word_id = len(self.vid2wid) + 1 21 | 22 | def surfaces(self, vocab_ids): 23 | return [self.vid2surf.get(vocab_id, self.unk_surf) for vocab_id in vocab_ids] 24 | 25 | def word_ids(self, vocab_ids): 26 | return [self.vid2wid.get(vocab_id, self.unk_word_id) for vocab_id in vocab_ids] 27 | 28 | def __len__(self): 29 | return len(self.vid2surf) + 1 30 | 31 | 32 | class AnswerData(object): 33 | 34 | def __init__(self, data_path): 35 | self.answers = dict() 36 | self.lexicon = list() 37 | with open(data_path, encoding='utf-8') as f: 38 | for _line in f: 39 | answer_id, answer = _line.rstrip().split('\t') 40 | vocab_ids = answer.split(' ') 41 | self.answers[int(answer_id)] = vocab_ids 42 | self.lexicon.extend(vocab_ids) 43 | self.lexicon = collections.Counter(self.lexicon) 44 | 45 | 46 | class QaData(object): 47 | 48 | def __init__(self, data_path): 49 | self.questions = list() 50 | self.positive = list() 51 | self.negative = list() 52 | self.lexicon = list() 53 | with open(data_path, encoding='utf-8') as f: 54 | for _line in f: 55 | values = _line.rstrip().split('\t') 56 | if len(values) == 2: 57 | question, answer_ids = values 58 | positive_ids = list(map(int, answer_ids.split(' '))) 59 | negative_ids = list() 60 | elif len(values) == 3: 61 | answer_ids, question, pool = values 62 | positive_ids = list(map(int, answer_ids.split(' '))) 63 | negative_ids = list(filter(lambda x: x not in set(positive_ids), 64 | map(int, pool.split(' ')))) 65 | else: 66 | continue 67 | vocab_ids = question.split(' ') 68 | self.questions.append(vocab_ids) 69 | self.lexicon.extend(vocab_ids) 70 | self.positive.append(positive_ids) 71 | self.negative.append(negative_ids) 72 | self.lexicon = collections.Counter(self.lexicon) 73 | 74 | 75 | class InsuranceQaDataset(torch.utils.data.Dataset): 76 | 77 | def __init__(self, question_data, answer_data, vocab, max_length=200): 78 | self.vocab = vocab 79 | self.positive = question_data.positive 80 | self.negative = question_data.negative 81 | self.questions = list(map(self.vocab.word_ids, question_data.questions)) 82 | self.answer_map = dict() 83 | for answer_id, vids in answer_data.answers.items(): 84 | self.answer_map[answer_id] = self.vocab.word_ids(vids[:max_length]) 85 | self.answers = list(self.answer_map.values()) 86 | 87 | def __len__(self): 88 | return len(self.questions) 89 | 90 | def __getitem__(self, index): 91 | question, positive_ids, negative_ids = \ 92 | self.questions[index], self.positive[index], self.negative[index] 93 | positive = self.answer_map[positive_ids[numpy.random.randint(len(positive_ids))]] 94 | if len(negative_ids) > 0: 95 | negative = self.answer_map[negative_ids[numpy.random.randint(len(negative_ids))]] 96 | else: 97 | negative = self.answers[numpy.random.randint(len(self.answers))] 98 | return torch.LongTensor(question), \ 99 | torch.LongTensor(positive), \ 100 | torch.LongTensor(negative) 101 | 102 | def get_qa_entry(self, index): 103 | question, positive_ids, negative_ids = \ 104 | self.questions[index], self.positive[index], self.negative[index] 105 | positives = [self.answer_map[positive_id] for positive_id in positive_ids] 106 | negatives = [self.answer_map[negative_id] for negative_id in negative_ids] 107 | return question, positives, negatives 108 | 109 | @classmethod 110 | def collate(cls, batch): 111 | qs, ps, ns = zip(*batch) 112 | return torch.nn.utils.rnn.pad_sequence(qs, batch_first=True), \ 113 | torch.nn.utils.rnn.pad_sequence(ps, batch_first=True), \ 114 | torch.nn.utils.rnn.pad_sequence(ns, batch_first=True), 115 | -------------------------------------------------------------------------------- /examples/question-answering/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class QaLoss(torch.nn.Module): 5 | 6 | def __init__(self, margin): 7 | super().__init__() 8 | self.margin = margin 9 | 10 | def forward(self, question, positive, negative): 11 | """ 12 | max {0, margin - cosine(q, a+) + cosine(q, a-)} 13 | """ 14 | positive_sim = (question * positive).sum(1, keepdim=True) 15 | negative_sim = (question * negative).sum(1, keepdim=True) 16 | zeros = positive_sim.data.new_zeros(*positive_sim.shape) 17 | loss = torch.cat((zeros, negative_sim - positive_sim + self.margin), dim=1) 18 | loss, _ = torch.max(loss, dim=1) 19 | return torch.mean(loss) 20 | -------------------------------------------------------------------------------- /examples/question-answering/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 3 | 4 | 5 | class SentenceEncoder(torch.nn.Module): 6 | 7 | def __init__(self, embedding_weights, hidden_size): 8 | super().__init__() 9 | embedding_weights = torch.FloatTensor(embedding_weights) 10 | self.embedding = torch.nn.Embedding.from_pretrained(embedding_weights) 11 | self.rnn = torch.nn.LSTM(embedding_weights.shape[-1], hidden_size, 12 | batch_first=True, bidirectional=True) 13 | 14 | def forward(self, x): 15 | lengths = (-x.data.eq(0).long() + 1).sum(1) 16 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 17 | _, idx_unsort = torch.sort(idx_sort, dim=0) 18 | x = x.index_select(0, idx_sort) 19 | lengths = lengths.index_select(0, idx_sort) 20 | x = self.embedding(x) 21 | x = pack_padded_sequence(x, lengths, batch_first=True) 22 | x, *_ = self.rnn(x) 23 | x, _ = pad_packed_sequence(x, batch_first=True, padding_value=float('-inf')) 24 | x, _ = torch.max(x, dim=1) 25 | norm = x.norm(p=2, dim=1, keepdim=True) 26 | x = x.div(norm) 27 | x = x.index_select(0, idx_unsort) 28 | return x 29 | -------------------------------------------------------------------------------- /examples/question-answering/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # download and unzip insurance qa dataset 4 | git submodule update --recursive 5 | 6 | # download pretrained word2vec model 7 | curl -O https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz 8 | gzip -d GoogleNews-vectors-negative300.bin.gz 9 | -------------------------------------------------------------------------------- /examples/question-answering/search_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import elasticsearch 4 | import numpy as np 5 | import torch 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | from dataset import AnswerData 9 | from common import SearchClient, fit_pq_params 10 | 11 | 12 | def get_vid_surf_mappers(vocab_data_path): 13 | vid2surf = dict() 14 | with open(vocab_data_path, encoding='utf-8') as f: 15 | for line in f: 16 | vid, surf = line.rstrip().split('\t') 17 | vid2surf[vid] = surf 18 | return vid2surf, {surf: vid for vid, surf in vid2surf.items()} 19 | 20 | 21 | def get_features(model, word_ids_list, use_cuda, batch_size=64): 22 | feats = list() 23 | for batch_i in range(0, len(word_ids_list), batch_size): 24 | batch = word_ids_list[batch_i:batch_i + batch_size] 25 | with torch.no_grad(): 26 | batch = pad_sequence([torch.LongTensor(x) for x in batch], batch_first=True) 27 | if use_cuda: 28 | batch = batch.cuda() 29 | feat = model(batch) 30 | norm = feat.norm(dim=1, keepdim=True) 31 | feat = feat.div(norm.expand_as(feat)) 32 | feats.append(feat.cpu().numpy()) 33 | return np.concatenate(feats, axis=0) 34 | 35 | 36 | def main(question, 37 | result_size, 38 | answer_data_path, 39 | vocab_data_path, 40 | model_path, 41 | nlist, 42 | m, 43 | max_length, 44 | use_cuda): 45 | # load model 46 | state = torch.load(model_path) 47 | model = state['model'].eval() 48 | if not use_cuda: 49 | model = model.cpu() 50 | else: 51 | model = model.cuda() 52 | vocab = state['vocab'] 53 | 54 | vid2surf, surf2vid = get_vid_surf_mappers(vocab_data_path) 55 | answer_data = AnswerData(answer_data_path) 56 | answer_surfaces = list() 57 | answer_wids = list() 58 | for answer_vids in answer_data.answers.values(): 59 | answer_surfaces.append(' '.join([vid2surf.get(vid) for vid in answer_vids])) 60 | answer_wids.append(vocab.word_ids(answer_vids[:max_length])) 61 | 62 | question_wids = vocab.word_ids([surf2vid.get(surf, '_unk') for surf in question.split(' ')]) 63 | es = elasticsearch.Elasticsearch() 64 | client = SearchClient(es, index_name='answers', type_name='answer') 65 | feats = get_features(model, answer_wids, use_cuda) 66 | coarse_centroids, pq_centroids, ksub, dsub = fit_pq_params(feats, feats.shape[1], nlist, m) 67 | client.create_mapping(feats.shape[1], m, ksub, coarse_centroids, pq_centroids) 68 | client.add_vectors(answer_surfaces, feats) 69 | encoded_query = get_features(model, [question_wids], use_cuda, batch_size=1) 70 | result = client.query(encoded_query[0], result_size) 71 | print(json.dumps(result, indent=2)) 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--question', required=True) 79 | parser.add_argument('--result_size', type=int, default=5) 80 | parser.add_argument('--answer_data_path', default='dataset/V1/answers.label.token_idx') 81 | parser.add_argument('--vocab_data_path', default='dataset/V1/vocabulary') 82 | parser.add_argument('--model_path', default='./model.pt') 83 | parser.add_argument('--nlist', type=int, default=64) 84 | parser.add_argument('--m', type=int, default=47) 85 | parser.add_argument('--max_length', type=int, default=200) 86 | parser.add_argument('--use_cuda', type=bool, default=True) 87 | args = parser.parse_args() 88 | 89 | main(question=args.question, 90 | result_size=args.result_size, 91 | answer_data_path=args.answer_data_path, 92 | vocab_data_path=args.vocab_data_path, 93 | model_path=args.model_path, 94 | nlist=args.nlist, 95 | m=args.m, 96 | max_length=args.max_length, 97 | use_cuda=args.use_cuda) 98 | -------------------------------------------------------------------------------- /examples/question-answering/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gensim 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | from torch.nn.utils.rnn import pad_sequence 8 | from tqdm import tqdm 9 | 10 | from dataset import InsuranceQaDataset, Vocab, AnswerData, QaData 11 | from loss import QaLoss 12 | from models import SentenceEncoder 13 | 14 | 15 | def train(model, 16 | train_data_loader, 17 | test_dataset, 18 | optimizer, 19 | criterion, 20 | epoch, 21 | use_cuda, 22 | checkpoint=50): 23 | for epoch_i in range(epoch): 24 | model = model.train() 25 | total_loss = 0 26 | for question, positive, negative in tqdm(train_data_loader): 27 | if use_cuda: 28 | question, positive, negative = \ 29 | question.cuda(), positive.cuda(), negative.cuda() 30 | q_embed, p_embed, n_embed = model(question), model(positive), model(negative) 31 | loss = criterion(q_embed, p_embed, n_embed) 32 | 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | total_loss += loss.item() 37 | print('epoch:', epoch_i, 'loss:', total_loss / len(train_data_loader)) 38 | if (epoch_i + 1) % checkpoint == 0: 39 | test(model, test_dataset, use_cuda) 40 | 41 | 42 | def test(model, test_dataset, use_cuda): 43 | model = model.eval() 44 | accuracy = 0 45 | with torch.no_grad(): 46 | for i in tqdm(range(len(test_dataset))): 47 | question, positives, negatives = test_dataset.get_qa_entry(i) 48 | question = torch.LongTensor(question).unsqueeze(0) 49 | positives = pad_sequence([torch.LongTensor(x) for x in positives], batch_first=True) 50 | negatives = pad_sequence([torch.LongTensor(x) for x in negatives], batch_first=True) 51 | if use_cuda: 52 | question, positives, negatives = \ 53 | question.cuda(), positives.cuda(), negatives.cuda() 54 | q_embed, p_embed, n_embed = model(question), model(positives), model(negatives) 55 | 56 | p_sims = (q_embed * p_embed).sum(1, keepdim=True) 57 | n_sims = (q_embed * n_embed).sum(1, keepdim=True) 58 | result = list() 59 | result.extend([(score, 0) for score in n_sims]) 60 | result.extend([(score, 1) for score in p_sims]) 61 | accuracy += sorted(result, key=lambda x: -x[0])[0][1] 62 | print('test dataset: top1 accuracy:', accuracy / len(test_dataset)) 63 | 64 | 65 | def main(train_data_path, 66 | test_data_path, 67 | answer_data_path, 68 | vocab_data_path, 69 | embedding_data_path, 70 | batch_size, 71 | learning_rate, 72 | hidden_size, 73 | margin, 74 | epoch, 75 | save_path, 76 | pretrained_path, 77 | use_cuda): 78 | # load qa data 79 | answer_data = AnswerData(answer_data_path) 80 | train_data = QaData(train_data_path) 81 | test_data = QaData(test_data_path) 82 | 83 | # load pretrained embedding 84 | pretrained_embedding = gensim.models.KeyedVectors.load_word2vec_format( 85 | embedding_data_path, binary=True) 86 | vocab = Vocab(vocab_data_path, answer_data.lexicon + train_data.lexicon) 87 | pretrained_weights = np.zeros((len(vocab) + 1, 300)) # TODO magic number 88 | for wid, surf in vocab.wid2surf.items(): 89 | if surf in pretrained_embedding.vocab: 90 | pretrained_weights[wid] = pretrained_embedding.wv[surf] 91 | 92 | # create dataset / data loader 93 | train_dataset = InsuranceQaDataset(train_data, answer_data, vocab) 94 | train_data_loader = torch.utils.data.DataLoader(train_dataset, 95 | shuffle=True, 96 | batch_size=batch_size, 97 | collate_fn=train_dataset.collate) 98 | test_dataset = InsuranceQaDataset(test_data, answer_data, vocab) 99 | 100 | # train model 101 | if pretrained_path is not None: 102 | model = torch.load(pretrained_path)['model'] 103 | else: 104 | model = SentenceEncoder(pretrained_weights, hidden_size) 105 | optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate) 106 | criterion = QaLoss(margin=margin) 107 | if use_cuda: 108 | model = model.cuda() 109 | train(model, train_data_loader, test_dataset, optimizer, criterion, epoch, use_cuda) 110 | 111 | # save model 112 | torch.save({ 113 | 'model': model, 114 | 'vocab': vocab 115 | }, save_path) 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--train_data_path', default='dataset/V1/question.train.token_idx.label') 121 | parser.add_argument('--test_data_path', default='dataset/V1/question.dev.label.token_idx.pool') 122 | parser.add_argument('--answer_data_path', default='dataset/V1/answers.label.token_idx') 123 | parser.add_argument('--vocab_data_path', default='dataset/V1/vocabulary') 124 | parser.add_argument('--embedding_data_path', default='./GoogleNews-vectors-negative300.bin') 125 | parser.add_argument('--batch_size', type=int, default=128) 126 | parser.add_argument('--learning_rate', type=float, default=0.0005) 127 | parser.add_argument('--epoch', type=int, default=100) 128 | parser.add_argument('--hidden_size', type=int, default=141) 129 | parser.add_argument('--loss_margin', type=float, default=0.2) 130 | parser.add_argument('--save_path', default='./model.pt') 131 | parser.add_argument('--pretrained_path', default=None) 132 | parser.add_argument('--use_cuda', type=bool, default=True) 133 | args = parser.parse_args() 134 | 135 | main(train_data_path=args.train_data_path, 136 | test_data_path=args.test_data_path, 137 | answer_data_path=args.answer_data_path, 138 | vocab_data_path=args.vocab_data_path, 139 | embedding_data_path=args.embedding_data_path, 140 | batch_size=args.batch_size, 141 | learning_rate=args.learning_rate, 142 | hidden_size=args.hidden_size, 143 | margin=args.loss_margin, 144 | epoch=args.epoch, 145 | save_path=args.save_path, 146 | pretrained_path=args.pretrained_path, 147 | use_cuda=args.use_cuda) 148 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rixwew/elasticsearch-approximate-nearest-neighbor/e99ef5367a70dff79bde2f0f5879504d4e120fd4/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.2-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS="" 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Escape application args 158 | save () { 159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 160 | echo " " 161 | } 162 | APP_ARGS=$(save "$@") 163 | 164 | # Collect all arguments for the java command, following the shell quoting and substitution rules 165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 166 | 167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 169 | cd "$(dirname "$0")" 170 | fi 171 | 172 | exec "$JAVACMD" "$@" 173 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS= 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/analysis/CodeAttribute.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.analysis; 15 | 16 | import org.apache.lucene.util.Attribute; 17 | 18 | public interface CodeAttribute extends Attribute { 19 | 20 | short[] getCodes(); 21 | 22 | void setCodes(short[] codes); 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/analysis/CodeAttributeImpl.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.analysis; 15 | 16 | import org.apache.lucene.util.AttributeImpl; 17 | import org.apache.lucene.util.AttributeReflector; 18 | 19 | 20 | public class CodeAttributeImpl extends AttributeImpl implements CodeAttribute, Cloneable { 21 | 22 | private short[] codes; 23 | 24 | @Override 25 | public short[] getCodes() { 26 | return codes; 27 | } 28 | 29 | @Override 30 | public void setCodes(short[] codes) { 31 | this.codes = codes; 32 | } 33 | 34 | @Override 35 | public void clear() { 36 | codes = null; 37 | } 38 | 39 | @Override 40 | public void reflectWith(AttributeReflector reflector) { 41 | reflector.reflect(CodeAttribute.class, "codes", getCodes()); 42 | } 43 | 44 | @Override 45 | public void copyTo(AttributeImpl target) { 46 | CodeAttribute codeAttribute = (CodeAttribute) target; 47 | codeAttribute.setCodes(codes); 48 | 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/analysis/IvfpqAnalyzer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.analysis; 15 | 16 | import org.apache.lucene.analysis.Analyzer; 17 | import org.apache.lucene.analysis.Tokenizer; 18 | import org.elasticsearch.ann.ExactSearch; 19 | import org.elasticsearch.ann.ProductQuantizer; 20 | 21 | 22 | public class IvfpqAnalyzer extends Analyzer { 23 | 24 | private ExactSearch cq; 25 | 26 | private ProductQuantizer pq; 27 | 28 | public IvfpqAnalyzer(ExactSearch cq, ProductQuantizer pq) { 29 | this.cq = cq; 30 | this.pq = pq; 31 | } 32 | 33 | @Override 34 | protected TokenStreamComponents createComponents(String fieldName) { 35 | Tokenizer tokenizer = new IvfpqTokenizer(cq, pq); 36 | return new TokenStreamComponents(tokenizer); 37 | } 38 | 39 | public ProductQuantizer getProductQuantizer() { 40 | return pq; 41 | } 42 | 43 | public ExactSearch getCoarseQuantizer() { 44 | return cq; 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/analysis/IvfpqAnalyzerProvider.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.analysis; 15 | 16 | import org.elasticsearch.ann.ArrayUtils; 17 | import org.elasticsearch.ann.ExactSearch; 18 | import org.elasticsearch.ann.ProductQuantizer; 19 | import org.elasticsearch.common.settings.Settings; 20 | import org.elasticsearch.env.Environment; 21 | import org.elasticsearch.index.IndexSettings; 22 | import org.elasticsearch.index.analysis.AbstractIndexAnalyzerProvider; 23 | 24 | public class IvfpqAnalyzerProvider extends AbstractIndexAnalyzerProvider { 25 | 26 | private ProductQuantizer pq; 27 | 28 | private ExactSearch cq; 29 | 30 | public IvfpqAnalyzerProvider(IndexSettings indexSettings, Environment environment, String name, 31 | Settings settings) { 32 | super(indexSettings, name, settings); 33 | loadSettings(settings); 34 | } 35 | 36 | @Override 37 | public IvfpqAnalyzer get() { 38 | return new IvfpqAnalyzer(cq, pq); 39 | } 40 | 41 | private void loadSettings(Settings settings) { 42 | int m = settings.getAsInt("m", 0); 43 | int d = settings.getAsInt("d", 0); 44 | int ksub = settings.getAsInt("ksub", 0); 45 | float[] coarseCentroids = ArrayUtils.parseFloatArrayCsv(settings.get("coarseCentroids")); 46 | float[] pqCentroids = ArrayUtils.parseFloatArrayCsv(settings.get("pqCentroids")); 47 | this.cq = new ExactSearch(d, coarseCentroids); 48 | this.pq = new ProductQuantizer(d, m, ksub, pqCentroids); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/analysis/IvfpqTokenizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.analysis; 15 | 16 | import org.apache.logging.log4j.LogManager; 17 | import org.apache.logging.log4j.Logger; 18 | import org.apache.lucene.analysis.Tokenizer; 19 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 20 | import org.elasticsearch.ann.ArrayUtils; 21 | import org.elasticsearch.ann.ExactSearch; 22 | import org.elasticsearch.ann.ProductQuantizer; 23 | 24 | import java.io.IOException; 25 | 26 | public class IvfpqTokenizer extends Tokenizer { 27 | 28 | private static final Logger LOGGER = LogManager.getLogger(IvfpqTokenizer.class); 29 | 30 | private final CharTermAttribute charTermAttribute = addAttribute(CharTermAttribute.class); 31 | 32 | private final CodeAttribute codeAttribute = addAttribute(CodeAttribute.class); 33 | 34 | private ProductQuantizer pq; 35 | 36 | private ExactSearch cq; 37 | 38 | public IvfpqTokenizer(ExactSearch cq, ProductQuantizer pq) { 39 | this.cq = cq; 40 | this.pq = pq; 41 | } 42 | 43 | @Override 44 | public boolean incrementToken() throws IOException { 45 | StringBuilder stringBuilder = new StringBuilder(); 46 | for (int c = input.read(); c != -1; c = input.read()) { 47 | stringBuilder.append((char) c); 48 | } 49 | String value = stringBuilder.toString(); 50 | if (value.length() == 0) { 51 | return false; 52 | } 53 | float[] features = ArrayUtils.parseFloatArrayCsv(value); 54 | int coarseCenter = cq.searchNearest(features); 55 | if (coarseCenter != -1) { 56 | String coarseCenterText = String.valueOf(coarseCenter); 57 | charTermAttribute.copyBuffer(coarseCenterText.toCharArray(), 0, coarseCenterText.length()); 58 | float[] residual = cq.getResidual(coarseCenter, features); 59 | short[] codes = pq.getCodes(residual); 60 | codeAttribute.setCodes(codes); 61 | return true; 62 | } else { 63 | return false; 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/ann/AlgebraicOps.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | 17 | import java.util.Arrays; 18 | 19 | 20 | public class AlgebraicOps { 21 | 22 | public static void cumulativeSum(float[] x, float[] z) { 23 | int total = 0; 24 | for (int i = 0; i < x.length; ++i) { 25 | total += x[i]; 26 | z[i] = total; 27 | } 28 | } 29 | 30 | public static void square(float[] x, float[] z, int n, int d) { 31 | for (int i = 0, xoffset = 0; i < n; ++i, xoffset += d) { 32 | for (int j = 0; j < d; j++) { 33 | z[i] += x[xoffset + j] * x[xoffset + j]; 34 | } 35 | } 36 | } 37 | 38 | public static void transpose(float[] x, float[] z, int nx, int d) { 39 | for (int i = 0, xoffset = 0; i < nx; ++i, xoffset += d) { 40 | for (int j = 0, zoffset = 0; j < d; ++j, zoffset += nx) { 41 | z[zoffset + i] = x[xoffset + j]; 42 | } 43 | } 44 | } 45 | 46 | public static void multiply(float[] x, float[] y, float[] z, int nx, int ny, int d) { 47 | float[] ypart = new float[ny]; 48 | float[] zpart = new float[ny]; 49 | for (int i = 0, xoffset = 0, zoffset = 0; i < nx; ++i, xoffset += d, zoffset += ny) { 50 | for (int k = 0, yoffset = 0; k < d; ++k, yoffset += ny) { 51 | final float xval = x[xoffset + k]; 52 | System.arraycopy(y, yoffset, ypart, 0, ny); 53 | saxpy(ny, xval, ypart, zpart); 54 | } 55 | System.arraycopy(zpart, 0, z, zoffset, ny); 56 | Arrays.fill(zpart, 0f); 57 | } 58 | } 59 | 60 | public static void l2distance(float[] x, float[] y, float[] z, int nx, int d) { 61 | for (int i = 0, xoffset = 0; i < nx; ++i, xoffset += d) { 62 | for (int k = 0; k < d; ++k) { 63 | final float diff = x[xoffset + k] - y[k]; 64 | z[i] += diff * diff; 65 | } 66 | } 67 | } 68 | 69 | public static void multiplyElementwise(float[] x, float[] y, float[] z, int nx) { 70 | for (int i = 0; i < nx; ++i) { 71 | z[i] = x[i] * y[i]; 72 | } 73 | } 74 | 75 | public static int findNearest(float[] x, float[] y, int ny, int d) { 76 | int nearest = -1; 77 | float minDistance = Float.MAX_VALUE; 78 | float[] ypart = new float[d]; 79 | for (int i = 0, ioffset = 0; i < ny; ++i, ioffset += d) { 80 | System.arraycopy(y, ioffset, ypart, 0, d); 81 | float distance = 0; 82 | for (int j = 0; j < d; ++j) { 83 | final float diff = ypart[j] - x[j]; 84 | distance += diff * diff; 85 | } 86 | if (distance < minDistance) { 87 | minDistance = distance; 88 | nearest = i; 89 | } 90 | } 91 | return nearest; 92 | } 93 | 94 | private static void saxpy(int ny, float x, float[] y, float[] z) { 95 | for (int i = 0; i < ny; ++i) { 96 | z[i] += x * y[i]; 97 | } 98 | } 99 | 100 | } 101 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/ann/ArrayUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | import java.io.ByteArrayInputStream; 17 | import java.io.ByteArrayOutputStream; 18 | import java.io.DataInputStream; 19 | import java.io.DataOutputStream; 20 | import java.io.IOException; 21 | 22 | public class ArrayUtils { 23 | 24 | public static float[] parseFloatArrayCsv(String floatArrayText) { 25 | if (floatArrayText == null || floatArrayText.length() == 0) { 26 | return new float[]{}; 27 | } 28 | String[] texts = floatArrayText.split(","); 29 | float[] floats = new float[texts.length]; 30 | for (int i = 0; i < floats.length; ++i) { 31 | floats[i] = Float.valueOf(texts[i].trim()); 32 | } 33 | return floats; 34 | } 35 | 36 | public static byte[] encodeShortArray(short[] array) throws IOException { 37 | try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream( 38 | array.length * 2); 39 | DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) { 40 | for (short x : array) { 41 | dataOutputStream.writeShort(x); 42 | } 43 | dataOutputStream.flush(); 44 | return byteArrayOutputStream.toByteArray(); 45 | } 46 | } 47 | 48 | public static short[] decodeShortArray(byte[] array) throws IOException { 49 | short[] shorts = new short[array.length / 2]; 50 | try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(array); 51 | DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) { 52 | for (int i = 0; i < shorts.length; i++) { 53 | shorts[i] = dataInputStream.readShort(); 54 | } 55 | } 56 | return shorts; 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/ann/ExactSearch.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | import java.util.Arrays; 17 | 18 | public class ExactSearch { 19 | 20 | private final int d; 21 | 22 | private final int nlist; 23 | 24 | private final float[] centroids; 25 | 26 | public ExactSearch(int d, float[] centroids) { 27 | this.d = d; 28 | this.nlist = d == 0 ? 0 : centroids.length / d; 29 | this.centroids = centroids; 30 | } 31 | 32 | public float[] getResidual(int nearest, float[] feature) { 33 | float[] residual = new float[d]; 34 | System.arraycopy(centroids, nearest * d, residual, 0, d); 35 | for (int i = 0; i < d; ++i) { 36 | residual[i] = feature[i] - residual[i]; 37 | } 38 | return residual; 39 | } 40 | 41 | public int searchNearest(float[] feature) { 42 | return AlgebraicOps.findNearest(feature, centroids, nlist, d); 43 | } 44 | 45 | public int[] searchNearest(float[] feature, int k) { 46 | float[] distances = new float[nlist]; 47 | AlgebraicOps.l2distance(centroids, feature, distances, nlist, d); 48 | long[] encoded = new long[distances.length]; 49 | for (int i = 0; i < nlist; ++i) { 50 | encoded[i] = (((long) Float.floatToIntBits(distances[i])) << 32) | (i & 0xffffffffL); 51 | } 52 | Arrays.sort(encoded); 53 | int size = nlist >= k ? k : nlist; 54 | int[] result = new int[size]; 55 | for (int i = 0; i < size; ++i) { 56 | result[i] = (int) (encoded[i]); 57 | } 58 | return result; 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/ann/ProductQuantizer.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | public class ProductQuantizer { 17 | 18 | private final int m; 19 | 20 | private final int dsub; 21 | 22 | private final int ksub; 23 | 24 | private final float[] pqCentroids; 25 | 26 | public ProductQuantizer(int d, int m, int ksub, float[] pqCentroids) { 27 | this.m = m; 28 | this.dsub = m == 0 ? 0 : d / m; 29 | this.ksub = ksub; 30 | this.pqCentroids = pqCentroids; 31 | } 32 | 33 | public float getDistance(float[] codeTable, short[] codes) { 34 | float distance = 0; 35 | for (int i = 0, offset = 0; i < codes.length; ++i, offset += ksub) { 36 | distance += codeTable[offset + codes[i]]; 37 | } 38 | return distance; 39 | } 40 | 41 | public float[] getCodeTable(float[] feature) { 42 | final float[] codeTable = new float[m * ksub]; 43 | for (int i = 0, ioffset = 0, foffset = 0, toffset = 0, subLen = ksub * dsub; 44 | i < m; ++i, ioffset += subLen, foffset += dsub, toffset += ksub) { 45 | for (int j = 0, joffset = 0; j < ksub; ++j, joffset += dsub) { 46 | for (int k = 0; k < dsub; ++k) { 47 | final float diff = feature[foffset + k] - pqCentroids[ioffset + joffset + k]; 48 | codeTable[toffset + j] += diff * diff; 49 | } 50 | } 51 | } 52 | return codeTable; 53 | } 54 | 55 | public short[] getCodes(float[] feature) { 56 | short[] codes = new short[m]; 57 | float[] fpart = new float[dsub]; 58 | float[] cpart = new float[ksub * dsub]; 59 | for (int i = 0, ioffset = 0; i < m; ++i, ioffset += dsub) { 60 | System.arraycopy(feature, ioffset, fpart, 0, dsub); 61 | System.arraycopy(pqCentroids, ksub * ioffset, cpart, 0, ksub * dsub); 62 | codes[i] = (short) AlgebraicOps.findNearest(fpart, cpart, ksub, dsub); 63 | } 64 | return codes; 65 | } 66 | 67 | } 68 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/mapper/IvfpqFieldMapper.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.mapper; 15 | 16 | import org.apache.logging.log4j.LogManager; 17 | import org.apache.logging.log4j.Logger; 18 | import org.apache.lucene.analysis.TokenStream; 19 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 20 | import org.apache.lucene.document.Field; 21 | import org.apache.lucene.document.StoredField; 22 | import org.apache.lucene.document.StringField; 23 | import org.apache.lucene.index.IndexableField; 24 | import org.apache.lucene.search.Query; 25 | import org.apache.lucene.util.BytesRef; 26 | import org.elasticsearch.analysis.CodeAttribute; 27 | import org.elasticsearch.ann.ArrayUtils; 28 | import org.elasticsearch.common.settings.Settings; 29 | import org.elasticsearch.common.xcontent.XContentBuilder; 30 | import org.elasticsearch.common.xcontent.XContentParser; 31 | import org.elasticsearch.index.analysis.NamedAnalyzer; 32 | import org.elasticsearch.index.mapper.FieldMapper; 33 | import org.elasticsearch.index.mapper.MappedFieldType; 34 | import org.elasticsearch.index.mapper.Mapper; 35 | import org.elasticsearch.index.mapper.MapperParsingException; 36 | import org.elasticsearch.index.mapper.ParseContext; 37 | import org.elasticsearch.index.mapper.TypeParsers; 38 | import org.elasticsearch.index.query.QueryShardContext; 39 | 40 | import java.io.IOException; 41 | import java.util.List; 42 | import java.util.Map; 43 | 44 | 45 | public class IvfpqFieldMapper extends FieldMapper { 46 | 47 | private static final Logger LOGGER = LogManager.getLogger(IvfpqFieldMapper.class); 48 | 49 | public static final String CONTENT_TYPE = "ivfpq"; 50 | 51 | protected IvfpqFieldMapper(String simpleName, MappedFieldType fieldType, MappedFieldType defaultFieldType, 52 | Settings indexSettings, MultiFields multiFields, CopyTo copyTo) { 53 | super(simpleName, fieldType, defaultFieldType, indexSettings, multiFields, copyTo); 54 | } 55 | 56 | @Override 57 | protected String contentType() { 58 | return CONTENT_TYPE; 59 | } 60 | 61 | @Override 62 | protected void parseCreateField(ParseContext context, List fields) 63 | throws IOException { 64 | String value; 65 | if (context.externalValueSet()) { 66 | value = context.externalValue().toString(); 67 | } else { 68 | XContentParser parser = context.parser(); 69 | if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { 70 | value = fieldType().nullValueAsString(); 71 | } else { 72 | value = parser.textOrNull(); 73 | } 74 | } 75 | if (value == null) { 76 | return; 77 | } 78 | NamedAnalyzer indexAnalyzer = fieldType.indexAnalyzer(); 79 | try (TokenStream tokenStream = indexAnalyzer.tokenStream(name(), value)) { 80 | final CharTermAttribute charTermAttribute = 81 | tokenStream.addAttribute(CharTermAttribute.class); 82 | final CodeAttribute codeAttribute = tokenStream.addAttribute(CodeAttribute.class); 83 | tokenStream.reset(); 84 | if (tokenStream.incrementToken()) { 85 | fields.add(new StringField(name(), charTermAttribute.toString(), Field.Store.NO)); 86 | BytesRef bytes = new BytesRef(ArrayUtils.encodeShortArray(codeAttribute.getCodes())); 87 | fields.add(new StoredField(getCodesField(name()), bytes)); 88 | } 89 | tokenStream.end(); 90 | } 91 | } 92 | 93 | @Override 94 | protected void doXContentBody(XContentBuilder builder, boolean includeDefaults, Params params) 95 | throws IOException { 96 | super.doXContentBody(builder, includeDefaults, params); 97 | doXContentAnalyzers(builder, includeDefaults); 98 | } 99 | 100 | static class Defaults { 101 | static final MappedFieldType FIELD_TYPE = new IvfpqFieldType(); 102 | 103 | static { 104 | FIELD_TYPE.freeze(); 105 | } 106 | } 107 | 108 | public static class TypeParser implements Mapper.TypeParser { 109 | @Override 110 | public Mapper.Builder parse(String name, Map node, 111 | ParserContext parserContext) throws MapperParsingException { 112 | Builder builder = new Builder(name); 113 | TypeParsers.parseTextField(builder, name, node, parserContext); 114 | return builder; 115 | } 116 | } 117 | 118 | 119 | public static class Builder extends FieldMapper.Builder { 120 | 121 | Builder(String name) { 122 | super(name, Defaults.FIELD_TYPE, Defaults.FIELD_TYPE); 123 | builder = this; 124 | } 125 | 126 | @Override 127 | public IvfpqFieldMapper build(BuilderContext context) { 128 | setupFieldType(context); 129 | return new IvfpqFieldMapper(name, fieldType, defaultFieldType, context.indexSettings(), 130 | multiFieldsBuilder.build(this, context), copyTo); 131 | } 132 | } 133 | 134 | 135 | public static class IvfpqFieldType extends MappedFieldType { 136 | IvfpqFieldType() { 137 | } 138 | 139 | IvfpqFieldType(IvfpqFieldType ref) { 140 | super(ref); 141 | } 142 | 143 | @Override 144 | public String typeName() { 145 | return CONTENT_TYPE; 146 | } 147 | 148 | @Override 149 | public Query termQuery(Object o, QueryShardContext queryShardContext) { 150 | throw new UnsupportedOperationException(); 151 | } 152 | 153 | @Override 154 | public Query existsQuery(QueryShardContext queryShardContext) { 155 | throw new UnsupportedOperationException(); 156 | } 157 | 158 | @Override 159 | public IvfpqFieldType clone() { 160 | return new IvfpqFieldType(this); 161 | } 162 | 163 | } 164 | 165 | public static String getCodesField(String field) { 166 | return field + ".pq"; 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/plugin/AnnPlugin.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.plugin; 15 | 16 | import org.apache.lucene.analysis.Analyzer; 17 | import org.elasticsearch.analysis.IvfpqAnalyzerProvider; 18 | import org.elasticsearch.index.analysis.AnalyzerProvider; 19 | import org.elasticsearch.index.mapper.Mapper; 20 | import org.elasticsearch.indices.analysis.AnalysisModule; 21 | import org.elasticsearch.mapper.IvfpqFieldMapper; 22 | import org.elasticsearch.plugins.AnalysisPlugin; 23 | import org.elasticsearch.plugins.MapperPlugin; 24 | import org.elasticsearch.plugins.Plugin; 25 | import org.elasticsearch.plugins.SearchPlugin; 26 | import org.elasticsearch.search.IvfpqQueryBuilder; 27 | 28 | import java.util.ArrayList; 29 | import java.util.HashMap; 30 | import java.util.List; 31 | import java.util.Map; 32 | 33 | public class AnnPlugin extends Plugin implements AnalysisPlugin, MapperPlugin, SearchPlugin { 34 | 35 | @Override 36 | public Map>> getAnalyzers() { 37 | Map>> 38 | analyzers = new HashMap<>(); 39 | analyzers.put("ivfpq_analyzer", IvfpqAnalyzerProvider::new); 40 | return analyzers; 41 | } 42 | 43 | @Override 44 | public Map getMappers() { 45 | Map map = new HashMap<>(); 46 | map.put(IvfpqFieldMapper.CONTENT_TYPE, new IvfpqFieldMapper.TypeParser()); 47 | return map; 48 | } 49 | 50 | @Override 51 | public List> getQueries() { 52 | List> queries = new ArrayList<>(); 53 | queries.add(new QuerySpec<>(IvfpqQueryBuilder.NAME, IvfpqQueryBuilder::new, 54 | IvfpqQueryBuilder::fromXContent)); 55 | return queries; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/search/IvfpqQuery.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.search; 15 | 16 | import org.apache.logging.log4j.LogManager; 17 | import org.apache.logging.log4j.Logger; 18 | import org.apache.lucene.analysis.Analyzer; 19 | import org.apache.lucene.document.Document; 20 | import org.apache.lucene.index.LeafReaderContext; 21 | import org.apache.lucene.index.Term; 22 | import org.apache.lucene.queries.function.FunctionScoreQuery; 23 | import org.apache.lucene.search.DisjunctionMaxQuery; 24 | import org.apache.lucene.search.DoubleValues; 25 | import org.apache.lucene.search.DoubleValuesSource; 26 | import org.apache.lucene.search.IndexSearcher; 27 | import org.apache.lucene.search.Query; 28 | import org.apache.lucene.search.TermQuery; 29 | import org.apache.lucene.util.BytesRef; 30 | import org.elasticsearch.ElasticsearchException; 31 | import org.elasticsearch.analysis.IvfpqAnalyzer; 32 | import org.elasticsearch.ann.ArrayUtils; 33 | import org.elasticsearch.ann.ExactSearch; 34 | import org.elasticsearch.ann.ProductQuantizer; 35 | import org.elasticsearch.index.analysis.NamedAnalyzer; 36 | import org.elasticsearch.index.mapper.MappedFieldType; 37 | import org.elasticsearch.index.query.QueryShardContext; 38 | import org.elasticsearch.mapper.IvfpqFieldMapper; 39 | 40 | import java.io.IOException; 41 | import java.util.ArrayList; 42 | import java.util.Collections; 43 | import java.util.HashSet; 44 | import java.util.List; 45 | import java.util.Map; 46 | import java.util.Objects; 47 | 48 | class IvfpqQuery { 49 | 50 | private static final Logger LOGGER = LogManager.getLogger(IvfpqQuery.class); 51 | 52 | private QueryShardContext context; 53 | 54 | IvfpqQuery(QueryShardContext context) { 55 | this.context = context; 56 | } 57 | 58 | Query parse(Map fieldNames, Object value, int nprobe) { 59 | float[] features = ArrayUtils.parseFloatArrayCsv((String) value); 60 | List fieldQueries = new ArrayList<>(); 61 | for (String field : fieldNames.keySet()) { 62 | MappedFieldType fieldMapper = context.fieldMapper(field); 63 | Analyzer analyzer = context.getSearchAnalyzer(fieldMapper); 64 | while (analyzer instanceof NamedAnalyzer) { 65 | analyzer = ((NamedAnalyzer) analyzer).analyzer(); 66 | } 67 | if (!(analyzer instanceof IvfpqAnalyzer)) { 68 | throw new ElasticsearchException("illegal analyzer: " + analyzer); 69 | } 70 | ProductQuantizer pq = ((IvfpqAnalyzer) analyzer).getProductQuantizer(); 71 | ExactSearch cq = ((IvfpqAnalyzer) analyzer).getCoarseQuantizer(); 72 | for (int nearest : cq.searchNearest(features, nprobe)) { 73 | float[] residual = cq.getResidual(nearest, features); 74 | float[] table = pq.getCodeTable(residual); 75 | Query query = new FunctionScoreQuery( 76 | new TermQuery(new Term(field, String.valueOf(nearest))), 77 | new CustomValueSource(field, pq, table)); 78 | fieldQueries.add(query); 79 | } 80 | } 81 | return new DisjunctionMaxQuery(fieldQueries, 1.0f); 82 | } 83 | 84 | private class CustomValueSource extends DoubleValuesSource { 85 | 86 | private String field; 87 | 88 | private ProductQuantizer pq; 89 | 90 | private float[] codeTable; 91 | 92 | CustomValueSource(String field, ProductQuantizer pq, float[] codeTable) { 93 | this.field = field; 94 | this.pq = pq; 95 | this.codeTable = codeTable; 96 | } 97 | 98 | @Override 99 | public DoubleValues getValues(LeafReaderContext leafReaderContext, DoubleValues scores) { 100 | return new DoubleValues() { 101 | 102 | private float value; 103 | 104 | @Override 105 | public double doubleValue() { 106 | return value; 107 | } 108 | 109 | @Override 110 | public boolean advanceExact(int doc) throws IOException { 111 | Document document = leafReaderContext.reader() 112 | .document(doc, new HashSet<>(Collections.singletonList( 113 | IvfpqFieldMapper.getCodesField(field)))); 114 | BytesRef bytesRef = document.getBinaryValue( 115 | IvfpqFieldMapper.getCodesField(field)); 116 | if (bytesRef == null) { 117 | return false; 118 | } 119 | short[] codes = ArrayUtils.decodeShortArray(bytesRef.bytes); 120 | value = pq.getDistance(codeTable, codes); 121 | return true; 122 | } 123 | }; 124 | } 125 | 126 | @Override 127 | public boolean needsScores() { 128 | return false; 129 | } 130 | 131 | @Override 132 | public DoubleValuesSource rewrite(IndexSearcher reader) { 133 | return this; 134 | } 135 | 136 | @Override 137 | public boolean equals(Object o) { 138 | if (o instanceof IvfpqQuery) { 139 | return o.hashCode() == this.hashCode(); 140 | } 141 | return false; 142 | } 143 | 144 | @Override 145 | public int hashCode() { 146 | return Objects.hash(field, codeTable); 147 | } 148 | 149 | @Override 150 | public String toString() { 151 | return null; 152 | } 153 | 154 | @Override 155 | public boolean isCacheable(LeafReaderContext ctx) { 156 | return false; 157 | } 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /src/main/java/org/elasticsearch/search/IvfpqQueryBuilder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.search; 15 | 16 | import org.apache.lucene.search.Query; 17 | import org.elasticsearch.common.ParseField; 18 | import org.elasticsearch.common.io.stream.StreamInput; 19 | import org.elasticsearch.common.io.stream.StreamOutput; 20 | import org.elasticsearch.common.xcontent.DeprecationHandler; 21 | import org.elasticsearch.common.xcontent.XContentBuilder; 22 | import org.elasticsearch.common.xcontent.XContentParser; 23 | import org.elasticsearch.index.query.AbstractQueryBuilder; 24 | import org.elasticsearch.index.query.QueryShardContext; 25 | 26 | import java.io.IOException; 27 | import java.util.Map; 28 | import java.util.Objects; 29 | import java.util.TreeMap; 30 | 31 | public class IvfpqQueryBuilder extends AbstractQueryBuilder { 32 | 33 | public static final String NAME = "ivfpq_query"; 34 | 35 | private static final int DEFAULT_NPROBE = 3; 36 | 37 | private static final ParseField QUERY_FIELD = new ParseField("query"); 38 | 39 | private static final ParseField FIELDS_FIELD = new ParseField("fields"); 40 | 41 | private static final ParseField NPROBE_FIELD = new ParseField("nprobe"); 42 | 43 | private final Object value; 44 | 45 | private Map fieldsBoosts; 46 | 47 | private int nprobe; 48 | 49 | private IvfpqQueryBuilder(Object value, Map fieldsBoosts, int nprobe) { 50 | if (value == null) { 51 | throw new IllegalArgumentException("[" + NAME + "] requires query value"); 52 | } 53 | this.value = value; 54 | this.fieldsBoosts = fieldsBoosts; 55 | this.nprobe = nprobe; 56 | } 57 | 58 | public IvfpqQueryBuilder(StreamInput in) throws IOException { 59 | super(in); 60 | value = in.readGenericValue(); 61 | nprobe = in.readVInt(); 62 | int size = in.readVInt(); 63 | fieldsBoosts = new TreeMap<>(); 64 | for (int i = 0; i < size; i++) { 65 | fieldsBoosts.put(in.readString(), in.readFloat()); 66 | } 67 | } 68 | 69 | @Override 70 | protected void doWriteTo(StreamOutput streamOutput) throws IOException { 71 | streamOutput.writeGenericValue(value); 72 | streamOutput.writeVInt(nprobe); 73 | streamOutput.writeVInt(fieldsBoosts.size()); 74 | for (Map.Entry fieldsEntry : fieldsBoosts.entrySet()) { 75 | streamOutput.writeString(fieldsEntry.getKey()); 76 | streamOutput.writeFloat(fieldsEntry.getValue()); 77 | } 78 | } 79 | 80 | @Override 81 | protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { 82 | xContentBuilder.startObject(NAME); 83 | xContentBuilder.field(QUERY_FIELD.getPreferredName(), value); 84 | xContentBuilder.field(NPROBE_FIELD.getPreferredName(), nprobe); 85 | xContentBuilder.startArray(FIELDS_FIELD.getPreferredName()); 86 | for (Map.Entry fieldEntry : this.fieldsBoosts.entrySet()) { 87 | xContentBuilder.value(fieldEntry.getKey() + "^" + fieldEntry.getValue()); 88 | } 89 | xContentBuilder.endArray(); 90 | xContentBuilder.endObject(); 91 | } 92 | 93 | @Override 94 | protected Query doToQuery(QueryShardContext queryShardContext) { 95 | IvfpqQuery ivfpqQuery = new IvfpqQuery(queryShardContext); 96 | return ivfpqQuery.parse(fieldsBoosts, value, nprobe); 97 | } 98 | 99 | @Override 100 | protected boolean doEquals(IvfpqQueryBuilder ivfpqQueryBuilder) { 101 | return ivfpqQueryBuilder.value.equals(value) && ivfpqQueryBuilder.fieldsBoosts 102 | .equals(fieldsBoosts); 103 | } 104 | 105 | @Override 106 | protected int doHashCode() { 107 | return Objects.hash(value, fieldsBoosts); 108 | } 109 | 110 | @Override 111 | public String getWriteableName() { 112 | return NAME; 113 | } 114 | 115 | public static IvfpqQueryBuilder fromXContent(XContentParser parser) throws IOException { 116 | Object value = null; 117 | int nprobe = DEFAULT_NPROBE; 118 | Map fieldsBoosts = new TreeMap<>(); 119 | XContentParser.Token token; 120 | String currentFieldName = null; 121 | while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { 122 | if (token == XContentParser.Token.FIELD_NAME) { 123 | currentFieldName = parser.currentName(); 124 | } else if (FIELDS_FIELD 125 | .match(currentFieldName, DeprecationHandler.THROW_UNSUPPORTED_OPERATION)) { 126 | if (token == XContentParser.Token.START_ARRAY) { 127 | while (parser.nextToken() != XContentParser.Token.END_ARRAY) { 128 | parseFieldAndBoost(parser, fieldsBoosts); 129 | } 130 | } else if (token.isValue()) { 131 | parseFieldAndBoost(parser, fieldsBoosts); 132 | } 133 | } else if (token.isValue()) { 134 | if (QUERY_FIELD 135 | .match(currentFieldName, DeprecationHandler.THROW_UNSUPPORTED_OPERATION)) { 136 | value = parser.objectText(); 137 | } else if (NPROBE_FIELD.match(currentFieldName, DeprecationHandler.THROW_UNSUPPORTED_OPERATION)) { 138 | nprobe = parser.intValue(); 139 | } 140 | } 141 | } 142 | return new IvfpqQueryBuilder(value, fieldsBoosts, nprobe); 143 | } 144 | 145 | private static void parseFieldAndBoost(XContentParser parser, Map fieldsBoosts) 146 | throws IOException { 147 | String fField = null; 148 | float fBoost = AbstractQueryBuilder.DEFAULT_BOOST; 149 | char[] fieldText = parser.textCharacters(); 150 | int end = parser.textOffset() + parser.textLength(); 151 | for (int i = parser.textOffset(); i < end; i++) { 152 | if (fieldText[i] == '^') { 153 | int relativeLocation = i - parser.textOffset(); 154 | fField = new String(fieldText, parser.textOffset(), relativeLocation); 155 | fBoost = Float.parseFloat( 156 | new String(fieldText, i + 1, parser.textLength() - relativeLocation - 1)); 157 | break; 158 | } 159 | } 160 | if (fField == null) { 161 | fField = parser.text(); 162 | } 163 | fieldsBoosts.put(fField, fBoost); 164 | } 165 | 166 | } 167 | -------------------------------------------------------------------------------- /src/main/plugin-metadata/plugin-security.policy: -------------------------------------------------------------------------------- 1 | grant { 2 | }; 3 | -------------------------------------------------------------------------------- /src/test/java/org/elasticsearch/ann/ExactSearchTests.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | import org.apache.lucene.util.LuceneTestCase; 17 | 18 | public class ExactSearchTests extends LuceneTestCase { 19 | 20 | public void testGetResidual01() { 21 | float[] centroids = new float[]{0, 0.5F, 0.2F, 0.5F, 1, 1.1F}; 22 | ExactSearch exactSearch = new ExactSearch(3, centroids); 23 | float[] vector = new float[]{0.1F, 0.1F, 0.1F}; 24 | float[] residual0 = exactSearch.getResidual(0, vector); 25 | assertArrayEquals(new float[]{0.1F, -0.4F, -0.1F}, residual0, Float.MIN_NORMAL); 26 | float[] residual1 = exactSearch.getResidual(1, vector); 27 | assertArrayEquals(new float[]{-0.4F, -0.9F, -1}, residual1, Float.MIN_NORMAL); 28 | } 29 | 30 | public void testSearchNearest01() { 31 | float[] centroids = new float[]{0, 0, 0, 1, 1, 1}; 32 | ExactSearch exactSearch = new ExactSearch(3, centroids); 33 | int nearest; 34 | nearest = exactSearch.searchNearest(new float[]{0.1F, 0.3F, 0.5F}); 35 | assertEquals(0, nearest); 36 | nearest = exactSearch.searchNearest(new float[]{1.1F, 0.5F, 0.5F}); 37 | assertEquals(1, nearest); 38 | } 39 | 40 | public void testSearchNearest02() { 41 | float[] centroids = new float[]{0, 0, 0, 1, 1, 1, 2, 2, 2}; 42 | ExactSearch exactSearch = new ExactSearch(3, centroids); 43 | int[] nearest = exactSearch.searchNearest(new float[]{1.3F, 1, 0.9F}, 2); 44 | assertArrayEquals(new int[]{1, 2}, nearest); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/test/java/org/elasticsearch/ann/ProductQuantizerTests.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.ann; 15 | 16 | import org.apache.lucene.util.LuceneTestCase; 17 | 18 | public class ProductQuantizerTests extends LuceneTestCase { 19 | 20 | public void testGetDistance01() { 21 | int d = 4; 22 | int m = 2; 23 | int ksub = 2; 24 | float[] pqCentroids = new float[]{0, 0, 0.25F, 0, 0, 0.25F, 0.25F, 0.25F}; 25 | ProductQuantizer pq = new ProductQuantizer(d, m, ksub, pqCentroids); 26 | float[] feature = new float[]{0.25F, 0, 0, 0.25F}; 27 | assertEquals(0.25, feature[0], Float.MIN_NORMAL); 28 | assertEquals(0, feature[1], Float.MIN_NORMAL); 29 | assertEquals(0, feature[2], Float.MIN_NORMAL); 30 | assertEquals(0.25, feature[3], Float.MIN_NORMAL); 31 | float[] table = pq.getCodeTable(feature); 32 | short[] codes; 33 | float value; 34 | 35 | codes = new short[]{0, 0}; 36 | value = pq.getDistance(table, codes); 37 | assertEquals(0.25 * 0.25, value, Float.MIN_NORMAL); 38 | 39 | codes = new short[]{0, 1}; 40 | value = pq.getDistance(table, codes); 41 | assertEquals(0.25 * 0.25 * 2, value, Float.MIN_NORMAL); 42 | 43 | codes = new short[]{1, 0}; 44 | value = pq.getDistance(table, codes); 45 | assertEquals(0, value, Float.MIN_NORMAL); 46 | 47 | codes = new short[]{1, 1}; 48 | value = pq.getDistance(table, codes); 49 | assertEquals(0.25 * 0.25, value, Float.MIN_NORMAL); 50 | } 51 | 52 | public void testGetCodes01() { 53 | int d = 4; 54 | int m = 2; 55 | int ksub = 2; 56 | float[] pqCentroids = new float[]{0, 0, 0.25F, 0, 0.25F, 0, 0, 0}; 57 | ProductQuantizer pq = new ProductQuantizer(d, m, ksub, pqCentroids); 58 | float[] feature = new float[]{0.25F, 0, 0.25F, 0}; 59 | short[] codes = pq.getCodes(feature); 60 | assertArrayEquals(new short[]{1, 0}, codes); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/java/org/elasticsearch/plugin/AnnPluginIT.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | package org.elasticsearch.plugin; 15 | 16 | import org.elasticsearch.test.ESIntegTestCase; 17 | 18 | 19 | public class AnnPluginIT extends ESIntegTestCase { 20 | 21 | public void testEmpty() { 22 | assertTrue(true); 23 | } 24 | } 25 | --------------------------------------------------------------------------------