├── src
├── main
│ ├── java
│ │ └── com
│ │ │ └── bj58
│ │ │ └── spider
│ │ │ ├── faiss
│ │ │ └── .gitignore
│ │ │ └── faiss4j
│ │ │ ├── IndexHelper.java
│ │ │ └── FaissIndex.java
│ └── resources
│ │ └── log4j2.xml
└── test
│ └── java
│ └── com
│ └── bj58
│ └── spider
│ └── faiss
│ └── tests
│ ├── IOTests.java
│ └── Examples.java
├── test.sh
├── swigc++2java.sh
├── gendylib.sh
├── README.md
├── pom.xml
└── swigfaiss4j.i
/src/main/java/com/bj58/spider/faiss/.gitignore:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | export LD_LIBRARY_PATH="."
2 | JAVA="java"
3 | #JAVA="/opt/soft/jdk/jdk1.8.0_66/bin/java"
4 | JAR="target/faiss4java-1.0.0-jar-with-dependencies.jar"
5 |
6 | ulimit -c unlimited
7 | $JAVA -Djava.library.path="." -classpath $JAR com.bj58.spider.faiss4j.FaissIndex "part-00000" "128"
8 |
--------------------------------------------------------------------------------
/swigc++2java.sh:
--------------------------------------------------------------------------------
1 | rm -rf src/main/java/com/bj58/spider/faiss/*.java swigfaiss4j.cpp
2 | mkdir -p src/main/java/com/bj58/spider/faiss/
3 | # swig -version == 4.0
4 | # -I../faiss: 确保是faiss源码路径
5 | #/usr/local/bin/swig -c++ -java -package com.bj58.spider.faiss -o swigfaiss4j.cpp -outdir src/main/java/com/bj58/spider/faiss/ -Doverride= -I../faiss/ swigfaiss.swig
6 | /usr/local/bin/swig -c++ -java -package com.bj58.spider.faiss -o swigfaiss4j.cpp -outdir src/main/java/com/bj58/spider/faiss/ -Doverride= -I../faiss/ swigfaiss4j.i
7 |
8 |
--------------------------------------------------------------------------------
/gendylib.sh:
--------------------------------------------------------------------------------
1 | # mac
2 | SHAREDEXT=dylib
3 | # linux
4 | #SHAREDEXT=so
5 |
6 | rm -rf swigfaiss4j.$SHAREDEXT
7 |
8 | #改成自己的faiss安装目录
9 | FAISS_INSTALL_DIR=/usr/local/
10 |
11 | # -L$FAISS_INSTALL_DIR/lib: libfaiss.so/dylib 所在目录
12 | # -I$JAVA_HOME/include/: 找jni.h
13 | # $JAVA_HOME/include/linux/: 找jni_md.h
14 |
15 | # linux
16 | #CXX="g++ -std=c++11"
17 | #CXXFLAGS="-fPIC -fopenmp -m64 -Wno-unused-parameter -Wno-unused-parameter -Wno-strict-aliasing -g -O3 -Wextra -msse4 -mpopcnt -fpermissive -L$FAISS_INSTALL_DIR/lib -I$FAISS_INSTALL_DIR/include/faiss/ -I$JAVA_HOME/include/ -I$JAVA_HOME/include/linux/ "
18 | # mac
19 | CXX="/usr/local/opt/llvm/bin/clang++ -std=c++11"
20 | CXXFLAGS="-fPIC -fopenmp -m64 -Wall -g -O3 -Wextra -msse4 -mpopcnt -L$FAISS_INSTALL_DIR/lib -I$FAISS_INSTALL_DIR/include/faiss/ -I$JAVA_HOME/include/ -I$JAVA_HOME/include/darwin/"
21 |
22 | ${CXX} ${CXXFLAGS} -lfaiss swigfaiss4j.cpp -shared -o libswigfaiss4j.$SHAREDEXT
23 |
--------------------------------------------------------------------------------
/src/main/resources/log4j2.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/test/java/com/bj58/spider/faiss/tests/IOTests.java:
--------------------------------------------------------------------------------
1 | package com.bj58.spider.faiss.tests;
2 |
3 | import org.slf4j.Logger;
4 | import org.slf4j.LoggerFactory;
5 |
6 | import java.nio.file.Paths;
7 |
8 | public class IOTests {
9 | private static final Logger log = LoggerFactory.getLogger(IOTests.class);
10 |
11 | public static void load() {
12 | System.load(Paths.get("./swigfaiss4j.dylib").toAbsolutePath().toString());
13 | System.loadLibrary("faiss");
14 | }
15 |
16 | private static float[][] dummyData() {
17 | return new float[][]{
18 | new float[]{10, 0, 0},
19 | new float[]{9, 0, 0},
20 | new float[]{8, 0, 0},
21 | new float[]{7, 0, 0},
22 | new float[]{6, 0, 0},
23 |
24 | new float[]{0, 10, 0},
25 | new float[]{0, 9, 0},
26 | new float[]{0, 8, 0},
27 | new float[]{0, 7, 0},
28 | new float[]{0, 6, 0},
29 |
30 | new float[]{0, 0, 10},
31 | new float[]{0, 0, 9},
32 | new float[]{0, 0, 8},
33 | new float[]{0, 0, 7},
34 | new float[]{0, 0, 6},
35 | };
36 | }
37 |
38 | public void testIndexWriteAndRead() {
39 | String filename = "./index-1";
40 |
41 |
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # faiss4java
2 |
3 | ### Introduction
4 |
5 | 原始[faiss](https://github.com/facebookresearch/faiss)只支持c++和python,本项目支持v1.5.3版本faiss java接口,主要参考[faiss4j](https://github.com/thenetcircle/faiss4j.git)。
6 | > 以下过程在Mac Mojave系统下执行
7 |
8 | ### Building faiss
9 | 1. 安装faiss源码,参考[官网](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md)
10 |
11 | ```
12 | ➜ git clone -b v1.5.3 https://github.com/facebookresearch/faiss.git
13 | ➜ cd faiss
14 | ➜ ./configure --without-cuda #注意1
15 | ➜ make all && make install
16 | ➜ make py && make py install
17 | ```
18 | 2. 测试openblas库,
19 |
20 | ```
21 | errors = ... (不用管)
22 | info=0000064b00000000
23 | Lapack uses xx-bit integers
24 | ```
25 | 出现如上结果,说明成功了,直接跳到Buiding faiss4java部分。
26 | 3. 如出现下面错误:
27 | ```
28 | ➜ make misc/test_blas
29 | ➜ ./misc/test_blas
30 | dyld: Library not loaded: @rpath/libopenblas_nehalemp-r0.2.19.dylib
31 | Referenced from: /Users/jiananliu/test/faiss/./misc/test_blas
32 | Reason: image not found
33 | [1] 89887 abort ./misc/test_blas
34 | ```
35 | 后面在java中System.loadLibrary("swigfaiss4j");的时候也会报上面类似的错误。
36 | 虽然上面执行./configure会显示
37 |
38 | ```
39 | ...
40 | checking for sgemm_ in -lopenblas... yes
41 | ...
42 | ```
43 |
44 | 所以,需要指定--with-blas参数,即
45 | ```
46 | ➜ ./configure --without-cuda --with-blas=/usr/local/opt/openblas/lib/libopenblas.dylib
47 | ➜ rm -f misc/test_blas #make clean不能删除这个
48 | ➜ make misc/test_blas
49 | ➜ ./misc/test_blas
50 | ```
51 |
52 | 对应的openblas路径可以用如下命令查找:
53 |
54 | ```
55 | ➜ brew info openblas
56 | ```
57 | 重新编译faiss源码。
58 | ### Building faiss4java
59 | 1. 注意faiss4java和faiss同级目录,确保swig版本是4.0。
60 |
61 | ```
62 | ➜ git clone https://github.com/belkov0912/faiss4java.git
63 | # swig c++转成java接口,并生成swigfaiss4j.cpp
64 | ➜ ./swigc++2java.sh
65 | # 生成动态库,mac后缀是dylib,linux是so
66 | ➜ ./gendylib.sh
67 | ```
68 |
69 | 2. 编译完java接口,用idea打开会提示如下错误:
70 |
71 | ```
72 | protected xxxclass(long cPtr, boolean cMemoryOwn) {
73 | super(swigfaissJNI.IndexReplicas_SWIGUpcast(cPtr), cMemoryOwn);
74 | swigCPtr = cPtr;
75 | }
76 |
77 | Error:java: 已在类 com.bj58.spider.faiss.xxx 中定义了构造器 xxx(long,boolean)
78 | ```
79 | 直接把提示的构造器删除就可以了(可能swig解析c++模版类有关,不太清楚原理,欢迎大家指教)。
80 |
81 | 3. 运行com.bj58.spider.faiss.tests.Examples,需要配置java.library.path
82 |
83 | ```
84 | #VM options中加入
85 | -Djava.library.path="/xxx/faiss4j"
86 | ```
87 | 运行结果如下:
88 | ```
89 | 19-09-03 18:45:43 [main] INFO Examples:49 - is_trained = true
90 | 19-09-03 18:45:43 [main] INFO Examples:51 - ntotal = 10
91 | 19-09-03 18:45:43 [main] INFO Examples:62 - search 5 first vector of xb
92 | 19-09-03 18:45:44 [main] INFO Examples:64 - Vectors:
93 | 0 |0.00000 0.897081 0.0970035 0.0361408 0.149420
94 | 1 |0.00100000 0.286766 0.871114 0.127865 0.528833
95 | 2 |0.00200000 0.332409 0.414103 0.197723 0.603372
96 | ...
97 | ```
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/src/main/java/com/bj58/spider/faiss4j/IndexHelper.java:
--------------------------------------------------------------------------------
1 | package com.bj58.spider.faiss4j;
2 |
3 | import com.bj58.spider.faiss.*;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 |
7 | public class IndexHelper {
8 | private static final Logger log = LoggerFactory.getLogger(IndexHelper.class);
9 |
10 | public static String show(longlongArray a, int rows, int cols) {
11 | StringBuilder sb = new StringBuilder();
12 | for (int i = 0; i < rows; i++) {
13 | sb.append(i).append('\t').append('|');
14 | for (int j = 0; j < cols; j++) {
15 | sb.append(String.format("%5d ", a.getitem(i * cols + j)));
16 | }
17 | sb.append("\n");
18 | }
19 | return sb.toString();
20 | }
21 |
22 | public static String show(longArray a, int rows, int cols) {
23 | StringBuilder sb = new StringBuilder();
24 | for (int i = 0; i < rows; i++) {
25 | sb.append(i).append('\t').append('|');
26 | for (int j = 0; j < cols; j++) {
27 | sb.append(String.format("%5d ", a.getitem(i * cols + j)));
28 | }
29 | sb.append("\n");
30 | }
31 | return sb.toString();
32 | }
33 |
34 | public static String show(floatArray a, int rows, int cols) {
35 | StringBuilder sb = new StringBuilder();
36 | for (int i = 0; i < rows; i++) {
37 | sb.append(i).append('\t').append('|');
38 | for (int j = 0; j < cols; j++) {
39 | sb.append(String.format("%7g ", a.getitem(i * cols + j)));
40 | }
41 | sb.append("\n");
42 | }
43 | return sb.toString();
44 | }
45 |
46 | public static floatArray makeFloatArray(float[][] vectors) {
47 | int d = vectors[0].length;
48 | int nb = vectors.length;
49 | floatArray fa = new floatArray(d * nb);
50 | for (int i = 0; i < nb; i++) {
51 | for (int j = 0; j < d; j++) {
52 | fa.setitem(d * i + j, vectors[i][j]);
53 | }
54 | }
55 | return fa;
56 | }
57 |
58 | // 获取floatArray [i0 ~ i1)区间的所有行, 每行的维度是d
59 | public static float[][] getLineFromFloatArray(floatArray fa, int i0, int i1, int d) {
60 | float[][] rt = new float[i1-i0][d];
61 | for (int i=i0; i
2 |
4 |
5 | 4.0.0
6 | com.bj58.spider
7 | faiss4java
8 | 1.0.0
9 |
10 |
11 | UTF-8
12 | UTF-8
13 | 2.7.3
14 | 1.7
15 | 3.2.5
16 | 2.11
17 | ${scala.minor.version}.8
18 | 2.3.2
19 |
20 |
21 | UTF-8
22 | provided
23 | compile
24 | 2.0.10
25 | 1.0.89
26 | 1.0.2-SNAPSHOT
27 | 1.0.0-SNAPSHOT
28 | 1.0.0-SNAPSHOT
29 |
30 |
31 |
32 |
33 | org.apache.logging.log4j
34 | log4j-api
35 | 2.7
36 |
37 |
38 | commons-logging
39 | commons-logging
40 | 1.1.1
41 |
42 |
43 | org.apache.logging.log4j
44 | log4j-core
45 | 2.7
46 |
47 |
48 | org.apache.logging.log4j
49 | log4j-slf4j-impl
50 | 2.2
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | net.alchim31.maven
60 |
61 | scala-maven-plugin
62 | 3.1.5
63 |
64 |
65 |
66 |
67 |
68 | maven-assembly-plugin
69 |
70 |
71 | jar-with-dependencies
72 |
73 |
74 |
75 |
76 | net.alchim31.maven
77 | scala-maven-plugin
78 |
79 |
80 | scala-compile-first
81 | process-resources
82 |
83 | add-source
84 | compile
85 |
86 |
87 |
88 | scala-test-compile
89 |
90 | process-test-resources
91 |
92 | testCompile
93 |
94 |
95 |
96 |
97 |
98 | org.apache.maven.plugins
99 | maven-dependency-plugin
100 |
101 |
102 | copy
103 | package
104 |
105 | copy-dependencies
106 |
107 |
108 |
109 | ${project.build.directory}/lib
110 |
111 | true
112 | true
113 |
114 |
115 |
116 |
117 |
118 | org.apache.maven.plugins
119 | maven-compiler-plugin
120 | 3.6.1
121 |
122 | ${java.version}
123 | ${java.version}
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 | config-package
132 |
133 |
134 |
135 |
136 | org.apache.maven.plugins
137 | maven-antrun-plugin
138 | 1.7
139 |
140 |
141 | config-package
142 | package
143 |
144 | run
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 | true
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/src/main/java/com/bj58/spider/faiss4j/FaissIndex.java:
--------------------------------------------------------------------------------
1 | package com.bj58.spider.faiss4j;
2 |
3 | import com.bj58.spider.faiss.*;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 |
7 | import java.io.BufferedReader;
8 | import java.io.File;
9 | import java.io.FileReader;
10 | import java.io.IOException;
11 | import java.util.ArrayList;
12 | import java.lang.Math;
13 |
14 | import static com.bj58.spider.faiss4j.IndexHelper.*;
15 |
16 | public class FaissIndex {
17 | private static final Logger logger = LoggerFactory.getLogger(FaissIndex.class);
18 | private static final boolean debug_mode = false;
19 |
20 | static {
21 | try {
22 | System.loadLibrary("swigfaiss4j");
23 | } catch (Exception ex) {
24 | throw new java.lang.RuntimeException("please make sure libs:[faiss, swigfaiss4j] in java.library.path");
25 | }
26 |
27 | logger.info("load libswigfaiss4j success");
28 | }
29 |
30 | private Index index;
31 | private floatArray xb;
32 | private int d;
33 | private int nb;
34 | private float[][] tmp_vec = null;
35 |
36 | private float[][] fvecs_read(String fileName, int min, int max) {
37 | File file = new File(fileName);
38 | int max_line;
39 | if (debug_mode) max_line = 100000; else max_line = 943139;
40 | if (max == -1)
41 | max = max_line;
42 |
43 | BufferedReader reader = null;
44 | int row = 0, column = 0;
45 | ArrayList vecs = new ArrayList<>(max);
46 | try {
47 | reader = new BufferedReader(new FileReader(file));
48 | String line;
49 | while ((line = reader.readLine()) != null) {
50 | int last_index = line.lastIndexOf(':');
51 | // String id = line.substring(0, last_index);
52 | String values = line.substring(last_index + 1);
53 | String[] sp = values.split(",");
54 | if (column == 0) column = sp.length;
55 | if (sp.length != column) {
56 | logger.warn(String.format("column:[%d] != real_columns:[%d] row:[%d] line:[%s]", column, sp.length, row, line));
57 | continue;
58 | }
59 |
60 | float[] vs = IndexHelper.toFloatArray(sp);
61 | vecs.add(vs);
62 | row += 1;
63 | }
64 | reader.close();
65 | } catch (IOException e) {
66 | e.printStackTrace();
67 | } finally {
68 | if (reader != null) {
69 | try {
70 | reader.close();
71 | } catch (IOException e1) {
72 | logger.error("close reader failed");
73 | }
74 | }
75 | }
76 |
77 | return vecs.toArray(new float[row][column]);
78 | }
79 |
80 | public void set_train_data(String file_name) {
81 | float[][] data = fvecs_read(file_name, 0, -1);
82 | this.xb = makeFloatArray(data);
83 | this.nb = data.length;
84 | this.d = data[0].length;
85 | this.tmp_vec = data;
86 | logger.info(String.format("train vecs row:[%d] columns:[%d]", nb, d));
87 | }
88 |
89 | public void printMatrixStats() {
90 | MatrixStats matrixStats = new MatrixStats(nb, d, xb.cast());
91 | logger.info("-----------------------------------");
92 | logger.info(matrixStats.getComments());
93 | logger.info("-----------------------------------");
94 | }
95 |
96 | public void create_index(boolean use_imi, int nprob) {
97 | try {
98 | int nhash;
99 | long ncentroids;
100 | Index quantizer;
101 | if (use_imi) {
102 | nhash = 2;
103 | long nbits_subq = (int)(Math.log(nb+1)/Math.log(2))/2; // good choice in general
104 | ncentroids = 1 << (nhash * nbits_subq); // total # of centroids == nlist
105 | // MultiIndexQuantizer
106 | quantizer = new MultiIndexQuantizer(d, nhash, nbits_subq);
107 | logger.info(String.format("IMI (%d,%d): %d virtual centroids (target: %d base vectors)", nhash, nbits_subq, ncentroids, nb));
108 |
109 | } else {
110 | // IndexFlatL2
111 | quantizer = new IndexFlatL2(d);
112 | ncentroids = 4096;
113 | logger.info(String.format("IF: %d virtual centroids (target: %d base vectors)", ncentroids, nb));
114 | }
115 |
116 | IndexIVFFlat index = new IndexIVFFlat(quantizer, d, ncentroids, MetricType.METRIC_L2);
117 |
118 | logger.info(String.format("index is trained: %b", index.getIs_trained()));
119 |
120 | if (use_imi) {
121 | index.setQuantizer_trains_alone('1');
122 | }
123 |
124 | // define the number of probes. 2048 is for high-dim, overkilled in practice
125 | // Use 4-1024 depending on the trade-off speed accuracy that you want
126 | index.setNprobe(nprob);
127 | logger.info(String.format("set nprob:%d", nprob));
128 |
129 | long t0 = System.currentTimeMillis();
130 | // logger.info("Vectors:\n{}", show(tb, trainData.length, dimension));
131 | index.train(nb, xb.cast());
132 | long t1 = System.currentTimeMillis();
133 | logger.info(String.format("[%d ms] Training the index", t1-t0));
134 |
135 | index.add(nb, xb.cast());
136 | long t2 = System.currentTimeMillis();
137 | logger.info(String.format("[%d ms] Adding the vectors to the index", t2-t1));
138 |
139 | this.index = index;
140 | } catch (Exception e) {
141 | logger.error("failed", e);
142 | }
143 | }
144 |
145 | public void search(float[][] query) {
146 | int k = 5;
147 | int nq = query.length;
148 | floatArray q = makeFloatArray(query);
149 | longlongArray labels = new longlongArray(k * nq);
150 | floatArray distances = new floatArray(k * nq);
151 |
152 | long t4 = System.currentTimeMillis();
153 | index.search(query.length, q.cast(), k, distances.cast(), labels.cast());
154 |
155 | long t5 = System.currentTimeMillis();
156 | logger.info(String.format("[%dms] Searching the %d nearest neighbors of %d vectors in the index", t5 - t4, k, nq));
157 | logger.info("Query results (vector ids, then distances):");
158 |
159 | logger.info("Distances:\n{}", show(distances, nq, k));
160 | logger.info("Labels:\n{}", show(labels, nq, k));
161 | }
162 |
163 | public static void main(String[] argv) {
164 |
165 | String file_name = "part-00000";
166 | int nprob = 512;
167 | if (argv.length >= 2) {
168 | file_name = argv[0];
169 | nprob = Integer.parseInt(argv[1]);
170 | }
171 | logger.info("read file:" + file_name);
172 |
173 | FaissIndex index = new FaissIndex();
174 |
175 | index.set_train_data(file_name);
176 |
177 | index.printMatrixStats();
178 |
179 | index.create_index(true, nprob);
180 |
181 | {
182 | // remember a few elements from the database as queries
183 | int i0 = 1234;
184 | int i1 = 1245;
185 |
186 | // float[][] query = Arrays.copyOfRange(index.tmp_vec, i0, i1);
187 | float[][] query = getLineFromFloatArray(index.xb, i0, i1, index.d);
188 |
189 | index.search(query);
190 | }
191 | }
192 | }
193 |
--------------------------------------------------------------------------------
/src/test/java/com/bj58/spider/faiss/tests/Examples.java:
--------------------------------------------------------------------------------
1 | package com.bj58.spider.faiss.tests;
2 |
3 | import com.bj58.spider.faiss.*;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 |
7 | import java.util.Random;
8 |
9 | import static com.bj58.spider.faiss4j.IndexHelper.*;
10 |
11 |
12 | public class Examples {
13 | private static final Logger log = LoggerFactory.getLogger(Examples.class);
14 |
15 | static {
16 | //
17 | String property = System.getProperty("java.library.path");
18 | System.out.println(property);
19 | // System.loadLibrary("faiss");
20 | System.loadLibrary("swigfaiss4j");
21 | System.out.println("load libswigfaiss4j success");
22 | }
23 |
24 |
25 | public static void testFlat() {
26 | int d = 5; // dimension
27 | int nb = 10; // database size
28 | int nq = 10000; // nb of queries
29 |
30 | // float[] xb = new float[d * nb];
31 | // float[] xq = new float[d * nq];
32 | try {
33 | floatArray xb = new floatArray(d * nb);
34 |
35 | Random rand = new Random();
36 |
37 | for (int i = 0; i < nb; i++) {
38 | for (int j = 0; j < d; j++) {
39 | // xb[d * i + j] = rand.nextFloat();
40 | xb.setitem(d * i + j, rand.nextFloat());
41 | }
42 | // xb[d * i] += i / 1000.;
43 | xb.setitem(d * i, (float) (i / 1000.0));
44 | }
45 |
46 | IndexFlatL2 index = new IndexFlatL2(d);
47 | log.info("is_trained = {}", index.getIs_trained());
48 | index.add(nb, xb.cast());
49 | log.info("ntotal = {}", index.getNtotal());
50 |
51 |
52 | {
53 | int k = 4;
54 | longlongArray I = new longlongArray(k * 5);
55 | floatArray D = new floatArray(k * 5);
56 |
57 | log.info("search 5 first vector of xb");
58 | index.search(5, xb.cast(), 4, D.cast(), I.cast());
59 | log.info("Vectors:\n{}", show(xb, nb, d));
60 | log.info("Distances:\n{}", show(D, 5, 4));
61 | log.info("I:\n{}", show(I, 5, 4));
62 | }
63 | } catch (Exception e) {
64 | log.error("failed", e);
65 | }
66 | }
67 |
68 | public void simpleTest() {
69 | try {
70 | float[][] data = dummyData3d(10);
71 | int d = data[0].length;
72 | int numberOfVector = data.length;
73 | floatArray xb = makeFloatArray(data);
74 | longArray ids = makeLongArray(new int[]{0, 1, 2});
75 | IndexFlatL2 index = new IndexFlatL2(d);
76 | //what(): Error in virtual void faiss::Index::add_with_ids(faiss::Index::idx_t, const float*, const long int*) at Index.cpp:46: add_with_ids not implemented for this type of index
77 |
78 | // index.add_with_ids(3, xb.cast(), ids.cast());
79 | index.add(numberOfVector, xb.cast());
80 |
81 | log.info("ntotal = {}", index.getNtotal());
82 |
83 | {
84 | int resultSize = 3;
85 | float[][] queryConds = {new float[]{0, 1, 8}};
86 |
87 | floatArray query = makeFloatArray(queryConds);
88 | longlongArray labels = new longlongArray(resultSize);
89 | floatArray distances = new floatArray(resultSize);
90 | index.search(1, query.cast(), resultSize, distances.cast(), labels.cast());
91 |
92 | log.info("Vectors:\n{}", show(xb, numberOfVector, d));
93 | log.info("Query:\n{}", show(query, queryConds.length, queryConds[0].length));
94 | log.info("Distances:\n{}", show(distances, 1, resultSize));
95 | log.info("Labels:\n{}", show(labels, 1, resultSize));
96 | }
97 | } catch (Exception e) {
98 | log.error("failed", e);
99 | }
100 | }
101 |
102 | public void testSearchRange() {
103 | float[][] data = dummyData3d(20);
104 | int d = data[0].length;
105 | int numberOfVector = data.length;
106 |
107 | try {
108 | floatArray xb = makeFloatArray(data);
109 | IndexFlatL2 index = new IndexFlatL2(d);
110 | index.add(numberOfVector, xb.cast());
111 |
112 | {
113 | int resultSize = 4;
114 | float[][] queryConds = {new float[]{0, 1, 8}};
115 | floatArray query = makeFloatArray(queryConds);
116 |
117 | RangeSearchResult re = new RangeSearchResult(resultSize);
118 | int querySize = queryConds.length;
119 | index.range_search(querySize, query.cast(), 0.3f, re);
120 |
121 | longlongArray labels = longlongArray.frompointer(re.getLabels());
122 | floatArray distances = floatArray.frompointer(re.getDistances());
123 |
124 | log.info("Vectors:\n{}", show(xb, numberOfVector, d));
125 | log.info("Query:\n{}", show(query, querySize, queryConds[0].length));
126 | log.info("Distances:\n{}", show(distances, querySize, resultSize));
127 | log.info("Labels:\n{}", show(labels, querySize, resultSize));
128 | }
129 |
130 | } catch (Exception e) {
131 | log.error("failed", e);
132 | }
133 | }
134 |
135 | public static void egIndexIVFFlat() {
136 | try {
137 | float[][] data = randomData3d(200);
138 | int dimension = data[0].length;
139 | int numberOfVector = data.length;
140 | int nlist = 6;
141 | int nprobe = 2;
142 |
143 | IndexFlatL2 quantizer = new IndexFlatL2(dimension);
144 | IndexIVFFlat index = new IndexIVFFlat(quantizer, dimension, nlist, MetricType.METRIC_L2);
145 |
146 | float[][] trainData = dummyData3d(5);
147 | floatArray tb = makeFloatArray(trainData);
148 | log.info("Vectors:\n{}", show(tb, trainData.length, dimension));
149 | index.train(trainData.length, tb.cast());
150 |
151 | floatArray xb = makeFloatArray(data);
152 | index.add(numberOfVector, xb.cast());
153 |
154 | int resultSize = 10;
155 | float[][] queryConds = {new float[]{0, 0, 8}};
156 |
157 | floatArray query = makeFloatArray(queryConds);
158 | longlongArray labels = new longlongArray(resultSize);
159 | floatArray distances = new floatArray(resultSize);
160 |
161 | int numberOfQuery = queryConds.length;
162 | index.setNprobe(nprobe);
163 | index.search(numberOfQuery, query.cast(), resultSize, distances.cast(), labels.cast());
164 |
165 | log.info("Vectors:\n{}", show(xb, numberOfVector, dimension));
166 | log.info("Query:\n{}", show(query, queryConds.length, queryConds[0].length));
167 | log.info("Distances:\n{}", show(distances, 1, resultSize));
168 | log.info("Labels:\n{}", show(labels, 1, resultSize));
169 | } catch (Exception e) {
170 | log.error("failed", e);
171 | }
172 | }
173 |
174 | private static float[][] dummyData3d(int size) {
175 | float[][] data = new float[size * 3][3];
176 | float half = size / 2.0f;
177 | for (int i = 0; i < 3; i++) {
178 | for (int j = 0; j < size; j++) {
179 | float[] row = new float[]{0, 0, 0};
180 | row[i] = j - half;
181 | data[i * size + j] = row;
182 | }
183 | }
184 | return data;
185 | }
186 |
187 | private static float[][] randomData3d(int size) {
188 | float[][] data = new float[size * 3][3];
189 | float half = size / 2.0f;
190 | Random rand = new Random();
191 | for (int i = 0, j = data.length; i < j; i++) {
192 | float[] row = new float[]{rand.nextFloat() * size, rand.nextFloat() * size, rand.nextFloat() * size};
193 | data[i] = row;
194 | }
195 | return data;
196 | }
197 |
198 | public static void main(String argv[]) {
199 | // testFlat();
200 | egIndexIVFFlat();
201 | }
202 |
203 | }
204 |
--------------------------------------------------------------------------------
/swigfaiss4j.i:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | // -*- C++ -*-
9 |
10 | // This file describes the C++-scripting language bridge for both Lua
11 | // and Python It contains mainly includes and a few macros. There are
12 | // 3 preprocessor macros of interest:
13 | // SWIGLUA: Lua-specific code
14 | // SWIGPYTHON: Python-specific code
15 | // GPU_WRAPPER: also compile interfaces for GPU.
16 |
17 | %rename(__operator_lt__) operator<;
18 | %rename(__operator_middle__) operator[];
19 | %rename(__operator_small__) operator();
20 |
21 | %module swigfaiss;
22 |
23 | #ifdef SWIGJAVA
24 | %include "arrays_java.i"
25 | //%apply int[] {int *};
26 | //%apply double[] {double *};
27 | //%apply float[] {float *};
28 | //%apply long[] {long *};
29 | %include "carrays.i"
30 | %array_class(int, intArray);
31 | %array_class(float, floatArray);
32 | %array_class(long, longArray);
33 | %array_class(long long, longlongArray);
34 | %array_class(double, doubleArray);
35 | #endif
36 |
37 | // fbode SWIG fails on warnings, so make them non fatal
38 | //#pragma SWIG nowarn=321
39 | //#pragma SWIG nowarn=403
40 | //#pragma SWIG nowarn=325
41 | //#pragma SWIG nowarn=389
42 | //#pragma SWIG nowarn=341
43 | //#pragma SWIG nowarn=512
44 |
45 |
46 | %include
47 | typedef int64_t size_t;
48 |
49 | #define __restrict
50 |
51 |
52 | /*******************************************************************
53 | * Copied verbatim to wrapper. Contains the C++-visible includes, and
54 | * the language includes for their respective matrix libraries.
55 | *******************************************************************/
56 |
57 | %{
58 |
59 |
60 | #include
61 | #include
62 |
63 |
64 | #include "IndexFlat.h"
65 | #include "VectorTransform.h"
66 | #include "IndexLSH.h"
67 | #include "IndexPQ.h"
68 |
69 | #include "IndexIVF.h"
70 | #include "IndexIVFPQ.h"
71 | #include "IndexIVFFlat.h"
72 | #include "IndexScalarQuantizer.h"
73 | #include "IndexIVFSpectralHash.h"
74 | #include "ThreadedIndex.h"
75 | #include "IndexShards.h"
76 | #include "IndexReplicas.h"
77 | #include "HNSW.h"
78 | #include "IndexHNSW.h"
79 | #include "MetaIndexes.h"
80 | #include "FaissAssert.h"
81 |
82 | #include "IndexBinaryFlat.h"
83 | #include "IndexBinaryIVF.h"
84 | #include "IndexBinaryFromFloat.h"
85 | #include "IndexBinaryHNSW.h"
86 |
87 | #include "index_io.h"
88 |
89 | #include "IVFlib.h"
90 | #include "utils.h"
91 | #include "distances.h"
92 | #include "Heap.h"
93 | #include "AuxIndexStructures.h"
94 | #include "OnDiskInvertedLists.h"
95 |
96 | #include "Clustering.h"
97 |
98 | #include "hamming.h"
99 |
100 | #include "AutoTune.h"
101 |
102 |
103 |
104 | %}
105 |
106 |
107 | /*******************************************************************
108 | * Types of vectors we want to manipulate at the scripting language
109 | * level.
110 | *******************************************************************/
111 |
112 | // simplified interface for vector
113 | namespace std {
114 |
115 | template
116 | class vector {
117 | public:
118 | vector();
119 | void push_back(T);
120 | void clear();
121 | T * data();
122 | size_t size();
123 | T at (size_t n) const;
124 | void resize (size_t n);
125 | void swap (vector & other);
126 | };
127 | };
128 |
129 |
130 |
131 | %template(FloatVector) std::vector;
132 | %template(DoubleVector) std::vector;
133 | %template(ByteVector) std::vector;
134 | %template(CharVector) std::vector;
135 | // NOTE(hoss): Using unsigned long instead of uint64_t because OSX defines
136 | // uint64_t as unsigned long long, which SWIG is not aware of.
137 | %template(Uint64Vector) std::vector;
138 | %template(LongVector) std::vector;
139 | %template(IntVector) std::vector;
140 | %template(VectorTransformVector) std::vector;
141 | %template(OperatingPointVector) std::vector;
142 | %template(InvertedListsPtrVector) std::vector;
143 | %template(FloatVectorVector) std::vector >;
144 | %template(ByteVectorVector) std::vector >;
145 | %template(LongVectorVector) std::vector >;
146 |
147 | #ifdef GPU_WRAPPER
148 | %template(GpuResourcesVector) std::vector;
149 | #endif
150 |
151 | %include
152 |
153 | // produces an error on the Mac
154 | %ignore faiss::hamming;
155 |
156 | /*******************************************************************
157 | * Parse headers
158 | *******************************************************************/
159 |
160 |
161 | %ignore *::cmp;
162 |
163 | %include "Heap.h"
164 | %include "hamming.h"
165 |
166 | int get_num_gpus();
167 |
168 | #ifdef GPU_WRAPPER
169 |
170 | %{
171 |
172 | #include "gpu/StandardGpuResources.h"
173 | #include "gpu/GpuIndicesOptions.h"
174 | #include "gpu/GpuClonerOptions.h"
175 | #include "gpu/utils/MemorySpace.h"
176 | #include "gpu/GpuIndex.h"
177 | #include "gpu/GpuIndexFlat.h"
178 | #include "gpu/GpuIndexIVF.h"
179 | #include "gpu/GpuIndexIVFPQ.h"
180 | #include "gpu/GpuIndexIVFFlat.h"
181 | #include "gpu/GpuIndexBinaryFlat.h"
182 | #include "gpu/GpuAutoTune.h"
183 | #include "gpu/GpuDistance.h"
184 |
185 | int get_num_gpus()
186 | {
187 | return faiss::gpu::getNumDevices();
188 | }
189 |
190 | %}
191 |
192 | // causes weird wrapper bug
193 | %ignore *::getMemoryManager;
194 | %ignore *::getMemoryManagerCurrentDevice;
195 |
196 | %include "gpu/GpuResources.h"
197 | %include "gpu/StandardGpuResources.h"
198 |
199 | #else
200 |
201 | %{
202 | int get_num_gpus()
203 | {
204 | return 0;
205 | }
206 | %}
207 |
208 |
209 | #endif
210 |
211 |
212 | %include "utils.h"
213 |
214 | %include "Index.h"
215 | %include "Clustering.h"
216 |
217 | %include "distances.h"
218 |
219 | %ignore faiss::ProductQuantizer::get_centroids(size_t,size_t) const;
220 |
221 | %include "ProductQuantizer.h"
222 |
223 | %include "VectorTransform.h"
224 | %include "IndexFlat.h"
225 | %include "IndexLSH.h"
226 | %include "PolysemousTraining.h"
227 | %include "IndexPQ.h"
228 |
229 | %include "InvertedLists.h"
230 |
231 | %include
232 | %interface_impl(faiss::Level1Quantizer);
233 | //%interface_impl(faiss::Index);
234 | %ignore InvertedListScanner;
235 | %ignore BinaryInvertedListScanner;
236 | %include "IndexIVF.h"
237 | // NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the
238 | // non-const one.
239 | %warnfilter(509) extract_index_ivf;
240 | %include "IVFlib.h"
241 | %include "IndexScalarQuantizer.h"
242 | %include "IndexIVFSpectralHash.h"
243 | %include "HNSW.h"
244 | %include "IndexHNSW.h"
245 | %include "IndexIVFFlat.h"
246 | %include "OnDiskInvertedLists.h"
247 |
248 | %ignore faiss::IndexIVFPQ::alloc_type;
249 | %include "IndexIVFPQ.h"
250 |
251 | %include "IndexBinary.h"
252 | %include "IndexBinaryFlat.h"
253 | %include "IndexBinaryIVF.h"
254 | %include "IndexBinaryFromFloat.h"
255 | %include "IndexBinaryHNSW.h"
256 |
257 |
258 |
259 | // %ignore faiss::IndexReplicas::at(int) const;
260 |
261 | %include "ThreadedIndex.h"
262 | %template(ThreadedIndexBase) faiss::ThreadedIndex;
263 | %template(ThreadedIndexBaseBinary) faiss::ThreadedIndex;
264 |
265 | %include "IndexShards.h"
266 | %template(IndexShards) faiss::IndexShardsTemplate;
267 | %template(IndexBinaryShards) faiss::IndexShardsTemplate;
268 |
269 | %include "IndexReplicas.h"
270 | %template(IndexReplicas) faiss::IndexReplicasTemplate;
271 | %template(IndexBinaryReplicas) faiss::IndexReplicasTemplate;
272 |
273 |
274 | %include "MetaIndexes.h"
275 | %template(IndexIDMap) faiss::IndexIDMapTemplate;
276 | %template(IndexBinaryIDMap) faiss::IndexIDMapTemplate;
277 | %template(IndexIDMap2) faiss::IndexIDMap2Template;
278 | %template(IndexBinaryIDMap2) faiss::IndexIDMap2Template;
279 |
280 | #ifdef GPU_WRAPPER
281 |
282 | // quiet SWIG warnings
283 | %ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF;
284 |
285 | %include "gpu/GpuIndicesOptions.h"
286 | %include "gpu/GpuClonerOptions.h"
287 | %include "gpu/utils/MemorySpace.h"
288 | %include "gpu/GpuIndex.h"
289 | %include "gpu/GpuIndexFlat.h"
290 | %include "gpu/GpuIndexIVF.h"
291 | %include "gpu/GpuIndexIVFPQ.h"
292 | %include "gpu/GpuIndexIVFFlat.h"
293 | %include "gpu/GpuIndexBinaryFlat.h"
294 | %include "gpu/GpuDistance.h"
295 |
296 | #ifdef SWIGLUA
297 |
298 | /// in Lua, swigfaiss_gpu is known as swigfaiss
299 | %luacode {
300 | local swigfaiss = swigfaiss_gpu
301 | }
302 |
303 | #endif
304 |
305 |
306 | #endif
307 |
308 |
309 |
310 |
311 | /*******************************************************************
312 | * Lua-specific: support async execution of searches in an index
313 | * Python equivalent is just to use Python threads.
314 | *******************************************************************/
315 |
316 |
317 | #ifdef SWIGLUA
318 |
319 | %{
320 |
321 |
322 | namespace faiss {
323 |
324 | struct AsyncIndexSearchC {
325 | typedef Index::idx_t idx_t;
326 | const Index * index;
327 |
328 | idx_t n;
329 | const float *x;
330 | idx_t k;
331 | float *distances;
332 | idx_t *labels;
333 |
334 | bool is_finished;
335 |
336 | pthread_t thread;
337 |
338 |
339 | AsyncIndexSearchC (const Index *index,
340 | idx_t n, const float *x, idx_t k,
341 | float *distances, idx_t *labels):
342 | index(index), n(n), x(x), k(k), distances(distances),
343 | labels(labels)
344 | {
345 | is_finished = false;
346 | pthread_create (&thread, NULL, &AsyncIndexSearchC::callback,
347 | this);
348 | }
349 |
350 | static void *callback (void *arg)
351 | {
352 | AsyncIndexSearchC *aidx = (AsyncIndexSearchC *)arg;
353 | aidx->do_search();
354 | return NULL;
355 | }
356 |
357 | void do_search ()
358 | {
359 | index->search (n, x, k, distances, labels);
360 | }
361 | void join ()
362 | {
363 | pthread_join (thread, NULL);
364 | }
365 |
366 | };
367 |
368 | }
369 |
370 | %}
371 |
372 | // re-decrlare only what we need
373 | namespace faiss {
374 |
375 | struct AsyncIndexSearchC {
376 | typedef Index::idx_t idx_t;
377 | bool is_finished;
378 | AsyncIndexSearchC (const Index *index,
379 | idx_t n, const float *x, idx_t k,
380 | float *distances, idx_t *labels);
381 |
382 |
383 | void join ();
384 | };
385 |
386 | }
387 |
388 |
389 | #endif
390 |
391 |
392 |
393 |
394 | /*******************************************************************
395 | * downcast return of some functions so that the sub-class is used
396 | * instead of the generic upper-class.
397 | *******************************************************************/
398 | #ifdef SWIGJAVA
399 |
400 | %define DOWNCAST(subclass)
401 | if (dynamic_cast ($1)) {
402 | faiss::subclass *instance_ptr = (faiss::subclass *)$1;
403 | $result = (jlong)instance_ptr;
404 | } else
405 | %enddef
406 |
407 | %define DOWNCAST2(subclass, longname)
408 | if (dynamic_cast ($1)) {
409 | faiss::subclass *instance_ptr = (faiss::subclass *)$1;
410 | $result = (jlong)instance_ptr;
411 | } else
412 | %enddef
413 |
414 | %define DOWNCAST_GPU(subclass)
415 | if (dynamic_cast ($1)) {
416 | faiss::subclass *instance_ptr = (faiss::subclass *)$1;
417 | $result = (jlong)instance_ptr;
418 | } else
419 | %enddef
420 |
421 | #endif
422 |
423 | #ifdef SWIGPYTHON
424 |
425 | %define DOWNCAST(subclass)
426 | if (dynamic_cast ($1)) {
427 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner);
428 | } else
429 | %enddef
430 |
431 | %define DOWNCAST2(subclass, longname)
432 | if (dynamic_cast ($1)) {
433 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner);
434 | } else
435 | %enddef
436 |
437 | %define DOWNCAST_GPU(subclass)
438 | if (dynamic_cast ($1)) {
439 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__gpu__ ## subclass,$owner);
440 | } else
441 | %enddef
442 |
443 | #endif
444 |
445 | %newobject read_index;
446 | %newobject read_index_binary;
447 | %newobject read_VectorTransform;
448 | %newobject read_ProductQuantizer;
449 | %newobject clone_index;
450 | %newobject clone_VectorTransform;
451 |
452 | // Subclasses should appear before their parent
453 | %typemap(out) faiss::Index * {
454 | DOWNCAST2 ( IndexIDMap, IndexIDMapTemplateT_faiss__Index_t )
455 | DOWNCAST2 ( IndexIDMap2, IndexIDMap2TemplateT_faiss__Index_t )
456 | DOWNCAST2 ( IndexShards, IndexShardsTemplateT_faiss__Index_t )
457 | DOWNCAST2 ( IndexReplicas, IndexReplicasTemplateT_faiss__Index_t )
458 | DOWNCAST ( IndexIVFPQR )
459 | DOWNCAST ( IndexIVFPQ )
460 | DOWNCAST ( IndexIVFSpectralHash )
461 | DOWNCAST ( IndexIVFScalarQuantizer )
462 | DOWNCAST ( IndexIVFFlatDedup )
463 | DOWNCAST ( IndexIVFFlat )
464 | DOWNCAST ( IndexIVF )
465 | DOWNCAST ( IndexFlat )
466 | DOWNCAST ( IndexPQ )
467 | DOWNCAST ( IndexScalarQuantizer )
468 | DOWNCAST ( IndexLSH )
469 | DOWNCAST ( IndexPreTransform )
470 | DOWNCAST ( MultiIndexQuantizer )
471 | DOWNCAST ( IndexHNSWFlat )
472 | DOWNCAST ( IndexHNSWPQ )
473 | DOWNCAST ( IndexHNSWSQ )
474 | DOWNCAST ( IndexHNSW2Level )
475 | DOWNCAST ( Index2Layer )
476 | #ifdef GPU_WRAPPER
477 | DOWNCAST_GPU ( GpuIndexIVFPQ )
478 | DOWNCAST_GPU ( GpuIndexIVFFlat )
479 | DOWNCAST_GPU ( GpuIndexFlat )
480 | #endif
481 | // default for non-recognized classes
482 | DOWNCAST ( Index )
483 | if ($1 == NULL)
484 | {
485 | #ifdef SWIGJAVA
486 | $result = 0;
487 | #endif
488 | } else {
489 | assert(false);
490 | }
491 | #ifdef SWIGLUA
492 | SWIG_arg++;
493 | #endif
494 | }
495 |
496 | %typemap(out) faiss::IndexBinary * {
497 | DOWNCAST2 ( IndexBinaryReplicas, IndexReplicasTemplateT_faiss__IndexBinary_t )
498 | DOWNCAST2 ( IndexBinaryIDMap, IndexIDMapTemplateT_faiss__IndexBinary_t )
499 | DOWNCAST2 ( IndexBinaryIDMap2, IndexIDMap2TemplateT_faiss__IndexBinary_t )
500 | DOWNCAST ( IndexBinaryIVF )
501 | DOWNCAST ( IndexBinaryFlat )
502 | DOWNCAST ( IndexBinaryFromFloat )
503 | DOWNCAST ( IndexBinaryHNSW )
504 | #ifdef GPU_WRAPPER
505 | DOWNCAST_GPU ( GpuIndexBinaryFlat )
506 | #endif
507 | // default for non-recognized classes
508 | DOWNCAST ( IndexBinary )
509 | if ($1 == NULL)
510 | {
511 | #ifdef SWIGPYTHON
512 | $result = SWIG_Py_Void();
513 | #endif
514 | #ifdef SWIGJAVA
515 | $result = 0;
516 | #endif
517 | // Lua does not need a push for nil
518 | } else {
519 | assert(false);
520 | }
521 | #ifdef SWIGLUA
522 | SWIG_arg++;
523 | #endif
524 | }
525 |
526 | %typemap(out) faiss::VectorTransform * {
527 | DOWNCAST (RemapDimensionsTransform)
528 | DOWNCAST (OPQMatrix)
529 | DOWNCAST (PCAMatrix)
530 | DOWNCAST (RandomRotationMatrix)
531 | DOWNCAST (LinearTransform)
532 | DOWNCAST (NormalizationTransform)
533 | DOWNCAST (CenteringTransform)
534 | DOWNCAST (VectorTransform)
535 | {
536 | assert(false);
537 | }
538 | #ifdef SWIGLUA
539 | SWIG_arg++;
540 | #endif
541 | }
542 |
543 | %typemap(out) faiss::InvertedLists * {
544 | DOWNCAST (ArrayInvertedLists)
545 | DOWNCAST (OnDiskInvertedLists)
546 | DOWNCAST (VStackInvertedLists)
547 | DOWNCAST (HStackInvertedLists)
548 | DOWNCAST (MaskedInvertedLists)
549 | DOWNCAST (InvertedLists)
550 | {
551 | assert(false);
552 | }
553 | #ifdef SWIGLUA
554 | SWIG_arg++;
555 | #endif
556 | }
557 |
558 | // just to downcast pointers that come from elsewhere (eg. direct
559 | // access to object fields)
560 | %inline %{
561 | faiss::Index * downcast_index (faiss::Index *index)
562 | {
563 | return index;
564 | }
565 | faiss::VectorTransform * downcast_VectorTransform (faiss::VectorTransform *vt)
566 | {
567 | return vt;
568 | }
569 | faiss::IndexBinary * downcast_IndexBinary (faiss::IndexBinary *index)
570 | {
571 | return index;
572 | }
573 | faiss::InvertedLists * downcast_InvertedLists (faiss::InvertedLists *il)
574 | {
575 | return il;
576 | }
577 | %}
578 |
579 |
580 | %include "index_io.h"
581 |
582 | %newobject index_factory;
583 | %newobject index_binary_factory;
584 |
585 | %include "AutoTune.h"
586 |
587 |
588 | #ifdef GPU_WRAPPER
589 |
590 | %newobject index_gpu_to_cpu;
591 | %newobject index_cpu_to_gpu;
592 | %newobject index_cpu_to_gpu_multiple;
593 |
594 | %include "gpu/GpuAutoTune.h"
595 |
596 | #endif
597 |
598 | // Python-specific: do not release GIL any more, as functions below
599 | // use the Python/C API
600 | #ifdef SWIGPYTHON
601 | %exception;
602 | #endif
603 |
604 |
605 |
606 |
607 |
608 | /*******************************************************************
609 | * Python specific: numpy array <-> C++ pointer interface
610 | *******************************************************************/
611 |
612 | #ifdef SWIGPYTHON
613 |
614 | %{
615 | PyObject *swig_ptr (PyObject *a)
616 | {
617 | if(!PyArray_Check(a)) {
618 | PyErr_SetString(PyExc_ValueError, "input not a numpy array");
619 | return NULL;
620 | }
621 | PyArrayObject *ao = (PyArrayObject *)a;
622 |
623 | if(!PyArray_ISCONTIGUOUS(ao)) {
624 | PyErr_SetString(PyExc_ValueError, "array is not C-contiguous");
625 | return NULL;
626 | }
627 | void * data = PyArray_DATA(ao);
628 | if(PyArray_TYPE(ao) == NPY_FLOAT32) {
629 | return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0);
630 | }
631 | if(PyArray_TYPE(ao) == NPY_FLOAT64) {
632 | return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0);
633 | }
634 | if(PyArray_TYPE(ao) == NPY_INT32) {
635 | return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0);
636 | }
637 | if(PyArray_TYPE(ao) == NPY_UINT8) {
638 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0);
639 | }
640 | if(PyArray_TYPE(ao) == NPY_INT8) {
641 | return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0);
642 | }
643 | if(PyArray_TYPE(ao) == NPY_UINT64) {
644 | #ifdef SWIGWORDSIZE64
645 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0);
646 | #else
647 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long_long, 0);
648 | #endif
649 | }
650 | if(PyArray_TYPE(ao) == NPY_INT64) {
651 | #ifdef SWIGWORDSIZE64
652 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0);
653 | #else
654 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long_long, 0);
655 | #endif
656 | }
657 | PyErr_SetString(PyExc_ValueError, "did not recognize array type");
658 | return NULL;
659 | }
660 |
661 |
662 | struct PythonInterruptCallback: faiss::InterruptCallback {
663 |
664 | bool want_interrupt () override {
665 | int err;
666 | {
667 | PyGILState_STATE gstate;
668 | gstate = PyGILState_Ensure();
669 | err = PyErr_CheckSignals();
670 | PyGILState_Release(gstate);
671 | }
672 | return err == -1;
673 | }
674 |
675 | };
676 |
677 |
678 | %}
679 |
680 |
681 | %init %{
682 | /* needed, else crash at runtime */
683 | import_array();
684 |
685 | faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
686 |
687 | %}
688 |
689 | // return a pointer usable as input for functions that expect pointers
690 | PyObject *swig_ptr (PyObject *a);
691 |
692 | %define REV_SWIG_PTR(ctype, numpytype)
693 |
694 | %{
695 | PyObject * rev_swig_ptr(ctype *src, npy_intp size) {
696 | return PyArray_SimpleNewFromData(1, &size, numpytype, src);
697 | }
698 | %}
699 |
700 | PyObject * rev_swig_ptr(ctype *src, size_t size);
701 |
702 | %enddef
703 |
704 | REV_SWIG_PTR(float, NPY_FLOAT32);
705 | REV_SWIG_PTR(int, NPY_INT32);
706 | REV_SWIG_PTR(unsigned char, NPY_UINT8);
707 | REV_SWIG_PTR(int64_t, NPY_INT64);
708 | REV_SWIG_PTR(uint64_t, NPY_UINT64);
709 |
710 | #endif
711 |
712 |
713 |
714 | /*******************************************************************
715 | * Lua specific: Torch tensor <-> C++ pointer interface
716 | *******************************************************************/
717 |
718 | #ifdef SWIGLUA
719 |
720 |
721 | // provide a XXX_ptr function to convert Lua XXXTensor -> C++ XXX*
722 |
723 | %define TYPE_CONVERSION(ctype, tensortype)
724 |
725 | // typemap for the *_ptr_from_cdata function
726 | %typemap(in) ctype** {
727 | if(lua_type(L, $input) != 10) {
728 | fprintf(stderr, "not cdata input\n");
729 | SWIG_fail;
730 | }
731 | $1 = (ctype**)lua_topointer(L, $input);
732 | }
733 |
734 |
735 | // SWIG and C declaration for the *_ptr_from_cdata function
736 | %{
737 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs) {
738 | return *x + ofs;
739 | }
740 | %}
741 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs);
742 |
743 | // the *_ptr function
744 | %luacode {
745 |
746 | function swigfaiss. ctype ## _ptr(tensor)
747 | assert(tensor:type() == "torch." .. # tensortype, "need a " .. # tensortype)
748 | assert(tensor:isContiguous(), "requires contiguous tensor")
749 | return swigfaiss. ctype ## _ptr_from_cdata(
750 | tensor:storage():data(),
751 | tensor:storageOffset() - 1)
752 | end
753 |
754 | }
755 |
756 | %enddef
757 |
758 | TYPE_CONVERSION (int, IntTensor)
759 | TYPE_CONVERSION (float, FloatTensor)
760 | TYPE_CONVERSION (long, LongTensor)
761 | TYPE_CONVERSION (uint64_t, LongTensor)
762 | TYPE_CONVERSION (uint8_t, ByteTensor)
763 |
764 | #endif
765 |
766 | /*******************************************************************
767 | * How should the template objects apprear in the scripting language?
768 | *******************************************************************/
769 |
770 | // answer: the same as the C++ typedefs, but we still have to redefine them
771 |
772 | %template() faiss::CMin;
773 | %template() faiss::CMin;
774 | %template() faiss::CMax;
775 | %template() faiss::CMax;
776 |
777 | %template(float_minheap_array_t) faiss::HeapArray >;
778 | %template(int_minheap_array_t) faiss::HeapArray >;
779 |
780 | %template(float_maxheap_array_t) faiss::HeapArray >;
781 | %template(int_maxheap_array_t) faiss::HeapArray >;
782 |
783 |
784 | /*******************************************************************
785 | * Expose a few basic functions
786 | *******************************************************************/
787 |
788 |
789 | void omp_set_num_threads (int num_threads);
790 | int omp_get_max_threads ();
791 | void *memcpy(void *dest, const void *src, size_t n);
792 |
793 |
794 | /*******************************************************************
795 | * For Faiss/Pytorch interop via pointers encoded as longs
796 | *******************************************************************/
797 |
798 | %inline %{
799 | float * cast_integer_to_float_ptr (long x) {
800 | return (float*)x;
801 | }
802 |
803 | long * cast_integer_to_long_ptr (long x) {
804 | return (long*)x;
805 | }
806 |
807 | int * cast_integer_to_int_ptr (long x) {
808 | return (int*)x;
809 | }
810 |
811 | %}
812 |
813 |
814 |
815 | /*******************************************************************
816 | * Range search interface
817 | *******************************************************************/
818 |
819 | %ignore faiss::BufferList::Buffer;
820 | %ignore faiss::RangeSearchPartialResult::QueryResult;
821 | %ignore faiss::IDSelectorBatch::set;
822 | %ignore faiss::IDSelectorBatch::bloom;
823 |
824 | %ignore faiss::InterruptCallback::instance;
825 | %ignore faiss::InterruptCallback::lock;
826 | %include "AuxIndexStructures.h"
827 |
828 | %{
829 | // may be useful for lua code launched in background from shell
830 |
831 | #include
832 | void ignore_SIGTTIN() {
833 | signal(SIGTTIN, SIG_IGN);
834 | }
835 | %}
836 |
837 | void ignore_SIGTTIN();
838 |
839 |
840 | %inline %{
841 |
842 | // numpy misses a hash table implementation, hence this class. It
843 | // represents not found values as -1 like in the Index implementation
844 |
845 | struct MapLong2Long {
846 | std::unordered_map map;
847 |
848 | void add(size_t n, const int64_t *keys, const int64_t *vals) {
849 | map.reserve(map.size() + n);
850 | for (size_t i = 0; i < n; i++) {
851 | map[keys[i]] = vals[i];
852 | }
853 | }
854 |
855 | long search(int64_t key) {
856 | if (map.count(key) == 0) {
857 | return -1;
858 | } else {
859 | return map[key];
860 | }
861 | }
862 |
863 | void search_multiple(size_t n, int64_t *keys, int64_t * vals) {
864 | for (size_t i = 0; i < n; i++) {
865 | vals[i] = search(keys[i]);
866 | }
867 | }
868 | };
869 |
870 | %}
871 |
872 | %inline %{
873 | void wait() {
874 | // in gdb, use return to get out of this function
875 | for(int i = 0; i == 0; i += 0);
876 | }
877 | %}
878 |
879 | // End of file...
880 |
--------------------------------------------------------------------------------