├── requirements.txt ├── pyfaiss ├── __init__.py ├── faiss_utils │ ├── __init__.py │ └── FAISS.md ├── faiss_search.py └── train_index.py ├── .gitignore ├── Dockerfile.faiss ├── setup.py ├── Dockerfile.conda └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyfaiss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyfaiss/faiss_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | pyfaiss/faiss_utils/_swigfaiss.so 3 | pyfaiss/faiss_utils/swigfaiss.py 4 | -------------------------------------------------------------------------------- /Dockerfile.faiss: -------------------------------------------------------------------------------- 1 | from conda3:1.0 2 | 3 | # RUN conda install pytorch -y 4 | RUN conda install faiss-cpu -c pytorch -y 5 | 6 | CMD [ "/bin/bash" ] 7 | -------------------------------------------------------------------------------- /pyfaiss/faiss_utils/FAISS.md: -------------------------------------------------------------------------------- 1 | ## Install Faiss 2 | [download](https://github.com/facebookresearch/faiss) faiss library and [install](https://github.com/facebookresearch/faiss/blob/master/INSTALL). 3 | 4 | ## Copy Files 5 | after **build** and **install** faiss, copy three files:**faiss.py, swigfaiss.py, _swigfaiss.so** to current path. 6 | ``` 7 | cp build_files_path . 8 | ``` 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | try: 5 | from setuptools import setup, find_packages 6 | except ImportError: 7 | from distutils.core import setup 8 | 9 | 10 | setup( 11 | name='pyfaiss', 12 | version='1.0', 13 | description= 'a python api for faiss search and add.', 14 | url = 'http://gitlab.benditoutiao.com/bdtt/pyfaiss', 15 | author = 'Fisher', 16 | author_email = '992049896@qq.com', 17 | classifiers=[ 'Programming Language :: Python :: 2.7',], 18 | include_package_data=True, 19 | packages = find_packages() 20 | ) 21 | 22 | -------------------------------------------------------------------------------- /Dockerfile.conda: -------------------------------------------------------------------------------- 1 | from ubuntu:16.04 2 | RUN apt-get update && apt-get install -y --no-install-recommends \ 3 | bzip2 \ 4 | g++ \ 5 | git \ 6 | graphviz \ 7 | libgl1-mesa-glx \ 8 | libhdf5-dev \ 9 | openmpi-bin \ 10 | wget && \ 11 | rm -rf /var/lib/apt/lists/* 12 | 13 | RUN sed -i 's/archive.ubuntu.com/mirrors.ustc.edu.cn/g' /etc/apt/sources.list 14 | RUN apt-get update 15 | 16 | ADD ./Anaconda3-2019.10-Linux-x86_64.sh ./anaconda.sh 17 | 18 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 19 | ENV PATH /opt/conda/bin:$PATH 20 | RUN /bin/bash ./anaconda.sh -b -p /opt/conda && rm ./anaconda.sh && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && echo "conda activate base" >> ~/.bashrc && find /opt/conda/ -follow -type f -name '*.a' -delete && find /opt/conda/ -follow -type f -name '*.js.map' -delete && /opt/conda/bin/conda clean -afy 21 | 22 | 23 | CMD [ "/bin/bash" ] 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyFaiss 2 | ![](https://readthedocs.org/projects/pygorithm/badge/?version=latest) ![](https://img.shields.io/badge/python%20-%202.7-brightgreen.svg) 3 | ======== 4 | 5 | **NOTE:To use this package must ensure has installed faiss lib** 6 | 7 | ## `Install` 8 | - 使用该模块时,在项目的**[requirements.txt]**中添加 9 | 并使用命令行安装: 10 | 11 | ` /usr/local/bin/pip install -r requirements.txt` 12 | 13 | - 或者通过单独安装 14 | 15 | ## `Include` 16 | + [train_index] 17 | 18 | 训练faiss `index` 模块**TrainIndex**初始参数详解: 19 | 20 | class TrainIndex(kwargs): 21 | - `files`: (type `list`) 文档向量文件列表,注意:不是文件名,而是文档向量的绝对路径. 22 | - `vpath`: 文档向量读取路径,读取的数据应满足一定的前缀或者后缀规则(`prefix`, `suffix`). 23 | - `prefix`:文档向量文件名前缀,如 `vector_**`. 24 | - `suffix`:文档向量文件名后缀,如 `**_vector`. 25 | - `dpath`: `index`训练结果存储路径. 26 | - `iname`: `index`训练结果存储名称,默认为(`index.index`). 27 | - `direct`: 实例化`TrainIndex`时,直接使用默认提取的向量文件进行训练生成`index`,并返回. 28 | - `fpath`: 安装`faiss` 产生的编译文件(faiss.py, swigfaiss.py, \_swigfaiss.so)存储路径,该模块运行必须导入这些相关文件. 29 | 30 | + [faiss_search] 31 | 32 | class FaissSearch(kwargs) 33 | 两种传入方式,1. 直接传入index; 2. 传入index存储路径: 34 | - `index`:直接传入`index`文件; 35 | - `ipath`:读取`index`的路径(包含index的文件名,如:`/home/xxx/x/x.index`); 36 | - `fpath`:安装`faiss` 产生的编译文件(faiss.py, swigfaiss.py, \_swigfaiss.so)存储路径,该模块运行必须导入这些相关文件. 37 | 38 | faiss搜索相关接口 39 | - `add()`: 40 | - `add_one()`: 41 | - `search()`: 42 | - `search_many()`: 43 | 44 | 45 | ## `Basic Usage` 46 | ```python 47 | >>> from pyfaiss.train_index import TrainIndex 48 | >>> vpath = u'/home/user/' 49 | >>> prefix= u'vector_' 50 | >>> fpath = '/home/user/' 51 | >>> trainer = TrainIndex(vpath=vpath, prefix=prefix, fpath=fpath) 52 | >>> files = trainer.files[:2] 53 | >>> print files 54 | >>> index = trainer.train(files) 55 | >>> print index.ntotal 56 | 211897 57 | ``` 58 | 59 | ## `Faiss install by Docker` 60 | ``` 61 | # build conda3 image 62 | docker build -f=Dockerfile.conda -t conda3:1.0 . 63 | 64 | # build faiss image 65 | docker build -f=Dockerfile.faiss -t conda-faiss:1.0 . 66 | ``` 67 | -------------------------------------------------------------------------------- /pyfaiss/faiss_search.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # coding=utf8 3 | 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | #from .faiss_utils import faiss 9 | 10 | class FaissSearch(object): 11 | """ 12 | intent to seach similar news from faiss index. 13 | 14 | @param:`index_path`: which index to search. 15 | @param:`callback` : future extension. 16 | get vector func, different index different func. 17 | """ 18 | def __init__(self, **kwargs): 19 | self.index = None 20 | self.callback = None 21 | if kwargs.has_key('index'): 22 | self.index = kwargs['index'] 23 | if not self.index and all(k in kwargs for k in ['ipath', 'fpath']): 24 | fpath = kwargs.get('fpath') 25 | try: 26 | import faiss 27 | except ImportError: 28 | sys.path.insert(0, fpath) 29 | import faiss 30 | ipath = kwargs.get('ipath') 31 | self.index = faiss.read_index(ipath) 32 | 33 | assert self.index is not None 34 | 35 | @staticmethod 36 | def read_index(ipath, fpath='./'): 37 | """ 38 | @param: fpath: faiss file path. 39 | """ 40 | try: 41 | import faiss 42 | except ImportError: 43 | sys.path.insert(0, fpath) 44 | import faiss 45 | return faiss.read_index(ipath) 46 | 47 | @staticmethod 48 | def write_index(index, fpath='./', savepath='./', index_name='index.index'): 49 | try: 50 | import faiss 51 | except ImportError: 52 | sys.path.insert(0, fpath) 53 | import faiss 54 | if not os.path.exists(savepath): 55 | os.makedirs(savepath) 56 | index_name = savepath + index_name 57 | faiss.write_index(index, index_name) 58 | 59 | def add_one(self, item_vector, **kwargs): 60 | """ 61 | only to add vector to index. 62 | @param: `text` : type `str`, search text. 63 | @param: **kwargs : other params 64 | `try_time` : try the most time to get vector server. 65 | `id`:type : `ndarray` or `list`, which to add in faiss 66 | index. 67 | `add_with_ids`: flag, indicate whether to add with user- 68 | defined ids. 69 | """ 70 | if not all(k in kwargs for k in ('id', 'add_with_ids')): 71 | import warnings 72 | warnings.warn( 73 | """ 74 | if use faiss `add_with_ids` functions, field `id`, 75 | `add_with_ids` must in parameters kwargs at the same 76 | time. 77 | """) 78 | 79 | if not isinstance(item_vector, np.ndarray): 80 | raise ValueError('vector must `numpy.ndarray` type, not %s.' % type(item_vector)) 81 | if len(item_vector.shape) < 2: 82 | item_vector = item_vector.reshape([1,item_vector.shape[0]]).astype('float32') 83 | if not kwargs.get('add_with_ids', False): 84 | self.index.add(item_vector) 85 | return 86 | if not kwargs.has_key('id'): 87 | raise KeyError('use add_with_ids must provied `id`') 88 | id = kwargs.get('id') 89 | if not isinstance(id, np.ndarray): 90 | id = np.array(id).astype('int') 91 | 92 | self.index.add_with_ids(item_vector, id) 93 | return 94 | 95 | def add_many(self, item_vectors, **kwargs): 96 | """ 97 | add clust vectors to index. 98 | """ 99 | if not all(k in kwargs for k in ('ids', 'add_with_ids')): 100 | import warnings 101 | warnings.warn( 102 | """ 103 | if use faiss `add_with_ids` functions, field `ids`, 104 | `add_with_ids` must in parameters kwargs at the same 105 | time. 106 | """) 107 | if not isinstance(item_vectors, np.ndarray): 108 | raise ValueError('vector must `numpy.ndarray` type, not %s.' % type(item_vectors)) 109 | assert len(item_vectors.shape) == 2 110 | if not kwargs.get('add_with_ids', False): 111 | self.index.add(item_vectors) 112 | return 113 | if not kwargs.has_key('ids'): 114 | raise KeyError('use add_with_ids must provied `ids`') 115 | ids = kwargs.get('ids') 116 | if not isinstance(ids, np.ndarray): 117 | ids = np.array(ids).astype('int') 118 | assert len(ids.shape) == 1 119 | assert item_vectors.shape[0] == ids.shape[0] 120 | 121 | self.index.add_with_ids(item_vectors, ids) 122 | return 123 | 124 | def search_one(self, item_vector, **kwargs): 125 | """ 126 | @param: `text` : type `str`, search text. 127 | @param: **kwargs : other params 128 | `k`:return most 'k' nearnest search result. 129 | `try_time`:try the most time to get vector server. 130 | `add`: flag, after search add new vector to index, then 131 | return the search result. 132 | """ 133 | if not isinstance(item_vector, np.ndarray): 134 | raise ValueError('vector must `numpy.ndarray` type, not %s.' % type(item_vector)) 135 | if len(item_vector.shape) < 2: 136 | item_vector = item_vector.reshape([1,item_vector.shape[0]]).astype('float32') 137 | add = False 138 | search_len = kwargs.get('k', 100) 139 | D, I = self.index.search(item_vector, search_len) 140 | if u'add' in kwargs: 141 | add = kwargs.get('add') 142 | if add: 143 | self.add_one(item_vector, **kwargs) 144 | return (D, I) 145 | 146 | def search_many(self, item_vectors, **kwargs): 147 | """ 148 | @param: `texts` : type `list`, search text list. 149 | @param: **kwargs : other params 150 | `k`:return most 'k' nearnest search result. 151 | `try_time`:try the most time to get vector server 152 | """ 153 | 154 | if not isinstance(item_vectors, np.ndarray): 155 | raise ValueError('vector must `numpy.ndarray` type, not %s.' % type(item_vectors)) 156 | assert len(item_vectors.shape) == 2 157 | add_many = False 158 | search_len = kwargs.get('k', 100) 159 | D, I = self.index.search(item_vectors, search_len) 160 | if u'add_many' in kwargs: 161 | add_many = kwargs.get('add_many') 162 | if add_many: 163 | self.add_many(item_vectors, **kwargs) 164 | return (D, I) 165 | 166 | def remove_ids(self, ids): 167 | """ 168 | intent to remove ids from index. 169 | """ 170 | pass 171 | 172 | 173 | -------------------------------------------------------------------------------- /pyfaiss/train_index.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | """ 5 | pyfaiss.train_faiss. 6 | ---------------------------------------------------- 7 | this module intend to train faiss index object. 8 | input vector path or vector files, and index output 9 | path. Then return the trained index object to dest- 10 | ination path. 11 | ---------------------------------------------------- 12 | """ 13 | 14 | import os 15 | import sys 16 | import time 17 | import logging 18 | import warnings 19 | import numpy as np 20 | 21 | from mylogging.namedlogger import NamedLogger 22 | 23 | #from faiss_utils import faiss 24 | 25 | class TrainIndex(object): 26 | def __init__(self, **kwargs): 27 | """ 28 | :param files: from which files to get vectors. 29 | :param vpath: from where to get all vector files, note this files 30 | name should startswith `vector_`. 31 | :param dpath: where trained index object to save. 32 | :param iname: trained index saved as iname. 33 | :param fpath: faiss built files path. from where to import faiss 34 | modules. 35 | :param direct:when instance TrainIndex object direct to train. 36 | """ 37 | self.files = [] 38 | self.vpath = '' 39 | self.dpath = '' 40 | self.iname = '' 41 | self.fpath = '' 42 | self.direct = False 43 | self.index = None 44 | if not any([u'files' in kwargs, u'vpath' in kwargs]): 45 | Error = 'Must need anyone key in (`files`, `vpath`)' 46 | raise KeyError(Error) 47 | 48 | if u'files' in kwargs: 49 | self.files = kwargs.get('files') 50 | if u'vpath' in kwargs: 51 | self.vpath = kwargs.get('vpath') 52 | self.prefix= kwargs.get('prefix', '\1') 53 | self.suffix= kwargs.get('suffix', '\1') 54 | self.fix = {'prefix':self.prefix} or {'suffix':self.suffix} 55 | if self.vpath: 56 | if self.files: 57 | Warn = ('Input `files` and `vpath` at the same time, input ' 58 | 'files list will be replaced by vpath select files.') 59 | warnings.warn(Warn) 60 | self.files = self._get_files() 61 | if u'fpath' in kwargs: 62 | self.fpath = kwargs.get('fpath') 63 | 64 | #: index object save path 65 | if u'dpath' in kwargs: 66 | self.dpath = kwargs.get('dpath') 67 | if u'iname' not in kwargs: 68 | Warn = ('NO `iname` input, the index object will save as ' 69 | 'default name [index.index]') 70 | warnings.warn(Warn) 71 | self.iname = kwargs.get('iname', 'index.index') 72 | self.logger = kwargs.get('logger', '') or logging.getLogger(__file__) 73 | if u'direct' in kwargs: 74 | self.direct = kwargs.get('direct') 75 | if self.direct: 76 | self.index = self.train(self.files) 77 | 78 | 79 | def _get_files(self): 80 | """ select files from vpath """ 81 | 82 | if 'prefix' in self.fix: 83 | _files = [file for file in os.listdir(self.vpath) if \ 84 | file.startswith(self.fix['prefix'])] 85 | else: 86 | _files = [file for file in os.listdir(self.vpath) if \ 87 | file.endswith(self.fix['suffix'])] 88 | files = [os.path.join(self.vpath, _) for _ in _files] 89 | return files 90 | 91 | @staticmethod 92 | def _generate_npvectors(files): 93 | """ 94 | intend from vector files to generate numpy matrixs. 95 | NOTE:this files not only filenames, is a list about 96 | file absoult path. 97 | 98 | :param files: type `list`: vector files absoulte path. 99 | @rtype: numpy matrix. 100 | """ 101 | count = 0 102 | allvectors = list() 103 | for file in files: 104 | with open(file, 'r') as fopen: 105 | vectorlines = fopen.readlines() 106 | count += len(vectorlines) 107 | allvectors.extend(vectorlines) 108 | npvectors = np.zeros(shape=[count, 102]).astype('float32') 109 | for i, line in enumerate(allvectors): 110 | l = np.array([float(_) for _ in line.strip().split(' ')]).astype('float32') 111 | npvectors[i] = l 112 | 113 | return npvectors 114 | 115 | def train(self, files): 116 | """ train index """ 117 | 118 | npvectors = self._generate_npvectors(files) 119 | 120 | #: npvectors[0:2] save docvector ids' number, 121 | #: and [2:] save docvetor true infos. 122 | vectors = npvectors[:, 2:].astype('float32') 123 | vectorids = npvectors[:, :2] 124 | vector_ids= vectorids[:, 0].astype('int') * 10000 + vectorids[:, 1].astype('int') 125 | 126 | try: 127 | from faiss_utils import faiss 128 | except ImportError: 129 | if not self.fpath: 130 | Error = ('Can not import faiss, need to put faiss build files to ' 131 | './faiss_utils or input ref fpath.') 132 | raise ImportError(Error) 133 | sys.path.insert(0, self.fpath) 134 | import faiss 135 | 136 | index = faiss.index_factory(100, "PCA80, IVF1024, Flat") 137 | start = time.time() 138 | self.logger.info('start to train index') 139 | index.train(vectors) 140 | index.add_with_ids(vectors, vector_ids) 141 | 142 | if self.dpath: 143 | self.logger.info('write index to %r as name: %r' % 144 | (self.dpath, self.iname)) 145 | dfile = os.path.join(self.dpath, self.iname) 146 | faiss.write_index(index, dfile) 147 | 148 | self.logger.info('train index cost time:{0}'.format(time.time()-start)) 149 | return index 150 | 151 | if __name__ == '__main__': 152 | logger = NamedLogger('index_train') 153 | trainer = TrainIndex(vpath='/home/yulianghua/gitlab/wechat_marking_system/title_vectors', 154 | prefix='vector_', 155 | fpath = '/home/yulianghua/github/faiss', 156 | direct=True, 157 | #files = ['/home/yulianghua/gitlab/news_extract/wechat/vectors/vector_20170819'], 158 | #dpath='/home/yulianghua', 159 | #iname='test.index', 160 | logger=logger) 161 | index = trainer.index 162 | 163 | #files = trainer.files[:2] 164 | #print files 165 | #index = trainer.train(files) 166 | print 'index ntotal', index.ntotal 167 | 168 | 169 | --------------------------------------------------------------------------------