├── src ├── main │ ├── java │ │ └── com │ │ │ └── bj58 │ │ │ └── spider │ │ │ ├── faiss │ │ │ └── .gitignore │ │ │ └── faiss4j │ │ │ ├── IndexHelper.java │ │ │ └── FaissIndex.java │ └── resources │ │ └── log4j2.xml └── test │ └── java │ └── com │ └── bj58 │ └── spider │ └── faiss │ └── tests │ ├── IOTests.java │ └── Examples.java ├── test.sh ├── swigc++2java.sh ├── gendylib.sh ├── README.md ├── pom.xml └── swigfaiss4j.i /src/main/java/com/bj58/spider/faiss/.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | export LD_LIBRARY_PATH="." 2 | JAVA="java" 3 | #JAVA="/opt/soft/jdk/jdk1.8.0_66/bin/java" 4 | JAR="target/faiss4java-1.0.0-jar-with-dependencies.jar" 5 | 6 | ulimit -c unlimited 7 | $JAVA -Djava.library.path="." -classpath $JAR com.bj58.spider.faiss4j.FaissIndex "part-00000" "128" 8 | -------------------------------------------------------------------------------- /swigc++2java.sh: -------------------------------------------------------------------------------- 1 | rm -rf src/main/java/com/bj58/spider/faiss/*.java swigfaiss4j.cpp 2 | mkdir -p src/main/java/com/bj58/spider/faiss/ 3 | # swig -version == 4.0 4 | # -I../faiss: 确保是faiss源码路径 5 | #/usr/local/bin/swig -c++ -java -package com.bj58.spider.faiss -o swigfaiss4j.cpp -outdir src/main/java/com/bj58/spider/faiss/ -Doverride= -I../faiss/ swigfaiss.swig 6 | /usr/local/bin/swig -c++ -java -package com.bj58.spider.faiss -o swigfaiss4j.cpp -outdir src/main/java/com/bj58/spider/faiss/ -Doverride= -I../faiss/ swigfaiss4j.i 7 | 8 | -------------------------------------------------------------------------------- /gendylib.sh: -------------------------------------------------------------------------------- 1 | # mac 2 | SHAREDEXT=dylib 3 | # linux 4 | #SHAREDEXT=so 5 | 6 | rm -rf swigfaiss4j.$SHAREDEXT 7 | 8 | #改成自己的faiss安装目录 9 | FAISS_INSTALL_DIR=/usr/local/ 10 | 11 | # -L$FAISS_INSTALL_DIR/lib: libfaiss.so/dylib 所在目录 12 | # -I$JAVA_HOME/include/: 找jni.h 13 | # $JAVA_HOME/include/linux/: 找jni_md.h 14 | 15 | # linux 16 | #CXX="g++ -std=c++11" 17 | #CXXFLAGS="-fPIC -fopenmp -m64 -Wno-unused-parameter -Wno-unused-parameter -Wno-strict-aliasing -g -O3 -Wextra -msse4 -mpopcnt -fpermissive -L$FAISS_INSTALL_DIR/lib -I$FAISS_INSTALL_DIR/include/faiss/ -I$JAVA_HOME/include/ -I$JAVA_HOME/include/linux/ " 18 | # mac 19 | CXX="/usr/local/opt/llvm/bin/clang++ -std=c++11" 20 | CXXFLAGS="-fPIC -fopenmp -m64 -Wall -g -O3 -Wextra -msse4 -mpopcnt -L$FAISS_INSTALL_DIR/lib -I$FAISS_INSTALL_DIR/include/faiss/ -I$JAVA_HOME/include/ -I$JAVA_HOME/include/darwin/" 21 | 22 | ${CXX} ${CXXFLAGS} -lfaiss swigfaiss4j.cpp -shared -o libswigfaiss4j.$SHAREDEXT 23 | -------------------------------------------------------------------------------- /src/main/resources/log4j2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/test/java/com/bj58/spider/faiss/tests/IOTests.java: -------------------------------------------------------------------------------- 1 | package com.bj58.spider.faiss.tests; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | 6 | import java.nio.file.Paths; 7 | 8 | public class IOTests { 9 | private static final Logger log = LoggerFactory.getLogger(IOTests.class); 10 | 11 | public static void load() { 12 | System.load(Paths.get("./swigfaiss4j.dylib").toAbsolutePath().toString()); 13 | System.loadLibrary("faiss"); 14 | } 15 | 16 | private static float[][] dummyData() { 17 | return new float[][]{ 18 | new float[]{10, 0, 0}, 19 | new float[]{9, 0, 0}, 20 | new float[]{8, 0, 0}, 21 | new float[]{7, 0, 0}, 22 | new float[]{6, 0, 0}, 23 | 24 | new float[]{0, 10, 0}, 25 | new float[]{0, 9, 0}, 26 | new float[]{0, 8, 0}, 27 | new float[]{0, 7, 0}, 28 | new float[]{0, 6, 0}, 29 | 30 | new float[]{0, 0, 10}, 31 | new float[]{0, 0, 9}, 32 | new float[]{0, 0, 8}, 33 | new float[]{0, 0, 7}, 34 | new float[]{0, 0, 6}, 35 | }; 36 | } 37 | 38 | public void testIndexWriteAndRead() { 39 | String filename = "./index-1"; 40 | 41 | 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # faiss4java 2 | 3 | ### Introduction 4 | 5 | 原始[faiss](https://github.com/facebookresearch/faiss)只支持c++和python,本项目支持v1.5.3版本faiss java接口,主要参考[faiss4j](https://github.com/thenetcircle/faiss4j.git)。 6 | > 以下过程在Mac Mojave系统下执行 7 | 8 | ### Building faiss 9 | 1. 安装faiss源码,参考[官网](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md) 10 | 11 | ``` 12 | ➜ git clone -b v1.5.3 https://github.com/facebookresearch/faiss.git 13 | ➜ cd faiss 14 | ➜ ./configure --without-cuda #注意1 15 | ➜ make all && make install 16 | ➜ make py && make py install 17 | ``` 18 | 2. 测试openblas库, 19 | 20 | ``` 21 | errors = ... (不用管) 22 | info=0000064b00000000 23 | Lapack uses xx-bit integers 24 | ``` 25 | 出现如上结果,说明成功了,直接跳到Buiding faiss4java部分。 26 | 3. 如出现下面错误: 27 | ``` 28 | ➜ make misc/test_blas 29 | ➜ ./misc/test_blas 30 | dyld: Library not loaded: @rpath/libopenblas_nehalemp-r0.2.19.dylib 31 | Referenced from: /Users/jiananliu/test/faiss/./misc/test_blas 32 | Reason: image not found 33 | [1] 89887 abort ./misc/test_blas 34 | ``` 35 | 后面在java中System.loadLibrary("swigfaiss4j");的时候也会报上面类似的错误。 36 | 虽然上面执行./configure会显示 37 | 38 | ``` 39 | ... 40 | checking for sgemm_ in -lopenblas... yes 41 | ... 42 | ``` 43 | 44 | 所以,需要指定--with-blas参数,即 45 | ``` 46 | ➜ ./configure --without-cuda --with-blas=/usr/local/opt/openblas/lib/libopenblas.dylib 47 | ➜ rm -f misc/test_blas #make clean不能删除这个 48 | ➜ make misc/test_blas 49 | ➜ ./misc/test_blas 50 | ``` 51 | 52 | 对应的openblas路径可以用如下命令查找: 53 | 54 | ``` 55 | ➜ brew info openblas 56 | ``` 57 | 重新编译faiss源码。 58 | ### Building faiss4java 59 | 1. 注意faiss4java和faiss同级目录,确保swig版本是4.0。 60 | 61 | ``` 62 | ➜ git clone https://github.com/belkov0912/faiss4java.git 63 | # swig c++转成java接口,并生成swigfaiss4j.cpp 64 | ➜ ./swigc++2java.sh 65 | # 生成动态库,mac后缀是dylib,linux是so 66 | ➜ ./gendylib.sh 67 | ``` 68 | 69 | 2. 编译完java接口,用idea打开会提示如下错误: 70 | 71 | ``` 72 | protected xxxclass(long cPtr, boolean cMemoryOwn) { 73 | super(swigfaissJNI.IndexReplicas_SWIGUpcast(cPtr), cMemoryOwn); 74 | swigCPtr = cPtr; 75 | } 76 | 77 | Error:java: 已在类 com.bj58.spider.faiss.xxx 中定义了构造器 xxx(long,boolean) 78 | ``` 79 | 直接把提示的构造器删除就可以了(可能swig解析c++模版类有关,不太清楚原理,欢迎大家指教)。 80 | 81 | 3. 运行com.bj58.spider.faiss.tests.Examples,需要配置java.library.path 82 | 83 | ``` 84 | #VM options中加入 85 | -Djava.library.path="/xxx/faiss4j" 86 | ``` 87 | 运行结果如下: 88 | ``` 89 | 19-09-03 18:45:43 [main] INFO Examples:49 - is_trained = true 90 | 19-09-03 18:45:43 [main] INFO Examples:51 - ntotal = 10 91 | 19-09-03 18:45:43 [main] INFO Examples:62 - search 5 first vector of xb 92 | 19-09-03 18:45:44 [main] INFO Examples:64 - Vectors: 93 | 0 |0.00000 0.897081 0.0970035 0.0361408 0.149420 94 | 1 |0.00100000 0.286766 0.871114 0.127865 0.528833 95 | 2 |0.00200000 0.332409 0.414103 0.197723 0.603372 96 | ... 97 | ``` 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/main/java/com/bj58/spider/faiss4j/IndexHelper.java: -------------------------------------------------------------------------------- 1 | package com.bj58.spider.faiss4j; 2 | 3 | import com.bj58.spider.faiss.*; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | public class IndexHelper { 8 | private static final Logger log = LoggerFactory.getLogger(IndexHelper.class); 9 | 10 | public static String show(longlongArray a, int rows, int cols) { 11 | StringBuilder sb = new StringBuilder(); 12 | for (int i = 0; i < rows; i++) { 13 | sb.append(i).append('\t').append('|'); 14 | for (int j = 0; j < cols; j++) { 15 | sb.append(String.format("%5d ", a.getitem(i * cols + j))); 16 | } 17 | sb.append("\n"); 18 | } 19 | return sb.toString(); 20 | } 21 | 22 | public static String show(longArray a, int rows, int cols) { 23 | StringBuilder sb = new StringBuilder(); 24 | for (int i = 0; i < rows; i++) { 25 | sb.append(i).append('\t').append('|'); 26 | for (int j = 0; j < cols; j++) { 27 | sb.append(String.format("%5d ", a.getitem(i * cols + j))); 28 | } 29 | sb.append("\n"); 30 | } 31 | return sb.toString(); 32 | } 33 | 34 | public static String show(floatArray a, int rows, int cols) { 35 | StringBuilder sb = new StringBuilder(); 36 | for (int i = 0; i < rows; i++) { 37 | sb.append(i).append('\t').append('|'); 38 | for (int j = 0; j < cols; j++) { 39 | sb.append(String.format("%7g ", a.getitem(i * cols + j))); 40 | } 41 | sb.append("\n"); 42 | } 43 | return sb.toString(); 44 | } 45 | 46 | public static floatArray makeFloatArray(float[][] vectors) { 47 | int d = vectors[0].length; 48 | int nb = vectors.length; 49 | floatArray fa = new floatArray(d * nb); 50 | for (int i = 0; i < nb; i++) { 51 | for (int j = 0; j < d; j++) { 52 | fa.setitem(d * i + j, vectors[i][j]); 53 | } 54 | } 55 | return fa; 56 | } 57 | 58 | // 获取floatArray [i0 ~ i1)区间的所有行, 每行的维度是d 59 | public static float[][] getLineFromFloatArray(floatArray fa, int i0, int i1, int d) { 60 | float[][] rt = new float[i1-i0][d]; 61 | for (int i=i0; i 2 | 4 | 5 | 4.0.0 6 | com.bj58.spider 7 | faiss4java 8 | 1.0.0 9 | 10 | 11 | UTF-8 12 | UTF-8 13 | 2.7.3 14 | 1.7 15 | 3.2.5 16 | 2.11 17 | ${scala.minor.version}.8 18 | 2.3.2 19 | 20 | 21 | UTF-8 22 | provided 23 | compile 24 | 2.0.10 25 | 1.0.89 26 | 1.0.2-SNAPSHOT 27 | 1.0.0-SNAPSHOT 28 | 1.0.0-SNAPSHOT 29 | 30 | 31 | 32 | 33 | org.apache.logging.log4j 34 | log4j-api 35 | 2.7 36 | 37 | 38 | commons-logging 39 | commons-logging 40 | 1.1.1 41 | 42 | 43 | org.apache.logging.log4j 44 | log4j-core 45 | 2.7 46 | 47 | 48 | org.apache.logging.log4j 49 | log4j-slf4j-impl 50 | 2.2 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | net.alchim31.maven 60 | 61 | scala-maven-plugin 62 | 3.1.5 63 | 64 | 65 | 66 | 67 | 68 | maven-assembly-plugin 69 | 70 | 71 | jar-with-dependencies 72 | 73 | 74 | 75 | 76 | net.alchim31.maven 77 | scala-maven-plugin 78 | 79 | 80 | scala-compile-first 81 | process-resources 82 | 83 | add-source 84 | compile 85 | 86 | 87 | 88 | scala-test-compile 89 | 90 | process-test-resources 91 | 92 | testCompile 93 | 94 | 95 | 96 | 97 | 98 | org.apache.maven.plugins 99 | maven-dependency-plugin 100 | 101 | 102 | copy 103 | package 104 | 105 | copy-dependencies 106 | 107 | 108 | 109 | ${project.build.directory}/lib 110 | 111 | true 112 | true 113 | 114 | 115 | 116 | 117 | 118 | org.apache.maven.plugins 119 | maven-compiler-plugin 120 | 3.6.1 121 | 122 | ${java.version} 123 | ${java.version} 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | config-package 132 | 133 | 134 | 135 | 136 | org.apache.maven.plugins 137 | maven-antrun-plugin 138 | 1.7 139 | 140 | 141 | config-package 142 | package 143 | 144 | run 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | true 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /src/main/java/com/bj58/spider/faiss4j/FaissIndex.java: -------------------------------------------------------------------------------- 1 | package com.bj58.spider.faiss4j; 2 | 3 | import com.bj58.spider.faiss.*; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | import java.io.BufferedReader; 8 | import java.io.File; 9 | import java.io.FileReader; 10 | import java.io.IOException; 11 | import java.util.ArrayList; 12 | import java.lang.Math; 13 | 14 | import static com.bj58.spider.faiss4j.IndexHelper.*; 15 | 16 | public class FaissIndex { 17 | private static final Logger logger = LoggerFactory.getLogger(FaissIndex.class); 18 | private static final boolean debug_mode = false; 19 | 20 | static { 21 | try { 22 | System.loadLibrary("swigfaiss4j"); 23 | } catch (Exception ex) { 24 | throw new java.lang.RuntimeException("please make sure libs:[faiss, swigfaiss4j] in java.library.path"); 25 | } 26 | 27 | logger.info("load libswigfaiss4j success"); 28 | } 29 | 30 | private Index index; 31 | private floatArray xb; 32 | private int d; 33 | private int nb; 34 | private float[][] tmp_vec = null; 35 | 36 | private float[][] fvecs_read(String fileName, int min, int max) { 37 | File file = new File(fileName); 38 | int max_line; 39 | if (debug_mode) max_line = 100000; else max_line = 943139; 40 | if (max == -1) 41 | max = max_line; 42 | 43 | BufferedReader reader = null; 44 | int row = 0, column = 0; 45 | ArrayList vecs = new ArrayList<>(max); 46 | try { 47 | reader = new BufferedReader(new FileReader(file)); 48 | String line; 49 | while ((line = reader.readLine()) != null) { 50 | int last_index = line.lastIndexOf(':'); 51 | // String id = line.substring(0, last_index); 52 | String values = line.substring(last_index + 1); 53 | String[] sp = values.split(","); 54 | if (column == 0) column = sp.length; 55 | if (sp.length != column) { 56 | logger.warn(String.format("column:[%d] != real_columns:[%d] row:[%d] line:[%s]", column, sp.length, row, line)); 57 | continue; 58 | } 59 | 60 | float[] vs = IndexHelper.toFloatArray(sp); 61 | vecs.add(vs); 62 | row += 1; 63 | } 64 | reader.close(); 65 | } catch (IOException e) { 66 | e.printStackTrace(); 67 | } finally { 68 | if (reader != null) { 69 | try { 70 | reader.close(); 71 | } catch (IOException e1) { 72 | logger.error("close reader failed"); 73 | } 74 | } 75 | } 76 | 77 | return vecs.toArray(new float[row][column]); 78 | } 79 | 80 | public void set_train_data(String file_name) { 81 | float[][] data = fvecs_read(file_name, 0, -1); 82 | this.xb = makeFloatArray(data); 83 | this.nb = data.length; 84 | this.d = data[0].length; 85 | this.tmp_vec = data; 86 | logger.info(String.format("train vecs row:[%d] columns:[%d]", nb, d)); 87 | } 88 | 89 | public void printMatrixStats() { 90 | MatrixStats matrixStats = new MatrixStats(nb, d, xb.cast()); 91 | logger.info("-----------------------------------"); 92 | logger.info(matrixStats.getComments()); 93 | logger.info("-----------------------------------"); 94 | } 95 | 96 | public void create_index(boolean use_imi, int nprob) { 97 | try { 98 | int nhash; 99 | long ncentroids; 100 | Index quantizer; 101 | if (use_imi) { 102 | nhash = 2; 103 | long nbits_subq = (int)(Math.log(nb+1)/Math.log(2))/2; // good choice in general 104 | ncentroids = 1 << (nhash * nbits_subq); // total # of centroids == nlist 105 | // MultiIndexQuantizer 106 | quantizer = new MultiIndexQuantizer(d, nhash, nbits_subq); 107 | logger.info(String.format("IMI (%d,%d): %d virtual centroids (target: %d base vectors)", nhash, nbits_subq, ncentroids, nb)); 108 | 109 | } else { 110 | // IndexFlatL2 111 | quantizer = new IndexFlatL2(d); 112 | ncentroids = 4096; 113 | logger.info(String.format("IF: %d virtual centroids (target: %d base vectors)", ncentroids, nb)); 114 | } 115 | 116 | IndexIVFFlat index = new IndexIVFFlat(quantizer, d, ncentroids, MetricType.METRIC_L2); 117 | 118 | logger.info(String.format("index is trained: %b", index.getIs_trained())); 119 | 120 | if (use_imi) { 121 | index.setQuantizer_trains_alone('1'); 122 | } 123 | 124 | // define the number of probes. 2048 is for high-dim, overkilled in practice 125 | // Use 4-1024 depending on the trade-off speed accuracy that you want 126 | index.setNprobe(nprob); 127 | logger.info(String.format("set nprob:%d", nprob)); 128 | 129 | long t0 = System.currentTimeMillis(); 130 | // logger.info("Vectors:\n{}", show(tb, trainData.length, dimension)); 131 | index.train(nb, xb.cast()); 132 | long t1 = System.currentTimeMillis(); 133 | logger.info(String.format("[%d ms] Training the index", t1-t0)); 134 | 135 | index.add(nb, xb.cast()); 136 | long t2 = System.currentTimeMillis(); 137 | logger.info(String.format("[%d ms] Adding the vectors to the index", t2-t1)); 138 | 139 | this.index = index; 140 | } catch (Exception e) { 141 | logger.error("failed", e); 142 | } 143 | } 144 | 145 | public void search(float[][] query) { 146 | int k = 5; 147 | int nq = query.length; 148 | floatArray q = makeFloatArray(query); 149 | longlongArray labels = new longlongArray(k * nq); 150 | floatArray distances = new floatArray(k * nq); 151 | 152 | long t4 = System.currentTimeMillis(); 153 | index.search(query.length, q.cast(), k, distances.cast(), labels.cast()); 154 | 155 | long t5 = System.currentTimeMillis(); 156 | logger.info(String.format("[%dms] Searching the %d nearest neighbors of %d vectors in the index", t5 - t4, k, nq)); 157 | logger.info("Query results (vector ids, then distances):"); 158 | 159 | logger.info("Distances:\n{}", show(distances, nq, k)); 160 | logger.info("Labels:\n{}", show(labels, nq, k)); 161 | } 162 | 163 | public static void main(String[] argv) { 164 | 165 | String file_name = "part-00000"; 166 | int nprob = 512; 167 | if (argv.length >= 2) { 168 | file_name = argv[0]; 169 | nprob = Integer.parseInt(argv[1]); 170 | } 171 | logger.info("read file:" + file_name); 172 | 173 | FaissIndex index = new FaissIndex(); 174 | 175 | index.set_train_data(file_name); 176 | 177 | index.printMatrixStats(); 178 | 179 | index.create_index(true, nprob); 180 | 181 | { 182 | // remember a few elements from the database as queries 183 | int i0 = 1234; 184 | int i1 = 1245; 185 | 186 | // float[][] query = Arrays.copyOfRange(index.tmp_vec, i0, i1); 187 | float[][] query = getLineFromFloatArray(index.xb, i0, i1, index.d); 188 | 189 | index.search(query); 190 | } 191 | } 192 | } 193 | -------------------------------------------------------------------------------- /src/test/java/com/bj58/spider/faiss/tests/Examples.java: -------------------------------------------------------------------------------- 1 | package com.bj58.spider.faiss.tests; 2 | 3 | import com.bj58.spider.faiss.*; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | import java.util.Random; 8 | 9 | import static com.bj58.spider.faiss4j.IndexHelper.*; 10 | 11 | 12 | public class Examples { 13 | private static final Logger log = LoggerFactory.getLogger(Examples.class); 14 | 15 | static { 16 | // 17 | String property = System.getProperty("java.library.path"); 18 | System.out.println(property); 19 | // System.loadLibrary("faiss"); 20 | System.loadLibrary("swigfaiss4j"); 21 | System.out.println("load libswigfaiss4j success"); 22 | } 23 | 24 | 25 | public static void testFlat() { 26 | int d = 5; // dimension 27 | int nb = 10; // database size 28 | int nq = 10000; // nb of queries 29 | 30 | // float[] xb = new float[d * nb]; 31 | // float[] xq = new float[d * nq]; 32 | try { 33 | floatArray xb = new floatArray(d * nb); 34 | 35 | Random rand = new Random(); 36 | 37 | for (int i = 0; i < nb; i++) { 38 | for (int j = 0; j < d; j++) { 39 | // xb[d * i + j] = rand.nextFloat(); 40 | xb.setitem(d * i + j, rand.nextFloat()); 41 | } 42 | // xb[d * i] += i / 1000.; 43 | xb.setitem(d * i, (float) (i / 1000.0)); 44 | } 45 | 46 | IndexFlatL2 index = new IndexFlatL2(d); 47 | log.info("is_trained = {}", index.getIs_trained()); 48 | index.add(nb, xb.cast()); 49 | log.info("ntotal = {}", index.getNtotal()); 50 | 51 | 52 | { 53 | int k = 4; 54 | longlongArray I = new longlongArray(k * 5); 55 | floatArray D = new floatArray(k * 5); 56 | 57 | log.info("search 5 first vector of xb"); 58 | index.search(5, xb.cast(), 4, D.cast(), I.cast()); 59 | log.info("Vectors:\n{}", show(xb, nb, d)); 60 | log.info("Distances:\n{}", show(D, 5, 4)); 61 | log.info("I:\n{}", show(I, 5, 4)); 62 | } 63 | } catch (Exception e) { 64 | log.error("failed", e); 65 | } 66 | } 67 | 68 | public void simpleTest() { 69 | try { 70 | float[][] data = dummyData3d(10); 71 | int d = data[0].length; 72 | int numberOfVector = data.length; 73 | floatArray xb = makeFloatArray(data); 74 | longArray ids = makeLongArray(new int[]{0, 1, 2}); 75 | IndexFlatL2 index = new IndexFlatL2(d); 76 | //what(): Error in virtual void faiss::Index::add_with_ids(faiss::Index::idx_t, const float*, const long int*) at Index.cpp:46: add_with_ids not implemented for this type of index 77 | 78 | // index.add_with_ids(3, xb.cast(), ids.cast()); 79 | index.add(numberOfVector, xb.cast()); 80 | 81 | log.info("ntotal = {}", index.getNtotal()); 82 | 83 | { 84 | int resultSize = 3; 85 | float[][] queryConds = {new float[]{0, 1, 8}}; 86 | 87 | floatArray query = makeFloatArray(queryConds); 88 | longlongArray labels = new longlongArray(resultSize); 89 | floatArray distances = new floatArray(resultSize); 90 | index.search(1, query.cast(), resultSize, distances.cast(), labels.cast()); 91 | 92 | log.info("Vectors:\n{}", show(xb, numberOfVector, d)); 93 | log.info("Query:\n{}", show(query, queryConds.length, queryConds[0].length)); 94 | log.info("Distances:\n{}", show(distances, 1, resultSize)); 95 | log.info("Labels:\n{}", show(labels, 1, resultSize)); 96 | } 97 | } catch (Exception e) { 98 | log.error("failed", e); 99 | } 100 | } 101 | 102 | public void testSearchRange() { 103 | float[][] data = dummyData3d(20); 104 | int d = data[0].length; 105 | int numberOfVector = data.length; 106 | 107 | try { 108 | floatArray xb = makeFloatArray(data); 109 | IndexFlatL2 index = new IndexFlatL2(d); 110 | index.add(numberOfVector, xb.cast()); 111 | 112 | { 113 | int resultSize = 4; 114 | float[][] queryConds = {new float[]{0, 1, 8}}; 115 | floatArray query = makeFloatArray(queryConds); 116 | 117 | RangeSearchResult re = new RangeSearchResult(resultSize); 118 | int querySize = queryConds.length; 119 | index.range_search(querySize, query.cast(), 0.3f, re); 120 | 121 | longlongArray labels = longlongArray.frompointer(re.getLabels()); 122 | floatArray distances = floatArray.frompointer(re.getDistances()); 123 | 124 | log.info("Vectors:\n{}", show(xb, numberOfVector, d)); 125 | log.info("Query:\n{}", show(query, querySize, queryConds[0].length)); 126 | log.info("Distances:\n{}", show(distances, querySize, resultSize)); 127 | log.info("Labels:\n{}", show(labels, querySize, resultSize)); 128 | } 129 | 130 | } catch (Exception e) { 131 | log.error("failed", e); 132 | } 133 | } 134 | 135 | public static void egIndexIVFFlat() { 136 | try { 137 | float[][] data = randomData3d(200); 138 | int dimension = data[0].length; 139 | int numberOfVector = data.length; 140 | int nlist = 6; 141 | int nprobe = 2; 142 | 143 | IndexFlatL2 quantizer = new IndexFlatL2(dimension); 144 | IndexIVFFlat index = new IndexIVFFlat(quantizer, dimension, nlist, MetricType.METRIC_L2); 145 | 146 | float[][] trainData = dummyData3d(5); 147 | floatArray tb = makeFloatArray(trainData); 148 | log.info("Vectors:\n{}", show(tb, trainData.length, dimension)); 149 | index.train(trainData.length, tb.cast()); 150 | 151 | floatArray xb = makeFloatArray(data); 152 | index.add(numberOfVector, xb.cast()); 153 | 154 | int resultSize = 10; 155 | float[][] queryConds = {new float[]{0, 0, 8}}; 156 | 157 | floatArray query = makeFloatArray(queryConds); 158 | longlongArray labels = new longlongArray(resultSize); 159 | floatArray distances = new floatArray(resultSize); 160 | 161 | int numberOfQuery = queryConds.length; 162 | index.setNprobe(nprobe); 163 | index.search(numberOfQuery, query.cast(), resultSize, distances.cast(), labels.cast()); 164 | 165 | log.info("Vectors:\n{}", show(xb, numberOfVector, dimension)); 166 | log.info("Query:\n{}", show(query, queryConds.length, queryConds[0].length)); 167 | log.info("Distances:\n{}", show(distances, 1, resultSize)); 168 | log.info("Labels:\n{}", show(labels, 1, resultSize)); 169 | } catch (Exception e) { 170 | log.error("failed", e); 171 | } 172 | } 173 | 174 | private static float[][] dummyData3d(int size) { 175 | float[][] data = new float[size * 3][3]; 176 | float half = size / 2.0f; 177 | for (int i = 0; i < 3; i++) { 178 | for (int j = 0; j < size; j++) { 179 | float[] row = new float[]{0, 0, 0}; 180 | row[i] = j - half; 181 | data[i * size + j] = row; 182 | } 183 | } 184 | return data; 185 | } 186 | 187 | private static float[][] randomData3d(int size) { 188 | float[][] data = new float[size * 3][3]; 189 | float half = size / 2.0f; 190 | Random rand = new Random(); 191 | for (int i = 0, j = data.length; i < j; i++) { 192 | float[] row = new float[]{rand.nextFloat() * size, rand.nextFloat() * size, rand.nextFloat() * size}; 193 | data[i] = row; 194 | } 195 | return data; 196 | } 197 | 198 | public static void main(String argv[]) { 199 | // testFlat(); 200 | egIndexIVFFlat(); 201 | } 202 | 203 | } 204 | -------------------------------------------------------------------------------- /swigfaiss4j.i: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | // -*- C++ -*- 9 | 10 | // This file describes the C++-scripting language bridge for both Lua 11 | // and Python It contains mainly includes and a few macros. There are 12 | // 3 preprocessor macros of interest: 13 | // SWIGLUA: Lua-specific code 14 | // SWIGPYTHON: Python-specific code 15 | // GPU_WRAPPER: also compile interfaces for GPU. 16 | 17 | %rename(__operator_lt__) operator<; 18 | %rename(__operator_middle__) operator[]; 19 | %rename(__operator_small__) operator(); 20 | 21 | %module swigfaiss; 22 | 23 | #ifdef SWIGJAVA 24 | %include "arrays_java.i" 25 | //%apply int[] {int *}; 26 | //%apply double[] {double *}; 27 | //%apply float[] {float *}; 28 | //%apply long[] {long *}; 29 | %include "carrays.i" 30 | %array_class(int, intArray); 31 | %array_class(float, floatArray); 32 | %array_class(long, longArray); 33 | %array_class(long long, longlongArray); 34 | %array_class(double, doubleArray); 35 | #endif 36 | 37 | // fbode SWIG fails on warnings, so make them non fatal 38 | //#pragma SWIG nowarn=321 39 | //#pragma SWIG nowarn=403 40 | //#pragma SWIG nowarn=325 41 | //#pragma SWIG nowarn=389 42 | //#pragma SWIG nowarn=341 43 | //#pragma SWIG nowarn=512 44 | 45 | 46 | %include 47 | typedef int64_t size_t; 48 | 49 | #define __restrict 50 | 51 | 52 | /******************************************************************* 53 | * Copied verbatim to wrapper. Contains the C++-visible includes, and 54 | * the language includes for their respective matrix libraries. 55 | *******************************************************************/ 56 | 57 | %{ 58 | 59 | 60 | #include 61 | #include 62 | 63 | 64 | #include "IndexFlat.h" 65 | #include "VectorTransform.h" 66 | #include "IndexLSH.h" 67 | #include "IndexPQ.h" 68 | 69 | #include "IndexIVF.h" 70 | #include "IndexIVFPQ.h" 71 | #include "IndexIVFFlat.h" 72 | #include "IndexScalarQuantizer.h" 73 | #include "IndexIVFSpectralHash.h" 74 | #include "ThreadedIndex.h" 75 | #include "IndexShards.h" 76 | #include "IndexReplicas.h" 77 | #include "HNSW.h" 78 | #include "IndexHNSW.h" 79 | #include "MetaIndexes.h" 80 | #include "FaissAssert.h" 81 | 82 | #include "IndexBinaryFlat.h" 83 | #include "IndexBinaryIVF.h" 84 | #include "IndexBinaryFromFloat.h" 85 | #include "IndexBinaryHNSW.h" 86 | 87 | #include "index_io.h" 88 | 89 | #include "IVFlib.h" 90 | #include "utils.h" 91 | #include "distances.h" 92 | #include "Heap.h" 93 | #include "AuxIndexStructures.h" 94 | #include "OnDiskInvertedLists.h" 95 | 96 | #include "Clustering.h" 97 | 98 | #include "hamming.h" 99 | 100 | #include "AutoTune.h" 101 | 102 | 103 | 104 | %} 105 | 106 | 107 | /******************************************************************* 108 | * Types of vectors we want to manipulate at the scripting language 109 | * level. 110 | *******************************************************************/ 111 | 112 | // simplified interface for vector 113 | namespace std { 114 | 115 | template 116 | class vector { 117 | public: 118 | vector(); 119 | void push_back(T); 120 | void clear(); 121 | T * data(); 122 | size_t size(); 123 | T at (size_t n) const; 124 | void resize (size_t n); 125 | void swap (vector & other); 126 | }; 127 | }; 128 | 129 | 130 | 131 | %template(FloatVector) std::vector; 132 | %template(DoubleVector) std::vector; 133 | %template(ByteVector) std::vector; 134 | %template(CharVector) std::vector; 135 | // NOTE(hoss): Using unsigned long instead of uint64_t because OSX defines 136 | // uint64_t as unsigned long long, which SWIG is not aware of. 137 | %template(Uint64Vector) std::vector; 138 | %template(LongVector) std::vector; 139 | %template(IntVector) std::vector; 140 | %template(VectorTransformVector) std::vector; 141 | %template(OperatingPointVector) std::vector; 142 | %template(InvertedListsPtrVector) std::vector; 143 | %template(FloatVectorVector) std::vector >; 144 | %template(ByteVectorVector) std::vector >; 145 | %template(LongVectorVector) std::vector >; 146 | 147 | #ifdef GPU_WRAPPER 148 | %template(GpuResourcesVector) std::vector; 149 | #endif 150 | 151 | %include 152 | 153 | // produces an error on the Mac 154 | %ignore faiss::hamming; 155 | 156 | /******************************************************************* 157 | * Parse headers 158 | *******************************************************************/ 159 | 160 | 161 | %ignore *::cmp; 162 | 163 | %include "Heap.h" 164 | %include "hamming.h" 165 | 166 | int get_num_gpus(); 167 | 168 | #ifdef GPU_WRAPPER 169 | 170 | %{ 171 | 172 | #include "gpu/StandardGpuResources.h" 173 | #include "gpu/GpuIndicesOptions.h" 174 | #include "gpu/GpuClonerOptions.h" 175 | #include "gpu/utils/MemorySpace.h" 176 | #include "gpu/GpuIndex.h" 177 | #include "gpu/GpuIndexFlat.h" 178 | #include "gpu/GpuIndexIVF.h" 179 | #include "gpu/GpuIndexIVFPQ.h" 180 | #include "gpu/GpuIndexIVFFlat.h" 181 | #include "gpu/GpuIndexBinaryFlat.h" 182 | #include "gpu/GpuAutoTune.h" 183 | #include "gpu/GpuDistance.h" 184 | 185 | int get_num_gpus() 186 | { 187 | return faiss::gpu::getNumDevices(); 188 | } 189 | 190 | %} 191 | 192 | // causes weird wrapper bug 193 | %ignore *::getMemoryManager; 194 | %ignore *::getMemoryManagerCurrentDevice; 195 | 196 | %include "gpu/GpuResources.h" 197 | %include "gpu/StandardGpuResources.h" 198 | 199 | #else 200 | 201 | %{ 202 | int get_num_gpus() 203 | { 204 | return 0; 205 | } 206 | %} 207 | 208 | 209 | #endif 210 | 211 | 212 | %include "utils.h" 213 | 214 | %include "Index.h" 215 | %include "Clustering.h" 216 | 217 | %include "distances.h" 218 | 219 | %ignore faiss::ProductQuantizer::get_centroids(size_t,size_t) const; 220 | 221 | %include "ProductQuantizer.h" 222 | 223 | %include "VectorTransform.h" 224 | %include "IndexFlat.h" 225 | %include "IndexLSH.h" 226 | %include "PolysemousTraining.h" 227 | %include "IndexPQ.h" 228 | 229 | %include "InvertedLists.h" 230 | 231 | %include 232 | %interface_impl(faiss::Level1Quantizer); 233 | //%interface_impl(faiss::Index); 234 | %ignore InvertedListScanner; 235 | %ignore BinaryInvertedListScanner; 236 | %include "IndexIVF.h" 237 | // NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the 238 | // non-const one. 239 | %warnfilter(509) extract_index_ivf; 240 | %include "IVFlib.h" 241 | %include "IndexScalarQuantizer.h" 242 | %include "IndexIVFSpectralHash.h" 243 | %include "HNSW.h" 244 | %include "IndexHNSW.h" 245 | %include "IndexIVFFlat.h" 246 | %include "OnDiskInvertedLists.h" 247 | 248 | %ignore faiss::IndexIVFPQ::alloc_type; 249 | %include "IndexIVFPQ.h" 250 | 251 | %include "IndexBinary.h" 252 | %include "IndexBinaryFlat.h" 253 | %include "IndexBinaryIVF.h" 254 | %include "IndexBinaryFromFloat.h" 255 | %include "IndexBinaryHNSW.h" 256 | 257 | 258 | 259 | // %ignore faiss::IndexReplicas::at(int) const; 260 | 261 | %include "ThreadedIndex.h" 262 | %template(ThreadedIndexBase) faiss::ThreadedIndex; 263 | %template(ThreadedIndexBaseBinary) faiss::ThreadedIndex; 264 | 265 | %include "IndexShards.h" 266 | %template(IndexShards) faiss::IndexShardsTemplate; 267 | %template(IndexBinaryShards) faiss::IndexShardsTemplate; 268 | 269 | %include "IndexReplicas.h" 270 | %template(IndexReplicas) faiss::IndexReplicasTemplate; 271 | %template(IndexBinaryReplicas) faiss::IndexReplicasTemplate; 272 | 273 | 274 | %include "MetaIndexes.h" 275 | %template(IndexIDMap) faiss::IndexIDMapTemplate; 276 | %template(IndexBinaryIDMap) faiss::IndexIDMapTemplate; 277 | %template(IndexIDMap2) faiss::IndexIDMap2Template; 278 | %template(IndexBinaryIDMap2) faiss::IndexIDMap2Template; 279 | 280 | #ifdef GPU_WRAPPER 281 | 282 | // quiet SWIG warnings 283 | %ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF; 284 | 285 | %include "gpu/GpuIndicesOptions.h" 286 | %include "gpu/GpuClonerOptions.h" 287 | %include "gpu/utils/MemorySpace.h" 288 | %include "gpu/GpuIndex.h" 289 | %include "gpu/GpuIndexFlat.h" 290 | %include "gpu/GpuIndexIVF.h" 291 | %include "gpu/GpuIndexIVFPQ.h" 292 | %include "gpu/GpuIndexIVFFlat.h" 293 | %include "gpu/GpuIndexBinaryFlat.h" 294 | %include "gpu/GpuDistance.h" 295 | 296 | #ifdef SWIGLUA 297 | 298 | /// in Lua, swigfaiss_gpu is known as swigfaiss 299 | %luacode { 300 | local swigfaiss = swigfaiss_gpu 301 | } 302 | 303 | #endif 304 | 305 | 306 | #endif 307 | 308 | 309 | 310 | 311 | /******************************************************************* 312 | * Lua-specific: support async execution of searches in an index 313 | * Python equivalent is just to use Python threads. 314 | *******************************************************************/ 315 | 316 | 317 | #ifdef SWIGLUA 318 | 319 | %{ 320 | 321 | 322 | namespace faiss { 323 | 324 | struct AsyncIndexSearchC { 325 | typedef Index::idx_t idx_t; 326 | const Index * index; 327 | 328 | idx_t n; 329 | const float *x; 330 | idx_t k; 331 | float *distances; 332 | idx_t *labels; 333 | 334 | bool is_finished; 335 | 336 | pthread_t thread; 337 | 338 | 339 | AsyncIndexSearchC (const Index *index, 340 | idx_t n, const float *x, idx_t k, 341 | float *distances, idx_t *labels): 342 | index(index), n(n), x(x), k(k), distances(distances), 343 | labels(labels) 344 | { 345 | is_finished = false; 346 | pthread_create (&thread, NULL, &AsyncIndexSearchC::callback, 347 | this); 348 | } 349 | 350 | static void *callback (void *arg) 351 | { 352 | AsyncIndexSearchC *aidx = (AsyncIndexSearchC *)arg; 353 | aidx->do_search(); 354 | return NULL; 355 | } 356 | 357 | void do_search () 358 | { 359 | index->search (n, x, k, distances, labels); 360 | } 361 | void join () 362 | { 363 | pthread_join (thread, NULL); 364 | } 365 | 366 | }; 367 | 368 | } 369 | 370 | %} 371 | 372 | // re-decrlare only what we need 373 | namespace faiss { 374 | 375 | struct AsyncIndexSearchC { 376 | typedef Index::idx_t idx_t; 377 | bool is_finished; 378 | AsyncIndexSearchC (const Index *index, 379 | idx_t n, const float *x, idx_t k, 380 | float *distances, idx_t *labels); 381 | 382 | 383 | void join (); 384 | }; 385 | 386 | } 387 | 388 | 389 | #endif 390 | 391 | 392 | 393 | 394 | /******************************************************************* 395 | * downcast return of some functions so that the sub-class is used 396 | * instead of the generic upper-class. 397 | *******************************************************************/ 398 | #ifdef SWIGJAVA 399 | 400 | %define DOWNCAST(subclass) 401 | if (dynamic_cast ($1)) { 402 | faiss::subclass *instance_ptr = (faiss::subclass *)$1; 403 | $result = (jlong)instance_ptr; 404 | } else 405 | %enddef 406 | 407 | %define DOWNCAST2(subclass, longname) 408 | if (dynamic_cast ($1)) { 409 | faiss::subclass *instance_ptr = (faiss::subclass *)$1; 410 | $result = (jlong)instance_ptr; 411 | } else 412 | %enddef 413 | 414 | %define DOWNCAST_GPU(subclass) 415 | if (dynamic_cast ($1)) { 416 | faiss::subclass *instance_ptr = (faiss::subclass *)$1; 417 | $result = (jlong)instance_ptr; 418 | } else 419 | %enddef 420 | 421 | #endif 422 | 423 | #ifdef SWIGPYTHON 424 | 425 | %define DOWNCAST(subclass) 426 | if (dynamic_cast ($1)) { 427 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner); 428 | } else 429 | %enddef 430 | 431 | %define DOWNCAST2(subclass, longname) 432 | if (dynamic_cast ($1)) { 433 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner); 434 | } else 435 | %enddef 436 | 437 | %define DOWNCAST_GPU(subclass) 438 | if (dynamic_cast ($1)) { 439 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__gpu__ ## subclass,$owner); 440 | } else 441 | %enddef 442 | 443 | #endif 444 | 445 | %newobject read_index; 446 | %newobject read_index_binary; 447 | %newobject read_VectorTransform; 448 | %newobject read_ProductQuantizer; 449 | %newobject clone_index; 450 | %newobject clone_VectorTransform; 451 | 452 | // Subclasses should appear before their parent 453 | %typemap(out) faiss::Index * { 454 | DOWNCAST2 ( IndexIDMap, IndexIDMapTemplateT_faiss__Index_t ) 455 | DOWNCAST2 ( IndexIDMap2, IndexIDMap2TemplateT_faiss__Index_t ) 456 | DOWNCAST2 ( IndexShards, IndexShardsTemplateT_faiss__Index_t ) 457 | DOWNCAST2 ( IndexReplicas, IndexReplicasTemplateT_faiss__Index_t ) 458 | DOWNCAST ( IndexIVFPQR ) 459 | DOWNCAST ( IndexIVFPQ ) 460 | DOWNCAST ( IndexIVFSpectralHash ) 461 | DOWNCAST ( IndexIVFScalarQuantizer ) 462 | DOWNCAST ( IndexIVFFlatDedup ) 463 | DOWNCAST ( IndexIVFFlat ) 464 | DOWNCAST ( IndexIVF ) 465 | DOWNCAST ( IndexFlat ) 466 | DOWNCAST ( IndexPQ ) 467 | DOWNCAST ( IndexScalarQuantizer ) 468 | DOWNCAST ( IndexLSH ) 469 | DOWNCAST ( IndexPreTransform ) 470 | DOWNCAST ( MultiIndexQuantizer ) 471 | DOWNCAST ( IndexHNSWFlat ) 472 | DOWNCAST ( IndexHNSWPQ ) 473 | DOWNCAST ( IndexHNSWSQ ) 474 | DOWNCAST ( IndexHNSW2Level ) 475 | DOWNCAST ( Index2Layer ) 476 | #ifdef GPU_WRAPPER 477 | DOWNCAST_GPU ( GpuIndexIVFPQ ) 478 | DOWNCAST_GPU ( GpuIndexIVFFlat ) 479 | DOWNCAST_GPU ( GpuIndexFlat ) 480 | #endif 481 | // default for non-recognized classes 482 | DOWNCAST ( Index ) 483 | if ($1 == NULL) 484 | { 485 | #ifdef SWIGJAVA 486 | $result = 0; 487 | #endif 488 | } else { 489 | assert(false); 490 | } 491 | #ifdef SWIGLUA 492 | SWIG_arg++; 493 | #endif 494 | } 495 | 496 | %typemap(out) faiss::IndexBinary * { 497 | DOWNCAST2 ( IndexBinaryReplicas, IndexReplicasTemplateT_faiss__IndexBinary_t ) 498 | DOWNCAST2 ( IndexBinaryIDMap, IndexIDMapTemplateT_faiss__IndexBinary_t ) 499 | DOWNCAST2 ( IndexBinaryIDMap2, IndexIDMap2TemplateT_faiss__IndexBinary_t ) 500 | DOWNCAST ( IndexBinaryIVF ) 501 | DOWNCAST ( IndexBinaryFlat ) 502 | DOWNCAST ( IndexBinaryFromFloat ) 503 | DOWNCAST ( IndexBinaryHNSW ) 504 | #ifdef GPU_WRAPPER 505 | DOWNCAST_GPU ( GpuIndexBinaryFlat ) 506 | #endif 507 | // default for non-recognized classes 508 | DOWNCAST ( IndexBinary ) 509 | if ($1 == NULL) 510 | { 511 | #ifdef SWIGPYTHON 512 | $result = SWIG_Py_Void(); 513 | #endif 514 | #ifdef SWIGJAVA 515 | $result = 0; 516 | #endif 517 | // Lua does not need a push for nil 518 | } else { 519 | assert(false); 520 | } 521 | #ifdef SWIGLUA 522 | SWIG_arg++; 523 | #endif 524 | } 525 | 526 | %typemap(out) faiss::VectorTransform * { 527 | DOWNCAST (RemapDimensionsTransform) 528 | DOWNCAST (OPQMatrix) 529 | DOWNCAST (PCAMatrix) 530 | DOWNCAST (RandomRotationMatrix) 531 | DOWNCAST (LinearTransform) 532 | DOWNCAST (NormalizationTransform) 533 | DOWNCAST (CenteringTransform) 534 | DOWNCAST (VectorTransform) 535 | { 536 | assert(false); 537 | } 538 | #ifdef SWIGLUA 539 | SWIG_arg++; 540 | #endif 541 | } 542 | 543 | %typemap(out) faiss::InvertedLists * { 544 | DOWNCAST (ArrayInvertedLists) 545 | DOWNCAST (OnDiskInvertedLists) 546 | DOWNCAST (VStackInvertedLists) 547 | DOWNCAST (HStackInvertedLists) 548 | DOWNCAST (MaskedInvertedLists) 549 | DOWNCAST (InvertedLists) 550 | { 551 | assert(false); 552 | } 553 | #ifdef SWIGLUA 554 | SWIG_arg++; 555 | #endif 556 | } 557 | 558 | // just to downcast pointers that come from elsewhere (eg. direct 559 | // access to object fields) 560 | %inline %{ 561 | faiss::Index * downcast_index (faiss::Index *index) 562 | { 563 | return index; 564 | } 565 | faiss::VectorTransform * downcast_VectorTransform (faiss::VectorTransform *vt) 566 | { 567 | return vt; 568 | } 569 | faiss::IndexBinary * downcast_IndexBinary (faiss::IndexBinary *index) 570 | { 571 | return index; 572 | } 573 | faiss::InvertedLists * downcast_InvertedLists (faiss::InvertedLists *il) 574 | { 575 | return il; 576 | } 577 | %} 578 | 579 | 580 | %include "index_io.h" 581 | 582 | %newobject index_factory; 583 | %newobject index_binary_factory; 584 | 585 | %include "AutoTune.h" 586 | 587 | 588 | #ifdef GPU_WRAPPER 589 | 590 | %newobject index_gpu_to_cpu; 591 | %newobject index_cpu_to_gpu; 592 | %newobject index_cpu_to_gpu_multiple; 593 | 594 | %include "gpu/GpuAutoTune.h" 595 | 596 | #endif 597 | 598 | // Python-specific: do not release GIL any more, as functions below 599 | // use the Python/C API 600 | #ifdef SWIGPYTHON 601 | %exception; 602 | #endif 603 | 604 | 605 | 606 | 607 | 608 | /******************************************************************* 609 | * Python specific: numpy array <-> C++ pointer interface 610 | *******************************************************************/ 611 | 612 | #ifdef SWIGPYTHON 613 | 614 | %{ 615 | PyObject *swig_ptr (PyObject *a) 616 | { 617 | if(!PyArray_Check(a)) { 618 | PyErr_SetString(PyExc_ValueError, "input not a numpy array"); 619 | return NULL; 620 | } 621 | PyArrayObject *ao = (PyArrayObject *)a; 622 | 623 | if(!PyArray_ISCONTIGUOUS(ao)) { 624 | PyErr_SetString(PyExc_ValueError, "array is not C-contiguous"); 625 | return NULL; 626 | } 627 | void * data = PyArray_DATA(ao); 628 | if(PyArray_TYPE(ao) == NPY_FLOAT32) { 629 | return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0); 630 | } 631 | if(PyArray_TYPE(ao) == NPY_FLOAT64) { 632 | return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0); 633 | } 634 | if(PyArray_TYPE(ao) == NPY_INT32) { 635 | return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0); 636 | } 637 | if(PyArray_TYPE(ao) == NPY_UINT8) { 638 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0); 639 | } 640 | if(PyArray_TYPE(ao) == NPY_INT8) { 641 | return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0); 642 | } 643 | if(PyArray_TYPE(ao) == NPY_UINT64) { 644 | #ifdef SWIGWORDSIZE64 645 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0); 646 | #else 647 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long_long, 0); 648 | #endif 649 | } 650 | if(PyArray_TYPE(ao) == NPY_INT64) { 651 | #ifdef SWIGWORDSIZE64 652 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0); 653 | #else 654 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long_long, 0); 655 | #endif 656 | } 657 | PyErr_SetString(PyExc_ValueError, "did not recognize array type"); 658 | return NULL; 659 | } 660 | 661 | 662 | struct PythonInterruptCallback: faiss::InterruptCallback { 663 | 664 | bool want_interrupt () override { 665 | int err; 666 | { 667 | PyGILState_STATE gstate; 668 | gstate = PyGILState_Ensure(); 669 | err = PyErr_CheckSignals(); 670 | PyGILState_Release(gstate); 671 | } 672 | return err == -1; 673 | } 674 | 675 | }; 676 | 677 | 678 | %} 679 | 680 | 681 | %init %{ 682 | /* needed, else crash at runtime */ 683 | import_array(); 684 | 685 | faiss::InterruptCallback::instance.reset(new PythonInterruptCallback()); 686 | 687 | %} 688 | 689 | // return a pointer usable as input for functions that expect pointers 690 | PyObject *swig_ptr (PyObject *a); 691 | 692 | %define REV_SWIG_PTR(ctype, numpytype) 693 | 694 | %{ 695 | PyObject * rev_swig_ptr(ctype *src, npy_intp size) { 696 | return PyArray_SimpleNewFromData(1, &size, numpytype, src); 697 | } 698 | %} 699 | 700 | PyObject * rev_swig_ptr(ctype *src, size_t size); 701 | 702 | %enddef 703 | 704 | REV_SWIG_PTR(float, NPY_FLOAT32); 705 | REV_SWIG_PTR(int, NPY_INT32); 706 | REV_SWIG_PTR(unsigned char, NPY_UINT8); 707 | REV_SWIG_PTR(int64_t, NPY_INT64); 708 | REV_SWIG_PTR(uint64_t, NPY_UINT64); 709 | 710 | #endif 711 | 712 | 713 | 714 | /******************************************************************* 715 | * Lua specific: Torch tensor <-> C++ pointer interface 716 | *******************************************************************/ 717 | 718 | #ifdef SWIGLUA 719 | 720 | 721 | // provide a XXX_ptr function to convert Lua XXXTensor -> C++ XXX* 722 | 723 | %define TYPE_CONVERSION(ctype, tensortype) 724 | 725 | // typemap for the *_ptr_from_cdata function 726 | %typemap(in) ctype** { 727 | if(lua_type(L, $input) != 10) { 728 | fprintf(stderr, "not cdata input\n"); 729 | SWIG_fail; 730 | } 731 | $1 = (ctype**)lua_topointer(L, $input); 732 | } 733 | 734 | 735 | // SWIG and C declaration for the *_ptr_from_cdata function 736 | %{ 737 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs) { 738 | return *x + ofs; 739 | } 740 | %} 741 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs); 742 | 743 | // the *_ptr function 744 | %luacode { 745 | 746 | function swigfaiss. ctype ## _ptr(tensor) 747 | assert(tensor:type() == "torch." .. # tensortype, "need a " .. # tensortype) 748 | assert(tensor:isContiguous(), "requires contiguous tensor") 749 | return swigfaiss. ctype ## _ptr_from_cdata( 750 | tensor:storage():data(), 751 | tensor:storageOffset() - 1) 752 | end 753 | 754 | } 755 | 756 | %enddef 757 | 758 | TYPE_CONVERSION (int, IntTensor) 759 | TYPE_CONVERSION (float, FloatTensor) 760 | TYPE_CONVERSION (long, LongTensor) 761 | TYPE_CONVERSION (uint64_t, LongTensor) 762 | TYPE_CONVERSION (uint8_t, ByteTensor) 763 | 764 | #endif 765 | 766 | /******************************************************************* 767 | * How should the template objects apprear in the scripting language? 768 | *******************************************************************/ 769 | 770 | // answer: the same as the C++ typedefs, but we still have to redefine them 771 | 772 | %template() faiss::CMin; 773 | %template() faiss::CMin; 774 | %template() faiss::CMax; 775 | %template() faiss::CMax; 776 | 777 | %template(float_minheap_array_t) faiss::HeapArray >; 778 | %template(int_minheap_array_t) faiss::HeapArray >; 779 | 780 | %template(float_maxheap_array_t) faiss::HeapArray >; 781 | %template(int_maxheap_array_t) faiss::HeapArray >; 782 | 783 | 784 | /******************************************************************* 785 | * Expose a few basic functions 786 | *******************************************************************/ 787 | 788 | 789 | void omp_set_num_threads (int num_threads); 790 | int omp_get_max_threads (); 791 | void *memcpy(void *dest, const void *src, size_t n); 792 | 793 | 794 | /******************************************************************* 795 | * For Faiss/Pytorch interop via pointers encoded as longs 796 | *******************************************************************/ 797 | 798 | %inline %{ 799 | float * cast_integer_to_float_ptr (long x) { 800 | return (float*)x; 801 | } 802 | 803 | long * cast_integer_to_long_ptr (long x) { 804 | return (long*)x; 805 | } 806 | 807 | int * cast_integer_to_int_ptr (long x) { 808 | return (int*)x; 809 | } 810 | 811 | %} 812 | 813 | 814 | 815 | /******************************************************************* 816 | * Range search interface 817 | *******************************************************************/ 818 | 819 | %ignore faiss::BufferList::Buffer; 820 | %ignore faiss::RangeSearchPartialResult::QueryResult; 821 | %ignore faiss::IDSelectorBatch::set; 822 | %ignore faiss::IDSelectorBatch::bloom; 823 | 824 | %ignore faiss::InterruptCallback::instance; 825 | %ignore faiss::InterruptCallback::lock; 826 | %include "AuxIndexStructures.h" 827 | 828 | %{ 829 | // may be useful for lua code launched in background from shell 830 | 831 | #include 832 | void ignore_SIGTTIN() { 833 | signal(SIGTTIN, SIG_IGN); 834 | } 835 | %} 836 | 837 | void ignore_SIGTTIN(); 838 | 839 | 840 | %inline %{ 841 | 842 | // numpy misses a hash table implementation, hence this class. It 843 | // represents not found values as -1 like in the Index implementation 844 | 845 | struct MapLong2Long { 846 | std::unordered_map map; 847 | 848 | void add(size_t n, const int64_t *keys, const int64_t *vals) { 849 | map.reserve(map.size() + n); 850 | for (size_t i = 0; i < n; i++) { 851 | map[keys[i]] = vals[i]; 852 | } 853 | } 854 | 855 | long search(int64_t key) { 856 | if (map.count(key) == 0) { 857 | return -1; 858 | } else { 859 | return map[key]; 860 | } 861 | } 862 | 863 | void search_multiple(size_t n, int64_t *keys, int64_t * vals) { 864 | for (size_t i = 0; i < n; i++) { 865 | vals[i] = search(keys[i]); 866 | } 867 | } 868 | }; 869 | 870 | %} 871 | 872 | %inline %{ 873 | void wait() { 874 | // in gdb, use return to get out of this function 875 | for(int i = 0; i == 0; i += 0); 876 | } 877 | %} 878 | 879 | // End of file... 880 | --------------------------------------------------------------------------------