├── .prettierignore ├── tests ├── data │ └── git-stub ├── basic-config.js ├── very-big-config.js ├── testworker.js ├── worker-thread-test.js ├── smalltest.js ├── smalltest-manhattan.js └── basictests.js ├── .eslintignore ├── index.js ├── .gitignore ├── .npmignore ├── addon.cc ├── binding.gyp ├── .travis.yml ├── .eslintrc.js ├── package.json ├── tools ├── build-word-index-db.js ├── w2v-to-json.c └── test-word-index-db.js ├── Makefile ├── annoyindexwrapper.h ├── kissrandom.h ├── README.md ├── annoyindexwrapper.cc └── annoylib.h /.prettierignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/data/git-stub: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.eslintignore: -------------------------------------------------------------------------------- 1 | index.js 2 | node_modules 3 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | var annoyAddon = require('bindings')('addon'); 2 | module.exports = annoyAddon.Annoy; 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | tests/data/*.bin 4 | tests/data/*.annoy 5 | tests/data/*.json 6 | tools/w2v-to-json 7 | build 8 | *.swp 9 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | tests/data/*.bin 4 | tests/data/*.annoy* 5 | tests/data/*.json 6 | tests/data/*.db 7 | tools/w2v-to-json 8 | build 9 | .npmignore 10 | -------------------------------------------------------------------------------- /addon.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "annoyindexwrapper.h" 3 | 4 | void InitAll(v8::Local exports) { 5 | AnnoyIndexWrapper::Init(exports); 6 | } 7 | 8 | NODE_MODULE_INIT() { 9 | InitAll(exports); 10 | } 11 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "addon", 5 | "sources": [ "addon.cc", "annoyindexwrapper.cc" ], 6 | "include_dirs": [ 7 | " { 15 | // return the result to the main thread 16 | const result = annoyIndex.getNNsByItem(id, 10, -1, false) 17 | parentPort.postMessage(result); 18 | }); -------------------------------------------------------------------------------- /tests/worker-thread-test.js: -------------------------------------------------------------------------------- 1 | var test = require('tape'); 2 | const { StaticPool } = require('node-worker-threads-pool') 3 | 4 | var workerPath = __dirname + '/testworker.js'; 5 | 6 | test('Worker thread test', workerThreadTest); 7 | 8 | async function workerThreadTest(t) { 9 | const workerPool = new StaticPool({ 10 | size: 2, 11 | task: workerPath 12 | }); 13 | 14 | const idsToLookUp = [0, 1] 15 | const indexLookups = idsToLookUp.map(id => workerPool.exec(id)) 16 | console.log("Lookup results: ", (await Promise.all(indexLookups))) 17 | workerPool.destroy() 18 | t.end() 19 | } 20 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "annoy", 3 | "version": "4.0.0", 4 | "description": "Node bindings for Annoy, an efficient Approximate Nearest Neighbors implementation written in C++.", 5 | "main": "index.js", 6 | "gypfile": true, 7 | "scripts": { 8 | "test": "make test" 9 | }, 10 | "repository": { 11 | "type": "git", 12 | "url": "git@github.com:jimkang/annoy-node.git" 13 | }, 14 | "keywords": [ 15 | "annoy", 16 | "approximate", 17 | "nearest", 18 | "neighbor", 19 | "search", 20 | "vector", 21 | "machine learning" 22 | ], 23 | "author": "Jim Kang", 24 | "license": "MIT", 25 | "bugs": { 26 | "url": "https://github.com/jimkang/annoy-node/issues" 27 | }, 28 | "homepage": "https://github.com/jimkang/annoy-node", 29 | "devDependencies": { 30 | "assert-no-error": "^1.0.0", 31 | "d3-queue": "^3.0.3", 32 | "level-sublevel": "^6.6.4", 33 | "ndjson": "^1.5.0", 34 | "node-worker-threads-pool": "^1.5.1", 35 | "tape": "^5.6.6", 36 | "through2": "^2.0.1" 37 | }, 38 | "dependencies": { 39 | "bindings": "^1.2.1", 40 | "level": "^6.0.0", 41 | "nan": "^2.14.0" 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /tools/build-word-index-db.js: -------------------------------------------------------------------------------- 1 | /* global process */ 2 | 3 | var fs = require('fs'); 4 | var ndjson = require('ndjson'); 5 | var level = require('level'); 6 | var Sublevel = require('level-sublevel'); 7 | var through2 = require('through2'); 8 | var queue = require('d3-queue').queue; 9 | 10 | if (process.argv.length < 4) { 11 | console.log( 12 | 'Usage: node tools/build-word-index-db.js ' 13 | ); 14 | process.exit(); 15 | } 16 | 17 | const vectorJSONPath = process.argv[2]; 18 | const dbPath = process.argv[3]; 19 | 20 | var db = Sublevel(level(dbPath)); 21 | var indexesForWords = db.sublevel('indexes'); 22 | var wordsForIndexes = db.sublevel('words'); 23 | 24 | var vectorCount = 0; 25 | 26 | fs 27 | .createReadStream(vectorJSONPath) 28 | .pipe(ndjson.parse({ strict: false })) 29 | .pipe(through2({ objectMode: true }, addToDb)) 30 | .on('end', closeDb); 31 | 32 | function addToDb(wordVectorPair, enc, done) { 33 | var q = queue(); 34 | q.defer(indexesForWords.put, wordVectorPair.word, vectorCount); 35 | q.defer(wordsForIndexes.put, vectorCount, wordVectorPair.word); 36 | q.await(incrementCount); 37 | 38 | function incrementCount(error) { 39 | if (error) { 40 | throw error; 41 | } else { 42 | vectorCount += 1; 43 | done(); 44 | } 45 | } 46 | } 47 | 48 | function closeDb() { 49 | db.close(logDone); 50 | 51 | function logDone(error) { 52 | if (error) { 53 | console.error(error); 54 | } else { 55 | console.log('Done building db.'); 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TOOLS_CC = g++ 2 | TOOLS_CFLAGS = -lm -pthread -Ofast -march=native -Wall -funroll-loops -Wno-unused-result -std=c++11 3 | 4 | TESTDATADIR = tests/data 5 | 6 | build-wrapper: 7 | node-gyp rebuild 8 | 9 | test: tests/data/text8-vector.json 10 | node tests/smalltest.js 11 | node tests/worker-thread-test.js 12 | node tests/smalltest-manhattan.js 13 | node tests/basictests.js basic-config.js 14 | 15 | big-test: tests/data/GoogleNews-vectors-negative300.json 16 | node tests/basictests.js very-big-config.js 17 | 18 | tests/data/text8-vector.bin: 19 | wget https://github.com/jimkang/nearest-neighbor-test-data/raw/master/text8-vector.bin -O $(TESTDATADIR)/text8-vector.bin 20 | 21 | # If this fails, you end up with a 0-byte json file. 22 | # When you run this target again, make will see that the 23 | # file exists, then quit. To get it to run, delete 24 | # tests/data/GoogleNews-vectors-negative300.json 25 | # See README about where to get the bin file used in this 26 | # target. 27 | tests/data/GoogleNews-vectors-negative300.json: tools/w2v-to-json 28 | ./tools/w2v-to-json "$(TESTDATADIR)/GoogleNews-vectors-negative300.bin" tests/data/GoogleNews-vectors-negative300.json 29 | 30 | tools/w2v-to-json: 31 | $(TOOLS_CC) tools/w2v-to-json.c -o tools/w2v-to-json $(TOOLS_CFLAGS) 32 | 33 | tests/data/text8-vector.json: tests/data/text8-vector.bin tools/w2v-to-json 34 | ./tools/w2v-to-json tests/data/text8-vector.bin tests/data/text8-vector.json 35 | 36 | test-word-index-db: 37 | node tools/test-word-index-db.js tests/data/GoogleNews-vectors-negative300.json tests/data/word-index-google-news.db tests/data/very-big-test.annoy 38 | 39 | pushall: 40 | git push origin master && npm publish 41 | 42 | prettier: 43 | prettier --single-quote --write "**/*.js" 44 | -------------------------------------------------------------------------------- /tools/w2v-to-json.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | const long long max_w = 2000; 5 | 6 | void w2vToJSON(char *binPath, char *jsonPath) { 7 | 8 | FILE *fi = fopen(binPath, "rb"); 9 | // FILE *fWordIndex = fopen(wordIndexPath, "wb"); 10 | FILE *fJson = fopen(jsonPath, "wb"); 11 | 12 | long long words, dimensions; 13 | fscanf(fi, "%lld", &words); 14 | fscanf(fi, "%lld", &dimensions); 15 | fscanf(fi, "%*[ ]"); 16 | fscanf(fi, "%*[\n]"); 17 | 18 | // *numberOfDimensionsPerVector = dimensions; 19 | 20 | char word[max_w]; 21 | char ch; 22 | float value; 23 | int b, a; 24 | 25 | // Start outer array. 26 | // fprintf(fJson, "%s\n,", "["); 27 | 28 | for (b = 0; b < words; b++) { 29 | if(feof(fi)) 30 | break; 31 | 32 | word[0] = 0; 33 | fscanf(fi, "%[^ ]", word); 34 | fscanf(fi, "%c", &ch); 35 | 36 | // fprintf(fWordIndex, "%s:%d\n", word, b); 37 | fprintf(fJson, "{\"word\": \"%s\", \"vector\": [", word); 38 | // fprintf(fJson, "%s\n", "["); 39 | 40 | for (a = 0; a < dimensions; a++) { 41 | fread(&value, sizeof(float), 1, fi); 42 | if (a < dimensions - 1) { 43 | fprintf(fJson, "\"%f\",", value); 44 | } 45 | else { 46 | fprintf(fJson, "\"%f\"", value); 47 | } 48 | } 49 | 50 | fprintf(fJson, "%s\n", "]}"); 51 | 52 | fscanf(fi, "%*[\n]"); 53 | } 54 | 55 | // End outer array. 56 | // fprintf(fJson, "%s\n,", "]"); 57 | 58 | fclose(fi); 59 | fclose(fJson); 60 | } 61 | 62 | int main(int argc, char **argv) { 63 | if (argc < 3) { 64 | printf("Usage: w2v-to-json \n"); 65 | return -1; 66 | } 67 | 68 | char *binPath = argv[1]; 69 | char *jsonPath = argv[2]; 70 | 71 | w2vToJSON(binPath, jsonPath); 72 | } 73 | -------------------------------------------------------------------------------- /annoyindexwrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef ANNOYINDEXWRAPPER_H 2 | #define ANNOYINDEXWRAPPER_H 3 | 4 | #include 5 | #include "annoylib.h" 6 | 7 | class AnnoyIndexWrapper : public Nan::ObjectWrap { 8 | public: 9 | static void Init(v8::Local exports); 10 | int getDimensions(); 11 | AnnoyIndexInterface *annoyIndex; 12 | 13 | private: 14 | explicit AnnoyIndexWrapper(int dimensions, const char *metricString); 15 | virtual ~AnnoyIndexWrapper(); 16 | 17 | static void New(const Nan::FunctionCallbackInfo& info); 18 | static void AddItem(const Nan::FunctionCallbackInfo& info); 19 | static void Build(const Nan::FunctionCallbackInfo& info); 20 | static void Save(const Nan::FunctionCallbackInfo& info); 21 | static void Load(const Nan::FunctionCallbackInfo& info); 22 | static void Unload(const Nan::FunctionCallbackInfo& info); 23 | static void GetItem(const Nan::FunctionCallbackInfo& info); 24 | static void GetNNSByVector(const Nan::FunctionCallbackInfo& info); 25 | static void GetNNSByItem(const Nan::FunctionCallbackInfo& info); 26 | static void GetNItems(const Nan::FunctionCallbackInfo& info); 27 | static void GetDistance(const Nan::FunctionCallbackInfo& info); 28 | 29 | static Nan::Persistent constructor; 30 | static bool getFloatArrayParam(const Nan::FunctionCallbackInfo& info, 31 | int paramIndex, float *vec); 32 | static void setNNReturnValues( 33 | int numberOfNeighbors, bool includeDistances, 34 | const std::vector& nnIndexes, const std::vector& distances, 35 | const Nan::FunctionCallbackInfo& info); 36 | static void getSupplementaryGetNNsParams( 37 | const Nan::FunctionCallbackInfo& info, 38 | int& numberOfNeighbors, int& searchK, bool& includeDistances); 39 | 40 | int annoyDimensions; 41 | }; 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /tests/smalltest.js: -------------------------------------------------------------------------------- 1 | /* global __dirname */ 2 | 3 | var test = require('tape'); 4 | var Annoy = require('../index'); 5 | 6 | var annoyPath = __dirname + '/data/test.annoy'; 7 | 8 | test('Add test', addTest); 9 | test('Load test', loadTest); 10 | 11 | function addTest(t) { 12 | var obj = new Annoy(10, 'Angular'); 13 | 14 | obj.addItem(0, [-5.0, -4.5, -3.2, -2.8, -2.1, -1.5, -0.34, 0, 3.7, 6]); 15 | obj.addItem(1, [5.0, 4.5, 3.2, 2.8, 2.1, 1.5, 0.34, 0, -3.7, -6]); 16 | obj.addItem(2, [0, 0, 0, 0, 0, -1, -1, -0.2, 0.1, 0.8]); 17 | 18 | t.equal(obj.getNItems(), 3, 'Index has all the added items.'); 19 | 20 | obj.build(); 21 | t.ok(obj.save(annoyPath), 'Saved successfully.'); 22 | obj.unload(); 23 | t.end(); 24 | } 25 | 26 | function loadTest(t) { 27 | var obj2 = new Annoy(10, 'Angular'); 28 | var loadResult = obj2.load(annoyPath); 29 | t.ok(loadResult, 'Loads successfully.'); 30 | 31 | if (loadResult) { 32 | t.equal(obj2.getNItems(), 3, 'Number of items in index is correct.'); 33 | 34 | t.equal( 35 | obj2.getDistance(0, 1), 36 | 2.0, 37 | 'getDistance calculates correct distance between items 0 and 1.' 38 | ); 39 | 40 | var v1 = obj2.getItem(0); 41 | var v2 = obj2.getItem(1); 42 | 43 | var sum = []; 44 | for (var i = 0; i < v1.length; ++i) { 45 | sum.push(v1[i] + v2[i]); 46 | } 47 | // console.log('Sum:', sum); 48 | var neighbors = obj2.getNNsByVector(sum, 10, -1, false); 49 | t.ok(Array.isArray(neighbors), 'getNNsByVector result is an array.'); 50 | // console.log('Nearest neighbors to sum', neighbors); 51 | 52 | var nnResult = obj2.getNNsByVector(sum, 10, -1, true); 53 | checkNeighborsAndDistancesResult(nnResult); 54 | 55 | var neighborsByItem = obj2.getNNsByItem(1, 10, -1, false); 56 | t.ok(Array.isArray(neighborsByItem), 'NN by item result is an array.'); 57 | var nnResultByItem = obj2.getNNsByItem(1, 10, -1, true); 58 | checkNeighborsAndDistancesResult(nnResultByItem); 59 | } 60 | 61 | t.end(); 62 | 63 | function checkNeighborsAndDistancesResult(result) { 64 | t.equal(typeof result, 'object', 'NN result is an object.'); 65 | t.ok(Array.isArray(result.neighbors), 'NN result has a neighbors array.'); 66 | t.ok(Array.isArray(result.distances), 'NN result has a distances array.'); 67 | // console.log('Nearest neighbors to sum with distances', result); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /tools/test-word-index-db.js: -------------------------------------------------------------------------------- 1 | /* global process */ 2 | 3 | var fs = require('fs'); 4 | var ndjson = require('ndjson'); 5 | var level = require('level'); 6 | var Sublevel = require('level-sublevel'); 7 | var through2 = require('through2'); 8 | var test = require('tape'); 9 | var assertNoError = require('assert-no-error'); 10 | var Annoy = require('../index'); 11 | 12 | if (process.argv.length < 5) { 13 | console.log( 14 | 'Usage: node tools/test-word-index-db.js ' 15 | ); 16 | process.exit(); 17 | } 18 | 19 | const vectorJSONPath = process.argv[2]; 20 | const dbPath = process.argv[3]; 21 | const annoyPath = process.argv[4]; 22 | 23 | var db = Sublevel(level(dbPath)); 24 | var indexesForWords = db.sublevel('indexes'); 25 | var wordsForIndexes = db.sublevel('words'); 26 | var annoyIndex = new Annoy(300, 'Euclidean'); 27 | annoyIndex.load(annoyPath); 28 | 29 | var vectorCount = 0; 30 | 31 | fs 32 | .createReadStream(vectorJSONPath) 33 | .pipe(ndjson.parse({ strict: false })) 34 | .pipe(through2({ objectMode: true }, runTestOnPair)) 35 | .on('end', closeDb); 36 | 37 | function runTestOnPair(wordVectorPair, enc, done) { 38 | test(wordVectorPair.word + '/' + vectorCount, testPair); 39 | 40 | function testPair(t) { 41 | indexesForWords.get(wordVectorPair.word, checkIndex); 42 | 43 | function checkIndex(error, index) { 44 | assertNoError(t.ok, error, 'No error while getting index for word.'); 45 | t.equal(parseInt(index, 10), vectorCount, 'Index matches index in JSON.'); 46 | var vector = annoyIndex.getItem(index); 47 | t.deepEqual( 48 | vector.map(toFixedDecimalString), 49 | wordVectorPair.vector, 50 | 'Vector for index matches vector for same index in JSON.' 51 | ); 52 | 53 | wordsForIndexes.get(index, checkWord); 54 | } 55 | 56 | function checkWord(error, word) { 57 | assertNoError(t.ok, error, 'No error while getting word for index.'); 58 | t.equal(word, wordVectorPair.word, 'Word matches word in JSON.'); 59 | 60 | vectorCount += 1; 61 | 62 | t.end(); 63 | done(); 64 | } 65 | } 66 | } 67 | 68 | function closeDb() { 69 | db.close(logDone); 70 | 71 | function logDone(error) { 72 | if (error) { 73 | console.error(error); 74 | } else { 75 | console.log('Done testing db.'); 76 | } 77 | } 78 | } 79 | 80 | function toFixedDecimalString(num) { 81 | return num.toFixed(6); 82 | } 83 | -------------------------------------------------------------------------------- /tests/smalltest-manhattan.js: -------------------------------------------------------------------------------- 1 | /* global __dirname */ 2 | 3 | var test = require('tape'); 4 | var Annoy = require('../index'); 5 | 6 | var annoyPath = __dirname + '/data/test.annoy'; 7 | 8 | items = [ 9 | [-5.0, -4.5, -3.2, -2.8, -2.1, -1.5, -0.34, 0, 3.7, 6], 10 | [5.0, 4.5, 3.2, 2.8, 2.1, 1.5, 0.34, 0, -3.7, -6], 11 | [0, 0, 0, 0, 0, -1, -1, -0.2, 0.1, 0.8] 12 | ]; 13 | 14 | test('Add test', addTest); 15 | test('Load test', loadTest); 16 | 17 | function addTest(t) { 18 | var obj = new Annoy(10, 'Manhattan'); 19 | 20 | obj.addItem(0, items[0]); 21 | obj.addItem(1, items[1]); 22 | obj.addItem(2, items[2]); 23 | 24 | t.equal(obj.getNItems(), 3, 'Index has all the added items.'); 25 | 26 | obj.build(); 27 | t.ok(obj.save(annoyPath), 'Saved successfully.'); 28 | obj.unload(); 29 | t.end(); 30 | } 31 | 32 | function loadTest(t) { 33 | var obj2 = new Annoy(10, 'Manhattan'); 34 | var loadResult = obj2.load(annoyPath); 35 | t.ok(loadResult, 'Loads successfully.'); 36 | 37 | if (loadResult) { 38 | t.equal(obj2.getNItems(), 3, 'Number of items in index is correct.'); 39 | 40 | var dist = 0; 41 | for (var i = 0; i < items[0].length; i++) { 42 | dist += Math.abs(items[0][i] - items[1][i]); 43 | } 44 | 45 | t.equal( 46 | obj2.getDistance(0, 1).toPrecision(2), 47 | dist.toPrecision(2), 48 | 'getDistance calculates correct distance between items 0 and 1.' 49 | ); 50 | 51 | var v1 = obj2.getItem(0); 52 | var v2 = obj2.getItem(1); 53 | 54 | var sum = []; 55 | for (var i = 0; i < v1.length; ++i) { 56 | sum.push(v1[i] + v2[i]); 57 | } 58 | // console.log('Sum:', sum); 59 | var neighbors = obj2.getNNsByVector(sum, 10, -1, false); 60 | t.ok(Array.isArray(neighbors), 'getNNsByVector result is an array.'); 61 | // console.log('Nearest neighbors to sum', neighbors); 62 | 63 | var nnResult = obj2.getNNsByVector(sum, 10, -1, true); 64 | checkNeighborsAndDistancesResult(nnResult); 65 | 66 | var neighborsByItem = obj2.getNNsByItem(1, 10, -1, false); 67 | t.ok(Array.isArray(neighborsByItem), 'NN by item result is an array.'); 68 | var nnResultByItem = obj2.getNNsByItem(1, 10, -1, true); 69 | checkNeighborsAndDistancesResult(nnResultByItem); 70 | } 71 | 72 | t.end(); 73 | 74 | function checkNeighborsAndDistancesResult(result) { 75 | t.equal(typeof result, 'object', 'NN result is an object.'); 76 | t.ok(Array.isArray(result.neighbors), 'NN result has a neighbors array.'); 77 | t.ok(Array.isArray(result.distances), 'NN result has a distances array.'); 78 | // console.log('Nearest neighbors to sum with distances', result); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /kissrandom.h: -------------------------------------------------------------------------------- 1 | #ifndef KISSRANDOM_H 2 | #define KISSRANDOM_H 3 | 4 | #if defined(_MSC_VER) && _MSC_VER == 1500 5 | typedef unsigned __int32 uint32_t; 6 | typedef unsigned __int32 uint64_t; 7 | #else 8 | #include 9 | #endif 10 | 11 | // KISS = "keep it simple, stupid", but high quality random number generator 12 | // http://www0.cs.ucl.ac.uk/staff/d.jones/GoodPracticeRNG.pdf -> "Use a good RNG and build it into your code" 13 | // http://mathforum.org/kb/message.jspa?messageID=6627731 14 | // https://de.wikipedia.org/wiki/KISS_(Zufallszahlengenerator) 15 | 16 | // 32 bit KISS 17 | struct Kiss32Random { 18 | uint32_t x; 19 | uint32_t y; 20 | uint32_t z; 21 | uint32_t c; 22 | 23 | // seed must be != 0 24 | Kiss32Random(uint32_t seed = 123456789) { 25 | x = seed; 26 | y = 362436000; 27 | z = 521288629; 28 | c = 7654321; 29 | } 30 | 31 | uint32_t kiss() { 32 | // Linear congruence generator 33 | x = 69069 * x + 12345; 34 | 35 | // Xor shift 36 | y ^= y << 13; 37 | y ^= y >> 17; 38 | y ^= y << 5; 39 | 40 | // Multiply-with-carry 41 | uint64_t t = 698769069ULL * z + c; 42 | c = t >> 32; 43 | z = (uint32_t) t; 44 | 45 | return x + y + z; 46 | } 47 | inline int flip() { 48 | // Draw random 0 or 1 49 | return kiss() & 1; 50 | } 51 | inline size_t index(size_t n) { 52 | // Draw random integer between 0 and n-1 where n is at most the number of data points you have 53 | return kiss() % n; 54 | } 55 | inline void set_seed(uint32_t seed) { 56 | x = seed; 57 | } 58 | }; 59 | 60 | // 64 bit KISS. Use this if you have more than about 2^24 data points ("big data" ;) ) 61 | struct Kiss64Random { 62 | uint64_t x; 63 | uint64_t y; 64 | uint64_t z; 65 | uint64_t c; 66 | 67 | // seed must be != 0 68 | Kiss64Random(uint64_t seed = 1234567890987654321ULL) { 69 | x = seed; 70 | y = 362436362436362436ULL; 71 | z = 1066149217761810ULL; 72 | c = 123456123456123456ULL; 73 | } 74 | 75 | uint64_t kiss() { 76 | // Linear congruence generator 77 | z = 6906969069LL*z+1234567; 78 | 79 | // Xor shift 80 | y ^= (y<<13); 81 | y ^= (y>>17); 82 | y ^= (y<<43); 83 | 84 | // Multiply-with-carry (uint128_t t = (2^58 + 1) * x + c; c = t >> 64; x = (uint64_t) t) 85 | uint64_t t = (x<<58)+c; 86 | c = (x>>6); 87 | x += t; 88 | c += (x'); 10 | process.exit(); 11 | } 12 | 13 | var config = require('./' + process.argv[2]); 14 | 15 | const vectorJSONPath = __dirname + '/data/' + config.vectorJSONFile; 16 | const annoyIndexPath = __dirname + '/data/' + config.annoyFile; 17 | const dimensions = config.dimensions; 18 | 19 | var indexesForWords = {}; 20 | var wordsForIndexes = {}; 21 | var vectorCount = 0; 22 | 23 | test('Adding vectors to Annoy', addTest); 24 | test('Using vectors from Annoy', usingTest); 25 | 26 | function addTest(t) { 27 | var annoyIndex = new Annoy(dimensions, 'Euclidean'); 28 | 29 | fs 30 | .createReadStream(vectorJSONPath) 31 | .pipe(ndjson.parse({ strict: false })) 32 | .on('data', addToAnnoy) 33 | .on('end', checkAdded); 34 | 35 | function addToAnnoy(wordVectorPair) { 36 | indexesForWords[wordVectorPair.word] = vectorCount; 37 | wordsForIndexes[vectorCount] = wordVectorPair.word; 38 | annoyIndex.addItem(vectorCount, wordVectorPair.vector); 39 | // process.stdout.write('+'); 40 | vectorCount += 1; 41 | // console.log(wordVectorPair.word, wordVectorPair.vector); 42 | } 43 | 44 | function checkAdded() { 45 | t.ok(vectorCount > 0, 'More than one vector was added to the index.'); 46 | t.equal( 47 | annoyIndex.getNItems(), 48 | vectorCount, 49 | "The index's total vector count is correct." 50 | ); 51 | annoyIndex.build(); 52 | t.ok(annoyIndex.save(annoyIndexPath), 'Saved successfully.'); 53 | annoyIndex.unload(); 54 | t.end(); 55 | } 56 | } 57 | 58 | function usingTest(t) { 59 | var annoyIndex = new Annoy(dimensions, 'Euclidean'); 60 | t.ok(annoyIndex.load(annoyIndexPath), 'Loaded successfully.'); 61 | 62 | t.equal( 63 | annoyIndex.getNItems(), 64 | vectorCount, 65 | "The loaded index's total vector count is correct: " + vectorCount 66 | ); 67 | 68 | t.equal( 69 | annoyIndex 70 | .getDistance( 71 | indexesForWords[config.lookupWord1], 72 | indexesForWords[config.lookupWord2] 73 | ) 74 | .toPrecision(7), 75 | config.distanceBetweenWord1And2.toString(), 76 | 'getDistance calculates correct distance between items for ' + 77 | config.lookupWord1 + 78 | ' and ' + 79 | config.lookupWord2 80 | ); 81 | 82 | var v1 = annoyIndex.getItem(indexesForWords[config.lookupWord1]); 83 | var v2 = annoyIndex.getItem(indexesForWords[config.lookupWord2]); 84 | checkVector(v1); 85 | checkVector(v2); 86 | 87 | var sumVector = []; 88 | for (var i = 0; i < v1.length; ++i) { 89 | sumVector.push(v1[i] + v2[i]); 90 | } 91 | 92 | // console.log('Sum:', sumVector); 93 | var nnResult = annoyIndex.getNNsByVector(sumVector, 100, -1, true); 94 | // console.log('Neighbors and distances:', nnResult); 95 | 96 | checkNNResult('nnResult', nnResult); 97 | 98 | // console.log('Third closest neighbor:', wordsForIndexes[nnResult.neighbors[2]]); 99 | 100 | var nnResultByItem = annoyIndex.getNNsByItem( 101 | indexesForWords[config.indexLookupWord], 102 | 100, 103 | -1, 104 | true 105 | ); 106 | checkNNResult('nnResultByItem', nnResultByItem); 107 | 108 | t.end(); 109 | 110 | function checkNNResult(resultName, result) { 111 | t.equal(typeof result, 'object', resultName + ' is an object.'); 112 | t.ok( 113 | Array.isArray(result.neighbors), 114 | resultName + ' has a neighbors array.' 115 | ); 116 | t.equal( 117 | result.neighbors.length, 118 | 100, 119 | 'Correct number of neighbors is returned.' 120 | ); 121 | t.ok( 122 | result.neighbors.every(val => typeof val === 'number'), 123 | 'Neighbors contains all numbers.' 124 | ); 125 | 126 | t.ok( 127 | Array.isArray(result.distances), 128 | resultName + ' has a distances array.' 129 | ); 130 | t.equal( 131 | result.distances.length, 132 | 100, 133 | 'Correct number of distances is returned.' 134 | ); 135 | t.ok( 136 | result.distances.every(val => typeof val === 'number'), 137 | 'Distances contains all numbers.' 138 | ); 139 | } 140 | 141 | function checkVector(vector) { 142 | // console.log(vector); 143 | t.equal( 144 | vector.length, 145 | dimensions, 146 | 'Vector has correct number of dimensions.' 147 | ); 148 | t.ok( 149 | vector.every(val => typeof val === 'number'), 150 | 'Vector contains all numbers.' 151 | ); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | annoy-node 2 | ================== 3 | 4 | [![Build Status](https://travis-ci.org/jimkang/annoy-node.svg?branch=master)](https://travis-ci.org/jimkang/annoy-node) 5 | 6 | Node bindings for [Annoy](https://github.com/spotify/annoy), an efficient Approximate Nearest Neighbors implementation written in C++. 7 | 8 | Version 4.0.0 requires Node 14 or Node 16 and does not yet work on 18. 9 | 10 | Status: Tests pass, including one that loads 3 million vectors, but API coverage is not complete. Run on OS X and Linux with Node 8, 10, and 12. Not tried on Windows yet. Support for Node 6.3 and 4.6 ended at version 2.0.1 of this package. 11 | 12 | All of the [Python API](https://github.com/spotify/annoy#full-python-api) methods are implemented. The names are camel cased, JavaScript-style. 13 | 14 | - `addItem` 15 | - `build` 16 | - `save` 17 | - `load` 18 | - `unload` 19 | - `getItem` 20 | - `getNNsByVector` 21 | - `getNNsByItem` 22 | - `getNItems` 23 | - `getDistance` 24 | 25 | There are a few minor differences in behavior: 26 | 27 | - If you set the "include distances" param (the fourth param) when calling `getNNsByVector` and `getNNsByItem`, rather than returning a 2D array containing the neighbors and distances, it will return an object with the properties `neighbors` and `distances`, each of which is an array. 28 | - `get_item_vector` in with the Python API is just called `getItem` here. 29 | 30 | Installation 31 | ------------ 32 | 33 | On Linux, if you don't already have Python 2.7 and g++ 4.8, you need to install them. Here's how you do it on Ubuntu: 34 | 35 | (sudo) apt-get install python2.7 36 | (sudo) apt-get install g++-4.8 37 | npm config set python /path/to/executable/python2.7 38 | 39 | Then, symlink g++ somewhere it can be found: 40 | 41 | ln -s /usr/bin/g++-4.8 /usr/local/bin/g++ 42 | 43 | On OS X, they should already be there. 44 | 45 | Then: 46 | 47 | npm install annoy 48 | 49 | Usage 50 | ----- 51 | 52 | var Annoy = require('annoy'); 53 | var annoyIndex1 = new Annoy(10, 'Angular'); 54 | 55 | annoyIndex1.addItem(0, [-5.0, -4.5, -3.2, -2.8, -2.1, -1.5, -0.34, 0, 3.7, 6]); 56 | annoyIndex1.addItem(1, [5.0, 4.5, 3.2, 2.8, 2.1, 1.5, 0.34, 0, -3.7, -6]); 57 | annoyIndex1.addItem(2, [0, 0, 0, 0, 0, -1, -1, -0.2, 0.1, 0.8]); 58 | annoyIndex1.build(); 59 | annoyIndex1.save(annoyPath); 60 | 61 | read(); 62 | 63 | function read() { 64 | var annoyIndex2 = new Annoy(10, 'Angular'); 65 | 66 | if (annoyIndex2.load(annoyPath)) { 67 | var v1 = annoyIndex2.getItem(0); 68 | var v2 = annoyIndex2.getItem(1); 69 | console.log('Gotten vectors:', v1, v2); 70 | 71 | for (var i = 0; i < v1.length; ++i) { 72 | sum.push(v1[i] + v2[i]); 73 | } 74 | 75 | var neighbors = annoyIndex2.getNNsByVector(sum, 10, -1, false); 76 | console.log('Nearest neighbors to sum', neighbors); 77 | 78 | var neighborsAndDistances = annoyIndex2.getNNsByVector(sum, 10, -1, true); 79 | console.log('Nearest neighbors to sum with distances', neighborsAndDistances); 80 | } 81 | } 82 | 83 | Development 84 | ------------ 85 | 86 | npm install -g node-gyp 87 | node-gyp rebuild 88 | 89 | Run `eslint .` and `make prettier` before committing. 90 | 91 | Tests 92 | ----- 93 | 94 | Run tests with `make test`. 95 | 96 | You can also run tests individually: 97 | 98 | - This is a short baseline test: `node tests/smalltest.js` 99 | - This is a test that uses 70K 200-dimension vectors: `node tests/basictests.js` 100 | 101 | There is also a `big-test` target that is not a dependency of the `test` target. It loads about 3 million 300-dimension vectors. It takes about six minutes to run on good-for-2016 hardware. Before you can run it, you need to download [GoogleNews-vectors-negative300.bin](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing) to `tests/data`. 102 | 103 | Then, you can run `make tests/data/GoogleNews-vectors-negative300.json`, which takes a while, and gets the test data ready for the big test. (See comment about running that in the Makefile.) Then, `make big-test`. 104 | 105 | Contributors 106 | ------------ 107 | 108 | Thanks to: 109 | 110 | - [mbuszka](https://github.com/mbuszka) for [updating the wrapper to the latest Annoy (with Manhattan distance) and updating the random number generator](https://github.com/jimkang/annoy-node/pull/4). 111 | - [aaaton](https://github.com/aaaton) for [updating the example code so that it works](https://github.com/jimkang/annoy-node/pull/1). 112 | - [kornesh](https://github.com/kornesh) for [updating annoylib.h](https://github.com/jimkang/annoy-node/pull/10) to match the [Annoy of 2020-07-19](https://github.com/spotify/annoy/commit/7f2562add33eeb217dcdc755520c201aefc1b021). 113 | - [Benjaminrivard](https://github.com/Benjaminrivard) for [updating the wrapper for Node 14](https://github.com/jimkang/annoy-node/pull/13) and testing the thread support. 114 | - [S4N0I](https://github.com/S4N0I) for adding context-awareness so that it can be used in threads. 115 | 116 | License 117 | ------- 118 | 119 | The MIT License (MIT) 120 | 121 | Copyright (c) 2016 Jim Kang 122 | 123 | Permission is hereby granted, free of charge, to any person obtaining a copy 124 | of this software and associated documentation files (the "Software"), to deal 125 | in the Software without restriction, including without limitation the rights 126 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 127 | copies of the Software, and to permit persons to whom the Software is 128 | furnished to do so, subject to the following conditions: 129 | 130 | The above copyright notice and this permission notice shall be included in 131 | all copies or substantial portions of the Software. 132 | 133 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 134 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 135 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 136 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 137 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 138 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 139 | THE SOFTWARE. 140 | -------------------------------------------------------------------------------- /annoyindexwrapper.cc: -------------------------------------------------------------------------------- 1 | #include "annoyindexwrapper.h" 2 | #include "kissrandom.h" 3 | #include 4 | 5 | using namespace v8; 6 | using namespace Nan; 7 | 8 | Nan::Persistent AnnoyIndexWrapper::constructor; 9 | 10 | AnnoyIndexWrapper::AnnoyIndexWrapper(int dimensions, const char *metricString) : 11 | annoyDimensions(dimensions) { 12 | 13 | if (strcmp(metricString, "Angular") == 0) { 14 | annoyIndex = new AnnoyIndex(dimensions); 15 | } 16 | else if (strcmp(metricString, "Manhattan") == 0) { 17 | annoyIndex = new AnnoyIndex(dimensions); 18 | } 19 | else { 20 | annoyIndex = new AnnoyIndex(dimensions); 21 | } 22 | } 23 | 24 | AnnoyIndexWrapper::~AnnoyIndexWrapper() { 25 | delete annoyIndex; 26 | } 27 | 28 | void AnnoyIndexWrapper::Init(v8::Local exports) { 29 | Nan::HandleScope scope; 30 | 31 | // Prepare constructor template 32 | v8::Local tpl = Nan::New(New); 33 | tpl->SetClassName(Nan::New("Annoy").ToLocalChecked()); 34 | tpl->InstanceTemplate()->SetInternalFieldCount(2); 35 | 36 | // Prototype 37 | // Nan::SetPrototypeMethod(tpl, "value", GetValue); 38 | // Nan::SetPrototypeMethod(tpl, "plusOne", PlusOne); 39 | // Nan::SetPrototypeMethod(tpl, "multiply", Multiply); 40 | Nan::SetPrototypeMethod(tpl, "addItem", AddItem); 41 | Nan::SetPrototypeMethod(tpl, "build", Build); 42 | Nan::SetPrototypeMethod(tpl, "save", Save); 43 | Nan::SetPrototypeMethod(tpl, "load", Load); 44 | Nan::SetPrototypeMethod(tpl, "unload", Unload); 45 | Nan::SetPrototypeMethod(tpl, "getItem", GetItem); 46 | Nan::SetPrototypeMethod(tpl, "getNNsByVector", GetNNSByVector); 47 | Nan::SetPrototypeMethod(tpl, "getNNsByItem", GetNNSByItem); 48 | Nan::SetPrototypeMethod(tpl, "getNItems", GetNItems); 49 | Nan::SetPrototypeMethod(tpl, "getDistance", GetDistance); 50 | 51 | constructor.Reset(Nan::GetFunction(tpl).ToLocalChecked()); 52 | Nan::Set(exports, Nan::New("Annoy").ToLocalChecked(), Nan::GetFunction(tpl).ToLocalChecked()); 53 | } 54 | 55 | void AnnoyIndexWrapper::New(const Nan::FunctionCallbackInfo& info) { 56 | 57 | if (info.IsConstructCall()) { 58 | // Invoked as constructor: `new AnnoyIndexWrapper(...)` 59 | double dimensions = info[0]->IsUndefined() ? 0 : info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 60 | Local metricString; 61 | 62 | if (!info[1]->IsUndefined()) { 63 | Nan::MaybeLocal s = Nan::To(info[1]); 64 | if (!s.IsEmpty()) { 65 | metricString = s.ToLocalChecked(); 66 | } 67 | } 68 | 69 | AnnoyIndexWrapper* obj = new AnnoyIndexWrapper( 70 | (int)dimensions, *Nan::Utf8String(metricString) 71 | ); 72 | obj->Wrap(info.This()); 73 | info.GetReturnValue().Set(info.This()); 74 | } 75 | } 76 | 77 | void AnnoyIndexWrapper::AddItem(const Nan::FunctionCallbackInfo& info) { 78 | // Get out object. 79 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 80 | // Get out index. 81 | int index = info[0]->IsUndefined() ? 1 : info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 82 | // Get out array. 83 | float vec[obj->getDimensions()]; 84 | if (getFloatArrayParam(info, 1, vec)) { 85 | obj->annoyIndex->add_item(index, vec); 86 | } 87 | } 88 | 89 | void AnnoyIndexWrapper::Build(const Nan::FunctionCallbackInfo& info) { 90 | // Get out object. 91 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 92 | // Get out numberOfTrees. 93 | int numberOfTrees = info[0]->IsUndefined() ? 1 : info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 94 | // printf("%s\n", "Calling build"); 95 | obj->annoyIndex->build(numberOfTrees); 96 | } 97 | 98 | void AnnoyIndexWrapper::Save(const Nan::FunctionCallbackInfo& info) { 99 | bool result = false; 100 | 101 | // Get out object. 102 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 103 | // Get out file path. 104 | if (!info[0]->IsUndefined()) { 105 | Nan::MaybeLocal maybeStr = Nan::To(info[0]); 106 | v8::Local str; 107 | if (maybeStr.ToLocal(&str)) { 108 | result = obj->annoyIndex->save(*Nan::Utf8String(str)); 109 | } 110 | } 111 | info.GetReturnValue().Set(Nan::New(result)); 112 | } 113 | 114 | void AnnoyIndexWrapper::Load(const Nan::FunctionCallbackInfo& info) { 115 | bool result = false; 116 | // Get out object. 117 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 118 | // Get out file path. 119 | if (!info[0]->IsUndefined()) { 120 | Nan::MaybeLocal maybeStr = Nan::To(info[0]); 121 | v8::Local str; 122 | if (maybeStr.ToLocal(&str)) { 123 | result = obj->annoyIndex->load(*Nan::Utf8String(str)); 124 | } 125 | } 126 | info.GetReturnValue().Set(Nan::New(result)); 127 | } 128 | 129 | void AnnoyIndexWrapper::Unload(const Nan::FunctionCallbackInfo& info) { 130 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 131 | obj->annoyIndex->unload(); 132 | } 133 | 134 | void AnnoyIndexWrapper::GetItem(const Nan::FunctionCallbackInfo& info) { 135 | Nan::HandleScope scope; 136 | 137 | // Get out object. 138 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 139 | 140 | // Get out index. 141 | int index = info[0]->IsUndefined() ? 1 : info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 142 | 143 | // Get the vector. 144 | int length = obj->getDimensions(); 145 | float vec[length]; 146 | obj->annoyIndex->get_item(index, vec); 147 | 148 | // Allocate the return array. 149 | Local results = Nan::New(length); 150 | for (int i = 0; i < length; ++i) { 151 | // printf("Adding to array: %f\n", vec[i]); 152 | Nan::Set(results, i, Nan::New(vec[i])); 153 | } 154 | 155 | info.GetReturnValue().Set(results); 156 | } 157 | 158 | void AnnoyIndexWrapper::GetDistance(const Nan::FunctionCallbackInfo& info) { 159 | // Get out object. 160 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 161 | 162 | // Get out indexes. 163 | int indexA = info[0]->IsUndefined() ? 0 : info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 164 | int indexB = info[1]->IsUndefined() ? 0 : info[1]->NumberValue(Nan::GetCurrentContext()).FromJust(); 165 | 166 | // Return the distances. 167 | info.GetReturnValue().Set(obj->annoyIndex->get_distance(indexA, indexB)); 168 | } 169 | 170 | void AnnoyIndexWrapper::GetNNSByVector(const Nan::FunctionCallbackInfo& info) { 171 | Nan::HandleScope scope; 172 | 173 | int numberOfNeighbors, searchK; 174 | bool includeDistances; 175 | getSupplementaryGetNNsParams(info, numberOfNeighbors, searchK, includeDistances); 176 | 177 | // Get out object. 178 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 179 | // Get out array. 180 | float vec[obj->getDimensions()]; 181 | if (!getFloatArrayParam(info, 0, vec)) { 182 | return; 183 | } 184 | 185 | std::vector nnIndexes; 186 | std::vector distances; 187 | std::vector *distancesPtr = nullptr; 188 | 189 | if (includeDistances) { 190 | distancesPtr = &distances; 191 | } 192 | 193 | // Make the call. 194 | obj->annoyIndex->get_nns_by_vector( 195 | vec, numberOfNeighbors, searchK, &nnIndexes, distancesPtr 196 | ); 197 | 198 | setNNReturnValues(numberOfNeighbors, includeDistances, nnIndexes, distances, info); 199 | } 200 | 201 | void AnnoyIndexWrapper::GetNNSByItem(const Nan::FunctionCallbackInfo& info) { 202 | Nan::HandleScope scope; 203 | 204 | // Get out object. 205 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 206 | 207 | if (info[0]->IsUndefined()) { 208 | return; 209 | } 210 | 211 | // Get out params. 212 | int index = info[0]->NumberValue(Nan::GetCurrentContext()).FromJust(); 213 | int numberOfNeighbors, searchK; 214 | bool includeDistances; 215 | getSupplementaryGetNNsParams(info, numberOfNeighbors, searchK, includeDistances); 216 | 217 | std::vector nnIndexes; 218 | std::vector distances; 219 | std::vector *distancesPtr = nullptr; 220 | 221 | if (includeDistances) { 222 | distancesPtr = &distances; 223 | } 224 | 225 | // Make the call. 226 | obj->annoyIndex->get_nns_by_item( 227 | index, numberOfNeighbors, searchK, &nnIndexes, distancesPtr 228 | ); 229 | 230 | setNNReturnValues(numberOfNeighbors, includeDistances, nnIndexes, distances, info); 231 | } 232 | 233 | void AnnoyIndexWrapper::getSupplementaryGetNNsParams( 234 | const Nan::FunctionCallbackInfo& info, 235 | int& numberOfNeighbors, int& searchK, bool& includeDistances) { 236 | 237 | v8::Local context = Nan::GetCurrentContext(); 238 | 239 | // Get out number of neighbors. 240 | numberOfNeighbors = info[1]->IsUndefined() ? 1 : info[1]->NumberValue(context).FromJust(); 241 | 242 | // Get out searchK. 243 | searchK = info[2]->IsUndefined() ? -1 : info[2]->NumberValue(context).FromJust(); 244 | 245 | // Get out include distances flag. 246 | includeDistances = info[3]->IsUndefined() ? false : Nan::To(info[3]).FromJust(); 247 | } 248 | 249 | void AnnoyIndexWrapper::setNNReturnValues( 250 | int numberOfNeighbors, bool includeDistances, 251 | const std::vector& nnIndexes, const std::vector& distances, 252 | const Nan::FunctionCallbackInfo& info) { 253 | 254 | // Allocate the neighbors array. 255 | Local jsNNIndexes = Nan::New(numberOfNeighbors); 256 | for (int i = 0; i < numberOfNeighbors; ++i) { 257 | // printf("Adding to neighbors array: %d\n", nnIndexes[i]); 258 | Nan::Set(jsNNIndexes, i, Nan::New(nnIndexes[i])); 259 | } 260 | 261 | Local jsResultObject; 262 | Local jsDistancesArray; 263 | 264 | if (includeDistances) { 265 | // Allocate the distances array. 266 | jsDistancesArray = Nan::New(numberOfNeighbors); 267 | for (int i = 0; i < numberOfNeighbors; ++i) { 268 | // printf("Adding to distances array: %f\n", distances[i]); 269 | Nan::Set(jsDistancesArray, i, Nan::New(distances[i])); 270 | } 271 | 272 | jsResultObject = Nan::New(); 273 | Nan::Set(jsResultObject, Nan::New("neighbors").ToLocalChecked(), jsNNIndexes); 274 | Nan::Set(jsResultObject, Nan::New("distances").ToLocalChecked(), jsDistancesArray); 275 | } 276 | else { 277 | jsResultObject = jsNNIndexes.As(); 278 | } 279 | 280 | info.GetReturnValue().Set(jsResultObject); 281 | } 282 | 283 | void AnnoyIndexWrapper::GetNItems(const Nan::FunctionCallbackInfo& info) { 284 | // Get out object. 285 | AnnoyIndexWrapper* obj = ObjectWrap::Unwrap(info.Holder()); 286 | Local count = Nan::New(obj->annoyIndex->get_n_items()); 287 | info.GetReturnValue().Set(count); 288 | } 289 | 290 | // Returns true if it was able to get items out of the array. false, if not. 291 | bool AnnoyIndexWrapper::getFloatArrayParam( 292 | const Nan::FunctionCallbackInfo& info, int paramIndex, float *vec) { 293 | 294 | bool succeeded = false; 295 | 296 | Local val; 297 | if (info[paramIndex]->IsArray()) { 298 | // TODO: Make sure it really is OK to use Local instead of Handle here. 299 | Local jsArray = Local::Cast(info[paramIndex]); 300 | Local val; 301 | for (unsigned int i = 0; i < jsArray->Length(); i++) { 302 | val = Nan::Get(jsArray, i).ToLocalChecked(); 303 | // printf("Adding item to array: %f\n", (float)val->NumberValue(Nan::GetCurrentContext()).FromJust()); 304 | vec[i] = (float)val->NumberValue(Nan::GetCurrentContext()).FromJust(); 305 | } 306 | succeeded = true; 307 | } 308 | 309 | return succeeded; 310 | } 311 | 312 | int AnnoyIndexWrapper::getDimensions() { 313 | return annoyDimensions; 314 | } 315 | 316 | -------------------------------------------------------------------------------- /annoylib.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2013 Spotify AB 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 | // use this file except in compliance with the License. You may obtain a copy of 5 | // the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 | // License for the specific language governing permissions and limitations under 13 | // the License. 14 | 15 | 16 | #ifndef ANNOYLIB_H 17 | #define ANNOYLIB_H 18 | 19 | #include 20 | #include 21 | #ifndef _MSC_VER 22 | #include 23 | #endif 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #if defined(_MSC_VER) && _MSC_VER == 1500 31 | typedef unsigned char uint8_t; 32 | typedef signed __int32 int32_t; 33 | typedef unsigned __int64 uint64_t; 34 | typedef signed __int64 int64_t; 35 | #else 36 | #include 37 | #endif 38 | 39 | #if defined(_MSC_VER) || defined(__MINGW32__) 40 | // a bit hacky, but override some definitions to support 64 bit 41 | #define off_t int64_t 42 | #define lseek_getsize(fd) _lseeki64(fd, 0, SEEK_END) 43 | #ifndef NOMINMAX 44 | #define NOMINMAX 45 | #endif 46 | #include "mman.h" 47 | #include 48 | #else 49 | #include 50 | #define lseek_getsize(fd) lseek(fd, 0, SEEK_END) 51 | #endif 52 | 53 | #include 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | #include 60 | 61 | #ifdef _MSC_VER 62 | // Needed for Visual Studio to disable runtime checks for mempcy 63 | #pragma runtime_checks("s", off) 64 | #endif 65 | 66 | // This allows others to supply their own logger / error printer without 67 | // requiring Annoy to import their headers. See RcppAnnoy for a use case. 68 | #ifndef __ERROR_PRINTER_OVERRIDE__ 69 | #define showUpdate(...) { fprintf(stderr, __VA_ARGS__ ); } 70 | #else 71 | #define showUpdate(...) { __ERROR_PRINTER_OVERRIDE__( __VA_ARGS__ ); } 72 | #endif 73 | 74 | // Portable alloc definition, cf Writing R Extensions, Section 1.6.4 75 | #ifdef __GNUC__ 76 | // Includes GCC, clang and Intel compilers 77 | # undef alloca 78 | # define alloca(x) __builtin_alloca((x)) 79 | #elif defined(__sun) || defined(_AIX) 80 | // this is necessary (and sufficient) for Solaris 10 and AIX 6: 81 | # include 82 | #endif 83 | 84 | inline void set_error_from_errno(char **error, const char* msg) { 85 | showUpdate("%s: %s (%d)\n", msg, strerror(errno), errno); 86 | if (error) { 87 | *error = (char *)malloc(256); // TODO: win doesn't support snprintf 88 | sprintf(*error, "%s: %s (%d)", msg, strerror(errno), errno); 89 | } 90 | } 91 | 92 | inline void set_error_from_string(char **error, const char* msg) { 93 | showUpdate("%s\n", msg); 94 | if (error) { 95 | *error = (char *)malloc(strlen(msg) + 1); 96 | strcpy(*error, msg); 97 | } 98 | } 99 | 100 | // We let the v array in the Node struct take whatever space is needed, so this is a mostly insignificant number. 101 | // Compilers need *some* size defined for the v array, and some memory checking tools will flag for buffer overruns if this is set too low. 102 | #define V_ARRAY_SIZE 65536 103 | 104 | #ifndef _MSC_VER 105 | #define popcount __builtin_popcountll 106 | #else // See #293, #358 107 | #define isnan(x) _isnan(x) 108 | #define popcount cole_popcount 109 | #endif 110 | 111 | #if !defined(NO_MANUAL_VECTORIZATION) && defined(__GNUC__) && (__GNUC__ >6) && defined(__AVX512F__) // See #402 112 | #define USE_AVX512 113 | #elif !defined(NO_MANUAL_VECTORIZATION) && defined(__AVX__) && defined (__SSE__) && defined(__SSE2__) && defined(__SSE3__) 114 | #define USE_AVX 115 | #else 116 | #endif 117 | 118 | #if defined(USE_AVX) || defined(USE_AVX512) 119 | #if defined(_MSC_VER) 120 | #include 121 | #elif defined(__GNUC__) 122 | #include 123 | #endif 124 | #endif 125 | 126 | 127 | using std::vector; 128 | using std::pair; 129 | using std::numeric_limits; 130 | using std::make_pair; 131 | 132 | inline void* remap_memory(void* _ptr, int _fd, size_t old_size, size_t new_size) { 133 | #ifdef __linux__ 134 | _ptr = mremap(_ptr, old_size, new_size, MREMAP_MAYMOVE); 135 | #else 136 | munmap(_ptr, old_size); 137 | #ifdef MAP_POPULATE 138 | _ptr = mmap(_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0); 139 | #else 140 | _ptr = mmap(_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0); 141 | #endif 142 | #endif 143 | return _ptr; 144 | } 145 | 146 | namespace { 147 | 148 | template 149 | inline Node* get_node_ptr(const void* _nodes, const size_t _s, const S i) { 150 | return (Node*)((uint8_t *)_nodes + (_s * i)); 151 | } 152 | 153 | template 154 | inline T dot(const T* x, const T* y, int f) { 155 | T s = 0; 156 | for (int z = 0; z < f; z++) { 157 | s += (*x) * (*y); 158 | x++; 159 | y++; 160 | } 161 | return s; 162 | } 163 | 164 | template 165 | inline T manhattan_distance(const T* x, const T* y, int f) { 166 | T d = 0.0; 167 | for (int i = 0; i < f; i++) 168 | d += fabs(x[i] - y[i]); 169 | return d; 170 | } 171 | 172 | template 173 | inline T euclidean_distance(const T* x, const T* y, int f) { 174 | // Don't use dot-product: avoid catastrophic cancellation in #314. 175 | T d = 0.0; 176 | for (int i = 0; i < f; ++i) { 177 | const T tmp=*x - *y; 178 | d += tmp * tmp; 179 | ++x; 180 | ++y; 181 | } 182 | return d; 183 | } 184 | 185 | #ifdef USE_AVX 186 | // Horizontal single sum of 256bit vector. 187 | inline float hsum256_ps_avx(__m256 v) { 188 | const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v)); 189 | const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); 190 | const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); 191 | return _mm_cvtss_f32(x32); 192 | } 193 | 194 | template<> 195 | inline float dot(const float* x, const float *y, int f) { 196 | float result = 0; 197 | if (f > 7) { 198 | __m256 d = _mm256_setzero_ps(); 199 | for (; f > 7; f -= 8) { 200 | d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y))); 201 | x += 8; 202 | y += 8; 203 | } 204 | // Sum all floats in dot register. 205 | result += hsum256_ps_avx(d); 206 | } 207 | // Don't forget the remaining values. 208 | for (; f > 0; f--) { 209 | result += *x * *y; 210 | x++; 211 | y++; 212 | } 213 | return result; 214 | } 215 | 216 | template<> 217 | inline float manhattan_distance(const float* x, const float* y, int f) { 218 | float result = 0; 219 | int i = f; 220 | if (f > 7) { 221 | __m256 manhattan = _mm256_setzero_ps(); 222 | __m256 minus_zero = _mm256_set1_ps(-0.0f); 223 | for (; i > 7; i -= 8) { 224 | const __m256 x_minus_y = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)); 225 | const __m256 distance = _mm256_andnot_ps(minus_zero, x_minus_y); // Absolute value of x_minus_y (forces sign bit to zero) 226 | manhattan = _mm256_add_ps(manhattan, distance); 227 | x += 8; 228 | y += 8; 229 | } 230 | // Sum all floats in manhattan register. 231 | result = hsum256_ps_avx(manhattan); 232 | } 233 | // Don't forget the remaining values. 234 | for (; i > 0; i--) { 235 | result += fabsf(*x - *y); 236 | x++; 237 | y++; 238 | } 239 | return result; 240 | } 241 | 242 | template<> 243 | inline float euclidean_distance(const float* x, const float* y, int f) { 244 | float result=0; 245 | if (f > 7) { 246 | __m256 d = _mm256_setzero_ps(); 247 | for (; f > 7; f -= 8) { 248 | const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)); 249 | d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX... 250 | x += 8; 251 | y += 8; 252 | } 253 | // Sum all floats in dot register. 254 | result = hsum256_ps_avx(d); 255 | } 256 | // Don't forget the remaining values. 257 | for (; f > 0; f--) { 258 | float tmp = *x - *y; 259 | result += tmp * tmp; 260 | x++; 261 | y++; 262 | } 263 | return result; 264 | } 265 | 266 | #endif 267 | 268 | #ifdef USE_AVX512 269 | template<> 270 | inline float dot(const float* x, const float *y, int f) { 271 | float result = 0; 272 | if (f > 15) { 273 | __m512 d = _mm512_setzero_ps(); 274 | for (; f > 15; f -= 16) { 275 | //AVX512F includes FMA 276 | d = _mm512_fmadd_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y), d); 277 | x += 16; 278 | y += 16; 279 | } 280 | // Sum all floats in dot register. 281 | result += _mm512_reduce_add_ps(d); 282 | } 283 | // Don't forget the remaining values. 284 | for (; f > 0; f--) { 285 | result += *x * *y; 286 | x++; 287 | y++; 288 | } 289 | return result; 290 | } 291 | 292 | template<> 293 | inline float manhattan_distance(const float* x, const float* y, int f) { 294 | float result = 0; 295 | int i = f; 296 | if (f > 15) { 297 | __m512 manhattan = _mm512_setzero_ps(); 298 | for (; i > 15; i -= 16) { 299 | const __m512 x_minus_y = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y)); 300 | manhattan = _mm512_add_ps(manhattan, _mm512_abs_ps(x_minus_y)); 301 | x += 16; 302 | y += 16; 303 | } 304 | // Sum all floats in manhattan register. 305 | result = _mm512_reduce_add_ps(manhattan); 306 | } 307 | // Don't forget the remaining values. 308 | for (; i > 0; i--) { 309 | result += fabsf(*x - *y); 310 | x++; 311 | y++; 312 | } 313 | return result; 314 | } 315 | 316 | template<> 317 | inline float euclidean_distance(const float* x, const float* y, int f) { 318 | float result=0; 319 | if (f > 15) { 320 | __m512 d = _mm512_setzero_ps(); 321 | for (; f > 15; f -= 16) { 322 | const __m512 diff = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y)); 323 | d = _mm512_fmadd_ps(diff, diff, d); 324 | x += 16; 325 | y += 16; 326 | } 327 | // Sum all floats in dot register. 328 | result = _mm512_reduce_add_ps(d); 329 | } 330 | // Don't forget the remaining values. 331 | for (; f > 0; f--) { 332 | float tmp = *x - *y; 333 | result += tmp * tmp; 334 | x++; 335 | y++; 336 | } 337 | return result; 338 | } 339 | 340 | #endif 341 | 342 | 343 | template 344 | inline T get_norm(T* v, int f) { 345 | return sqrt(dot(v, v, f)); 346 | } 347 | 348 | template 349 | inline void two_means(const vector& nodes, int f, Random& random, bool cosine, Node* p, Node* q) { 350 | /* 351 | This algorithm is a huge heuristic. Empirically it works really well, but I 352 | can't motivate it well. The basic idea is to keep two centroids and assign 353 | points to either one of them. We weight each centroid by the number of points 354 | assigned to it, so to balance it. 355 | */ 356 | static int iteration_steps = 200; 357 | size_t count = nodes.size(); 358 | 359 | size_t i = random.index(count); 360 | size_t j = random.index(count-1); 361 | j += (j >= i); // ensure that i != j 362 | 363 | Distance::template copy_node(p, nodes[i], f); 364 | Distance::template copy_node(q, nodes[j], f); 365 | 366 | if (cosine) { Distance::template normalize(p, f); Distance::template normalize(q, f); } 367 | Distance::init_node(p, f); 368 | Distance::init_node(q, f); 369 | 370 | int ic = 1, jc = 1; 371 | for (int l = 0; l < iteration_steps; l++) { 372 | size_t k = random.index(count); 373 | T di = ic * Distance::distance(p, nodes[k], f), 374 | dj = jc * Distance::distance(q, nodes[k], f); 375 | T norm = cosine ? get_norm(nodes[k]->v, f) : 1; 376 | if (!(norm > T(0))) { 377 | continue; 378 | } 379 | if (di < dj) { 380 | for (int z = 0; z < f; z++) 381 | p->v[z] = (p->v[z] * ic + nodes[k]->v[z] / norm) / (ic + 1); 382 | Distance::init_node(p, f); 383 | ic++; 384 | } else if (dj < di) { 385 | for (int z = 0; z < f; z++) 386 | q->v[z] = (q->v[z] * jc + nodes[k]->v[z] / norm) / (jc + 1); 387 | Distance::init_node(q, f); 388 | jc++; 389 | } 390 | } 391 | } 392 | } // namespace 393 | 394 | struct Base { 395 | template 396 | static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) { 397 | // Override this in specific metric structs below if you need to do any pre-processing 398 | // on the entire set of nodes passed into this index. 399 | } 400 | 401 | template 402 | static inline void zero_value(Node* dest) { 403 | // Initialize any fields that require sane defaults within this node. 404 | } 405 | 406 | template 407 | static inline void copy_node(Node* dest, const Node* source, const int f) { 408 | memcpy(dest->v, source->v, f * sizeof(T)); 409 | } 410 | 411 | template 412 | static inline void normalize(Node* node, int f) { 413 | T norm = get_norm(node->v, f); 414 | if (norm > 0) { 415 | for (int z = 0; z < f; z++) 416 | node->v[z] /= norm; 417 | } 418 | } 419 | }; 420 | 421 | struct Angular : Base { 422 | template 423 | struct Node { 424 | /* 425 | * We store a binary tree where each node has two things 426 | * - A vector associated with it 427 | * - Two children 428 | * All nodes occupy the same amount of memory 429 | * All nodes with n_descendants == 1 are leaf nodes. 430 | * A memory optimization is that for nodes with 2 <= n_descendants <= K, 431 | * we skip the vector. Instead we store a list of all descendants. K is 432 | * determined by the number of items that fits in the space of the vector. 433 | * For nodes with n_descendants == 1 the vector is a data point. 434 | * For nodes with n_descendants > K the vector is the normal of the split plane. 435 | * Note that we can't really do sizeof(node) because we cheat and allocate 436 | * more memory to be able to fit the vector outside 437 | */ 438 | S n_descendants; 439 | union { 440 | S children[2]; // Will possibly store more than 2 441 | T norm; 442 | }; 443 | T v[V_ARRAY_SIZE]; 444 | }; 445 | template 446 | static inline T distance(const Node* x, const Node* y, int f) { 447 | // want to calculate (a/|a| - b/|b|)^2 448 | // = a^2 / a^2 + b^2 / b^2 - 2ab/|a||b| 449 | // = 2 - 2cos 450 | T pp = x->norm ? x->norm : dot(x->v, x->v, f); // For backwards compatibility reasons, we need to fall back and compute the norm here 451 | T qq = y->norm ? y->norm : dot(y->v, y->v, f); 452 | T pq = dot(x->v, y->v, f); 453 | T ppqq = pp * qq; 454 | if (ppqq > 0) return 2.0 - 2.0 * pq / sqrt(ppqq); 455 | else return 2.0; // cos is 0 456 | } 457 | template 458 | static inline T margin(const Node* n, const T* y, int f) { 459 | return dot(n->v, y, f); 460 | } 461 | template 462 | static inline bool side(const Node* n, const T* y, int f, Random& random) { 463 | T dot = margin(n, y, f); 464 | if (dot != 0) 465 | return (dot > 0); 466 | else 467 | return (bool)random.flip(); 468 | } 469 | template 470 | static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { 471 | Node* p = (Node*)alloca(s); 472 | Node* q = (Node*)alloca(s); 473 | two_means >(nodes, f, random, true, p, q); 474 | for (int z = 0; z < f; z++) 475 | n->v[z] = p->v[z] - q->v[z]; 476 | Base::normalize >(n, f); 477 | } 478 | template 479 | static inline T normalized_distance(T distance) { 480 | // Used when requesting distances from Python layer 481 | // Turns out sometimes the squared distance is -0.0 482 | // so we have to make sure it's a positive number. 483 | return sqrt(std::max(distance, T(0))); 484 | } 485 | template 486 | static inline T pq_distance(T distance, T margin, int child_nr) { 487 | if (child_nr == 0) 488 | margin = -margin; 489 | return std::min(distance, margin); 490 | } 491 | template 492 | static inline T pq_initial_value() { 493 | return numeric_limits::infinity(); 494 | } 495 | template 496 | static inline void init_node(Node* n, int f) { 497 | n->norm = dot(n->v, n->v, f); 498 | } 499 | static const char* name() { 500 | return "angular"; 501 | } 502 | }; 503 | 504 | 505 | struct DotProduct : Angular { 506 | template 507 | struct Node { 508 | /* 509 | * This is an extension of the Angular node with an extra attribute for the scaled norm. 510 | */ 511 | S n_descendants; 512 | S children[2]; // Will possibly store more than 2 513 | T dot_factor; 514 | T v[V_ARRAY_SIZE]; 515 | }; 516 | 517 | static const char* name() { 518 | return "dot"; 519 | } 520 | template 521 | static inline T distance(const Node* x, const Node* y, int f) { 522 | return -dot(x->v, y->v, f); 523 | } 524 | 525 | template 526 | static inline void zero_value(Node* dest) { 527 | dest->dot_factor = 0; 528 | } 529 | 530 | template 531 | static inline void init_node(Node* n, int f) { 532 | } 533 | 534 | template 535 | static inline void copy_node(Node* dest, const Node* source, const int f) { 536 | memcpy(dest->v, source->v, f * sizeof(T)); 537 | dest->dot_factor = source->dot_factor; 538 | } 539 | 540 | template 541 | static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { 542 | Node* p = (Node*)alloca(s); 543 | Node* q = (Node*)alloca(s); 544 | DotProduct::zero_value(p); 545 | DotProduct::zero_value(q); 546 | two_means >(nodes, f, random, true, p, q); 547 | for (int z = 0; z < f; z++) 548 | n->v[z] = p->v[z] - q->v[z]; 549 | n->dot_factor = p->dot_factor - q->dot_factor; 550 | DotProduct::normalize >(n, f); 551 | } 552 | 553 | template 554 | static inline void normalize(Node* node, int f) { 555 | T norm = sqrt(dot(node->v, node->v, f) + pow(node->dot_factor, 2)); 556 | if (norm > 0) { 557 | for (int z = 0; z < f; z++) 558 | node->v[z] /= norm; 559 | node->dot_factor /= norm; 560 | } 561 | } 562 | 563 | template 564 | static inline T margin(const Node* n, const T* y, int f) { 565 | return dot(n->v, y, f) + (n->dot_factor * n->dot_factor); 566 | } 567 | 568 | template 569 | static inline bool side(const Node* n, const T* y, int f, Random& random) { 570 | T dot = margin(n, y, f); 571 | if (dot != 0) 572 | return (dot > 0); 573 | else 574 | return (bool)random.flip(); 575 | } 576 | 577 | template 578 | static inline T normalized_distance(T distance) { 579 | return -distance; 580 | } 581 | 582 | template 583 | static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) { 584 | // This uses a method from Microsoft Research for transforming inner product spaces to cosine/angular-compatible spaces. 585 | // (Bachrach et al., 2014, see https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf) 586 | 587 | // Step one: compute the norm of each vector and store that in its extra dimension (f-1) 588 | for (S i = 0; i < node_count; i++) { 589 | Node* node = get_node_ptr(nodes, _s, i); 590 | T norm = sqrt(dot(node->v, node->v, f)); 591 | if (isnan(norm)) norm = 0; 592 | node->dot_factor = norm; 593 | } 594 | 595 | // Step two: find the maximum norm 596 | T max_norm = 0; 597 | for (S i = 0; i < node_count; i++) { 598 | Node* node = get_node_ptr(nodes, _s, i); 599 | if (node->dot_factor > max_norm) { 600 | max_norm = node->dot_factor; 601 | } 602 | } 603 | 604 | // Step three: set each vector's extra dimension to sqrt(max_norm^2 - norm^2) 605 | for (S i = 0; i < node_count; i++) { 606 | Node* node = get_node_ptr(nodes, _s, i); 607 | T node_norm = node->dot_factor; 608 | 609 | T dot_factor = sqrt(pow(max_norm, static_cast(2.0)) - pow(node_norm, static_cast(2.0))); 610 | if (isnan(dot_factor)) dot_factor = 0; 611 | 612 | node->dot_factor = dot_factor; 613 | } 614 | } 615 | }; 616 | 617 | struct Hamming : Base { 618 | template 619 | struct Node { 620 | S n_descendants; 621 | S children[2]; 622 | T v[V_ARRAY_SIZE]; 623 | }; 624 | 625 | static const size_t max_iterations = 20; 626 | 627 | template 628 | static inline T pq_distance(T distance, T margin, int child_nr) { 629 | return distance - (margin != (unsigned int) child_nr); 630 | } 631 | 632 | template 633 | static inline T pq_initial_value() { 634 | return numeric_limits::max(); 635 | } 636 | template 637 | static inline int cole_popcount(T v) { 638 | // Note: Only used with MSVC 9, which lacks intrinsics and fails to 639 | // calculate std::bitset::count for v > 32bit. Uses the generalized 640 | // approach by Eric Cole. 641 | // See https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64 642 | v = v - ((v >> 1) & (T)~(T)0/3); 643 | v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3); 644 | v = (v + (v >> 4)) & (T)~(T)0/255*15; 645 | return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * 8; 646 | } 647 | template 648 | static inline T distance(const Node* x, const Node* y, int f) { 649 | size_t dist = 0; 650 | for (int i = 0; i < f; i++) { 651 | dist += popcount(x->v[i] ^ y->v[i]); 652 | } 653 | return dist; 654 | } 655 | template 656 | static inline bool margin(const Node* n, const T* y, int f) { 657 | static const size_t n_bits = sizeof(T) * 8; 658 | T chunk = n->v[0] / n_bits; 659 | return (y[chunk] & (static_cast(1) << (n_bits - 1 - (n->v[0] % n_bits)))) != 0; 660 | } 661 | template 662 | static inline bool side(const Node* n, const T* y, int f, Random& random) { 663 | return margin(n, y, f); 664 | } 665 | template 666 | static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { 667 | size_t cur_size = 0; 668 | size_t i = 0; 669 | int dim = f * 8 * sizeof(T); 670 | for (; i < max_iterations; i++) { 671 | // choose random position to split at 672 | n->v[0] = random.index(dim); 673 | cur_size = 0; 674 | for (typename vector*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) { 675 | if (margin(n, (*it)->v, f)) { 676 | cur_size++; 677 | } 678 | } 679 | if (cur_size > 0 && cur_size < nodes.size()) { 680 | break; 681 | } 682 | } 683 | // brute-force search for splitting coordinate 684 | if (i == max_iterations) { 685 | int j = 0; 686 | for (; j < dim; j++) { 687 | n->v[0] = j; 688 | cur_size = 0; 689 | for (typename vector*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) { 690 | if (margin(n, (*it)->v, f)) { 691 | cur_size++; 692 | } 693 | } 694 | if (cur_size > 0 && cur_size < nodes.size()) { 695 | break; 696 | } 697 | } 698 | } 699 | } 700 | template 701 | static inline T normalized_distance(T distance) { 702 | return distance; 703 | } 704 | template 705 | static inline void init_node(Node* n, int f) { 706 | } 707 | static const char* name() { 708 | return "hamming"; 709 | } 710 | }; 711 | 712 | 713 | struct Minkowski : Base { 714 | template 715 | struct Node { 716 | S n_descendants; 717 | T a; // need an extra constant term to determine the offset of the plane 718 | S children[2]; 719 | T v[V_ARRAY_SIZE]; 720 | }; 721 | template 722 | static inline T margin(const Node* n, const T* y, int f) { 723 | return n->a + dot(n->v, y, f); 724 | } 725 | template 726 | static inline bool side(const Node* n, const T* y, int f, Random& random) { 727 | T dot = margin(n, y, f); 728 | if (dot != 0) 729 | return (dot > 0); 730 | else 731 | return (bool)random.flip(); 732 | } 733 | template 734 | static inline T pq_distance(T distance, T margin, int child_nr) { 735 | if (child_nr == 0) 736 | margin = -margin; 737 | return std::min(distance, margin); 738 | } 739 | template 740 | static inline T pq_initial_value() { 741 | return numeric_limits::infinity(); 742 | } 743 | }; 744 | 745 | 746 | struct Euclidean : Minkowski { 747 | template 748 | static inline T distance(const Node* x, const Node* y, int f) { 749 | return euclidean_distance(x->v, y->v, f); 750 | } 751 | template 752 | static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { 753 | Node* p = (Node*)alloca(s); 754 | Node* q = (Node*)alloca(s); 755 | two_means >(nodes, f, random, false, p, q); 756 | 757 | for (int z = 0; z < f; z++) 758 | n->v[z] = p->v[z] - q->v[z]; 759 | Base::normalize >(n, f); 760 | n->a = 0.0; 761 | for (int z = 0; z < f; z++) 762 | n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2; 763 | } 764 | template 765 | static inline T normalized_distance(T distance) { 766 | return sqrt(std::max(distance, T(0))); 767 | } 768 | template 769 | static inline void init_node(Node* n, int f) { 770 | } 771 | static const char* name() { 772 | return "euclidean"; 773 | } 774 | 775 | }; 776 | 777 | struct Manhattan : Minkowski { 778 | template 779 | static inline T distance(const Node* x, const Node* y, int f) { 780 | return manhattan_distance(x->v, y->v, f); 781 | } 782 | template 783 | static inline void create_split(const vector*>& nodes, int f, size_t s, Random& random, Node* n) { 784 | Node* p = (Node*)alloca(s); 785 | Node* q = (Node*)alloca(s); 786 | two_means >(nodes, f, random, false, p, q); 787 | 788 | for (int z = 0; z < f; z++) 789 | n->v[z] = p->v[z] - q->v[z]; 790 | Base::normalize >(n, f); 791 | n->a = 0.0; 792 | for (int z = 0; z < f; z++) 793 | n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2; 794 | } 795 | template 796 | static inline T normalized_distance(T distance) { 797 | return std::max(distance, T(0)); 798 | } 799 | template 800 | static inline void init_node(Node* n, int f) { 801 | } 802 | static const char* name() { 803 | return "manhattan"; 804 | } 805 | }; 806 | 807 | template 808 | class AnnoyIndexInterface { 809 | public: 810 | // Note that the methods with an **error argument will allocate memory and write the pointer to that string if error is non-NULL 811 | virtual ~AnnoyIndexInterface() {}; 812 | virtual bool add_item(S item, const T* w, char** error=NULL) = 0; 813 | virtual bool build(int q, char** error=NULL) = 0; 814 | virtual bool unbuild(char** error=NULL) = 0; 815 | virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0; 816 | virtual void unload() = 0; 817 | virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0; 818 | virtual T get_distance(S i, S j) const = 0; 819 | virtual void get_nns_by_item(S item, size_t n, int search_k, vector* result, vector* distances) const = 0; 820 | virtual void get_nns_by_vector(const T* w, size_t n, int search_k, vector* result, vector* distances) const = 0; 821 | virtual S get_n_items() const = 0; 822 | virtual S get_n_trees() const = 0; 823 | virtual void verbose(bool v) = 0; 824 | virtual void get_item(S item, T* v) const = 0; 825 | virtual void set_seed(int q) = 0; 826 | virtual bool on_disk_build(const char* filename, char** error=NULL) = 0; 827 | }; 828 | 829 | template 830 | class AnnoyIndex : public AnnoyIndexInterface { 831 | /* 832 | * We use random projection to build a forest of binary trees of all items. 833 | * Basically just split the hyperspace into two sides by a hyperplane, 834 | * then recursively split each of those subtrees etc. 835 | * We create a tree like this q times. The default q is determined automatically 836 | * in such a way that we at most use 2x as much memory as the vectors take. 837 | */ 838 | public: 839 | typedef Distance D; 840 | typedef typename D::template Node Node; 841 | 842 | protected: 843 | const int _f; 844 | size_t _s; 845 | S _n_items; 846 | Random _random; 847 | void* _nodes; // Could either be mmapped, or point to a memory buffer that we reallocate 848 | S _n_nodes; 849 | S _nodes_size; 850 | vector _roots; 851 | S _K; 852 | bool _loaded; 853 | bool _verbose; 854 | int _fd; 855 | bool _on_disk; 856 | bool _built; 857 | public: 858 | 859 | AnnoyIndex(int f) : _f(f), _random() { 860 | _s = offsetof(Node, v) + _f * sizeof(T); // Size of each node 861 | _verbose = false; 862 | _built = false; 863 | _K = (S) (((size_t) (_s - offsetof(Node, children))) / sizeof(S)); // Max number of descendants to fit into node 864 | reinitialize(); // Reset everything 865 | } 866 | ~AnnoyIndex() { 867 | unload(); 868 | } 869 | 870 | int get_f() const { 871 | return _f; 872 | } 873 | 874 | bool add_item(S item, const T* w, char** error=NULL) { 875 | return add_item_impl(item, w, error); 876 | } 877 | 878 | template 879 | bool add_item_impl(S item, const W& w, char** error=NULL) { 880 | if (_loaded) { 881 | set_error_from_string(error, "You can't add an item to a loaded index"); 882 | return false; 883 | } 884 | _allocate_size(item + 1); 885 | Node* n = _get(item); 886 | 887 | D::zero_value(n); 888 | 889 | n->children[0] = 0; 890 | n->children[1] = 0; 891 | n->n_descendants = 1; 892 | 893 | for (int z = 0; z < _f; z++) 894 | n->v[z] = w[z]; 895 | 896 | D::init_node(n, _f); 897 | 898 | if (item >= _n_items) 899 | _n_items = item + 1; 900 | 901 | return true; 902 | } 903 | 904 | bool on_disk_build(const char* file, char** error=NULL) { 905 | _on_disk = true; 906 | _fd = open(file, O_RDWR | O_CREAT | O_TRUNC, (int) 0600); 907 | if (_fd == -1) { 908 | set_error_from_errno(error, "Unable to open"); 909 | _fd = 0; 910 | return false; 911 | } 912 | _nodes_size = 1; 913 | if (ftruncate(_fd, _s * _nodes_size) == -1) { 914 | set_error_from_errno(error, "Unable to truncate"); 915 | return false; 916 | } 917 | #ifdef MAP_POPULATE 918 | _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0); 919 | #else 920 | _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0); 921 | #endif 922 | return true; 923 | } 924 | 925 | bool build(int q, char** error=NULL) { 926 | if (_loaded) { 927 | set_error_from_string(error, "You can't build a loaded index"); 928 | return false; 929 | } 930 | 931 | if (_built) { 932 | set_error_from_string(error, "You can't build a built index"); 933 | return false; 934 | } 935 | 936 | D::template preprocess(_nodes, _s, _n_items, _f); 937 | 938 | _n_nodes = _n_items; 939 | while (1) { 940 | if (q == -1 && _n_nodes >= _n_items * 2) 941 | break; 942 | if (q != -1 && _roots.size() >= (size_t)q) 943 | break; 944 | if (_verbose) showUpdate("pass %zd...\n", _roots.size()); 945 | 946 | vector indices; 947 | for (S i = 0; i < _n_items; i++) { 948 | if (_get(i)->n_descendants >= 1) // Issue #223 949 | indices.push_back(i); 950 | } 951 | 952 | _roots.push_back(_make_tree(indices, true)); 953 | } 954 | 955 | // Also, copy the roots into the last segment of the array 956 | // This way we can load them faster without reading the whole file 957 | _allocate_size(_n_nodes + (S)_roots.size()); 958 | for (size_t i = 0; i < _roots.size(); i++) 959 | memcpy(_get(_n_nodes + (S)i), _get(_roots[i]), _s); 960 | _n_nodes += _roots.size(); 961 | 962 | if (_verbose) showUpdate("has %d nodes\n", _n_nodes); 963 | 964 | if (_on_disk) { 965 | _nodes = remap_memory(_nodes, _fd, _s * _nodes_size, _s * _n_nodes); 966 | if (ftruncate(_fd, _s * _n_nodes)) { 967 | // TODO: this probably creates an index in a corrupt state... not sure what to do 968 | set_error_from_errno(error, "Unable to truncate"); 969 | return false; 970 | } 971 | _nodes_size = _n_nodes; 972 | } 973 | _built = true; 974 | return true; 975 | } 976 | 977 | bool unbuild(char** error=NULL) { 978 | if (_loaded) { 979 | set_error_from_string(error, "You can't unbuild a loaded index"); 980 | return false; 981 | } 982 | 983 | _roots.clear(); 984 | _n_nodes = _n_items; 985 | _built = false; 986 | 987 | return true; 988 | } 989 | 990 | bool save(const char* filename, bool prefault=false, char** error=NULL) { 991 | if (!_built) { 992 | set_error_from_string(error, "You can't save an index that hasn't been built"); 993 | return false; 994 | } 995 | if (_on_disk) { 996 | return true; 997 | } else { 998 | // Delete file if it already exists (See issue #335) 999 | unlink(filename); 1000 | 1001 | FILE *f = fopen(filename, "wb"); 1002 | if (f == NULL) { 1003 | set_error_from_errno(error, "Unable to open"); 1004 | return false; 1005 | } 1006 | 1007 | if (fwrite(_nodes, _s, _n_nodes, f) != (size_t) _n_nodes) { 1008 | set_error_from_errno(error, "Unable to write"); 1009 | return false; 1010 | } 1011 | 1012 | if (fclose(f) == EOF) { 1013 | set_error_from_errno(error, "Unable to close"); 1014 | return false; 1015 | } 1016 | 1017 | unload(); 1018 | return load(filename, prefault, error); 1019 | } 1020 | } 1021 | 1022 | void reinitialize() { 1023 | _fd = 0; 1024 | _nodes = NULL; 1025 | _loaded = false; 1026 | _n_items = 0; 1027 | _n_nodes = 0; 1028 | _nodes_size = 0; 1029 | _on_disk = false; 1030 | _roots.clear(); 1031 | } 1032 | 1033 | void unload() { 1034 | if (_on_disk && _fd) { 1035 | close(_fd); 1036 | munmap(_nodes, _s * _nodes_size); 1037 | } else { 1038 | if (_fd) { 1039 | // we have mmapped data 1040 | close(_fd); 1041 | munmap(_nodes, _n_nodes * _s); 1042 | } else if (_nodes) { 1043 | // We have heap allocated data 1044 | free(_nodes); 1045 | } 1046 | } 1047 | reinitialize(); 1048 | if (_verbose) showUpdate("unloaded\n"); 1049 | } 1050 | 1051 | bool load(const char* filename, bool prefault=false, char** error=NULL) { 1052 | _fd = open(filename, O_RDONLY, (int)0400); 1053 | if (_fd == -1) { 1054 | set_error_from_errno(error, "Unable to open"); 1055 | _fd = 0; 1056 | return false; 1057 | } 1058 | off_t size = lseek_getsize(_fd); 1059 | if (size == -1) { 1060 | set_error_from_errno(error, "Unable to get size"); 1061 | return false; 1062 | } else if (size == 0) { 1063 | set_error_from_errno(error, "Size of file is zero"); 1064 | return false; 1065 | } else if (size % _s) { 1066 | // Something is fishy with this index! 1067 | set_error_from_errno(error, "Index size is not a multiple of vector size. Ensure you are opening using the same metric you used to create the index."); 1068 | return false; 1069 | } 1070 | 1071 | int flags = MAP_SHARED; 1072 | if (prefault) { 1073 | #ifdef MAP_POPULATE 1074 | flags |= MAP_POPULATE; 1075 | #else 1076 | showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform"); 1077 | #endif 1078 | } 1079 | _nodes = (Node*)mmap(0, size, PROT_READ, flags, _fd, 0); 1080 | _n_nodes = (S)(size / _s); 1081 | 1082 | // Find the roots by scanning the end of the file and taking the nodes with most descendants 1083 | _roots.clear(); 1084 | S m = -1; 1085 | for (S i = _n_nodes - 1; i >= 0; i--) { 1086 | S k = _get(i)->n_descendants; 1087 | if (m == -1 || k == m) { 1088 | _roots.push_back(i); 1089 | m = k; 1090 | } else { 1091 | break; 1092 | } 1093 | } 1094 | // hacky fix: since the last root precedes the copy of all roots, delete it 1095 | if (_roots.size() > 1 && _get(_roots.front())->children[0] == _get(_roots.back())->children[0]) 1096 | _roots.pop_back(); 1097 | _loaded = true; 1098 | _built = true; 1099 | _n_items = m; 1100 | if (_verbose) showUpdate("found %lu roots with degree %d\n", _roots.size(), m); 1101 | return true; 1102 | } 1103 | 1104 | T get_distance(S i, S j) const { 1105 | return D::normalized_distance(D::distance(_get(i), _get(j), _f)); 1106 | } 1107 | 1108 | void get_nns_by_item(S item, size_t n, int search_k, vector* result, vector* distances) const { 1109 | // TODO: handle OOB 1110 | const Node* m = _get(item); 1111 | _get_all_nns(m->v, n, search_k, result, distances); 1112 | } 1113 | 1114 | void get_nns_by_vector(const T* w, size_t n, int search_k, vector* result, vector* distances) const { 1115 | _get_all_nns(w, n, search_k, result, distances); 1116 | } 1117 | 1118 | S get_n_items() const { 1119 | return _n_items; 1120 | } 1121 | 1122 | S get_n_trees() const { 1123 | return (S)_roots.size(); 1124 | } 1125 | 1126 | void verbose(bool v) { 1127 | _verbose = v; 1128 | } 1129 | 1130 | void get_item(S item, T* v) const { 1131 | // TODO: handle OOB 1132 | Node* m = _get(item); 1133 | memcpy(v, m->v, (_f) * sizeof(T)); 1134 | } 1135 | 1136 | void set_seed(int seed) { 1137 | _random.set_seed(seed); 1138 | } 1139 | 1140 | protected: 1141 | void _allocate_size(S n) { 1142 | if (n > _nodes_size) { 1143 | const double reallocation_factor = 1.3; 1144 | S new_nodes_size = std::max(n, (S) ((_nodes_size + 1) * reallocation_factor)); 1145 | void *old = _nodes; 1146 | 1147 | if (_on_disk) { 1148 | int rc = ftruncate(_fd, _s * new_nodes_size); 1149 | if (_verbose && rc) showUpdate("File truncation error\n"); 1150 | _nodes = remap_memory(_nodes, _fd, _s * _nodes_size, _s * new_nodes_size); 1151 | } else { 1152 | _nodes = realloc(_nodes, _s * new_nodes_size); 1153 | memset((char *) _nodes + (_nodes_size * _s) / sizeof(char), 0, (new_nodes_size - _nodes_size) * _s); 1154 | } 1155 | 1156 | _nodes_size = new_nodes_size; 1157 | if (_verbose) showUpdate("Reallocating to %d nodes: old_address=%p, new_address=%p\n", new_nodes_size, old, _nodes); 1158 | } 1159 | } 1160 | 1161 | inline Node* _get(const S i) const { 1162 | return get_node_ptr(_nodes, _s, i); 1163 | } 1164 | 1165 | S _make_tree(const vector& indices, bool is_root) { 1166 | // The basic rule is that if we have <= _K items, then it's a leaf node, otherwise it's a split node. 1167 | // There's some regrettable complications caused by the problem that root nodes have to be "special": 1168 | // 1. We identify root nodes by the arguable logic that _n_items == n->n_descendants, regardless of how many descendants they actually have 1169 | // 2. Root nodes with only 1 child need to be a "dummy" parent 1170 | // 3. Due to the _n_items "hack", we need to be careful with the cases where _n_items <= _K or _n_items > _K 1171 | if (indices.size() == 1 && !is_root) 1172 | return indices[0]; 1173 | 1174 | if (indices.size() <= (size_t)_K && (!is_root || (size_t)_n_items <= (size_t)_K || indices.size() == 1)) { 1175 | _allocate_size(_n_nodes + 1); 1176 | S item = _n_nodes++; 1177 | Node* m = _get(item); 1178 | m->n_descendants = is_root ? _n_items : (S)indices.size(); 1179 | 1180 | // Using std::copy instead of a loop seems to resolve issues #3 and #13, 1181 | // probably because gcc 4.8 goes overboard with optimizations. 1182 | // Using memcpy instead of std::copy for MSVC compatibility. #235 1183 | // Only copy when necessary to avoid crash in MSVC 9. #293 1184 | if (!indices.empty()) 1185 | memcpy(m->children, &indices[0], indices.size() * sizeof(S)); 1186 | return item; 1187 | } 1188 | 1189 | vector children; 1190 | for (size_t i = 0; i < indices.size(); i++) { 1191 | S j = indices[i]; 1192 | Node* n = _get(j); 1193 | if (n) 1194 | children.push_back(n); 1195 | } 1196 | 1197 | vector children_indices[2]; 1198 | Node* m = (Node*)alloca(_s); 1199 | D::create_split(children, _f, _s, _random, m); 1200 | 1201 | for (size_t i = 0; i < indices.size(); i++) { 1202 | S j = indices[i]; 1203 | Node* n = _get(j); 1204 | if (n) { 1205 | bool side = D::side(m, n->v, _f, _random); 1206 | children_indices[side].push_back(j); 1207 | } else { 1208 | showUpdate("No node for index %d?\n", j); 1209 | } 1210 | } 1211 | 1212 | // If we didn't find a hyperplane, just randomize sides as a last option 1213 | while (children_indices[0].size() == 0 || children_indices[1].size() == 0) { 1214 | if (_verbose) 1215 | showUpdate("\tNo hyperplane found (left has %ld children, right has %ld children)\n", 1216 | children_indices[0].size(), children_indices[1].size()); 1217 | if (_verbose && indices.size() > 100000) 1218 | showUpdate("Failed splitting %lu items\n", indices.size()); 1219 | 1220 | children_indices[0].clear(); 1221 | children_indices[1].clear(); 1222 | 1223 | // Set the vector to 0.0 1224 | for (int z = 0; z < _f; z++) 1225 | m->v[z] = 0; 1226 | 1227 | for (size_t i = 0; i < indices.size(); i++) { 1228 | S j = indices[i]; 1229 | // Just randomize... 1230 | children_indices[_random.flip()].push_back(j); 1231 | } 1232 | } 1233 | 1234 | int flip = (children_indices[0].size() > children_indices[1].size()); 1235 | 1236 | m->n_descendants = is_root ? _n_items : (S)indices.size(); 1237 | for (int side = 0; side < 2; side++) { 1238 | // run _make_tree for the smallest child first (for cache locality) 1239 | m->children[side^flip] = _make_tree(children_indices[side^flip], false); 1240 | } 1241 | 1242 | _allocate_size(_n_nodes + 1); 1243 | S item = _n_nodes++; 1244 | memcpy(_get(item), m, _s); 1245 | 1246 | return item; 1247 | } 1248 | 1249 | void _get_all_nns(const T* v, size_t n, int search_k, vector* result, vector* distances) const { 1250 | Node* v_node = (Node *)alloca(_s); 1251 | D::template zero_value(v_node); 1252 | memcpy(v_node->v, v, sizeof(T) * _f); 1253 | D::init_node(v_node, _f); 1254 | 1255 | std::priority_queue > q; 1256 | 1257 | if (search_k == -1) { 1258 | search_k = n * _roots.size(); 1259 | } 1260 | 1261 | for (size_t i = 0; i < _roots.size(); i++) { 1262 | q.push(make_pair(Distance::template pq_initial_value(), _roots[i])); 1263 | } 1264 | 1265 | std::vector nns; 1266 | while (nns.size() < (size_t)search_k && !q.empty()) { 1267 | const pair& top = q.top(); 1268 | T d = top.first; 1269 | S i = top.second; 1270 | Node* nd = _get(i); 1271 | q.pop(); 1272 | if (nd->n_descendants == 1 && i < _n_items) { 1273 | nns.push_back(i); 1274 | } else if (nd->n_descendants <= _K) { 1275 | const S* dst = nd->children; 1276 | nns.insert(nns.end(), dst, &dst[nd->n_descendants]); 1277 | } else { 1278 | T margin = D::margin(nd, v, _f); 1279 | q.push(make_pair(D::pq_distance(d, margin, 1), static_cast(nd->children[1]))); 1280 | q.push(make_pair(D::pq_distance(d, margin, 0), static_cast(nd->children[0]))); 1281 | } 1282 | } 1283 | 1284 | // Get distances for all items 1285 | // To avoid calculating distance multiple times for any items, sort by id 1286 | std::sort(nns.begin(), nns.end()); 1287 | vector > nns_dist; 1288 | S last = -1; 1289 | for (size_t i = 0; i < nns.size(); i++) { 1290 | S j = nns[i]; 1291 | if (j == last) 1292 | continue; 1293 | last = j; 1294 | if (_get(j)->n_descendants == 1) // This is only to guard a really obscure case, #284 1295 | nns_dist.push_back(make_pair(D::distance(v_node, _get(j), _f), j)); 1296 | } 1297 | 1298 | size_t m = nns_dist.size(); 1299 | size_t p = n < m ? n : m; // Return this many items 1300 | std::partial_sort(nns_dist.begin(), nns_dist.begin() + p, nns_dist.end()); 1301 | for (size_t i = 0; i < p; i++) { 1302 | if (distances) 1303 | distances->push_back(D::normalized_distance(nns_dist[i].first)); 1304 | result->push_back(nns_dist[i].second); 1305 | } 1306 | } 1307 | }; 1308 | 1309 | #endif 1310 | // vim: tabstop=2 shiftwidth=2 --------------------------------------------------------------------------------