├── .gitattributes ├── CMakeLists.txt ├── README.md ├── cmake ├── CompilerSettings.cmake ├── ProjectBoost.cmake ├── ProjectFaiss.cmake └── ProjectPistache.cmake ├── docker ├── Dockerfile └── README.md ├── docs ├── .Doxyfile.in ├── CMakeLists.txt ├── Code-Style.md ├── UserManual.md └── config.md ├── scripts ├── README.md ├── git-lfs.deb.sh ├── install.sh ├── install_cuda.sh ├── install_deps.sh ├── monitor.sh └── pack.sh ├── src ├── CMakeLists.txt ├── common │ ├── CMakeLists.txt │ ├── configParams.cpp │ ├── configParams.h │ ├── easylog++.h │ ├── easylogging++.cpp │ ├── easylogging++.h │ ├── error.h │ ├── json.h │ ├── memusage.h │ └── version.h ├── libRestServer │ ├── CMakeLists.txt │ ├── RequestHandler.cpp │ ├── RequestHandler.h │ ├── RestServer.cpp │ ├── RestServer.h │ ├── SearchProcessor.cpp │ └── SearchProcessor.h ├── libSearch │ ├── CMakeLists.txt │ ├── FaissInterface.cpp │ └── FaissInterface.h └── main.cpp └── test ├── CMakeLists.txt ├── FaissCPUSearch.cpp ├── FaissGPUSearch.cpp ├── FaissLoadTest.cpp ├── HNSWSearch.cpp ├── python-test ├── add.py ├── query.py ├── queryRange.py ├── querydays.py ├── reconfig.py ├── remove.py └── removeRange.py ├── sift1M.cpp ├── testRemove.cpp └── testSearchRange.cpp /.gitattributes: -------------------------------------------------------------------------------- 1 | deps/libopenblas.tar.gz filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.0.2) 2 | project (Searcher) 3 | 4 | # The project version number. 5 | set(VERSION_MAJOR 0 CACHE STRING "Project major version number.") 6 | set(VERSION_MINOR 1 CACHE STRING "Project minor version number.") 7 | set(VERSION_PATCH 0 CACHE STRING "Project patch version number.") 8 | mark_as_advanced(VERSION_MAJOR VERSION_MINOR VERSION_PATCH) 9 | message(STATUS "Build ${CMAKE_PROJECT_NAME} ${VERSION_MAJOR}.${VERSION_MINOR}.${VERSION_PATCH}") 10 | 11 | add_definitions(-DVERSION_MAJOR=${VERSION_MAJOR}) 12 | add_definitions(-DVERSION_MINOR=${VERSION_MINOR}) 13 | add_definitions(-DVERSION_PATCH=${VERSION_PATCH}) 14 | 15 | # Cmake scripts directory 16 | set(VERIFIER_CMAKE_DIR "${CMAKE_CURRENT_LIST_DIR}/cmake" CACHE PATH "The path to the cmake directory") 17 | list(APPEND CMAKE_MODULE_PATH ${VERIFIER_CMAKE_DIR}) 18 | 19 | # Options 20 | OPTION(BUILD_TEST "Build Tests." ON) 21 | OPTION(BUILD_DOCS "Build documentation" ON) 22 | enable_testing() 23 | 24 | # Find cuda 25 | find_package(CUDA QUIET) 26 | include_directories("${CUDA_INCLUDE_DIRS}") 27 | 28 | # Include cmake scripts 29 | include(CompilerSettings) 30 | include(ProjectPistache) 31 | include(ProjectFaiss) 32 | 33 | # Set binary path 34 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) 35 | add_subdirectory (src) 36 | add_subdirectory (test) 37 | add_subdirectory (docs) 38 | 39 | # Install 40 | install(PROGRAMS ${CMAKE_BINARY_DIR}/bin/searcher DESTINATION bin) 41 | # install(DIRECTORY model DESTINATION /etc/searcher) 42 | # install(FILES config.json DESTINATION /etc/searcher) 43 | 44 | # cpack config, make package 45 | set(CPACK_PACKAGE_VERSION_MAJOR ${VERSION_MAJOR}) 46 | set(CPACK_PACKAGE_VERSION_MINOR ${VERSION_MINOR}) 47 | set(CPACK_PACKAGE_VERSION_PATCH ${VERSION_PATCH}) 48 | set(CPACK_GENERATOR "DEB") 49 | set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) 50 | set(CPACK_DEBIAN_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) 51 | set(CPACK_DEBIAN_PACKAGE_ARCHITECTURE "amd64") 52 | set(CPACK_PACKAGE_CONTACT "support@xxx.cn") 53 | set(CPACK_DEBIAN_PACKAGE_DEPENDS "libc6, libstdc++6, libopencv-dev(>=2.4.0), libopencv-dev(<=2.4.13)") 54 | include(CPack) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于faiss的检索服务 2 | 3 | 4 | 5 | - [基于faiss的检索服务](#基于faiss的检索服务) 6 | - [1 概述](#1-概述) 7 | - [2 使用](#2-使用) 8 | - [2.1 安装依赖](#21-安装依赖) 9 | - [2.2 编译运行](#22-编译运行) 10 | - [2.3 版本发布](#23-版本发布) 11 | - [3 文档](#3-文档) 12 | - [3.1 项目文档](#31-项目文档) 13 | - [4 依赖项目](#4-依赖项目) 14 | 15 | 16 | 17 | ## 1 概述 18 | 19 | **一个基于 faiss 的每日构建索引的检索服务** 20 | 21 | 主要功能包括: 22 | 23 | - 添加向量至索引结构 24 | - 检索向量 25 | - topK Search 26 | - Approximate Nearest Neighbor Search 27 | - 按日期检索 28 | - 删除指定id向量或范围内向量 29 | - 重新配置 30 | 31 | |目录|说明| 32 | |:--:|:--:| 33 | |cmake|cmake脚本文件| 34 | |docs|项目相关文档| 35 | |scripts|项目相关的脚本| 36 | |src|项目源码(按模块组织)| 37 | |test|项目测试使用的脚本| 38 | 39 | ## 2 使用 40 | 41 | ### 2.1 安装依赖 42 | 43 | ```bash 44 | # sudo apt-get install libopenblas-dev 45 | # clone project 46 | $ git clone https://github.com/FlYWMe/SearchServer.git 47 | $ cd SearchServer 48 | ``` 49 | 50 | ### 2.2 编译运行 51 | 52 | ```bash 53 | # require cmake3 54 | $ mkdir build 55 | $ cd build 56 | $ cmake .. 57 | $ make -j"$(nproc)" 58 | 59 | # run 60 | $ ./bin/queryServer 61 | 62 | # unit test 63 | $ ./bin/FaissCPUSearch # CPU flat暴力搜索采用最大堆实现 64 | $ ./bin/FaissGPUSearch # GPU flat warpSlect 65 | $ ./bin/HNSWSearch 66 | $ ./bin/sift1M # datasets:http://corpus-texmex.irisa.fr/ sift1M GPU flat recall @1:99.19% @10:1 @100:1 67 | 68 | # function test 69 | $ ./bin/FaissLoadTest 70 | $ ./bin/testRemove 71 | $ ./bin/testSearchRange 72 | ``` 73 | 74 | **PS:** 编译要求gcc版本4.9以上,当配置文件不存在时(首次运行),会使用默认的配置参数,并在当前目录下自动生成配置文件`config.json` 75 | 76 | ### 2.3 版本发布 77 | 78 | - 修改版本号:修改最外层`CMakeLists.txt`文件中的`VERSION_MAJOR.VERSION_MINOR.VERSION_PATCH`值为相应版本号 79 | - 合并分支:发起PR将当前分支合并到develop分支,测试通过后合并develop到master分支。 80 | - 添加标签:以相应的版本号新建标签。 81 | 82 | >**PS**:[版本号命名规则](https://semver.org/) 83 | 84 | ## 3 文档 85 | 86 | ### 3.1 项目文档 87 | 88 | ```bash 89 | # 生成文档位于build/documents 90 | $ make doc 91 | ``` 92 | 93 | 1. [代码规范](docs/Code-Style.md) 94 | 1. [API说明](docs/API.md) 95 | 1. [配置说明](docs/config.md) 96 | 97 | ## 4 依赖项目 98 | 99 | - [`oktal/pistache`](https://github.com/oktal/pistache) 100 | - [`boost`](http://www.boost.org/) 101 | - [`nlohmann/json`](https://github.com/nlohmann/json) 102 | - [`muflihun/easyloggingpp`](https://github.com/muflihun/easyloggingpp) 103 | -------------------------------------------------------------------------------- /cmake/CompilerSettings.cmake: -------------------------------------------------------------------------------- 1 | # Clang seeks to be command-line compatible with GCC as much as possible, so 2 | # most of our compiler settings are common between GCC and Clang. 3 | # 4 | # These settings then end up spanning all POSIX platforms (Linux, OS X, BSD, etc) 5 | 6 | # Use ccache if available 7 | find_program(CCACHE_FOUND ccache) 8 | if(CCACHE_FOUND) 9 | set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) 10 | set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) 11 | message("Using ccache") 12 | endif(CCACHE_FOUND) 13 | 14 | # Find cuda 15 | # find_package(CUDA QUIET) 16 | # if(NOT CUDA_FOUND) 17 | # set(USE_CUDA OFF) 18 | # message(STATUS "Build ${CMAKE_PROJECT_NAME} without CUDA support.") 19 | # elseif(CUDA_FOUND) 20 | # message(STATUS "Build ${CMAKE_PROJECT_NAME} with CUDA : " ${CUDA_VERSION}) 21 | # endif() 22 | 23 | if (("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")) 24 | 25 | # set default build type Release 26 | if(NOT CMAKE_BUILD_TYPE) 27 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING 28 | "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." FORCE) 29 | endif(NOT CMAKE_BUILD_TYPE) 30 | 31 | # Use ISO C++11 standard language. 32 | include(CheckCXXCompilerFlag) 33 | CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORTS_CXX11) 34 | CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORTS_CXX0X) 35 | if(COMPILER_SUPPORTS_CXX11) 36 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -fPIC -fopenmp") 37 | elseif(COMPILER_SUPPORTS_CXX0X) 38 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x -fPIC -fopenmp") 39 | else() 40 | message(FATAL_ERROR "The compiler ${CMAKE_CXX_COMPILER} has no C++11 support. Please use a different C++ compiler.") 41 | endif() 42 | add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) 43 | 44 | # Enables all the warnings about constructions that some users consider questionable, 45 | # and that are easy to avoid. Also enable some extra warning flags that are not 46 | # enabled by -Wall. Finally, treat at warnings-as-errors, which forces developers 47 | # to fix warnings as they arise, so they don't accumulate "to be fixed later". 48 | add_compile_options(-Wall) 49 | add_compile_options(-Wno-unused-variable) 50 | add_compile_options(-Wunused-parameter) 51 | add_compile_options(-Wno-unused-function) 52 | add_compile_options(-Wextra) 53 | #add_compile_options(-Werror) 54 | 55 | # Disable warnings about unknown pragmas (which is enabled by -Wall). 56 | add_compile_options(-Wno-unknown-pragmas) 57 | 58 | add_compile_options(-fno-omit-frame-pointer) 59 | 60 | # Configuration-specific compiler settings. 61 | set(CMAKE_CXX_FLAGS_DEBUG "-Og -g") 62 | set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG") 63 | set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") 64 | set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g") 65 | 66 | option(USE_LD_GOLD "Use GNU gold linker" ON) 67 | if (USE_LD_GOLD) 68 | execute_process(COMMAND ${CMAKE_C_COMPILER} -fuse-ld=gold -Wl,--version ERROR_QUIET OUTPUT_VARIABLE LD_VERSION) 69 | if ("${LD_VERSION}" MATCHES "GNU gold") 70 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=gold") 71 | set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fuse-ld=gold") 72 | endif () 73 | endif () 74 | 75 | # Additional GCC-specific compiler settings. 76 | if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") 77 | 78 | # Check that we've got GCC 4.8 or newer. 79 | execute_process( 80 | COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION) 81 | if (NOT (GCC_VERSION VERSION_GREATER 4.8 OR GCC_VERSION VERSION_EQUAL 4.8)) 82 | message(FATAL_ERROR "${PROJECT_NAME} requires g++ 4.8 or greater.") 83 | endif () 84 | 85 | # Strong stack protection was only added in GCC 4.9. 86 | # Use it if we have the option to do so. 87 | # See https://lwn.net/Articles/584225/ 88 | if (GCC_VERSION VERSION_GREATER 4.9 OR GCC_VERSION VERSION_EQUAL 4.9) 89 | add_compile_options(-fstack-protector-strong) 90 | add_compile_options(-fstack-protector) 91 | endif() 92 | 93 | # Additional Clang-specific compiler settings. 94 | elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") 95 | 96 | add_compile_options(-fstack-protector) 97 | 98 | # Enable strong stack protection only on Mac and only for OS X Yosemite 99 | # or newer (AppleClang 7.0+). We should be able to re-enable this setting 100 | # on non-Apple Clang as well, if we can work out what expression to use for 101 | # the version detection. 102 | 103 | # The fact that the version-reporting for AppleClang loses the original 104 | # Clang versioning is rather annoying. Ideally we could just have 105 | # a single cross-platform "if version >= 3.4.1" check. 106 | # 107 | # There is debug text in the else clause below, to help us work out what 108 | # such an expression should be, if we can get this running on a Trusty box 109 | # with Clang. Greg Colvin previously replicated the issue there too. 110 | # 111 | # See https://github.com/ethereum/webthree-umbrella/issues/594 112 | 113 | if (APPLE) 114 | if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0 OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) 115 | add_compile_options(-fstack-protector-strong) 116 | endif() 117 | else() 118 | message(WARNING "CMAKE_CXX_COMPILER_VERSION = ${CMAKE_CXX_COMPILER_VERSION}") 119 | endif() 120 | 121 | # Some Linux-specific Clang settings. We don't want these for OS X. 122 | if ("${CMAKE_SYSTEM_NAME}" MATCHES "Linux") 123 | 124 | # Tell Boost that we're using Clang's libc++. Not sure exactly why we need to do. 125 | add_definitions(-DBOOST_ASIO_HAS_CLANG_LIBCXX) 126 | 127 | # Use fancy colors in the compiler diagnostics 128 | add_compile_options(-fcolor-diagnostics) 129 | endif() 130 | endif() 131 | 132 | # The major alternative compiler to GCC/Clang is Microsoft's Visual C++ compiler, only available on Windows. 133 | elseif (MSVC) 134 | 135 | add_compile_options(/MP) # enable parallel compilation 136 | add_compile_options(/EHsc) # specify Exception Handling Model in msvc 137 | add_compile_options(/WX) # enable warnings-as-errors 138 | add_compile_options(/wd4068) # disable unknown pragma warning (4068) 139 | add_compile_options(/wd4996) # disable unsafe function warning (4996) 140 | add_compile_options(/wd4503) # disable decorated name length exceeded, name was truncated (4503) 141 | add_compile_options(/wd4267) # disable conversion from 'size_t' to 'type', possible loss of data (4267) 142 | add_compile_options(/wd4180) # disable qualifier applied to function type has no meaning; ignored (4180) 143 | add_compile_options(/wd4290) # disable C++ exception specification ignored except to indicate a function is not __declspec(nothrow) (4290) 144 | add_compile_options(/wd4297) # disable 's function assumed not to throw an exception but does (4297) 145 | add_compile_options(/wd4244) # disable conversion from 'type1' to 'type2', possible loss of data (4244) 146 | add_compile_options(/wd4800) # disable forcing value to bool 'true' or 'false' (performance warning) (4800) 147 | add_compile_options(-D_WIN32_WINNT=0x0600) # declare Windows Vista API requirement 148 | add_compile_options(-DNOMINMAX) # undefine windows.h MAX && MIN macros cause it cause conflicts with std::min && std::max functions 149 | add_compile_options(-DMINIUPNP_STATICLIB) # define miniupnp static library 150 | 151 | # Always use Release variant of C++ runtime. 152 | # We don't want to provide Debug variants of all dependencies. Some default 153 | # flags set by CMake must be tweaked. 154 | string(REPLACE "/MDd" "/MD" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") 155 | string(REPLACE "/D_DEBUG" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") 156 | string(REPLACE "/RTC1" "" CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG}") 157 | string(REPLACE "/MDd" "/MD" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") 158 | string(REPLACE "/D_DEBUG" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") 159 | string(REPLACE "/RTC1" "" CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}") 160 | set_property(GLOBAL PROPERTY DEBUG_CONFIGURATIONS OFF) 161 | 162 | # disable empty object file warning 163 | set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /ignore:4221") 164 | # warning LNK4075: ignoring '/EDITANDCONTINUE' due to '/SAFESEH' specification 165 | # warning LNK4099: pdb was not found with lib 166 | # stack size 16MB 167 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /ignore:4099,4075 /STACK:16777216") 168 | 169 | # If you don't have GCC, Clang or VC++ then you are on your own. Good luck! 170 | else () 171 | message(WARNING "Your compiler is not tested, if you run into any issues, we'd welcome any patches.") 172 | endif () 173 | 174 | if (SANITIZE) 175 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer -fsanitize=${SANITIZE}") 176 | if (${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") 177 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize-blacklist=${CMAKE_SOURCE_DIR}/sanitizer-blacklist.txt") 178 | endif() 179 | endif() 180 | 181 | if (PROFILING AND (("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang"))) 182 | set(CMAKE_CXX_FLAGS "-g ${CMAKE_CXX_FLAGS}") 183 | set(CMAKE_C_FLAGS "-g ${CMAKE_C_FLAGS}") 184 | # add_definitions(-DETH_PROFILING_GPERF) 185 | set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -lprofiler") 186 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -lprofiler") 187 | endif () 188 | 189 | if (COVERAGE) 190 | set(CMAKE_CXX_FLAGS "-g --coverage ${CMAKE_CXX_FLAGS}") 191 | set(CMAKE_C_FLAGS "-g --coverage ${CMAKE_C_FLAGS}") 192 | set(CMAKE_SHARED_LINKER_FLAGS "--coverage ${CMAKE_SHARED_LINKER_FLAGS}") 193 | set(CMAKE_EXE_LINKER_FLAGS "--coverage ${CMAKE_EXE_LINKER_FLAGS}") 194 | find_program(LCOV_TOOL lcov) 195 | message(STATUS "lcov tool: ${LCOV_TOOL}") 196 | if (LCOV_TOOL) 197 | add_custom_target(coverage.info 198 | COMMAND ${LCOV_TOOL} -o ${CMAKE_BINARY_DIR}/coverage.info -c -d ${CMAKE_BINARY_DIR} 199 | COMMAND ${LCOV_TOOL} -o ${CMAKE_BINARY_DIR}/coverage.info -r ${CMAKE_BINARY_DIR}/coverage.info '/usr*' '${CMAKE_BINARY_DIR}/deps/*' '${CMAKE_SOURCE_DIR}/deps/*' 200 | ) 201 | endif() 202 | endif () 203 | -------------------------------------------------------------------------------- /cmake/ProjectBoost.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | include(GNUInstallDirs) 3 | 4 | set(BOOST_CXXFLAGS "cxxflags=-fPIC") 5 | if (WIN32) 6 | set(BOOST_BOOTSTRAP_COMMAND bootstrap.bat) 7 | set(BOOST_BUILD_TOOL b2.exe) 8 | set(BOOST_LIBRARY_SUFFIX -vc140-mt-1_63.lib) 9 | else() 10 | set(BOOST_BOOTSTRAP_COMMAND ./bootstrap.sh) 11 | set(BOOST_BUILD_TOOL ./b2) 12 | set(BOOST_LIBRARY_SUFFIX .a) 13 | # if (${BUILD_SHARED_LIBS}) 14 | # set(BOOST_CXXFLAGS "cxxflags=-fPIC") 15 | # endif() 16 | endif() 17 | 18 | ExternalProject_Add(boost 19 | PREFIX ${CMAKE_BINARY_DIR}/deps 20 | DOWNLOAD_DIR ${CMAKE_SOURCE_DIR}/deps 21 | DOWNLOAD_NO_PROGRESS 1 22 | URL https://dl.bintray.com/boostorg/release/1.65.1/source/boost_1_65_1.tar.gz 23 | URL_HASH SHA256=a13de2c8fbad635e6ba9c8f8714a0e6b4264b60a29b964b940a22554705b6b60 24 | BUILD_IN_SOURCE 1 25 | CONFIGURE_COMMAND ${BOOST_BOOTSTRAP_COMMAND} 26 | BUILD_COMMAND ${BOOST_BUILD_TOOL} install 27 | ${BOOST_CXXFLAGS} 28 | threading=multi 29 | link=static 30 | variant=release 31 | address-model=64 32 | --prefix= 33 | --with-chrono 34 | --with-date_time 35 | --with-system 36 | --with-filesystem 37 | --with-random 38 | --with-regex 39 | --with-test 40 | --with-thread 41 | --with-serialization 42 | LOG_DOWNLOAD 1 43 | LOG_CONFIGURE 1 44 | LOG_BUILD 1 45 | LOG_INSTALL 1 46 | INSTALL_COMMAND "" 47 | ) 48 | 49 | ExternalProject_Get_Property(boost SOURCE_DIR INSTALL_DIR) 50 | set(BOOST_INCLUDE_DIR ${INSTALL_DIR}/include) 51 | set(BOOST_LIB_DIR ${INSTALL_DIR}/lib) 52 | unset(BUILD_DIR) 53 | 54 | add_library(Boost::Chrono STATIC IMPORTED) 55 | set_property(TARGET Boost::Chrono PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_chrono${BOOST_LIBRARY_SUFFIX}) 56 | add_dependencies(Boost::Chrono boost) 57 | 58 | add_library(Boost::DataTime STATIC IMPORTED) 59 | set_property(TARGET Boost::DataTime PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_date_time${BOOST_LIBRARY_SUFFIX}) 60 | add_dependencies(Boost::DataTime boost) 61 | 62 | # add_library(Boost::Regex STATIC IMPORTED) 63 | # set_property(TARGET Boost::Regex PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_regex${BOOST_LIBRARY_SUFFIX}) 64 | # add_dependencies(Boost::Regex boost) 65 | 66 | add_library(Boost::System STATIC IMPORTED) 67 | set_property(TARGET Boost::System PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_system${BOOST_LIBRARY_SUFFIX}) 68 | add_dependencies(Boost::System boost) 69 | 70 | add_library(Boost::Filesystem STATIC IMPORTED) 71 | set_property(TARGET Boost::Filesystem PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_filesystem${BOOST_LIBRARY_SUFFIX}) 72 | set_property(TARGET Boost::Filesystem PROPERTY INTERFACE_LINK_LIBRARIES Boost::System) 73 | add_dependencies(Boost::Filesystem boost) 74 | 75 | # add_library(Boost::Random STATIC IMPORTED) 76 | # set_property(TARGET Boost::Random PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_random${BOOST_LIBRARY_SUFFIX}) 77 | # add_dependencies(Boost::Random boost) 78 | 79 | # add_library(Boost::UnitTestFramework STATIC IMPORTED) 80 | # set_property(TARGET Boost::UnitTestFramework PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_unit_test_framework${BOOST_LIBRARY_SUFFIX}) 81 | # add_dependencies(Boost::UnitTestFramework boost) 82 | 83 | add_library(Boost::Thread STATIC IMPORTED) 84 | set_property(TARGET Boost::Thread PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_thread${BOOST_LIBRARY_SUFFIX}) 85 | set_property(TARGET Boost::Thread PROPERTY INTERFACE_LINK_LIBRARIES Boost::Chrono Boost::DataTime) 86 | # set_property(TARGET Boost::Thread PROPERTY INTERFACE_LINK_LIBRARIES Boost::Chrono Boost::DataTime Boost::Regex) 87 | add_dependencies(Boost::Thread boost) 88 | 89 | add_library(Boost::Serialization STATIC IMPORTED) 90 | set_property(TARGET Boost::Serialization PROPERTY IMPORTED_LOCATION ${BOOST_LIB_DIR}/libboost_serialization${BOOST_LIBRARY_SUFFIX}) 91 | # set_property(TARGET Boost::Serialization PROPERTY INTERFACE_LINK_LIBRARIES) 92 | add_dependencies(Boost::Serialization boost) 93 | 94 | unset(INSTALL_DIR) 95 | unset(SOURCE_DIR) 96 | -------------------------------------------------------------------------------- /cmake/ProjectFaiss.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | if (${CMAKE_SYSTEM_NAME} STREQUAL "Emscripten") 4 | set(FAISS_CMAKE_COMMAND emcmake cmake) 5 | else() 6 | set(FAISS_CMAKE_COMMAND ${CMAKE_COMMAND}) 7 | endif() 8 | 9 | ExternalProject_Add(faiss 10 | PREFIX ${CMAKE_BINARY_DIR}/deps 11 | DOWNLOAD_DIR ${CMAKE_SOURCE_DIR}/deps 12 | DOWNLOAD_NO_PROGRESS 1 13 | DOWNLOAD_NAME faiss-1.0.tar.gz 14 | URL https://github.com/facebookresearch/faiss/archive/v1.2.1.tar.gz 15 | URL_HASH SHA256=0a8d629f86ee728c9c9cd72527027c09fc390963dd3cbdd9675eb577873a2695 16 | BUILD_IN_SOURCE 1 17 | # GIT_REPOSITORY https://github.com/facebookresearch/faiss.git 18 | # GIT_TAG master 19 | # UPDATE_COMMAND "" 20 | CMAKE_COMMAND ${FAISS_CMAKE_COMMAND} 21 | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX= 22 | -DBUILD_TUTORIAL=OFF 23 | -DBUILD_TEST=OFF 24 | -DBUILD_WITH_GPU=${CUDA_FOUND} 25 | -DCMAKE_BUILD_TYPE=Release 26 | LOG_DOWNLOAD 1 27 | LOG_CONFIGURE 1 28 | LOG_BUILD 1 29 | LOG_INSTALL 1 30 | # BUILD_COMMAND "" 31 | INSTALL_COMMAND cmake -E copy_directory lib /lib 32 | ) 33 | 34 | ExternalProject_Get_Property(faiss INSTALL_DIR SOURCE_DIR) 35 | 36 | 37 | add_library(Faiss::CPU STATIC IMPORTED) 38 | set(FAISS_LIBRARY ${INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}faiss${CMAKE_STATIC_LIBRARY_SUFFIX}) 39 | set(FAISS_INCLUDE_DIR ${INSTALL_DIR}/include) 40 | set_property(TARGET Faiss::CPU PROPERTY IMPORTED_LOCATION ${FAISS_LIBRARY}) 41 | set_property(TARGET Faiss::CPU PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${FAISS_INCLUDE_DIR}) 42 | add_dependencies(Faiss::CPU faiss) 43 | 44 | add_library(Faiss::GPU STATIC IMPORTED) 45 | set(FAISS_GPU_LIBRARY ${INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gpufaiss${CMAKE_STATIC_LIBRARY_SUFFIX}) 46 | set_property(TARGET Faiss::GPU PROPERTY IMPORTED_LOCATION ${FAISS_GPU_LIBRARY}) 47 | add_dependencies(Faiss::GPU faiss) 48 | 49 | file(MAKE_DIRECTORY ${FAISS_INCLUDE_DIR}/faiss ${FAISS_INCLUDE_DIR}/faiss/gpu) # Must exist. 50 | ADD_CUSTOM_TARGET(install_header ALL COMMAND ${CMAKE_COMMAND} -E copy ${SOURCE_DIR}/*.h ${INSTALL_DIR}/include/faiss 51 | COMMAND ${CMAKE_COMMAND} -E copy_directory ${SOURCE_DIR}/gpu ${INSTALL_DIR}/include/faiss/gpu) 52 | add_dependencies(install_header faiss) 53 | unset(INSTALL_DIR) 54 | unset(SOURCE_DIR) 55 | 56 | -------------------------------------------------------------------------------- /cmake/ProjectPistache.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | if (${CMAKE_SYSTEM_NAME} STREQUAL "Emscripten") 4 | set(PISTACHE_CMAKE_COMMAND emcmake cmake) 5 | else() 6 | set(PISTACHE_CMAKE_COMMAND ${CMAKE_COMMAND}) 7 | endif() 8 | 9 | ExternalProject_Add(pistache 10 | PREFIX ${CMAKE_BINARY_DIR}/deps 11 | DOWNLOAD_DIR ${CMAKE_SOURCE_DIR}/deps 12 | # SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps/pistache 13 | GIT_REPOSITORY https://github.com/bxq2011hust/pistache.git 14 | GIT_TAG master 15 | UPDATE_COMMAND "" 16 | CMAKE_COMMAND ${PISTACHE_CMAKE_COMMAND} 17 | CMAKE_ARGS -DCMAKE_INSTALL_PREFIX= 18 | # Build static lib but suitable to be included in a shared lib. 19 | -DBUILD_SHARED_LIBS=OFF 20 | -DCMAKE_POSITION_INDEPENDENT_CODE=${BUILD_SHARED_LIBS} 21 | -DCMAKE_BUILD_TYPE=Release 22 | BUILD_COMMAND "" 23 | # INSTALL_COMMAND "" 24 | LOG_CONFIGURE 1 25 | LOG_BUILD 1 26 | LOG_INSTALL 1 27 | LOG_DOWNLOAD 1 28 | ) 29 | 30 | # Create Pistache imported library 31 | ExternalProject_Get_Property(pistache INSTALL_DIR) 32 | add_library(Pistache STATIC IMPORTED) 33 | set(PISTACHE_LIBRARY ${INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}pistache${CMAKE_STATIC_LIBRARY_SUFFIX}) 34 | set(PISTACHE_INCLUDE_DIR ${INSTALL_DIR}/include) 35 | file(MAKE_DIRECTORY ${PISTACHE_INCLUDE_DIR}) # Must exist. 36 | set_property(TARGET Pistache PROPERTY IMPORTED_LOCATION ${PISTACHE_LIBRARY}) 37 | set_property(TARGET Pistache PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${PISTACHE_INCLUDE_DIR}) 38 | add_dependencies(Pistache pistache) 39 | unset(INSTALL_DIR) 40 | 41 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y cmake git software-properties-common && \ 5 | mkdir -p /opt/searcher && \ 6 | cd /opt/searcher && mkdir source && cd source && \ 7 | git clone http://USER:PASSWORD@gitlab.oceanai.com.cn/verifier/searcher.git && \ 8 | cd searcher && \ 9 | apt-get install -y libopenblas-dev libboost-dev && \ 10 | mkdir build && \ 11 | cd build && \ 12 | cmake .. && \ 13 | make -j"$(nproc)" && \ 14 | cd /opt/searcher && \ 15 | cp /opt/searcher/source/searcher/build/bin/queryServer . && \ 16 | apt-get -y remove cmake git software-properties-common && \ 17 | apt-get autoremove -y && \ 18 | apt-get clean && \ 19 | rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ 20 | cd /opt/searcher && rm -rf ./source 21 | 22 | EXPOSE 2333 23 | 24 | CMD cd /opt/searcher && ./queryServer 25 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # usage 2 | 3 | ## Before you start 4 | 5 | 将docker文件夹中的Dockerfile文件中的git clone后面的地址中的USERNAME和PASSWORD改为你的远程仓库用户名和密码。 6 | 7 | ## Prerequisites 8 | 9 | - [docker-17](https://docs.docker.com/install/linux/docker-ce/ubuntu/#install-docker-ce-1) 10 | - [nvidia-docker](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)) 11 | 12 | ## GPU 13 | 14 | ```bash 15 | #current path is SearchServer/docker 16 | # build image 17 | $ docker build -t searcher:gpu -f Dockerfile . 18 | # start container 19 | $ nvidia-docker run -p 2333:2333 searcher:gpu 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Building docs script 2 | # Requirements: 3 | # sudo apt-get install doxygen graphviz 4 | # [Build doxygen from CMake script](https://stackoverflow.com/questions/34878276/build-doxygen-from-cmake-script) 5 | 6 | if(NOT BUILD_DOCS) 7 | return() 8 | endif() 9 | 10 | find_package(Doxygen QUIET) 11 | 12 | if(DOXYGEN_FOUND) 13 | # additional config 14 | set(doxyfile_in ${CMAKE_CURRENT_SOURCE_DIR}/.Doxyfile.in) 15 | set(doxyfile ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile) 16 | configure_file(${doxyfile_in} ${doxyfile} @ONLY) 17 | 18 | # Adding docs target 19 | add_custom_target(doc COMMAND ${DOXYGEN_EXECUTABLE} ${doxyfile} 20 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 21 | COMMENT "Generating documentation with Doxygen..." VERBATIM) 22 | elseif() 23 | message(WARNING "Doxygen is needed to build the documentation. Please install doxygen and graphviz") 24 | endif() -------------------------------------------------------------------------------- /docs/Code-Style.md: -------------------------------------------------------------------------------- 1 | # 规范 2 | 3 | 4 | 5 | - [规范](#规范) 6 | - [1 C++代码规范(Fellow Google)](#1-c代码规范fellow-google) 7 | - [2 Git提交规范](#2-git提交规范) 8 | - [3 版本号管理](#3-版本号管理) 9 | - [4 API](#4-api) 10 | 11 | 12 | 13 | ## 1 C++代码规范(Fellow Google) 14 | 15 | 1. 源文件名/变量命名:小驼峰命名法,一般不使用缩写 16 | 2. 类名使用大驼峰命名法 17 | 3. 命名空间一律小写 18 | 4. 代码注释与文档使用[Doxygen][Doxygen] 19 | 20 | 参考[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html) 21 | 参考[Google C++ Style Guide(cn)](http://zh-google-styleguide.readthedocs.io/en/latest/) 22 | 参考[doxygen/manual cpp](http://www.stack.nl/~dimitri/doxygen/manual/docblocks.html#cppblock) 23 | 24 | ## 2 Git提交规范 25 | 26 | 1. 每个人在自己的分支开发,开发完成后自测,自测后提交。提交信息命名为`add,fix,update`,添加功能时使用`add`,修复bug时使用`fix`,更新已有功能时或不符合前两种时使用`update` 27 | ```bash 28 | # 例如 29 | git commit -m "fix http包长度bug" 30 | git commit -m "add base64解码" 31 | git commit -m "update 更新README.md" 32 | 33 | # 推送到远程仓库 34 | $ git push 35 | ``` 36 | 37 | 2. 发起`PR`经`Review`之后合并入`develop`分支 38 | 39 | ## 3 版本号管理 40 | 41 | 1. 使用[语义版本][Semantic Versioning],版本号`MAJOR.MINOR.PATCH` 42 | 43 | |Version|Detial| 44 | |:-|:-| 45 | |MAJOR version| when you make incompatible API changes| 46 | |MINOR version| when you add functionality in a backwards-compatible manner| 47 | |PATCH version| when you make backwards-compatible bug fixes| 48 | **Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.** 49 | 50 | ## 4 API 51 | 52 | [Microsoft REST API Guidelines](https://github.com/Microsoft/api-guidelines/blob/vNext/Guidelines.md) 53 | 54 | [Semantic Versioning]:https://semver.org/ 55 | [Doxygen]:https://www.stack.nl/~dimitri/doxygen/manual/docblocks.html 56 | 57 | -------------------------------------------------------------------------------- /docs/UserManual.md: -------------------------------------------------------------------------------- 1 | # 用户手册 2 | 3 | 4 | 5 | - [用户手册](#用户手册) 6 | - [1 文件格式说明](#1-文件格式说明) 7 | - [2 API说明](#2-api说明) 8 | - [2.1 添加向量至索引](#21-添加向量至索引) 9 | - [2.1.1 请求](#211-请求) 10 | - [2.1.2 返回](#212-返回) 11 | - [2.1.3 成功请求返回值示例](#213-成功请求返回值示例) 12 | - [2.1.4 失败请求返回值示例](#214-失败请求返回值示例) 13 | - [2.1.5 当前 API 特有的 ERROR_MESSAGE](#215-当前-api-特有的-error_message) 14 | - [2.2 查询向量](#22-查询向量) 15 | - [2.2.1 请求](#221-请求) 16 | - [2.2.2 返回](#222-返回) 17 | - [2.2.3 成功请求返回值示例](#223-成功请求返回值示例) 18 | - [2.2.4 失败请求返回值示例](#224-失败请求返回值示例) 19 | - [2.2.5 当前 API 特有的 ERROR_MESSAGE](#225-当前-api-特有的-error_message) 20 | - [2.3 删除向量标识](#23-删除向量标识) 21 | - [2.3.1 请求](#231-请求) 22 | - [2.3.2 返回](#232-返回) 23 | - [2.3.3 成功请求返回值示例](#233-成功请求返回值示例) 24 | - [2.3.4 失败请求返回值示例](#234-失败请求返回值示例) 25 | - [2.3.5 当前 API 特有的 ERROR_MESSAGE](#235-当前-api-特有的-error_message) 26 | - [2.4 删除范围内的向量标识](#24-删除范围内的向量标识) 27 | - [2.4.1 请求](#241-请求) 28 | - [2.4.2 返回](#242-返回) 29 | - [2.4.3 成功请求返回值示例](#243-成功请求返回值示例) 30 | - [2.4.4 失败请求返回值示例](#244-失败请求返回值示例) 31 | - [2.4.5 当前 API 特有的 ERROR_MESSAGE](#245-当前-api-特有的-error_message) 32 | - [2.5 重新配置和加载索引文件](#25-重新配置和加载索引文件) 33 | - [2.5.1 请求](#251-请求) 34 | - [2.5.2 返回](#252-返回) 35 | - [2.5.3 成功请求返回值示例](#253-成功请求返回值示例) 36 | - [2.5.4 失败请求返回值示例](#254-失败请求返回值示例) 37 | - [2.5.5 当前 API 特有的 ERROR_MESSAGE](#255-当前-api-特有的-error_message) 38 | - [2.6 查询范围内向量](#26-查询范围内向量) 39 | - [2.6.1 请求](#261-请求) 40 | - [2.6.2 返回](#262-返回) 41 | - [2.6.3 成功请求返回值示例](#263-成功请求返回值示例) 42 | - [2.6.4 失败请求返回值示例](#264-失败请求返回值示例) 43 | - [2.6.5 当前 API 特有的 ERROR_MESSAGE](#265-当前-api-特有的-error_message) 44 | - [2.7 按照日期查询向量](#27-按照日期查询向量) 45 | - [2.7.1 请求](#271-请求) 46 | - [2.7.2 返回](#272-返回) 47 | - [2.7.3 成功请求返回值示例](#273-成功请求返回值示例) 48 | - [2.7.4 失败请求返回值示例](#274-失败请求返回值示例) 49 | - [2.7.5 当前 API 特有的 ERROR_MESSAGE](#275-当前-api-特有的-error_message) 50 | - [2.8 通用的 ERROR_MESSAGE](#28-通用的-error_message) 51 | 52 | 53 | 54 | ## 1 文件格式说明 55 | 56 | 文件中前20个字节为文件头,前12个字节顺序存储向量维数、向量个数、文件格式版本号,存储类型为 `unsigned int`,文件头最后8个字节为空。 57 | 58 | 文件头后面顺序存储 `unsigned int`类型的向量id,以及对应的 `vector`类型的特征向量feature。 59 | 60 | | 名称 | 类型 | 参数说明 | 61 | | :----: | :--: | :------: | 62 | | header[0] | `unsigned int` | 向量维数 | 63 | | header[1] | `unsigned int` | 向量个数 | 64 | | header[2] | `unsigned int` | 文件格式版本号 | 65 | | id | `unsigned int` | 向量唯一标识 | 66 | | feature | `vector` | 特征向量 | 67 | 68 | ## 2 API说明 69 | 70 | ## 2.1 添加向量至索引 71 | 72 | ### 2.1.1 请求 73 | 74 | 调用地址:http://127.0.0.1/add 75 | 请求方式:POST 76 | 请求类型:application/json 77 | 78 | | 是否必选 | 参数名 | 类型 | 参数说明 | 79 | | :------: | :----: | :--: | :------: | 80 | | 必选 | ntotal | `Int` | 需要添加的向量个数 | 81 | | 必选 | data | `Object` | 向量唯一标识及对应特征向量 | 82 | 83 | **请求参数结构示例** 84 | 85 | ```json 86 | { 87 | "ntotal": 100, 88 | "data": 89 | { 90 | "0": [0.1,...,0.5], 91 | ..., 92 | "99": [0.1,...,0.5] 93 | } 94 | } 95 | ``` 96 | 97 | ### 2.1.2 返回 98 | 99 | 返回类型:JSON 100 | 101 | | 参数名 | 类型 | 参数说明 | 102 | | :----: | :--: | :------: | 103 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 104 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 105 | 106 | ### 2.1.3 成功请求返回值示例 107 | 108 | ```json 109 | { 110 | "time_used": 50, 111 | "error_message": "ADD_SUCCESS" 112 | } 113 | ``` 114 | 115 | ### 2.1.4 失败请求返回值示例 116 | 117 | ```json 118 | { 119 | "time_used": 0, 120 | "error_message": "INVALID_NTOTAL" 121 | } 122 | ``` 123 | 124 | ### 2.1.5 当前 API 特有的 ERROR_MESSAGE 125 | 126 | | HTTP状态代码 | 错误信息 | 说明 | 127 | | :----: | :--: | :------: | 128 | | 400 | INVALID_NTOTAL | 参数ntotal格式不正确 | 129 | | 400 | INVALID_DATA | 参数data格式不正确 | 130 | | 400 | EMPTY_DATA | 参数data中没有向量 | 131 | 132 | ## 2.2 查询向量 133 | 134 | ### 2.2.1 请求 135 | 136 | 调用地址:http://127.0.0.1/search 137 | 请求方式:POST 138 | 请求类型:application/json 139 | 140 | | 是否必选 | 参数名 | 类型 | 参数说明 | 141 | | :------: | :----: | :--: | :------: | 142 | | 必选 | qtotal | `Int` | 需要查询的向量个数 | 143 | | 必选 | topk | `Int` | 查询最近邻的个数 | 144 | | 必选 | queries | `Object` | 需要查询的向量标识及对应特征向量 | 145 | 146 | **请求参数结构示例** 147 | 148 | ```json 149 | { 150 | "qtotal": 10, 151 | "topk": 5, 152 | "queries": 153 | { 154 | "q0": [0.1,...,0.5], 155 | ..., 156 | "q9": [0.1,...,0.5] 157 | } 158 | } 159 | ``` 160 | 161 | ### 2.2.2 返回 162 | 163 | 返回类型:JSON 164 | 165 | | 参数名 | 类型 | 参数说明 | 166 | | :----: | :--: | :------: | 167 | | result | `Object` | 每个查询向量的k个最近邻距离及标识 | 168 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 169 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 170 | 171 | ### 2.2.3 成功请求返回值示例 172 | 173 | ```json 174 | { 175 | "result": 176 | { 177 | "q0": {"distance": [77.55, 78.39], "labels": [80, 1]}, 178 | "q1": {"distance": [67.94, 71.17], "labels": [70, 49]}, 179 | ..., 180 | "q9": {"distance": [73.63, 73.78], "labels": [47, 28]} 181 | } 182 | "time_used": 50, 183 | "error_message": "SEARCH_SUCCESS" 184 | } 185 | ``` 186 | 187 | ### 2.2.4 失败请求返回值示例 188 | 189 | ```json 190 | { 191 | "time_used": 0, 192 | "error_message": "EMPTY_QUERIES" 193 | } 194 | ``` 195 | 196 | ### 2.2.5 当前 API 特有的 ERROR_MESSAGE 197 | 198 | | HTTP状态代码 | 错误信息 | 说明 | 199 | | :----: | :--: | :------: | 200 | | 400 | INVALID_QTOTAL | 参数qtotal格式不正确 | 201 | | 400 | INVALID_TOPK | 参数topk格式不正确 | 202 | | 400 | INVALID_QUERIES | 参数queries格式不正确 | 203 | | 400 | EMPTY_QUERIES | 参数queries中没有向量 | 204 | 205 | ## 2.3 删除向量标识 206 | 207 | ### 2.3.1 请求 208 | 209 | 调用地址:http://127.0.0.1/delete 210 | 请求方式:POST 211 | 请求类型:application/json 212 | 213 | | 是否必选 | 参数名 | 类型 | 参数说明 | 214 | | :------: | :----: | :--: | :------: | 215 | | 必选 | ids | `Array` | 需要删除的向量标识列表 | 216 | 217 | **请求参数结构示例** 218 | 219 | ```json 220 | { 221 | "ids": [1,...,5] 222 | } 223 | ``` 224 | 225 | ### 2.3.2 返回 226 | 227 | 返回类型:JSON 228 | 229 | | 参数名 | 类型 | 参数说明 | 230 | | :----: | :--: | :------: | 231 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 232 | | ndelete | `Int` | 成功删除的向量个数 | 233 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 234 | 235 | ### 2.3.3 成功请求返回值示例 236 | 237 | ```json 238 | { 239 | "time_used": 10, 240 | "ndelete": 10, 241 | "error_message": "DELETE_SUCCESS" 242 | } 243 | ``` 244 | 245 | ### 2.3.4 失败请求返回值示例 246 | 247 | ```json 248 | { 249 | "time_used": 0, 250 | "ndelete": 0, 251 | "error_message": "INVALID_IDS" 252 | } 253 | ``` 254 | 255 | ### 2.3.5 当前 API 特有的 ERROR_MESSAGE 256 | 257 | | HTTP状态代码 | 错误信息 | 说明 | 258 | | :----: | :--: | :------: | 259 | | 400 | INVALID_IDS | 参数ids格式不正确 | 260 | | 400 | EMPTY_DATA | 参数ids为空 | 261 | 262 | ## 2.4 删除范围内的向量标识 263 | 264 | ### 2.4.1 请求 265 | 266 | 调用地址:http://127.0.0.1/deleteRange 267 | 请求方式:POST 268 | 请求类型:application/json 269 | 270 | | 是否必选 | 参数名 | 类型 | 参数说明 | 271 | | :------: | :----: | :--: | :------: | 272 | | 必选 | start | `Int` | 需要删除的起始向量标识 | 273 | | 必选 | end | `Int` | 需要删除的结尾向量标识 | 274 | 275 | **请求参数结构示例** 276 | 277 | ```json 278 | { 279 | "start": 1, 280 | "end": 5 281 | } 282 | ``` 283 | 284 | ### 2.4.2 返回 285 | 286 | 返回类型:JSON 287 | 288 | | 参数名 | 类型 | 参数说明 | 289 | | :----: | :--: | :------: | 290 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 291 | | ndelete | `Int` | 成功删除的向量个数 | 292 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 293 | 294 | ### 2.4.3 成功请求返回值示例 295 | 296 | ```json 297 | { 298 | "time_used": 10, 299 | "ndelete": 10, 300 | "error_message": "DELETERANGE_SUCCESS" 301 | } 302 | ``` 303 | 304 | ### 2.4.4 失败请求返回值示例 305 | 306 | ```json 307 | { 308 | "time_used": 0, 309 | "ndelete": 0, 310 | "error_message": "INVALID_RANGE" 311 | } 312 | ``` 313 | 314 | ### 2.4.5 当前 API 特有的 ERROR_MESSAGE 315 | 316 | | HTTP状态代码 | 错误信息 | 说明 | 317 | | :----: | :--: | :------: | 318 | | 400 | INVALID_RANGE | 参数start或end格式不正确 | 319 | 320 | ## 2.5 重新配置和加载索引文件 321 | 322 | ### 2.5.1 请求 323 | 324 | 调用地址:http://127.0.0.1/reconfig 325 | 请求方式:POST 326 | 请求类型:application/json 327 | 328 | | 是否必选 | 参数名 | 类型 | 参数说明 | 329 | | :------: | :----: | :--: | :------: | 330 | | 必选 | reconfigFilePath | `String` | 需要重新加载的配置文件路径 | 331 | 332 | **请求参数结构示例** 333 | 334 | ```json 335 | { 336 | "reconfigFilePath": "config.json" 337 | } 338 | ``` 339 | 340 | ### 2.5.2 返回 341 | 342 | 返回类型:JSON 343 | 344 | | 参数名 | 类型 | 参数说明 | 345 | | :----: | :--: | :------: | 346 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 347 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 348 | 349 | ### 2.5.3 成功请求返回值示例 350 | 351 | ```json 352 | { 353 | "time_used": 350, 354 | "error_message": "RECONFIG_SUCCESS" 355 | } 356 | ``` 357 | 358 | ### 2.5.4 失败请求返回值示例 359 | 360 | ```json 361 | { 362 | "time_used": 0, 363 | "error_message": "INVALID_RECONFIG_FILEPATH" 364 | } 365 | ``` 366 | 367 | ### 2.5.5 当前 API 特有的 ERROR_MESSAGE 368 | 369 | | HTTP状态代码 | 错误信息 | 说明 | 370 | | :----: | :--: | :------: | 371 | | 400 | INVALID_RECONFIG_FILEPATH | 参数reconfigFilePath格式不正确 | 372 | | 400 | RELOAD_FAIL | 重新加载数据失败 | 373 | | 400 | READD_FAIL | 重新构建索引结构失败 | 374 | 375 | ## 2.6 查询范围内向量 376 | 377 | ### 2.6.1 请求 378 | 379 | 调用地址:http://127.0.0.1/searchRange 380 | 请求方式:POST 381 | 请求类型:application/json 382 | 383 | | 是否必选 | 参数名 | 类型 | 参数说明 | 384 | | :------: | :----: | :--: | :------: | 385 | | 必选 | nq | `Int` | 需要查询的向量个数 | 386 | | 必选 | radius | `Number` | 查询最近邻的范围阈值 | 387 | | 必选 | queries | `Object` | 需要查询的向量标识及对应特征向量 | 388 | 389 | **请求参数结构示例** 390 | 391 | ```json 392 | { 393 | "nq": 10, 394 | "radius": 0.5, 395 | "queries": 396 | { 397 | "q0": [0.1,...,0.5], 398 | ..., 399 | "q9": [0.1,...,0.5] 400 | } 401 | } 402 | ``` 403 | 404 | ### 2.6.2 返回 405 | 406 | 返回类型:JSON 407 | 408 | | 参数名 | 类型 | 参数说明 | 409 | | :----: | :--: | :------: | 410 | | result | `Object` | 每个查询向量的最近邻距离及标识 | 411 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 412 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 413 | 414 | ### 2.6.3 成功请求返回值示例 415 | 416 | ```json 417 | { 418 | "result": 419 | { 420 | "q0": {"distance": [77.55, 78.39], "labels": [80, 1]}, 421 | "q1": {"distance": [67.94, 71.17], "labels": [70, 49]}, 422 | ..., 423 | "q9": {"distance": [73.63, 73.78], "labels": [47, 28]} 424 | } 425 | "time_used": 50, 426 | "error_message": "SEARCH_RANGE_SUCCESS" 427 | } 428 | ``` 429 | 430 | ### 2.6.4 失败请求返回值示例 431 | 432 | ```json 433 | { 434 | "time_used": 0, 435 | "error_message": "EMPTY_QUERIES" 436 | } 437 | ``` 438 | 439 | ### 2.6.5 当前 API 特有的 ERROR_MESSAGE 440 | 441 | | HTTP状态代码 | 错误信息 | 说明 | 442 | | :----: | :--: | :------: | 443 | | 400 | INVALID_NQ | 参数nq格式不正确 | 444 | | 400 | INVALID_RADIUS | 参数radius格式不正确 | 445 | | 400 | INVALID_QUERIES | 参数queries格式不正确 | 446 | | 400 | EMPTY_QUERIES | 参数queries中没有向量 | 447 | 448 | ## 2.7 按照日期查询向量 449 | 450 | ### 2.7.1 请求 451 | 452 | 调用地址:http://127.0.0.1/searchDays 453 | 请求方式:POST 454 | 请求类型:application/json 455 | 456 | | 是否必选 | 参数名 | 类型 | 参数说明 | 457 | | :------: | :----: | :--: | :------: | 458 | | 必选 | qtotal | `Int` | 需要查询的向量个数 | 459 | | 必选 | topk | `Number` | 查询最近邻的个数 | 460 | | 必选 | queries | `Object` | 需要查询的向量标识及对应特征向量 | 461 | | 必选 | days | `List` | 需要查询的日期列表 | 462 | 463 | **请求参数结构示例** 464 | 465 | ```json 466 | { 467 | "qtotal": 10, 468 | "topk": 5, 469 | "queries": 470 | { 471 | "q0": [0.1,...,0.5], 472 | ..., 473 | "q9": [0.1,...,0.5] 474 | }, 475 | "days":["20180424","20180425"] 476 | } 477 | ``` 478 | 479 | ### 2.7.2 返回 480 | 481 | 返回类型:JSON 482 | 483 | | 参数名 | 类型 | 参数说明 | 484 | | :----: | :--: | :------: | 485 | | result | `Object` | 每个查询向量的最近邻距离及标识 | 486 | | time_used | `Int` | 整个请求所花费的时间,单位为毫秒 | 487 | | error_message | `String` | 当请求失败时返回错误信息,请求成功时返回请求结果。 | 488 | 489 | ### 2.7.3 成功请求返回值示例 490 | 491 | ```json 492 | { 493 | "result": 494 | { 495 | "q0": {"distance": [77.55, 78.39], "labels": [80, 1]}, 496 | "q1": {"distance": [67.94, 71.17], "labels": [70, 49]}, 497 | ..., 498 | "q9": {"distance": [73.63, 73.78], "labels": [47, 28]} 499 | } 500 | "time_used": 50, 501 | "error_message": "SEARCHDAYS_SUCCESS" 502 | } 503 | ``` 504 | 505 | ### 2.7.4 失败请求返回值示例 506 | 507 | ```json 508 | { 509 | "time_used": 0, 510 | "error_message": "SEARCHDAYS_FAIL" 511 | } 512 | ``` 513 | 514 | ### 2.7.5 当前 API 特有的 ERROR_MESSAGE 515 | 516 | | HTTP状态代码 | 错误信息 | 说明 | 517 | | :----: | :--: | :------: | 518 | | 400 | INVALID_QTOTAL | 参数qtotal格式不正确 | 519 | | 400 | INVALID_DAYS | 参数days格式不正确 | 520 | | 400 | INVALID_QUERIES | 参数queries格式不正确 | 521 | | 400 | EMPTY_QUERIES | 参数queries中没有向量 | 522 | | 400 | INVALID_TOPK | 参数topk格式不正确 | 523 | | 400 | NOTEXIST_INDEX | 要查询日期的索引文件不存在 | 524 | 525 | ## 2.8 通用的 ERROR_MESSAGE 526 | 527 | | HTTP状态代码 | 错误信息 | 说明 | 528 | | :----: | :--: | :------: | 529 | | 400 | MISSING_ARGUMENTS | 缺少某个必选参数 | 530 | | 400 | INVALID_DIMENSION | 参数data中向量维数不匹配 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # 配置说明 2 | 3 | |选项|参数类型|说明| 4 | |:--|:--|:--| 5 | |listenip|ipv4|| 6 | |port|int|| 7 | |httpThreads|int|处理http请求| 8 | |logConfigParams|object|日志级别开关配置| 9 | |dataFilePath|string|需要加载的数据文件路径| 10 | |dimension|int|特征维数| 11 | |searchFactory|string|faiss的索引结构类型| 12 | |usegpu|bool|是否使用GPU| 13 | 14 | ```json 15 | { 16 | "httpParams": { 17 | "httpThreads": 2, 18 | "listenIP": "0.0.0.0", 19 | "port": 2333 20 | }, 21 | "logConfigParams": { 22 | "debugEnabled": false, 23 | "infoEnabled": true, 24 | "warningEnabled": false 25 | }, 26 | "searchToolParams": { 27 | "dataFilePath": "data.bin", 28 | "dimension": 512, 29 | "searchFactory": "IDMap,Flat", 30 | "usegpu": true 31 | } 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | ## scripts文件说明 2 | 3 | |脚本|说明| 4 | |:--:|:--:| 5 | |install_deps.sh|依赖安装脚本(OpenBlas)| 6 | |monitor.sh|监控重启脚本| 7 | |queryServer_install.run|用户软件安装包| 8 | |install_cuda.sh|安装cuda脚本| 9 | |pack.sh|打包发布脚本| 10 | 11 | ## 打包脚本 12 | 13 | ```bash 14 | # Current path is 'queryServer' 15 | ./scripts/pack.sh 16 | # After this script is executed, a installer will be generate in 'release' 17 | ``` 18 | 19 | ## 用户安装脚本 20 | 21 | 打包完成后,将release文件夹中的queryServer_install.run软件包交给用户。 22 | 用户直接执行./queryServer_install.run即可安装。 23 | 安装完成后插入加密锁,进入queryServer目录,执行./queryServer即可启动服务。 24 | -------------------------------------------------------------------------------- /scripts/git-lfs.deb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | unknown_os () 4 | { 5 | echo "Unfortunately, your operating system distribution and version are not supported by this script." 6 | echo 7 | echo "You can override the OS detection by setting os= and dist= prior to running this script." 8 | echo "You can find a list of supported OSes and distributions on our website: https://packagecloud.io/docs#os_distro_version" 9 | echo 10 | echo "For example, to force Ubuntu Trusty: os=ubuntu dist=trusty ./script.sh" 11 | echo 12 | echo "Please email support@packagecloud.io and let us know if you run into any issues." 13 | exit 1 14 | } 15 | 16 | curl_check () 17 | { 18 | echo "Checking for curl..." 19 | if command -v curl > /dev/null; then 20 | echo "Detected curl..." 21 | else 22 | echo "Installing curl..." 23 | apt-get install -q -y curl 24 | fi 25 | } 26 | 27 | install_debian_keyring () 28 | { 29 | if [ "${os}" = "debian" ]; then 30 | echo "Installing debian-archive-keyring which is needed for installing " 31 | echo "apt-transport-https on many Debian systems." 32 | apt-get install -y debian-archive-keyring &> /dev/null 33 | fi 34 | } 35 | 36 | 37 | detect_os () 38 | { 39 | if [[ ( -z "${os}" ) && ( -z "${dist}" ) ]]; then 40 | # some systems dont have lsb-release yet have the lsb_release binary and 41 | # vice-versa 42 | if [ -e /etc/lsb-release ]; then 43 | . /etc/lsb-release 44 | 45 | if [ "${ID}" = "raspbian" ]; then 46 | os=${ID} 47 | dist=`cut --delimiter='.' -f1 /etc/debian_version` 48 | else 49 | os=${DISTRIB_ID} 50 | dist=${DISTRIB_CODENAME} 51 | 52 | if [ -z "$dist" ]; then 53 | dist=${DISTRIB_RELEASE} 54 | fi 55 | fi 56 | 57 | elif [ `which lsb_release 2>/dev/null` ]; then 58 | dist=`lsb_release -c | cut -f2` 59 | os=`lsb_release -i | cut -f2 | awk '{ print tolower($1) }'` 60 | 61 | elif [ -e /etc/debian_version ]; then 62 | # some Debians have jessie/sid in their /etc/debian_version 63 | # while others have '6.0.7' 64 | os=`cat /etc/issue | head -1 | awk '{ print tolower($1) }'` 65 | if grep -q '/' /etc/debian_version; then 66 | dist=`cut --delimiter='/' -f1 /etc/debian_version` 67 | else 68 | dist=`cut --delimiter='.' -f1 /etc/debian_version` 69 | fi 70 | 71 | else 72 | unknown_os 73 | fi 74 | fi 75 | 76 | if [ -z "$dist" ]; then 77 | unknown_os 78 | fi 79 | 80 | # remove whitespace from OS and dist name 81 | os="${os// /}" 82 | dist="${dist// /}" 83 | 84 | echo "Detected operating system as $os/$dist." 85 | } 86 | 87 | main () 88 | { 89 | detect_os 90 | curl_check 91 | 92 | # Need to first run apt-get update so that apt-transport-https can be 93 | # installed 94 | echo -n "Running apt-get update... " 95 | apt-get update &> /dev/null 96 | echo "done." 97 | 98 | # Install the debian-archive-keyring package on debian systems so that 99 | # apt-transport-https can be installed next 100 | install_debian_keyring 101 | 102 | echo -n "Installing apt-transport-https... " 103 | apt-get install -y apt-transport-https &> /dev/null 104 | echo "done." 105 | 106 | 107 | gpg_key_url="https://packagecloud.io/github/git-lfs/gpgkey" 108 | apt_config_url="https://packagecloud.io/install/repositories/github/git-lfs/config_file.list?os=${os}&dist=${dist}&source=script" 109 | 110 | apt_source_path="/etc/apt/sources.list.d/github_git-lfs.list" 111 | 112 | echo -n "Installing $apt_source_path..." 113 | 114 | # create an apt config file for this repository 115 | curl -sSf "${apt_config_url}" > $apt_source_path 116 | curl_exit_code=$? 117 | 118 | if [ "$curl_exit_code" = "22" ]; then 119 | echo 120 | echo 121 | echo -n "Unable to download repo config from: " 122 | echo "${apt_config_url}" 123 | echo 124 | echo "This usually happens if your operating system is not supported by " 125 | echo "packagecloud.io, or this script's OS detection failed." 126 | echo 127 | echo "You can override the OS detection by setting os= and dist= prior to running this script." 128 | echo "You can find a list of supported OSes and distributions on our website: https://packagecloud.io/docs#os_distro_version" 129 | echo 130 | echo "For example, to force Ubuntu Trusty: os=ubuntu dist=trusty ./script.sh" 131 | echo 132 | echo "If you are running a supported OS, please email support@packagecloud.io and report this." 133 | [ -e $apt_source_path ] && rm $apt_source_path 134 | exit 1 135 | elif [ "$curl_exit_code" = "35" -o "$curl_exit_code" = "60" ]; then 136 | echo "curl is unable to connect to packagecloud.io over TLS when running: " 137 | echo " curl ${apt_config_url}" 138 | echo "This is usually due to one of two things:" 139 | echo 140 | echo " 1.) Missing CA root certificates (make sure the ca-certificates package is installed)" 141 | echo " 2.) An old version of libssl. Try upgrading libssl on your system to a more recent version" 142 | echo 143 | echo "Contact support@packagecloud.io with information about your system for help." 144 | [ -e $apt_source_path ] && rm $apt_source_path 145 | exit 1 146 | elif [ "$curl_exit_code" -gt "0" ]; then 147 | echo 148 | echo "Unable to run: " 149 | echo " curl ${apt_config_url}" 150 | echo 151 | echo "Double check your curl installation and try again." 152 | [ -e $apt_source_path ] && rm $apt_source_path 153 | exit 1 154 | else 155 | echo "done." 156 | fi 157 | 158 | echo -n "Importing packagecloud gpg key... " 159 | # import the gpg key 160 | curl -L "${gpg_key_url}" 2> /dev/null | apt-key add - &>/dev/null 161 | echo "done." 162 | 163 | echo -n "Running apt-get update... " 164 | # update apt on this system 165 | apt-get update &> /dev/null 166 | echo "done." 167 | 168 | echo 169 | echo "The repository is setup! You can now install packages." 170 | } 171 | 172 | main 173 | 174 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROGRAM_NAME="queryServer" 4 | PACKAGE_NAME="${PROGRAM_NAME}.tar.gz" 5 | MD5=0 6 | 7 | check() 8 | { 9 | 10 | if [ ! -e $PACKAGE_NAME ]; then 11 | echo "ERROR - can not find package file ${PACKAGE_NAME} !"; 12 | exit 1; 13 | fi 14 | md5=$(md5sum $PACKAGE_NAME) 15 | md5=${md5:0:32} 16 | if [ "$MD5" != "$md5" ];then 17 | echo "check md5 failed! exit now!"; 18 | exit 1; 19 | fi 20 | } 21 | 22 | extract() 23 | { 24 | echo "INFO - extracting ...." 25 | line=`awk '/===Binary files===/{print NR}' $0 | tail -n1` 26 | line=`expr $line + 1` 27 | tail -n +$line $0 | base64 -d >${PACKAGE_NAME} 28 | 29 | check 30 | 31 | tar zxf $PACKAGE_NAME 32 | cd $PROGRAM_NAME 33 | chmod a+x ./* 34 | echo "INFO - unpack finished!" 35 | } 36 | 37 | install() 38 | { 39 | 40 | echo "INFO - installing ..." 41 | 42 | sh ./scripts/install_deps.sh 43 | 44 | echo "INFO - install dependace finished!" 45 | 46 | sh ./scripts/install_cuda.sh 47 | 48 | } 49 | 50 | post_install() 51 | { 52 | #rm -rf ./deps 53 | rm -rf ./scripts 54 | 55 | rm ../$PACKAGE_NAME 56 | 57 | if [ -e "cuda-repo-ubuntu1404_8.0.61-1_amd64.deb" ]; then 58 | echo "----------------have cuda-repo-ubuntu1404" 59 | rm cuda-repo-ubuntu1404_8.0.61-1_amd64.deb 60 | fi 61 | 62 | echo "INFO - Installation finished!" 63 | } 64 | 65 | setup_monitor() 66 | { 67 | cp ./scripts/monitor.sh ./ 68 | MONITOR="monitor" 69 | if [ ! -e ${MONITOR} ];then 70 | touch ${MONITOR} 71 | fi 72 | echo "* * * * * /bin/bash $(pwd)/monitor.sh $(pwd)" > ${MONITOR} 73 | crontab ${MONITOR} 74 | echo "INFO - Setup monitor done!" 75 | } 76 | 77 | main() 78 | { 79 | extract 80 | install 81 | setup_monitor 82 | post_install 83 | } 84 | 85 | main 86 | exit 0 87 | -------------------------------------------------------------------------------- /scripts/install_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #------------------------------------------------------------------------------ 5 | # 6 | # - git clone --recursive 7 | # - ./install_deps.sh 8 | # - cmake && make 9 | # 10 | 11 | set -e 12 | 13 | # Check for 'uname' and abort if it is not available. 14 | uname -v > /dev/null 2>&1 || { echo >&2 "ERROR - verifier requires 'uname' to identify the platform."; exit 1; } 15 | # 是否有GPU 16 | lspci|grep -i NVIDIA > /dev/null 2>&1 || { echo >&2 "ERROR - cannot find GPU !"; exit 1; } 17 | 18 | case $(uname -s) in 19 | 20 | #------------------------------------------------------------------------------ 21 | # Linux 22 | #------------------------------------------------------------------------------ 23 | Linux) 24 | 25 | # Detect if sudo is needed. 26 | if [ $(id -u) != 0 ]; then 27 | SUDO="sudo" 28 | fi 29 | 30 | #------------------------------------------------------------------------------ 31 | # Arch Linux 32 | #------------------------------------------------------------------------------ 33 | 34 | if [ -f "/etc/arch-release" ]; then 35 | 36 | echo "Installing CUDA and CuDNN on Arch Linux." 37 | 38 | # The majority of our dependencies can be found in the 39 | # Arch Linux official repositories. 40 | # See https://wiki.archlinux.org/index.php/Official_repositories 41 | 42 | #TODO 43 | 44 | elif [ -f "/etc/os-release" ]; then 45 | 46 | DISTRO_NAME=$(. /etc/os-release; echo $NAME) 47 | case $DISTRO_NAME in 48 | 49 | Debian*) 50 | echo "Installing CUDA and CuDNN on Debian Linux." 51 | # TODO 52 | ;; 53 | 54 | Fedora) 55 | echo "Installing CUDA and CuDNN on Fedora Linux." 56 | # TODO 57 | ;; 58 | 59 | #------------------------------------------------------------------------------ 60 | # Ubuntu 61 | # 62 | # TODO - I wonder whether all of the Ubuntu-variants need some special 63 | # treatment? 64 | # 65 | # TODO - We should also test this code on Ubuntu Server, Ubuntu Snappy Core 66 | # and Ubuntu Phone. 67 | # 68 | # TODO - Our Ubuntu build is only working for amd64 and i386 processors. 69 | # It would be good to add armel, armhf and arm64. 70 | # See https://github.com/ethereum/webthree-umbrella/issues/228. 71 | #------------------------------------------------------------------------------ 72 | Ubuntu|LinuxMint) 73 | 74 | # if [[ -f "/usr/local/cuda-8.0" ]]; then 75 | # echo "INFO - cuda already installed!"; 76 | # exit 1; 77 | # fi 78 | 79 | # echo "INFO - Installing CUDA and cudnn on Ubuntu." 80 | echo "INFO - Installing CUDA on Ubuntu." 81 | # install nvidia driver and cuda 82 | UBUNTU_VERSION=$(. /etc/os-release; echo $VERSION_ID) 83 | BIT="x86_64" 84 | DOWNLOAD_PREFIX="http://developer.download.nvidia.com/compute/cuda/repos/" 85 | # echo $UBUNTU_VERSION; 86 | 87 | case $UBUNTU_VERSION in 88 | 89 | 14.04) 90 | DOWNLOAD_LINK="ubuntu1404/${BIT}/cuda-repo-ubuntu1404_8.0.61-1_amd64.deb" 91 | ;; 92 | 93 | 16.04) 94 | DOWNLOAD_LINK="ubuntu1604/${BIT}/cuda-repo-ubuntu1604_8.0.44-1_amd64.deb" 95 | ;; 96 | *) 97 | echo "ERROR - Cannot find support for current Ubuntu version" 98 | exit 1 99 | ;; 100 | esac 101 | 102 | DOWNLOAD_LINK=$DOWNLOAD_PREFIX$DOWNLOAD_LINK 103 | 104 | if [ -z "$(dpkg -l | grep cuda-repo)" ]; then 105 | wget -nv $DOWNLOAD_LINK 106 | $SUDO dpkg -i ${DOWNLOAD_LINK##*/} 107 | fi 108 | 109 | if [ -z "$(dpkg -l | grep cuda-drivers)" ]; then 110 | $SUDO apt-get update 111 | $SUDO apt-get install -y cuda 112 | fi 113 | 114 | if [ -e "${DOWNLOAD_LINK##*/}" ]; then 115 | $SUDO rm ${DOWNLOAD_LINK##*/} 116 | fi 117 | 118 | echo "INFO - Install CUDA finished." 119 | 120 | # CUDNN_PACKAGE="cudnn-8.0-linux-x64-v5.1.tar" 121 | # CUDNN_PATH="./deps" 122 | # # install cudnn 123 | # if [ ! -e "${CUDNN_PATH}/${CUDNN_PACKAGE}" ]; 124 | # then 125 | # echo "ERROR - can not find cudnn tar file ${CUDNN_PATH}/${CUDNN_PACKAGE}!"; 126 | # exit 1; 127 | # fi 128 | # cd $CUDNN_PATH 129 | # tar -xf $CUDNN_PACKAGE 130 | # $SUDO cp cuda/include/* /usr/local/cuda/include/ 131 | # $SUDO cp cuda/lib64/* /usr/local/cuda/lib64/ 132 | # rm -rf ./cuda 133 | # if [ -z "$(cat /etc/profile | grep /usr/local/cuda/lib64)" ]; then 134 | # echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/cuda/lib64" | $SUDO tee -a /etc/profile 135 | # fi 136 | # echo "INFO - Install cudnn finished." 137 | 138 | ;; 139 | 140 | CentOS*) 141 | echo "Installing CUDA and CuDNN on CentOS." 142 | # Enable EPEL repo that contains leveldb-devel 143 | 144 | # TODO 145 | 146 | ;; 147 | 148 | *) 149 | echo "Unsupported Linux distribution: $DISTRO_NAME." 150 | exit 1 151 | ;; 152 | 153 | esac 154 | 155 | elif [ -f "/etc/alpine-release" ]; then 156 | # Alpine Linux 157 | echo "Installing CUDA and CuDNN on Alpine Linux." 158 | #TODO 159 | 160 | else 161 | 162 | case $(lsb_release -is) in 163 | 164 | #------------------------------------------------------------------------------ 165 | # Other (unknown) Linux 166 | # Major and medium distros which we are missing would include Mint, CentOS, 167 | # RHEL, Raspbian, Cygwin, OpenWrt, gNewSense, Trisquel and SteamOS. 168 | #------------------------------------------------------------------------------ 169 | *) 170 | #other Linux 171 | echo "ERROR - Unsupported or unidentified Linux distro." 172 | exit 1 173 | ;; 174 | esac 175 | fi 176 | ;; 177 | 178 | #------------------------------------------------------------------------------ 179 | # Other platform (not Linux, FreeBSD or macOS). 180 | # Not sure what might end up here? 181 | # Maybe OpenBSD, NetBSD, AIX, Solaris, HP-UX? 182 | #------------------------------------------------------------------------------ 183 | *) 184 | #other 185 | echo "ERROR - Unsupported or unidentified operating system." 186 | ;; 187 | esac 188 | -------------------------------------------------------------------------------- /scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | set -e 6 | 7 | 8 | main() 9 | { 10 | mkdir deps 11 | cd ./deps 12 | tar -zxf libopenblas.tar.gz && \ 13 | sudo dpkg -i libblas-common_3.6.0-2ubuntu2_amd64.deb && \ 14 | sudo dpkg -i libopenblas-base_0.2.18-1ubuntu1_amd64.deb && \ 15 | sudo dpkg -i libopenblas-dev_0.2.18-1ubuntu1_amd64.deb 16 | echo "INFO - install openblas dependence finished!" 17 | } 18 | main $# $@ 19 | -------------------------------------------------------------------------------- /scripts/monitor.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | source /etc/profile 3 | # check every 30 seconds 4 | # crontab -e 5 | # * * * * * bash /home/hadoop/queryServer/scripts/monitor.sh 6 | # * * * * * sleep 30; bash /home/hadoop/queryServer/scripts/monitor.sh 7 | 8 | # 9 | PROCESS_NAME="queryServer" 10 | PROCESS_PATH=$1 # 程序绝对路径 11 | START_PROCESS="./${PROCESS_NAME}" 12 | PROCESS_PARAMS=">>queryServer.log &" 13 | 14 | 15 | # 函数: CheckProcess 16 | # 功能: 检查一个进程是否存在 17 | # 参数: $1 --- 要检查的进程名称 18 | # 返回: 指定进程数量 19 | #------------------------------------------------------------------------------ 20 | CheckProcess() 21 | { 22 | #$PROCESS_NUM获取指定进程名的数目,为1返回0,表示正常,不为1返回1,表示有错误,需要重新启动 23 | PROCESS_NUM=`ps -A | grep "$1" | grep -v "grep" | wc -l` 24 | return $PROCESS_NUM 25 | } 26 | 27 | # 检查实例是否已经存在 28 | CheckProcess "${PROCESS_NAME}" 29 | CheckRet=$? 30 | if [ $CheckRet -eq 0 ]; then 31 | # 任意你需要执行的操作 32 | /bin/echo "$(date '+%Y-%m-%d %T') : Restarting ${PROCESS_NAME} ...">> ${PROCESS_PATH}/restart.log 33 | cd ${PROCESS_PATH}; 34 | (${START_PROCESS}) >> ${PROCESS_PATH}/restart.log 2>&1 & 35 | fi 36 | -------------------------------------------------------------------------------- /scripts/pack.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | PROJECT_DIR=$(pwd) 6 | RELEASE_DIR="${PROJECT_DIR}/release" 7 | EXECUTABLE_DIR="${PROJECT_DIR}/build/bin" 8 | EXECUTABLE_FILE="${EXECUTABLE_DIR}/queryServer" 9 | TARGET_DIR="${RELEASE_DIR}/queryServer" 10 | SCRIPT_PATH="${PROJECT_DIR}/scripts" 11 | FAISS_FILE="${PROJECT_DIR}/deps/faiss-1.0.tar.gz" 12 | BLAS_FILE="${PROJECT_DIR}/deps/libopenblas.tar.gz" 13 | BUILD_TYPE=0 # 0 is debug, 1 is release 14 | 15 | pre_check() 16 | { 17 | if [ -e ${RELEASE_DIR} ]; then 18 | rm -rf ${RELEASE_DIR} 19 | fi 20 | 21 | if [ ! -d ${TARGET_DIR} ]; then 22 | mkdir -p ${TARGET_DIR} 23 | fi 24 | 25 | cd ${TARGET_DIR} 26 | 27 | if [ -e "queryServer.tar.gz" ]; then 28 | rm queryServer.tar.gz 29 | fi 30 | 31 | if [ ! -e $EXECUTABLE_FILE ]; then 32 | echo "ERROR - can not find executable file ${EXECUTABLE_FILE}!"; 33 | exit 1; 34 | fi 35 | 36 | if [ ! -e "${SCRIPT_PATH}/install_deps.sh" ]; then 37 | echo "ERROR - can not find install dependence scripts!"; 38 | exit 1; 39 | fi 40 | 41 | if [ ! -e "${SCRIPT_PATH}/install_cuda.sh" ]; then 42 | echo "ERROR - can not find install cuda scripts!"; 43 | exit 1; 44 | fi 45 | 46 | if [ ! -e "${SCRIPT_PATH}/install.sh" ]; then 47 | echo "ERROR - can not find install scripts!"; 48 | exit 1; 49 | fi 50 | 51 | if [ ! -e "${FAISS_FILE}" ]; then 52 | echo "ERROR - can not find faiss file ${FAISS_FILE}!"; 53 | exit 1; 54 | fi 55 | if [ ! -e "${BLAS_FILE}" ]; then 56 | echo "ERROR - can not find openblas file ${BLAS_FILE}!"; 57 | exit 1; 58 | fi 59 | echo "Check files done!" 60 | } 61 | 62 | make_dir() 63 | { 64 | mkdir $TARGET_DIR/scripts 65 | echo "Gennerate directories done!" 66 | } 67 | 68 | 69 | copy_files() 70 | { 71 | cp $EXECUTABLE_FILE $TARGET_DIR 72 | cp ${SCRIPT_PATH}/install_deps.sh $TARGET_DIR/scripts 73 | cp ${SCRIPT_PATH}/install_cuda.sh $TARGET_DIR/scripts 74 | cp ${SCRIPT_PATH}/monitor.sh $TARGET_DIR/scripts 75 | cp ${SCRIPT_PATH}/install.sh ${RELEASE_DIR} 76 | mkdir -p ${TARGET_DIR}/deps/ 77 | cp ${FAISS_FILE} ${TARGET_DIR}/deps/ 78 | cp ${BLAS_FILE} ${TARGET_DIR}/deps/ 79 | } 80 | 81 | compress() 82 | { 83 | cd ${RELEASE_DIR} 84 | tar zcvf queryServer.tar.gz queryServer 85 | echo "Compress done!" 86 | } 87 | 88 | gen_md5() 89 | { 90 | md5=$(md5sum queryServer.tar.gz) 91 | md5=${md5:0:32} 92 | sed -i "s/MD5=0/MD5=${md5}/" install.sh 93 | echo "Generate md5 done!" 94 | } 95 | 96 | gen_prog() 97 | { 98 | echo "===Binary files===" >> install.sh 99 | base64 queryServer.tar.gz >> install.sh 100 | 101 | mv install.sh queryServer_install.run 102 | chmod u+x queryServer_install.run 103 | 104 | echo "Generate installer done!" 105 | } 106 | 107 | clean() 108 | { 109 | if [ -e "${RELEASE_DIR}/queryServer.tar.gz" ]; then 110 | echo "clean...." 111 | rm -rf ${TARGET_DIR} 112 | rm $RELEASE_DIR/queryServer.tar.gz 113 | fi 114 | } 115 | 116 | main() 117 | { 118 | if [ $1 != 0 ];then 119 | BUILD_TYPE=1 120 | fi 121 | if [ $BUILD_TYPE == 1 ]; 122 | then 123 | echo "Build type : release." 124 | else 125 | echo "Build type : debug" 126 | fi 127 | 128 | pre_check 129 | make_dir 130 | 131 | copy_files 132 | compress 133 | gen_md5 134 | gen_prog 135 | clean 136 | } 137 | 138 | main $# $@ 139 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | include_directories(${CMAKE_SOURCE_DIR}/src 3 | ${PISTACHE_INCLUDE_DIR}) 4 | 5 | add_subdirectory(common) 6 | add_subdirectory(libRestServer) 7 | add_subdirectory(libSearch) 8 | 9 | add_executable(searcher main.cpp) 10 | target_include_directories(searcher PRIVATE ./) 11 | target_link_libraries(searcher restserver) 12 | 13 | -------------------------------------------------------------------------------- /src/common/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRC_LIST "*.cpp") 2 | file(GLOB HEADERS "*.h") 3 | 4 | add_library(common STATIC ${SRC_LIST} ${HEADERS}) 5 | -------------------------------------------------------------------------------- /src/common/configParams.cpp: -------------------------------------------------------------------------------- 1 | #include "configParams.h" 2 | #include "json.h" 3 | #include "easylog++.h" 4 | #include "version.h" 5 | #include 6 | #include 7 | #include 8 | 9 | INITIALIZE_EASYLOGGINGPP 10 | 11 | using namespace dev; 12 | using namespace std; 13 | using json = nlohmann::json; 14 | 15 | #define CONFIG_WARN "Config file " 16 | 17 | // void rolloutHandler(const char *filename, size_t) 18 | // { 19 | // stringstream ss; 20 | // static map s_mlogIndex; 21 | // map::iterator iter = s_mlogIndex.find(filename); 22 | // if (iter != s_mlogIndex.end()) 23 | // { 24 | // ss << filename << "." << iter->second++; 25 | // s_mlogIndex[filename] = iter->second++; 26 | // } 27 | // else 28 | // { 29 | // ss << filename << "." << 0; 30 | // s_mlogIndex[filename] = 0; 31 | // } 32 | // boost::filesystem::rename(filename, ss.str().c_str()); 33 | // } 34 | 35 | ConfigParams::ConfigParams(string const &configPath) 36 | { 37 | cpuCores = thread::hardware_concurrency(); 38 | loadConfig(configPath); 39 | initEasylogging(logConfigParams); 40 | versionMajor = VERIFIER_VERSION_MAJOR; 41 | versionMinor = VERIFIER_VERSION_MINOR; 42 | versionPatch = VERIFIER_VERSION_PATCH; 43 | 44 | #ifdef NDEBUG 45 | buildType = " Release"; 46 | #else 47 | buildType = " Debug "; 48 | #endif 49 | } 50 | 51 | void ConfigParams::loadConfig(string const &configPath) 52 | { 53 | ifstream in(configPath); 54 | json config; 55 | bool missConfigfile = true; 56 | if (in.is_open()) 57 | { 58 | try 59 | { 60 | in >> config; 61 | missConfigfile = false; 62 | } 63 | catch (std::exception &e) 64 | { 65 | LOG(ERROR) << e.what(); 66 | } 67 | in.close(); 68 | } 69 | 70 | if (missConfigfile) 71 | { 72 | LOG(INFO) << "Use default config."; 73 | if (!generateDefaultConfig(configPath)) 74 | LOG(ERROR) << "Generate default config fail."; 75 | return; 76 | } 77 | 78 | // config http params 79 | if (config["httpParams"].is_object()) 80 | { 81 | if (config["httpParams"]["port"].is_number_unsigned()) 82 | port = config["httpParams"]["port"].get(); 83 | else 84 | LOG(WARNING) << CONFIG_WARN << "Missing port."; 85 | 86 | if (config["httpParams"]["listenIP"].is_string()) 87 | ip = config["httpParams"]["listenIP"].get(); 88 | else 89 | LOG(WARNING) << CONFIG_WARN << "Missing listenIP."; 90 | 91 | if (config["httpParams"]["httpThreads"].is_number_unsigned()) 92 | threads = config["httpParams"]["httpThreads"].get(); 93 | else 94 | LOG(WARNING) << CONFIG_WARN << "Missing httpThreads."; 95 | } 96 | else 97 | LOG(WARNING) << CONFIG_WARN << "Missing httpParams."; 98 | 99 | // config search tool 100 | if (config["searchToolParams"].is_object()) 101 | { 102 | if (config["searchToolParams"]["searchFactory"].is_string()) 103 | searchFactory = config["searchToolParams"]["searchFactory"].get(); 104 | if (config["searchToolParams"]["dimension"].is_number_unsigned()) 105 | dimension = config["searchToolParams"]["dimension"].get(); 106 | if (config["searchToolParams"]["dataFilePath"].is_string()) 107 | dataFilePath = config["searchToolParams"]["dataFilePath"].get(); 108 | if (config["searchToolParams"]["usegpu"].is_boolean()) 109 | usegpu = config["searchToolParams"]["usegpu"].get(); 110 | } 111 | else 112 | LOG(WARNING) << CONFIG_WARN << "Missing searchToolParams."; 113 | 114 | // config log 115 | if (config["logConfigParams"].is_object()) 116 | { 117 | if (config["logConfigParams"]["debugEnabled"].is_boolean()) 118 | logConfigParams.debugEnabled = config["logConfigParams"]["debugEnabled"].get(); 119 | if (config["logConfigParams"]["warningEnabled"].is_boolean()) 120 | logConfigParams.warningEnabled = config["logConfigParams"]["warningEnabled"].get(); 121 | if (config["logConfigParams"]["infoEnabled"].is_boolean()) 122 | logConfigParams.infoEnabled = config["logConfigParams"]["infoEnabled"].get(); 123 | } 124 | else 125 | LOG(WARNING) << CONFIG_WARN << "Missing logConfigParams."; 126 | } 127 | 128 | bool ConfigParams::generateDefaultConfig(string const &configPath) const 129 | { 130 | json httpConfig; 131 | httpConfig["listenIP"] = ip; 132 | httpConfig["port"] = port; 133 | httpConfig["httpThreads"] = threads; 134 | 135 | json searchConfig; 136 | searchConfig["searchFactory"] = searchFactory; 137 | searchConfig["dimension"] = dimension; 138 | searchConfig["dataFilePath"] = dataFilePath; 139 | searchConfig["usegpu"] = usegpu; 140 | 141 | json logParamsConfig; 142 | logParamsConfig["debugEnabled"] = logConfigParams.debugEnabled; 143 | logParamsConfig["warningEnabled"] = logConfigParams.warningEnabled; 144 | logParamsConfig["infoEnabled"] = logConfigParams.infoEnabled; 145 | 146 | json config; 147 | config["httpParams"] = httpConfig; 148 | config["searchToolParams"] = searchConfig; 149 | config["logConfigParams"] = logParamsConfig; 150 | 151 | ofstream os(configPath); 152 | if (!os.is_open()) 153 | return false; 154 | os << std::setw(4) << config; 155 | os.close(); 156 | return true; 157 | } 158 | 159 | //日志配置文件放到log目录 160 | void ConfigParams::initEasylogging(const LogConfigParams &logConfig) const 161 | { 162 | el::Loggers::addFlag(el::LoggingFlag::MultiLoggerSupport); // Enables support for multiple loggers 163 | el::Loggers::addFlag(el::LoggingFlag::StrictLogFileSizeCheck); 164 | 165 | el::Configurations conf; 166 | conf.set(el::Level::Global, el::ConfigurationType::Enabled, to_string(true)); 167 | conf.set(el::Level::Trace, el::ConfigurationType::Enabled, to_string(logConfig.traceEnabled)); 168 | conf.set(el::Level::Debug, el::ConfigurationType::Enabled, to_string(logConfig.debugEnabled)); 169 | conf.set(el::Level::Fatal, el::ConfigurationType::Enabled, to_string(logConfig.fatalEnabled)); 170 | conf.set(el::Level::Error, el::ConfigurationType::Enabled, to_string(logConfig.errorEnabled)); 171 | conf.set(el::Level::Warning, el::ConfigurationType::Enabled, to_string(logConfig.warningEnabled)); 172 | conf.set(el::Level::Info, el::ConfigurationType::Enabled, to_string(logConfig.infoEnabled)); 173 | conf.set(el::Level::Verbose, el::ConfigurationType::Enabled, to_string(logConfig.verboseEnabled)); 174 | 175 | conf.set(el::Level::Global, el::ConfigurationType::ToFile, to_string(true)); 176 | conf.set(el::Level::Global, el::ConfigurationType::ToStandardOutput, to_string(false)); 177 | conf.set(el::Level::Global, el::ConfigurationType::Format, "%level|%datetime{%Y-%M-%d %H:%m:%s}|%msg"); 178 | conf.set(el::Level::Global, el::ConfigurationType::MillisecondsWidth, to_string(3)); 179 | conf.set(el::Level::Global, el::ConfigurationType::PerformanceTracking, to_string(false)); 180 | conf.set(el::Level::Global, el::ConfigurationType::MaxLogFileSize, to_string(209715200)); // 200MB 181 | conf.set(el::Level::Global, el::ConfigurationType::LogFlushThreshold, to_string(100)); // flush after every 100 logs 182 | 183 | if (logConfig.debugEnabled) 184 | { 185 | conf.set(el::Level::Debug, el::ConfigurationType::Filename, "logs/debug_log_%datetime{%Y%M%d%H}.log"); 186 | conf.set(el::Level::Debug, el::ConfigurationType::Format, "%level|%datetime|%file|%func|%line|%msg"); 187 | } 188 | if (logConfig.errorEnabled) 189 | { 190 | conf.set(el::Level::Error, el::ConfigurationType::Filename, "logs/error_log_%datetime{%Y%M%d%H}.log"); 191 | conf.set(el::Level::Error, el::ConfigurationType::Format, "%level|%datetime|%file|%func|%line|%msg"); 192 | } 193 | if (logConfig.infoEnabled) 194 | { 195 | conf.set(el::Level::Info, el::ConfigurationType::Filename, "logs/info_log_%datetime{%Y%M%d%H}.log"); 196 | } 197 | el::Loggers::reconfigureAllLoggers(conf); 198 | 199 | // log file rollout 200 | // el::Helpers::installPreRollOutCallback(rolloutHandler); 201 | } 202 | 203 | void ConfigParams::printParams() const 204 | { 205 | stringstream ss; 206 | ss << "===================================" << endl; 207 | ss << ">>>>>> Version " 208 | << versionMajor << "." << versionMinor << "." << versionPatch 209 | << buildType << " <<<<<<" << endl; 210 | ss << "===================================" << endl; 211 | ss << "=========== Http Config ===========" << endl; 212 | ss << "CPU Cores : " << cpuCores << endl; 213 | ss << "Listen IP : " << ip << endl; 214 | ss << "Listen Port : " << port << endl; 215 | ss << "Http Threads: " << threads << endl; 216 | ss << "====================================" << endl; 217 | ss << "======== Search Parameters =========" << endl; 218 | ss << "searchFactory : " << searchFactory << endl; 219 | ss << "dimension : " << dimension << endl; 220 | ss << "usegpu : " << usegpu << endl; 221 | ss << "===================================="; 222 | cout << ss.str() << endl; 223 | LOG(INFO) << endl 224 | << ss.str(); 225 | } 226 | -------------------------------------------------------------------------------- /src/common/configParams.h: -------------------------------------------------------------------------------- 1 | #ifndef CONFIG_PARAMS_H_ 2 | #define CONFIG_PARAMS_H_ 3 | #include 4 | 5 | namespace dev 6 | { 7 | 8 | /** 9 | * @brief prase config parameters from json config file 10 | * 11 | */ 12 | struct ConfigParams 13 | { 14 | /// @brief constructor 15 | explicit ConfigParams(std::string const &configPath); 16 | /// @brief print configure parameters 17 | void printParams() const; 18 | 19 | int port = 2333; 20 | int threads = 2; 21 | int cpuCores = 0; 22 | std::string ip = "0.0.0.0"; 23 | std::string searchFactory = "IDMap,Flat"; 24 | unsigned int dimension = 256; 25 | std::string dataFilePath = "data.bin"; 26 | bool usegpu = true; 27 | 28 | int versionMajor; 29 | int versionMinor; 30 | int versionPatch; 31 | std::string buildType; 32 | struct LogConfigParams 33 | { 34 | bool traceEnabled = false; 35 | bool debugEnabled = false; 36 | bool fatalEnabled = false; 37 | bool errorEnabled = true; 38 | bool warningEnabled = false; 39 | bool infoEnabled = true; 40 | bool verboseEnabled = false; 41 | } logConfigParams; 42 | 43 | private: 44 | /** 45 | * @brief load config 46 | * 47 | * @param configPath path of configure file 48 | */ 49 | void loadConfig(std::string const &configPath); 50 | 51 | /** 52 | * @brief generate default configure file when configure file don't exist 53 | * 54 | * @param configPath path of default configure file 55 | * @return true success 56 | * @return false fail 57 | */ 58 | bool generateDefaultConfig(std::string const &configPath) const; 59 | 60 | /** 61 | * @brief initialize log configure 62 | * 63 | * @param logConfig log configure parameter 64 | */ 65 | void initEasylogging(const LogConfigParams &logConfig) const; 66 | }; // struct ConfigParams 67 | } // namespace dev 68 | 69 | #endif //CONFIG_PARAMS_H_ 70 | -------------------------------------------------------------------------------- /src/common/easylog++.h: -------------------------------------------------------------------------------- 1 | #ifndef EASYLOGPP_H_ 2 | #define EASYLOGPP_H_ 3 | 4 | #define ELPP_NO_DEFAULT_LOG_FILE 5 | #define ELPP_THREAD_SAFE 6 | 7 | #include "easylogging++.h" 8 | 9 | #undef LOG 10 | #define LOG(LEVEL) CLOG(LEVEL, "default", "fileLogger") 11 | 12 | #endif //EASYLOGPP_H_ 13 | -------------------------------------------------------------------------------- /src/common/error.h: -------------------------------------------------------------------------------- 1 | #ifndef QUERY_ERROR_H_ 2 | #define QUERY_ERROR_H_ 3 | 4 | #define INVALID_NTOTAL "INVALID_NTOTAL" 5 | #define INVALID_DATA "INVALID_DATA" 6 | #define EMPTY_DATA "EMPTY_DATA" 7 | #define INVALID_DIMENSION "INVALID_DIMENSION" 8 | #define MISSING_ARGUMENTS "MISSING_ARGUMENTS" 9 | #define INVALID_TOPK "INVALID_TOPK" 10 | #define INVALID_QTOTAL "INVALID_QTOTAL" 11 | #define EMPTY_QUERIES "EMPTY_QUERIES" 12 | #define INVALID_QUERIES "INVALID_QUERIES" 13 | #define INVALID_IDS "INVALID_IDS" 14 | #define INVALID_RANGE "INVALID_RANGE" 15 | #define INVALID_RECONFIG_FILEPATH "INVALID_RECONFIG_FILEPATH" 16 | #define RELOAD_FAIL "RELOAD_FAIL" 17 | #define READD_FAIL "READD_FAIL" 18 | #define INVALID_NQ "INVALID_NQ" 19 | #define INVALID_RADIUS "INVALID_RADIUS" 20 | #define INVALID_DAYS "INVALID_DAYS" 21 | #define NOTEXIST_INDEX "NOTEXIST_INDEX" 22 | 23 | #define BAD_REQUEST "400" 24 | #define OK "200" 25 | 26 | #define ADD_SUCCESS "ADD_SUCCESS" 27 | #define ADD_FAIL "ADD_FAIL" 28 | 29 | #define SEARCH_SUCCESS "SEARCH_SUCCESS" 30 | #define SEARCH_FAIL "SEARCH_FAIL" 31 | 32 | #define DELETE_SUCCESS "DELETE_SUCCESS" 33 | #define DELETE_FAIL "DELETE_FAIL" 34 | 35 | #define DELETERANGE_SUCCESS "DELETERANGE_SUCCESS" 36 | #define DELETERANGE_FAIL "DELETERANGE_FAIL" 37 | 38 | #define RECONFIG_SUCCESS "RECONFIG_SUCCESS" 39 | 40 | #define SEARCH_RANGE_SUCCESS "SEARCH_RANGE_SUCCESS" 41 | #define SEARCH_RANGE_FAIL "SEARCH_RANGE_FAIL" 42 | 43 | #define SEARCHDAYS_SUCCESS "SEARCHDAYS_SUCCESS" 44 | #define SEARCHDAYS_FAIL "SEARCHDAYS_FAIL" 45 | 46 | #endif //QUERY_ERROR_H_ 47 | -------------------------------------------------------------------------------- /src/common/memusage.h: -------------------------------------------------------------------------------- 1 | #ifndef MEMUSAGE_H__ 2 | #define MEMUSAGE_H__ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | // taken from http://stackoverflow.com/questions/669438/how-to-get-memory-usage-at-run-time-in-c 11 | void process_mem_usage(double *vm_usage, double *resident_set) 12 | { 13 | using std::ios_base; 14 | using std::ifstream; 15 | using std::string; 16 | 17 | *vm_usage = 0.0; 18 | *resident_set = 0.0; 19 | 20 | // 'file' stat seems to give the most reliable results 21 | // 22 | ifstream stat_stream("/proc/self/stat",ios_base::in); 23 | 24 | // dummy vars for leading entries in stat that we don't care about 25 | // 26 | string pid, comm, state, ppid, pgrp, session, tty_nr; 27 | string tpgid, flags, minflt, cminflt, majflt, cmajflt; 28 | string utime, stime, cutime, cstime, priority, nice; 29 | string O, itrealvalue, starttime; 30 | 31 | // the two fields we want 32 | // 33 | unsigned long vsize; 34 | long rss; 35 | 36 | stat_stream >> pid >> comm >> state >> ppid >> pgrp >> session >> tty_nr 37 | >> tpgid >> flags >> minflt >> cminflt >> majflt >> cmajflt 38 | >> utime >> stime >> cutime >> cstime >> priority >> nice 39 | >> O >> itrealvalue >> starttime >> vsize >> rss; // don't care about the rest 40 | 41 | stat_stream.close(); 42 | 43 | long page_size_kb = sysconf(_SC_PAGE_SIZE) / 1024; // in case x86-64 is configured to use 2MB pages 44 | *vm_usage = vsize / 1024.0; 45 | *resident_set = rss * page_size_kb; 46 | } 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /src/common/version.h: -------------------------------------------------------------------------------- 1 | #ifndef VERSION_H_ 2 | #define VERSION_H_ 3 | 4 | #ifdef VERSION_MAJOR 5 | #define VERIFIER_VERSION_MAJOR VERSION_MAJOR 6 | #else 7 | #define VERIFIER_VERSION_MAJOR 0 8 | #endif 9 | 10 | #ifdef VERSION_MINOR 11 | #define VERIFIER_VERSION_MINOR VERSION_MINOR 12 | #else 13 | #define VERIFIER_VERSION_MINOR 0 14 | #endif // DVERSION_MINOR 15 | 16 | #ifdef VERSION_PATCH 17 | #define VERIFIER_VERSION_PATCH VERSION_PATCH 18 | #else 19 | #define VERIFIER_VERSION_PATCH 0 20 | #endif 21 | 22 | #endif // VERSION_H_ -------------------------------------------------------------------------------- /src/libRestServer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRC_LIST "*.cpp") 2 | file(GLOB HEADERS "*.h") 3 | 4 | add_library(restserver STATIC ${SRC_LIST} ${HEADERS}) 5 | 6 | target_include_directories(restserver PRIVATE ..) 7 | target_link_libraries(restserver Pistache common search) 8 | add_dependencies(restserver Pistache) 9 | # include_directories(../common) -------------------------------------------------------------------------------- /src/libRestServer/RequestHandler.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "RequestHandler.h" 5 | #include "common/easylog++.h" 6 | #include "libSearch/FaissInterface.h" 7 | #include "common/error.h" 8 | 9 | using namespace std; 10 | using namespace dev; 11 | using json = nlohmann::json; 12 | 13 | RequestHandler::RequestHandler(ConfigParams *_cp) : cp(_cp) 14 | { 15 | search.reset(new faissSearch(_cp->searchFactory, _cp->dimension, _cp->usegpu, true)); 16 | searchdays.reset(new faissSearch(cp->searchFactory, cp->dimension, cp->usegpu)); 17 | search_processor.reset(new SearchProcessor(_cp)); 18 | search_processor->RegistAPI(this); 19 | vector ids; 20 | vector features; 21 | if (!search->load(_cp->dataFilePath, ids, features)) 22 | { 23 | LOG(ERROR) << "load data failed." << endl; 24 | return; 25 | } 26 | else 27 | { 28 | LOG(INFO) << "load data success." << endl; 29 | bool res = search->add_with_ids((idx_t)ids.size(), features.data(), ids.data()); 30 | if (!res) 31 | { 32 | LOG(ERROR) << "add loaded data failed." << endl; 33 | return; 34 | } 35 | LOG(INFO) << "add loaded data success." << endl; 36 | } 37 | } 38 | 39 | RequestHandler::~RequestHandler() {} 40 | 41 | string RequestHandler::FillResponse(const string &interface, json &response, const char *errorMsg, const TIMEPOINTS &timePoints) 42 | { 43 | response["time_used"] = (chrono::duration_cast(CURRENT_SYS_TIMEPOINT - timePoints.front())).count(); 44 | if (errorMsg) 45 | response["error_message"] = errorMsg; 46 | WriteLog(interface, response, timePoints); 47 | return response.dump(); 48 | } 49 | 50 | void RequestHandler::WriteLog(const string &interface, const json &resp, const TIMEPOINTS &timePoints) 51 | { 52 | char timeStr[10]; 53 | auto timeLong = chrono::duration_cast(CURRENT_SYS_TIMEPOINT - timePoints.front()); 54 | auto tp = chrono::system_clock::to_time_t(timePoints.front()); 55 | strftime(timeStr, sizeof(timeStr), "%H:%M:%S", std::localtime(&tp)); 56 | LOG(INFO) << "|Req Time: " << timeStr << left << "|Interface: " << interface << "|Time Used: " << timeLong.count() << " ms" 57 | << "|Result: " << resp << "|indexNum: " << search->ntotal << endl; 58 | } 59 | 60 | void RequestHandler::Reconfig(string const &request, string &responseBody, string &httpCode) 61 | { 62 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 63 | string interfaceName("/reconfig"); 64 | json value = json::parse(request); 65 | json resp; 66 | bool res = search->reset(); 67 | responseBody = FillResponse(interfaceName, resp, RECONFIG_SUCCESS, timepoints); 68 | httpCode = OK; 69 | } 70 | 71 | void RequestHandler::Add(string const &request, string &responseBody, string &httpCode) 72 | { 73 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 74 | string interfaceName("/add"); 75 | json value = json::parse(request); 76 | json resp; 77 | if (value["ntotal"].is_null() || value["data"].is_null()) 78 | { 79 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 80 | httpCode = BAD_REQUEST; 81 | return; 82 | } 83 | if (!value["ntotal"].is_number()) 84 | { 85 | responseBody = FillResponse(interfaceName, resp, INVALID_NTOTAL); 86 | httpCode = BAD_REQUEST; 87 | return; 88 | } 89 | if (!value["data"].is_object()) 90 | { 91 | responseBody = FillResponse(interfaceName, resp, INVALID_DATA); 92 | httpCode = BAD_REQUEST; 93 | return; 94 | } 95 | int n = value["ntotal"]; 96 | vector ids(n, 0); 97 | unsigned int dim = cp->dimension; 98 | vector features; 99 | features.reserve(n * dim); 100 | json object = value["data"]; 101 | if (object.size() == 0) 102 | { 103 | responseBody = FillResponse(interfaceName, resp, EMPTY_DATA); 104 | httpCode = BAD_REQUEST; 105 | return; 106 | } 107 | unordered_map queries = object; 108 | int count = 0; 109 | for (auto &query : queries) 110 | { 111 | ids[count] = stol(query.first); 112 | if (query.second.size() != dim) 113 | { 114 | responseBody = FillResponse(interfaceName, resp, INVALID_DIMENSION); 115 | httpCode = BAD_REQUEST; 116 | return; 117 | } 118 | features.insert(features.end(), query.second.begin(), query.second.end()); 119 | count++; 120 | } 121 | 122 | bool res = search->add_with_ids(n, features.data(), ids.data()); 123 | responseBody = FillResponse(interfaceName, resp, res ? ADD_SUCCESS : ADD_FAIL, timepoints); 124 | httpCode = OK; 125 | } 126 | 127 | void RequestHandler::Query(string const &request, string &responseBody, string &httpCode) 128 | { 129 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 130 | string interfaceName("/search"); 131 | json value = json::parse(request); 132 | json resp = json::object(); 133 | if (value["topk"].is_null() || value["qtotal"].is_null() || value["queries"].is_null()) 134 | { 135 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 136 | httpCode = BAD_REQUEST; 137 | return; 138 | } 139 | if (!value["topk"].is_number()) 140 | { 141 | responseBody = FillResponse(interfaceName, resp, INVALID_TOPK); 142 | httpCode = BAD_REQUEST; 143 | return; 144 | } 145 | if (!value["qtotal"].is_number()) 146 | { 147 | responseBody = FillResponse(interfaceName, resp, INVALID_QTOTAL); 148 | httpCode = BAD_REQUEST; 149 | return; 150 | } 151 | if (!value["queries"].is_object()) 152 | { 153 | responseBody = FillResponse(interfaceName, resp, INVALID_QUERIES); 154 | httpCode = BAD_REQUEST; 155 | return; 156 | } 157 | unsigned int dim = cp->dimension; 158 | unsigned int n = value["qtotal"]; 159 | idx_t k = value["topk"]; 160 | vector features; 161 | features.reserve(n * dim); 162 | vector ids(n); 163 | json object = value["queries"]; 164 | if (object.size() == 0 || n == 0) 165 | { 166 | responseBody = FillResponse(interfaceName, resp, EMPTY_QUERIES); 167 | httpCode = BAD_REQUEST; 168 | return; 169 | } 170 | unordered_map> queries = object; 171 | int count = 0; 172 | for (auto &query : queries) 173 | { 174 | ids[count] = query.first; 175 | if (query.second.size() != dim) 176 | { 177 | responseBody = FillResponse(interfaceName, resp, INVALID_DIMENSION); 178 | httpCode = BAD_REQUEST; 179 | return; 180 | } 181 | features.insert(features.end(), query.second.begin(), query.second.end()); 182 | count++; 183 | } 184 | 185 | vector resDistance(n * k, 0); 186 | vector resLabels(n * k, 0); 187 | bool res = search->search(n, features.data(), k, resDistance.data(), resLabels.data()); 188 | if (search->ntotal == 0) 189 | { 190 | resDistance.assign(n * k, -1); 191 | } 192 | json queryMap; 193 | json resultMap; 194 | vector dis(k, 0); 195 | vector label(k, 0); 196 | for (size_t i = 0; i < resDistance.size(); i++) 197 | { 198 | resDistance[i] = resDistance[i] * 0.5 + 0.5; 199 | } 200 | for (size_t i = 0; i < n; i++) 201 | { 202 | label.assign(resLabels.begin() + i * k, resLabels.begin() + i * k + k); 203 | queryMap["labels"] = label; 204 | dis.assign(resDistance.begin() + i * k, resDistance.begin() + i * k + k); 205 | queryMap["distance"] = dis; 206 | resultMap[ids[i]] = queryMap; 207 | } 208 | resp["result"] = resultMap; 209 | responseBody = FillResponse(interfaceName, resp, res ? SEARCH_SUCCESS : SEARCH_FAIL, timepoints); 210 | httpCode = OK; 211 | } 212 | 213 | void RequestHandler::QueryRange(string const &request, string &responseBody, string &httpCode) 214 | { 215 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 216 | string interfaceName("/searchRange"); 217 | json value = json::parse(request); 218 | json resp = json::object(); 219 | if (value["nq"].is_null() || value["radius"].is_null() || value["queries"].is_null()) 220 | { 221 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 222 | httpCode = BAD_REQUEST; 223 | return; 224 | } 225 | if (!value["nq"].is_number()) 226 | { 227 | responseBody = FillResponse(interfaceName, resp, INVALID_NQ); 228 | httpCode = BAD_REQUEST; 229 | return; 230 | } 231 | 232 | if (!value["radius"].is_number()) 233 | { 234 | responseBody = FillResponse(interfaceName, resp, INVALID_RADIUS); 235 | httpCode = BAD_REQUEST; 236 | return; 237 | } 238 | if (!value["queries"].is_object()) 239 | { 240 | responseBody = FillResponse(interfaceName, resp, INVALID_QUERIES); 241 | httpCode = BAD_REQUEST; 242 | return; 243 | } 244 | float radius = value["radius"]; 245 | idx_t nq = value["nq"]; 246 | unsigned int dim = cp->dimension; 247 | vector features; 248 | features.reserve(nq * dim); 249 | vector ids(nq); 250 | json object = value["queries"]; 251 | if (object.size() == 0 || nq == 0) 252 | { 253 | responseBody = FillResponse(interfaceName, resp, EMPTY_QUERIES); 254 | httpCode = BAD_REQUEST; 255 | return; 256 | } 257 | unordered_map> queries = object; 258 | int count = 0; 259 | for (auto &query : queries) 260 | { 261 | ids[count] = query.first; 262 | if (query.second.size() != dim) 263 | { 264 | responseBody = FillResponse(interfaceName, resp, INVALID_DIMENSION); 265 | httpCode = BAD_REQUEST; 266 | return; 267 | } 268 | features.insert(features.end(), query.second.begin(), query.second.end()); 269 | count++; 270 | } 271 | 272 | faiss::RangeSearchResult *result = new faiss::RangeSearchResult(nq); 273 | bool res = search->search_range(nq, features.data(), radius, result); 274 | json queryMap; 275 | json resultMap; 276 | size_t num = 0; 277 | for (idx_t i = 0; i < nq; i++) 278 | { 279 | num = result->lims[i + 1] - result->lims[i]; 280 | vector dis(num, 0); 281 | vector label(num, 0); 282 | label.assign(result->labels + result->lims[i], result->labels + result->lims[i + 1]); 283 | dis.assign(result->distances + result->lims[i], result->distances + result->lims[i + 1]); 284 | queryMap["labels"] = label; 285 | queryMap["distance"] = dis; 286 | resultMap[ids[i]] = queryMap; 287 | } 288 | resp["result"] = resultMap; 289 | 290 | responseBody = FillResponse(interfaceName, resp, res ? SEARCH_RANGE_SUCCESS : SEARCH_RANGE_FAIL, timepoints); 291 | httpCode = OK; 292 | } 293 | 294 | void RequestHandler::Remove(string const &request, string &responseBody, string &httpCode) 295 | { 296 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 297 | string interfaceName("/delete"); 298 | json value = json::parse(request); 299 | json resp; 300 | if (value["ids"].is_null()) 301 | { 302 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 303 | httpCode = BAD_REQUEST; 304 | return; 305 | } 306 | if (!value["ids"].is_array()) 307 | { 308 | responseBody = FillResponse(interfaceName, resp, INVALID_IDS); 309 | httpCode = BAD_REQUEST; 310 | return; 311 | } 312 | json list = value["ids"]; 313 | if (list.size() == 0) 314 | { 315 | responseBody = FillResponse(interfaceName, resp, EMPTY_DATA); 316 | httpCode = BAD_REQUEST; 317 | return; 318 | } 319 | vector ids = list; 320 | bool res = true; 321 | long nremove = 0; 322 | long location = 0; 323 | long sum = 0; 324 | for (size_t i = 0; i < ids.size(); i++) 325 | { 326 | if (i == ids.size() - 1) 327 | { 328 | location = 1; 329 | if (i == 0) 330 | location = 2; 331 | } 332 | if (!search->remove_ids(faiss::IDSelectorRange(ids[i], ids[i] + 1), nremove, location)) 333 | { 334 | res = false; 335 | break; 336 | } 337 | location = -1; 338 | sum += nremove; 339 | } 340 | resp["ndelete"] = sum; 341 | responseBody = FillResponse(interfaceName, resp, res ? DELETE_SUCCESS : DELETE_FAIL, timepoints); 342 | httpCode = OK; 343 | } 344 | 345 | void RequestHandler::RemoveRange(string const &request, string &responseBody, string &httpCode) 346 | { 347 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 348 | string interfaceName("/deleteRange"); 349 | json value = json::parse(request); 350 | json resp; 351 | if (value["start"].is_null() || value["end"].is_null()) 352 | { 353 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 354 | httpCode = BAD_REQUEST; 355 | return; 356 | } 357 | if (!value["start"].is_number() || !value["end"].is_number()) 358 | { 359 | responseBody = FillResponse(interfaceName, resp, INVALID_RANGE); 360 | httpCode = BAD_REQUEST; 361 | return; 362 | } 363 | long st = value["start"]; 364 | long ed = value["end"]; 365 | bool res = true; 366 | long nremove = 0; 367 | res = search->remove_ids_range(faiss::IDSelectorRange(st, ed + 1), nremove); 368 | 369 | resp["ndelete"] = nremove; 370 | responseBody = FillResponse(interfaceName, resp, res ? DELETERANGE_SUCCESS : DELETERANGE_FAIL, timepoints); 371 | httpCode = OK; 372 | } 373 | void RequestHandler::QueryDays(string const &request, string &responseBody, string &httpCode) 374 | { 375 | TIMEPOINTS timepoints(1, CURRENT_SYS_TIMEPOINT); 376 | string interfaceName("/searchDays"); 377 | json value = json::parse(request); 378 | json resp = json::object(); 379 | if (value["topk"].is_null() || value["qtotal"].is_null() || value["queries"].is_null() || value["days"].is_null()) 380 | { 381 | responseBody = FillResponse(interfaceName, resp, MISSING_ARGUMENTS); 382 | httpCode = BAD_REQUEST; 383 | return; 384 | } 385 | if (!value["topk"].is_number()) 386 | { 387 | responseBody = FillResponse(interfaceName, resp, INVALID_TOPK); 388 | httpCode = BAD_REQUEST; 389 | return; 390 | } 391 | if (!value["qtotal"].is_number()) 392 | { 393 | responseBody = FillResponse(interfaceName, resp, INVALID_QTOTAL); 394 | httpCode = BAD_REQUEST; 395 | return; 396 | } 397 | if (!value["queries"].is_object()) 398 | { 399 | responseBody = FillResponse(interfaceName, resp, INVALID_QUERIES); 400 | httpCode = BAD_REQUEST; 401 | return; 402 | } 403 | if (!value["days"].is_array()) 404 | { 405 | responseBody = FillResponse(interfaceName, resp, INVALID_DAYS); 406 | httpCode = BAD_REQUEST; 407 | return; 408 | } 409 | unsigned int dim = cp->dimension; 410 | unsigned int n = value["qtotal"]; 411 | idx_t k = value["topk"]; 412 | vector features; 413 | features.reserve(n * dim); 414 | vector ids(n); 415 | vector days = value["days"]; 416 | unsigned int daysNum = days.size(); 417 | 418 | json object = value["queries"]; 419 | if (object.size() == 0 || n == 0) 420 | { 421 | responseBody = FillResponse(interfaceName, resp, EMPTY_QUERIES); 422 | httpCode = BAD_REQUEST; 423 | return; 424 | } 425 | unordered_map> queries = object; 426 | int count = 0; 427 | for (auto &query : queries) 428 | { 429 | ids[count] = query.first; 430 | if (query.second.size() != dim) 431 | { 432 | responseBody = FillResponse(interfaceName, resp, INVALID_DIMENSION); 433 | httpCode = BAD_REQUEST; 434 | return; 435 | } 436 | features.insert(features.end(), query.second.begin(), query.second.end()); 437 | count++; 438 | } 439 | 440 | vector resDistance(n * k, 0); 441 | vector resLabels(n * k, 0); 442 | 443 | list::iterator it = recentIndex.begin(); 444 | IndexPair st = *it; 445 | it++; 446 | IndexPair md = *it; 447 | it++; 448 | IndexPair ed = *it; 449 | string filePath; 450 | bool res; 451 | vector> ndis(n); 452 | vector> nlab(n); 453 | for (size_t i = 0; i < n; i++) 454 | { 455 | ndis[i].reserve(k * daysNum); 456 | nlab[i].reserve(k * daysNum); 457 | } 458 | for (size_t i = 0; i < daysNum; i++) 459 | { 460 | if (days[i] == st.date) 461 | searchdays = st.index; 462 | else if (days[i] == md.date) 463 | searchdays = md.index; 464 | else if (days[i] == ed.date) 465 | searchdays = ed.index; 466 | else 467 | { 468 | filePath = "../indexFile/" + days[i] + ".faissIndex"; 469 | if ((access(filePath.c_str(), 0)) == -1) 470 | { 471 | responseBody = FillResponse(interfaceName, resp, NOTEXIST_INDEX); 472 | httpCode = BAD_REQUEST; 473 | return; 474 | } 475 | searchdays->read_index(filePath.c_str()); 476 | } 477 | res = searchdays->search(n, features.data(), k, resDistance.data(), resLabels.data()); 478 | if (!res) 479 | { 480 | responseBody = FillResponse(interfaceName, resp, SEARCHDAYS_FAIL); 481 | httpCode = OK; 482 | return; 483 | } 484 | for (size_t i = 0; i < n; i++) 485 | { 486 | nlab[i].insert(nlab[i].end(), resLabels.begin() + i * k, resLabels.begin() + i * k + k); 487 | ndis[i].insert(ndis[i].end(), resDistance.begin() + i * k, resDistance.begin() + i * k + k); 488 | } 489 | } 490 | 491 | json queryMap; 492 | json resultMap; 493 | vector kdis(k, 0); 494 | vector klabel(k, 0); 495 | Ids_lab_pair ipair; 496 | ids_lab.resize(k * daysNum); 497 | for (size_t i = 0; i < n; i++) 498 | { 499 | for (size_t j = 0; j < ndis[i].size(); j++) 500 | { 501 | ipair.dis = ndis[i][j]; 502 | ipair.label = nlab[i][j]; 503 | ids_lab[j] = (ipair); 504 | } 505 | sort(ids_lab.begin(), ids_lab.end(), disSort); 506 | for (idx_t i = 0; i < k; i++) 507 | { 508 | klabel[i] = ids_lab[i].label; 509 | kdis[i] = ids_lab[i].dis; 510 | } 511 | 512 | queryMap["labels"] = klabel; 513 | queryMap["distance"] = kdis; 514 | resultMap[ids[i]] = queryMap; 515 | } 516 | 517 | resp["result"] = resultMap; 518 | responseBody = FillResponse(interfaceName, resp, SEARCHDAYS_SUCCESS, timepoints); 519 | httpCode = OK; 520 | } 521 | -------------------------------------------------------------------------------- /src/libRestServer/RequestHandler.h: -------------------------------------------------------------------------------- 1 | #ifndef REQUEST_HANDLER_H_ 2 | #define REQUEST_HANDLER_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "common/json.h" 11 | #include "common/configParams.h" 12 | #include "libSearch/FaissInterface.h" 13 | #include "SearchProcessor.h" 14 | #define CURRENT_SYS_TIMEPOINT std::chrono::system_clock::now() 15 | #define CURRENT_STEADY_TIMEPOINT std::chrono::steady_clock::now() 16 | typedef std::vector TIMEPOINTS; 17 | 18 | namespace dev 19 | { 20 | /** 21 | * @brief handle request 22 | * 23 | */ 24 | class RequestHandler 25 | { 26 | 27 | public: 28 | struct IndexPair 29 | { 30 | std::string date; 31 | std::shared_ptr index; 32 | }; 33 | std::list recentIndex; 34 | 35 | struct Ids_lab_pair 36 | { 37 | float dis; 38 | long label; 39 | }; 40 | std::vector ids_lab; 41 | 42 | RequestHandler(dev::ConfigParams *_cp); 43 | ~RequestHandler(); 44 | /// @brief Add vectors to the index 45 | void Add(std::string const &request, std::string &response, string &httpCode); 46 | /// @brief for each query vector, find its k nearest neighbors in the database 47 | void Query(std::string const &request, std::string &response, string &httpCode); 48 | /// @brief for each query vector, find all vectors with distance < radius 49 | void QueryRange(std::string const &request, std::string &response, string &httpCode); 50 | /// @brief remove id from index 51 | void Remove(std::string const &request, std::string &response, string &httpCode); 52 | /// @brief remove ids 53 | void RemoveRange(std::string const &request, std::string &response, string &httpCode); 54 | /// @brief reload data 55 | void Reconfig(std::string const &request, std::string &response, string &httpCode); 56 | void QueryDays(std::string const &request, std::string &response, string &httpCode); 57 | static bool disSort(Ids_lab_pair &a, Ids_lab_pair &b) { return (a.dis < b.dis); } 58 | 59 | std::shared_ptr search = nullptr; 60 | std::shared_ptr searchdays = nullptr; 61 | /** 62 | * @brief generate response string 63 | * 64 | * @param interface name 65 | * @param response json 66 | * @param errorMsg 67 | * @param timePoints 68 | * @return std::string 69 | */ 70 | std::string FillResponse(const std::string &interface, nlohmann::json &response, const char *errorMsg = NULL, const TIMEPOINTS &timePoints = TIMEPOINTS(1, CURRENT_SYS_TIMEPOINT)); 71 | 72 | private: 73 | void WriteLog(const std::string &interface, const nlohmann::json &resp, const TIMEPOINTS &timePoints); 74 | const dev::ConfigParams *cp = nullptr; 75 | std::shared_ptr search_processor = nullptr; 76 | 77 | }; // class RequestHandler 78 | } // namespace dev 79 | 80 | #endif //REQUEST_HANDLER_H_ -------------------------------------------------------------------------------- /src/libRestServer/RestServer.cpp: -------------------------------------------------------------------------------- 1 | #include "RestServer.h" 2 | 3 | using namespace std; 4 | using namespace Pistache; 5 | 6 | namespace dev 7 | { 8 | 9 | void handleReady(const Rest::Request &, Http::ResponseWriter response) 10 | { 11 | response.send(Http::Code::Ok, "ready"); 12 | } 13 | } 14 | 15 | RestServer::RestServer(dev::ConfigParams *_cp) 16 | : cp(_cp), httpEndpoint(new Http::Endpoint(Address(cp->ip, Port(cp->port)))), 17 | handler(new dev::RequestHandler(_cp)) 18 | { 19 | init(cp->threads); 20 | } 21 | 22 | void RestServer::init(size_t thr) 23 | { 24 | auto opts = Http::Endpoint::options() 25 | .threads(thr) 26 | .flags(Tcp::Options::InstallSignalHandler | Tcp::Options::ReuseAddr); 27 | httpEndpoint->init(opts); 28 | setupRoutes(); 29 | } 30 | 31 | void RestServer::setupRoutes() 32 | { 33 | using namespace Rest; 34 | Routes::Post(router, "/add", Routes::bind(&RestServer::AddData, this)); 35 | Routes::Post(router, "/search", Routes::bind(&RestServer::SearchData, this)); 36 | Routes::Post(router, "/searchRange", Routes::bind(&RestServer::SearchRange, this)); 37 | Routes::Post(router, "/searchDays", Routes::bind(&RestServer::SearchDays, this)); 38 | Routes::Post(router, "/delete", Routes::bind(&RestServer::DeleteData, this)); 39 | Routes::Post(router, "/deleteRange", Routes::bind(&RestServer::DeleteRange, this)); 40 | Routes::Post(router, "/reconfig", Routes::bind(&RestServer::ReconfigData, this)); 41 | Routes::Get(router, "/ready", Routes::bind(&dev::handleReady)); 42 | } 43 | 44 | void RestServer::AddData(const Rest::Request &request, Http::ResponseWriter response) 45 | { 46 | string responseBody; 47 | string httpCode; 48 | handler->Add(request.body(), responseBody, httpCode); 49 | if (httpCode == "200") 50 | response.send(Http::Code::Ok, responseBody); 51 | else 52 | response.send(Http::Code::Bad_Request, responseBody); 53 | } 54 | 55 | void RestServer::SearchData(const Rest::Request &request, Http::ResponseWriter response) 56 | { 57 | string responseBody; 58 | string httpCode; 59 | handler->Query(request.body(), responseBody, httpCode); 60 | if (httpCode == "200") 61 | response.send(Http::Code::Ok, responseBody); 62 | else 63 | response.send(Http::Code::Bad_Request, responseBody); 64 | } 65 | 66 | void RestServer::SearchRange(const Rest::Request &request, Http::ResponseWriter response) 67 | { 68 | string responseBody; 69 | string httpCode; 70 | handler->QueryRange(request.body(), responseBody, httpCode); 71 | if (httpCode == "200") 72 | response.send(Http::Code::Ok, responseBody); 73 | else 74 | response.send(Http::Code::Bad_Request, responseBody); 75 | } 76 | 77 | void RestServer::SearchDays(const Rest::Request &request, Http::ResponseWriter response) 78 | { 79 | string responseBody; 80 | string httpCode; 81 | handler->QueryDays(request.body(), responseBody, httpCode); 82 | if (httpCode == "200") 83 | response.send(Http::Code::Ok, responseBody); 84 | else 85 | response.send(Http::Code::Bad_Request, responseBody); 86 | } 87 | 88 | void RestServer::DeleteData(const Rest::Request &request, Http::ResponseWriter response) 89 | { 90 | string responseBody; 91 | string httpCode; 92 | handler->Remove(request.body(), responseBody, httpCode); 93 | if (httpCode == "200") 94 | response.send(Http::Code::Ok, responseBody); 95 | else 96 | response.send(Http::Code::Bad_Request, responseBody); 97 | } 98 | 99 | void RestServer::DeleteRange(const Rest::Request &request, Http::ResponseWriter response) 100 | { 101 | string responseBody; 102 | string httpCode; 103 | handler->RemoveRange(request.body(), responseBody, httpCode); 104 | if (httpCode == "200") 105 | response.send(Http::Code::Ok, responseBody); 106 | else 107 | response.send(Http::Code::Bad_Request, responseBody); 108 | } 109 | 110 | void RestServer::ReconfigData(const Rest::Request &request, Http::ResponseWriter response) 111 | { 112 | string responseBody; 113 | string httpCode; 114 | handler->Reconfig(request.body(), responseBody, httpCode); 115 | if (httpCode == "200") 116 | response.send(Http::Code::Ok, responseBody); 117 | else 118 | response.send(Http::Code::Bad_Request, responseBody); 119 | } -------------------------------------------------------------------------------- /src/libRestServer/RestServer.h: -------------------------------------------------------------------------------- 1 | #ifndef REST_SERVER_H_ 2 | #define REST_SERVER_H_ 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include "common/configParams.h" 10 | #include "RequestHandler.h" 11 | 12 | using namespace Pistache; 13 | 14 | /** 15 | * @brief http server 16 | * 17 | */ 18 | class RestServer 19 | { 20 | public: 21 | RestServer(dev::ConfigParams *_cp); 22 | /** 23 | * @brief init httpEndpoint 24 | * 25 | * @param thr http threads 26 | */ 27 | void init(size_t thr = 2); 28 | 29 | void start() 30 | { 31 | httpEndpoint->setHandler(router.handler()); 32 | std::cout << "Restserver is started." << std::endl; 33 | httpEndpoint->serve(); 34 | } 35 | 36 | void shutdown() 37 | { 38 | httpEndpoint->shutdown(); 39 | } 40 | 41 | private: 42 | /// @brief set routes, according request url and routes to decide call witch function 43 | void setupRoutes(); 44 | void AddData(const Rest::Request &request, Http::ResponseWriter response); 45 | void SearchData(const Rest::Request &request, Http::ResponseWriter response); 46 | void SearchRange(const Rest::Request &request, Http::ResponseWriter response); 47 | void SearchDays(const Rest::Request &request, Http::ResponseWriter response); 48 | void DeleteData(const Rest::Request &request, Http::ResponseWriter response); 49 | void DeleteRange(const Rest::Request &request, Http::ResponseWriter response); 50 | void ReconfigData(const Rest::Request &request, Http::ResponseWriter response); 51 | 52 | typedef std::mutex Lock; 53 | typedef std::lock_guard Guard; 54 | dev::ConfigParams *cp; 55 | std::shared_ptr httpEndpoint; 56 | Rest::Router router; 57 | std::unique_ptr handler; 58 | }; 59 | 60 | #endif //REST_SERVER_H_ -------------------------------------------------------------------------------- /src/libRestServer/SearchProcessor.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "RequestHandler.h" 5 | #include "common/easylog++.h" 6 | #include "SearchProcessor.h" 7 | #include "common/error.h" 8 | 9 | using namespace std; 10 | using namespace dev; 11 | using json = nlohmann::json; 12 | 13 | string SearchProcessor::getDate() 14 | { 15 | time_t nowtime = time(NULL); 16 | char tmp[64]; 17 | strftime(tmp, sizeof(tmp), "%Y%m%d", localtime(&nowtime)); 18 | string date(tmp); 19 | return date; 20 | } 21 | 22 | void SearchProcessor::doWork() 23 | { 24 | while (!exitThread) 25 | { 26 | std::this_thread::sleep_for(std::chrono::seconds(1)); 27 | string filePath; 28 | string today = getDate(); 29 | time_t now = time(0); 30 | tm *ltm = localtime(&now); 31 | 32 | string current = std::to_string(ltm->tm_hour) + ":" + std::to_string(ltm->tm_min) + ":" + std::to_string(ltm->tm_sec); 33 | RequestHandler::IndexPair idxp; 34 | idxp.date = "null"; 35 | idxp.index = nullptr; 36 | handler->recentIndex.push_back(idxp); 37 | handler->recentIndex.push_back(idxp); 38 | handler->recentIndex.push_back(idxp); 39 | 40 | if (current == "0:0:0") 41 | { 42 | filePath = "../indexFile/" + today + ".faissIndex"; 43 | searchprocessor = handler->search; 44 | searchprocessor->write_index(filePath.c_str()); 45 | 46 | if (handler->recentIndex.size() < 3) 47 | { 48 | idxp.date = today; 49 | idxp.index = searchprocessor; 50 | handler->recentIndex.push_back(idxp); 51 | } 52 | else if (handler->recentIndex.size() == 3) 53 | { 54 | handler->recentIndex.erase(handler->recentIndex.begin()); 55 | idxp.date = today; 56 | idxp.index = searchprocessor; 57 | handler->recentIndex.push_back(idxp); 58 | } 59 | searchprocessor.reset(new faissSearch(searchMethod, d, usegpu)); 60 | handler->search = searchprocessor; 61 | } 62 | } 63 | } 64 | 65 | SearchProcessor::SearchProcessor(dev::ConfigParams *_cp) : exitThread(false) 66 | { 67 | searchMethod = _cp->searchFactory; 68 | usegpu = _cp->usegpu; 69 | d = _cp->dimension; 70 | searchprocessor.reset(new faissSearch(searchMethod, d, usegpu)); 71 | try 72 | { 73 | threads = std::thread(&SearchProcessor::doWork, this); 74 | } 75 | catch (std::exception &e) 76 | { 77 | LOG(ERROR) << "FaceData thread init failed!"; 78 | LOG(ERROR) << e.what(); 79 | } 80 | catch (...) 81 | { 82 | LOG(ERROR) << "FaceData thread init failed!"; 83 | } 84 | } 85 | 86 | SearchProcessor::~SearchProcessor() 87 | { 88 | exitThread = true; 89 | if (threads.joinable()) 90 | { 91 | threads.join(); 92 | } 93 | } 94 | 95 | void SearchProcessor::RegistAPI(RequestHandler *handler_) 96 | { 97 | handler = handler_; 98 | } -------------------------------------------------------------------------------- /src/libRestServer/SearchProcessor.h: -------------------------------------------------------------------------------- 1 | #ifndef SEARCH_PROCESS_H_ 2 | #define SEARCH_PROCESS_H_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "common/easylog++.h" 12 | #include "common/configParams.h" 13 | 14 | namespace dev 15 | { 16 | class RequestHandler; 17 | } 18 | namespace dev 19 | { 20 | 21 | class SearchProcessor 22 | { 23 | public: 24 | SearchProcessor(dev::ConfigParams *_cp); 25 | ~SearchProcessor(); 26 | void RegistAPI(RequestHandler *handler_); 27 | std::string getDate(); 28 | void doWork(); 29 | 30 | std::thread threads; 31 | std::string current; 32 | std::atomic exitThread; 33 | std::string searchMethod; 34 | bool usegpu; 35 | int d; 36 | std::shared_ptr searchprocessor; 37 | RequestHandler *handler; 38 | }; 39 | } 40 | #endif //SEARCH_PROCESS_H_ 41 | -------------------------------------------------------------------------------- /src/libSearch/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SRC_LIST "*.cpp") 2 | file(GLOB HEADERS "*.h") 3 | 4 | add_library(search STATIC ${SRC_LIST} ${HEADERS}) 5 | target_include_directories(search PRIVATE .. ${FAISS_INCLUDE_DIR}) 6 | 7 | if(NOT CUDA_FOUND) 8 | message(STATUS "Build ${CMAKE_PROJECT_NAME} without CUDA support.") 9 | include_directories(${FAISS_INCLUDE_DIR}) 10 | target_link_libraries(search Faiss::CPU blas lapack) 11 | elseif(CUDA_FOUND) 12 | add_definitions(-DCUDA_VERSION=${CUDA_VERSION_STRING}) 13 | message(STATUS "Build ${CMAKE_PROJECT_NAME} with CUDA : " ${CUDA_VERSION}) 14 | include_directories(${FAISS_INCLUDE_DIR} ${CUDA_INCLUDE_DIRS}) 15 | target_link_libraries(search Faiss::GPU Faiss::CPU blas lapack ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES}) 16 | endif() 17 | 18 | add_dependencies(search Faiss::CPU) 19 | -------------------------------------------------------------------------------- /src/libSearch/FaissInterface.cpp: -------------------------------------------------------------------------------- 1 | #include "common/easylog++.h" 2 | #include "FaissInterface.h" 3 | 4 | #ifdef CUDA_VERSION 5 | #include "faiss/gpu/GpuAutoTune.h" 6 | #include "faiss/gpu/StandardGpuResources.h" 7 | #include "faiss/gpu/utils/DeviceUtils.h" 8 | #endif 9 | 10 | using namespace std; 11 | using namespace dev; 12 | 13 | faissSearch::faissSearch(const string &indexKey, const int dimension, bool useGPU, bool initGpuResources, faiss::MetricType metric) 14 | { 15 | faissIndex.reset(faiss::index_factory(dimension, indexKey.c_str(), metric)); 16 | is_trained = faissIndex->is_trained; 17 | usegpu = useGPU; 18 | ntotal = faissIndex->ntotal; 19 | dim = dimension; 20 | if (useGPU && initGpuResources) 21 | { 22 | #ifdef CUDA_VERSION 23 | ngpus = faiss::gpu::getNumDevices(); 24 | for (int i = 0; i < ngpus; i++) 25 | { 26 | res.push_back(new faiss::gpu::StandardGpuResources); 27 | devs.push_back(i); 28 | } 29 | options->indicesOptions = faiss::gpu::INDICES_64_BIT; 30 | options->useFloat16CoarseQuantizer = false; 31 | options->useFloat16 = false; 32 | options->usePrecomputed = false; 33 | options->reserveVecs = 0; 34 | options->storeTransposed = false; 35 | options->verbose = true; 36 | initGpuResources = true; 37 | #else 38 | LOG(WARNING) << "This release doesn't support GPU search"; 39 | #endif 40 | } 41 | if (usegpu) 42 | { 43 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get())); 44 | } 45 | } 46 | 47 | bool faissSearch::reset() 48 | { 49 | try 50 | { 51 | faissIndex->reset(); 52 | ntotal = faissIndex->ntotal; 53 | } 54 | catch (std::exception &e) 55 | { 56 | LOG(ERROR) << e.what(); 57 | return false; 58 | } 59 | return true; 60 | } 61 | 62 | bool faissSearch::load(const string &filePath, unordered_map> &data) 63 | { 64 | ifstream in(filePath, ifstream::binary | ifstream::in); 65 | if (!in.is_open()) 66 | { 67 | return false; 68 | } 69 | int header[5] = {0}; 70 | in.read((char *)header, sizeof(header)); 71 | int dimension = header[0]; 72 | unsigned int count = header[1]; 73 | int fileFormatVersion = header[2]; 74 | if (fileFormatVersion != 1 || dim != dimension) 75 | { 76 | in.close(); 77 | return false; 78 | } 79 | unsigned int id = 0; 80 | for (size_t i = 0; i < count; i++) 81 | { 82 | vector feature(dim, 0); 83 | in.read((char *)&id, sizeof(int)); 84 | in.read((char *)feature.data(), dim * sizeof(float)); 85 | data[static_cast(id)] = move(feature); 86 | } 87 | in.close(); 88 | return true; 89 | } 90 | 91 | bool faissSearch::load(const string &filePath, vector &ids, vector &data) 92 | { 93 | ifstream in(filePath, ifstream::binary | ifstream::in); 94 | if (!in.is_open()) 95 | { 96 | return false; 97 | } 98 | int header[5] = {0}; 99 | in.read((char *)header, sizeof(header)); 100 | int dimension = header[0]; 101 | unsigned int count = header[1]; 102 | int fileFormatVersion = header[2]; 103 | if (fileFormatVersion != 1 || dim != dimension) 104 | { 105 | in.close(); 106 | return false; 107 | } 108 | data.clear(); 109 | data.resize(count * dim); 110 | ids.clear(); 111 | ids.resize(count); 112 | unsigned int id = 0; 113 | vector feature(dim); 114 | for (size_t i = 0; i < count; i++) 115 | { 116 | 117 | in.read((char *)&id, sizeof(int)); 118 | ids[i] = static_cast(id); 119 | in.read((char *)&data[i * dim], dim * sizeof(float)); 120 | } 121 | in.close(); 122 | return true; 123 | } 124 | 125 | bool faissSearch::write_index(const char *filePath) 126 | { 127 | try 128 | { 129 | if (usegpu) 130 | { 131 | faissIndex.reset(faiss::gpu::index_gpu_to_cpu(faissIndex.get())); 132 | } 133 | faiss::write_index(faissIndex.get(), filePath); 134 | if (usegpu) 135 | { 136 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 137 | } 138 | } 139 | catch (std::exception &e) 140 | { 141 | LOG(ERROR) << e.what(); 142 | if (usegpu) 143 | { 144 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 145 | } 146 | return false; 147 | } 148 | return true; 149 | } 150 | 151 | bool faissSearch::read_index(const char *filePath) 152 | { 153 | try 154 | { 155 | faissIndex.reset(faiss::read_index(filePath)); 156 | if (usegpu) 157 | { 158 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 159 | } 160 | ntotal = faissIndex->ntotal; 161 | } 162 | catch (std::exception &e) 163 | { 164 | LOG(ERROR) << e.what(); 165 | return false; 166 | } 167 | return true; 168 | } 169 | 170 | void faissSearch::train(idx_t n, const float *data) 171 | { 172 | if (!is_trained) 173 | faissIndex->train(n, data); 174 | is_trained = faissIndex->is_trained; 175 | } 176 | 177 | void faissSearch::add(const vector> &data) 178 | { 179 | for (auto item : data) 180 | faissIndex->add(1, item.data()); 181 | ntotal += data.size(); 182 | } 183 | 184 | void faissSearch::add(idx_t n, const vector &data) 185 | { 186 | 187 | faissIndex->add(n, data.data()); 188 | ntotal += n; 189 | } 190 | 191 | bool faissSearch::add_with_ids(idx_t n, const float *xdata, const long *xids) 192 | { 193 | try 194 | { 195 | faissIndex->add_with_ids(n, xdata, xids); 196 | ntotal += n; 197 | } 198 | catch (std::exception &e) 199 | { 200 | LOG(ERROR) << e.what(); 201 | return false; 202 | } 203 | return true; 204 | } 205 | 206 | bool faissSearch::search(idx_t n, const float *data, idx_t k, float *distances, long *labels) const 207 | { 208 | try 209 | { 210 | faissIndex->search(n, data, k, distances, labels); 211 | } 212 | catch (std::exception &e) 213 | { 214 | LOG(ERROR) << e.what(); 215 | return false; 216 | } 217 | return true; 218 | } 219 | 220 | bool faissSearch::search_range(idx_t n, const float *data, float radius, faiss::RangeSearchResult *result) 221 | { 222 | try 223 | { 224 | if (usegpu) 225 | { 226 | faissIndex.reset(faiss::gpu::index_gpu_to_cpu(faissIndex.get())); 227 | } 228 | faissIndex->range_search(n, data, radius, result); 229 | if (usegpu) 230 | { 231 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 232 | } 233 | } 234 | catch (std::exception &e) 235 | { 236 | LOG(ERROR) << e.what(); 237 | if (usegpu) 238 | { 239 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 240 | } 241 | return false; 242 | } 243 | return true; 244 | } 245 | 246 | bool faissSearch::remove_ids(const faiss::IDSelector &sel, long &nremove, long &location) 247 | { 248 | if (usegpu && (location == 0 || location == 2)) 249 | { 250 | faissIndex.reset(faiss::gpu::index_gpu_to_cpu(faissIndex.get())); 251 | } 252 | try 253 | { 254 | nremove = faissIndex->remove_ids(sel); 255 | ntotal -= nremove; 256 | if ((location == 1 || location == 2) && usegpu) 257 | { 258 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 259 | } 260 | } 261 | catch (std::exception &e) 262 | { 263 | LOG(ERROR) << e.what(); 264 | if (usegpu) 265 | { 266 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 267 | } 268 | return false; 269 | } 270 | return true; 271 | } 272 | 273 | bool faissSearch::remove_ids_range(const faiss::IDSelector &sel, long &nremove) 274 | { 275 | if (usegpu) 276 | { 277 | faissIndex.reset(faiss::gpu::index_gpu_to_cpu(faissIndex.get())); 278 | } 279 | try 280 | { 281 | nremove = faissIndex->remove_ids(sel); 282 | ntotal -= nremove; 283 | if (usegpu) 284 | { 285 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 286 | } 287 | } 288 | catch (std::exception &e) 289 | { 290 | LOG(ERROR) << e.what(); 291 | if (usegpu) 292 | { 293 | faissIndex.reset(faiss::gpu::index_cpu_to_gpu_multiple(res, devs, faissIndex.get(), options)); 294 | } 295 | return false; 296 | } 297 | return true; 298 | } 299 | 300 | bool faissSearch::index_display() 301 | { 302 | try 303 | { 304 | faissIndex->display(); 305 | return true; 306 | } 307 | catch (std::exception &e) 308 | { 309 | LOG(ERROR) << e.what(); 310 | return false; 311 | } 312 | } 313 | -------------------------------------------------------------------------------- /src/libSearch/FaissInterface.h: -------------------------------------------------------------------------------- 1 | #ifndef FAISS_SEARCH_INTERFACE_H_ 2 | #define FAISS_SEARCH_INTERFACE_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "faiss/AutoTune.h" 10 | #include "faiss/AuxIndexStructures.h" 11 | #include "faiss/gpu/GpuAutoTune.h" 12 | #include "faiss/index_io.h" 13 | 14 | namespace dev 15 | { 16 | using std::string; 17 | using std::unordered_map; 18 | using std::vector; 19 | typedef long idx_t; 20 | 21 | class faissSearch 22 | { 23 | public: 24 | int dim; ///< vector dimension 25 | idx_t ntotal; ///< total nb of indexed vectors 26 | bool is_trained; ///< set if the Index does not require training, or if training is done already 27 | 28 | faissSearch(const string &indexKey, const int d, bool useGPU = true, bool initGpuResources = false, faiss::MetricType metric = faiss::METRIC_INNER_PRODUCT); 29 | virtual bool reset(); 30 | /** 31 | * @brief Load dataset from disk 32 | * 33 | * @param filePath dataset file path 34 | * @param data input dataset 35 | */ 36 | virtual bool load(const string &filePath, unordered_map> &data); 37 | 38 | virtual bool load(const string &filePath, vector &ids, vector &data); 39 | /** 40 | * @brief write index to the file 41 | * 42 | * @param filePath the path of the index file 43 | */ 44 | virtual bool write_index(const char *filePath); 45 | /** 46 | * @brief read index from the file 47 | * 48 | * @param filePath the path of the index file 49 | */ 50 | virtual bool read_index( const char *filePath); 51 | /** 52 | * @brief Perform training on a representative set of vectors 53 | * 54 | * @param data training vectors, size n *d 55 | */ 56 | virtual void train(idx_t n, const float *data); 57 | /** 58 | * @brief Add n vectors of dimension d to the index. 59 | * 60 | * @param data input matrix, size n * d 61 | */ 62 | virtual void add(const vector> &data); 63 | virtual void add(idx_t n, const vector &data); 64 | /** 65 | * @brief Same as add, but stores xids instead of sequential ids. 66 | * 67 | * @param data input matrix, size n * d 68 | * @param if ids is not empty ids for the vectors 69 | */ 70 | virtual bool add_with_ids(idx_t n, const float *xdata, const long *xids); 71 | /** 72 | * @brief for each query vector, find its k nearest neighbors in the database 73 | * 74 | * @param n queries size 75 | * @param data query vectors 76 | * @param k top k nearest neighbors 77 | * @param distances top k nearest distances 78 | * @param labels neighbors of the queries 79 | */ 80 | virtual bool search(idx_t n, const float *data, idx_t k, float *distances, long *labels) const; 81 | virtual bool search_range(idx_t n, const float *x, float radius, faiss::RangeSearchResult *result); 82 | virtual bool remove_ids(const faiss::IDSelector &sel, long &nremove, long &location); 83 | virtual bool remove_ids_range(const faiss::IDSelector &sel, long &nremove); 84 | virtual bool index_display(); 85 | 86 | private: 87 | std::shared_ptr faissIndex = nullptr; 88 | int ngpus = 0; 89 | bool usegpu = true; 90 | vector res; 91 | vector devs; 92 | faiss::gpu::GpuMultipleClonerOptions *options = new faiss::gpu::GpuMultipleClonerOptions(); 93 | }; 94 | } 95 | 96 | #endif //FAISS_SEARCH_INTERFACE_H_ -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "common/easylog++.h" 2 | #include "libRestServer/RestServer.h" 3 | 4 | using namespace std; 5 | using namespace dev; 6 | 7 | void help() 8 | { 9 | cout << "Usage verifier [config.json]" << endl; 10 | } 11 | 12 | int main(int argc, char *argv[]) 13 | { 14 | string configPath("config.json"); 15 | if (argc == 2) 16 | { 17 | configPath = argv[1]; 18 | cout << "Config file : " << configPath << endl; 19 | } 20 | else if (argc == 1) 21 | { 22 | cout << "Config file : config.json" << endl; 23 | } 24 | else 25 | { 26 | help(); 27 | return -1; 28 | } 29 | 30 | std::shared_ptr cp = nullptr; 31 | try 32 | { 33 | cp.reset(new ConfigParams(configPath)); 34 | } 35 | catch (std::exception &e) 36 | { 37 | cerr << e.what() << endl; 38 | return -1; 39 | } 40 | 41 | cp->printParams(); 42 | 43 | try 44 | { 45 | // start rest-http-server 46 | RestServer restServer(cp.get()); 47 | restServer.start(); 48 | restServer.shutdown(); 49 | } 50 | catch (std::exception &e) 51 | { 52 | LOG(ERROR) << e.what(); 53 | cerr << e.what(); 54 | return -1; 55 | } 56 | return 0; 57 | } 58 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if(BUILD_TEST) 2 | 3 | include_directories(${CMAKE_SOURCE_DIR}/src) 4 | file(GLOB testSrcs ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 5 | 6 | foreach(source ${testSrcs}) 7 | get_filename_component(name ${source} NAME_WE) 8 | 9 | # target 10 | add_executable(${name} ${source}) 11 | target_link_libraries(${name} search common) 12 | 13 | # Install 14 | # install(TARGETS ${name} DESTINATION bin) 15 | 16 | # Unit test 17 | add_test("test_${name}" ${CMAKE_BINARY_DIR}/bin/${name}) 18 | endforeach(source) 19 | 20 | endif(BUILD_TEST) 21 | -------------------------------------------------------------------------------- /test/FaissCPUSearch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "common/easylog++.h" 7 | #include "libSearch/FaissInterface.h" 8 | 9 | INITIALIZE_EASYLOGGINGPP 10 | 11 | using namespace std; 12 | using namespace dev; 13 | 14 | int main() 15 | { 16 | int d = 256; // dimension 17 | int nb = 1000000; // database size 18 | int nq = 5; // nb of queries 19 | int k = 4; 20 | 21 | // this is typically the fastest one. 22 | string searchMethod("IDMap,Flat"); 23 | 24 | // these ones have better memory usage 25 | // string searchMethod("IVF64,Flat"); 26 | // string searchMethod("PQ32"); 27 | // string searchMethod("PCA80,Flat"); 28 | // string searchMethod("IVF64,PQ8+16"); 29 | // string searchMethod("IVF64,PQ32"); 30 | // string searchMethod("IMI2x8,PQ32"); 31 | // string searchMethod("OPQ16_64,IMI2x8,PQ8+16"); 32 | 33 | vector xb(d * nb, 0); 34 | vector xid(nb, 0); 35 | vector xq(d * nq, 0); 36 | vector sumb(nb, 0); 37 | vector sumq(nq, 0); 38 | for (int i = 0; i < nb; i++) 39 | { 40 | for (int j = 0; j < d; j++) 41 | { 42 | xb[d * i + j] = drand48(); 43 | sumb[i] += xb[d * i + j] * xb[d * i + j]; 44 | } 45 | sumb[i] = sqrt(sumb[i]); 46 | xid[i] = i; 47 | } 48 | for (int i = 0; i < nb; i++) 49 | { 50 | for (int j = 0; j < d; j++) 51 | { 52 | xb[d * i + j] /= sumb[i]; 53 | } 54 | } 55 | for (int i = 0; i < nq; i++) 56 | { 57 | for (int j = 0; j < d; j++) 58 | { 59 | xq[d * i + j] = drand48(); 60 | sumq[i] += xq[d * i + j] * xq[d * i + j]; 61 | } 62 | sumq[i] = sqrt(sumq[i]); 63 | } 64 | for (int i = 0; i < nq; i++) 65 | { 66 | for (int j = 0; j < d; j++) 67 | { 68 | xq[d * i + j] /= sumq[i]; 69 | } 70 | } 71 | 72 | shared_ptr index(new faissSearch(searchMethod, d, false)); 73 | cout << "Search " << nq << " from " << nb << " | Use " << searchMethod << endl; 74 | 75 | chrono::system_clock::time_point t1 = chrono::system_clock::now(); 76 | // index->train(nb, xb.data()); 77 | chrono::system_clock::time_point t2 = chrono::system_clock::now(); 78 | cout << "|||Training time : " << (chrono::duration_cast(t2 - t1)).count() << " ms|||" << endl; 79 | cout << "is_trained = " << (index->is_trained ? "true" : "false") << endl; 80 | // assert(index->is_trained); 81 | // index->add(nb, xb); // add vectors to the index 82 | index->add_with_ids(nb, xb.data(), xid.data()); // add vectors to the index 83 | cout << "ntotal = " << index->ntotal << endl; 84 | chrono::system_clock::time_point t3 = chrono::system_clock::now(); 85 | cout << "|||Adding time : " << (chrono::duration_cast(t3 - t2)).count() << " ms|||" << endl; 86 | 87 | { // sanity check: search 5 first vectors of xb 88 | vector I(k * nq, 0); 89 | vector D(k * nq, 0); 90 | vector xb5(xb.begin(), xb.begin() + nq * d); 91 | index->search(nq, xb5.data(), k, D.data(), I.data()); 92 | 93 | // print results 94 | cout << "I=" << endl; 95 | for (int i = 0; i < nq; i++) 96 | { 97 | for (int j = 0; j < k; j++) 98 | cout << setprecision(3) << setw(5) << I[i * k + j] << " "; 99 | cout << endl; 100 | } 101 | 102 | cout << "D=" << endl; 103 | for (int i = 0; i < nq; i++) 104 | { 105 | for (int j = 0; j < k; j++) 106 | cout << setw(5) << D[i * k + j] << " "; 107 | cout << endl; 108 | } 109 | } 110 | 111 | { // search xq 112 | 113 | vector I(k * nq, 0); 114 | vector D(k * nq, 0); 115 | chrono::system_clock::time_point t4 = chrono::system_clock::now(); 116 | index->search(nq, xq.data(), k, D.data(), I.data()); 117 | chrono::system_clock::time_point t5 = chrono::system_clock::now(); 118 | cout << "|||Searching time: " << (chrono::duration_cast(t5 - t4)).count() << " ms|||" << endl; 119 | 120 | // print results 121 | cout << "I (5 first results)=" << endl; 122 | for (int i = 0; i < nq; i++) 123 | { 124 | for (int j = 0; j < k; j++) 125 | cout << setw(8) << I[i * k + j] << " "; 126 | cout << endl; 127 | } 128 | 129 | cout << "D (5 first results)=" << endl; 130 | for (int i = 0; i < nq; i++) 131 | { 132 | for (int j = 0; j < k; j++) 133 | cout << setw(8) << D[i * k + j] << " "; 134 | cout << endl; 135 | } 136 | } 137 | 138 | return 0; 139 | } -------------------------------------------------------------------------------- /test/FaissGPUSearch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "common/easylog++.h" 7 | #include "libSearch/FaissInterface.h" 8 | 9 | INITIALIZE_EASYLOGGINGPP 10 | 11 | using namespace std; 12 | using namespace dev; 13 | 14 | int main() 15 | { 16 | int d = 256; // dimension 17 | int nb = 1000000; // database size 18 | int nq = 5; // nb of queries 19 | int k = 4; 20 | 21 | // this is typically the fastest one. 22 | string searchMethod("IDMap,Flat"); 23 | 24 | // these ones have better memory usage 25 | // string searchMethod("IVF64,Flat"); 26 | // string searchMethod("PQ32"); 27 | // string searchMethod("PCA80,Flat"); 28 | // string searchMethod("IVF64,PQ8+16"); 29 | // string searchMethod("IVF64,PQ32"); 30 | // string searchMethod("IMI2x8,PQ32"); 31 | // string searchMethod("OPQ16_64,IMI2x8,PQ8+16"); 32 | 33 | vector xb(d * nb, 0); 34 | vector xid(nb, 0); 35 | vector xq(d * nq, 0); 36 | vector sumb(nb, 0); 37 | vector sumq(nq, 0); 38 | for (int i = 0; i < nb; i++) 39 | { 40 | for (int j = 0; j < d; j++) 41 | { 42 | xb[d * i + j] = drand48(); 43 | sumb[i] += xb[d * i + j] * xb[d * i + j]; 44 | } 45 | sumb[i] = sqrt(sumb[i]); 46 | xid[i] = i; 47 | } 48 | for (int i = 0; i < nb; i++) 49 | { 50 | for (int j = 0; j < d; j++) 51 | { 52 | xb[d * i + j] /= sumb[i]; 53 | } 54 | } 55 | for (int i = 0; i < nq; i++) 56 | { 57 | for (int j = 0; j < d; j++) 58 | { 59 | xq[d * i + j] = drand48(); 60 | sumq[i] += xq[d * i + j] * xq[d * i + j]; 61 | } 62 | sumq[i] = sqrt(sumq[i]); 63 | } 64 | for (int i = 0; i < nq; i++) 65 | { 66 | for (int j = 0; j < d; j++) 67 | { 68 | xq[d * i + j] /= sumq[i]; 69 | } 70 | } 71 | 72 | shared_ptr index(new faissSearch(searchMethod, d, true, true)); 73 | cout << "Search " << nq << " from " << nb << " | Use " << searchMethod << endl; 74 | 75 | chrono::system_clock::time_point t1 = chrono::system_clock::now(); 76 | // index->train(nb, xb.data()); 77 | chrono::system_clock::time_point t2 = chrono::system_clock::now(); 78 | cout << "|||Training time : " << (chrono::duration_cast(t2 - t1)).count() << " ms|||" << endl; 79 | cout << "is_trained = " << (index->is_trained ? "true" : "false") << endl; 80 | // assert(index->is_trained); 81 | // index->add(nb, xb); // add vectors to the index 82 | index->add_with_ids(nb, xb.data(), xid.data()); // add vectors to the index 83 | cout << "ntotal = " << index->ntotal << endl; 84 | chrono::system_clock::time_point t3 = chrono::system_clock::now(); 85 | cout << "|||Adding time : " << (chrono::duration_cast(t3 - t2)).count() << " ms|||" << endl; 86 | 87 | { // sanity check: search 5 first vectors of xb 88 | vector I(k * nq, 0); 89 | vector D(k * nq, 0); 90 | vector xb5(xb.begin(), xb.begin() + nq * d); 91 | index->search(nq, xb5.data(), k, D.data(), I.data()); 92 | 93 | // print results 94 | cout << "I=" << endl; 95 | for (int i = 0; i < nq; i++) 96 | { 97 | for (int j = 0; j < k; j++) 98 | cout << setprecision(3) << setw(5) << I[i * k + j] << " "; 99 | cout << endl; 100 | } 101 | 102 | cout << "D=" << endl; 103 | for (int i = 0; i < nq; i++) 104 | { 105 | for (int j = 0; j < k; j++) 106 | cout << setw(5) << D[i * k + j] << " "; 107 | cout << endl; 108 | } 109 | } 110 | 111 | { // search xq 112 | 113 | vector I(k * nq, 0); 114 | vector D(k * nq, 0); 115 | chrono::system_clock::time_point t4 = chrono::system_clock::now(); 116 | index->search(nq, xq.data(), k, D.data(), I.data()); 117 | chrono::system_clock::time_point t5 = chrono::system_clock::now(); 118 | cout << "|||Searching time: " << (chrono::duration_cast(t5 - t4)).count() << " ms|||" << endl; 119 | 120 | // print results 121 | cout << "I (5 first results)=" << endl; 122 | for (int i = 0; i < nq; i++) 123 | { 124 | for (int j = 0; j < k; j++) 125 | cout << setw(8) << I[i * k + j] << " "; 126 | cout << endl; 127 | } 128 | 129 | cout << "D (5 first results)=" << endl; 130 | for (int i = 0; i < nq; i++) 131 | { 132 | for (int j = 0; j < k; j++) 133 | cout << setw(8) << D[i * k + j] << " "; 134 | cout << endl; 135 | } 136 | } 137 | 138 | return 0; 139 | } -------------------------------------------------------------------------------- /test/FaissLoadTest.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "common/easylog++.h" 9 | #include "libSearch/FaissInterface.h" 10 | 11 | INITIALIZE_EASYLOGGINGPP 12 | 13 | using namespace std; 14 | using namespace dev; 15 | 16 | bool writeToFile(vector &data, string &filePath, int d, int n) 17 | { 18 | int header[5] = {0}; 19 | ofstream out(filePath, ofstream::out | ofstream::binary); 20 | if (!out.is_open()) 21 | { 22 | return false; 23 | } 24 | header[0] = d; 25 | header[1] = n; 26 | header[2] = 1; 27 | out.write((char *)header, sizeof(header)); 28 | for (int i = 0; i < header[1]; i++) 29 | { 30 | out.write((char *)&i, sizeof(int)); 31 | out.write((char *)&data[i * header[0]], header[0] * sizeof(float)); 32 | } 33 | out.close(); 34 | return true; 35 | } 36 | 37 | int main() 38 | { 39 | int d = 256; // dimension 40 | int nb = 200000; // database size 41 | int nq = 10000; // nb of queries 42 | 43 | string searchMethod("Flat"); 44 | 45 | vector xb(d * nb, 0); 46 | vector xq(d * nq, 0); 47 | for (int i = 0; i < nb; i++) 48 | { 49 | for (int j = 0; j < d; j++) 50 | xb[d * i + j] = drand48(); 51 | xb[d * i] += i / 1000.; 52 | } 53 | 54 | for (int i = 0; i < nq; i++) 55 | { 56 | for (int j = 0; j < d; j++) 57 | xq[d * i + j] = drand48(); 58 | xq[d * i] += i / 1000.; 59 | } 60 | shared_ptr index(new faissSearch(searchMethod, d, true, true)); 61 | 62 | string fileName = "data.bin"; 63 | if (!writeToFile(xb, fileName, d, nb)) 64 | { 65 | cout << "data write to file failed." << endl; 66 | return 0; 67 | } 68 | cout << "data write to file done." << endl; 69 | unordered_map> umap; 70 | if (!index->load(fileName, umap)) 71 | { 72 | cout << "load data with umap is failed." << endl; 73 | return 0; 74 | } 75 | for (unsigned int i = 0; i < umap.size(); i++) 76 | { 77 | for (int j = 0; j < d; j++) 78 | { 79 | if (xb[i * d + j] != umap[i][j]) 80 | { 81 | cout << "number " << i << "data is wrong!" << endl; 82 | return -1; 83 | } 84 | } 85 | } 86 | cout << "load data with umap done." << endl; 87 | 88 | vector ids; 89 | vector features; 90 | if (!index->load(fileName, ids, features)) 91 | { 92 | cout << "load data without umap failed." << endl; 93 | return 0; 94 | } 95 | for (size_t i = 0; i < features.size(); i++) 96 | { 97 | if (xb[i] != features[i]) 98 | { 99 | cout << "number " << i << "data is wrong!" << endl; 100 | return -1; 101 | } 102 | } 103 | cout << "load data without umap done." << endl; 104 | 105 | return 0; 106 | } -------------------------------------------------------------------------------- /test/HNSWSearch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "common/easylog++.h" 7 | #include "libSearch/FaissInterface.h" 8 | 9 | INITIALIZE_EASYLOGGINGPP 10 | 11 | using namespace std; 12 | using namespace dev; 13 | 14 | int main() 15 | { 16 | int d = 512; // dimension 17 | int nb = 1000000; // database size 18 | int nq = 1; // nb of queries 19 | 20 | // this is typically the fastest one. 21 | string searchMethod("HNSW32"); 22 | 23 | // these ones have better memory usage 24 | // string searchMethod("Flat"); 25 | // string searchMethod("PQ32"); 26 | // string searchMethod("PCA80,Flat"); 27 | // string searchMethod("IVF64,PQ8+16"); 28 | // string searchMethod("IVF64,PQ32"); 29 | // string searchMethod("IMI2x8,PQ32"); 30 | // string searchMethod("OPQ16_64,IMI2x8,PQ8+16"); 31 | 32 | vector xb(d * nb, 0); 33 | vector xq(d * nq, 0); 34 | for (int i = 0; i < nb; i++) 35 | { 36 | for (int j = 0; j < d; j++) 37 | xb[d * i + j] = drand48(); 38 | xb[d * i] += i / 1000.; 39 | } 40 | 41 | for (int i = 0; i < nq; i++) 42 | { 43 | for (int j = 0; j < d; j++) 44 | xq[d * i + j] = drand48(); 45 | xq[d * i] += i / 1000.; 46 | } 47 | shared_ptr index(new faissSearch(searchMethod, d ,false)); 48 | cout << "Search " << nq << " from " << nb << " | Use " << searchMethod << endl; 49 | 50 | chrono::system_clock::time_point t1 = chrono::system_clock::now(); 51 | //index->train(nb, xb.data()); 52 | chrono::system_clock::time_point t2 = chrono::system_clock::now(); 53 | cout << "|||Training time : " << (chrono::duration_cast(t2 - t1)).count() << " ms|||" << endl; 54 | cout << "is_trained = " << (index->is_trained ? "true" : "false") << endl; 55 | //assert(index->is_trained); 56 | index->add(nb,xb); // add vectors to the index 57 | cout << "ntotal = " << index->ntotal << endl; 58 | chrono::system_clock::time_point t3 = chrono::system_clock::now(); 59 | cout << "|||Adding time : " << (chrono::duration_cast(t3 - t2)).count() << " ms|||" << endl; 60 | 61 | int k = 4; 62 | 63 | { // sanity check: search 5 first vectors of xb 64 | vector I(k * 5, 0); 65 | vector D(k * 5, 0); 66 | vector xb5(xb.begin(), xb.begin() + 5 * d); 67 | index->search(5, xb5.data(), k, D.data(), I.data()); 68 | chrono::system_clock::time_point t4 = chrono::system_clock::now(); 69 | cout << "|||Searching time : " << (chrono::duration_cast(t4 - t3)).count() << " ms|||" << endl; 70 | 71 | // print results 72 | cout << "I=" << endl; 73 | for (int i = 0; i < 5; i++) 74 | { 75 | for (int j = 0; j < k; j++) 76 | cout << setprecision(3) << setw(5) << I[i * k + j] << " "; 77 | cout << endl; 78 | } 79 | 80 | cout << "D=" << endl; 81 | for (int i = 0; i < 5; i++) 82 | { 83 | for (int j = 0; j < k; j++) 84 | cout << setw(5) << D[i * k + j]; 85 | cout << endl; 86 | } 87 | } 88 | 89 | 90 | return 0; 91 | } 92 | -------------------------------------------------------------------------------- /test/python-test/add.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | 13 | def main(): 14 | ip = '127.0.0.1' 15 | port = '2333' 16 | api = '/add' 17 | length = len(sys.argv) 18 | if length == 2: 19 | ip = sys.argv[1] 20 | elif length == 3: 21 | ip = sys.argv[1] 22 | port = sys.argv[2] 23 | elif length > 3: 24 | print 'Usage: python add.py [ip] [port]' 25 | return 26 | ntotal = 20000 27 | dim = 256 28 | features = [[0 for col in range(dim)] for row in range(ntotal)] 29 | dic = {} 30 | for i in range(0, ntotal): 31 | dic[i] = list(np.random.rand(dim)) 32 | 33 | address = 'http://' + ip + ':' + port 34 | url = address + api 35 | headers = {'Content-Type': 'application/json'} 36 | 37 | data = {'ntotal': ntotal, 'data': dic} 38 | 39 | start = time.clock() 40 | r = requests.post(url, data=json.dumps(data)) 41 | elapsed = (time.clock() - start) 42 | result = r.json() 43 | print result 44 | print 'Time used:', elapsed * 1000 45 | return 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /test/python-test/query.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | def main(): 13 | ip = '127.0.0.1' 14 | port = '2333' 15 | api = '/search' 16 | length = len(sys.argv) 17 | if length == 2: 18 | ip = sys.argv[1] 19 | elif length == 3: 20 | ip = sys.argv[1] 21 | port = sys.argv[2] 22 | elif length > 3: 23 | print 'Usage: python query.py [ip] [port]' 24 | return 25 | qtotal = 3 26 | dim = 256 27 | k = 2 28 | dic = {} 29 | features = [[0 for col in range(dim)] for row in range(qtotal)] 30 | 31 | for i in range(0,qtotal): 32 | dic['q' + str(i)] = list(np.random.rand(dim)) 33 | 34 | address = 'http://' + ip + ':' + port 35 | url = address + api 36 | headers = {'Content-Type': 'application/json'} 37 | data = { 38 | 'qtotal': qtotal, 39 | 'queries':dic, 40 | 'topk': k 41 | } 42 | 43 | start = time.clock() 44 | r = requests.post(url, data=json.dumps(data)) 45 | elapsed = (time.clock() - start) 46 | result = r.json() 47 | print result 48 | print 'Time used:', elapsed * 1000 49 | return 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /test/python-test/queryRange.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | 13 | def main(): 14 | ip = '127.0.0.1' 15 | port = '2333' 16 | api = '/searchRange' 17 | length = len(sys.argv) 18 | if length == 2: 19 | ip = sys.argv[1] 20 | elif length == 3: 21 | ip = sys.argv[1] 22 | port = sys.argv[2] 23 | elif length > 3: 24 | print 'Usage: python queryRange.py [ip] [port]' 25 | return 26 | qtotal = 1 27 | dim = 256 28 | r = 73 29 | dic = {} 30 | features = [[0 for col in range(dim)] for row in range(qtotal)] 31 | for i in range(0, qtotal): 32 | dic['q' + str(i)] = list(np.random.rand(dim)) 33 | 34 | address = 'http://' + ip + ':' + port 35 | url = address + api 36 | headers = {'Content-Type': 'application/json'} 37 | data = {'nq': qtotal, 'queries': dic, 'radius': r} 38 | 39 | start = time.clock() 40 | r = requests.post(url, data=json.dumps(data)) 41 | elapsed = (time.clock() - start) 42 | result = r.json() 43 | print result 44 | print 'Time used:', elapsed * 1000 45 | return 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /test/python-test/querydays.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | def main(): 13 | ip = '127.0.0.1' 14 | port = '2333' 15 | api = '/searchDays' 16 | length = len(sys.argv) 17 | if length == 2: 18 | ip = sys.argv[1] 19 | elif length == 3: 20 | ip = sys.argv[1] 21 | port = sys.argv[2] 22 | elif length > 3: 23 | print 'Usage: python searchDays.py [ip] [port]' 24 | return 25 | qtotal = 1 26 | dim = 256 27 | k = 4 28 | dic = {} 29 | features = [[0 for col in range(dim)] for row in range(qtotal)] 30 | days = ["20180424"] 31 | 32 | for i in range(0,qtotal): 33 | dic['q' + str(i)] = list(np.random.rand(dim)) 34 | 35 | address = 'http://' + ip + ':' + port 36 | url = address + api 37 | headers = {'Content-Type': 'application/json'} 38 | data = { 39 | 'qtotal': qtotal, 40 | 'queries':dic, 41 | 'topk': k, 42 | 'days':days 43 | } 44 | 45 | start = time.clock() 46 | r = requests.post(url, data=json.dumps(data)) 47 | elapsed = (time.clock() - start) 48 | result = r.json() 49 | print result 50 | print 'Time used:', elapsed * 1000 51 | return 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /test/python-test/reconfig.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | 13 | def main(): 14 | ip = '127.0.0.1' 15 | port = '2333' 16 | api = '/reconfig' 17 | length = len(sys.argv) 18 | if length == 2: 19 | ip = sys.argv[1] 20 | elif length == 3: 21 | ip = sys.argv[1] 22 | port = sys.argv[2] 23 | elif length > 3: 24 | print 'Usage: python reconfig.py [ip] [port]' 25 | return 26 | 27 | address = 'http://' + ip + ':' + port 28 | url = address + api 29 | headers = {'Content-Type': 'application/json'} 30 | 31 | data = {'reconfigFilePath': "config.json"} 32 | 33 | start = time.clock() 34 | r = requests.post(url, data=json.dumps(data)) 35 | elapsed = (time.clock() - start) 36 | result = r.json() 37 | print result 38 | print 'Time used:', elapsed * 1000 39 | return 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /test/python-test/remove.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | 13 | def main(): 14 | ip = '127.0.0.1' 15 | port = '2333' 16 | api = '/delete' 17 | length = len(sys.argv) 18 | if length == 2: 19 | ip = sys.argv[1] 20 | elif length == 3: 21 | ip = sys.argv[1] 22 | port = sys.argv[2] 23 | elif length > 3: 24 | print 'Usage: python remove.py [ip] [port]' 25 | return 26 | ntotal = 100 27 | ids = list(range(0, ntotal + 1)) 28 | address = 'http://' + ip + ':' + port 29 | url = address + api 30 | headers = {'Content-Type': 'application/json'} 31 | 32 | data = {'ids': ids} 33 | 34 | start = time.clock() 35 | r = requests.post(url, data=json.dumps(data)) 36 | elapsed = (time.clock() - start) 37 | result = r.json() 38 | print result 39 | print 'Time used:', elapsed * 1000 40 | return 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /test/python-test/removeRange.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import requests 6 | import base64 7 | import json 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | 13 | def main(): 14 | ip = '127.0.0.1' 15 | port = '2333' 16 | api = '/deleteRange' 17 | length = len(sys.argv) 18 | if length == 2: 19 | ip = sys.argv[1] 20 | elif length == 3: 21 | ip = sys.argv[1] 22 | port = sys.argv[2] 23 | elif length > 3: 24 | print 'Usage: python removeRange.py [ip] [port]' 25 | return 26 | 27 | address = 'http://' + ip + ':' + port 28 | url = address + api 29 | headers = {'Content-Type': 'application/json'} 30 | 31 | data = {'start': 0, 'end': 100} 32 | 33 | start = time.clock() 34 | r = requests.post(url, data=json.dumps(data)) 35 | elapsed = (time.clock() - start) 36 | result = r.json() 37 | print result 38 | print 'Time used:', elapsed * 1000 39 | return 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /test/sift1M.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD+Patents license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | // Copyright 2004-present Facebook. All Rights Reserved 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | #include "libSearch/FaissInterface.h" 25 | 26 | /** 27 | * To run this demo, please download the ANN_SIFT1M dataset from 28 | * 29 | * http://corpus-texmex.irisa.fr/ 30 | * 31 | * and unzip it to the sudirectory sift1M. 32 | **/ 33 | 34 | /***************************************************** 35 | * I/O functions for fvecs and ivecs 36 | *****************************************************/ 37 | 38 | float *fvecs_read(const char *fname, 39 | size_t *d_out, size_t *n_out) 40 | { 41 | FILE *f = fopen(fname, "r"); 42 | if (!f) 43 | { 44 | fprintf(stderr, "could not open %s\n", fname); 45 | perror(""); 46 | abort(); 47 | } 48 | int d; 49 | fread(&d, 1, sizeof(int), f); 50 | assert((d > 0 && d < 1000000) || !"unreasonable dimension"); 51 | fseek(f, 0, SEEK_SET); 52 | struct stat st; 53 | fstat(fileno(f), &st); 54 | size_t sz = st.st_size; 55 | assert(sz % ((d + 1) * 4) == 0 || !"weird file size"); 56 | size_t n = sz / ((d + 1) * 4); 57 | 58 | *d_out = d; 59 | *n_out = n; 60 | float *x = new float[n * (d + 1)]; 61 | size_t nr = fread(x, sizeof(float), n * (d + 1), f); 62 | assert(nr == n * (d + 1) || !"could not read whole file"); 63 | 64 | // shift array to remove row headers 65 | for (size_t i = 0; i < n; i++) 66 | memmove(x + i * d, x + 1 + i * (d + 1), d * sizeof(*x)); 67 | 68 | fclose(f); 69 | return x; 70 | } 71 | 72 | // not very clean, but works as long as sizeof(int) == sizeof(float) 73 | int *ivecs_read(const char *fname, size_t *d_out, size_t *n_out) 74 | { 75 | return (int *)fvecs_read(fname, d_out, n_out); 76 | } 77 | 78 | double elapsed() 79 | { 80 | struct timeval tv; 81 | gettimeofday(&tv, nullptr); 82 | return tv.tv_sec + tv.tv_usec * 1e-6; 83 | } 84 | 85 | int main() 86 | { 87 | double t0 = elapsed(); 88 | 89 | // this is typically the fastest one. 90 | const char *index_key = "IDMap,Flat"; 91 | 92 | // these ones have better memory usage 93 | // const char *index_key = "Flat"; 94 | // const char *index_key = "PQ32"; 95 | // const char *index_key = "PCA80,Flat"; 96 | // const char *index_key = "IVF4096,PQ8+16"; 97 | // const char *index_key = "IVF4096,PQ32"; 98 | // const char *index_key = "IMI2x8,PQ32"; 99 | // const char *index_key = "IMI2x8,PQ8+16"; 100 | // const char *index_key = "OPQ16_64,IMI2x8,PQ8+16"; 101 | 102 | faiss::Index *index; 103 | 104 | size_t d; 105 | 106 | { 107 | printf("[%.3f s] Loading train set\n", elapsed() - t0); 108 | 109 | size_t nt; 110 | float *xt = fvecs_read("/home/lyf/data/sift/sift_learn.fvecs", &d, &nt); 111 | 112 | printf("[%.3f s] Preparing index \"%s\" d=%ld\n", 113 | elapsed() - t0, index_key, d); 114 | index = faiss::index_factory(d, index_key); 115 | 116 | printf("[%.3f s] Training on %ld vectors\n", elapsed() - t0, nt); 117 | 118 | // index->train(nt, xt); 119 | delete[] xt; 120 | } 121 | 122 | { 123 | printf("[%.3f s] Loading database\n", elapsed() - t0); 124 | 125 | size_t nb, d2; 126 | float *xb = fvecs_read("/home/lyf/data/sift/sift_base.fvecs", &d2, &nb); 127 | assert(d == d2 || !"dataset does not have same dimension as train set"); 128 | 129 | printf("[%.3f s] Indexing database, size %ld*%ld\n", 130 | elapsed() - t0, nb, d); 131 | std::vector xid(nb, 0); 132 | for (size_t i = 0; i < nb; i++) 133 | { 134 | xid[i] = i; 135 | } 136 | index->add_with_ids(nb, xb, xid.data()); 137 | 138 | delete[] xb; 139 | } 140 | 141 | size_t nq; 142 | float *xq; 143 | 144 | { 145 | printf("[%.3f s] Loading queries\n", elapsed() - t0); 146 | 147 | size_t d2; 148 | xq = fvecs_read("/home/lyf/data/sift/sift_query.fvecs", &d2, &nq); 149 | assert(d == d2 || !"query does not have same dimension as train set"); 150 | } 151 | 152 | size_t k; // nb of results per query in the GT 153 | faiss::Index::idx_t *gt; // nq * k matrix of ground-truth nearest-neighbors 154 | 155 | { 156 | printf("[%.3f s] Loading ground truth for %ld queries\n", 157 | elapsed() - t0, nq); 158 | 159 | // load ground-truth and convert int to long 160 | size_t nq2; 161 | int *gt_int = ivecs_read("/home/lyf/data/sift/sift_groundtruth.ivecs", &k, &nq2); 162 | assert(nq2 == nq || !"incorrect nb of ground truth entries"); 163 | 164 | gt = new faiss::Index::idx_t[k * nq]; 165 | for (size_t i = 0; i < k * nq; i++) 166 | { 167 | gt[i] = gt_int[i]; 168 | } 169 | delete[] gt_int; 170 | } 171 | 172 | // Result of the auto-tuning 173 | std::string selected_params; 174 | 175 | { // run auto-tuning 176 | 177 | printf("[%.3f s] Preparing auto-tune criterion 1-recall at 1 " 178 | "criterion, with k=%ld nq=%ld\n", 179 | elapsed() - t0, k, nq); //k=100 nq=10000 180 | 181 | faiss::OneRecallAtRCriterion crit(nq, 1); 182 | crit.set_groundtruth(k, nullptr, gt); 183 | crit.nnn = k; // by default, the criterion will request only 1 NN 184 | 185 | printf("[%.3f s] Preparing auto-tune parameters\n", elapsed() - t0); 186 | 187 | faiss::ParameterSpace params; 188 | params.initialize(index); 189 | 190 | printf("[%.3f s] Auto-tuning over %ld parameters (%ld combinations)\n", 191 | elapsed() - t0, params.parameter_ranges.size(), 192 | params.n_combinations()); 193 | 194 | faiss::OperatingPoints ops; 195 | params.explore(index, nq, xq, crit, &ops); 196 | 197 | printf("[%.3f s] Found the following operating points: \n", 198 | elapsed() - t0); 199 | 200 | ops.display(); 201 | 202 | // keep the first parameter that obtains > 0.5 1-recall@1 203 | for (size_t i = 0; i < ops.optimal_pts.size(); i++) 204 | { 205 | if (ops.optimal_pts[i].perf > 0.5) 206 | { 207 | selected_params = ops.optimal_pts[i].key; 208 | break; 209 | } 210 | } 211 | assert(selected_params.size() >= 0 || 212 | !"could not find good enough op point"); 213 | } 214 | 215 | { // Use the found configuration to perform a search 216 | 217 | faiss::ParameterSpace params; 218 | 219 | printf("[%.3f s] Setting parameter configuration \"%s\" on index\n", 220 | elapsed() - t0, selected_params.c_str()); 221 | 222 | params.set_index_parameters(index, selected_params.c_str()); 223 | 224 | printf("[%.3f s] Perform a search on %ld queries\n", 225 | elapsed() - t0, nq); 226 | 227 | // output buffers 228 | faiss::Index::idx_t *I = new faiss::Index::idx_t[nq * k]; 229 | float *D = new float[nq * k]; 230 | 231 | index->search(nq, xq, k, D, I); 232 | 233 | printf("[%.3f s] Compute recalls\n", elapsed() - t0); 234 | 235 | // evaluate result by hand. 236 | int n_1 = 0, n_10 = 0, n_100 = 0; 237 | for (size_t i = 0; i < nq; i++) 238 | { 239 | int gt_nn = gt[i * k]; 240 | for (size_t j = 0; j < k; j++) 241 | { 242 | if (I[i * k + j] == gt_nn) 243 | { 244 | if (j < 1) 245 | n_1++; 246 | if (j < 10) 247 | n_10++; 248 | if (j < 100) 249 | n_100++; 250 | } 251 | } 252 | } 253 | printf("R@1 = %.4f\n", n_1 / float(nq)); 254 | printf("R@10 = %.4f\n", n_10 / float(nq)); 255 | printf("R@100 = %.4f\n", n_100 / float(nq)); 256 | } 257 | 258 | delete[] xq; 259 | delete[] gt; 260 | delete index; 261 | return 0; 262 | } 263 | -------------------------------------------------------------------------------- /test/testRemove.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "common/easylog++.h" 7 | #include "common/memusage.h" 8 | #include "libSearch/FaissInterface.h" 9 | 10 | INITIALIZE_EASYLOGGINGPP 11 | 12 | using namespace std; 13 | using namespace dev; 14 | 15 | int main() 16 | { 17 | int d = 256; // dimension 18 | int nb = 1000; // database size 19 | int nq = 1; // nb of queries 20 | string searchMethod("IDMap,Flat"); 21 | 22 | vector xb(d * nb, 0); 23 | vector xq(d * nq, 0); 24 | vector xid(nb, 0); 25 | vector sumb(nb, 0); 26 | vector sumq(nq, 0); 27 | for (int i = 0; i < nb; i++) 28 | { 29 | for (int j = 0; j < d; j++) 30 | { 31 | xb[d * i + j] = drand48(); 32 | sumb[i] += xb[d * i + j] * xb[d * i + j]; 33 | } 34 | sumb[i] = sqrt(sumb[i]); 35 | xid[i] = i; 36 | } 37 | for (int i = 0; i < nb; i++) 38 | { 39 | for (int j = 0; j < d; j++) 40 | { 41 | xb[d * i + j] /= sumb[i]; 42 | } 43 | } 44 | for (int i = 0; i < nq; i++) 45 | { 46 | for (int j = 0; j < d; j++) 47 | { 48 | xq[d * i + j] = drand48(); 49 | sumq[i] += xq[d * i + j] * xq[d * i + j]; 50 | } 51 | sumq[i] = sqrt(sumq[i]); 52 | } 53 | for (int i = 0; i < nq; i++) 54 | { 55 | for (int j = 0; j < d; j++) 56 | { 57 | xq[d * i + j] /= sumq[i]; 58 | } 59 | } 60 | 61 | shared_ptr index(new faissSearch(searchMethod, d, true, true)); 62 | //index->train(nb, xb.data()); 63 | chrono::system_clock::time_point t1 = chrono::system_clock::now(); 64 | index->add_with_ids(nb, xb.data(), xid.data()); // add vectors to the index 65 | chrono::system_clock::time_point t2 = chrono::system_clock::now(); 66 | cout << "|||Adding time : " << (chrono::duration_cast(t2 - t1)).count() << " ms|||" << endl; 67 | cout << "n:" << index->ntotal << endl; 68 | 69 | int k = 5; 70 | { // search xq 71 | vector I(k * nq, 0); 72 | vector D(k * nq, 0); 73 | vector xb1(xb.begin(), xb.begin() + nq * d); 74 | chrono::system_clock::time_point t4 = chrono::system_clock::now(); 75 | 76 | index->search(nq, xb1.data(), k, D.data(), I.data()); 77 | 78 | chrono::system_clock::time_point t5 = chrono::system_clock::now(); 79 | cout << "|||Searching time: " << (chrono::duration_cast(t5 - t4)).count() << " ms|||" << endl; 80 | 81 | // print results 82 | cout << "I=" << endl; 83 | for (int i = 0; i < nq; i++) 84 | { 85 | for (int j = 0; j < k; j++) 86 | cout << setprecision(3) << setw(8) << I[i * k + j] << " "; 87 | cout << endl; 88 | } 89 | 90 | cout << "D=" << endl; 91 | for (int i = 0; i < nq; i++) 92 | { 93 | for (int j = 0; j < k; j++) 94 | cout << setw(8) << D[i * k + j] << " "; 95 | cout << endl; 96 | } 97 | } 98 | 99 | { 100 | vector I(k * nq, 0); 101 | vector D(k * nq, 0); 102 | vector xb1(xb.begin(), xb.begin() + nq * d); 103 | faiss::IDSelectorRange ids(int(nb / 2), nb); 104 | vector idlist(100, 0); 105 | long nremove = 0; 106 | long loc = 0; 107 | 108 | chrono::system_clock::time_point t6 = chrono::system_clock::now(); 109 | index->remove_ids_range(ids, nremove); 110 | 111 | for (size_t i = 0; i < idlist.size(); i++) 112 | { 113 | if (i == idlist.size() - 1) 114 | loc = 1; 115 | faiss::IDSelectorRange id(i, i + 1); 116 | index->remove_ids(id, nremove, loc); 117 | loc = -1; 118 | } 119 | 120 | chrono::system_clock::time_point t7 = chrono::system_clock::now(); 121 | 122 | cout << "|||Removeing time: " << (chrono::duration_cast(t7 - t6)).count() << " ms|||" << endl; 123 | cout << "n:" << index->ntotal << endl; 124 | chrono::system_clock::time_point t8 = chrono::system_clock::now(); 125 | index->search(nq, xb1.data(), k, D.data(), I.data()); 126 | chrono::system_clock::time_point t9 = chrono::system_clock::now(); 127 | cout << "|||Searching time: " << (chrono::duration_cast(t9 - t8)).count() << " ms|||" << endl; 128 | 129 | // print results 130 | cout << "I=" << endl; 131 | for (int i = 0; i < nq; i++) 132 | { 133 | for (int j = 0; j < k; j++) 134 | cout << setprecision(3) << setw(8) << I[i * k + j] << " "; 135 | cout << endl; 136 | } 137 | 138 | cout << "D=" << endl; 139 | for (int i = 0; i < nq; i++) 140 | { 141 | for (int j = 0; j < k; j++) 142 | cout << setw(8) << D[i * k + j] << " "; 143 | cout << endl; 144 | } 145 | } 146 | double vm, rss; 147 | process_mem_usage(&vm, &rss); 148 | vm /= double(1024 * 1024); 149 | rss /= double(1024 * 1024); 150 | printf("done | VM %.1fgb | RSS %.1fgb \n", vm, rss); 151 | return 0; 152 | } -------------------------------------------------------------------------------- /test/testSearchRange.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "common/easylog++.h" 7 | #include "common/memusage.h" 8 | #include "libSearch/FaissInterface.h" 9 | #include 10 | #include 11 | #include 12 | 13 | INITIALIZE_EASYLOGGINGPP 14 | 15 | using namespace std; 16 | using namespace dev; 17 | 18 | int main() 19 | { 20 | int d = 256; // dimension 21 | int nb = 1000000; // database size 22 | int nq = 1; // nb of queries 23 | string searchMethod("IDMap,Flat"); 24 | 25 | vector xb(d * nb, 0); 26 | vector xq(d * nq, 0); 27 | vector xid(nb, 0); 28 | vector sumb(nb, 0); 29 | vector sumq(nq, 0); 30 | for (int i = 0; i < nb; i++) 31 | { 32 | for (int j = 0; j < d; j++) 33 | { 34 | xb[d * i + j] = drand48(); 35 | sumb[i] += xb[d * i + j] * xb[d * i + j]; 36 | } 37 | sumb[i] = sqrt(sumb[i]); 38 | xid[i] = i; 39 | } 40 | for (int i = 0; i < nb; i++) 41 | { 42 | for (int j = 0; j < d; j++) 43 | { 44 | xb[d * i + j] /= sumb[i]; 45 | } 46 | } 47 | for (int i = 0; i < nq; i++) 48 | { 49 | for (int j = 0; j < d; j++) 50 | { 51 | xq[d * i + j] = drand48(); 52 | sumq[i] += xq[d * i + j] * xq[d * i + j]; 53 | } 54 | sumq[i] = sqrt(sumq[i]); 55 | } 56 | for (int i = 0; i < nq; i++) 57 | { 58 | for (int j = 0; j < d; j++) 59 | { 60 | xq[d * i + j] /= sumq[i]; 61 | } 62 | } 63 | 64 | shared_ptr index(new faissSearch(searchMethod, d, true, true)); 65 | chrono::system_clock::time_point t1 = chrono::system_clock::now(); 66 | index->add_with_ids(nb, xb.data(), xid.data()); // add vectors to the index 67 | chrono::system_clock::time_point t2 = chrono::system_clock::now(); 68 | cout << "|||Adding time : " << (chrono::duration_cast(t2 - t1)).count() << " ms|||" << endl; 69 | cout << "n:" << index->ntotal << endl; 70 | 71 | int k = 5; 72 | { // search xq 73 | vector I(k * nq, 0); 74 | vector D(k * nq, 0); 75 | vector xb1(xb.begin(), xb.begin() + nq * d); 76 | chrono::system_clock::time_point t4 = chrono::system_clock::now(); 77 | 78 | index->search(nq, xb1.data(), k, D.data(), I.data()); 79 | 80 | chrono::system_clock::time_point t5 = chrono::system_clock::now(); 81 | cout << "|||Searching time: " << (chrono::duration_cast(t5 - t4)).count() << " ms|||" << endl; 82 | 83 | // print results 84 | cout << "I=" << endl; 85 | for (int i = 0; i < nq; i++) 86 | { 87 | for (int j = 0; j < k; j++) 88 | cout << setprecision(3) << setw(8) << I[i * k + j] << " "; 89 | cout << endl; 90 | } 91 | 92 | cout << "D=" << endl; 93 | for (int i = 0; i < nq; i++) 94 | { 95 | for (int j = 0; j < k; j++) 96 | cout << setw(8) << D[i * k + j] << " "; 97 | cout << endl; 98 | } 99 | } 100 | const char *dir = "../indexFile"; 101 | if (access(dir, 0) == -1) 102 | { 103 | if (mkdir(dir, 0777) == 0) 104 | { 105 | cout << "mkdir success." << endl; 106 | } 107 | } 108 | 109 | // write to faissIndex file 110 | time_t nowtime = time(NULL); 111 | char tmp[64]; 112 | strftime(tmp, sizeof(tmp), "%Y%m%d", localtime(&nowtime)); 113 | string date(tmp); 114 | const string filePath = "../indexFile/" + date + ".faissIndex"; 115 | chrono::system_clock::time_point t8 = chrono::system_clock::now(); 116 | index->write_index(filePath.c_str()); 117 | chrono::system_clock::time_point t9 = chrono::system_clock::now(); 118 | cout << "|||writing index time: " << (chrono::duration_cast(t9 - t8)).count() << " ms|||" << endl; 119 | 120 | shared_ptr new_index(new faissSearch(searchMethod, d, false)); 121 | chrono::system_clock::time_point t10 = chrono::system_clock::now(); 122 | new_index->read_index(filePath.c_str()); 123 | chrono::system_clock::time_point t11 = chrono::system_clock::now(); 124 | cout << "|||reading index time: " << (chrono::duration_cast(t11 - t10)).count() << " ms|||" << endl; 125 | 126 | { 127 | vector I(k * nq, 0); 128 | vector D(k * nq, 0); 129 | vector xb1(xb.begin(), xb.begin() + nq * d); 130 | faiss::RangeSearchResult *result = new faiss::RangeSearchResult(nq); 131 | float radius = 0.81; 132 | 133 | chrono::system_clock::time_point t6 = chrono::system_clock::now(); 134 | new_index->search_range(nq, xb1.data(), radius, result); 135 | chrono::system_clock::time_point t7 = chrono::system_clock::now(); 136 | cout << "|||SearchRange time: " << (chrono::duration_cast(t7 - t6)).count() << " ms|||" << endl; 137 | cout << "n:" << new_index->ntotal << endl; 138 | 139 | // print results 140 | for (int i = 0; i < nq; i++) 141 | { 142 | cout << "query " << i << " results:" << endl; 143 | for (size_t j = result->lims[i]; j < result->lims[i + 1]; j++) 144 | { 145 | cout << result->labels[j] << ": " << result->distances[j] << " "; 146 | } 147 | } 148 | } 149 | 150 | double vm, rss; 151 | process_mem_usage(&vm, &rss); 152 | vm /= double(1024 * 1024); 153 | rss /= double(1024 * 1024); 154 | printf("\ndone | VM %.1fgb | RSS %.1fgb \n", vm, rss); 155 | return 0; 156 | } --------------------------------------------------------------------------------