├── .gitmodules
├── .gitignore
├── cpu
└── pom.xml
├── gpu
└── pom.xml
├── Dockerfile
├── DockerfileGpu
├── pom.xml
├── common
├── src
│ └── main
│ │ └── java
│ │ └── com
│ │ └── gameofdimension
│ │ └── faiss
│ │ └── utils
│ │ ├── JniFaissInitializer.java
│ │ └── NativeUtils.java
└── pom.xml
├── README.md
├── jni
├── Makefile
├── MakefileGpu
└── swigfaiss.swig
├── cpu-demo
├── src
│ └── main
│ │ └── java
│ │ └── com
│ │ └── gameofdimension
│ │ └── faiss
│ │ ├── tutorial
│ │ ├── OneFlat.java
│ │ ├── TwoIVFFlat.java
│ │ └── ThreeIVFPQ.java
│ │ └── utils
│ │ └── IndexHelper.java
└── pom.xml
└── gpu-demo
├── src
└── main
│ └── java
│ └── com
│ └── gameofdimension
│ └── faiss
│ ├── tutorial
│ ├── GpuOneFlat.java
│ ├── FiveMultipleGPUs.java
│ └── FourGPU.java
│ └── utils
│ └── IndexHelper.java
└── pom.xml
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "faiss"]
2 | path = faiss
3 | url = https://github.com/facebookresearch/faiss.git
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.ipr
3 | *.iws
4 | *.iml
5 | target/
6 | src/main/java/com/gameofdimension/faiss/swig/
7 | dependency-reduced-pom.xml
8 |
--------------------------------------------------------------------------------
/cpu/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | jni-faiss
7 | com.gameofdimension
8 | 0.0.1
9 |
10 | 4.0.0
11 |
12 | cpu
13 |
14 |
15 |
--------------------------------------------------------------------------------
/gpu/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | jni-faiss
7 | com.gameofdimension
8 | 0.0.1
9 |
10 | 4.0.0
11 |
12 | gpu
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM centos:7
2 |
3 | RUN yum install -y lapack lapack-devel
4 |
5 | # Install necessary build tools
6 | RUN yum install -y gcc-c++ make swig3
7 |
8 | RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel maven
9 |
10 | COPY . /opt/jni-faiss
11 |
12 | WORKDIR /opt/jni-faiss/faiss
13 |
14 | RUN ./configure --prefix=/usr --libdir=/usr/lib64 --without-cuda
15 | RUN make -j $(nproc)
16 | RUN make install
17 |
18 |
19 | WORKDIR /opt/jni-faiss/jni
20 |
21 | RUN make
22 |
23 | WORKDIR /opt/jni-faiss
24 |
25 | RUN mvn clean install -pl cpu -am
26 |
27 | RUN mvn clean package -pl cpu-demo -am
28 |
--------------------------------------------------------------------------------
/DockerfileGpu:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:10.0-devel-centos7
2 |
3 | RUN yum install -y lapack lapack-devel
4 |
5 | # Install necessary build tools
6 | RUN yum install -y gcc-c++ make swig3
7 |
8 | RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel maven
9 |
10 | COPY . /opt/jni-faiss
11 |
12 | WORKDIR /opt/jni-faiss/faiss
13 |
14 | RUN ./configure --prefix=/usr --libdir=/usr/lib64 --with-cuda=/usr/local/cuda-10.0
15 | RUN make -j $(nproc)
16 | RUN make install
17 |
18 | WORKDIR /opt/jni-faiss/jni
19 |
20 | RUN make -f MakefileGpu
21 |
22 | WORKDIR /opt/jni-faiss
23 |
24 | RUN mvn clean install -pl gpu -am
25 |
26 | RUN mvn clean package -pl gpu-demo -am
27 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | com.gameofdimension
7 | jni-faiss
8 | 0.0.1
9 |
10 | cpu
11 | gpu
12 | common
13 | cpu-demo
14 | gpu-demo
15 |
16 |
17 | pom
18 | 4.0.0
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/common/src/main/java/com/gameofdimension/faiss/utils/JniFaissInitializer.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.utils;
2 |
3 | import com.google.common.base.Preconditions;
4 |
5 | import java.io.IOException;
6 |
7 | /**
8 | * @author yzq, yzq@leyantech.com
9 | * @date 2020-01-28.
10 | */
11 | public class JniFaissInitializer {
12 |
13 | private static volatile boolean initialized = false;
14 |
15 | static {
16 | try {
17 | NativeUtils.loadLibraryFromJar("/_swigfaiss.so");
18 | initialized = true;
19 | } catch (IOException e) {
20 | Preconditions.checkArgument(false);
21 | }
22 | }
23 |
24 | public static boolean initialized() {
25 | return initialized;
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # jni-faiss
2 | > linux only so far
3 |
4 | ## cpu
5 |
6 | - git clone https://github.com/gameofdimension/jni-faiss.git && cd jni-faiss && git submodule update --init
7 |
8 | - docker build -t jni-faiss .
9 |
10 | - docker run -it jni-faiss java -cp cpu-demo/target/cpu-demo-0.0.1.jar com.gameofdimension.faiss.tutorial.OneFlat
11 |
12 | ## gpu
13 |
14 | - git clone https://github.com/gameofdimension/jni-faiss.git && cd jni-faiss && git submodule update --init
15 |
16 | - docker build -t jni-faiss-gpu -f DockerfileGpu .
17 |
18 | - docker run --gpus 1 -it jni-faiss-gpu java -Xmx8g -cp gpu-demo/target/gpu-demo-0.0.1.jar com.gameofdimension.faiss.tutorial.GpuOneFlat
19 |
20 | ## reference
21 |
22 | - https://github.com/adamheinrich/native-utils
23 |
24 | - https://github.com/thenetcircle/faiss4j
--------------------------------------------------------------------------------
/common/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | jni-faiss
7 | com.gameofdimension
8 | 0.0.1
9 |
10 | 4.0.0
11 |
12 | common
13 |
14 |
15 |
16 | com.google.guava
17 | guava
18 | 21.0
19 |
20 |
21 |
22 |
23 |
24 | no-uber
25 |
26 | true
27 |
28 |
29 |
30 |
31 | maven-compiler-plugin
32 | 3.5.1
33 |
34 | 1.8
35 | 1.8
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/jni/Makefile:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | -include ../faiss/makefile.inc
7 |
8 | ifneq ($(strip $(NVCC)),)
9 | SWIGFLAGS += -DGPU_WRAPPER
10 | endif
11 |
12 | JAVACFLAGS = -I /usr/lib/jvm/java/include/ -I /usr/lib/jvm/java/include/linux/
13 | SWIGJAVASRC = ../cpu/src/main/java/com/gameofdimension/faiss/swig
14 |
15 | all: build
16 |
17 | # Also silently generates swigfaiss.py.
18 | swigfaiss.cpp: swigfaiss.swig ../faiss/libfaiss.a
19 | mkdir -p $(SWIGJAVASRC)
20 | $(SWIG) -java -c++ -Doverride= -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $<
21 |
22 | swigfaiss_avx2.cpp: swigfaiss.swig ../faiss/libfaiss.a
23 | mkdir -p $(SWIGJAVASRC)
24 | $(SWIG) -java -c++ -Doverride= -module swigfaiss_avx2 -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $<
25 |
26 | %.o: %.cpp
27 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) $(JAVACFLAGS) \
28 | -I../ -c $< -o $@
29 |
30 | # Extension is .so even on OSX.
31 | _%.so: %.o ../faiss/libfaiss.a
32 | $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS)
33 |
34 | build: _swigfaiss.so
35 | mkdir -p ../cpu/src/main/resources
36 | cp _swigfaiss.so ../cpu/src/main/resources/
37 |
38 | # install: build
39 | # $(PYTHON) setup.py install
40 | #
41 |
42 | clean:
43 | rm -f swigfaiss*.cpp swigfaiss*.o _swigfaiss*.so
44 | rm -rf $(SWIGJAVASRC)
45 |
46 |
47 | .PHONY: all build clean install
48 |
--------------------------------------------------------------------------------
/jni/MakefileGpu:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | -include ../faiss/makefile.inc
7 |
8 | ifneq ($(strip $(NVCC)),)
9 | SWIGFLAGS += -DGPU_WRAPPER
10 | endif
11 |
12 |
13 | JAVACFLAGS = -I /usr/lib/jvm/java/include/ -I /usr/lib/jvm/java/include/linux
14 | SWIGJAVASRC = ../gpu/src/main/java/com/gameofdimension/faiss/swig
15 |
16 | all: build
17 |
18 | # Also silently generates swigfaiss.py.
19 | swigfaiss.cpp: swigfaiss.swig ../faiss/libfaiss.a
20 | mkdir -p $(SWIGJAVASRC)
21 | $(SWIG) -java -c++ -Doverride= -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $<
22 |
23 | swigfaiss_avx2.cpp: swigfaiss.swig ../faiss/libfaiss.a
24 | mkdir -p $(SWIGJAVASRC)
25 | $(SWIG) -java -c++ -Doverride= -module swigfaiss_avx2 -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $<
26 |
27 | %.o: %.cpp
28 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) $(JAVACFLAGS) \
29 | -I../ -c $< -o $@
30 |
31 | # Extension is .so even on OSX.
32 | _%.so: %.o ../faiss/libfaiss.a
33 | $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS)
34 |
35 | build: _swigfaiss.so
36 | mkdir -p ../gpu/src/main/resources
37 | cp _swigfaiss.so ../gpu/src/main/resources/
38 |
39 | # install: build
40 | # $(PYTHON) setup.py install
41 | #
42 |
43 | clean:
44 | rm -f swigfaiss*.cpp swigfaiss*.o _swigfaiss*.so
45 | rm -rf $(SWIGJAVASRC)
46 |
47 |
48 | .PHONY: all build clean install
49 |
--------------------------------------------------------------------------------
/cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/OneFlat.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
5 |
6 | import com.gameofdimension.faiss.swig.IndexFlatL2;
7 | import com.gameofdimension.faiss.swig.floatArray;
8 | import com.gameofdimension.faiss.swig.longArray;
9 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
10 | import com.google.common.base.Preconditions;
11 |
12 | import java.util.Random;
13 | import java.util.logging.Logger;
14 |
15 | /**
16 | * @author yzq, yzq@leyantech.com
17 | * @date 2020-01-28.
18 | */
19 | public class OneFlat {
20 |
21 | private static Logger LOG = Logger.getLogger(OneFlat.class.getName());
22 | private static int d = 64; // dimension
23 | private static int nb = 100000; // database size
24 | private static int nq = 10000; // nb of queries
25 |
26 | private floatArray xb;
27 | private floatArray xq;
28 |
29 | private Random random;
30 | private IndexFlatL2 index;
31 |
32 | public OneFlat() {
33 | Preconditions.checkArgument(JniFaissInitializer.initialized());
34 | random = new Random(42);
35 | index = new IndexFlatL2(d);
36 | xb = makeRandomFloatArray(nb, d, random);
37 | xq = makeRandomFloatArray(nq, d, random);
38 | index.add(nb, xb.cast());
39 | LOG.info(String.format("is_trained = %s, ntotal = %d",
40 | index.getIs_trained(), index.getNtotal()));
41 | }
42 |
43 | public void sanityCheck() {
44 | int rn = 4;
45 | int qn = 5;
46 | floatArray distances = new floatArray(qn * rn);
47 | longArray indices = new longArray(qn * rn);
48 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast());
49 |
50 | LOG.info(show(distances, qn, rn));
51 | LOG.info(show(indices, qn, rn));
52 | }
53 |
54 | public void search() {
55 | int rn = 4;
56 | floatArray distances = new floatArray(nq * rn);
57 | longArray indices = new longArray(nq * rn);
58 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
59 |
60 | LOG.info(show(distances, 5, rn));
61 | LOG.info(show(indices, 5, rn));
62 | }
63 |
64 | public static void main(String[] args) {
65 | OneFlat oneFlat = new OneFlat();
66 |
67 | LOG.info("****************************************************");
68 | oneFlat.sanityCheck();
69 | LOG.info("****************************************************");
70 | oneFlat.search();
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/GpuOneFlat.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
5 |
6 | import com.gameofdimension.faiss.swig.GpuIndexFlatL2;
7 | import com.gameofdimension.faiss.swig.StandardGpuResources;
8 | import com.gameofdimension.faiss.swig.floatArray;
9 | import com.gameofdimension.faiss.swig.longArray;
10 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
11 | import com.google.common.base.Preconditions;
12 |
13 | import java.util.Random;
14 | import java.util.logging.Logger;
15 |
16 | /**
17 | * @author yzq, yzq@leyantech.com
18 | * @date 2020-01-28.
19 | */
20 | public class GpuOneFlat {
21 |
22 | private static Logger LOG = Logger.getLogger(GpuOneFlat.class.getName());
23 | private static int d = 64; // dimension
24 | private static int nb = 100000; // database size
25 | private static int nq = 10000; // nb of queries
26 |
27 | private floatArray xb;
28 | private floatArray xq;
29 |
30 | private Random random;
31 | private StandardGpuResources res;
32 | private GpuIndexFlatL2 index;
33 |
34 | public GpuOneFlat() {
35 | Preconditions.checkArgument(JniFaissInitializer.initialized());
36 | random = new Random(42);
37 | res = new StandardGpuResources();
38 | index = new GpuIndexFlatL2(res, d);
39 | xb = makeRandomFloatArray(nb, d, random);
40 | xq = makeRandomFloatArray(nq, d, random);
41 | index.add(nb, xb.cast());
42 | LOG.info(String.format("is_trained = %s, ntotal = %d",
43 | index.getIs_trained(), index.getNtotal()));
44 | }
45 |
46 | public void sanityCheck() {
47 | int rn = 4;
48 | int qn = 5;
49 | floatArray distances = new floatArray(qn * rn);
50 | longArray indices = new longArray(qn * rn);
51 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast());
52 |
53 | LOG.info(show(distances, qn, rn));
54 | LOG.info(show(indices, qn, rn));
55 | }
56 |
57 | public void search() {
58 | int rn = 4;
59 | floatArray distances = new floatArray(nq * rn);
60 | longArray indices = new longArray(nq * rn);
61 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
62 |
63 | LOG.info(show(distances, 5, rn));
64 | LOG.info(show(indices, 5, rn));
65 | }
66 |
67 | public static void main(String[] args) {
68 | GpuOneFlat oneFlat = new GpuOneFlat();
69 |
70 | LOG.info("****************************************************");
71 | oneFlat.sanityCheck();
72 | LOG.info("****************************************************");
73 | oneFlat.search();
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/TwoIVFFlat.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
5 |
6 | import com.gameofdimension.faiss.swig.IndexFlatL2;
7 | import com.gameofdimension.faiss.swig.IndexIVFFlat;
8 | import com.gameofdimension.faiss.swig.MetricType;
9 | import com.gameofdimension.faiss.swig.floatArray;
10 | import com.gameofdimension.faiss.swig.longArray;
11 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
12 | import com.google.common.base.Preconditions;
13 |
14 | import java.util.Random;
15 | import java.util.logging.Logger;
16 |
17 | /**
18 | * @author yzq, yzq@leyantech.com
19 | * @date 2020-01-29.
20 | */
21 | public class TwoIVFFlat {
22 |
23 | private static Logger LOG = Logger.getLogger(TwoIVFFlat.class.getName());
24 | private static int d = 64; // dimension
25 | private static int nb = 100000; // database size
26 | private static int nq = 10000; // nb of queries
27 | private static int nlist = 100;
28 |
29 | private floatArray xb;
30 | private floatArray xq;
31 |
32 | private Random random;
33 | private IndexFlatL2 quantizer;
34 | private IndexIVFFlat index;
35 |
36 | public TwoIVFFlat() {
37 |
38 | Preconditions.checkArgument(JniFaissInitializer.initialized());
39 | random = new Random(42);
40 |
41 | xb = makeRandomFloatArray(nb, d, random);
42 | xq = makeRandomFloatArray(nq, d, random);
43 |
44 | quantizer = new IndexFlatL2(d);
45 | index = new IndexIVFFlat(quantizer, d, nlist, MetricType.METRIC_L2);
46 | Preconditions.checkArgument(!index.getIs_trained());
47 | index.train(nb, xb.cast());
48 | Preconditions.checkArgument(index.getIs_trained());
49 | index.add(nb, xb.cast());
50 | }
51 |
52 |
53 | public void search() {
54 | int rn = 4;
55 | floatArray distances = new floatArray(nq * rn);
56 | longArray indices = new longArray(nq * rn);
57 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
58 |
59 | LOG.info(show(distances, 5, rn));
60 | LOG.info(show(indices, 5, rn));
61 | }
62 |
63 | public void searchAgain() {
64 | int rn = 4;
65 | floatArray distances = new floatArray(nq * rn);
66 | longArray indices = new longArray(nq * rn);
67 | index.setNprobe(10);
68 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
69 |
70 | LOG.info(show(distances, 5, rn));
71 | LOG.info(show(indices, 5, rn));
72 | }
73 |
74 | public static void main(String[] args) {
75 | TwoIVFFlat twoIVFFlat = new TwoIVFFlat();
76 |
77 | LOG.info("****************************************************");
78 | twoIVFFlat.search();
79 | LOG.info("****************************************************");
80 | twoIVFFlat.searchAgain();
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/ThreeIVFPQ.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
5 |
6 | import com.gameofdimension.faiss.swig.IndexFlatL2;
7 | import com.gameofdimension.faiss.swig.IndexIVFPQ;
8 | import com.gameofdimension.faiss.swig.floatArray;
9 | import com.gameofdimension.faiss.swig.longArray;
10 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
11 | import com.google.common.base.Preconditions;
12 |
13 | import java.util.Random;
14 | import java.util.logging.Logger;
15 |
16 | /**
17 | * @author yzq, yzq@leyantech.com
18 | * @date 2020-01-31.
19 | */
20 | public class ThreeIVFPQ {
21 |
22 | private static Logger LOG = Logger.getLogger(ThreeIVFPQ.class.getName());
23 | private static int d = 64; // dimension
24 | private static int nb = 100000; // database size
25 | private static int nq = 10000; // nb of queries
26 | private static int nlist = 100;
27 | private static int m = 8;
28 |
29 | private floatArray xb;
30 | private floatArray xq;
31 |
32 | private Random random;
33 | private IndexFlatL2 quantizer;
34 | private IndexIVFPQ index;
35 |
36 | public ThreeIVFPQ() {
37 |
38 | Preconditions.checkArgument(JniFaissInitializer.initialized());
39 | random = new Random(42);
40 |
41 | xb = makeRandomFloatArray(nb, d, random);
42 | xq = makeRandomFloatArray(nq, d, random);
43 |
44 | quantizer = new IndexFlatL2(d);
45 | index = new IndexIVFPQ(quantizer, d, nlist, m, 8);
46 | Preconditions.checkArgument(!index.getIs_trained());
47 | index.train(nb, xb.cast());
48 | Preconditions.checkArgument(index.getIs_trained());
49 | index.add(nb, xb.cast());
50 | }
51 |
52 |
53 | public void sanityCheck() {
54 |
55 | int rn = 4;
56 | int qn = 5;
57 | floatArray distances = new floatArray(qn * rn);
58 | longArray indices = new longArray(qn * rn);
59 | index.setNprobe(10);
60 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast());
61 |
62 | LOG.info(show(distances, qn, rn));
63 | LOG.info(show(indices, qn, rn));
64 | }
65 |
66 |
67 | public void search() {
68 |
69 | int rn = 4;
70 | floatArray distances = new floatArray(nq * rn);
71 | longArray indices = new longArray(nq * rn);
72 | index.setNprobe(10);
73 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
74 |
75 | LOG.info(show(distances, 5, rn));
76 | LOG.info(show(indices, 5, rn));
77 | }
78 |
79 | public static void main(String[] args) {
80 | ThreeIVFPQ threeIVFPQ = new ThreeIVFPQ();
81 |
82 | LOG.info("****************************************************");
83 | threeIVFPQ.sanityCheck();
84 | LOG.info("****************************************************");
85 | threeIVFPQ.search();
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/FiveMultipleGPUs.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.swig.swigfaiss_gpu.getNumDevices;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
5 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
6 |
7 | import com.gameofdimension.faiss.swig.GpuResourcesVector;
8 | import com.gameofdimension.faiss.swig.Index;
9 | import com.gameofdimension.faiss.swig.IndexFlatL2;
10 | import com.gameofdimension.faiss.swig.IntVector;
11 | import com.gameofdimension.faiss.swig.StandardGpuResources;
12 | import com.gameofdimension.faiss.swig.floatArray;
13 | import com.gameofdimension.faiss.swig.longArray;
14 | import com.gameofdimension.faiss.swig.swigfaiss_gpu;
15 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
16 | import com.google.common.base.Preconditions;
17 |
18 | import java.util.Random;
19 | import java.util.logging.Logger;
20 |
21 | /**
22 | * @author yzq, yzq@leyantech.com
23 | * @date 2020-02-01.
24 | */
25 | public class FiveMultipleGPUs {
26 |
27 | private static Logger LOG = Logger.getLogger(FiveMultipleGPUs.class.getName());
28 | private static int d = 64; // dimension
29 | private static int nb = 100000; // database size
30 | private static int nq = 10000; // nb of queries
31 |
32 | private floatArray xb;
33 | private floatArray xq;
34 |
35 | private Random random;
36 | private Index gpuIndex;
37 |
38 | public FiveMultipleGPUs() {
39 |
40 | Preconditions.checkArgument(JniFaissInitializer.initialized());
41 | int ngpus = getNumDevices();
42 | LOG.info(String.format("Number of GPUs: %d", ngpus));
43 | GpuResourcesVector res = new GpuResourcesVector();
44 | IntVector devs = new IntVector();
45 | for (int i = 0; i < ngpus; i++) {
46 | devs.push_back(i);
47 | res.push_back(new StandardGpuResources());
48 | }
49 |
50 | IndexFlatL2 cpuIndex = new IndexFlatL2(d);
51 | gpuIndex = swigfaiss_gpu.index_cpu_to_gpu_multiple(res, devs, cpuIndex);
52 |
53 | random = new Random(42);
54 | xb = makeRandomFloatArray(nb, d, random);
55 | xq = makeRandomFloatArray(nq, d, random);
56 | LOG.info(String.format("is_trained = %s", gpuIndex.getIs_trained()));
57 | gpuIndex.add(nb, xb.cast()); // add vectors to the index
58 | LOG.info(String.format("ntotal = %d", gpuIndex.getNtotal()));
59 | }
60 |
61 | public void search() {
62 | int rn = 4;
63 | floatArray distances = new floatArray(nq * rn);
64 | longArray indices = new longArray(nq * rn);
65 | gpuIndex.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
66 |
67 | LOG.info(show(distances, 5, rn));
68 | LOG.info(show(indices, 5, rn));
69 | }
70 |
71 | public static void main(String[] args) {
72 | FiveMultipleGPUs fiveMultipleGPUs = new FiveMultipleGPUs();
73 |
74 | LOG.info("****************************************************");
75 | fiveMultipleGPUs.search();
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/FourGPU.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.tutorial;
2 |
3 | import static com.gameofdimension.faiss.swig.MetricType.METRIC_L2;
4 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray;
5 | import static com.gameofdimension.faiss.utils.IndexHelper.show;
6 |
7 | import com.gameofdimension.faiss.swig.GpuIndexFlatL2;
8 | import com.gameofdimension.faiss.swig.GpuIndexIVFFlat;
9 | import com.gameofdimension.faiss.swig.StandardGpuResources;
10 | import com.gameofdimension.faiss.swig.floatArray;
11 | import com.gameofdimension.faiss.swig.longArray;
12 | import com.gameofdimension.faiss.utils.JniFaissInitializer;
13 | import com.google.common.base.Preconditions;
14 |
15 | import java.util.Random;
16 | import java.util.logging.Logger;
17 |
18 | /**
19 | * @author yzq, yzq@leyantech.com
20 | * @date 2020-02-01.
21 | */
22 | public class FourGPU {
23 |
24 | private static Logger LOG = Logger.getLogger(FourGPU.class.getName());
25 | private static int d = 64; // dimension
26 | private static int nb = 100000; // database size
27 | private static int nq = 10000; // nb of queries
28 | private static int nlist = 100;
29 |
30 | private floatArray xb;
31 | private floatArray xq;
32 |
33 | private Random random;
34 | private StandardGpuResources res;
35 | private GpuIndexFlatL2 index;
36 | private GpuIndexIVFFlat ivfIndex;
37 |
38 | public FourGPU() {
39 | Preconditions.checkArgument(JniFaissInitializer.initialized());
40 | random = new Random(42);
41 | res = new StandardGpuResources();
42 | index = new GpuIndexFlatL2(res, d);
43 | ivfIndex = new GpuIndexIVFFlat(res, d, nlist, METRIC_L2);
44 | xb = makeRandomFloatArray(nb, d, random);
45 | xq = makeRandomFloatArray(nq, d, random);
46 | ivfIndex.train(nb, xb.cast());
47 | ivfIndex.add(nb, xb.cast());
48 | index.add(nb, xb.cast());
49 | LOG.info(String.format("is_trained = %s, ntotal = %d",
50 | index.getIs_trained(), index.getNtotal()));
51 | }
52 |
53 | public void searchFlat() {
54 | int rn = 4;
55 | floatArray distances = new floatArray(nq * rn);
56 | longArray indices = new longArray(nq * rn);
57 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
58 |
59 | LOG.info(show(distances, 5, rn));
60 | LOG.info(show(indices, 5, rn));
61 | }
62 |
63 | public void searchIvf() {
64 | Preconditions.checkArgument(ivfIndex.getIs_trained());
65 | int rn = 4;
66 | floatArray distances = new floatArray(nq * rn);
67 | longArray indices = new longArray(nq * rn);
68 | ivfIndex.search(nq, xq.cast(), rn, distances.cast(), indices.cast());
69 |
70 | LOG.info(show(distances, 5, rn));
71 | LOG.info(show(indices, 5, rn));
72 | }
73 |
74 | public static void main(String[] args) {
75 | FourGPU fourGPU = new FourGPU();
76 |
77 | LOG.info("****************************************************");
78 | fourGPU.searchFlat();
79 | LOG.info("****************************************************");
80 | fourGPU.searchIvf();
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/cpu-demo/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | jni-faiss
7 | com.gameofdimension
8 | 0.0.1
9 |
10 | 4.0.0
11 |
12 | cpu-demo
13 |
14 |
15 |
16 | com.gameofdimension
17 | common
18 | 0.0.1
19 |
20 |
21 | com.gameofdimension
22 | cpu
23 | 0.0.1
24 |
25 |
26 |
27 |
28 |
29 | uber
30 |
31 | true
32 |
33 |
34 |
35 |
36 | org.apache.maven.plugins
37 | maven-shade-plugin
38 | 2.3
39 |
40 |
41 | package
42 |
43 | shade
44 |
45 |
46 |
47 |
49 |
50 |
51 |
52 |
53 | *.*
54 |
55 | META-INF/*.RSA
56 | META-INF/*.DSA
57 | META-INF/*.SF
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | org.apache.maven.plugins
67 | maven-compiler-plugin
68 |
69 | 1.8
70 | 1.8
71 |
72 |
73 |
74 | org.jacoco
75 | jacoco-maven-plugin
76 |
77 |
78 |
79 |
80 |
81 | no-uber
82 |
83 |
84 |
85 | maven-compiler-plugin
86 | 3.5.1
87 |
88 | 1.8
89 | 1.8
90 |
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/gpu-demo/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 | jni-faiss
7 | com.gameofdimension
8 | 0.0.1
9 |
10 | 4.0.0
11 |
12 | gpu-demo
13 |
14 |
15 |
16 | com.gameofdimension
17 | common
18 | 0.0.1
19 |
20 |
21 | com.gameofdimension
22 | gpu
23 | 0.0.1
24 |
25 |
26 |
27 |
28 |
29 |
30 | uber
31 |
32 | true
33 |
34 |
35 |
36 |
37 | org.apache.maven.plugins
38 | maven-shade-plugin
39 | 2.3
40 |
41 |
42 | package
43 |
44 | shade
45 |
46 |
47 |
48 |
50 |
51 |
52 |
53 |
54 | *.*
55 |
56 | META-INF/*.RSA
57 | META-INF/*.DSA
58 | META-INF/*.SF
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | org.apache.maven.plugins
68 | maven-compiler-plugin
69 |
70 | 1.8
71 | 1.8
72 |
73 |
74 |
75 | org.jacoco
76 | jacoco-maven-plugin
77 |
78 |
79 |
80 |
81 |
82 | no-uber
83 |
84 |
85 |
86 | maven-compiler-plugin
87 | 3.5.1
88 |
89 | 1.8
90 | 1.8
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/cpu-demo/src/main/java/com/gameofdimension/faiss/utils/IndexHelper.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.utils;
2 |
3 | // copy from https://raw.githubusercontent.com/thenetcircle/faiss4j/master/src/main/java/com/thenetcircle/services/faiss4j/IndexHelper.java
4 |
5 | import com.gameofdimension.faiss.swig.floatArray;
6 | import com.gameofdimension.faiss.swig.intArray;
7 | import com.gameofdimension.faiss.swig.longArray;
8 |
9 | import java.util.Random;
10 | // import org.slf4j.Logger;
11 | // import org.slf4j.LoggerFactory;
12 |
13 | public class IndexHelper {
14 | // private static final Logger log = LoggerFactory.getLogger(IndexHelper.class);
15 |
16 | public static String show(longArray a, int rows, int cols) {
17 | StringBuilder sb = new StringBuilder();
18 | for (int i = 0; i < rows; i++) {
19 | sb.append(i).append('\t').append('|');
20 | for (int j = 0; j < cols; j++) {
21 | sb.append(String.format("%5d ", a.getitem(i * cols + j)));
22 | }
23 | sb.append("\n");
24 | }
25 | return sb.toString();
26 | }
27 |
28 | public static String show(floatArray a, int rows, int cols) {
29 | StringBuilder sb = new StringBuilder();
30 | for (int i = 0; i < rows; i++) {
31 | sb.append(i).append('\t').append('|');
32 | for (int j = 0; j < cols; j++) {
33 | sb.append(String.format("%7g ", a.getitem(i * cols + j)));
34 | }
35 | sb.append("\n");
36 | }
37 | return sb.toString();
38 | }
39 |
40 | public static floatArray makeFloatArray(float[][] vectors) {
41 | int d = vectors[0].length;
42 | int nb = vectors.length;
43 | floatArray fa = new floatArray(d * nb);
44 | for (int i = 0; i < nb; i++) {
45 | for (int j = 0; j < d; j++) {
46 | fa.setitem(d * i + j, vectors[i][j]);
47 | }
48 | }
49 | return fa;
50 | }
51 |
52 | public static longArray makeLongArray(int[] ints) {
53 | int len = ints.length;
54 | longArray la = new longArray(len);
55 | for (int i = 0; i < len; i++) {
56 | la.setitem(i, ints[i]);
57 | }
58 | return la;
59 | }
60 |
61 | public static long[] toArray(longArray c_array, int length) {
62 | return toArray(c_array, 0, length);
63 | }
64 |
65 | public static long[] toArray(longArray c_array, int start, int length) {
66 | long[] re = new long[length];
67 | for (int i = start; i < length; i++) {
68 | re[i] = c_array.getitem(i);
69 | }
70 | return re;
71 | }
72 |
73 | public static int[] toArray(intArray c_array, int length) {
74 | return toArray(c_array, 0, length);
75 | }
76 |
77 | public static int[] toArray(intArray c_array, int start, int length) {
78 | int[] re = new int[length];
79 | for (int i = start; i < length; i++) {
80 | re[i] = c_array.getitem(i);
81 | }
82 | return re;
83 | }
84 |
85 | public static float[] toArray(floatArray c_array, int length) {
86 | return toArray(c_array, 0, length);
87 | }
88 |
89 | public static float[] toArray(floatArray c_array, int start, int length) {
90 | float[] re = new float[length];
91 | for (int i = start; i < length; i++) {
92 | re[i] = c_array.getitem(i);
93 | }
94 | return re;
95 | }
96 |
97 | public static floatArray makeRandomFloatArray(int size, int dim, Random random) {
98 | floatArray array = new floatArray(size * dim);
99 | for (int i = 0; i < size; i++) {
100 | for (int j = 0; j < dim; j++) {
101 | array.setitem(i * dim + j, random.nextFloat());
102 | }
103 | array.setitem(i * dim, i / 1000.f + array.getitem(i * dim));
104 | }
105 | return array;
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/gpu-demo/src/main/java/com/gameofdimension/faiss/utils/IndexHelper.java:
--------------------------------------------------------------------------------
1 | package com.gameofdimension.faiss.utils;
2 |
3 | // copy from https://raw.githubusercontent.com/thenetcircle/faiss4j/master/src/main/java/com/thenetcircle/services/faiss4j/IndexHelper.java
4 |
5 | import com.gameofdimension.faiss.swig.floatArray;
6 | import com.gameofdimension.faiss.swig.intArray;
7 | import com.gameofdimension.faiss.swig.longArray;
8 |
9 | import java.util.Random;
10 | // import org.slf4j.Logger;
11 | // import org.slf4j.LoggerFactory;
12 |
13 | public class IndexHelper {
14 | // private static final Logger log = LoggerFactory.getLogger(IndexHelper.class);
15 |
16 | public static String show(longArray a, int rows, int cols) {
17 | StringBuilder sb = new StringBuilder();
18 | for (int i = 0; i < rows; i++) {
19 | sb.append(i).append('\t').append('|');
20 | for (int j = 0; j < cols; j++) {
21 | sb.append(String.format("%5d ", a.getitem(i * cols + j)));
22 | }
23 | sb.append("\n");
24 | }
25 | return sb.toString();
26 | }
27 |
28 | public static String show(floatArray a, int rows, int cols) {
29 | StringBuilder sb = new StringBuilder();
30 | for (int i = 0; i < rows; i++) {
31 | sb.append(i).append('\t').append('|');
32 | for (int j = 0; j < cols; j++) {
33 | sb.append(String.format("%7g ", a.getitem(i * cols + j)));
34 | }
35 | sb.append("\n");
36 | }
37 | return sb.toString();
38 | }
39 |
40 | public static floatArray makeFloatArray(float[][] vectors) {
41 | int d = vectors[0].length;
42 | int nb = vectors.length;
43 | floatArray fa = new floatArray(d * nb);
44 | for (int i = 0; i < nb; i++) {
45 | for (int j = 0; j < d; j++) {
46 | fa.setitem(d * i + j, vectors[i][j]);
47 | }
48 | }
49 | return fa;
50 | }
51 |
52 | public static longArray makeLongArray(int[] ints) {
53 | int len = ints.length;
54 | longArray la = new longArray(len);
55 | for (int i = 0; i < len; i++) {
56 | la.setitem(i, ints[i]);
57 | }
58 | return la;
59 | }
60 |
61 | public static long[] toArray(longArray c_array, int length) {
62 | return toArray(c_array, 0, length);
63 | }
64 |
65 | public static long[] toArray(longArray c_array, int start, int length) {
66 | long[] re = new long[length];
67 | for (int i = start; i < length; i++) {
68 | re[i] = c_array.getitem(i);
69 | }
70 | return re;
71 | }
72 |
73 | public static int[] toArray(intArray c_array, int length) {
74 | return toArray(c_array, 0, length);
75 | }
76 |
77 | public static int[] toArray(intArray c_array, int start, int length) {
78 | int[] re = new int[length];
79 | for (int i = start; i < length; i++) {
80 | re[i] = c_array.getitem(i);
81 | }
82 | return re;
83 | }
84 |
85 | public static float[] toArray(floatArray c_array, int length) {
86 | return toArray(c_array, 0, length);
87 | }
88 |
89 | public static float[] toArray(floatArray c_array, int start, int length) {
90 | float[] re = new float[length];
91 | for (int i = start; i < length; i++) {
92 | re[i] = c_array.getitem(i);
93 | }
94 | return re;
95 | }
96 |
97 | public static floatArray makeRandomFloatArray(int size, int dim, Random random) {
98 | floatArray array = new floatArray(size * dim);
99 | for (int i = 0; i < size; i++) {
100 | for (int j = 0; j < dim; j++) {
101 | array.setitem(i * dim + j, random.nextFloat());
102 | }
103 | array.setitem(i * dim, i / 1000.f + array.getitem(i * dim));
104 | }
105 | return array;
106 | }
107 | }
108 |
--------------------------------------------------------------------------------
/common/src/main/java/com/gameofdimension/faiss/utils/NativeUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Class NativeUtils is published under the The MIT License:
3 | *
4 | * Copyright (c) 2012 Adam Heinrich
5 | *
6 | * Permission is hereby granted, free of charge, to any person obtaining a copy
7 | * of this software and associated documentation files (the "Software"), to deal
8 | * in the Software without restriction, including without limitation the rights
9 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | * copies of the Software, and to permit persons to whom the Software is
11 | * furnished to do so, subject to the following conditions:
12 | *
13 | * The above copyright notice and this permission notice shall be included in all
14 | * copies or substantial portions of the Software.
15 | *
16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | * SOFTWARE.
23 | */
24 |
25 | package com.gameofdimension.faiss.utils;
26 |
27 | import java.io.*;
28 | import java.nio.file.FileSystemNotFoundException;
29 | import java.nio.file.FileSystems;
30 | import java.nio.file.Files;
31 | import java.nio.file.ProviderNotFoundException;
32 | import java.nio.file.StandardCopyOption;
33 |
34 | /**
35 | * A simple library class which helps with loading dynamic libraries stored in the JAR archive.
36 | * These libraries usually contain implementation of some methods in native code (using JNI - Java
37 | * Native Interface).
38 | *
39 | * @see http://adamheinrich.com/blog/2012/how-to-load-native-jni-library-from-jar
40 | * @see https://github.com/adamheinrich/native-utils
41 | */
42 | public class NativeUtils {
43 |
44 | /**
45 | * The minimum length a prefix for a file has to have according to {@link
46 | * File#createTempFile(String, String)}}.
47 | */
48 | private static final int MIN_PREFIX_LENGTH = 3;
49 | public static final String NATIVE_FOLDER_PATH_PREFIX = "nativeutils";
50 |
51 | /**
52 | * Temporary directory which will contain the DLLs.
53 | */
54 | private static File temporaryDir;
55 |
56 | /**
57 | * Private constructor - this class will never be instanced
58 | */
59 | private NativeUtils() {
60 | }
61 |
62 | /**
63 | * Loads library from current JAR archive
64 | *
65 | * The file from JAR is copied into system temporary directory and then loaded. The temporary file
66 | * is deleted after exiting. Method uses String as filename because the pathname is "abstract",
67 | * not system-dependent.
68 | *
69 | * @param path The path of file inside JAR as absolute path (beginning with '/'), e.g.
70 | * /package/File.ext
71 | * @throws IOException If temporary file creation or read/write operation fails
72 | * @throws IllegalArgumentException If source file (param path) does not exist
73 | * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than
74 | * three characters (restriction of {@link File#createTempFile(java.lang.String,
75 | * java.lang.String)}).
76 | * @throws FileNotFoundException If the file could not be found inside the JAR.
77 | */
78 | public static void loadLibraryFromJar(String path) throws IOException {
79 |
80 | if (null == path || !path.startsWith("/")) {
81 | throw new IllegalArgumentException("The path has to be absolute (start with '/').");
82 | }
83 |
84 | // Obtain filename from path
85 | String[] parts = path.split("/");
86 | String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
87 |
88 | // Check if the filename is okay
89 | if (filename == null || filename.length() < MIN_PREFIX_LENGTH) {
90 | throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
91 | }
92 |
93 | // Prepare temporary file
94 | if (temporaryDir == null) {
95 | temporaryDir = createTempDirectory(NATIVE_FOLDER_PATH_PREFIX);
96 | temporaryDir.deleteOnExit();
97 | }
98 |
99 | File temp = new File(temporaryDir, filename);
100 |
101 | try (InputStream is = NativeUtils.class.getResourceAsStream(path)) {
102 | Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING);
103 | } catch (IOException e) {
104 | temp.delete();
105 | throw e;
106 | } catch (NullPointerException e) {
107 | temp.delete();
108 | throw new FileNotFoundException("File " + path + " was not found inside JAR.");
109 | }
110 |
111 | try {
112 | System.load(temp.getAbsolutePath());
113 | } finally {
114 | if (isPosixCompliant()) {
115 | // Assume POSIX compliant file system, can be deleted after loading
116 | temp.delete();
117 | } else {
118 | // Assume non-POSIX, and don't delete until last file descriptor closed
119 | temp.deleteOnExit();
120 | }
121 | }
122 | }
123 |
124 | private static boolean isPosixCompliant() {
125 | try {
126 | return FileSystems.getDefault()
127 | .supportedFileAttributeViews()
128 | .contains("posix");
129 | } catch (FileSystemNotFoundException
130 | | ProviderNotFoundException
131 | | SecurityException e) {
132 | return false;
133 | }
134 | }
135 |
136 | private static File createTempDirectory(String prefix) throws IOException {
137 | String tempDir = System.getProperty("java.io.tmpdir");
138 | File generatedDir = new File(tempDir, prefix + System.nanoTime());
139 |
140 | if (!generatedDir.mkdir()) {
141 | throw new IOException("Failed to create temp directory " + generatedDir.getName());
142 | }
143 |
144 | return generatedDir;
145 | }
146 | }
--------------------------------------------------------------------------------
/jni/swigfaiss.swig:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | // -*- C++ -*-
9 |
10 | // This file describes the C++-scripting language bridge for both Lua
11 | // and Python It contains mainly includes and a few macros. There are
12 | // 3 preprocessor macros of interest:
13 | // SWIGLUA: Lua-specific code
14 | // SWIGPYTHON: Python-specific code
15 | // GPU_WRAPPER: also compile interfaces for GPU.
16 |
17 | #ifdef SWIGJAVA
18 | %include "arrays_java.i"
19 | //%apply int[] {int *};
20 | //%apply double[] {double *};
21 | //%apply float[] {float *};
22 | //%apply long[] {long *};
23 | %include "carrays.i"
24 | %array_class(int, intArray);
25 | %array_class(float, floatArray);
26 | %array_class(long, longArray);
27 | %array_class(double, doubleArray);
28 | #endif
29 |
30 | #ifdef GPU_WRAPPER
31 | %module swigfaiss_gpu;
32 | #else
33 | %module swigfaiss;
34 | #endif
35 |
36 | // %module swigfaiss;
37 |
38 | %rename(faiss_RangeSearchPartialResult_finalize) faiss::RangeSearchPartialResult::finalize();
39 | %rename(faiss_gpu_GpuIndexIVF_getQuantizer) faiss::gpu::GpuIndexIVF::getQuantizer();
40 | %ignore wait();
41 |
42 | // fbode SWIG fails on warnings, so make them non fatal
43 | #pragma SWIG nowarn=321
44 | #pragma SWIG nowarn=403
45 | #pragma SWIG nowarn=325
46 | #pragma SWIG nowarn=389
47 | #pragma SWIG nowarn=341
48 | #pragma SWIG nowarn=512
49 |
50 | %include
51 | typedef int64_t size_t;
52 |
53 | #define __restrict
54 |
55 |
56 | /*******************************************************************
57 | * Copied verbatim to wrapper. Contains the C++-visible includes, and
58 | * the language includes for their respective matrix libraries.
59 | *******************************************************************/
60 |
61 | %{
62 |
63 |
64 | #include
65 | #include
66 |
67 |
68 | #ifdef SWIGLUA
69 |
70 | #include
71 |
72 | extern "C" {
73 |
74 | #include |
75 | #include
76 | #undef THTensor
77 |
78 | }
79 |
80 | #endif
81 |
82 |
83 | #ifdef SWIGPYTHON
84 |
85 | #undef popcount64
86 |
87 | #define SWIG_FILE_WITH_INIT
88 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
89 | #include
90 |
91 | #endif
92 |
93 |
94 | #include
95 | #include
96 | #include
97 | #include
98 | #include
99 | #include
100 | #include
101 | #include
102 | #include
103 | #include
104 | #include
105 | #include
106 | #include
107 | #include
108 | #include
109 | #include
110 | #include
111 | #include
112 | #include
113 |
114 | #include
115 | #include
116 | #include
117 | #include
118 |
119 | #include
120 | #include
121 | #include
122 |
123 | #include
124 | #include
125 | #include
126 | #include
127 | #include
128 | #include
129 | #include
130 | #include
131 |
132 | #include
133 |
134 | #include
135 |
136 | #include
137 | #include
138 | #include
139 |
140 | #include
141 | #include
142 |
143 |
144 | %}
145 |
146 | /********************************************************
147 | * GIL manipulation and exception handling
148 | ********************************************************/
149 |
150 | #ifdef SWIGPYTHON
151 | // %catches(faiss::FaissException);
152 |
153 |
154 | // Python-specific: release GIL by default for all functions
155 | %exception {
156 | Py_BEGIN_ALLOW_THREADS
157 | try {
158 | $action
159 | } catch(faiss::FaissException & e) {
160 | PyEval_RestoreThread(_save);
161 |
162 | if (PyErr_Occurred()) {
163 | // some previous code already set the error type.
164 | } else {
165 | PyErr_SetString(PyExc_RuntimeError, e.what());
166 | }
167 | SWIG_fail;
168 | } catch(std::bad_alloc & ba) {
169 | PyEval_RestoreThread(_save);
170 | PyErr_SetString(PyExc_MemoryError, "std::bad_alloc");
171 | SWIG_fail;
172 | }
173 | Py_END_ALLOW_THREADS
174 | }
175 |
176 | #endif
177 |
178 | #ifdef SWIGLUA
179 |
180 | %exception {
181 | try {
182 | $action
183 | } catch(faiss::FaissException & e) {
184 | SWIG_Lua_pushferrstring(L, "C++ exception: %s", e.what()); \
185 | goto fail;
186 | }
187 | }
188 |
189 | #endif
190 |
191 |
192 | /*******************************************************************
193 | * Types of vectors we want to manipulate at the scripting language
194 | * level.
195 | *******************************************************************/
196 |
197 | // simplified interface for vector
198 | namespace std {
199 |
200 | template
201 | class vector {
202 | public:
203 | vector();
204 | void push_back(T);
205 | void clear();
206 | T * data();
207 | size_t size();
208 | T at (size_t n) const;
209 | void resize (size_t n);
210 | void swap (vector & other);
211 | };
212 | };
213 |
214 |
215 |
216 | %template(FloatVector) std::vector;
217 | %template(DoubleVector) std::vector;
218 | %template(ByteVector) std::vector;
219 | %template(CharVector) std::vector;
220 | // NOTE(hoss): Using unsigned long instead of uint64_t because OSX defines
221 | // uint64_t as unsigned long long, which SWIG is not aware of.
222 | %template(Uint64Vector) std::vector;
223 | %template(LongVector) std::vector;
224 | %template(IntVector) std::vector;
225 | %template(FloatVectorVector) std::vector >;
226 | %template(ByteVectorVector) std::vector >;
227 | %template(LongVectorVector) std::vector >;
228 | %template(VectorTransformVector) std::vector;
229 | %template(OperatingPointVector) std::vector;
230 | %template(InvertedListsPtrVector) std::vector;
231 | %template(RepeatVector) std::vector;
232 |
233 | #ifdef GPU_WRAPPER
234 | %template(GpuResourcesVector) std::vector;
235 | #endif
236 |
237 | %include
238 |
239 | // produces an error on the Mac
240 | %ignore faiss::hamming;
241 |
242 | /*******************************************************************
243 | * Parse headers
244 | *******************************************************************/
245 |
246 |
247 | %ignore *::cmp;
248 |
249 | %include
250 | %include
251 |
252 | int get_num_gpus();
253 | void gpu_profiler_start();
254 | void gpu_profiler_stop();
255 | void gpu_sync_all_devices();
256 |
257 | #ifdef GPU_WRAPPER
258 |
259 | %{
260 |
261 | #include
262 | #include
263 | #include
264 | #include
265 | #include
266 | #include
267 | #include
268 | #include
269 | #include
270 | #include
271 | #include
272 | #include
273 | #include
274 | #include
275 | #include
276 |
277 | int get_num_gpus()
278 | {
279 | return faiss::gpu::getNumDevices();
280 | }
281 |
282 | void gpu_profiler_start()
283 | {
284 | return faiss::gpu::profilerStart();
285 | }
286 |
287 | void gpu_profiler_stop()
288 | {
289 | return faiss::gpu::profilerStop();
290 | }
291 |
292 | void gpu_sync_all_devices()
293 | {
294 | return faiss::gpu::synchronizeAllDevices();
295 | }
296 |
297 | %}
298 |
299 | // causes weird wrapper bug
300 | %ignore *::getMemoryManager;
301 | %ignore *::getMemoryManagerCurrentDevice;
302 |
303 | %include
304 | %include
305 |
306 | #else
307 |
308 | %{
309 | int get_num_gpus()
310 | {
311 | return 0;
312 | }
313 |
314 | void gpu_profiler_start()
315 | {
316 | }
317 |
318 | void gpu_profiler_stop()
319 | {
320 | }
321 |
322 | void gpu_sync_all_devices()
323 | {
324 | }
325 | %}
326 |
327 |
328 | #endif
329 |
330 | // order matters because includes are not recursive
331 |
332 | %include
333 | %include
334 | %include
335 |
336 | %include
337 | %include
338 |
339 | %include
340 |
341 | %ignore faiss::ProductQuantizer::get_centroids(size_t,size_t) const;
342 |
343 | %include
344 |
345 | %include
346 | %include
347 | %include
348 | %include
349 | %include
350 | %include
351 | %include
352 | %ignore InvertedListScanner;
353 | %ignore BinaryInvertedListScanner;
354 | %include
355 | // NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the
356 | // non-const one.
357 | %warnfilter(509) extract_index_ivf;
358 | %include
359 | %ignore faiss::ScalarQuantizer::SQDistanceComputer;
360 | %include
361 | %include
362 | %include
363 | %include
364 | %include
365 | %include
366 | %include
367 |
368 | %include
369 | %include
370 |
371 | %ignore faiss::IndexIVFPQ::alloc_type;
372 | %include
373 | %include
374 | %include
375 |
376 | %include
377 | %include
378 | %include
379 | %include
380 | %include
381 |
382 |
383 |
384 | // %ignore faiss::IndexReplicas::at(int) const;
385 |
386 | %include
387 | %template(ThreadedIndexBase) faiss::ThreadedIndex;
388 | %template(ThreadedIndexBaseBinary) faiss::ThreadedIndex;
389 |
390 | %include
391 | %template(IndexShards) faiss::IndexShardsTemplate;
392 | %template(IndexBinaryShards) faiss::IndexShardsTemplate;
393 |
394 | %include
395 | %template(IndexReplicas) faiss::IndexReplicasTemplate;
396 | %template(IndexBinaryReplicas) faiss::IndexReplicasTemplate;
397 |
398 | %include
399 | %template(IndexIDMap) faiss::IndexIDMapTemplate;
400 | %template(IndexBinaryIDMap) faiss::IndexIDMapTemplate;
401 | %template(IndexIDMap2) faiss::IndexIDMap2Template;
402 | %template(IndexBinaryIDMap2) faiss::IndexIDMap2Template;
403 |
404 | #ifdef GPU_WRAPPER
405 |
406 | // quiet SWIG warnings
407 | %ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF;
408 |
409 | %include
410 | %include
411 | %include
412 | %include
413 | %include
414 | %include
415 | %include
416 | %include
417 | %include
418 | %include
419 | %include
420 | %include
421 |
422 | #ifdef SWIGLUA
423 |
424 | /// in Lua, swigfaiss_gpu is known as swigfaiss
425 | %luacode {
426 | local swigfaiss = swigfaiss_gpu
427 | }
428 |
429 | #endif
430 |
431 |
432 | #endif
433 |
434 |
435 |
436 |
437 | /*******************************************************************
438 | * Lua-specific: support async execution of searches in an index
439 | * Python equivalent is just to use Python threads.
440 | *******************************************************************/
441 |
442 |
443 | #ifdef SWIGLUA
444 |
445 | %{
446 |
447 |
448 | namespace faiss {
449 |
450 | struct AsyncIndexSearchC {
451 | typedef Index::idx_t idx_t;
452 | const Index * index;
453 |
454 | idx_t n;
455 | const float *x;
456 | idx_t k;
457 | float *distances;
458 | idx_t *labels;
459 |
460 | bool is_finished;
461 |
462 | pthread_t thread;
463 |
464 |
465 | AsyncIndexSearchC (const Index *index,
466 | idx_t n, const float *x, idx_t k,
467 | float *distances, idx_t *labels):
468 | index(index), n(n), x(x), k(k), distances(distances),
469 | labels(labels)
470 | {
471 | is_finished = false;
472 | pthread_create (&thread, NULL, &AsyncIndexSearchC::callback,
473 | this);
474 | }
475 |
476 | static void *callback (void *arg)
477 | {
478 | AsyncIndexSearchC *aidx = (AsyncIndexSearchC *)arg;
479 | aidx->do_search();
480 | return NULL;
481 | }
482 |
483 | void do_search ()
484 | {
485 | index->search (n, x, k, distances, labels);
486 | }
487 | void join ()
488 | {
489 | pthread_join (thread, NULL);
490 | }
491 |
492 | };
493 |
494 | }
495 |
496 | %}
497 |
498 | // re-decrlare only what we need
499 | namespace faiss {
500 |
501 | struct AsyncIndexSearchC {
502 | typedef Index::idx_t idx_t;
503 | bool is_finished;
504 | AsyncIndexSearchC (const Index *index,
505 | idx_t n, const float *x, idx_t k,
506 | float *distances, idx_t *labels);
507 |
508 |
509 | void join ();
510 | };
511 |
512 | }
513 |
514 |
515 | #endif
516 |
517 |
518 |
519 |
520 | /*******************************************************************
521 | * downcast return of some functions so that the sub-class is used
522 | * instead of the generic upper-class.
523 | *******************************************************************/
524 |
525 | #ifdef SWIGJAVA
526 |
527 | %define DOWNCAST(subclass)
528 | if (dynamic_cast ($1)) {
529 | faiss::subclass *instance_ptr = (faiss::subclass *)$1;
530 | $result = (jlong)instance_ptr;
531 | } else
532 | %enddef
533 |
534 | %define DOWNCAST_GPU(subclass)
535 | if (dynamic_cast ($1)) {
536 | faiss::gpu::subclass *instance_ptr = (faiss::gpu::subclass *)$1;
537 | $result = (jlong)instance_ptr;
538 | } else
539 | %enddef
540 |
541 | #endif
542 |
543 | #ifdef SWIGLUA
544 |
545 | %define DOWNCAST(subclass)
546 | if (dynamic_cast ($1)) {
547 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## subclass, $owner);
548 | } else
549 | %enddef
550 |
551 | %define DOWNCAST2(subclass, longname)
552 | if (dynamic_cast ($1)) {
553 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## longname, $owner);
554 | } else
555 | %enddef
556 |
557 | %define DOWNCAST_GPU(subclass)
558 | if (dynamic_cast ($1)) {
559 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__gpu__ ## subclass, $owner);
560 | } else
561 | %enddef
562 |
563 | #endif
564 |
565 |
566 | #ifdef SWIGPYTHON
567 |
568 | %define DOWNCAST(subclass)
569 | if (dynamic_cast ($1)) {
570 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner);
571 | } else
572 | %enddef
573 |
574 | %define DOWNCAST2(subclass, longname)
575 | if (dynamic_cast ($1)) {
576 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner);
577 | } else
578 | %enddef
579 |
580 | %define DOWNCAST_GPU(subclass)
581 | if (dynamic_cast ($1)) {
582 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__gpu__ ## subclass,$owner);
583 | } else
584 | %enddef
585 |
586 | #endif
587 |
588 | %newobject read_index;
589 | %newobject read_index_binary;
590 | %newobject read_VectorTransform;
591 | %newobject read_ProductQuantizer;
592 | %newobject clone_index;
593 | %newobject clone_VectorTransform;
594 |
595 | // Subclasses should appear before their parent
596 | %typemap(out) faiss::Index * {
597 | DOWNCAST ( IndexIDMap )
598 | DOWNCAST ( IndexIDMap2 )
599 | DOWNCAST ( IndexShards )
600 | DOWNCAST ( IndexReplicas )
601 | DOWNCAST ( IndexIVFPQR )
602 | DOWNCAST ( IndexIVFPQ )
603 | DOWNCAST ( IndexIVFSpectralHash )
604 | DOWNCAST ( IndexIVFScalarQuantizer )
605 | DOWNCAST ( IndexIVFFlatDedup )
606 | DOWNCAST ( IndexIVFFlat )
607 | DOWNCAST ( IndexIVF )
608 | DOWNCAST ( IndexFlat )
609 | DOWNCAST ( IndexPQ )
610 | DOWNCAST ( IndexScalarQuantizer )
611 | DOWNCAST ( IndexLSH )
612 | DOWNCAST ( IndexLattice )
613 | DOWNCAST ( IndexPreTransform )
614 | DOWNCAST ( MultiIndexQuantizer )
615 | DOWNCAST ( IndexHNSWFlat )
616 | DOWNCAST ( IndexHNSWPQ )
617 | DOWNCAST ( IndexHNSWSQ )
618 | DOWNCAST ( IndexHNSW2Level )
619 | DOWNCAST ( Index2Layer )
620 | #ifdef GPU_WRAPPER
621 | DOWNCAST_GPU ( GpuIndexIVFPQ )
622 | DOWNCAST_GPU ( GpuIndexIVFFlat )
623 | DOWNCAST_GPU ( GpuIndexIVFScalarQuantizer )
624 | DOWNCAST_GPU ( GpuIndexFlat )
625 | #endif
626 | // default for non-recognized classes
627 | DOWNCAST ( Index )
628 | if ($1 == NULL)
629 | {
630 | #ifdef SWIGPYTHON
631 | $result = SWIG_Py_Void();
632 | #endif
633 | #ifdef SWIGJAVA
634 | $result = 0;
635 | #endif
636 | // Lua does not need a push for nil
637 | } else {
638 | assert(false);
639 | }
640 | #ifdef SWIGLUA
641 | SWIG_arg++;
642 | #endif
643 | }
644 |
645 | %typemap(out) faiss::IndexBinary * {
646 | DOWNCAST ( IndexBinaryReplicas )
647 | DOWNCAST ( IndexBinaryIDMap )
648 | DOWNCAST ( IndexBinaryIDMap2 )
649 | DOWNCAST ( IndexBinaryIVF )
650 | DOWNCAST ( IndexBinaryFlat )
651 | DOWNCAST ( IndexBinaryFromFloat )
652 | DOWNCAST ( IndexBinaryHNSW )
653 | #ifdef GPU_WRAPPER
654 | DOWNCAST_GPU ( GpuIndexBinaryFlat )
655 | #endif
656 | // default for non-recognized classes
657 | DOWNCAST ( IndexBinary )
658 | if ($1 == NULL)
659 | {
660 | #ifdef SWIGPYTHON
661 | $result = SWIG_Py_Void();
662 | #endif
663 | #ifdef SWIGJAVA
664 | $result = 0;
665 | #endif
666 | // Lua does not need a push for nil
667 | } else {
668 | assert(false);
669 | }
670 | #ifdef SWIGLUA
671 | SWIG_arg++;
672 | #endif
673 | }
674 |
675 | %typemap(out) faiss::VectorTransform * {
676 | DOWNCAST (RemapDimensionsTransform)
677 | DOWNCAST (OPQMatrix)
678 | DOWNCAST (PCAMatrix)
679 | DOWNCAST (RandomRotationMatrix)
680 | DOWNCAST (LinearTransform)
681 | DOWNCAST (NormalizationTransform)
682 | DOWNCAST (CenteringTransform)
683 | DOWNCAST (VectorTransform)
684 | {
685 | assert(false);
686 | }
687 | #ifdef SWIGLUA
688 | SWIG_arg++;
689 | #endif
690 | }
691 |
692 | %typemap(out) faiss::InvertedLists * {
693 | DOWNCAST (ArrayInvertedLists)
694 | DOWNCAST (OnDiskInvertedLists)
695 | DOWNCAST (VStackInvertedLists)
696 | DOWNCAST (HStackInvertedLists)
697 | DOWNCAST (MaskedInvertedLists)
698 | DOWNCAST (InvertedLists)
699 | {
700 | assert(false);
701 | }
702 | #ifdef SWIGLUA
703 | SWIG_arg++;
704 | #endif
705 | }
706 |
707 | // just to downcast pointers that come from elsewhere (eg. direct
708 | // access to object fields)
709 | %inline %{
710 | faiss::Index * downcast_index (faiss::Index *index)
711 | {
712 | return index;
713 | }
714 | faiss::VectorTransform * downcast_VectorTransform (faiss::VectorTransform *vt)
715 | {
716 | return vt;
717 | }
718 | faiss::IndexBinary * downcast_IndexBinary (faiss::IndexBinary *index)
719 | {
720 | return index;
721 | }
722 | faiss::InvertedLists * downcast_InvertedLists (faiss::InvertedLists *il)
723 | {
724 | return il;
725 | }
726 | %}
727 |
728 | %include
729 | %include
730 | %include
731 |
732 | %newobject index_factory;
733 | %newobject index_binary_factory;
734 |
735 | %include
736 | %include
737 | %include
738 |
739 |
740 | #ifdef GPU_WRAPPER
741 |
742 | %include
743 |
744 | %newobject index_gpu_to_cpu;
745 | %newobject index_cpu_to_gpu;
746 | %newobject index_cpu_to_gpu_multiple;
747 |
748 | %include
749 |
750 | #endif
751 |
752 | // Python-specific: do not release GIL any more, as functions below
753 | // use the Python/C API
754 | #ifdef SWIGPYTHON
755 | %exception;
756 | #endif
757 |
758 |
759 |
760 |
761 |
762 | /*******************************************************************
763 | * Python specific: numpy array <-> C++ pointer interface
764 | *******************************************************************/
765 |
766 | #ifdef SWIGPYTHON
767 |
768 | %{
769 | PyObject *swig_ptr (PyObject *a)
770 | {
771 | if(!PyArray_Check(a)) {
772 | PyErr_SetString(PyExc_ValueError, "input not a numpy array");
773 | return NULL;
774 | }
775 | PyArrayObject *ao = (PyArrayObject *)a;
776 |
777 | if(!PyArray_ISCONTIGUOUS(ao)) {
778 | PyErr_SetString(PyExc_ValueError, "array is not C-contiguous");
779 | return NULL;
780 | }
781 | void * data = PyArray_DATA(ao);
782 | if(PyArray_TYPE(ao) == NPY_FLOAT32) {
783 | return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0);
784 | }
785 | if(PyArray_TYPE(ao) == NPY_FLOAT64) {
786 | return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0);
787 | }
788 | if(PyArray_TYPE(ao) == NPY_INT32) {
789 | return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0);
790 | }
791 | if(PyArray_TYPE(ao) == NPY_UINT8) {
792 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0);
793 | }
794 | if(PyArray_TYPE(ao) == NPY_INT8) {
795 | return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0);
796 | }
797 | if(PyArray_TYPE(ao) == NPY_UINT64) {
798 | #ifdef SWIGWORDSIZE64
799 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0);
800 | #else
801 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long_long, 0);
802 | #endif
803 | }
804 | if(PyArray_TYPE(ao) == NPY_INT64) {
805 | #ifdef SWIGWORDSIZE64
806 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0);
807 | #else
808 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long_long, 0);
809 | #endif
810 | }
811 | PyErr_SetString(PyExc_ValueError, "did not recognize array type");
812 | return NULL;
813 | }
814 |
815 |
816 | struct PythonInterruptCallback: faiss::InterruptCallback {
817 |
818 | bool want_interrupt () override {
819 | int err;
820 | {
821 | PyGILState_STATE gstate;
822 | gstate = PyGILState_Ensure();
823 | err = PyErr_CheckSignals();
824 | PyGILState_Release(gstate);
825 | }
826 | return err == -1;
827 | }
828 |
829 | };
830 |
831 |
832 | %}
833 |
834 |
835 | %init %{
836 | /* needed, else crash at runtime */
837 | import_array();
838 |
839 | faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
840 |
841 | %}
842 |
843 | // return a pointer usable as input for functions that expect pointers
844 | PyObject *swig_ptr (PyObject *a);
845 |
846 | %define REV_SWIG_PTR(ctype, numpytype)
847 |
848 | %{
849 | PyObject * rev_swig_ptr(ctype *src, npy_intp size) {
850 | return PyArray_SimpleNewFromData(1, &size, numpytype, src);
851 | }
852 | %}
853 |
854 | PyObject * rev_swig_ptr(ctype *src, size_t size);
855 |
856 | %enddef
857 |
858 | REV_SWIG_PTR(float, NPY_FLOAT32);
859 | REV_SWIG_PTR(int, NPY_INT32);
860 | REV_SWIG_PTR(unsigned char, NPY_UINT8);
861 | REV_SWIG_PTR(int64_t, NPY_INT64);
862 | REV_SWIG_PTR(uint64_t, NPY_UINT64);
863 |
864 | #endif
865 |
866 |
867 |
868 | /*******************************************************************
869 | * Lua specific: Torch tensor <-> C++ pointer interface
870 | *******************************************************************/
871 |
872 | #ifdef SWIGLUA
873 |
874 |
875 | // provide a XXX_ptr function to convert Lua XXXTensor -> C++ XXX*
876 |
877 | %define TYPE_CONVERSION(ctype, tensortype)
878 |
879 | // typemap for the *_ptr_from_cdata function
880 | %typemap(in) ctype** {
881 | if(lua_type(L, $input) != 10) {
882 | fprintf(stderr, "not cdata input\n");
883 | SWIG_fail;
884 | }
885 | $1 = (ctype**)lua_topointer(L, $input);
886 | }
887 |
888 |
889 | // SWIG and C declaration for the *_ptr_from_cdata function
890 | %{
891 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs) {
892 | return *x + ofs;
893 | }
894 | %}
895 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs);
896 |
897 | // the *_ptr function
898 | %luacode {
899 |
900 | function swigfaiss. ctype ## _ptr(tensor)
901 | assert(tensor:type() == "torch." .. # tensortype, "need a " .. # tensortype)
902 | assert(tensor:isContiguous(), "requires contiguous tensor")
903 | return swigfaiss. ctype ## _ptr_from_cdata(
904 | tensor:storage():data(),
905 | tensor:storageOffset() - 1)
906 | end
907 |
908 | }
909 |
910 | %enddef
911 |
912 | TYPE_CONVERSION (int, IntTensor)
913 | TYPE_CONVERSION (float, FloatTensor)
914 | TYPE_CONVERSION (long, LongTensor)
915 | TYPE_CONVERSION (uint64_t, LongTensor)
916 | TYPE_CONVERSION (uint8_t, ByteTensor)
917 |
918 | #endif
919 |
920 | /*******************************************************************
921 | * How should the template objects apprear in the scripting language?
922 | *******************************************************************/
923 |
924 | // answer: the same as the C++ typedefs, but we still have to redefine them
925 |
926 | %template() faiss::CMin;
927 | %template() faiss::CMin;
928 | %template() faiss::CMax;
929 | %template() faiss::CMax;
930 |
931 | %template(float_minheap_array_t) faiss::HeapArray >;
932 | %template(int_minheap_array_t) faiss::HeapArray >;
933 |
934 | %template(float_maxheap_array_t) faiss::HeapArray >;
935 | %template(int_maxheap_array_t) faiss::HeapArray >;
936 |
937 |
938 | /*******************************************************************
939 | * Expose a few basic functions
940 | *******************************************************************/
941 |
942 |
943 | void omp_set_num_threads (int num_threads);
944 | int omp_get_max_threads ();
945 | void *memcpy(void *dest, const void *src, size_t n);
946 |
947 |
948 | /*******************************************************************
949 | * For Faiss/Pytorch interop via pointers encoded as longs
950 | *******************************************************************/
951 |
952 | %inline %{
953 | float * cast_integer_to_float_ptr (long x) {
954 | return (float*)x;
955 | }
956 |
957 | long * cast_integer_to_long_ptr (long x) {
958 | return (long*)x;
959 | }
960 |
961 | int * cast_integer_to_int_ptr (long x) {
962 | return (int*)x;
963 | }
964 |
965 | %}
966 |
967 |
968 |
969 | /*******************************************************************
970 | * Range search interface
971 | *******************************************************************/
972 |
973 | %ignore faiss::BufferList::Buffer;
974 | %ignore faiss::RangeSearchPartialResult::QueryResult;
975 | %ignore faiss::IDSelectorBatch::set;
976 | %ignore faiss::IDSelectorBatch::bloom;
977 |
978 | %ignore faiss::InterruptCallback::instance;
979 | %ignore faiss::InterruptCallback::lock;
980 | %include
981 |
982 | %{
983 | // may be useful for lua code launched in background from shell
984 |
985 | #include
986 | void ignore_SIGTTIN() {
987 | signal(SIGTTIN, SIG_IGN);
988 | }
989 | %}
990 |
991 | void ignore_SIGTTIN();
992 |
993 |
994 | %inline %{
995 |
996 | // numpy misses a hash table implementation, hence this class. It
997 | // represents not found values as -1 like in the Index implementation
998 |
999 | struct MapLong2Long {
1000 | std::unordered_map map;
1001 |
1002 | void add(size_t n, const int64_t *keys, const int64_t *vals) {
1003 | map.reserve(map.size() + n);
1004 | for (size_t i = 0; i < n; i++) {
1005 | map[keys[i]] = vals[i];
1006 | }
1007 | }
1008 |
1009 | long search(int64_t key) {
1010 | if (map.count(key) == 0) {
1011 | return -1;
1012 | } else {
1013 | return map[key];
1014 | }
1015 | }
1016 |
1017 | void search_multiple(size_t n, int64_t *keys, int64_t * vals) {
1018 | for (size_t i = 0; i < n; i++) {
1019 | vals[i] = search(keys[i]);
1020 | }
1021 | }
1022 | };
1023 |
1024 | %}
1025 |
1026 | %inline %{
1027 | void wait() {
1028 | // in gdb, use return to get out of this function
1029 | for(int i = 0; i == 0; i += 0);
1030 | }
1031 | %}
1032 |
1033 | // End of file...
1034 |
--------------------------------------------------------------------------------
|