├── .gitmodules ├── .gitignore ├── cpu └── pom.xml ├── gpu └── pom.xml ├── Dockerfile ├── DockerfileGpu ├── pom.xml ├── common ├── src │ └── main │ │ └── java │ │ └── com │ │ └── gameofdimension │ │ └── faiss │ │ └── utils │ │ ├── JniFaissInitializer.java │ │ └── NativeUtils.java └── pom.xml ├── README.md ├── jni ├── Makefile ├── MakefileGpu └── swigfaiss.swig ├── cpu-demo ├── src │ └── main │ │ └── java │ │ └── com │ │ └── gameofdimension │ │ └── faiss │ │ ├── tutorial │ │ ├── OneFlat.java │ │ ├── TwoIVFFlat.java │ │ └── ThreeIVFPQ.java │ │ └── utils │ │ └── IndexHelper.java └── pom.xml └── gpu-demo ├── src └── main │ └── java │ └── com │ └── gameofdimension │ └── faiss │ ├── tutorial │ ├── GpuOneFlat.java │ ├── FiveMultipleGPUs.java │ └── FourGPU.java │ └── utils │ └── IndexHelper.java └── pom.xml /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "faiss"] 2 | path = faiss 3 | url = https://github.com/facebookresearch/faiss.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.ipr 3 | *.iws 4 | *.iml 5 | target/ 6 | src/main/java/com/gameofdimension/faiss/swig/ 7 | dependency-reduced-pom.xml 8 | -------------------------------------------------------------------------------- /cpu/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | jni-faiss 7 | com.gameofdimension 8 | 0.0.1 9 | 10 | 4.0.0 11 | 12 | cpu 13 | 14 | 15 | -------------------------------------------------------------------------------- /gpu/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | jni-faiss 7 | com.gameofdimension 8 | 0.0.1 9 | 10 | 4.0.0 11 | 12 | gpu 13 | 14 | 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM centos:7 2 | 3 | RUN yum install -y lapack lapack-devel 4 | 5 | # Install necessary build tools 6 | RUN yum install -y gcc-c++ make swig3 7 | 8 | RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel maven 9 | 10 | COPY . /opt/jni-faiss 11 | 12 | WORKDIR /opt/jni-faiss/faiss 13 | 14 | RUN ./configure --prefix=/usr --libdir=/usr/lib64 --without-cuda 15 | RUN make -j $(nproc) 16 | RUN make install 17 | 18 | 19 | WORKDIR /opt/jni-faiss/jni 20 | 21 | RUN make 22 | 23 | WORKDIR /opt/jni-faiss 24 | 25 | RUN mvn clean install -pl cpu -am 26 | 27 | RUN mvn clean package -pl cpu-demo -am 28 | -------------------------------------------------------------------------------- /DockerfileGpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-devel-centos7 2 | 3 | RUN yum install -y lapack lapack-devel 4 | 5 | # Install necessary build tools 6 | RUN yum install -y gcc-c++ make swig3 7 | 8 | RUN yum install -y java-1.8.0-openjdk java-1.8.0-openjdk-devel maven 9 | 10 | COPY . /opt/jni-faiss 11 | 12 | WORKDIR /opt/jni-faiss/faiss 13 | 14 | RUN ./configure --prefix=/usr --libdir=/usr/lib64 --with-cuda=/usr/local/cuda-10.0 15 | RUN make -j $(nproc) 16 | RUN make install 17 | 18 | WORKDIR /opt/jni-faiss/jni 19 | 20 | RUN make -f MakefileGpu 21 | 22 | WORKDIR /opt/jni-faiss 23 | 24 | RUN mvn clean install -pl gpu -am 25 | 26 | RUN mvn clean package -pl gpu-demo -am 27 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | com.gameofdimension 7 | jni-faiss 8 | 0.0.1 9 | 10 | cpu 11 | gpu 12 | common 13 | cpu-demo 14 | gpu-demo 15 | 16 | 17 | pom 18 | 4.0.0 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /common/src/main/java/com/gameofdimension/faiss/utils/JniFaissInitializer.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.utils; 2 | 3 | import com.google.common.base.Preconditions; 4 | 5 | import java.io.IOException; 6 | 7 | /** 8 | * @author yzq, yzq@leyantech.com 9 | * @date 2020-01-28. 10 | */ 11 | public class JniFaissInitializer { 12 | 13 | private static volatile boolean initialized = false; 14 | 15 | static { 16 | try { 17 | NativeUtils.loadLibraryFromJar("/_swigfaiss.so"); 18 | initialized = true; 19 | } catch (IOException e) { 20 | Preconditions.checkArgument(false); 21 | } 22 | } 23 | 24 | public static boolean initialized() { 25 | return initialized; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jni-faiss 2 | > linux only so far 3 | 4 | ## cpu 5 | 6 | - git clone https://github.com/gameofdimension/jni-faiss.git && cd jni-faiss && git submodule update --init 7 | 8 | - docker build -t jni-faiss . 9 | 10 | - docker run -it jni-faiss java -cp cpu-demo/target/cpu-demo-0.0.1.jar com.gameofdimension.faiss.tutorial.OneFlat 11 | 12 | ## gpu 13 | 14 | - git clone https://github.com/gameofdimension/jni-faiss.git && cd jni-faiss && git submodule update --init 15 | 16 | - docker build -t jni-faiss-gpu -f DockerfileGpu . 17 | 18 | - docker run --gpus 1 -it jni-faiss-gpu java -Xmx8g -cp gpu-demo/target/gpu-demo-0.0.1.jar com.gameofdimension.faiss.tutorial.GpuOneFlat 19 | 20 | ## reference 21 | 22 | - https://github.com/adamheinrich/native-utils 23 | 24 | - https://github.com/thenetcircle/faiss4j -------------------------------------------------------------------------------- /common/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | jni-faiss 7 | com.gameofdimension 8 | 0.0.1 9 | 10 | 4.0.0 11 | 12 | common 13 | 14 | 15 | 16 | com.google.guava 17 | guava 18 | 21.0 19 | 20 | 21 | 22 | 23 | 24 | no-uber 25 | 26 | true 27 | 28 | 29 | 30 | 31 | maven-compiler-plugin 32 | 3.5.1 33 | 34 | 1.8 35 | 1.8 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /jni/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | -include ../faiss/makefile.inc 7 | 8 | ifneq ($(strip $(NVCC)),) 9 | SWIGFLAGS += -DGPU_WRAPPER 10 | endif 11 | 12 | JAVACFLAGS = -I /usr/lib/jvm/java/include/ -I /usr/lib/jvm/java/include/linux/ 13 | SWIGJAVASRC = ../cpu/src/main/java/com/gameofdimension/faiss/swig 14 | 15 | all: build 16 | 17 | # Also silently generates swigfaiss.py. 18 | swigfaiss.cpp: swigfaiss.swig ../faiss/libfaiss.a 19 | mkdir -p $(SWIGJAVASRC) 20 | $(SWIG) -java -c++ -Doverride= -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $< 21 | 22 | swigfaiss_avx2.cpp: swigfaiss.swig ../faiss/libfaiss.a 23 | mkdir -p $(SWIGJAVASRC) 24 | $(SWIG) -java -c++ -Doverride= -module swigfaiss_avx2 -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $< 25 | 26 | %.o: %.cpp 27 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) $(JAVACFLAGS) \ 28 | -I../ -c $< -o $@ 29 | 30 | # Extension is .so even on OSX. 31 | _%.so: %.o ../faiss/libfaiss.a 32 | $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) 33 | 34 | build: _swigfaiss.so 35 | mkdir -p ../cpu/src/main/resources 36 | cp _swigfaiss.so ../cpu/src/main/resources/ 37 | 38 | # install: build 39 | # $(PYTHON) setup.py install 40 | # 41 | 42 | clean: 43 | rm -f swigfaiss*.cpp swigfaiss*.o _swigfaiss*.so 44 | rm -rf $(SWIGJAVASRC) 45 | 46 | 47 | .PHONY: all build clean install 48 | -------------------------------------------------------------------------------- /jni/MakefileGpu: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | -include ../faiss/makefile.inc 7 | 8 | ifneq ($(strip $(NVCC)),) 9 | SWIGFLAGS += -DGPU_WRAPPER 10 | endif 11 | 12 | 13 | JAVACFLAGS = -I /usr/lib/jvm/java/include/ -I /usr/lib/jvm/java/include/linux 14 | SWIGJAVASRC = ../gpu/src/main/java/com/gameofdimension/faiss/swig 15 | 16 | all: build 17 | 18 | # Also silently generates swigfaiss.py. 19 | swigfaiss.cpp: swigfaiss.swig ../faiss/libfaiss.a 20 | mkdir -p $(SWIGJAVASRC) 21 | $(SWIG) -java -c++ -Doverride= -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $< 22 | 23 | swigfaiss_avx2.cpp: swigfaiss.swig ../faiss/libfaiss.a 24 | mkdir -p $(SWIGJAVASRC) 25 | $(SWIG) -java -c++ -Doverride= -module swigfaiss_avx2 -I../ $(SWIGFLAGS) -package com.gameofdimension.faiss.swig -outdir $(SWIGJAVASRC) -o $@ $< 26 | 27 | %.o: %.cpp 28 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) $(CPUFLAGS) $(JAVACFLAGS) \ 29 | -I../ -c $< -o $@ 30 | 31 | # Extension is .so even on OSX. 32 | _%.so: %.o ../faiss/libfaiss.a 33 | $(CXX) $(SHAREDFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) 34 | 35 | build: _swigfaiss.so 36 | mkdir -p ../gpu/src/main/resources 37 | cp _swigfaiss.so ../gpu/src/main/resources/ 38 | 39 | # install: build 40 | # $(PYTHON) setup.py install 41 | # 42 | 43 | clean: 44 | rm -f swigfaiss*.cpp swigfaiss*.o _swigfaiss*.so 45 | rm -rf $(SWIGJAVASRC) 46 | 47 | 48 | .PHONY: all build clean install 49 | -------------------------------------------------------------------------------- /cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/OneFlat.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 5 | 6 | import com.gameofdimension.faiss.swig.IndexFlatL2; 7 | import com.gameofdimension.faiss.swig.floatArray; 8 | import com.gameofdimension.faiss.swig.longArray; 9 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 10 | import com.google.common.base.Preconditions; 11 | 12 | import java.util.Random; 13 | import java.util.logging.Logger; 14 | 15 | /** 16 | * @author yzq, yzq@leyantech.com 17 | * @date 2020-01-28. 18 | */ 19 | public class OneFlat { 20 | 21 | private static Logger LOG = Logger.getLogger(OneFlat.class.getName()); 22 | private static int d = 64; // dimension 23 | private static int nb = 100000; // database size 24 | private static int nq = 10000; // nb of queries 25 | 26 | private floatArray xb; 27 | private floatArray xq; 28 | 29 | private Random random; 30 | private IndexFlatL2 index; 31 | 32 | public OneFlat() { 33 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 34 | random = new Random(42); 35 | index = new IndexFlatL2(d); 36 | xb = makeRandomFloatArray(nb, d, random); 37 | xq = makeRandomFloatArray(nq, d, random); 38 | index.add(nb, xb.cast()); 39 | LOG.info(String.format("is_trained = %s, ntotal = %d", 40 | index.getIs_trained(), index.getNtotal())); 41 | } 42 | 43 | public void sanityCheck() { 44 | int rn = 4; 45 | int qn = 5; 46 | floatArray distances = new floatArray(qn * rn); 47 | longArray indices = new longArray(qn * rn); 48 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast()); 49 | 50 | LOG.info(show(distances, qn, rn)); 51 | LOG.info(show(indices, qn, rn)); 52 | } 53 | 54 | public void search() { 55 | int rn = 4; 56 | floatArray distances = new floatArray(nq * rn); 57 | longArray indices = new longArray(nq * rn); 58 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 59 | 60 | LOG.info(show(distances, 5, rn)); 61 | LOG.info(show(indices, 5, rn)); 62 | } 63 | 64 | public static void main(String[] args) { 65 | OneFlat oneFlat = new OneFlat(); 66 | 67 | LOG.info("****************************************************"); 68 | oneFlat.sanityCheck(); 69 | LOG.info("****************************************************"); 70 | oneFlat.search(); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/GpuOneFlat.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 5 | 6 | import com.gameofdimension.faiss.swig.GpuIndexFlatL2; 7 | import com.gameofdimension.faiss.swig.StandardGpuResources; 8 | import com.gameofdimension.faiss.swig.floatArray; 9 | import com.gameofdimension.faiss.swig.longArray; 10 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 11 | import com.google.common.base.Preconditions; 12 | 13 | import java.util.Random; 14 | import java.util.logging.Logger; 15 | 16 | /** 17 | * @author yzq, yzq@leyantech.com 18 | * @date 2020-01-28. 19 | */ 20 | public class GpuOneFlat { 21 | 22 | private static Logger LOG = Logger.getLogger(GpuOneFlat.class.getName()); 23 | private static int d = 64; // dimension 24 | private static int nb = 100000; // database size 25 | private static int nq = 10000; // nb of queries 26 | 27 | private floatArray xb; 28 | private floatArray xq; 29 | 30 | private Random random; 31 | private StandardGpuResources res; 32 | private GpuIndexFlatL2 index; 33 | 34 | public GpuOneFlat() { 35 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 36 | random = new Random(42); 37 | res = new StandardGpuResources(); 38 | index = new GpuIndexFlatL2(res, d); 39 | xb = makeRandomFloatArray(nb, d, random); 40 | xq = makeRandomFloatArray(nq, d, random); 41 | index.add(nb, xb.cast()); 42 | LOG.info(String.format("is_trained = %s, ntotal = %d", 43 | index.getIs_trained(), index.getNtotal())); 44 | } 45 | 46 | public void sanityCheck() { 47 | int rn = 4; 48 | int qn = 5; 49 | floatArray distances = new floatArray(qn * rn); 50 | longArray indices = new longArray(qn * rn); 51 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast()); 52 | 53 | LOG.info(show(distances, qn, rn)); 54 | LOG.info(show(indices, qn, rn)); 55 | } 56 | 57 | public void search() { 58 | int rn = 4; 59 | floatArray distances = new floatArray(nq * rn); 60 | longArray indices = new longArray(nq * rn); 61 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 62 | 63 | LOG.info(show(distances, 5, rn)); 64 | LOG.info(show(indices, 5, rn)); 65 | } 66 | 67 | public static void main(String[] args) { 68 | GpuOneFlat oneFlat = new GpuOneFlat(); 69 | 70 | LOG.info("****************************************************"); 71 | oneFlat.sanityCheck(); 72 | LOG.info("****************************************************"); 73 | oneFlat.search(); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/TwoIVFFlat.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 5 | 6 | import com.gameofdimension.faiss.swig.IndexFlatL2; 7 | import com.gameofdimension.faiss.swig.IndexIVFFlat; 8 | import com.gameofdimension.faiss.swig.MetricType; 9 | import com.gameofdimension.faiss.swig.floatArray; 10 | import com.gameofdimension.faiss.swig.longArray; 11 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 12 | import com.google.common.base.Preconditions; 13 | 14 | import java.util.Random; 15 | import java.util.logging.Logger; 16 | 17 | /** 18 | * @author yzq, yzq@leyantech.com 19 | * @date 2020-01-29. 20 | */ 21 | public class TwoIVFFlat { 22 | 23 | private static Logger LOG = Logger.getLogger(TwoIVFFlat.class.getName()); 24 | private static int d = 64; // dimension 25 | private static int nb = 100000; // database size 26 | private static int nq = 10000; // nb of queries 27 | private static int nlist = 100; 28 | 29 | private floatArray xb; 30 | private floatArray xq; 31 | 32 | private Random random; 33 | private IndexFlatL2 quantizer; 34 | private IndexIVFFlat index; 35 | 36 | public TwoIVFFlat() { 37 | 38 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 39 | random = new Random(42); 40 | 41 | xb = makeRandomFloatArray(nb, d, random); 42 | xq = makeRandomFloatArray(nq, d, random); 43 | 44 | quantizer = new IndexFlatL2(d); 45 | index = new IndexIVFFlat(quantizer, d, nlist, MetricType.METRIC_L2); 46 | Preconditions.checkArgument(!index.getIs_trained()); 47 | index.train(nb, xb.cast()); 48 | Preconditions.checkArgument(index.getIs_trained()); 49 | index.add(nb, xb.cast()); 50 | } 51 | 52 | 53 | public void search() { 54 | int rn = 4; 55 | floatArray distances = new floatArray(nq * rn); 56 | longArray indices = new longArray(nq * rn); 57 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 58 | 59 | LOG.info(show(distances, 5, rn)); 60 | LOG.info(show(indices, 5, rn)); 61 | } 62 | 63 | public void searchAgain() { 64 | int rn = 4; 65 | floatArray distances = new floatArray(nq * rn); 66 | longArray indices = new longArray(nq * rn); 67 | index.setNprobe(10); 68 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 69 | 70 | LOG.info(show(distances, 5, rn)); 71 | LOG.info(show(indices, 5, rn)); 72 | } 73 | 74 | public static void main(String[] args) { 75 | TwoIVFFlat twoIVFFlat = new TwoIVFFlat(); 76 | 77 | LOG.info("****************************************************"); 78 | twoIVFFlat.search(); 79 | LOG.info("****************************************************"); 80 | twoIVFFlat.searchAgain(); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /cpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/ThreeIVFPQ.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 5 | 6 | import com.gameofdimension.faiss.swig.IndexFlatL2; 7 | import com.gameofdimension.faiss.swig.IndexIVFPQ; 8 | import com.gameofdimension.faiss.swig.floatArray; 9 | import com.gameofdimension.faiss.swig.longArray; 10 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 11 | import com.google.common.base.Preconditions; 12 | 13 | import java.util.Random; 14 | import java.util.logging.Logger; 15 | 16 | /** 17 | * @author yzq, yzq@leyantech.com 18 | * @date 2020-01-31. 19 | */ 20 | public class ThreeIVFPQ { 21 | 22 | private static Logger LOG = Logger.getLogger(ThreeIVFPQ.class.getName()); 23 | private static int d = 64; // dimension 24 | private static int nb = 100000; // database size 25 | private static int nq = 10000; // nb of queries 26 | private static int nlist = 100; 27 | private static int m = 8; 28 | 29 | private floatArray xb; 30 | private floatArray xq; 31 | 32 | private Random random; 33 | private IndexFlatL2 quantizer; 34 | private IndexIVFPQ index; 35 | 36 | public ThreeIVFPQ() { 37 | 38 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 39 | random = new Random(42); 40 | 41 | xb = makeRandomFloatArray(nb, d, random); 42 | xq = makeRandomFloatArray(nq, d, random); 43 | 44 | quantizer = new IndexFlatL2(d); 45 | index = new IndexIVFPQ(quantizer, d, nlist, m, 8); 46 | Preconditions.checkArgument(!index.getIs_trained()); 47 | index.train(nb, xb.cast()); 48 | Preconditions.checkArgument(index.getIs_trained()); 49 | index.add(nb, xb.cast()); 50 | } 51 | 52 | 53 | public void sanityCheck() { 54 | 55 | int rn = 4; 56 | int qn = 5; 57 | floatArray distances = new floatArray(qn * rn); 58 | longArray indices = new longArray(qn * rn); 59 | index.setNprobe(10); 60 | index.search(qn, xb.cast(), rn, distances.cast(), indices.cast()); 61 | 62 | LOG.info(show(distances, qn, rn)); 63 | LOG.info(show(indices, qn, rn)); 64 | } 65 | 66 | 67 | public void search() { 68 | 69 | int rn = 4; 70 | floatArray distances = new floatArray(nq * rn); 71 | longArray indices = new longArray(nq * rn); 72 | index.setNprobe(10); 73 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 74 | 75 | LOG.info(show(distances, 5, rn)); 76 | LOG.info(show(indices, 5, rn)); 77 | } 78 | 79 | public static void main(String[] args) { 80 | ThreeIVFPQ threeIVFPQ = new ThreeIVFPQ(); 81 | 82 | LOG.info("****************************************************"); 83 | threeIVFPQ.sanityCheck(); 84 | LOG.info("****************************************************"); 85 | threeIVFPQ.search(); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/FiveMultipleGPUs.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.swig.swigfaiss_gpu.getNumDevices; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 5 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 6 | 7 | import com.gameofdimension.faiss.swig.GpuResourcesVector; 8 | import com.gameofdimension.faiss.swig.Index; 9 | import com.gameofdimension.faiss.swig.IndexFlatL2; 10 | import com.gameofdimension.faiss.swig.IntVector; 11 | import com.gameofdimension.faiss.swig.StandardGpuResources; 12 | import com.gameofdimension.faiss.swig.floatArray; 13 | import com.gameofdimension.faiss.swig.longArray; 14 | import com.gameofdimension.faiss.swig.swigfaiss_gpu; 15 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 16 | import com.google.common.base.Preconditions; 17 | 18 | import java.util.Random; 19 | import java.util.logging.Logger; 20 | 21 | /** 22 | * @author yzq, yzq@leyantech.com 23 | * @date 2020-02-01. 24 | */ 25 | public class FiveMultipleGPUs { 26 | 27 | private static Logger LOG = Logger.getLogger(FiveMultipleGPUs.class.getName()); 28 | private static int d = 64; // dimension 29 | private static int nb = 100000; // database size 30 | private static int nq = 10000; // nb of queries 31 | 32 | private floatArray xb; 33 | private floatArray xq; 34 | 35 | private Random random; 36 | private Index gpuIndex; 37 | 38 | public FiveMultipleGPUs() { 39 | 40 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 41 | int ngpus = getNumDevices(); 42 | LOG.info(String.format("Number of GPUs: %d", ngpus)); 43 | GpuResourcesVector res = new GpuResourcesVector(); 44 | IntVector devs = new IntVector(); 45 | for (int i = 0; i < ngpus; i++) { 46 | devs.push_back(i); 47 | res.push_back(new StandardGpuResources()); 48 | } 49 | 50 | IndexFlatL2 cpuIndex = new IndexFlatL2(d); 51 | gpuIndex = swigfaiss_gpu.index_cpu_to_gpu_multiple(res, devs, cpuIndex); 52 | 53 | random = new Random(42); 54 | xb = makeRandomFloatArray(nb, d, random); 55 | xq = makeRandomFloatArray(nq, d, random); 56 | LOG.info(String.format("is_trained = %s", gpuIndex.getIs_trained())); 57 | gpuIndex.add(nb, xb.cast()); // add vectors to the index 58 | LOG.info(String.format("ntotal = %d", gpuIndex.getNtotal())); 59 | } 60 | 61 | public void search() { 62 | int rn = 4; 63 | floatArray distances = new floatArray(nq * rn); 64 | longArray indices = new longArray(nq * rn); 65 | gpuIndex.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 66 | 67 | LOG.info(show(distances, 5, rn)); 68 | LOG.info(show(indices, 5, rn)); 69 | } 70 | 71 | public static void main(String[] args) { 72 | FiveMultipleGPUs fiveMultipleGPUs = new FiveMultipleGPUs(); 73 | 74 | LOG.info("****************************************************"); 75 | fiveMultipleGPUs.search(); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /gpu-demo/src/main/java/com/gameofdimension/faiss/tutorial/FourGPU.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.tutorial; 2 | 3 | import static com.gameofdimension.faiss.swig.MetricType.METRIC_L2; 4 | import static com.gameofdimension.faiss.utils.IndexHelper.makeRandomFloatArray; 5 | import static com.gameofdimension.faiss.utils.IndexHelper.show; 6 | 7 | import com.gameofdimension.faiss.swig.GpuIndexFlatL2; 8 | import com.gameofdimension.faiss.swig.GpuIndexIVFFlat; 9 | import com.gameofdimension.faiss.swig.StandardGpuResources; 10 | import com.gameofdimension.faiss.swig.floatArray; 11 | import com.gameofdimension.faiss.swig.longArray; 12 | import com.gameofdimension.faiss.utils.JniFaissInitializer; 13 | import com.google.common.base.Preconditions; 14 | 15 | import java.util.Random; 16 | import java.util.logging.Logger; 17 | 18 | /** 19 | * @author yzq, yzq@leyantech.com 20 | * @date 2020-02-01. 21 | */ 22 | public class FourGPU { 23 | 24 | private static Logger LOG = Logger.getLogger(FourGPU.class.getName()); 25 | private static int d = 64; // dimension 26 | private static int nb = 100000; // database size 27 | private static int nq = 10000; // nb of queries 28 | private static int nlist = 100; 29 | 30 | private floatArray xb; 31 | private floatArray xq; 32 | 33 | private Random random; 34 | private StandardGpuResources res; 35 | private GpuIndexFlatL2 index; 36 | private GpuIndexIVFFlat ivfIndex; 37 | 38 | public FourGPU() { 39 | Preconditions.checkArgument(JniFaissInitializer.initialized()); 40 | random = new Random(42); 41 | res = new StandardGpuResources(); 42 | index = new GpuIndexFlatL2(res, d); 43 | ivfIndex = new GpuIndexIVFFlat(res, d, nlist, METRIC_L2); 44 | xb = makeRandomFloatArray(nb, d, random); 45 | xq = makeRandomFloatArray(nq, d, random); 46 | ivfIndex.train(nb, xb.cast()); 47 | ivfIndex.add(nb, xb.cast()); 48 | index.add(nb, xb.cast()); 49 | LOG.info(String.format("is_trained = %s, ntotal = %d", 50 | index.getIs_trained(), index.getNtotal())); 51 | } 52 | 53 | public void searchFlat() { 54 | int rn = 4; 55 | floatArray distances = new floatArray(nq * rn); 56 | longArray indices = new longArray(nq * rn); 57 | index.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 58 | 59 | LOG.info(show(distances, 5, rn)); 60 | LOG.info(show(indices, 5, rn)); 61 | } 62 | 63 | public void searchIvf() { 64 | Preconditions.checkArgument(ivfIndex.getIs_trained()); 65 | int rn = 4; 66 | floatArray distances = new floatArray(nq * rn); 67 | longArray indices = new longArray(nq * rn); 68 | ivfIndex.search(nq, xq.cast(), rn, distances.cast(), indices.cast()); 69 | 70 | LOG.info(show(distances, 5, rn)); 71 | LOG.info(show(indices, 5, rn)); 72 | } 73 | 74 | public static void main(String[] args) { 75 | FourGPU fourGPU = new FourGPU(); 76 | 77 | LOG.info("****************************************************"); 78 | fourGPU.searchFlat(); 79 | LOG.info("****************************************************"); 80 | fourGPU.searchIvf(); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /cpu-demo/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | jni-faiss 7 | com.gameofdimension 8 | 0.0.1 9 | 10 | 4.0.0 11 | 12 | cpu-demo 13 | 14 | 15 | 16 | com.gameofdimension 17 | common 18 | 0.0.1 19 | 20 | 21 | com.gameofdimension 22 | cpu 23 | 0.0.1 24 | 25 | 26 | 27 | 28 | 29 | uber 30 | 31 | true 32 | 33 | 34 | 35 | 36 | org.apache.maven.plugins 37 | maven-shade-plugin 38 | 2.3 39 | 40 | 41 | package 42 | 43 | shade 44 | 45 | 46 | 47 | 49 | 50 | 51 | 52 | 53 | *.* 54 | 55 | META-INF/*.RSA 56 | META-INF/*.DSA 57 | META-INF/*.SF 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | org.apache.maven.plugins 67 | maven-compiler-plugin 68 | 69 | 1.8 70 | 1.8 71 | 72 | 73 | 74 | org.jacoco 75 | jacoco-maven-plugin 76 | 77 | 78 | 79 | 80 | 81 | no-uber 82 | 83 | 84 | 85 | maven-compiler-plugin 86 | 3.5.1 87 | 88 | 1.8 89 | 1.8 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /gpu-demo/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | jni-faiss 7 | com.gameofdimension 8 | 0.0.1 9 | 10 | 4.0.0 11 | 12 | gpu-demo 13 | 14 | 15 | 16 | com.gameofdimension 17 | common 18 | 0.0.1 19 | 20 | 21 | com.gameofdimension 22 | gpu 23 | 0.0.1 24 | 25 | 26 | 27 | 28 | 29 | 30 | uber 31 | 32 | true 33 | 34 | 35 | 36 | 37 | org.apache.maven.plugins 38 | maven-shade-plugin 39 | 2.3 40 | 41 | 42 | package 43 | 44 | shade 45 | 46 | 47 | 48 | 50 | 51 | 52 | 53 | 54 | *.* 55 | 56 | META-INF/*.RSA 57 | META-INF/*.DSA 58 | META-INF/*.SF 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | org.apache.maven.plugins 68 | maven-compiler-plugin 69 | 70 | 1.8 71 | 1.8 72 | 73 | 74 | 75 | org.jacoco 76 | jacoco-maven-plugin 77 | 78 | 79 | 80 | 81 | 82 | no-uber 83 | 84 | 85 | 86 | maven-compiler-plugin 87 | 3.5.1 88 | 89 | 1.8 90 | 1.8 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /cpu-demo/src/main/java/com/gameofdimension/faiss/utils/IndexHelper.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.utils; 2 | 3 | // copy from https://raw.githubusercontent.com/thenetcircle/faiss4j/master/src/main/java/com/thenetcircle/services/faiss4j/IndexHelper.java 4 | 5 | import com.gameofdimension.faiss.swig.floatArray; 6 | import com.gameofdimension.faiss.swig.intArray; 7 | import com.gameofdimension.faiss.swig.longArray; 8 | 9 | import java.util.Random; 10 | // import org.slf4j.Logger; 11 | // import org.slf4j.LoggerFactory; 12 | 13 | public class IndexHelper { 14 | // private static final Logger log = LoggerFactory.getLogger(IndexHelper.class); 15 | 16 | public static String show(longArray a, int rows, int cols) { 17 | StringBuilder sb = new StringBuilder(); 18 | for (int i = 0; i < rows; i++) { 19 | sb.append(i).append('\t').append('|'); 20 | for (int j = 0; j < cols; j++) { 21 | sb.append(String.format("%5d ", a.getitem(i * cols + j))); 22 | } 23 | sb.append("\n"); 24 | } 25 | return sb.toString(); 26 | } 27 | 28 | public static String show(floatArray a, int rows, int cols) { 29 | StringBuilder sb = new StringBuilder(); 30 | for (int i = 0; i < rows; i++) { 31 | sb.append(i).append('\t').append('|'); 32 | for (int j = 0; j < cols; j++) { 33 | sb.append(String.format("%7g ", a.getitem(i * cols + j))); 34 | } 35 | sb.append("\n"); 36 | } 37 | return sb.toString(); 38 | } 39 | 40 | public static floatArray makeFloatArray(float[][] vectors) { 41 | int d = vectors[0].length; 42 | int nb = vectors.length; 43 | floatArray fa = new floatArray(d * nb); 44 | for (int i = 0; i < nb; i++) { 45 | for (int j = 0; j < d; j++) { 46 | fa.setitem(d * i + j, vectors[i][j]); 47 | } 48 | } 49 | return fa; 50 | } 51 | 52 | public static longArray makeLongArray(int[] ints) { 53 | int len = ints.length; 54 | longArray la = new longArray(len); 55 | for (int i = 0; i < len; i++) { 56 | la.setitem(i, ints[i]); 57 | } 58 | return la; 59 | } 60 | 61 | public static long[] toArray(longArray c_array, int length) { 62 | return toArray(c_array, 0, length); 63 | } 64 | 65 | public static long[] toArray(longArray c_array, int start, int length) { 66 | long[] re = new long[length]; 67 | for (int i = start; i < length; i++) { 68 | re[i] = c_array.getitem(i); 69 | } 70 | return re; 71 | } 72 | 73 | public static int[] toArray(intArray c_array, int length) { 74 | return toArray(c_array, 0, length); 75 | } 76 | 77 | public static int[] toArray(intArray c_array, int start, int length) { 78 | int[] re = new int[length]; 79 | for (int i = start; i < length; i++) { 80 | re[i] = c_array.getitem(i); 81 | } 82 | return re; 83 | } 84 | 85 | public static float[] toArray(floatArray c_array, int length) { 86 | return toArray(c_array, 0, length); 87 | } 88 | 89 | public static float[] toArray(floatArray c_array, int start, int length) { 90 | float[] re = new float[length]; 91 | for (int i = start; i < length; i++) { 92 | re[i] = c_array.getitem(i); 93 | } 94 | return re; 95 | } 96 | 97 | public static floatArray makeRandomFloatArray(int size, int dim, Random random) { 98 | floatArray array = new floatArray(size * dim); 99 | for (int i = 0; i < size; i++) { 100 | for (int j = 0; j < dim; j++) { 101 | array.setitem(i * dim + j, random.nextFloat()); 102 | } 103 | array.setitem(i * dim, i / 1000.f + array.getitem(i * dim)); 104 | } 105 | return array; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /gpu-demo/src/main/java/com/gameofdimension/faiss/utils/IndexHelper.java: -------------------------------------------------------------------------------- 1 | package com.gameofdimension.faiss.utils; 2 | 3 | // copy from https://raw.githubusercontent.com/thenetcircle/faiss4j/master/src/main/java/com/thenetcircle/services/faiss4j/IndexHelper.java 4 | 5 | import com.gameofdimension.faiss.swig.floatArray; 6 | import com.gameofdimension.faiss.swig.intArray; 7 | import com.gameofdimension.faiss.swig.longArray; 8 | 9 | import java.util.Random; 10 | // import org.slf4j.Logger; 11 | // import org.slf4j.LoggerFactory; 12 | 13 | public class IndexHelper { 14 | // private static final Logger log = LoggerFactory.getLogger(IndexHelper.class); 15 | 16 | public static String show(longArray a, int rows, int cols) { 17 | StringBuilder sb = new StringBuilder(); 18 | for (int i = 0; i < rows; i++) { 19 | sb.append(i).append('\t').append('|'); 20 | for (int j = 0; j < cols; j++) { 21 | sb.append(String.format("%5d ", a.getitem(i * cols + j))); 22 | } 23 | sb.append("\n"); 24 | } 25 | return sb.toString(); 26 | } 27 | 28 | public static String show(floatArray a, int rows, int cols) { 29 | StringBuilder sb = new StringBuilder(); 30 | for (int i = 0; i < rows; i++) { 31 | sb.append(i).append('\t').append('|'); 32 | for (int j = 0; j < cols; j++) { 33 | sb.append(String.format("%7g ", a.getitem(i * cols + j))); 34 | } 35 | sb.append("\n"); 36 | } 37 | return sb.toString(); 38 | } 39 | 40 | public static floatArray makeFloatArray(float[][] vectors) { 41 | int d = vectors[0].length; 42 | int nb = vectors.length; 43 | floatArray fa = new floatArray(d * nb); 44 | for (int i = 0; i < nb; i++) { 45 | for (int j = 0; j < d; j++) { 46 | fa.setitem(d * i + j, vectors[i][j]); 47 | } 48 | } 49 | return fa; 50 | } 51 | 52 | public static longArray makeLongArray(int[] ints) { 53 | int len = ints.length; 54 | longArray la = new longArray(len); 55 | for (int i = 0; i < len; i++) { 56 | la.setitem(i, ints[i]); 57 | } 58 | return la; 59 | } 60 | 61 | public static long[] toArray(longArray c_array, int length) { 62 | return toArray(c_array, 0, length); 63 | } 64 | 65 | public static long[] toArray(longArray c_array, int start, int length) { 66 | long[] re = new long[length]; 67 | for (int i = start; i < length; i++) { 68 | re[i] = c_array.getitem(i); 69 | } 70 | return re; 71 | } 72 | 73 | public static int[] toArray(intArray c_array, int length) { 74 | return toArray(c_array, 0, length); 75 | } 76 | 77 | public static int[] toArray(intArray c_array, int start, int length) { 78 | int[] re = new int[length]; 79 | for (int i = start; i < length; i++) { 80 | re[i] = c_array.getitem(i); 81 | } 82 | return re; 83 | } 84 | 85 | public static float[] toArray(floatArray c_array, int length) { 86 | return toArray(c_array, 0, length); 87 | } 88 | 89 | public static float[] toArray(floatArray c_array, int start, int length) { 90 | float[] re = new float[length]; 91 | for (int i = start; i < length; i++) { 92 | re[i] = c_array.getitem(i); 93 | } 94 | return re; 95 | } 96 | 97 | public static floatArray makeRandomFloatArray(int size, int dim, Random random) { 98 | floatArray array = new floatArray(size * dim); 99 | for (int i = 0; i < size; i++) { 100 | for (int j = 0; j < dim; j++) { 101 | array.setitem(i * dim + j, random.nextFloat()); 102 | } 103 | array.setitem(i * dim, i / 1000.f + array.getitem(i * dim)); 104 | } 105 | return array; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /common/src/main/java/com/gameofdimension/faiss/utils/NativeUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Class NativeUtils is published under the The MIT License: 3 | * 4 | * Copyright (c) 2012 Adam Heinrich 5 | * 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy 7 | * of this software and associated documentation files (the "Software"), to deal 8 | * in the Software without restriction, including without limitation the rights 9 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | * copies of the Software, and to permit persons to whom the Software is 11 | * furnished to do so, subject to the following conditions: 12 | * 13 | * The above copyright notice and this permission notice shall be included in all 14 | * copies or substantial portions of the Software. 15 | * 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | * SOFTWARE. 23 | */ 24 | 25 | package com.gameofdimension.faiss.utils; 26 | 27 | import java.io.*; 28 | import java.nio.file.FileSystemNotFoundException; 29 | import java.nio.file.FileSystems; 30 | import java.nio.file.Files; 31 | import java.nio.file.ProviderNotFoundException; 32 | import java.nio.file.StandardCopyOption; 33 | 34 | /** 35 | * A simple library class which helps with loading dynamic libraries stored in the JAR archive. 36 | * These libraries usually contain implementation of some methods in native code (using JNI - Java 37 | * Native Interface). 38 | * 39 | * @see http://adamheinrich.com/blog/2012/how-to-load-native-jni-library-from-jar 40 | * @see https://github.com/adamheinrich/native-utils 41 | */ 42 | public class NativeUtils { 43 | 44 | /** 45 | * The minimum length a prefix for a file has to have according to {@link 46 | * File#createTempFile(String, String)}}. 47 | */ 48 | private static final int MIN_PREFIX_LENGTH = 3; 49 | public static final String NATIVE_FOLDER_PATH_PREFIX = "nativeutils"; 50 | 51 | /** 52 | * Temporary directory which will contain the DLLs. 53 | */ 54 | private static File temporaryDir; 55 | 56 | /** 57 | * Private constructor - this class will never be instanced 58 | */ 59 | private NativeUtils() { 60 | } 61 | 62 | /** 63 | * Loads library from current JAR archive 64 | * 65 | * The file from JAR is copied into system temporary directory and then loaded. The temporary file 66 | * is deleted after exiting. Method uses String as filename because the pathname is "abstract", 67 | * not system-dependent. 68 | * 69 | * @param path The path of file inside JAR as absolute path (beginning with '/'), e.g. 70 | * /package/File.ext 71 | * @throws IOException If temporary file creation or read/write operation fails 72 | * @throws IllegalArgumentException If source file (param path) does not exist 73 | * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than 74 | * three characters (restriction of {@link File#createTempFile(java.lang.String, 75 | * java.lang.String)}). 76 | * @throws FileNotFoundException If the file could not be found inside the JAR. 77 | */ 78 | public static void loadLibraryFromJar(String path) throws IOException { 79 | 80 | if (null == path || !path.startsWith("/")) { 81 | throw new IllegalArgumentException("The path has to be absolute (start with '/')."); 82 | } 83 | 84 | // Obtain filename from path 85 | String[] parts = path.split("/"); 86 | String filename = (parts.length > 1) ? parts[parts.length - 1] : null; 87 | 88 | // Check if the filename is okay 89 | if (filename == null || filename.length() < MIN_PREFIX_LENGTH) { 90 | throw new IllegalArgumentException("The filename has to be at least 3 characters long."); 91 | } 92 | 93 | // Prepare temporary file 94 | if (temporaryDir == null) { 95 | temporaryDir = createTempDirectory(NATIVE_FOLDER_PATH_PREFIX); 96 | temporaryDir.deleteOnExit(); 97 | } 98 | 99 | File temp = new File(temporaryDir, filename); 100 | 101 | try (InputStream is = NativeUtils.class.getResourceAsStream(path)) { 102 | Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING); 103 | } catch (IOException e) { 104 | temp.delete(); 105 | throw e; 106 | } catch (NullPointerException e) { 107 | temp.delete(); 108 | throw new FileNotFoundException("File " + path + " was not found inside JAR."); 109 | } 110 | 111 | try { 112 | System.load(temp.getAbsolutePath()); 113 | } finally { 114 | if (isPosixCompliant()) { 115 | // Assume POSIX compliant file system, can be deleted after loading 116 | temp.delete(); 117 | } else { 118 | // Assume non-POSIX, and don't delete until last file descriptor closed 119 | temp.deleteOnExit(); 120 | } 121 | } 122 | } 123 | 124 | private static boolean isPosixCompliant() { 125 | try { 126 | return FileSystems.getDefault() 127 | .supportedFileAttributeViews() 128 | .contains("posix"); 129 | } catch (FileSystemNotFoundException 130 | | ProviderNotFoundException 131 | | SecurityException e) { 132 | return false; 133 | } 134 | } 135 | 136 | private static File createTempDirectory(String prefix) throws IOException { 137 | String tempDir = System.getProperty("java.io.tmpdir"); 138 | File generatedDir = new File(tempDir, prefix + System.nanoTime()); 139 | 140 | if (!generatedDir.mkdir()) { 141 | throw new IOException("Failed to create temp directory " + generatedDir.getName()); 142 | } 143 | 144 | return generatedDir; 145 | } 146 | } -------------------------------------------------------------------------------- /jni/swigfaiss.swig: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | // -*- C++ -*- 9 | 10 | // This file describes the C++-scripting language bridge for both Lua 11 | // and Python It contains mainly includes and a few macros. There are 12 | // 3 preprocessor macros of interest: 13 | // SWIGLUA: Lua-specific code 14 | // SWIGPYTHON: Python-specific code 15 | // GPU_WRAPPER: also compile interfaces for GPU. 16 | 17 | #ifdef SWIGJAVA 18 | %include "arrays_java.i" 19 | //%apply int[] {int *}; 20 | //%apply double[] {double *}; 21 | //%apply float[] {float *}; 22 | //%apply long[] {long *}; 23 | %include "carrays.i" 24 | %array_class(int, intArray); 25 | %array_class(float, floatArray); 26 | %array_class(long, longArray); 27 | %array_class(double, doubleArray); 28 | #endif 29 | 30 | #ifdef GPU_WRAPPER 31 | %module swigfaiss_gpu; 32 | #else 33 | %module swigfaiss; 34 | #endif 35 | 36 | // %module swigfaiss; 37 | 38 | %rename(faiss_RangeSearchPartialResult_finalize) faiss::RangeSearchPartialResult::finalize(); 39 | %rename(faiss_gpu_GpuIndexIVF_getQuantizer) faiss::gpu::GpuIndexIVF::getQuantizer(); 40 | %ignore wait(); 41 | 42 | // fbode SWIG fails on warnings, so make them non fatal 43 | #pragma SWIG nowarn=321 44 | #pragma SWIG nowarn=403 45 | #pragma SWIG nowarn=325 46 | #pragma SWIG nowarn=389 47 | #pragma SWIG nowarn=341 48 | #pragma SWIG nowarn=512 49 | 50 | %include 51 | typedef int64_t size_t; 52 | 53 | #define __restrict 54 | 55 | 56 | /******************************************************************* 57 | * Copied verbatim to wrapper. Contains the C++-visible includes, and 58 | * the language includes for their respective matrix libraries. 59 | *******************************************************************/ 60 | 61 | %{ 62 | 63 | 64 | #include 65 | #include 66 | 67 | 68 | #ifdef SWIGLUA 69 | 70 | #include 71 | 72 | extern "C" { 73 | 74 | #include 75 | #include 76 | #undef THTensor 77 | 78 | } 79 | 80 | #endif 81 | 82 | 83 | #ifdef SWIGPYTHON 84 | 85 | #undef popcount64 86 | 87 | #define SWIG_FILE_WITH_INIT 88 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 89 | #include 90 | 91 | #endif 92 | 93 | 94 | #include 95 | #include 96 | #include 97 | #include 98 | #include 99 | #include 100 | #include 101 | #include 102 | #include 103 | #include 104 | #include 105 | #include 106 | #include 107 | #include 108 | #include 109 | #include 110 | #include 111 | #include 112 | #include 113 | 114 | #include 115 | #include 116 | #include 117 | #include 118 | 119 | #include 120 | #include 121 | #include 122 | 123 | #include 124 | #include 125 | #include 126 | #include 127 | #include 128 | #include 129 | #include 130 | #include 131 | 132 | #include 133 | 134 | #include 135 | 136 | #include 137 | #include 138 | #include 139 | 140 | #include 141 | #include 142 | 143 | 144 | %} 145 | 146 | /******************************************************** 147 | * GIL manipulation and exception handling 148 | ********************************************************/ 149 | 150 | #ifdef SWIGPYTHON 151 | // %catches(faiss::FaissException); 152 | 153 | 154 | // Python-specific: release GIL by default for all functions 155 | %exception { 156 | Py_BEGIN_ALLOW_THREADS 157 | try { 158 | $action 159 | } catch(faiss::FaissException & e) { 160 | PyEval_RestoreThread(_save); 161 | 162 | if (PyErr_Occurred()) { 163 | // some previous code already set the error type. 164 | } else { 165 | PyErr_SetString(PyExc_RuntimeError, e.what()); 166 | } 167 | SWIG_fail; 168 | } catch(std::bad_alloc & ba) { 169 | PyEval_RestoreThread(_save); 170 | PyErr_SetString(PyExc_MemoryError, "std::bad_alloc"); 171 | SWIG_fail; 172 | } 173 | Py_END_ALLOW_THREADS 174 | } 175 | 176 | #endif 177 | 178 | #ifdef SWIGLUA 179 | 180 | %exception { 181 | try { 182 | $action 183 | } catch(faiss::FaissException & e) { 184 | SWIG_Lua_pushferrstring(L, "C++ exception: %s", e.what()); \ 185 | goto fail; 186 | } 187 | } 188 | 189 | #endif 190 | 191 | 192 | /******************************************************************* 193 | * Types of vectors we want to manipulate at the scripting language 194 | * level. 195 | *******************************************************************/ 196 | 197 | // simplified interface for vector 198 | namespace std { 199 | 200 | template 201 | class vector { 202 | public: 203 | vector(); 204 | void push_back(T); 205 | void clear(); 206 | T * data(); 207 | size_t size(); 208 | T at (size_t n) const; 209 | void resize (size_t n); 210 | void swap (vector & other); 211 | }; 212 | }; 213 | 214 | 215 | 216 | %template(FloatVector) std::vector; 217 | %template(DoubleVector) std::vector; 218 | %template(ByteVector) std::vector; 219 | %template(CharVector) std::vector; 220 | // NOTE(hoss): Using unsigned long instead of uint64_t because OSX defines 221 | // uint64_t as unsigned long long, which SWIG is not aware of. 222 | %template(Uint64Vector) std::vector; 223 | %template(LongVector) std::vector; 224 | %template(IntVector) std::vector; 225 | %template(FloatVectorVector) std::vector >; 226 | %template(ByteVectorVector) std::vector >; 227 | %template(LongVectorVector) std::vector >; 228 | %template(VectorTransformVector) std::vector; 229 | %template(OperatingPointVector) std::vector; 230 | %template(InvertedListsPtrVector) std::vector; 231 | %template(RepeatVector) std::vector; 232 | 233 | #ifdef GPU_WRAPPER 234 | %template(GpuResourcesVector) std::vector; 235 | #endif 236 | 237 | %include 238 | 239 | // produces an error on the Mac 240 | %ignore faiss::hamming; 241 | 242 | /******************************************************************* 243 | * Parse headers 244 | *******************************************************************/ 245 | 246 | 247 | %ignore *::cmp; 248 | 249 | %include 250 | %include 251 | 252 | int get_num_gpus(); 253 | void gpu_profiler_start(); 254 | void gpu_profiler_stop(); 255 | void gpu_sync_all_devices(); 256 | 257 | #ifdef GPU_WRAPPER 258 | 259 | %{ 260 | 261 | #include 262 | #include 263 | #include 264 | #include 265 | #include 266 | #include 267 | #include 268 | #include 269 | #include 270 | #include 271 | #include 272 | #include 273 | #include 274 | #include 275 | #include 276 | 277 | int get_num_gpus() 278 | { 279 | return faiss::gpu::getNumDevices(); 280 | } 281 | 282 | void gpu_profiler_start() 283 | { 284 | return faiss::gpu::profilerStart(); 285 | } 286 | 287 | void gpu_profiler_stop() 288 | { 289 | return faiss::gpu::profilerStop(); 290 | } 291 | 292 | void gpu_sync_all_devices() 293 | { 294 | return faiss::gpu::synchronizeAllDevices(); 295 | } 296 | 297 | %} 298 | 299 | // causes weird wrapper bug 300 | %ignore *::getMemoryManager; 301 | %ignore *::getMemoryManagerCurrentDevice; 302 | 303 | %include 304 | %include 305 | 306 | #else 307 | 308 | %{ 309 | int get_num_gpus() 310 | { 311 | return 0; 312 | } 313 | 314 | void gpu_profiler_start() 315 | { 316 | } 317 | 318 | void gpu_profiler_stop() 319 | { 320 | } 321 | 322 | void gpu_sync_all_devices() 323 | { 324 | } 325 | %} 326 | 327 | 328 | #endif 329 | 330 | // order matters because includes are not recursive 331 | 332 | %include 333 | %include 334 | %include 335 | 336 | %include 337 | %include 338 | 339 | %include 340 | 341 | %ignore faiss::ProductQuantizer::get_centroids(size_t,size_t) const; 342 | 343 | %include 344 | 345 | %include 346 | %include 347 | %include 348 | %include 349 | %include 350 | %include 351 | %include 352 | %ignore InvertedListScanner; 353 | %ignore BinaryInvertedListScanner; 354 | %include 355 | // NOTE(hoss): SWIG (wrongly) believes the overloaded const version shadows the 356 | // non-const one. 357 | %warnfilter(509) extract_index_ivf; 358 | %include 359 | %ignore faiss::ScalarQuantizer::SQDistanceComputer; 360 | %include 361 | %include 362 | %include 363 | %include 364 | %include 365 | %include 366 | %include 367 | 368 | %include 369 | %include 370 | 371 | %ignore faiss::IndexIVFPQ::alloc_type; 372 | %include 373 | %include 374 | %include 375 | 376 | %include 377 | %include 378 | %include 379 | %include 380 | %include 381 | 382 | 383 | 384 | // %ignore faiss::IndexReplicas::at(int) const; 385 | 386 | %include 387 | %template(ThreadedIndexBase) faiss::ThreadedIndex; 388 | %template(ThreadedIndexBaseBinary) faiss::ThreadedIndex; 389 | 390 | %include 391 | %template(IndexShards) faiss::IndexShardsTemplate; 392 | %template(IndexBinaryShards) faiss::IndexShardsTemplate; 393 | 394 | %include 395 | %template(IndexReplicas) faiss::IndexReplicasTemplate; 396 | %template(IndexBinaryReplicas) faiss::IndexReplicasTemplate; 397 | 398 | %include 399 | %template(IndexIDMap) faiss::IndexIDMapTemplate; 400 | %template(IndexBinaryIDMap) faiss::IndexIDMapTemplate; 401 | %template(IndexIDMap2) faiss::IndexIDMap2Template; 402 | %template(IndexBinaryIDMap2) faiss::IndexIDMap2Template; 403 | 404 | #ifdef GPU_WRAPPER 405 | 406 | // quiet SWIG warnings 407 | %ignore faiss::gpu::GpuIndexIVF::GpuIndexIVF; 408 | 409 | %include 410 | %include 411 | %include 412 | %include 413 | %include 414 | %include 415 | %include 416 | %include 417 | %include 418 | %include 419 | %include 420 | %include 421 | 422 | #ifdef SWIGLUA 423 | 424 | /// in Lua, swigfaiss_gpu is known as swigfaiss 425 | %luacode { 426 | local swigfaiss = swigfaiss_gpu 427 | } 428 | 429 | #endif 430 | 431 | 432 | #endif 433 | 434 | 435 | 436 | 437 | /******************************************************************* 438 | * Lua-specific: support async execution of searches in an index 439 | * Python equivalent is just to use Python threads. 440 | *******************************************************************/ 441 | 442 | 443 | #ifdef SWIGLUA 444 | 445 | %{ 446 | 447 | 448 | namespace faiss { 449 | 450 | struct AsyncIndexSearchC { 451 | typedef Index::idx_t idx_t; 452 | const Index * index; 453 | 454 | idx_t n; 455 | const float *x; 456 | idx_t k; 457 | float *distances; 458 | idx_t *labels; 459 | 460 | bool is_finished; 461 | 462 | pthread_t thread; 463 | 464 | 465 | AsyncIndexSearchC (const Index *index, 466 | idx_t n, const float *x, idx_t k, 467 | float *distances, idx_t *labels): 468 | index(index), n(n), x(x), k(k), distances(distances), 469 | labels(labels) 470 | { 471 | is_finished = false; 472 | pthread_create (&thread, NULL, &AsyncIndexSearchC::callback, 473 | this); 474 | } 475 | 476 | static void *callback (void *arg) 477 | { 478 | AsyncIndexSearchC *aidx = (AsyncIndexSearchC *)arg; 479 | aidx->do_search(); 480 | return NULL; 481 | } 482 | 483 | void do_search () 484 | { 485 | index->search (n, x, k, distances, labels); 486 | } 487 | void join () 488 | { 489 | pthread_join (thread, NULL); 490 | } 491 | 492 | }; 493 | 494 | } 495 | 496 | %} 497 | 498 | // re-decrlare only what we need 499 | namespace faiss { 500 | 501 | struct AsyncIndexSearchC { 502 | typedef Index::idx_t idx_t; 503 | bool is_finished; 504 | AsyncIndexSearchC (const Index *index, 505 | idx_t n, const float *x, idx_t k, 506 | float *distances, idx_t *labels); 507 | 508 | 509 | void join (); 510 | }; 511 | 512 | } 513 | 514 | 515 | #endif 516 | 517 | 518 | 519 | 520 | /******************************************************************* 521 | * downcast return of some functions so that the sub-class is used 522 | * instead of the generic upper-class. 523 | *******************************************************************/ 524 | 525 | #ifdef SWIGJAVA 526 | 527 | %define DOWNCAST(subclass) 528 | if (dynamic_cast ($1)) { 529 | faiss::subclass *instance_ptr = (faiss::subclass *)$1; 530 | $result = (jlong)instance_ptr; 531 | } else 532 | %enddef 533 | 534 | %define DOWNCAST_GPU(subclass) 535 | if (dynamic_cast ($1)) { 536 | faiss::gpu::subclass *instance_ptr = (faiss::gpu::subclass *)$1; 537 | $result = (jlong)instance_ptr; 538 | } else 539 | %enddef 540 | 541 | #endif 542 | 543 | #ifdef SWIGLUA 544 | 545 | %define DOWNCAST(subclass) 546 | if (dynamic_cast ($1)) { 547 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## subclass, $owner); 548 | } else 549 | %enddef 550 | 551 | %define DOWNCAST2(subclass, longname) 552 | if (dynamic_cast ($1)) { 553 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__ ## longname, $owner); 554 | } else 555 | %enddef 556 | 557 | %define DOWNCAST_GPU(subclass) 558 | if (dynamic_cast ($1)) { 559 | SWIG_NewPointerObj(L,$1,SWIGTYPE_p_faiss__gpu__ ## subclass, $owner); 560 | } else 561 | %enddef 562 | 563 | #endif 564 | 565 | 566 | #ifdef SWIGPYTHON 567 | 568 | %define DOWNCAST(subclass) 569 | if (dynamic_cast ($1)) { 570 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner); 571 | } else 572 | %enddef 573 | 574 | %define DOWNCAST2(subclass, longname) 575 | if (dynamic_cast ($1)) { 576 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner); 577 | } else 578 | %enddef 579 | 580 | %define DOWNCAST_GPU(subclass) 581 | if (dynamic_cast ($1)) { 582 | $result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__gpu__ ## subclass,$owner); 583 | } else 584 | %enddef 585 | 586 | #endif 587 | 588 | %newobject read_index; 589 | %newobject read_index_binary; 590 | %newobject read_VectorTransform; 591 | %newobject read_ProductQuantizer; 592 | %newobject clone_index; 593 | %newobject clone_VectorTransform; 594 | 595 | // Subclasses should appear before their parent 596 | %typemap(out) faiss::Index * { 597 | DOWNCAST ( IndexIDMap ) 598 | DOWNCAST ( IndexIDMap2 ) 599 | DOWNCAST ( IndexShards ) 600 | DOWNCAST ( IndexReplicas ) 601 | DOWNCAST ( IndexIVFPQR ) 602 | DOWNCAST ( IndexIVFPQ ) 603 | DOWNCAST ( IndexIVFSpectralHash ) 604 | DOWNCAST ( IndexIVFScalarQuantizer ) 605 | DOWNCAST ( IndexIVFFlatDedup ) 606 | DOWNCAST ( IndexIVFFlat ) 607 | DOWNCAST ( IndexIVF ) 608 | DOWNCAST ( IndexFlat ) 609 | DOWNCAST ( IndexPQ ) 610 | DOWNCAST ( IndexScalarQuantizer ) 611 | DOWNCAST ( IndexLSH ) 612 | DOWNCAST ( IndexLattice ) 613 | DOWNCAST ( IndexPreTransform ) 614 | DOWNCAST ( MultiIndexQuantizer ) 615 | DOWNCAST ( IndexHNSWFlat ) 616 | DOWNCAST ( IndexHNSWPQ ) 617 | DOWNCAST ( IndexHNSWSQ ) 618 | DOWNCAST ( IndexHNSW2Level ) 619 | DOWNCAST ( Index2Layer ) 620 | #ifdef GPU_WRAPPER 621 | DOWNCAST_GPU ( GpuIndexIVFPQ ) 622 | DOWNCAST_GPU ( GpuIndexIVFFlat ) 623 | DOWNCAST_GPU ( GpuIndexIVFScalarQuantizer ) 624 | DOWNCAST_GPU ( GpuIndexFlat ) 625 | #endif 626 | // default for non-recognized classes 627 | DOWNCAST ( Index ) 628 | if ($1 == NULL) 629 | { 630 | #ifdef SWIGPYTHON 631 | $result = SWIG_Py_Void(); 632 | #endif 633 | #ifdef SWIGJAVA 634 | $result = 0; 635 | #endif 636 | // Lua does not need a push for nil 637 | } else { 638 | assert(false); 639 | } 640 | #ifdef SWIGLUA 641 | SWIG_arg++; 642 | #endif 643 | } 644 | 645 | %typemap(out) faiss::IndexBinary * { 646 | DOWNCAST ( IndexBinaryReplicas ) 647 | DOWNCAST ( IndexBinaryIDMap ) 648 | DOWNCAST ( IndexBinaryIDMap2 ) 649 | DOWNCAST ( IndexBinaryIVF ) 650 | DOWNCAST ( IndexBinaryFlat ) 651 | DOWNCAST ( IndexBinaryFromFloat ) 652 | DOWNCAST ( IndexBinaryHNSW ) 653 | #ifdef GPU_WRAPPER 654 | DOWNCAST_GPU ( GpuIndexBinaryFlat ) 655 | #endif 656 | // default for non-recognized classes 657 | DOWNCAST ( IndexBinary ) 658 | if ($1 == NULL) 659 | { 660 | #ifdef SWIGPYTHON 661 | $result = SWIG_Py_Void(); 662 | #endif 663 | #ifdef SWIGJAVA 664 | $result = 0; 665 | #endif 666 | // Lua does not need a push for nil 667 | } else { 668 | assert(false); 669 | } 670 | #ifdef SWIGLUA 671 | SWIG_arg++; 672 | #endif 673 | } 674 | 675 | %typemap(out) faiss::VectorTransform * { 676 | DOWNCAST (RemapDimensionsTransform) 677 | DOWNCAST (OPQMatrix) 678 | DOWNCAST (PCAMatrix) 679 | DOWNCAST (RandomRotationMatrix) 680 | DOWNCAST (LinearTransform) 681 | DOWNCAST (NormalizationTransform) 682 | DOWNCAST (CenteringTransform) 683 | DOWNCAST (VectorTransform) 684 | { 685 | assert(false); 686 | } 687 | #ifdef SWIGLUA 688 | SWIG_arg++; 689 | #endif 690 | } 691 | 692 | %typemap(out) faiss::InvertedLists * { 693 | DOWNCAST (ArrayInvertedLists) 694 | DOWNCAST (OnDiskInvertedLists) 695 | DOWNCAST (VStackInvertedLists) 696 | DOWNCAST (HStackInvertedLists) 697 | DOWNCAST (MaskedInvertedLists) 698 | DOWNCAST (InvertedLists) 699 | { 700 | assert(false); 701 | } 702 | #ifdef SWIGLUA 703 | SWIG_arg++; 704 | #endif 705 | } 706 | 707 | // just to downcast pointers that come from elsewhere (eg. direct 708 | // access to object fields) 709 | %inline %{ 710 | faiss::Index * downcast_index (faiss::Index *index) 711 | { 712 | return index; 713 | } 714 | faiss::VectorTransform * downcast_VectorTransform (faiss::VectorTransform *vt) 715 | { 716 | return vt; 717 | } 718 | faiss::IndexBinary * downcast_IndexBinary (faiss::IndexBinary *index) 719 | { 720 | return index; 721 | } 722 | faiss::InvertedLists * downcast_InvertedLists (faiss::InvertedLists *il) 723 | { 724 | return il; 725 | } 726 | %} 727 | 728 | %include 729 | %include 730 | %include 731 | 732 | %newobject index_factory; 733 | %newobject index_binary_factory; 734 | 735 | %include 736 | %include 737 | %include 738 | 739 | 740 | #ifdef GPU_WRAPPER 741 | 742 | %include 743 | 744 | %newobject index_gpu_to_cpu; 745 | %newobject index_cpu_to_gpu; 746 | %newobject index_cpu_to_gpu_multiple; 747 | 748 | %include 749 | 750 | #endif 751 | 752 | // Python-specific: do not release GIL any more, as functions below 753 | // use the Python/C API 754 | #ifdef SWIGPYTHON 755 | %exception; 756 | #endif 757 | 758 | 759 | 760 | 761 | 762 | /******************************************************************* 763 | * Python specific: numpy array <-> C++ pointer interface 764 | *******************************************************************/ 765 | 766 | #ifdef SWIGPYTHON 767 | 768 | %{ 769 | PyObject *swig_ptr (PyObject *a) 770 | { 771 | if(!PyArray_Check(a)) { 772 | PyErr_SetString(PyExc_ValueError, "input not a numpy array"); 773 | return NULL; 774 | } 775 | PyArrayObject *ao = (PyArrayObject *)a; 776 | 777 | if(!PyArray_ISCONTIGUOUS(ao)) { 778 | PyErr_SetString(PyExc_ValueError, "array is not C-contiguous"); 779 | return NULL; 780 | } 781 | void * data = PyArray_DATA(ao); 782 | if(PyArray_TYPE(ao) == NPY_FLOAT32) { 783 | return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0); 784 | } 785 | if(PyArray_TYPE(ao) == NPY_FLOAT64) { 786 | return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0); 787 | } 788 | if(PyArray_TYPE(ao) == NPY_INT32) { 789 | return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0); 790 | } 791 | if(PyArray_TYPE(ao) == NPY_UINT8) { 792 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0); 793 | } 794 | if(PyArray_TYPE(ao) == NPY_INT8) { 795 | return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0); 796 | } 797 | if(PyArray_TYPE(ao) == NPY_UINT64) { 798 | #ifdef SWIGWORDSIZE64 799 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0); 800 | #else 801 | return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long_long, 0); 802 | #endif 803 | } 804 | if(PyArray_TYPE(ao) == NPY_INT64) { 805 | #ifdef SWIGWORDSIZE64 806 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0); 807 | #else 808 | return SWIG_NewPointerObj(data, SWIGTYPE_p_long_long, 0); 809 | #endif 810 | } 811 | PyErr_SetString(PyExc_ValueError, "did not recognize array type"); 812 | return NULL; 813 | } 814 | 815 | 816 | struct PythonInterruptCallback: faiss::InterruptCallback { 817 | 818 | bool want_interrupt () override { 819 | int err; 820 | { 821 | PyGILState_STATE gstate; 822 | gstate = PyGILState_Ensure(); 823 | err = PyErr_CheckSignals(); 824 | PyGILState_Release(gstate); 825 | } 826 | return err == -1; 827 | } 828 | 829 | }; 830 | 831 | 832 | %} 833 | 834 | 835 | %init %{ 836 | /* needed, else crash at runtime */ 837 | import_array(); 838 | 839 | faiss::InterruptCallback::instance.reset(new PythonInterruptCallback()); 840 | 841 | %} 842 | 843 | // return a pointer usable as input for functions that expect pointers 844 | PyObject *swig_ptr (PyObject *a); 845 | 846 | %define REV_SWIG_PTR(ctype, numpytype) 847 | 848 | %{ 849 | PyObject * rev_swig_ptr(ctype *src, npy_intp size) { 850 | return PyArray_SimpleNewFromData(1, &size, numpytype, src); 851 | } 852 | %} 853 | 854 | PyObject * rev_swig_ptr(ctype *src, size_t size); 855 | 856 | %enddef 857 | 858 | REV_SWIG_PTR(float, NPY_FLOAT32); 859 | REV_SWIG_PTR(int, NPY_INT32); 860 | REV_SWIG_PTR(unsigned char, NPY_UINT8); 861 | REV_SWIG_PTR(int64_t, NPY_INT64); 862 | REV_SWIG_PTR(uint64_t, NPY_UINT64); 863 | 864 | #endif 865 | 866 | 867 | 868 | /******************************************************************* 869 | * Lua specific: Torch tensor <-> C++ pointer interface 870 | *******************************************************************/ 871 | 872 | #ifdef SWIGLUA 873 | 874 | 875 | // provide a XXX_ptr function to convert Lua XXXTensor -> C++ XXX* 876 | 877 | %define TYPE_CONVERSION(ctype, tensortype) 878 | 879 | // typemap for the *_ptr_from_cdata function 880 | %typemap(in) ctype** { 881 | if(lua_type(L, $input) != 10) { 882 | fprintf(stderr, "not cdata input\n"); 883 | SWIG_fail; 884 | } 885 | $1 = (ctype**)lua_topointer(L, $input); 886 | } 887 | 888 | 889 | // SWIG and C declaration for the *_ptr_from_cdata function 890 | %{ 891 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs) { 892 | return *x + ofs; 893 | } 894 | %} 895 | ctype * ctype ## _ptr_from_cdata(ctype **x, long ofs); 896 | 897 | // the *_ptr function 898 | %luacode { 899 | 900 | function swigfaiss. ctype ## _ptr(tensor) 901 | assert(tensor:type() == "torch." .. # tensortype, "need a " .. # tensortype) 902 | assert(tensor:isContiguous(), "requires contiguous tensor") 903 | return swigfaiss. ctype ## _ptr_from_cdata( 904 | tensor:storage():data(), 905 | tensor:storageOffset() - 1) 906 | end 907 | 908 | } 909 | 910 | %enddef 911 | 912 | TYPE_CONVERSION (int, IntTensor) 913 | TYPE_CONVERSION (float, FloatTensor) 914 | TYPE_CONVERSION (long, LongTensor) 915 | TYPE_CONVERSION (uint64_t, LongTensor) 916 | TYPE_CONVERSION (uint8_t, ByteTensor) 917 | 918 | #endif 919 | 920 | /******************************************************************* 921 | * How should the template objects apprear in the scripting language? 922 | *******************************************************************/ 923 | 924 | // answer: the same as the C++ typedefs, but we still have to redefine them 925 | 926 | %template() faiss::CMin; 927 | %template() faiss::CMin; 928 | %template() faiss::CMax; 929 | %template() faiss::CMax; 930 | 931 | %template(float_minheap_array_t) faiss::HeapArray >; 932 | %template(int_minheap_array_t) faiss::HeapArray >; 933 | 934 | %template(float_maxheap_array_t) faiss::HeapArray >; 935 | %template(int_maxheap_array_t) faiss::HeapArray >; 936 | 937 | 938 | /******************************************************************* 939 | * Expose a few basic functions 940 | *******************************************************************/ 941 | 942 | 943 | void omp_set_num_threads (int num_threads); 944 | int omp_get_max_threads (); 945 | void *memcpy(void *dest, const void *src, size_t n); 946 | 947 | 948 | /******************************************************************* 949 | * For Faiss/Pytorch interop via pointers encoded as longs 950 | *******************************************************************/ 951 | 952 | %inline %{ 953 | float * cast_integer_to_float_ptr (long x) { 954 | return (float*)x; 955 | } 956 | 957 | long * cast_integer_to_long_ptr (long x) { 958 | return (long*)x; 959 | } 960 | 961 | int * cast_integer_to_int_ptr (long x) { 962 | return (int*)x; 963 | } 964 | 965 | %} 966 | 967 | 968 | 969 | /******************************************************************* 970 | * Range search interface 971 | *******************************************************************/ 972 | 973 | %ignore faiss::BufferList::Buffer; 974 | %ignore faiss::RangeSearchPartialResult::QueryResult; 975 | %ignore faiss::IDSelectorBatch::set; 976 | %ignore faiss::IDSelectorBatch::bloom; 977 | 978 | %ignore faiss::InterruptCallback::instance; 979 | %ignore faiss::InterruptCallback::lock; 980 | %include 981 | 982 | %{ 983 | // may be useful for lua code launched in background from shell 984 | 985 | #include 986 | void ignore_SIGTTIN() { 987 | signal(SIGTTIN, SIG_IGN); 988 | } 989 | %} 990 | 991 | void ignore_SIGTTIN(); 992 | 993 | 994 | %inline %{ 995 | 996 | // numpy misses a hash table implementation, hence this class. It 997 | // represents not found values as -1 like in the Index implementation 998 | 999 | struct MapLong2Long { 1000 | std::unordered_map map; 1001 | 1002 | void add(size_t n, const int64_t *keys, const int64_t *vals) { 1003 | map.reserve(map.size() + n); 1004 | for (size_t i = 0; i < n; i++) { 1005 | map[keys[i]] = vals[i]; 1006 | } 1007 | } 1008 | 1009 | long search(int64_t key) { 1010 | if (map.count(key) == 0) { 1011 | return -1; 1012 | } else { 1013 | return map[key]; 1014 | } 1015 | } 1016 | 1017 | void search_multiple(size_t n, int64_t *keys, int64_t * vals) { 1018 | for (size_t i = 0; i < n; i++) { 1019 | vals[i] = search(keys[i]); 1020 | } 1021 | } 1022 | }; 1023 | 1024 | %} 1025 | 1026 | %inline %{ 1027 | void wait() { 1028 | // in gdb, use return to get out of this function 1029 | for(int i = 0; i == 0; i += 0); 1030 | } 1031 | %} 1032 | 1033 | // End of file... 1034 | --------------------------------------------------------------------------------